mirror of https://github.com/alibaba/MNN.git
				
				
				
			[MNN:Sync] Sync Internal 2.6.3
This commit is contained in:
		
							parent
							
								
									c603c52955
								
							
						
					
					
						commit
						98ba00c2f3
					
				|  | @ -271,7 +271,7 @@ struct GemmBatchedIdentityThreadblockSwizzle { | |||
|     return GemmCoord( | ||||
|       (problem_size.m() + tile_size.m() - 1) / tile_size.m(), | ||||
|       (problem_size.n() + tile_size.n() - 1) / tile_size.n(), | ||||
|       batch_count % (1 << 16)); | ||||
|       batch_count >= 65536 ? 65535 : batch_count); | ||||
|   } | ||||
| 
 | ||||
|   /// Computes CUDA grid dimensions given a size in units of logical tiles
 | ||||
|  |  | |||
|  | @ -22,6 +22,9 @@ std::string OpenCLTarget::type() { | |||
| } | ||||
| std::string OpenCLTarget::macro() { | ||||
|     return | ||||
|     "#ifdef MNN_SUPPORT_FP16\n" | ||||
|     "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" | ||||
|     "#endif\n" | ||||
|     "#define OFFSET_CHECK\\\n" | ||||
|     "\tconst int c = get_global_id(0), w = get_global_id(1), hb = get_global_id(2);\\\n" | ||||
|     "\tif (c >= global_size_dim0 || w >= global_size_dim1 || hb >= global_size_dim2) { return; }\\\n" | ||||
|  | @ -113,61 +116,61 @@ std::string OpenCLTarget::codegen(std::vector<std::string>& inputs, const Comman | |||
|                     ss << inpName << "=" << operand << " * " << operand; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_ERF: | ||||
|                     ss << inpName << "=erf(convert_float4(" << operand << "))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(erf(convert_float4(" << operand << ")))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_ERFC: | ||||
|                     ss << inpName << "=erfc(convert_float4(" << operand << "))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(erfc(convert_float4(" << operand << ")))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_SQRT: | ||||
|                     ss << inpName << "=sqrt(convert_float4(" << operand << "))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(sqrt(convert_float4(" << operand << ")))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_RSQRT: | ||||
|                     ss << inpName << "=rsqrt(convert_float4(" << operand << "))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(rsqrt(convert_float4(" << operand << ")))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_ABS: | ||||
|                     ss << inpName << "=fabs(convert_float4(" << operand << "))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(fabs(convert_float4(" << operand << ")))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_SIN: | ||||
|                     ss << inpName << "=sin(convert_float4(" << operand << "))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(sin(convert_float4(" << operand << ")))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_COS: | ||||
|                     ss << inpName << "=cos(convert_float4(" << operand << "))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(cos(convert_float4(" << operand << ")))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_SIGN: | ||||
|                     ss << inpName << "=sign(convert_float4(" << operand << "))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(sign(convert_float4(" << operand << ")))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_EXP: | ||||
|                     ss << inpName << "=exp(convert_float4(" << operand << "))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(exp(convert_float4(" << operand << ")))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_NEG: | ||||
|                     ss << inpName << "=-(" << operand << ")"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_TAN: | ||||
|                     ss << inpName << "=tan(convert_float4(" << operand << "))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(tan(convert_float4(" << operand << ")))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_CEIL: | ||||
|                     ss << inpName << "=ceil(convert_float4(" << operand << "))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(ceil(convert_float4(" << operand << ")))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_LOG1P: | ||||
|                     ss << inpName << "=log1p(convert_float4(" << operand << "))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(log1p(convert_float4(" << operand << ")))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_FLOOR: | ||||
|                     ss << inpName << "=floor(convert_float4(" << operand << "))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(floor(convert_float4(" << operand << ")))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_ROUND: | ||||
|                     ss << inpName << "=round(convert_float4(" << operand << "))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(round(convert_float4(" << operand << ")))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_SIGMOID: | ||||
|                     ss << inpName << "=native_recip((float4)1+native_exp(convert_float4(-" << operand << ")))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(native_recip((float4)1+native_exp(convert_float4(-" << operand << "))))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_TANH: | ||||
|                     ss << inpName << "=tanh(convert_float4(" << operand << "))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(tanh(convert_float4(" << operand << ")))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_RECIPROCAL: | ||||
|                     ss << inpName << "=native_recip(convert_float4(" << operand << "))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(native_recip(convert_float4(" << operand << ")))"; | ||||
|                     break; | ||||
|                 case UnaryOpOperation_LOG: | ||||
|                     ss << inpName << "=native_log(convert_float4(" << operand << "+(float4)((float)0.0000001)))"; | ||||
|                     ss << inpName << "=CONVERT_FLOAT4(native_log(convert_float4(" << operand << ")+(float4)((float)0.0000001)))"; | ||||
|                     break; | ||||
|                 default: | ||||
|                     MNN_ASSERT(false); | ||||
|  | @ -198,13 +201,13 @@ std::string OpenCLTarget::codegen(std::vector<std::string>& inputs, const Comman | |||
|     return ss.str(); | ||||
| } | ||||
| std::string OpenCLTarget::load(const std::string& base, const std::string& offset, const Command* cmd, std::string& inpName) { | ||||
|     return "FLOAT4 " + inpName + "=read_imagef(" + base + ", SAMPLER, " + offset + ")"; | ||||
|     return "FLOAT4 " + inpName + "=RI_F(" + base + ", SAMPLER, " + offset + ")"; | ||||
| } | ||||
| std::string OpenCLTarget::loadscalar(const std::string& base, std::string& inpName) { | ||||
|     return "FLOAT4 " + inpName + "=((float4)read_imagef(" + base + ", SAMPLER, (int2)(0, 0)).x)"; | ||||
|     return "FLOAT4 " + inpName + "=(RI_F(" + base + ", SAMPLER, (int2)(0, 0)).x)"; | ||||
| } | ||||
| std::string OpenCLTarget::store(const std::string base, const std::string& offset, const std::string& data) { | ||||
|     return "write_imagef(" + base + ", " + offset + ", " + data + ");\n"; | ||||
|     return "WI_F(" + base + ", " + offset + ", " + data + ");\n"; | ||||
| } | ||||
| 
 | ||||
| std::string OpenCLTarget::proto(const std::string& name, const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, bool hasSingleConvertRaster) { | ||||
|  |  | |||
|  | @ -22,7 +22,7 @@ private: | |||
|     std::string store(const std::string base, const std::string& offset, const std::string& data) override; | ||||
|     std::string proto(const std::string& name, const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, bool hasSingleConvertRaster = false) override; | ||||
|     template <typename T> | ||||
|     std::string numval(T t) { return "((float4)" + std::to_string(t) + ")"; } | ||||
|     std::string numval(T t) { return "((FLOAT4)" + std::to_string(t) + ")"; } | ||||
| }; | ||||
| 
 | ||||
| } | ||||
|  |  | |||
|  | @ -29,7 +29,59 @@ MNN在C++的基础上,增加了Python扩展。扩展单元包括两个部分 | |||
| ### MNNTools | ||||
| MNNTools提供目前主要是2个工具,用法可以参考[mnnconvert](../tools/python.html#mnnconvert)和[mnnquant](../tools/python.html#mnnquant) | ||||
| 
 | ||||
| ## 使用Python Session API | ||||
| ## 使用Python Module API | ||||
| ### 数据类型 | ||||
| Python中的`Module API`与C++中的函数名略有区别,用法相似。主要数据类型如下: | ||||
| - [_Module](../pymnn/_Module.md) 模型实例 | ||||
| - [Var](../pymnn/Var.md) 模型的输入输出 | ||||
| ### 推理流程 | ||||
| 基本推理流程如下: | ||||
| - [创建Module](../pymnn/nn.html#load-module-from-file-file-name-input-names-output-names-dynamic-shape-mutable-rearrange-backend-memory-mode-power-mode-precision-mode) | ||||
| - 创建输入: 使用`expr`或`numpy`函数创建`Var`即可作为输入 | ||||
| - [执行推理](../pymnn/_Module.html#forward-input) | ||||
| - 获取输出: 输出为`Var`类型,可以通过`expr`或`numpy`函数执行后处理 | ||||
| ### 示例 | ||||
| ```python | ||||
| import MNN.nn as nn | ||||
| import MNN.cv as cv | ||||
| import MNN.numpy as np | ||||
| import MNN.expr as expr | ||||
| 
 | ||||
| # 配置执行后端,线程数,精度等信息;key-vlaue请查看API介绍 | ||||
| config = {} | ||||
| config['precision'] = 'low' # 当硬件支持(armv8.2)时使用fp16推理 | ||||
| config['backend'] = 0       # CPU | ||||
| config['numThread'] = 4     # 线程数 | ||||
| 
 | ||||
| rt = nn.create_runtime_manager((config,)) | ||||
| # 加载模型创建_Module | ||||
| net = nn.load_module_from_file('mobilenet_v1.mnn', ['data'], ['prob'], runtime_manager=rt) | ||||
| 
 | ||||
| # 读取图片 | ||||
| image = cv.imread('cat.jpg') | ||||
| # 转换为float32, 形状为[224,224,3] | ||||
| image = cv.resize(image, (224, 224), mean=[103.94, 116.78, 123.68], norm=[0.017, 0.017, 0.017]) | ||||
| # 增加batch HWC to NHWC | ||||
| input_var = np.expand_dims(image, 0) | ||||
| # NHWC to NC4HW4 | ||||
| input_var = expr.convert(input_var, expr.NC4HW4) | ||||
| 
 | ||||
| # 执行推理 | ||||
| output_var = net.forward(input_var) | ||||
| 
 | ||||
| # NC4HW4 to NHWC | ||||
| output_var = expr.convert(output_var, expr.NHWC) | ||||
| # 打印出分类结果, 282为猫 | ||||
| print("output belong to class: {}".format(np.argmax(output_var))) | ||||
| # output belong to class: 282 | ||||
| ``` | ||||
| 其他示例可以参考[示例](../pymnn/RuntimeManager.html#example);也可以参考[示例工程](../start/demo.html#id5)。 | ||||
| 
 | ||||
| 
 | ||||
| ## 使用Python Session API *[deprecated]* | ||||
| 
 | ||||
| 不建议使用该API执行推理,建议使用Module API | ||||
| 
 | ||||
| ### 数据类型 | ||||
| Python中`Session API`的函数名与用法与C++基本一样。使用的主要数据类型如下: | ||||
| - [Interpreter](../pymnn/Interpreter.md) 解释器,持有模型资源 | ||||
|  | @ -118,107 +170,7 @@ print("output belong to class: {}".format(np.argmax(output_var, 1))) | |||
| # output belong to class: array([282, 385], dtype=int32) | ||||
| ``` | ||||
| 其他示例可以参考[示例](../pymnn/Interpreter.html#example);也可以参考[示例工程](../start/demo.html#session)。 | ||||
| ## 使用Python Module API | ||||
| ### 数据类型 | ||||
| Python中的`Module API`与C++中的函数名略有区别,用法相似。主要数据类型如下: | ||||
| - [_Module](../pymnn/_Module.md) 模型实例 | ||||
| - [Var](../pymnn/Var.md) 模型的输入输出 | ||||
| ### 推理流程 | ||||
| 基本推理流程如下: | ||||
| - [创建Module](../pymnn/nn.html#load-module-from-file-file-name-input-names-output-names-dynamic-shape-mutable-rearrange-backend-memory-mode-power-mode-precision-mode) | ||||
| - 创建输入: 使用`expr`或`numpy`函数创建`Var`即可作为输入 | ||||
| - [执行推理](../pymnn/_Module.html#forward-input) | ||||
| - 获取输出: 输出为`Var`类型,可以通过`expr`或`numpy`函数执行后处理 | ||||
| ### 示例 | ||||
| ```python | ||||
| import MNN.nn as nn | ||||
| import MNN.cv as cv | ||||
| import MNN.numpy as np | ||||
| import MNN.expr as expr | ||||
| 
 | ||||
| # 配置执行后端,线程数,精度等信息;key-vlaue请查看API介绍 | ||||
| config = {} | ||||
| config['precision'] = 'low' # 当硬件支持(armv8.2)时使用fp16推理 | ||||
| config['backend'] = 0       # CPU | ||||
| config['numThread'] = 4     # 线程数 | ||||
| 
 | ||||
| rt = nn.create_runtime_manager((config,)) | ||||
| # 加载模型创建_Module | ||||
| net = nn.load_module_from_file('mobilenet_v1.mnn', ['data'], ['prob'], runtime_manager=rt) | ||||
| 
 | ||||
| # 读取图片 | ||||
| image = cv.imread('cat.jpg') | ||||
| # 转换为float32, 形状为[224,224,3]         | ||||
| image = cv.resize(image, (224, 224), mean=[103.94, 116.78, 123.68], norm=[0.017, 0.017, 0.017]) | ||||
| # 增加batch HWC to NHWC | ||||
| input_var = np.expand_dims(image, 0) | ||||
| # NHWC to NC4HW4 | ||||
| input_var = expr.convert(input_var, expr.NC4HW4) | ||||
| 
 | ||||
| # 执行推理 | ||||
| output_var = net.forward(input_var) | ||||
| 
 | ||||
| # NC4HW4 to NHWC  | ||||
| output_var = expr.convert(output_var, expr.NHWC) | ||||
| # 打印出分类结果, 282为猫 | ||||
| print("output belong to class: {}".format(np.argmax(output_var))) | ||||
| # output belong to class: 282 | ||||
| ``` | ||||
| 其他示例可以参考[示例](../pymnn/RuntimeManager.html#example);也可以参考[示例工程](../start/demo.html#id5)。 | ||||
| 
 | ||||
| ## 使用Python Expr API | ||||
| ### 数据类型 | ||||
| Python的`Expr API`相比C++在命名和使用方式上略有区别,但是功能一致。主要数据类型如下: | ||||
| - [Var](../pymnn/Var.md) 表达式计算中的变量 | ||||
| ### 主要用法 | ||||
| 因为`Expr`不仅有模型推理的能力,还具备数值计算的能力。在实际使用中`Expr`被用作构图或者计算的情况更多,实际用来执行模型推理的情况并不多,当`Expr`用作模型推理时的主要流程如下: | ||||
| - [加载计算图](../pymnn/expr.html#load-as-dict-filename) | ||||
| - 获取输入输出:直接使用Python中的`dict`的方式获取,如:`net['input']` | ||||
| - [写入输入数据](../pymnn/Var.html#write-data) | ||||
| - [读取输出数据](../pymnn/Var.html#read):读取数据不限于`read`,尝试打印和使用都可能触发读取操作 | ||||
| ### 示例 | ||||
| `Expr`用作模型推理: | ||||
| ```python | ||||
| import MNN.cv as cv | ||||
| import MNN.numpy as np | ||||
| import MNN.expr as expr | ||||
| 
 | ||||
| net = expr.load_as_dict('mobilenet_v1.mnn') | ||||
| input_var = net['data'] | ||||
| output_var = net['prob'] | ||||
| 
 | ||||
| # 读取图片 | ||||
| image = cv.imread('cat.jpg') | ||||
| # 转换为float32, 形状为[224,224,3]         | ||||
| image = cv.resize(image, (224, 224), mean=[103.94, 116.78, 123.68], norm=[0.017, 0.017, 0.017]) | ||||
| # 增加batch HWC to NHWC | ||||
| input_data = np.expand_dims(image, 0) | ||||
| # NHWC to NC4HW4 | ||||
| input_data = expr.convert(input_data, expr.NC4HW4) | ||||
| 
 | ||||
| input_var.write(input_data.read_as_tuple()) | ||||
| 
 | ||||
| # 打印出分类结果, 282为猫 | ||||
| print("output belong to class: {}".format(np.argmax(output_var))) | ||||
| ``` | ||||
| `Expr`用于数值计算与数据存取: | ||||
| ```python | ||||
| import MNN.numpy as np | ||||
| import MNN.expr as expr | ||||
| 
 | ||||
| x = expr.range(0., 10., 1.) | ||||
| y = expr.fill([10], 3.1415) | ||||
| z = expr.sin(x * y + x / y) | ||||
| expr.save([z], 'z.mnn') | ||||
| a = expr.load_as_list('z.mnn')[0] | ||||
| print(a) | ||||
| ''' | ||||
| array([ 0.        , -0.31288275,  0.59434694, -0.8161286 ,  0.955958  , | ||||
|        -0.9997932 ,  0.943233  , -0.79195637,  0.561154  , -0.27400237], | ||||
|       dtype=float32) | ||||
| ''' | ||||
| ``` | ||||
| 其他示例可以参考[示例](../pymnn/Var.html#example);也可以参考[示例工程](../start/demo.html#id5)。 | ||||
| ## 使用cv/numpy API | ||||
| ### 数据类型 | ||||
| Python的`cv`和`numpy`接口,其中`cv`是对C++中`tools/cv`实现的封装;`numpy`则是对`expr`接口的封装;这两个接口主要为了提高MNN的易用性,与`opencv`与`numpy`做到了再接口上的部分兼容,在用法和思路上基本一致。主要数据类型如下: | ||||
|  |  | |||
|  | @ -1,5 +1,122 @@ | |||
| # 发布版本 | ||||
| ## 2.4.0 (`Latest`) | ||||
| ## 2.6.0 (`Latest`) | ||||
| #### 新特性 | ||||
| - 新增int8量化算子支持: | ||||
|   - Softmax | ||||
|   - Interp | ||||
|   - Binary | ||||
|   - Unary | ||||
|   - Scale | ||||
| - OpenCL 支持 Loop 算子特定情形; | ||||
|   - BatchMatMul | ||||
|   - Gather | ||||
| - x86_64支持Gelu-bf16; | ||||
| - CUDA支持bf16模型推理; | ||||
| - benchmark 工具支持直接测试模型量化后的性能(不需要先用量化工具量化模型) | ||||
| - Pymnn Tensor/Var使用Tuple创建时支持混合类型数据; | ||||
| - 权值量化模型支持低内存推理模式,计算时反量化; | ||||
| - 支持ChatGLM-6B模型推理内存占用3G; | ||||
| - 支持构建了ChatGLM-MNN Android app; | ||||
| #### 优化 | ||||
| - OpenCL支持高通reocrd queue ,以降低创建 GPU Command Buffer 所需的时间; | ||||
|   Oneplus 9 机型 Benchmark 测试结果如下 | ||||
| 
 | ||||
|   |Model	|unrecord	|record | | ||||
|   |-------|---------|-------| | ||||
|   |resnet-v2-50.mnn	|21.254	|20.160| | ||||
|   |MobileNetV2_224.mnn	|4.853	|4.186| | ||||
|   |mobilenet-v1-1.0.mnn	|6.424	|5.315| | ||||
|   |nasnet.mnn	|46.751	|20.260| | ||||
|   |SqueezeNetV1.0.mnn	|7.35	|6.832| | ||||
|   |squeezenetv1.1.mnn	|3.936	|3.693| | ||||
|   |mobilenetV3.mnn	|14.201	|6.743| | ||||
|   |inception-v3.mnn	|33.111	|32.032| | ||||
| 
 | ||||
| - 稀疏卷积内存优化,降低内存占用; | ||||
| - 减少异构(CPU低精度/GPU)运行 MNN 模型时的常量内存占用; | ||||
| - CUDA优化int8算子性能; | ||||
| - 减少Permute几何计算产生的region数量; | ||||
| - 重新调整ConvolutionInt8及im2col在AVX512-VNNI下的分块大小,提升性能20%-30%; | ||||
| - X86新增bilinear/nearest sample的SIMD实现,提升ImageProcess性能 50% 左右; | ||||
| #### Bugfix | ||||
| - 关联 Github Issue 解决 | ||||
|   - 修复CUDA Raster错误导致输出为0的问题;issue-2333 | ||||
|   - 修复OpenCL Gather算子出错的问题;issue-2424 | ||||
|   - 修复ImageProcess出错的问题;issue-2386 | ||||
|   - OpenCL支持用户选择device id; issue-2343 | ||||
| - 其他 Bugfix | ||||
|   - CUDA CMakeList对未支持架构增加报错信息; | ||||
|   - testMNNFromOnnx脚本在模型测试正确时不启用DEBUG模式; | ||||
|   - load_module_from_file中的shape_mutable默认改为True(存在子图的模型无法在False情形下运行); | ||||
|   - MNNConvert使用keepInputFormat选项时,也同时将输出Tensor的format转换为原始格式 | ||||
|   - 修复log记录时设备为空时Crash的情况; | ||||
|   - 修复BinaryOp单元测试在Windows下无法编译的问题; | ||||
|   - 修复MNN_SUPPORT_DEPRECATED_OP宏不控制OptimizedComputer的问题; | ||||
|   - 修复fp16多线程且分块方向为channel时convolution计算出错的问题; | ||||
|   - 修复deconvolutionInt8访存越界的问题; | ||||
|   - 修复TensorArrayWrite几何计算产生zero region的问题; | ||||
|   - 修复CUDA depthwise conv出错的问题; | ||||
|   - 修复一些文档格式、内容的错误; | ||||
|   - 修复多线程下createRuntime和setGlobalConfig出错的问题; | ||||
|   - 修复Vec.hpp中无用代码导致的编译失败问题; | ||||
|   - 修复OpenCL对gpuDevice的assert失败的问题; | ||||
|   - 修复OpenCL bianry mod出错的问题; | ||||
|   - 修复CUDA argmax出错的问题; | ||||
|   - 修复pymnn/example/mnn_numpy_cv_demo.py中形状不对的问题; | ||||
| ## 2.5.0 | ||||
| #### 新特性 | ||||
| - MNN OpenCV新增算子: | ||||
|   - erode | ||||
|   - convertMaps | ||||
|   - remap | ||||
|   - adaptiveThreshold | ||||
|   - bilateralFilter | ||||
|   - solve (MNN numpy新增solve) | ||||
|   - normalize | ||||
|   - split | ||||
|   - merge | ||||
|   - addWeight | ||||
| - 支持Tflite int8量化模型转换到MNN模型; | ||||
| - ARM CPU支持GELU-bf16 | ||||
| - CUDA 新增算子: | ||||
| - GridSampler | ||||
| - Multi-Input Convolution | ||||
| - Multi-Input Deconvolution | ||||
| - CUDA针对多卡推理,支持用户设置运行device_id | ||||
| - 支持Deconvolution-int8 | ||||
| - runSession/runSessionWithCallBack函数加锁,避免多线程调用出错 | ||||
| - 支持非拓扑序ONNX模型转换 | ||||
| - 支持ONNX多版本Softmax转换 | ||||
| #### 重构/优化 | ||||
| - 优化内存分配与回收时机,新增Session | ||||
| - 简化ONNX Slice算子模型转换 | ||||
| - Cuda性能优化 | ||||
| - Argmax针对dim size较大的情况性能优化 | ||||
| - Softmax在channel较大时性能优化 | ||||
| - MatMul算子预重排逻辑优化 | ||||
| - 优化后ChatGLM模型在A10显卡上性能优于Pytorch 2.0 | ||||
| - OpenCL优化,resnet测试优于OpenVINO | ||||
| - 使用intel subgroup扩展优化winogard算子,调整数据排布格式与准入条件 | ||||
| - 根据输入尺寸调整conv2d算子的数据排布格式,使用intel subgroup扩展优化 | ||||
| - 优化后ResNet18模型在intel UHD Graphics 630显卡上性能优于OpenVINO | ||||
| - GELU-bf16实现后性能提升 | ||||
| #### Bugfix | ||||
| - 关联 Github Issue 解决 | ||||
|   - 修复CPURaster 的 singleConvert 部分情况出错 issue-2264 | ||||
|   - 修复atan2计算错误的问题 | ||||
|   - 修复ONNX dynamic shape转换出错的问题 issue-2276 | ||||
|   - 修复i8mm时Im2col出错的问题 | ||||
|   - 修复CPUConvolutionDepthwise错误的问题 issue-2291 | ||||
|   - 修复CUDA int8编译失败的问题 issue-2321 | ||||
|   - 修复Onnx Loop 算子的 M 和 cond 为optional 时,转换失败的问题 issue-2267 | ||||
|   - 修复Raster中fastblit 中turnPackRegion 部分情况下出错的问题 issue-2337 | ||||
| - 其他 Bugfix | ||||
|   - 修复 onnx 子图中 identity 被优化导致 输出数和原始子图不一致的问题 | ||||
|   - 修复 Onnx sequense 相关算子转换问题 | ||||
|   - 修复 TensorArrayConcat 计算 newAxis = 1 时的问题(此时为 stack) | ||||
|   - 修复 TensorArray 计算 eleSize 时,axis < 0 时计算出错的问题 | ||||
|   - 修复低精度计算或者 GPU 无法运行 mnn 训练模型的问题 | ||||
| ## 2.4.0 | ||||
| #### 新特性 | ||||
| - NNAPI 支持int8 量化模型; | ||||
| - MNN OpenCL/Metal支持算子在线Fuse与代码生成; | ||||
|  |  | |||
|  | @ -1,5 +1,5 @@ | |||
| <!-- pymnn/CVImageProcess.md --> | ||||
| ## MNN.CVImageProcess | ||||
| ## MNN.CVImageProcess *[deprecated]* | ||||
| 
 | ||||
| ```python | ||||
| class CVImageProcess | ||||
|  | @ -10,6 +10,7 @@ CVImageProcess用于图像处理,该图像处理类提供了一下图像处理 | |||
| - 图像的仿射变换,类似于`cv2.resize`和`cv2.warpAffine`,通过设置[CVMatrix](CVMatrix.md) 来实现 | ||||
| - 对图像进行归一化,通过设置`mean`和`normal`来实现; `x = (x - mean) / normal` | ||||
| 
 | ||||
| *不建议使用该接口,请使用[cv](cv.md)代替* | ||||
| --- | ||||
| ### `MNN.CV_ImageFormat_*` | ||||
| 描述图像格式的数据类型,支持RBG,RGBA,BGR,BGRA,GRAY,YUV_NV21类型 | ||||
|  |  | |||
|  | @ -1,5 +1,5 @@ | |||
| <!-- pymnn/CVMatrix.md --> | ||||
| ## MNN.CVMatrix | ||||
| ## MNN.CVMatrix *[deprecated]* | ||||
| 
 | ||||
| ```python | ||||
| class CVMatrix | ||||
|  |  | |||
|  | @ -1,10 +1,12 @@ | |||
| ## MNN.Interpreter | ||||
| ## MNN.Interpreter *[deprecated]* | ||||
| 
 | ||||
| ```python | ||||
| class Interpreter | ||||
| ``` | ||||
| Interpreter是MNN V2接口中模型数据的持有者。使用MNN推理时,有两个层级的抽象,分别是解释器Interpreter和会话[Session](Session.md)。 | ||||
| 
 | ||||
| *不建议使用该接口,请使用[nn](nn.md)代替* | ||||
| 
 | ||||
| --- | ||||
| ### `Interpreter(model_path)` | ||||
| 加载`.mnn`模型文件创建一个MNN解释器,返回一个解释器对象 | ||||
|  |  | |||
|  | @ -1,10 +1,12 @@ | |||
| ## MNN.Session | ||||
| ## MNN.Session *[deprecated]* | ||||
| 
 | ||||
| ```python | ||||
| class Session | ||||
| ``` | ||||
| Session是MNN V2接口中推理数据的持有者。Session通过[Interpreter](Interpreter.md)创建;多个推理可以共用同一个模型,即,多个Session可以共用一个Interpreter。 | ||||
| 
 | ||||
| *不建议使用该接口,请使用[nn](nn.md)代替* | ||||
| 
 | ||||
| --- | ||||
| ### `Session()` | ||||
| 创建一个空Tensor | ||||
|  |  | |||
|  | @ -1,11 +1,13 @@ | |||
| <!-- pymnn/Tensor.md --> | ||||
| ## MNN.Tensor | ||||
| ## MNN.Tensor *[deprecated]* | ||||
| 
 | ||||
| ```python | ||||
| class Tensor | ||||
| ``` | ||||
| Tensor是MNN V2接口中的基础数据结构,是最基本的数据封装类型。Tensor存储了数据以及数据类型,形状等诸多信息,用户可以通过Tensor本身的函数获取这些信息。 | ||||
| 
 | ||||
| *不建议使用该接口,请使用[Var](Var.md)代替* | ||||
| 
 | ||||
| --- | ||||
| ### `MNN.Halide_Type_*` | ||||
| 描述Tensor的数据类型 | ||||
|  |  | |||
|  | @ -103,45 +103,6 @@ array([0., 1., 2., 3.], dtype=float32) | |||
|         [3., 4.]], dtype=float32)] | ||||
| ``` | ||||
| --- | ||||
| ### `load_as_dict(fileName)` | ||||
| 从文件中加载模型,并将模型转换为计算图,以`dict`的形式返回计算图的所有节点名称和`Var` | ||||
| 
 | ||||
| 参数: | ||||
| - `fileName:str` 模型文件路径 | ||||
| 
 | ||||
| 返回:加载的模型计算图,其`key`为`str`,`value`为`Var` | ||||
| 
 | ||||
| 返回类型:`dict` | ||||
| 
 | ||||
| 示例: | ||||
| 
 | ||||
| ```python | ||||
| >>> vars = expr.load_as_dict('mobilenet_v1.mnn') | ||||
| >>> vars.keys() | ||||
| dict_keys(['conv1', 'conv2_1/dw', 'conv2_1/sep', 'conv2_2/dw', 'conv2_2/sep', 'conv3_1/dw', 'conv3_1/sep', 'conv3_2/dw', 'conv3_2/sep', 'conv4_1/dw', 'conv4_1/sep', 'conv4_2/dw', 'conv4_2/sep', 'conv5_1/dw', 'conv5_1/sep', 'conv5_2/dw', 'conv5_2/sep', 'conv5_3/dw', 'conv5_3/sep', 'conv5_4/dw', 'conv5_4/sep', 'conv5_5/dw', 'conv5_5/sep', 'conv5_6/dw', 'conv5_6/sep', 'conv6/dw', 'conv6/sep', 'data', 'fc7', 'pool6', 'prob']) | ||||
| ``` | ||||
| --- | ||||
| ### `get_inputs_and_outputs(allVariable)` | ||||
| 获取`dict`形式计算图的输入输出节点,可以在使用V3接口时获取输入输出的信息 | ||||
| 
 | ||||
| 参数: | ||||
| - `allVariable:dict` 计算图的`dict`形式,其`key`为`str`,`value`为`Var` | ||||
| 
 | ||||
| 返回:计算图的输入输出,其中输入输出都为`dict`形式,其`key`为`str`,`value`为`Var` | ||||
| 
 | ||||
| 返回类型:`(dict, dict)` | ||||
| 
 | ||||
| 示例: | ||||
| 
 | ||||
| ```python | ||||
| >>> vars = expr.load_as_dict('mobilenet_v1.mnn') | ||||
| >>> inputs, outputs = expr.get_inputs_and_outputs(vars) | ||||
| >>> inputs.keys() | ||||
| dict_keys(['data']) | ||||
| >>> outputs.keys() | ||||
| dict_keys(['prob']) | ||||
| ``` | ||||
| --- | ||||
| ### `gc(full)` | ||||
| 手动回收内存,当在循环中调用MNN表达式求值时,常量部分数据不会在每次循环结束释放,当执行次数增加时会有内存增长现象,可以在每次循环结束时调用该函数回收常量内存 | ||||
| 
 | ||||
|  | @ -3050,3 +3011,47 @@ roialign | |||
| ```python | ||||
| TODO | ||||
| ``` | ||||
| --- | ||||
| **以下函数为框架开发者使用函数,普通用户不建议使用!** | ||||
| 
 | ||||
| --- | ||||
| ### `load_as_dict(fileName)` *[deprecated]* | ||||
| 从文件中加载模型,并将模型转换为计算图,以`dict`的形式返回计算图的所有节点名称和`Var` | ||||
| 
 | ||||
| *不建议使用该接口* | ||||
| 
 | ||||
| 参数: | ||||
| - `fileName:str` 模型文件路径 | ||||
| 
 | ||||
| 返回:加载的模型计算图,其`key`为`str`,`value`为`Var` | ||||
| 
 | ||||
| 返回类型:`dict` | ||||
| 
 | ||||
| 示例: | ||||
| 
 | ||||
| ```python | ||||
| >>> vars = expr.load_as_dict('mobilenet_v1.mnn') | ||||
| >>> vars.keys() | ||||
| dict_keys(['conv1', 'conv2_1/dw', 'conv2_1/sep', 'conv2_2/dw', 'conv2_2/sep', 'conv3_1/dw', 'conv3_1/sep', 'conv3_2/dw', 'conv3_2/sep', 'conv4_1/dw', 'conv4_1/sep', 'conv4_2/dw', 'conv4_2/sep', 'conv5_1/dw', 'conv5_1/sep', 'conv5_2/dw', 'conv5_2/sep', 'conv5_3/dw', 'conv5_3/sep', 'conv5_4/dw', 'conv5_4/sep', 'conv5_5/dw', 'conv5_5/sep', 'conv5_6/dw', 'conv5_6/sep', 'conv6/dw', 'conv6/sep', 'data', 'fc7', 'pool6', 'prob']) | ||||
| ``` | ||||
| --- | ||||
| ### `get_inputs_and_outputs(allVariable)` *[deprecated]* | ||||
| 获取`dict`形式计算图的输入输出节点,可以在使用V3接口时获取输入输出的信息 | ||||
| 
 | ||||
| 参数: | ||||
| - `allVariable:dict` 计算图的`dict`形式,其`key`为`str`,`value`为`Var` | ||||
| 
 | ||||
| 返回:计算图的输入输出,其中输入输出都为`dict`形式,其`key`为`str`,`value`为`Var` | ||||
| 
 | ||||
| 返回类型:`(dict, dict)` | ||||
| 
 | ||||
| 示例: | ||||
| 
 | ||||
| ```python | ||||
| >>> vars = expr.load_as_dict('mobilenet_v1.mnn') | ||||
| >>> inputs, outputs = expr.get_inputs_and_outputs(vars) | ||||
| >>> inputs.keys() | ||||
| dict_keys(['data']) | ||||
| >>> outputs.keys() | ||||
| dict_keys(['prob']) | ||||
| ``` | ||||
|  | @ -628,6 +628,8 @@ void Executor::_makeCache(const std::vector<EXPRP>& expr, bool forceCPU) { | |||
|         expr->inside()->mCache = cahce; | ||||
|     } | ||||
|     cahce->mCacheBuffers = std::move(opBuffers); | ||||
|     // Don't report error when use expr dynamic compute, which will be called in model convert
 | ||||
|     scheduleInfo.pipelineInfo[0].first.reportError = false; | ||||
|     scheduleInfo.pipelineInfo[0].first.info.numThread = 1; | ||||
|     if (forceCPU) { | ||||
|         scheduleInfo.pipelineInfo[0].first.info.type = MNN_FORWARD_CPU; | ||||
|  |  | |||
|  | @ -31,7 +31,8 @@ static Scope<ExecutorRef>* _getGlobalScope() { | |||
| #if TARGET_OS_IPHONE | ||||
|         pthread_key_create(&gKey, NULL); | ||||
| #else | ||||
|         g_executor_scope = new Scope<ExecutorRef>; | ||||
|         thread_local static Scope<ExecutorRef> initValue; | ||||
|         g_executor_scope = &initValue; | ||||
| #endif | ||||
|     }); | ||||
| #if TARGET_OS_IPHONE | ||||
|  |  | |||
|  | @ -69,6 +69,6 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \ | |||
| #define STR(x) STR_IMP(x) | ||||
| #define MNN_VERSION_MAJOR 2 | ||||
| #define MNN_VERSION_MINOR 6 | ||||
| #define MNN_VERSION_PATCH 2 | ||||
| #define MNN_VERSION_PATCH 3 | ||||
| #define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH) | ||||
| #endif /* MNNDefine_h */ | ||||
|  |  | |||
|  | @ -63,6 +63,11 @@ typedef enum { | |||
|      then choose the better one according to performance*/ | ||||
|     MNN_GPU_MEMORY_BUFFER = 1 << 6,/* User assign mode */ | ||||
|     MNN_GPU_MEMORY_IMAGE  = 1 << 7,/* User assign mode */ | ||||
|     // choose one opencl memory mode Only, this mode Only support for Qualcomm gpu
 | ||||
|     /* User can try MNN_GPU_RECORD_OP and MNN_GPU_RECORD_KERNEL both,
 | ||||
|      then choose the better one according to performance*/ | ||||
|     MNN_GPU_RECORD_OP  = 1 << 8,/* the kernels in one op execution record into one recording */ | ||||
|     MNN_GPU_RECORD_BATCH  = 1 << 9,/* 10 kernels record into one recording */ | ||||
| } MNNGpuMode; | ||||
| 
 | ||||
| #ifdef __cplusplus | ||||
|  |  | |||
|  | @ -4,9 +4,11 @@ | |||
| #       |-- Release | ||||
| #             |--- Dynamic | ||||
| #             |      |--- MD | ||||
| #             |      |--- MT | ||||
| #             | | ||||
| #             |--- Static | ||||
| #                    |--- MD | ||||
| #                    |--- MT | ||||
| # | ||||
| Param( | ||||
|     [Parameter(Mandatory=$true)][String]$path, | ||||
|  | @ -27,7 +29,7 @@ Remove-Item -Path $PACKAGE_PATH/include -Recurse -ErrorAction Ignore | |||
| cp -r include $PACKAGE_PATH | ||||
| cp -r tools/cv/include/cv $PACKAGE_PATH/include | ||||
| pushd $PACKAGE_LIB_PATH | ||||
| mkdir -p Release\Dynamic\MD, Release\Static\MD | ||||
| mkdir -p Release\Dynamic\MT, Release\Dynamic\MD, Release\Static\MD, Release\Static\MT | ||||
| popd | ||||
| 
 | ||||
| $CMAKE_ARGS = "-DMNN_SEP_BUILD=OFF -DMNN_BUILD_TRAIN=ON -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON  -DMNN_OPENCL=ON -DMNN_VULKAN=ON -DMNN_AVX512=ON" | ||||
|  | @ -68,6 +70,14 @@ function Build([String]$cmake_cmd, [String]$ninja_cmd = "ninja MNN") { | |||
|     exit 1 | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| ##### Release/Dynamic/MT #### | ||||
| log "Release/Dynamic/MT" | ||||
| Remove-Item CMakeCache.txt -ErrorAction Ignore | ||||
| Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=ON .." | ||||
| cp MNN.lib, MNN.dll, MNN.pdb $PACKAGE_LIB_PATH\Release\Dynamic\MT | ||||
| rm MNN.* | ||||
| 
 | ||||
| ##### Release/Dynamic/MD #### | ||||
| log "Release/Dynamic/MD" | ||||
| Remove-Item CMakeCache.txt -ErrorAction Ignore | ||||
|  | @ -75,6 +85,12 @@ Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_M | |||
| cp MNN.lib, MNN.dll, MNN.pdb $PACKAGE_LIB_PATH\Release\Dynamic\MD | ||||
| rm MNN.* | ||||
| 
 | ||||
| ##### Release/Static/MT #### | ||||
| log "Release/Static/MT" | ||||
| Remove-Item CMakeCache.txt -ErrorAction Ignore | ||||
| Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=ON -DMNN_BUILD_SHARED_LIBS=OFF .." | ||||
| cp MNN.lib $PACKAGE_LIB_PATH\Release\Static\MT | ||||
| 
 | ||||
| ##### Release/Static/MD #### | ||||
| log "Release/Static/MD" | ||||
| Remove-Item CMakeCache.txt -ErrorAction Ignore | ||||
|  |  | |||
|  | @ -1839,6 +1839,7 @@ | |||
| 		488873A8215B639D0079B12E /* source */ = { | ||||
| 			isa = PBXGroup; | ||||
| 			children = ( | ||||
| 				CE482EF5288536DA007CD935 /* internal */, | ||||
| 				4DF87C482887D3560003E2D4 /* calib3d */, | ||||
| 				4D4CF4612760946500A36D9F /* imgproc */, | ||||
| 				4A5BEC6226AAB3D70032F6BD /* common */, | ||||
|  | @ -2873,15 +2874,18 @@ | |||
| 				CEA82BDC2A15F8AD002CBC95 /* IdstConvolutionInt8.hpp in Headers */, | ||||
| 				4DE4E82C275E307B0016A916 /* cv in Headers */, | ||||
| 				1F501F842397BA5B004E8721 /* ImageProcess.hpp in Headers */, | ||||
| 				CECF8C5D299CACFD00D3875B /* Log.hpp in Headers */, | ||||
| 				1F501F822397BA5B004E8721 /* Interpreter.hpp in Headers */, | ||||
| 				C4F906B327688C3A0026B847 /* NMSModule.hpp in Headers */, | ||||
| 				1F501F882397BA5B004E8721 /* Tensor.hpp in Headers */, | ||||
| 				1F501F872397BA5B004E8721 /* Matrix.h in Headers */, | ||||
| 				CECF8C5A299CACFD00D3875B /* WorkerThread.hpp in Headers */, | ||||
| 				48C84B85250F711700EE7666 /* IfModule.hpp in Headers */, | ||||
| 				4D9A937326255BDA00F9B43C /* CoreMLUnary.hpp in Headers */, | ||||
| 				48C84B98250F71E900EE7666 /* CPUSoftmax.hpp in Headers */, | ||||
| 				4882C8B8241A22B800DAC168 /* OpCommonUtils.hpp in Headers */, | ||||
| 				48608B54250632EC00CB1D71 /* GeometryComputer.hpp in Headers */, | ||||
| 				CECF8C7A299CAD9400D3875B /* sha1.h in Headers */, | ||||
| 				4894C6EC27016F7200D8BE79 /* CPUResizeCache.hpp in Headers */, | ||||
| 				92FF04A623AA0BFB00AC97F6 /* FileLoader.hpp in Headers */, | ||||
| 				48F34733273A7C8400C45394 /* ImageProcessFunction.hpp in Headers */, | ||||
|  | @ -2896,6 +2900,7 @@ | |||
| 				48925F352744AC0700919B37 /* CPUROIAlign.hpp in Headers */, | ||||
| 				92FF029623AA0B5A00AC97F6 /* CPUCast.hpp in Headers */, | ||||
| 				4D9A937826255BDA00F9B43C /* CoreMLBinary.hpp in Headers */, | ||||
| 				CECF8C85299CAD9400D3875B /* log_util.h in Headers */, | ||||
| 				489D7AB02550FDC900AD896A /* MetalDefine.h in Headers */, | ||||
| 				4D6D7FD52656896600F80814 /* DenseConvolutionTiledExecutor.hpp in Headers */, | ||||
| 				4D9A936626255BDA00F9B43C /* CoreMLExecutor.h in Headers */, | ||||
|  | @ -2905,6 +2910,7 @@ | |||
| 				1F501F802397BA5B004E8721 /* MNNDefine.h in Headers */, | ||||
| 				19D0FE76285C66F200B74B1A /* MetalLayerNorm.hpp in Headers */, | ||||
| 				489D7A682550FDC800AD896A /* MetalReduction.hpp in Headers */, | ||||
| 				CECF8C86299CAD9400D3875B /* sds.h in Headers */, | ||||
| 				1F501F7F2397BA5B004E8721 /* HalideRuntime.h in Headers */, | ||||
| 				92FF029E23AA0B5A00AC97F6 /* CPUDeconvolutionDepthwise.hpp in Headers */, | ||||
| 				4D9A935B26255BDA00F9B43C /* NeuralNetwork.pb-c.h in Headers */, | ||||
|  | @ -2925,8 +2931,10 @@ | |||
| 				481C2DEE25FE2CD6001ED6DF /* Arm82Functions.hpp in Headers */, | ||||
| 				4894C6EA27016F7200D8BE79 /* UnaryUtils.hpp in Headers */, | ||||
| 				EBD4842A2485FF650083CE95 /* Arm82Interp.hpp in Headers */, | ||||
| 				CECF8C81299CAD9400D3875B /* log_util_imp.h in Headers */, | ||||
| 				92FF037623AA0B5A00AC97F6 /* CPUBinary.hpp in Headers */, | ||||
| 				4D9A935826255BDA00F9B43C /* FeatureTypes.pb-c.h in Headers */, | ||||
| 				CECF8C7C299CAD9400D3875B /* hmac-sha.h in Headers */, | ||||
| 				48608B53250632EC00CB1D71 /* GeometryComputerUtils.hpp in Headers */, | ||||
| 				950B28F529F629A90002F454 /* CPUBinaryInt8.hpp in Headers */, | ||||
| 				489D7A732550FDC800AD896A /* MetalBackend.hpp in Headers */, | ||||
|  | @ -2948,6 +2956,7 @@ | |||
| 				4DF87C522887D3F20003E2D4 /* CPUSvd.hpp in Headers */, | ||||
| 				48747D4B245D9D24000B9709 /* RuntimeFactory.hpp in Headers */, | ||||
| 				92FF03B323AA0B5A00AC97F6 /* ConvolutionDepthwise3x3.hpp in Headers */, | ||||
| 				CECF8C77299CAD9400D3875B /* log_builder.h in Headers */, | ||||
| 				4D9A937226255BDA00F9B43C /* CoreMLConvolution.hpp in Headers */, | ||||
| 				92FF038B23AA0B5A00AC97F6 /* CPUUnravelIndex.hpp in Headers */, | ||||
| 				4AF4FB26269ED235005BA97B /* SparseConvInt8TiledExecutor.hpp in Headers */, | ||||
|  | @ -2983,6 +2992,7 @@ | |||
| 				92FF03CA23AA0B5A00AC97F6 /* CPUConvolutionDepthwise.hpp in Headers */, | ||||
| 				92FF04A923AA0BFB00AC97F6 /* Schedule.hpp in Headers */, | ||||
| 				489D7A9F2550FDC900AD896A /* MetalConvolutionCommon.hpp in Headers */, | ||||
| 				CECF8C80299CAD9400D3875B /* lz4.h in Headers */, | ||||
| 				92FF028623AA0B5A00AC97F6 /* CPUDeconvolution.hpp in Headers */, | ||||
| 				489D7A722550FDC800AD896A /* MetalReLU6.hpp in Headers */, | ||||
| 				92FF04B523AA0BFB00AC97F6 /* TensorUtils.hpp in Headers */, | ||||
|  | @ -3035,20 +3045,24 @@ | |||
| 				92FF03A623AA0B5A00AC97F6 /* ConvolutionTiledExecutor.hpp in Headers */, | ||||
| 				92FF036523AA0B5A00AC97F6 /* CPUResize.hpp in Headers */, | ||||
| 				92FF04B423AA0BFB00AC97F6 /* MNNMemoryUtils.h in Headers */, | ||||
| 				CECF8C88299CAD9400D3875B /* log_api.h in Headers */, | ||||
| 				4A224A0D27D0C2D9000A9260 /* ConvolutionPackWinograd.hpp in Headers */, | ||||
| 				4A224A0E27D0C2D9000A9260 /* ConvolutionPackFreeWinograd.hpp in Headers */, | ||||
| 				4D9A937426255BDA00F9B43C /* CoreMLReduction.hpp in Headers */, | ||||
| 				48C84B8B250F711700EE7666 /* PipelineModule.hpp in Headers */, | ||||
| 				F41497D7278D8A21004A363A /* RuntimeAttr.hpp in Headers */, | ||||
| 				CECF8C5B299CACFD00D3875B /* LogHelper.hpp in Headers */, | ||||
| 				92FF04C123AA0BFB00AC97F6 /* Backend.hpp in Headers */, | ||||
| 				482BFBCD28351BA1009210E4 /* ShaderMap.hpp in Headers */, | ||||
| 				489D7A812550FDC900AD896A /* MetalPooling.hpp in Headers */, | ||||
| 				CECF8C7F299CAD9400D3875B /* md5.h in Headers */, | ||||
| 				92FF02A623AA0B5A00AC97F6 /* CPUQuantizedMaxPool.hpp in Headers */, | ||||
| 				92FF028023AA0B5A00AC97F6 /* CPUFloatToInt8.hpp in Headers */, | ||||
| 				92FF028723AA0B5A00AC97F6 /* CPUFixedPoint.hpp in Headers */, | ||||
| 				C43C8227251894F400A0FF84 /* Vec.hpp in Headers */, | ||||
| 				4819FB1D24C138DF0050BD09 /* GeometryConvUtils.hpp in Headers */, | ||||
| 				489D7A952550FDC900AD896A /* MetalMatMul.hpp in Headers */, | ||||
| 				CECF8C83299CAD9400D3875B /* log_define.h in Headers */, | ||||
| 				C48CAE2628900C4A00271A6D /* ConvInt8Winograd.hpp in Headers */, | ||||
| 				48F34730273A7C7300C45394 /* CPUImageProcess.hpp in Headers */, | ||||
| 				489D7A702550FDC800AD896A /* MetalRaster.hpp in Headers */, | ||||
|  | @ -3272,6 +3286,7 @@ | |||
| 				489D7A8A2550FDC900AD896A /* MetalConvolutionDepthwise.mm in Sources */, | ||||
| 				48123003269EA83400EB7ABA /* ShapeUnique.cpp in Sources */, | ||||
| 				92FF037D23AA0B5A00AC97F6 /* CPURelu.cpp in Sources */, | ||||
| 				CECF8C5E299CACFD00D3875B /* WorkerThread.cpp in Sources */, | ||||
| 				489D7A842550FDC900AD896A /* MetalBinary.mm in Sources */, | ||||
| 				48747D6B245D9E33000B9709 /* GeometryFill.cpp in Sources */, | ||||
| 				4819FB1F24C138DF0050BD09 /* GeometryConvUtils.cpp in Sources */, | ||||
|  | @ -3370,6 +3385,7 @@ | |||
| 				48F34734273A7C8400C45394 /* ImageProcessFunction.cpp in Sources */, | ||||
| 				6A131E4025823349002EC3D6 /* PluginKernel.cpp in Sources */, | ||||
| 				48958781268EBA6F00EA01A7 /* CPUSegmentMean.cpp in Sources */, | ||||
| 				CECF8C7B299CAD9400D3875B /* sha1.c in Sources */, | ||||
| 				4D9A937026255BDA00F9B43C /* CoreMLUnary.cpp in Sources */, | ||||
| 				92FF04A823AA0BFB00AC97F6 /* AutoTime.cpp in Sources */, | ||||
| 				92FF04AE23AA0BFB00AC97F6 /* Backend.cpp in Sources */, | ||||
|  | @ -3424,6 +3440,7 @@ | |||
| 				92FF03CE23AA0B5A00AC97F6 /* CPUOPRegister.cpp in Sources */, | ||||
| 				92FF02B323AA0B5A00AC97F6 /* CPUInstanceNorm.cpp in Sources */, | ||||
| 				4819FB2C24C1396A0050BD09 /* GeometryPoolGrad.cpp in Sources */, | ||||
| 				CECF8C7E299CAD9400D3875B /* log_builder.cpp in Sources */, | ||||
| 				92FF042223AA0B7100AC97F6 /* ShapeConcat.cpp in Sources */, | ||||
| 				4D6D7FD12656891400F80814 /* MNNPackedSparseMatMulEpx4.S in Sources */, | ||||
| 				4D5662CC299B76ED0031C1A1 /* MNNMaxPoolInt8.S in Sources */, | ||||
|  | @ -3499,6 +3516,7 @@ | |||
| 				4D759B2C25FF89EE0037B0B6 /* GeometryShape.cpp in Sources */, | ||||
| 				11A01A07258785EA00745FA7 /* MNNVectorTop1Float.S in Sources */, | ||||
| 				48747D6E245D9E33000B9709 /* GeometrySlice.cpp in Sources */, | ||||
| 				CECF8C7D299CAD9400D3875B /* md5.c in Sources */, | ||||
| 				92FF041923AA0B7100AC97F6 /* ShapeQuantizedMaxPool.cpp in Sources */, | ||||
| 				92FF038A23AA0B5A00AC97F6 /* CPURange.cpp in Sources */, | ||||
| 				CE125CC92A52BF6B003698C9 /* MNNBilinearLineC8.S in Sources */, | ||||
|  | @ -3555,7 +3573,9 @@ | |||
| 				92FF042E23AA0B7100AC97F6 /* ShapeProposal.cpp in Sources */, | ||||
| 				92FF025923AA0B5A00AC97F6 /* CPUPoolInt8.cpp in Sources */, | ||||
| 				92FF045B23AA0B7100AC97F6 /* ShapeShape.cpp in Sources */, | ||||
| 				CECF8C87299CAD9400D3875B /* sds.c in Sources */, | ||||
| 				4D6D7FD72656896D00F80814 /* SparseConvolutionTiledExecutor.cpp in Sources */, | ||||
| 				CECF8C82299CAD9400D3875B /* log_api.cpp in Sources */, | ||||
| 				92FF03A823AA0B5A00AC97F6 /* WinogradOptFunction.cpp in Sources */, | ||||
| 				950B28E229F627E00002F454 /* MNNBinarySubInt8.S in Sources */, | ||||
| 				950B28F029F627F70002F454 /* MNNBinarySubInt8.S in Sources */, | ||||
|  | @ -3566,6 +3586,7 @@ | |||
| 				4D9A936026255BDA00F9B43C /* Model.pb-c.c in Sources */, | ||||
| 				CE9AFED628E54E3300566949 /* CPUInterp3D.cpp in Sources */, | ||||
| 				C4F906B427688C3A0026B847 /* NMSModule.cpp in Sources */, | ||||
| 				CECF8C64299CAD8400D3875B /* LogHelper.mm in Sources */, | ||||
| 				48FA474523AA127B00172C3B /* Executor.cpp in Sources */, | ||||
| 				92FF02EA23AA0B5A00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S in Sources */, | ||||
| 				48A8A61A21D101DE00C2B9A7 /* Matrix_CV.cpp in Sources */, | ||||
|  | @ -3590,6 +3611,7 @@ | |||
| 				92FF027F23AA0B5A00AC97F6 /* CPUDeconvolutionDepthwise.cpp in Sources */, | ||||
| 				EBECA3A724643D5D0062C7A3 /* MNNQuantizeFP16_UNIT4.S in Sources */, | ||||
| 				92FF04A423AA0BFB00AC97F6 /* Interpreter.cpp in Sources */, | ||||
| 				CECF8C5C299CACFD00D3875B /* Log.cpp in Sources */, | ||||
| 				92FF045623AA0B7100AC97F6 /* ShapeReshape.cpp in Sources */, | ||||
| 				92FF032523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */, | ||||
| 				92FF044423AA0B7100AC97F6 /* ShapeLSTM.cpp in Sources */, | ||||
|  | @ -3625,6 +3647,7 @@ | |||
| 				92FF02B623AA0B5A00AC97F6 /* CPUUnary.cpp in Sources */, | ||||
| 				92FF032723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */, | ||||
| 				CE7DC00028E2DE6B00797689 /* ShapeConvTranspose3D.cpp in Sources */, | ||||
| 				CECF8C78299CAD9400D3875B /* log_util_imp.cpp in Sources */, | ||||
| 				92FF02CA23AA0B5A00AC97F6 /* MNNUnPackC4.S in Sources */, | ||||
| 				48925F372744AC2A00919B37 /* ShapeROIAlign.cpp in Sources */, | ||||
| 				92FF02E723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */, | ||||
|  | @ -3649,11 +3672,13 @@ | |||
| 				92FF02FF23AA0B5A00AC97F6 /* MNNFloat2Int8.S in Sources */, | ||||
| 				4D9A937926255BDA00F9B43C /* CoreMLRaster.cpp in Sources */, | ||||
| 				48417FF224D13BF50056D9A7 /* GeometrySelect.cpp in Sources */, | ||||
| 				CECF8C84299CAD9400D3875B /* lz4.c in Sources */, | ||||
| 				489D7A7E2550FDC900AD896A /* MNNMetalContext.mm in Sources */, | ||||
| 				92FF033423AA0B5A00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S in Sources */, | ||||
| 				92FF036B23AA0B5A00AC97F6 /* CPUResize.cpp in Sources */, | ||||
| 				92FF02C723AA0B5A00AC97F6 /* MNNCopyC4WithStride.S in Sources */, | ||||
| 				92FF030923AA0B5A00AC97F6 /* MNNNV21ToBGRUnit.S in Sources */, | ||||
| 				CECF8C79299CAD9400D3875B /* hmac-sha.cpp in Sources */, | ||||
| 				92FF032623AA0B5A00AC97F6 /* MNNWinogradMatrixProductLeft.S in Sources */, | ||||
| 				92FF04C023AA0BFB00AC97F6 /* Tensor.cpp in Sources */, | ||||
| 				CEE9B95B2A3AA4D4006438F2 /* MNNBilinearLineC8.S in Sources */, | ||||
|  | @ -3981,10 +4006,12 @@ | |||
| 			isa = XCBuildConfiguration; | ||||
| 			buildSettings = { | ||||
| 				CODE_SIGN_IDENTITY = "Apple Development"; | ||||
| 				"CODE_SIGN_IDENTITY[sdk=iphoneos*]" = ""; | ||||
| 				"CODE_SIGN_IDENTITY[sdk=macosx*]" = "Apple Development"; | ||||
| 				CODE_SIGN_STYLE = Automatic; | ||||
| 				DEAD_CODE_STRIPPING = YES; | ||||
| 				DEFINES_MODULE = YES; | ||||
| 				DEVELOPMENT_TEAM = 6G7464HHUS; | ||||
| 				DEVELOPMENT_TEAM = Q48UX93J22; | ||||
| 				DYLIB_COMPATIBILITY_VERSION = 1; | ||||
| 				DYLIB_CURRENT_VERSION = 1; | ||||
| 				DYLIB_INSTALL_NAME_BASE = "@rpath"; | ||||
|  | @ -4027,7 +4054,7 @@ | |||
| 				METAL_LIBRARY_FILE_BASE = mnn; | ||||
| 				ONLY_ACTIVE_ARCH = YES; | ||||
| 				OTHER_CFLAGS = ""; | ||||
| 				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111ss; | ||||
| 				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111; | ||||
| 				PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; | ||||
| 				PROVISIONING_PROFILE_SPECIFIER = ""; | ||||
| 				"PROVISIONING_PROFILE_SPECIFIER[sdk=macosx*]" = ""; | ||||
|  | @ -4048,7 +4075,7 @@ | |||
| 				CODE_SIGN_STYLE = Automatic; | ||||
| 				DEAD_CODE_STRIPPING = YES; | ||||
| 				DEFINES_MODULE = YES; | ||||
| 				DEVELOPMENT_TEAM = 6G7464HHUS; | ||||
| 				DEVELOPMENT_TEAM = Q48UX93J22; | ||||
| 				DYLIB_COMPATIBILITY_VERSION = 1; | ||||
| 				DYLIB_CURRENT_VERSION = 1; | ||||
| 				DYLIB_INSTALL_NAME_BASE = "@rpath"; | ||||
|  | @ -4089,7 +4116,7 @@ | |||
| 				MACH_O_TYPE = staticlib; | ||||
| 				METAL_LIBRARY_FILE_BASE = mnn; | ||||
| 				OTHER_CFLAGS = ""; | ||||
| 				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111ss; | ||||
| 				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111; | ||||
| 				PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)"; | ||||
| 				PROVISIONING_PROFILE_SPECIFIER = ""; | ||||
| 				"PROVISIONING_PROFILE_SPECIFIER[sdk=macosx*]" = ""; | ||||
|  | @ -4108,7 +4135,7 @@ | |||
| 				ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; | ||||
| 				ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage; | ||||
| 				CODE_SIGN_STYLE = Automatic; | ||||
| 				DEVELOPMENT_TEAM = 6G7464HHUS; | ||||
| 				DEVELOPMENT_TEAM = Q48UX93J22; | ||||
| 				GCC_ENABLE_CPP_EXCEPTIONS = NO; | ||||
| 				GCC_ENABLE_CPP_RTTI = NO; | ||||
| 				HEADER_SEARCH_PATHS = ( | ||||
|  | @ -4121,7 +4148,7 @@ | |||
| 				IPHONEOS_DEPLOYMENT_TARGET = 9.0; | ||||
| 				LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; | ||||
| 				OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; | ||||
| 				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111ss; | ||||
| 				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111; | ||||
| 				PRODUCT_NAME = "$(TARGET_NAME)"; | ||||
| 				TARGETED_DEVICE_FAMILY = "1,2"; | ||||
| 			}; | ||||
|  | @ -4133,7 +4160,7 @@ | |||
| 				ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; | ||||
| 				ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage; | ||||
| 				CODE_SIGN_STYLE = Automatic; | ||||
| 				DEVELOPMENT_TEAM = 6G7464HHUS; | ||||
| 				DEVELOPMENT_TEAM = Q48UX93J22; | ||||
| 				GCC_ENABLE_CPP_EXCEPTIONS = NO; | ||||
| 				GCC_ENABLE_CPP_RTTI = NO; | ||||
| 				HEADER_SEARCH_PATHS = ( | ||||
|  | @ -4146,7 +4173,7 @@ | |||
| 				IPHONEOS_DEPLOYMENT_TARGET = 9.0; | ||||
| 				LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; | ||||
| 				OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; | ||||
| 				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111ss; | ||||
| 				PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111; | ||||
| 				PRODUCT_NAME = "$(TARGET_NAME)"; | ||||
| 				TARGETED_DEVICE_FAMILY = "1,2"; | ||||
| 			}; | ||||
|  | @ -4260,3 +4287,4 @@ | |||
| 	}; | ||||
| 	rootObject = 0F1465AE1FA18D1000F9860A /* Project object */; | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -76,9 +76,15 @@ endif() | |||
| if(PYMNN_IMGPROC_FILTER) | ||||
|     target_compile_definitions(mnnpybridge PRIVATE PYMNN_IMGPROC_FILTER) | ||||
| endif() | ||||
| if(PYMNN_IMGPROC_HISTOGRAMS) | ||||
|     target_compile_definitions(mnnpybridge PRIVATE PYMNN_IMGPROC_HISTOGRAMS) | ||||
| endif() | ||||
| if(PYMNN_CALIB3D) | ||||
|     target_compile_definitions(mnnpybridge PRIVATE PYMNN_CALIB3D) | ||||
| endif() | ||||
| if(PYMNN_CVCORE) | ||||
|     target_compile_definitions(mnnpybridge PRIVATE PYMNN_CVCORE) | ||||
| endif() | ||||
| 
 | ||||
| if(PYMNN_INTERNAL_SERVING) | ||||
|     message(STATUS "mnnpybridge define PYMNN_INTERNAL_SERVING") | ||||
|  |  | |||
|  | @ -1,5 +1,4 @@ | |||
| import MNN.expr as _F | ||||
| import MNN.cv as _cv | ||||
| 
 | ||||
| # Linear algebra | ||||
| def norm(x, ord=None, axis=None, keepdims=False): | ||||
|  | @ -48,4 +47,5 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False): | |||
|     return (u, w, vt) | ||||
| 
 | ||||
| def solve(a, b): | ||||
|     import _mnncengine.cv as _cv | ||||
|     return _cv.solve(a, b)[1] | ||||
|  | @ -44,12 +44,15 @@ def build_deps(): | |||
|         shutil.rmtree(cmake_build_dir) | ||||
|     os.makedirs(cmake_build_dir) | ||||
|     os.chdir(cmake_build_dir) | ||||
|     extra_opts = '-DMNN_LOW_MEMORY=ON' | ||||
|     extra_opts += ' -DMNN_VULKAN=ON -DMNN_VULKAN_IMAGE=OFF' | ||||
|     extra_opts += ' -DMNN_OPENCL=ON' | ||||
|     if IS_WINDOWS: | ||||
|         os.system('cmake -G "Ninja" -DMNN_BUILD_TRAIN=ON -DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TORCH=OFF\ | ||||
|         os.system('cmake -G "Ninja" ' + extra_opts +' -DMNN_BUILD_TRAIN=ON -DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TORCH=OFF\ | ||||
|             -DMNN_BUILD_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=ON\ | ||||
|             -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF .. && ninja MNN MNNTrain MNNConvert') | ||||
|     elif IS_LINUX: | ||||
|         extra_opts = '-DMNN_TENSORRT=ON \ | ||||
|         extra_opts += '-DMNN_TENSORRT=ON \ | ||||
|         -DCMAKE_LIBRARY_PATH=/usr/local/cuda/lib64/stubs/ ' if USE_TRT else ' ' | ||||
|         extra_opts += ' -DMNN_INTERNAL=ON ' if IS_INTERNAL_BUILD else ' ' | ||||
|         extra_opts += ' -DMNN_BUILD_TORCH=ON ' if IS_BUILD_TORCH else ' ' | ||||
|  | @ -59,11 +62,11 @@ def build_deps(): | |||
|             -DMNN_BUILD_SHARED_LIBS=OFF -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON \ | ||||
|             -DMNN_USE_THREAD_POOL=ON -DMNN_OPENMP=OFF .. && make MNN MNNTrain MNNConvert -j4') | ||||
|     else: | ||||
|         extra_opts = ' -DMNN_INTERNAL=ON ' if IS_INTERNAL_BUILD else ' ' | ||||
|         extra_opts += ' -DMNN_INTERNAL=ON ' if IS_INTERNAL_BUILD else ' ' | ||||
|         extra_opts += ' -DMNN_BUILD_TORCH=ON ' if IS_BUILD_TORCH else ' ' | ||||
|         print(extra_opts) | ||||
|         os.system('cmake ' + extra_opts + '-DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TRAIN=ON -DCMAKE_BUILD_TYPE=Release \ | ||||
|             -DMNN_BUILD_SHARED_LIBS=OFF -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF -DMNN_EXPR_SHAPE_EAGER=ON -DMNN_TRAIN_DEBUG=ON\ | ||||
|             -DMNN_BUILD_SHARED_LIBS=OFF -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF\ | ||||
|             -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON \ | ||||
|             .. && make MNN MNNTrain MNNConvert -j4') | ||||
| ################################################################################ | ||||
|  |  | |||
|  | @ -21,19 +21,16 @@ using Vec4 = MNN::Math::Vec<float, 4>; | |||
| namespace MNN { | ||||
| 
 | ||||
| ErrorCode CPUBinary::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) { | ||||
|     const int input0DataCount = ((CPUBackend*)backend())->getTensorSize(inputs[0]); | ||||
|     const int input1DataCount = ((CPUBackend*)backend())->getTensorSize(inputs[1]); | ||||
|     auto input0DataCount = TensorUtils::getRawSize(inputs[0]); | ||||
|     auto input1DataCount = TensorUtils::getRawSize(inputs[1]); | ||||
|     if (input1DataCount == input0DataCount) { | ||||
|         mNeedBroadcastIndex = -1; | ||||
|         mTotalSize = input1DataCount; | ||||
|     } else if (input0DataCount == 1) { | ||||
|         mNeedBroadcastIndex = 0; | ||||
|         mTotalSize = input1DataCount; | ||||
|     } else { | ||||
|         mNeedBroadcastIndex = 1; | ||||
|         mTotalSize = input0DataCount; | ||||
|     } | ||||
|     MNN_ASSERT(mTotalSize == ((CPUBackend*)backend())->getTensorSize(outputs[0])); | ||||
|     mTotalSize = ((CPUBackend*)backend())->getTensorSize(outputs[0]); | ||||
|      | ||||
|     if(mActivationType == 1 && outputs[0]->getType().code == halide_type_float) { | ||||
|         mActivationExe.reset(new CPURelu(backend(), 0.0)); | ||||
|  |  | |||
|  | @ -19,19 +19,16 @@ | |||
| namespace MNN { | ||||
| 
 | ||||
| ErrorCode CPUBinaryInt8::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) { | ||||
|     const int input0DataCount = ((CPUBackend*)backend())->getTensorSize(inputs[0]); | ||||
|     const int input1DataCount = ((CPUBackend*)backend())->getTensorSize(inputs[1]); | ||||
|     auto input0DataCount = TensorUtils::getRawSize(inputs[0]); | ||||
|     auto input1DataCount = TensorUtils::getRawSize(inputs[1]); | ||||
|     if (input1DataCount == input0DataCount) { | ||||
|         mNeedBroadcastIndex = -1; | ||||
|         mTotalSize = input1DataCount; | ||||
|     } else if (input0DataCount == 1) { | ||||
|         mNeedBroadcastIndex = 0; | ||||
|         mTotalSize = input1DataCount; | ||||
|     } else { | ||||
|         mNeedBroadcastIndex = 1; | ||||
|         mTotalSize = input0DataCount; | ||||
|     } | ||||
|     MNN_ASSERT(mTotalSize == ((CPUBackend*)backend())->getTensorSize(outputs[0])); | ||||
|     mTotalSize = ((CPUBackend*)backend())->getTensorSize(outputs[0]); | ||||
| 
 | ||||
|     auto core = static_cast<CPUBackend*>(backend())->functions(); | ||||
| 
 | ||||
|  |  | |||
|  | @ -835,10 +835,12 @@ public: | |||
|                     auto dstIter = *(iter0 + iter0Stride * iter); | ||||
|                     auto srcOffset = srcIter * step1 + srcView->offset(); | ||||
|                     auto dstOffset = dstIter * step0 + dstView->offset(); | ||||
|                     if (srcOffset >= 0 && srcOffset < inputSize) { | ||||
|                         _blit(reg, bytes, input->host<uint8_t>() + bytes * srcOffset, output->host<uint8_t>() + bytes * dstOffset, proc); | ||||
|                     } else { | ||||
|                         _zero(reg, bytes, output->host<uint8_t>() + bytes * dstOffset); | ||||
|                     if (dstOffset >= 0) { | ||||
|                         if (srcOffset >= 0 && srcOffset < inputSize) { | ||||
|                             _blit(reg, bytes, input->host<uint8_t>() + bytes * srcOffset, output->host<uint8_t>() + bytes * dstOffset, proc); | ||||
|                         } else { | ||||
|                             _zero(reg, bytes, output->host<uint8_t>() + bytes * dstOffset); | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|                 return NO_ERROR; | ||||
|  |  | |||
|  | @ -16,6 +16,10 @@ | |||
| asm_function MNNFloat2Int8 | ||||
| //void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, float* scale, size_t aMin, size_t aMax, size_t zeroPoint);
 | ||||
| //x0:src, x1:dst, x2:sizeQuad, x3:scale, x4:aMin, x5:aMax, x6:zeroPoint | ||||
| stp d14, d15, [sp, #-64]! | ||||
| stp d12, d13, [sp, #16] | ||||
| stp d10, d11, [sp, #32] | ||||
| stp d8,  d9,  [sp, #48] | ||||
| 
 | ||||
| ld1 {v31.4s}, [x3] | ||||
| 
 | ||||
|  | @ -23,12 +27,380 @@ dup v30.16b, w4 | |||
| dup v29.16b, w5 | ||||
| 
 | ||||
| // copy zero point | ||||
| mov v28.s[0], w6 | ||||
| mov v28.s[1], w6 | ||||
| mov v28.s[2], w6 | ||||
| mov v28.s[3], w6 | ||||
| dup v28.4s, w6 | ||||
| scvtf v28.4s, v28.4s | ||||
| 
 | ||||
| FL32: | ||||
| cmp x2, #32 | ||||
| ble FL16 | ||||
| 
 | ||||
| FLLoop32: | ||||
| ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 | ||||
| ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 | ||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 | ||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 | ||||
| ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64 | ||||
| ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 | ||||
| // ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 | ||||
| // ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], #64 | ||||
| fmul v0.4s, v0.4s, v31.4s | ||||
| fadd v0.4s, v0.4s, v28.4s | ||||
| fmul v1.4s, v1.4s, v31.4s | ||||
| fadd v1.4s, v1.4s, v28.4s | ||||
| fmul v2.4s, v2.4s, v31.4s | ||||
| fadd v2.4s, v2.4s, v28.4s | ||||
| fmul v3.4s, v3.4s, v31.4s | ||||
| fadd v3.4s, v3.4s, v28.4s | ||||
| 
 | ||||
| fmul v4.4s, v4.4s, v31.4s | ||||
| fadd v4.4s, v4.4s, v28.4s | ||||
| fmul v5.4s, v5.4s, v31.4s | ||||
| fadd v5.4s, v5.4s, v28.4s | ||||
| fmul v6.4s, v6.4s, v31.4s | ||||
| fadd v6.4s, v6.4s, v28.4s | ||||
| fmul v7.4s, v7.4s, v31.4s | ||||
| fadd v7.4s, v7.4s, v28.4s | ||||
| 
 | ||||
| fmul v8.4s, v8.4s, v31.4s | ||||
| fadd v8.4s, v8.4s, v28.4s | ||||
| fmul v9.4s, v9.4s, v31.4s | ||||
| fadd v9.4s, v9.4s, v28.4s | ||||
| fmul v10.4s, v10.4s, v31.4s | ||||
| fadd v10.4s, v10.4s, v28.4s | ||||
| fmul v11.4s, v11.4s, v31.4s | ||||
| fadd v11.4s, v11.4s, v28.4s | ||||
| 
 | ||||
| fmul v12.4s, v12.4s, v31.4s | ||||
| fadd v12.4s, v12.4s, v28.4s | ||||
| fmul v13.4s, v13.4s, v31.4s | ||||
| fadd v13.4s, v13.4s, v28.4s | ||||
| fmul v14.4s, v14.4s, v31.4s | ||||
| fadd v14.4s, v14.4s, v28.4s | ||||
| fmul v15.4s, v15.4s, v31.4s | ||||
| fadd v15.4s, v15.4s, v28.4s | ||||
| 
 | ||||
| 
 | ||||
| fmul v16.4s, v16.4s, v31.4s | ||||
| fadd v16.4s, v16.4s, v28.4s | ||||
| fmul v17.4s, v17.4s, v31.4s | ||||
| fadd v17.4s, v17.4s, v28.4s | ||||
| fmul v18.4s, v18.4s, v31.4s | ||||
| fadd v18.4s, v18.4s, v28.4s | ||||
| fmul v19.4s, v19.4s, v31.4s | ||||
| fadd v19.4s, v19.4s, v28.4s | ||||
| 
 | ||||
| fmul v20.4s, v20.4s, v31.4s | ||||
| fadd v20.4s, v20.4s, v28.4s | ||||
| fmul v21.4s, v21.4s, v31.4s | ||||
| fadd v21.4s, v21.4s, v28.4s | ||||
| fmul v22.4s, v22.4s, v31.4s | ||||
| fadd v22.4s, v22.4s, v28.4s | ||||
| fmul v23.4s, v23.4s, v31.4s | ||||
| fadd v23.4s, v23.4s, v28.4s | ||||
| 
 | ||||
| fcvtas v0.4s, v0.4s | ||||
| fcvtas v1.4s, v1.4s | ||||
| fcvtas v2.4s, v2.4s | ||||
| fcvtas v3.4s, v3.4s | ||||
| fcvtas v4.4s, v4.4s | ||||
| fcvtas v5.4s, v5.4s | ||||
| fcvtas v6.4s, v6.4s | ||||
| fcvtas v7.4s, v7.4s | ||||
| 
 | ||||
| fcvtas v8.4s, v8.4s | ||||
| fcvtas v9.4s, v9.4s | ||||
| fcvtas v10.4s, v10.4s | ||||
| fcvtas v11.4s, v11.4s | ||||
| fcvtas v12.4s, v12.4s | ||||
| fcvtas v13.4s, v13.4s | ||||
| fcvtas v14.4s, v14.4s | ||||
| fcvtas v15.4s, v15.4s | ||||
| 
 | ||||
| fcvtas v16.4s, v16.4s | ||||
| fcvtas v17.4s, v17.4s | ||||
| fcvtas v18.4s, v18.4s | ||||
| fcvtas v19.4s, v19.4s | ||||
| fcvtas v20.4s, v20.4s | ||||
| fcvtas v21.4s, v21.4s | ||||
| fcvtas v22.4s, v22.4s | ||||
| fcvtas v23.4s, v23.4s | ||||
| 
 | ||||
| 
 | ||||
| sqxtn v24.4h, v0.4s | ||||
| sqxtn2 v24.8h, v1.4s | ||||
| sqxtn v25.4h, v2.4s | ||||
| sqxtn2 v25.8h, v3.4s | ||||
| sqxtn v26.4h, v4.4s | ||||
| sqxtn2 v26.8h, v5.4s | ||||
| sqxtn v27.4h, v6.4s | ||||
| sqxtn2 v27.8h, v7.4s | ||||
| 
 | ||||
| sqxtn v0.4h, v8.4s | ||||
| sqxtn2 v0.8h, v9.4s | ||||
| sqxtn v1.4h, v10.4s | ||||
| sqxtn2 v1.8h, v11.4s | ||||
| sqxtn v2.4h, v12.4s | ||||
| sqxtn2 v2.8h, v13.4s | ||||
| sqxtn v3.4h, v14.4s | ||||
| sqxtn2 v3.8h, v15.4s | ||||
| 
 | ||||
| sqxtn v4.4h, v16.4s | ||||
| sqxtn2 v4.8h, v17.4s | ||||
| sqxtn v5.4h, v18.4s | ||||
| sqxtn2 v5.8h, v19.4s | ||||
| sqxtn v6.4h, v20.4s | ||||
| sqxtn2 v6.8h, v21.4s | ||||
| sqxtn v7.4h, v22.4s | ||||
| sqxtn2 v7.8h, v23.4s | ||||
| 
 | ||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 | ||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 | ||||
| 
 | ||||
| sqxtn v24.8b, v24.8h | ||||
| sqxtn2 v24.16b, v25.8h | ||||
| sqxtn v26.8b, v26.8h | ||||
| sqxtn2 v26.16b, v27.8h | ||||
| sqxtn v0.8b, v0.8h | ||||
| sqxtn2 v0.16b, v1.8h | ||||
| sqxtn v2.8b, v2.8h | ||||
| sqxtn2 v2.16b, v3.8h | ||||
| 
 | ||||
| sqxtn v4.8b, v4.8h | ||||
| sqxtn v6.8b, v6.8h | ||||
| sqxtn2 v4.16b, v5.8h | ||||
| sqxtn2 v6.16b, v7.8h | ||||
| 
 | ||||
| fmul v8.4s, v8.4s, v31.4s | ||||
| fadd v8.4s, v8.4s, v28.4s | ||||
| fmul v9.4s, v9.4s, v31.4s | ||||
| fadd v9.4s, v9.4s, v28.4s | ||||
| fmul v10.4s, v10.4s, v31.4s | ||||
| fadd v10.4s, v10.4s, v28.4s | ||||
| fmul v11.4s, v11.4s, v31.4s | ||||
| fadd v11.4s, v11.4s, v28.4s | ||||
| 
 | ||||
| fmul v12.4s, v12.4s, v31.4s | ||||
| fadd v12.4s, v12.4s, v28.4s | ||||
| fmul v13.4s, v13.4s, v31.4s | ||||
| fadd v13.4s, v13.4s, v28.4s | ||||
| fmul v14.4s, v14.4s, v31.4s | ||||
| fadd v14.4s, v14.4s, v28.4s | ||||
| fmul v15.4s, v15.4s, v31.4s | ||||
| fadd v15.4s, v15.4s, v28.4s | ||||
| 
 | ||||
| fcvtas v8.4s, v8.4s | ||||
| fcvtas v9.4s, v9.4s | ||||
| fcvtas v10.4s, v10.4s | ||||
| fcvtas v11.4s, v11.4s | ||||
| fcvtas v12.4s, v12.4s | ||||
| fcvtas v13.4s, v13.4s | ||||
| fcvtas v14.4s, v14.4s | ||||
| fcvtas v15.4s, v15.4s | ||||
| 
 | ||||
| sqxtn v16.4h, v8.4s | ||||
| sqxtn2 v16.8h, v9.4s | ||||
| sqxtn v17.4h, v10.4s | ||||
| sqxtn2 v17.8h, v11.4s | ||||
| sqxtn v18.4h, v12.4s | ||||
| sqxtn2 v18.8h, v13.4s | ||||
| sqxtn v19.4h, v14.4s | ||||
| sqxtn2 v19.8h, v15.4s | ||||
| 
 | ||||
| smin v24.16b, v24.16b, v29.16b | ||||
| smax v24.16b, v24.16b, v30.16b | ||||
| smin v25.16b, v26.16b, v29.16b | ||||
| smax v25.16b, v25.16b, v30.16b | ||||
| 
 | ||||
| sqxtn v20.8b, v16.8h | ||||
| sqxtn2 v20.16b, v17.8h | ||||
| sqxtn v21.8b, v18.8h | ||||
| sqxtn2 v21.16b, v19.8h | ||||
| 
 | ||||
| smin v26.16b, v0.16b, v29.16b | ||||
| smax v26.16b, v26.16b, v30.16b | ||||
| smin v27.16b, v2.16b, v29.16b | ||||
| smax v27.16b, v27.16b, v30.16b | ||||
| 
 | ||||
| smin v12.16b, v4.16b, v29.16b | ||||
| smax v12.16b, v12.16b, v30.16b | ||||
| smin v13.16b, v6.16b, v29.16b | ||||
| smax v13.16b, v13.16b, v30.16b | ||||
| 
 | ||||
| smin v14.16b, v20.16b, v29.16b | ||||
| smax v14.16b, v14.16b, v30.16b | ||||
| smin v15.16b, v21.16b, v29.16b | ||||
| smax v15.16b, v15.16b, v30.16b | ||||
| 
 | ||||
| st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x1], #64 | ||||
| st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64 | ||||
| 
 | ||||
| sub x2, x2, #32 | ||||
| cmp x2, #32 | ||||
| bge FLLoop32 | ||||
| 
 | ||||
| FL16: | ||||
| cmp x2, #16 | ||||
| ble FL8 | ||||
| 
 | ||||
| FLLoop16: | ||||
| ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 | ||||
| ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 | ||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 | ||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 | ||||
| fmul v0.4s, v0.4s, v31.4s | ||||
| fadd v0.4s, v0.4s, v28.4s | ||||
| fmul v1.4s, v1.4s, v31.4s | ||||
| fadd v1.4s, v1.4s, v28.4s | ||||
| fmul v2.4s, v2.4s, v31.4s | ||||
| fadd v2.4s, v2.4s, v28.4s | ||||
| fmul v3.4s, v3.4s, v31.4s | ||||
| fadd v3.4s, v3.4s, v28.4s | ||||
| 
 | ||||
| fmul v4.4s, v4.4s, v31.4s | ||||
| fadd v4.4s, v4.4s, v28.4s | ||||
| fmul v5.4s, v5.4s, v31.4s | ||||
| fadd v5.4s, v5.4s, v28.4s | ||||
| fmul v6.4s, v6.4s, v31.4s | ||||
| fadd v6.4s, v6.4s, v28.4s | ||||
| fmul v7.4s, v7.4s, v31.4s | ||||
| fadd v7.4s, v7.4s, v28.4s | ||||
| 
 | ||||
| fmul v8.4s, v8.4s, v31.4s | ||||
| fadd v8.4s, v8.4s, v28.4s | ||||
| fmul v9.4s, v9.4s, v31.4s | ||||
| fadd v9.4s, v9.4s, v28.4s | ||||
| fmul v10.4s, v10.4s, v31.4s | ||||
| fadd v10.4s, v10.4s, v28.4s | ||||
| fmul v11.4s, v11.4s, v31.4s | ||||
| fadd v11.4s, v11.4s, v28.4s | ||||
| 
 | ||||
| fmul v12.4s, v12.4s, v31.4s | ||||
| fadd v12.4s, v12.4s, v28.4s | ||||
| fmul v13.4s, v13.4s, v31.4s | ||||
| fadd v13.4s, v13.4s, v28.4s | ||||
| fmul v14.4s, v14.4s, v31.4s | ||||
| fadd v14.4s, v14.4s, v28.4s | ||||
| fmul v15.4s, v15.4s, v31.4s | ||||
| fadd v15.4s, v15.4s, v28.4s | ||||
| 
 | ||||
| fcvtas v0.4s, v0.4s | ||||
| fcvtas v1.4s, v1.4s | ||||
| fcvtas v2.4s, v2.4s | ||||
| fcvtas v3.4s, v3.4s | ||||
| fcvtas v4.4s, v4.4s | ||||
| fcvtas v5.4s, v5.4s | ||||
| fcvtas v6.4s, v6.4s | ||||
| fcvtas v7.4s, v7.4s | ||||
| 
 | ||||
| fcvtas v8.4s, v8.4s | ||||
| fcvtas v9.4s, v9.4s | ||||
| fcvtas v10.4s, v10.4s | ||||
| fcvtas v11.4s, v11.4s | ||||
| fcvtas v12.4s, v12.4s | ||||
| fcvtas v13.4s, v13.4s | ||||
| fcvtas v14.4s, v14.4s | ||||
| fcvtas v15.4s, v15.4s | ||||
| 
 | ||||
| sqxtn v16.4h, v0.4s | ||||
| sqxtn2 v16.8h, v1.4s | ||||
| sqxtn v17.4h, v2.4s | ||||
| sqxtn2 v17.8h, v3.4s | ||||
| sqxtn v18.4h, v4.4s | ||||
| sqxtn2 v18.8h, v5.4s | ||||
| sqxtn v19.4h, v6.4s | ||||
| sqxtn2 v19.8h, v7.4s | ||||
| 
 | ||||
| sqxtn v20.4h, v8.4s | ||||
| sqxtn2 v20.8h, v9.4s | ||||
| sqxtn v21.4h, v10.4s | ||||
| sqxtn2 v21.8h, v11.4s | ||||
| sqxtn v22.4h, v12.4s | ||||
| sqxtn2 v22.8h, v13.4s | ||||
| sqxtn v23.4h, v14.4s | ||||
| sqxtn2 v23.8h, v15.4s | ||||
| 
 | ||||
| sqxtn v24.8b, v16.8h | ||||
| sqxtn2 v24.16b, v17.8h | ||||
| sqxtn v25.8b, v18.8h | ||||
| sqxtn2 v25.16b, v19.8h | ||||
| sqxtn v26.8b, v20.8h | ||||
| sqxtn2 v26.16b, v21.8h | ||||
| sqxtn v27.8b, v22.8h | ||||
| sqxtn2 v27.16b, v23.8h | ||||
| smin v24.16b, v24.16b, v29.16b | ||||
| smax v24.16b, v24.16b, v30.16b | ||||
| smin v25.16b, v25.16b, v29.16b | ||||
| smax v25.16b, v25.16b, v30.16b | ||||
| smin v26.16b, v26.16b, v29.16b | ||||
| smax v26.16b, v26.16b, v30.16b | ||||
| smin v27.16b, v27.16b, v29.16b | ||||
| smax v27.16b, v27.16b, v30.16b | ||||
| 
 | ||||
| st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x1], #64 | ||||
| 
 | ||||
| sub x2, x2, #16 | ||||
| cmp x2, #16 | ||||
| bge FLLoop16 | ||||
| 
 | ||||
| FL8: | ||||
| cmp x2, #8 | ||||
| ble FL4 | ||||
| 
 | ||||
| FLLoop8: | ||||
| ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 | ||||
| ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 | ||||
| fmul v0.4s, v0.4s, v31.4s | ||||
| fadd v0.4s, v0.4s, v28.4s | ||||
| fmul v1.4s, v1.4s, v31.4s | ||||
| fadd v1.4s, v1.4s, v28.4s | ||||
| fmul v2.4s, v2.4s, v31.4s | ||||
| fadd v2.4s, v2.4s, v28.4s | ||||
| fmul v3.4s, v3.4s, v31.4s | ||||
| fadd v3.4s, v3.4s, v28.4s | ||||
| 
 | ||||
| fmul v4.4s, v4.4s, v31.4s | ||||
| fadd v4.4s, v4.4s, v28.4s | ||||
| fmul v5.4s, v5.4s, v31.4s | ||||
| fadd v5.4s, v5.4s, v28.4s | ||||
| fmul v6.4s, v6.4s, v31.4s | ||||
| fadd v6.4s, v6.4s, v28.4s | ||||
| fmul v7.4s, v7.4s, v31.4s | ||||
| fadd v7.4s, v7.4s, v28.4s | ||||
| 
 | ||||
| fcvtas v0.4s, v0.4s | ||||
| fcvtas v1.4s, v1.4s | ||||
| fcvtas v2.4s, v2.4s | ||||
| fcvtas v3.4s, v3.4s | ||||
| fcvtas v4.4s, v4.4s | ||||
| fcvtas v5.4s, v5.4s | ||||
| fcvtas v6.4s, v6.4s | ||||
| fcvtas v7.4s, v7.4s | ||||
| 
 | ||||
| sqxtn v8.4h, v0.4s | ||||
| sqxtn2 v8.8h, v1.4s | ||||
| sqxtn v9.4h, v2.4s | ||||
| sqxtn2 v9.8h, v3.4s | ||||
| sqxtn v10.4h, v4.4s | ||||
| sqxtn2 v10.8h, v5.4s | ||||
| sqxtn v11.4h, v6.4s | ||||
| sqxtn2 v11.8h, v7.4s | ||||
| 
 | ||||
| sqxtn v12.8b, v8.8h | ||||
| sqxtn2 v12.16b, v9.8h | ||||
| sqxtn v13.8b, v10.8h | ||||
| sqxtn2 v13.16b, v11.8h | ||||
| smin v12.16b, v12.16b, v29.16b | ||||
| smax v12.16b, v12.16b, v30.16b | ||||
| smin v13.16b, v13.16b, v29.16b | ||||
| smax v13.16b, v13.16b, v30.16b | ||||
| 
 | ||||
| st1 {v12.4s, v13.4s}, [x1], #32 | ||||
| 
 | ||||
| sub x2, x2, #8 | ||||
| cmp x2, #8 | ||||
| bge FLLoop8 | ||||
| 
 | ||||
| FL4: | ||||
| cmp x2, #3 | ||||
| ble FL1 | ||||
| 
 | ||||
|  | @ -89,6 +461,9 @@ subs x2, x2, #1 | |||
| bne FLLoop1 | ||||
| 
 | ||||
| FLEnd: | ||||
| 
 | ||||
| ldp d8,  d9,  [sp, #48] | ||||
| ldp d10, d11, [sp, #32] | ||||
| ldp d12, d13, [sp, #16] | ||||
| ldp d14, d15, [sp], #64 | ||||
| ret | ||||
| #endif | ||||
|  |  | |||
|  | @ -595,12 +595,13 @@ Tile1LoopEnd: | |||
|     bge LoopDz_TILE_1 | ||||
| 
 | ||||
| End: | ||||
| ldp x23, x24, [sp, #(16 * 6)] | ||||
| ldp x19, x20, [sp, #(16 * 5)] | ||||
| ldp x21, x22, [sp, #(16 * 4)] | ||||
| ldp d8,  d9,  [sp, #(16 * 3)] | ||||
| ldp d10, d11, [sp, #(16 * 2)] | ||||
| ldp d12, d13, [sp, #(16 * 1)] | ||||
| ldp d14, d15, [sp], #(16 * 6) | ||||
| ldp d14, d15, [sp], #(16 * 7) | ||||
| ret | ||||
| 
 | ||||
| #endif // __aarch64__ | ||||
|  |  | |||
|  | @ -18,6 +18,10 @@ asm_function MNNMatrixAdd | |||
| 
 | ||||
| //Auto: x0: C, x1:A, x2:B, x3:widthC4 | ||||
| //x4:cStride, x5:aStride, x6:bStride, x7:height | ||||
| stp d14, d15, [sp, #-64]! | ||||
| stp d12, d13, [sp, #16] | ||||
| stp d10, d11, [sp, #32] | ||||
| stp d8,  d9,  [sp, #48] | ||||
| 
 | ||||
| mov x12, #4 //sizeof(float) | ||||
| mul x4, x12, x4 | ||||
|  | @ -31,6 +35,72 @@ mov x10, x2 | |||
| 
 | ||||
| mov x11, x3 | ||||
| 
 | ||||
| L16: | ||||
| cmp x11, #16 | ||||
| blt L8 | ||||
| 
 | ||||
| L16Loop: | ||||
| ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 | ||||
| ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 | ||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 | ||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 | ||||
| ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64 | ||||
| ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64 | ||||
| ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64 | ||||
| ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 | ||||
| 
 | ||||
| fadd v0.4s, v0.4s, v8.4s | ||||
| fadd v1.4s, v1.4s, v9.4s | ||||
| fadd v2.4s, v2.4s, v10.4s | ||||
| fadd v3.4s, v3.4s, v11.4s | ||||
| fadd v4.4s, v4.4s, v12.4s | ||||
| fadd v5.4s, v5.4s, v13.4s | ||||
| fadd v6.4s, v6.4s, v14.4s | ||||
| fadd v7.4s, v7.4s, v15.4s | ||||
| 
 | ||||
| sub x11, x11, #16 | ||||
| 
 | ||||
| fadd v16.4s, v16.4s, v24.4s | ||||
| fadd v17.4s, v17.4s, v25.4s | ||||
| fadd v18.4s, v18.4s, v26.4s | ||||
| fadd v19.4s, v19.4s, v27.4s | ||||
| fadd v20.4s, v20.4s, v28.4s | ||||
| fadd v21.4s, v21.4s, v29.4s | ||||
| fadd v22.4s, v22.4s, v30.4s | ||||
| fadd v23.4s, v23.4s, v31.4s | ||||
| 
 | ||||
| st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 | ||||
| st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 | ||||
| st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64 | ||||
| st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 | ||||
| cmp x11, #16 | ||||
| bge L16Loop | ||||
| 
 | ||||
| L8: | ||||
| cmp x11, #8 | ||||
| blt L4 | ||||
| 
 | ||||
| L8Loop: | ||||
| ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 | ||||
| ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 | ||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 | ||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 | ||||
| 
 | ||||
| fadd v0.4s, v0.4s, v8.4s | ||||
| fadd v1.4s, v1.4s, v9.4s | ||||
| fadd v2.4s, v2.4s, v10.4s | ||||
| fadd v3.4s, v3.4s, v11.4s | ||||
| fadd v4.4s, v4.4s, v12.4s | ||||
| fadd v5.4s, v5.4s, v13.4s | ||||
| fadd v6.4s, v6.4s, v14.4s | ||||
| fadd v7.4s, v7.4s, v15.4s | ||||
| sub x11, x11, #8 | ||||
| 
 | ||||
| st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 | ||||
| st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 | ||||
| cmp x11, #8 | ||||
| bge L8Loop | ||||
| 
 | ||||
| L4: | ||||
| cmp x11, #4 | ||||
| blt L1 | ||||
|  | @ -89,5 +159,10 @@ add x2, x10, x6 | |||
| subs x7, x7, #1 | ||||
| bne LoopY | ||||
| 
 | ||||
| End: | ||||
| ldp d8,  d9,  [sp, #48] | ||||
| ldp d10, d11, [sp, #32] | ||||
| ldp d12, d13, [sp, #16] | ||||
| ldp d14, d15, [sp], #64 | ||||
| ret | ||||
| #endif | ||||
|  |  | |||
|  | @ -18,6 +18,10 @@ asm_function MNNMatrixSub | |||
| 
 | ||||
| //Auto: x0: C, x1:A, x2:B, x3:widthC4 | ||||
| //x4:cStride, x5:aStride, x6:bStride, x7:height | ||||
| stp d14, d15, [sp, #-64]! | ||||
| stp d12, d13, [sp, #16] | ||||
| stp d10, d11, [sp, #32] | ||||
| stp d8,  d9,  [sp, #48] | ||||
| 
 | ||||
| mov x12, #4 //sizeof(float) | ||||
| mul x4, x12, x4 | ||||
|  | @ -31,6 +35,72 @@ mov x10, x2 | |||
| 
 | ||||
| mov x11, x3 | ||||
| 
 | ||||
| L16: | ||||
| cmp x11, #16 | ||||
| blt L8 | ||||
| 
 | ||||
| L16Loop: | ||||
| ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 | ||||
| ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 | ||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 | ||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 | ||||
| ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64 | ||||
| ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64 | ||||
| ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64 | ||||
| ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64 | ||||
| 
 | ||||
| fsub v0.4s, v0.4s, v8.4s | ||||
| fsub v1.4s, v1.4s, v9.4s | ||||
| fsub v2.4s, v2.4s, v10.4s | ||||
| fsub v3.4s, v3.4s, v11.4s | ||||
| fsub v4.4s, v4.4s, v12.4s | ||||
| fsub v5.4s, v5.4s, v13.4s | ||||
| fsub v6.4s, v6.4s, v14.4s | ||||
| fsub v7.4s, v7.4s, v15.4s | ||||
| 
 | ||||
| sub x11, x11, #16 | ||||
| 
 | ||||
| fsub v16.4s, v16.4s, v24.4s | ||||
| fsub v17.4s, v17.4s, v25.4s | ||||
| fsub v18.4s, v18.4s, v26.4s | ||||
| fsub v19.4s, v19.4s, v27.4s | ||||
| fsub v20.4s, v20.4s, v28.4s | ||||
| fsub v21.4s, v21.4s, v29.4s | ||||
| fsub v22.4s, v22.4s, v30.4s | ||||
| fsub v23.4s, v23.4s, v31.4s | ||||
| 
 | ||||
| st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 | ||||
| st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 | ||||
| st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64 | ||||
| st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 | ||||
| cmp x11, #16 | ||||
| bge L16Loop | ||||
| 
 | ||||
| L8: | ||||
| cmp x11, #8 | ||||
| blt L4 | ||||
| 
 | ||||
| L8Loop: | ||||
| ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 | ||||
| ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 | ||||
| ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 | ||||
| ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 | ||||
| 
 | ||||
| fsub v0.4s, v0.4s, v8.4s | ||||
| fsub v1.4s, v1.4s, v9.4s | ||||
| fsub v2.4s, v2.4s, v10.4s | ||||
| fsub v3.4s, v3.4s, v11.4s | ||||
| fsub v4.4s, v4.4s, v12.4s | ||||
| fsub v5.4s, v5.4s, v13.4s | ||||
| fsub v6.4s, v6.4s, v14.4s | ||||
| fsub v7.4s, v7.4s, v15.4s | ||||
| sub x11, x11, #8 | ||||
| 
 | ||||
| st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 | ||||
| st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 | ||||
| cmp x11, #8 | ||||
| bge L8Loop | ||||
| 
 | ||||
| L4: | ||||
| cmp x11, #4 | ||||
| blt L1 | ||||
|  | @ -89,6 +159,11 @@ add x2, x10, x6 | |||
| subs x7, x7, #1 | ||||
| bne LoopY | ||||
| 
 | ||||
| End: | ||||
| ldp d8,  d9,  [sp, #48] | ||||
| ldp d10, d11, [sp, #32] | ||||
| ldp d12, d13, [sp, #16] | ||||
| ldp d14, d15, [sp], #64 | ||||
| ret | ||||
| 
 | ||||
| #endif | ||||
|  |  | |||
|  | @ -17,7 +17,10 @@ asm_function MNNReluWithSlopeChannel | |||
| 
 | ||||
| //Auto Load: | ||||
| //x0:dst, x1:src, x2:slope, x3:sizeQuad, x4:depthQuad | ||||
| 
 | ||||
| stp d14, d15, [sp, #-64]! | ||||
| stp d12, d13, [sp, #16] | ||||
| stp d10, d11, [sp, #32] | ||||
| stp d8,  d9,  [sp, #48] | ||||
| 
 | ||||
| cmp x4, #0 | ||||
| beq PReluEnd | ||||
|  | @ -26,48 +29,164 @@ beq PReluEnd | |||
| 
 | ||||
| 
 | ||||
| PReluZLoop: | ||||
| ld1 {v23.4s}, [x2], #16 | ||||
| ld1 {v31.4s}, [x2], #16 | ||||
| mov x5, x3 | ||||
| 
 | ||||
| PReluL16: | ||||
| cmp x5, #15 | ||||
| ble PReluL8 | ||||
| 
 | ||||
| PReluL16Loop: | ||||
| ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 | ||||
| ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 | ||||
| ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x1], #64 | ||||
| 
 | ||||
| fcmle v16.4s, v0.4s, #0 | ||||
| fcmle v17.4s, v1.4s, #0 | ||||
| fcmle v18.4s, v2.4s, #0 | ||||
| fcmle v19.4s, v3.4s, #0 | ||||
| fcmle v20.4s, v4.4s, #0 | ||||
| fcmle v21.4s, v5.4s, #0 | ||||
| fcmle v22.4s, v6.4s, #0 | ||||
| fcmle v23.4s, v7.4s, #0 | ||||
| 
 | ||||
| fmul v8.4s,  v0.4s, v31.4s | ||||
| fmul v9.4s,  v1.4s, v31.4s | ||||
| fmul v10.4s, v2.4s, v31.4s | ||||
| fmul v11.4s, v3.4s, v31.4s | ||||
| fmul v12.4s, v4.4s, v31.4s | ||||
| fmul v13.4s, v5.4s, v31.4s | ||||
| fmul v14.4s, v6.4s, v31.4s | ||||
| fmul v15.4s, v7.4s, v31.4s | ||||
| 
 | ||||
| fcmle v28.4s, v24.4s, #0 | ||||
| fcmle v29.4s, v25.4s, #0 | ||||
| fcmle v30.4s, v26.4s, #0 | ||||
| 
 | ||||
| bit v0.16b, v8.16b, v16.16b | ||||
| bit v1.16b, v9.16b, v17.16b | ||||
| bit v2.16b, v10.16b, v18.16b | ||||
| bit v3.16b, v11.16b, v19.16b | ||||
| bit v4.16b, v12.16b, v20.16b | ||||
| bit v5.16b, v13.16b, v21.16b | ||||
| bit v6.16b, v14.16b, v22.16b | ||||
| bit v7.16b, v15.16b, v23.16b | ||||
| 
 | ||||
| st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 | ||||
| st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 | ||||
| fcmle v8.4s, v27.4s, #0 | ||||
| fmul v9.4s,  v24.4s, v31.4s | ||||
| fmul v10.4s, v25.4s, v31.4s | ||||
| fmul v11.4s, v26.4s, v31.4s | ||||
| fmul v12.4s, v27.4s, v31.4s | ||||
| ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64 | ||||
| 
 | ||||
| fcmle v13.4s, v16.4s, #0 | ||||
| fcmle v14.4s, v17.4s, #0 | ||||
| fcmle v15.4s, v18.4s, #0 | ||||
| fcmle v0.4s, v19.4s, #0 | ||||
| 
 | ||||
| fmul v20.4s, v16.4s, v31.4s | ||||
| fmul v21.4s, v17.4s, v31.4s | ||||
| fmul v22.4s, v18.4s, v31.4s | ||||
| fmul v23.4s, v19.4s, v31.4s | ||||
| 
 | ||||
| 
 | ||||
| bit v24.16b, v9.16b, v28.16b | ||||
| bit v25.16b, v10.16b, v29.16b | ||||
| bit v26.16b, v11.16b, v30.16b | ||||
| bit v27.16b, v12.16b, v8.16b | ||||
| bit v16.16b, v20.16b, v13.16b | ||||
| bit v17.16b, v21.16b, v14.16b | ||||
| bit v18.16b, v22.16b, v15.16b | ||||
| bit v19.16b, v23.16b, v0.16b | ||||
| 
 | ||||
| st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 | ||||
| st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64 | ||||
| 
 | ||||
| sub x5, x5, #16 | ||||
| cmp x5, #16 | ||||
| bge PReluL16Loop | ||||
| 
 | ||||
| PReluL8: | ||||
| cmp x5, #7 | ||||
| ble PReluL4 | ||||
| 
 | ||||
| PReluL8Loop: | ||||
| ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 | ||||
| ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 | ||||
| 
 | ||||
| fcmle v16.4s, v0.4s, #0 | ||||
| fcmle v17.4s, v1.4s, #0 | ||||
| fcmle v18.4s, v2.4s, #0 | ||||
| fcmle v19.4s, v3.4s, #0 | ||||
| fcmle v20.4s, v4.4s, #0 | ||||
| fcmle v21.4s, v5.4s, #0 | ||||
| fcmle v22.4s, v6.4s, #0 | ||||
| fcmle v23.4s, v7.4s, #0 | ||||
| 
 | ||||
| fmul v8.4s,  v0.4s, v31.4s | ||||
| fmul v9.4s,  v1.4s, v31.4s | ||||
| fmul v10.4s, v2.4s, v31.4s | ||||
| fmul v11.4s, v3.4s, v31.4s | ||||
| fmul v12.4s, v4.4s, v31.4s | ||||
| fmul v13.4s, v5.4s, v31.4s | ||||
| fmul v14.4s, v6.4s, v31.4s | ||||
| fmul v15.4s, v7.4s, v31.4s | ||||
| 
 | ||||
| 
 | ||||
| bit v0.16b, v8.16b, v16.16b | ||||
| bit v1.16b, v9.16b, v17.16b | ||||
| bit v2.16b, v10.16b, v18.16b | ||||
| bit v3.16b, v11.16b, v19.16b | ||||
| bit v4.16b, v12.16b, v20.16b | ||||
| bit v5.16b, v13.16b, v21.16b | ||||
| bit v6.16b, v14.16b, v22.16b | ||||
| bit v7.16b, v15.16b, v23.16b | ||||
| 
 | ||||
| st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 | ||||
| st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 | ||||
| 
 | ||||
| sub x5, x5, #8 | ||||
| cmp x5, #8 | ||||
| bge PReluL8Loop | ||||
| 
 | ||||
| PReluL4: | ||||
| cmp x5, #3 | ||||
| ble PReluL1 | ||||
| 
 | ||||
| PReluL4Loop: | ||||
| ld1 {v0.4s, v1.4s}, [x1], #32 | ||||
| ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 | ||||
| 
 | ||||
| fcmle v20.4s, v0.4s, #0 | ||||
| fcmle v21.4s, v1.4s, #0 | ||||
| fcmle v8.4s, v0.4s, #0 | ||||
| fcmle v9.4s, v1.4s, #0 | ||||
| fcmle v10.4s, v2.4s, #0 | ||||
| fcmle v11.4s, v3.4s, #0 | ||||
| 
 | ||||
| ld1 {v2.4s, v3.4s}, [x1], #32 | ||||
| fmul v4.4s, v0.4s, v31.4s | ||||
| fmul v5.4s, v1.4s, v31.4s | ||||
| fmul v6.4s, v2.4s, v31.4s | ||||
| fmul v7.4s, v3.4s, v31.4s | ||||
| 
 | ||||
| fmul v16.4s, v0.4s, v23.4s | ||||
| fmul v17.4s, v1.4s, v23.4s | ||||
| bit v0.16b, v16.16b, v20.16b | ||||
| bit v1.16b, v17.16b, v21.16b | ||||
| bit v0.16b, v4.16b, v8.16b | ||||
| bit v1.16b, v5.16b, v9.16b | ||||
| bit v2.16b, v6.16b, v10.16b | ||||
| bit v3.16b, v7.16b, v11.16b | ||||
| 
 | ||||
| fmul v16.4s, v2.4s, v23.4s | ||||
| fmul v17.4s, v3.4s, v23.4s | ||||
| st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 | ||||
| 
 | ||||
| st1 {v0.4s, v1.4s}, [x0], #32 | ||||
| 
 | ||||
| fcmle v20.4s, v2.4s, #0 | ||||
| fcmle v21.4s, v3.4s, #0 | ||||
| bit v2.16b, v16.16b, v20.16b | ||||
| bit v3.16b, v17.16b, v21.16b | ||||
| 
 | ||||
| st1 {v2.4s, v3.4s}, [x0], #32 | ||||
| sub x5, x5, #4 | ||||
| cmp x5, #4 | ||||
| bge PReluL4Loop | ||||
| 
 | ||||
| PReluL1: | ||||
| cmp x5, #0 | ||||
| 
 | ||||
| beq PReluL1End | ||||
| 
 | ||||
| PReluL1Loop: | ||||
| ld1 {v0.4s}, [x1], #16 | ||||
| fcmle v2.4s, v0.4s, #0 | ||||
| fmul v1.4s, v0.4s, v23.4s | ||||
| fmul v1.4s, v0.4s, v31.4s | ||||
| bit v0.16b, v1.16b, v2.16b | ||||
| st1 {v0.4s}, [x0], #16 | ||||
| subs x5, x5, #1 | ||||
|  | @ -80,6 +199,10 @@ bne PReluZLoop | |||
| 
 | ||||
| 
 | ||||
| PReluEnd: | ||||
| ldp d8,  d9,  [sp, #48] | ||||
| ldp d10, d11, [sp, #32] | ||||
| ldp d12, d13, [sp, #16] | ||||
| ldp d14, d15, [sp], #64 | ||||
| 
 | ||||
| ret | ||||
| #endif | ||||
|  |  | |||
|  | @ -16,6 +16,10 @@ asm_function MNNScaleAndAddBias | |||
| //void MNNScaleAndAddBias(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber, size_t biasNumber) | ||||
| 
 | ||||
| //Auto: x0:dst, x1:src, x2:bias, x3:alpha, x4:planeNumber, x5:biasNumber | ||||
| stp d14, d15, [sp, #-64]! | ||||
| stp d12, d13, [sp, #16] | ||||
| stp d10, d11, [sp, #32] | ||||
| stp d8,  d9,  [sp, #48] | ||||
| 
 | ||||
| cmp x4, #0 | ||||
| beq BSEnd | ||||
|  | @ -27,8 +31,158 @@ BSLoopZ: | |||
|     mov x6, x4 | ||||
|     ld1 {v31.4s}, [x2], #16 | ||||
|     ld1 {v30.4s}, [x3], #16 | ||||
| 
 | ||||
|     BSL32: | ||||
|     cmp x6, #31 | ||||
|     ble BSL16 | ||||
|     BSLoopP32: | ||||
|         ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 | ||||
|         ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 | ||||
|         ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64 | ||||
|         ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64 | ||||
|         ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64 | ||||
|         ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64 | ||||
|         ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x1], #64 | ||||
|         fmul v0.4s, v0.4s, v30.4s | ||||
|         fmul v1.4s, v1.4s, v30.4s | ||||
|         fmul v2.4s, v2.4s, v30.4s | ||||
|         fmul v3.4s, v3.4s, v30.4s | ||||
|         fmul v4.4s, v4.4s, v30.4s | ||||
|         fmul v5.4s, v5.4s, v30.4s | ||||
|         fmul v6.4s, v6.4s, v30.4s | ||||
|         fmul v7.4s, v7.4s, v30.4s | ||||
|         sub x6, x6, #32 | ||||
|         fmul v8.4s, v8.4s, v30.4s | ||||
|         fmul v9.4s, v9.4s, v30.4s | ||||
|         fmul v10.4s, v10.4s, v30.4s | ||||
|         fmul v11.4s, v11.4s, v30.4s | ||||
|         fmul v12.4s, v12.4s, v30.4s | ||||
|         fmul v13.4s, v13.4s, v30.4s | ||||
|         fmul v14.4s, v14.4s, v30.4s | ||||
|         fmul v15.4s, v15.4s, v30.4s | ||||
| 
 | ||||
|         fmul v16.4s, v16.4s, v30.4s | ||||
|         fmul v17.4s, v17.4s, v30.4s | ||||
|         fmul v18.4s, v18.4s, v30.4s | ||||
|         fmul v19.4s, v19.4s, v30.4s | ||||
|         fmul v20.4s, v20.4s, v30.4s | ||||
|         fmul v21.4s, v21.4s, v30.4s | ||||
|         fmul v22.4s, v22.4s, v30.4s | ||||
|         fmul v23.4s, v23.4s, v30.4s | ||||
| 
 | ||||
|         fmul v24.4s, v24.4s, v30.4s | ||||
|         fmul v25.4s, v25.4s, v30.4s | ||||
|         fmul v26.4s, v26.4s, v30.4s | ||||
|         fmul v27.4s, v27.4s, v30.4s | ||||
|      | ||||
|         fadd v0.4s, v0.4s, v31.4s | ||||
|         fadd v1.4s, v1.4s, v31.4s | ||||
|         fadd v2.4s, v2.4s, v31.4s | ||||
|         fadd v3.4s, v3.4s, v31.4s | ||||
|         fadd v4.4s, v4.4s, v31.4s | ||||
|         fadd v5.4s, v5.4s, v31.4s | ||||
|         fadd v6.4s, v6.4s, v31.4s | ||||
|         fadd v7.4s, v7.4s, v31.4s | ||||
|         cmp x6, #32 | ||||
|         fadd v8.4s, v8.4s, v31.4s | ||||
|         fadd v9.4s, v9.4s, v31.4s | ||||
|         fadd v10.4s, v10.4s, v31.4s | ||||
|         fadd v11.4s, v11.4s, v31.4s | ||||
|         fadd v12.4s, v12.4s, v31.4s | ||||
|         fadd v13.4s, v13.4s, v31.4s | ||||
|         fadd v14.4s, v14.4s, v31.4s | ||||
|         fadd v15.4s, v15.4s, v31.4s | ||||
|         st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 | ||||
| 
 | ||||
|         fadd v16.4s, v16.4s, v31.4s | ||||
|         fadd v17.4s, v17.4s, v31.4s | ||||
|         fadd v18.4s, v18.4s, v31.4s | ||||
|         fadd v19.4s, v19.4s, v31.4s | ||||
|         fadd v20.4s, v20.4s, v31.4s | ||||
|         fadd v21.4s, v21.4s, v31.4s | ||||
|         fadd v22.4s, v22.4s, v31.4s | ||||
|         fadd v23.4s, v23.4s, v31.4s | ||||
|         ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 | ||||
| 
 | ||||
|         fadd v24.4s, v24.4s, v31.4s | ||||
|         fadd v25.4s, v25.4s, v31.4s | ||||
|         fadd v26.4s, v26.4s, v31.4s | ||||
|         fadd v27.4s, v27.4s, v31.4s | ||||
| 
 | ||||
|         fmul v0.4s, v0.4s, v30.4s | ||||
|         fmul v1.4s, v1.4s, v30.4s | ||||
|         fmul v2.4s, v2.4s, v30.4s | ||||
|         fmul v3.4s, v3.4s, v30.4s | ||||
| 
 | ||||
|         st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 | ||||
|         st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 | ||||
|         st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 | ||||
|         fadd v0.4s, v0.4s, v31.4s | ||||
|         fadd v1.4s, v1.4s, v31.4s | ||||
|         fadd v2.4s, v2.4s, v31.4s | ||||
|         fadd v3.4s, v3.4s, v31.4s | ||||
|         st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64 | ||||
|         st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 | ||||
|         st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 | ||||
|          | ||||
|         st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 | ||||
| 
 | ||||
|         bge BSLoopP32 | ||||
| 
 | ||||
|     BSL16: | ||||
|     cmp x6, #15 | ||||
|     ble BSL8_ | ||||
|     BSLoopP16: | ||||
|         ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 | ||||
|         ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 | ||||
|         ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64 | ||||
|         ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64 | ||||
|         fmul v0.4s, v0.4s, v30.4s | ||||
|         fmul v1.4s, v1.4s, v30.4s | ||||
|         fmul v2.4s, v2.4s, v30.4s | ||||
|         fmul v3.4s, v3.4s, v30.4s | ||||
|         fmul v4.4s, v4.4s, v30.4s | ||||
|         fmul v5.4s, v5.4s, v30.4s | ||||
|         fmul v6.4s, v6.4s, v30.4s | ||||
|         fmul v7.4s, v7.4s, v30.4s | ||||
|         sub x6, x6, #16 | ||||
|         fmul v8.4s, v8.4s, v30.4s | ||||
|         fmul v9.4s, v9.4s, v30.4s | ||||
|         fmul v10.4s, v10.4s, v30.4s | ||||
|         fmul v11.4s, v11.4s, v30.4s | ||||
|         fmul v12.4s, v12.4s, v30.4s | ||||
|         fmul v13.4s, v13.4s, v30.4s | ||||
|         fmul v14.4s, v14.4s, v30.4s | ||||
|         fmul v15.4s, v15.4s, v30.4s | ||||
|      | ||||
|         fadd v0.4s, v0.4s, v31.4s | ||||
|         fadd v1.4s, v1.4s, v31.4s | ||||
|         fadd v2.4s, v2.4s, v31.4s | ||||
|         fadd v3.4s, v3.4s, v31.4s | ||||
|         fadd v4.4s, v4.4s, v31.4s | ||||
|         fadd v5.4s, v5.4s, v31.4s | ||||
|         fadd v6.4s, v6.4s, v31.4s | ||||
|         fadd v7.4s, v7.4s, v31.4s | ||||
|         cmp x6, #16 | ||||
|         fadd v8.4s, v8.4s, v31.4s | ||||
|         fadd v9.4s, v9.4s, v31.4s | ||||
|         fadd v10.4s, v10.4s, v31.4s | ||||
|         fadd v11.4s, v11.4s, v31.4s | ||||
|         fadd v12.4s, v12.4s, v31.4s | ||||
|         fadd v13.4s, v13.4s, v31.4s | ||||
|         fadd v14.4s, v14.4s, v31.4s | ||||
|         fadd v15.4s, v15.4s, v31.4s | ||||
| 
 | ||||
|         st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 | ||||
|         st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 | ||||
|         st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 | ||||
|         st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 | ||||
| 
 | ||||
|         bge BSLoopP16 | ||||
| 
 | ||||
|     BSL8_: | ||||
|     cmp x6, #7 | ||||
|     ble BSLoopP1 | ||||
|     ble BSL1 | ||||
|     BSLoopP8: | ||||
|         ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 | ||||
|         fmul v0.4s, v0.4s, v30.4s | ||||
|  | @ -53,7 +207,7 @@ BSLoopZ: | |||
|         st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 | ||||
|         cmp x6, #8 | ||||
|         bge BSLoopP8 | ||||
| 
 | ||||
|     BSL1: | ||||
|     cmp x6, #0 | ||||
|     beq BSLoopPEnd | ||||
| 
 | ||||
|  | @ -71,7 +225,10 @@ BSLoopZ: | |||
| 
 | ||||
| 
 | ||||
| BSEnd: | ||||
| 
 | ||||
| ldp d8,  d9,  [sp, #48] | ||||
| ldp d10, d11, [sp, #32] | ||||
| ldp d12, d13, [sp, #16] | ||||
| ldp d14, d15, [sp], #64 | ||||
| 
 | ||||
| ret | ||||
| 
 | ||||
|  |  | |||
|  | @ -148,7 +148,7 @@ WinogradConfig ConvolutionPackWinograd::bestWinogradUnit(const Convolution2DComm | |||
|     int oc      = outputTensor->channel(); | ||||
|     int ePack, hPack, lPack; | ||||
|     core->MNNGetMatMulPackMode(&ePack, &lPack, &hPack); | ||||
|     int unit2   = UP_DIV(ow * oh, ePack * threadNumber); | ||||
|     int unit2   = UP_DIV(ow * oh, threadNumber); | ||||
|     int maxUnit = (int)::sqrtf((float)unit2); | ||||
|     maxUnit     = std::min(maxUnit, CONVOLUTION_WINOGRAD_MAX_UNIT); | ||||
|     maxUnit     = std::max(maxUnit, CONVOLUTION_WINOGRAD_MIN_UNIT); | ||||
|  |  | |||
|  | @ -257,7 +257,7 @@ public: | |||
|                     auto srcStride0 = cmd->view()->GetAs<View>(1)->stride()->data(); | ||||
|                     auto srcStride1 = cmd->view()->GetAs<View>(2)->stride()->data(); | ||||
|                     auto dstStride = cmd->view()->GetAs<View>(0)->stride()->data(); | ||||
|                     //MNN_PRINT("Binary Loop in optype:%d\n", opType); | ||||
|                     // MNN_PRINT("Binary Loop in optype:%d\n", opType); | ||||
|                     BinaryBlit((uint8_t*)dst, (const uint8_t*)src0, (const uint8_t*)src1, | ||||
|                         cmd->size()->data(), srcStride0, srcStride1, dstStride, type, runtime, opType); | ||||
| 
 | ||||
|  |  | |||
|  | @ -127,7 +127,7 @@ __global__ void FLOAT##Name(const T *input, T *output,\ | |||
| }\ | ||||
| 
 | ||||
| template<typename T> | ||||
| __global__ void blit_2(const T *input, T *output, | ||||
| __global__ void blit_2_float(const T *input, T *output, | ||||
|     int count, | ||||
|     DivModFast sizeZ, DivModFast sizeY, DivModFast sizeX, | ||||
|     int strideZ, int strideY, | ||||
|  | @ -143,6 +143,23 @@ __global__ void blit_2(const T *input, T *output, | |||
|         dstF[0] = ((int2 *)(input+srcOffset))[0]; | ||||
|     } | ||||
| } | ||||
| template<typename T> | ||||
| __global__ void blit_2_half(const T *input, T *output, | ||||
|     int count, | ||||
|     DivModFast sizeZ, DivModFast sizeY, DivModFast sizeX, | ||||
|     int strideZ, int strideY, | ||||
|     int dstStrideZ, int dstStrideY | ||||
|     ) {  | ||||
|     for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) { | ||||
|         int ix, tmp, iy, iz; | ||||
|         sizeX.divmod(i, tmp, ix); | ||||
|         sizeY.divmod(tmp, iz, iy); | ||||
|         int srcOffset = iz * strideZ + iy * strideY + (ix << 1); | ||||
|         int dstOffset = iz * dstStrideZ + iy * dstStrideY + (ix << 1); | ||||
|         int* dstF = (int *)(output+dstOffset); | ||||
|         dstF[0] = ((int *)(input+srcOffset))[0]; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| struct Bytes512 { | ||||
|     int4 x[4]; | ||||
|  | @ -186,26 +203,73 @@ UNARY_FUNC(GELU_STANDARD, (erf(x*0.7071067932881648f)+1.f)*x*0.5); | |||
| void RasterBlit(uint8_t* output, const uint8_t* input, const int32_t* size, const int32_t* srcStride, const int32_t* dstStride, int bytes, CUDARuntime* runtime) { | ||||
|     int count = size[0] * size[1] * size[2]; | ||||
| 
 | ||||
|     // MNN_PRINT("blit info size:%d-%d-%d, srcStride:%d-%d-%d, dstStride:%d-%d-%d\n", size[0], size[1], size[2], srcStride[0], srcStride[1], srcStride[2], dstStride[0], dstStride[1], dstStride[2]); | ||||
|     bool isThirdSizeVector  = (size[2] % 2 == 0 && srcStride[2] == 1 && dstStride[2] == 1); | ||||
|     bool isSecondSizeVector = (size[1] % 2 == 0 && srcStride[1] == 1 && dstStride[1] == 1) && (size[2] == 1 && srcStride[2] == 1 && dstStride[2] == 1); | ||||
|     bool isFirstSizeVector  = (size[0] % 2 == 0 && srcStride[0] == 1 && dstStride[0] == 1) && (size[1] == 1 && srcStride[1] == 1 && dstStride[1] == 1) && (size[2] == 1 && srcStride[2] == 1 && dstStride[2] == 1); | ||||
|     bool isSizeVector = isThirdSizeVector || isSecondSizeVector || isFirstSizeVector; | ||||
|     if(count > 16384 && isSizeVector) { | ||||
|         int32_t newSize[3], newSrcStride[3], newDstStride[3]; | ||||
|         newSize[0] = size[0];  | ||||
|         newSize[1] = size[1];  | ||||
|         newSize[2] = size[2];  | ||||
|         newSrcStride[0] = srcStride[0];  | ||||
|         newSrcStride[1] = srcStride[1];  | ||||
|         newSrcStride[2] = srcStride[2];  | ||||
|         newDstStride[0] = dstStride[0];  | ||||
|         newDstStride[1] = dstStride[1];  | ||||
|         newDstStride[2] = dstStride[2];  | ||||
|         if(isSecondSizeVector) { | ||||
|             /*  size   : [size_0, size_1, 1]  srcStride   : [ss_0, 1, 1] dstStride   : [ds_0, 1, 1] | ||||
|             --> newSize: [1, size_0, size_1]  newSrcStride: [1, ss_0, 1] newDstStride: [1, ds_0, 1] | ||||
|             */ | ||||
|             newSize[2] = size[1]; | ||||
|             newSize[1] = size[0]; | ||||
|             newSize[0] = 1; | ||||
|             newSrcStride[1] = srcStride[0]; | ||||
|             newSrcStride[0] = 1; | ||||
|             newDstStride[1] = dstStride[0]; | ||||
|             newDstStride[0] = 1; | ||||
|         } | ||||
|         if(isFirstSizeVector) { | ||||
|             /*  size   : [size_0, 1, 1]  srcStride   : [1, 1, 1] dstStride   : [1, 1, 1] | ||||
|             --> newSize: [1, 1, size_0]  newSrcStride: [1, 1, 1] newDstStride: [1, 1, 1] | ||||
|             */ | ||||
|             newSize[2] = size[0]; | ||||
|             newSize[0] = 1; | ||||
|         } | ||||
| 
 | ||||
|         DivModFast new_sz(newSize[0]); | ||||
|         DivModFast new_sy(newSize[1]); | ||||
|         DivModFast new_sx(newSize[2]/2); | ||||
| 
 | ||||
|         int newCount = count / 2; | ||||
|         int block_num = runtime->blocks_num(newCount); | ||||
|         int threads_num = runtime->threads_num(); | ||||
| 
 | ||||
|         // Forbid addresss misalign | ||||
|         if(bytes == 4 && reinterpret_cast<std::uintptr_t>(input) % 8 == 0 && reinterpret_cast<std::uintptr_t>(output) % 8 == 0) { | ||||
|             blit_2_float<<<block_num, threads_num>>>((const float*)input, (float*)output,  | ||||
|                 newCount, | ||||
|                 new_sz, new_sy, new_sx, | ||||
|                 newSrcStride[0], newSrcStride[1], | ||||
|                 newDstStride[0], newDstStride[1]); | ||||
|             checkKernelErrors; | ||||
|             return; | ||||
|         } else if(bytes == 2 && reinterpret_cast<std::uintptr_t>(input) % 4 == 0 && reinterpret_cast<std::uintptr_t>(output) % 4 == 0) { | ||||
|             blit_2_half<<<block_num, threads_num>>>((const half*)input, (half*)output,  | ||||
|                 newCount, | ||||
|                 new_sz, new_sy, new_sx, | ||||
|                 newSrcStride[0], newSrcStride[1], | ||||
|                 newDstStride[0], newDstStride[1]); | ||||
|             checkKernelErrors; | ||||
|             return; | ||||
|         } | ||||
|     } | ||||
|      | ||||
|     DivModFast sz(size[0]); | ||||
|     DivModFast sy(size[1]); | ||||
|     DivModFast sx(size[2]); | ||||
| 
 | ||||
|     // MNN_PRINT("blit info size:%d-%d-%d, srcStride:%d-%d-%d, dstStride:%d-%d-%d\n", size[0], size[1], size[2], srcStride[0], srcStride[1], srcStride[2], dstStride[0], dstStride[1], dstStride[2]); | ||||
|     if(bytes == 4 && count > 16384 && size[2] % 2 == 0 && srcStride[2] == 1 && dstStride[2] == 1) { | ||||
|         //printf("%d-%d-%d, %d-%d-%d,-%d-%d-%d\n\n", size[0], size[1], size[2], srcStride[0], srcStride[1], srcStride[2], dstStride[0], dstStride[1], dstStride[2]); | ||||
|         count /= 2; | ||||
|         int block_num = runtime->blocks_num(count); | ||||
|         int threads_num = runtime->threads_num(); | ||||
|         DivModFast sx_2((size[2]/2)); | ||||
| 
 | ||||
|         blit_2<<<block_num, threads_num>>>((const float*)input, (float*)output,  | ||||
|             count, | ||||
|             sz, sy, sx_2, | ||||
|             srcStride[0], srcStride[1], | ||||
|             dstStride[0], dstStride[1]); | ||||
|         return; | ||||
|     } | ||||
|      | ||||
|     int block_num = runtime->blocks_num(count); | ||||
|     int threads_num = runtime->threads_num(); | ||||
| 
 | ||||
|  | @ -225,7 +289,7 @@ void RasterBlit(uint8_t* output, const uint8_t* input, const int32_t* size, cons | |||
|                 dstStride[0], dstStride[1], dstStride[2]); | ||||
|             break; | ||||
|         case 4: | ||||
|             blit<<<block_num, threads_num>>>((const float*)input, (float*)output, | ||||
|     	    blit<<<block_num, threads_num>>>((const float*)input, (float*)output, | ||||
|                 count, | ||||
|                 sz, sy, sx, | ||||
|                 srcStride[0], srcStride[1], srcStride[2], | ||||
|  | @ -248,6 +312,7 @@ void RasterBlit(uint8_t* output, const uint8_t* input, const int32_t* size, cons | |||
|         default: | ||||
|             break; | ||||
|     } | ||||
|     checkKernelErrors; | ||||
| } | ||||
| 
 | ||||
| template<typename T0, typename T1> | ||||
|  | @ -538,7 +603,6 @@ __global__ void Binary##Name(\ | |||
|     ) { \ | ||||
|     int count = sizeZ * sizeY * sizeX;\ | ||||
|     for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {\ | ||||
|         int total = sizeZ * sizeY * sizeX;\ | ||||
|         int ix = i % sizeX;\ | ||||
|         int tmp = i / sizeX;\ | ||||
|         int iy = tmp % sizeY;\ | ||||
|  | @ -563,15 +627,14 @@ __global__ void BinaryMid##Name(\ | |||
|     int sizeZ, int sizeY, int sizeX,\ | ||||
|     int strideZ, int strideY, int strideX,\ | ||||
|     int strideZ1, int strideY1, int strideX1,\ | ||||
|     int dstStrideZ, int dstStrideY, int dstStrideX, int activationType, int bytes\ | ||||
|     int dstStrideZ, int dstStrideY, int dstStrideX, int activationType,\ | ||||
|     DivModFast d_sizeY, DivModFast d_sizeX\ | ||||
|     ) { \ | ||||
|     int count = sizeZ * sizeY * sizeX;\ | ||||
|     for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {\ | ||||
|         int total = sizeZ * sizeY * sizeX;\ | ||||
|         int ix = i % sizeX;\ | ||||
|         int tmp = i / sizeX;\ | ||||
|         int iy = tmp % sizeY;\ | ||||
|         int iz = tmp / sizeY;\ | ||||
|         int ix, tmp, iy, iz;\ | ||||
|         d_sizeX.divmod(i, tmp, ix);\ | ||||
|         d_sizeY.divmod(tmp, iz, iy);\ | ||||
|         int srcOffset = iz * strideZ + iy * strideY + ix * strideX;\ | ||||
|         int srcOffset1 = iz * strideZ1 + iy * strideY1 + ix * strideX1;\ | ||||
|         int dstOffset = iz * dstStrideZ + iy * dstStrideY + ix * dstStrideX;\ | ||||
|  | @ -581,11 +644,95 @@ __global__ void BinaryMid##Name(\ | |||
|         if(activationType == 1) {\ | ||||
|             val = (val < 0.0f ? 0.0f : val);\ | ||||
|         }\ | ||||
|         if(bytes == 2) {\ | ||||
|             val = min(val, 65504.0f);\ | ||||
|             val = max(val, -65504.0f);\ | ||||
|         output[dstOffset] = val;\ | ||||
|     }\ | ||||
| }\ | ||||
| template<typename TIn, typename TOut>\ | ||||
| __global__ void BinaryMid4_##Name(\ | ||||
|     const TIn *input0, const TIn* input1, TOut *output,\ | ||||
|     int sizeZ, int sizeY, int sizeX,\ | ||||
|     int strideZ, int strideY,\ | ||||
|     int strideZ1, int strideY1,\ | ||||
|     int dstStrideZ, int dstStrideY, int activationType,\ | ||||
|     DivModFast d_sizeY, DivModFast d_sizeX,\ | ||||
|     bool inp0Broadcast, bool inp1Broadcast\ | ||||
|     ) { \ | ||||
|     int count = sizeZ * sizeY * sizeX;\ | ||||
|     for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {\ | ||||
|         int ix, tmp, iy, iz;\ | ||||
|         d_sizeX.divmod(i, tmp, ix);\ | ||||
|         d_sizeY.divmod(tmp, iz, iy);\ | ||||
|         ix = ix << 2;\ | ||||
|         int srcOffset = iz * strideZ + iy * strideY + ix;\ | ||||
|         int srcOffset1 = iz * strideZ1 + iy * strideY1 + ix;\ | ||||
|         int dstOffset = iz * dstStrideZ + iy * dstStrideY + ix;\ | ||||
|         float4 xx = inp0Broadcast ? make_float4(input0[srcOffset-ix],input0[srcOffset-ix], input0[srcOffset-ix], input0[srcOffset-ix]) : ((float4 *)(input0+srcOffset))[0];\ | ||||
|         float4 yy = inp1Broadcast ? make_float4(input1[srcOffset1-ix],input1[srcOffset1-ix], input1[srcOffset1-ix], input1[srcOffset1-ix]) :((float4 *)(input1+srcOffset1))[0];\ | ||||
|         float x = xx.x;\ | ||||
|         float y = yy.x;\ | ||||
|         float val = (float)(Func);\ | ||||
|         if(activationType == 1) {\ | ||||
|             val = (val < 0.0f ? 0.0f : val);\ | ||||
|         }\ | ||||
|         output[dstOffset] = val;\ | ||||
|         x = xx.y;\ | ||||
|         y = yy.y;\ | ||||
|         val = (float)(Func);\ | ||||
|         if(activationType == 1) {\ | ||||
|             val = (val < 0.0f ? 0.0f : val);\ | ||||
|         }\ | ||||
|         output[dstOffset+1] = val;\ | ||||
|         x = xx.z;\ | ||||
|         y = yy.z;\ | ||||
|         val = (float)(Func);\ | ||||
|         if(activationType == 1) {\ | ||||
|             val = (val < 0.0f ? 0.0f : val);\ | ||||
|         }\ | ||||
|         output[dstOffset+2] = val;\ | ||||
|         x = xx.w;\ | ||||
|         y = yy.w;\ | ||||
|         val = (float)(Func);\ | ||||
|         if(activationType == 1) {\ | ||||
|             val = (val < 0.0f ? 0.0f : val);\ | ||||
|         }\ | ||||
|         output[dstOffset+3] = val;\ | ||||
|     }\ | ||||
| }\ | ||||
| template<typename TIn, typename TOut>\ | ||||
| __global__ void BinaryMidHalf2_##Name(\ | ||||
|     const TIn *input0, const TIn* input1, TOut *output,\ | ||||
|     int sizeZ, int sizeY, int sizeX,\ | ||||
|     int strideZ, int strideY,\ | ||||
|     int strideZ1, int strideY1,\ | ||||
|     int dstStrideZ, int dstStrideY, int activationType,\ | ||||
|     DivModFast d_sizeY, DivModFast d_sizeX,\ | ||||
|     bool inp0Broadcast, bool inp1Broadcast\ | ||||
|     ) { \ | ||||
|     int count = sizeZ * sizeY * sizeX;\ | ||||
|     for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {\ | ||||
|         int ix, tmp, iy, iz;\ | ||||
|         d_sizeX.divmod(i, tmp, ix);\ | ||||
|         d_sizeY.divmod(tmp, iz, iy);\ | ||||
|         ix = ix << 1;\ | ||||
|         int srcOffset = iz * strideZ + iy * strideY + ix;\ | ||||
|         int srcOffset1 = iz * strideZ1 + iy * strideY1 + ix;\ | ||||
|         int dstOffset = iz * dstStrideZ + iy * dstStrideY + ix;\ | ||||
|         half2 xx = inp0Broadcast ? make_half2(input0[srcOffset-ix], input0[srcOffset-ix]) : ((half2 *)(input0+srcOffset))[0];\ | ||||
|         half2 yy = inp1Broadcast ? make_half2(input1[srcOffset1-ix], input1[srcOffset1-ix]) : ((half2 *)(input1+srcOffset1))[0];\ | ||||
|         float x = xx.x;\ | ||||
|         float y = yy.x;\ | ||||
|         float val = (float)(Func);\ | ||||
|         if(activationType == 1) {\ | ||||
|             val = (val < 0.0f ? 0.0f : val);\ | ||||
|         }\ | ||||
|         output[dstOffset] = val;\ | ||||
|         x = xx.y;\ | ||||
|         y = yy.y;\ | ||||
|         val = (float)(Func);\ | ||||
|         if(activationType == 1) {\ | ||||
|             val = (val < 0.0f ? 0.0f : val);\ | ||||
|         }\ | ||||
|         output[dstOffset+1] = val;\ | ||||
|     }\ | ||||
| }\ | ||||
| template<typename TIn, typename TOut>\ | ||||
|  | @ -595,8 +742,7 @@ __global__ void BinaryMidLinear##Name(\ | |||
|     int strideZ,\ | ||||
|     int strideZ1,\ | ||||
|     int dstStrideZ,\ | ||||
|     int activationType,\ | ||||
|     int bytes\ | ||||
|     int activationType\ | ||||
|     ) { \ | ||||
|     int count = sizeZ;\ | ||||
|     for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {\ | ||||
|  | @ -610,10 +756,6 @@ __global__ void BinaryMidLinear##Name(\ | |||
|         if(activationType == 1) {\ | ||||
|             val = (val < 0.0f ? 0.0f : val);\ | ||||
|         }\ | ||||
|         if(bytes == 2) {\ | ||||
|             val = min(val, 65504.0f);\ | ||||
|             val = max(val, -65504.0f);\ | ||||
|         }\ | ||||
|         output[dstOffset] = (TOut)val;\ | ||||
|     }\ | ||||
| }\ | ||||
|  | @ -622,15 +764,16 @@ __global__ void BinaryMidLinear##Name(\ | |||
| template<typename TIn, typename TOut>\ | ||||
| __global__ void BinaryMidLinear4_##Name(\ | ||||
|     const TIn *input0, const TIn* input1, TOut *output,\ | ||||
|     int count_4, int activationType\ | ||||
|     int count_4, int activationType,\ | ||||
|     bool inp0Broadcast, bool inp1Broadcast\ | ||||
|     ) { \ | ||||
|     for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count_4); i += blockDim.x * gridDim.x) {\ | ||||
|         int iz = i;\ | ||||
|         int srcOffset = iz << 2;\ | ||||
|         int srcOffset1 = iz << 2;\ | ||||
|         int dstOffset = iz << 2;\ | ||||
|         float4 xx = ((float4 *)(input0+srcOffset))[0];\ | ||||
|         float4 yy = ((float4 *)(input1+srcOffset1))[0];\ | ||||
|         float4 xx = inp0Broadcast ? make_float4(input0[0], input0[0], input0[0], input0[0]) : ((float4 *)(input0+srcOffset))[0];\ | ||||
|         float4 yy = inp1Broadcast ? make_float4(input1[0], input1[0], input1[0], input1[0]) : ((float4 *)(input1+srcOffset1))[0];\ | ||||
|         float x = xx.x;\ | ||||
|         float y = yy.x;\ | ||||
|         TOut val = (TOut)(Func);\ | ||||
|  | @ -664,23 +807,22 @@ __global__ void BinaryMidLinear4_##Name(\ | |||
| template<typename TIn, typename TOut>\ | ||||
| __global__ void BinaryMidLinearHalf4_##Name(\ | ||||
|     const TIn *input0, const TIn* input1, TOut *output,\ | ||||
|     int count_4, int activationType\ | ||||
|     int count_4, int activationType,\ | ||||
|     bool inp0Broadcast, bool inp1Broadcast\ | ||||
|     ) { \ | ||||
|     for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count_4); i += blockDim.x * gridDim.x) {\ | ||||
|         int iz = i;\ | ||||
|         int srcOffset = iz << 2;\ | ||||
|         int srcOffset1 = iz << 2;\ | ||||
|         int dstOffset = iz << 2;\ | ||||
|         half2 xx = ((half2 *)(input0+srcOffset))[0];\ | ||||
|         half2 yy = ((half2 *)(input1+srcOffset1))[0];\ | ||||
|         half2 xx = inp0Broadcast ? make_half2(input0[0], input0[0]) : ((half2 *)(input0+srcOffset))[0];\ | ||||
|         half2 yy = inp1Broadcast ? make_half2(input1[0], input1[0]) : ((half2 *)(input1+srcOffset1))[0];\ | ||||
|         float x = (float)xx.x;\ | ||||
|         float y = (float)yy.x;\ | ||||
|         float val = (float)(Func);\ | ||||
|         if(activationType == 1) {\ | ||||
|             val = (val < 0.0f ? 0.0f : val);\ | ||||
|         }\ | ||||
|         val = min(val, 65504.0f);\ | ||||
|         val = max(val, -65504.0f);\ | ||||
|         output[dstOffset] = (TOut)val;\ | ||||
|         x = (float)xx.y;\ | ||||
|         y = (float)yy.y;\ | ||||
|  | @ -688,19 +830,15 @@ __global__ void BinaryMidLinearHalf4_##Name(\ | |||
|         if(activationType == 1) {\ | ||||
|             val = (val < 0.0f ? 0.0f : val);\ | ||||
|         }\ | ||||
|         val = min(val, 65504.0f);\ | ||||
|         val = max(val, -65504.0f);\ | ||||
|         output[dstOffset+1] = (TOut)val;\ | ||||
|         xx = ((half2 *)(input0+srcOffset))[1];\ | ||||
|         yy = ((half2 *)(input1+srcOffset1))[1];\ | ||||
|         xx = inp0Broadcast ? make_half2(input0[0], input0[0]) : ((half2 *)(input0+srcOffset))[1];\ | ||||
|         yy = inp1Broadcast ? make_half2(input1[0], input1[0]) : ((half2 *)(input1+srcOffset1))[1];\ | ||||
|         x = (float)xx.x;\ | ||||
|         y = (float)yy.x;\ | ||||
|         val = (float)(Func);\ | ||||
|         if(activationType == 1) {\ | ||||
|             val = (val <  0.0f ? 0.0f  : val);\ | ||||
|         }\ | ||||
|         val = min(val, 65504.0f);\ | ||||
|         val = max(val, -65504.0f);\ | ||||
|         output[dstOffset+2] = (TOut)val;\ | ||||
|         x = (float)xx.y;\ | ||||
|         y = (float)yy.y;\ | ||||
|  | @ -708,8 +846,6 @@ __global__ void BinaryMidLinearHalf4_##Name(\ | |||
|         if(activationType == 1) {\ | ||||
|             val = (val < 0.0f ? 0.0f : val);\ | ||||
|         }\ | ||||
|         val = min(val, 65504.0f);\ | ||||
|         val = max(val, -65504.0f);\ | ||||
|         output[dstOffset+3] = (TOut)val;\ | ||||
|     }\ | ||||
| }\ | ||||
|  | @ -784,18 +920,21 @@ void BinaryBlitTemplateFloat(T* output, const T* input, const T* input1, const i | |||
|     int count = size[0] * size[1] * size[2]; | ||||
|     int block_num = runtime->blocks_num(count); | ||||
|     int threads_num = runtime->threads_num(); | ||||
|     // MNN_PRINT("binary :%d %d %d, %d %d %d, %d %d %d, %d %d %d, \n", size[0], size[1], size[2], srcStride[0], srcStride[1], srcStride[2], srcStride1[0], srcStride1[1], srcStride1[2], dstStride[0], dstStride[1], dstStride[2]); | ||||
|     #define COMPUTE_FLOAT(TYPE, TOut)\ | ||||
|         if (opType == MNN::BinaryOpOperation_##TYPE ) {\ | ||||
|             if (size[2] == count) {\ | ||||
|                 if(count % 4 == 0 && count > 16384 && srcStride[2] == 1 && srcStride1[2] == 1 && dstStride[2] == 1) {\ | ||||
|                 if(count % 4 == 0 && count > 16384 && (srcStride[2] == 0 || srcStride[2] == 1) && (srcStride1[2] == 0 || srcStride1[2] == 1) && dstStride[2] == 1) {\ | ||||
|                     block_num = runtime->blocks_num(count/4);\ | ||||
|                     threads_num = runtime->threads_num();\ | ||||
|                     bool srcBroadcast = srcStride[2] == 0;\ | ||||
|                     bool srcBroadcast1 = srcStride1[2] == 0;\ | ||||
|                     if(bytes == 4) {\ | ||||
|                         BinaryMidLinear4_##TYPE<<<block_num, threads_num>>>((const T*)input, (const T*)(input1), (TOut*)output,\ | ||||
|                             count/4, activationType);\ | ||||
|                             count/4, activationType, srcBroadcast, srcBroadcast1);\ | ||||
|                     } else {\ | ||||
|                         BinaryMidLinearHalf4_##TYPE<<<block_num, threads_num>>>((const T*)input, (const T*)(input1), (TOut*)output,\ | ||||
|                             count/4, activationType);\ | ||||
|                             count/4, activationType, srcBroadcast, srcBroadcast1);\ | ||||
|                     }\ | ||||
|                 } else {\ | ||||
|                     BinaryMidLinear##TYPE<<<block_num, threads_num>>>((const T*)input, (const T*)(input1), (TOut*)output,\ | ||||
|  | @ -803,14 +942,41 @@ void BinaryBlitTemplateFloat(T* output, const T* input, const T* input1, const i | |||
|                         srcStride[2],\ | ||||
|                         srcStride1[2],\ | ||||
|                         dstStride[2],\ | ||||
|                         activationType, bytes);\ | ||||
|                         activationType);\ | ||||
|                 }\ | ||||
|             } else {\ | ||||
|                 BinaryMid##TYPE<<<block_num, threads_num>>>((const T*)input, (const T*)(input1), (TOut*)output,\ | ||||
|                     size[0], size[1], size[2],\ | ||||
|                     srcStride[0], srcStride[1], srcStride[2],\ | ||||
|                     srcStride1[0], srcStride1[1], srcStride1[2],\ | ||||
|                     dstStride[0], dstStride[1], dstStride[2], activationType, bytes);\ | ||||
|                 bool isVectorSizeZ = (size[0] == 1 || ((srcStride[2] == 0 || srcStride[0] % bytes == 0) && (srcStride1[2] == 0 || srcStride1[0] % bytes == 0) && dstStride[0] % bytes == 0));\ | ||||
|                 bool isVectorSizeY = (size[1] == 1 || ((srcStride[2] == 0 || srcStride[1] % bytes == 0) && (srcStride1[2] == 0 || srcStride1[1] % bytes == 0) && dstStride[1] % bytes == 0));\ | ||||
|                 bool isVector4 = size[2] % bytes == 0 && isVectorSizeZ && isVectorSizeY;\                 | ||||
| 		        if(isVector4 && count > 16384 && (srcStride[2] == 0 || srcStride[2] == 1) && (srcStride1[2] == 0 || srcStride1[2] == 1) && dstStride[2] == 1) {\ | ||||
|                     block_num = runtime->blocks_num(count/bytes);\ | ||||
|                     threads_num = runtime->threads_num();\ | ||||
|                     DivModFast sy(size[1]);\ | ||||
|                     DivModFast sx(size[2]/bytes);\ | ||||
|                     bool srcBroadcast = srcStride[2] == 0;\ | ||||
|                     bool srcBroadcast1 = srcStride1[2] == 0;\ | ||||
|                     if(bytes == 4) {\ | ||||
|                         BinaryMid4_##TYPE<<<block_num, threads_num>>>((const T*)input, (const T*)(input1), (TOut*)output,\ | ||||
|                             size[0], size[1], size[2]/4,\ | ||||
|                             srcStride[0], srcStride[1],\ | ||||
|                             srcStride1[0], srcStride1[1],\ | ||||
|                             dstStride[0], dstStride[1], activationType, sy, sx, srcBroadcast, srcBroadcast1);\ | ||||
|                     } else {\ | ||||
|                         BinaryMidHalf2_##TYPE<<<block_num, threads_num>>>((const T*)input, (const T*)(input1), (TOut*)output,\ | ||||
|                             size[0], size[1], size[2]/2,\ | ||||
|                             srcStride[0], srcStride[1],\ | ||||
|                             srcStride1[0], srcStride1[1],\ | ||||
|                             dstStride[0], dstStride[1], activationType, sy, sx, srcBroadcast, srcBroadcast1);\ | ||||
|                     }\ | ||||
|                 } else {\ | ||||
|                     DivModFast sy(size[1]);\ | ||||
|                     DivModFast sx(size[2]);\ | ||||
|                     BinaryMid##TYPE<<<block_num, threads_num>>>((const T*)input, (const T*)(input1), (TOut*)output,\ | ||||
|                         size[0], size[1], size[2],\ | ||||
|                         srcStride[0], srcStride[1], srcStride[2],\ | ||||
|                         srcStride1[0], srcStride1[1], srcStride1[2],\ | ||||
|                         dstStride[0], dstStride[1], dstStride[2], activationType, sy, sx);\ | ||||
|                 }\ | ||||
|             }\ | ||||
|             return;\ | ||||
|         }\ | ||||
|  |  | |||
|  | @ -118,7 +118,7 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con | |||
|         auto& slice0 = des->regions[0]; | ||||
|         for (int i=0; i< des->regions.size(); ++i) { | ||||
|             auto& slice = des->regions[i]; | ||||
|             // MNN_PRINT("%d-%d-%d-%d-%d\n", ____inputs[0]->batch(), ____inputs[0]->height(), ____inputs[0]->width(), ____inputs[0]->channel(), outputs[0]->channel());
 | ||||
|             // MNN_PRINT("%d-%d-%d-%d-%d\n", ____inputs[i]->batch(), ____inputs[i]->height(), ____inputs[i]->width(), ____inputs[i]->channel(), outputs[0]->channel());
 | ||||
|             // MNN_PRINT("%d-%d-%d, %d-%d-%d, %d-%d-%d, %d-%d\n\n", slice.size[0], slice.size[1], slice.size[2], slice.src.stride[0], slice.src.stride[1], slice.src.stride[2], slice.dst.stride[0], slice.dst.stride[1], slice.dst.stride[2],  slice.src.offset,  slice.dst.offset);
 | ||||
|             if (TensorUtils::getDescribe(slice.origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { | ||||
|                 mFast = false; | ||||
|  | @ -132,13 +132,13 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con | |||
|                 mFast = false; | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|         // MNN_PRINT("raster fast:%d regionNum:%d\n\n\n", mFast, des->regions.size());
 | ||||
| 	    } | ||||
|         //MNN_PRINT("raster fast:%d regionNum:%d\n\n\n", mFast, des->regions.size());
 | ||||
|         if (mFast) { | ||||
|             int srcStep = 1; | ||||
|             int dstStep = 1; | ||||
|             for (int i=0; i< des->regions.size(); ++i) { | ||||
|                  auto& slice = des->regions[i]; | ||||
|                 int srcStep = 1; | ||||
|                 int dstStep = 1; | ||||
| 	            auto& slice = des->regions[i]; | ||||
|                 if(slice.dst.offset / (slice.size[2] * slice.size[1]) >= 1) { | ||||
|                     int batchChannel = slice.dst.offset / (slice.size[1] * slice.size[2]) + 1; | ||||
|                     dstStep = dstStep > batchChannel ? dstStep : batchChannel; | ||||
|  | @ -155,11 +155,12 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con | |||
|                     int tmp = slice.src.stride[0] / slice.dst.stride[0]; | ||||
|                     srcStep = srcStep > tmp ? srcStep : tmp; | ||||
|                 } | ||||
|             } | ||||
| 		        if(____inputs[i]->channel() > slice.size[1]) { | ||||
|                     int tmp = ____inputs[i]->channel() / slice.size[1]; | ||||
|                     srcStep = srcStep > tmp ? srcStep : tmp; | ||||
| 		        } | ||||
| 		 | ||||
|             for (int i=0; i< des->regions.size(); ++i) { | ||||
|                 auto& slice = des->regions[i]; | ||||
|                 if (slice.origin == nullptr) { | ||||
| 		        if (slice.origin == nullptr) { | ||||
|                     continue; | ||||
|                 } | ||||
|                 Tensor::InsideDescribe::Region newRegion; | ||||
|  | @ -167,7 +168,7 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con | |||
|                 _turnToNewRegion(slice, newRegion, srcStep, dstStep); | ||||
|                 mFastBlit.emplace_back(std::make_pair(slice.origin, std::move(newRegion))); | ||||
|                 // MNN_PRINT("new step %d-%d:%d-%d-%d, %d-%d-%d, %d-%d-%d, %d-%d\n\n", srcStep, dstStep, newRegion.size[0], newRegion.size[1], newRegion.size[2], newRegion.src.stride[0], newRegion.src.stride[1], newRegion.src.stride[2], newRegion.dst.stride[0], newRegion.dst.stride[1], newRegion.dst.stride[2],  newRegion.src.offset,  newRegion.dst.offset);
 | ||||
|             } | ||||
| 	        } | ||||
|             return NO_ERROR; | ||||
|         } | ||||
|     } | ||||
|  | @ -299,19 +300,21 @@ void RasterExecution::executeFaster(const std::vector<Tensor *> &inputs, const s | |||
|     if (mNeedZero) { | ||||
|         auto size = static_cast<CUDABackend*>(backend())->realSize(output) * bytes; | ||||
|         cudaMemset((uint8_t*)output->deviceId(), 0, size); | ||||
|         checkKernelErrors; | ||||
|     } | ||||
|     // Use mFastBlit
 | ||||
|     for (auto& iter : mFastBlit) { | ||||
|         auto srcPtr = (uint8_t*)iter.first->deviceId() + iter.second.src.offset * bytes; | ||||
|         auto dstPtr = (uint8_t*)output->deviceId() + iter.second.dst.offset * bytes; | ||||
|         RasterBlit(dstPtr, srcPtr, iter.second.size, iter.second.src.stride, iter.second.dst.stride, bytes, runtime); | ||||
| 	    RasterBlit(dstPtr, srcPtr, iter.second.size, iter.second.src.stride, iter.second.dst.stride, bytes, runtime); | ||||
|         checkKernelErrors; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| ErrorCode RasterExecution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { | ||||
|     if (mFast) { | ||||
|         executeFaster(inputs, outputs); | ||||
|         return NO_ERROR; | ||||
| 	    return NO_ERROR; | ||||
|     } | ||||
|     auto bn = static_cast<CUDABackend*>(backend()); | ||||
|     auto input = outputs[0]; | ||||
|  | @ -346,12 +349,14 @@ ErrorCode RasterExecution::onExecute(const std::vector<Tensor *> &inputs, const | |||
|                 return NO_ERROR; | ||||
|             } | ||||
|             UnpackBuffer(dstPtr, srcPtr, &pack, bytes, runtime); | ||||
|             checkKernelErrors;           | ||||
|         } else { | ||||
|             if (output->dimensions() <= 1) { | ||||
|                 cudaMemcpy(dstPtr, srcPtr, bn->realSize(realInput) * bytes, cudaMemcpyDeviceToDevice); | ||||
|                 return NO_ERROR; | ||||
|             } | ||||
|             PackBuffer(dstPtr, srcPtr, &pack, bytes, runtime); | ||||
|             checkKernelErrors;          | ||||
|         } | ||||
|         return NO_ERROR; | ||||
|     } | ||||
|  | @ -359,9 +364,11 @@ ErrorCode RasterExecution::onExecute(const std::vector<Tensor *> &inputs, const | |||
|     if (mNeedZero) { | ||||
|         auto size = static_cast<CUDABackend*>(backend())->realSize(mOutputPtr) * bytes; | ||||
|         cudaMemset((uint8_t*)mOutputPtr->deviceId(), 0, size); | ||||
|         checkKernelErrors; | ||||
|     } | ||||
|     for (auto& iter : mTempInput) { | ||||
|         backend()->onCopyBuffer(iter.first, iter.second); | ||||
|         checkKernelErrors; | ||||
|     } | ||||
|     //MNN_PRINT("\n%d\n", mFuseRaster.first);
 | ||||
|     if(mFuseRaster.first > 0) { | ||||
|  | @ -373,16 +380,19 @@ ErrorCode RasterExecution::onExecute(const std::vector<Tensor *> &inputs, const | |||
|         //MNN_PRINT("fuseRaster:%p-%p\n", mSrcOffset, mDstOffset);
 | ||||
| 
 | ||||
|         FuseRasterBlit(dstPtr, srcPtr, slice.size, slice.src.stride, slice.dst.stride, mFuseRaster.second, mOffset, bytes, runtime, mFuseRaster.first); | ||||
|         checkKernelErrors; | ||||
|     } else { | ||||
|         for (auto& iter : mTempInputCopy) { | ||||
|             auto srcPtr = (uint8_t*)iter.first->deviceId() + iter.second->src.offset * bytes; | ||||
|             auto dstPtr = (uint8_t*)mOutputPtr->deviceId() + iter.second->dst.offset * bytes; | ||||
|             RasterBlit(dstPtr, srcPtr, iter.second->size, iter.second->src.stride, iter.second->dst.stride, bytes, runtime); | ||||
|             checkKernelErrors; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     if (nullptr != mTempOutput) { | ||||
|         backend()->onCopyBuffer(mTempOutput.get(), output); | ||||
|         checkKernelErrors; | ||||
|     } | ||||
|     return NO_ERROR; | ||||
| } | ||||
|  |  | |||
|  | @ -209,12 +209,12 @@ ErrorCode SoftmaxExecution::onExecute(const std::vector<Tensor *> &inputs, const | |||
|     int block_num = runtime->blocks_num(count); | ||||
|     int threads_num = runtime->threads_num(); | ||||
|     if (static_cast<CUDABackend*>(backend())->useFp16()) { | ||||
|         if(axis % 256 == 0 || axis >= 768) { | ||||
| 	if(axis % 256 == 0 || axis >= 768) { | ||||
|             block_num = count; | ||||
|             int calc_multi_num = (axis + 255) / 256; | ||||
|             SOFTMAX_AXIS_REDUCE<<<block_num, 256>>>((const half*)input, (half*)dst, inside, axis, 256, calc_multi_num, outside, count); | ||||
|             checkKernelErrors; | ||||
|         } else if(axis % 64 == 0 || axis >= 256) { | ||||
|         } else if(axis % 64 == 0 || axis > 32) { | ||||
|             block_num = count; | ||||
|             int calc_multi_num = (axis + 63) / 64; | ||||
|             SOFTMAX_AXIS_REDUCE<<<block_num, 64>>>((const half*)input, (half*)dst, inside, axis, 64, calc_multi_num, outside, count); | ||||
|  | @ -234,7 +234,7 @@ ErrorCode SoftmaxExecution::onExecute(const std::vector<Tensor *> &inputs, const | |||
|             int calc_multi_num = (axis + 255) / 256; | ||||
|             SOFTMAX_AXIS_REDUCE<<<block_num, 256>>>((const float*)input, (float*)dst, inside, axis, 256, calc_multi_num, outside, count); | ||||
|             checkKernelErrors; | ||||
|         } else if(axis % 64 == 0 || axis >= 256) { | ||||
|         } else if(axis % 64 == 0 || axis > 32) { | ||||
|             block_num = count; | ||||
|             int calc_multi_num = (axis + 63) / 64; | ||||
|             SOFTMAX_AXIS_REDUCE<<<block_num, 64>>>((const float*)input, (float*)dst, inside, axis, 64, calc_multi_num, outside, count); | ||||
|  |  | |||
|  | @ -0,0 +1,312 @@ | |||
| #include "TopKV2Execution.hpp" | ||||
| #include <memory> | ||||
| 
 | ||||
| 
 | ||||
| namespace MNN { | ||||
| namespace CUDA { | ||||
| 
 | ||||
| 
 | ||||
| // rank TopK in the corresponding thead | ||||
| template<typename indexT, typename valueT> | ||||
| __device__ void TopKInThread(const valueT * inputDevice, indexT * indicesThread, valueT * valuesThread, const int K, const int numElePerRow, const valueT minValue, const int descendFlag) { | ||||
|     for (int i = 0 ; i < K; i++) { | ||||
|         indicesThread[i] = -1; | ||||
|         valuesThread[i] = (valueT)(descendFlag) * minValue; | ||||
|     } | ||||
| 
 | ||||
|     int idxFirstEleInRow = threadIdx.x + blockIdx.x * blockDim.x; | ||||
| 
 | ||||
|     for (indexT i =  idxFirstEleInRow; i < numElePerRow; i += gridDim.x * blockDim.x) { | ||||
|         valueT data = inputDevice[i]; | ||||
|         if ((valueT)(descendFlag) * data <= (valueT)(descendFlag) * valuesThread[K - 1]) { | ||||
|             continue; | ||||
|         } else { | ||||
|             for (int j = K - 2; j >= 0; j--) { | ||||
|                 if ((valueT)(descendFlag) * data > (valueT)(descendFlag) * valuesThread[j]) { | ||||
|                     valuesThread[j + 1] = valuesThread[j]; | ||||
|                     indicesThread[j + 1] = indicesThread[j]; | ||||
|                     valuesThread[j] = data; | ||||
|                     indicesThread[j] = i; | ||||
|                 } else { | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     return; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| // reduce TopK results of two offsets | ||||
| template<typename indexT, typename valueT> | ||||
| __device__ void ReduceTopK(indexT * indicesArray, valueT * valuesArray, const int offset1, const int offset2, const int K, const int descendFlag) { | ||||
|     indexT idx1 = offset1 + K - 1; | ||||
|     indexT idx2 = offset2 + K - 1; | ||||
|     indexT idxVirtual = offset1 + 2 * K -1; | ||||
| 
 | ||||
|     while (idx2 >= offset2) { | ||||
|         if (idx1 < offset1) { | ||||
|             while (idxVirtual >= offset1) { | ||||
|                 indicesArray[idxVirtual] = indicesArray[offset2 + (idxVirtual - offset1)]; | ||||
|                 valuesArray[idxVirtual] = valuesArray[offset2 + (idxVirtual - offset1)]; | ||||
|                 idxVirtual --; | ||||
|             } | ||||
|             break; | ||||
|         } | ||||
| 
 | ||||
|         if ((valueT)(descendFlag) * valuesArray[idx1] <= (valueT)(descendFlag) * valuesArray[idx2]) { | ||||
|             if (idxVirtual <= offset1 + K - 1) { | ||||
|                 indicesArray[idxVirtual] = indicesArray[idx1]; | ||||
|                 valuesArray[idxVirtual] = valuesArray[idx1]; | ||||
|             } | ||||
|             idx1 --; | ||||
|         } else { | ||||
|             if (idxVirtual <= offset1 + K - 1) { | ||||
|                 indicesArray[idxVirtual] = indicesArray[idx2]; | ||||
|                 valuesArray[idxVirtual] = valuesArray[idx2]; | ||||
|             } | ||||
|             idx2 --; | ||||
|         } | ||||
|         idxVirtual --; | ||||
|     } | ||||
| 
 | ||||
|     return; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| // get results of all blocks' TopK in one row | ||||
| template<typename indexT, typename valueT> | ||||
| __device__ void TopKOneRow(const valueT * inputDevice, indexT * indicesBlock, valueT * valuesBlock, indexT * tempIndicesDevice, valueT * tempValuesDevice, const int K, const int lengthRow, valueT minValue, const int descendFlag) { | ||||
|     indexT * indicesThread = indicesBlock + threadIdx.x * K; | ||||
|     valueT * valuesThread = valuesBlock + threadIdx.x * K; | ||||
| 
 | ||||
|     // rank TopK | ||||
|     TopKInThread<indexT, valueT>(inputDevice, indicesThread, valuesThread, K, lengthRow, minValue, descendFlag); | ||||
| 
 | ||||
|     __syncthreads(); | ||||
| 
 | ||||
|     // reduce | ||||
|     for(int stride = (blockDim.x >> 1); stride > 0; stride >>= 1) { | ||||
|         if(threadIdx.x < stride) { | ||||
|             ReduceTopK<indexT, valueT>(indicesBlock, valuesBlock, threadIdx.x * K, (threadIdx.x + stride) * K, K, descendFlag); | ||||
|         } | ||||
|         __syncthreads(); | ||||
|     } | ||||
| 
 | ||||
|     // move data from block's smem to global memory(prepare for the next kernel function) | ||||
|     if (threadIdx.x == 0) { | ||||
|         for(int i = 0; i < K; i++) { | ||||
|             tempIndicesDevice[K * blockIdx.x + i] = indicesBlock[i]; | ||||
|             tempValuesDevice[K * blockIdx.x + i] = valuesBlock[i]; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     return; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| // get results of the final TopK from all block's TopK in a row | ||||
| template<typename indexT, typename valueT> | ||||
| __device__ void GetResultOneRow(indexT * outputIndicesDevice, valueT * outputValuesDevice, indexT * tempIndicesDevice, valueT * tempValuesDevice, indexT * finalIndices, valueT * finalValues, const int K, const int reduceLength, const int descendFlag) { | ||||
|     // move data from global memory to a block's smem | ||||
|     if (threadIdx.x < reduceLength) { | ||||
|         for (int i = 0; i < K; i++) { | ||||
|             finalIndices[threadIdx.x * K + i] = tempIndicesDevice[threadIdx.x * K + i]; | ||||
|             finalValues[threadIdx.x * K + i] = tempValuesDevice[threadIdx.x * K + i]; | ||||
|         } | ||||
|     } | ||||
|     __syncthreads(); | ||||
| 
 | ||||
|     // the first round of reducing needs special action | ||||
|     int stride = blockDim.x >> 1; | ||||
|     if ((threadIdx.x < stride) && (threadIdx.x + stride < reduceLength)) { | ||||
|         ReduceTopK<indexT, valueT>(finalIndices, finalValues, threadIdx.x * K, (threadIdx.x + stride) * K, K, descendFlag); | ||||
|     } | ||||
|     __syncthreads(); | ||||
|     stride >>= 1; | ||||
| 
 | ||||
|     // the remaining rounds of reducing | ||||
|     for (; stride > 0; stride >>= 1) { | ||||
|         if (threadIdx.x < stride) { | ||||
|             ReduceTopK<indexT, valueT>(finalIndices, finalValues, threadIdx.x * K, (threadIdx.x + stride) * K, K, descendFlag); | ||||
|         } | ||||
|         __syncthreads(); | ||||
|     } | ||||
| 
 | ||||
|     //move data from a block's smem to global memory | ||||
|     if (threadIdx.x == 0) { | ||||
|         for (int i = 0; i < K; i++) { | ||||
|             outputIndicesDevice[i] = finalIndices[i]; | ||||
|             outputValuesDevice[i] = finalValues[i]; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     return; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| // allocate addresses for each row and call <TopKOneRow> | ||||
| template<typename indexT, typename valueT> | ||||
| __global__ void TopKAllRows(const valueT * inputDevice, indexT * tempIndicesDevice, valueT * tempValuesDevice, const int K, const int lengthRow, valueT minValue, const int descendFlag) { | ||||
|     extern __shared__ char smem[]; | ||||
|     indexT * indicesBlock = reinterpret_cast<indexT *>(smem); | ||||
|     valueT * valuesBlock = reinterpret_cast<valueT *>(&smem[blockDim.x * K * sizeof(indexT)]); | ||||
| 
 | ||||
|     int idxRow = blockIdx.y; | ||||
| 
 | ||||
|     const valueT * inputDeviceThisRow = inputDevice + idxRow * lengthRow; | ||||
|     indexT * tempIndicesDeviceThisRow = tempIndicesDevice + idxRow * gridDim.x * K; | ||||
|     valueT * tempValuesDeviceThisRow = tempValuesDevice + idxRow * gridDim.x * K; | ||||
| 
 | ||||
|     TopKOneRow<indexT, valueT>(inputDeviceThisRow, indicesBlock, valuesBlock, tempIndicesDeviceThisRow, tempValuesDeviceThisRow, K, lengthRow, minValue, descendFlag); | ||||
|     __syncthreads(); | ||||
| 
 | ||||
|     return; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| // allocate addresses for each row and call <GetResultOneRow> | ||||
| // This kernel assumes that each row of data corresponds to one block. | ||||
| template<typename indexT, typename valueT> | ||||
| __global__ void GetResultAllRows(indexT * outputIndicesDevice, valueT * outputValuesDevice, indexT * tempIndicesDevice, valueT * tempValuesDevice, const int K, const int numBlockPerRow, const int descendFlag) { | ||||
|     extern __shared__ char smem[]; | ||||
|     indexT * finalIndices = reinterpret_cast<indexT *>(smem); | ||||
|     valueT * finalValues = reinterpret_cast<valueT *>(&smem[numBlockPerRow * K * sizeof(indexT)]); | ||||
| 
 | ||||
|     int idxRow = blockIdx.x; // each block corresponds to a row | ||||
| 
 | ||||
|     indexT * outputIndicesDeviceThisRow = outputIndicesDevice + idxRow * K; | ||||
|     valueT * outputValuesDeviceThisRow = outputValuesDevice + idxRow * K; | ||||
|     indexT * tempIndicesDeviceThisRow = tempIndicesDevice + idxRow * numBlockPerRow * K; | ||||
|     valueT * tempValuesDeviceThisRow = tempValuesDevice + idxRow * numBlockPerRow * K; | ||||
| 
 | ||||
|     GetResultOneRow<indexT, valueT>(outputIndicesDeviceThisRow, outputValuesDeviceThisRow, tempIndicesDeviceThisRow, tempValuesDeviceThisRow, finalIndices, finalValues, K, numBlockPerRow, descendFlag); | ||||
| 
 | ||||
|     return; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| int CalculateNumThreadPerBlock(const int K) { | ||||
|     int numThreadPerBlock; | ||||
|     if (K <= 48) { | ||||
|         numThreadPerBlock = 128; | ||||
|     } else if (K <= 96) { | ||||
|         numThreadPerBlock = 64; | ||||
|     } else { | ||||
|         numThreadPerBlock = 32; | ||||
|     } | ||||
| 
 | ||||
|     return numThreadPerBlock; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| TopKV2Execution::TopKV2Execution(const Op* op, Backend* backend) : Execution(backend) { | ||||
|     mOp = op; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| ErrorCode TopKV2Execution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { | ||||
|     // prepare some params for the kernel function | ||||
|     Tensor * inputTensor = inputs[0]; | ||||
| 
 | ||||
|     int lengthRow = inputTensor->buffer().dim[inputTensor->buffer().dimensions - 1].extent; | ||||
|     int numRow = inputTensor->elementSize() / lengthRow; | ||||
| 
 | ||||
|     mParams.mLengthRow = lengthRow; | ||||
|     mParams.mNumRow = numRow; | ||||
| 
 | ||||
|     auto boolDescendFlag = mOp->main_as_TopKV2(); | ||||
|     if (boolDescendFlag != nullptr) { | ||||
|         mParams.mDescendFlag = boolDescendFlag ? 1 : -1; | ||||
|     } | ||||
| 
 | ||||
|     mParams.mNumElePerRow = mParams.mLengthRow; | ||||
|     mParams.mNumK = outputs[0]->buffer().dim[outputs[0]->buffer().dimensions-1].extent; | ||||
|     mParams.mNumElePerThread = mParams.mNumK; | ||||
|     mParams.mNumThreadPerBlock = CalculateNumThreadPerBlock(mParams.mNumK); | ||||
|     mParams.mNumElePerBlock = mParams.mNumElePerThread * mParams.mNumThreadPerBlock; | ||||
|     mParams.mNumBlockPerRow = (mParams.mNumElePerRow - 1 + mParams.mNumElePerBlock) / mParams.mNumElePerBlock; | ||||
|     mParams.mNumBlockFinal = mParams.mNumRow; | ||||
|     mParams.mNumThreadFinal = std::pow(2, (std::ceil(std::log2(mParams.mNumBlockPerRow)))); | ||||
|     mParams.mNumBlockTotal = mParams.mNumBlockPerRow * mParams.mNumRow; | ||||
| 
 | ||||
|     // prepare temp buffer | ||||
|     auto pool = static_cast<CUDABackend*>(backend())->getStaticBufferPool(); | ||||
| 
 | ||||
|     if (inputTensor->getType().code == halide_type_int && inputTensor->getType().bits == 32) { | ||||
|         std::pair<void*, int> bufferIndices = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(int)); | ||||
|         mParams.mBufferIndices = (void*)((uint8_t*)bufferIndices.first + bufferIndices.second); | ||||
|         std::pair<void*, int> bufferValues = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(int)); | ||||
|         mParams.mBufferValues = (void*)((uint8_t*)bufferValues.first + bufferValues.second); | ||||
|         pool->free(bufferIndices); | ||||
|         pool->free(bufferValues); | ||||
|     } else if (static_cast<CUDABackend*>(backend())->useFp16()) { | ||||
|         std::pair<void*, int> bufferIndices = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(int)); | ||||
|         mParams.mBufferIndices = (void*)((uint8_t*)bufferIndices.first + bufferIndices.second); | ||||
|         std::pair<void*, int> bufferValues = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(half)); | ||||
|         mParams.mBufferValues = (void*)((uint8_t*)bufferValues.first + bufferValues.second); | ||||
|         pool->free(bufferIndices); | ||||
|         pool->free(bufferValues); | ||||
|     } else { | ||||
|         std::pair<void*, int> bufferIndices = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(int)); | ||||
|         mParams.mBufferIndices = (void*)((uint8_t*)bufferIndices.first + bufferIndices.second); | ||||
|         std::pair<void*, int> bufferValues = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(float)); | ||||
|         mParams.mBufferValues = (void*)((uint8_t*)bufferValues.first + bufferValues.second); | ||||
|         pool->free(bufferIndices); | ||||
|         pool->free(bufferValues); | ||||
|     } | ||||
| 
 | ||||
|     return NO_ERROR; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| ErrorCode TopKV2Execution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { | ||||
|     // get input and output pointers | ||||
|     void * inputDeviceAddr = reinterpret_cast<void *>(inputs[0]->deviceId()); | ||||
|     void * outputIndicesDeviceAddr = reinterpret_cast<void *>(outputs[1]->deviceId()); | ||||
|     void * outputValuesDeviceAddr = reinterpret_cast<void *>(outputs[0]->deviceId()); | ||||
| 
 | ||||
|     // configure threads | ||||
|     dim3 grid1 = {mParams.mNumBlockPerRow, mParams.mNumRow}; | ||||
|     dim3 block1 = {mParams.mNumThreadPerBlock, 1}; | ||||
|     int smemSize_1 = mParams.mNumThreadPerBlock * mParams.mNumK; | ||||
|     dim3 grid2 = {mParams.mNumBlockFinal}; | ||||
|     dim3 block2 = {mParams.mNumThreadFinal}; | ||||
|     int smemSize_2 = mParams.mNumBlockPerRow * mParams.mNumK; | ||||
| 
 | ||||
|     if (inputs[0]->getType().code == halide_type_int && inputs[0]->getType().bits == 32) { | ||||
|         TopKAllRows<int, int><<<grid1, block1, smemSize_1 * (sizeof(int) + sizeof(int))>>>(static_cast<const int *>(inputDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<int *>(mParams.mBufferValues), mParams.mNumK, mParams.mLengthRow, mParams.mMinInt, mParams.mDescendFlag); | ||||
|         checkKernelErrors; | ||||
|         GetResultAllRows<int, int><<<grid2, block2, smemSize_2 * (sizeof(int) + sizeof(int))>>>(static_cast<int *>(outputIndicesDeviceAddr), static_cast<int *>(outputValuesDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<int *>(mParams.mBufferValues), mParams.mNumK, mParams.mNumBlockPerRow, mParams.mDescendFlag); | ||||
|         checkKernelErrors; | ||||
|     } else if (static_cast<CUDABackend*>(backend())->useFp16()) { | ||||
|         TopKAllRows<int, half><<<grid1, block1, smemSize_1 * (sizeof(float) + sizeof(int))>>>(static_cast<const half *>(inputDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<half *>(mParams.mBufferValues), mParams.mNumK, mParams.mLengthRow, mParams.mMinHalf, mParams.mDescendFlag); | ||||
|         checkKernelErrors; | ||||
|         GetResultAllRows<int, half><<<grid2, block2, smemSize_2 * (sizeof(float) + sizeof(int))>>>(static_cast<int *>(outputIndicesDeviceAddr), static_cast<half *>(outputValuesDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<half *>(mParams.mBufferValues), mParams.mNumK, mParams.mNumBlockPerRow, mParams.mDescendFlag); | ||||
|         checkKernelErrors; | ||||
|     } else { | ||||
|         TopKAllRows<int, float><<<grid1, block1, smemSize_1 * (sizeof(float) + sizeof(int))>>>(static_cast<const float *>(inputDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<float *>(mParams.mBufferValues), mParams.mNumK, mParams.mLengthRow, mParams.mMinFloat, mParams.mDescendFlag); | ||||
|         checkKernelErrors; | ||||
|         GetResultAllRows<int, float><<<grid2, block2, smemSize_2 * (sizeof(float) + sizeof(int))>>>(static_cast<int *>(outputIndicesDeviceAddr), static_cast<float *>(outputValuesDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<float *>(mParams.mBufferValues), mParams.mNumK, mParams.mNumBlockPerRow, mParams.mDescendFlag); | ||||
|         checkKernelErrors; | ||||
|     } | ||||
| 
 | ||||
|     return NO_ERROR; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| class TopKV2Creator : public CUDABackend::Creator { | ||||
| public: | ||||
|     virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, | ||||
|                                 const MNN::Op* op, Backend* backend) const override { | ||||
|         return new TopKV2Execution(op, backend); | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| 
 | ||||
| static CUDACreatorRegister<TopKV2Creator> __init(OpType_TopKV2); | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| } | ||||
|  | @ -0,0 +1,68 @@ | |||
| //
 | ||||
| //  TopKV2Execution.hpp
 | ||||
| //  MNN
 | ||||
| //
 | ||||
| //  Created by MNN on 2023/07/19.
 | ||||
| //  Copyright © 2018, Alibaba Group Holding Limited
 | ||||
| //
 | ||||
| 
 | ||||
| 
 | ||||
| #ifndef TopKV2Execution_hpp | ||||
| #define TopKV2Execution_hpp | ||||
| 
 | ||||
| #include "core/Execution.hpp" | ||||
| #include "core/Macro.h" | ||||
| #include "backend/cuda/core/CUDABackend.hpp" | ||||
| #include <memory> | ||||
| #include <limits> | ||||
| #include "cuda_fp16.h" | ||||
| 
 | ||||
| namespace MNN { | ||||
| namespace CUDA { | ||||
| 
 | ||||
| 
 | ||||
| class TopKV2Execution : public Execution { | ||||
| public: | ||||
|     TopKV2Execution(const Op * op, Backend * backend); | ||||
|     virtual ~TopKV2Execution() = default; | ||||
| 
 | ||||
|     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | ||||
|     virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | ||||
| 
 | ||||
| private: | ||||
|     struct TopKV2Params { | ||||
|         int mLengthRow; | ||||
|         int mNumRow; | ||||
|         int mDescendFlag = 1; | ||||
|         void * mBufferIndices; | ||||
|         void * mBufferValues; | ||||
| 
 | ||||
|         int mNumK; | ||||
|         int mNumElePerRow; | ||||
|         int mNumElePerThread; | ||||
|         int mNumThreadPerBlock; | ||||
|         int mNumElePerBlock; | ||||
|         int mNumBlockPerRow; | ||||
|         int mNumBlockTotal; | ||||
|         int mNumBlockFinal; | ||||
|         int mNumThreadFinal; | ||||
| 
 | ||||
|         float mMinFloat = std::numeric_limits<float>::lowest(); | ||||
|         half mMinHalf = __float2half(-65504.0f); | ||||
|         int mMinInt = -std::numeric_limits<int>::max(); | ||||
|     }; | ||||
| 
 | ||||
|     const Op * mOp; | ||||
|     TopKV2Params mParams; | ||||
| }; | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| } | ||||
| 
 | ||||
| #endif | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|  | @ -180,13 +180,13 @@ void UnpackBuffer(void* output, const void* input, const PackInfo* info, int byt | |||
|         int block_num = runtime->blocks_num(maxCount); | ||||
|         int block_size = runtime->threads_num(); | ||||
|         int axisAlign = UP_DIV(info->axis / 4, PACK_NUMBER / 4) * PACK_NUMBER / 4;; | ||||
| 	if(bytes == 4) {         | ||||
| 	    UNPACKCOMMON_4<<<block_num, block_size>>>((const int4*)input, (int4*)output,  | ||||
|                         maxCount, info->inside, info->axis / 4, info->outside, | ||||
|                         info->insideStride / 4, info->axisStride, axisAlign, is, cs); | ||||
| 	    checkKernelErrors; | ||||
| 	    return; | ||||
| 	}         | ||||
|         if(bytes == 4) {         | ||||
|             UNPACKCOMMON_4<<<block_num, block_size>>>((const int4*)input, (int4*)output,  | ||||
|                             maxCount, info->inside, info->axis / 4, info->outside, | ||||
|                             info->insideStride / 4, info->axisStride, axisAlign, is, cs); | ||||
|             checkKernelErrors; | ||||
|             return; | ||||
|         }         | ||||
|         if(bytes == 2) { | ||||
|             UNPACKCOMMON_4<<<block_num, block_size>>>((const int2*)input, (int2*)output, | ||||
|                         maxCount, info->inside, info->axis / 4, info->outside, | ||||
|  |  | |||
|  | @ -429,6 +429,7 @@ Execution* OpenCLBackend::onCreate(const std::vector<Tensor*>& inputs, const std | |||
|     auto iter      = creators->find(std::make_pair(op->type(), mOpenCLRuntime->getGpuMemType())); | ||||
| 
 | ||||
|     if (iter == creators->end()) { | ||||
|         mOpenCLRuntime->setDevideOpRecord(); | ||||
|         #if 0//close log
 | ||||
|         if (nullptr != op->name()) { | ||||
|             MNN_PRINT("Don't support type %s memObject:%d, %s\n", EnumNameOpType(op->type()), mOpenCLRuntime->getGpuMemType(), op->name()->c_str()); | ||||
|  | @ -462,6 +463,7 @@ Execution* OpenCLBackend::onCreate(const std::vector<Tensor*>& inputs, const std | |||
|         } | ||||
| 
 | ||||
|         if (!valid) { | ||||
|             mOpenCLRuntime->setDevideOpRecord(); | ||||
|             #if 0//close log
 | ||||
|             for (auto t : inputs) { | ||||
|                 auto tensorShape = OpenCL::tensorShapeFormat(t); | ||||
|  | @ -479,6 +481,7 @@ Execution* OpenCLBackend::onCreate(const std::vector<Tensor*>& inputs, const std | |||
| 
 | ||||
|     auto exe = iter->second->onCreate(inputs, outputs, op, this); | ||||
|     if (NULL == exe) { | ||||
|         mOpenCLRuntime->setDevideOpRecord(); | ||||
|         #if 0//close log
 | ||||
|         if (nullptr != op->name()) { | ||||
|             MNN_PRINT("The Creator Don't support type %s, memObject:%d, %s\n", MNN::EnumNameOpType(op->type()), mOpenCLRuntime->getGpuMemType(), op->name()->c_str()); | ||||
|  | @ -498,12 +501,14 @@ void OpenCLBackend::onResizeBegin() { | |||
| #ifndef ENABLE_OPENCL_TIME_PROFILER | ||||
|     mOpenCLRuntime->setCommandQueueProfileEnable(); | ||||
| #endif | ||||
|     mOpenCLRuntime->releaseRecord(); | ||||
| } | ||||
| 
 | ||||
| void OpenCLBackend::onResizeEnd() { | ||||
| #ifndef ENABLE_OPENCL_TIME_PROFILER | ||||
|     mOpenCLRuntime->setCommandQueueProfileDisable(); | ||||
| #endif | ||||
|     mOpenCLRuntime->endRecord(); | ||||
| } | ||||
| 
 | ||||
| void OpenCLBackend::onExecuteBegin() const { | ||||
|  | @ -515,6 +520,7 @@ void OpenCLBackend::onExecuteBegin() const { | |||
| void OpenCLBackend::onExecuteEnd() const { | ||||
|     mOpenCLRuntime->mQueueCount = 0; | ||||
|     mOpenCLRuntime->clearRecord(); | ||||
|     mOpenCLRuntime->enqeueRecord(); | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -569,11 +569,13 @@ void startRecord(OpenCLRuntime *runtime, cl_recording_qcom &recording){ | |||
|     MNN_PRINT("start startRecord !\n"); | ||||
| #endif | ||||
|     cl_int res = CL_SUCCESS; | ||||
|     if(recording != NULL){ | ||||
|         clReleaseRecordingQCOM(recording); | ||||
|     if(runtime->isDevideOpRecord()){ | ||||
|         if(recording != NULL){ | ||||
|             clReleaseRecordingQCOM(recording); | ||||
|         } | ||||
|         recording = runtime->recordableQueue().NewRecordingQCOM(&res); | ||||
|         MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM"); | ||||
|     } | ||||
|     recording = runtime->recordableQueue().NewRecordingQCOM(&res); | ||||
|     MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM"); | ||||
| #ifdef LOG_VERBOSE | ||||
|     MNN_PRINT("end startRecord !\n"); | ||||
| #endif | ||||
|  | @ -588,9 +590,11 @@ void endRecord(OpenCLRuntime *runtime, cl_recording_qcom &recording){ | |||
| #ifdef LOG_VERBOSE | ||||
|     MNN_PRINT("start endRecord !\n"); | ||||
| #endif | ||||
|     cl_int res = CL_SUCCESS; | ||||
|     res = clEndRecordingQCOM(recording); | ||||
|     MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM"); | ||||
|     if(runtime->isDevideOpRecord()){ | ||||
|         cl_int res = CL_SUCCESS; | ||||
|         res = clEndRecordingQCOM(recording); | ||||
|         MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM"); | ||||
|     } | ||||
| #ifdef LOG_VERBOSE | ||||
|     MNN_PRINT("end endRecord !\n"); | ||||
| #endif | ||||
|  | @ -607,6 +611,25 @@ void recordKernel2d(const ::cl::Kernel &kernel, const std::vector<uint32_t> &gws | |||
|     MNN_PRINT("start recordKernel !\n"); | ||||
| #endif | ||||
|     cl_int res = CL_SUCCESS; | ||||
|     if(!runtime->isDevideOpRecord()){ | ||||
|         auto RecordNum = runtime->getRecordNum(); | ||||
|         auto maxRecordNum = runtime->getUseRecordableQueueSize(); | ||||
|         if(RecordNum == 0){ | ||||
|             cl_recording_qcom recording = runtime->recordableQueue().NewRecordingQCOM(&res); | ||||
|             MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM"); | ||||
|             runtime->getRecordings()->emplace_back(recording); | ||||
|         }else if(RecordNum == maxRecordNum){ | ||||
|             res = clEndRecordingQCOM( runtime->getRecordings()->back()); | ||||
|             MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM"); | ||||
|             cl_recording_qcom recording = runtime->recordableQueue().NewRecordingQCOM(&res); | ||||
|             MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM"); | ||||
|             runtime->getRecordings()->emplace_back(recording); | ||||
|             RecordNum = 0; | ||||
|         } | ||||
|         RecordNum++; | ||||
|         runtime->setRecordNum(RecordNum); | ||||
|     } | ||||
|      | ||||
|     std::vector<uint32_t> internalGlobalWS = gws; | ||||
|     for (size_t i = 0; i < 2; ++i) { | ||||
|         internalGlobalWS[i] = ROUND_UP(gws[i], std::max((uint32_t)1, lws[i])); | ||||
|  | @ -642,7 +665,24 @@ void recordKernel3d(const ::cl::Kernel &kernel, const std::vector<uint32_t> &gws | |||
|     for (size_t i = 0; i < 3; ++i) { | ||||
|         internalGlobalWS[i] = ROUND_UP(gws[i], std::max((uint32_t)1, lws[i])); | ||||
|     } | ||||
| 
 | ||||
|     if(!runtime->isDevideOpRecord()){ | ||||
|         auto maxRecordNum = runtime->getUseRecordableQueueSize(); | ||||
|         auto RecordNum = runtime->getRecordNum(); | ||||
|         if(RecordNum == 0){ | ||||
|             cl_recording_qcom recording = runtime->recordableQueue().NewRecordingQCOM(&res); | ||||
|             MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM"); | ||||
|             runtime->getRecordings()->emplace_back(recording); | ||||
|         }else if(RecordNum == maxRecordNum){ | ||||
|             res = clEndRecordingQCOM( runtime->getRecordings()->back()); | ||||
|             MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM"); | ||||
|             cl_recording_qcom recording = runtime->recordableQueue().NewRecordingQCOM(&res); | ||||
|             MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM"); | ||||
|             runtime->getRecordings()->emplace_back(recording); | ||||
|             RecordNum = 0; | ||||
|         } | ||||
|         RecordNum++; | ||||
|         runtime->setRecordNum(RecordNum); | ||||
|     } | ||||
| 
 | ||||
|     if(lws[0]==0 || lws[1]==0 || lws[2]==0){ | ||||
|         res = runtime->recordableQueue().enqueueNDRangeKernel( | ||||
|  |  | |||
|  | @ -235,9 +235,12 @@ OpenCLRuntime::OpenCLRuntime(const BackendConfig::PrecisionMode precision, const | |||
| #if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER) | ||||
|             { | ||||
|                 if((false == OpenCLSymbolsOperator::getOpenclSymbolsPtr()->isQcomError()) && getDeviceSupportsExtension(*(mFirstGPUDevicePtr.get()), "cl_qcom_recordable_queues")){ | ||||
|                     mMaxRecordableQueueSize = mFirstGPUDevicePtr->getInfo<CL_DEVICE_RECORDABLE_QUEUE_MAX_SIZE>(); | ||||
|                     uint32_t MaxRecordableQueueSize = mFirstGPUDevicePtr->getInfo<CL_DEVICE_RECORDABLE_QUEUE_MAX_SIZE>(); | ||||
|                     cl_int err; | ||||
|                     if(mMaxRecordableQueueSize > 0){ | ||||
|                     if(MaxRecordableQueueSize > 0){ | ||||
|                         // TODO: Use setSessionHint to set the number of mUseRecordableQueueSize
 | ||||
|                         mUseRecordableQueueSize = 10; | ||||
|                         mUseRecordableQueueSize = MaxRecordableQueueSize < mUseRecordableQueueSize ? MaxRecordableQueueSize : mUseRecordableQueueSize; | ||||
|                         mUseRecordQueue = true; | ||||
|                         mRecordableQueuePtr = std::make_shared<cl::CommandQueue>(*mContext, *mFirstGPUDevicePtr, CL_QUEUE_RECORDABLE_QCOM, &err); | ||||
|                         if(err != CL_SUCCESS){ | ||||
|  | @ -309,6 +312,23 @@ void OpenCLRuntime::setGpuMode(const int cl_mode_num) { | |||
|     if(totalSet != 1) { | ||||
|         MNN_PRINT("set multi tuning mode is not permitted, please check cl_mode:%x!\n", cl_mode_num); | ||||
|     } | ||||
|      | ||||
|     totalSet = 0; | ||||
|     isSet = (cl_mode_num & MNN_GPU_RECORD_OP); | ||||
|     if(isSet) { | ||||
|         mDevideOpRecord = true; | ||||
|         totalSet++; | ||||
|     } | ||||
|      | ||||
|     isSet = (cl_mode_num & MNN_GPU_RECORD_BATCH); | ||||
|     if(isSet) { | ||||
|         mDevideOpRecord = false; | ||||
|         totalSet++; | ||||
|     } | ||||
|      | ||||
|     if(totalSet > 1) { | ||||
|         MNN_PRINT("set multi record kernel mode is not permitted, please check cl_mode:%x!\n", cl_mode_num); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| void OpenCLRuntime::setCommandQueueProfileEnable() { | ||||
|  | @ -344,6 +364,7 @@ OpenCLRuntime::~OpenCLRuntime() { | |||
| #ifdef LOG_VERBOSE | ||||
|     MNN_PRINT("start ~OpenCLRuntime !\n"); | ||||
| #endif | ||||
|     releaseRecord(); | ||||
|     mBuildProgramMap.clear(); | ||||
|     mRecordings.clear(); | ||||
|     mCommandQueuePtr.reset(); | ||||
|  | @ -711,7 +732,7 @@ bool OpenCLRuntime::setCache(std::pair<const void*, size_t> cache) { | |||
| 
 | ||||
| void OpenCLRuntime::clearRecord(){ | ||||
| #if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER) | ||||
|     if(mUseRecordQueue){ | ||||
|     if(mUseRecordQueue && mDevideOpRecord){ | ||||
|         for(int i = 0; i < mRecordings.size(); ++i){ | ||||
|             cl_int res = mCommandQueuePtr->EnqueueRecordingQCOM(mRecordings[i], 0, nullptr, 0, nullptr, | ||||
|                   0, nullptr, 0, nullptr, 0, nullptr, nullptr); | ||||
|  | @ -722,4 +743,40 @@ void OpenCLRuntime::clearRecord(){ | |||
|     } | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| void OpenCLRuntime::enqeueRecord(){ | ||||
| #if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER) | ||||
|     if(mUseRecordQueue && !mDevideOpRecord){ | ||||
|         for(int i = 0; i < mRecordings.size(); ++i){ | ||||
|             cl_int res = mCommandQueuePtr->EnqueueRecordingQCOM(mRecordings[i], 0, nullptr, 0, nullptr, | ||||
|                   0, nullptr, 0, nullptr, 0, nullptr, nullptr); | ||||
|             MNN_CHECK_CL_SUCCESS(res, "EnqueueRecordingQCOM"); | ||||
|         } | ||||
|         mCommandQueuePtr->finish(); | ||||
|     } | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| void OpenCLRuntime::endRecord(){ | ||||
| #if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER) | ||||
|     if(mUseRecordQueue  && !mDevideOpRecord){ | ||||
|         if(!mRecordings.empty()){ | ||||
|             cl_int res = clEndRecordingQCOM(mRecordings.back()); | ||||
|             MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM"); | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| void OpenCLRuntime::releaseRecord(){ | ||||
| #if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER) | ||||
|     if(mUseRecordQueue  && !mDevideOpRecord){ | ||||
|         for(int i = 0; i < mRecordings.size(); ++i){ | ||||
|             cl_int res = clReleaseRecordingQCOM(mRecordings[i]); | ||||
|             MNN_CHECK_CL_SUCCESS(res, "clReleaseRecordingQCOM"); | ||||
|         } | ||||
|         mRecordings.clear(); | ||||
|     } | ||||
| #endif | ||||
| } | ||||
| } // namespace MNN
 | ||||
|  |  | |||
|  | @ -72,12 +72,24 @@ public: | |||
|     std::vector<cl_recording_qcom> *getRecordings(){ | ||||
|         return &mRecordings; | ||||
|     } | ||||
|     uint32_t getMaxRecordableQueueSize(){ | ||||
|         return mMaxRecordableQueueSize; | ||||
|     uint32_t getUseRecordableQueueSize(){ | ||||
|         return mUseRecordableQueueSize; | ||||
|     } | ||||
|     bool isUseRecordQueue(){ | ||||
|         return mUseRecordQueue; | ||||
|     } | ||||
|     bool isDevideOpRecord(){ | ||||
|         return mDevideOpRecord; | ||||
|     } | ||||
|     void setDevideOpRecord(){ | ||||
|         mDevideOpRecord = true; | ||||
|     } | ||||
|     void setRecordNum(int num){ | ||||
|         mRecordNums = num; | ||||
|     } | ||||
|     uint32_t getRecordNum(){ | ||||
|         return mRecordNums; | ||||
|     } | ||||
|     GpuType getGpuType() { | ||||
|         return mGpuType; | ||||
|     } | ||||
|  | @ -105,6 +117,9 @@ public: | |||
|     void setCommandQueueProfileEnable(); | ||||
|     void setCommandQueueProfileDisable(); | ||||
|     void clearRecord(); | ||||
|     void enqeueRecord(); | ||||
|     void endRecord(); | ||||
|     void releaseRecord(); | ||||
| 
 | ||||
|     unsigned int mQueueCount = 0; | ||||
|     unsigned int getQueueNum(); | ||||
|  | @ -153,8 +168,10 @@ private: | |||
|     uint64_t mMaxLocalMemSize; | ||||
|     uint32_t mMaxThreadsPerDevice; | ||||
|     uint32_t mMaxWorkGroupSize; | ||||
|     uint32_t mMaxRecordableQueueSize; | ||||
|     uint32_t mUseRecordableQueueSize; | ||||
|     uint32_t mRecordNums = 0; | ||||
|     bool mUseRecordQueue = false; | ||||
|     bool mDevideOpRecord = true; | ||||
|     bool mIsSupportedFP16     = false; | ||||
|     bool mIsDeviceSupportedFP16 = false; | ||||
|     bool mIsDeviceSupportedLowPower = false; | ||||
|  |  | |||
|  | @ -59,9 +59,9 @@ std::pair<std::vector<uint32_t>,  uint32_t> ConvBufCommonExecution::gws2dLwsTune | |||
|                         lws_prefer[1] = lws[1]; | ||||
|                     } | ||||
|                 } | ||||
|                 lws[0]++; | ||||
|                 lws[0]<<=1; | ||||
|             } | ||||
|             lws[1]++; | ||||
|             lws[1]<<=1; | ||||
|         } | ||||
|     } else if(runtime->getCLTuneLevel() == Wide) { | ||||
|         while(lws[1] <= gws[1] || lws[1] <= 6) { | ||||
|  | @ -88,12 +88,12 @@ std::pair<std::vector<uint32_t>,  uint32_t> ConvBufCommonExecution::gws2dLwsTune | |||
|                     } | ||||
|                 } | ||||
|                 do { | ||||
|                     lws[0]++; | ||||
|                     lws[0]<<=1; | ||||
|                 } | ||||
|                 while(((2*gws[0])%lws[0] > 1) && (lws[0] & (lws[0] - 1)) != 0 && (lws[0] <= gws[0]) && (lws[0] > 6));//divisible powOfTwo lessThanSix
 | ||||
|             } | ||||
|             do { | ||||
|                 lws[1]++; | ||||
|                 lws[1]<<=1; | ||||
|             } | ||||
|             while(((2*gws[1])%lws[1] > 1) && (lws[1] & (lws[1] - 1)) != 0 && (lws[1] <= gws[1]) && (lws[1] > 6));//divisible powOfTwo lessThanSix
 | ||||
|         } | ||||
|  | @ -122,12 +122,12 @@ std::pair<std::vector<uint32_t>,  uint32_t> ConvBufCommonExecution::gws2dLwsTune | |||
|                     } | ||||
|                 } | ||||
|                 do { | ||||
|                     lws[0]++; | ||||
|                     lws[0]<<=1; | ||||
|                 } | ||||
|                 while(((2*gws[0])%lws[0] > 1) && (lws[0] & (lws[0] - 1)) != 0 && (lws[0] <= gws[0]) && (lws[0] > 6));//divisible powOfTwo lessThanSix
 | ||||
|             } | ||||
|             do { | ||||
|                 lws[1]++; | ||||
|                 lws[1]<<=1; | ||||
|             } | ||||
|             while(((2*gws[1])%lws[1] > 1) && (lws[1] & (lws[1] - 1)) != 0 && (lws[1] <= gws[1]) && (lws[1] <= 6));//divisible powOfTwo lessThanSix
 | ||||
|         } | ||||
|  | @ -156,12 +156,12 @@ std::pair<std::vector<uint32_t>,  uint32_t> ConvBufCommonExecution::gws2dLwsTune | |||
|                     } | ||||
|                 } | ||||
|                 do { | ||||
|                     lws[0]++; | ||||
|                     lws[0]<<=1; | ||||
|                 } | ||||
|                 while(((2*gws[0])%lws[0] > 1) && (lws[0] & (lws[0] - 1)) != 0 && (lws[0] <= gws[0]) && (lws[0] <= 6));//divisible powOfTwo lessThanSix
 | ||||
|             } | ||||
|             do { | ||||
|                 lws[1]++; | ||||
|                 lws[1]<<=1; | ||||
|             } | ||||
|             while(((2*gws[1])%lws[1] > 1) && (lws[1] & (lws[1] - 1)) != 0 && (lws[1] <= gws[1]) && (lws[1] <= 6));//divisible powOfTwo lessThanSix
 | ||||
|         } | ||||
|  | @ -476,7 +476,14 @@ ErrorCode ConvBufExecution::onResize(const std::vector<Tensor *> &inputs, const | |||
|         std::vector<uint32_t> localWorkSize[total_kernel]; | ||||
|         std::pair<int, int> min_cost(INT_MAX, 0);//(min_time, min_index)
 | ||||
|         for(int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) { | ||||
|             kernel[knl_idx]        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[knl_idx], mBuildOptions); | ||||
|             std::set<std::string> buildOption = mBuildOptions; | ||||
|             if(outputShape.at(3) % itemC[knl_idx] != 0){ | ||||
|                 buildOption.emplace("-DCHANNEL_LEAVE"); | ||||
|             } | ||||
|             if((outputShape.at(2) % itemW[knl_idx]) != 0){ | ||||
|                 buildOption.emplace("-DBLOCK_LEAVE"); | ||||
|             } | ||||
|             kernel[knl_idx]        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[knl_idx], buildOption); | ||||
|             uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); | ||||
|              | ||||
|             uint32_t idx            = 0; | ||||
|  | @ -515,7 +522,14 @@ ErrorCode ConvBufExecution::onResize(const std::vector<Tensor *> &inputs, const | |||
|         } | ||||
|         mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]}; | ||||
|          | ||||
|         mKernel        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[min_index], mBuildOptions); | ||||
|         std::set<std::string> buildOption = mBuildOptions; | ||||
|         if(outputShape.at(3) % itemC[min_index] != 0){ | ||||
|             buildOption.emplace("-DCHANNEL_LEAVE"); | ||||
|         } | ||||
|         if((outputShape.at(2) % itemW[min_index]) != 0){ | ||||
|             buildOption.emplace("-DBLOCK_LEAVE"); | ||||
|         } | ||||
|         mKernel        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[min_index], buildOption); | ||||
|         uint32_t idx = 0; | ||||
|         cl_int ret = CL_SUCCESS; | ||||
| 
 | ||||
|  | @ -542,48 +556,29 @@ ErrorCode ConvBufExecution::onResize(const std::vector<Tensor *> &inputs, const | |||
|         int dilationShape[2]    = {mDilations[0], mDilations[1]}; | ||||
|          | ||||
|         // {"conv_2d_c4h1w2", "conv_2d_c4h1w1", "conv_2d_c8h1w1", "conv_2d_c4h1w4", "conv_2d_c8h2w1", "conv_2d_c4h4w1"};
 | ||||
|         const int total_kernel = 6; | ||||
|         std::string kernelName[total_kernel] = {"conv_2d_c4h1w1", "conv_2d_c4h1w2", "conv_2d_c4h4w1", "conv_2d_c8h2w1", "conv_2d_c8h4w1", "conv_2d_c4h1w4"}; | ||||
|         int itemC[total_kernel] = {4, 4, 4, 8, 8, 4}; | ||||
|         int itemH[total_kernel] = {1, 1, 4, 2, 4, 1}; | ||||
|         int itemW[total_kernel] = {1, 2, 1, 1, 1, 4}; | ||||
|         const int total_kernel = 7; | ||||
|         std::string kernelName[total_kernel] = {"conv_2d_c4h1w1", "conv_2d_c4h1w2", "conv_2d_c4h4w1", "conv_2d_c8h2w1", "conv_2d_c8h4w1", "conv_2d_c4h1w4", "conv_2d_c8h1w4"}; | ||||
|         int itemC[total_kernel] = {4, 4, 4, 8, 8, 4, 8}; | ||||
|         int itemH[total_kernel] = {1, 1, 4, 2, 4, 1, 1}; | ||||
|         int itemW[total_kernel] = {1, 2, 1, 1, 1, 4, 4}; | ||||
|          | ||||
|          | ||||
|         int actual_kernel = total_kernel; | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->getCLTuneLevel() == Normal) { | ||||
|             actual_kernel = 2; | ||||
|         } else if(mOpenCLBackend->getOpenCLRuntime()->getCLTuneLevel() == Fast || mOpenCLBackend->getOpenCLRuntime()->getCLTuneLevel() == None) { | ||||
|             actual_kernel = 1; | ||||
|         }else if(mOpenCLBackend->getOpenCLRuntime()->getCLTuneLevel() == Wide){ | ||||
|             actual_kernel = 4; | ||||
|             auto gpuType = mOpenCLBackend->getOpenCLRuntime()->getGpuType(); | ||||
|             auto maliArType = mOpenCLBackend->getOpenCLRuntime()->getMaliAr(); | ||||
|             if(gpuType == MNN::MALI && maliArType == MNN::VALHALL){ | ||||
|                 if(outputShape.at(3) <= 8){ | ||||
|                     kernelName[3] = "conv_2d_c4h1w4"; | ||||
|                     itemC[3]      = 4; | ||||
|                     itemH[3]      = 1; | ||||
|                     itemW[3]      = 4; | ||||
|                 }else{ | ||||
|                     kernelName[2] = "conv_2d_c8h2w1"; | ||||
|                     itemC[2]      = 8; | ||||
|                     itemH[2]      = 2; | ||||
|                     itemW[2]      = 1; | ||||
|          | ||||
|                     kernelName[3] = "conv_2d_c8h4w1"; | ||||
|                     itemC[3]      = 8; | ||||
|                     itemH[3]      = 4; | ||||
|                     itemW[3]      = 1; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|          | ||||
|         cl::Kernel kernel[total_kernel]; | ||||
|         std::vector<uint32_t> globalWorkSize[total_kernel]; | ||||
|         std::vector<uint32_t> localWorkSize[total_kernel]; | ||||
|         std::pair<int, int> min_cost(INT_MAX, 0);//(min_time, min_index)
 | ||||
|         for(int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) { | ||||
|             kernel[knl_idx]        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[knl_idx], mBuildOptions); | ||||
|             std::set<std::string> buildOption = mBuildOptions; | ||||
|             if(outputShape.at(3) % itemC[knl_idx] != 0){ | ||||
|                 buildOption.emplace("-DCHANNEL_LEAVE"); | ||||
|             } | ||||
|             if((outputShape.at(2) % itemW[knl_idx]) != 0 || (outputShape.at(1) % itemH[knl_idx]) != 0){ | ||||
|                 buildOption.emplace("-DBLOCK_LEAVE"); | ||||
|             } | ||||
|             kernel[knl_idx]        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[knl_idx], buildOption); | ||||
|             uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); | ||||
|              | ||||
|             globalWorkSize[knl_idx] = {static_cast<uint32_t>(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast<uint32_t>(outputShape.at(0) * UP_DIV(outputShape.at(1), itemH[knl_idx]))}; | ||||
|  | @ -620,7 +615,14 @@ ErrorCode ConvBufExecution::onResize(const std::vector<Tensor *> &inputs, const | |||
|         int min_index  = min_cost.second; | ||||
|         mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]}; | ||||
|          | ||||
|         mKernel        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[min_index], mBuildOptions); | ||||
|         std::set<std::string> buildOption = mBuildOptions; | ||||
|         if(outputShape.at(3) % itemC[min_index] != 0){ | ||||
|             buildOption.emplace("-DCHANNEL_LEAVE"); | ||||
|         } | ||||
|         if((outputShape.at(2) % itemW[min_index]) != 0 || (outputShape.at(1) % itemH[min_index]) != 0){ | ||||
|             buildOption.emplace("-DBLOCK_LEAVE"); | ||||
|         } | ||||
|         mKernel        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[min_index], buildOption); | ||||
|          | ||||
|         uint32_t idx            = 0; | ||||
|         cl_int ret = CL_SUCCESS; | ||||
|  |  | |||
|  | @ -38,7 +38,7 @@ bool ConvBufWinograd::valid(const Convolution2DCommon* common, const Tensor* inp | |||
|     if(input->channel() < 32 || input->channel() > input_channel_limit){ | ||||
|         return false; | ||||
|     } | ||||
|     return (input->width() <= 16 && input->height() <= 16); | ||||
|     return (input->width() <= 32 && input->height() <= 32); | ||||
| } | ||||
| 
 | ||||
| ConvBufWinograd::ConvBufWinograd(const MNN::Convolution2D* op, Backend* backend) : Execution(backend) { | ||||
|  |  | |||
|  | @ -147,6 +147,7 @@ ErrorCode RasterBufExecution::onResize(const std::vector<Tensor *> &____inputs, | |||
|         return NO_ERROR; | ||||
|     } | ||||
| 
 | ||||
|     bool cancombine = CanCombine(outputs); | ||||
|     // Alloc Temp buffer
 | ||||
|     auto bufferPool     = ((OpenCLBackend *)backend())->getBufferPool(); | ||||
|     auto bufferUnitSize = runtime->isSupportedFP16() ? sizeof(half_float::half) : sizeof(float); | ||||
|  | @ -170,6 +171,9 @@ ErrorCode RasterBufExecution::onResize(const std::vector<Tensor *> &____inputs, | |||
|     bufferPool->recycle(mTempOutput); | ||||
|      | ||||
|     auto originNum = mTempInput.size(); | ||||
|     if(cancombine){ | ||||
|         regionNum = 1; | ||||
|     } | ||||
|     mUnits.resize(regionNum + originNum + 1); | ||||
|      | ||||
|     int kernel_idx = 0; | ||||
|  | @ -241,14 +245,23 @@ ErrorCode RasterBufExecution::onResize(const std::vector<Tensor *> &____inputs, | |||
|     } | ||||
|      | ||||
|     // buffer raster
 | ||||
|     for (auto& slice : des->regions) | ||||
|     { | ||||
|         Unit &unit          = mUnits[kernel_idx++]; | ||||
|         unit.kernel         = runtime->buildKernel("raster_buf", "raster_buffer", {}); | ||||
|     if(cancombine){ | ||||
|         auto regions = des->regions; | ||||
|         auto slice = regions[0]; | ||||
|         int nums = regions.size(); | ||||
|         int src_offset = regions[1].src.offset - slice.src.offset; | ||||
|         int dst_offset = regions[1].dst.offset - slice.dst.offset; | ||||
|          | ||||
|         const std::vector<uint32_t> gws =  {(uint32_t)slice.size[2], | ||||
|                                                 (uint32_t)slice.size[1], | ||||
|                                                 (uint32_t)slice.size[0]}; | ||||
|         Unit &unit          = mUnits[kernel_idx++]; | ||||
|         unit.kernel         = runtime->buildKernel("raster", "raster_buffer_combine", {}); | ||||
|          | ||||
|         unit.globalWorkSize = {(uint32_t)slice.size[2] * nums, | ||||
|             (uint32_t)slice.size[1], | ||||
|             (uint32_t)slice.size[0]}; | ||||
|          | ||||
|         const std::vector<uint32_t> gws =  {(uint32_t)slice.size[2] * nums, | ||||
|             (uint32_t)slice.size[1], | ||||
|             (uint32_t)slice.size[0]}; | ||||
|         uint32_t mMaxWorkGroupSize      = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(unit.kernel)); | ||||
|          | ||||
|         uint32_t idx   = 0; | ||||
|  | @ -258,27 +271,71 @@ ErrorCode RasterBufExecution::onResize(const std::vector<Tensor *> &____inputs, | |||
|         ret |= unit.kernel.setArg(idx++, gws[2]); | ||||
|         ret |= unit.kernel.setArg(idx++, *(mTempInput[slice.origin])); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.src.offset); | ||||
|         ret |= unit.kernel.setArg(idx++, src_offset); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.src.stride[0]); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.src.stride[1]); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.src.stride[2]); | ||||
|         ret |= unit.kernel.setArg(idx++, *mTempOutput); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.dst.offset); | ||||
|         ret |= unit.kernel.setArg(idx++, dst_offset); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.dst.stride[0]); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.dst.stride[1]); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.dst.stride[2]); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.size[2]); | ||||
|         if(ret != CL_SUCCESS) | ||||
|         { | ||||
|             MNN_PRINT("setArg err %d\n", (int)ret); | ||||
|         } | ||||
|          | ||||
|         std::string name = "raster_buffer"; | ||||
|         std::string name = "rasterBuffer"; | ||||
|         const std::vector<uint32_t> lws = localWS3DDefault(gws, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), name, unit.kernel).first; | ||||
|          | ||||
|         unit.localWorkSize = {lws[0], lws[1], lws[2]}; | ||||
|          | ||||
|         unit.globalWorkSize = {ROUND_UP(gws[0], std::max((uint32_t)1, lws[0])), | ||||
|                                ROUND_UP(gws[1], std::max((uint32_t)1, lws[1])), | ||||
|                                ROUND_UP(gws[2], std::max((uint32_t)1, lws[2]))}; | ||||
|             ROUND_UP(gws[1], std::max((uint32_t)1, lws[1])), | ||||
|             ROUND_UP(gws[2], std::max((uint32_t)1, lws[2]))}; | ||||
|         recordKernel3d(unit.kernel, gws, lws, runtime); | ||||
|     }else{ | ||||
|         for (auto& slice : des->regions) | ||||
|         { | ||||
|             Unit &unit          = mUnits[kernel_idx++]; | ||||
|             unit.kernel         = runtime->buildKernel("raster_buf", "raster_buffer", {}); | ||||
|              | ||||
|             const std::vector<uint32_t> gws =  {(uint32_t)slice.size[2], | ||||
|                 (uint32_t)slice.size[1], | ||||
|                 (uint32_t)slice.size[0]}; | ||||
|             uint32_t mMaxWorkGroupSize      = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(unit.kernel)); | ||||
|              | ||||
|             uint32_t idx   = 0; | ||||
|             cl_int ret = CL_SUCCESS; | ||||
|             ret |= unit.kernel.setArg(idx++, gws[0]); | ||||
|             ret |= unit.kernel.setArg(idx++, gws[1]); | ||||
|             ret |= unit.kernel.setArg(idx++, gws[2]); | ||||
|             ret |= unit.kernel.setArg(idx++, *(mTempInput[slice.origin])); | ||||
|             ret |= unit.kernel.setArg(idx++, slice.src.offset); | ||||
|             ret |= unit.kernel.setArg(idx++, slice.src.stride[0]); | ||||
|             ret |= unit.kernel.setArg(idx++, slice.src.stride[1]); | ||||
|             ret |= unit.kernel.setArg(idx++, slice.src.stride[2]); | ||||
|             ret |= unit.kernel.setArg(idx++, *mTempOutput); | ||||
|             ret |= unit.kernel.setArg(idx++, slice.dst.offset); | ||||
|             ret |= unit.kernel.setArg(idx++, slice.dst.stride[0]); | ||||
|             ret |= unit.kernel.setArg(idx++, slice.dst.stride[1]); | ||||
|             ret |= unit.kernel.setArg(idx++, slice.dst.stride[2]); | ||||
|             if(ret != CL_SUCCESS) | ||||
|             { | ||||
|                 MNN_PRINT("setArg err %d\n", (int)ret); | ||||
|             } | ||||
|              | ||||
|             std::string name = "raster_buffer"; | ||||
|             const std::vector<uint32_t> lws = localWS3DDefault(gws, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), name, unit.kernel).first; | ||||
|              | ||||
|             unit.localWorkSize = {lws[0], lws[1], lws[2]}; | ||||
|              | ||||
|             unit.globalWorkSize = {ROUND_UP(gws[0], std::max((uint32_t)1, lws[0])), | ||||
|                 ROUND_UP(gws[1], std::max((uint32_t)1, lws[1])), | ||||
|                 ROUND_UP(gws[2], std::max((uint32_t)1, lws[2]))}; | ||||
|         } | ||||
|     } | ||||
|      | ||||
|     //buffer to nc4hw4 buffer
 | ||||
|  | @ -349,6 +406,44 @@ public: | |||
|     } | ||||
| }; | ||||
| 
 | ||||
| bool RasterBufExecution::CanCombine(const std::vector<Tensor *> &outputs){ | ||||
|     auto des = TensorUtils::getDescribe(outputs[0]); | ||||
|     auto regions = des->regions; | ||||
|     if(regions.size() < 2) | ||||
|         return false; | ||||
|     auto origin = regions[0].origin; | ||||
|     const int size0 = regions[0].size[0]; | ||||
|     const int size1 = regions[0].size[1]; | ||||
|     const int size2 = regions[0].size[2]; | ||||
|     const int src_offset = regions[1].src.offset - regions[0].src.offset; | ||||
|     const int dst_offset = regions[1].dst.offset - regions[0].dst.offset; | ||||
|     const int src_sride0 = regions[0].src.stride[0]; | ||||
|     const int src_sride1 = regions[0].src.stride[1]; | ||||
|     const int src_sride2 = regions[0].src.stride[2]; | ||||
|     const int dst_sride0 = regions[0].dst.stride[0]; | ||||
|     const int dst_sride1 = regions[0].dst.stride[1]; | ||||
|     const int dst_sride2 = regions[0].dst.stride[2]; | ||||
|     bool res = true; | ||||
|     for(int i = 1; i < regions.size(); ++i){ | ||||
|         res &= regions[i].origin == origin; | ||||
|         res &= regions[i].size[0] == size0; | ||||
|         res &= regions[i].size[1] == size1; | ||||
|         res &= regions[i].size[2] == size2; | ||||
|         res &= regions[i].src.stride[0] == src_sride0; | ||||
|         res &= regions[i].src.stride[1] == src_sride1; | ||||
|         res &= regions[i].src.stride[2] == src_sride2; | ||||
|         res &= regions[i].dst.stride[0] == dst_sride0; | ||||
|         res &= regions[i].dst.stride[1] == dst_sride1; | ||||
|         res &= regions[i].dst.stride[2] == dst_sride2; | ||||
|         res &= (regions[i].src.offset - regions[i - 1].src.offset) == src_offset; | ||||
|         res &= (regions[i].dst.offset - regions[i - 1].dst.offset) == dst_offset; | ||||
|         if(res == false){ | ||||
|             return res; | ||||
|         } | ||||
|     } | ||||
|     return res; | ||||
| } | ||||
| 
 | ||||
| OpenCLCreatorRegister<RasterCreator> __RasterBuf_op(OpType_Raster, BUFFER); | ||||
| } // namespace OpenCL
 | ||||
| } // namespace MNN
 | ||||
|  |  | |||
|  | @ -28,6 +28,7 @@ public: | |||
|     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | ||||
| 
 | ||||
| private: | ||||
|     bool CanCombine(const std::vector<Tensor *> &outputs); | ||||
|     std::map<Tensor*, cl::Buffer *> mTempInput; | ||||
|     cl::Buffer *mTempOutput; | ||||
|     OpenCLBackend *mOpenCLBackend; | ||||
|  |  | |||
|  | @ -285,13 +285,19 @@ __kernel | |||
| #if SET_ATTRIBUTE | ||||
| __attribute__((work_group_size_hint(16, 16, 1))) | ||||
| #endif | ||||
| void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t weights, | ||||
| void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, | ||||
| #ifdef USE_BUFFER | ||||
| __global const FLOAT *weights, | ||||
| #else | ||||
| __read_only image2d_t weights, | ||||
| #endif | ||||
|                           __read_only image2d_t bias, | ||||
|                           __write_only image2d_t output, | ||||
|                           __private const int2 input_shape, | ||||
|                           __private const int in_channel_block, __private const int2 output_shape, | ||||
|                           __private const int2 stride_shape, | ||||
|                           __private const int output_width_4) { | ||||
|                           __private const int output_width_4, | ||||
|                           __private const int out_channel_blocks) { | ||||
| 
 | ||||
|     const int output_channel_width_idx = get_global_id(0); | ||||
|     const int output_batch_height_idx  = get_global_id(1); | ||||
|  | @ -305,6 +311,12 @@ void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only ima | |||
|     FLOAT4 out2 = out0; | ||||
|     FLOAT4 out3 = out0; | ||||
| 
 | ||||
| #ifdef MNN_CONV_S1D1 | ||||
|     int intput_width_idx0 = output_width_block_idx << 2; | ||||
|     int intput_width_idx1 = intput_width_idx0 + 1; | ||||
|     int intput_width_idx2 = intput_width_idx0 + 2; | ||||
|     int intput_width_idx3 = intput_width_idx0 + 3; | ||||
| #else | ||||
|     int intput_width_idx0 = mul24(output_width_block_idx, stride_shape.y*4); | ||||
|     int intput_width_idx1 = intput_width_idx0 + stride_shape.y; | ||||
|     int intput_width_idx2 = intput_width_idx1 + stride_shape.y; | ||||
|  | @ -314,7 +326,7 @@ void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only ima | |||
|     intput_width_idx1 = select(intput_width_idx1, INT_MIN, intput_width_idx1 >= input_shape.y); | ||||
|     intput_width_idx2 = select(intput_width_idx2, INT_MIN, intput_width_idx2 >= input_shape.y); | ||||
|     intput_width_idx3 = select(intput_width_idx3, INT_MIN, intput_width_idx3 >= input_shape.y); | ||||
| 
 | ||||
| #endif | ||||
|     int batch_index            = output_batch_height_idx / output_shape.x; | ||||
|     int input_height_block_idx = mul24((output_batch_height_idx % output_shape.x), stride_shape.x) + batch_index * input_shape.x; | ||||
| 
 | ||||
|  | @ -326,19 +338,26 @@ void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only ima | |||
|     FLOAT4 weights1; | ||||
|     FLOAT4 weights2; | ||||
|     FLOAT4 weights3; | ||||
|     int weight_offset = output_channel_block_idx * in_channel_block * 4 * 4; | ||||
| 
 | ||||
|     for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block; ++in_channel_block_idx) { | ||||
|         int input_width_base  = in_channel_block_idx * input_shape.y; | ||||
|         int weights_width_base = in_channel_block_idx << 2; | ||||
|         in0 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx0, input_height_block_idx)); | ||||
|         in1 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx1, input_height_block_idx)); | ||||
|         in2 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx2, input_height_block_idx)); | ||||
|         in3 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx3, input_height_block_idx)); | ||||
| 
 | ||||
| #ifdef USE_BUFFER | ||||
|         weights0 = vload4(weights_width_base, weights + weight_offset); | ||||
|         weights1 = vload4(weights_width_base + 1, weights + weight_offset); | ||||
|         weights2 = vload4(weights_width_base + 2, weights + weight_offset); | ||||
|         weights3 = vload4(weights_width_base + 3, weights + weight_offset); | ||||
| #else | ||||
|         weights0 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 0, output_channel_block_idx)); | ||||
|         weights1 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 1, output_channel_block_idx)); | ||||
|         weights2 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 2, output_channel_block_idx)); | ||||
|         weights3 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 3, output_channel_block_idx)); | ||||
| #endif | ||||
|         in0 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx0, input_height_block_idx)); | ||||
|         in1 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx1, input_height_block_idx)); | ||||
|         in2 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx2, input_height_block_idx)); | ||||
|         in3 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx3, input_height_block_idx)); | ||||
| 
 | ||||
|         CALCULATE_OUTPUT(0); | ||||
|         CALCULATE_OUTPUT(1); | ||||
|  | @ -386,7 +405,187 @@ __kernel | |||
| #if SET_ATTRIBUTE | ||||
| __attribute__((work_group_size_hint(16, 16, 1))) | ||||
| #endif | ||||
| void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t weights, | ||||
| void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, | ||||
| #ifdef USE_BUFFER | ||||
| __global const FLOAT *weights, | ||||
| #else | ||||
| __read_only image2d_t weights, | ||||
| #endif | ||||
|                           __read_only image2d_t bias, | ||||
|                           __write_only image2d_t output, | ||||
|                           __private const int2 input_shape, | ||||
|                           __private const int in_channel_block, __private const int2 output_shape, | ||||
|                           __private const int2 stride_shape, | ||||
|                           __private const int output_width_4, | ||||
|                           __private const int out_channel_blocks) { | ||||
| 
 | ||||
|     const int output_channel_width_idx = get_global_id(0); | ||||
|     const int output_batch_height_idx  = get_global_id(1); | ||||
|     DEAL_NON_UNIFORM_DIM2(output_channel_width_idx, output_batch_height_idx); | ||||
| 
 | ||||
|     const int output_channel_block_idx = output_channel_width_idx / output_width_4; | ||||
|     const int output_width_block_idx   = output_channel_width_idx % output_width_4; | ||||
|     const int output_channel_idx = output_channel_block_idx << 1; | ||||
| 
 | ||||
|     FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(output_channel_idx, 0)); | ||||
|     FLOAT4 out1 = out0; | ||||
|     FLOAT4 out2 = out0; | ||||
|     FLOAT4 out3 = out0; | ||||
|      | ||||
|     FLOAT4 out4 = RI_F(bias, SAMPLER, (int2)(output_channel_idx + 1, 0)); | ||||
|     FLOAT4 out5 = out4; | ||||
|     FLOAT4 out6 = out4; | ||||
|     FLOAT4 out7 = out4; | ||||
| 
 | ||||
| #ifdef MNN_CONV_S1D1 | ||||
|     int intput_width_idx0 = output_width_block_idx << 2; | ||||
|     int intput_width_idx1 = intput_width_idx0 + 1; | ||||
|     int intput_width_idx2 = intput_width_idx0 + 2; | ||||
|     int intput_width_idx3 = intput_width_idx0 + 3; | ||||
| #else | ||||
|     int intput_width_idx0 = mul24(output_width_block_idx, stride_shape.y*4); | ||||
|     int intput_width_idx1 = intput_width_idx0 + stride_shape.y; | ||||
|     int intput_width_idx2 = intput_width_idx1 + stride_shape.y; | ||||
|     int intput_width_idx3 = intput_width_idx2 + stride_shape.y; | ||||
| 
 | ||||
|     intput_width_idx0 = select(intput_width_idx0, INT_MIN, intput_width_idx0 >= input_shape.y); | ||||
|     intput_width_idx1 = select(intput_width_idx1, INT_MIN, intput_width_idx1 >= input_shape.y); | ||||
|     intput_width_idx2 = select(intput_width_idx2, INT_MIN, intput_width_idx2 >= input_shape.y); | ||||
|     intput_width_idx3 = select(intput_width_idx3, INT_MIN, intput_width_idx3 >= input_shape.y); | ||||
| #endif | ||||
| 
 | ||||
|     int batch_index            = output_batch_height_idx / output_shape.x; | ||||
|     int input_height_block_idx = mul24((output_batch_height_idx % output_shape.x), stride_shape.x) + batch_index * input_shape.x; | ||||
| 
 | ||||
|     FLOAT4 in0; | ||||
|     FLOAT4 in1; | ||||
|     FLOAT4 in2; | ||||
|     FLOAT4 in3; | ||||
|     FLOAT4 weights0; | ||||
|     FLOAT4 weights1; | ||||
|     FLOAT4 weights2; | ||||
|     FLOAT4 weights3; | ||||
|     FLOAT4 weights4; | ||||
|     FLOAT4 weights5; | ||||
|     FLOAT4 weights6; | ||||
|     FLOAT4 weights7; | ||||
|     int weight_offset = output_channel_idx * in_channel_block * 4 * 4; | ||||
|     int weight_offset1 = weight_offset + in_channel_block * 4 * 4; | ||||
| 
 | ||||
|     for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block; ++in_channel_block_idx) { | ||||
|         int input_width_base  = in_channel_block_idx * input_shape.y; | ||||
|         int weights_width_base = in_channel_block_idx << 2; | ||||
|         in0 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx0, input_height_block_idx)); | ||||
|         in1 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx1, input_height_block_idx)); | ||||
|         in2 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx2, input_height_block_idx)); | ||||
|         in3 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx3, input_height_block_idx)); | ||||
| 
 | ||||
| #ifdef USE_BUFFER | ||||
|         weights0 = vload4(weights_width_base, weights + weight_offset); | ||||
|         weights1 = vload4(weights_width_base + 1, weights + weight_offset); | ||||
|         weights2 = vload4(weights_width_base + 2, weights + weight_offset); | ||||
|         weights3 = vload4(weights_width_base + 3, weights + weight_offset); | ||||
| 
 | ||||
|         weights4 = vload4(weights_width_base, weights + weight_offset1); | ||||
|         weights5 = vload4(weights_width_base + 1, weights + weight_offset1); | ||||
|         weights6 = vload4(weights_width_base + 2, weights + weight_offset1); | ||||
|         weights7 = vload4(weights_width_base + 3, weights + weight_offset1); | ||||
| #else | ||||
|         weights0 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 0, output_channel_idx)); | ||||
|         weights1 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 1, output_channel_idx)); | ||||
|         weights2 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 2, output_channel_idx)); | ||||
|         weights3 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 3, output_channel_idx)); | ||||
|          | ||||
|         weights4 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 0, output_channel_idx + 1)); | ||||
|         weights5 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 1, output_channel_idx + 1)); | ||||
|         weights6 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 2, output_channel_idx + 1)); | ||||
|         weights7 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 3, output_channel_idx + 1)); | ||||
| #endif | ||||
| 
 | ||||
|         CALCULATE_OUTPUT(0); | ||||
|         CALCULATE_OUTPUT(1); | ||||
|         CALCULATE_OUTPUT(2); | ||||
|         CALCULATE_OUTPUT(3); | ||||
|          | ||||
|         CALCULATE_OUTPUT_WEIGHTS4(4, 0); | ||||
|         CALCULATE_OUTPUT_WEIGHTS4(5, 1); | ||||
|         CALCULATE_OUTPUT_WEIGHTS4(6, 2); | ||||
|         CALCULATE_OUTPUT_WEIGHTS4(7, 3); | ||||
|     } | ||||
| 
 | ||||
| #ifdef RELU | ||||
|     out0 = fmax(out0, (FLOAT4)0); | ||||
|     out1 = fmax(out1, (FLOAT4)0); | ||||
|     out2 = fmax(out2, (FLOAT4)0); | ||||
|     out3 = fmax(out3, (FLOAT4)0); | ||||
|     out4 = fmax(out4, (FLOAT4)0); | ||||
|     out5 = fmax(out5, (FLOAT4)0); | ||||
|     out6 = fmax(out6, (FLOAT4)0); | ||||
|     out7 = fmax(out7, (FLOAT4)0); | ||||
| #endif | ||||
| 
 | ||||
| #ifdef RELU6 | ||||
|     out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); | ||||
|     out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); | ||||
|     out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6); | ||||
|     out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6); | ||||
|     out4 = clamp(out4, (FLOAT4)0, (FLOAT4)6); | ||||
|     out5 = clamp(out5, (FLOAT4)0, (FLOAT4)6); | ||||
|     out6 = clamp(out6, (FLOAT4)0, (FLOAT4)6); | ||||
|     out7 = clamp(out7, (FLOAT4)0, (FLOAT4)6); | ||||
| #endif | ||||
| 
 | ||||
|     const int out_x_base = mul24(output_channel_idx, output_shape.y); | ||||
|     int out_x_idx        = output_width_block_idx << 2; | ||||
| 
 | ||||
|     const int remain = output_shape.y - out_x_idx; | ||||
|     int output_idx   = out_x_base + out_x_idx; | ||||
|     if (remain >= 4) { | ||||
|         WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); | ||||
|         WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); | ||||
|         WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2); | ||||
|         WI_F(output, (int2)(output_idx + 3, output_batch_height_idx), out3); | ||||
|     } else if (remain == 3) { | ||||
|         WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); | ||||
|         WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); | ||||
|         WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2); | ||||
|     } else if (remain == 2) { | ||||
|         WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); | ||||
|         WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1); | ||||
|     } else if (remain == 1) { | ||||
|         WI_F(output, (int2)(output_idx, output_batch_height_idx), out0); | ||||
|     } | ||||
|      | ||||
|     if(output_channel_idx + 1 >= out_channel_blocks) | ||||
|         return; | ||||
|     output_idx += output_shape.y; | ||||
|     if (remain >= 4) { | ||||
|         WI_F(output, (int2)(output_idx, output_batch_height_idx), out4); | ||||
|         WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out5); | ||||
|         WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out6); | ||||
|         WI_F(output, (int2)(output_idx + 3, output_batch_height_idx), out7); | ||||
|     } else if (remain == 3) { | ||||
|         WI_F(output, (int2)(output_idx, output_batch_height_idx), out4); | ||||
|         WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out5); | ||||
|         WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out6); | ||||
|     } else if (remain == 2) { | ||||
|         WI_F(output, (int2)(output_idx, output_batch_height_idx), out4); | ||||
|         WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out5); | ||||
|     } else if (remain == 1) { | ||||
|         WI_F(output, (int2)(output_idx, output_batch_height_idx), out4); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| __kernel | ||||
| #if SET_ATTRIBUTE | ||||
| __attribute__((work_group_size_hint(16, 16, 1))) | ||||
| #endif | ||||
| void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, | ||||
| #ifdef USE_BUFFER | ||||
| __global const FLOAT *weights, | ||||
| #else | ||||
| __read_only image2d_t weights, | ||||
| #endif | ||||
| #ifdef BIAS | ||||
|                       __read_only image2d_t bias, | ||||
| #endif | ||||
|  | @ -425,26 +624,36 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only | |||
|      | ||||
| #ifdef MNN_CONV_S1D1 | ||||
|     const int height_start = mad24((output_batch_height_idx % output_shape.x), 1, -padding_shape.x); | ||||
|     int in_height_start    = select(0, (-height_start), height_start < 0) + height_start; | ||||
|     const int kh_start = select(0, (-height_start), height_start < 0); | ||||
|     int in_height_start    = kh_start + height_start; | ||||
|     int in_height_end      = min(weights_shape.x + height_start, input_shape.x); | ||||
| 
 | ||||
|     const int batch_idx          = mul24((output_batch_height_idx / output_shape.x), input_shape.x); | ||||
|     const int weights_h_idx = mul24(out_channel_block_idx, mul24(weights_shape.y, weights_shape.x)) + mul24(select(0, (-height_start), height_start < 0), weights_shape.y); | ||||
| #else | ||||
|     const int height_start = mad24((output_batch_height_idx % output_shape.x), stride_shape.x, -padding_shape.x); | ||||
|     int in_height_start    = mad24(select(0, (-height_start + dilation_shape.x - 1) / dilation_shape.x, height_start < 0), dilation_shape.x, height_start); | ||||
|     const int kh_start = select(0, (-height_start + dilation_shape.x - 1) / dilation_shape.x, height_start < 0); | ||||
|     int in_height_start    = mad24(kh_start, dilation_shape.x, height_start); | ||||
|     int in_height_end      = min(mad24(weights_shape.x, dilation_shape.x, height_start), input_shape.x); | ||||
| 
 | ||||
|     const int batch_idx          = mul24((output_batch_height_idx / output_shape.x), input_shape.x); | ||||
|     const int weights_h_idx = mul24(out_channel_block_idx, mul24(weights_shape.y, weights_shape.x)) + mul24(select(0, (-height_start + dilation_shape.x - 1) / dilation_shape.x, height_start < 0), weights_shape.y); | ||||
| #endif | ||||
| 
 | ||||
| #ifdef USE_BUFFER | ||||
|     const int weight_oc_offset = out_channel_blocks * weights_shape.x * weights_shape.y * 4; | ||||
| #endif | ||||
| 
 | ||||
|     FLOAT4 in0, in1, in2, in3; | ||||
|     FLOAT4 weights0, weights1, weights2, weights3; | ||||
|     for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) { | ||||
|         const int in_idx = mul24(in_channel_block_idx, input_shape.y); | ||||
| #ifdef USE_BUFFER | ||||
|         int weight_offset = ((((4*in_channel_block_idx+0)* out_channel_blocks + out_channel_block_idx) *weights_shape.x + kh_start)*weights_shape.y + 0) * 4; | ||||
| #else | ||||
|         int weights_x_idx = in_channel_block_idx << 2; | ||||
|         int weights_y_idx = weights_h_idx; | ||||
| #endif | ||||
|         for (int iy = in_height_start; iy < in_height_end; iy += dilation_shape.x) { | ||||
|             int in_hb_value = iy + batch_idx; | ||||
| #ifdef MNN_CONV_S1D1 | ||||
|  | @ -453,12 +662,18 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only | |||
|                 READ_INPUT_IMAGE(1, 0); | ||||
|                 READ_INPUT_IMAGE(2, 0); | ||||
|                 READ_INPUT_IMAGE(3, 0); | ||||
| 
 | ||||
| #ifdef USE_BUFFER | ||||
|                 weights0 = vload4(0, weights+weight_offset); | ||||
|                 weights1 = vload4(0, weights+weight_offset+weight_oc_offset); | ||||
|                 weights2 = vload4(0, weights+weight_offset+weight_oc_offset*2); | ||||
|                 weights3 = vload4(0, weights+weight_offset+weight_oc_offset*3); | ||||
|                 weight_offset += 4; | ||||
| #else | ||||
|                 weights0 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 0, weights_y_idx)); | ||||
|                 weights1 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 1, weights_y_idx)); | ||||
|                 weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx)); | ||||
|                 weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++)); | ||||
| 
 | ||||
| #endif | ||||
|                 CALCULATE_OUTPUT(0); | ||||
|                 CALCULATE_OUTPUT(1); | ||||
|                 CALCULATE_OUTPUT(2); | ||||
|  | @ -469,12 +684,18 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only | |||
|                 in1 = in2; | ||||
|                 in2 = in3; | ||||
|                 READ_INPUT_IMAGE(3, w); | ||||
| 
 | ||||
| #ifdef USE_BUFFER | ||||
|                 weights0 = vload4(0, weights+weight_offset); | ||||
|                 weights1 = vload4(0, weights+weight_offset+weight_oc_offset); | ||||
|                 weights2 = vload4(0, weights+weight_offset+weight_oc_offset*2); | ||||
|                 weights3 = vload4(0, weights+weight_offset+weight_oc_offset*3); | ||||
|                 weight_offset += 4; | ||||
| #else | ||||
|                 weights0 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 0, weights_y_idx)); | ||||
|                 weights1 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 1, weights_y_idx)); | ||||
|                 weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx)); | ||||
|                 weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++)); | ||||
| 
 | ||||
| #endif | ||||
|                 CALCULATE_OUTPUT(0); | ||||
|                 CALCULATE_OUTPUT(1); | ||||
|                 CALCULATE_OUTPUT(2); | ||||
|  | @ -487,12 +708,18 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only | |||
|                 READ_INPUT_IMAGE(1, input_width_base); | ||||
|                 READ_INPUT_IMAGE(2, input_width_base); | ||||
|                 READ_INPUT_IMAGE(3, input_width_base); | ||||
| 
 | ||||
| #ifdef USE_BUFFER | ||||
|                 weights0 = vload4(0, weights+weight_offset); | ||||
|                 weights1 = vload4(0, weights+weight_offset+weight_oc_offset); | ||||
|                 weights2 = vload4(0, weights+weight_offset+weight_oc_offset*2); | ||||
|                 weights3 = vload4(0, weights+weight_offset+weight_oc_offset*3); | ||||
|                 weight_offset += 4; | ||||
| #else | ||||
|                 weights0 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 0, weights_y_idx));  | ||||
|                 weights1 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 1, weights_y_idx));  | ||||
|                 weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx));  | ||||
|                 weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++)); | ||||
| 
 | ||||
| #endif | ||||
|                 CALCULATE_OUTPUT(0); | ||||
|                 CALCULATE_OUTPUT(1); | ||||
|                 CALCULATE_OUTPUT(2); | ||||
|  | @ -542,7 +769,12 @@ __kernel | |||
| #if SET_ATTRIBUTE | ||||
| __attribute__((work_group_size_hint(16, 16, 1))) | ||||
| #endif | ||||
| void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t weights, | ||||
| void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, | ||||
| #ifdef USE_BUFFER | ||||
| __global const FLOAT *weights, | ||||
| #else | ||||
| __read_only image2d_t weights, | ||||
| #endif | ||||
| #ifdef BIAS | ||||
|                       __read_only image2d_t bias, | ||||
| #endif | ||||
|  | @ -581,6 +813,11 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only | |||
|     FLOAT4 out6 = out4; | ||||
|     FLOAT4 out7 = out4; | ||||
| 
 | ||||
| #ifdef USE_BUFFER | ||||
|     const int weight_oc_offset = weights_shape.x * weights_shape.y * 4; | ||||
|     const int weight_ic_offset = out_channel_blocks * weight_oc_offset; | ||||
| #endif | ||||
| 
 | ||||
|     int in_width0          = mad24(out_width_block_idx, stride_shape.y, -padding_shape.y); | ||||
|     int in_height0         = mad24(out_height_block_idx, stride_shape.x<<2, -padding_shape.x); | ||||
|     int in_height1         = in_height0 + stride_shape.x; | ||||
|  | @ -595,8 +832,12 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only | |||
|     FLOAT4 weights0, weights1, weights2, weights3, weights4, weights5, weights6, weights7; | ||||
|     for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) { | ||||
|         const int in_idx = mul24(in_channel_block_idx, input_shape.y); | ||||
| #ifdef USE_BUFFER | ||||
|         int weight_offset = ((((4*in_channel_block_idx+0)* out_channel_blocks + out_channel_block_idx) *weights_shape.x + 0)*weights_shape.y + 0) * 4; | ||||
| #else | ||||
|         int weights_x_idx = in_channel_block_idx << 2; | ||||
|         int weights_y_idx = weights_h_idx; | ||||
| #endif | ||||
|         for (int iy = 0; iy < weights_shape.x * dilation_shape.x; iy += dilation_shape.x) { | ||||
|             int h0 =  select(in_height0 + iy + batch_idx, -1, (in_height0 + iy < 0 || in_height0 + iy  >= input_shape.x)); | ||||
|             int h1 =  select(in_height1 + iy + batch_idx, -1, (in_height1 + iy < 0 || in_height1 + iy  >= input_shape.x)); | ||||
|  | @ -610,6 +851,17 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only | |||
|                 in2 = RI_F(input, SAMPLER, (int2)(w0, h2)); | ||||
|                 in3 = RI_F(input, SAMPLER, (int2)(w0, h3)); | ||||
| 
 | ||||
| #ifdef USE_BUFFER | ||||
|                 weights0 = vload4(0, weights+weight_offset); | ||||
|                 weights1 = vload4(0, weights+weight_offset+weight_ic_offset); | ||||
|                 weights2 = vload4(0, weights+weight_offset+weight_ic_offset*2); | ||||
|                 weights3 = vload4(0, weights+weight_offset+weight_ic_offset*3); | ||||
|                 weights4 = vload4(0, weights+weight_offset + weight_oc_offset); | ||||
|                 weights5 = vload4(0, weights+weight_offset+weight_ic_offset + weight_oc_offset); | ||||
|                 weights6 = vload4(0, weights+weight_offset+weight_ic_offset*2 + weight_oc_offset); | ||||
|                 weights7 = vload4(0, weights+weight_offset+weight_ic_offset*3 + weight_oc_offset); | ||||
|                 weight_offset += 4; | ||||
| #else | ||||
|                 weights0 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 0, weights_y_idx)); | ||||
|                 weights1 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 1, weights_y_idx)); | ||||
|                 weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx)); | ||||
|  | @ -618,6 +870,7 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only | |||
|                 weights5 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 1, weight_size + weights_y_idx)); | ||||
|                 weights6 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weight_size + weights_y_idx)); | ||||
|                 weights7 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weight_size + weights_y_idx++)); | ||||
| #endif | ||||
|                  | ||||
|                 CALCULATE_OUTPUT(0); | ||||
|                 CALCULATE_OUTPUT(1); | ||||
|  | @ -703,7 +956,12 @@ __kernel | |||
| #if SET_ATTRIBUTE | ||||
| __attribute__((work_group_size_hint(16, 16, 1))) | ||||
| #endif | ||||
| void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t weights, | ||||
| void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, | ||||
| #ifdef USE_BUFFER | ||||
| __global const FLOAT *weights, | ||||
| #else | ||||
| __read_only image2d_t weights, | ||||
| #endif | ||||
| #ifdef BIAS | ||||
|                       __read_only image2d_t bias, | ||||
| #endif | ||||
|  | @ -749,10 +1007,17 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only | |||
|      | ||||
|     FLOAT4 in0, in1, in2, in3; | ||||
|     FLOAT4 weights0, weights1, weights2, weights3; | ||||
| #ifdef USE_BUFFER | ||||
|     const int weight_oc_offset = out_channel_blocks * weights_shape.x * weights_shape.y * 4; | ||||
| #endif | ||||
|     for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) { | ||||
|         const int in_idx = mul24(in_channel_block_idx, input_shape.y); | ||||
| #ifdef USE_BUFFER | ||||
|         int weight_offset = ((((4*in_channel_block_idx+0)* out_channel_blocks + out_channel_block_idx) *weights_shape.x + 0)*weights_shape.y + 0) * 4; | ||||
| #else | ||||
|         int weights_x_idx = in_channel_block_idx << 2; | ||||
|         int weights_y_idx = weights_h_idx; | ||||
| #endif | ||||
|         for (int iy = 0; iy < weights_shape.x * dilation_shape.x; iy += dilation_shape.x) { | ||||
|             int h0 =  select(in_height0 + iy + batch_idx, -1, (in_height0 + iy < 0 || in_height0 + iy  >= input_shape.x)); | ||||
|             int h1 =  select(in_height1 + iy + batch_idx, -1, (in_height1 + iy < 0 || in_height1 + iy  >= input_shape.x)); | ||||
|  | @ -765,11 +1030,18 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only | |||
|                 in1 = RI_F(input, SAMPLER, (int2)(w0, h1)); | ||||
|                 in2 = RI_F(input, SAMPLER, (int2)(w0, h2)); | ||||
|                 in3 = RI_F(input, SAMPLER, (int2)(w0, h3)); | ||||
| 
 | ||||
| #ifdef USE_BUFFER | ||||
|                 weights0 = vload4(0, weights+weight_offset); | ||||
|                 weights1 = vload4(0, weights+weight_offset+weight_oc_offset); | ||||
|                 weights2 = vload4(0, weights+weight_offset+weight_oc_offset*2); | ||||
|                 weights3 = vload4(0, weights+weight_offset+weight_oc_offset*3); | ||||
|                 weight_offset += 4; | ||||
| #else | ||||
|                 weights0 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 0, weights_y_idx)); | ||||
|                 weights1 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 1, weights_y_idx)); | ||||
|                 weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx)); | ||||
|                 weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++)); | ||||
| #endif | ||||
| 
 | ||||
|                 CALCULATE_OUTPUT(0); | ||||
|                 CALCULATE_OUTPUT(1); | ||||
|  |  | |||
|  | @ -186,9 +186,13 @@ void conv_2d_c4h1w2(GLOBAL_SIZE_2_DIMS | |||
| #endif | ||||
| 
 | ||||
|     const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; | ||||
| #ifdef BLOCK_LEAVE | ||||
|     vstore4(out0, 0, output+out_offset); | ||||
|     if(out_w_idx + 1 >= out_hw.y) return; | ||||
|     vstore4(out1, 1, output+out_offset); | ||||
| #else | ||||
|     vstore8((FLOAT8)(out0, out1), 0, output+out_offset); | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| __kernel | ||||
|  | @ -298,13 +302,22 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS | |||
| #endif | ||||
| 
 | ||||
|     const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; | ||||
|     vstore4(out0, 0, output+out_offset); | ||||
|     if(out_w_idx + 1 >= out_hw.y) return; | ||||
|     vstore4(out1, 1, output+out_offset); | ||||
|     if(out_w_idx + 2 >= out_hw.y) return; | ||||
|     vstore4(out2, 2, output+out_offset); | ||||
|     if(out_w_idx + 3 >= out_hw.y) return; | ||||
|     vstore4(out3, 3, output+out_offset); | ||||
| #ifdef BLOCK_LEAVE | ||||
|     const int remain = out_hw.y - out_w_idx; | ||||
| 
 | ||||
|     if (remain >= 4) { | ||||
|         vstore16((FLOAT16)(out0, out1, out2, out3), 0, output+out_offset); | ||||
|     }else if(remain == 3){ | ||||
|         vstore8((FLOAT8)(out0, out1), 0, output+out_offset); | ||||
|         vstore4(out2, 2, output+out_offset); | ||||
|     }else if(remain == 2){ | ||||
|         vstore8((FLOAT8)(out0, out1), 0, output+out_offset); | ||||
|     }else if(remain == 1){ | ||||
|         vstore4(out0, 0, output+out_offset); | ||||
|     } | ||||
| #else | ||||
|     vstore16((FLOAT16)(out0, out1, out2, out3), 0, output+out_offset); | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| __kernel | ||||
|  | @ -393,25 +406,21 @@ void conv_2d_1x1_c4h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, | |||
| #endif | ||||
| 
 | ||||
|     const int out_offset = (((out_b_idx*out_c_block + out_c_idx)*out_h + out_h_idx)* out_w + out_w4_idx)*4; | ||||
| 
 | ||||
| #ifdef BLOCK_LEAVE | ||||
|     const int remain = out_w - out_w4_idx; | ||||
| 
 | ||||
|     if (remain >= 4) { | ||||
|         vstore4(out0, 0, output+out_offset); | ||||
|         vstore4(out1, 1, output+out_offset); | ||||
|         vstore4(out2, 2, output+out_offset); | ||||
|         vstore4(out3, 3, output+out_offset); | ||||
|         vstore16((FLOAT16)(out0, out1, out2, out3), 0, output+out_offset); | ||||
|     } else if (remain == 3) { | ||||
|         vstore4(out0, 0, output+out_offset); | ||||
|         vstore4(out1, 1, output+out_offset); | ||||
|         vstore8((FLOAT8)(out0, out1), 0, output+out_offset); | ||||
|         vstore4(out2, 2, output+out_offset); | ||||
|     } else if (remain == 2) { | ||||
|         vstore4(out0, 0, output+out_offset); | ||||
|         vstore4(out1, 1, output+out_offset); | ||||
|         vstore8((FLOAT8)(out0, out1), 0, output+out_offset); | ||||
|     } else if (remain == 1) { | ||||
|         vstore4(out0, 0, output+out_offset); | ||||
|     } | ||||
| 
 | ||||
| #else | ||||
|     vstore16((FLOAT16)(out0, out1, out2, out3), 0, output+out_offset); | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
|  | @ -539,44 +548,45 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, | |||
| 
 | ||||
|     const int out_offset = (((out_b_idx*out_c_block + out_c_idx*2)*out_h + out_h_idx)* out_w + out_w4_idx)*4; | ||||
| 
 | ||||
|     const int remain = out_w - out_w4_idx; | ||||
| 
 | ||||
|     __global FLOAT* _tempoutput = output + out_offset; | ||||
|     __global FLOAT* _tempoutput1 = _tempoutput + 4*out_h*out_w; | ||||
| 
 | ||||
| #ifdef BLOCK_LEAVE | ||||
|     const int remain = out_w - out_w4_idx; | ||||
|     if (remain >= 4) { | ||||
|         vstore4(out0, 0, _tempoutput); | ||||
|         vstore4(out1, 1, _tempoutput); | ||||
|         vstore4(out2, 2, _tempoutput); | ||||
|         vstore4(out3, 3, _tempoutput); | ||||
|         vstore16((FLOAT16)(out0, out1, out2, out3), 0, _tempoutput); | ||||
|     } else if (remain == 3) { | ||||
|         vstore4(out0, 0, _tempoutput); | ||||
|         vstore4(out1, 1, _tempoutput); | ||||
|         vstore8((FLOAT8)(out0, out1), 0, _tempoutput); | ||||
|         vstore4(out2, 2, _tempoutput); | ||||
|     } else if (remain == 2) { | ||||
|         vstore4(out0, 0, _tempoutput); | ||||
|         vstore4(out1, 1, _tempoutput); | ||||
|         vstore8((FLOAT8)(out0, out1), 0, _tempoutput); | ||||
|     } else if (remain == 1) { | ||||
|         vstore4(out0, 0, _tempoutput); | ||||
|     } | ||||
| #ifdef CHANNEL_LEAVE | ||||
|     if(out_c_idx*2+1 >= out_c_block) { | ||||
|         return; | ||||
|     } | ||||
| #endif | ||||
|     if (remain >= 4) { | ||||
|         vstore4(out4, 0, _tempoutput1); | ||||
|         vstore4(out5, 1, _tempoutput1); | ||||
|         vstore4(out6, 2, _tempoutput1); | ||||
|         vstore4(out7, 3, _tempoutput1); | ||||
|         vstore16((FLOAT16)(out4, out5, out6, out7), 0, _tempoutput1); | ||||
|     } else if (remain == 3) { | ||||
|         vstore4(out4, 0, _tempoutput1); | ||||
|         vstore4(out5, 1, _tempoutput1); | ||||
|         vstore8((FLOAT8)(out4, out5), 0, _tempoutput1); | ||||
|         vstore4(out6, 2, _tempoutput1); | ||||
|     } else if (remain == 2) { | ||||
|         vstore4(out4, 0, _tempoutput1); | ||||
|         vstore4(out5, 1, _tempoutput1); | ||||
|         vstore8((FLOAT8)(out4, out5), 0, _tempoutput1); | ||||
|     } else if (remain == 1) { | ||||
|         vstore4(out4, 0, _tempoutput1); | ||||
|     } | ||||
| #else | ||||
|     vstore16((FLOAT16)(out0, out1, out2, out3), 0, _tempoutput); | ||||
| #ifdef CHANNEL_LEAVE | ||||
|     if(out_c_idx*2+1 >= out_c_block) { | ||||
|         return; | ||||
|     } | ||||
| #endif | ||||
|     vstore16((FLOAT16)(out4, out5, out6, out7), 0, _tempoutput1); | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
|  | @ -668,26 +678,36 @@ void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, | |||
| 
 | ||||
|     const int out_offset = (((out_b_idx*out_c_block + out_c_idx*2)*out_h + out_h_idx)* out_w + out_w2_idx)*4; | ||||
| 
 | ||||
|     const int remain = out_w - out_w2_idx; | ||||
| 
 | ||||
|     __global FLOAT* _tempoutput = output + out_offset; | ||||
|     __global FLOAT* _tempoutput1 = _tempoutput + 4*out_h*out_w; | ||||
| 
 | ||||
| #ifdef BLOCK_LEAVE | ||||
|     const int remain = out_w - out_w2_idx; | ||||
|     if (remain >= 2) { | ||||
|         vstore4(out0, 0, _tempoutput); | ||||
|         vstore4(out1, 1, _tempoutput); | ||||
|         vstore8((FLOAT8)(out0, out1), 0, _tempoutput); | ||||
|     } else if (remain == 1) { | ||||
|         vstore4(out0, 0, _tempoutput); | ||||
|     } | ||||
| #ifdef CHANNEL_LEAVE | ||||
|     if(out_c_idx*2+1 >= out_c_block) { | ||||
|         return; | ||||
|     } | ||||
| #endif | ||||
|     if (remain >= 2) { | ||||
|         vstore4(out4, 0, _tempoutput1); | ||||
|         vstore4(out5, 1, _tempoutput1); | ||||
|         vstore8((FLOAT8)(out4, out5), 0, _tempoutput1); | ||||
|     } else if (remain == 1) { | ||||
|         vstore4(out4, 0, _tempoutput1); | ||||
|     } | ||||
| #else | ||||
|     vstore8((FLOAT8)(out0, out1), 0, _tempoutput); | ||||
| #ifdef CHANNEL_LEAVE | ||||
|     if(out_c_idx*2+1 >= out_c_block) { | ||||
|         return; | ||||
|     } | ||||
| #endif | ||||
|     vstore8((FLOAT8)(out4, out5), 0, _tempoutput1); | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| __kernel | ||||
|  | @ -814,15 +834,17 @@ void conv_2d_1x1_c4h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, | |||
| 
 | ||||
|     const int out_offset = (((out_b_idx*out_c_block + out_c_idx)*out_h + out_h_idx)* out_w + out_w2_idx)*4; | ||||
| 
 | ||||
| #ifdef BLOCK_LEAVE | ||||
|     const int remain = out_w - out_w2_idx; | ||||
| 
 | ||||
|     if (remain >= 2) { | ||||
|         vstore4(out0, 0, output+out_offset); | ||||
|         vstore4(out1, 1, output+out_offset); | ||||
|         vstore8((FLOAT8)(out0, out1), 0, output+out_offset); | ||||
|     } else if (remain == 1) { | ||||
|         vstore4(out0, 0, output+out_offset); | ||||
|     } | ||||
| 
 | ||||
| #else | ||||
|     vstore8((FLOAT8)(out0, out1), 0, output+out_offset); | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| __kernel | ||||
|  | @ -932,13 +954,29 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS | |||
| #endif | ||||
| 
 | ||||
|     const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; | ||||
| #ifdef BLOCK_LEAVE | ||||
|     const int remain = out_hw.x - out_h_idx; | ||||
|     if(remain >= 4){ | ||||
|         vstore4(out0, 0, output+out_offset); | ||||
|         vstore4(out1, out_hw.y, output+out_offset); | ||||
|         vstore4(out2, 2 * out_hw.y, output+out_offset); | ||||
|         vstore4(out3, 3 * out_hw.y, output+out_offset); | ||||
|     }else if(remain == 3){ | ||||
|         vstore4(out0, 0, output+out_offset); | ||||
|         vstore4(out1, out_hw.y, output+out_offset); | ||||
|         vstore4(out2, 2 * out_hw.y, output+out_offset); | ||||
|     }else if(remain == 2){ | ||||
|         vstore4(out0, 0, output+out_offset); | ||||
|         vstore4(out1, out_hw.y, output+out_offset); | ||||
|     }else if(remain == 1){ | ||||
|         vstore4(out0, 0, output+out_offset); | ||||
|     } | ||||
| #else | ||||
|     vstore4(out0, 0, output+out_offset); | ||||
|     if(out_h_idx + 1 >= out_hw.x) return; | ||||
|     vstore4(out1, out_hw.y, output+out_offset); | ||||
|     if(out_h_idx + 2 >= out_hw.x) return; | ||||
|     vstore4(out2, 2 * out_hw.y, output+out_offset); | ||||
|     if(out_h_idx + 3 >= out_hw.x) return; | ||||
|     vstore4(out3, 3 * out_hw.y, output+out_offset); | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| __kernel | ||||
|  | @ -1086,6 +1124,7 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS | |||
| #endif | ||||
| 
 | ||||
|     int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; | ||||
| #ifdef BLOCK_LEAVE | ||||
|     const int remain = out_hw.x - out_h_idx; | ||||
|     if(remain >= 4){ | ||||
|         vstore4(out0, 0, output+out_offset); | ||||
|  | @ -1102,9 +1141,11 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS | |||
|     }else if(remain == 1){ | ||||
|         vstore4(out0, 0, output+out_offset); | ||||
|     } | ||||
| #ifdef CHANNEL_LEAVE | ||||
|     if(out_c_idx + 1 >= out_c_blocks){ | ||||
|         return; | ||||
|     } | ||||
| #endif | ||||
|     out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; | ||||
|     if(remain >= 4){ | ||||
|         vstore4(out4, 0, output+out_offset); | ||||
|  | @ -1121,6 +1162,22 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS | |||
|     }else if(remain == 1){ | ||||
|         vstore4(out4, 0, output+out_offset); | ||||
|     } | ||||
| #else | ||||
|     vstore4(out0, 0, output+out_offset); | ||||
|     vstore4(out1, out_hw.y, output+out_offset); | ||||
|     vstore4(out2, 2 * out_hw.y, output+out_offset); | ||||
|     vstore4(out3, 3 * out_hw.y, output+out_offset); | ||||
| #ifdef CHANNEL_LEAVE | ||||
|     if(out_c_idx + 1 >= out_c_blocks){ | ||||
|         return; | ||||
|     } | ||||
| #endif | ||||
|     out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; | ||||
|     vstore4(out4, 0, output+out_offset); | ||||
|     vstore4(out5, out_hw.y, output+out_offset); | ||||
|     vstore4(out6, 2 * out_hw.y, output+out_offset); | ||||
|     vstore4(out7, 3 * out_hw.y, output+out_offset); | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| __kernel | ||||
|  | @ -1230,6 +1287,7 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS | |||
| #endif | ||||
| 
 | ||||
|     int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; | ||||
| #ifdef BLOCK_LEAVE | ||||
|     const int remain = out_hw.x - out_h_idx; | ||||
|     if(remain >= 2){ | ||||
|         vstore4(out0, 0, output+out_offset); | ||||
|  | @ -1237,9 +1295,11 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS | |||
|     }else if(remain == 1){ | ||||
|         vstore4(out0, 0, output+out_offset); | ||||
|     } | ||||
| #ifdef CHANNEL_LEAVE | ||||
|     if(out_c_idx + 1 >= out_c_blocks){ | ||||
|         return; | ||||
|     } | ||||
| #endif | ||||
|     out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; | ||||
|     if(remain >= 2){ | ||||
|         vstore4(out2, 0, output+out_offset); | ||||
|  | @ -1247,4 +1307,198 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS | |||
|     }else if(remain == 1){ | ||||
|         vstore4(out2, 0, output+out_offset); | ||||
|     } | ||||
| #else | ||||
|     vstore4(out0, 0, output+out_offset); | ||||
|     vstore4(out1, out_hw.y, output+out_offset); | ||||
| #ifdef CHANNEL_LEAVE | ||||
|     if(out_c_idx + 1 >= out_c_blocks){ | ||||
|         return; | ||||
|     } | ||||
| #endif | ||||
|     out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; | ||||
|     vstore4(out2, 0, output+out_offset); | ||||
|     vstore4(out3, out_hw.y, output+out_offset); | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| __kernel | ||||
| void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS | ||||
|                       __global const FLOAT *input, | ||||
|                       __global const FLOAT *weight, | ||||
|                       __global const FLOAT *bias, | ||||
|                       __global FLOAT *output, | ||||
|                       __private const int2 in_hw, | ||||
|                       __private const int inChannel, | ||||
|                       __private const int in_c_blocks, | ||||
|                       __private const int2 out_hw, | ||||
|                       __private const int2 filter_hw, | ||||
|                       __private const int2 stride_hw, | ||||
|                       __private const int2 pad_hw, | ||||
|                       __private const int2 dilate_hw, | ||||
|                       __private const int out_w_blocks, | ||||
|                       __private const int out_c_blocks, | ||||
|                       __private const int out_h_blocks) { | ||||
|     const int out_c_w_idx = get_global_id(0); //c/4 w | ||||
|     const int out_b_h_idx  = get_global_id(1); //b h | ||||
| 
 | ||||
|     DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx); | ||||
| 
 | ||||
|     const int out_c_idx = (out_c_w_idx / out_w_blocks) << 1; | ||||
|     const int out_w_idx = (out_c_w_idx % out_w_blocks) << 2; | ||||
|     const int out_b_idx = out_b_h_idx / out_hw.x;//equal to in_b_idx | ||||
|     const int out_h_idx = out_b_h_idx % out_hw.x; | ||||
|      | ||||
|     FLOAT4 out0 = vload4(out_c_idx, bias); | ||||
|     FLOAT4 out1 = out0; | ||||
|     FLOAT4 out2 = out0; | ||||
|     FLOAT4 out3 = out0; | ||||
|      | ||||
|     FLOAT4 out4 = vload4(out_c_idx + 1, bias); | ||||
|     FLOAT4 out5 = out4; | ||||
|     FLOAT4 out6 = out4; | ||||
|     FLOAT4 out7 = out4; | ||||
| 
 | ||||
|     const int in_w0_idx_base = mad24(out_w_idx, stride_hw.y, -pad_hw.y); | ||||
|     const int in_w1_idx_base = in_w0_idx_base + stride_hw.y; | ||||
|     const int in_w2_idx_base = in_w1_idx_base + stride_hw.y; | ||||
|     const int in_w3_idx_base = in_w2_idx_base + stride_hw.y; | ||||
| 
 | ||||
|     const int in_h_idx_base = mad24(out_h_idx, stride_hw.x, -pad_hw.x); | ||||
|      | ||||
|     const int kh_start = select(0, (-in_h_idx_base + dilate_hw.x - 1) / dilate_hw.x, in_h_idx_base < 0); | ||||
|     const int in_h_idx_start = mad24(kh_start, dilate_hw.x, in_h_idx_base); | ||||
|     const int in_h_idx_end = min(mad24(filter_hw.x, dilate_hw.x, in_h_idx_base), in_hw.x); | ||||
|      | ||||
|     const int weight_oc_offset = filter_hw.x * filter_hw.y * 4; | ||||
|     const int weight_ic_offset = out_c_blocks * weight_oc_offset; | ||||
|     for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) { | ||||
|         //weights  NC4HW4  [1,  4*icC4,  ocC4*kh*kw,  1] xic4 | ||||
|         //index:   [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] | ||||
|         int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4; | ||||
| 
 | ||||
|         for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) { | ||||
|             const int inp_offset_base = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + 0) * 4; | ||||
| 
 | ||||
|             for(int fw = 0; fw < filter_hw.y; fw++) { | ||||
|                 const int in_w0_idx = fw * dilate_hw.y + in_w0_idx_base; | ||||
|                 const int in_w1_idx = fw * dilate_hw.y + in_w1_idx_base; | ||||
|                 const int in_w2_idx = fw * dilate_hw.y + in_w2_idx_base; | ||||
|                 const int in_w3_idx = fw * dilate_hw.y + in_w3_idx_base; | ||||
| 
 | ||||
|                 FLOAT4 in0 = (in_w0_idx < 0 || in_w0_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w0_idx, input+inp_offset_base); | ||||
|                 FLOAT4 in1 = (in_w1_idx < 0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx, input+inp_offset_base); | ||||
|                 FLOAT4 in2 = (in_w2_idx < 0 || in_w2_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w2_idx, input+inp_offset_base); | ||||
|                 FLOAT4 in3 = (in_w3_idx < 0 || in_w3_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w3_idx, input+inp_offset_base); | ||||
| 
 | ||||
|                 FLOAT4 weight0 = vload4(0, weight+weight_offset); | ||||
|                 FLOAT4 weight1 = vload4(0, weight+weight_offset+weight_ic_offset); | ||||
|                 FLOAT4 weight2 = vload4(0, weight+weight_offset+weight_ic_offset*2); | ||||
|                 FLOAT4 weight3 = vload4(0, weight+weight_offset+weight_ic_offset*3); | ||||
|                  | ||||
|                 out0 = mad(in0.x, weight0, out0); | ||||
|                 out0 = mad(in0.y, weight1, out0); | ||||
|                 out0 = mad(in0.z, weight2, out0); | ||||
|                 out0 = mad(in0.w, weight3, out0); | ||||
|                  | ||||
|                 out1 = mad(in1.x, weight0, out1); | ||||
|                 out1 = mad(in1.y, weight1, out1); | ||||
|                 out1 = mad(in1.z, weight2, out1); | ||||
|                 out1 = mad(in1.w, weight3, out1); | ||||
|                  | ||||
|                 out2 = mad(in2.x, weight0, out2); | ||||
|                 out2 = mad(in2.y, weight1, out2); | ||||
|                 out2 = mad(in2.z, weight2, out2); | ||||
|                 out2 = mad(in2.w, weight3, out2); | ||||
|                  | ||||
|                 out3 = mad(in3.x, weight0, out3); | ||||
|                 out3 = mad(in3.y, weight1, out3); | ||||
|                 out3 = mad(in3.z, weight2, out3); | ||||
|                 out3 = mad(in3.w, weight3, out3); | ||||
|                  | ||||
|                 weight0 = vload4(0, weight+weight_oc_offset+weight_offset); | ||||
|                 weight1 = vload4(0, weight+weight_oc_offset+weight_offset+weight_ic_offset); | ||||
|                 weight2 = vload4(0, weight+weight_oc_offset+weight_offset+weight_ic_offset*2); | ||||
|                 weight3 = vload4(0, weight+weight_oc_offset+weight_offset+weight_ic_offset*3); | ||||
|                  | ||||
|                 out4 = mad(in0.x, weight0, out4); | ||||
|                 out4 = mad(in0.y, weight1, out4); | ||||
|                 out4 = mad(in0.z, weight2, out4); | ||||
|                 out4 = mad(in0.w, weight3, out4); | ||||
|                  | ||||
|                 out5 = mad(in1.x, weight0, out5); | ||||
|                 out5 = mad(in1.y, weight1, out5); | ||||
|                 out5 = mad(in1.z, weight2, out5); | ||||
|                 out5 = mad(in1.w, weight3, out5); | ||||
|                  | ||||
|                 out6 = mad(in2.x, weight0, out6); | ||||
|                 out6 = mad(in2.y, weight1, out6); | ||||
|                 out6 = mad(in2.z, weight2, out6); | ||||
|                 out6 = mad(in2.w, weight3, out6); | ||||
|                  | ||||
|                 out7 = mad(in3.x, weight0, out7); | ||||
|                 out7 = mad(in3.y, weight1, out7); | ||||
|                 out7 = mad(in3.z, weight2, out7); | ||||
|                 out7 = mad(in3.w, weight3, out7); | ||||
|                  | ||||
|                 weight_offset += 4; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| #ifdef RELU | ||||
|     out0 = fmax(out0, (FLOAT4)0); | ||||
|     out1 = fmax(out1, (FLOAT4)0); | ||||
|     out2 = fmax(out2, (FLOAT4)0); | ||||
|     out3 = fmax(out3, (FLOAT4)0); | ||||
|     out4 = fmax(out4, (FLOAT4)0); | ||||
|     out5 = fmax(out5, (FLOAT4)0); | ||||
|     out6 = fmax(out6, (FLOAT4)0); | ||||
|     out7 = fmax(out7, (FLOAT4)0); | ||||
| #endif | ||||
| 
 | ||||
| #ifdef RELU6 | ||||
|     out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6); | ||||
|     out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6); | ||||
|     out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6); | ||||
|     out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6); | ||||
|     out4 = clamp(out4, (FLOAT4)0, (FLOAT4)6); | ||||
|     out5 = clamp(out5, (FLOAT4)0, (FLOAT4)6); | ||||
|     out6 = clamp(out6, (FLOAT4)0, (FLOAT4)6); | ||||
|     out7 = clamp(out7, (FLOAT4)0, (FLOAT4)6); | ||||
| #endif | ||||
| 
 | ||||
|     int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; | ||||
| #ifdef BLOCK_LEAVE | ||||
|     const int remain = out_hw.y - out_w_idx; | ||||
|     if(remain >= 4){ | ||||
|         vstore16((FLOAT16)(out0, out1, out2, out3), 0, output+out_offset); | ||||
|     }else if(remain == 3){ | ||||
|         vstore8((FLOAT8)(out0, out1), 0, output+out_offset); | ||||
|         vstore4(out2, 2, output+out_offset); | ||||
|     }else if(remain == 2){ | ||||
|         vstore8((FLOAT8)(out0, out1), 0, output+out_offset); | ||||
|     }else if(remain == 1){ | ||||
|         vstore4(out0, 0, output+out_offset); | ||||
|     } | ||||
| #ifdef CHANNEL_LEAVE | ||||
|     if(out_c_idx + 1 >= out_c_blocks)return; | ||||
| #endif | ||||
|     out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; | ||||
|     if(remain >= 4){ | ||||
|         vstore16((FLOAT16)(out4, out5, out6, out7), 0, output+out_offset); | ||||
|     }else if(remain == 3){ | ||||
|         vstore8((FLOAT8)(out4, out5), 0, output+out_offset); | ||||
|         vstore4(out6, 2, output+out_offset); | ||||
|     }else if(remain == 2){ | ||||
|         vstore8((FLOAT8)(out4, out5), 0, output+out_offset); | ||||
|     }else if(remain == 1){ | ||||
|         vstore4(out4, 0, output+out_offset); | ||||
|     } | ||||
| #else | ||||
|     vstore16((FLOAT16)(out0, out1, out2, out3), 0, output+out_offset); | ||||
| #ifdef CHANNEL_LEAVE | ||||
|     if(out_c_idx + 1 >= out_c_blocks)return; | ||||
| #endif | ||||
|     out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; | ||||
|     vstore16((FLOAT16)(out4, out5, out6, out7), 0, output+out_offset); | ||||
| #endif | ||||
| } | ||||
|  |  | |||
|  | @ -116,12 +116,154 @@ __kernel void gemmWinograd(__read_only image2d_t uInput, __read_only image2d_t u | |||
| 
 | ||||
|         __private int out_y_idx = mad24(pos_z, unitHeight, pos_y); | ||||
|         __private int out_x_idx = mad24(pos_w, unitWidth, srcX); | ||||
|         WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0); | ||||
|         if(srcX + 1 >= unitWidth) return; | ||||
|         WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1); | ||||
|         if(srcX + 2 >= unitWidth) return; | ||||
|         WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2); | ||||
|         if(srcX + 3 >= unitWidth) return; | ||||
|         WI_F(uOutput, (int2)(out_x_idx + 3, out_y_idx), o3); | ||||
|         const int remain = unitWidth - srcX; | ||||
|         if(remain >= 4){ | ||||
|             WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 3, out_y_idx), o3); | ||||
|         }else if(remain == 3){ | ||||
|             WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2); | ||||
|         }else if(remain == 2){ | ||||
|             WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1); | ||||
|         }else if(remain == 1){ | ||||
|             WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| __kernel void gemmWinogradW2(__read_only image2d_t uInput, __read_only image2d_t uKernel, __write_only image2d_t uOutput, | ||||
|                    __private const int unitWidth, __private const int unitHeight, __private const int dstChannelC4, __private const int multiLength, __private const int alpha2) { | ||||
|      | ||||
|     int2 pos = (int2)(get_global_id(0), get_global_id(1)); | ||||
|     const int unitWidth8 = (unitWidth + 7) / 8; | ||||
|     if (pos.x < unitWidth8 * unitHeight && pos.y < alpha2 * dstChannelC4) { | ||||
|          | ||||
|         const int pos_x = pos.x % unitWidth8; | ||||
|         const int pos_y = pos.x / unitWidth8; | ||||
|         const int pos_z = pos.y % dstChannelC4; | ||||
|         const int pos_w = pos.y / dstChannelC4; | ||||
| 
 | ||||
|         FLOAT4 o0 = (FLOAT4)(0); | ||||
|         FLOAT4 o1 = (FLOAT4)(0); | ||||
|         FLOAT4 o2 = (FLOAT4)(0); | ||||
|         FLOAT4 o3 = (FLOAT4)(0); | ||||
|         FLOAT4 o4 = (FLOAT4)(0); | ||||
|         FLOAT4 o5 = (FLOAT4)(0); | ||||
|         FLOAT4 o6 = (FLOAT4)(0); | ||||
|         FLOAT4 o7 = (FLOAT4)(0); | ||||
|         int srcY = mad24(pos_w, unitHeight, pos_y); | ||||
|         int srcX = pos_x << 3; | ||||
| 
 | ||||
|         for (int k = 0; k < multiLength; ++k) { | ||||
|             __private int index = mul24(k, 4); | ||||
|             __private int x_offset = mul24(k, unitWidth); | ||||
|             FLOAT4 k0 = RI_F(uKernel, SAMPLER, (int2)(index, pos.y)); | ||||
|             FLOAT4 k1 = RI_F(uKernel, SAMPLER, (int2)(index + 1, pos.y)); | ||||
|             FLOAT4 k2 = RI_F(uKernel, SAMPLER, (int2)(index + 2, pos.y)); | ||||
|             FLOAT4 k3 = RI_F(uKernel, SAMPLER, (int2)(index + 3, pos.y)); | ||||
| 
 | ||||
|             FLOAT4 s0 = RI_F(uInput, SAMPLER, (int2)(srcX + x_offset, srcY)); | ||||
|             FLOAT4 s1 = RI_F(uInput, SAMPLER, (int2)(srcX + x_offset + 1, srcY)); | ||||
|             FLOAT4 s2 = RI_F(uInput, SAMPLER, (int2)(srcX + x_offset + 2, srcY)); | ||||
|             FLOAT4 s3 = RI_F(uInput, SAMPLER, (int2)(srcX + x_offset + 3, srcY)); | ||||
|             FLOAT4 s4 = RI_F(uInput, SAMPLER, (int2)(srcX + x_offset + 4, srcY)); | ||||
|             FLOAT4 s5 = RI_F(uInput, SAMPLER, (int2)(srcX + x_offset + 5, srcY)); | ||||
|             FLOAT4 s6 = RI_F(uInput, SAMPLER, (int2)(srcX + x_offset + 6, srcY)); | ||||
|             FLOAT4 s7 = RI_F(uInput, SAMPLER, (int2)(srcX + x_offset + 7, srcY)); | ||||
| 
 | ||||
|             o0 = mad(s0.x, k0, o0); | ||||
|             o0 = mad(s0.y, k1, o0); | ||||
|             o0 = mad(s0.z, k2, o0); | ||||
|             o0 = mad(s0.w, k3, o0); | ||||
| 
 | ||||
|             o1 = mad(s1.x, k0, o1); | ||||
|             o1 = mad(s1.y, k1, o1); | ||||
|             o1 = mad(s1.z, k2, o1); | ||||
|             o1 = mad(s1.w, k3, o1); | ||||
| 
 | ||||
|             o2 = mad(s2.x, k0, o2); | ||||
|             o2 = mad(s2.y, k1, o2); | ||||
|             o2 = mad(s2.z, k2, o2); | ||||
|             o2 = mad(s2.w, k3, o2); | ||||
| 
 | ||||
|             o3 = mad(s3.x, k0, o3); | ||||
|             o3 = mad(s3.y, k1, o3); | ||||
|             o3 = mad(s3.z, k2, o3); | ||||
|             o3 = mad(s3.w, k3, o3); | ||||
|              | ||||
|             o4 = mad(s4.x, k0, o4); | ||||
|             o4 = mad(s4.y, k1, o4); | ||||
|             o4 = mad(s4.z, k2, o4); | ||||
|             o4 = mad(s4.w, k3, o4); | ||||
| 
 | ||||
|             o5 = mad(s5.x, k0, o5); | ||||
|             o5 = mad(s5.y, k1, o5); | ||||
|             o5 = mad(s5.z, k2, o5); | ||||
|             o5 = mad(s5.w, k3, o5); | ||||
| 
 | ||||
|             o6 = mad(s6.x, k0, o6); | ||||
|             o6 = mad(s6.y, k1, o6); | ||||
|             o6 = mad(s6.z, k2, o6); | ||||
|             o6 = mad(s6.w, k3, o6); | ||||
| 
 | ||||
|             o7 = mad(s7.x, k0, o7); | ||||
|             o7 = mad(s7.y, k1, o7); | ||||
|             o7 = mad(s7.z, k2, o7); | ||||
|             o7 = mad(s7.w, k3, o7); | ||||
|         } | ||||
| 
 | ||||
|         __private int out_y_idx = mad24(pos_z, unitHeight, pos_y); | ||||
|         __private int out_x_idx = mad24(pos_w, unitWidth, srcX); | ||||
|         const int remain = unitWidth - srcX; | ||||
|         if(remain >= 8){ | ||||
|             WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 3, out_y_idx), o3); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 4, out_y_idx), o4); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 5, out_y_idx), o5); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 6, out_y_idx), o6); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 7, out_y_idx), o7); | ||||
|         }else if(remain == 7){ | ||||
|             WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 3, out_y_idx), o3); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 4, out_y_idx), o4); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 5, out_y_idx), o5); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 6, out_y_idx), o6); | ||||
|         }else if(remain == 6){ | ||||
|             WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 3, out_y_idx), o3); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 4, out_y_idx), o4); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 5, out_y_idx), o5); | ||||
|         }else if(remain == 5){ | ||||
|             WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 3, out_y_idx), o3); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 4, out_y_idx), o4); | ||||
|         }else if(remain == 4){ | ||||
|             WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 3, out_y_idx), o3); | ||||
|         }else if(remain == 3){ | ||||
|             WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2); | ||||
|         }else if(remain == 2){ | ||||
|             WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0); | ||||
|             WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1); | ||||
|         }else if(remain == 1){ | ||||
|             WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  |  | |||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							|  | @ -68,6 +68,35 @@ __kernel void raster_buffer( | |||
|     output[outputIndex] = input[inputIndex]; | ||||
| } | ||||
| 
 | ||||
| __kernel void raster_buffer_combine( | ||||
|                     GLOBAL_SIZE_3_DIMS | ||||
|                     __global FLOAT *input, | ||||
|                     __private const int inputOffset, | ||||
|                     __private const int combineSrcOffset, | ||||
|                     __private const int inputStride0, | ||||
|                     __private const int inputStride1, | ||||
|                     __private const int inputStride2, | ||||
|                     __global FLOAT *output, | ||||
|                     __private const int outputOffset, | ||||
|                     __private const int combineDstOffset, | ||||
|                     __private const int outputStride0, | ||||
|                     __private const int outputStride1, | ||||
|                     __private const int outputStride2, | ||||
|                     __private const int global_size0 | ||||
|                     ) { | ||||
|     const int idx = get_global_id(0); | ||||
|     const int y = get_global_id(1); | ||||
|     const int z = get_global_id(2); | ||||
|      | ||||
|     DEAL_NON_UNIFORM_DIM3(idx, y, z); | ||||
|     const int x = idx % global_size0; | ||||
|     const int id = idx / global_size0; | ||||
|      | ||||
|     int inputIndex = inputOffset + id * combineSrcOffset + z * inputStride0 + y * inputStride1 + x * inputStride2; | ||||
|     int outputIndex = outputOffset + id * combineDstOffset + z * outputStride0 + y * outputStride1 + x * outputStride2; | ||||
|     output[outputIndex] = input[inputIndex]; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| __kernel void raster_image( | ||||
|                     GLOBAL_SIZE_3_DIMS | ||||
|  |  | |||
|  | @ -20,7 +20,8 @@ ErrorCode CommonExecution::onExecute(const std::vector<Tensor *> &inputs, const | |||
|     int idx = 0; | ||||
| #else | ||||
|     if(runtime->isUseRecordQueue()){ | ||||
|         runtime->getRecordings()->emplace_back(mRecording); | ||||
|         if(runtime->isDevideOpRecord()) | ||||
|             runtime->getRecordings()->emplace_back(mRecording); | ||||
|         return NO_ERROR; | ||||
|     } | ||||
| #endif | ||||
|  |  | |||
|  | @ -77,6 +77,8 @@ ConvExecution::ConvExecution(const std::vector<Tensor *> &inputs, const std::vec | |||
|     int kernelWidth   = conv2dCommonParams->kernelX(); | ||||
|     int kernelHeight  = conv2dCommonParams->kernelY(); | ||||
|     int outputChannel = conv2dCommonParams->outputCount(); | ||||
|     auto gpuType = mOpenCLBackend->getOpenCLRuntime()->getGpuType(); | ||||
|     mWeightUseBuffer = gpuType == GpuType::MALI; | ||||
| 
 | ||||
|     int weightSize             = 0; | ||||
|     const float *filterDataPtr = nullptr; | ||||
|  | @ -103,13 +105,12 @@ ConvExecution::ConvExecution(const std::vector<Tensor *> &inputs, const std::vec | |||
|     } | ||||
|     int inputChannel = weightSize / (kernelWidth * kernelHeight * outputChannel); | ||||
| 
 | ||||
|     auto gpuType = mOpenCLBackend->getOpenCLRuntime()->getGpuType(); | ||||
| 
 | ||||
|     //select opt conv method
 | ||||
|     std::string kernelName = "conv_2d_c4h1w4"; | ||||
|     if (kernelHeight == kernelWidth && kernelHeight == 1 && mPaddings[0] == 0 && | ||||
|         mPaddings[1] == 0) { | ||||
|         mConv1x1Opt = (mStrides[0] == 1 && mStrides[1] == 1 && gpuType == GpuType::MALI); | ||||
|         mConv1x1Opt = (mStrides[0] == 1 && mStrides[1] == 1 && gpuType == GpuType::MALI && !mWeightUseBuffer); | ||||
| #if 0 | ||||
|         if((gpuType == GpuType::ADRENO)){ | ||||
|             uint64_t useLocalSize = UNIT*UNIT*4*sizeof(float)*4; | ||||
|  | @ -193,6 +194,36 @@ ConvExecution::ConvExecution(const std::vector<Tensor *> &inputs, const std::vec | |||
|         } | ||||
|         mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*(mBiasBuffer.get()), biasPtrCL); | ||||
| 
 | ||||
|     }else if(kernelHeight == kernelWidth && kernelHeight == 1 && mPaddings[0] == 0 && mPaddings[1] == 0 && mWeightUseBuffer){ | ||||
|         cl_int error; | ||||
|         std::shared_ptr<Tensor> filterBuffer(Tensor::createDevice<float>({UP_DIV(outputChannel, 4), ROUND_UP(inputChannel, 4), 4})); | ||||
|          | ||||
|         int buffer_size = filterBuffer->elementSize(); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()) { | ||||
|             buffer_size *= sizeof(half_float::half); | ||||
|         } else { | ||||
|             buffer_size *= sizeof(float); | ||||
|         } | ||||
|          | ||||
|         mKernelBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size)); | ||||
|         auto kernelBufferPtr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*(mKernelBuffer.get()), true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error); | ||||
|         if(kernelBufferPtr != nullptr && error == CL_SUCCESS){ | ||||
|             ::memset(kernelBufferPtr, 0, buffer_size); | ||||
|             for(int o = 0; o < outputChannel; o++){ | ||||
|                 for(int i = 0 ; i < inputChannel; i++){ | ||||
|                     int bufferIdx = (o/4) * ROUND_UP(inputChannel, 4)*4 + i*4 + (o%4); | ||||
|                     int filterIdx = o*inputChannel + i; | ||||
|                     if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()){ | ||||
|                         ((half_float::half*)kernelBufferPtr)[bufferIdx] = (half_float::half)(filterDataPtr[filterIdx]); | ||||
|                     }else{ | ||||
|                         ((float*)kernelBufferPtr)[bufferIdx] = (float)(filterDataPtr[filterIdx]); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         }else{ | ||||
|             MNN_ERROR("Map error ptrCL == nullptr \n"); | ||||
|         } | ||||
|         mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*(mKernelBuffer.get()), kernelBufferPtr); | ||||
|     }else{ | ||||
|         std::vector<int> filterImageShape{(int)inputChannel, (int)(UP_DIV(outputChannel, 4) * kernelWidth * kernelHeight)}; | ||||
|         std::shared_ptr<Tensor> filterBuffer( | ||||
|  | @ -223,15 +254,34 @@ ConvExecution::ConvExecution(const std::vector<Tensor *> &inputs, const std::vec | |||
|         } | ||||
|         mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBufferCL, ptrCL); | ||||
| 
 | ||||
|         mFilter.reset(Tensor::createDevice<float>({1, filterImageShape[1], 1, 4 * filterImageShape[0]})); | ||||
|         mOpenCLBackend->onAcquireBuffer(mFilter.get(), Backend::STATIC); | ||||
|         MNN::OpenCL::ImageBufferConvertor imageBufferConvertor{mOpenCLBackend->getOpenCLRuntime()}; | ||||
|         if(mWeightUseBuffer){ | ||||
|             mFilter.reset(Tensor::createDevice<float>({UP_DIV(inputChannel, 4)*4, UP_DIV(outputChannel, 4), kernelWidth * kernelHeight, 4})); | ||||
|             int kernel_buffer_size = UP_DIV(outputChannel, 4)*4* UP_DIV(inputChannel, 4)*4* kernelWidth* kernelHeight; | ||||
|             if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()) { | ||||
|                 kernel_buffer_size *= sizeof(half_float::half); | ||||
|             } else { | ||||
|                 kernel_buffer_size *= sizeof(float); | ||||
|             } | ||||
|             mKernelBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, kernel_buffer_size)); | ||||
|             mFilter.get()->buffer().device = (uint64_t)mKernelBuffer.get(); | ||||
|             MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()}; | ||||
|              | ||||
|         std::string buildOption = ""; | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf() == false){ | ||||
|             buildOption = "-DBUFFER_INP_FP32"; | ||||
|             bool needTrans = false; | ||||
|             if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf() == false){ | ||||
|                 needTrans = true; | ||||
|             } | ||||
|             bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mFilter.get(), needTrans); | ||||
|         } else{ | ||||
|             mFilter.reset(Tensor::createDevice<float>({1, filterImageShape[1], 1, 4 * filterImageShape[0]})); | ||||
|             mOpenCLBackend->onAcquireBuffer(mFilter.get(), Backend::STATIC); | ||||
|             MNN::OpenCL::ImageBufferConvertor imageBufferConvertor{mOpenCLBackend->getOpenCLRuntime()}; | ||||
|              | ||||
|             std::string buildOption = ""; | ||||
|             if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf() == false){ | ||||
|                 buildOption = "-DBUFFER_INP_FP32"; | ||||
|             } | ||||
|             imageBufferConvertor.convertBufferToImage(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mFilter.get(), false, buildOption); | ||||
|         } | ||||
|         imageBufferConvertor.convertBufferToImage(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mFilter.get(), false, buildOption); | ||||
|     } | ||||
| 
 | ||||
|     // Create Kernel
 | ||||
|  | @ -244,6 +294,9 @@ ConvExecution::ConvExecution(const std::vector<Tensor *> &inputs, const std::vec | |||
|     } else if (mConv2dCommonParams->relu6()) { | ||||
|         mBuildOptions.emplace("-DRELU6"); | ||||
|     } | ||||
|     if(mWeightUseBuffer){ | ||||
|         mBuildOptions.emplace("-DUSE_BUFFER"); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     mKernel           = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName, mBuildOptions); | ||||
|  | @ -255,7 +308,7 @@ ConvExecution::ConvExecution(const std::vector<Tensor *> &inputs, const std::vec | |||
| } | ||||
| 
 | ||||
| ConvExecution::~ConvExecution() { | ||||
|     if(mUseLocalMem || !mConv1x1Opt){ | ||||
|     if((mUseLocalMem || !mConv1x1Opt) && !mWeightUseBuffer){ | ||||
|         mOpenCLBackend->onReleaseBuffer(mFilter.get(), Backend::STATIC); | ||||
|     } | ||||
| } | ||||
|  | @ -329,28 +382,77 @@ ErrorCode ConvExecution::onResize(const std::vector<Tensor *> &inputs, const std | |||
| 
 | ||||
| 
 | ||||
|         }else{ | ||||
|             mGlobalWorkSize = { | ||||
|             static_cast<uint32_t>(UP_DIV(outputShape.at(3), 4) * static_cast<uint32_t>(UP_DIV(outputShape.at(2), 4))), | ||||
|             static_cast<uint32_t>(outputShape.at(0) * outputShape.at(1))}; | ||||
|              | ||||
|             auto kernel             = &mKernel; | ||||
|             uint32_t idx            = 0; | ||||
|             int inputImageShape[2]  = {inputHeight, inputWidth}; | ||||
|             int outputImageShape[2] = {height, width}; | ||||
|             int stideShape[2]       = {mStrides[0], mStrides[1]}; | ||||
|             kernel->setArg(idx++, mGlobalWorkSize[0]); | ||||
|             kernel->setArg(idx++, mGlobalWorkSize[1]); | ||||
|             kernel->setArg(idx++, openCLImage(input)); | ||||
|             kernel->setArg(idx++, openCLImage(mFilter.get())); | ||||
|             kernel->setArg(idx++, openCLImage(mBias.get())); | ||||
|             kernel->setArg(idx++, openCLImage(output)); | ||||
|             kernel->setArg(idx++, sizeof(inputImageShape), inputImageShape); | ||||
|             kernel->setArg(idx++, static_cast<int>(inputChannelBlocks)); | ||||
|             kernel->setArg(idx++, sizeof(outputImageShape), outputImageShape); | ||||
|             kernel->setArg(idx++, sizeof(stideShape), stideShape); | ||||
|             kernel->setArg(idx++, UP_DIV(width, 4)); | ||||
|             std::string kernelName = "conv_2d_1x1"; | ||||
|             mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName, mKernel).first; | ||||
|             const int total_kernel = 2; | ||||
|             std::string kernelName[total_kernel] = {"conv_2d_1x1", "conv_2d_1x1_c8h1w4"}; | ||||
|             int itemC[total_kernel] = {4, 8}; | ||||
|             int itemH[total_kernel] = {1, 1}; | ||||
|             int itemW[total_kernel] = {4, 4}; | ||||
|              | ||||
|             int actual_kernel = total_kernel; | ||||
| 
 | ||||
|             cl::Kernel kernel[total_kernel]; | ||||
|             std::vector<uint32_t> globalWorkSize[total_kernel]; | ||||
|             std::vector<uint32_t> localWorkSize[total_kernel]; | ||||
|             std::pair<int, int> min_cost(INT_MAX, 0);//(min_time, min_index)
 | ||||
|              | ||||
|             for(int knl_idx = 0; knl_idx < total_kernel; knl_idx++) { | ||||
|                 kernel[knl_idx]        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], mBuildOptions); | ||||
|                 uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); | ||||
|                  | ||||
|                 globalWorkSize[knl_idx] = {static_cast<uint32_t>(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast<uint32_t>(outputShape.at(0) * UP_DIV(outputShape.at(1), itemH[knl_idx]))}; | ||||
|                 uint32_t idx            = 0; | ||||
|                 kernel[knl_idx].setArg(idx++, globalWorkSize[knl_idx][0]); | ||||
|                 kernel[knl_idx].setArg(idx++, globalWorkSize[knl_idx][1]); | ||||
|                 kernel[knl_idx].setArg(idx++, openCLImage(input)); | ||||
|                 if(mWeightUseBuffer){ | ||||
|                     kernel[knl_idx].setArg(idx++, *mKernelBuffer.get()); | ||||
|                 }else{ | ||||
|                     kernel[knl_idx].setArg(idx++, openCLImage(mFilter.get())); | ||||
|                 } | ||||
|                 kernel[knl_idx].setArg(idx++, openCLImage(mBias.get())); | ||||
|                 kernel[knl_idx].setArg(idx++, openCLImage(output)); | ||||
|                 kernel[knl_idx].setArg(idx++, sizeof(inputImageShape), inputImageShape); | ||||
|                 kernel[knl_idx].setArg(idx++, static_cast<int>(inputChannelBlocks)); | ||||
|                 kernel[knl_idx].setArg(idx++, sizeof(outputImageShape), outputImageShape); | ||||
|                 kernel[knl_idx].setArg(idx++, sizeof(stideShape), stideShape); | ||||
|                 kernel[knl_idx].setArg(idx++, UP_DIV(width, 4)); | ||||
|                 kernel[knl_idx].setArg(idx++, UP_DIV(outputShape.at(3), 4)); | ||||
|                  | ||||
|                 std::pair<std::vector<uint32_t>, uint32_t> retTune; | ||||
|                 retTune = localWS2DDefault(globalWorkSize[knl_idx], mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName[knl_idx] + info, kernel[knl_idx]); | ||||
|                  | ||||
|                 //printf("conv1x1 kernel_%d = %d  [%d, %d]\n", knl_idx, retTune.second, retTune.first[0], retTune.first[1]);
 | ||||
|                 if(min_cost.first > retTune.second) { | ||||
|                     min_cost.first = retTune.second; | ||||
|                     min_cost.second = knl_idx; | ||||
|                     mLocalWorkSize = {retTune.first[0], retTune.first[1]}; | ||||
|                 } | ||||
|             } | ||||
|             int min_index  = min_cost.second; | ||||
|             //printf("min_index = %d  %d\n", min_index, min_cost.first);
 | ||||
|             mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]}; | ||||
|             mKernel        = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], mBuildOptions); | ||||
| 
 | ||||
|             uint32_t idx = 0; | ||||
|             mKernel.setArg(idx++, mGlobalWorkSize[0]); | ||||
|             mKernel.setArg(idx++, mGlobalWorkSize[1]); | ||||
|             mKernel.setArg(idx++, openCLImage(input)); | ||||
|             if(mWeightUseBuffer){ | ||||
|                 mKernel.setArg(idx++, *mKernelBuffer.get()); | ||||
|             }else{ | ||||
|                 mKernel.setArg(idx++, openCLImage(mFilter.get())); | ||||
|             } | ||||
|             mKernel.setArg(idx++, openCLImage(mBias.get())); | ||||
|             mKernel.setArg(idx++, openCLImage(output)); | ||||
|             mKernel.setArg(idx++, sizeof(inputImageShape), inputImageShape); | ||||
|             mKernel.setArg(idx++, static_cast<int>(inputChannelBlocks)); | ||||
|             mKernel.setArg(idx++, sizeof(outputImageShape), outputImageShape); | ||||
|             mKernel.setArg(idx++, sizeof(stideShape), stideShape); | ||||
|             mKernel.setArg(idx++, UP_DIV(width, 4)); | ||||
|             mKernel.setArg(idx++, UP_DIV(outputShape.at(3), 4)); | ||||
|             recordKernel2d(mKernel, mGlobalWorkSize, mLocalWorkSize, mOpenCLBackend->getOpenCLRuntime()); | ||||
|         } | ||||
|     }else { | ||||
|  | @ -385,7 +487,11 @@ ErrorCode ConvExecution::onResize(const std::vector<Tensor *> &inputs, const std | |||
|             ret |= kernel[knl_idx].setArg(idx++, globalWorkSize[knl_idx][0]); | ||||
|             ret |= kernel[knl_idx].setArg(idx++, globalWorkSize[knl_idx][1]); | ||||
|             ret |= kernel[knl_idx].setArg(idx++, openCLImage(input)); | ||||
|             ret |= kernel[knl_idx].setArg(idx++, openCLImage(mFilter.get())); | ||||
|             if(mWeightUseBuffer){ | ||||
|                 ret |= kernel[knl_idx].setArg(idx++, openCLBuffer(mFilter.get())); | ||||
|             }else{ | ||||
|                 ret |= kernel[knl_idx].setArg(idx++, openCLImage(mFilter.get())); | ||||
|             } | ||||
|             ret |= kernel[knl_idx].setArg(idx++, openCLImage(mBias.get())); | ||||
|             ret |= kernel[knl_idx].setArg(idx++, openCLImage(output)); | ||||
|             ret |= kernel[knl_idx].setArg(idx++, sizeof(inputImageShape), inputImageShape); | ||||
|  | @ -418,7 +524,11 @@ ErrorCode ConvExecution::onResize(const std::vector<Tensor *> &inputs, const std | |||
|         ret |= mKernel.setArg(idx++, mGlobalWorkSize[0]); | ||||
|         ret |= mKernel.setArg(idx++, mGlobalWorkSize[1]); | ||||
|         ret |= mKernel.setArg(idx++, openCLImage(input)); | ||||
|         ret |= mKernel.setArg(idx++, openCLImage(mFilter.get())); | ||||
|         if(mWeightUseBuffer){ | ||||
|             ret |= mKernel.setArg(idx++, openCLBuffer(mFilter.get())); | ||||
|         }else{ | ||||
|             ret |= mKernel.setArg(idx++, openCLImage(mFilter.get())); | ||||
|         } | ||||
|         ret |= mKernel.setArg(idx++, openCLImage(mBias.get())); | ||||
|         ret |= mKernel.setArg(idx++, openCLImage(output)); | ||||
|         ret |= mKernel.setArg(idx++, sizeof(inputImageShape), inputImageShape); | ||||
|  | @ -456,7 +566,8 @@ ErrorCode ConvExecution::onExecute(const std::vector<Tensor *> &inputs, const st | |||
|         MNN_PRINT("kernel cost:%f    us Conv UseLocalMem\n",costTime); | ||||
|     #else | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|             if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|                 mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
| #ifdef LOG_VERBOSE | ||||
|     MNN_PRINT("end ConvExecution onExecute !\n"); | ||||
| #endif | ||||
|  | @ -476,7 +587,8 @@ ErrorCode ConvExecution::onExecute(const std::vector<Tensor *> &inputs, const st | |||
|     MNN_PRINT("kernel cost:%d    us Conv2D\n",costTime); | ||||
| #else | ||||
|     if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
| #ifdef LOG_VERBOSE | ||||
|         MNN_PRINT("end ConvExecution onExecute !\n"); | ||||
| #endif | ||||
|  |  | |||
|  | @ -57,6 +57,7 @@ private: | |||
|     std::shared_ptr<cl::Buffer> mKernelBuffer; | ||||
|     std::shared_ptr<cl::Buffer> mBiasBuffer; | ||||
|     std::set<std::string> mBuildOptions; | ||||
|     bool mWeightUseBuffer = false; | ||||
| }; | ||||
| 
 | ||||
| } // namespace OpenCL
 | ||||
|  |  | |||
|  | @ -237,7 +237,6 @@ ErrorCode ConvWinograd::onResize(const std::vector<Tensor*>& inputs, const std:: | |||
|                                      "winogradTransformDest", buildOptions); | ||||
|             mMaxWGS_D[i] = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mDestTransform[i])); | ||||
|         } | ||||
|         mMatMul[i] = runTime->buildKernel("gemm", "gemmWinograd", basic); | ||||
|         mMaxWGS_M[i] = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mMatMul[i])); | ||||
|     } | ||||
|      | ||||
|  | @ -261,14 +260,6 @@ ErrorCode ConvWinograd::onResize(const std::vector<Tensor*>& inputs, const std:: | |||
|         ret |= mSourceTransform[b].setArg(8, icC4); | ||||
|         ret |= mSourceTransform[b].setArg(9, b); | ||||
| 
 | ||||
|         ret |= mMatMul[b].setArg(0, openCLImage(mSource.get())); | ||||
|         ret |= mMatMul[b].setArg(1, *mWeight); | ||||
|         ret |= mMatMul[b].setArg(2, openCLImage(mDest.get())); | ||||
|         ret |= mMatMul[b].setArg(3, wUnit); | ||||
|         ret |= mMatMul[b].setArg(4, hUnit); | ||||
|         ret |= mMatMul[b].setArg(5, ocC4); | ||||
|         ret |= mMatMul[b].setArg(6, icC4); | ||||
|         ret |= mMatMul[b].setArg(7, alpha*alpha); | ||||
| 
 | ||||
|         ret |= mDestTransform[b].setArg(0, openCLImage(mDest.get())); | ||||
|         ret |= mDestTransform[b].setArg(1, *mBias); | ||||
|  | @ -291,10 +282,56 @@ ErrorCode ConvWinograd::onResize(const std::vector<Tensor*>& inputs, const std:: | |||
| 
 | ||||
|         /*MatMul*/ | ||||
|         { | ||||
|             const int total_kernel                     = 2; | ||||
|             const std::string kernelName[total_kernel] = {"gemmWinograd", "gemmWinogradW2"}; | ||||
|             int itemW[total_kernel]                    = {4, 8}; | ||||
|             auto gemmHeight = ocC4; | ||||
|             mGWS_M[b] = {static_cast<uint32_t>(UP_DIV(wUnit, 4) * hUnit), static_cast<uint32_t>(alpha * alpha * ocC4)}; | ||||
|             std::string kernelName = "gemmWinograd"; | ||||
|             mLWS_M[b] = localWS2DDefault(mGWS_M[b], mMaxWGS_M[b], mOpenCLBackend->getOpenCLRuntime(), kernelName, mMatMul[b]).first; | ||||
|             int actual_kernel = total_kernel; | ||||
|              | ||||
|             cl::Kernel kernel[total_kernel]; | ||||
|             std::vector<uint32_t> globalWorkSize[total_kernel]; | ||||
|             std::vector<uint32_t> localWorkSize[total_kernel]; | ||||
|             std::pair<uint32_t, int> min_cost(UINT_MAX, 0); //(min_time, min_index)
 | ||||
|             for (int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) { | ||||
|                 cl_int ret = CL_SUCCESS; | ||||
|                 kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm", kernelName[knl_idx], basic); | ||||
|                 uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); | ||||
| 
 | ||||
|                 globalWorkSize[knl_idx] = {static_cast<uint32_t>(UP_DIV(wUnit, itemW[knl_idx]) * hUnit), static_cast<uint32_t>(alpha * alpha * ocC4)}; | ||||
|                 ret |= kernel[knl_idx].setArg(0, openCLImage(mSource.get())); | ||||
|                 ret |= kernel[knl_idx].setArg(1, *mWeight); | ||||
|                 ret |= kernel[knl_idx].setArg(2, openCLImage(mDest.get())); | ||||
|                 ret |= kernel[knl_idx].setArg(3, wUnit); | ||||
|                 ret |= kernel[knl_idx].setArg(4, hUnit); | ||||
|                 ret |= kernel[knl_idx].setArg(5, ocC4); | ||||
|                 ret |= kernel[knl_idx].setArg(6, icC4); | ||||
|                 ret |= kernel[knl_idx].setArg(7, alpha*alpha); | ||||
|                 MNN_CHECK_CL_SUCCESS(ret, "setArg ConvWinogradExecution gemm"); | ||||
| 
 | ||||
|                 std::pair<std::vector<uint32_t>, uint32_t> retTune; | ||||
|                 retTune = localWS2DDefault(globalWorkSize[knl_idx], maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName[knl_idx], kernel[knl_idx]); | ||||
|                 // printf("gemm %d, %d\n", knl_idx, retTune.second);
 | ||||
|                 if (min_cost.first > retTune.second) { | ||||
|                     min_cost.first  = retTune.second; | ||||
|                     min_cost.second = knl_idx; | ||||
|                     mLWS_M[b]       = {retTune.first[0], retTune.first[1]}; | ||||
|                 } | ||||
|             } | ||||
|             cl_int ret = CL_SUCCESS; | ||||
|             int min_index = min_cost.second; | ||||
|             //printf("gemm min_index = %d  %d\n", min_index, min_cost.first);
 | ||||
|             mMatMul[b] = runTime->buildKernel("gemm", kernelName[min_index], basic); | ||||
|              | ||||
|             ret |= mMatMul[b].setArg(0, openCLImage(mSource.get())); | ||||
|             ret |= mMatMul[b].setArg(1, *mWeight); | ||||
|             ret |= mMatMul[b].setArg(2, openCLImage(mDest.get())); | ||||
|             ret |= mMatMul[b].setArg(3, wUnit); | ||||
|             ret |= mMatMul[b].setArg(4, hUnit); | ||||
|             ret |= mMatMul[b].setArg(5, ocC4); | ||||
|             ret |= mMatMul[b].setArg(6, icC4); | ||||
|             ret |= mMatMul[b].setArg(7, alpha*alpha); | ||||
|             MNN_CHECK_CL_SUCCESS(ret, "setArg ConvWinogradExecution gemm"); | ||||
|             mGWS_M[b] = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]}; | ||||
|             recordKernel2d(mMatMul[b], mGWS_M[b], mLWS_M[b], mOpenCLBackend->getOpenCLRuntime()); | ||||
|         } | ||||
| 
 | ||||
|  | @ -319,7 +356,8 @@ ErrorCode ConvWinograd::onExecute(const std::vector<Tensor*>& inputs, const std: | |||
|     int costTime = 0; | ||||
|     #else | ||||
|     if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         return NO_ERROR; | ||||
|     } | ||||
|     #endif | ||||
|  |  | |||
|  | @ -182,7 +182,8 @@ ErrorCode DeconvExecution::onExecute(const std::vector<Tensor *> &inputs, const | |||
|     MNN_PRINT("kernel cost:%d    us Deconv\n",costTime); | ||||
| #else | ||||
|     if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
| #ifdef LOG_VERBOSE | ||||
|         MNN_PRINT("End DeconvExecution onExecute... \n"); | ||||
| #endif | ||||
|  |  | |||
|  | @ -169,7 +169,8 @@ ErrorCode DepthwiseConvExecution::onExecute(const std::vector<Tensor *> &inputs, | |||
|     MNN_PRINT("kernel cost:%d    us DepthwiseConv\n",costTime); | ||||
| #else | ||||
|     if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
| #ifdef LOG_VERBOSE | ||||
|         MNN_PRINT("End DepthwiseConvExecution onExecute... \n"); | ||||
| #endif | ||||
|  |  | |||
|  | @ -172,7 +172,8 @@ ErrorCode DepthwiseDeconvExecution::onExecute(const std::vector<Tensor *> &input | |||
|     MNN_PRINT("kernel cost:%d    us DepthwiseDeconv\n",costTime); | ||||
| #else | ||||
|     if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
| #ifdef LOG_VERBOSE | ||||
|         MNN_PRINT("End DepthwiseDeconvExecution onExecute... \n"); | ||||
| #endif | ||||
|  |  | |||
|  | @ -87,7 +87,8 @@ ErrorCode FuseExecution::onExecute(const std::vector<Tensor *> &inputs, const st | |||
|     MNN_PRINT("kernel cost:%d    us Fuse\n",costTime); | ||||
| #else | ||||
|     if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
| #ifdef LOG_VERBOSE | ||||
|         MNN_PRINT("end SoftmaxExecution onExecute !\n"); | ||||
| #endif | ||||
|  |  | |||
|  | @ -96,7 +96,8 @@ ErrorCode GridSampleExecution::onExecute(const std::vector<Tensor *> &inputs, co | |||
|     MNN_PRINT("kernel cost:%d    us GridSample\n", costTime); | ||||
| #else | ||||
|     if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         return NO_ERROR; | ||||
|     } | ||||
|     run3DKernelDefault(mKernel, mGlobalWorkSize, mLocalWorkSize, mOpenCLBackend->getOpenCLRuntime()); | ||||
|  |  | |||
|  | @ -107,7 +107,8 @@ ErrorCode Interp3DExecution::onExecute(const std::vector<Tensor *> &inputs, cons | |||
|     MNN_PRINT("kernel cost:%d    us Interp3D\n",costTime); | ||||
| #else | ||||
|     if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
| #ifdef LOG_VERBOSE | ||||
|         MNN_PRINT("End Interp3DExecution onExecute... \n"); | ||||
| #endif | ||||
|  |  | |||
|  | @ -99,7 +99,8 @@ ErrorCode InterpExecution::onExecute(const std::vector<Tensor *> &inputs, const | |||
|     MNN_PRINT("kernel cost:%d    us Interp\n",costTime); | ||||
| #else | ||||
|     if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
| #ifdef LOG_VERBOSE | ||||
|         MNN_PRINT("End InterpExecution onExecute... \n"); | ||||
| #endif | ||||
|  |  | |||
|  | @ -180,7 +180,8 @@ ErrorCode LayerNormExecution::onExecute(const std::vector<Tensor *> &inputs, con | |||
|     MNN_PRINT("kernel cost:%d    us LayerNorm\n",costTime); | ||||
| #else | ||||
|     if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
| #ifdef LOG_VERBOSE | ||||
|         MNN_PRINT("End LayerNormExecution onExecute... \n"); | ||||
| #endif | ||||
|  |  | |||
|  | @ -120,7 +120,8 @@ ErrorCode MatMulExecution::onExecute(const std::vector<Tensor *> &inputs, const | |||
|         MNN_PRINT("kernel cost:%d    us Matmul\n",costTime); | ||||
|     #else | ||||
|     if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
| #ifdef LOG_VERBOSE | ||||
|         MNN_PRINT("End MatMulExecution onExecute... \n"); | ||||
| #endif | ||||
|  |  | |||
|  | @ -155,7 +155,8 @@ ErrorCode PoolExecution::onExecute(const std::vector<Tensor *> &inputs, const st | |||
|     MNN_PRINT("kernel cost:%d    us Pooling\n",costTime); | ||||
| #else | ||||
|     if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
| #ifdef LOG_VERBOSE | ||||
|         MNN_PRINT("End PoolExecution onExecute... \n"); | ||||
| #endif | ||||
|  |  | |||
|  | @ -153,6 +153,7 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con | |||
|         return NO_ERROR; | ||||
|     } | ||||
|      | ||||
|     bool cancombine = CanCombine(outputs); | ||||
|     // Alloc Temp buffer
 | ||||
|     auto bufferPool     = ((OpenCLBackend *)backend())->getBufferPool(); | ||||
|     auto bufferUnitSize = runtime->isSupportedFP16() ? sizeof(half_float::half) : sizeof(float); | ||||
|  | @ -176,6 +177,9 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con | |||
|     bufferPool->recycle(mTempOutput); | ||||
|      | ||||
|     auto originNum = mTempInput.size(); | ||||
|     if(cancombine){ | ||||
|         regionNum = 1; | ||||
|     } | ||||
|     mUnits.resize(regionNum + originNum + 1); | ||||
|      | ||||
|     int kernel_idx = 0; | ||||
|  | @ -259,18 +263,23 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con | |||
|     } | ||||
|      | ||||
|     // buffer raster
 | ||||
|     for (auto& slice : des->regions) | ||||
|     { | ||||
|     if(cancombine){ | ||||
|         auto regions = des->regions; | ||||
|         auto slice = regions[0]; | ||||
|         int nums = regions.size(); | ||||
|         int src_offset = regions[1].src.offset - slice.src.offset; | ||||
|         int dst_offset = regions[1].dst.offset - slice.dst.offset; | ||||
|          | ||||
|         Unit &unit          = mUnits[kernel_idx++]; | ||||
|         unit.kernel         = runtime->buildKernel("raster", "raster_buffer", {}); | ||||
|         unit.kernel         = runtime->buildKernel("raster", "raster_buffer_combine", {}); | ||||
|          | ||||
|         unit.globalWorkSize = {(uint32_t)slice.size[2], | ||||
|                                (uint32_t)slice.size[1], | ||||
|                                (uint32_t)slice.size[0]}; | ||||
|         unit.globalWorkSize = {(uint32_t)slice.size[2] * nums, | ||||
|             (uint32_t)slice.size[1], | ||||
|             (uint32_t)slice.size[0]}; | ||||
|          | ||||
|         const std::vector<uint32_t> gws =  {(uint32_t)slice.size[2], | ||||
|                                                 (uint32_t)slice.size[1], | ||||
|                                                 (uint32_t)slice.size[0]}; | ||||
|         const std::vector<uint32_t> gws =  {(uint32_t)slice.size[2] * nums, | ||||
|             (uint32_t)slice.size[1], | ||||
|             (uint32_t)slice.size[0]}; | ||||
|         uint32_t mMaxWorkGroupSize      = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(unit.kernel)); | ||||
|          | ||||
|         uint32_t idx   = 0; | ||||
|  | @ -280,14 +289,17 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con | |||
|         ret |= unit.kernel.setArg(idx++, gws[2]); | ||||
|         ret |= unit.kernel.setArg(idx++, *(mTempInput[slice.origin])); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.src.offset); | ||||
|         ret |= unit.kernel.setArg(idx++, src_offset); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.src.stride[0]); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.src.stride[1]); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.src.stride[2]); | ||||
|         ret |= unit.kernel.setArg(idx++, *mTempOutput); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.dst.offset); | ||||
|         ret |= unit.kernel.setArg(idx++, dst_offset); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.dst.stride[0]); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.dst.stride[1]); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.dst.stride[2]); | ||||
|         ret |= unit.kernel.setArg(idx++, slice.size[2]); | ||||
|         if(ret != CL_SUCCESS) | ||||
|         { | ||||
|             MNN_PRINT("setArg err %d\n", (int)ret); | ||||
|  | @ -299,9 +311,54 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con | |||
|         unit.localWorkSize = {lws[0], lws[1], lws[2]}; | ||||
|          | ||||
|         unit.globalWorkSize = {ROUND_UP(gws[0], std::max((uint32_t)1, lws[0])), | ||||
|                                ROUND_UP(gws[1], std::max((uint32_t)1, lws[1])), | ||||
|                                ROUND_UP(gws[2], std::max((uint32_t)1, lws[2]))}; | ||||
|             ROUND_UP(gws[1], std::max((uint32_t)1, lws[1])), | ||||
|             ROUND_UP(gws[2], std::max((uint32_t)1, lws[2]))}; | ||||
|         recordKernel3d(unit.kernel, gws, lws, runtime); | ||||
|     }else{ | ||||
|         for (auto& slice : des->regions) | ||||
|         { | ||||
|             Unit &unit          = mUnits[kernel_idx++]; | ||||
|             unit.kernel         = runtime->buildKernel("raster", "raster_buffer", {}); | ||||
|              | ||||
|             unit.globalWorkSize = {(uint32_t)slice.size[2], | ||||
|                 (uint32_t)slice.size[1], | ||||
|                 (uint32_t)slice.size[0]}; | ||||
|              | ||||
|             const std::vector<uint32_t> gws =  {(uint32_t)slice.size[2], | ||||
|                 (uint32_t)slice.size[1], | ||||
|                 (uint32_t)slice.size[0]}; | ||||
|             uint32_t mMaxWorkGroupSize      = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(unit.kernel)); | ||||
|              | ||||
|             uint32_t idx   = 0; | ||||
|             cl_int ret = CL_SUCCESS; | ||||
|             ret |= unit.kernel.setArg(idx++, gws[0]); | ||||
|             ret |= unit.kernel.setArg(idx++, gws[1]); | ||||
|             ret |= unit.kernel.setArg(idx++, gws[2]); | ||||
|             ret |= unit.kernel.setArg(idx++, *(mTempInput[slice.origin])); | ||||
|             ret |= unit.kernel.setArg(idx++, slice.src.offset); | ||||
|             ret |= unit.kernel.setArg(idx++, slice.src.stride[0]); | ||||
|             ret |= unit.kernel.setArg(idx++, slice.src.stride[1]); | ||||
|             ret |= unit.kernel.setArg(idx++, slice.src.stride[2]); | ||||
|             ret |= unit.kernel.setArg(idx++, *mTempOutput); | ||||
|             ret |= unit.kernel.setArg(idx++, slice.dst.offset); | ||||
|             ret |= unit.kernel.setArg(idx++, slice.dst.stride[0]); | ||||
|             ret |= unit.kernel.setArg(idx++, slice.dst.stride[1]); | ||||
|             ret |= unit.kernel.setArg(idx++, slice.dst.stride[2]); | ||||
|             if(ret != CL_SUCCESS) | ||||
|             { | ||||
|                 MNN_PRINT("setArg err %d\n", (int)ret); | ||||
|             } | ||||
|              | ||||
|             std::string name = "rasterBuffer"; | ||||
|             const std::vector<uint32_t> lws = localWS3DDefault(gws, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), name, unit.kernel).first; | ||||
|              | ||||
|             unit.localWorkSize = {lws[0], lws[1], lws[2]}; | ||||
|              | ||||
|             unit.globalWorkSize = {ROUND_UP(gws[0], std::max((uint32_t)1, lws[0])), | ||||
|                 ROUND_UP(gws[1], std::max((uint32_t)1, lws[1])), | ||||
|                 ROUND_UP(gws[2], std::max((uint32_t)1, lws[2]))}; | ||||
|             recordKernel3d(unit.kernel, gws, lws, runtime); | ||||
|         } | ||||
|     } | ||||
|      | ||||
|     //buffer to image
 | ||||
|  | @ -364,6 +421,44 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con | |||
|     return NO_ERROR; | ||||
| } | ||||
| 
 | ||||
| bool RasterExecution::CanCombine(const std::vector<Tensor *> &outputs){ | ||||
|     auto des = TensorUtils::getDescribe(outputs[0]); | ||||
|     auto regions = des->regions; | ||||
|     if(regions.size() < 2) | ||||
|         return false; | ||||
|     auto origin = regions[0].origin; | ||||
|     const int size0 = regions[0].size[0]; | ||||
|     const int size1 = regions[0].size[1]; | ||||
|     const int size2 = regions[0].size[2]; | ||||
|     const int src_offset = regions[1].src.offset - regions[0].src.offset; | ||||
|     const int dst_offset = regions[1].dst.offset - regions[0].dst.offset; | ||||
|     const int src_sride0 = regions[0].src.stride[0]; | ||||
|     const int src_sride1 = regions[0].src.stride[1]; | ||||
|     const int src_sride2 = regions[0].src.stride[2]; | ||||
|     const int dst_sride0 = regions[0].dst.stride[0]; | ||||
|     const int dst_sride1 = regions[0].dst.stride[1]; | ||||
|     const int dst_sride2 = regions[0].dst.stride[2]; | ||||
|     bool res = true; | ||||
|     for(int i = 1; i < regions.size(); ++i){ | ||||
|         res &= regions[i].origin == origin; | ||||
|         res &= regions[i].size[0] == size0; | ||||
|         res &= regions[i].size[1] == size1; | ||||
|         res &= regions[i].size[2] == size2; | ||||
|         res &= regions[i].src.stride[0] == src_sride0; | ||||
|         res &= regions[i].src.stride[1] == src_sride1; | ||||
|         res &= regions[i].src.stride[2] == src_sride2; | ||||
|         res &= regions[i].dst.stride[0] == dst_sride0; | ||||
|         res &= regions[i].dst.stride[1] == dst_sride1; | ||||
|         res &= regions[i].dst.stride[2] == dst_sride2; | ||||
|         res &= (regions[i].src.offset - regions[i - 1].src.offset) == src_offset; | ||||
|         res &= (regions[i].dst.offset - regions[i - 1].dst.offset) == dst_offset; | ||||
|         if(res == false){ | ||||
|             return res; | ||||
|         } | ||||
|     } | ||||
|     return res; | ||||
| } | ||||
| 
 | ||||
| OpenCLCreatorRegister<TypedCreator<RasterExecution>> __Raster_op(OpType_Raster, IMAGE); | ||||
| } // namespace OpenCL
 | ||||
| } // namespace MNN
 | ||||
|  |  | |||
|  | @ -26,6 +26,7 @@ public: | |||
|     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | ||||
|      | ||||
| private: | ||||
|     bool CanCombine(const std::vector<Tensor *> &outputs); | ||||
|     std::map<Tensor*, cl::Buffer *> mTempInput; | ||||
|     cl::Buffer *mTempOutput; | ||||
|     OpenCLBackend *mOpenCLBackend; | ||||
|  |  | |||
|  | @ -175,7 +175,8 @@ ErrorCode ReductionExecution::onExecute(const std::vector<Tensor *> &inputs, con | |||
|         MNN_PRINT("kernel cost:%d    us Reduct1D\n",costTime); | ||||
|     #else | ||||
|     if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
| #ifdef LOG_VERBOSE | ||||
|         MNN_PRINT("End ReductionExecution onExecute... \n"); | ||||
| #endif | ||||
|  |  | |||
|  | @ -133,7 +133,8 @@ ErrorCode RoiPooling::onExecute(const std::vector<Tensor *> &inputs, const std:: | |||
|     MNN_PRINT("kernel cost:%d    us RoiPooling\n",costTime); | ||||
| #else | ||||
|     if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
| #ifdef LOG_VERBOSE | ||||
|         MNN_PRINT("End RoiPooling onExecute... \n"); | ||||
| #endif | ||||
|  |  | |||
|  | @ -175,7 +175,8 @@ ErrorCode ScaleExecution::onExecute(const std::vector<Tensor *> &inputs, const s | |||
|     MNN_PRINT("kernel cost:%d    us Softmax\n",costTime); | ||||
| #else | ||||
|     if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
| #ifdef LOG_VERBOSE | ||||
|         MNN_PRINT("End ScaleExecution onExecute... \n"); | ||||
| #endif | ||||
|  |  | |||
|  | @ -117,7 +117,8 @@ ErrorCode SoftmaxExecution::onExecute(const std::vector<Tensor *> &inputs, const | |||
|     MNN_PRINT("kernel cost:%d    us Softmax\n",costTime); | ||||
| #else | ||||
|     if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
| #ifdef LOG_VERBOSE | ||||
|         MNN_PRINT("End SoftmaxExecution onExecute... \n"); | ||||
| #endif | ||||
|  |  | |||
|  | @ -79,7 +79,8 @@ ErrorCode UnaryExecution::onExecute(const std::vector<Tensor*>& inputs, const st | |||
| #else | ||||
|     auto openCLBackend = static_cast<OpenCLBackend*>(backend()); | ||||
|     if(openCLBackend->getOpenCLRuntime()->isUseRecordQueue()){ | ||||
|         mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
|         if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord()) | ||||
|             mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording); | ||||
| #ifdef LOG_VERBOSE | ||||
|         MNN_PRINT("End UnaryExecution onExecute... \n"); | ||||
| #endif | ||||
|  |  | |||
|  | @ -318,6 +318,7 @@ void VulkanBackend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTenso | |||
|             tempTensor.reset(Tensor::create(srcTensor->shape(), dstTensor->getType(), nullptr, _convert(TensorUtils::getDescribe(srcTensor)->dimensionFormat)), [dstTensor](void* t) { | ||||
|                 Tensor* temp = (Tensor*)t; | ||||
|                 MNNCPUCopyBuffer(temp, dstTensor); | ||||
|                 delete temp; | ||||
|             }); | ||||
|             dstTensor = tempTensor.get(); | ||||
|         } | ||||
|  |  | |||
|  | @ -36,9 +36,9 @@ public: | |||
|             MNN_ERROR("Invalid origin conv op\n"); | ||||
|             return nullptr; | ||||
|         } | ||||
|         auto op = originConv->expr().first->get()->UnPack(); | ||||
|         std::unique_ptr<MNN::OpT> op( originConv->expr().first->get()->UnPack()); | ||||
|         op->main.AsConvolution2D()->symmetricQuan->winogradAttr = encode(); | ||||
|         return (Express::Variable::create(Express::Expr::create(op, originConv->expr().first->inputs()))); | ||||
|         return (Express::Variable::create(Express::Expr::create(op.get(), originConv->expr().first->inputs()))); | ||||
|     } | ||||
|     std::vector<Attr> attrs; | ||||
| private: | ||||
|  |  | |||
|  | @ -567,7 +567,8 @@ void ConvolutionCommon::getConvParameters(std::shared_ptr<Int8Common> *quanCommo | |||
|     *originWeight = nullptr; | ||||
|     *originWeightSize = 0; | ||||
|     if (nullptr != conv2d->quanParameter()) { | ||||
|         *quanCommon = load(conv2d->quanParameter(), false); | ||||
|         bool forceFloat = conv2d->quanParameter()->index() != nullptr; | ||||
|         *quanCommon = load(conv2d->quanParameter(), forceFloat); | ||||
|         *originWeight     = (*quanCommon)->weightFloat.get(); | ||||
|         *originWeightSize = (*quanCommon)->weightFloat.size(); | ||||
|     } | ||||
|  |  | |||
|  | @ -581,7 +581,9 @@ static ErrorCode _createExecutions(Schedule::PipelineInfo& mInfo) { | |||
|                 // Try Backup
 | ||||
|                 iter.execution.reset(mBackupBackend->onCreate(iter.inputs, iter.outputs, iter.op)); | ||||
|                 if (nullptr == iter.execution) { | ||||
|                     MNN_ERROR("Create exection error : %d\n", iter.op->type()); | ||||
|                     if (mInfo.first.reportError) { | ||||
|                         MNN_ERROR("Create execution error : %d\n", iter.op->type()); | ||||
|                     } | ||||
|                     return NOT_SUPPORT; | ||||
|                 } | ||||
|             } | ||||
|  |  | |||
|  | @ -61,6 +61,7 @@ public: | |||
|         std::pair<std::shared_ptr<Backend>, std::shared_ptr<Backend>> cache; | ||||
|         bool needComputeShape = true; | ||||
|         bool needComputeGeometry = true; | ||||
|         bool reportError = true; | ||||
|         std::map<Tensor*, TENSORCACHE> inputTensorCopyCache; | ||||
|     }; | ||||
|     typedef std::pair<BackendCache, std::vector<OpCacheInfo>> PipelineInfo; | ||||
|  |  | |||
|  | @ -852,7 +852,7 @@ public: | |||
|             return true; | ||||
|         }; | ||||
|         auto y = _mobileNetV1Expr(); | ||||
|         bool res = func(y, 60.0f); | ||||
|         bool res = func(y, 62.0f); | ||||
|         if (!res) { | ||||
|             return false; | ||||
|         } | ||||
|  |  | |||
|  | @ -22,7 +22,7 @@ protected: | |||
|     template<typename Tin, typename Tout> | ||||
|     bool test(VARP (*opFunc)(VARP, VARP), string name, float threshold, | ||||
|               const vector<Tin>& data_x, const vector<Tin>& data_y, const vector<Tout>& data_out, | ||||
|               const vector<int>& shape_x, const vector<int>& shape_y, const vector<int>& shape_out, const vector<float> quantScales={-100, -100, -100}, const vector<float> zeroPoints={-100, -100, -100}) { | ||||
|               const vector<int>& shape_x, const vector<int>& shape_y, const vector<int>& shape_out, const vector<float> quantScales={-100.f, -100.f, -100.f}, const vector<float> zeroPoints={-100.f, -100.f, -100.f}, Dimensionformat format = NCHW) { | ||||
|         int size_x = 1, size_y = 1, size_out = 1; | ||||
|         for (int i = 0; i < shape_x.size(); ++i) { | ||||
|             size_x *= shape_x[i]; | ||||
|  | @ -34,8 +34,8 @@ protected: | |||
|             size_out *= shape_out[i]; | ||||
|         } | ||||
| 
 | ||||
|         auto input_x = _Input(shape_x, NCHW, halide_type_of<Tin>()); | ||||
|         auto input_y = _Input(shape_y, NCHW, halide_type_of<Tin>()); | ||||
|         auto input_x = _Input(shape_x, format, halide_type_of<Tin>()); | ||||
|         auto input_y = _Input(shape_y, format, halide_type_of<Tin>()); | ||||
|         input_x->setName("input_x"); | ||||
|         input_y->setName("input_y"); | ||||
|         if (quantScales[0] != -100) { // -100 means invalid scale.
 | ||||
|  | @ -593,6 +593,42 @@ public: | |||
|     } | ||||
| }; | ||||
| 
 | ||||
| class AddC4Test : public BinaryTestCommon { | ||||
| public: | ||||
|     virtual ~AddC4Test() = default; | ||||
|     virtual bool run(int precision) { | ||||
|         { | ||||
|             vector<float> inp2 = {1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6}, inp1 = {2}; | ||||
|             vector<float> rightResult = {3.1, 4.2, 5.3, 6.6,3.1, 4.2, 5.3, 6.6,3.1, 4.2, 5.3, 6.6,3.1, 4.2, 5.3, 6.6,3.1, 4.2, 5.3, 6.6, 3.1, 4.2, 5.3, 6.6, 3.1, 4.2, 5.3, 6.6, 3.1, 4.2, 5.3, 6.6}; | ||||
|             bool res = test<float, float>(MNN::Express::_Add, "AddInt8C4Test", 0.01, inp1, inp2, rightResult, {1, 1, 1, 1}, {1, 32, 1, 1}, {1, 32, 1, 1}, {0.4, 0.4, 1.0}, | ||||
|                                       {1., 2., 3.}, NC4HW4); | ||||
|             if (!res) { | ||||
|                 FUNC_PRINT(1); | ||||
|                 return false; | ||||
|             } | ||||
|         } | ||||
|         std::vector<float> i1 = { | ||||
|             -1.0, -2.0, 0.f, 0.f | ||||
|             -3.0, -4.0, 0.f, 0.f | ||||
|             -5.0, -6.0, 0.f, 0.f | ||||
|             -7.0, -8.0, 0.f, 0.f | ||||
|         }; | ||||
|         std::vector<float> i0 = { | ||||
|             1.0f, 0.0f, 0.f, 0.f | ||||
|         }; | ||||
|         std::vector<float> i2 = { | ||||
|             0.0, -1.0, 0.f, 0.f | ||||
|             -2.0, -3.0, 0.f, 0.f | ||||
|             -4.0, -5.0, 0.f, 0.f | ||||
|             -6.0, -7.0, 0.f, 0.f | ||||
|         }; | ||||
|         return test<float, float>(MNN::Express::_BiasAdd, "AddC4FloatTest", 0.01, | ||||
|                     i0, i1, i2, | ||||
|                                   {1, 1, 1, 1}, {4, 2, 1, 1}, {4, 2, 1, 1}, {-100.f, -100.f, -100.f}, {-100.f, -100.f, -100.f}, NC4HW4); | ||||
| 
 | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| // Float32 OpTest.
 | ||||
| MNNTestSuiteRegister(BinaryBroadcastShapeTest, "op/binary/broadcastShapeTest"); | ||||
| MNNTestSuiteRegister(AddTest, "op/binary/add"); | ||||
|  | @ -633,3 +669,5 @@ MNNTestSuiteRegister(FloorDivInt8Test, "op/binary/floordivInt8"); | |||
| MNNTestSuiteRegister(FloorModInt8Test, "op/binary/floormodInt8"); | ||||
| MNNTestSuiteRegister(Atan2Int8Test, "op/binary/atan2Int8"); | ||||
| MNNTestSuiteRegister(SquaredDifferenceInt8Test, "op/binary/sqdInt8"); | ||||
| 
 | ||||
| MNNTestSuiteRegister(AddC4Test, "op/binary/addC4"); | ||||
|  |  | |||
|  | @ -0,0 +1,225 @@ | |||
| //
 | ||||
| //  TopKV2Execution.hpp
 | ||||
| //  MNN
 | ||||
| //
 | ||||
| //  Created by MNN on 2023/07/19.
 | ||||
| //  Copyright © 2018, Alibaba Group Holding Limited
 | ||||
| //
 | ||||
| 
 | ||||
| #include <MNN/expr/Expr.hpp> | ||||
| #include <MNN/expr/ExprCreator.hpp> | ||||
| #include "MNNTestSuite.h" | ||||
| #include "TestUtils.h" | ||||
| #include <random> | ||||
| #include <vector> | ||||
| 
 | ||||
| using namespace MNN::Express; | ||||
| 
 | ||||
| 
 | ||||
| template<typename valueT, typename indexT> | ||||
| void MinHeapify(valueT * arr, indexT * index, int size, int i) { | ||||
|     int l = 2 * i + 1; | ||||
|     int r = 2 * i + 2; | ||||
|     int smallest = i; | ||||
|     if (l < size && arr[l] < arr[smallest]) { | ||||
|         smallest = l; | ||||
|     } | ||||
|     if (r < size && arr[r] < arr[smallest]) { | ||||
|         smallest = r; | ||||
|     } | ||||
|     if (smallest != i) { | ||||
|         std::swap(arr[i], arr[smallest]); | ||||
|         std::swap(index[i], index[smallest]); | ||||
|         MinHeapify<valueT, indexT>(arr, index, size, smallest); | ||||
|     } | ||||
| 
 | ||||
|     return; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| template<typename valueT, typename indexT> | ||||
| void BuildMinHeap(valueT * arr, indexT * index, int size) { | ||||
|     for (int i = size / 2 - 1; i >= 0; i--) { | ||||
|         MinHeapify<valueT, indexT>(arr, index, size, i); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| template<typename valueT, typename indexT> | ||||
| void Sort(valueT * values, indexT * indices, const int num) { | ||||
|     valueT * _values = static_cast<valueT *>(values); | ||||
|     indexT * _indices = static_cast<indexT *>(indices); | ||||
|     for (int i = 0; i < num - 1; i++) { | ||||
|         for (int j = 0; j < num - i - 1; j++) { | ||||
|             if (_values[j] < _values[j + 1]) { | ||||
|                 std::swap(_values[j], _values[j + 1]); | ||||
|                 std::swap(_indices[j], _indices[j + 1]); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     return; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| template<typename valueT, typename indexT> | ||||
| void CpuKernelOneRow(const valueT * input, indexT * outputIndices, valueT * outputValues, const int K, const int length) { | ||||
|     for (int i = 0; i < K; i++) { | ||||
|         outputIndices[i] = i; | ||||
|         outputValues[i] = input[i]; | ||||
|     } | ||||
|     BuildMinHeap<valueT, indexT>(outputValues, outputIndices, K); | ||||
|     for (int i = K; i < length; i++) { | ||||
|         if (input[i] > outputValues[0]) { | ||||
|             outputValues[0] = input[i]; | ||||
|             outputIndices[0] = i; | ||||
|             MinHeapify<valueT, indexT>(outputValues, outputIndices, K, 0); | ||||
|         } | ||||
|     } | ||||
|     Sort<valueT, indexT>(outputValues, outputIndices, K); | ||||
| 
 | ||||
|     return; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| template<typename indexT, typename valueT> | ||||
| void CpuKernelAllRows(valueT * input, indexT * outputIndices, valueT * outputValues, const int K, const int lengthRow, const int numRow, int descendFlag) { | ||||
|     for (int i = 0; i < lengthRow * numRow; i++) { | ||||
|         input[i] = input[i] * descendFlag; | ||||
|     } | ||||
| 
 | ||||
|     for (int i = 0; i < numRow; i++) { | ||||
|         const valueT * inputThisRow = input + lengthRow * i; | ||||
|         indexT * outputIndicesThisRow = outputIndices + K * i; | ||||
|         valueT * outputValuesThisRow = outputValues + K * i; | ||||
|         CpuKernelOneRow(inputThisRow, outputIndicesThisRow, outputValuesThisRow, K, lengthRow); | ||||
|     } | ||||
| 
 | ||||
|     for (int i = 0; i < lengthRow * numRow; i++) { | ||||
|         input[i] = input[i] * descendFlag; | ||||
|     } | ||||
| 
 | ||||
|     for (int i = 0 ; i < numRow * K; i++) { | ||||
|         outputValues[i] = outputValues[i] * descendFlag; | ||||
|     } | ||||
| 
 | ||||
|     return; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| void RandomInitFloat(float * array, const int & numEle) { | ||||
|     std::mt19937 rng(4); | ||||
|     std::uniform_real_distribution<float> dist(0.0, 1.0); | ||||
| 
 | ||||
|     for (int i = 0; i < numEle; i++) { | ||||
|         array[i] = dist(rng); | ||||
|     } | ||||
|     return; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| void SetK(int * valuePtr, const int K) { | ||||
|     *valuePtr = K; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| bool checkIndicesHalf(const float * input, const float * expectedOutput0, const int * gotOutput1, const int K, const int numRow, const int lengthRow) { | ||||
|     for (int i = 0; i < numRow; i++) { | ||||
|         for (int j = 0; j < K; j++) { | ||||
|             bool condition = (convertFP32ToFP16(expectedOutput0[i * K + j]) != convertFP32ToFP16(input[gotOutput1[i * K + j] + i * lengthRow])); | ||||
|             if (condition) { | ||||
|                     MNN_PRINT("Conflict: Number %d. Value Correct is %f. Value Computed is %f.\n", i * K + j, convertFP32ToFP16(expectedOutput0[i * K + j]), convertFP32ToFP16(input[gotOutput1[i * K + j] + i * lengthRow])); | ||||
|                     return false; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     return true; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| bool checkIndicesFloat(const float * input, const float * expectedOutput0, const int * gotOutput1, const int K, const int numRow, const int lengthRow) { | ||||
|     for (int i = 0; i < numRow; i++) { | ||||
|         for (int j = 0; j < K; j++) { | ||||
|             bool condition = (expectedOutput0[i * K + j] != input[gotOutput1[i * K + j] + i * lengthRow]); | ||||
|             if (condition) { | ||||
|                     MNN_PRINT("Conflict: Number %d. Value Correct is %f. Value Computed is %f.\n", i * K + j, expectedOutput0[i * K + j], input[gotOutput1[i * K + j] + i * lengthRow]); | ||||
|                     return false; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     return true; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| void printTimeCost(uint64_t timeCost) { | ||||
|     uint64_t seconds = timeCost / 1000000; | ||||
|     uint64_t microseconds = timeCost % 1000000; | ||||
|     MNN_PRINT("%lu s %lu ms\n", seconds, microseconds / 1000); | ||||
| 
 | ||||
|     return; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| class TopKV2Test : public MNNTestCase { | ||||
| public: | ||||
|     virtual ~TopKV2Test() = default; | ||||
| 
 | ||||
|     virtual bool run(int precision) { | ||||
|         // set params
 | ||||
|         const int K = 10; | ||||
|         const int numRow = 180; | ||||
|          | ||||
|         const int lengthRow = 21491; | ||||
| 
 | ||||
|         // set input
 | ||||
|         VARP input0 = _Input({numRow, lengthRow}, NCHW, halide_type_of<float>()); | ||||
|         VARP input1 = _Input({1}, NCHW, halide_type_of<int>()); | ||||
|         RandomInitFloat(input0->writeMap<float>(), numRow * lengthRow); | ||||
|         SetK(input1->writeMap<int>(), K); | ||||
| 
 | ||||
|         auto timeStart = getTimeInUs(); | ||||
|         // calculate gotOutput
 | ||||
|         auto res = _TopKV2(input0, input1); | ||||
|         VARP output0 = res[0]; | ||||
|         VARP output1 = res[1]; | ||||
|         auto gotOutput0                        = output0->readMap<float>(); | ||||
|         auto gotOutput1                        = output1->readMap<int>(); | ||||
|         auto timeEnd = getTimeInUs(); | ||||
|         auto timeCost = timeEnd - timeStart; | ||||
| 
 | ||||
|         // calculate expectedOutput
 | ||||
|         std::vector<float> expectedOutput0(numRow * K); | ||||
|         std::vector<int> expectedOutput1(numRow * K); | ||||
|         CpuKernelAllRows<int, float>(input0->writeMap<float>(), expectedOutput1.data(), expectedOutput0.data(), K, lengthRow, numRow, 1); | ||||
| 
 | ||||
|         printTimeCost(timeCost); | ||||
| 
 | ||||
|         // check values
 | ||||
|         float errorScale = precision <= MNN::BackendConfig::Precision_High ? 1 : 20; | ||||
|         if (!checkVectorByRelativeError<float>(gotOutput0, expectedOutput0.data(), numRow * K, 0.001 * errorScale)) { | ||||
|             MNN_ERROR("TopKV2 test failed!\n"); | ||||
|             return false; | ||||
|         } | ||||
| 
 | ||||
|         // check indices
 | ||||
|         if (precision <= 1) { | ||||
|             if (!checkIndicesFloat(input0->readMap<float>(), expectedOutput0.data(), gotOutput1, K, numRow, lengthRow)) { | ||||
|                 MNN_ERROR("TopKV2 test failed!\n"); | ||||
|                 return false; | ||||
|             } | ||||
|         } else if (precision == 2) { | ||||
|             if (!checkIndicesHalf(input0->readMap<float>(), expectedOutput0.data(), gotOutput1, K, numRow, lengthRow)) { | ||||
|                 MNN_ERROR("TopKV2 test failed!\n"); | ||||
|                 return false; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         return true; | ||||
|     } | ||||
| 
 | ||||
| }; | ||||
| 
 | ||||
| 
 | ||||
| MNNTestSuiteRegister(TopKV2Test, "op/TopKV2"); | ||||
|  | @ -158,7 +158,7 @@ int writeFb(std::unique_ptr<MNN::NetT>& netT, const std::string& MNNModelFile, c | |||
|         output.write((const char*)bufferOutput, sizeOutput); | ||||
|     } | ||||
|     if (!netT->subgraphs.empty()) { | ||||
|         MNN_PRINT("The model has subgraphs, please use MNN::Module to run it\n"); | ||||
|         MNN_PRINT("The model has subgraphs, please use MNN::Express::Module to run it\n"); | ||||
|     } | ||||
| 
 | ||||
| #ifdef MNN_DUMP_SUBGRAPH | ||||
|  |  | |||
|  | @ -21,58 +21,6 @@ | |||
| #include "onnxConverter.hpp" | ||||
| #include "onnxOpConverter.hpp" | ||||
| 
 | ||||
| std::vector<int> topoSort(const ::onnx::GraphProto& onnxGraph) { | ||||
|     std::vector<int> idxMap; | ||||
|     const int nodeCount   = onnxGraph.node_size(); | ||||
|     std::map<std::string, int> outputMap; | ||||
|     std::map<int, std::vector<int>> graph; // key --[in]--> values
 | ||||
|     std::vector<int> inDegree(nodeCount); | ||||
|     // build Graph and inDegree
 | ||||
|     for (int i = 0; i < nodeCount; ++i) { | ||||
|         const auto& onnxNode = onnxGraph.node(i); | ||||
|         if (onnxNode.op_type() == "Loop" || onnxNode.op_type() == "If") { | ||||
|             return idxMap; | ||||
|         } | ||||
|         for (int k = 0; k < onnxNode.output_size(); k++) { | ||||
|             outputMap.insert(std::make_pair(onnxNode.output(k), i)); | ||||
|         } | ||||
|     } | ||||
|     for (int i = 0; i < nodeCount; ++i) { | ||||
|         const auto& onnxNode = onnxGraph.node(i); | ||||
|         for (int k = 0; k < onnxNode.input_size(); k++) { | ||||
|             auto inputName = onnxNode.input(k); | ||||
|             auto iter = outputMap.find(inputName); | ||||
|             if (iter != outputMap.end()) { | ||||
|                 graph[iter->second].push_back(i); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     for (auto node : graph) { | ||||
|         for (auto output : node.second) { | ||||
|             inDegree[output]++; | ||||
|         } | ||||
|     } | ||||
|     // topo sort
 | ||||
|     std::queue<int> validNode; | ||||
|     for (int i = 0; i < nodeCount; i++) { | ||||
|         if (!inDegree[i]) { | ||||
|             validNode.push(i); | ||||
|         } | ||||
|     } | ||||
|     while (!validNode.empty()) { | ||||
|         int node = validNode.front(); | ||||
|         validNode.pop(); | ||||
|         idxMap.push_back(node); | ||||
|         for (auto succ : graph[node]) { | ||||
|             if (--inDegree[succ] == 0) { | ||||
|                 validNode.push(succ); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     MNN_ASSERT(idxMap.size() == nodeCount); | ||||
|     return idxMap; | ||||
| } | ||||
| 
 | ||||
| int onnx2MNNNet(const std::string inputModel, const std::string bizCode, | ||||
|                 std::unique_ptr<MNN::NetT>& netT) { | ||||
|     std::string modelDir; | ||||
|  | @ -117,6 +65,7 @@ int onnx2MNNNet(const std::string inputModel, const std::string bizCode, | |||
|             MNNOp->main.type = MNN::OpParameter_Input; | ||||
|             auto inputParam  = new MNN::InputT; | ||||
|             const auto it    = inputs.find(iter.first); | ||||
|             //FUNC_PRINT_ALL(iter.first.c_str(), s);
 | ||||
|             DCHECK(it != inputs.end()) << "Input Paramter ERROR ==> " << iter.first; | ||||
|             const auto& tensorInfo = (it->second)->type().tensor_type(); | ||||
|             const int inputDimSize = tensorInfo.shape().dim_size(); | ||||
|  | @ -138,8 +87,24 @@ int onnx2MNNNet(const std::string inputModel, const std::string bizCode, | |||
|     } | ||||
| 
 | ||||
|     // onnx model not all topo sort graph, sort it
 | ||||
|     std::vector<int> idxMap = topoSort(onnxGraph); | ||||
|     std::vector<int> idxMap = OnnxScope::topoSort(onnxGraph); | ||||
| 
 | ||||
|     auto makeConst = [&](const std::string& inputName) { | ||||
|         const auto it         = initializers.find(inputName); | ||||
|         if (it != initializers.end() && scope->lookupTensor(it->first) == -1) { | ||||
|             // Create const Op
 | ||||
|             MNN::OpT* constOp   = new MNN::OpT; | ||||
|             constOp->type       = MNN::OpType_Const; | ||||
|             constOp->main.type  = MNN::OpParameter_Blob; | ||||
|             constOp->main.value = onnxOpConverter::convertTensorToBlob(it->second, modelDir); | ||||
|             constOp->name    = it->first; | ||||
|             constOp->outputIndexes.push_back(scope->declareTensor(it->first)); | ||||
|             netT->oplists.emplace_back(constOp); | ||||
|         } | ||||
|     }; | ||||
|     for (int i=0; i<onnxGraph.output_size(); ++i) { | ||||
|         makeConst(onnxGraph.output(i).name()); | ||||
|     } | ||||
|     // onnx node ==> MNN node
 | ||||
|     for (int idx = 0; idx < nodeCount; ++idx) { | ||||
|         int i = idxMap.size() == nodeCount ? idxMap[idx] : idx; | ||||
|  | @ -158,17 +123,7 @@ int onnx2MNNNet(const std::string inputModel, const std::string bizCode, | |||
|         // convert initializer to be Constant node(op)
 | ||||
|         for (int k = 0; k < onnxNode.input_size(); ++k) { | ||||
|             const auto& inputName = onnxNode.input(k); | ||||
|             const auto it         = initializers.find(inputName); | ||||
|             if (it != initializers.end() && scope->lookupTensor(it->first) == -1) { | ||||
|                 // Create const Op
 | ||||
|                 MNN::OpT* constOp   = new MNN::OpT; | ||||
|                 constOp->type       = MNN::OpType_Const; | ||||
|                 constOp->main.type  = MNN::OpParameter_Blob; | ||||
|                 constOp->main.value = onnxOpConverter::convertTensorToBlob(it->second, modelDir); | ||||
|                 constOp->name    = it->first; | ||||
|                 constOp->outputIndexes.push_back(scope->declareTensor(it->first)); | ||||
|                 netT->oplists.emplace_back(constOp); | ||||
|             } | ||||
|             makeConst(inputName); | ||||
|         } | ||||
| 
 | ||||
|         // build input and output
 | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ | |||
| //  Copyright © 2018, Alibaba Group Holding Limited
 | ||||
| //
 | ||||
| 
 | ||||
| #include <queue> | ||||
| #include "onnxOpConverter.hpp" | ||||
| #include "OpCount.hpp" | ||||
| #include "OnnxTmpGraph.hpp" | ||||
|  | @ -20,6 +21,67 @@ static int32_t _limit(int64_t i64) { | |||
|     } | ||||
|     return i64; | ||||
| } | ||||
| std::vector<int> OnnxScope::topoSort(const onnx::GraphProto& onnxGraph) { | ||||
|     std::vector<int> idxMap; | ||||
|     const int nodeCount   = onnxGraph.node_size(); | ||||
|     std::map<std::string, int> outputMap; | ||||
|     std::map<int, std::vector<int>> graph; // key --[in]--> values
 | ||||
|     std::vector<int> inDegree(nodeCount); | ||||
|     // build Graph and inDegree
 | ||||
|     for (int i = 0; i < nodeCount; ++i) { | ||||
|         const auto& onnxNode = onnxGraph.node(i); | ||||
|         for (int k = 0; k < onnxNode.output_size(); k++) { | ||||
|             outputMap.insert(std::make_pair(onnxNode.output(k), i)); | ||||
|         } | ||||
|     } | ||||
|     for (int i = 0; i < nodeCount; ++i) { | ||||
|         const auto& onnxNode = onnxGraph.node(i); | ||||
|         for (int k = 0; k < onnxNode.input_size(); k++) { | ||||
|             auto inputName = onnxNode.input(k); | ||||
|             auto iter = outputMap.find(inputName); | ||||
|             if (iter != outputMap.end()) { | ||||
|                 graph[iter->second].push_back(i); | ||||
|             } | ||||
|         } | ||||
|         if (onnxNode.op_type() == "Loop") { | ||||
|             auto& body = onnxNode.attribute(0).g(); | ||||
|             for (int j=0; j<body.node_size(); ++j) { | ||||
|                 for (int k=0; k<body.node(j).input_size(); ++k) { | ||||
|                     auto inputName = body.node(j).input(k); | ||||
|                     auto iter = outputMap.find(inputName); | ||||
|                     if (iter != outputMap.end()) { | ||||
|                         graph[iter->second].push_back(i); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     for (auto node : graph) { | ||||
|         for (auto output : node.second) { | ||||
|             inDegree[output]++; | ||||
|         } | ||||
|     } | ||||
|     // topo sort
 | ||||
|     std::queue<int> validNode; | ||||
|     for (int i = 0; i < nodeCount; i++) { | ||||
|         if (!inDegree[i]) { | ||||
|             validNode.push(i); | ||||
|         } | ||||
|     } | ||||
|     while (!validNode.empty()) { | ||||
|         int node = validNode.front(); | ||||
|         validNode.pop(); | ||||
|         idxMap.push_back(node); | ||||
|         for (auto succ : graph[node]) { | ||||
|             if (--inDegree[succ] == 0) { | ||||
|                 validNode.push(succ); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     MNN_ASSERT(idxMap.size() == nodeCount); | ||||
|     return idxMap; | ||||
| } | ||||
| 
 | ||||
| class DefaultonnxOpConverter : public onnxOpConverter { | ||||
| public: | ||||
|     virtual void run(MNN::OpT* dstOp, const onnx::NodeProto* onnxNode, | ||||
|  | @ -107,6 +169,7 @@ onnxOpConverter* onnxOpConverterSuit::search(const std::string& name) { | |||
| MNN::DataType onnxOpConverter::convertDataType(int32_t itype) { | ||||
|     static std::map<::onnx::TensorProto_DataType, MNN::DataType> dataTypeMap{ | ||||
|         {onnx::TensorProto_DataType_FLOAT, MNN::DataType_DT_FLOAT}, | ||||
|         {onnx::TensorProto_DataType_FLOAT16, MNN::DataType_DT_HALF}, | ||||
|         {onnx::TensorProto_DataType_INT8, MNN::DataType_DT_INT8}, | ||||
|         {onnx::TensorProto_DataType_INT32, MNN::DataType_DT_INT32}, | ||||
|         {onnx::TensorProto_DataType_INT64, MNN::DataType_DT_INT32},  // For compability, use int32 instead of int64
 | ||||
|  | @ -188,6 +251,8 @@ MNN::BlobT* onnxOpConverter::convertTensorToBlob(const onnx::TensorProto* consta | |||
|         CASE_DATA_TYPE(onnx::TensorProto_DataType_DOUBLE, double); | ||||
|         CASE_DATA_TYPE(onnx::TensorProto_DataType_INT64, int64); | ||||
|         CASE_DATA_TYPE(onnx::TensorProto_DataType_INT32, int32); | ||||
|         CASE_DATA_TYPE(onnx::TensorProto_DataType_UINT8, int32); | ||||
|         CASE_DATA_TYPE(onnx::TensorProto_DataType_INT8, int32); | ||||
|         CASE_DATA_TYPE(onnx::TensorProto_DataType_FLOAT, float); | ||||
|         CASE_DATA_TYPE(onnx::TensorProto_DataType_UINT64, uint64); | ||||
|         CASE_DATA_TYPE(onnx::TensorProto_DataType_BOOL, int32); | ||||
|  | @ -265,13 +330,22 @@ MNN::BlobT* onnxOpConverter::convertTensorToBlob(const onnx::TensorProto* consta | |||
|             break; | ||||
|         } | ||||
|         case onnx::TensorProto_DataType_UINT8: { | ||||
|             auto source = (uint8_t*)tensor_content; | ||||
|             constantParam->uint8s.resize(dataSize); | ||||
|             for (int i = 0; i < dataSize; ++i) { | ||||
|                 constantParam->uint8s[i] = source[i]; | ||||
|             if (constantTp->int32_data_size() > 0) { | ||||
|                 auto source = (int32_t*)tensor_content; | ||||
|                 for (int i = 0; i < dataSize; ++i) { | ||||
|                     constantParam->uint8s[i] = source[i]; | ||||
|                 } | ||||
|             } else { | ||||
|                 ::memcpy(constantParam->uint8s.data(), tensor_content, dataSize * sizeof(uint8_t)); | ||||
|             } | ||||
|             break; | ||||
|         } | ||||
|         case onnx::TensorProto_DataType_FLOAT16: { | ||||
|             constantParam->uint8s.resize(dataSize * sizeof(int16_t)); | ||||
|             ::memcpy(constantParam->uint8s.data(), tensor_content, dataSize * sizeof(int16_t)); | ||||
|             break; | ||||
|         } | ||||
|         case onnx::TensorProto_DataType_FLOAT: { | ||||
|             float* tempFloatData = (float*)tensor_content; | ||||
|             constantParam->float32s.resize(dataSize); | ||||
|  | @ -411,7 +485,33 @@ std::vector<std::string> OnnxScope::buildSubGraph(const onnx::GraphProto* graph, | |||
|     } | ||||
|     // Find Extra Input from outside graph
 | ||||
|     std::map<std::string, int> outsideInputs; | ||||
|     for (int i = 0; i < graph->node_size(); i++) { | ||||
|     auto findConst = [&](const std::string& name) { | ||||
|         if (scope->lookupTensor(name) >= 0) { | ||||
|             return; | ||||
|         } | ||||
|         // onnx subgraph may use tensor from initializers in outter level graph, recurrsive find it
 | ||||
|         for (auto curScope = scope.get(); curScope != nullptr; ) { | ||||
|             const auto& curInits = curScope->mInitializers; | ||||
|             const auto it = curInits.find(name); | ||||
|             if (it != curInits.end()) { | ||||
|                 // Create const Op
 | ||||
|                 MNN::OpT* constOp   = new MNN::OpT; | ||||
|                 constOp->type       = MNN::OpType_Const; | ||||
|                 constOp->main.type  = MNN::OpParameter_Blob; | ||||
|                 constOp->main.value = onnxOpConverter::convertTensorToBlob(it->second); | ||||
|                 constOp->name    = it->first; | ||||
|                 constOp->outputIndexes.push_back(scope->declareTensor(it->first)); | ||||
|                 subgraph->nodes.emplace_back(constOp); | ||||
|                 break; | ||||
|             } | ||||
|             curScope = reinterpret_cast<decltype(curScope)>(curScope->mParent); | ||||
|         } | ||||
|     }; | ||||
|     for (int i=0; i<graph->output_size(); ++i) { | ||||
|         findConst(graph->output(i).name()); | ||||
|     } | ||||
|     auto indexes = OnnxScope::topoSort(*graph); | ||||
|     for (auto i : indexes) { | ||||
|         const auto& onnxNode = graph->node(i); | ||||
|         const auto& opType   = onnxNode.op_type(); | ||||
|         // name maybe null, use the first output name as node-name
 | ||||
|  | @ -423,26 +523,7 @@ std::vector<std::string> OnnxScope::buildSubGraph(const onnx::GraphProto* graph, | |||
|         MNNOp->main.type = opConverter->type(); | ||||
|         for (int k = 0; k < onnxNode.input_size(); ++k) { | ||||
|             const auto& inputName = onnxNode.input(k); | ||||
|             if (scope->lookupTensor(inputName) >= 0) { | ||||
|                 continue; | ||||
|             } | ||||
|             // onnx subgraph may use tensor from initializers in outter level graph, recurrsive find it
 | ||||
|             for (auto curScope = scope.get(); curScope != nullptr; ) { | ||||
|                 const auto& curInits = curScope->mInitializers; | ||||
|                 const auto it = curInits.find(inputName); | ||||
|                 if (it != curInits.end()) { | ||||
|                     // Create const Op
 | ||||
|                     MNN::OpT* constOp   = new MNN::OpT; | ||||
|                     constOp->type       = MNN::OpType_Const; | ||||
|                     constOp->main.type  = MNN::OpParameter_Blob; | ||||
|                     constOp->main.value = onnxOpConverter::convertTensorToBlob(it->second); | ||||
|                     constOp->name    = it->first; | ||||
|                     constOp->outputIndexes.push_back(scope->declareTensor(it->first)); | ||||
|                     subgraph->nodes.emplace_back(constOp); | ||||
|                     break; | ||||
|                 } | ||||
|                 curScope = reinterpret_cast<decltype(curScope)>(curScope->mParent); | ||||
|             } | ||||
|             findConst(inputName); | ||||
|         } | ||||
|         // build input and output
 | ||||
|         for (int k = 0; k < onnxNode.input_size(); k++) { | ||||
|  | @ -453,6 +534,7 @@ std::vector<std::string> OnnxScope::buildSubGraph(const onnx::GraphProto* graph, | |||
|                 if (iter == outsideInputs.end()) { | ||||
|                     idx = scope->declareTensor(inputName); | ||||
|                     std::unique_ptr<MNN::OpT> inputOp(new MNN::OpT); | ||||
|                     //FUNC_PRINT_ALL(inputName.c_str(), s);
 | ||||
|                     inputOp->name      = inputName; | ||||
|                     inputOp->type      = MNN::OpType_Input; | ||||
|                     inputOp->main.type = MNN::OpParameter_Input; | ||||
|  | @ -501,9 +583,10 @@ std::vector<std::string> OnnxScope::buildSubGraph(const onnx::GraphProto* graph, | |||
|     int N = graph->input_size() - 2, K = graph->output_size() - N - 1; | ||||
|     for (int i = 0; i < N + 1; i++) { | ||||
|         int idx = scope->lookupTensor(graph->output(i).name()); | ||||
|         MNN_ASSERT(idx >= 0); | ||||
|         if (idx >= 0) { | ||||
|             subgraph->outputs.push_back(idx); | ||||
|         } else { | ||||
|             FUNC_PRINT_ALL(graph->output(i).name().c_str(), s); | ||||
|         } | ||||
|     } | ||||
|     std::vector<std::string> resOutside; | ||||
|  |  | |||
|  | @ -19,6 +19,7 @@ | |||
| 
 | ||||
| class OnnxScope : public ConverterScope { | ||||
| public: | ||||
|     static std::vector<int> topoSort(const onnx::GraphProto& onnxGraph); | ||||
|     OnnxScope(const onnx::GraphProto* graph, MNN::NetT* net) : mGraph(graph), ConverterScope(net) { onnxInit(); } | ||||
|     OnnxScope(const onnx::GraphProto* graph, MNN::SubGraphProtoT* subnet, MNN::NetT* net, | ||||
|               OnnxScope* parent) : mGraph(graph), ConverterScope(subnet, net, parent) { onnxInit(); } | ||||
|  |  | |||
|  | @ -342,7 +342,12 @@ static auto gRegister = []() { | |||
|                         continue; | ||||
|                     } | ||||
|                     auto inputVar = inputVarIter->second; | ||||
|                     auto newVar = _Gelu(inputVar); | ||||
|                     std::unique_ptr<MNN::OpT> newUnary(new MNN::OpT); | ||||
|                     newUnary->type = OpType_UnaryOp; | ||||
|                     newUnary->main.type = OpParameter_UnaryOp; | ||||
|                     newUnary->main.value = new UnaryOpT; | ||||
|                     newUnary->main.AsUnaryOp()->opType = UnaryOpOperation_GELU_STANDARD; | ||||
|                     auto newVar = MNN::Express::Variable::create(MNN::Express::Expr::create(newUnary.get(), {inputVar})); | ||||
|                     newVar->setName(expr->outputName(0)); | ||||
|                     Expr::replace(expr, newVar->expr().first); | ||||
|                     return true; | ||||
|  |  | |||
|  | @ -0,0 +1,52 @@ | |||
| //
 | ||||
| //  DynamicQuantizeLinear.cpp
 | ||||
| //  MNNConverter
 | ||||
| //
 | ||||
| //  Created by MNN on 2023/08/14.
 | ||||
| //  Copyright © 2018, Alibaba Group Holding Limited
 | ||||
| //
 | ||||
| 
 | ||||
| #include <MNN/expr/ExprCreator.hpp> | ||||
| #include "MNN_generated.h" | ||||
| #include "OnnxExtraManager.hpp" | ||||
| 
 | ||||
| namespace MNN { | ||||
| namespace Express { | ||||
| // Ref from https://github.com/onnx/onnx/blob/main/docs/Operators.md#DynamicQuantizeLinear
 | ||||
| class OnnxDynamicQuantizeLinearTransform : public OnnxExtraManager::Transform { | ||||
| public: | ||||
|     virtual EXPRP onExecute(EXPRP expr) const override { | ||||
|         auto x   = expr->inputs()[0]; | ||||
|         auto range = _Scalar<float>(1.0f/255.0f); | ||||
|         auto maxX = _ReduceMax(x); | ||||
|         auto minX = _ReduceMin(x); | ||||
|         auto scale = (maxX - minX) * range; | ||||
|         auto scaleReq = _Reciprocal(scale); | ||||
|         // Qmin = 0
 | ||||
|         auto interZero = _Negative(minX * scaleReq); | ||||
|         auto zeroFloat = _Round(_Relu6(interZero, 0.0f, 255.0f)); | ||||
|         auto zero = _Cast<uint8_t>(zeroFloat); | ||||
|         auto y = _Cast<uint8_t>(_Round(_Relu6(_Round(x * scaleReq) + zeroFloat, 0.0f, 255.0f))); | ||||
|         std::unique_ptr<MNN::OpT> iden(new MNN::OpT); | ||||
|         iden->type = OpType_Identity; | ||||
|          | ||||
| 
 | ||||
|         auto newExpr = MNN::Express::Expr::create(iden.get(), {y, scale, zero}, 3); | ||||
|         newExpr->setName(expr->name()); | ||||
|         for (int i=0; i<3; ++i) { | ||||
|             auto v = MNN::Express::Variable::create(newExpr, i); | ||||
|             v->setName(expr->outputName(i)); | ||||
|         } | ||||
|         return newExpr; | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| 
 | ||||
| static auto gRegister = []() { | ||||
|     OnnxExtraManager::get()->insert("DynamicQuantizeLinear", | ||||
|                                     std::shared_ptr<OnnxExtraManager::Transform>(new OnnxDynamicQuantizeLinearTransform)); | ||||
|     return true; | ||||
| }(); | ||||
| 
 | ||||
| } // namespace Express
 | ||||
| } // namespace MNN
 | ||||
|  | @ -0,0 +1,45 @@ | |||
| //
 | ||||
| //  MatMulInteger.cpp
 | ||||
| //  MNNConverter
 | ||||
| //
 | ||||
| //  Created by MNN on 2023/08/14.
 | ||||
| //  Copyright © 2018, Alibaba Group Holding Limited
 | ||||
| //
 | ||||
| 
 | ||||
| #include <MNN/expr/ExprCreator.hpp> | ||||
| #include "MNN_generated.h" | ||||
| #include "OnnxExtraManager.hpp" | ||||
| 
 | ||||
| namespace MNN { | ||||
| namespace Express { | ||||
| // Ref from https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMulInteger
 | ||||
| // Use float instead of uint8 to complete it
 | ||||
| class OnnxMatMulIntegerTransform : public OnnxExtraManager::Transform { | ||||
| public: | ||||
|     virtual EXPRP onExecute(EXPRP expr) const override { | ||||
|         auto inputs = expr->inputs(); | ||||
|         auto x = inputs[0]; | ||||
|         auto y = inputs[1]; | ||||
|         x = _Cast<float>(x); | ||||
|         y = _Cast<float>(y); | ||||
|         if (inputs.size() > 2) { | ||||
|             x = x - _Cast<float>(inputs[2]); | ||||
|             y = y - _Cast<float>(inputs[3]); | ||||
|         } | ||||
|         auto z = _MatMul(x, y); | ||||
|         auto zInt = _Cast<int32_t>(z); | ||||
|         auto newExpr = zInt->expr().first; | ||||
|         newExpr->setName(expr->name()); | ||||
|         return newExpr; | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| 
 | ||||
| static auto gRegister = []() { | ||||
|     OnnxExtraManager::get()->insert("MatMulInteger", | ||||
|                                     std::shared_ptr<OnnxExtraManager::Transform>(new OnnxMatMulIntegerTransform)); | ||||
|     return true; | ||||
| }(); | ||||
| 
 | ||||
| } // namespace Express
 | ||||
| } // namespace MNN
 | ||||
|  | @ -358,13 +358,15 @@ public: | |||
|         if (config->keepInputFormat) { | ||||
|             // Change Output
 | ||||
|             auto& outputs = mNet->outputName; | ||||
|             std::vector<std::unique_ptr<MNN::OpT>> extraOp; | ||||
|             for (auto& op : mNet->oplists) { | ||||
|                 for (int idx : op->outputIndexes) { | ||||
|                     for (int j = 0; j < outputs.size(); j++) { | ||||
|                         if (mNet->tensorName[idx] == outputs[j]) { | ||||
|                             auto outputFormat = tensorFormats[idx]; | ||||
|                             if (outputFormat == MNN_DATA_FORMAT_NC4HW4) { | ||||
|                                 auto newOutputName = outputs[j] + "__tr"; | ||||
|                                 auto newOutputName = outputs[j] + "__before_tr"; | ||||
|                                 mNet->tensorName[idx] = newOutputName; | ||||
|                                 // Append a convert op
 | ||||
|                                 MNN::OpT* transformOp = new MNN::OpT; | ||||
|                                 MNN::TensorConvertInfoT* tc = new MNN::TensorConvertInfoT; | ||||
|  | @ -374,17 +376,20 @@ public: | |||
|                                 transformOp->main.value     = tc; | ||||
|                                 transformOp->name           = newOutputName; | ||||
|                                 transformOp->inputIndexes.push_back(idx); | ||||
|                                 transformOp->outputIndexes.push_back(mNet->tensorName.size()); | ||||
|                                 int newOutputIndex = (int)mNet->tensorName.size(); | ||||
|                                 transformOp->outputIndexes.push_back(newOutputIndex); | ||||
|                                 tensorFormats.push_back(originTensorType); | ||||
|                                 mNet->tensorName.push_back(transformOp->name); | ||||
|                                 mNet->tensorName.push_back(outputs[j]); | ||||
|                                 transformOp->type   = MNN::OpType_ConvertTensor; | ||||
|                                 outputs[j] = newOutputName; | ||||
|                                 mNet->oplists.emplace_back(transformOp); | ||||
|                                 extraOp.emplace_back(transformOp); | ||||
|                             } | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             for (auto&& op : extraOp) { | ||||
|                 mNet->oplists.emplace_back(std::move(op)); | ||||
|             } | ||||
|         } else { | ||||
|             // Change Input
 | ||||
|             for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end(); iter++) { | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ | |||
| #include "cv/imgproc/geometric.hpp" | ||||
| #include <MNN/expr/NeuralNetWorkOp.hpp> | ||||
| #include <MNN/expr/MathOp.hpp> | ||||
| #include <cmath> | ||||
| 
 | ||||
| namespace MNN { | ||||
| namespace CV { | ||||
|  |  | |||
|  | @ -561,7 +561,7 @@ public: | |||
|                     mWinogradAttr.reset(new WinogradInt8Attr); | ||||
|                     mWinogradTransInputMaxPos = addParameter(mWinogradTransInputMax); | ||||
|                     mWinogradTransInputMinPos = addParameter(mWinogradTransInputMin); | ||||
|                     mWinogradTransWeightScalePos = addParameter(nullptr); | ||||
|                     mWinogradTransWeightScalePos = addParameter(mWinogradTransInputMax); | ||||
|                 } | ||||
|                 setName(mConvParameter.name); | ||||
|                 modules[i] = nullptr; | ||||
|  | @ -675,6 +675,10 @@ public: | |||
|         if (nullptr == originValue) { | ||||
|             return newValue; | ||||
|         } | ||||
|         auto ptr = originValue->readMap<float>(); | ||||
|         if (ptr[0] == -100.0f) { | ||||
|             return newValue; | ||||
|         } | ||||
|         switch (mScaleUpdateMethod) { | ||||
|             case NN::MovingAverage: | ||||
|                 return originValue * _Scalar<float>(mMomentum) + newValue * _Scalar<float>(1.0f-mMomentum); | ||||
|  | @ -700,7 +704,7 @@ public: | |||
|         maxUnit = std::max(std::min(maxUnit, MAX_UNIT), MIN_UNIT); | ||||
| 
 | ||||
|         auto units = std::pair<int, int>({0, 0}); | ||||
|         float maxRate = 1.0f, originCost = outH * outW * inChannel * outChannel * kernelH * kernelW; | ||||
|         float maxRate = 2.0f, originCost = outH * outW * inChannel * outChannel * kernelH * kernelW; | ||||
|         std::set<int> supportSu{4, 6}; | ||||
|         for (int uh = MIN_UNIT; uh <= maxUnit; ++uh) { | ||||
|             for (int uw = MIN_UNIT; uw <= maxUnit; ++uw) { | ||||
|  | @ -1097,8 +1101,8 @@ private: | |||
|     NN::ScaleUpdateMethod mScaleUpdateMethod; | ||||
|     bool mAccumulateToInt16 = false; | ||||
|     std::shared_ptr<WinogradInt8Attr> mWinogradAttr; | ||||
|     VARP mWinogradTransInputMin = nullptr; | ||||
|     VARP mWinogradTransInputMax = nullptr; | ||||
|     VARP mWinogradTransInputMin = _Const(-100.f); | ||||
|     VARP mWinogradTransInputMax = _Const(-100.f); | ||||
|     int mWinogradTransInputMinPos = -1; | ||||
|     int mWinogradTransInputMaxPos = -1; | ||||
|     int mWinogradTransWeightScalePos = -1; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue