mirror of https://github.com/alibaba/MNN.git
[MNN:Sync] Sync Internal 2.6.3
This commit is contained in:
parent
c603c52955
commit
98ba00c2f3
|
@ -271,7 +271,7 @@ struct GemmBatchedIdentityThreadblockSwizzle {
|
|||
return GemmCoord(
|
||||
(problem_size.m() + tile_size.m() - 1) / tile_size.m(),
|
||||
(problem_size.n() + tile_size.n() - 1) / tile_size.n(),
|
||||
batch_count % (1 << 16));
|
||||
batch_count >= 65536 ? 65535 : batch_count);
|
||||
}
|
||||
|
||||
/// Computes CUDA grid dimensions given a size in units of logical tiles
|
||||
|
|
|
@ -22,6 +22,9 @@ std::string OpenCLTarget::type() {
|
|||
}
|
||||
std::string OpenCLTarget::macro() {
|
||||
return
|
||||
"#ifdef MNN_SUPPORT_FP16\n"
|
||||
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
|
||||
"#endif\n"
|
||||
"#define OFFSET_CHECK\\\n"
|
||||
"\tconst int c = get_global_id(0), w = get_global_id(1), hb = get_global_id(2);\\\n"
|
||||
"\tif (c >= global_size_dim0 || w >= global_size_dim1 || hb >= global_size_dim2) { return; }\\\n"
|
||||
|
@ -113,61 +116,61 @@ std::string OpenCLTarget::codegen(std::vector<std::string>& inputs, const Comman
|
|||
ss << inpName << "=" << operand << " * " << operand;
|
||||
break;
|
||||
case UnaryOpOperation_ERF:
|
||||
ss << inpName << "=erf(convert_float4(" << operand << "))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(erf(convert_float4(" << operand << ")))";
|
||||
break;
|
||||
case UnaryOpOperation_ERFC:
|
||||
ss << inpName << "=erfc(convert_float4(" << operand << "))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(erfc(convert_float4(" << operand << ")))";
|
||||
break;
|
||||
case UnaryOpOperation_SQRT:
|
||||
ss << inpName << "=sqrt(convert_float4(" << operand << "))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(sqrt(convert_float4(" << operand << ")))";
|
||||
break;
|
||||
case UnaryOpOperation_RSQRT:
|
||||
ss << inpName << "=rsqrt(convert_float4(" << operand << "))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(rsqrt(convert_float4(" << operand << ")))";
|
||||
break;
|
||||
case UnaryOpOperation_ABS:
|
||||
ss << inpName << "=fabs(convert_float4(" << operand << "))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(fabs(convert_float4(" << operand << ")))";
|
||||
break;
|
||||
case UnaryOpOperation_SIN:
|
||||
ss << inpName << "=sin(convert_float4(" << operand << "))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(sin(convert_float4(" << operand << ")))";
|
||||
break;
|
||||
case UnaryOpOperation_COS:
|
||||
ss << inpName << "=cos(convert_float4(" << operand << "))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(cos(convert_float4(" << operand << ")))";
|
||||
break;
|
||||
case UnaryOpOperation_SIGN:
|
||||
ss << inpName << "=sign(convert_float4(" << operand << "))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(sign(convert_float4(" << operand << ")))";
|
||||
break;
|
||||
case UnaryOpOperation_EXP:
|
||||
ss << inpName << "=exp(convert_float4(" << operand << "))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(exp(convert_float4(" << operand << ")))";
|
||||
break;
|
||||
case UnaryOpOperation_NEG:
|
||||
ss << inpName << "=-(" << operand << ")";
|
||||
break;
|
||||
case UnaryOpOperation_TAN:
|
||||
ss << inpName << "=tan(convert_float4(" << operand << "))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(tan(convert_float4(" << operand << ")))";
|
||||
break;
|
||||
case UnaryOpOperation_CEIL:
|
||||
ss << inpName << "=ceil(convert_float4(" << operand << "))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(ceil(convert_float4(" << operand << ")))";
|
||||
break;
|
||||
case UnaryOpOperation_LOG1P:
|
||||
ss << inpName << "=log1p(convert_float4(" << operand << "))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(log1p(convert_float4(" << operand << ")))";
|
||||
break;
|
||||
case UnaryOpOperation_FLOOR:
|
||||
ss << inpName << "=floor(convert_float4(" << operand << "))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(floor(convert_float4(" << operand << ")))";
|
||||
break;
|
||||
case UnaryOpOperation_ROUND:
|
||||
ss << inpName << "=round(convert_float4(" << operand << "))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(round(convert_float4(" << operand << ")))";
|
||||
break;
|
||||
case UnaryOpOperation_SIGMOID:
|
||||
ss << inpName << "=native_recip((float4)1+native_exp(convert_float4(-" << operand << ")))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(native_recip((float4)1+native_exp(convert_float4(-" << operand << "))))";
|
||||
break;
|
||||
case UnaryOpOperation_TANH:
|
||||
ss << inpName << "=tanh(convert_float4(" << operand << "))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(tanh(convert_float4(" << operand << ")))";
|
||||
break;
|
||||
case UnaryOpOperation_RECIPROCAL:
|
||||
ss << inpName << "=native_recip(convert_float4(" << operand << "))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(native_recip(convert_float4(" << operand << ")))";
|
||||
break;
|
||||
case UnaryOpOperation_LOG:
|
||||
ss << inpName << "=native_log(convert_float4(" << operand << "+(float4)((float)0.0000001)))";
|
||||
ss << inpName << "=CONVERT_FLOAT4(native_log(convert_float4(" << operand << ")+(float4)((float)0.0000001)))";
|
||||
break;
|
||||
default:
|
||||
MNN_ASSERT(false);
|
||||
|
@ -198,13 +201,13 @@ std::string OpenCLTarget::codegen(std::vector<std::string>& inputs, const Comman
|
|||
return ss.str();
|
||||
}
|
||||
std::string OpenCLTarget::load(const std::string& base, const std::string& offset, const Command* cmd, std::string& inpName) {
|
||||
return "FLOAT4 " + inpName + "=read_imagef(" + base + ", SAMPLER, " + offset + ")";
|
||||
return "FLOAT4 " + inpName + "=RI_F(" + base + ", SAMPLER, " + offset + ")";
|
||||
}
|
||||
std::string OpenCLTarget::loadscalar(const std::string& base, std::string& inpName) {
|
||||
return "FLOAT4 " + inpName + "=((float4)read_imagef(" + base + ", SAMPLER, (int2)(0, 0)).x)";
|
||||
return "FLOAT4 " + inpName + "=(RI_F(" + base + ", SAMPLER, (int2)(0, 0)).x)";
|
||||
}
|
||||
std::string OpenCLTarget::store(const std::string base, const std::string& offset, const std::string& data) {
|
||||
return "write_imagef(" + base + ", " + offset + ", " + data + ");\n";
|
||||
return "WI_F(" + base + ", " + offset + ", " + data + ");\n";
|
||||
}
|
||||
|
||||
std::string OpenCLTarget::proto(const std::string& name, const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, bool hasSingleConvertRaster) {
|
||||
|
|
|
@ -22,7 +22,7 @@ private:
|
|||
std::string store(const std::string base, const std::string& offset, const std::string& data) override;
|
||||
std::string proto(const std::string& name, const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, bool hasSingleConvertRaster = false) override;
|
||||
template <typename T>
|
||||
std::string numval(T t) { return "((float4)" + std::to_string(t) + ")"; }
|
||||
std::string numval(T t) { return "((FLOAT4)" + std::to_string(t) + ")"; }
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -29,7 +29,59 @@ MNN在C++的基础上,增加了Python扩展。扩展单元包括两个部分
|
|||
### MNNTools
|
||||
MNNTools提供目前主要是2个工具,用法可以参考[mnnconvert](../tools/python.html#mnnconvert)和[mnnquant](../tools/python.html#mnnquant)
|
||||
|
||||
## 使用Python Session API
|
||||
## 使用Python Module API
|
||||
### 数据类型
|
||||
Python中的`Module API`与C++中的函数名略有区别,用法相似。主要数据类型如下:
|
||||
- [_Module](../pymnn/_Module.md) 模型实例
|
||||
- [Var](../pymnn/Var.md) 模型的输入输出
|
||||
### 推理流程
|
||||
基本推理流程如下:
|
||||
- [创建Module](../pymnn/nn.html#load-module-from-file-file-name-input-names-output-names-dynamic-shape-mutable-rearrange-backend-memory-mode-power-mode-precision-mode)
|
||||
- 创建输入: 使用`expr`或`numpy`函数创建`Var`即可作为输入
|
||||
- [执行推理](../pymnn/_Module.html#forward-input)
|
||||
- 获取输出: 输出为`Var`类型,可以通过`expr`或`numpy`函数执行后处理
|
||||
### 示例
|
||||
```python
|
||||
import MNN.nn as nn
|
||||
import MNN.cv as cv
|
||||
import MNN.numpy as np
|
||||
import MNN.expr as expr
|
||||
|
||||
# 配置执行后端,线程数,精度等信息;key-vlaue请查看API介绍
|
||||
config = {}
|
||||
config['precision'] = 'low' # 当硬件支持(armv8.2)时使用fp16推理
|
||||
config['backend'] = 0 # CPU
|
||||
config['numThread'] = 4 # 线程数
|
||||
|
||||
rt = nn.create_runtime_manager((config,))
|
||||
# 加载模型创建_Module
|
||||
net = nn.load_module_from_file('mobilenet_v1.mnn', ['data'], ['prob'], runtime_manager=rt)
|
||||
|
||||
# 读取图片
|
||||
image = cv.imread('cat.jpg')
|
||||
# 转换为float32, 形状为[224,224,3]
|
||||
image = cv.resize(image, (224, 224), mean=[103.94, 116.78, 123.68], norm=[0.017, 0.017, 0.017])
|
||||
# 增加batch HWC to NHWC
|
||||
input_var = np.expand_dims(image, 0)
|
||||
# NHWC to NC4HW4
|
||||
input_var = expr.convert(input_var, expr.NC4HW4)
|
||||
|
||||
# 执行推理
|
||||
output_var = net.forward(input_var)
|
||||
|
||||
# NC4HW4 to NHWC
|
||||
output_var = expr.convert(output_var, expr.NHWC)
|
||||
# 打印出分类结果, 282为猫
|
||||
print("output belong to class: {}".format(np.argmax(output_var)))
|
||||
# output belong to class: 282
|
||||
```
|
||||
其他示例可以参考[示例](../pymnn/RuntimeManager.html#example);也可以参考[示例工程](../start/demo.html#id5)。
|
||||
|
||||
|
||||
## 使用Python Session API *[deprecated]*
|
||||
|
||||
不建议使用该API执行推理,建议使用Module API
|
||||
|
||||
### 数据类型
|
||||
Python中`Session API`的函数名与用法与C++基本一样。使用的主要数据类型如下:
|
||||
- [Interpreter](../pymnn/Interpreter.md) 解释器,持有模型资源
|
||||
|
@ -118,107 +170,7 @@ print("output belong to class: {}".format(np.argmax(output_var, 1)))
|
|||
# output belong to class: array([282, 385], dtype=int32)
|
||||
```
|
||||
其他示例可以参考[示例](../pymnn/Interpreter.html#example);也可以参考[示例工程](../start/demo.html#session)。
|
||||
## 使用Python Module API
|
||||
### 数据类型
|
||||
Python中的`Module API`与C++中的函数名略有区别,用法相似。主要数据类型如下:
|
||||
- [_Module](../pymnn/_Module.md) 模型实例
|
||||
- [Var](../pymnn/Var.md) 模型的输入输出
|
||||
### 推理流程
|
||||
基本推理流程如下:
|
||||
- [创建Module](../pymnn/nn.html#load-module-from-file-file-name-input-names-output-names-dynamic-shape-mutable-rearrange-backend-memory-mode-power-mode-precision-mode)
|
||||
- 创建输入: 使用`expr`或`numpy`函数创建`Var`即可作为输入
|
||||
- [执行推理](../pymnn/_Module.html#forward-input)
|
||||
- 获取输出: 输出为`Var`类型,可以通过`expr`或`numpy`函数执行后处理
|
||||
### 示例
|
||||
```python
|
||||
import MNN.nn as nn
|
||||
import MNN.cv as cv
|
||||
import MNN.numpy as np
|
||||
import MNN.expr as expr
|
||||
|
||||
# 配置执行后端,线程数,精度等信息;key-vlaue请查看API介绍
|
||||
config = {}
|
||||
config['precision'] = 'low' # 当硬件支持(armv8.2)时使用fp16推理
|
||||
config['backend'] = 0 # CPU
|
||||
config['numThread'] = 4 # 线程数
|
||||
|
||||
rt = nn.create_runtime_manager((config,))
|
||||
# 加载模型创建_Module
|
||||
net = nn.load_module_from_file('mobilenet_v1.mnn', ['data'], ['prob'], runtime_manager=rt)
|
||||
|
||||
# 读取图片
|
||||
image = cv.imread('cat.jpg')
|
||||
# 转换为float32, 形状为[224,224,3]
|
||||
image = cv.resize(image, (224, 224), mean=[103.94, 116.78, 123.68], norm=[0.017, 0.017, 0.017])
|
||||
# 增加batch HWC to NHWC
|
||||
input_var = np.expand_dims(image, 0)
|
||||
# NHWC to NC4HW4
|
||||
input_var = expr.convert(input_var, expr.NC4HW4)
|
||||
|
||||
# 执行推理
|
||||
output_var = net.forward(input_var)
|
||||
|
||||
# NC4HW4 to NHWC
|
||||
output_var = expr.convert(output_var, expr.NHWC)
|
||||
# 打印出分类结果, 282为猫
|
||||
print("output belong to class: {}".format(np.argmax(output_var)))
|
||||
# output belong to class: 282
|
||||
```
|
||||
其他示例可以参考[示例](../pymnn/RuntimeManager.html#example);也可以参考[示例工程](../start/demo.html#id5)。
|
||||
|
||||
## 使用Python Expr API
|
||||
### 数据类型
|
||||
Python的`Expr API`相比C++在命名和使用方式上略有区别,但是功能一致。主要数据类型如下:
|
||||
- [Var](../pymnn/Var.md) 表达式计算中的变量
|
||||
### 主要用法
|
||||
因为`Expr`不仅有模型推理的能力,还具备数值计算的能力。在实际使用中`Expr`被用作构图或者计算的情况更多,实际用来执行模型推理的情况并不多,当`Expr`用作模型推理时的主要流程如下:
|
||||
- [加载计算图](../pymnn/expr.html#load-as-dict-filename)
|
||||
- 获取输入输出:直接使用Python中的`dict`的方式获取,如:`net['input']`
|
||||
- [写入输入数据](../pymnn/Var.html#write-data)
|
||||
- [读取输出数据](../pymnn/Var.html#read):读取数据不限于`read`,尝试打印和使用都可能触发读取操作
|
||||
### 示例
|
||||
`Expr`用作模型推理:
|
||||
```python
|
||||
import MNN.cv as cv
|
||||
import MNN.numpy as np
|
||||
import MNN.expr as expr
|
||||
|
||||
net = expr.load_as_dict('mobilenet_v1.mnn')
|
||||
input_var = net['data']
|
||||
output_var = net['prob']
|
||||
|
||||
# 读取图片
|
||||
image = cv.imread('cat.jpg')
|
||||
# 转换为float32, 形状为[224,224,3]
|
||||
image = cv.resize(image, (224, 224), mean=[103.94, 116.78, 123.68], norm=[0.017, 0.017, 0.017])
|
||||
# 增加batch HWC to NHWC
|
||||
input_data = np.expand_dims(image, 0)
|
||||
# NHWC to NC4HW4
|
||||
input_data = expr.convert(input_data, expr.NC4HW4)
|
||||
|
||||
input_var.write(input_data.read_as_tuple())
|
||||
|
||||
# 打印出分类结果, 282为猫
|
||||
print("output belong to class: {}".format(np.argmax(output_var)))
|
||||
```
|
||||
`Expr`用于数值计算与数据存取:
|
||||
```python
|
||||
import MNN.numpy as np
|
||||
import MNN.expr as expr
|
||||
|
||||
x = expr.range(0., 10., 1.)
|
||||
y = expr.fill([10], 3.1415)
|
||||
z = expr.sin(x * y + x / y)
|
||||
expr.save([z], 'z.mnn')
|
||||
a = expr.load_as_list('z.mnn')[0]
|
||||
print(a)
|
||||
'''
|
||||
array([ 0. , -0.31288275, 0.59434694, -0.8161286 , 0.955958 ,
|
||||
-0.9997932 , 0.943233 , -0.79195637, 0.561154 , -0.27400237],
|
||||
dtype=float32)
|
||||
'''
|
||||
```
|
||||
其他示例可以参考[示例](../pymnn/Var.html#example);也可以参考[示例工程](../start/demo.html#id5)。
|
||||
## 使用cv/numpy API
|
||||
### 数据类型
|
||||
Python的`cv`和`numpy`接口,其中`cv`是对C++中`tools/cv`实现的封装;`numpy`则是对`expr`接口的封装;这两个接口主要为了提高MNN的易用性,与`opencv`与`numpy`做到了再接口上的部分兼容,在用法和思路上基本一致。主要数据类型如下:
|
||||
|
|
|
@ -1,5 +1,122 @@
|
|||
# 发布版本
|
||||
## 2.4.0 (`Latest`)
|
||||
## 2.6.0 (`Latest`)
|
||||
#### 新特性
|
||||
- 新增int8量化算子支持:
|
||||
- Softmax
|
||||
- Interp
|
||||
- Binary
|
||||
- Unary
|
||||
- Scale
|
||||
- OpenCL 支持 Loop 算子特定情形;
|
||||
- BatchMatMul
|
||||
- Gather
|
||||
- x86_64支持Gelu-bf16;
|
||||
- CUDA支持bf16模型推理;
|
||||
- benchmark 工具支持直接测试模型量化后的性能(不需要先用量化工具量化模型)
|
||||
- Pymnn Tensor/Var使用Tuple创建时支持混合类型数据;
|
||||
- 权值量化模型支持低内存推理模式,计算时反量化;
|
||||
- 支持ChatGLM-6B模型推理内存占用3G;
|
||||
- 支持构建了ChatGLM-MNN Android app;
|
||||
#### 优化
|
||||
- OpenCL支持高通reocrd queue ,以降低创建 GPU Command Buffer 所需的时间;
|
||||
Oneplus 9 机型 Benchmark 测试结果如下
|
||||
|
||||
|Model |unrecord |record |
|
||||
|-------|---------|-------|
|
||||
|resnet-v2-50.mnn |21.254 |20.160|
|
||||
|MobileNetV2_224.mnn |4.853 |4.186|
|
||||
|mobilenet-v1-1.0.mnn |6.424 |5.315|
|
||||
|nasnet.mnn |46.751 |20.260|
|
||||
|SqueezeNetV1.0.mnn |7.35 |6.832|
|
||||
|squeezenetv1.1.mnn |3.936 |3.693|
|
||||
|mobilenetV3.mnn |14.201 |6.743|
|
||||
|inception-v3.mnn |33.111 |32.032|
|
||||
|
||||
- 稀疏卷积内存优化,降低内存占用;
|
||||
- 减少异构(CPU低精度/GPU)运行 MNN 模型时的常量内存占用;
|
||||
- CUDA优化int8算子性能;
|
||||
- 减少Permute几何计算产生的region数量;
|
||||
- 重新调整ConvolutionInt8及im2col在AVX512-VNNI下的分块大小,提升性能20%-30%;
|
||||
- X86新增bilinear/nearest sample的SIMD实现,提升ImageProcess性能 50% 左右;
|
||||
#### Bugfix
|
||||
- 关联 Github Issue 解决
|
||||
- 修复CUDA Raster错误导致输出为0的问题;issue-2333
|
||||
- 修复OpenCL Gather算子出错的问题;issue-2424
|
||||
- 修复ImageProcess出错的问题;issue-2386
|
||||
- OpenCL支持用户选择device id; issue-2343
|
||||
- 其他 Bugfix
|
||||
- CUDA CMakeList对未支持架构增加报错信息;
|
||||
- testMNNFromOnnx脚本在模型测试正确时不启用DEBUG模式;
|
||||
- load_module_from_file中的shape_mutable默认改为True(存在子图的模型无法在False情形下运行);
|
||||
- MNNConvert使用keepInputFormat选项时,也同时将输出Tensor的format转换为原始格式
|
||||
- 修复log记录时设备为空时Crash的情况;
|
||||
- 修复BinaryOp单元测试在Windows下无法编译的问题;
|
||||
- 修复MNN_SUPPORT_DEPRECATED_OP宏不控制OptimizedComputer的问题;
|
||||
- 修复fp16多线程且分块方向为channel时convolution计算出错的问题;
|
||||
- 修复deconvolutionInt8访存越界的问题;
|
||||
- 修复TensorArrayWrite几何计算产生zero region的问题;
|
||||
- 修复CUDA depthwise conv出错的问题;
|
||||
- 修复一些文档格式、内容的错误;
|
||||
- 修复多线程下createRuntime和setGlobalConfig出错的问题;
|
||||
- 修复Vec.hpp中无用代码导致的编译失败问题;
|
||||
- 修复OpenCL对gpuDevice的assert失败的问题;
|
||||
- 修复OpenCL bianry mod出错的问题;
|
||||
- 修复CUDA argmax出错的问题;
|
||||
- 修复pymnn/example/mnn_numpy_cv_demo.py中形状不对的问题;
|
||||
## 2.5.0
|
||||
#### 新特性
|
||||
- MNN OpenCV新增算子:
|
||||
- erode
|
||||
- convertMaps
|
||||
- remap
|
||||
- adaptiveThreshold
|
||||
- bilateralFilter
|
||||
- solve (MNN numpy新增solve)
|
||||
- normalize
|
||||
- split
|
||||
- merge
|
||||
- addWeight
|
||||
- 支持Tflite int8量化模型转换到MNN模型;
|
||||
- ARM CPU支持GELU-bf16
|
||||
- CUDA 新增算子:
|
||||
- GridSampler
|
||||
- Multi-Input Convolution
|
||||
- Multi-Input Deconvolution
|
||||
- CUDA针对多卡推理,支持用户设置运行device_id
|
||||
- 支持Deconvolution-int8
|
||||
- runSession/runSessionWithCallBack函数加锁,避免多线程调用出错
|
||||
- 支持非拓扑序ONNX模型转换
|
||||
- 支持ONNX多版本Softmax转换
|
||||
#### 重构/优化
|
||||
- 优化内存分配与回收时机,新增Session
|
||||
- 简化ONNX Slice算子模型转换
|
||||
- Cuda性能优化
|
||||
- Argmax针对dim size较大的情况性能优化
|
||||
- Softmax在channel较大时性能优化
|
||||
- MatMul算子预重排逻辑优化
|
||||
- 优化后ChatGLM模型在A10显卡上性能优于Pytorch 2.0
|
||||
- OpenCL优化,resnet测试优于OpenVINO
|
||||
- 使用intel subgroup扩展优化winogard算子,调整数据排布格式与准入条件
|
||||
- 根据输入尺寸调整conv2d算子的数据排布格式,使用intel subgroup扩展优化
|
||||
- 优化后ResNet18模型在intel UHD Graphics 630显卡上性能优于OpenVINO
|
||||
- GELU-bf16实现后性能提升
|
||||
#### Bugfix
|
||||
- 关联 Github Issue 解决
|
||||
- 修复CPURaster 的 singleConvert 部分情况出错 issue-2264
|
||||
- 修复atan2计算错误的问题
|
||||
- 修复ONNX dynamic shape转换出错的问题 issue-2276
|
||||
- 修复i8mm时Im2col出错的问题
|
||||
- 修复CPUConvolutionDepthwise错误的问题 issue-2291
|
||||
- 修复CUDA int8编译失败的问题 issue-2321
|
||||
- 修复Onnx Loop 算子的 M 和 cond 为optional 时,转换失败的问题 issue-2267
|
||||
- 修复Raster中fastblit 中turnPackRegion 部分情况下出错的问题 issue-2337
|
||||
- 其他 Bugfix
|
||||
- 修复 onnx 子图中 identity 被优化导致 输出数和原始子图不一致的问题
|
||||
- 修复 Onnx sequense 相关算子转换问题
|
||||
- 修复 TensorArrayConcat 计算 newAxis = 1 时的问题(此时为 stack)
|
||||
- 修复 TensorArray 计算 eleSize 时,axis < 0 时计算出错的问题
|
||||
- 修复低精度计算或者 GPU 无法运行 mnn 训练模型的问题
|
||||
## 2.4.0
|
||||
#### 新特性
|
||||
- NNAPI 支持int8 量化模型;
|
||||
- MNN OpenCL/Metal支持算子在线Fuse与代码生成;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
<!-- pymnn/CVImageProcess.md -->
|
||||
## MNN.CVImageProcess
|
||||
## MNN.CVImageProcess *[deprecated]*
|
||||
|
||||
```python
|
||||
class CVImageProcess
|
||||
|
@ -10,6 +10,7 @@ CVImageProcess用于图像处理,该图像处理类提供了一下图像处理
|
|||
- 图像的仿射变换,类似于`cv2.resize`和`cv2.warpAffine`,通过设置[CVMatrix](CVMatrix.md) 来实现
|
||||
- 对图像进行归一化,通过设置`mean`和`normal`来实现; `x = (x - mean) / normal`
|
||||
|
||||
*不建议使用该接口,请使用[cv](cv.md)代替*
|
||||
---
|
||||
### `MNN.CV_ImageFormat_*`
|
||||
描述图像格式的数据类型,支持RBG,RGBA,BGR,BGRA,GRAY,YUV_NV21类型
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
<!-- pymnn/CVMatrix.md -->
|
||||
## MNN.CVMatrix
|
||||
## MNN.CVMatrix *[deprecated]*
|
||||
|
||||
```python
|
||||
class CVMatrix
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
## MNN.Interpreter
|
||||
## MNN.Interpreter *[deprecated]*
|
||||
|
||||
```python
|
||||
class Interpreter
|
||||
```
|
||||
Interpreter是MNN V2接口中模型数据的持有者。使用MNN推理时,有两个层级的抽象,分别是解释器Interpreter和会话[Session](Session.md)。
|
||||
|
||||
*不建议使用该接口,请使用[nn](nn.md)代替*
|
||||
|
||||
---
|
||||
### `Interpreter(model_path)`
|
||||
加载`.mnn`模型文件创建一个MNN解释器,返回一个解释器对象
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
## MNN.Session
|
||||
## MNN.Session *[deprecated]*
|
||||
|
||||
```python
|
||||
class Session
|
||||
```
|
||||
Session是MNN V2接口中推理数据的持有者。Session通过[Interpreter](Interpreter.md)创建;多个推理可以共用同一个模型,即,多个Session可以共用一个Interpreter。
|
||||
|
||||
*不建议使用该接口,请使用[nn](nn.md)代替*
|
||||
|
||||
---
|
||||
### `Session()`
|
||||
创建一个空Tensor
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
<!-- pymnn/Tensor.md -->
|
||||
## MNN.Tensor
|
||||
## MNN.Tensor *[deprecated]*
|
||||
|
||||
```python
|
||||
class Tensor
|
||||
```
|
||||
Tensor是MNN V2接口中的基础数据结构,是最基本的数据封装类型。Tensor存储了数据以及数据类型,形状等诸多信息,用户可以通过Tensor本身的函数获取这些信息。
|
||||
|
||||
*不建议使用该接口,请使用[Var](Var.md)代替*
|
||||
|
||||
---
|
||||
### `MNN.Halide_Type_*`
|
||||
描述Tensor的数据类型
|
||||
|
|
|
@ -103,45 +103,6 @@ array([0., 1., 2., 3.], dtype=float32)
|
|||
[3., 4.]], dtype=float32)]
|
||||
```
|
||||
---
|
||||
### `load_as_dict(fileName)`
|
||||
从文件中加载模型,并将模型转换为计算图,以`dict`的形式返回计算图的所有节点名称和`Var`
|
||||
|
||||
参数:
|
||||
- `fileName:str` 模型文件路径
|
||||
|
||||
返回:加载的模型计算图,其`key`为`str`,`value`为`Var`
|
||||
|
||||
返回类型:`dict`
|
||||
|
||||
示例:
|
||||
|
||||
```python
|
||||
>>> vars = expr.load_as_dict('mobilenet_v1.mnn')
|
||||
>>> vars.keys()
|
||||
dict_keys(['conv1', 'conv2_1/dw', 'conv2_1/sep', 'conv2_2/dw', 'conv2_2/sep', 'conv3_1/dw', 'conv3_1/sep', 'conv3_2/dw', 'conv3_2/sep', 'conv4_1/dw', 'conv4_1/sep', 'conv4_2/dw', 'conv4_2/sep', 'conv5_1/dw', 'conv5_1/sep', 'conv5_2/dw', 'conv5_2/sep', 'conv5_3/dw', 'conv5_3/sep', 'conv5_4/dw', 'conv5_4/sep', 'conv5_5/dw', 'conv5_5/sep', 'conv5_6/dw', 'conv5_6/sep', 'conv6/dw', 'conv6/sep', 'data', 'fc7', 'pool6', 'prob'])
|
||||
```
|
||||
---
|
||||
### `get_inputs_and_outputs(allVariable)`
|
||||
获取`dict`形式计算图的输入输出节点,可以在使用V3接口时获取输入输出的信息
|
||||
|
||||
参数:
|
||||
- `allVariable:dict` 计算图的`dict`形式,其`key`为`str`,`value`为`Var`
|
||||
|
||||
返回:计算图的输入输出,其中输入输出都为`dict`形式,其`key`为`str`,`value`为`Var`
|
||||
|
||||
返回类型:`(dict, dict)`
|
||||
|
||||
示例:
|
||||
|
||||
```python
|
||||
>>> vars = expr.load_as_dict('mobilenet_v1.mnn')
|
||||
>>> inputs, outputs = expr.get_inputs_and_outputs(vars)
|
||||
>>> inputs.keys()
|
||||
dict_keys(['data'])
|
||||
>>> outputs.keys()
|
||||
dict_keys(['prob'])
|
||||
```
|
||||
---
|
||||
### `gc(full)`
|
||||
手动回收内存,当在循环中调用MNN表达式求值时,常量部分数据不会在每次循环结束释放,当执行次数增加时会有内存增长现象,可以在每次循环结束时调用该函数回收常量内存
|
||||
|
||||
|
@ -3049,4 +3010,48 @@ roialign
|
|||
|
||||
```python
|
||||
TODO
|
||||
```
|
||||
---
|
||||
**以下函数为框架开发者使用函数,普通用户不建议使用!**
|
||||
|
||||
---
|
||||
### `load_as_dict(fileName)` *[deprecated]*
|
||||
从文件中加载模型,并将模型转换为计算图,以`dict`的形式返回计算图的所有节点名称和`Var`
|
||||
|
||||
*不建议使用该接口*
|
||||
|
||||
参数:
|
||||
- `fileName:str` 模型文件路径
|
||||
|
||||
返回:加载的模型计算图,其`key`为`str`,`value`为`Var`
|
||||
|
||||
返回类型:`dict`
|
||||
|
||||
示例:
|
||||
|
||||
```python
|
||||
>>> vars = expr.load_as_dict('mobilenet_v1.mnn')
|
||||
>>> vars.keys()
|
||||
dict_keys(['conv1', 'conv2_1/dw', 'conv2_1/sep', 'conv2_2/dw', 'conv2_2/sep', 'conv3_1/dw', 'conv3_1/sep', 'conv3_2/dw', 'conv3_2/sep', 'conv4_1/dw', 'conv4_1/sep', 'conv4_2/dw', 'conv4_2/sep', 'conv5_1/dw', 'conv5_1/sep', 'conv5_2/dw', 'conv5_2/sep', 'conv5_3/dw', 'conv5_3/sep', 'conv5_4/dw', 'conv5_4/sep', 'conv5_5/dw', 'conv5_5/sep', 'conv5_6/dw', 'conv5_6/sep', 'conv6/dw', 'conv6/sep', 'data', 'fc7', 'pool6', 'prob'])
|
||||
```
|
||||
---
|
||||
### `get_inputs_and_outputs(allVariable)` *[deprecated]*
|
||||
获取`dict`形式计算图的输入输出节点,可以在使用V3接口时获取输入输出的信息
|
||||
|
||||
参数:
|
||||
- `allVariable:dict` 计算图的`dict`形式,其`key`为`str`,`value`为`Var`
|
||||
|
||||
返回:计算图的输入输出,其中输入输出都为`dict`形式,其`key`为`str`,`value`为`Var`
|
||||
|
||||
返回类型:`(dict, dict)`
|
||||
|
||||
示例:
|
||||
|
||||
```python
|
||||
>>> vars = expr.load_as_dict('mobilenet_v1.mnn')
|
||||
>>> inputs, outputs = expr.get_inputs_and_outputs(vars)
|
||||
>>> inputs.keys()
|
||||
dict_keys(['data'])
|
||||
>>> outputs.keys()
|
||||
dict_keys(['prob'])
|
||||
```
|
|
@ -628,6 +628,8 @@ void Executor::_makeCache(const std::vector<EXPRP>& expr, bool forceCPU) {
|
|||
expr->inside()->mCache = cahce;
|
||||
}
|
||||
cahce->mCacheBuffers = std::move(opBuffers);
|
||||
// Don't report error when use expr dynamic compute, which will be called in model convert
|
||||
scheduleInfo.pipelineInfo[0].first.reportError = false;
|
||||
scheduleInfo.pipelineInfo[0].first.info.numThread = 1;
|
||||
if (forceCPU) {
|
||||
scheduleInfo.pipelineInfo[0].first.info.type = MNN_FORWARD_CPU;
|
||||
|
|
|
@ -31,7 +31,8 @@ static Scope<ExecutorRef>* _getGlobalScope() {
|
|||
#if TARGET_OS_IPHONE
|
||||
pthread_key_create(&gKey, NULL);
|
||||
#else
|
||||
g_executor_scope = new Scope<ExecutorRef>;
|
||||
thread_local static Scope<ExecutorRef> initValue;
|
||||
g_executor_scope = &initValue;
|
||||
#endif
|
||||
});
|
||||
#if TARGET_OS_IPHONE
|
||||
|
|
|
@ -69,6 +69,6 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
|
|||
#define STR(x) STR_IMP(x)
|
||||
#define MNN_VERSION_MAJOR 2
|
||||
#define MNN_VERSION_MINOR 6
|
||||
#define MNN_VERSION_PATCH 2
|
||||
#define MNN_VERSION_PATCH 3
|
||||
#define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH)
|
||||
#endif /* MNNDefine_h */
|
||||
|
|
|
@ -63,6 +63,11 @@ typedef enum {
|
|||
then choose the better one according to performance*/
|
||||
MNN_GPU_MEMORY_BUFFER = 1 << 6,/* User assign mode */
|
||||
MNN_GPU_MEMORY_IMAGE = 1 << 7,/* User assign mode */
|
||||
// choose one opencl memory mode Only, this mode Only support for Qualcomm gpu
|
||||
/* User can try MNN_GPU_RECORD_OP and MNN_GPU_RECORD_KERNEL both,
|
||||
then choose the better one according to performance*/
|
||||
MNN_GPU_RECORD_OP = 1 << 8,/* the kernels in one op execution record into one recording */
|
||||
MNN_GPU_RECORD_BATCH = 1 << 9,/* 10 kernels record into one recording */
|
||||
} MNNGpuMode;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
@ -4,9 +4,11 @@
|
|||
# |-- Release
|
||||
# |--- Dynamic
|
||||
# | |--- MD
|
||||
# | |--- MT
|
||||
# |
|
||||
# |--- Static
|
||||
# |--- MD
|
||||
# |--- MT
|
||||
#
|
||||
Param(
|
||||
[Parameter(Mandatory=$true)][String]$path,
|
||||
|
@ -27,7 +29,7 @@ Remove-Item -Path $PACKAGE_PATH/include -Recurse -ErrorAction Ignore
|
|||
cp -r include $PACKAGE_PATH
|
||||
cp -r tools/cv/include/cv $PACKAGE_PATH/include
|
||||
pushd $PACKAGE_LIB_PATH
|
||||
mkdir -p Release\Dynamic\MD, Release\Static\MD
|
||||
mkdir -p Release\Dynamic\MT, Release\Dynamic\MD, Release\Static\MD, Release\Static\MT
|
||||
popd
|
||||
|
||||
$CMAKE_ARGS = "-DMNN_SEP_BUILD=OFF -DMNN_BUILD_TRAIN=ON -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_OPENCL=ON -DMNN_VULKAN=ON -DMNN_AVX512=ON"
|
||||
|
@ -68,6 +70,14 @@ function Build([String]$cmake_cmd, [String]$ninja_cmd = "ninja MNN") {
|
|||
exit 1
|
||||
}
|
||||
|
||||
|
||||
##### Release/Dynamic/MT ####
|
||||
log "Release/Dynamic/MT"
|
||||
Remove-Item CMakeCache.txt -ErrorAction Ignore
|
||||
Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=ON .."
|
||||
cp MNN.lib, MNN.dll, MNN.pdb $PACKAGE_LIB_PATH\Release\Dynamic\MT
|
||||
rm MNN.*
|
||||
|
||||
##### Release/Dynamic/MD ####
|
||||
log "Release/Dynamic/MD"
|
||||
Remove-Item CMakeCache.txt -ErrorAction Ignore
|
||||
|
@ -75,6 +85,12 @@ Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_M
|
|||
cp MNN.lib, MNN.dll, MNN.pdb $PACKAGE_LIB_PATH\Release\Dynamic\MD
|
||||
rm MNN.*
|
||||
|
||||
##### Release/Static/MT ####
|
||||
log "Release/Static/MT"
|
||||
Remove-Item CMakeCache.txt -ErrorAction Ignore
|
||||
Build "cmake -G Ninja $CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=ON -DMNN_BUILD_SHARED_LIBS=OFF .."
|
||||
cp MNN.lib $PACKAGE_LIB_PATH\Release\Static\MT
|
||||
|
||||
##### Release/Static/MD ####
|
||||
log "Release/Static/MD"
|
||||
Remove-Item CMakeCache.txt -ErrorAction Ignore
|
||||
|
|
|
@ -1839,6 +1839,7 @@
|
|||
488873A8215B639D0079B12E /* source */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
CE482EF5288536DA007CD935 /* internal */,
|
||||
4DF87C482887D3560003E2D4 /* calib3d */,
|
||||
4D4CF4612760946500A36D9F /* imgproc */,
|
||||
4A5BEC6226AAB3D70032F6BD /* common */,
|
||||
|
@ -2873,15 +2874,18 @@
|
|||
CEA82BDC2A15F8AD002CBC95 /* IdstConvolutionInt8.hpp in Headers */,
|
||||
4DE4E82C275E307B0016A916 /* cv in Headers */,
|
||||
1F501F842397BA5B004E8721 /* ImageProcess.hpp in Headers */,
|
||||
CECF8C5D299CACFD00D3875B /* Log.hpp in Headers */,
|
||||
1F501F822397BA5B004E8721 /* Interpreter.hpp in Headers */,
|
||||
C4F906B327688C3A0026B847 /* NMSModule.hpp in Headers */,
|
||||
1F501F882397BA5B004E8721 /* Tensor.hpp in Headers */,
|
||||
1F501F872397BA5B004E8721 /* Matrix.h in Headers */,
|
||||
CECF8C5A299CACFD00D3875B /* WorkerThread.hpp in Headers */,
|
||||
48C84B85250F711700EE7666 /* IfModule.hpp in Headers */,
|
||||
4D9A937326255BDA00F9B43C /* CoreMLUnary.hpp in Headers */,
|
||||
48C84B98250F71E900EE7666 /* CPUSoftmax.hpp in Headers */,
|
||||
4882C8B8241A22B800DAC168 /* OpCommonUtils.hpp in Headers */,
|
||||
48608B54250632EC00CB1D71 /* GeometryComputer.hpp in Headers */,
|
||||
CECF8C7A299CAD9400D3875B /* sha1.h in Headers */,
|
||||
4894C6EC27016F7200D8BE79 /* CPUResizeCache.hpp in Headers */,
|
||||
92FF04A623AA0BFB00AC97F6 /* FileLoader.hpp in Headers */,
|
||||
48F34733273A7C8400C45394 /* ImageProcessFunction.hpp in Headers */,
|
||||
|
@ -2896,6 +2900,7 @@
|
|||
48925F352744AC0700919B37 /* CPUROIAlign.hpp in Headers */,
|
||||
92FF029623AA0B5A00AC97F6 /* CPUCast.hpp in Headers */,
|
||||
4D9A937826255BDA00F9B43C /* CoreMLBinary.hpp in Headers */,
|
||||
CECF8C85299CAD9400D3875B /* log_util.h in Headers */,
|
||||
489D7AB02550FDC900AD896A /* MetalDefine.h in Headers */,
|
||||
4D6D7FD52656896600F80814 /* DenseConvolutionTiledExecutor.hpp in Headers */,
|
||||
4D9A936626255BDA00F9B43C /* CoreMLExecutor.h in Headers */,
|
||||
|
@ -2905,6 +2910,7 @@
|
|||
1F501F802397BA5B004E8721 /* MNNDefine.h in Headers */,
|
||||
19D0FE76285C66F200B74B1A /* MetalLayerNorm.hpp in Headers */,
|
||||
489D7A682550FDC800AD896A /* MetalReduction.hpp in Headers */,
|
||||
CECF8C86299CAD9400D3875B /* sds.h in Headers */,
|
||||
1F501F7F2397BA5B004E8721 /* HalideRuntime.h in Headers */,
|
||||
92FF029E23AA0B5A00AC97F6 /* CPUDeconvolutionDepthwise.hpp in Headers */,
|
||||
4D9A935B26255BDA00F9B43C /* NeuralNetwork.pb-c.h in Headers */,
|
||||
|
@ -2925,8 +2931,10 @@
|
|||
481C2DEE25FE2CD6001ED6DF /* Arm82Functions.hpp in Headers */,
|
||||
4894C6EA27016F7200D8BE79 /* UnaryUtils.hpp in Headers */,
|
||||
EBD4842A2485FF650083CE95 /* Arm82Interp.hpp in Headers */,
|
||||
CECF8C81299CAD9400D3875B /* log_util_imp.h in Headers */,
|
||||
92FF037623AA0B5A00AC97F6 /* CPUBinary.hpp in Headers */,
|
||||
4D9A935826255BDA00F9B43C /* FeatureTypes.pb-c.h in Headers */,
|
||||
CECF8C7C299CAD9400D3875B /* hmac-sha.h in Headers */,
|
||||
48608B53250632EC00CB1D71 /* GeometryComputerUtils.hpp in Headers */,
|
||||
950B28F529F629A90002F454 /* CPUBinaryInt8.hpp in Headers */,
|
||||
489D7A732550FDC800AD896A /* MetalBackend.hpp in Headers */,
|
||||
|
@ -2948,6 +2956,7 @@
|
|||
4DF87C522887D3F20003E2D4 /* CPUSvd.hpp in Headers */,
|
||||
48747D4B245D9D24000B9709 /* RuntimeFactory.hpp in Headers */,
|
||||
92FF03B323AA0B5A00AC97F6 /* ConvolutionDepthwise3x3.hpp in Headers */,
|
||||
CECF8C77299CAD9400D3875B /* log_builder.h in Headers */,
|
||||
4D9A937226255BDA00F9B43C /* CoreMLConvolution.hpp in Headers */,
|
||||
92FF038B23AA0B5A00AC97F6 /* CPUUnravelIndex.hpp in Headers */,
|
||||
4AF4FB26269ED235005BA97B /* SparseConvInt8TiledExecutor.hpp in Headers */,
|
||||
|
@ -2983,6 +2992,7 @@
|
|||
92FF03CA23AA0B5A00AC97F6 /* CPUConvolutionDepthwise.hpp in Headers */,
|
||||
92FF04A923AA0BFB00AC97F6 /* Schedule.hpp in Headers */,
|
||||
489D7A9F2550FDC900AD896A /* MetalConvolutionCommon.hpp in Headers */,
|
||||
CECF8C80299CAD9400D3875B /* lz4.h in Headers */,
|
||||
92FF028623AA0B5A00AC97F6 /* CPUDeconvolution.hpp in Headers */,
|
||||
489D7A722550FDC800AD896A /* MetalReLU6.hpp in Headers */,
|
||||
92FF04B523AA0BFB00AC97F6 /* TensorUtils.hpp in Headers */,
|
||||
|
@ -3035,20 +3045,24 @@
|
|||
92FF03A623AA0B5A00AC97F6 /* ConvolutionTiledExecutor.hpp in Headers */,
|
||||
92FF036523AA0B5A00AC97F6 /* CPUResize.hpp in Headers */,
|
||||
92FF04B423AA0BFB00AC97F6 /* MNNMemoryUtils.h in Headers */,
|
||||
CECF8C88299CAD9400D3875B /* log_api.h in Headers */,
|
||||
4A224A0D27D0C2D9000A9260 /* ConvolutionPackWinograd.hpp in Headers */,
|
||||
4A224A0E27D0C2D9000A9260 /* ConvolutionPackFreeWinograd.hpp in Headers */,
|
||||
4D9A937426255BDA00F9B43C /* CoreMLReduction.hpp in Headers */,
|
||||
48C84B8B250F711700EE7666 /* PipelineModule.hpp in Headers */,
|
||||
F41497D7278D8A21004A363A /* RuntimeAttr.hpp in Headers */,
|
||||
CECF8C5B299CACFD00D3875B /* LogHelper.hpp in Headers */,
|
||||
92FF04C123AA0BFB00AC97F6 /* Backend.hpp in Headers */,
|
||||
482BFBCD28351BA1009210E4 /* ShaderMap.hpp in Headers */,
|
||||
489D7A812550FDC900AD896A /* MetalPooling.hpp in Headers */,
|
||||
CECF8C7F299CAD9400D3875B /* md5.h in Headers */,
|
||||
92FF02A623AA0B5A00AC97F6 /* CPUQuantizedMaxPool.hpp in Headers */,
|
||||
92FF028023AA0B5A00AC97F6 /* CPUFloatToInt8.hpp in Headers */,
|
||||
92FF028723AA0B5A00AC97F6 /* CPUFixedPoint.hpp in Headers */,
|
||||
C43C8227251894F400A0FF84 /* Vec.hpp in Headers */,
|
||||
4819FB1D24C138DF0050BD09 /* GeometryConvUtils.hpp in Headers */,
|
||||
489D7A952550FDC900AD896A /* MetalMatMul.hpp in Headers */,
|
||||
CECF8C83299CAD9400D3875B /* log_define.h in Headers */,
|
||||
C48CAE2628900C4A00271A6D /* ConvInt8Winograd.hpp in Headers */,
|
||||
48F34730273A7C7300C45394 /* CPUImageProcess.hpp in Headers */,
|
||||
489D7A702550FDC800AD896A /* MetalRaster.hpp in Headers */,
|
||||
|
@ -3272,6 +3286,7 @@
|
|||
489D7A8A2550FDC900AD896A /* MetalConvolutionDepthwise.mm in Sources */,
|
||||
48123003269EA83400EB7ABA /* ShapeUnique.cpp in Sources */,
|
||||
92FF037D23AA0B5A00AC97F6 /* CPURelu.cpp in Sources */,
|
||||
CECF8C5E299CACFD00D3875B /* WorkerThread.cpp in Sources */,
|
||||
489D7A842550FDC900AD896A /* MetalBinary.mm in Sources */,
|
||||
48747D6B245D9E33000B9709 /* GeometryFill.cpp in Sources */,
|
||||
4819FB1F24C138DF0050BD09 /* GeometryConvUtils.cpp in Sources */,
|
||||
|
@ -3370,6 +3385,7 @@
|
|||
48F34734273A7C8400C45394 /* ImageProcessFunction.cpp in Sources */,
|
||||
6A131E4025823349002EC3D6 /* PluginKernel.cpp in Sources */,
|
||||
48958781268EBA6F00EA01A7 /* CPUSegmentMean.cpp in Sources */,
|
||||
CECF8C7B299CAD9400D3875B /* sha1.c in Sources */,
|
||||
4D9A937026255BDA00F9B43C /* CoreMLUnary.cpp in Sources */,
|
||||
92FF04A823AA0BFB00AC97F6 /* AutoTime.cpp in Sources */,
|
||||
92FF04AE23AA0BFB00AC97F6 /* Backend.cpp in Sources */,
|
||||
|
@ -3424,6 +3440,7 @@
|
|||
92FF03CE23AA0B5A00AC97F6 /* CPUOPRegister.cpp in Sources */,
|
||||
92FF02B323AA0B5A00AC97F6 /* CPUInstanceNorm.cpp in Sources */,
|
||||
4819FB2C24C1396A0050BD09 /* GeometryPoolGrad.cpp in Sources */,
|
||||
CECF8C7E299CAD9400D3875B /* log_builder.cpp in Sources */,
|
||||
92FF042223AA0B7100AC97F6 /* ShapeConcat.cpp in Sources */,
|
||||
4D6D7FD12656891400F80814 /* MNNPackedSparseMatMulEpx4.S in Sources */,
|
||||
4D5662CC299B76ED0031C1A1 /* MNNMaxPoolInt8.S in Sources */,
|
||||
|
@ -3499,6 +3516,7 @@
|
|||
4D759B2C25FF89EE0037B0B6 /* GeometryShape.cpp in Sources */,
|
||||
11A01A07258785EA00745FA7 /* MNNVectorTop1Float.S in Sources */,
|
||||
48747D6E245D9E33000B9709 /* GeometrySlice.cpp in Sources */,
|
||||
CECF8C7D299CAD9400D3875B /* md5.c in Sources */,
|
||||
92FF041923AA0B7100AC97F6 /* ShapeQuantizedMaxPool.cpp in Sources */,
|
||||
92FF038A23AA0B5A00AC97F6 /* CPURange.cpp in Sources */,
|
||||
CE125CC92A52BF6B003698C9 /* MNNBilinearLineC8.S in Sources */,
|
||||
|
@ -3555,7 +3573,9 @@
|
|||
92FF042E23AA0B7100AC97F6 /* ShapeProposal.cpp in Sources */,
|
||||
92FF025923AA0B5A00AC97F6 /* CPUPoolInt8.cpp in Sources */,
|
||||
92FF045B23AA0B7100AC97F6 /* ShapeShape.cpp in Sources */,
|
||||
CECF8C87299CAD9400D3875B /* sds.c in Sources */,
|
||||
4D6D7FD72656896D00F80814 /* SparseConvolutionTiledExecutor.cpp in Sources */,
|
||||
CECF8C82299CAD9400D3875B /* log_api.cpp in Sources */,
|
||||
92FF03A823AA0B5A00AC97F6 /* WinogradOptFunction.cpp in Sources */,
|
||||
950B28E229F627E00002F454 /* MNNBinarySubInt8.S in Sources */,
|
||||
950B28F029F627F70002F454 /* MNNBinarySubInt8.S in Sources */,
|
||||
|
@ -3566,6 +3586,7 @@
|
|||
4D9A936026255BDA00F9B43C /* Model.pb-c.c in Sources */,
|
||||
CE9AFED628E54E3300566949 /* CPUInterp3D.cpp in Sources */,
|
||||
C4F906B427688C3A0026B847 /* NMSModule.cpp in Sources */,
|
||||
CECF8C64299CAD8400D3875B /* LogHelper.mm in Sources */,
|
||||
48FA474523AA127B00172C3B /* Executor.cpp in Sources */,
|
||||
92FF02EA23AA0B5A00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S in Sources */,
|
||||
48A8A61A21D101DE00C2B9A7 /* Matrix_CV.cpp in Sources */,
|
||||
|
@ -3590,6 +3611,7 @@
|
|||
92FF027F23AA0B5A00AC97F6 /* CPUDeconvolutionDepthwise.cpp in Sources */,
|
||||
EBECA3A724643D5D0062C7A3 /* MNNQuantizeFP16_UNIT4.S in Sources */,
|
||||
92FF04A423AA0BFB00AC97F6 /* Interpreter.cpp in Sources */,
|
||||
CECF8C5C299CACFD00D3875B /* Log.cpp in Sources */,
|
||||
92FF045623AA0B7100AC97F6 /* ShapeReshape.cpp in Sources */,
|
||||
92FF032523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */,
|
||||
92FF044423AA0B7100AC97F6 /* ShapeLSTM.cpp in Sources */,
|
||||
|
@ -3625,6 +3647,7 @@
|
|||
92FF02B623AA0B5A00AC97F6 /* CPUUnary.cpp in Sources */,
|
||||
92FF032723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */,
|
||||
CE7DC00028E2DE6B00797689 /* ShapeConvTranspose3D.cpp in Sources */,
|
||||
CECF8C78299CAD9400D3875B /* log_util_imp.cpp in Sources */,
|
||||
92FF02CA23AA0B5A00AC97F6 /* MNNUnPackC4.S in Sources */,
|
||||
48925F372744AC2A00919B37 /* ShapeROIAlign.cpp in Sources */,
|
||||
92FF02E723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */,
|
||||
|
@ -3649,11 +3672,13 @@
|
|||
92FF02FF23AA0B5A00AC97F6 /* MNNFloat2Int8.S in Sources */,
|
||||
4D9A937926255BDA00F9B43C /* CoreMLRaster.cpp in Sources */,
|
||||
48417FF224D13BF50056D9A7 /* GeometrySelect.cpp in Sources */,
|
||||
CECF8C84299CAD9400D3875B /* lz4.c in Sources */,
|
||||
489D7A7E2550FDC900AD896A /* MNNMetalContext.mm in Sources */,
|
||||
92FF033423AA0B5A00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S in Sources */,
|
||||
92FF036B23AA0B5A00AC97F6 /* CPUResize.cpp in Sources */,
|
||||
92FF02C723AA0B5A00AC97F6 /* MNNCopyC4WithStride.S in Sources */,
|
||||
92FF030923AA0B5A00AC97F6 /* MNNNV21ToBGRUnit.S in Sources */,
|
||||
CECF8C79299CAD9400D3875B /* hmac-sha.cpp in Sources */,
|
||||
92FF032623AA0B5A00AC97F6 /* MNNWinogradMatrixProductLeft.S in Sources */,
|
||||
92FF04C023AA0BFB00AC97F6 /* Tensor.cpp in Sources */,
|
||||
CEE9B95B2A3AA4D4006438F2 /* MNNBilinearLineC8.S in Sources */,
|
||||
|
@ -3981,10 +4006,12 @@
|
|||
isa = XCBuildConfiguration;
|
||||
buildSettings = {
|
||||
CODE_SIGN_IDENTITY = "Apple Development";
|
||||
"CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "";
|
||||
"CODE_SIGN_IDENTITY[sdk=macosx*]" = "Apple Development";
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
DEAD_CODE_STRIPPING = YES;
|
||||
DEFINES_MODULE = YES;
|
||||
DEVELOPMENT_TEAM = 6G7464HHUS;
|
||||
DEVELOPMENT_TEAM = Q48UX93J22;
|
||||
DYLIB_COMPATIBILITY_VERSION = 1;
|
||||
DYLIB_CURRENT_VERSION = 1;
|
||||
DYLIB_INSTALL_NAME_BASE = "@rpath";
|
||||
|
@ -4027,7 +4054,7 @@
|
|||
METAL_LIBRARY_FILE_BASE = mnn;
|
||||
ONLY_ACTIVE_ARCH = YES;
|
||||
OTHER_CFLAGS = "";
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111ss;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111;
|
||||
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
|
||||
PROVISIONING_PROFILE_SPECIFIER = "";
|
||||
"PROVISIONING_PROFILE_SPECIFIER[sdk=macosx*]" = "";
|
||||
|
@ -4048,7 +4075,7 @@
|
|||
CODE_SIGN_STYLE = Automatic;
|
||||
DEAD_CODE_STRIPPING = YES;
|
||||
DEFINES_MODULE = YES;
|
||||
DEVELOPMENT_TEAM = 6G7464HHUS;
|
||||
DEVELOPMENT_TEAM = Q48UX93J22;
|
||||
DYLIB_COMPATIBILITY_VERSION = 1;
|
||||
DYLIB_CURRENT_VERSION = 1;
|
||||
DYLIB_INSTALL_NAME_BASE = "@rpath";
|
||||
|
@ -4089,7 +4116,7 @@
|
|||
MACH_O_TYPE = staticlib;
|
||||
METAL_LIBRARY_FILE_BASE = mnn;
|
||||
OTHER_CFLAGS = "";
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111ss;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111;
|
||||
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
|
||||
PROVISIONING_PROFILE_SPECIFIER = "";
|
||||
"PROVISIONING_PROFILE_SPECIFIER[sdk=macosx*]" = "";
|
||||
|
@ -4108,7 +4135,7 @@
|
|||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
DEVELOPMENT_TEAM = 6G7464HHUS;
|
||||
DEVELOPMENT_TEAM = Q48UX93J22;
|
||||
GCC_ENABLE_CPP_EXCEPTIONS = NO;
|
||||
GCC_ENABLE_CPP_RTTI = NO;
|
||||
HEADER_SEARCH_PATHS = (
|
||||
|
@ -4121,7 +4148,7 @@
|
|||
IPHONEOS_DEPLOYMENT_TARGET = 9.0;
|
||||
LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
|
||||
OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111ss;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
|
@ -4133,7 +4160,7 @@
|
|||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
DEVELOPMENT_TEAM = 6G7464HHUS;
|
||||
DEVELOPMENT_TEAM = Q48UX93J22;
|
||||
GCC_ENABLE_CPP_EXCEPTIONS = NO;
|
||||
GCC_ENABLE_CPP_RTTI = NO;
|
||||
HEADER_SEARCH_PATHS = (
|
||||
|
@ -4146,7 +4173,7 @@
|
|||
IPHONEOS_DEPLOYMENT_TARGET = 9.0;
|
||||
LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
|
||||
OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111ss;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.playground.abcd111;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
|
@ -4260,3 +4287,4 @@
|
|||
};
|
||||
rootObject = 0F1465AE1FA18D1000F9860A /* Project object */;
|
||||
}
|
||||
|
||||
|
|
|
@ -76,9 +76,15 @@ endif()
|
|||
if(PYMNN_IMGPROC_FILTER)
|
||||
target_compile_definitions(mnnpybridge PRIVATE PYMNN_IMGPROC_FILTER)
|
||||
endif()
|
||||
if(PYMNN_IMGPROC_HISTOGRAMS)
|
||||
target_compile_definitions(mnnpybridge PRIVATE PYMNN_IMGPROC_HISTOGRAMS)
|
||||
endif()
|
||||
if(PYMNN_CALIB3D)
|
||||
target_compile_definitions(mnnpybridge PRIVATE PYMNN_CALIB3D)
|
||||
endif()
|
||||
if(PYMNN_CVCORE)
|
||||
target_compile_definitions(mnnpybridge PRIVATE PYMNN_CVCORE)
|
||||
endif()
|
||||
|
||||
if(PYMNN_INTERNAL_SERVING)
|
||||
message(STATUS "mnnpybridge define PYMNN_INTERNAL_SERVING")
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import MNN.expr as _F
|
||||
import MNN.cv as _cv
|
||||
|
||||
# Linear algebra
|
||||
def norm(x, ord=None, axis=None, keepdims=False):
|
||||
|
@ -48,4 +47,5 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
|
|||
return (u, w, vt)
|
||||
|
||||
def solve(a, b):
|
||||
import _mnncengine.cv as _cv
|
||||
return _cv.solve(a, b)[1]
|
|
@ -44,12 +44,15 @@ def build_deps():
|
|||
shutil.rmtree(cmake_build_dir)
|
||||
os.makedirs(cmake_build_dir)
|
||||
os.chdir(cmake_build_dir)
|
||||
extra_opts = '-DMNN_LOW_MEMORY=ON'
|
||||
extra_opts += ' -DMNN_VULKAN=ON -DMNN_VULKAN_IMAGE=OFF'
|
||||
extra_opts += ' -DMNN_OPENCL=ON'
|
||||
if IS_WINDOWS:
|
||||
os.system('cmake -G "Ninja" -DMNN_BUILD_TRAIN=ON -DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TORCH=OFF\
|
||||
os.system('cmake -G "Ninja" ' + extra_opts +' -DMNN_BUILD_TRAIN=ON -DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TORCH=OFF\
|
||||
-DMNN_BUILD_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release -DMNN_WIN_RUNTIME_MT=ON\
|
||||
-DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF .. && ninja MNN MNNTrain MNNConvert')
|
||||
elif IS_LINUX:
|
||||
extra_opts = '-DMNN_TENSORRT=ON \
|
||||
extra_opts += '-DMNN_TENSORRT=ON \
|
||||
-DCMAKE_LIBRARY_PATH=/usr/local/cuda/lib64/stubs/ ' if USE_TRT else ' '
|
||||
extra_opts += ' -DMNN_INTERNAL=ON ' if IS_INTERNAL_BUILD else ' '
|
||||
extra_opts += ' -DMNN_BUILD_TORCH=ON ' if IS_BUILD_TORCH else ' '
|
||||
|
@ -59,11 +62,11 @@ def build_deps():
|
|||
-DMNN_BUILD_SHARED_LIBS=OFF -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON \
|
||||
-DMNN_USE_THREAD_POOL=ON -DMNN_OPENMP=OFF .. && make MNN MNNTrain MNNConvert -j4')
|
||||
else:
|
||||
extra_opts = ' -DMNN_INTERNAL=ON ' if IS_INTERNAL_BUILD else ' '
|
||||
extra_opts += ' -DMNN_INTERNAL=ON ' if IS_INTERNAL_BUILD else ' '
|
||||
extra_opts += ' -DMNN_BUILD_TORCH=ON ' if IS_BUILD_TORCH else ' '
|
||||
print(extra_opts)
|
||||
os.system('cmake ' + extra_opts + '-DMNN_BUILD_CONVERTER=on -DMNN_BUILD_TRAIN=ON -DCMAKE_BUILD_TYPE=Release \
|
||||
-DMNN_BUILD_SHARED_LIBS=OFF -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF -DMNN_EXPR_SHAPE_EAGER=ON -DMNN_TRAIN_DEBUG=ON\
|
||||
-DMNN_BUILD_SHARED_LIBS=OFF -DMNN_AAPL_FMWK=OFF -DMNN_SEP_BUILD=OFF\
|
||||
-DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON \
|
||||
.. && make MNN MNNTrain MNNConvert -j4')
|
||||
################################################################################
|
||||
|
|
|
@ -21,19 +21,16 @@ using Vec4 = MNN::Math::Vec<float, 4>;
|
|||
namespace MNN {
|
||||
|
||||
ErrorCode CPUBinary::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
||||
const int input0DataCount = ((CPUBackend*)backend())->getTensorSize(inputs[0]);
|
||||
const int input1DataCount = ((CPUBackend*)backend())->getTensorSize(inputs[1]);
|
||||
auto input0DataCount = TensorUtils::getRawSize(inputs[0]);
|
||||
auto input1DataCount = TensorUtils::getRawSize(inputs[1]);
|
||||
if (input1DataCount == input0DataCount) {
|
||||
mNeedBroadcastIndex = -1;
|
||||
mTotalSize = input1DataCount;
|
||||
} else if (input0DataCount == 1) {
|
||||
mNeedBroadcastIndex = 0;
|
||||
mTotalSize = input1DataCount;
|
||||
} else {
|
||||
mNeedBroadcastIndex = 1;
|
||||
mTotalSize = input0DataCount;
|
||||
}
|
||||
MNN_ASSERT(mTotalSize == ((CPUBackend*)backend())->getTensorSize(outputs[0]));
|
||||
mTotalSize = ((CPUBackend*)backend())->getTensorSize(outputs[0]);
|
||||
|
||||
if(mActivationType == 1 && outputs[0]->getType().code == halide_type_float) {
|
||||
mActivationExe.reset(new CPURelu(backend(), 0.0));
|
||||
|
|
|
@ -19,19 +19,16 @@
|
|||
namespace MNN {
|
||||
|
||||
ErrorCode CPUBinaryInt8::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
||||
const int input0DataCount = ((CPUBackend*)backend())->getTensorSize(inputs[0]);
|
||||
const int input1DataCount = ((CPUBackend*)backend())->getTensorSize(inputs[1]);
|
||||
auto input0DataCount = TensorUtils::getRawSize(inputs[0]);
|
||||
auto input1DataCount = TensorUtils::getRawSize(inputs[1]);
|
||||
if (input1DataCount == input0DataCount) {
|
||||
mNeedBroadcastIndex = -1;
|
||||
mTotalSize = input1DataCount;
|
||||
} else if (input0DataCount == 1) {
|
||||
mNeedBroadcastIndex = 0;
|
||||
mTotalSize = input1DataCount;
|
||||
} else {
|
||||
mNeedBroadcastIndex = 1;
|
||||
mTotalSize = input0DataCount;
|
||||
}
|
||||
MNN_ASSERT(mTotalSize == ((CPUBackend*)backend())->getTensorSize(outputs[0]));
|
||||
mTotalSize = ((CPUBackend*)backend())->getTensorSize(outputs[0]);
|
||||
|
||||
auto core = static_cast<CPUBackend*>(backend())->functions();
|
||||
|
||||
|
|
|
@ -835,10 +835,12 @@ public:
|
|||
auto dstIter = *(iter0 + iter0Stride * iter);
|
||||
auto srcOffset = srcIter * step1 + srcView->offset();
|
||||
auto dstOffset = dstIter * step0 + dstView->offset();
|
||||
if (srcOffset >= 0 && srcOffset < inputSize) {
|
||||
_blit(reg, bytes, input->host<uint8_t>() + bytes * srcOffset, output->host<uint8_t>() + bytes * dstOffset, proc);
|
||||
} else {
|
||||
_zero(reg, bytes, output->host<uint8_t>() + bytes * dstOffset);
|
||||
if (dstOffset >= 0) {
|
||||
if (srcOffset >= 0 && srcOffset < inputSize) {
|
||||
_blit(reg, bytes, input->host<uint8_t>() + bytes * srcOffset, output->host<uint8_t>() + bytes * dstOffset, proc);
|
||||
} else {
|
||||
_zero(reg, bytes, output->host<uint8_t>() + bytes * dstOffset);
|
||||
}
|
||||
}
|
||||
}
|
||||
return NO_ERROR;
|
||||
|
|
|
@ -16,6 +16,10 @@
|
|||
asm_function MNNFloat2Int8
|
||||
//void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, float* scale, size_t aMin, size_t aMax, size_t zeroPoint);
|
||||
//x0:src, x1:dst, x2:sizeQuad, x3:scale, x4:aMin, x5:aMax, x6:zeroPoint
|
||||
stp d14, d15, [sp, #-64]!
|
||||
stp d12, d13, [sp, #16]
|
||||
stp d10, d11, [sp, #32]
|
||||
stp d8, d9, [sp, #48]
|
||||
|
||||
ld1 {v31.4s}, [x3]
|
||||
|
||||
|
@ -23,12 +27,380 @@ dup v30.16b, w4
|
|||
dup v29.16b, w5
|
||||
|
||||
// copy zero point
|
||||
mov v28.s[0], w6
|
||||
mov v28.s[1], w6
|
||||
mov v28.s[2], w6
|
||||
mov v28.s[3], w6
|
||||
dup v28.4s, w6
|
||||
scvtf v28.4s, v28.4s
|
||||
|
||||
FL32:
|
||||
cmp x2, #32
|
||||
ble FL16
|
||||
|
||||
FLLoop32:
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
|
||||
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64
|
||||
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64
|
||||
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64
|
||||
// ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64
|
||||
// ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], #64
|
||||
fmul v0.4s, v0.4s, v31.4s
|
||||
fadd v0.4s, v0.4s, v28.4s
|
||||
fmul v1.4s, v1.4s, v31.4s
|
||||
fadd v1.4s, v1.4s, v28.4s
|
||||
fmul v2.4s, v2.4s, v31.4s
|
||||
fadd v2.4s, v2.4s, v28.4s
|
||||
fmul v3.4s, v3.4s, v31.4s
|
||||
fadd v3.4s, v3.4s, v28.4s
|
||||
|
||||
fmul v4.4s, v4.4s, v31.4s
|
||||
fadd v4.4s, v4.4s, v28.4s
|
||||
fmul v5.4s, v5.4s, v31.4s
|
||||
fadd v5.4s, v5.4s, v28.4s
|
||||
fmul v6.4s, v6.4s, v31.4s
|
||||
fadd v6.4s, v6.4s, v28.4s
|
||||
fmul v7.4s, v7.4s, v31.4s
|
||||
fadd v7.4s, v7.4s, v28.4s
|
||||
|
||||
fmul v8.4s, v8.4s, v31.4s
|
||||
fadd v8.4s, v8.4s, v28.4s
|
||||
fmul v9.4s, v9.4s, v31.4s
|
||||
fadd v9.4s, v9.4s, v28.4s
|
||||
fmul v10.4s, v10.4s, v31.4s
|
||||
fadd v10.4s, v10.4s, v28.4s
|
||||
fmul v11.4s, v11.4s, v31.4s
|
||||
fadd v11.4s, v11.4s, v28.4s
|
||||
|
||||
fmul v12.4s, v12.4s, v31.4s
|
||||
fadd v12.4s, v12.4s, v28.4s
|
||||
fmul v13.4s, v13.4s, v31.4s
|
||||
fadd v13.4s, v13.4s, v28.4s
|
||||
fmul v14.4s, v14.4s, v31.4s
|
||||
fadd v14.4s, v14.4s, v28.4s
|
||||
fmul v15.4s, v15.4s, v31.4s
|
||||
fadd v15.4s, v15.4s, v28.4s
|
||||
|
||||
|
||||
fmul v16.4s, v16.4s, v31.4s
|
||||
fadd v16.4s, v16.4s, v28.4s
|
||||
fmul v17.4s, v17.4s, v31.4s
|
||||
fadd v17.4s, v17.4s, v28.4s
|
||||
fmul v18.4s, v18.4s, v31.4s
|
||||
fadd v18.4s, v18.4s, v28.4s
|
||||
fmul v19.4s, v19.4s, v31.4s
|
||||
fadd v19.4s, v19.4s, v28.4s
|
||||
|
||||
fmul v20.4s, v20.4s, v31.4s
|
||||
fadd v20.4s, v20.4s, v28.4s
|
||||
fmul v21.4s, v21.4s, v31.4s
|
||||
fadd v21.4s, v21.4s, v28.4s
|
||||
fmul v22.4s, v22.4s, v31.4s
|
||||
fadd v22.4s, v22.4s, v28.4s
|
||||
fmul v23.4s, v23.4s, v31.4s
|
||||
fadd v23.4s, v23.4s, v28.4s
|
||||
|
||||
fcvtas v0.4s, v0.4s
|
||||
fcvtas v1.4s, v1.4s
|
||||
fcvtas v2.4s, v2.4s
|
||||
fcvtas v3.4s, v3.4s
|
||||
fcvtas v4.4s, v4.4s
|
||||
fcvtas v5.4s, v5.4s
|
||||
fcvtas v6.4s, v6.4s
|
||||
fcvtas v7.4s, v7.4s
|
||||
|
||||
fcvtas v8.4s, v8.4s
|
||||
fcvtas v9.4s, v9.4s
|
||||
fcvtas v10.4s, v10.4s
|
||||
fcvtas v11.4s, v11.4s
|
||||
fcvtas v12.4s, v12.4s
|
||||
fcvtas v13.4s, v13.4s
|
||||
fcvtas v14.4s, v14.4s
|
||||
fcvtas v15.4s, v15.4s
|
||||
|
||||
fcvtas v16.4s, v16.4s
|
||||
fcvtas v17.4s, v17.4s
|
||||
fcvtas v18.4s, v18.4s
|
||||
fcvtas v19.4s, v19.4s
|
||||
fcvtas v20.4s, v20.4s
|
||||
fcvtas v21.4s, v21.4s
|
||||
fcvtas v22.4s, v22.4s
|
||||
fcvtas v23.4s, v23.4s
|
||||
|
||||
|
||||
sqxtn v24.4h, v0.4s
|
||||
sqxtn2 v24.8h, v1.4s
|
||||
sqxtn v25.4h, v2.4s
|
||||
sqxtn2 v25.8h, v3.4s
|
||||
sqxtn v26.4h, v4.4s
|
||||
sqxtn2 v26.8h, v5.4s
|
||||
sqxtn v27.4h, v6.4s
|
||||
sqxtn2 v27.8h, v7.4s
|
||||
|
||||
sqxtn v0.4h, v8.4s
|
||||
sqxtn2 v0.8h, v9.4s
|
||||
sqxtn v1.4h, v10.4s
|
||||
sqxtn2 v1.8h, v11.4s
|
||||
sqxtn v2.4h, v12.4s
|
||||
sqxtn2 v2.8h, v13.4s
|
||||
sqxtn v3.4h, v14.4s
|
||||
sqxtn2 v3.8h, v15.4s
|
||||
|
||||
sqxtn v4.4h, v16.4s
|
||||
sqxtn2 v4.8h, v17.4s
|
||||
sqxtn v5.4h, v18.4s
|
||||
sqxtn2 v5.8h, v19.4s
|
||||
sqxtn v6.4h, v20.4s
|
||||
sqxtn2 v6.8h, v21.4s
|
||||
sqxtn v7.4h, v22.4s
|
||||
sqxtn2 v7.8h, v23.4s
|
||||
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64
|
||||
|
||||
sqxtn v24.8b, v24.8h
|
||||
sqxtn2 v24.16b, v25.8h
|
||||
sqxtn v26.8b, v26.8h
|
||||
sqxtn2 v26.16b, v27.8h
|
||||
sqxtn v0.8b, v0.8h
|
||||
sqxtn2 v0.16b, v1.8h
|
||||
sqxtn v2.8b, v2.8h
|
||||
sqxtn2 v2.16b, v3.8h
|
||||
|
||||
sqxtn v4.8b, v4.8h
|
||||
sqxtn v6.8b, v6.8h
|
||||
sqxtn2 v4.16b, v5.8h
|
||||
sqxtn2 v6.16b, v7.8h
|
||||
|
||||
fmul v8.4s, v8.4s, v31.4s
|
||||
fadd v8.4s, v8.4s, v28.4s
|
||||
fmul v9.4s, v9.4s, v31.4s
|
||||
fadd v9.4s, v9.4s, v28.4s
|
||||
fmul v10.4s, v10.4s, v31.4s
|
||||
fadd v10.4s, v10.4s, v28.4s
|
||||
fmul v11.4s, v11.4s, v31.4s
|
||||
fadd v11.4s, v11.4s, v28.4s
|
||||
|
||||
fmul v12.4s, v12.4s, v31.4s
|
||||
fadd v12.4s, v12.4s, v28.4s
|
||||
fmul v13.4s, v13.4s, v31.4s
|
||||
fadd v13.4s, v13.4s, v28.4s
|
||||
fmul v14.4s, v14.4s, v31.4s
|
||||
fadd v14.4s, v14.4s, v28.4s
|
||||
fmul v15.4s, v15.4s, v31.4s
|
||||
fadd v15.4s, v15.4s, v28.4s
|
||||
|
||||
fcvtas v8.4s, v8.4s
|
||||
fcvtas v9.4s, v9.4s
|
||||
fcvtas v10.4s, v10.4s
|
||||
fcvtas v11.4s, v11.4s
|
||||
fcvtas v12.4s, v12.4s
|
||||
fcvtas v13.4s, v13.4s
|
||||
fcvtas v14.4s, v14.4s
|
||||
fcvtas v15.4s, v15.4s
|
||||
|
||||
sqxtn v16.4h, v8.4s
|
||||
sqxtn2 v16.8h, v9.4s
|
||||
sqxtn v17.4h, v10.4s
|
||||
sqxtn2 v17.8h, v11.4s
|
||||
sqxtn v18.4h, v12.4s
|
||||
sqxtn2 v18.8h, v13.4s
|
||||
sqxtn v19.4h, v14.4s
|
||||
sqxtn2 v19.8h, v15.4s
|
||||
|
||||
smin v24.16b, v24.16b, v29.16b
|
||||
smax v24.16b, v24.16b, v30.16b
|
||||
smin v25.16b, v26.16b, v29.16b
|
||||
smax v25.16b, v25.16b, v30.16b
|
||||
|
||||
sqxtn v20.8b, v16.8h
|
||||
sqxtn2 v20.16b, v17.8h
|
||||
sqxtn v21.8b, v18.8h
|
||||
sqxtn2 v21.16b, v19.8h
|
||||
|
||||
smin v26.16b, v0.16b, v29.16b
|
||||
smax v26.16b, v26.16b, v30.16b
|
||||
smin v27.16b, v2.16b, v29.16b
|
||||
smax v27.16b, v27.16b, v30.16b
|
||||
|
||||
smin v12.16b, v4.16b, v29.16b
|
||||
smax v12.16b, v12.16b, v30.16b
|
||||
smin v13.16b, v6.16b, v29.16b
|
||||
smax v13.16b, v13.16b, v30.16b
|
||||
|
||||
smin v14.16b, v20.16b, v29.16b
|
||||
smax v14.16b, v14.16b, v30.16b
|
||||
smin v15.16b, v21.16b, v29.16b
|
||||
smax v15.16b, v15.16b, v30.16b
|
||||
|
||||
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x1], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64
|
||||
|
||||
sub x2, x2, #32
|
||||
cmp x2, #32
|
||||
bge FLLoop32
|
||||
|
||||
FL16:
|
||||
cmp x2, #16
|
||||
ble FL8
|
||||
|
||||
FLLoop16:
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
|
||||
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64
|
||||
fmul v0.4s, v0.4s, v31.4s
|
||||
fadd v0.4s, v0.4s, v28.4s
|
||||
fmul v1.4s, v1.4s, v31.4s
|
||||
fadd v1.4s, v1.4s, v28.4s
|
||||
fmul v2.4s, v2.4s, v31.4s
|
||||
fadd v2.4s, v2.4s, v28.4s
|
||||
fmul v3.4s, v3.4s, v31.4s
|
||||
fadd v3.4s, v3.4s, v28.4s
|
||||
|
||||
fmul v4.4s, v4.4s, v31.4s
|
||||
fadd v4.4s, v4.4s, v28.4s
|
||||
fmul v5.4s, v5.4s, v31.4s
|
||||
fadd v5.4s, v5.4s, v28.4s
|
||||
fmul v6.4s, v6.4s, v31.4s
|
||||
fadd v6.4s, v6.4s, v28.4s
|
||||
fmul v7.4s, v7.4s, v31.4s
|
||||
fadd v7.4s, v7.4s, v28.4s
|
||||
|
||||
fmul v8.4s, v8.4s, v31.4s
|
||||
fadd v8.4s, v8.4s, v28.4s
|
||||
fmul v9.4s, v9.4s, v31.4s
|
||||
fadd v9.4s, v9.4s, v28.4s
|
||||
fmul v10.4s, v10.4s, v31.4s
|
||||
fadd v10.4s, v10.4s, v28.4s
|
||||
fmul v11.4s, v11.4s, v31.4s
|
||||
fadd v11.4s, v11.4s, v28.4s
|
||||
|
||||
fmul v12.4s, v12.4s, v31.4s
|
||||
fadd v12.4s, v12.4s, v28.4s
|
||||
fmul v13.4s, v13.4s, v31.4s
|
||||
fadd v13.4s, v13.4s, v28.4s
|
||||
fmul v14.4s, v14.4s, v31.4s
|
||||
fadd v14.4s, v14.4s, v28.4s
|
||||
fmul v15.4s, v15.4s, v31.4s
|
||||
fadd v15.4s, v15.4s, v28.4s
|
||||
|
||||
fcvtas v0.4s, v0.4s
|
||||
fcvtas v1.4s, v1.4s
|
||||
fcvtas v2.4s, v2.4s
|
||||
fcvtas v3.4s, v3.4s
|
||||
fcvtas v4.4s, v4.4s
|
||||
fcvtas v5.4s, v5.4s
|
||||
fcvtas v6.4s, v6.4s
|
||||
fcvtas v7.4s, v7.4s
|
||||
|
||||
fcvtas v8.4s, v8.4s
|
||||
fcvtas v9.4s, v9.4s
|
||||
fcvtas v10.4s, v10.4s
|
||||
fcvtas v11.4s, v11.4s
|
||||
fcvtas v12.4s, v12.4s
|
||||
fcvtas v13.4s, v13.4s
|
||||
fcvtas v14.4s, v14.4s
|
||||
fcvtas v15.4s, v15.4s
|
||||
|
||||
sqxtn v16.4h, v0.4s
|
||||
sqxtn2 v16.8h, v1.4s
|
||||
sqxtn v17.4h, v2.4s
|
||||
sqxtn2 v17.8h, v3.4s
|
||||
sqxtn v18.4h, v4.4s
|
||||
sqxtn2 v18.8h, v5.4s
|
||||
sqxtn v19.4h, v6.4s
|
||||
sqxtn2 v19.8h, v7.4s
|
||||
|
||||
sqxtn v20.4h, v8.4s
|
||||
sqxtn2 v20.8h, v9.4s
|
||||
sqxtn v21.4h, v10.4s
|
||||
sqxtn2 v21.8h, v11.4s
|
||||
sqxtn v22.4h, v12.4s
|
||||
sqxtn2 v22.8h, v13.4s
|
||||
sqxtn v23.4h, v14.4s
|
||||
sqxtn2 v23.8h, v15.4s
|
||||
|
||||
sqxtn v24.8b, v16.8h
|
||||
sqxtn2 v24.16b, v17.8h
|
||||
sqxtn v25.8b, v18.8h
|
||||
sqxtn2 v25.16b, v19.8h
|
||||
sqxtn v26.8b, v20.8h
|
||||
sqxtn2 v26.16b, v21.8h
|
||||
sqxtn v27.8b, v22.8h
|
||||
sqxtn2 v27.16b, v23.8h
|
||||
smin v24.16b, v24.16b, v29.16b
|
||||
smax v24.16b, v24.16b, v30.16b
|
||||
smin v25.16b, v25.16b, v29.16b
|
||||
smax v25.16b, v25.16b, v30.16b
|
||||
smin v26.16b, v26.16b, v29.16b
|
||||
smax v26.16b, v26.16b, v30.16b
|
||||
smin v27.16b, v27.16b, v29.16b
|
||||
smax v27.16b, v27.16b, v30.16b
|
||||
|
||||
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x1], #64
|
||||
|
||||
sub x2, x2, #16
|
||||
cmp x2, #16
|
||||
bge FLLoop16
|
||||
|
||||
FL8:
|
||||
cmp x2, #8
|
||||
ble FL4
|
||||
|
||||
FLLoop8:
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
|
||||
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
|
||||
fmul v0.4s, v0.4s, v31.4s
|
||||
fadd v0.4s, v0.4s, v28.4s
|
||||
fmul v1.4s, v1.4s, v31.4s
|
||||
fadd v1.4s, v1.4s, v28.4s
|
||||
fmul v2.4s, v2.4s, v31.4s
|
||||
fadd v2.4s, v2.4s, v28.4s
|
||||
fmul v3.4s, v3.4s, v31.4s
|
||||
fadd v3.4s, v3.4s, v28.4s
|
||||
|
||||
fmul v4.4s, v4.4s, v31.4s
|
||||
fadd v4.4s, v4.4s, v28.4s
|
||||
fmul v5.4s, v5.4s, v31.4s
|
||||
fadd v5.4s, v5.4s, v28.4s
|
||||
fmul v6.4s, v6.4s, v31.4s
|
||||
fadd v6.4s, v6.4s, v28.4s
|
||||
fmul v7.4s, v7.4s, v31.4s
|
||||
fadd v7.4s, v7.4s, v28.4s
|
||||
|
||||
fcvtas v0.4s, v0.4s
|
||||
fcvtas v1.4s, v1.4s
|
||||
fcvtas v2.4s, v2.4s
|
||||
fcvtas v3.4s, v3.4s
|
||||
fcvtas v4.4s, v4.4s
|
||||
fcvtas v5.4s, v5.4s
|
||||
fcvtas v6.4s, v6.4s
|
||||
fcvtas v7.4s, v7.4s
|
||||
|
||||
sqxtn v8.4h, v0.4s
|
||||
sqxtn2 v8.8h, v1.4s
|
||||
sqxtn v9.4h, v2.4s
|
||||
sqxtn2 v9.8h, v3.4s
|
||||
sqxtn v10.4h, v4.4s
|
||||
sqxtn2 v10.8h, v5.4s
|
||||
sqxtn v11.4h, v6.4s
|
||||
sqxtn2 v11.8h, v7.4s
|
||||
|
||||
sqxtn v12.8b, v8.8h
|
||||
sqxtn2 v12.16b, v9.8h
|
||||
sqxtn v13.8b, v10.8h
|
||||
sqxtn2 v13.16b, v11.8h
|
||||
smin v12.16b, v12.16b, v29.16b
|
||||
smax v12.16b, v12.16b, v30.16b
|
||||
smin v13.16b, v13.16b, v29.16b
|
||||
smax v13.16b, v13.16b, v30.16b
|
||||
|
||||
st1 {v12.4s, v13.4s}, [x1], #32
|
||||
|
||||
sub x2, x2, #8
|
||||
cmp x2, #8
|
||||
bge FLLoop8
|
||||
|
||||
FL4:
|
||||
cmp x2, #3
|
||||
ble FL1
|
||||
|
||||
|
@ -89,6 +461,9 @@ subs x2, x2, #1
|
|||
bne FLLoop1
|
||||
|
||||
FLEnd:
|
||||
|
||||
ldp d8, d9, [sp, #48]
|
||||
ldp d10, d11, [sp, #32]
|
||||
ldp d12, d13, [sp, #16]
|
||||
ldp d14, d15, [sp], #64
|
||||
ret
|
||||
#endif
|
||||
|
|
|
@ -595,12 +595,13 @@ Tile1LoopEnd:
|
|||
bge LoopDz_TILE_1
|
||||
|
||||
End:
|
||||
ldp x23, x24, [sp, #(16 * 6)]
|
||||
ldp x19, x20, [sp, #(16 * 5)]
|
||||
ldp x21, x22, [sp, #(16 * 4)]
|
||||
ldp d8, d9, [sp, #(16 * 3)]
|
||||
ldp d10, d11, [sp, #(16 * 2)]
|
||||
ldp d12, d13, [sp, #(16 * 1)]
|
||||
ldp d14, d15, [sp], #(16 * 6)
|
||||
ldp d14, d15, [sp], #(16 * 7)
|
||||
ret
|
||||
|
||||
#endif // __aarch64__
|
||||
|
|
|
@ -18,6 +18,10 @@ asm_function MNNMatrixAdd
|
|||
|
||||
//Auto: x0: C, x1:A, x2:B, x3:widthC4
|
||||
//x4:cStride, x5:aStride, x6:bStride, x7:height
|
||||
stp d14, d15, [sp, #-64]!
|
||||
stp d12, d13, [sp, #16]
|
||||
stp d10, d11, [sp, #32]
|
||||
stp d8, d9, [sp, #48]
|
||||
|
||||
mov x12, #4 //sizeof(float)
|
||||
mul x4, x12, x4
|
||||
|
@ -31,6 +35,72 @@ mov x10, x2
|
|||
|
||||
mov x11, x3
|
||||
|
||||
L16:
|
||||
cmp x11, #16
|
||||
blt L8
|
||||
|
||||
L16Loop:
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
|
||||
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64
|
||||
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64
|
||||
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64
|
||||
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64
|
||||
ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64
|
||||
|
||||
fadd v0.4s, v0.4s, v8.4s
|
||||
fadd v1.4s, v1.4s, v9.4s
|
||||
fadd v2.4s, v2.4s, v10.4s
|
||||
fadd v3.4s, v3.4s, v11.4s
|
||||
fadd v4.4s, v4.4s, v12.4s
|
||||
fadd v5.4s, v5.4s, v13.4s
|
||||
fadd v6.4s, v6.4s, v14.4s
|
||||
fadd v7.4s, v7.4s, v15.4s
|
||||
|
||||
sub x11, x11, #16
|
||||
|
||||
fadd v16.4s, v16.4s, v24.4s
|
||||
fadd v17.4s, v17.4s, v25.4s
|
||||
fadd v18.4s, v18.4s, v26.4s
|
||||
fadd v19.4s, v19.4s, v27.4s
|
||||
fadd v20.4s, v20.4s, v28.4s
|
||||
fadd v21.4s, v21.4s, v29.4s
|
||||
fadd v22.4s, v22.4s, v30.4s
|
||||
fadd v23.4s, v23.4s, v31.4s
|
||||
|
||||
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
|
||||
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
|
||||
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64
|
||||
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64
|
||||
cmp x11, #16
|
||||
bge L16Loop
|
||||
|
||||
L8:
|
||||
cmp x11, #8
|
||||
blt L4
|
||||
|
||||
L8Loop:
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
|
||||
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64
|
||||
|
||||
fadd v0.4s, v0.4s, v8.4s
|
||||
fadd v1.4s, v1.4s, v9.4s
|
||||
fadd v2.4s, v2.4s, v10.4s
|
||||
fadd v3.4s, v3.4s, v11.4s
|
||||
fadd v4.4s, v4.4s, v12.4s
|
||||
fadd v5.4s, v5.4s, v13.4s
|
||||
fadd v6.4s, v6.4s, v14.4s
|
||||
fadd v7.4s, v7.4s, v15.4s
|
||||
sub x11, x11, #8
|
||||
|
||||
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
|
||||
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
|
||||
cmp x11, #8
|
||||
bge L8Loop
|
||||
|
||||
L4:
|
||||
cmp x11, #4
|
||||
blt L1
|
||||
|
@ -89,5 +159,10 @@ add x2, x10, x6
|
|||
subs x7, x7, #1
|
||||
bne LoopY
|
||||
|
||||
End:
|
||||
ldp d8, d9, [sp, #48]
|
||||
ldp d10, d11, [sp, #32]
|
||||
ldp d12, d13, [sp, #16]
|
||||
ldp d14, d15, [sp], #64
|
||||
ret
|
||||
#endif
|
||||
|
|
|
@ -18,6 +18,10 @@ asm_function MNNMatrixSub
|
|||
|
||||
//Auto: x0: C, x1:A, x2:B, x3:widthC4
|
||||
//x4:cStride, x5:aStride, x6:bStride, x7:height
|
||||
stp d14, d15, [sp, #-64]!
|
||||
stp d12, d13, [sp, #16]
|
||||
stp d10, d11, [sp, #32]
|
||||
stp d8, d9, [sp, #48]
|
||||
|
||||
mov x12, #4 //sizeof(float)
|
||||
mul x4, x12, x4
|
||||
|
@ -31,6 +35,72 @@ mov x10, x2
|
|||
|
||||
mov x11, x3
|
||||
|
||||
L16:
|
||||
cmp x11, #16
|
||||
blt L8
|
||||
|
||||
L16Loop:
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
|
||||
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64
|
||||
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64
|
||||
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64
|
||||
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64
|
||||
ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x2], #64
|
||||
|
||||
fsub v0.4s, v0.4s, v8.4s
|
||||
fsub v1.4s, v1.4s, v9.4s
|
||||
fsub v2.4s, v2.4s, v10.4s
|
||||
fsub v3.4s, v3.4s, v11.4s
|
||||
fsub v4.4s, v4.4s, v12.4s
|
||||
fsub v5.4s, v5.4s, v13.4s
|
||||
fsub v6.4s, v6.4s, v14.4s
|
||||
fsub v7.4s, v7.4s, v15.4s
|
||||
|
||||
sub x11, x11, #16
|
||||
|
||||
fsub v16.4s, v16.4s, v24.4s
|
||||
fsub v17.4s, v17.4s, v25.4s
|
||||
fsub v18.4s, v18.4s, v26.4s
|
||||
fsub v19.4s, v19.4s, v27.4s
|
||||
fsub v20.4s, v20.4s, v28.4s
|
||||
fsub v21.4s, v21.4s, v29.4s
|
||||
fsub v22.4s, v22.4s, v30.4s
|
||||
fsub v23.4s, v23.4s, v31.4s
|
||||
|
||||
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
|
||||
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
|
||||
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64
|
||||
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64
|
||||
cmp x11, #16
|
||||
bge L16Loop
|
||||
|
||||
L8:
|
||||
cmp x11, #8
|
||||
blt L4
|
||||
|
||||
L8Loop:
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
|
||||
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64
|
||||
|
||||
fsub v0.4s, v0.4s, v8.4s
|
||||
fsub v1.4s, v1.4s, v9.4s
|
||||
fsub v2.4s, v2.4s, v10.4s
|
||||
fsub v3.4s, v3.4s, v11.4s
|
||||
fsub v4.4s, v4.4s, v12.4s
|
||||
fsub v5.4s, v5.4s, v13.4s
|
||||
fsub v6.4s, v6.4s, v14.4s
|
||||
fsub v7.4s, v7.4s, v15.4s
|
||||
sub x11, x11, #8
|
||||
|
||||
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
|
||||
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
|
||||
cmp x11, #8
|
||||
bge L8Loop
|
||||
|
||||
L4:
|
||||
cmp x11, #4
|
||||
blt L1
|
||||
|
@ -89,6 +159,11 @@ add x2, x10, x6
|
|||
subs x7, x7, #1
|
||||
bne LoopY
|
||||
|
||||
End:
|
||||
ldp d8, d9, [sp, #48]
|
||||
ldp d10, d11, [sp, #32]
|
||||
ldp d12, d13, [sp, #16]
|
||||
ldp d14, d15, [sp], #64
|
||||
ret
|
||||
|
||||
#endif
|
||||
|
|
|
@ -17,7 +17,10 @@ asm_function MNNReluWithSlopeChannel
|
|||
|
||||
//Auto Load:
|
||||
//x0:dst, x1:src, x2:slope, x3:sizeQuad, x4:depthQuad
|
||||
|
||||
stp d14, d15, [sp, #-64]!
|
||||
stp d12, d13, [sp, #16]
|
||||
stp d10, d11, [sp, #32]
|
||||
stp d8, d9, [sp, #48]
|
||||
|
||||
cmp x4, #0
|
||||
beq PReluEnd
|
||||
|
@ -26,48 +29,164 @@ beq PReluEnd
|
|||
|
||||
|
||||
PReluZLoop:
|
||||
ld1 {v23.4s}, [x2], #16
|
||||
ld1 {v31.4s}, [x2], #16
|
||||
mov x5, x3
|
||||
|
||||
PReluL16:
|
||||
cmp x5, #15
|
||||
ble PReluL8
|
||||
|
||||
PReluL16Loop:
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
|
||||
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
|
||||
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x1], #64
|
||||
|
||||
fcmle v16.4s, v0.4s, #0
|
||||
fcmle v17.4s, v1.4s, #0
|
||||
fcmle v18.4s, v2.4s, #0
|
||||
fcmle v19.4s, v3.4s, #0
|
||||
fcmle v20.4s, v4.4s, #0
|
||||
fcmle v21.4s, v5.4s, #0
|
||||
fcmle v22.4s, v6.4s, #0
|
||||
fcmle v23.4s, v7.4s, #0
|
||||
|
||||
fmul v8.4s, v0.4s, v31.4s
|
||||
fmul v9.4s, v1.4s, v31.4s
|
||||
fmul v10.4s, v2.4s, v31.4s
|
||||
fmul v11.4s, v3.4s, v31.4s
|
||||
fmul v12.4s, v4.4s, v31.4s
|
||||
fmul v13.4s, v5.4s, v31.4s
|
||||
fmul v14.4s, v6.4s, v31.4s
|
||||
fmul v15.4s, v7.4s, v31.4s
|
||||
|
||||
fcmle v28.4s, v24.4s, #0
|
||||
fcmle v29.4s, v25.4s, #0
|
||||
fcmle v30.4s, v26.4s, #0
|
||||
|
||||
bit v0.16b, v8.16b, v16.16b
|
||||
bit v1.16b, v9.16b, v17.16b
|
||||
bit v2.16b, v10.16b, v18.16b
|
||||
bit v3.16b, v11.16b, v19.16b
|
||||
bit v4.16b, v12.16b, v20.16b
|
||||
bit v5.16b, v13.16b, v21.16b
|
||||
bit v6.16b, v14.16b, v22.16b
|
||||
bit v7.16b, v15.16b, v23.16b
|
||||
|
||||
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
|
||||
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
|
||||
fcmle v8.4s, v27.4s, #0
|
||||
fmul v9.4s, v24.4s, v31.4s
|
||||
fmul v10.4s, v25.4s, v31.4s
|
||||
fmul v11.4s, v26.4s, v31.4s
|
||||
fmul v12.4s, v27.4s, v31.4s
|
||||
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64
|
||||
|
||||
fcmle v13.4s, v16.4s, #0
|
||||
fcmle v14.4s, v17.4s, #0
|
||||
fcmle v15.4s, v18.4s, #0
|
||||
fcmle v0.4s, v19.4s, #0
|
||||
|
||||
fmul v20.4s, v16.4s, v31.4s
|
||||
fmul v21.4s, v17.4s, v31.4s
|
||||
fmul v22.4s, v18.4s, v31.4s
|
||||
fmul v23.4s, v19.4s, v31.4s
|
||||
|
||||
|
||||
bit v24.16b, v9.16b, v28.16b
|
||||
bit v25.16b, v10.16b, v29.16b
|
||||
bit v26.16b, v11.16b, v30.16b
|
||||
bit v27.16b, v12.16b, v8.16b
|
||||
bit v16.16b, v20.16b, v13.16b
|
||||
bit v17.16b, v21.16b, v14.16b
|
||||
bit v18.16b, v22.16b, v15.16b
|
||||
bit v19.16b, v23.16b, v0.16b
|
||||
|
||||
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64
|
||||
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64
|
||||
|
||||
sub x5, x5, #16
|
||||
cmp x5, #16
|
||||
bge PReluL16Loop
|
||||
|
||||
PReluL8:
|
||||
cmp x5, #7
|
||||
ble PReluL4
|
||||
|
||||
PReluL8Loop:
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
|
||||
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
|
||||
|
||||
fcmle v16.4s, v0.4s, #0
|
||||
fcmle v17.4s, v1.4s, #0
|
||||
fcmle v18.4s, v2.4s, #0
|
||||
fcmle v19.4s, v3.4s, #0
|
||||
fcmle v20.4s, v4.4s, #0
|
||||
fcmle v21.4s, v5.4s, #0
|
||||
fcmle v22.4s, v6.4s, #0
|
||||
fcmle v23.4s, v7.4s, #0
|
||||
|
||||
fmul v8.4s, v0.4s, v31.4s
|
||||
fmul v9.4s, v1.4s, v31.4s
|
||||
fmul v10.4s, v2.4s, v31.4s
|
||||
fmul v11.4s, v3.4s, v31.4s
|
||||
fmul v12.4s, v4.4s, v31.4s
|
||||
fmul v13.4s, v5.4s, v31.4s
|
||||
fmul v14.4s, v6.4s, v31.4s
|
||||
fmul v15.4s, v7.4s, v31.4s
|
||||
|
||||
|
||||
bit v0.16b, v8.16b, v16.16b
|
||||
bit v1.16b, v9.16b, v17.16b
|
||||
bit v2.16b, v10.16b, v18.16b
|
||||
bit v3.16b, v11.16b, v19.16b
|
||||
bit v4.16b, v12.16b, v20.16b
|
||||
bit v5.16b, v13.16b, v21.16b
|
||||
bit v6.16b, v14.16b, v22.16b
|
||||
bit v7.16b, v15.16b, v23.16b
|
||||
|
||||
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
|
||||
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
|
||||
|
||||
sub x5, x5, #8
|
||||
cmp x5, #8
|
||||
bge PReluL8Loop
|
||||
|
||||
PReluL4:
|
||||
cmp x5, #3
|
||||
ble PReluL1
|
||||
|
||||
PReluL4Loop:
|
||||
ld1 {v0.4s, v1.4s}, [x1], #32
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
|
||||
|
||||
fcmle v20.4s, v0.4s, #0
|
||||
fcmle v21.4s, v1.4s, #0
|
||||
fcmle v8.4s, v0.4s, #0
|
||||
fcmle v9.4s, v1.4s, #0
|
||||
fcmle v10.4s, v2.4s, #0
|
||||
fcmle v11.4s, v3.4s, #0
|
||||
|
||||
ld1 {v2.4s, v3.4s}, [x1], #32
|
||||
fmul v4.4s, v0.4s, v31.4s
|
||||
fmul v5.4s, v1.4s, v31.4s
|
||||
fmul v6.4s, v2.4s, v31.4s
|
||||
fmul v7.4s, v3.4s, v31.4s
|
||||
|
||||
fmul v16.4s, v0.4s, v23.4s
|
||||
fmul v17.4s, v1.4s, v23.4s
|
||||
bit v0.16b, v16.16b, v20.16b
|
||||
bit v1.16b, v17.16b, v21.16b
|
||||
bit v0.16b, v4.16b, v8.16b
|
||||
bit v1.16b, v5.16b, v9.16b
|
||||
bit v2.16b, v6.16b, v10.16b
|
||||
bit v3.16b, v7.16b, v11.16b
|
||||
|
||||
fmul v16.4s, v2.4s, v23.4s
|
||||
fmul v17.4s, v3.4s, v23.4s
|
||||
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
|
||||
|
||||
st1 {v0.4s, v1.4s}, [x0], #32
|
||||
|
||||
fcmle v20.4s, v2.4s, #0
|
||||
fcmle v21.4s, v3.4s, #0
|
||||
bit v2.16b, v16.16b, v20.16b
|
||||
bit v3.16b, v17.16b, v21.16b
|
||||
|
||||
st1 {v2.4s, v3.4s}, [x0], #32
|
||||
sub x5, x5, #4
|
||||
cmp x5, #4
|
||||
bge PReluL4Loop
|
||||
|
||||
PReluL1:
|
||||
cmp x5, #0
|
||||
|
||||
beq PReluL1End
|
||||
|
||||
PReluL1Loop:
|
||||
ld1 {v0.4s}, [x1], #16
|
||||
fcmle v2.4s, v0.4s, #0
|
||||
fmul v1.4s, v0.4s, v23.4s
|
||||
fmul v1.4s, v0.4s, v31.4s
|
||||
bit v0.16b, v1.16b, v2.16b
|
||||
st1 {v0.4s}, [x0], #16
|
||||
subs x5, x5, #1
|
||||
|
@ -80,6 +199,10 @@ bne PReluZLoop
|
|||
|
||||
|
||||
PReluEnd:
|
||||
ldp d8, d9, [sp, #48]
|
||||
ldp d10, d11, [sp, #32]
|
||||
ldp d12, d13, [sp, #16]
|
||||
ldp d14, d15, [sp], #64
|
||||
|
||||
ret
|
||||
#endif
|
||||
|
|
|
@ -16,6 +16,10 @@ asm_function MNNScaleAndAddBias
|
|||
//void MNNScaleAndAddBias(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber, size_t biasNumber)
|
||||
|
||||
//Auto: x0:dst, x1:src, x2:bias, x3:alpha, x4:planeNumber, x5:biasNumber
|
||||
stp d14, d15, [sp, #-64]!
|
||||
stp d12, d13, [sp, #16]
|
||||
stp d10, d11, [sp, #32]
|
||||
stp d8, d9, [sp, #48]
|
||||
|
||||
cmp x4, #0
|
||||
beq BSEnd
|
||||
|
@ -27,8 +31,158 @@ BSLoopZ:
|
|||
mov x6, x4
|
||||
ld1 {v31.4s}, [x2], #16
|
||||
ld1 {v30.4s}, [x3], #16
|
||||
|
||||
BSL32:
|
||||
cmp x6, #31
|
||||
ble BSL16
|
||||
BSLoopP32:
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
|
||||
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64
|
||||
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64
|
||||
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64
|
||||
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x1], #64
|
||||
fmul v0.4s, v0.4s, v30.4s
|
||||
fmul v1.4s, v1.4s, v30.4s
|
||||
fmul v2.4s, v2.4s, v30.4s
|
||||
fmul v3.4s, v3.4s, v30.4s
|
||||
fmul v4.4s, v4.4s, v30.4s
|
||||
fmul v5.4s, v5.4s, v30.4s
|
||||
fmul v6.4s, v6.4s, v30.4s
|
||||
fmul v7.4s, v7.4s, v30.4s
|
||||
sub x6, x6, #32
|
||||
fmul v8.4s, v8.4s, v30.4s
|
||||
fmul v9.4s, v9.4s, v30.4s
|
||||
fmul v10.4s, v10.4s, v30.4s
|
||||
fmul v11.4s, v11.4s, v30.4s
|
||||
fmul v12.4s, v12.4s, v30.4s
|
||||
fmul v13.4s, v13.4s, v30.4s
|
||||
fmul v14.4s, v14.4s, v30.4s
|
||||
fmul v15.4s, v15.4s, v30.4s
|
||||
|
||||
fmul v16.4s, v16.4s, v30.4s
|
||||
fmul v17.4s, v17.4s, v30.4s
|
||||
fmul v18.4s, v18.4s, v30.4s
|
||||
fmul v19.4s, v19.4s, v30.4s
|
||||
fmul v20.4s, v20.4s, v30.4s
|
||||
fmul v21.4s, v21.4s, v30.4s
|
||||
fmul v22.4s, v22.4s, v30.4s
|
||||
fmul v23.4s, v23.4s, v30.4s
|
||||
|
||||
fmul v24.4s, v24.4s, v30.4s
|
||||
fmul v25.4s, v25.4s, v30.4s
|
||||
fmul v26.4s, v26.4s, v30.4s
|
||||
fmul v27.4s, v27.4s, v30.4s
|
||||
|
||||
fadd v0.4s, v0.4s, v31.4s
|
||||
fadd v1.4s, v1.4s, v31.4s
|
||||
fadd v2.4s, v2.4s, v31.4s
|
||||
fadd v3.4s, v3.4s, v31.4s
|
||||
fadd v4.4s, v4.4s, v31.4s
|
||||
fadd v5.4s, v5.4s, v31.4s
|
||||
fadd v6.4s, v6.4s, v31.4s
|
||||
fadd v7.4s, v7.4s, v31.4s
|
||||
cmp x6, #32
|
||||
fadd v8.4s, v8.4s, v31.4s
|
||||
fadd v9.4s, v9.4s, v31.4s
|
||||
fadd v10.4s, v10.4s, v31.4s
|
||||
fadd v11.4s, v11.4s, v31.4s
|
||||
fadd v12.4s, v12.4s, v31.4s
|
||||
fadd v13.4s, v13.4s, v31.4s
|
||||
fadd v14.4s, v14.4s, v31.4s
|
||||
fadd v15.4s, v15.4s, v31.4s
|
||||
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
|
||||
|
||||
fadd v16.4s, v16.4s, v31.4s
|
||||
fadd v17.4s, v17.4s, v31.4s
|
||||
fadd v18.4s, v18.4s, v31.4s
|
||||
fadd v19.4s, v19.4s, v31.4s
|
||||
fadd v20.4s, v20.4s, v31.4s
|
||||
fadd v21.4s, v21.4s, v31.4s
|
||||
fadd v22.4s, v22.4s, v31.4s
|
||||
fadd v23.4s, v23.4s, v31.4s
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
|
||||
|
||||
fadd v24.4s, v24.4s, v31.4s
|
||||
fadd v25.4s, v25.4s, v31.4s
|
||||
fadd v26.4s, v26.4s, v31.4s
|
||||
fadd v27.4s, v27.4s, v31.4s
|
||||
|
||||
fmul v0.4s, v0.4s, v30.4s
|
||||
fmul v1.4s, v1.4s, v30.4s
|
||||
fmul v2.4s, v2.4s, v30.4s
|
||||
fmul v3.4s, v3.4s, v30.4s
|
||||
|
||||
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64
|
||||
fadd v0.4s, v0.4s, v31.4s
|
||||
fadd v1.4s, v1.4s, v31.4s
|
||||
fadd v2.4s, v2.4s, v31.4s
|
||||
fadd v3.4s, v3.4s, v31.4s
|
||||
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64
|
||||
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64
|
||||
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64
|
||||
|
||||
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
|
||||
|
||||
bge BSLoopP32
|
||||
|
||||
BSL16:
|
||||
cmp x6, #15
|
||||
ble BSL8_
|
||||
BSLoopP16:
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
|
||||
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
|
||||
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64
|
||||
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64
|
||||
fmul v0.4s, v0.4s, v30.4s
|
||||
fmul v1.4s, v1.4s, v30.4s
|
||||
fmul v2.4s, v2.4s, v30.4s
|
||||
fmul v3.4s, v3.4s, v30.4s
|
||||
fmul v4.4s, v4.4s, v30.4s
|
||||
fmul v5.4s, v5.4s, v30.4s
|
||||
fmul v6.4s, v6.4s, v30.4s
|
||||
fmul v7.4s, v7.4s, v30.4s
|
||||
sub x6, x6, #16
|
||||
fmul v8.4s, v8.4s, v30.4s
|
||||
fmul v9.4s, v9.4s, v30.4s
|
||||
fmul v10.4s, v10.4s, v30.4s
|
||||
fmul v11.4s, v11.4s, v30.4s
|
||||
fmul v12.4s, v12.4s, v30.4s
|
||||
fmul v13.4s, v13.4s, v30.4s
|
||||
fmul v14.4s, v14.4s, v30.4s
|
||||
fmul v15.4s, v15.4s, v30.4s
|
||||
|
||||
fadd v0.4s, v0.4s, v31.4s
|
||||
fadd v1.4s, v1.4s, v31.4s
|
||||
fadd v2.4s, v2.4s, v31.4s
|
||||
fadd v3.4s, v3.4s, v31.4s
|
||||
fadd v4.4s, v4.4s, v31.4s
|
||||
fadd v5.4s, v5.4s, v31.4s
|
||||
fadd v6.4s, v6.4s, v31.4s
|
||||
fadd v7.4s, v7.4s, v31.4s
|
||||
cmp x6, #16
|
||||
fadd v8.4s, v8.4s, v31.4s
|
||||
fadd v9.4s, v9.4s, v31.4s
|
||||
fadd v10.4s, v10.4s, v31.4s
|
||||
fadd v11.4s, v11.4s, v31.4s
|
||||
fadd v12.4s, v12.4s, v31.4s
|
||||
fadd v13.4s, v13.4s, v31.4s
|
||||
fadd v14.4s, v14.4s, v31.4s
|
||||
fadd v15.4s, v15.4s, v31.4s
|
||||
|
||||
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
|
||||
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64
|
||||
|
||||
bge BSLoopP16
|
||||
|
||||
BSL8_:
|
||||
cmp x6, #7
|
||||
ble BSLoopP1
|
||||
ble BSL1
|
||||
BSLoopP8:
|
||||
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
|
||||
fmul v0.4s, v0.4s, v30.4s
|
||||
|
@ -53,7 +207,7 @@ BSLoopZ:
|
|||
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
|
||||
cmp x6, #8
|
||||
bge BSLoopP8
|
||||
|
||||
BSL1:
|
||||
cmp x6, #0
|
||||
beq BSLoopPEnd
|
||||
|
||||
|
@ -71,7 +225,10 @@ BSLoopZ:
|
|||
|
||||
|
||||
BSEnd:
|
||||
|
||||
ldp d8, d9, [sp, #48]
|
||||
ldp d10, d11, [sp, #32]
|
||||
ldp d12, d13, [sp, #16]
|
||||
ldp d14, d15, [sp], #64
|
||||
|
||||
ret
|
||||
|
||||
|
|
|
@ -148,7 +148,7 @@ WinogradConfig ConvolutionPackWinograd::bestWinogradUnit(const Convolution2DComm
|
|||
int oc = outputTensor->channel();
|
||||
int ePack, hPack, lPack;
|
||||
core->MNNGetMatMulPackMode(&ePack, &lPack, &hPack);
|
||||
int unit2 = UP_DIV(ow * oh, ePack * threadNumber);
|
||||
int unit2 = UP_DIV(ow * oh, threadNumber);
|
||||
int maxUnit = (int)::sqrtf((float)unit2);
|
||||
maxUnit = std::min(maxUnit, CONVOLUTION_WINOGRAD_MAX_UNIT);
|
||||
maxUnit = std::max(maxUnit, CONVOLUTION_WINOGRAD_MIN_UNIT);
|
||||
|
|
|
@ -257,7 +257,7 @@ public:
|
|||
auto srcStride0 = cmd->view()->GetAs<View>(1)->stride()->data();
|
||||
auto srcStride1 = cmd->view()->GetAs<View>(2)->stride()->data();
|
||||
auto dstStride = cmd->view()->GetAs<View>(0)->stride()->data();
|
||||
//MNN_PRINT("Binary Loop in optype:%d\n", opType);
|
||||
// MNN_PRINT("Binary Loop in optype:%d\n", opType);
|
||||
BinaryBlit((uint8_t*)dst, (const uint8_t*)src0, (const uint8_t*)src1,
|
||||
cmd->size()->data(), srcStride0, srcStride1, dstStride, type, runtime, opType);
|
||||
|
||||
|
|
|
@ -127,7 +127,7 @@ __global__ void FLOAT##Name(const T *input, T *output,\
|
|||
}\
|
||||
|
||||
template<typename T>
|
||||
__global__ void blit_2(const T *input, T *output,
|
||||
__global__ void blit_2_float(const T *input, T *output,
|
||||
int count,
|
||||
DivModFast sizeZ, DivModFast sizeY, DivModFast sizeX,
|
||||
int strideZ, int strideY,
|
||||
|
@ -143,6 +143,23 @@ __global__ void blit_2(const T *input, T *output,
|
|||
dstF[0] = ((int2 *)(input+srcOffset))[0];
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
__global__ void blit_2_half(const T *input, T *output,
|
||||
int count,
|
||||
DivModFast sizeZ, DivModFast sizeY, DivModFast sizeX,
|
||||
int strideZ, int strideY,
|
||||
int dstStrideZ, int dstStrideY
|
||||
) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) {
|
||||
int ix, tmp, iy, iz;
|
||||
sizeX.divmod(i, tmp, ix);
|
||||
sizeY.divmod(tmp, iz, iy);
|
||||
int srcOffset = iz * strideZ + iy * strideY + (ix << 1);
|
||||
int dstOffset = iz * dstStrideZ + iy * dstStrideY + (ix << 1);
|
||||
int* dstF = (int *)(output+dstOffset);
|
||||
dstF[0] = ((int *)(input+srcOffset))[0];
|
||||
}
|
||||
}
|
||||
|
||||
struct Bytes512 {
|
||||
int4 x[4];
|
||||
|
@ -186,26 +203,73 @@ UNARY_FUNC(GELU_STANDARD, (erf(x*0.7071067932881648f)+1.f)*x*0.5);
|
|||
void RasterBlit(uint8_t* output, const uint8_t* input, const int32_t* size, const int32_t* srcStride, const int32_t* dstStride, int bytes, CUDARuntime* runtime) {
|
||||
int count = size[0] * size[1] * size[2];
|
||||
|
||||
// MNN_PRINT("blit info size:%d-%d-%d, srcStride:%d-%d-%d, dstStride:%d-%d-%d\n", size[0], size[1], size[2], srcStride[0], srcStride[1], srcStride[2], dstStride[0], dstStride[1], dstStride[2]);
|
||||
bool isThirdSizeVector = (size[2] % 2 == 0 && srcStride[2] == 1 && dstStride[2] == 1);
|
||||
bool isSecondSizeVector = (size[1] % 2 == 0 && srcStride[1] == 1 && dstStride[1] == 1) && (size[2] == 1 && srcStride[2] == 1 && dstStride[2] == 1);
|
||||
bool isFirstSizeVector = (size[0] % 2 == 0 && srcStride[0] == 1 && dstStride[0] == 1) && (size[1] == 1 && srcStride[1] == 1 && dstStride[1] == 1) && (size[2] == 1 && srcStride[2] == 1 && dstStride[2] == 1);
|
||||
bool isSizeVector = isThirdSizeVector || isSecondSizeVector || isFirstSizeVector;
|
||||
if(count > 16384 && isSizeVector) {
|
||||
int32_t newSize[3], newSrcStride[3], newDstStride[3];
|
||||
newSize[0] = size[0];
|
||||
newSize[1] = size[1];
|
||||
newSize[2] = size[2];
|
||||
newSrcStride[0] = srcStride[0];
|
||||
newSrcStride[1] = srcStride[1];
|
||||
newSrcStride[2] = srcStride[2];
|
||||
newDstStride[0] = dstStride[0];
|
||||
newDstStride[1] = dstStride[1];
|
||||
newDstStride[2] = dstStride[2];
|
||||
if(isSecondSizeVector) {
|
||||
/* size : [size_0, size_1, 1] srcStride : [ss_0, 1, 1] dstStride : [ds_0, 1, 1]
|
||||
--> newSize: [1, size_0, size_1] newSrcStride: [1, ss_0, 1] newDstStride: [1, ds_0, 1]
|
||||
*/
|
||||
newSize[2] = size[1];
|
||||
newSize[1] = size[0];
|
||||
newSize[0] = 1;
|
||||
newSrcStride[1] = srcStride[0];
|
||||
newSrcStride[0] = 1;
|
||||
newDstStride[1] = dstStride[0];
|
||||
newDstStride[0] = 1;
|
||||
}
|
||||
if(isFirstSizeVector) {
|
||||
/* size : [size_0, 1, 1] srcStride : [1, 1, 1] dstStride : [1, 1, 1]
|
||||
--> newSize: [1, 1, size_0] newSrcStride: [1, 1, 1] newDstStride: [1, 1, 1]
|
||||
*/
|
||||
newSize[2] = size[0];
|
||||
newSize[0] = 1;
|
||||
}
|
||||
|
||||
DivModFast new_sz(newSize[0]);
|
||||
DivModFast new_sy(newSize[1]);
|
||||
DivModFast new_sx(newSize[2]/2);
|
||||
|
||||
int newCount = count / 2;
|
||||
int block_num = runtime->blocks_num(newCount);
|
||||
int threads_num = runtime->threads_num();
|
||||
|
||||
// Forbid addresss misalign
|
||||
if(bytes == 4 && reinterpret_cast<std::uintptr_t>(input) % 8 == 0 && reinterpret_cast<std::uintptr_t>(output) % 8 == 0) {
|
||||
blit_2_float<<<block_num, threads_num>>>((const float*)input, (float*)output,
|
||||
newCount,
|
||||
new_sz, new_sy, new_sx,
|
||||
newSrcStride[0], newSrcStride[1],
|
||||
newDstStride[0], newDstStride[1]);
|
||||
checkKernelErrors;
|
||||
return;
|
||||
} else if(bytes == 2 && reinterpret_cast<std::uintptr_t>(input) % 4 == 0 && reinterpret_cast<std::uintptr_t>(output) % 4 == 0) {
|
||||
blit_2_half<<<block_num, threads_num>>>((const half*)input, (half*)output,
|
||||
newCount,
|
||||
new_sz, new_sy, new_sx,
|
||||
newSrcStride[0], newSrcStride[1],
|
||||
newDstStride[0], newDstStride[1]);
|
||||
checkKernelErrors;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
DivModFast sz(size[0]);
|
||||
DivModFast sy(size[1]);
|
||||
DivModFast sx(size[2]);
|
||||
|
||||
// MNN_PRINT("blit info size:%d-%d-%d, srcStride:%d-%d-%d, dstStride:%d-%d-%d\n", size[0], size[1], size[2], srcStride[0], srcStride[1], srcStride[2], dstStride[0], dstStride[1], dstStride[2]);
|
||||
if(bytes == 4 && count > 16384 && size[2] % 2 == 0 && srcStride[2] == 1 && dstStride[2] == 1) {
|
||||
//printf("%d-%d-%d, %d-%d-%d,-%d-%d-%d\n\n", size[0], size[1], size[2], srcStride[0], srcStride[1], srcStride[2], dstStride[0], dstStride[1], dstStride[2]);
|
||||
count /= 2;
|
||||
int block_num = runtime->blocks_num(count);
|
||||
int threads_num = runtime->threads_num();
|
||||
DivModFast sx_2((size[2]/2));
|
||||
|
||||
blit_2<<<block_num, threads_num>>>((const float*)input, (float*)output,
|
||||
count,
|
||||
sz, sy, sx_2,
|
||||
srcStride[0], srcStride[1],
|
||||
dstStride[0], dstStride[1]);
|
||||
return;
|
||||
}
|
||||
|
||||
int block_num = runtime->blocks_num(count);
|
||||
int threads_num = runtime->threads_num();
|
||||
|
||||
|
@ -225,7 +289,7 @@ void RasterBlit(uint8_t* output, const uint8_t* input, const int32_t* size, cons
|
|||
dstStride[0], dstStride[1], dstStride[2]);
|
||||
break;
|
||||
case 4:
|
||||
blit<<<block_num, threads_num>>>((const float*)input, (float*)output,
|
||||
blit<<<block_num, threads_num>>>((const float*)input, (float*)output,
|
||||
count,
|
||||
sz, sy, sx,
|
||||
srcStride[0], srcStride[1], srcStride[2],
|
||||
|
@ -248,6 +312,7 @@ void RasterBlit(uint8_t* output, const uint8_t* input, const int32_t* size, cons
|
|||
default:
|
||||
break;
|
||||
}
|
||||
checkKernelErrors;
|
||||
}
|
||||
|
||||
template<typename T0, typename T1>
|
||||
|
@ -538,7 +603,6 @@ __global__ void Binary##Name(\
|
|||
) { \
|
||||
int count = sizeZ * sizeY * sizeX;\
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {\
|
||||
int total = sizeZ * sizeY * sizeX;\
|
||||
int ix = i % sizeX;\
|
||||
int tmp = i / sizeX;\
|
||||
int iy = tmp % sizeY;\
|
||||
|
@ -563,15 +627,14 @@ __global__ void BinaryMid##Name(\
|
|||
int sizeZ, int sizeY, int sizeX,\
|
||||
int strideZ, int strideY, int strideX,\
|
||||
int strideZ1, int strideY1, int strideX1,\
|
||||
int dstStrideZ, int dstStrideY, int dstStrideX, int activationType, int bytes\
|
||||
int dstStrideZ, int dstStrideY, int dstStrideX, int activationType,\
|
||||
DivModFast d_sizeY, DivModFast d_sizeX\
|
||||
) { \
|
||||
int count = sizeZ * sizeY * sizeX;\
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {\
|
||||
int total = sizeZ * sizeY * sizeX;\
|
||||
int ix = i % sizeX;\
|
||||
int tmp = i / sizeX;\
|
||||
int iy = tmp % sizeY;\
|
||||
int iz = tmp / sizeY;\
|
||||
int ix, tmp, iy, iz;\
|
||||
d_sizeX.divmod(i, tmp, ix);\
|
||||
d_sizeY.divmod(tmp, iz, iy);\
|
||||
int srcOffset = iz * strideZ + iy * strideY + ix * strideX;\
|
||||
int srcOffset1 = iz * strideZ1 + iy * strideY1 + ix * strideX1;\
|
||||
int dstOffset = iz * dstStrideZ + iy * dstStrideY + ix * dstStrideX;\
|
||||
|
@ -581,11 +644,95 @@ __global__ void BinaryMid##Name(\
|
|||
if(activationType == 1) {\
|
||||
val = (val < 0.0f ? 0.0f : val);\
|
||||
}\
|
||||
if(bytes == 2) {\
|
||||
val = min(val, 65504.0f);\
|
||||
val = max(val, -65504.0f);\
|
||||
output[dstOffset] = val;\
|
||||
}\
|
||||
}\
|
||||
template<typename TIn, typename TOut>\
|
||||
__global__ void BinaryMid4_##Name(\
|
||||
const TIn *input0, const TIn* input1, TOut *output,\
|
||||
int sizeZ, int sizeY, int sizeX,\
|
||||
int strideZ, int strideY,\
|
||||
int strideZ1, int strideY1,\
|
||||
int dstStrideZ, int dstStrideY, int activationType,\
|
||||
DivModFast d_sizeY, DivModFast d_sizeX,\
|
||||
bool inp0Broadcast, bool inp1Broadcast\
|
||||
) { \
|
||||
int count = sizeZ * sizeY * sizeX;\
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {\
|
||||
int ix, tmp, iy, iz;\
|
||||
d_sizeX.divmod(i, tmp, ix);\
|
||||
d_sizeY.divmod(tmp, iz, iy);\
|
||||
ix = ix << 2;\
|
||||
int srcOffset = iz * strideZ + iy * strideY + ix;\
|
||||
int srcOffset1 = iz * strideZ1 + iy * strideY1 + ix;\
|
||||
int dstOffset = iz * dstStrideZ + iy * dstStrideY + ix;\
|
||||
float4 xx = inp0Broadcast ? make_float4(input0[srcOffset-ix],input0[srcOffset-ix], input0[srcOffset-ix], input0[srcOffset-ix]) : ((float4 *)(input0+srcOffset))[0];\
|
||||
float4 yy = inp1Broadcast ? make_float4(input1[srcOffset1-ix],input1[srcOffset1-ix], input1[srcOffset1-ix], input1[srcOffset1-ix]) :((float4 *)(input1+srcOffset1))[0];\
|
||||
float x = xx.x;\
|
||||
float y = yy.x;\
|
||||
float val = (float)(Func);\
|
||||
if(activationType == 1) {\
|
||||
val = (val < 0.0f ? 0.0f : val);\
|
||||
}\
|
||||
output[dstOffset] = val;\
|
||||
x = xx.y;\
|
||||
y = yy.y;\
|
||||
val = (float)(Func);\
|
||||
if(activationType == 1) {\
|
||||
val = (val < 0.0f ? 0.0f : val);\
|
||||
}\
|
||||
output[dstOffset+1] = val;\
|
||||
x = xx.z;\
|
||||
y = yy.z;\
|
||||
val = (float)(Func);\
|
||||
if(activationType == 1) {\
|
||||
val = (val < 0.0f ? 0.0f : val);\
|
||||
}\
|
||||
output[dstOffset+2] = val;\
|
||||
x = xx.w;\
|
||||
y = yy.w;\
|
||||
val = (float)(Func);\
|
||||
if(activationType == 1) {\
|
||||
val = (val < 0.0f ? 0.0f : val);\
|
||||
}\
|
||||
output[dstOffset+3] = val;\
|
||||
}\
|
||||
}\
|
||||
template<typename TIn, typename TOut>\
|
||||
__global__ void BinaryMidHalf2_##Name(\
|
||||
const TIn *input0, const TIn* input1, TOut *output,\
|
||||
int sizeZ, int sizeY, int sizeX,\
|
||||
int strideZ, int strideY,\
|
||||
int strideZ1, int strideY1,\
|
||||
int dstStrideZ, int dstStrideY, int activationType,\
|
||||
DivModFast d_sizeY, DivModFast d_sizeX,\
|
||||
bool inp0Broadcast, bool inp1Broadcast\
|
||||
) { \
|
||||
int count = sizeZ * sizeY * sizeX;\
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {\
|
||||
int ix, tmp, iy, iz;\
|
||||
d_sizeX.divmod(i, tmp, ix);\
|
||||
d_sizeY.divmod(tmp, iz, iy);\
|
||||
ix = ix << 1;\
|
||||
int srcOffset = iz * strideZ + iy * strideY + ix;\
|
||||
int srcOffset1 = iz * strideZ1 + iy * strideY1 + ix;\
|
||||
int dstOffset = iz * dstStrideZ + iy * dstStrideY + ix;\
|
||||
half2 xx = inp0Broadcast ? make_half2(input0[srcOffset-ix], input0[srcOffset-ix]) : ((half2 *)(input0+srcOffset))[0];\
|
||||
half2 yy = inp1Broadcast ? make_half2(input1[srcOffset1-ix], input1[srcOffset1-ix]) : ((half2 *)(input1+srcOffset1))[0];\
|
||||
float x = xx.x;\
|
||||
float y = yy.x;\
|
||||
float val = (float)(Func);\
|
||||
if(activationType == 1) {\
|
||||
val = (val < 0.0f ? 0.0f : val);\
|
||||
}\
|
||||
output[dstOffset] = val;\
|
||||
x = xx.y;\
|
||||
y = yy.y;\
|
||||
val = (float)(Func);\
|
||||
if(activationType == 1) {\
|
||||
val = (val < 0.0f ? 0.0f : val);\
|
||||
}\
|
||||
output[dstOffset+1] = val;\
|
||||
}\
|
||||
}\
|
||||
template<typename TIn, typename TOut>\
|
||||
|
@ -595,8 +742,7 @@ __global__ void BinaryMidLinear##Name(\
|
|||
int strideZ,\
|
||||
int strideZ1,\
|
||||
int dstStrideZ,\
|
||||
int activationType,\
|
||||
int bytes\
|
||||
int activationType\
|
||||
) { \
|
||||
int count = sizeZ;\
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {\
|
||||
|
@ -610,10 +756,6 @@ __global__ void BinaryMidLinear##Name(\
|
|||
if(activationType == 1) {\
|
||||
val = (val < 0.0f ? 0.0f : val);\
|
||||
}\
|
||||
if(bytes == 2) {\
|
||||
val = min(val, 65504.0f);\
|
||||
val = max(val, -65504.0f);\
|
||||
}\
|
||||
output[dstOffset] = (TOut)val;\
|
||||
}\
|
||||
}\
|
||||
|
@ -622,15 +764,16 @@ __global__ void BinaryMidLinear##Name(\
|
|||
template<typename TIn, typename TOut>\
|
||||
__global__ void BinaryMidLinear4_##Name(\
|
||||
const TIn *input0, const TIn* input1, TOut *output,\
|
||||
int count_4, int activationType\
|
||||
int count_4, int activationType,\
|
||||
bool inp0Broadcast, bool inp1Broadcast\
|
||||
) { \
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count_4); i += blockDim.x * gridDim.x) {\
|
||||
int iz = i;\
|
||||
int srcOffset = iz << 2;\
|
||||
int srcOffset1 = iz << 2;\
|
||||
int dstOffset = iz << 2;\
|
||||
float4 xx = ((float4 *)(input0+srcOffset))[0];\
|
||||
float4 yy = ((float4 *)(input1+srcOffset1))[0];\
|
||||
float4 xx = inp0Broadcast ? make_float4(input0[0], input0[0], input0[0], input0[0]) : ((float4 *)(input0+srcOffset))[0];\
|
||||
float4 yy = inp1Broadcast ? make_float4(input1[0], input1[0], input1[0], input1[0]) : ((float4 *)(input1+srcOffset1))[0];\
|
||||
float x = xx.x;\
|
||||
float y = yy.x;\
|
||||
TOut val = (TOut)(Func);\
|
||||
|
@ -664,23 +807,22 @@ __global__ void BinaryMidLinear4_##Name(\
|
|||
template<typename TIn, typename TOut>\
|
||||
__global__ void BinaryMidLinearHalf4_##Name(\
|
||||
const TIn *input0, const TIn* input1, TOut *output,\
|
||||
int count_4, int activationType\
|
||||
int count_4, int activationType,\
|
||||
bool inp0Broadcast, bool inp1Broadcast\
|
||||
) { \
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count_4); i += blockDim.x * gridDim.x) {\
|
||||
int iz = i;\
|
||||
int srcOffset = iz << 2;\
|
||||
int srcOffset1 = iz << 2;\
|
||||
int dstOffset = iz << 2;\
|
||||
half2 xx = ((half2 *)(input0+srcOffset))[0];\
|
||||
half2 yy = ((half2 *)(input1+srcOffset1))[0];\
|
||||
half2 xx = inp0Broadcast ? make_half2(input0[0], input0[0]) : ((half2 *)(input0+srcOffset))[0];\
|
||||
half2 yy = inp1Broadcast ? make_half2(input1[0], input1[0]) : ((half2 *)(input1+srcOffset1))[0];\
|
||||
float x = (float)xx.x;\
|
||||
float y = (float)yy.x;\
|
||||
float val = (float)(Func);\
|
||||
if(activationType == 1) {\
|
||||
val = (val < 0.0f ? 0.0f : val);\
|
||||
}\
|
||||
val = min(val, 65504.0f);\
|
||||
val = max(val, -65504.0f);\
|
||||
output[dstOffset] = (TOut)val;\
|
||||
x = (float)xx.y;\
|
||||
y = (float)yy.y;\
|
||||
|
@ -688,19 +830,15 @@ __global__ void BinaryMidLinearHalf4_##Name(\
|
|||
if(activationType == 1) {\
|
||||
val = (val < 0.0f ? 0.0f : val);\
|
||||
}\
|
||||
val = min(val, 65504.0f);\
|
||||
val = max(val, -65504.0f);\
|
||||
output[dstOffset+1] = (TOut)val;\
|
||||
xx = ((half2 *)(input0+srcOffset))[1];\
|
||||
yy = ((half2 *)(input1+srcOffset1))[1];\
|
||||
xx = inp0Broadcast ? make_half2(input0[0], input0[0]) : ((half2 *)(input0+srcOffset))[1];\
|
||||
yy = inp1Broadcast ? make_half2(input1[0], input1[0]) : ((half2 *)(input1+srcOffset1))[1];\
|
||||
x = (float)xx.x;\
|
||||
y = (float)yy.x;\
|
||||
val = (float)(Func);\
|
||||
if(activationType == 1) {\
|
||||
val = (val < 0.0f ? 0.0f : val);\
|
||||
}\
|
||||
val = min(val, 65504.0f);\
|
||||
val = max(val, -65504.0f);\
|
||||
output[dstOffset+2] = (TOut)val;\
|
||||
x = (float)xx.y;\
|
||||
y = (float)yy.y;\
|
||||
|
@ -708,8 +846,6 @@ __global__ void BinaryMidLinearHalf4_##Name(\
|
|||
if(activationType == 1) {\
|
||||
val = (val < 0.0f ? 0.0f : val);\
|
||||
}\
|
||||
val = min(val, 65504.0f);\
|
||||
val = max(val, -65504.0f);\
|
||||
output[dstOffset+3] = (TOut)val;\
|
||||
}\
|
||||
}\
|
||||
|
@ -784,18 +920,21 @@ void BinaryBlitTemplateFloat(T* output, const T* input, const T* input1, const i
|
|||
int count = size[0] * size[1] * size[2];
|
||||
int block_num = runtime->blocks_num(count);
|
||||
int threads_num = runtime->threads_num();
|
||||
// MNN_PRINT("binary :%d %d %d, %d %d %d, %d %d %d, %d %d %d, \n", size[0], size[1], size[2], srcStride[0], srcStride[1], srcStride[2], srcStride1[0], srcStride1[1], srcStride1[2], dstStride[0], dstStride[1], dstStride[2]);
|
||||
#define COMPUTE_FLOAT(TYPE, TOut)\
|
||||
if (opType == MNN::BinaryOpOperation_##TYPE ) {\
|
||||
if (size[2] == count) {\
|
||||
if(count % 4 == 0 && count > 16384 && srcStride[2] == 1 && srcStride1[2] == 1 && dstStride[2] == 1) {\
|
||||
if(count % 4 == 0 && count > 16384 && (srcStride[2] == 0 || srcStride[2] == 1) && (srcStride1[2] == 0 || srcStride1[2] == 1) && dstStride[2] == 1) {\
|
||||
block_num = runtime->blocks_num(count/4);\
|
||||
threads_num = runtime->threads_num();\
|
||||
bool srcBroadcast = srcStride[2] == 0;\
|
||||
bool srcBroadcast1 = srcStride1[2] == 0;\
|
||||
if(bytes == 4) {\
|
||||
BinaryMidLinear4_##TYPE<<<block_num, threads_num>>>((const T*)input, (const T*)(input1), (TOut*)output,\
|
||||
count/4, activationType);\
|
||||
count/4, activationType, srcBroadcast, srcBroadcast1);\
|
||||
} else {\
|
||||
BinaryMidLinearHalf4_##TYPE<<<block_num, threads_num>>>((const T*)input, (const T*)(input1), (TOut*)output,\
|
||||
count/4, activationType);\
|
||||
count/4, activationType, srcBroadcast, srcBroadcast1);\
|
||||
}\
|
||||
} else {\
|
||||
BinaryMidLinear##TYPE<<<block_num, threads_num>>>((const T*)input, (const T*)(input1), (TOut*)output,\
|
||||
|
@ -803,14 +942,41 @@ void BinaryBlitTemplateFloat(T* output, const T* input, const T* input1, const i
|
|||
srcStride[2],\
|
||||
srcStride1[2],\
|
||||
dstStride[2],\
|
||||
activationType, bytes);\
|
||||
activationType);\
|
||||
}\
|
||||
} else {\
|
||||
BinaryMid##TYPE<<<block_num, threads_num>>>((const T*)input, (const T*)(input1), (TOut*)output,\
|
||||
size[0], size[1], size[2],\
|
||||
srcStride[0], srcStride[1], srcStride[2],\
|
||||
srcStride1[0], srcStride1[1], srcStride1[2],\
|
||||
dstStride[0], dstStride[1], dstStride[2], activationType, bytes);\
|
||||
bool isVectorSizeZ = (size[0] == 1 || ((srcStride[2] == 0 || srcStride[0] % bytes == 0) && (srcStride1[2] == 0 || srcStride1[0] % bytes == 0) && dstStride[0] % bytes == 0));\
|
||||
bool isVectorSizeY = (size[1] == 1 || ((srcStride[2] == 0 || srcStride[1] % bytes == 0) && (srcStride1[2] == 0 || srcStride1[1] % bytes == 0) && dstStride[1] % bytes == 0));\
|
||||
bool isVector4 = size[2] % bytes == 0 && isVectorSizeZ && isVectorSizeY;\
|
||||
if(isVector4 && count > 16384 && (srcStride[2] == 0 || srcStride[2] == 1) && (srcStride1[2] == 0 || srcStride1[2] == 1) && dstStride[2] == 1) {\
|
||||
block_num = runtime->blocks_num(count/bytes);\
|
||||
threads_num = runtime->threads_num();\
|
||||
DivModFast sy(size[1]);\
|
||||
DivModFast sx(size[2]/bytes);\
|
||||
bool srcBroadcast = srcStride[2] == 0;\
|
||||
bool srcBroadcast1 = srcStride1[2] == 0;\
|
||||
if(bytes == 4) {\
|
||||
BinaryMid4_##TYPE<<<block_num, threads_num>>>((const T*)input, (const T*)(input1), (TOut*)output,\
|
||||
size[0], size[1], size[2]/4,\
|
||||
srcStride[0], srcStride[1],\
|
||||
srcStride1[0], srcStride1[1],\
|
||||
dstStride[0], dstStride[1], activationType, sy, sx, srcBroadcast, srcBroadcast1);\
|
||||
} else {\
|
||||
BinaryMidHalf2_##TYPE<<<block_num, threads_num>>>((const T*)input, (const T*)(input1), (TOut*)output,\
|
||||
size[0], size[1], size[2]/2,\
|
||||
srcStride[0], srcStride[1],\
|
||||
srcStride1[0], srcStride1[1],\
|
||||
dstStride[0], dstStride[1], activationType, sy, sx, srcBroadcast, srcBroadcast1);\
|
||||
}\
|
||||
} else {\
|
||||
DivModFast sy(size[1]);\
|
||||
DivModFast sx(size[2]);\
|
||||
BinaryMid##TYPE<<<block_num, threads_num>>>((const T*)input, (const T*)(input1), (TOut*)output,\
|
||||
size[0], size[1], size[2],\
|
||||
srcStride[0], srcStride[1], srcStride[2],\
|
||||
srcStride1[0], srcStride1[1], srcStride1[2],\
|
||||
dstStride[0], dstStride[1], dstStride[2], activationType, sy, sx);\
|
||||
}\
|
||||
}\
|
||||
return;\
|
||||
}\
|
||||
|
|
|
@ -118,7 +118,7 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con
|
|||
auto& slice0 = des->regions[0];
|
||||
for (int i=0; i< des->regions.size(); ++i) {
|
||||
auto& slice = des->regions[i];
|
||||
// MNN_PRINT("%d-%d-%d-%d-%d\n", ____inputs[0]->batch(), ____inputs[0]->height(), ____inputs[0]->width(), ____inputs[0]->channel(), outputs[0]->channel());
|
||||
// MNN_PRINT("%d-%d-%d-%d-%d\n", ____inputs[i]->batch(), ____inputs[i]->height(), ____inputs[i]->width(), ____inputs[i]->channel(), outputs[0]->channel());
|
||||
// MNN_PRINT("%d-%d-%d, %d-%d-%d, %d-%d-%d, %d-%d\n\n", slice.size[0], slice.size[1], slice.size[2], slice.src.stride[0], slice.src.stride[1], slice.src.stride[2], slice.dst.stride[0], slice.dst.stride[1], slice.dst.stride[2], slice.src.offset, slice.dst.offset);
|
||||
if (TensorUtils::getDescribe(slice.origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) {
|
||||
mFast = false;
|
||||
|
@ -132,13 +132,13 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con
|
|||
mFast = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// MNN_PRINT("raster fast:%d regionNum:%d\n\n\n", mFast, des->regions.size());
|
||||
}
|
||||
//MNN_PRINT("raster fast:%d regionNum:%d\n\n\n", mFast, des->regions.size());
|
||||
if (mFast) {
|
||||
int srcStep = 1;
|
||||
int dstStep = 1;
|
||||
for (int i=0; i< des->regions.size(); ++i) {
|
||||
auto& slice = des->regions[i];
|
||||
int srcStep = 1;
|
||||
int dstStep = 1;
|
||||
auto& slice = des->regions[i];
|
||||
if(slice.dst.offset / (slice.size[2] * slice.size[1]) >= 1) {
|
||||
int batchChannel = slice.dst.offset / (slice.size[1] * slice.size[2]) + 1;
|
||||
dstStep = dstStep > batchChannel ? dstStep : batchChannel;
|
||||
|
@ -155,11 +155,12 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con
|
|||
int tmp = slice.src.stride[0] / slice.dst.stride[0];
|
||||
srcStep = srcStep > tmp ? srcStep : tmp;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i=0; i< des->regions.size(); ++i) {
|
||||
auto& slice = des->regions[i];
|
||||
if (slice.origin == nullptr) {
|
||||
if(____inputs[i]->channel() > slice.size[1]) {
|
||||
int tmp = ____inputs[i]->channel() / slice.size[1];
|
||||
srcStep = srcStep > tmp ? srcStep : tmp;
|
||||
}
|
||||
|
||||
if (slice.origin == nullptr) {
|
||||
continue;
|
||||
}
|
||||
Tensor::InsideDescribe::Region newRegion;
|
||||
|
@ -167,7 +168,7 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con
|
|||
_turnToNewRegion(slice, newRegion, srcStep, dstStep);
|
||||
mFastBlit.emplace_back(std::make_pair(slice.origin, std::move(newRegion)));
|
||||
// MNN_PRINT("new step %d-%d:%d-%d-%d, %d-%d-%d, %d-%d-%d, %d-%d\n\n", srcStep, dstStep, newRegion.size[0], newRegion.size[1], newRegion.size[2], newRegion.src.stride[0], newRegion.src.stride[1], newRegion.src.stride[2], newRegion.dst.stride[0], newRegion.dst.stride[1], newRegion.dst.stride[2], newRegion.src.offset, newRegion.dst.offset);
|
||||
}
|
||||
}
|
||||
return NO_ERROR;
|
||||
}
|
||||
}
|
||||
|
@ -299,19 +300,21 @@ void RasterExecution::executeFaster(const std::vector<Tensor *> &inputs, const s
|
|||
if (mNeedZero) {
|
||||
auto size = static_cast<CUDABackend*>(backend())->realSize(output) * bytes;
|
||||
cudaMemset((uint8_t*)output->deviceId(), 0, size);
|
||||
checkKernelErrors;
|
||||
}
|
||||
// Use mFastBlit
|
||||
for (auto& iter : mFastBlit) {
|
||||
auto srcPtr = (uint8_t*)iter.first->deviceId() + iter.second.src.offset * bytes;
|
||||
auto dstPtr = (uint8_t*)output->deviceId() + iter.second.dst.offset * bytes;
|
||||
RasterBlit(dstPtr, srcPtr, iter.second.size, iter.second.src.stride, iter.second.dst.stride, bytes, runtime);
|
||||
RasterBlit(dstPtr, srcPtr, iter.second.size, iter.second.src.stride, iter.second.dst.stride, bytes, runtime);
|
||||
checkKernelErrors;
|
||||
}
|
||||
}
|
||||
|
||||
ErrorCode RasterExecution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
if (mFast) {
|
||||
executeFaster(inputs, outputs);
|
||||
return NO_ERROR;
|
||||
return NO_ERROR;
|
||||
}
|
||||
auto bn = static_cast<CUDABackend*>(backend());
|
||||
auto input = outputs[0];
|
||||
|
@ -345,13 +348,15 @@ ErrorCode RasterExecution::onExecute(const std::vector<Tensor *> &inputs, const
|
|||
cudaMemcpy(dstPtr, srcPtr, bn->realSize(realInput) * bytes, cudaMemcpyDeviceToDevice);
|
||||
return NO_ERROR;
|
||||
}
|
||||
UnpackBuffer(dstPtr, srcPtr, &pack, bytes, runtime);
|
||||
UnpackBuffer(dstPtr, srcPtr, &pack, bytes, runtime);
|
||||
checkKernelErrors;
|
||||
} else {
|
||||
if (output->dimensions() <= 1) {
|
||||
cudaMemcpy(dstPtr, srcPtr, bn->realSize(realInput) * bytes, cudaMemcpyDeviceToDevice);
|
||||
return NO_ERROR;
|
||||
}
|
||||
PackBuffer(dstPtr, srcPtr, &pack, bytes, runtime);
|
||||
PackBuffer(dstPtr, srcPtr, &pack, bytes, runtime);
|
||||
checkKernelErrors;
|
||||
}
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
@ -359,9 +364,11 @@ ErrorCode RasterExecution::onExecute(const std::vector<Tensor *> &inputs, const
|
|||
if (mNeedZero) {
|
||||
auto size = static_cast<CUDABackend*>(backend())->realSize(mOutputPtr) * bytes;
|
||||
cudaMemset((uint8_t*)mOutputPtr->deviceId(), 0, size);
|
||||
checkKernelErrors;
|
||||
}
|
||||
for (auto& iter : mTempInput) {
|
||||
backend()->onCopyBuffer(iter.first, iter.second);
|
||||
checkKernelErrors;
|
||||
}
|
||||
//MNN_PRINT("\n%d\n", mFuseRaster.first);
|
||||
if(mFuseRaster.first > 0) {
|
||||
|
@ -373,16 +380,19 @@ ErrorCode RasterExecution::onExecute(const std::vector<Tensor *> &inputs, const
|
|||
//MNN_PRINT("fuseRaster:%p-%p\n", mSrcOffset, mDstOffset);
|
||||
|
||||
FuseRasterBlit(dstPtr, srcPtr, slice.size, slice.src.stride, slice.dst.stride, mFuseRaster.second, mOffset, bytes, runtime, mFuseRaster.first);
|
||||
checkKernelErrors;
|
||||
} else {
|
||||
for (auto& iter : mTempInputCopy) {
|
||||
auto srcPtr = (uint8_t*)iter.first->deviceId() + iter.second->src.offset * bytes;
|
||||
auto dstPtr = (uint8_t*)mOutputPtr->deviceId() + iter.second->dst.offset * bytes;
|
||||
RasterBlit(dstPtr, srcPtr, iter.second->size, iter.second->src.stride, iter.second->dst.stride, bytes, runtime);
|
||||
checkKernelErrors;
|
||||
}
|
||||
}
|
||||
|
||||
if (nullptr != mTempOutput) {
|
||||
backend()->onCopyBuffer(mTempOutput.get(), output);
|
||||
checkKernelErrors;
|
||||
}
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
|
|
@ -209,12 +209,12 @@ ErrorCode SoftmaxExecution::onExecute(const std::vector<Tensor *> &inputs, const
|
|||
int block_num = runtime->blocks_num(count);
|
||||
int threads_num = runtime->threads_num();
|
||||
if (static_cast<CUDABackend*>(backend())->useFp16()) {
|
||||
if(axis % 256 == 0 || axis >= 768) {
|
||||
if(axis % 256 == 0 || axis >= 768) {
|
||||
block_num = count;
|
||||
int calc_multi_num = (axis + 255) / 256;
|
||||
SOFTMAX_AXIS_REDUCE<<<block_num, 256>>>((const half*)input, (half*)dst, inside, axis, 256, calc_multi_num, outside, count);
|
||||
checkKernelErrors;
|
||||
} else if(axis % 64 == 0 || axis >= 256) {
|
||||
} else if(axis % 64 == 0 || axis > 32) {
|
||||
block_num = count;
|
||||
int calc_multi_num = (axis + 63) / 64;
|
||||
SOFTMAX_AXIS_REDUCE<<<block_num, 64>>>((const half*)input, (half*)dst, inside, axis, 64, calc_multi_num, outside, count);
|
||||
|
@ -234,7 +234,7 @@ ErrorCode SoftmaxExecution::onExecute(const std::vector<Tensor *> &inputs, const
|
|||
int calc_multi_num = (axis + 255) / 256;
|
||||
SOFTMAX_AXIS_REDUCE<<<block_num, 256>>>((const float*)input, (float*)dst, inside, axis, 256, calc_multi_num, outside, count);
|
||||
checkKernelErrors;
|
||||
} else if(axis % 64 == 0 || axis >= 256) {
|
||||
} else if(axis % 64 == 0 || axis > 32) {
|
||||
block_num = count;
|
||||
int calc_multi_num = (axis + 63) / 64;
|
||||
SOFTMAX_AXIS_REDUCE<<<block_num, 64>>>((const float*)input, (float*)dst, inside, axis, 64, calc_multi_num, outside, count);
|
||||
|
|
|
@ -0,0 +1,312 @@
|
|||
#include "TopKV2Execution.hpp"
|
||||
#include <memory>
|
||||
|
||||
|
||||
namespace MNN {
|
||||
namespace CUDA {
|
||||
|
||||
|
||||
// rank TopK in the corresponding thead
|
||||
template<typename indexT, typename valueT>
|
||||
__device__ void TopKInThread(const valueT * inputDevice, indexT * indicesThread, valueT * valuesThread, const int K, const int numElePerRow, const valueT minValue, const int descendFlag) {
|
||||
for (int i = 0 ; i < K; i++) {
|
||||
indicesThread[i] = -1;
|
||||
valuesThread[i] = (valueT)(descendFlag) * minValue;
|
||||
}
|
||||
|
||||
int idxFirstEleInRow = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
||||
for (indexT i = idxFirstEleInRow; i < numElePerRow; i += gridDim.x * blockDim.x) {
|
||||
valueT data = inputDevice[i];
|
||||
if ((valueT)(descendFlag) * data <= (valueT)(descendFlag) * valuesThread[K - 1]) {
|
||||
continue;
|
||||
} else {
|
||||
for (int j = K - 2; j >= 0; j--) {
|
||||
if ((valueT)(descendFlag) * data > (valueT)(descendFlag) * valuesThread[j]) {
|
||||
valuesThread[j + 1] = valuesThread[j];
|
||||
indicesThread[j + 1] = indicesThread[j];
|
||||
valuesThread[j] = data;
|
||||
indicesThread[j] = i;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// reduce TopK results of two offsets
|
||||
template<typename indexT, typename valueT>
|
||||
__device__ void ReduceTopK(indexT * indicesArray, valueT * valuesArray, const int offset1, const int offset2, const int K, const int descendFlag) {
|
||||
indexT idx1 = offset1 + K - 1;
|
||||
indexT idx2 = offset2 + K - 1;
|
||||
indexT idxVirtual = offset1 + 2 * K -1;
|
||||
|
||||
while (idx2 >= offset2) {
|
||||
if (idx1 < offset1) {
|
||||
while (idxVirtual >= offset1) {
|
||||
indicesArray[idxVirtual] = indicesArray[offset2 + (idxVirtual - offset1)];
|
||||
valuesArray[idxVirtual] = valuesArray[offset2 + (idxVirtual - offset1)];
|
||||
idxVirtual --;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if ((valueT)(descendFlag) * valuesArray[idx1] <= (valueT)(descendFlag) * valuesArray[idx2]) {
|
||||
if (idxVirtual <= offset1 + K - 1) {
|
||||
indicesArray[idxVirtual] = indicesArray[idx1];
|
||||
valuesArray[idxVirtual] = valuesArray[idx1];
|
||||
}
|
||||
idx1 --;
|
||||
} else {
|
||||
if (idxVirtual <= offset1 + K - 1) {
|
||||
indicesArray[idxVirtual] = indicesArray[idx2];
|
||||
valuesArray[idxVirtual] = valuesArray[idx2];
|
||||
}
|
||||
idx2 --;
|
||||
}
|
||||
idxVirtual --;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// get results of all blocks' TopK in one row
|
||||
template<typename indexT, typename valueT>
|
||||
__device__ void TopKOneRow(const valueT * inputDevice, indexT * indicesBlock, valueT * valuesBlock, indexT * tempIndicesDevice, valueT * tempValuesDevice, const int K, const int lengthRow, valueT minValue, const int descendFlag) {
|
||||
indexT * indicesThread = indicesBlock + threadIdx.x * K;
|
||||
valueT * valuesThread = valuesBlock + threadIdx.x * K;
|
||||
|
||||
// rank TopK
|
||||
TopKInThread<indexT, valueT>(inputDevice, indicesThread, valuesThread, K, lengthRow, minValue, descendFlag);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// reduce
|
||||
for(int stride = (blockDim.x >> 1); stride > 0; stride >>= 1) {
|
||||
if(threadIdx.x < stride) {
|
||||
ReduceTopK<indexT, valueT>(indicesBlock, valuesBlock, threadIdx.x * K, (threadIdx.x + stride) * K, K, descendFlag);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// move data from block's smem to global memory(prepare for the next kernel function)
|
||||
if (threadIdx.x == 0) {
|
||||
for(int i = 0; i < K; i++) {
|
||||
tempIndicesDevice[K * blockIdx.x + i] = indicesBlock[i];
|
||||
tempValuesDevice[K * blockIdx.x + i] = valuesBlock[i];
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// get results of the final TopK from all block's TopK in a row
|
||||
template<typename indexT, typename valueT>
|
||||
__device__ void GetResultOneRow(indexT * outputIndicesDevice, valueT * outputValuesDevice, indexT * tempIndicesDevice, valueT * tempValuesDevice, indexT * finalIndices, valueT * finalValues, const int K, const int reduceLength, const int descendFlag) {
|
||||
// move data from global memory to a block's smem
|
||||
if (threadIdx.x < reduceLength) {
|
||||
for (int i = 0; i < K; i++) {
|
||||
finalIndices[threadIdx.x * K + i] = tempIndicesDevice[threadIdx.x * K + i];
|
||||
finalValues[threadIdx.x * K + i] = tempValuesDevice[threadIdx.x * K + i];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// the first round of reducing needs special action
|
||||
int stride = blockDim.x >> 1;
|
||||
if ((threadIdx.x < stride) && (threadIdx.x + stride < reduceLength)) {
|
||||
ReduceTopK<indexT, valueT>(finalIndices, finalValues, threadIdx.x * K, (threadIdx.x + stride) * K, K, descendFlag);
|
||||
}
|
||||
__syncthreads();
|
||||
stride >>= 1;
|
||||
|
||||
// the remaining rounds of reducing
|
||||
for (; stride > 0; stride >>= 1) {
|
||||
if (threadIdx.x < stride) {
|
||||
ReduceTopK<indexT, valueT>(finalIndices, finalValues, threadIdx.x * K, (threadIdx.x + stride) * K, K, descendFlag);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
//move data from a block's smem to global memory
|
||||
if (threadIdx.x == 0) {
|
||||
for (int i = 0; i < K; i++) {
|
||||
outputIndicesDevice[i] = finalIndices[i];
|
||||
outputValuesDevice[i] = finalValues[i];
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// allocate addresses for each row and call <TopKOneRow>
|
||||
template<typename indexT, typename valueT>
|
||||
__global__ void TopKAllRows(const valueT * inputDevice, indexT * tempIndicesDevice, valueT * tempValuesDevice, const int K, const int lengthRow, valueT minValue, const int descendFlag) {
|
||||
extern __shared__ char smem[];
|
||||
indexT * indicesBlock = reinterpret_cast<indexT *>(smem);
|
||||
valueT * valuesBlock = reinterpret_cast<valueT *>(&smem[blockDim.x * K * sizeof(indexT)]);
|
||||
|
||||
int idxRow = blockIdx.y;
|
||||
|
||||
const valueT * inputDeviceThisRow = inputDevice + idxRow * lengthRow;
|
||||
indexT * tempIndicesDeviceThisRow = tempIndicesDevice + idxRow * gridDim.x * K;
|
||||
valueT * tempValuesDeviceThisRow = tempValuesDevice + idxRow * gridDim.x * K;
|
||||
|
||||
TopKOneRow<indexT, valueT>(inputDeviceThisRow, indicesBlock, valuesBlock, tempIndicesDeviceThisRow, tempValuesDeviceThisRow, K, lengthRow, minValue, descendFlag);
|
||||
__syncthreads();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// allocate addresses for each row and call <GetResultOneRow>
|
||||
// This kernel assumes that each row of data corresponds to one block.
|
||||
template<typename indexT, typename valueT>
|
||||
__global__ void GetResultAllRows(indexT * outputIndicesDevice, valueT * outputValuesDevice, indexT * tempIndicesDevice, valueT * tempValuesDevice, const int K, const int numBlockPerRow, const int descendFlag) {
|
||||
extern __shared__ char smem[];
|
||||
indexT * finalIndices = reinterpret_cast<indexT *>(smem);
|
||||
valueT * finalValues = reinterpret_cast<valueT *>(&smem[numBlockPerRow * K * sizeof(indexT)]);
|
||||
|
||||
int idxRow = blockIdx.x; // each block corresponds to a row
|
||||
|
||||
indexT * outputIndicesDeviceThisRow = outputIndicesDevice + idxRow * K;
|
||||
valueT * outputValuesDeviceThisRow = outputValuesDevice + idxRow * K;
|
||||
indexT * tempIndicesDeviceThisRow = tempIndicesDevice + idxRow * numBlockPerRow * K;
|
||||
valueT * tempValuesDeviceThisRow = tempValuesDevice + idxRow * numBlockPerRow * K;
|
||||
|
||||
GetResultOneRow<indexT, valueT>(outputIndicesDeviceThisRow, outputValuesDeviceThisRow, tempIndicesDeviceThisRow, tempValuesDeviceThisRow, finalIndices, finalValues, K, numBlockPerRow, descendFlag);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
int CalculateNumThreadPerBlock(const int K) {
|
||||
int numThreadPerBlock;
|
||||
if (K <= 48) {
|
||||
numThreadPerBlock = 128;
|
||||
} else if (K <= 96) {
|
||||
numThreadPerBlock = 64;
|
||||
} else {
|
||||
numThreadPerBlock = 32;
|
||||
}
|
||||
|
||||
return numThreadPerBlock;
|
||||
}
|
||||
|
||||
|
||||
TopKV2Execution::TopKV2Execution(const Op* op, Backend* backend) : Execution(backend) {
|
||||
mOp = op;
|
||||
}
|
||||
|
||||
|
||||
ErrorCode TopKV2Execution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
// prepare some params for the kernel function
|
||||
Tensor * inputTensor = inputs[0];
|
||||
|
||||
int lengthRow = inputTensor->buffer().dim[inputTensor->buffer().dimensions - 1].extent;
|
||||
int numRow = inputTensor->elementSize() / lengthRow;
|
||||
|
||||
mParams.mLengthRow = lengthRow;
|
||||
mParams.mNumRow = numRow;
|
||||
|
||||
auto boolDescendFlag = mOp->main_as_TopKV2();
|
||||
if (boolDescendFlag != nullptr) {
|
||||
mParams.mDescendFlag = boolDescendFlag ? 1 : -1;
|
||||
}
|
||||
|
||||
mParams.mNumElePerRow = mParams.mLengthRow;
|
||||
mParams.mNumK = outputs[0]->buffer().dim[outputs[0]->buffer().dimensions-1].extent;
|
||||
mParams.mNumElePerThread = mParams.mNumK;
|
||||
mParams.mNumThreadPerBlock = CalculateNumThreadPerBlock(mParams.mNumK);
|
||||
mParams.mNumElePerBlock = mParams.mNumElePerThread * mParams.mNumThreadPerBlock;
|
||||
mParams.mNumBlockPerRow = (mParams.mNumElePerRow - 1 + mParams.mNumElePerBlock) / mParams.mNumElePerBlock;
|
||||
mParams.mNumBlockFinal = mParams.mNumRow;
|
||||
mParams.mNumThreadFinal = std::pow(2, (std::ceil(std::log2(mParams.mNumBlockPerRow))));
|
||||
mParams.mNumBlockTotal = mParams.mNumBlockPerRow * mParams.mNumRow;
|
||||
|
||||
// prepare temp buffer
|
||||
auto pool = static_cast<CUDABackend*>(backend())->getStaticBufferPool();
|
||||
|
||||
if (inputTensor->getType().code == halide_type_int && inputTensor->getType().bits == 32) {
|
||||
std::pair<void*, int> bufferIndices = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(int));
|
||||
mParams.mBufferIndices = (void*)((uint8_t*)bufferIndices.first + bufferIndices.second);
|
||||
std::pair<void*, int> bufferValues = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(int));
|
||||
mParams.mBufferValues = (void*)((uint8_t*)bufferValues.first + bufferValues.second);
|
||||
pool->free(bufferIndices);
|
||||
pool->free(bufferValues);
|
||||
} else if (static_cast<CUDABackend*>(backend())->useFp16()) {
|
||||
std::pair<void*, int> bufferIndices = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(int));
|
||||
mParams.mBufferIndices = (void*)((uint8_t*)bufferIndices.first + bufferIndices.second);
|
||||
std::pair<void*, int> bufferValues = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(half));
|
||||
mParams.mBufferValues = (void*)((uint8_t*)bufferValues.first + bufferValues.second);
|
||||
pool->free(bufferIndices);
|
||||
pool->free(bufferValues);
|
||||
} else {
|
||||
std::pair<void*, int> bufferIndices = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(int));
|
||||
mParams.mBufferIndices = (void*)((uint8_t*)bufferIndices.first + bufferIndices.second);
|
||||
std::pair<void*, int> bufferValues = pool->alloc(mParams.mNumBlockTotal * mParams.mNumK * sizeof(float));
|
||||
mParams.mBufferValues = (void*)((uint8_t*)bufferValues.first + bufferValues.second);
|
||||
pool->free(bufferIndices);
|
||||
pool->free(bufferValues);
|
||||
}
|
||||
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
|
||||
ErrorCode TopKV2Execution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
// get input and output pointers
|
||||
void * inputDeviceAddr = reinterpret_cast<void *>(inputs[0]->deviceId());
|
||||
void * outputIndicesDeviceAddr = reinterpret_cast<void *>(outputs[1]->deviceId());
|
||||
void * outputValuesDeviceAddr = reinterpret_cast<void *>(outputs[0]->deviceId());
|
||||
|
||||
// configure threads
|
||||
dim3 grid1 = {mParams.mNumBlockPerRow, mParams.mNumRow};
|
||||
dim3 block1 = {mParams.mNumThreadPerBlock, 1};
|
||||
int smemSize_1 = mParams.mNumThreadPerBlock * mParams.mNumK;
|
||||
dim3 grid2 = {mParams.mNumBlockFinal};
|
||||
dim3 block2 = {mParams.mNumThreadFinal};
|
||||
int smemSize_2 = mParams.mNumBlockPerRow * mParams.mNumK;
|
||||
|
||||
if (inputs[0]->getType().code == halide_type_int && inputs[0]->getType().bits == 32) {
|
||||
TopKAllRows<int, int><<<grid1, block1, smemSize_1 * (sizeof(int) + sizeof(int))>>>(static_cast<const int *>(inputDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<int *>(mParams.mBufferValues), mParams.mNumK, mParams.mLengthRow, mParams.mMinInt, mParams.mDescendFlag);
|
||||
checkKernelErrors;
|
||||
GetResultAllRows<int, int><<<grid2, block2, smemSize_2 * (sizeof(int) + sizeof(int))>>>(static_cast<int *>(outputIndicesDeviceAddr), static_cast<int *>(outputValuesDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<int *>(mParams.mBufferValues), mParams.mNumK, mParams.mNumBlockPerRow, mParams.mDescendFlag);
|
||||
checkKernelErrors;
|
||||
} else if (static_cast<CUDABackend*>(backend())->useFp16()) {
|
||||
TopKAllRows<int, half><<<grid1, block1, smemSize_1 * (sizeof(float) + sizeof(int))>>>(static_cast<const half *>(inputDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<half *>(mParams.mBufferValues), mParams.mNumK, mParams.mLengthRow, mParams.mMinHalf, mParams.mDescendFlag);
|
||||
checkKernelErrors;
|
||||
GetResultAllRows<int, half><<<grid2, block2, smemSize_2 * (sizeof(float) + sizeof(int))>>>(static_cast<int *>(outputIndicesDeviceAddr), static_cast<half *>(outputValuesDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<half *>(mParams.mBufferValues), mParams.mNumK, mParams.mNumBlockPerRow, mParams.mDescendFlag);
|
||||
checkKernelErrors;
|
||||
} else {
|
||||
TopKAllRows<int, float><<<grid1, block1, smemSize_1 * (sizeof(float) + sizeof(int))>>>(static_cast<const float *>(inputDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<float *>(mParams.mBufferValues), mParams.mNumK, mParams.mLengthRow, mParams.mMinFloat, mParams.mDescendFlag);
|
||||
checkKernelErrors;
|
||||
GetResultAllRows<int, float><<<grid2, block2, smemSize_2 * (sizeof(float) + sizeof(int))>>>(static_cast<int *>(outputIndicesDeviceAddr), static_cast<float *>(outputValuesDeviceAddr), static_cast<int *>(mParams.mBufferIndices), static_cast<float *>(mParams.mBufferValues), mParams.mNumK, mParams.mNumBlockPerRow, mParams.mDescendFlag);
|
||||
checkKernelErrors;
|
||||
}
|
||||
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
|
||||
class TopKV2Creator : public CUDABackend::Creator {
|
||||
public:
|
||||
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
||||
const MNN::Op* op, Backend* backend) const override {
|
||||
return new TopKV2Execution(op, backend);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static CUDACreatorRegister<TopKV2Creator> __init(OpType_TopKV2);
|
||||
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
//
|
||||
// TopKV2Execution.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2023/07/19.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
|
||||
#ifndef TopKV2Execution_hpp
|
||||
#define TopKV2Execution_hpp
|
||||
|
||||
#include "core/Execution.hpp"
|
||||
#include "core/Macro.h"
|
||||
#include "backend/cuda/core/CUDABackend.hpp"
|
||||
#include <memory>
|
||||
#include <limits>
|
||||
#include "cuda_fp16.h"
|
||||
|
||||
namespace MNN {
|
||||
namespace CUDA {
|
||||
|
||||
|
||||
class TopKV2Execution : public Execution {
|
||||
public:
|
||||
TopKV2Execution(const Op * op, Backend * backend);
|
||||
virtual ~TopKV2Execution() = default;
|
||||
|
||||
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
|
||||
private:
|
||||
struct TopKV2Params {
|
||||
int mLengthRow;
|
||||
int mNumRow;
|
||||
int mDescendFlag = 1;
|
||||
void * mBufferIndices;
|
||||
void * mBufferValues;
|
||||
|
||||
int mNumK;
|
||||
int mNumElePerRow;
|
||||
int mNumElePerThread;
|
||||
int mNumThreadPerBlock;
|
||||
int mNumElePerBlock;
|
||||
int mNumBlockPerRow;
|
||||
int mNumBlockTotal;
|
||||
int mNumBlockFinal;
|
||||
int mNumThreadFinal;
|
||||
|
||||
float mMinFloat = std::numeric_limits<float>::lowest();
|
||||
half mMinHalf = __float2half(-65504.0f);
|
||||
int mMinInt = -std::numeric_limits<int>::max();
|
||||
};
|
||||
|
||||
const Op * mOp;
|
||||
TopKV2Params mParams;
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -180,13 +180,13 @@ void UnpackBuffer(void* output, const void* input, const PackInfo* info, int byt
|
|||
int block_num = runtime->blocks_num(maxCount);
|
||||
int block_size = runtime->threads_num();
|
||||
int axisAlign = UP_DIV(info->axis / 4, PACK_NUMBER / 4) * PACK_NUMBER / 4;;
|
||||
if(bytes == 4) {
|
||||
UNPACKCOMMON_4<<<block_num, block_size>>>((const int4*)input, (int4*)output,
|
||||
maxCount, info->inside, info->axis / 4, info->outside,
|
||||
info->insideStride / 4, info->axisStride, axisAlign, is, cs);
|
||||
checkKernelErrors;
|
||||
return;
|
||||
}
|
||||
if(bytes == 4) {
|
||||
UNPACKCOMMON_4<<<block_num, block_size>>>((const int4*)input, (int4*)output,
|
||||
maxCount, info->inside, info->axis / 4, info->outside,
|
||||
info->insideStride / 4, info->axisStride, axisAlign, is, cs);
|
||||
checkKernelErrors;
|
||||
return;
|
||||
}
|
||||
if(bytes == 2) {
|
||||
UNPACKCOMMON_4<<<block_num, block_size>>>((const int2*)input, (int2*)output,
|
||||
maxCount, info->inside, info->axis / 4, info->outside,
|
||||
|
|
|
@ -429,6 +429,7 @@ Execution* OpenCLBackend::onCreate(const std::vector<Tensor*>& inputs, const std
|
|||
auto iter = creators->find(std::make_pair(op->type(), mOpenCLRuntime->getGpuMemType()));
|
||||
|
||||
if (iter == creators->end()) {
|
||||
mOpenCLRuntime->setDevideOpRecord();
|
||||
#if 0//close log
|
||||
if (nullptr != op->name()) {
|
||||
MNN_PRINT("Don't support type %s memObject:%d, %s\n", EnumNameOpType(op->type()), mOpenCLRuntime->getGpuMemType(), op->name()->c_str());
|
||||
|
@ -462,6 +463,7 @@ Execution* OpenCLBackend::onCreate(const std::vector<Tensor*>& inputs, const std
|
|||
}
|
||||
|
||||
if (!valid) {
|
||||
mOpenCLRuntime->setDevideOpRecord();
|
||||
#if 0//close log
|
||||
for (auto t : inputs) {
|
||||
auto tensorShape = OpenCL::tensorShapeFormat(t);
|
||||
|
@ -479,6 +481,7 @@ Execution* OpenCLBackend::onCreate(const std::vector<Tensor*>& inputs, const std
|
|||
|
||||
auto exe = iter->second->onCreate(inputs, outputs, op, this);
|
||||
if (NULL == exe) {
|
||||
mOpenCLRuntime->setDevideOpRecord();
|
||||
#if 0//close log
|
||||
if (nullptr != op->name()) {
|
||||
MNN_PRINT("The Creator Don't support type %s, memObject:%d, %s\n", MNN::EnumNameOpType(op->type()), mOpenCLRuntime->getGpuMemType(), op->name()->c_str());
|
||||
|
@ -498,12 +501,14 @@ void OpenCLBackend::onResizeBegin() {
|
|||
#ifndef ENABLE_OPENCL_TIME_PROFILER
|
||||
mOpenCLRuntime->setCommandQueueProfileEnable();
|
||||
#endif
|
||||
mOpenCLRuntime->releaseRecord();
|
||||
}
|
||||
|
||||
void OpenCLBackend::onResizeEnd() {
|
||||
#ifndef ENABLE_OPENCL_TIME_PROFILER
|
||||
mOpenCLRuntime->setCommandQueueProfileDisable();
|
||||
#endif
|
||||
mOpenCLRuntime->endRecord();
|
||||
}
|
||||
|
||||
void OpenCLBackend::onExecuteBegin() const {
|
||||
|
@ -515,6 +520,7 @@ void OpenCLBackend::onExecuteBegin() const {
|
|||
void OpenCLBackend::onExecuteEnd() const {
|
||||
mOpenCLRuntime->mQueueCount = 0;
|
||||
mOpenCLRuntime->clearRecord();
|
||||
mOpenCLRuntime->enqeueRecord();
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -569,11 +569,13 @@ void startRecord(OpenCLRuntime *runtime, cl_recording_qcom &recording){
|
|||
MNN_PRINT("start startRecord !\n");
|
||||
#endif
|
||||
cl_int res = CL_SUCCESS;
|
||||
if(recording != NULL){
|
||||
clReleaseRecordingQCOM(recording);
|
||||
if(runtime->isDevideOpRecord()){
|
||||
if(recording != NULL){
|
||||
clReleaseRecordingQCOM(recording);
|
||||
}
|
||||
recording = runtime->recordableQueue().NewRecordingQCOM(&res);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM");
|
||||
}
|
||||
recording = runtime->recordableQueue().NewRecordingQCOM(&res);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM");
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("end startRecord !\n");
|
||||
#endif
|
||||
|
@ -588,9 +590,11 @@ void endRecord(OpenCLRuntime *runtime, cl_recording_qcom &recording){
|
|||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("start endRecord !\n");
|
||||
#endif
|
||||
cl_int res = CL_SUCCESS;
|
||||
res = clEndRecordingQCOM(recording);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM");
|
||||
if(runtime->isDevideOpRecord()){
|
||||
cl_int res = CL_SUCCESS;
|
||||
res = clEndRecordingQCOM(recording);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM");
|
||||
}
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("end endRecord !\n");
|
||||
#endif
|
||||
|
@ -607,6 +611,25 @@ void recordKernel2d(const ::cl::Kernel &kernel, const std::vector<uint32_t> &gws
|
|||
MNN_PRINT("start recordKernel !\n");
|
||||
#endif
|
||||
cl_int res = CL_SUCCESS;
|
||||
if(!runtime->isDevideOpRecord()){
|
||||
auto RecordNum = runtime->getRecordNum();
|
||||
auto maxRecordNum = runtime->getUseRecordableQueueSize();
|
||||
if(RecordNum == 0){
|
||||
cl_recording_qcom recording = runtime->recordableQueue().NewRecordingQCOM(&res);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM");
|
||||
runtime->getRecordings()->emplace_back(recording);
|
||||
}else if(RecordNum == maxRecordNum){
|
||||
res = clEndRecordingQCOM( runtime->getRecordings()->back());
|
||||
MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM");
|
||||
cl_recording_qcom recording = runtime->recordableQueue().NewRecordingQCOM(&res);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM");
|
||||
runtime->getRecordings()->emplace_back(recording);
|
||||
RecordNum = 0;
|
||||
}
|
||||
RecordNum++;
|
||||
runtime->setRecordNum(RecordNum);
|
||||
}
|
||||
|
||||
std::vector<uint32_t> internalGlobalWS = gws;
|
||||
for (size_t i = 0; i < 2; ++i) {
|
||||
internalGlobalWS[i] = ROUND_UP(gws[i], std::max((uint32_t)1, lws[i]));
|
||||
|
@ -642,7 +665,24 @@ void recordKernel3d(const ::cl::Kernel &kernel, const std::vector<uint32_t> &gws
|
|||
for (size_t i = 0; i < 3; ++i) {
|
||||
internalGlobalWS[i] = ROUND_UP(gws[i], std::max((uint32_t)1, lws[i]));
|
||||
}
|
||||
|
||||
if(!runtime->isDevideOpRecord()){
|
||||
auto maxRecordNum = runtime->getUseRecordableQueueSize();
|
||||
auto RecordNum = runtime->getRecordNum();
|
||||
if(RecordNum == 0){
|
||||
cl_recording_qcom recording = runtime->recordableQueue().NewRecordingQCOM(&res);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM");
|
||||
runtime->getRecordings()->emplace_back(recording);
|
||||
}else if(RecordNum == maxRecordNum){
|
||||
res = clEndRecordingQCOM( runtime->getRecordings()->back());
|
||||
MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM");
|
||||
cl_recording_qcom recording = runtime->recordableQueue().NewRecordingQCOM(&res);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clNewRecordingQCOM");
|
||||
runtime->getRecordings()->emplace_back(recording);
|
||||
RecordNum = 0;
|
||||
}
|
||||
RecordNum++;
|
||||
runtime->setRecordNum(RecordNum);
|
||||
}
|
||||
|
||||
if(lws[0]==0 || lws[1]==0 || lws[2]==0){
|
||||
res = runtime->recordableQueue().enqueueNDRangeKernel(
|
||||
|
|
|
@ -235,9 +235,12 @@ OpenCLRuntime::OpenCLRuntime(const BackendConfig::PrecisionMode precision, const
|
|||
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
|
||||
{
|
||||
if((false == OpenCLSymbolsOperator::getOpenclSymbolsPtr()->isQcomError()) && getDeviceSupportsExtension(*(mFirstGPUDevicePtr.get()), "cl_qcom_recordable_queues")){
|
||||
mMaxRecordableQueueSize = mFirstGPUDevicePtr->getInfo<CL_DEVICE_RECORDABLE_QUEUE_MAX_SIZE>();
|
||||
uint32_t MaxRecordableQueueSize = mFirstGPUDevicePtr->getInfo<CL_DEVICE_RECORDABLE_QUEUE_MAX_SIZE>();
|
||||
cl_int err;
|
||||
if(mMaxRecordableQueueSize > 0){
|
||||
if(MaxRecordableQueueSize > 0){
|
||||
// TODO: Use setSessionHint to set the number of mUseRecordableQueueSize
|
||||
mUseRecordableQueueSize = 10;
|
||||
mUseRecordableQueueSize = MaxRecordableQueueSize < mUseRecordableQueueSize ? MaxRecordableQueueSize : mUseRecordableQueueSize;
|
||||
mUseRecordQueue = true;
|
||||
mRecordableQueuePtr = std::make_shared<cl::CommandQueue>(*mContext, *mFirstGPUDevicePtr, CL_QUEUE_RECORDABLE_QCOM, &err);
|
||||
if(err != CL_SUCCESS){
|
||||
|
@ -309,6 +312,23 @@ void OpenCLRuntime::setGpuMode(const int cl_mode_num) {
|
|||
if(totalSet != 1) {
|
||||
MNN_PRINT("set multi tuning mode is not permitted, please check cl_mode:%x!\n", cl_mode_num);
|
||||
}
|
||||
|
||||
totalSet = 0;
|
||||
isSet = (cl_mode_num & MNN_GPU_RECORD_OP);
|
||||
if(isSet) {
|
||||
mDevideOpRecord = true;
|
||||
totalSet++;
|
||||
}
|
||||
|
||||
isSet = (cl_mode_num & MNN_GPU_RECORD_BATCH);
|
||||
if(isSet) {
|
||||
mDevideOpRecord = false;
|
||||
totalSet++;
|
||||
}
|
||||
|
||||
if(totalSet > 1) {
|
||||
MNN_PRINT("set multi record kernel mode is not permitted, please check cl_mode:%x!\n", cl_mode_num);
|
||||
}
|
||||
}
|
||||
|
||||
void OpenCLRuntime::setCommandQueueProfileEnable() {
|
||||
|
@ -344,6 +364,7 @@ OpenCLRuntime::~OpenCLRuntime() {
|
|||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("start ~OpenCLRuntime !\n");
|
||||
#endif
|
||||
releaseRecord();
|
||||
mBuildProgramMap.clear();
|
||||
mRecordings.clear();
|
||||
mCommandQueuePtr.reset();
|
||||
|
@ -711,7 +732,7 @@ bool OpenCLRuntime::setCache(std::pair<const void*, size_t> cache) {
|
|||
|
||||
void OpenCLRuntime::clearRecord(){
|
||||
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
|
||||
if(mUseRecordQueue){
|
||||
if(mUseRecordQueue && mDevideOpRecord){
|
||||
for(int i = 0; i < mRecordings.size(); ++i){
|
||||
cl_int res = mCommandQueuePtr->EnqueueRecordingQCOM(mRecordings[i], 0, nullptr, 0, nullptr,
|
||||
0, nullptr, 0, nullptr, 0, nullptr, nullptr);
|
||||
|
@ -722,4 +743,40 @@ void OpenCLRuntime::clearRecord(){
|
|||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void OpenCLRuntime::enqeueRecord(){
|
||||
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
|
||||
if(mUseRecordQueue && !mDevideOpRecord){
|
||||
for(int i = 0; i < mRecordings.size(); ++i){
|
||||
cl_int res = mCommandQueuePtr->EnqueueRecordingQCOM(mRecordings[i], 0, nullptr, 0, nullptr,
|
||||
0, nullptr, 0, nullptr, 0, nullptr, nullptr);
|
||||
MNN_CHECK_CL_SUCCESS(res, "EnqueueRecordingQCOM");
|
||||
}
|
||||
mCommandQueuePtr->finish();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void OpenCLRuntime::endRecord(){
|
||||
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
|
||||
if(mUseRecordQueue && !mDevideOpRecord){
|
||||
if(!mRecordings.empty()){
|
||||
cl_int res = clEndRecordingQCOM(mRecordings.back());
|
||||
MNN_CHECK_CL_SUCCESS(res, "clEndRecordingQCOM");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void OpenCLRuntime::releaseRecord(){
|
||||
#if !defined(ENABLE_OPENCL_TIME_PROFILER) && defined(MNN_USE_LIB_WRAPPER)
|
||||
if(mUseRecordQueue && !mDevideOpRecord){
|
||||
for(int i = 0; i < mRecordings.size(); ++i){
|
||||
cl_int res = clReleaseRecordingQCOM(mRecordings[i]);
|
||||
MNN_CHECK_CL_SUCCESS(res, "clReleaseRecordingQCOM");
|
||||
}
|
||||
mRecordings.clear();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
} // namespace MNN
|
||||
|
|
|
@ -72,12 +72,24 @@ public:
|
|||
std::vector<cl_recording_qcom> *getRecordings(){
|
||||
return &mRecordings;
|
||||
}
|
||||
uint32_t getMaxRecordableQueueSize(){
|
||||
return mMaxRecordableQueueSize;
|
||||
uint32_t getUseRecordableQueueSize(){
|
||||
return mUseRecordableQueueSize;
|
||||
}
|
||||
bool isUseRecordQueue(){
|
||||
return mUseRecordQueue;
|
||||
}
|
||||
bool isDevideOpRecord(){
|
||||
return mDevideOpRecord;
|
||||
}
|
||||
void setDevideOpRecord(){
|
||||
mDevideOpRecord = true;
|
||||
}
|
||||
void setRecordNum(int num){
|
||||
mRecordNums = num;
|
||||
}
|
||||
uint32_t getRecordNum(){
|
||||
return mRecordNums;
|
||||
}
|
||||
GpuType getGpuType() {
|
||||
return mGpuType;
|
||||
}
|
||||
|
@ -105,6 +117,9 @@ public:
|
|||
void setCommandQueueProfileEnable();
|
||||
void setCommandQueueProfileDisable();
|
||||
void clearRecord();
|
||||
void enqeueRecord();
|
||||
void endRecord();
|
||||
void releaseRecord();
|
||||
|
||||
unsigned int mQueueCount = 0;
|
||||
unsigned int getQueueNum();
|
||||
|
@ -153,8 +168,10 @@ private:
|
|||
uint64_t mMaxLocalMemSize;
|
||||
uint32_t mMaxThreadsPerDevice;
|
||||
uint32_t mMaxWorkGroupSize;
|
||||
uint32_t mMaxRecordableQueueSize;
|
||||
uint32_t mUseRecordableQueueSize;
|
||||
uint32_t mRecordNums = 0;
|
||||
bool mUseRecordQueue = false;
|
||||
bool mDevideOpRecord = true;
|
||||
bool mIsSupportedFP16 = false;
|
||||
bool mIsDeviceSupportedFP16 = false;
|
||||
bool mIsDeviceSupportedLowPower = false;
|
||||
|
|
|
@ -59,9 +59,9 @@ std::pair<std::vector<uint32_t>, uint32_t> ConvBufCommonExecution::gws2dLwsTune
|
|||
lws_prefer[1] = lws[1];
|
||||
}
|
||||
}
|
||||
lws[0]++;
|
||||
lws[0]<<=1;
|
||||
}
|
||||
lws[1]++;
|
||||
lws[1]<<=1;
|
||||
}
|
||||
} else if(runtime->getCLTuneLevel() == Wide) {
|
||||
while(lws[1] <= gws[1] || lws[1] <= 6) {
|
||||
|
@ -88,12 +88,12 @@ std::pair<std::vector<uint32_t>, uint32_t> ConvBufCommonExecution::gws2dLwsTune
|
|||
}
|
||||
}
|
||||
do {
|
||||
lws[0]++;
|
||||
lws[0]<<=1;
|
||||
}
|
||||
while(((2*gws[0])%lws[0] > 1) && (lws[0] & (lws[0] - 1)) != 0 && (lws[0] <= gws[0]) && (lws[0] > 6));//divisible powOfTwo lessThanSix
|
||||
}
|
||||
do {
|
||||
lws[1]++;
|
||||
lws[1]<<=1;
|
||||
}
|
||||
while(((2*gws[1])%lws[1] > 1) && (lws[1] & (lws[1] - 1)) != 0 && (lws[1] <= gws[1]) && (lws[1] > 6));//divisible powOfTwo lessThanSix
|
||||
}
|
||||
|
@ -122,12 +122,12 @@ std::pair<std::vector<uint32_t>, uint32_t> ConvBufCommonExecution::gws2dLwsTune
|
|||
}
|
||||
}
|
||||
do {
|
||||
lws[0]++;
|
||||
lws[0]<<=1;
|
||||
}
|
||||
while(((2*gws[0])%lws[0] > 1) && (lws[0] & (lws[0] - 1)) != 0 && (lws[0] <= gws[0]) && (lws[0] > 6));//divisible powOfTwo lessThanSix
|
||||
}
|
||||
do {
|
||||
lws[1]++;
|
||||
lws[1]<<=1;
|
||||
}
|
||||
while(((2*gws[1])%lws[1] > 1) && (lws[1] & (lws[1] - 1)) != 0 && (lws[1] <= gws[1]) && (lws[1] <= 6));//divisible powOfTwo lessThanSix
|
||||
}
|
||||
|
@ -156,12 +156,12 @@ std::pair<std::vector<uint32_t>, uint32_t> ConvBufCommonExecution::gws2dLwsTune
|
|||
}
|
||||
}
|
||||
do {
|
||||
lws[0]++;
|
||||
lws[0]<<=1;
|
||||
}
|
||||
while(((2*gws[0])%lws[0] > 1) && (lws[0] & (lws[0] - 1)) != 0 && (lws[0] <= gws[0]) && (lws[0] <= 6));//divisible powOfTwo lessThanSix
|
||||
}
|
||||
do {
|
||||
lws[1]++;
|
||||
lws[1]<<=1;
|
||||
}
|
||||
while(((2*gws[1])%lws[1] > 1) && (lws[1] & (lws[1] - 1)) != 0 && (lws[1] <= gws[1]) && (lws[1] <= 6));//divisible powOfTwo lessThanSix
|
||||
}
|
||||
|
@ -476,7 +476,14 @@ ErrorCode ConvBufExecution::onResize(const std::vector<Tensor *> &inputs, const
|
|||
std::vector<uint32_t> localWorkSize[total_kernel];
|
||||
std::pair<int, int> min_cost(INT_MAX, 0);//(min_time, min_index)
|
||||
for(int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) {
|
||||
kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[knl_idx], mBuildOptions);
|
||||
std::set<std::string> buildOption = mBuildOptions;
|
||||
if(outputShape.at(3) % itemC[knl_idx] != 0){
|
||||
buildOption.emplace("-DCHANNEL_LEAVE");
|
||||
}
|
||||
if((outputShape.at(2) % itemW[knl_idx]) != 0){
|
||||
buildOption.emplace("-DBLOCK_LEAVE");
|
||||
}
|
||||
kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[knl_idx], buildOption);
|
||||
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx]));
|
||||
|
||||
uint32_t idx = 0;
|
||||
|
@ -515,7 +522,14 @@ ErrorCode ConvBufExecution::onResize(const std::vector<Tensor *> &inputs, const
|
|||
}
|
||||
mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]};
|
||||
|
||||
mKernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[min_index], mBuildOptions);
|
||||
std::set<std::string> buildOption = mBuildOptions;
|
||||
if(outputShape.at(3) % itemC[min_index] != 0){
|
||||
buildOption.emplace("-DCHANNEL_LEAVE");
|
||||
}
|
||||
if((outputShape.at(2) % itemW[min_index]) != 0){
|
||||
buildOption.emplace("-DBLOCK_LEAVE");
|
||||
}
|
||||
mKernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[min_index], buildOption);
|
||||
uint32_t idx = 0;
|
||||
cl_int ret = CL_SUCCESS;
|
||||
|
||||
|
@ -542,48 +556,29 @@ ErrorCode ConvBufExecution::onResize(const std::vector<Tensor *> &inputs, const
|
|||
int dilationShape[2] = {mDilations[0], mDilations[1]};
|
||||
|
||||
// {"conv_2d_c4h1w2", "conv_2d_c4h1w1", "conv_2d_c8h1w1", "conv_2d_c4h1w4", "conv_2d_c8h2w1", "conv_2d_c4h4w1"};
|
||||
const int total_kernel = 6;
|
||||
std::string kernelName[total_kernel] = {"conv_2d_c4h1w1", "conv_2d_c4h1w2", "conv_2d_c4h4w1", "conv_2d_c8h2w1", "conv_2d_c8h4w1", "conv_2d_c4h1w4"};
|
||||
int itemC[total_kernel] = {4, 4, 4, 8, 8, 4};
|
||||
int itemH[total_kernel] = {1, 1, 4, 2, 4, 1};
|
||||
int itemW[total_kernel] = {1, 2, 1, 1, 1, 4};
|
||||
const int total_kernel = 7;
|
||||
std::string kernelName[total_kernel] = {"conv_2d_c4h1w1", "conv_2d_c4h1w2", "conv_2d_c4h4w1", "conv_2d_c8h2w1", "conv_2d_c8h4w1", "conv_2d_c4h1w4", "conv_2d_c8h1w4"};
|
||||
int itemC[total_kernel] = {4, 4, 4, 8, 8, 4, 8};
|
||||
int itemH[total_kernel] = {1, 1, 4, 2, 4, 1, 1};
|
||||
int itemW[total_kernel] = {1, 2, 1, 1, 1, 4, 4};
|
||||
|
||||
|
||||
int actual_kernel = total_kernel;
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->getCLTuneLevel() == Normal) {
|
||||
actual_kernel = 2;
|
||||
} else if(mOpenCLBackend->getOpenCLRuntime()->getCLTuneLevel() == Fast || mOpenCLBackend->getOpenCLRuntime()->getCLTuneLevel() == None) {
|
||||
actual_kernel = 1;
|
||||
}else if(mOpenCLBackend->getOpenCLRuntime()->getCLTuneLevel() == Wide){
|
||||
actual_kernel = 4;
|
||||
auto gpuType = mOpenCLBackend->getOpenCLRuntime()->getGpuType();
|
||||
auto maliArType = mOpenCLBackend->getOpenCLRuntime()->getMaliAr();
|
||||
if(gpuType == MNN::MALI && maliArType == MNN::VALHALL){
|
||||
if(outputShape.at(3) <= 8){
|
||||
kernelName[3] = "conv_2d_c4h1w4";
|
||||
itemC[3] = 4;
|
||||
itemH[3] = 1;
|
||||
itemW[3] = 4;
|
||||
}else{
|
||||
kernelName[2] = "conv_2d_c8h2w1";
|
||||
itemC[2] = 8;
|
||||
itemH[2] = 2;
|
||||
itemW[2] = 1;
|
||||
|
||||
kernelName[3] = "conv_2d_c8h4w1";
|
||||
itemC[3] = 8;
|
||||
itemH[3] = 4;
|
||||
itemW[3] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
cl::Kernel kernel[total_kernel];
|
||||
std::vector<uint32_t> globalWorkSize[total_kernel];
|
||||
std::vector<uint32_t> localWorkSize[total_kernel];
|
||||
std::pair<int, int> min_cost(INT_MAX, 0);//(min_time, min_index)
|
||||
for(int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) {
|
||||
kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[knl_idx], mBuildOptions);
|
||||
std::set<std::string> buildOption = mBuildOptions;
|
||||
if(outputShape.at(3) % itemC[knl_idx] != 0){
|
||||
buildOption.emplace("-DCHANNEL_LEAVE");
|
||||
}
|
||||
if((outputShape.at(2) % itemW[knl_idx]) != 0 || (outputShape.at(1) % itemH[knl_idx]) != 0){
|
||||
buildOption.emplace("-DBLOCK_LEAVE");
|
||||
}
|
||||
kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[knl_idx], buildOption);
|
||||
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx]));
|
||||
|
||||
globalWorkSize[knl_idx] = {static_cast<uint32_t>(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast<uint32_t>(outputShape.at(0) * UP_DIV(outputShape.at(1), itemH[knl_idx]))};
|
||||
|
@ -620,7 +615,14 @@ ErrorCode ConvBufExecution::onResize(const std::vector<Tensor *> &inputs, const
|
|||
int min_index = min_cost.second;
|
||||
mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]};
|
||||
|
||||
mKernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[min_index], mBuildOptions);
|
||||
std::set<std::string> buildOption = mBuildOptions;
|
||||
if(outputShape.at(3) % itemC[min_index] != 0){
|
||||
buildOption.emplace("-DCHANNEL_LEAVE");
|
||||
}
|
||||
if((outputShape.at(2) % itemW[min_index]) != 0 || (outputShape.at(1) % itemH[min_index]) != 0){
|
||||
buildOption.emplace("-DBLOCK_LEAVE");
|
||||
}
|
||||
mKernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[min_index], buildOption);
|
||||
|
||||
uint32_t idx = 0;
|
||||
cl_int ret = CL_SUCCESS;
|
||||
|
|
|
@ -38,7 +38,7 @@ bool ConvBufWinograd::valid(const Convolution2DCommon* common, const Tensor* inp
|
|||
if(input->channel() < 32 || input->channel() > input_channel_limit){
|
||||
return false;
|
||||
}
|
||||
return (input->width() <= 16 && input->height() <= 16);
|
||||
return (input->width() <= 32 && input->height() <= 32);
|
||||
}
|
||||
|
||||
ConvBufWinograd::ConvBufWinograd(const MNN::Convolution2D* op, Backend* backend) : Execution(backend) {
|
||||
|
|
|
@ -147,6 +147,7 @@ ErrorCode RasterBufExecution::onResize(const std::vector<Tensor *> &____inputs,
|
|||
return NO_ERROR;
|
||||
}
|
||||
|
||||
bool cancombine = CanCombine(outputs);
|
||||
// Alloc Temp buffer
|
||||
auto bufferPool = ((OpenCLBackend *)backend())->getBufferPool();
|
||||
auto bufferUnitSize = runtime->isSupportedFP16() ? sizeof(half_float::half) : sizeof(float);
|
||||
|
@ -170,6 +171,9 @@ ErrorCode RasterBufExecution::onResize(const std::vector<Tensor *> &____inputs,
|
|||
bufferPool->recycle(mTempOutput);
|
||||
|
||||
auto originNum = mTempInput.size();
|
||||
if(cancombine){
|
||||
regionNum = 1;
|
||||
}
|
||||
mUnits.resize(regionNum + originNum + 1);
|
||||
|
||||
int kernel_idx = 0;
|
||||
|
@ -241,16 +245,25 @@ ErrorCode RasterBufExecution::onResize(const std::vector<Tensor *> &____inputs,
|
|||
}
|
||||
|
||||
// buffer raster
|
||||
for (auto& slice : des->regions)
|
||||
{
|
||||
if(cancombine){
|
||||
auto regions = des->regions;
|
||||
auto slice = regions[0];
|
||||
int nums = regions.size();
|
||||
int src_offset = regions[1].src.offset - slice.src.offset;
|
||||
int dst_offset = regions[1].dst.offset - slice.dst.offset;
|
||||
|
||||
Unit &unit = mUnits[kernel_idx++];
|
||||
unit.kernel = runtime->buildKernel("raster_buf", "raster_buffer", {});
|
||||
|
||||
const std::vector<uint32_t> gws = {(uint32_t)slice.size[2],
|
||||
(uint32_t)slice.size[1],
|
||||
(uint32_t)slice.size[0]};
|
||||
unit.kernel = runtime->buildKernel("raster", "raster_buffer_combine", {});
|
||||
|
||||
unit.globalWorkSize = {(uint32_t)slice.size[2] * nums,
|
||||
(uint32_t)slice.size[1],
|
||||
(uint32_t)slice.size[0]};
|
||||
|
||||
const std::vector<uint32_t> gws = {(uint32_t)slice.size[2] * nums,
|
||||
(uint32_t)slice.size[1],
|
||||
(uint32_t)slice.size[0]};
|
||||
uint32_t mMaxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(unit.kernel));
|
||||
|
||||
|
||||
uint32_t idx = 0;
|
||||
cl_int ret = CL_SUCCESS;
|
||||
ret |= unit.kernel.setArg(idx++, gws[0]);
|
||||
|
@ -258,27 +271,71 @@ ErrorCode RasterBufExecution::onResize(const std::vector<Tensor *> &____inputs,
|
|||
ret |= unit.kernel.setArg(idx++, gws[2]);
|
||||
ret |= unit.kernel.setArg(idx++, *(mTempInput[slice.origin]));
|
||||
ret |= unit.kernel.setArg(idx++, slice.src.offset);
|
||||
ret |= unit.kernel.setArg(idx++, src_offset);
|
||||
ret |= unit.kernel.setArg(idx++, slice.src.stride[0]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.src.stride[1]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.src.stride[2]);
|
||||
ret |= unit.kernel.setArg(idx++, *mTempOutput);
|
||||
ret |= unit.kernel.setArg(idx++, slice.dst.offset);
|
||||
ret |= unit.kernel.setArg(idx++, dst_offset);
|
||||
ret |= unit.kernel.setArg(idx++, slice.dst.stride[0]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.dst.stride[1]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.dst.stride[2]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.size[2]);
|
||||
if(ret != CL_SUCCESS)
|
||||
{
|
||||
MNN_PRINT("setArg err %d\n", (int)ret);
|
||||
}
|
||||
|
||||
std::string name = "raster_buffer";
|
||||
std::string name = "rasterBuffer";
|
||||
const std::vector<uint32_t> lws = localWS3DDefault(gws, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), name, unit.kernel).first;
|
||||
|
||||
unit.localWorkSize = {lws[0], lws[1], lws[2]};
|
||||
|
||||
unit.globalWorkSize = {ROUND_UP(gws[0], std::max((uint32_t)1, lws[0])),
|
||||
ROUND_UP(gws[1], std::max((uint32_t)1, lws[1])),
|
||||
ROUND_UP(gws[2], std::max((uint32_t)1, lws[2]))};
|
||||
ROUND_UP(gws[1], std::max((uint32_t)1, lws[1])),
|
||||
ROUND_UP(gws[2], std::max((uint32_t)1, lws[2]))};
|
||||
recordKernel3d(unit.kernel, gws, lws, runtime);
|
||||
}else{
|
||||
for (auto& slice : des->regions)
|
||||
{
|
||||
Unit &unit = mUnits[kernel_idx++];
|
||||
unit.kernel = runtime->buildKernel("raster_buf", "raster_buffer", {});
|
||||
|
||||
const std::vector<uint32_t> gws = {(uint32_t)slice.size[2],
|
||||
(uint32_t)slice.size[1],
|
||||
(uint32_t)slice.size[0]};
|
||||
uint32_t mMaxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(unit.kernel));
|
||||
|
||||
uint32_t idx = 0;
|
||||
cl_int ret = CL_SUCCESS;
|
||||
ret |= unit.kernel.setArg(idx++, gws[0]);
|
||||
ret |= unit.kernel.setArg(idx++, gws[1]);
|
||||
ret |= unit.kernel.setArg(idx++, gws[2]);
|
||||
ret |= unit.kernel.setArg(idx++, *(mTempInput[slice.origin]));
|
||||
ret |= unit.kernel.setArg(idx++, slice.src.offset);
|
||||
ret |= unit.kernel.setArg(idx++, slice.src.stride[0]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.src.stride[1]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.src.stride[2]);
|
||||
ret |= unit.kernel.setArg(idx++, *mTempOutput);
|
||||
ret |= unit.kernel.setArg(idx++, slice.dst.offset);
|
||||
ret |= unit.kernel.setArg(idx++, slice.dst.stride[0]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.dst.stride[1]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.dst.stride[2]);
|
||||
if(ret != CL_SUCCESS)
|
||||
{
|
||||
MNN_PRINT("setArg err %d\n", (int)ret);
|
||||
}
|
||||
|
||||
std::string name = "raster_buffer";
|
||||
const std::vector<uint32_t> lws = localWS3DDefault(gws, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), name, unit.kernel).first;
|
||||
|
||||
unit.localWorkSize = {lws[0], lws[1], lws[2]};
|
||||
|
||||
unit.globalWorkSize = {ROUND_UP(gws[0], std::max((uint32_t)1, lws[0])),
|
||||
ROUND_UP(gws[1], std::max((uint32_t)1, lws[1])),
|
||||
ROUND_UP(gws[2], std::max((uint32_t)1, lws[2]))};
|
||||
}
|
||||
}
|
||||
|
||||
//buffer to nc4hw4 buffer
|
||||
|
@ -349,6 +406,44 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
bool RasterBufExecution::CanCombine(const std::vector<Tensor *> &outputs){
|
||||
auto des = TensorUtils::getDescribe(outputs[0]);
|
||||
auto regions = des->regions;
|
||||
if(regions.size() < 2)
|
||||
return false;
|
||||
auto origin = regions[0].origin;
|
||||
const int size0 = regions[0].size[0];
|
||||
const int size1 = regions[0].size[1];
|
||||
const int size2 = regions[0].size[2];
|
||||
const int src_offset = regions[1].src.offset - regions[0].src.offset;
|
||||
const int dst_offset = regions[1].dst.offset - regions[0].dst.offset;
|
||||
const int src_sride0 = regions[0].src.stride[0];
|
||||
const int src_sride1 = regions[0].src.stride[1];
|
||||
const int src_sride2 = regions[0].src.stride[2];
|
||||
const int dst_sride0 = regions[0].dst.stride[0];
|
||||
const int dst_sride1 = regions[0].dst.stride[1];
|
||||
const int dst_sride2 = regions[0].dst.stride[2];
|
||||
bool res = true;
|
||||
for(int i = 1; i < regions.size(); ++i){
|
||||
res &= regions[i].origin == origin;
|
||||
res &= regions[i].size[0] == size0;
|
||||
res &= regions[i].size[1] == size1;
|
||||
res &= regions[i].size[2] == size2;
|
||||
res &= regions[i].src.stride[0] == src_sride0;
|
||||
res &= regions[i].src.stride[1] == src_sride1;
|
||||
res &= regions[i].src.stride[2] == src_sride2;
|
||||
res &= regions[i].dst.stride[0] == dst_sride0;
|
||||
res &= regions[i].dst.stride[1] == dst_sride1;
|
||||
res &= regions[i].dst.stride[2] == dst_sride2;
|
||||
res &= (regions[i].src.offset - regions[i - 1].src.offset) == src_offset;
|
||||
res &= (regions[i].dst.offset - regions[i - 1].dst.offset) == dst_offset;
|
||||
if(res == false){
|
||||
return res;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
OpenCLCreatorRegister<RasterCreator> __RasterBuf_op(OpType_Raster, BUFFER);
|
||||
} // namespace OpenCL
|
||||
} // namespace MNN
|
||||
|
|
|
@ -28,6 +28,7 @@ public:
|
|||
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
|
||||
private:
|
||||
bool CanCombine(const std::vector<Tensor *> &outputs);
|
||||
std::map<Tensor*, cl::Buffer *> mTempInput;
|
||||
cl::Buffer *mTempOutput;
|
||||
OpenCLBackend *mOpenCLBackend;
|
||||
|
|
|
@ -285,13 +285,19 @@ __kernel
|
|||
#if SET_ATTRIBUTE
|
||||
__attribute__((work_group_size_hint(16, 16, 1)))
|
||||
#endif
|
||||
void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t weights,
|
||||
void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
|
||||
#ifdef USE_BUFFER
|
||||
__global const FLOAT *weights,
|
||||
#else
|
||||
__read_only image2d_t weights,
|
||||
#endif
|
||||
__read_only image2d_t bias,
|
||||
__write_only image2d_t output,
|
||||
__private const int2 input_shape,
|
||||
__private const int in_channel_block, __private const int2 output_shape,
|
||||
__private const int2 stride_shape,
|
||||
__private const int output_width_4) {
|
||||
__private const int output_width_4,
|
||||
__private const int out_channel_blocks) {
|
||||
|
||||
const int output_channel_width_idx = get_global_id(0);
|
||||
const int output_batch_height_idx = get_global_id(1);
|
||||
|
@ -305,6 +311,12 @@ void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only ima
|
|||
FLOAT4 out2 = out0;
|
||||
FLOAT4 out3 = out0;
|
||||
|
||||
#ifdef MNN_CONV_S1D1
|
||||
int intput_width_idx0 = output_width_block_idx << 2;
|
||||
int intput_width_idx1 = intput_width_idx0 + 1;
|
||||
int intput_width_idx2 = intput_width_idx0 + 2;
|
||||
int intput_width_idx3 = intput_width_idx0 + 3;
|
||||
#else
|
||||
int intput_width_idx0 = mul24(output_width_block_idx, stride_shape.y*4);
|
||||
int intput_width_idx1 = intput_width_idx0 + stride_shape.y;
|
||||
int intput_width_idx2 = intput_width_idx1 + stride_shape.y;
|
||||
|
@ -314,7 +326,7 @@ void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only ima
|
|||
intput_width_idx1 = select(intput_width_idx1, INT_MIN, intput_width_idx1 >= input_shape.y);
|
||||
intput_width_idx2 = select(intput_width_idx2, INT_MIN, intput_width_idx2 >= input_shape.y);
|
||||
intput_width_idx3 = select(intput_width_idx3, INT_MIN, intput_width_idx3 >= input_shape.y);
|
||||
|
||||
#endif
|
||||
int batch_index = output_batch_height_idx / output_shape.x;
|
||||
int input_height_block_idx = mul24((output_batch_height_idx % output_shape.x), stride_shape.x) + batch_index * input_shape.x;
|
||||
|
||||
|
@ -326,19 +338,26 @@ void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only ima
|
|||
FLOAT4 weights1;
|
||||
FLOAT4 weights2;
|
||||
FLOAT4 weights3;
|
||||
int weight_offset = output_channel_block_idx * in_channel_block * 4 * 4;
|
||||
|
||||
for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block; ++in_channel_block_idx) {
|
||||
int input_width_base = in_channel_block_idx * input_shape.y;
|
||||
int weights_width_base = in_channel_block_idx << 2;
|
||||
in0 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx0, input_height_block_idx));
|
||||
in1 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx1, input_height_block_idx));
|
||||
in2 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx2, input_height_block_idx));
|
||||
in3 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx3, input_height_block_idx));
|
||||
|
||||
#ifdef USE_BUFFER
|
||||
weights0 = vload4(weights_width_base, weights + weight_offset);
|
||||
weights1 = vload4(weights_width_base + 1, weights + weight_offset);
|
||||
weights2 = vload4(weights_width_base + 2, weights + weight_offset);
|
||||
weights3 = vload4(weights_width_base + 3, weights + weight_offset);
|
||||
#else
|
||||
weights0 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 0, output_channel_block_idx));
|
||||
weights1 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 1, output_channel_block_idx));
|
||||
weights2 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 2, output_channel_block_idx));
|
||||
weights3 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 3, output_channel_block_idx));
|
||||
#endif
|
||||
in0 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx0, input_height_block_idx));
|
||||
in1 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx1, input_height_block_idx));
|
||||
in2 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx2, input_height_block_idx));
|
||||
in3 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx3, input_height_block_idx));
|
||||
|
||||
CALCULATE_OUTPUT(0);
|
||||
CALCULATE_OUTPUT(1);
|
||||
|
@ -386,7 +405,187 @@ __kernel
|
|||
#if SET_ATTRIBUTE
|
||||
__attribute__((work_group_size_hint(16, 16, 1)))
|
||||
#endif
|
||||
void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t weights,
|
||||
void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
|
||||
#ifdef USE_BUFFER
|
||||
__global const FLOAT *weights,
|
||||
#else
|
||||
__read_only image2d_t weights,
|
||||
#endif
|
||||
__read_only image2d_t bias,
|
||||
__write_only image2d_t output,
|
||||
__private const int2 input_shape,
|
||||
__private const int in_channel_block, __private const int2 output_shape,
|
||||
__private const int2 stride_shape,
|
||||
__private const int output_width_4,
|
||||
__private const int out_channel_blocks) {
|
||||
|
||||
const int output_channel_width_idx = get_global_id(0);
|
||||
const int output_batch_height_idx = get_global_id(1);
|
||||
DEAL_NON_UNIFORM_DIM2(output_channel_width_idx, output_batch_height_idx);
|
||||
|
||||
const int output_channel_block_idx = output_channel_width_idx / output_width_4;
|
||||
const int output_width_block_idx = output_channel_width_idx % output_width_4;
|
||||
const int output_channel_idx = output_channel_block_idx << 1;
|
||||
|
||||
FLOAT4 out0 = RI_F(bias, SAMPLER, (int2)(output_channel_idx, 0));
|
||||
FLOAT4 out1 = out0;
|
||||
FLOAT4 out2 = out0;
|
||||
FLOAT4 out3 = out0;
|
||||
|
||||
FLOAT4 out4 = RI_F(bias, SAMPLER, (int2)(output_channel_idx + 1, 0));
|
||||
FLOAT4 out5 = out4;
|
||||
FLOAT4 out6 = out4;
|
||||
FLOAT4 out7 = out4;
|
||||
|
||||
#ifdef MNN_CONV_S1D1
|
||||
int intput_width_idx0 = output_width_block_idx << 2;
|
||||
int intput_width_idx1 = intput_width_idx0 + 1;
|
||||
int intput_width_idx2 = intput_width_idx0 + 2;
|
||||
int intput_width_idx3 = intput_width_idx0 + 3;
|
||||
#else
|
||||
int intput_width_idx0 = mul24(output_width_block_idx, stride_shape.y*4);
|
||||
int intput_width_idx1 = intput_width_idx0 + stride_shape.y;
|
||||
int intput_width_idx2 = intput_width_idx1 + stride_shape.y;
|
||||
int intput_width_idx3 = intput_width_idx2 + stride_shape.y;
|
||||
|
||||
intput_width_idx0 = select(intput_width_idx0, INT_MIN, intput_width_idx0 >= input_shape.y);
|
||||
intput_width_idx1 = select(intput_width_idx1, INT_MIN, intput_width_idx1 >= input_shape.y);
|
||||
intput_width_idx2 = select(intput_width_idx2, INT_MIN, intput_width_idx2 >= input_shape.y);
|
||||
intput_width_idx3 = select(intput_width_idx3, INT_MIN, intput_width_idx3 >= input_shape.y);
|
||||
#endif
|
||||
|
||||
int batch_index = output_batch_height_idx / output_shape.x;
|
||||
int input_height_block_idx = mul24((output_batch_height_idx % output_shape.x), stride_shape.x) + batch_index * input_shape.x;
|
||||
|
||||
FLOAT4 in0;
|
||||
FLOAT4 in1;
|
||||
FLOAT4 in2;
|
||||
FLOAT4 in3;
|
||||
FLOAT4 weights0;
|
||||
FLOAT4 weights1;
|
||||
FLOAT4 weights2;
|
||||
FLOAT4 weights3;
|
||||
FLOAT4 weights4;
|
||||
FLOAT4 weights5;
|
||||
FLOAT4 weights6;
|
||||
FLOAT4 weights7;
|
||||
int weight_offset = output_channel_idx * in_channel_block * 4 * 4;
|
||||
int weight_offset1 = weight_offset + in_channel_block * 4 * 4;
|
||||
|
||||
for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block; ++in_channel_block_idx) {
|
||||
int input_width_base = in_channel_block_idx * input_shape.y;
|
||||
int weights_width_base = in_channel_block_idx << 2;
|
||||
in0 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx0, input_height_block_idx));
|
||||
in1 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx1, input_height_block_idx));
|
||||
in2 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx2, input_height_block_idx));
|
||||
in3 = RI_F(input, SAMPLER, (int2)(input_width_base + intput_width_idx3, input_height_block_idx));
|
||||
|
||||
#ifdef USE_BUFFER
|
||||
weights0 = vload4(weights_width_base, weights + weight_offset);
|
||||
weights1 = vload4(weights_width_base + 1, weights + weight_offset);
|
||||
weights2 = vload4(weights_width_base + 2, weights + weight_offset);
|
||||
weights3 = vload4(weights_width_base + 3, weights + weight_offset);
|
||||
|
||||
weights4 = vload4(weights_width_base, weights + weight_offset1);
|
||||
weights5 = vload4(weights_width_base + 1, weights + weight_offset1);
|
||||
weights6 = vload4(weights_width_base + 2, weights + weight_offset1);
|
||||
weights7 = vload4(weights_width_base + 3, weights + weight_offset1);
|
||||
#else
|
||||
weights0 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 0, output_channel_idx));
|
||||
weights1 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 1, output_channel_idx));
|
||||
weights2 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 2, output_channel_idx));
|
||||
weights3 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 3, output_channel_idx));
|
||||
|
||||
weights4 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 0, output_channel_idx + 1));
|
||||
weights5 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 1, output_channel_idx + 1));
|
||||
weights6 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 2, output_channel_idx + 1));
|
||||
weights7 = RI_F(weights, SAMPLER, (int2)(weights_width_base + 3, output_channel_idx + 1));
|
||||
#endif
|
||||
|
||||
CALCULATE_OUTPUT(0);
|
||||
CALCULATE_OUTPUT(1);
|
||||
CALCULATE_OUTPUT(2);
|
||||
CALCULATE_OUTPUT(3);
|
||||
|
||||
CALCULATE_OUTPUT_WEIGHTS4(4, 0);
|
||||
CALCULATE_OUTPUT_WEIGHTS4(5, 1);
|
||||
CALCULATE_OUTPUT_WEIGHTS4(6, 2);
|
||||
CALCULATE_OUTPUT_WEIGHTS4(7, 3);
|
||||
}
|
||||
|
||||
#ifdef RELU
|
||||
out0 = fmax(out0, (FLOAT4)0);
|
||||
out1 = fmax(out1, (FLOAT4)0);
|
||||
out2 = fmax(out2, (FLOAT4)0);
|
||||
out3 = fmax(out3, (FLOAT4)0);
|
||||
out4 = fmax(out4, (FLOAT4)0);
|
||||
out5 = fmax(out5, (FLOAT4)0);
|
||||
out6 = fmax(out6, (FLOAT4)0);
|
||||
out7 = fmax(out7, (FLOAT4)0);
|
||||
#endif
|
||||
|
||||
#ifdef RELU6
|
||||
out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
|
||||
out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6);
|
||||
out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6);
|
||||
out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6);
|
||||
out4 = clamp(out4, (FLOAT4)0, (FLOAT4)6);
|
||||
out5 = clamp(out5, (FLOAT4)0, (FLOAT4)6);
|
||||
out6 = clamp(out6, (FLOAT4)0, (FLOAT4)6);
|
||||
out7 = clamp(out7, (FLOAT4)0, (FLOAT4)6);
|
||||
#endif
|
||||
|
||||
const int out_x_base = mul24(output_channel_idx, output_shape.y);
|
||||
int out_x_idx = output_width_block_idx << 2;
|
||||
|
||||
const int remain = output_shape.y - out_x_idx;
|
||||
int output_idx = out_x_base + out_x_idx;
|
||||
if (remain >= 4) {
|
||||
WI_F(output, (int2)(output_idx, output_batch_height_idx), out0);
|
||||
WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1);
|
||||
WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2);
|
||||
WI_F(output, (int2)(output_idx + 3, output_batch_height_idx), out3);
|
||||
} else if (remain == 3) {
|
||||
WI_F(output, (int2)(output_idx, output_batch_height_idx), out0);
|
||||
WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1);
|
||||
WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out2);
|
||||
} else if (remain == 2) {
|
||||
WI_F(output, (int2)(output_idx, output_batch_height_idx), out0);
|
||||
WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out1);
|
||||
} else if (remain == 1) {
|
||||
WI_F(output, (int2)(output_idx, output_batch_height_idx), out0);
|
||||
}
|
||||
|
||||
if(output_channel_idx + 1 >= out_channel_blocks)
|
||||
return;
|
||||
output_idx += output_shape.y;
|
||||
if (remain >= 4) {
|
||||
WI_F(output, (int2)(output_idx, output_batch_height_idx), out4);
|
||||
WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out5);
|
||||
WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out6);
|
||||
WI_F(output, (int2)(output_idx + 3, output_batch_height_idx), out7);
|
||||
} else if (remain == 3) {
|
||||
WI_F(output, (int2)(output_idx, output_batch_height_idx), out4);
|
||||
WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out5);
|
||||
WI_F(output, (int2)(output_idx + 2, output_batch_height_idx), out6);
|
||||
} else if (remain == 2) {
|
||||
WI_F(output, (int2)(output_idx, output_batch_height_idx), out4);
|
||||
WI_F(output, (int2)(output_idx + 1, output_batch_height_idx), out5);
|
||||
} else if (remain == 1) {
|
||||
WI_F(output, (int2)(output_idx, output_batch_height_idx), out4);
|
||||
}
|
||||
}
|
||||
|
||||
__kernel
|
||||
#if SET_ATTRIBUTE
|
||||
__attribute__((work_group_size_hint(16, 16, 1)))
|
||||
#endif
|
||||
void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
|
||||
#ifdef USE_BUFFER
|
||||
__global const FLOAT *weights,
|
||||
#else
|
||||
__read_only image2d_t weights,
|
||||
#endif
|
||||
#ifdef BIAS
|
||||
__read_only image2d_t bias,
|
||||
#endif
|
||||
|
@ -425,26 +624,36 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only
|
|||
|
||||
#ifdef MNN_CONV_S1D1
|
||||
const int height_start = mad24((output_batch_height_idx % output_shape.x), 1, -padding_shape.x);
|
||||
int in_height_start = select(0, (-height_start), height_start < 0) + height_start;
|
||||
const int kh_start = select(0, (-height_start), height_start < 0);
|
||||
int in_height_start = kh_start + height_start;
|
||||
int in_height_end = min(weights_shape.x + height_start, input_shape.x);
|
||||
|
||||
const int batch_idx = mul24((output_batch_height_idx / output_shape.x), input_shape.x);
|
||||
const int weights_h_idx = mul24(out_channel_block_idx, mul24(weights_shape.y, weights_shape.x)) + mul24(select(0, (-height_start), height_start < 0), weights_shape.y);
|
||||
#else
|
||||
const int height_start = mad24((output_batch_height_idx % output_shape.x), stride_shape.x, -padding_shape.x);
|
||||
int in_height_start = mad24(select(0, (-height_start + dilation_shape.x - 1) / dilation_shape.x, height_start < 0), dilation_shape.x, height_start);
|
||||
const int kh_start = select(0, (-height_start + dilation_shape.x - 1) / dilation_shape.x, height_start < 0);
|
||||
int in_height_start = mad24(kh_start, dilation_shape.x, height_start);
|
||||
int in_height_end = min(mad24(weights_shape.x, dilation_shape.x, height_start), input_shape.x);
|
||||
|
||||
const int batch_idx = mul24((output_batch_height_idx / output_shape.x), input_shape.x);
|
||||
const int weights_h_idx = mul24(out_channel_block_idx, mul24(weights_shape.y, weights_shape.x)) + mul24(select(0, (-height_start + dilation_shape.x - 1) / dilation_shape.x, height_start < 0), weights_shape.y);
|
||||
#endif
|
||||
|
||||
#ifdef USE_BUFFER
|
||||
const int weight_oc_offset = out_channel_blocks * weights_shape.x * weights_shape.y * 4;
|
||||
#endif
|
||||
|
||||
FLOAT4 in0, in1, in2, in3;
|
||||
FLOAT4 weights0, weights1, weights2, weights3;
|
||||
for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) {
|
||||
const int in_idx = mul24(in_channel_block_idx, input_shape.y);
|
||||
#ifdef USE_BUFFER
|
||||
int weight_offset = ((((4*in_channel_block_idx+0)* out_channel_blocks + out_channel_block_idx) *weights_shape.x + kh_start)*weights_shape.y + 0) * 4;
|
||||
#else
|
||||
int weights_x_idx = in_channel_block_idx << 2;
|
||||
int weights_y_idx = weights_h_idx;
|
||||
#endif
|
||||
for (int iy = in_height_start; iy < in_height_end; iy += dilation_shape.x) {
|
||||
int in_hb_value = iy + batch_idx;
|
||||
#ifdef MNN_CONV_S1D1
|
||||
|
@ -453,12 +662,18 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only
|
|||
READ_INPUT_IMAGE(1, 0);
|
||||
READ_INPUT_IMAGE(2, 0);
|
||||
READ_INPUT_IMAGE(3, 0);
|
||||
|
||||
#ifdef USE_BUFFER
|
||||
weights0 = vload4(0, weights+weight_offset);
|
||||
weights1 = vload4(0, weights+weight_offset+weight_oc_offset);
|
||||
weights2 = vload4(0, weights+weight_offset+weight_oc_offset*2);
|
||||
weights3 = vload4(0, weights+weight_offset+weight_oc_offset*3);
|
||||
weight_offset += 4;
|
||||
#else
|
||||
weights0 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 0, weights_y_idx));
|
||||
weights1 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 1, weights_y_idx));
|
||||
weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx));
|
||||
weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++));
|
||||
|
||||
#endif
|
||||
CALCULATE_OUTPUT(0);
|
||||
CALCULATE_OUTPUT(1);
|
||||
CALCULATE_OUTPUT(2);
|
||||
|
@ -469,12 +684,18 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only
|
|||
in1 = in2;
|
||||
in2 = in3;
|
||||
READ_INPUT_IMAGE(3, w);
|
||||
|
||||
#ifdef USE_BUFFER
|
||||
weights0 = vload4(0, weights+weight_offset);
|
||||
weights1 = vload4(0, weights+weight_offset+weight_oc_offset);
|
||||
weights2 = vload4(0, weights+weight_offset+weight_oc_offset*2);
|
||||
weights3 = vload4(0, weights+weight_offset+weight_oc_offset*3);
|
||||
weight_offset += 4;
|
||||
#else
|
||||
weights0 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 0, weights_y_idx));
|
||||
weights1 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 1, weights_y_idx));
|
||||
weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx));
|
||||
weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++));
|
||||
|
||||
#endif
|
||||
CALCULATE_OUTPUT(0);
|
||||
CALCULATE_OUTPUT(1);
|
||||
CALCULATE_OUTPUT(2);
|
||||
|
@ -487,12 +708,18 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only
|
|||
READ_INPUT_IMAGE(1, input_width_base);
|
||||
READ_INPUT_IMAGE(2, input_width_base);
|
||||
READ_INPUT_IMAGE(3, input_width_base);
|
||||
|
||||
#ifdef USE_BUFFER
|
||||
weights0 = vload4(0, weights+weight_offset);
|
||||
weights1 = vload4(0, weights+weight_offset+weight_oc_offset);
|
||||
weights2 = vload4(0, weights+weight_offset+weight_oc_offset*2);
|
||||
weights3 = vload4(0, weights+weight_offset+weight_oc_offset*3);
|
||||
weight_offset += 4;
|
||||
#else
|
||||
weights0 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 0, weights_y_idx));
|
||||
weights1 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 1, weights_y_idx));
|
||||
weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx));
|
||||
weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++));
|
||||
|
||||
#endif
|
||||
CALCULATE_OUTPUT(0);
|
||||
CALCULATE_OUTPUT(1);
|
||||
CALCULATE_OUTPUT(2);
|
||||
|
@ -542,7 +769,12 @@ __kernel
|
|||
#if SET_ATTRIBUTE
|
||||
__attribute__((work_group_size_hint(16, 16, 1)))
|
||||
#endif
|
||||
void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t weights,
|
||||
void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
|
||||
#ifdef USE_BUFFER
|
||||
__global const FLOAT *weights,
|
||||
#else
|
||||
__read_only image2d_t weights,
|
||||
#endif
|
||||
#ifdef BIAS
|
||||
__read_only image2d_t bias,
|
||||
#endif
|
||||
|
@ -581,6 +813,11 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only
|
|||
FLOAT4 out6 = out4;
|
||||
FLOAT4 out7 = out4;
|
||||
|
||||
#ifdef USE_BUFFER
|
||||
const int weight_oc_offset = weights_shape.x * weights_shape.y * 4;
|
||||
const int weight_ic_offset = out_channel_blocks * weight_oc_offset;
|
||||
#endif
|
||||
|
||||
int in_width0 = mad24(out_width_block_idx, stride_shape.y, -padding_shape.y);
|
||||
int in_height0 = mad24(out_height_block_idx, stride_shape.x<<2, -padding_shape.x);
|
||||
int in_height1 = in_height0 + stride_shape.x;
|
||||
|
@ -595,8 +832,12 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only
|
|||
FLOAT4 weights0, weights1, weights2, weights3, weights4, weights5, weights6, weights7;
|
||||
for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) {
|
||||
const int in_idx = mul24(in_channel_block_idx, input_shape.y);
|
||||
#ifdef USE_BUFFER
|
||||
int weight_offset = ((((4*in_channel_block_idx+0)* out_channel_blocks + out_channel_block_idx) *weights_shape.x + 0)*weights_shape.y + 0) * 4;
|
||||
#else
|
||||
int weights_x_idx = in_channel_block_idx << 2;
|
||||
int weights_y_idx = weights_h_idx;
|
||||
#endif
|
||||
for (int iy = 0; iy < weights_shape.x * dilation_shape.x; iy += dilation_shape.x) {
|
||||
int h0 = select(in_height0 + iy + batch_idx, -1, (in_height0 + iy < 0 || in_height0 + iy >= input_shape.x));
|
||||
int h1 = select(in_height1 + iy + batch_idx, -1, (in_height1 + iy < 0 || in_height1 + iy >= input_shape.x));
|
||||
|
@ -610,6 +851,17 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only
|
|||
in2 = RI_F(input, SAMPLER, (int2)(w0, h2));
|
||||
in3 = RI_F(input, SAMPLER, (int2)(w0, h3));
|
||||
|
||||
#ifdef USE_BUFFER
|
||||
weights0 = vload4(0, weights+weight_offset);
|
||||
weights1 = vload4(0, weights+weight_offset+weight_ic_offset);
|
||||
weights2 = vload4(0, weights+weight_offset+weight_ic_offset*2);
|
||||
weights3 = vload4(0, weights+weight_offset+weight_ic_offset*3);
|
||||
weights4 = vload4(0, weights+weight_offset + weight_oc_offset);
|
||||
weights5 = vload4(0, weights+weight_offset+weight_ic_offset + weight_oc_offset);
|
||||
weights6 = vload4(0, weights+weight_offset+weight_ic_offset*2 + weight_oc_offset);
|
||||
weights7 = vload4(0, weights+weight_offset+weight_ic_offset*3 + weight_oc_offset);
|
||||
weight_offset += 4;
|
||||
#else
|
||||
weights0 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 0, weights_y_idx));
|
||||
weights1 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 1, weights_y_idx));
|
||||
weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx));
|
||||
|
@ -618,7 +870,8 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only
|
|||
weights5 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 1, weight_size + weights_y_idx));
|
||||
weights6 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weight_size + weights_y_idx));
|
||||
weights7 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weight_size + weights_y_idx++));
|
||||
|
||||
#endif
|
||||
|
||||
CALCULATE_OUTPUT(0);
|
||||
CALCULATE_OUTPUT(1);
|
||||
CALCULATE_OUTPUT(2);
|
||||
|
@ -703,7 +956,12 @@ __kernel
|
|||
#if SET_ATTRIBUTE
|
||||
__attribute__((work_group_size_hint(16, 16, 1)))
|
||||
#endif
|
||||
void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only image2d_t weights,
|
||||
void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,
|
||||
#ifdef USE_BUFFER
|
||||
__global const FLOAT *weights,
|
||||
#else
|
||||
__read_only image2d_t weights,
|
||||
#endif
|
||||
#ifdef BIAS
|
||||
__read_only image2d_t bias,
|
||||
#endif
|
||||
|
@ -749,10 +1007,17 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only
|
|||
|
||||
FLOAT4 in0, in1, in2, in3;
|
||||
FLOAT4 weights0, weights1, weights2, weights3;
|
||||
#ifdef USE_BUFFER
|
||||
const int weight_oc_offset = out_channel_blocks * weights_shape.x * weights_shape.y * 4;
|
||||
#endif
|
||||
for (int in_channel_block_idx = 0; in_channel_block_idx < in_channel_block_length; ++in_channel_block_idx) {
|
||||
const int in_idx = mul24(in_channel_block_idx, input_shape.y);
|
||||
#ifdef USE_BUFFER
|
||||
int weight_offset = ((((4*in_channel_block_idx+0)* out_channel_blocks + out_channel_block_idx) *weights_shape.x + 0)*weights_shape.y + 0) * 4;
|
||||
#else
|
||||
int weights_x_idx = in_channel_block_idx << 2;
|
||||
int weights_y_idx = weights_h_idx;
|
||||
#endif
|
||||
for (int iy = 0; iy < weights_shape.x * dilation_shape.x; iy += dilation_shape.x) {
|
||||
int h0 = select(in_height0 + iy + batch_idx, -1, (in_height0 + iy < 0 || in_height0 + iy >= input_shape.x));
|
||||
int h1 = select(in_height1 + iy + batch_idx, -1, (in_height1 + iy < 0 || in_height1 + iy >= input_shape.x));
|
||||
|
@ -765,11 +1030,18 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input, __read_only
|
|||
in1 = RI_F(input, SAMPLER, (int2)(w0, h1));
|
||||
in2 = RI_F(input, SAMPLER, (int2)(w0, h2));
|
||||
in3 = RI_F(input, SAMPLER, (int2)(w0, h3));
|
||||
|
||||
#ifdef USE_BUFFER
|
||||
weights0 = vload4(0, weights+weight_offset);
|
||||
weights1 = vload4(0, weights+weight_offset+weight_oc_offset);
|
||||
weights2 = vload4(0, weights+weight_offset+weight_oc_offset*2);
|
||||
weights3 = vload4(0, weights+weight_offset+weight_oc_offset*3);
|
||||
weight_offset += 4;
|
||||
#else
|
||||
weights0 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 0, weights_y_idx));
|
||||
weights1 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 1, weights_y_idx));
|
||||
weights2 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 2, weights_y_idx));
|
||||
weights3 = RI_F(weights, SAMPLER, (int2)(weights_x_idx + 3, weights_y_idx++));
|
||||
#endif
|
||||
|
||||
CALCULATE_OUTPUT(0);
|
||||
CALCULATE_OUTPUT(1);
|
||||
|
|
|
@ -186,9 +186,13 @@ void conv_2d_c4h1w2(GLOBAL_SIZE_2_DIMS
|
|||
#endif
|
||||
|
||||
const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
|
||||
#ifdef BLOCK_LEAVE
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
if(out_w_idx + 1 >= out_hw.y) return;
|
||||
vstore4(out1, 1, output+out_offset);
|
||||
#else
|
||||
vstore8((FLOAT8)(out0, out1), 0, output+out_offset);
|
||||
#endif
|
||||
}
|
||||
|
||||
__kernel
|
||||
|
@ -298,13 +302,22 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS
|
|||
#endif
|
||||
|
||||
const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
if(out_w_idx + 1 >= out_hw.y) return;
|
||||
vstore4(out1, 1, output+out_offset);
|
||||
if(out_w_idx + 2 >= out_hw.y) return;
|
||||
vstore4(out2, 2, output+out_offset);
|
||||
if(out_w_idx + 3 >= out_hw.y) return;
|
||||
vstore4(out3, 3, output+out_offset);
|
||||
#ifdef BLOCK_LEAVE
|
||||
const int remain = out_hw.y - out_w_idx;
|
||||
|
||||
if (remain >= 4) {
|
||||
vstore16((FLOAT16)(out0, out1, out2, out3), 0, output+out_offset);
|
||||
}else if(remain == 3){
|
||||
vstore8((FLOAT8)(out0, out1), 0, output+out_offset);
|
||||
vstore4(out2, 2, output+out_offset);
|
||||
}else if(remain == 2){
|
||||
vstore8((FLOAT8)(out0, out1), 0, output+out_offset);
|
||||
}else if(remain == 1){
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
}
|
||||
#else
|
||||
vstore16((FLOAT16)(out0, out1, out2, out3), 0, output+out_offset);
|
||||
#endif
|
||||
}
|
||||
|
||||
__kernel
|
||||
|
@ -393,25 +406,21 @@ void conv_2d_1x1_c4h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
|
|||
#endif
|
||||
|
||||
const int out_offset = (((out_b_idx*out_c_block + out_c_idx)*out_h + out_h_idx)* out_w + out_w4_idx)*4;
|
||||
|
||||
#ifdef BLOCK_LEAVE
|
||||
const int remain = out_w - out_w4_idx;
|
||||
|
||||
if (remain >= 4) {
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
vstore4(out1, 1, output+out_offset);
|
||||
vstore4(out2, 2, output+out_offset);
|
||||
vstore4(out3, 3, output+out_offset);
|
||||
vstore16((FLOAT16)(out0, out1, out2, out3), 0, output+out_offset);
|
||||
} else if (remain == 3) {
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
vstore4(out1, 1, output+out_offset);
|
||||
vstore8((FLOAT8)(out0, out1), 0, output+out_offset);
|
||||
vstore4(out2, 2, output+out_offset);
|
||||
} else if (remain == 2) {
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
vstore4(out1, 1, output+out_offset);
|
||||
vstore8((FLOAT8)(out0, out1), 0, output+out_offset);
|
||||
} else if (remain == 1) {
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
}
|
||||
|
||||
#else
|
||||
vstore16((FLOAT16)(out0, out1, out2, out3), 0, output+out_offset);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
|
@ -539,44 +548,45 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
|
|||
|
||||
const int out_offset = (((out_b_idx*out_c_block + out_c_idx*2)*out_h + out_h_idx)* out_w + out_w4_idx)*4;
|
||||
|
||||
const int remain = out_w - out_w4_idx;
|
||||
|
||||
__global FLOAT* _tempoutput = output + out_offset;
|
||||
__global FLOAT* _tempoutput1 = _tempoutput + 4*out_h*out_w;
|
||||
|
||||
#ifdef BLOCK_LEAVE
|
||||
const int remain = out_w - out_w4_idx;
|
||||
if (remain >= 4) {
|
||||
vstore4(out0, 0, _tempoutput);
|
||||
vstore4(out1, 1, _tempoutput);
|
||||
vstore4(out2, 2, _tempoutput);
|
||||
vstore4(out3, 3, _tempoutput);
|
||||
vstore16((FLOAT16)(out0, out1, out2, out3), 0, _tempoutput);
|
||||
} else if (remain == 3) {
|
||||
vstore4(out0, 0, _tempoutput);
|
||||
vstore4(out1, 1, _tempoutput);
|
||||
vstore8((FLOAT8)(out0, out1), 0, _tempoutput);
|
||||
vstore4(out2, 2, _tempoutput);
|
||||
} else if (remain == 2) {
|
||||
vstore4(out0, 0, _tempoutput);
|
||||
vstore4(out1, 1, _tempoutput);
|
||||
vstore8((FLOAT8)(out0, out1), 0, _tempoutput);
|
||||
} else if (remain == 1) {
|
||||
vstore4(out0, 0, _tempoutput);
|
||||
}
|
||||
#ifdef CHANNEL_LEAVE
|
||||
if(out_c_idx*2+1 >= out_c_block) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
if (remain >= 4) {
|
||||
vstore4(out4, 0, _tempoutput1);
|
||||
vstore4(out5, 1, _tempoutput1);
|
||||
vstore4(out6, 2, _tempoutput1);
|
||||
vstore4(out7, 3, _tempoutput1);
|
||||
vstore16((FLOAT16)(out4, out5, out6, out7), 0, _tempoutput1);
|
||||
} else if (remain == 3) {
|
||||
vstore4(out4, 0, _tempoutput1);
|
||||
vstore4(out5, 1, _tempoutput1);
|
||||
vstore8((FLOAT8)(out4, out5), 0, _tempoutput1);
|
||||
vstore4(out6, 2, _tempoutput1);
|
||||
} else if (remain == 2) {
|
||||
vstore4(out4, 0, _tempoutput1);
|
||||
vstore4(out5, 1, _tempoutput1);
|
||||
vstore8((FLOAT8)(out4, out5), 0, _tempoutput1);
|
||||
} else if (remain == 1) {
|
||||
vstore4(out4, 0, _tempoutput1);
|
||||
}
|
||||
#else
|
||||
vstore16((FLOAT16)(out0, out1, out2, out3), 0, _tempoutput);
|
||||
#ifdef CHANNEL_LEAVE
|
||||
if(out_c_idx*2+1 >= out_c_block) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
vstore16((FLOAT16)(out4, out5, out6, out7), 0, _tempoutput1);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
|
@ -668,26 +678,36 @@ void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
|
|||
|
||||
const int out_offset = (((out_b_idx*out_c_block + out_c_idx*2)*out_h + out_h_idx)* out_w + out_w2_idx)*4;
|
||||
|
||||
const int remain = out_w - out_w2_idx;
|
||||
|
||||
__global FLOAT* _tempoutput = output + out_offset;
|
||||
__global FLOAT* _tempoutput1 = _tempoutput + 4*out_h*out_w;
|
||||
|
||||
#ifdef BLOCK_LEAVE
|
||||
const int remain = out_w - out_w2_idx;
|
||||
if (remain >= 2) {
|
||||
vstore4(out0, 0, _tempoutput);
|
||||
vstore4(out1, 1, _tempoutput);
|
||||
vstore8((FLOAT8)(out0, out1), 0, _tempoutput);
|
||||
} else if (remain == 1) {
|
||||
vstore4(out0, 0, _tempoutput);
|
||||
}
|
||||
#ifdef CHANNEL_LEAVE
|
||||
if(out_c_idx*2+1 >= out_c_block) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
if (remain >= 2) {
|
||||
vstore4(out4, 0, _tempoutput1);
|
||||
vstore4(out5, 1, _tempoutput1);
|
||||
vstore8((FLOAT8)(out4, out5), 0, _tempoutput1);
|
||||
} else if (remain == 1) {
|
||||
vstore4(out4, 0, _tempoutput1);
|
||||
}
|
||||
#else
|
||||
vstore8((FLOAT8)(out0, out1), 0, _tempoutput);
|
||||
#ifdef CHANNEL_LEAVE
|
||||
if(out_c_idx*2+1 >= out_c_block) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
vstore8((FLOAT8)(out4, out5), 0, _tempoutput1);
|
||||
#endif
|
||||
}
|
||||
|
||||
__kernel
|
||||
|
@ -814,15 +834,17 @@ void conv_2d_1x1_c4h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,
|
|||
|
||||
const int out_offset = (((out_b_idx*out_c_block + out_c_idx)*out_h + out_h_idx)* out_w + out_w2_idx)*4;
|
||||
|
||||
#ifdef BLOCK_LEAVE
|
||||
const int remain = out_w - out_w2_idx;
|
||||
|
||||
if (remain >= 2) {
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
vstore4(out1, 1, output+out_offset);
|
||||
vstore8((FLOAT8)(out0, out1), 0, output+out_offset);
|
||||
} else if (remain == 1) {
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
}
|
||||
|
||||
#else
|
||||
vstore8((FLOAT8)(out0, out1), 0, output+out_offset);
|
||||
#endif
|
||||
}
|
||||
|
||||
__kernel
|
||||
|
@ -932,13 +954,29 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS
|
|||
#endif
|
||||
|
||||
const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
|
||||
#ifdef BLOCK_LEAVE
|
||||
const int remain = out_hw.x - out_h_idx;
|
||||
if(remain >= 4){
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
vstore4(out1, out_hw.y, output+out_offset);
|
||||
vstore4(out2, 2 * out_hw.y, output+out_offset);
|
||||
vstore4(out3, 3 * out_hw.y, output+out_offset);
|
||||
}else if(remain == 3){
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
vstore4(out1, out_hw.y, output+out_offset);
|
||||
vstore4(out2, 2 * out_hw.y, output+out_offset);
|
||||
}else if(remain == 2){
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
vstore4(out1, out_hw.y, output+out_offset);
|
||||
}else if(remain == 1){
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
}
|
||||
#else
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
if(out_h_idx + 1 >= out_hw.x) return;
|
||||
vstore4(out1, out_hw.y, output+out_offset);
|
||||
if(out_h_idx + 2 >= out_hw.x) return;
|
||||
vstore4(out2, 2 * out_hw.y, output+out_offset);
|
||||
if(out_h_idx + 3 >= out_hw.x) return;
|
||||
vstore4(out3, 3 * out_hw.y, output+out_offset);
|
||||
#endif
|
||||
}
|
||||
|
||||
__kernel
|
||||
|
@ -1086,6 +1124,7 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS
|
|||
#endif
|
||||
|
||||
int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
|
||||
#ifdef BLOCK_LEAVE
|
||||
const int remain = out_hw.x - out_h_idx;
|
||||
if(remain >= 4){
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
|
@ -1102,9 +1141,11 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS
|
|||
}else if(remain == 1){
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
}
|
||||
#ifdef CHANNEL_LEAVE
|
||||
if(out_c_idx + 1 >= out_c_blocks){
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
|
||||
if(remain >= 4){
|
||||
vstore4(out4, 0, output+out_offset);
|
||||
|
@ -1121,6 +1162,22 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS
|
|||
}else if(remain == 1){
|
||||
vstore4(out4, 0, output+out_offset);
|
||||
}
|
||||
#else
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
vstore4(out1, out_hw.y, output+out_offset);
|
||||
vstore4(out2, 2 * out_hw.y, output+out_offset);
|
||||
vstore4(out3, 3 * out_hw.y, output+out_offset);
|
||||
#ifdef CHANNEL_LEAVE
|
||||
if(out_c_idx + 1 >= out_c_blocks){
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
|
||||
vstore4(out4, 0, output+out_offset);
|
||||
vstore4(out5, out_hw.y, output+out_offset);
|
||||
vstore4(out6, 2 * out_hw.y, output+out_offset);
|
||||
vstore4(out7, 3 * out_hw.y, output+out_offset);
|
||||
#endif
|
||||
}
|
||||
|
||||
__kernel
|
||||
|
@ -1230,6 +1287,7 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS
|
|||
#endif
|
||||
|
||||
int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
|
||||
#ifdef BLOCK_LEAVE
|
||||
const int remain = out_hw.x - out_h_idx;
|
||||
if(remain >= 2){
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
|
@ -1237,9 +1295,11 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS
|
|||
}else if(remain == 1){
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
}
|
||||
#ifdef CHANNEL_LEAVE
|
||||
if(out_c_idx + 1 >= out_c_blocks){
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
|
||||
if(remain >= 2){
|
||||
vstore4(out2, 0, output+out_offset);
|
||||
|
@ -1247,4 +1307,198 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS
|
|||
}else if(remain == 1){
|
||||
vstore4(out2, 0, output+out_offset);
|
||||
}
|
||||
#else
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
vstore4(out1, out_hw.y, output+out_offset);
|
||||
#ifdef CHANNEL_LEAVE
|
||||
if(out_c_idx + 1 >= out_c_blocks){
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
|
||||
vstore4(out2, 0, output+out_offset);
|
||||
vstore4(out3, out_hw.y, output+out_offset);
|
||||
#endif
|
||||
}
|
||||
|
||||
__kernel
|
||||
void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS
|
||||
__global const FLOAT *input,
|
||||
__global const FLOAT *weight,
|
||||
__global const FLOAT *bias,
|
||||
__global FLOAT *output,
|
||||
__private const int2 in_hw,
|
||||
__private const int inChannel,
|
||||
__private const int in_c_blocks,
|
||||
__private const int2 out_hw,
|
||||
__private const int2 filter_hw,
|
||||
__private const int2 stride_hw,
|
||||
__private const int2 pad_hw,
|
||||
__private const int2 dilate_hw,
|
||||
__private const int out_w_blocks,
|
||||
__private const int out_c_blocks,
|
||||
__private const int out_h_blocks) {
|
||||
const int out_c_w_idx = get_global_id(0); //c/4 w
|
||||
const int out_b_h_idx = get_global_id(1); //b h
|
||||
|
||||
DEAL_NON_UNIFORM_DIM2(out_c_w_idx, out_b_h_idx);
|
||||
|
||||
const int out_c_idx = (out_c_w_idx / out_w_blocks) << 1;
|
||||
const int out_w_idx = (out_c_w_idx % out_w_blocks) << 2;
|
||||
const int out_b_idx = out_b_h_idx / out_hw.x;//equal to in_b_idx
|
||||
const int out_h_idx = out_b_h_idx % out_hw.x;
|
||||
|
||||
FLOAT4 out0 = vload4(out_c_idx, bias);
|
||||
FLOAT4 out1 = out0;
|
||||
FLOAT4 out2 = out0;
|
||||
FLOAT4 out3 = out0;
|
||||
|
||||
FLOAT4 out4 = vload4(out_c_idx + 1, bias);
|
||||
FLOAT4 out5 = out4;
|
||||
FLOAT4 out6 = out4;
|
||||
FLOAT4 out7 = out4;
|
||||
|
||||
const int in_w0_idx_base = mad24(out_w_idx, stride_hw.y, -pad_hw.y);
|
||||
const int in_w1_idx_base = in_w0_idx_base + stride_hw.y;
|
||||
const int in_w2_idx_base = in_w1_idx_base + stride_hw.y;
|
||||
const int in_w3_idx_base = in_w2_idx_base + stride_hw.y;
|
||||
|
||||
const int in_h_idx_base = mad24(out_h_idx, stride_hw.x, -pad_hw.x);
|
||||
|
||||
const int kh_start = select(0, (-in_h_idx_base + dilate_hw.x - 1) / dilate_hw.x, in_h_idx_base < 0);
|
||||
const int in_h_idx_start = mad24(kh_start, dilate_hw.x, in_h_idx_base);
|
||||
const int in_h_idx_end = min(mad24(filter_hw.x, dilate_hw.x, in_h_idx_base), in_hw.x);
|
||||
|
||||
const int weight_oc_offset = filter_hw.x * filter_hw.y * 4;
|
||||
const int weight_ic_offset = out_c_blocks * weight_oc_offset;
|
||||
for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) {
|
||||
//weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4
|
||||
//index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0]
|
||||
int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4;
|
||||
|
||||
for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) {
|
||||
const int inp_offset_base = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + 0) * 4;
|
||||
|
||||
for(int fw = 0; fw < filter_hw.y; fw++) {
|
||||
const int in_w0_idx = fw * dilate_hw.y + in_w0_idx_base;
|
||||
const int in_w1_idx = fw * dilate_hw.y + in_w1_idx_base;
|
||||
const int in_w2_idx = fw * dilate_hw.y + in_w2_idx_base;
|
||||
const int in_w3_idx = fw * dilate_hw.y + in_w3_idx_base;
|
||||
|
||||
FLOAT4 in0 = (in_w0_idx < 0 || in_w0_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w0_idx, input+inp_offset_base);
|
||||
FLOAT4 in1 = (in_w1_idx < 0 || in_w1_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w1_idx, input+inp_offset_base);
|
||||
FLOAT4 in2 = (in_w2_idx < 0 || in_w2_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w2_idx, input+inp_offset_base);
|
||||
FLOAT4 in3 = (in_w3_idx < 0 || in_w3_idx >= in_hw.y) ? (FLOAT4)0 : vload4(in_w3_idx, input+inp_offset_base);
|
||||
|
||||
FLOAT4 weight0 = vload4(0, weight+weight_offset);
|
||||
FLOAT4 weight1 = vload4(0, weight+weight_offset+weight_ic_offset);
|
||||
FLOAT4 weight2 = vload4(0, weight+weight_offset+weight_ic_offset*2);
|
||||
FLOAT4 weight3 = vload4(0, weight+weight_offset+weight_ic_offset*3);
|
||||
|
||||
out0 = mad(in0.x, weight0, out0);
|
||||
out0 = mad(in0.y, weight1, out0);
|
||||
out0 = mad(in0.z, weight2, out0);
|
||||
out0 = mad(in0.w, weight3, out0);
|
||||
|
||||
out1 = mad(in1.x, weight0, out1);
|
||||
out1 = mad(in1.y, weight1, out1);
|
||||
out1 = mad(in1.z, weight2, out1);
|
||||
out1 = mad(in1.w, weight3, out1);
|
||||
|
||||
out2 = mad(in2.x, weight0, out2);
|
||||
out2 = mad(in2.y, weight1, out2);
|
||||
out2 = mad(in2.z, weight2, out2);
|
||||
out2 = mad(in2.w, weight3, out2);
|
||||
|
||||
out3 = mad(in3.x, weight0, out3);
|
||||
out3 = mad(in3.y, weight1, out3);
|
||||
out3 = mad(in3.z, weight2, out3);
|
||||
out3 = mad(in3.w, weight3, out3);
|
||||
|
||||
weight0 = vload4(0, weight+weight_oc_offset+weight_offset);
|
||||
weight1 = vload4(0, weight+weight_oc_offset+weight_offset+weight_ic_offset);
|
||||
weight2 = vload4(0, weight+weight_oc_offset+weight_offset+weight_ic_offset*2);
|
||||
weight3 = vload4(0, weight+weight_oc_offset+weight_offset+weight_ic_offset*3);
|
||||
|
||||
out4 = mad(in0.x, weight0, out4);
|
||||
out4 = mad(in0.y, weight1, out4);
|
||||
out4 = mad(in0.z, weight2, out4);
|
||||
out4 = mad(in0.w, weight3, out4);
|
||||
|
||||
out5 = mad(in1.x, weight0, out5);
|
||||
out5 = mad(in1.y, weight1, out5);
|
||||
out5 = mad(in1.z, weight2, out5);
|
||||
out5 = mad(in1.w, weight3, out5);
|
||||
|
||||
out6 = mad(in2.x, weight0, out6);
|
||||
out6 = mad(in2.y, weight1, out6);
|
||||
out6 = mad(in2.z, weight2, out6);
|
||||
out6 = mad(in2.w, weight3, out6);
|
||||
|
||||
out7 = mad(in3.x, weight0, out7);
|
||||
out7 = mad(in3.y, weight1, out7);
|
||||
out7 = mad(in3.z, weight2, out7);
|
||||
out7 = mad(in3.w, weight3, out7);
|
||||
|
||||
weight_offset += 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef RELU
|
||||
out0 = fmax(out0, (FLOAT4)0);
|
||||
out1 = fmax(out1, (FLOAT4)0);
|
||||
out2 = fmax(out2, (FLOAT4)0);
|
||||
out3 = fmax(out3, (FLOAT4)0);
|
||||
out4 = fmax(out4, (FLOAT4)0);
|
||||
out5 = fmax(out5, (FLOAT4)0);
|
||||
out6 = fmax(out6, (FLOAT4)0);
|
||||
out7 = fmax(out7, (FLOAT4)0);
|
||||
#endif
|
||||
|
||||
#ifdef RELU6
|
||||
out0 = clamp(out0, (FLOAT4)0, (FLOAT4)6);
|
||||
out1 = clamp(out1, (FLOAT4)0, (FLOAT4)6);
|
||||
out2 = clamp(out2, (FLOAT4)0, (FLOAT4)6);
|
||||
out3 = clamp(out3, (FLOAT4)0, (FLOAT4)6);
|
||||
out4 = clamp(out4, (FLOAT4)0, (FLOAT4)6);
|
||||
out5 = clamp(out5, (FLOAT4)0, (FLOAT4)6);
|
||||
out6 = clamp(out6, (FLOAT4)0, (FLOAT4)6);
|
||||
out7 = clamp(out7, (FLOAT4)0, (FLOAT4)6);
|
||||
#endif
|
||||
|
||||
int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
|
||||
#ifdef BLOCK_LEAVE
|
||||
const int remain = out_hw.y - out_w_idx;
|
||||
if(remain >= 4){
|
||||
vstore16((FLOAT16)(out0, out1, out2, out3), 0, output+out_offset);
|
||||
}else if(remain == 3){
|
||||
vstore8((FLOAT8)(out0, out1), 0, output+out_offset);
|
||||
vstore4(out2, 2, output+out_offset);
|
||||
}else if(remain == 2){
|
||||
vstore8((FLOAT8)(out0, out1), 0, output+out_offset);
|
||||
}else if(remain == 1){
|
||||
vstore4(out0, 0, output+out_offset);
|
||||
}
|
||||
#ifdef CHANNEL_LEAVE
|
||||
if(out_c_idx + 1 >= out_c_blocks)return;
|
||||
#endif
|
||||
out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
|
||||
if(remain >= 4){
|
||||
vstore16((FLOAT16)(out4, out5, out6, out7), 0, output+out_offset);
|
||||
}else if(remain == 3){
|
||||
vstore8((FLOAT8)(out4, out5), 0, output+out_offset);
|
||||
vstore4(out6, 2, output+out_offset);
|
||||
}else if(remain == 2){
|
||||
vstore8((FLOAT8)(out4, out5), 0, output+out_offset);
|
||||
}else if(remain == 1){
|
||||
vstore4(out4, 0, output+out_offset);
|
||||
}
|
||||
#else
|
||||
vstore16((FLOAT16)(out0, out1, out2, out3), 0, output+out_offset);
|
||||
#ifdef CHANNEL_LEAVE
|
||||
if(out_c_idx + 1 >= out_c_blocks)return;
|
||||
#endif
|
||||
out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4;
|
||||
vstore16((FLOAT16)(out4, out5, out6, out7), 0, output+out_offset);
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -116,12 +116,154 @@ __kernel void gemmWinograd(__read_only image2d_t uInput, __read_only image2d_t u
|
|||
|
||||
__private int out_y_idx = mad24(pos_z, unitHeight, pos_y);
|
||||
__private int out_x_idx = mad24(pos_w, unitWidth, srcX);
|
||||
WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0);
|
||||
if(srcX + 1 >= unitWidth) return;
|
||||
WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1);
|
||||
if(srcX + 2 >= unitWidth) return;
|
||||
WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2);
|
||||
if(srcX + 3 >= unitWidth) return;
|
||||
WI_F(uOutput, (int2)(out_x_idx + 3, out_y_idx), o3);
|
||||
const int remain = unitWidth - srcX;
|
||||
if(remain >= 4){
|
||||
WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 3, out_y_idx), o3);
|
||||
}else if(remain == 3){
|
||||
WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2);
|
||||
}else if(remain == 2){
|
||||
WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1);
|
||||
}else if(remain == 1){
|
||||
WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__kernel void gemmWinogradW2(__read_only image2d_t uInput, __read_only image2d_t uKernel, __write_only image2d_t uOutput,
|
||||
__private const int unitWidth, __private const int unitHeight, __private const int dstChannelC4, __private const int multiLength, __private const int alpha2) {
|
||||
|
||||
int2 pos = (int2)(get_global_id(0), get_global_id(1));
|
||||
const int unitWidth8 = (unitWidth + 7) / 8;
|
||||
if (pos.x < unitWidth8 * unitHeight && pos.y < alpha2 * dstChannelC4) {
|
||||
|
||||
const int pos_x = pos.x % unitWidth8;
|
||||
const int pos_y = pos.x / unitWidth8;
|
||||
const int pos_z = pos.y % dstChannelC4;
|
||||
const int pos_w = pos.y / dstChannelC4;
|
||||
|
||||
FLOAT4 o0 = (FLOAT4)(0);
|
||||
FLOAT4 o1 = (FLOAT4)(0);
|
||||
FLOAT4 o2 = (FLOAT4)(0);
|
||||
FLOAT4 o3 = (FLOAT4)(0);
|
||||
FLOAT4 o4 = (FLOAT4)(0);
|
||||
FLOAT4 o5 = (FLOAT4)(0);
|
||||
FLOAT4 o6 = (FLOAT4)(0);
|
||||
FLOAT4 o7 = (FLOAT4)(0);
|
||||
int srcY = mad24(pos_w, unitHeight, pos_y);
|
||||
int srcX = pos_x << 3;
|
||||
|
||||
for (int k = 0; k < multiLength; ++k) {
|
||||
__private int index = mul24(k, 4);
|
||||
__private int x_offset = mul24(k, unitWidth);
|
||||
FLOAT4 k0 = RI_F(uKernel, SAMPLER, (int2)(index, pos.y));
|
||||
FLOAT4 k1 = RI_F(uKernel, SAMPLER, (int2)(index + 1, pos.y));
|
||||
FLOAT4 k2 = RI_F(uKernel, SAMPLER, (int2)(index + 2, pos.y));
|
||||
FLOAT4 k3 = RI_F(uKernel, SAMPLER, (int2)(index + 3, pos.y));
|
||||
|
||||
FLOAT4 s0 = RI_F(uInput, SAMPLER, (int2)(srcX + x_offset, srcY));
|
||||
FLOAT4 s1 = RI_F(uInput, SAMPLER, (int2)(srcX + x_offset + 1, srcY));
|
||||
FLOAT4 s2 = RI_F(uInput, SAMPLER, (int2)(srcX + x_offset + 2, srcY));
|
||||
FLOAT4 s3 = RI_F(uInput, SAMPLER, (int2)(srcX + x_offset + 3, srcY));
|
||||
FLOAT4 s4 = RI_F(uInput, SAMPLER, (int2)(srcX + x_offset + 4, srcY));
|
||||
FLOAT4 s5 = RI_F(uInput, SAMPLER, (int2)(srcX + x_offset + 5, srcY));
|
||||
FLOAT4 s6 = RI_F(uInput, SAMPLER, (int2)(srcX + x_offset + 6, srcY));
|
||||
FLOAT4 s7 = RI_F(uInput, SAMPLER, (int2)(srcX + x_offset + 7, srcY));
|
||||
|
||||
o0 = mad(s0.x, k0, o0);
|
||||
o0 = mad(s0.y, k1, o0);
|
||||
o0 = mad(s0.z, k2, o0);
|
||||
o0 = mad(s0.w, k3, o0);
|
||||
|
||||
o1 = mad(s1.x, k0, o1);
|
||||
o1 = mad(s1.y, k1, o1);
|
||||
o1 = mad(s1.z, k2, o1);
|
||||
o1 = mad(s1.w, k3, o1);
|
||||
|
||||
o2 = mad(s2.x, k0, o2);
|
||||
o2 = mad(s2.y, k1, o2);
|
||||
o2 = mad(s2.z, k2, o2);
|
||||
o2 = mad(s2.w, k3, o2);
|
||||
|
||||
o3 = mad(s3.x, k0, o3);
|
||||
o3 = mad(s3.y, k1, o3);
|
||||
o3 = mad(s3.z, k2, o3);
|
||||
o3 = mad(s3.w, k3, o3);
|
||||
|
||||
o4 = mad(s4.x, k0, o4);
|
||||
o4 = mad(s4.y, k1, o4);
|
||||
o4 = mad(s4.z, k2, o4);
|
||||
o4 = mad(s4.w, k3, o4);
|
||||
|
||||
o5 = mad(s5.x, k0, o5);
|
||||
o5 = mad(s5.y, k1, o5);
|
||||
o5 = mad(s5.z, k2, o5);
|
||||
o5 = mad(s5.w, k3, o5);
|
||||
|
||||
o6 = mad(s6.x, k0, o6);
|
||||
o6 = mad(s6.y, k1, o6);
|
||||
o6 = mad(s6.z, k2, o6);
|
||||
o6 = mad(s6.w, k3, o6);
|
||||
|
||||
o7 = mad(s7.x, k0, o7);
|
||||
o7 = mad(s7.y, k1, o7);
|
||||
o7 = mad(s7.z, k2, o7);
|
||||
o7 = mad(s7.w, k3, o7);
|
||||
}
|
||||
|
||||
__private int out_y_idx = mad24(pos_z, unitHeight, pos_y);
|
||||
__private int out_x_idx = mad24(pos_w, unitWidth, srcX);
|
||||
const int remain = unitWidth - srcX;
|
||||
if(remain >= 8){
|
||||
WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 3, out_y_idx), o3);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 4, out_y_idx), o4);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 5, out_y_idx), o5);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 6, out_y_idx), o6);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 7, out_y_idx), o7);
|
||||
}else if(remain == 7){
|
||||
WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 3, out_y_idx), o3);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 4, out_y_idx), o4);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 5, out_y_idx), o5);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 6, out_y_idx), o6);
|
||||
}else if(remain == 6){
|
||||
WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 3, out_y_idx), o3);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 4, out_y_idx), o4);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 5, out_y_idx), o5);
|
||||
}else if(remain == 5){
|
||||
WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 3, out_y_idx), o3);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 4, out_y_idx), o4);
|
||||
}else if(remain == 4){
|
||||
WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 3, out_y_idx), o3);
|
||||
}else if(remain == 3){
|
||||
WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 2, out_y_idx), o2);
|
||||
}else if(remain == 2){
|
||||
WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0);
|
||||
WI_F(uOutput, (int2)(out_x_idx + 1, out_y_idx), o1);
|
||||
}else if(remain == 1){
|
||||
WI_F(uOutput, (int2)(out_x_idx, out_y_idx), o0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -68,6 +68,35 @@ __kernel void raster_buffer(
|
|||
output[outputIndex] = input[inputIndex];
|
||||
}
|
||||
|
||||
__kernel void raster_buffer_combine(
|
||||
GLOBAL_SIZE_3_DIMS
|
||||
__global FLOAT *input,
|
||||
__private const int inputOffset,
|
||||
__private const int combineSrcOffset,
|
||||
__private const int inputStride0,
|
||||
__private const int inputStride1,
|
||||
__private const int inputStride2,
|
||||
__global FLOAT *output,
|
||||
__private const int outputOffset,
|
||||
__private const int combineDstOffset,
|
||||
__private const int outputStride0,
|
||||
__private const int outputStride1,
|
||||
__private const int outputStride2,
|
||||
__private const int global_size0
|
||||
) {
|
||||
const int idx = get_global_id(0);
|
||||
const int y = get_global_id(1);
|
||||
const int z = get_global_id(2);
|
||||
|
||||
DEAL_NON_UNIFORM_DIM3(idx, y, z);
|
||||
const int x = idx % global_size0;
|
||||
const int id = idx / global_size0;
|
||||
|
||||
int inputIndex = inputOffset + id * combineSrcOffset + z * inputStride0 + y * inputStride1 + x * inputStride2;
|
||||
int outputIndex = outputOffset + id * combineDstOffset + z * outputStride0 + y * outputStride1 + x * outputStride2;
|
||||
output[outputIndex] = input[inputIndex];
|
||||
}
|
||||
|
||||
|
||||
__kernel void raster_image(
|
||||
GLOBAL_SIZE_3_DIMS
|
||||
|
|
|
@ -20,7 +20,8 @@ ErrorCode CommonExecution::onExecute(const std::vector<Tensor *> &inputs, const
|
|||
int idx = 0;
|
||||
#else
|
||||
if(runtime->isUseRecordQueue()){
|
||||
runtime->getRecordings()->emplace_back(mRecording);
|
||||
if(runtime->isDevideOpRecord())
|
||||
runtime->getRecordings()->emplace_back(mRecording);
|
||||
return NO_ERROR;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -77,6 +77,8 @@ ConvExecution::ConvExecution(const std::vector<Tensor *> &inputs, const std::vec
|
|||
int kernelWidth = conv2dCommonParams->kernelX();
|
||||
int kernelHeight = conv2dCommonParams->kernelY();
|
||||
int outputChannel = conv2dCommonParams->outputCount();
|
||||
auto gpuType = mOpenCLBackend->getOpenCLRuntime()->getGpuType();
|
||||
mWeightUseBuffer = gpuType == GpuType::MALI;
|
||||
|
||||
int weightSize = 0;
|
||||
const float *filterDataPtr = nullptr;
|
||||
|
@ -103,13 +105,12 @@ ConvExecution::ConvExecution(const std::vector<Tensor *> &inputs, const std::vec
|
|||
}
|
||||
int inputChannel = weightSize / (kernelWidth * kernelHeight * outputChannel);
|
||||
|
||||
auto gpuType = mOpenCLBackend->getOpenCLRuntime()->getGpuType();
|
||||
|
||||
//select opt conv method
|
||||
std::string kernelName = "conv_2d_c4h1w4";
|
||||
if (kernelHeight == kernelWidth && kernelHeight == 1 && mPaddings[0] == 0 &&
|
||||
mPaddings[1] == 0) {
|
||||
mConv1x1Opt = (mStrides[0] == 1 && mStrides[1] == 1 && gpuType == GpuType::MALI);
|
||||
mConv1x1Opt = (mStrides[0] == 1 && mStrides[1] == 1 && gpuType == GpuType::MALI && !mWeightUseBuffer);
|
||||
#if 0
|
||||
if((gpuType == GpuType::ADRENO)){
|
||||
uint64_t useLocalSize = UNIT*UNIT*4*sizeof(float)*4;
|
||||
|
@ -193,6 +194,36 @@ ConvExecution::ConvExecution(const std::vector<Tensor *> &inputs, const std::vec
|
|||
}
|
||||
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*(mBiasBuffer.get()), biasPtrCL);
|
||||
|
||||
}else if(kernelHeight == kernelWidth && kernelHeight == 1 && mPaddings[0] == 0 && mPaddings[1] == 0 && mWeightUseBuffer){
|
||||
cl_int error;
|
||||
std::shared_ptr<Tensor> filterBuffer(Tensor::createDevice<float>({UP_DIV(outputChannel, 4), ROUND_UP(inputChannel, 4), 4}));
|
||||
|
||||
int buffer_size = filterBuffer->elementSize();
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()) {
|
||||
buffer_size *= sizeof(half_float::half);
|
||||
} else {
|
||||
buffer_size *= sizeof(float);
|
||||
}
|
||||
|
||||
mKernelBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size));
|
||||
auto kernelBufferPtr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*(mKernelBuffer.get()), true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &error);
|
||||
if(kernelBufferPtr != nullptr && error == CL_SUCCESS){
|
||||
::memset(kernelBufferPtr, 0, buffer_size);
|
||||
for(int o = 0; o < outputChannel; o++){
|
||||
for(int i = 0 ; i < inputChannel; i++){
|
||||
int bufferIdx = (o/4) * ROUND_UP(inputChannel, 4)*4 + i*4 + (o%4);
|
||||
int filterIdx = o*inputChannel + i;
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()){
|
||||
((half_float::half*)kernelBufferPtr)[bufferIdx] = (half_float::half)(filterDataPtr[filterIdx]);
|
||||
}else{
|
||||
((float*)kernelBufferPtr)[bufferIdx] = (float)(filterDataPtr[filterIdx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}else{
|
||||
MNN_ERROR("Map error ptrCL == nullptr \n");
|
||||
}
|
||||
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*(mKernelBuffer.get()), kernelBufferPtr);
|
||||
}else{
|
||||
std::vector<int> filterImageShape{(int)inputChannel, (int)(UP_DIV(outputChannel, 4) * kernelWidth * kernelHeight)};
|
||||
std::shared_ptr<Tensor> filterBuffer(
|
||||
|
@ -223,15 +254,34 @@ ConvExecution::ConvExecution(const std::vector<Tensor *> &inputs, const std::vec
|
|||
}
|
||||
mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBufferCL, ptrCL);
|
||||
|
||||
mFilter.reset(Tensor::createDevice<float>({1, filterImageShape[1], 1, 4 * filterImageShape[0]}));
|
||||
mOpenCLBackend->onAcquireBuffer(mFilter.get(), Backend::STATIC);
|
||||
MNN::OpenCL::ImageBufferConvertor imageBufferConvertor{mOpenCLBackend->getOpenCLRuntime()};
|
||||
|
||||
std::string buildOption = "";
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf() == false){
|
||||
buildOption = "-DBUFFER_INP_FP32";
|
||||
if(mWeightUseBuffer){
|
||||
mFilter.reset(Tensor::createDevice<float>({UP_DIV(inputChannel, 4)*4, UP_DIV(outputChannel, 4), kernelWidth * kernelHeight, 4}));
|
||||
int kernel_buffer_size = UP_DIV(outputChannel, 4)*4* UP_DIV(inputChannel, 4)*4* kernelWidth* kernelHeight;
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()) {
|
||||
kernel_buffer_size *= sizeof(half_float::half);
|
||||
} else {
|
||||
kernel_buffer_size *= sizeof(float);
|
||||
}
|
||||
mKernelBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, kernel_buffer_size));
|
||||
mFilter.get()->buffer().device = (uint64_t)mKernelBuffer.get();
|
||||
MNN::OpenCL::BufferConvertor bufferConvertor{mOpenCLBackend->getOpenCLRuntime()};
|
||||
|
||||
bool needTrans = false;
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf() == false){
|
||||
needTrans = true;
|
||||
}
|
||||
bufferConvertor.convertToNC4HW4Buffer(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mFilter.get(), needTrans);
|
||||
} else{
|
||||
mFilter.reset(Tensor::createDevice<float>({1, filterImageShape[1], 1, 4 * filterImageShape[0]}));
|
||||
mOpenCLBackend->onAcquireBuffer(mFilter.get(), Backend::STATIC);
|
||||
MNN::OpenCL::ImageBufferConvertor imageBufferConvertor{mOpenCLBackend->getOpenCLRuntime()};
|
||||
|
||||
std::string buildOption = "";
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isWeightCpuTransHalf() == false){
|
||||
buildOption = "-DBUFFER_INP_FP32";
|
||||
}
|
||||
imageBufferConvertor.convertBufferToImage(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mFilter.get(), false, buildOption);
|
||||
}
|
||||
imageBufferConvertor.convertBufferToImage(filterBuffer.get(), MNN::OpenCL::CONV2D_FILTER, mFilter.get(), false, buildOption);
|
||||
}
|
||||
|
||||
// Create Kernel
|
||||
|
@ -244,6 +294,9 @@ ConvExecution::ConvExecution(const std::vector<Tensor *> &inputs, const std::vec
|
|||
} else if (mConv2dCommonParams->relu6()) {
|
||||
mBuildOptions.emplace("-DRELU6");
|
||||
}
|
||||
if(mWeightUseBuffer){
|
||||
mBuildOptions.emplace("-DUSE_BUFFER");
|
||||
}
|
||||
|
||||
|
||||
mKernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName, mBuildOptions);
|
||||
|
@ -255,7 +308,7 @@ ConvExecution::ConvExecution(const std::vector<Tensor *> &inputs, const std::vec
|
|||
}
|
||||
|
||||
ConvExecution::~ConvExecution() {
|
||||
if(mUseLocalMem || !mConv1x1Opt){
|
||||
if((mUseLocalMem || !mConv1x1Opt) && !mWeightUseBuffer){
|
||||
mOpenCLBackend->onReleaseBuffer(mFilter.get(), Backend::STATIC);
|
||||
}
|
||||
}
|
||||
|
@ -329,28 +382,77 @@ ErrorCode ConvExecution::onResize(const std::vector<Tensor *> &inputs, const std
|
|||
|
||||
|
||||
}else{
|
||||
mGlobalWorkSize = {
|
||||
static_cast<uint32_t>(UP_DIV(outputShape.at(3), 4) * static_cast<uint32_t>(UP_DIV(outputShape.at(2), 4))),
|
||||
static_cast<uint32_t>(outputShape.at(0) * outputShape.at(1))};
|
||||
|
||||
auto kernel = &mKernel;
|
||||
uint32_t idx = 0;
|
||||
int inputImageShape[2] = {inputHeight, inputWidth};
|
||||
int outputImageShape[2] = {height, width};
|
||||
int stideShape[2] = {mStrides[0], mStrides[1]};
|
||||
kernel->setArg(idx++, mGlobalWorkSize[0]);
|
||||
kernel->setArg(idx++, mGlobalWorkSize[1]);
|
||||
kernel->setArg(idx++, openCLImage(input));
|
||||
kernel->setArg(idx++, openCLImage(mFilter.get()));
|
||||
kernel->setArg(idx++, openCLImage(mBias.get()));
|
||||
kernel->setArg(idx++, openCLImage(output));
|
||||
kernel->setArg(idx++, sizeof(inputImageShape), inputImageShape);
|
||||
kernel->setArg(idx++, static_cast<int>(inputChannelBlocks));
|
||||
kernel->setArg(idx++, sizeof(outputImageShape), outputImageShape);
|
||||
kernel->setArg(idx++, sizeof(stideShape), stideShape);
|
||||
kernel->setArg(idx++, UP_DIV(width, 4));
|
||||
std::string kernelName = "conv_2d_1x1";
|
||||
mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName, mKernel).first;
|
||||
const int total_kernel = 2;
|
||||
std::string kernelName[total_kernel] = {"conv_2d_1x1", "conv_2d_1x1_c8h1w4"};
|
||||
int itemC[total_kernel] = {4, 8};
|
||||
int itemH[total_kernel] = {1, 1};
|
||||
int itemW[total_kernel] = {4, 4};
|
||||
|
||||
int actual_kernel = total_kernel;
|
||||
|
||||
cl::Kernel kernel[total_kernel];
|
||||
std::vector<uint32_t> globalWorkSize[total_kernel];
|
||||
std::vector<uint32_t> localWorkSize[total_kernel];
|
||||
std::pair<int, int> min_cost(INT_MAX, 0);//(min_time, min_index)
|
||||
|
||||
for(int knl_idx = 0; knl_idx < total_kernel; knl_idx++) {
|
||||
kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[knl_idx], mBuildOptions);
|
||||
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx]));
|
||||
|
||||
globalWorkSize[knl_idx] = {static_cast<uint32_t>(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast<uint32_t>(outputShape.at(0) * UP_DIV(outputShape.at(1), itemH[knl_idx]))};
|
||||
uint32_t idx = 0;
|
||||
kernel[knl_idx].setArg(idx++, globalWorkSize[knl_idx][0]);
|
||||
kernel[knl_idx].setArg(idx++, globalWorkSize[knl_idx][1]);
|
||||
kernel[knl_idx].setArg(idx++, openCLImage(input));
|
||||
if(mWeightUseBuffer){
|
||||
kernel[knl_idx].setArg(idx++, *mKernelBuffer.get());
|
||||
}else{
|
||||
kernel[knl_idx].setArg(idx++, openCLImage(mFilter.get()));
|
||||
}
|
||||
kernel[knl_idx].setArg(idx++, openCLImage(mBias.get()));
|
||||
kernel[knl_idx].setArg(idx++, openCLImage(output));
|
||||
kernel[knl_idx].setArg(idx++, sizeof(inputImageShape), inputImageShape);
|
||||
kernel[knl_idx].setArg(idx++, static_cast<int>(inputChannelBlocks));
|
||||
kernel[knl_idx].setArg(idx++, sizeof(outputImageShape), outputImageShape);
|
||||
kernel[knl_idx].setArg(idx++, sizeof(stideShape), stideShape);
|
||||
kernel[knl_idx].setArg(idx++, UP_DIV(width, 4));
|
||||
kernel[knl_idx].setArg(idx++, UP_DIV(outputShape.at(3), 4));
|
||||
|
||||
std::pair<std::vector<uint32_t>, uint32_t> retTune;
|
||||
retTune = localWS2DDefault(globalWorkSize[knl_idx], mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName[knl_idx] + info, kernel[knl_idx]);
|
||||
|
||||
//printf("conv1x1 kernel_%d = %d [%d, %d]\n", knl_idx, retTune.second, retTune.first[0], retTune.first[1]);
|
||||
if(min_cost.first > retTune.second) {
|
||||
min_cost.first = retTune.second;
|
||||
min_cost.second = knl_idx;
|
||||
mLocalWorkSize = {retTune.first[0], retTune.first[1]};
|
||||
}
|
||||
}
|
||||
int min_index = min_cost.second;
|
||||
//printf("min_index = %d %d\n", min_index, min_cost.first);
|
||||
mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]};
|
||||
mKernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d", kernelName[min_index], mBuildOptions);
|
||||
|
||||
uint32_t idx = 0;
|
||||
mKernel.setArg(idx++, mGlobalWorkSize[0]);
|
||||
mKernel.setArg(idx++, mGlobalWorkSize[1]);
|
||||
mKernel.setArg(idx++, openCLImage(input));
|
||||
if(mWeightUseBuffer){
|
||||
mKernel.setArg(idx++, *mKernelBuffer.get());
|
||||
}else{
|
||||
mKernel.setArg(idx++, openCLImage(mFilter.get()));
|
||||
}
|
||||
mKernel.setArg(idx++, openCLImage(mBias.get()));
|
||||
mKernel.setArg(idx++, openCLImage(output));
|
||||
mKernel.setArg(idx++, sizeof(inputImageShape), inputImageShape);
|
||||
mKernel.setArg(idx++, static_cast<int>(inputChannelBlocks));
|
||||
mKernel.setArg(idx++, sizeof(outputImageShape), outputImageShape);
|
||||
mKernel.setArg(idx++, sizeof(stideShape), stideShape);
|
||||
mKernel.setArg(idx++, UP_DIV(width, 4));
|
||||
mKernel.setArg(idx++, UP_DIV(outputShape.at(3), 4));
|
||||
recordKernel2d(mKernel, mGlobalWorkSize, mLocalWorkSize, mOpenCLBackend->getOpenCLRuntime());
|
||||
}
|
||||
}else {
|
||||
|
@ -385,7 +487,11 @@ ErrorCode ConvExecution::onResize(const std::vector<Tensor *> &inputs, const std
|
|||
ret |= kernel[knl_idx].setArg(idx++, globalWorkSize[knl_idx][0]);
|
||||
ret |= kernel[knl_idx].setArg(idx++, globalWorkSize[knl_idx][1]);
|
||||
ret |= kernel[knl_idx].setArg(idx++, openCLImage(input));
|
||||
ret |= kernel[knl_idx].setArg(idx++, openCLImage(mFilter.get()));
|
||||
if(mWeightUseBuffer){
|
||||
ret |= kernel[knl_idx].setArg(idx++, openCLBuffer(mFilter.get()));
|
||||
}else{
|
||||
ret |= kernel[knl_idx].setArg(idx++, openCLImage(mFilter.get()));
|
||||
}
|
||||
ret |= kernel[knl_idx].setArg(idx++, openCLImage(mBias.get()));
|
||||
ret |= kernel[knl_idx].setArg(idx++, openCLImage(output));
|
||||
ret |= kernel[knl_idx].setArg(idx++, sizeof(inputImageShape), inputImageShape);
|
||||
|
@ -418,7 +524,11 @@ ErrorCode ConvExecution::onResize(const std::vector<Tensor *> &inputs, const std
|
|||
ret |= mKernel.setArg(idx++, mGlobalWorkSize[0]);
|
||||
ret |= mKernel.setArg(idx++, mGlobalWorkSize[1]);
|
||||
ret |= mKernel.setArg(idx++, openCLImage(input));
|
||||
ret |= mKernel.setArg(idx++, openCLImage(mFilter.get()));
|
||||
if(mWeightUseBuffer){
|
||||
ret |= mKernel.setArg(idx++, openCLBuffer(mFilter.get()));
|
||||
}else{
|
||||
ret |= mKernel.setArg(idx++, openCLImage(mFilter.get()));
|
||||
}
|
||||
ret |= mKernel.setArg(idx++, openCLImage(mBias.get()));
|
||||
ret |= mKernel.setArg(idx++, openCLImage(output));
|
||||
ret |= mKernel.setArg(idx++, sizeof(inputImageShape), inputImageShape);
|
||||
|
@ -456,7 +566,8 @@ ErrorCode ConvExecution::onExecute(const std::vector<Tensor *> &inputs, const st
|
|||
MNN_PRINT("kernel cost:%f us Conv UseLocalMem\n",costTime);
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("end ConvExecution onExecute !\n");
|
||||
#endif
|
||||
|
@ -476,7 +587,8 @@ ErrorCode ConvExecution::onExecute(const std::vector<Tensor *> &inputs, const st
|
|||
MNN_PRINT("kernel cost:%d us Conv2D\n",costTime);
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("end ConvExecution onExecute !\n");
|
||||
#endif
|
||||
|
|
|
@ -57,6 +57,7 @@ private:
|
|||
std::shared_ptr<cl::Buffer> mKernelBuffer;
|
||||
std::shared_ptr<cl::Buffer> mBiasBuffer;
|
||||
std::set<std::string> mBuildOptions;
|
||||
bool mWeightUseBuffer = false;
|
||||
};
|
||||
|
||||
} // namespace OpenCL
|
||||
|
|
|
@ -237,7 +237,6 @@ ErrorCode ConvWinograd::onResize(const std::vector<Tensor*>& inputs, const std::
|
|||
"winogradTransformDest", buildOptions);
|
||||
mMaxWGS_D[i] = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mDestTransform[i]));
|
||||
}
|
||||
mMatMul[i] = runTime->buildKernel("gemm", "gemmWinograd", basic);
|
||||
mMaxWGS_M[i] = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mMatMul[i]));
|
||||
}
|
||||
|
||||
|
@ -261,14 +260,6 @@ ErrorCode ConvWinograd::onResize(const std::vector<Tensor*>& inputs, const std::
|
|||
ret |= mSourceTransform[b].setArg(8, icC4);
|
||||
ret |= mSourceTransform[b].setArg(9, b);
|
||||
|
||||
ret |= mMatMul[b].setArg(0, openCLImage(mSource.get()));
|
||||
ret |= mMatMul[b].setArg(1, *mWeight);
|
||||
ret |= mMatMul[b].setArg(2, openCLImage(mDest.get()));
|
||||
ret |= mMatMul[b].setArg(3, wUnit);
|
||||
ret |= mMatMul[b].setArg(4, hUnit);
|
||||
ret |= mMatMul[b].setArg(5, ocC4);
|
||||
ret |= mMatMul[b].setArg(6, icC4);
|
||||
ret |= mMatMul[b].setArg(7, alpha*alpha);
|
||||
|
||||
ret |= mDestTransform[b].setArg(0, openCLImage(mDest.get()));
|
||||
ret |= mDestTransform[b].setArg(1, *mBias);
|
||||
|
@ -291,10 +282,56 @@ ErrorCode ConvWinograd::onResize(const std::vector<Tensor*>& inputs, const std::
|
|||
|
||||
/*MatMul*/
|
||||
{
|
||||
const int total_kernel = 2;
|
||||
const std::string kernelName[total_kernel] = {"gemmWinograd", "gemmWinogradW2"};
|
||||
int itemW[total_kernel] = {4, 8};
|
||||
auto gemmHeight = ocC4;
|
||||
mGWS_M[b] = {static_cast<uint32_t>(UP_DIV(wUnit, 4) * hUnit), static_cast<uint32_t>(alpha * alpha * ocC4)};
|
||||
std::string kernelName = "gemmWinograd";
|
||||
mLWS_M[b] = localWS2DDefault(mGWS_M[b], mMaxWGS_M[b], mOpenCLBackend->getOpenCLRuntime(), kernelName, mMatMul[b]).first;
|
||||
int actual_kernel = total_kernel;
|
||||
|
||||
cl::Kernel kernel[total_kernel];
|
||||
std::vector<uint32_t> globalWorkSize[total_kernel];
|
||||
std::vector<uint32_t> localWorkSize[total_kernel];
|
||||
std::pair<uint32_t, int> min_cost(UINT_MAX, 0); //(min_time, min_index)
|
||||
for (int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) {
|
||||
cl_int ret = CL_SUCCESS;
|
||||
kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm", kernelName[knl_idx], basic);
|
||||
uint32_t maxWorkGroupSize = static_cast<uint32_t>(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx]));
|
||||
|
||||
globalWorkSize[knl_idx] = {static_cast<uint32_t>(UP_DIV(wUnit, itemW[knl_idx]) * hUnit), static_cast<uint32_t>(alpha * alpha * ocC4)};
|
||||
ret |= kernel[knl_idx].setArg(0, openCLImage(mSource.get()));
|
||||
ret |= kernel[knl_idx].setArg(1, *mWeight);
|
||||
ret |= kernel[knl_idx].setArg(2, openCLImage(mDest.get()));
|
||||
ret |= kernel[knl_idx].setArg(3, wUnit);
|
||||
ret |= kernel[knl_idx].setArg(4, hUnit);
|
||||
ret |= kernel[knl_idx].setArg(5, ocC4);
|
||||
ret |= kernel[knl_idx].setArg(6, icC4);
|
||||
ret |= kernel[knl_idx].setArg(7, alpha*alpha);
|
||||
MNN_CHECK_CL_SUCCESS(ret, "setArg ConvWinogradExecution gemm");
|
||||
|
||||
std::pair<std::vector<uint32_t>, uint32_t> retTune;
|
||||
retTune = localWS2DDefault(globalWorkSize[knl_idx], maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName[knl_idx], kernel[knl_idx]);
|
||||
// printf("gemm %d, %d\n", knl_idx, retTune.second);
|
||||
if (min_cost.first > retTune.second) {
|
||||
min_cost.first = retTune.second;
|
||||
min_cost.second = knl_idx;
|
||||
mLWS_M[b] = {retTune.first[0], retTune.first[1]};
|
||||
}
|
||||
}
|
||||
cl_int ret = CL_SUCCESS;
|
||||
int min_index = min_cost.second;
|
||||
//printf("gemm min_index = %d %d\n", min_index, min_cost.first);
|
||||
mMatMul[b] = runTime->buildKernel("gemm", kernelName[min_index], basic);
|
||||
|
||||
ret |= mMatMul[b].setArg(0, openCLImage(mSource.get()));
|
||||
ret |= mMatMul[b].setArg(1, *mWeight);
|
||||
ret |= mMatMul[b].setArg(2, openCLImage(mDest.get()));
|
||||
ret |= mMatMul[b].setArg(3, wUnit);
|
||||
ret |= mMatMul[b].setArg(4, hUnit);
|
||||
ret |= mMatMul[b].setArg(5, ocC4);
|
||||
ret |= mMatMul[b].setArg(6, icC4);
|
||||
ret |= mMatMul[b].setArg(7, alpha*alpha);
|
||||
MNN_CHECK_CL_SUCCESS(ret, "setArg ConvWinogradExecution gemm");
|
||||
mGWS_M[b] = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]};
|
||||
recordKernel2d(mMatMul[b], mGWS_M[b], mLWS_M[b], mOpenCLBackend->getOpenCLRuntime());
|
||||
}
|
||||
|
||||
|
@ -319,7 +356,8 @@ ErrorCode ConvWinograd::onExecute(const std::vector<Tensor*>& inputs, const std:
|
|||
int costTime = 0;
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
return NO_ERROR;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -182,7 +182,8 @@ ErrorCode DeconvExecution::onExecute(const std::vector<Tensor *> &inputs, const
|
|||
MNN_PRINT("kernel cost:%d us Deconv\n",costTime);
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("End DeconvExecution onExecute... \n");
|
||||
#endif
|
||||
|
|
|
@ -169,7 +169,8 @@ ErrorCode DepthwiseConvExecution::onExecute(const std::vector<Tensor *> &inputs,
|
|||
MNN_PRINT("kernel cost:%d us DepthwiseConv\n",costTime);
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("End DepthwiseConvExecution onExecute... \n");
|
||||
#endif
|
||||
|
|
|
@ -172,7 +172,8 @@ ErrorCode DepthwiseDeconvExecution::onExecute(const std::vector<Tensor *> &input
|
|||
MNN_PRINT("kernel cost:%d us DepthwiseDeconv\n",costTime);
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("End DepthwiseDeconvExecution onExecute... \n");
|
||||
#endif
|
||||
|
|
|
@ -87,7 +87,8 @@ ErrorCode FuseExecution::onExecute(const std::vector<Tensor *> &inputs, const st
|
|||
MNN_PRINT("kernel cost:%d us Fuse\n",costTime);
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("end SoftmaxExecution onExecute !\n");
|
||||
#endif
|
||||
|
|
|
@ -96,7 +96,8 @@ ErrorCode GridSampleExecution::onExecute(const std::vector<Tensor *> &inputs, co
|
|||
MNN_PRINT("kernel cost:%d us GridSample\n", costTime);
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
return NO_ERROR;
|
||||
}
|
||||
run3DKernelDefault(mKernel, mGlobalWorkSize, mLocalWorkSize, mOpenCLBackend->getOpenCLRuntime());
|
||||
|
|
|
@ -107,7 +107,8 @@ ErrorCode Interp3DExecution::onExecute(const std::vector<Tensor *> &inputs, cons
|
|||
MNN_PRINT("kernel cost:%d us Interp3D\n",costTime);
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("End Interp3DExecution onExecute... \n");
|
||||
#endif
|
||||
|
|
|
@ -99,7 +99,8 @@ ErrorCode InterpExecution::onExecute(const std::vector<Tensor *> &inputs, const
|
|||
MNN_PRINT("kernel cost:%d us Interp\n",costTime);
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("End InterpExecution onExecute... \n");
|
||||
#endif
|
||||
|
|
|
@ -180,7 +180,8 @@ ErrorCode LayerNormExecution::onExecute(const std::vector<Tensor *> &inputs, con
|
|||
MNN_PRINT("kernel cost:%d us LayerNorm\n",costTime);
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("End LayerNormExecution onExecute... \n");
|
||||
#endif
|
||||
|
|
|
@ -120,7 +120,8 @@ ErrorCode MatMulExecution::onExecute(const std::vector<Tensor *> &inputs, const
|
|||
MNN_PRINT("kernel cost:%d us Matmul\n",costTime);
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("End MatMulExecution onExecute... \n");
|
||||
#endif
|
||||
|
|
|
@ -155,7 +155,8 @@ ErrorCode PoolExecution::onExecute(const std::vector<Tensor *> &inputs, const st
|
|||
MNN_PRINT("kernel cost:%d us Pooling\n",costTime);
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("End PoolExecution onExecute... \n");
|
||||
#endif
|
||||
|
|
|
@ -152,7 +152,8 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con
|
|||
endRecord(mOpenCLBackend->getOpenCLRuntime(), mRecording);
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
|
||||
bool cancombine = CanCombine(outputs);
|
||||
// Alloc Temp buffer
|
||||
auto bufferPool = ((OpenCLBackend *)backend())->getBufferPool();
|
||||
auto bufferUnitSize = runtime->isSupportedFP16() ? sizeof(half_float::half) : sizeof(float);
|
||||
|
@ -176,6 +177,9 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con
|
|||
bufferPool->recycle(mTempOutput);
|
||||
|
||||
auto originNum = mTempInput.size();
|
||||
if(cancombine){
|
||||
regionNum = 1;
|
||||
}
|
||||
mUnits.resize(regionNum + originNum + 1);
|
||||
|
||||
int kernel_idx = 0;
|
||||
|
@ -259,20 +263,25 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con
|
|||
}
|
||||
|
||||
// buffer raster
|
||||
for (auto& slice : des->regions)
|
||||
{
|
||||
if(cancombine){
|
||||
auto regions = des->regions;
|
||||
auto slice = regions[0];
|
||||
int nums = regions.size();
|
||||
int src_offset = regions[1].src.offset - slice.src.offset;
|
||||
int dst_offset = regions[1].dst.offset - slice.dst.offset;
|
||||
|
||||
Unit &unit = mUnits[kernel_idx++];
|
||||
unit.kernel = runtime->buildKernel("raster", "raster_buffer", {});
|
||||
|
||||
unit.globalWorkSize = {(uint32_t)slice.size[2],
|
||||
(uint32_t)slice.size[1],
|
||||
(uint32_t)slice.size[0]};
|
||||
|
||||
const std::vector<uint32_t> gws = {(uint32_t)slice.size[2],
|
||||
(uint32_t)slice.size[1],
|
||||
(uint32_t)slice.size[0]};
|
||||
unit.kernel = runtime->buildKernel("raster", "raster_buffer_combine", {});
|
||||
|
||||
unit.globalWorkSize = {(uint32_t)slice.size[2] * nums,
|
||||
(uint32_t)slice.size[1],
|
||||
(uint32_t)slice.size[0]};
|
||||
|
||||
const std::vector<uint32_t> gws = {(uint32_t)slice.size[2] * nums,
|
||||
(uint32_t)slice.size[1],
|
||||
(uint32_t)slice.size[0]};
|
||||
uint32_t mMaxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(unit.kernel));
|
||||
|
||||
|
||||
uint32_t idx = 0;
|
||||
cl_int ret = CL_SUCCESS;
|
||||
ret |= unit.kernel.setArg(idx++, gws[0]);
|
||||
|
@ -280,14 +289,17 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con
|
|||
ret |= unit.kernel.setArg(idx++, gws[2]);
|
||||
ret |= unit.kernel.setArg(idx++, *(mTempInput[slice.origin]));
|
||||
ret |= unit.kernel.setArg(idx++, slice.src.offset);
|
||||
ret |= unit.kernel.setArg(idx++, src_offset);
|
||||
ret |= unit.kernel.setArg(idx++, slice.src.stride[0]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.src.stride[1]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.src.stride[2]);
|
||||
ret |= unit.kernel.setArg(idx++, *mTempOutput);
|
||||
ret |= unit.kernel.setArg(idx++, slice.dst.offset);
|
||||
ret |= unit.kernel.setArg(idx++, dst_offset);
|
||||
ret |= unit.kernel.setArg(idx++, slice.dst.stride[0]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.dst.stride[1]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.dst.stride[2]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.size[2]);
|
||||
if(ret != CL_SUCCESS)
|
||||
{
|
||||
MNN_PRINT("setArg err %d\n", (int)ret);
|
||||
|
@ -299,9 +311,54 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con
|
|||
unit.localWorkSize = {lws[0], lws[1], lws[2]};
|
||||
|
||||
unit.globalWorkSize = {ROUND_UP(gws[0], std::max((uint32_t)1, lws[0])),
|
||||
ROUND_UP(gws[1], std::max((uint32_t)1, lws[1])),
|
||||
ROUND_UP(gws[2], std::max((uint32_t)1, lws[2]))};
|
||||
ROUND_UP(gws[1], std::max((uint32_t)1, lws[1])),
|
||||
ROUND_UP(gws[2], std::max((uint32_t)1, lws[2]))};
|
||||
recordKernel3d(unit.kernel, gws, lws, runtime);
|
||||
}else{
|
||||
for (auto& slice : des->regions)
|
||||
{
|
||||
Unit &unit = mUnits[kernel_idx++];
|
||||
unit.kernel = runtime->buildKernel("raster", "raster_buffer", {});
|
||||
|
||||
unit.globalWorkSize = {(uint32_t)slice.size[2],
|
||||
(uint32_t)slice.size[1],
|
||||
(uint32_t)slice.size[0]};
|
||||
|
||||
const std::vector<uint32_t> gws = {(uint32_t)slice.size[2],
|
||||
(uint32_t)slice.size[1],
|
||||
(uint32_t)slice.size[0]};
|
||||
uint32_t mMaxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(unit.kernel));
|
||||
|
||||
uint32_t idx = 0;
|
||||
cl_int ret = CL_SUCCESS;
|
||||
ret |= unit.kernel.setArg(idx++, gws[0]);
|
||||
ret |= unit.kernel.setArg(idx++, gws[1]);
|
||||
ret |= unit.kernel.setArg(idx++, gws[2]);
|
||||
ret |= unit.kernel.setArg(idx++, *(mTempInput[slice.origin]));
|
||||
ret |= unit.kernel.setArg(idx++, slice.src.offset);
|
||||
ret |= unit.kernel.setArg(idx++, slice.src.stride[0]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.src.stride[1]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.src.stride[2]);
|
||||
ret |= unit.kernel.setArg(idx++, *mTempOutput);
|
||||
ret |= unit.kernel.setArg(idx++, slice.dst.offset);
|
||||
ret |= unit.kernel.setArg(idx++, slice.dst.stride[0]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.dst.stride[1]);
|
||||
ret |= unit.kernel.setArg(idx++, slice.dst.stride[2]);
|
||||
if(ret != CL_SUCCESS)
|
||||
{
|
||||
MNN_PRINT("setArg err %d\n", (int)ret);
|
||||
}
|
||||
|
||||
std::string name = "rasterBuffer";
|
||||
const std::vector<uint32_t> lws = localWS3DDefault(gws, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), name, unit.kernel).first;
|
||||
|
||||
unit.localWorkSize = {lws[0], lws[1], lws[2]};
|
||||
|
||||
unit.globalWorkSize = {ROUND_UP(gws[0], std::max((uint32_t)1, lws[0])),
|
||||
ROUND_UP(gws[1], std::max((uint32_t)1, lws[1])),
|
||||
ROUND_UP(gws[2], std::max((uint32_t)1, lws[2]))};
|
||||
recordKernel3d(unit.kernel, gws, lws, runtime);
|
||||
}
|
||||
}
|
||||
|
||||
//buffer to image
|
||||
|
@ -364,6 +421,44 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con
|
|||
return NO_ERROR;
|
||||
}
|
||||
|
||||
bool RasterExecution::CanCombine(const std::vector<Tensor *> &outputs){
|
||||
auto des = TensorUtils::getDescribe(outputs[0]);
|
||||
auto regions = des->regions;
|
||||
if(regions.size() < 2)
|
||||
return false;
|
||||
auto origin = regions[0].origin;
|
||||
const int size0 = regions[0].size[0];
|
||||
const int size1 = regions[0].size[1];
|
||||
const int size2 = regions[0].size[2];
|
||||
const int src_offset = regions[1].src.offset - regions[0].src.offset;
|
||||
const int dst_offset = regions[1].dst.offset - regions[0].dst.offset;
|
||||
const int src_sride0 = regions[0].src.stride[0];
|
||||
const int src_sride1 = regions[0].src.stride[1];
|
||||
const int src_sride2 = regions[0].src.stride[2];
|
||||
const int dst_sride0 = regions[0].dst.stride[0];
|
||||
const int dst_sride1 = regions[0].dst.stride[1];
|
||||
const int dst_sride2 = regions[0].dst.stride[2];
|
||||
bool res = true;
|
||||
for(int i = 1; i < regions.size(); ++i){
|
||||
res &= regions[i].origin == origin;
|
||||
res &= regions[i].size[0] == size0;
|
||||
res &= regions[i].size[1] == size1;
|
||||
res &= regions[i].size[2] == size2;
|
||||
res &= regions[i].src.stride[0] == src_sride0;
|
||||
res &= regions[i].src.stride[1] == src_sride1;
|
||||
res &= regions[i].src.stride[2] == src_sride2;
|
||||
res &= regions[i].dst.stride[0] == dst_sride0;
|
||||
res &= regions[i].dst.stride[1] == dst_sride1;
|
||||
res &= regions[i].dst.stride[2] == dst_sride2;
|
||||
res &= (regions[i].src.offset - regions[i - 1].src.offset) == src_offset;
|
||||
res &= (regions[i].dst.offset - regions[i - 1].dst.offset) == dst_offset;
|
||||
if(res == false){
|
||||
return res;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
OpenCLCreatorRegister<TypedCreator<RasterExecution>> __Raster_op(OpType_Raster, IMAGE);
|
||||
} // namespace OpenCL
|
||||
} // namespace MNN
|
||||
|
|
|
@ -26,6 +26,7 @@ public:
|
|||
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
|
||||
private:
|
||||
bool CanCombine(const std::vector<Tensor *> &outputs);
|
||||
std::map<Tensor*, cl::Buffer *> mTempInput;
|
||||
cl::Buffer *mTempOutput;
|
||||
OpenCLBackend *mOpenCLBackend;
|
||||
|
|
|
@ -175,7 +175,8 @@ ErrorCode ReductionExecution::onExecute(const std::vector<Tensor *> &inputs, con
|
|||
MNN_PRINT("kernel cost:%d us Reduct1D\n",costTime);
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("End ReductionExecution onExecute... \n");
|
||||
#endif
|
||||
|
|
|
@ -133,7 +133,8 @@ ErrorCode RoiPooling::onExecute(const std::vector<Tensor *> &inputs, const std::
|
|||
MNN_PRINT("kernel cost:%d us RoiPooling\n",costTime);
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("End RoiPooling onExecute... \n");
|
||||
#endif
|
||||
|
|
|
@ -175,7 +175,8 @@ ErrorCode ScaleExecution::onExecute(const std::vector<Tensor *> &inputs, const s
|
|||
MNN_PRINT("kernel cost:%d us Softmax\n",costTime);
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("End ScaleExecution onExecute... \n");
|
||||
#endif
|
||||
|
|
|
@ -117,7 +117,8 @@ ErrorCode SoftmaxExecution::onExecute(const std::vector<Tensor *> &inputs, const
|
|||
MNN_PRINT("kernel cost:%d us Softmax\n",costTime);
|
||||
#else
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("End SoftmaxExecution onExecute... \n");
|
||||
#endif
|
||||
|
|
|
@ -79,7 +79,8 @@ ErrorCode UnaryExecution::onExecute(const std::vector<Tensor*>& inputs, const st
|
|||
#else
|
||||
auto openCLBackend = static_cast<OpenCLBackend*>(backend());
|
||||
if(openCLBackend->getOpenCLRuntime()->isUseRecordQueue()){
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
if(mOpenCLBackend->getOpenCLRuntime()->isDevideOpRecord())
|
||||
mOpenCLBackend->getOpenCLRuntime()->getRecordings()->emplace_back(mRecording);
|
||||
#ifdef LOG_VERBOSE
|
||||
MNN_PRINT("End UnaryExecution onExecute... \n");
|
||||
#endif
|
||||
|
|
|
@ -318,6 +318,7 @@ void VulkanBackend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTenso
|
|||
tempTensor.reset(Tensor::create(srcTensor->shape(), dstTensor->getType(), nullptr, _convert(TensorUtils::getDescribe(srcTensor)->dimensionFormat)), [dstTensor](void* t) {
|
||||
Tensor* temp = (Tensor*)t;
|
||||
MNNCPUCopyBuffer(temp, dstTensor);
|
||||
delete temp;
|
||||
});
|
||||
dstTensor = tempTensor.get();
|
||||
}
|
||||
|
|
|
@ -36,9 +36,9 @@ public:
|
|||
MNN_ERROR("Invalid origin conv op\n");
|
||||
return nullptr;
|
||||
}
|
||||
auto op = originConv->expr().first->get()->UnPack();
|
||||
std::unique_ptr<MNN::OpT> op( originConv->expr().first->get()->UnPack());
|
||||
op->main.AsConvolution2D()->symmetricQuan->winogradAttr = encode();
|
||||
return (Express::Variable::create(Express::Expr::create(op, originConv->expr().first->inputs())));
|
||||
return (Express::Variable::create(Express::Expr::create(op.get(), originConv->expr().first->inputs())));
|
||||
}
|
||||
std::vector<Attr> attrs;
|
||||
private:
|
||||
|
|
|
@ -567,7 +567,8 @@ void ConvolutionCommon::getConvParameters(std::shared_ptr<Int8Common> *quanCommo
|
|||
*originWeight = nullptr;
|
||||
*originWeightSize = 0;
|
||||
if (nullptr != conv2d->quanParameter()) {
|
||||
*quanCommon = load(conv2d->quanParameter(), false);
|
||||
bool forceFloat = conv2d->quanParameter()->index() != nullptr;
|
||||
*quanCommon = load(conv2d->quanParameter(), forceFloat);
|
||||
*originWeight = (*quanCommon)->weightFloat.get();
|
||||
*originWeightSize = (*quanCommon)->weightFloat.size();
|
||||
}
|
||||
|
|
|
@ -581,7 +581,9 @@ static ErrorCode _createExecutions(Schedule::PipelineInfo& mInfo) {
|
|||
// Try Backup
|
||||
iter.execution.reset(mBackupBackend->onCreate(iter.inputs, iter.outputs, iter.op));
|
||||
if (nullptr == iter.execution) {
|
||||
MNN_ERROR("Create exection error : %d\n", iter.op->type());
|
||||
if (mInfo.first.reportError) {
|
||||
MNN_ERROR("Create execution error : %d\n", iter.op->type());
|
||||
}
|
||||
return NOT_SUPPORT;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -61,6 +61,7 @@ public:
|
|||
std::pair<std::shared_ptr<Backend>, std::shared_ptr<Backend>> cache;
|
||||
bool needComputeShape = true;
|
||||
bool needComputeGeometry = true;
|
||||
bool reportError = true;
|
||||
std::map<Tensor*, TENSORCACHE> inputTensorCopyCache;
|
||||
};
|
||||
typedef std::pair<BackendCache, std::vector<OpCacheInfo>> PipelineInfo;
|
||||
|
|
|
@ -852,7 +852,7 @@ public:
|
|||
return true;
|
||||
};
|
||||
auto y = _mobileNetV1Expr();
|
||||
bool res = func(y, 60.0f);
|
||||
bool res = func(y, 62.0f);
|
||||
if (!res) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ protected:
|
|||
template<typename Tin, typename Tout>
|
||||
bool test(VARP (*opFunc)(VARP, VARP), string name, float threshold,
|
||||
const vector<Tin>& data_x, const vector<Tin>& data_y, const vector<Tout>& data_out,
|
||||
const vector<int>& shape_x, const vector<int>& shape_y, const vector<int>& shape_out, const vector<float> quantScales={-100, -100, -100}, const vector<float> zeroPoints={-100, -100, -100}) {
|
||||
const vector<int>& shape_x, const vector<int>& shape_y, const vector<int>& shape_out, const vector<float> quantScales={-100.f, -100.f, -100.f}, const vector<float> zeroPoints={-100.f, -100.f, -100.f}, Dimensionformat format = NCHW) {
|
||||
int size_x = 1, size_y = 1, size_out = 1;
|
||||
for (int i = 0; i < shape_x.size(); ++i) {
|
||||
size_x *= shape_x[i];
|
||||
|
@ -34,8 +34,8 @@ protected:
|
|||
size_out *= shape_out[i];
|
||||
}
|
||||
|
||||
auto input_x = _Input(shape_x, NCHW, halide_type_of<Tin>());
|
||||
auto input_y = _Input(shape_y, NCHW, halide_type_of<Tin>());
|
||||
auto input_x = _Input(shape_x, format, halide_type_of<Tin>());
|
||||
auto input_y = _Input(shape_y, format, halide_type_of<Tin>());
|
||||
input_x->setName("input_x");
|
||||
input_y->setName("input_y");
|
||||
if (quantScales[0] != -100) { // -100 means invalid scale.
|
||||
|
@ -593,6 +593,42 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class AddC4Test : public BinaryTestCommon {
|
||||
public:
|
||||
virtual ~AddC4Test() = default;
|
||||
virtual bool run(int precision) {
|
||||
{
|
||||
vector<float> inp2 = {1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6}, inp1 = {2};
|
||||
vector<float> rightResult = {3.1, 4.2, 5.3, 6.6,3.1, 4.2, 5.3, 6.6,3.1, 4.2, 5.3, 6.6,3.1, 4.2, 5.3, 6.6,3.1, 4.2, 5.3, 6.6, 3.1, 4.2, 5.3, 6.6, 3.1, 4.2, 5.3, 6.6, 3.1, 4.2, 5.3, 6.6};
|
||||
bool res = test<float, float>(MNN::Express::_Add, "AddInt8C4Test", 0.01, inp1, inp2, rightResult, {1, 1, 1, 1}, {1, 32, 1, 1}, {1, 32, 1, 1}, {0.4, 0.4, 1.0},
|
||||
{1., 2., 3.}, NC4HW4);
|
||||
if (!res) {
|
||||
FUNC_PRINT(1);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
std::vector<float> i1 = {
|
||||
-1.0, -2.0, 0.f, 0.f
|
||||
-3.0, -4.0, 0.f, 0.f
|
||||
-5.0, -6.0, 0.f, 0.f
|
||||
-7.0, -8.0, 0.f, 0.f
|
||||
};
|
||||
std::vector<float> i0 = {
|
||||
1.0f, 0.0f, 0.f, 0.f
|
||||
};
|
||||
std::vector<float> i2 = {
|
||||
0.0, -1.0, 0.f, 0.f
|
||||
-2.0, -3.0, 0.f, 0.f
|
||||
-4.0, -5.0, 0.f, 0.f
|
||||
-6.0, -7.0, 0.f, 0.f
|
||||
};
|
||||
return test<float, float>(MNN::Express::_BiasAdd, "AddC4FloatTest", 0.01,
|
||||
i0, i1, i2,
|
||||
{1, 1, 1, 1}, {4, 2, 1, 1}, {4, 2, 1, 1}, {-100.f, -100.f, -100.f}, {-100.f, -100.f, -100.f}, NC4HW4);
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
// Float32 OpTest.
|
||||
MNNTestSuiteRegister(BinaryBroadcastShapeTest, "op/binary/broadcastShapeTest");
|
||||
MNNTestSuiteRegister(AddTest, "op/binary/add");
|
||||
|
@ -633,3 +669,5 @@ MNNTestSuiteRegister(FloorDivInt8Test, "op/binary/floordivInt8");
|
|||
MNNTestSuiteRegister(FloorModInt8Test, "op/binary/floormodInt8");
|
||||
MNNTestSuiteRegister(Atan2Int8Test, "op/binary/atan2Int8");
|
||||
MNNTestSuiteRegister(SquaredDifferenceInt8Test, "op/binary/sqdInt8");
|
||||
|
||||
MNNTestSuiteRegister(AddC4Test, "op/binary/addC4");
|
||||
|
|
|
@ -0,0 +1,225 @@
|
|||
//
|
||||
// TopKV2Execution.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2023/07/19.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#include <MNN/expr/Expr.hpp>
|
||||
#include <MNN/expr/ExprCreator.hpp>
|
||||
#include "MNNTestSuite.h"
|
||||
#include "TestUtils.h"
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
using namespace MNN::Express;
|
||||
|
||||
|
||||
template<typename valueT, typename indexT>
|
||||
void MinHeapify(valueT * arr, indexT * index, int size, int i) {
|
||||
int l = 2 * i + 1;
|
||||
int r = 2 * i + 2;
|
||||
int smallest = i;
|
||||
if (l < size && arr[l] < arr[smallest]) {
|
||||
smallest = l;
|
||||
}
|
||||
if (r < size && arr[r] < arr[smallest]) {
|
||||
smallest = r;
|
||||
}
|
||||
if (smallest != i) {
|
||||
std::swap(arr[i], arr[smallest]);
|
||||
std::swap(index[i], index[smallest]);
|
||||
MinHeapify<valueT, indexT>(arr, index, size, smallest);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
template<typename valueT, typename indexT>
|
||||
void BuildMinHeap(valueT * arr, indexT * index, int size) {
|
||||
for (int i = size / 2 - 1; i >= 0; i--) {
|
||||
MinHeapify<valueT, indexT>(arr, index, size, i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename valueT, typename indexT>
|
||||
void Sort(valueT * values, indexT * indices, const int num) {
|
||||
valueT * _values = static_cast<valueT *>(values);
|
||||
indexT * _indices = static_cast<indexT *>(indices);
|
||||
for (int i = 0; i < num - 1; i++) {
|
||||
for (int j = 0; j < num - i - 1; j++) {
|
||||
if (_values[j] < _values[j + 1]) {
|
||||
std::swap(_values[j], _values[j + 1]);
|
||||
std::swap(_indices[j], _indices[j + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
template<typename valueT, typename indexT>
|
||||
void CpuKernelOneRow(const valueT * input, indexT * outputIndices, valueT * outputValues, const int K, const int length) {
|
||||
for (int i = 0; i < K; i++) {
|
||||
outputIndices[i] = i;
|
||||
outputValues[i] = input[i];
|
||||
}
|
||||
BuildMinHeap<valueT, indexT>(outputValues, outputIndices, K);
|
||||
for (int i = K; i < length; i++) {
|
||||
if (input[i] > outputValues[0]) {
|
||||
outputValues[0] = input[i];
|
||||
outputIndices[0] = i;
|
||||
MinHeapify<valueT, indexT>(outputValues, outputIndices, K, 0);
|
||||
}
|
||||
}
|
||||
Sort<valueT, indexT>(outputValues, outputIndices, K);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
template<typename indexT, typename valueT>
|
||||
void CpuKernelAllRows(valueT * input, indexT * outputIndices, valueT * outputValues, const int K, const int lengthRow, const int numRow, int descendFlag) {
|
||||
for (int i = 0; i < lengthRow * numRow; i++) {
|
||||
input[i] = input[i] * descendFlag;
|
||||
}
|
||||
|
||||
for (int i = 0; i < numRow; i++) {
|
||||
const valueT * inputThisRow = input + lengthRow * i;
|
||||
indexT * outputIndicesThisRow = outputIndices + K * i;
|
||||
valueT * outputValuesThisRow = outputValues + K * i;
|
||||
CpuKernelOneRow(inputThisRow, outputIndicesThisRow, outputValuesThisRow, K, lengthRow);
|
||||
}
|
||||
|
||||
for (int i = 0; i < lengthRow * numRow; i++) {
|
||||
input[i] = input[i] * descendFlag;
|
||||
}
|
||||
|
||||
for (int i = 0 ; i < numRow * K; i++) {
|
||||
outputValues[i] = outputValues[i] * descendFlag;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
void RandomInitFloat(float * array, const int & numEle) {
|
||||
std::mt19937 rng(4);
|
||||
std::uniform_real_distribution<float> dist(0.0, 1.0);
|
||||
|
||||
for (int i = 0; i < numEle; i++) {
|
||||
array[i] = dist(rng);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
void SetK(int * valuePtr, const int K) {
|
||||
*valuePtr = K;
|
||||
}
|
||||
|
||||
|
||||
bool checkIndicesHalf(const float * input, const float * expectedOutput0, const int * gotOutput1, const int K, const int numRow, const int lengthRow) {
|
||||
for (int i = 0; i < numRow; i++) {
|
||||
for (int j = 0; j < K; j++) {
|
||||
bool condition = (convertFP32ToFP16(expectedOutput0[i * K + j]) != convertFP32ToFP16(input[gotOutput1[i * K + j] + i * lengthRow]));
|
||||
if (condition) {
|
||||
MNN_PRINT("Conflict: Number %d. Value Correct is %f. Value Computed is %f.\n", i * K + j, convertFP32ToFP16(expectedOutput0[i * K + j]), convertFP32ToFP16(input[gotOutput1[i * K + j] + i * lengthRow]));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool checkIndicesFloat(const float * input, const float * expectedOutput0, const int * gotOutput1, const int K, const int numRow, const int lengthRow) {
|
||||
for (int i = 0; i < numRow; i++) {
|
||||
for (int j = 0; j < K; j++) {
|
||||
bool condition = (expectedOutput0[i * K + j] != input[gotOutput1[i * K + j] + i * lengthRow]);
|
||||
if (condition) {
|
||||
MNN_PRINT("Conflict: Number %d. Value Correct is %f. Value Computed is %f.\n", i * K + j, expectedOutput0[i * K + j], input[gotOutput1[i * K + j] + i * lengthRow]);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
void printTimeCost(uint64_t timeCost) {
|
||||
uint64_t seconds = timeCost / 1000000;
|
||||
uint64_t microseconds = timeCost % 1000000;
|
||||
MNN_PRINT("%lu s %lu ms\n", seconds, microseconds / 1000);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
class TopKV2Test : public MNNTestCase {
|
||||
public:
|
||||
virtual ~TopKV2Test() = default;
|
||||
|
||||
virtual bool run(int precision) {
|
||||
// set params
|
||||
const int K = 10;
|
||||
const int numRow = 180;
|
||||
|
||||
const int lengthRow = 21491;
|
||||
|
||||
// set input
|
||||
VARP input0 = _Input({numRow, lengthRow}, NCHW, halide_type_of<float>());
|
||||
VARP input1 = _Input({1}, NCHW, halide_type_of<int>());
|
||||
RandomInitFloat(input0->writeMap<float>(), numRow * lengthRow);
|
||||
SetK(input1->writeMap<int>(), K);
|
||||
|
||||
auto timeStart = getTimeInUs();
|
||||
// calculate gotOutput
|
||||
auto res = _TopKV2(input0, input1);
|
||||
VARP output0 = res[0];
|
||||
VARP output1 = res[1];
|
||||
auto gotOutput0 = output0->readMap<float>();
|
||||
auto gotOutput1 = output1->readMap<int>();
|
||||
auto timeEnd = getTimeInUs();
|
||||
auto timeCost = timeEnd - timeStart;
|
||||
|
||||
// calculate expectedOutput
|
||||
std::vector<float> expectedOutput0(numRow * K);
|
||||
std::vector<int> expectedOutput1(numRow * K);
|
||||
CpuKernelAllRows<int, float>(input0->writeMap<float>(), expectedOutput1.data(), expectedOutput0.data(), K, lengthRow, numRow, 1);
|
||||
|
||||
printTimeCost(timeCost);
|
||||
|
||||
// check values
|
||||
float errorScale = precision <= MNN::BackendConfig::Precision_High ? 1 : 20;
|
||||
if (!checkVectorByRelativeError<float>(gotOutput0, expectedOutput0.data(), numRow * K, 0.001 * errorScale)) {
|
||||
MNN_ERROR("TopKV2 test failed!\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// check indices
|
||||
if (precision <= 1) {
|
||||
if (!checkIndicesFloat(input0->readMap<float>(), expectedOutput0.data(), gotOutput1, K, numRow, lengthRow)) {
|
||||
MNN_ERROR("TopKV2 test failed!\n");
|
||||
return false;
|
||||
}
|
||||
} else if (precision == 2) {
|
||||
if (!checkIndicesHalf(input0->readMap<float>(), expectedOutput0.data(), gotOutput1, K, numRow, lengthRow)) {
|
||||
MNN_ERROR("TopKV2 test failed!\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
MNNTestSuiteRegister(TopKV2Test, "op/TopKV2");
|
|
@ -158,7 +158,7 @@ int writeFb(std::unique_ptr<MNN::NetT>& netT, const std::string& MNNModelFile, c
|
|||
output.write((const char*)bufferOutput, sizeOutput);
|
||||
}
|
||||
if (!netT->subgraphs.empty()) {
|
||||
MNN_PRINT("The model has subgraphs, please use MNN::Module to run it\n");
|
||||
MNN_PRINT("The model has subgraphs, please use MNN::Express::Module to run it\n");
|
||||
}
|
||||
|
||||
#ifdef MNN_DUMP_SUBGRAPH
|
||||
|
|
|
@ -21,58 +21,6 @@
|
|||
#include "onnxConverter.hpp"
|
||||
#include "onnxOpConverter.hpp"
|
||||
|
||||
std::vector<int> topoSort(const ::onnx::GraphProto& onnxGraph) {
|
||||
std::vector<int> idxMap;
|
||||
const int nodeCount = onnxGraph.node_size();
|
||||
std::map<std::string, int> outputMap;
|
||||
std::map<int, std::vector<int>> graph; // key --[in]--> values
|
||||
std::vector<int> inDegree(nodeCount);
|
||||
// build Graph and inDegree
|
||||
for (int i = 0; i < nodeCount; ++i) {
|
||||
const auto& onnxNode = onnxGraph.node(i);
|
||||
if (onnxNode.op_type() == "Loop" || onnxNode.op_type() == "If") {
|
||||
return idxMap;
|
||||
}
|
||||
for (int k = 0; k < onnxNode.output_size(); k++) {
|
||||
outputMap.insert(std::make_pair(onnxNode.output(k), i));
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < nodeCount; ++i) {
|
||||
const auto& onnxNode = onnxGraph.node(i);
|
||||
for (int k = 0; k < onnxNode.input_size(); k++) {
|
||||
auto inputName = onnxNode.input(k);
|
||||
auto iter = outputMap.find(inputName);
|
||||
if (iter != outputMap.end()) {
|
||||
graph[iter->second].push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto node : graph) {
|
||||
for (auto output : node.second) {
|
||||
inDegree[output]++;
|
||||
}
|
||||
}
|
||||
// topo sort
|
||||
std::queue<int> validNode;
|
||||
for (int i = 0; i < nodeCount; i++) {
|
||||
if (!inDegree[i]) {
|
||||
validNode.push(i);
|
||||
}
|
||||
}
|
||||
while (!validNode.empty()) {
|
||||
int node = validNode.front();
|
||||
validNode.pop();
|
||||
idxMap.push_back(node);
|
||||
for (auto succ : graph[node]) {
|
||||
if (--inDegree[succ] == 0) {
|
||||
validNode.push(succ);
|
||||
}
|
||||
}
|
||||
}
|
||||
MNN_ASSERT(idxMap.size() == nodeCount);
|
||||
return idxMap;
|
||||
}
|
||||
|
||||
int onnx2MNNNet(const std::string inputModel, const std::string bizCode,
|
||||
std::unique_ptr<MNN::NetT>& netT) {
|
||||
std::string modelDir;
|
||||
|
@ -117,6 +65,7 @@ int onnx2MNNNet(const std::string inputModel, const std::string bizCode,
|
|||
MNNOp->main.type = MNN::OpParameter_Input;
|
||||
auto inputParam = new MNN::InputT;
|
||||
const auto it = inputs.find(iter.first);
|
||||
//FUNC_PRINT_ALL(iter.first.c_str(), s);
|
||||
DCHECK(it != inputs.end()) << "Input Paramter ERROR ==> " << iter.first;
|
||||
const auto& tensorInfo = (it->second)->type().tensor_type();
|
||||
const int inputDimSize = tensorInfo.shape().dim_size();
|
||||
|
@ -138,8 +87,24 @@ int onnx2MNNNet(const std::string inputModel, const std::string bizCode,
|
|||
}
|
||||
|
||||
// onnx model not all topo sort graph, sort it
|
||||
std::vector<int> idxMap = topoSort(onnxGraph);
|
||||
std::vector<int> idxMap = OnnxScope::topoSort(onnxGraph);
|
||||
|
||||
auto makeConst = [&](const std::string& inputName) {
|
||||
const auto it = initializers.find(inputName);
|
||||
if (it != initializers.end() && scope->lookupTensor(it->first) == -1) {
|
||||
// Create const Op
|
||||
MNN::OpT* constOp = new MNN::OpT;
|
||||
constOp->type = MNN::OpType_Const;
|
||||
constOp->main.type = MNN::OpParameter_Blob;
|
||||
constOp->main.value = onnxOpConverter::convertTensorToBlob(it->second, modelDir);
|
||||
constOp->name = it->first;
|
||||
constOp->outputIndexes.push_back(scope->declareTensor(it->first));
|
||||
netT->oplists.emplace_back(constOp);
|
||||
}
|
||||
};
|
||||
for (int i=0; i<onnxGraph.output_size(); ++i) {
|
||||
makeConst(onnxGraph.output(i).name());
|
||||
}
|
||||
// onnx node ==> MNN node
|
||||
for (int idx = 0; idx < nodeCount; ++idx) {
|
||||
int i = idxMap.size() == nodeCount ? idxMap[idx] : idx;
|
||||
|
@ -158,17 +123,7 @@ int onnx2MNNNet(const std::string inputModel, const std::string bizCode,
|
|||
// convert initializer to be Constant node(op)
|
||||
for (int k = 0; k < onnxNode.input_size(); ++k) {
|
||||
const auto& inputName = onnxNode.input(k);
|
||||
const auto it = initializers.find(inputName);
|
||||
if (it != initializers.end() && scope->lookupTensor(it->first) == -1) {
|
||||
// Create const Op
|
||||
MNN::OpT* constOp = new MNN::OpT;
|
||||
constOp->type = MNN::OpType_Const;
|
||||
constOp->main.type = MNN::OpParameter_Blob;
|
||||
constOp->main.value = onnxOpConverter::convertTensorToBlob(it->second, modelDir);
|
||||
constOp->name = it->first;
|
||||
constOp->outputIndexes.push_back(scope->declareTensor(it->first));
|
||||
netT->oplists.emplace_back(constOp);
|
||||
}
|
||||
makeConst(inputName);
|
||||
}
|
||||
|
||||
// build input and output
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#include <queue>
|
||||
#include "onnxOpConverter.hpp"
|
||||
#include "OpCount.hpp"
|
||||
#include "OnnxTmpGraph.hpp"
|
||||
|
@ -20,6 +21,67 @@ static int32_t _limit(int64_t i64) {
|
|||
}
|
||||
return i64;
|
||||
}
|
||||
std::vector<int> OnnxScope::topoSort(const onnx::GraphProto& onnxGraph) {
|
||||
std::vector<int> idxMap;
|
||||
const int nodeCount = onnxGraph.node_size();
|
||||
std::map<std::string, int> outputMap;
|
||||
std::map<int, std::vector<int>> graph; // key --[in]--> values
|
||||
std::vector<int> inDegree(nodeCount);
|
||||
// build Graph and inDegree
|
||||
for (int i = 0; i < nodeCount; ++i) {
|
||||
const auto& onnxNode = onnxGraph.node(i);
|
||||
for (int k = 0; k < onnxNode.output_size(); k++) {
|
||||
outputMap.insert(std::make_pair(onnxNode.output(k), i));
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < nodeCount; ++i) {
|
||||
const auto& onnxNode = onnxGraph.node(i);
|
||||
for (int k = 0; k < onnxNode.input_size(); k++) {
|
||||
auto inputName = onnxNode.input(k);
|
||||
auto iter = outputMap.find(inputName);
|
||||
if (iter != outputMap.end()) {
|
||||
graph[iter->second].push_back(i);
|
||||
}
|
||||
}
|
||||
if (onnxNode.op_type() == "Loop") {
|
||||
auto& body = onnxNode.attribute(0).g();
|
||||
for (int j=0; j<body.node_size(); ++j) {
|
||||
for (int k=0; k<body.node(j).input_size(); ++k) {
|
||||
auto inputName = body.node(j).input(k);
|
||||
auto iter = outputMap.find(inputName);
|
||||
if (iter != outputMap.end()) {
|
||||
graph[iter->second].push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto node : graph) {
|
||||
for (auto output : node.second) {
|
||||
inDegree[output]++;
|
||||
}
|
||||
}
|
||||
// topo sort
|
||||
std::queue<int> validNode;
|
||||
for (int i = 0; i < nodeCount; i++) {
|
||||
if (!inDegree[i]) {
|
||||
validNode.push(i);
|
||||
}
|
||||
}
|
||||
while (!validNode.empty()) {
|
||||
int node = validNode.front();
|
||||
validNode.pop();
|
||||
idxMap.push_back(node);
|
||||
for (auto succ : graph[node]) {
|
||||
if (--inDegree[succ] == 0) {
|
||||
validNode.push(succ);
|
||||
}
|
||||
}
|
||||
}
|
||||
MNN_ASSERT(idxMap.size() == nodeCount);
|
||||
return idxMap;
|
||||
}
|
||||
|
||||
class DefaultonnxOpConverter : public onnxOpConverter {
|
||||
public:
|
||||
virtual void run(MNN::OpT* dstOp, const onnx::NodeProto* onnxNode,
|
||||
|
@ -107,6 +169,7 @@ onnxOpConverter* onnxOpConverterSuit::search(const std::string& name) {
|
|||
MNN::DataType onnxOpConverter::convertDataType(int32_t itype) {
|
||||
static std::map<::onnx::TensorProto_DataType, MNN::DataType> dataTypeMap{
|
||||
{onnx::TensorProto_DataType_FLOAT, MNN::DataType_DT_FLOAT},
|
||||
{onnx::TensorProto_DataType_FLOAT16, MNN::DataType_DT_HALF},
|
||||
{onnx::TensorProto_DataType_INT8, MNN::DataType_DT_INT8},
|
||||
{onnx::TensorProto_DataType_INT32, MNN::DataType_DT_INT32},
|
||||
{onnx::TensorProto_DataType_INT64, MNN::DataType_DT_INT32}, // For compability, use int32 instead of int64
|
||||
|
@ -188,6 +251,8 @@ MNN::BlobT* onnxOpConverter::convertTensorToBlob(const onnx::TensorProto* consta
|
|||
CASE_DATA_TYPE(onnx::TensorProto_DataType_DOUBLE, double);
|
||||
CASE_DATA_TYPE(onnx::TensorProto_DataType_INT64, int64);
|
||||
CASE_DATA_TYPE(onnx::TensorProto_DataType_INT32, int32);
|
||||
CASE_DATA_TYPE(onnx::TensorProto_DataType_UINT8, int32);
|
||||
CASE_DATA_TYPE(onnx::TensorProto_DataType_INT8, int32);
|
||||
CASE_DATA_TYPE(onnx::TensorProto_DataType_FLOAT, float);
|
||||
CASE_DATA_TYPE(onnx::TensorProto_DataType_UINT64, uint64);
|
||||
CASE_DATA_TYPE(onnx::TensorProto_DataType_BOOL, int32);
|
||||
|
@ -265,13 +330,22 @@ MNN::BlobT* onnxOpConverter::convertTensorToBlob(const onnx::TensorProto* consta
|
|||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_UINT8: {
|
||||
auto source = (uint8_t*)tensor_content;
|
||||
constantParam->uint8s.resize(dataSize);
|
||||
for (int i = 0; i < dataSize; ++i) {
|
||||
constantParam->uint8s[i] = source[i];
|
||||
if (constantTp->int32_data_size() > 0) {
|
||||
auto source = (int32_t*)tensor_content;
|
||||
for (int i = 0; i < dataSize; ++i) {
|
||||
constantParam->uint8s[i] = source[i];
|
||||
}
|
||||
} else {
|
||||
::memcpy(constantParam->uint8s.data(), tensor_content, dataSize * sizeof(uint8_t));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_FLOAT16: {
|
||||
constantParam->uint8s.resize(dataSize * sizeof(int16_t));
|
||||
::memcpy(constantParam->uint8s.data(), tensor_content, dataSize * sizeof(int16_t));
|
||||
break;
|
||||
}
|
||||
case onnx::TensorProto_DataType_FLOAT: {
|
||||
float* tempFloatData = (float*)tensor_content;
|
||||
constantParam->float32s.resize(dataSize);
|
||||
|
@ -411,7 +485,33 @@ std::vector<std::string> OnnxScope::buildSubGraph(const onnx::GraphProto* graph,
|
|||
}
|
||||
// Find Extra Input from outside graph
|
||||
std::map<std::string, int> outsideInputs;
|
||||
for (int i = 0; i < graph->node_size(); i++) {
|
||||
auto findConst = [&](const std::string& name) {
|
||||
if (scope->lookupTensor(name) >= 0) {
|
||||
return;
|
||||
}
|
||||
// onnx subgraph may use tensor from initializers in outter level graph, recurrsive find it
|
||||
for (auto curScope = scope.get(); curScope != nullptr; ) {
|
||||
const auto& curInits = curScope->mInitializers;
|
||||
const auto it = curInits.find(name);
|
||||
if (it != curInits.end()) {
|
||||
// Create const Op
|
||||
MNN::OpT* constOp = new MNN::OpT;
|
||||
constOp->type = MNN::OpType_Const;
|
||||
constOp->main.type = MNN::OpParameter_Blob;
|
||||
constOp->main.value = onnxOpConverter::convertTensorToBlob(it->second);
|
||||
constOp->name = it->first;
|
||||
constOp->outputIndexes.push_back(scope->declareTensor(it->first));
|
||||
subgraph->nodes.emplace_back(constOp);
|
||||
break;
|
||||
}
|
||||
curScope = reinterpret_cast<decltype(curScope)>(curScope->mParent);
|
||||
}
|
||||
};
|
||||
for (int i=0; i<graph->output_size(); ++i) {
|
||||
findConst(graph->output(i).name());
|
||||
}
|
||||
auto indexes = OnnxScope::topoSort(*graph);
|
||||
for (auto i : indexes) {
|
||||
const auto& onnxNode = graph->node(i);
|
||||
const auto& opType = onnxNode.op_type();
|
||||
// name maybe null, use the first output name as node-name
|
||||
|
@ -423,26 +523,7 @@ std::vector<std::string> OnnxScope::buildSubGraph(const onnx::GraphProto* graph,
|
|||
MNNOp->main.type = opConverter->type();
|
||||
for (int k = 0; k < onnxNode.input_size(); ++k) {
|
||||
const auto& inputName = onnxNode.input(k);
|
||||
if (scope->lookupTensor(inputName) >= 0) {
|
||||
continue;
|
||||
}
|
||||
// onnx subgraph may use tensor from initializers in outter level graph, recurrsive find it
|
||||
for (auto curScope = scope.get(); curScope != nullptr; ) {
|
||||
const auto& curInits = curScope->mInitializers;
|
||||
const auto it = curInits.find(inputName);
|
||||
if (it != curInits.end()) {
|
||||
// Create const Op
|
||||
MNN::OpT* constOp = new MNN::OpT;
|
||||
constOp->type = MNN::OpType_Const;
|
||||
constOp->main.type = MNN::OpParameter_Blob;
|
||||
constOp->main.value = onnxOpConverter::convertTensorToBlob(it->second);
|
||||
constOp->name = it->first;
|
||||
constOp->outputIndexes.push_back(scope->declareTensor(it->first));
|
||||
subgraph->nodes.emplace_back(constOp);
|
||||
break;
|
||||
}
|
||||
curScope = reinterpret_cast<decltype(curScope)>(curScope->mParent);
|
||||
}
|
||||
findConst(inputName);
|
||||
}
|
||||
// build input and output
|
||||
for (int k = 0; k < onnxNode.input_size(); k++) {
|
||||
|
@ -453,6 +534,7 @@ std::vector<std::string> OnnxScope::buildSubGraph(const onnx::GraphProto* graph,
|
|||
if (iter == outsideInputs.end()) {
|
||||
idx = scope->declareTensor(inputName);
|
||||
std::unique_ptr<MNN::OpT> inputOp(new MNN::OpT);
|
||||
//FUNC_PRINT_ALL(inputName.c_str(), s);
|
||||
inputOp->name = inputName;
|
||||
inputOp->type = MNN::OpType_Input;
|
||||
inputOp->main.type = MNN::OpParameter_Input;
|
||||
|
@ -501,9 +583,10 @@ std::vector<std::string> OnnxScope::buildSubGraph(const onnx::GraphProto* graph,
|
|||
int N = graph->input_size() - 2, K = graph->output_size() - N - 1;
|
||||
for (int i = 0; i < N + 1; i++) {
|
||||
int idx = scope->lookupTensor(graph->output(i).name());
|
||||
MNN_ASSERT(idx >= 0);
|
||||
if (idx >= 0) {
|
||||
subgraph->outputs.push_back(idx);
|
||||
} else {
|
||||
FUNC_PRINT_ALL(graph->output(i).name().c_str(), s);
|
||||
}
|
||||
}
|
||||
std::vector<std::string> resOutside;
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
class OnnxScope : public ConverterScope {
|
||||
public:
|
||||
static std::vector<int> topoSort(const onnx::GraphProto& onnxGraph);
|
||||
OnnxScope(const onnx::GraphProto* graph, MNN::NetT* net) : mGraph(graph), ConverterScope(net) { onnxInit(); }
|
||||
OnnxScope(const onnx::GraphProto* graph, MNN::SubGraphProtoT* subnet, MNN::NetT* net,
|
||||
OnnxScope* parent) : mGraph(graph), ConverterScope(subnet, net, parent) { onnxInit(); }
|
||||
|
|
|
@ -342,7 +342,12 @@ static auto gRegister = []() {
|
|||
continue;
|
||||
}
|
||||
auto inputVar = inputVarIter->second;
|
||||
auto newVar = _Gelu(inputVar);
|
||||
std::unique_ptr<MNN::OpT> newUnary(new MNN::OpT);
|
||||
newUnary->type = OpType_UnaryOp;
|
||||
newUnary->main.type = OpParameter_UnaryOp;
|
||||
newUnary->main.value = new UnaryOpT;
|
||||
newUnary->main.AsUnaryOp()->opType = UnaryOpOperation_GELU_STANDARD;
|
||||
auto newVar = MNN::Express::Variable::create(MNN::Express::Expr::create(newUnary.get(), {inputVar}));
|
||||
newVar->setName(expr->outputName(0));
|
||||
Expr::replace(expr, newVar->expr().first);
|
||||
return true;
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
//
|
||||
// DynamicQuantizeLinear.cpp
|
||||
// MNNConverter
|
||||
//
|
||||
// Created by MNN on 2023/08/14.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#include <MNN/expr/ExprCreator.hpp>
|
||||
#include "MNN_generated.h"
|
||||
#include "OnnxExtraManager.hpp"
|
||||
|
||||
namespace MNN {
|
||||
namespace Express {
|
||||
// Ref from https://github.com/onnx/onnx/blob/main/docs/Operators.md#DynamicQuantizeLinear
|
||||
class OnnxDynamicQuantizeLinearTransform : public OnnxExtraManager::Transform {
|
||||
public:
|
||||
virtual EXPRP onExecute(EXPRP expr) const override {
|
||||
auto x = expr->inputs()[0];
|
||||
auto range = _Scalar<float>(1.0f/255.0f);
|
||||
auto maxX = _ReduceMax(x);
|
||||
auto minX = _ReduceMin(x);
|
||||
auto scale = (maxX - minX) * range;
|
||||
auto scaleReq = _Reciprocal(scale);
|
||||
// Qmin = 0
|
||||
auto interZero = _Negative(minX * scaleReq);
|
||||
auto zeroFloat = _Round(_Relu6(interZero, 0.0f, 255.0f));
|
||||
auto zero = _Cast<uint8_t>(zeroFloat);
|
||||
auto y = _Cast<uint8_t>(_Round(_Relu6(_Round(x * scaleReq) + zeroFloat, 0.0f, 255.0f)));
|
||||
std::unique_ptr<MNN::OpT> iden(new MNN::OpT);
|
||||
iden->type = OpType_Identity;
|
||||
|
||||
|
||||
auto newExpr = MNN::Express::Expr::create(iden.get(), {y, scale, zero}, 3);
|
||||
newExpr->setName(expr->name());
|
||||
for (int i=0; i<3; ++i) {
|
||||
auto v = MNN::Express::Variable::create(newExpr, i);
|
||||
v->setName(expr->outputName(i));
|
||||
}
|
||||
return newExpr;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static auto gRegister = []() {
|
||||
OnnxExtraManager::get()->insert("DynamicQuantizeLinear",
|
||||
std::shared_ptr<OnnxExtraManager::Transform>(new OnnxDynamicQuantizeLinearTransform));
|
||||
return true;
|
||||
}();
|
||||
|
||||
} // namespace Express
|
||||
} // namespace MNN
|
|
@ -0,0 +1,45 @@
|
|||
//
|
||||
// MatMulInteger.cpp
|
||||
// MNNConverter
|
||||
//
|
||||
// Created by MNN on 2023/08/14.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#include <MNN/expr/ExprCreator.hpp>
|
||||
#include "MNN_generated.h"
|
||||
#include "OnnxExtraManager.hpp"
|
||||
|
||||
namespace MNN {
|
||||
namespace Express {
|
||||
// Ref from https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMulInteger
|
||||
// Use float instead of uint8 to complete it
|
||||
class OnnxMatMulIntegerTransform : public OnnxExtraManager::Transform {
|
||||
public:
|
||||
virtual EXPRP onExecute(EXPRP expr) const override {
|
||||
auto inputs = expr->inputs();
|
||||
auto x = inputs[0];
|
||||
auto y = inputs[1];
|
||||
x = _Cast<float>(x);
|
||||
y = _Cast<float>(y);
|
||||
if (inputs.size() > 2) {
|
||||
x = x - _Cast<float>(inputs[2]);
|
||||
y = y - _Cast<float>(inputs[3]);
|
||||
}
|
||||
auto z = _MatMul(x, y);
|
||||
auto zInt = _Cast<int32_t>(z);
|
||||
auto newExpr = zInt->expr().first;
|
||||
newExpr->setName(expr->name());
|
||||
return newExpr;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static auto gRegister = []() {
|
||||
OnnxExtraManager::get()->insert("MatMulInteger",
|
||||
std::shared_ptr<OnnxExtraManager::Transform>(new OnnxMatMulIntegerTransform));
|
||||
return true;
|
||||
}();
|
||||
|
||||
} // namespace Express
|
||||
} // namespace MNN
|
|
@ -358,13 +358,15 @@ public:
|
|||
if (config->keepInputFormat) {
|
||||
// Change Output
|
||||
auto& outputs = mNet->outputName;
|
||||
std::vector<std::unique_ptr<MNN::OpT>> extraOp;
|
||||
for (auto& op : mNet->oplists) {
|
||||
for (int idx : op->outputIndexes) {
|
||||
for (int j = 0; j < outputs.size(); j++) {
|
||||
if (mNet->tensorName[idx] == outputs[j]) {
|
||||
auto outputFormat = tensorFormats[idx];
|
||||
if (outputFormat == MNN_DATA_FORMAT_NC4HW4) {
|
||||
auto newOutputName = outputs[j] + "__tr";
|
||||
auto newOutputName = outputs[j] + "__before_tr";
|
||||
mNet->tensorName[idx] = newOutputName;
|
||||
// Append a convert op
|
||||
MNN::OpT* transformOp = new MNN::OpT;
|
||||
MNN::TensorConvertInfoT* tc = new MNN::TensorConvertInfoT;
|
||||
|
@ -374,17 +376,20 @@ public:
|
|||
transformOp->main.value = tc;
|
||||
transformOp->name = newOutputName;
|
||||
transformOp->inputIndexes.push_back(idx);
|
||||
transformOp->outputIndexes.push_back(mNet->tensorName.size());
|
||||
int newOutputIndex = (int)mNet->tensorName.size();
|
||||
transformOp->outputIndexes.push_back(newOutputIndex);
|
||||
tensorFormats.push_back(originTensorType);
|
||||
mNet->tensorName.push_back(transformOp->name);
|
||||
mNet->tensorName.push_back(outputs[j]);
|
||||
transformOp->type = MNN::OpType_ConvertTensor;
|
||||
outputs[j] = newOutputName;
|
||||
mNet->oplists.emplace_back(transformOp);
|
||||
extraOp.emplace_back(transformOp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto&& op : extraOp) {
|
||||
mNet->oplists.emplace_back(std::move(op));
|
||||
}
|
||||
} else {
|
||||
// Change Input
|
||||
for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end(); iter++) {
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "cv/imgproc/geometric.hpp"
|
||||
#include <MNN/expr/NeuralNetWorkOp.hpp>
|
||||
#include <MNN/expr/MathOp.hpp>
|
||||
#include <cmath>
|
||||
|
||||
namespace MNN {
|
||||
namespace CV {
|
||||
|
|
|
@ -561,7 +561,7 @@ public:
|
|||
mWinogradAttr.reset(new WinogradInt8Attr);
|
||||
mWinogradTransInputMaxPos = addParameter(mWinogradTransInputMax);
|
||||
mWinogradTransInputMinPos = addParameter(mWinogradTransInputMin);
|
||||
mWinogradTransWeightScalePos = addParameter(nullptr);
|
||||
mWinogradTransWeightScalePos = addParameter(mWinogradTransInputMax);
|
||||
}
|
||||
setName(mConvParameter.name);
|
||||
modules[i] = nullptr;
|
||||
|
@ -675,6 +675,10 @@ public:
|
|||
if (nullptr == originValue) {
|
||||
return newValue;
|
||||
}
|
||||
auto ptr = originValue->readMap<float>();
|
||||
if (ptr[0] == -100.0f) {
|
||||
return newValue;
|
||||
}
|
||||
switch (mScaleUpdateMethod) {
|
||||
case NN::MovingAverage:
|
||||
return originValue * _Scalar<float>(mMomentum) + newValue * _Scalar<float>(1.0f-mMomentum);
|
||||
|
@ -700,7 +704,7 @@ public:
|
|||
maxUnit = std::max(std::min(maxUnit, MAX_UNIT), MIN_UNIT);
|
||||
|
||||
auto units = std::pair<int, int>({0, 0});
|
||||
float maxRate = 1.0f, originCost = outH * outW * inChannel * outChannel * kernelH * kernelW;
|
||||
float maxRate = 2.0f, originCost = outH * outW * inChannel * outChannel * kernelH * kernelW;
|
||||
std::set<int> supportSu{4, 6};
|
||||
for (int uh = MIN_UNIT; uh <= maxUnit; ++uh) {
|
||||
for (int uw = MIN_UNIT; uw <= maxUnit; ++uw) {
|
||||
|
@ -1097,8 +1101,8 @@ private:
|
|||
NN::ScaleUpdateMethod mScaleUpdateMethod;
|
||||
bool mAccumulateToInt16 = false;
|
||||
std::shared_ptr<WinogradInt8Attr> mWinogradAttr;
|
||||
VARP mWinogradTransInputMin = nullptr;
|
||||
VARP mWinogradTransInputMax = nullptr;
|
||||
VARP mWinogradTransInputMin = _Const(-100.f);
|
||||
VARP mWinogradTransInputMax = _Const(-100.f);
|
||||
int mWinogradTransInputMinPos = -1;
|
||||
int mWinogradTransInputMaxPos = -1;
|
||||
int mWinogradTransWeightScalePos = -1;
|
||||
|
|
Loading…
Reference in New Issue