MNN:Sync: Sync Internal 3.0.5

This commit is contained in:
xiaying 2025-02-12 11:14:19 +08:00
parent 32a8e0d72f
commit 3b6ddc0341
118 changed files with 6247 additions and 3837 deletions

View File

@ -187,9 +187,6 @@ endif()
if(MNN_SUPPORT_TRANSFORMER_FUSE) if(MNN_SUPPORT_TRANSFORMER_FUSE)
add_definitions(-DMNN_SUPPORT_TRANSFORMER_FUSE) add_definitions(-DMNN_SUPPORT_TRANSFORMER_FUSE)
endif() endif()
if(MNN_BUILD_AUDIO)
add_definitions(-DMNN_BUILD_AUDIO)
endif()
# debug options # debug options
if(MNN_DEBUG_MEMORY) if(MNN_DEBUG_MEMORY)
add_definitions(-DMNN_DEBUG_MEMORY) add_definitions(-DMNN_DEBUG_MEMORY)
@ -216,6 +213,8 @@ option(MNN_TENSORRT "Enable TensorRT" OFF)
option(MNN_COREML "Enable CoreML" OFF) option(MNN_COREML "Enable CoreML" OFF)
option(MNN_NNAPI "Enable NNAPI" OFF) option(MNN_NNAPI "Enable NNAPI" OFF)
option(MNN_GPU_TIME_PROFILE "Enable time profiling for the OpenCL backend and Vulkan backend." OFF)
option(MNN_CUDA_PROFILE "Enable CUDA profile" OFF) option(MNN_CUDA_PROFILE "Enable CUDA profile" OFF)
if (NOT MNN_CUDA OR NOT CMAKE_SYSTEM_NAME MATCHES "^Linux") if (NOT MNN_CUDA OR NOT CMAKE_SYSTEM_NAME MATCHES "^Linux")
@ -470,6 +469,11 @@ IF(MNN_BUILD_LLM)
list(APPEND MNN_EXTRA_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/transformers/llm/engine/include/llm/llm.hpp) list(APPEND MNN_EXTRA_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/transformers/llm/engine/include/llm/llm.hpp)
ENDIF() ENDIF()
IF(MNN_BUILD_DIFFUSION)
file(GLOB MNN_DIFFUSION_HDRS ${CMAKE_CURRENT_SOURCE_DIR}/transformers/diffusion/engine/include/diffusion/*)
list(APPEND MNN_EXTRA_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/transformers/diffusion/engine/include/diffusion/diffusion.hpp)
ENDIF()
# Add Thread dependency # Add Thread dependency
@ -921,6 +925,15 @@ ELSE()
ENDFOREACH() ENDFOREACH()
ENDIF() ENDIF()
IF(MNN_BUILD_DIFFUSION)
if (NOT MNN_AAPL_FMWK)
INSTALL(FILES ${MNN_DIFFUSION_HDRS} DESTINATION include/MNN/diffusion)
endif()
FOREACH(HDR ${MNN_DIFFUSION_HDRS})
SET_SOURCE_FILES_PROPERTIES(${HDR} PROPERTIES MACOSX_PACKAGE_LOCATION Headers/diffusion )
ENDFOREACH()
ENDIF()
if (NOT MNN_AAPL_FMWK) if (NOT MNN_AAPL_FMWK)
INSTALL(FILES ${MNN_PUB_HDRS} DESTINATION include/MNN/) INSTALL(FILES ${MNN_PUB_HDRS} DESTINATION include/MNN/)
INSTALL(FILES ${MNN_EXPR_PUB_HDRS} DESTINATION include/MNN/expr/) INSTALL(FILES ${MNN_EXPR_PUB_HDRS} DESTINATION include/MNN/expr/)

View File

@ -165,3 +165,4 @@ MNN refers to the following projects:
- [libyuv](https://chromium.googlesource.com/libyuv/libyuv) - [libyuv](https://chromium.googlesource.com/libyuv/libyuv)
- [libjpeg](https://github.com/libjpeg-turbo/libjpeg-turbo) - [libjpeg](https://github.com/libjpeg-turbo/libjpeg-turbo)
- [opencv](https://github.com/opencv/opencv) - [opencv](https://github.com/opencv/opencv)
- [onnxruntime](https://github.com/microsoft/onnxruntime)

View File

@ -155,4 +155,5 @@ MNN参考、借鉴了下列项目
- [libyuv](https://chromium.googlesource.com/libyuv/libyuv) - [libyuv](https://chromium.googlesource.com/libyuv/libyuv)
- [libjpeg](https://github.com/libjpeg-turbo/libjpeg-turbo) - [libjpeg](https://github.com/libjpeg-turbo/libjpeg-turbo)
- [opencv](https://github.com/opencv/opencv) - [opencv](https://github.com/opencv/opencv)
- [onnxruntime](https://github.com/microsoft/onnxruntime)

View File

@ -163,3 +163,4 @@ MNNは以下のプロジェクトを参照しています
- [libyuv](https://chromium.googlesource.com/libyuv/libyuv) - [libyuv](https://chromium.googlesource.com/libyuv/libyuv)
- [libjpeg](https://github.com/libjpeg-turbo/libjpeg-turbo) - [libjpeg](https://github.com/libjpeg-turbo/libjpeg-turbo)
- [opencv](https://github.com/opencv/opencv) - [opencv](https://github.com/opencv/opencv)
- [onnxruntime](https://github.com/microsoft/onnxruntime)

View File

@ -59,7 +59,7 @@ MNN使用CMake构建项目CMake中的宏定义列表如下
| MNN_SSE_USE_FP16_INSTEAD | 在X86平台是否使用`FP16`替代`BF16`,默认为`OFF` | | MNN_SSE_USE_FP16_INSTEAD | 在X86平台是否使用`FP16`替代`BF16`,默认为`OFF` |
| MNN_AVX512_VNNI | 是否使用`avx512_vnni`指令,该宏仅在`MNN_AVX512=ON`时生效,默认为`OFF` | | MNN_AVX512_VNNI | 是否使用`avx512_vnni`指令,该宏仅在`MNN_AVX512=ON`时生效,默认为`OFF` |
| MNN_OPENCL_SIZE_CUT | 是否为了降低OpenCL大小而关闭OpenCL Buffer实现该宏仅在`MNN_OPENCL=ON`时生效,默认为`OFF` | | MNN_OPENCL_SIZE_CUT | 是否为了降低OpenCL大小而关闭OpenCL Buffer实现该宏仅在`MNN_OPENCL=ON`时生效,默认为`OFF` |
| MNN_OPENCL_PROFILE | 是否打开OpenCL Kernel性能Profile该宏仅在`MNN_OPENCL=ON`时生效,默认为`OFF` | | MNN_GPU_TIME_PROFILE | 是否打开OpenCL后端及Vulkan后端的Kernel性能Profile该宏仅在`MNN_OPENCL=ON`或`MNN_VULKAN=ON`时生效,默认为`OFF` |
| MNN_METALLIB_SOURCE | 使用Metal时是否直接使用Metal源码该宏仅在`MNN_METAL=ON`时生效,默认为`ON` | | MNN_METALLIB_SOURCE | 使用Metal时是否直接使用Metal源码该宏仅在`MNN_METAL=ON`时生效,默认为`ON` |
| MNN_VULKAN_DEBUG | 是否打开Vulkan的DEBUG模式该宏仅在`MNN_VULKAN=ON`时生效,默认为`OFF` | | MNN_VULKAN_DEBUG | 是否打开Vulkan的DEBUG模式该宏仅在`MNN_VULKAN=ON`时生效,默认为`OFF` |
| MNN_OPENGL_REGEN | 是否重新生成OpenGL Kenel该宏仅在`MNN_OPENGL=ON`时生效,默认为`OFF` | | MNN_OPENGL_REGEN | 是否重新生成OpenGL Kenel该宏仅在`MNN_OPENGL=ON`时生效,默认为`OFF` |

View File

@ -49,13 +49,13 @@ python3 convert_mnn.py onnx_path mnn_save_path "--weightQuantBits=8 --transforme
cd mnn_path cd mnn_path
mkdir build mkdir build
cd build cd build
cmake .. -DMNN_BUILD_DIFFUSION=ON -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_OPENCL=ON -DMNN_SEP_BUILD=OFF -DMNN_SUPPORT_TRANSFORMER_FUSE=ON cmake .. -DMNN_LOW_MEMORY=ON -DMNN_BUILD_DIFFUSION=ON -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_OPENCL=ON -DMNN_SEP_BUILD=OFF -DMNN_SUPPORT_TRANSFORMER_FUSE=ON
make -j32 make -j32
``` ```
### Android上 ### Android上
``` ```
cd mnn_path/project/android/build cd mnn_path/project/android/build
../build_64.sh -DMNN_BUILD_DIFFUSION=ON -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_OPENCL=ON -DMNN_SEP_BUILD=OFF -DMNN_SUPPORT_TRANSFORMER_FUSE=ON ../build_64.sh -DMNN_LOW_MEMORY=ON -DMNN_BUILD_DIFFUSION=ON -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_OPENCL=ON -DMNN_SEP_BUILD=OFF -DMNN_SUPPORT_TRANSFORMER_FUSE=ON
../updateTest.sh ../updateTest.sh
``` ```
## 运行Diffusion Demo ## 运行Diffusion Demo
@ -90,6 +90,6 @@ cd mnn_path/project/android/build
``` ```
## FAQ ## FAQ
1. Demo运行报错、段错误怎么解决 1. Demo运行报错、段错误怎么解决
- 常见错误可能是设备内存不足通常支持opencl fp16的设备需要保证3GB以上的内存不支持fp16则需要6GB以上显存了。 - 常见错误可能是设备内存不足通常支持opencl fp16的设备需要保证2GB以上的内存不支持fp16则需要4GB以上显存了。
2. 使用其他后端,出现报错,什么原因? 2. 使用其他后端,出现报错,什么原因?
- 目前其他后端暂不支持transformer插件算子需要在onnx->mnn模型转换阶段去掉--transformerFuse。 - 目前其他后端暂不支持transformer插件算子需要在onnx->mnn模型转换阶段去掉--transformerFuse。

View File

@ -1320,20 +1320,5 @@ VARP _Histogram(VARP x, int bin, int min, int max, int channel) {
return (Variable::create(Expr::create(std::move(op), {x}))); return (Variable::create(Expr::create(std::move(op), {x})));
} }
#ifdef MNN_BUILD_AUDIO
VARP _Stft(VARP sample, VARP window, int n_fft, int hop_length, bool abs) {
std::unique_ptr<OpT> op(new OpT);
op->type = OpType_Stft;
op->main.type = OpParameter_StftParam;
auto param = new StftParamT;
param->n_fft = n_fft;
param->hop_length = hop_length;
param->abs = abs;
op->main.value = param;
EXPRP expr = Expr::create(std::move(op), {sample, window});
return Variable::create(expr);
}
#endif
} // namespace Express } // namespace Express
} // namespace MNN } // namespace MNN

View File

@ -1902,8 +1902,12 @@ VARP _Col2Im(VARP x, VARP outputShape, INTS kernelSize, INTS dilate, INTS pads,
auto common = new Convolution2DCommonT; auto common = new Convolution2DCommonT;
param->common.reset(common); param->common.reset(common);
op->main.value = param; op->main.value = param;
common->padX = pads[0]; if (pads.size() == 4) {
common->padY = pads[1]; common->pads = pads;
} else {
common->padX = pads[0];
common->padY = pads[1];
}
common->strideX = stride[0]; common->strideX = stride[0];
common->strideY = stride[1]; common->strideY = stride[1];
common->dilateX = dilate[0]; common->dilateX = dilate[0];

View File

@ -713,6 +713,7 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, std::shared_ptr<BufferStorage> bufferStorage, std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Module::Config* config, std::map<std::string, SubGraph>& subGraphMap) { Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, std::shared_ptr<BufferStorage> bufferStorage, std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Module::Config* config, std::map<std::string, SubGraph>& subGraphMap) {
MNN_ASSERT(nullptr != rtMgr); MNN_ASSERT(nullptr != rtMgr);
MNN_ASSERT(nullptr != config);
std::shared_ptr<Schedule::ScheduleInfo> sharedConst; std::shared_ptr<Schedule::ScheduleInfo> sharedConst;
auto buffer = bufferStorage->buffer(); auto buffer = bufferStorage->buffer();
auto length = bufferStorage->size(); auto length = bufferStorage->size();
@ -721,12 +722,14 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
// Extra Const Tensors // Extra Const Tensors
sharedConst.reset(new Schedule::ScheduleInfo); sharedConst.reset(new Schedule::ScheduleInfo);
auto curExe = ExecutorScope::Current(); auto curExe = ExecutorScope::Current();
bool permitCodeGen = false; bool preReplaceConstTensor = true;
std::shared_ptr<ModuleRuntimeConfig> modRuntimeCfgPtr(new ModuleRuntimeConfig); std::shared_ptr<ModuleRuntimeConfig> modRuntimeCfgPtr(new ModuleRuntimeConfig);
if (!rtMgr->getInside()->mContent->mExternalFile.empty()) { if (!rtMgr->getInside()->mContent->mExternalFile.empty()) {
modRuntimeCfgPtr->externalFile = rtMgr->getInside()->mContent->mExternalFile; modRuntimeCfgPtr->externalFile = rtMgr->getInside()->mContent->mExternalFile;
} }
permitCodeGen = rtMgr->getInside()->mContent->modes.codegenMode == Interpreter::Session_Codegen_Enable; if (rtMgr->getInside()->mContent->modes.codegenMode == Interpreter::Session_Codegen_Enable || (!config->shapeMutable)) {
preReplaceConstTensor = false;
}
std::shared_ptr<Backend> defaultBackend = curExe->getAttr()->constantBackend; std::shared_ptr<Backend> defaultBackend = curExe->getAttr()->constantBackend;
std::vector<std::shared_ptr<Tensor>> allTensors; std::vector<std::shared_ptr<Tensor>> allTensors;
sharedConst->allTensors.resize(net->tensorName()->size()); sharedConst->allTensors.resize(net->tensorName()->size());
@ -795,7 +798,7 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
for (int i=0; i<subModulesInfo.size(); ++i) { for (int i=0; i<subModulesInfo.size(); ++i) {
subModules[i].reset(_createSubModule(bufferStorage, subModulesInfo[i], subGraphMap, sharedConst, *config, modRuntime)); subModules[i].reset(_createSubModule(bufferStorage, subModulesInfo[i], subGraphMap, sharedConst, *config, modRuntime));
} }
if (!permitCodeGen) { if (preReplaceConstTensor) {
// Prereplace const tensor // Prereplace const tensor
auto curBackend = sharedConst->constReplaceBackend.get(); auto curBackend = sharedConst->constReplaceBackend.get();
if (sharedConst->constReplaceBackend->type() != sharedConst->defaultBackend->type()) { if (sharedConst->constReplaceBackend->type() != sharedConst->defaultBackend->type()) {

View File

@ -76,6 +76,6 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
#define STR(x) STR_IMP(x) #define STR(x) STR_IMP(x)
#define MNN_VERSION_MAJOR 3 #define MNN_VERSION_MAJOR 3
#define MNN_VERSION_MINOR 0 #define MNN_VERSION_MINOR 0
#define MNN_VERSION_PATCH 4 #define MNN_VERSION_PATCH 5
#define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH) #define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH)
#endif /* MNNDefine_h */ #endif /* MNNDefine_h */

View File

@ -138,9 +138,6 @@ MNN_PUBLIC VARP _CumSum(VARP x, int axis, bool exclusive = false, bool reverse =
MNN_PUBLIC VARP _CumProd(VARP x, int axis); MNN_PUBLIC VARP _CumProd(VARP x, int axis);
MNN_PUBLIC VARPS _Svd(VARP x); MNN_PUBLIC VARPS _Svd(VARP x);
MNN_PUBLIC VARP _Histogram(VARP x, int bin, int min, int max, int channel = -1); MNN_PUBLIC VARP _Histogram(VARP x, int bin, int min, int max, int channel = -1);
#ifdef MNN_BUILD_AUDIO
MNN_PUBLIC VARP _Stft(VARP sample, VARP window, int n_fft, int hop_length, bool abse = true);
#endif
}; // namespace Express }; // namespace Express
}; // namespace MNN }; // namespace MNN

View File

@ -25,6 +25,7 @@ using Vec = MNN::Math::Vec<FLOAT16, 8>;
extern "C" { extern "C" {
// (UP_DIV(l,8), e, 8) -> (UP_DIV(e,eP), l, eP) // (UP_DIV(l,8), e, 8) -> (UP_DIV(e,eP), l, eP)
void Arm82MNNPackForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); void Arm82MNNPackForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);
// void MNNPackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, size_t depth, int32_t* areaOffset);
// C(UP_DIV(h,8), e, h8) = B(UP_DIV(h,hP), l, hP) * A(l, eP), hP = 24 // C(UP_DIV(h,8), e, h8) = B(UP_DIV(h,hP), l, hP) * A(l, eP), hP = 24
// parameter: [aStride, l, h, cStride, bExtraStride] // parameter: [aStride, l, h, cStride, bExtraStride]
@ -372,14 +373,44 @@ void MNNUnpackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, si
int c = (int)depth; int c = (int)depth;
int cDiv4 = c / 8; int cDiv4 = c / 8;
int cAlign = cDiv4 * 8; int cAlign = cDiv4 * 8;
int areaDiv4 = area / 4;
int areaAlign = areaDiv4 * 4;
for (int hi = 0; hi < area; ++hi) { if (areaAlign > 0) {
auto srcHeight = src + hi * 8;
auto dstHeight = dst + hi * c;
for (int ci = 0; ci < cDiv4; ++ci) { for (int ci = 0; ci < cDiv4; ++ci) {
vst1q_s16(dstHeight + ci * 8, vld1q_s16(srcHeight + 8 * ci * srcAreaOffset)); auto srcH = src + ci * 8 * srcAreaOffset;
auto dstH = dst + ci * 8;
for (int hi = 0; hi < areaAlign; hi+=4) {
auto src0 = srcH + hi * 8;
auto src1 = srcH + hi * 8 + 8;
auto src2 = srcH + hi * 8 + 16;
auto src3 = srcH + hi * 8 + 24;
auto dst0 = dstH + hi * c;
auto dst1 = dstH + hi * c + c;
auto dst2 = dstH + hi * c + 2 * c;
auto dst3 = dstH + hi * c + 3 * c;
vst1q_s16(dst0, vld1q_s16(src0));
vst1q_s16(dst1, vld1q_s16(src1));
vst1q_s16(dst2, vld1q_s16(src2));
vst1q_s16(dst3, vld1q_s16(src3));
}
} }
} }
if (areaAlign < area) {
for (int ci = 0; ci < cDiv4; ++ci) {
auto srcH = src + 8 * ci * srcAreaOffset;
auto dstH = dst + ci * 8;
for (int hi = areaAlign; hi < area; ++hi) {
auto src0 = srcH + hi * 8;
auto dst0 = dstH + hi * c;
vst1q_s16(dst0, vld1q_s16(src0));
}
}
}
if (c == cAlign) {
return;
}
int cReamin = c - cAlign; int cReamin = c - cAlign;
auto srcAlign = src + srcAreaOffset * cAlign; auto srcAlign = src + srcAreaOffset * cAlign;
@ -404,11 +435,37 @@ void MNNPackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, size
int c = (int)depth; int c = (int)depth;
int cDiv4 = c / 8; int cDiv4 = c / 8;
int cAlign = cDiv4 * 8; int cAlign = cDiv4 * 8;
for (int hi = 0; hi < area; ++hi) { int areaDiv4 = area / 4;
auto srcHeight = (src + hi * c); int areaAlign = areaDiv4 * 4;
auto dstHeight = (dst + hi * 8); if (areaAlign > 0) {
for (int ci = 0; ci < cDiv4; ++ci) { for (int ci = 0; ci < cDiv4; ++ci) {
vst1q_s16(dstHeight + ci * dstAreaOffset * 8, vld1q_s16(srcHeight + 8 * ci)); auto srcH = src + ci * 8;
auto dstH = dst + ci * dstAreaOffset * 8;
for (int hi = 0; hi < areaAlign; hi+=4) {
auto src0 = srcH + hi * c;
auto src1 = srcH + hi * c + c;
auto src2 = srcH + hi * c + 2 * c;
auto src3 = srcH + hi * c + 3 * c;
auto dst0 = dstH + hi * 8;
auto dst1 = dstH + hi * 8 + 8;
auto dst2 = dstH + hi * 8 + 16;
auto dst3 = dstH + hi * 8 + 24;
vst1q_s16(dst0, vld1q_s16(src0));
vst1q_s16(dst1, vld1q_s16(src1));
vst1q_s16(dst2, vld1q_s16(src2));
vst1q_s16(dst3, vld1q_s16(src3));
}
}
}
if (areaAlign < area) {
for (int ci = 0; ci < cDiv4; ++ci) {
auto srcH = src + ci * 8;
auto dstH = dst + ci * dstAreaOffset * 8;
for (int hi = areaAlign; hi < area; ++hi) {
auto src0 = srcH + hi * c;
auto dst0 = dstH + hi * 8;
vst1q_s16(dst0, vld1q_s16(src0));
}
} }
} }

View File

@ -63,6 +63,12 @@
fcvtn \d1\().4h, \s2\().4s fcvtn \d1\().4h, \s2\().4s
fcvtn2 \d1\().8h, \s3\().4s fcvtn2 \d1\().8h, \s3\().4s
.endm .endm
.macro ADD_FLOAT d0, d1, d2, d3, s0, s1, s2, s3
fadd \d0\().4s, \d0\().4s, \s0\().4s
fadd \d1\().4s, \d1\().4s, \s1\().4s
fadd \d2\().4s, \d2\().4s, \s2\().4s
fadd \d3\().4s, \d3\().4s, \s3\().4s
.endm
asm_function MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16 asm_function MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16
/* /*
@ -90,36 +96,32 @@ struct QuanPostTreatParameters {
//Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step //Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step
//x5:dst_depth_quad, x6: parameters, x7: realDstCount //x5:dst_depth_quad, x6: parameters, x7: realDstCount
//Load from x6: x8: scale, x9: bias, x25: xKernelSum, x26: weightQuantBias, x23: fp32minmax, x27: blockNum //Load from x6: x9: bias, x8: xKernelSum, x23: fp32minmax
ldr x8, [x6, #0]
ldr x9, [x6, #8] ldr x9, [x6, #8]
//ldr w12, [x6, #16]
stp d14, d15, [sp, #(-16 * 10)]! stp d14, d15, [sp, #(-16 * 8)]!
stp d12, d13, [sp, #(16 * 1)] stp d12, d13, [sp, #(16 * 1)]
stp d10, d11, [sp, #(16 * 2)] stp d10, d11, [sp, #(16 * 2)]
stp d8, d9, [sp, #(16 * 3)] stp d8, d9, [sp, #(16 * 3)]
stp x21, x22, [sp, #(16 * 4)] stp x21, x22, [sp, #(16 * 4)]
stp x19, x20, [sp, #(16 * 5)] stp x19, x20, [sp, #(16 * 5)]
stp x27, x28, [sp, #(16 * 6)] stp x23, x24, [sp, #(16 * 6)]
stp x25, x26, [sp, #(16 * 7)] stp x25, x26, [sp, #(16 * 7)]
stp x23, x24, [sp, #(16 * 8)]
ldr x25, [x6, #40] // xKernelSum ldr x8, [x6, #40] // srcKernelSum
ldr x26, [x6, #48] // weightQuantBias ldr x24, [x6, #80] // extraScale
ldr x15, [x6, #96] // accumBuffer
mov x10, x15
mov x25, x24
mov x21, #16 // sizeof(float) * pack
ldr x23, [x6, #56] // fp32minmax ldr x23, [x6, #56] // fp32minmax
lsl x22, x7, #2 // eDest * SRC_UNIT
//add x24, x23, #4
mov x21, #16 // sizeof(float16_t) * PACK
Start:
lsl x15, x3, #5 // x15 = src_depth_quad * UNIT * SRC_UNIT
lsl x22, x7, #2 // src_steps
ldr x27, [x6, #80] // extra scale
TILE_12: TILE_12:
cmp x7, #12 cmp x7, #12
blt TILE_8 blt TILE_8
sub x4, x4, #128
L8LoopDz_TILE_12: L8LoopDz_TILE_12:
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
@ -130,7 +132,6 @@ L8LoopDz_TILE_12:
SET_BIAS v24, v25, v26, v27 SET_BIAS v24, v25, v26, v27
SET_BIAS v28, v29, v30, v31 SET_BIAS v28, v29, v30, v31
mov x28, x2
L8LoopSz_TILE_12: L8LoopSz_TILE_12:
ld1 {v3.16b, v4.16b}, [x2], #32 // weight ld1 {v3.16b, v4.16b}, [x2], #32 // weight
ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src
@ -164,13 +165,12 @@ L8LoopDz_TILE_12:
bne L8LoopSz_TILE_12 bne L8LoopSz_TILE_12
L8LoopSzEnd_TILE_12: L8LoopSzEnd_TILE_12:
add x2, x28, x15
sub x5, x5, #1 sub x5, x5, #1
L8Tile12Quan: L8Tile12Quan:
ld1 {v0.4s, v1.4s}, [x8], #32 // scale ld1 {v0.4s, v1.4s}, [x2], #32 // scale
ld1 {v2.4s, v3.4s, v4.4s}, [x25] // x kernel sum ld1 {v2.4s, v3.4s, v4.4s}, [x8] // x kernel sum
ld1 {v5.4s, v6.4s}, [x26], #32 // weight quan zeropoint ld1 {v5.4s, v6.4s}, [x2], #32 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
Int32ToFloat v16, v17, v18, v19 Int32ToFloat v16, v17, v18, v19
@ -185,16 +185,16 @@ L8LoopDz_TILE_12:
MUL_SCALE v1, v24, v25, v26, v27 MUL_SCALE v1, v24, v25, v26, v27
MUL_SCALE v1, v28, v29, v30, v31 MUL_SCALE v1, v28, v29, v30, v31
cbz x27, TILE12_L8_MLA_TERM cbz x25, TILE12_L8_MLA_TERM
ld1 {v0.4s, v1.4s}, [x27], #32 ld1 {v0.4s, v1.4s}, [x24], #32
ld1 {v7.4s}, [x27] ld1 {v7.4s}, [x24]
MUL_EXTRA_SCALE v0, v8, v9, v10, v11 MUL_EXTRA_SCALE v0, v8, v9, v10, v11
MUL_EXTRA_SCALE v1, v12, v13, v14, v15 MUL_EXTRA_SCALE v1, v12, v13, v14, v15
MUL_EXTRA_SCALE v7, v16, v17, v18, v19 MUL_EXTRA_SCALE v7, v16, v17, v18, v19
MUL_EXTRA_SCALE v0, v20, v21, v22, v23 MUL_EXTRA_SCALE v0, v20, v21, v22, v23
MUL_EXTRA_SCALE v1, v24, v25, v26, v27 MUL_EXTRA_SCALE v1, v24, v25, v26, v27
MUL_EXTRA_SCALE v7, v28, v29, v30, v31 MUL_EXTRA_SCALE v7, v28, v29, v30, v31
sub x27, x27, #32 sub x24, x24, #32
TILE12_L8_MLA_TERM: TILE12_L8_MLA_TERM:
MLA_WEIGHTZERO v8, v2, v5, 0 // tile:0, oc:0-3 MLA_WEIGHTZERO v8, v2, v5, 0 // tile:0, oc:0-3
@ -222,7 +222,6 @@ L8LoopDz_TILE_12:
MLA_WEIGHTZERO v29, v4, v6, 1 // tile:9, oc:4-7 MLA_WEIGHTZERO v29, v4, v6, 1 // tile:9, oc:4-7
MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7 MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7
MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7 MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7
sub x4, x4, #128
cbz x9, TILE12_ADD_DSTV cbz x9, TILE12_ADD_DSTV
TILE12_ADD_BIAS: TILE12_ADD_BIAS:
@ -233,40 +232,40 @@ L8LoopDz_TILE_12:
ADD_BIAS_FLOAT v20, v21, v22, v23, v1 ADD_BIAS_FLOAT v20, v21, v22, v23, v1
ADD_BIAS_FLOAT v24, v25, v26, v27, v1 ADD_BIAS_FLOAT v24, v25, v26, v27, v1
ADD_BIAS_FLOAT v28, v29, v30, v31, v1 ADD_BIAS_FLOAT v28, v29, v30, v31, v1
cbnz x0, TILE12_POST
Float32ToHalf v8, v20, v9, v21, v0, v1 b TILE12_L8_ACCUM_BUFFER
Float32ToHalf v10, v22, v11, v23, v2, v3
Float32ToHalf v12, v24, v13, v25, v4, v5
Float32ToHalf v14, v26, v15, v27, v6, v7
Float32ToHalf v16, v28, v17, v29, v8, v9
Float32ToHalf v18, v30, v19, v31, v10, v11
b TILE12_POST
TILE12_ADD_DSTV: TILE12_ADD_DSTV:
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3
ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ADD_FLOAT v16, v17, v18, v19, v0, v1, v2, v3
ADD_FLOAT v20, v21, v22, v23, v4, v5, v6, v7
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ADD_FLOAT v24, v25, v26, v27, v0, v1, v2, v3
ADD_FLOAT v28, v29, v30, v31, v4, v5, v6, v7
cbnz x0, TILE12_POST
TILE12_L8_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x15], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x15], #64
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x15], #64
st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x15], #64
b L8Tile12LoopCheck
TILE12_POST:
Float32ToHalf v8, v20, v9, v21, v0, v1 Float32ToHalf v8, v20, v9, v21, v0, v1
Float32ToHalf v10, v22, v11, v23, v2, v3 Float32ToHalf v10, v22, v11, v23, v2, v3
Float32ToHalf v12, v24, v13, v25, v4, v5 Float32ToHalf v12, v24, v13, v25, v4, v5
Float32ToHalf v14, v26, v15, v27, v6, v7 Float32ToHalf v14, v26, v15, v27, v6, v7
Float32ToHalf v16, v28, v17, v29, v8, v9 Float32ToHalf v16, v28, v17, v29, v8, v9
Float32ToHalf v18, v30, v19, v31, v10, v11 Float32ToHalf v18, v30, v19, v31, v10, v11
ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64
ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64
ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0]
fadd v0.8h, v0.8h, v20.8h
fadd v1.8h, v1.8h, v21.8h
fadd v2.8h, v2.8h, v22.8h
fadd v3.8h, v3.8h, v23.8h
fadd v4.8h, v4.8h, v12.8h
fadd v5.8h, v5.8h, v13.8h
fadd v6.8h, v6.8h, v14.8h
fadd v7.8h, v7.8h, v15.8h
fadd v8.8h, v8.8h, v16.8h
fadd v9.8h, v9.8h, v17.8h
fadd v10.8h, v10.8h, v18.8h
fadd v11.8h, v11.8h, v19.8h
sub x0, x0, #128
TILE12_POST:
cbz x23, TILE12_STORE cbz x23, TILE12_STORE
ld1r {v24.8h}, [x23], #2 // f32 min ld1r {v24.8h}, [x23], #2 // f32 min
ld1r {v25.8h}, [x23] // f32 max ld1r {v25.8h}, [x23] // f32 max
@ -281,24 +280,21 @@ L8LoopDz_TILE_12:
st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64
st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64
st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], x4 st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], x4
add x4, x4, #128
L8Tile12LoopCheck: L8Tile12LoopCheck:
cmp x5, #1 cmp x5, #1
bge L8LoopDz_TILE_12 bge L8LoopDz_TILE_12
blt End b End
TILE_8: TILE_8:
cmp x7, #8 cmp x7, #8
blt TILE_4 blt TILE_4
mov x10, x0 sub x19, x4, #64
mov x6, x0
mov x12, x2 mov x12, x2
mov x14, x5 mov x14, x5
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x26 // weightQuantBias
L8LoopDz_TILE_8: L8LoopDz_TILE_8:
//ld1 {v0.4s, v1.4s}, [x20], #32 // bias
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
@ -306,7 +302,6 @@ L8LoopDz_TILE_8:
SET_BIAS v12, v13, v14, v15 SET_BIAS v12, v13, v14, v15
SET_BIAS v16, v17, v18, v19 SET_BIAS v16, v17, v18, v19
SET_BIAS v20, v21, v22, v23 SET_BIAS v20, v21, v22, v23
mov x28, x12
L8LoopSz_TILE_8: L8LoopSz_TILE_8:
ld1 {v3.16b, v4.16b}, [x12], #32 // weight ld1 {v3.16b, v4.16b}, [x12], #32 // weight
ld1 {v0.16b, v1.16b}, [x11], x22 // src ld1 {v0.16b, v1.16b}, [x11], x22 // src
@ -332,13 +327,12 @@ L8LoopDz_TILE_8:
bne L8LoopSz_TILE_8 bne L8LoopSz_TILE_8
L8LoopSzEnd_TILE_8: L8LoopSzEnd_TILE_8:
add x12, x28, x15
sub x14, x14, #1 sub x14, x14, #1
L8Tile8Quan: L8Tile8Quan:
ld1 {v0.4s, v1.4s}, [x19], #32 // scale ld1 {v0.4s, v1.4s}, [x12], #32 // scale
ld1 {v2.4s, v3.4s}, [x25] // x kernel sum ld1 {v2.4s, v3.4s}, [x8] // x kernel sum
ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint ld1 {v24.4s, v25.4s}, [x12], #32 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
Int32ToFloat v16, v17, v18, v19 Int32ToFloat v16, v17, v18, v19
@ -348,8 +342,8 @@ L8LoopDz_TILE_8:
MUL_SCALE v1, v16, v17, v18, v19 MUL_SCALE v1, v16, v17, v18, v19
MUL_SCALE v1, v20, v21, v22, v23 MUL_SCALE v1, v20, v21, v22, v23
cbz x27, TILE8_L8_MLA_TERM cbz x25, TILE8_L8_MLA_TERM
ld1 {v4.4s, v5.4s}, [x27] ld1 {v4.4s, v5.4s}, [x24]
MUL_EXTRA_SCALE v4, v8, v9, v10, v11 MUL_EXTRA_SCALE v4, v8, v9, v10, v11
MUL_EXTRA_SCALE v5, v12, v13, v14, v15 MUL_EXTRA_SCALE v5, v12, v13, v14, v15
MUL_EXTRA_SCALE v4, v16, v17, v18, v19 MUL_EXTRA_SCALE v4, v16, v17, v18, v19
@ -373,8 +367,6 @@ L8LoopDz_TILE_8:
MLA_WEIGHTZERO v22, v3, v25, 2 // tile:6, oc:4-7 MLA_WEIGHTZERO v22, v3, v25, 2 // tile:6, oc:4-7
MLA_WEIGHTZERO v23, v3, v25, 3 // tile:7, oc:4-7 MLA_WEIGHTZERO v23, v3, v25, 3 // tile:7, oc:4-7
sub x4, x4, #64
cbz x9, TILE8_ADD_DSTV cbz x9, TILE8_ADD_DSTV
TILE8_ADD_BIAS: TILE8_ADD_BIAS:
ld1 {v0.4s, v1.4s}, [x20], #32 ld1 {v0.4s, v1.4s}, [x20], #32
@ -382,31 +374,33 @@ L8LoopDz_TILE_8:
ADD_BIAS_FLOAT v12, v13, v14, v15, v0 ADD_BIAS_FLOAT v12, v13, v14, v15, v0
ADD_BIAS_FLOAT v16, v17, v18, v19, v1 ADD_BIAS_FLOAT v16, v17, v18, v19, v1
ADD_BIAS_FLOAT v20, v21, v22, v23, v1 ADD_BIAS_FLOAT v20, v21, v22, v23, v1
cbnz x0, TILE8_POST
Float32ToHalf v8, v16, v9, v17, v0, v1 b TILE8_L8_ACCUM_BUFFER
Float32ToHalf v10, v18, v11, v19, v2, v3
Float32ToHalf v12, v20, v13, v21, v4, v5
Float32ToHalf v14, v22, v15, v23, v6, v7
b TILE8_POST
TILE8_ADD_DSTV: TILE8_ADD_DSTV:
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], #64
ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3
ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7
ADD_FLOAT v16, v17, v18, v19, v24, v25, v26, v27
ADD_FLOAT v20, v21, v22, v23, v28, v29, v30, v31
cbnz x0, TILE8_POST
TILE8_L8_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x15], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x15], #64
b L8Tile8LoopCheck
TILE8_POST:
Float32ToHalf v8, v16, v9, v17, v0, v1 Float32ToHalf v8, v16, v9, v17, v0, v1
Float32ToHalf v10, v18, v11, v19, v2, v3 Float32ToHalf v10, v18, v11, v19, v2, v3
Float32ToHalf v12, v20, v13, v21, v4, v5 Float32ToHalf v12, v20, v13, v21, v4, v5
Float32ToHalf v14, v22, v15, v23, v6, v7 Float32ToHalf v14, v22, v15, v23, v6, v7
ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x10], #64
ld1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x10]
fadd v0.8h, v0.8h, v24.8h
fadd v1.8h, v1.8h, v25.8h
fadd v2.8h, v2.8h, v26.8h
fadd v3.8h, v3.8h, v27.8h
fadd v4.8h, v4.8h, v28.8h
fadd v5.8h, v5.8h, v29.8h
fadd v6.8h, v6.8h, v30.8h
fadd v7.8h, v7.8h, v31.8h
sub x10, x10, #64
TILE8_POST:
cbz x23, TILE8_STORE cbz x23, TILE8_STORE
ld1r {v24.8h}, [x23], #2 // f16 min ld1r {v24.8h}, [x23], #2 // f16 min
ld1r {v25.8h}, [x23] // f16 max ld1r {v25.8h}, [x23] // f16 max
@ -415,45 +409,35 @@ L8LoopDz_TILE_8:
sub x23, x23, #2 sub x23, x23, #2
TILE8_STORE: TILE8_STORE:
st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x6], #64
st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], x4 st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x6], x19
//st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], #64
//st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4
//st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64
//st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], x4
add x4, x4, #64
L8Tile8LoopCheck: L8Tile8LoopCheck:
cmp x14, #1 cmp x14, #1
bge L8LoopDz_TILE_8 bge L8LoopDz_TILE_8
cbz x27, Tile8End
add x27, x27, #32
Tile8End: Tile8End:
sub x7, x7, #8 cbz x0, Tile8_End_Offset
add x0, x0, x21, LSL #3 add x0, x0, x21, LSL #3
Tile8_End_Offset:
sub x7, x7, #8
add x1, x1, #32 add x1, x1, #32
add x25, x25, #32 add x8, x8, #32
add x24, x24, #32
cbz x7, End
TILE_4: TILE_4:
cmp x7, #4 cmp x7, #4
blt TILE_1 blt TILE_1
mov x10, x0 mov x6, x0
mov x12, x2 mov x12, x2
mov x14, x5 mov x14, x5
mov x19, x8
mov x20, x9 mov x20, x9
mov x6, x26 // weightQuantBias
L8LoopDz_TILE_4: L8LoopDz_TILE_4:
//ld1 {v0.4s, v1.4s}, [x20], #32 // bias
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
SET_BIAS v8, v9, v10, v11 SET_BIAS v8, v9, v10, v11
SET_BIAS v12, v13, v14, v15 SET_BIAS v12, v13, v14, v15
mov x28, x12
L8LoopSz_TILE_4: L8LoopSz_TILE_4:
ld1 {v3.16b, v4.16b}, [x12], #32 // weight ld1 {v3.16b, v4.16b}, [x12], #32 // weight
ld1 {v0.16b}, [x11], x22 // src ld1 {v0.16b}, [x11], x22 // src
@ -469,20 +453,19 @@ L8LoopDz_TILE_4:
bne L8LoopSz_TILE_4 bne L8LoopSz_TILE_4
L8LoopSzEnd_TILE_4: L8LoopSzEnd_TILE_4:
add x12, x28, x15
sub x14, x14, #1 sub x14, x14, #1
L8Tile4Quan: L8Tile4Quan:
ld1 {v0.4s, v1.4s}, [x19], #32 // scale ld1 {v0.4s, v1.4s}, [x12], #32 // scale
ld1 {v2.4s}, [x25] // x kernel sum ld1 {v2.4s}, [x8] // x kernel sum
ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint ld1 {v24.4s, v25.4s}, [x12], #32 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
MUL_SCALE v0, v8, v9, v10, v11 MUL_SCALE v0, v8, v9, v10, v11
MUL_SCALE v1, v12, v13, v14, v15 MUL_SCALE v1, v12, v13, v14, v15
cbz x27, TILE4_L8_MLA_TERM cbz x25, TILE4_L8_MLA_TERM
ld1 {v4.4s}, [x27] ld1 {v4.4s}, [x24]
MUL_EXTRA_SCALE v4, v8, v9, v10, v11 MUL_EXTRA_SCALE v4, v8, v9, v10, v11
MUL_EXTRA_SCALE v4, v12, v13, v14, v15 MUL_EXTRA_SCALE v4, v12, v13, v14, v15
@ -498,63 +481,61 @@ L8LoopDz_TILE_4:
cbz x9, TILE4_ADD_DSTV cbz x9, TILE4_ADD_DSTV
TILE4_ADD_BIAS: TILE4_ADD_BIAS:
ld1 {v0.4s, v1.4s}, [x20], #32 ld1 {v4.4s, v5.4s}, [x20], #32
ADD_BIAS_FLOAT v8, v9, v10, v11, v0 ADD_BIAS_FLOAT v8, v9, v10, v11, v4
ADD_BIAS_FLOAT v12, v13, v14, v15, v1 ADD_BIAS_FLOAT v12, v13, v14, v15, v5
Float32ToHalf v8, v12, v9, v13, v0, v1 cbnz x0, TILE4_POST
Float32ToHalf v10, v14, v11, v15, v2, v3 b TILE4_L8_ACCUM_BUFFER
b TILE4_POST
TILE4_ADD_DSTV: TILE4_ADD_DSTV:
Float32ToHalf v8, v12, v9, v13, v0, v1 ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64
Float32ToHalf v10, v14, v11, v15, v2, v3 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x10] ADD_FLOAT v8, v9, v10, v11, v16, v17, v18, v19
fadd v0.8h, v0.8h, v20.8h ADD_FLOAT v12, v13, v14, v15, v20, v21, v22, v23
fadd v1.8h, v1.8h, v21.8h cbnz x0, TILE4_POST
fadd v2.8h, v2.8h, v22.8h
fadd v3.8h, v3.8h, v23.8h TILE4_L8_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
b L8Tile4LoopCheck
TILE4_POST: TILE4_POST:
Float32ToHalf v8, v12, v9, v13, v0, v1
Float32ToHalf v10, v14, v11, v15, v2, v3
cbz x23, TILE4_STORE cbz x23, TILE4_STORE
ld1r {v24.8h}, [x23], #2 // f16 min ld1r {v24.8h}, [x23], #2 // f16 min
ld1r {v25.8h}, [x23] // f16 max ld1r {v25.8h}, [x23] // f16 max
sub x23, x23, #2 sub x23, x23, #2
ReLU_FP16 v0, v1, v2, v3, v24, v25 ReLU_FP16 v0, v1, v2, v3, v24, v25
//st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], x4
//st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4
TILE4_STORE: TILE4_STORE:
st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], x4 st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x6], x4
L8Tile4LoopCheck: L8Tile4LoopCheck:
cmp x14, #1 cmp x14, #1
bge L8LoopDz_TILE_4 bge L8LoopDz_TILE_4
cbz x27, Tile4End
add x27, x27, #16
Tile4End: Tile4End:
sub x7, x7, #4 cbz x0, Tile4_End_Offset
add x0, x0, x21, LSL #2 add x0, x0, x21, LSL #2
Tile4_End_Offset:
sub x7, x7, #4
add x1, x1, #16 add x1, x1, #16
add x25, x25, #16 add x8, x8, #16
add x24, x24, #16
cbz x7, End
TILE_1: TILE_1:
cbz x7, End mov x6, x0
mov x10, x0
mov x12, x2 mov x12, x2
mov x14, x5 mov x14, x5
mov x19, x8
mov x20, x9 mov x20, x9
mov x6, x26 // weightQuantBias
L8LoopDz_TILE_1: L8LoopDz_TILE_1:
//ld1 {v0.4s, v1.4s}, [x20], #32 // bias
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
movi v8.16b, #0 movi v8.16b, #0
movi v9.16b, #0 movi v9.16b, #0
mov x28, x12
L8LoopSz_TILE_1: L8LoopSz_TILE_1:
ld1 {v3.16b, v4.16b}, [x12], #32 // weight ld1 {v3.16b, v4.16b}, [x12], #32 // weight
ld1 {v0.s}[0], [x11], x22 // src ld1 {v0.s}[0], [x11], x22 // src
@ -564,20 +545,19 @@ L8LoopDz_TILE_1:
bne L8LoopSz_TILE_1 bne L8LoopSz_TILE_1
L8LoopSzEnd_TILE_1: L8LoopSzEnd_TILE_1:
add x12, x28, x15
sub x14, x14, #1 sub x14, x14, #1
L8Tile1Quan: L8Tile1Quan:
ld1 {v0.4s, v1.4s}, [x19], #32 // scale ld1 {v0.4s, v1.4s}, [x12], #32 // scale
ld1 {v2.s}[0], [x25] // x kernel sum ld1 {v2.s}[0], [x8] // x kernel sum
ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint ld1 {v24.4s, v25.4s}, [x12], #32 // weight quan zeropoint
scvtf v8.4s, v8.4s scvtf v8.4s, v8.4s
scvtf v9.4s, v9.4s scvtf v9.4s, v9.4s
fmul v8.4s, v8.4s, v0.4s fmul v8.4s, v8.4s, v0.4s
fmul v9.4s, v9.4s, v1.4s fmul v9.4s, v9.4s, v1.4s
cbz x27, TILE1_L8_MLA_TERM cbz x25, TILE1_L8_MLA_TERM
ld1 {v4.s}[0], [x27] ld1 {v4.s}[0], [x24]
fmul v8.4s, v8.4s, v4.s[0] fmul v8.4s, v8.4s, v4.s[0]
fmul v9.4s, v9.4s, v4.s[0] fmul v9.4s, v9.4s, v4.s[0]
@ -587,52 +567,56 @@ L8LoopDz_TILE_1:
cbz x9, TILE1_ADD_DSTV cbz x9, TILE1_ADD_DSTV
TILE1_ADD_BIAS: TILE1_ADD_BIAS:
ld1 {v20.4s, v21.4s}, [x20], #32 ld1 {v10.4s, v11.4s}, [x20], #32
fadd v8.4s, v8.4s, v20.4s fadd v8.4s, v8.4s, v10.4s
fadd v9.4s, v9.4s, v21.4s fadd v9.4s, v9.4s, v11.4s
fcvtn v0.4h, v8.4s cbnz x0, TILE1_POST
fcvtn2 v0.8h, v9.4s b TILE1_L8_ACCUM_BUFFER
b TILE1_POST
TILE1_ADD_DSTV: TILE1_ADD_DSTV:
fcvtn v0.4h, v8.4s ld1 {v10.4s, v11.4s}, [x10], #32
fcvtn2 v0.8h, v9.4s fadd v8.4s, v8.4s, v10.4s
ld1 {v3.8h}, [x10] fadd v9.4s, v9.4s, v11.4s
fadd v0.8h, v0.8h, v3.8h cbnz x0, TILE1_POST
TILE1_L8_ACCUM_BUFFER:
st1 {v8.4s, v9.4s}, [x15], #32
b L8Tile1LoopCheck
TILE1_POST: TILE1_POST:
fcvtn v0.4h, v8.4s
fcvtn2 v0.8h, v9.4s
cbz x23, TILE1_STORE cbz x23, TILE1_STORE
ld1r {v24.8h}, [x23], #2 // f32 min ld1r {v24.8h}, [x23], #2 // f16 min
ld1r {v25.8h}, [x23] // f32 max ld1r {v25.8h}, [x23] // f16 max
sub x23, x23, #2 sub x23, x23, #2
fmax v0.8h, v24.8h, v0.8h fmax v0.8h, v24.8h, v0.8h
fmin v0.8h, v25.8h, v0.8h fmin v0.8h, v25.8h, v0.8h
TILE1_STORE: TILE1_STORE:
st1 {v0.8h}, [x10], x4 st1 {v0.8h}, [x6], x4
L8Tile1LoopCheck: L8Tile1LoopCheck:
cmp x14, #1 cmp x14, #1
bge L8LoopDz_TILE_1 bge L8LoopDz_TILE_1
cbz x27, Tile1End
add x27, x27, #4
Tile1End: Tile1End:
sub x7, x7, #1 cbz x0, Tile1_End_Offset
add x0, x0, x21 add x0, x0, x21
Tile1_End_Offset:
add x24, x24, #4
subs x7, x7, #1
add x1, x1, #4 add x1, x1, #4
add x25, x25, #4 add x8, x8, #4
b TILE_1 bne TILE_1
End: End:
ldp x23, x24, [sp, #(16 * 8)]
ldp x25, x26, [sp, #(16 * 7)] ldp x25, x26, [sp, #(16 * 7)]
ldp x27, x28, [sp, #(16 * 6)] ldp x23, x24, [sp, #(16 * 6)]
ldp x19, x20, [sp, #(16 * 5)] ldp x19, x20, [sp, #(16 * 5)]
ldp x21, x22, [sp, #(16 * 4)] ldp x21, x22, [sp, #(16 * 4)]
ldp d8, d9, [sp, #(16 * 3)] ldp d8, d9, [sp, #(16 * 3)]
ldp d10, d11, [sp, #(16 * 2)] ldp d10, d11, [sp, #(16 * 2)]
ldp d12, d13, [sp, #(16 * 1)] ldp d12, d13, [sp, #(16 * 1)]
ldp d14, d15, [sp], #(16 * 10) ldp d14, d15, [sp], #(16 * 8)
ret ret
#endif // __aarch64__ #endif // __aarch64__

View File

@ -63,6 +63,12 @@
fcvtn \d1\().4h, \s2\().4s fcvtn \d1\().4h, \s2\().4s
fcvtn2 \d1\().8h, \s3\().4s fcvtn2 \d1\().8h, \s3\().4s
.endm .endm
.macro ADD_FLOAT d0, d1, d2, d3, s0, s1, s2, s3
fadd \d0\().4s, \d0\().4s, \s0\().4s
fadd \d1\().4s, \d1\().4s, \s1\().4s
fadd \d2\().4s, \d2\().4s, \s2\().4s
fadd \d3\().4s, \d3\().4s, \s3\().4s
.endm
asm_function MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16 asm_function MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16
/* /*
@ -90,41 +96,37 @@ struct QuanPostTreatParameters {
//Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step //Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step
//x5:dst_depth_quad, x6: parameters, x7: realDstCount //x5:dst_depth_quad, x6: parameters, x7: realDstCount
//Load from x6: x8: scale, x9: bias, x25: xKernelSum, x26: weightQuantBias, x23: fp32minmax, x27: blockNum //Load from x6: x9: bias, x8: xKernelSum, x23: fp32minmax
ldr x8, [x6, #0]
ldr x9, [x6, #8] ldr x9, [x6, #8]
//ldr w12, [x6, #16]
stp d14, d15, [sp, #(-16 * 10)]! stp d14, d15, [sp, #(-16 * 8)]!
stp d12, d13, [sp, #(16 * 1)] stp d12, d13, [sp, #(16 * 1)]
stp d10, d11, [sp, #(16 * 2)] stp d10, d11, [sp, #(16 * 2)]
stp d8, d9, [sp, #(16 * 3)] stp d8, d9, [sp, #(16 * 3)]
stp x21, x22, [sp, #(16 * 4)] stp x21, x22, [sp, #(16 * 4)]
stp x19, x20, [sp, #(16 * 5)] stp x19, x20, [sp, #(16 * 5)]
stp x27, x28, [sp, #(16 * 6)] stp x23, x24, [sp, #(16 * 6)]
stp x25, x26, [sp, #(16 * 7)] stp x25, x26, [sp, #(16 * 7)]
stp x23, x24, [sp, #(16 * 8)]
//ldr w27, [x6, #20] ldr x8, [x6, #40] // srcKernelSum
ldr x25, [x6, #40] // xKernelSum ldr x24, [x6, #80] // extraScale
ldr x26, [x6, #48] // weightQuantBias ldr x15, [x6, #96] // accumBuffer
mov x10, x15
mov x25, x24
mov x21, #16 // sizeof(float) * pack
ldr x23, [x6, #56] // fp32minmax ldr x23, [x6, #56] // fp32minmax
lsl x22, x7, #2 // eDest * SRC_UNIT
mov x21, #16 // sizeof(float16_t) * PACK
Start:
lsl x15, x3, #4 // x15 = src_depth_quad * UNIT * SRC_UNIT * sizeof(int4_t)
lsl x22, x7, #2 // src_steps
ldr x27, [x6, #80] // extra scale
TILE_12: TILE_12:
cmp x7, #12 cmp x7, #12
blt TILE_8 blt TILE_8
sub x4, x4, #128
L8LoopDz_TILE_12: L8LoopDz_TILE_12:
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
movi v7.16b, #15 movi v7.16b, #15
// Init 0
SET_BIAS v8, v9, v10, v11 SET_BIAS v8, v9, v10, v11
SET_BIAS v12, v13, v14, v15 SET_BIAS v12, v13, v14, v15
SET_BIAS v16, v17, v18, v19 SET_BIAS v16, v17, v18, v19
@ -132,7 +134,6 @@ L8LoopDz_TILE_12:
SET_BIAS v24, v25, v26, v27 SET_BIAS v24, v25, v26, v27
SET_BIAS v28, v29, v30, v31 SET_BIAS v28, v29, v30, v31
mov x28, x2
L8LoopSz_TILE_12: L8LoopSz_TILE_12:
ld1 {v5.16b}, [x2], #16 // weight ld1 {v5.16b}, [x2], #16 // weight
ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src
@ -171,13 +172,12 @@ L8LoopDz_TILE_12:
bne L8LoopSz_TILE_12 bne L8LoopSz_TILE_12
L8LoopSzEnd_TILE_12: L8LoopSzEnd_TILE_12:
add x2, x28, x15
sub x5, x5, #1 sub x5, x5, #1
L8Tile12Quan: L8Tile12Quan:
ld1 {v0.4s, v1.4s}, [x8], #32 // scale ld1 {v0.4s, v1.4s}, [x2], #32 // scale
ld1 {v2.4s, v3.4s, v4.4s}, [x25] // x kernel sum ld1 {v2.4s, v3.4s, v4.4s}, [x8] // x kernel sum
ld1 {v5.4s, v6.4s}, [x26], #32 // weight quan zeropoint ld1 {v5.4s, v6.4s}, [x2], #32 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
Int32ToFloat v16, v17, v18, v19 Int32ToFloat v16, v17, v18, v19
@ -192,16 +192,16 @@ L8LoopDz_TILE_12:
MUL_SCALE v1, v24, v25, v26, v27 MUL_SCALE v1, v24, v25, v26, v27
MUL_SCALE v1, v28, v29, v30, v31 MUL_SCALE v1, v28, v29, v30, v31
cbz x27, TILE12_L8_MLA_TERM cbz x25, TILE12_L8_MLA_TERM
ld1 {v0.4s, v1.4s}, [x27], #32 ld1 {v0.4s, v1.4s}, [x24], #32
ld1 {v7.4s}, [x27] ld1 {v7.4s}, [x24]
MUL_EXTRA_SCALE v0, v8, v9, v10, v11 MUL_EXTRA_SCALE v0, v8, v9, v10, v11
MUL_EXTRA_SCALE v1, v12, v13, v14, v15 MUL_EXTRA_SCALE v1, v12, v13, v14, v15
MUL_EXTRA_SCALE v7, v16, v17, v18, v19 MUL_EXTRA_SCALE v7, v16, v17, v18, v19
MUL_EXTRA_SCALE v0, v20, v21, v22, v23 MUL_EXTRA_SCALE v0, v20, v21, v22, v23
MUL_EXTRA_SCALE v1, v24, v25, v26, v27 MUL_EXTRA_SCALE v1, v24, v25, v26, v27
MUL_EXTRA_SCALE v7, v28, v29, v30, v31 MUL_EXTRA_SCALE v7, v28, v29, v30, v31
sub x27, x27, #32 sub x24, x24, #32
TILE12_L8_MLA_TERM: TILE12_L8_MLA_TERM:
MLA_WEIGHTZERO v8, v2, v5, 0 // tile:0, oc:0-3 MLA_WEIGHTZERO v8, v2, v5, 0 // tile:0, oc:0-3
@ -229,7 +229,6 @@ L8LoopDz_TILE_12:
MLA_WEIGHTZERO v29, v4, v6, 1 // tile:9, oc:4-7 MLA_WEIGHTZERO v29, v4, v6, 1 // tile:9, oc:4-7
MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7 MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7
MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7 MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7
sub x4, x4, #128
cbz x9, TILE12_ADD_DSTV cbz x9, TILE12_ADD_DSTV
TILE12_ADD_BIAS: TILE12_ADD_BIAS:
@ -240,40 +239,40 @@ L8LoopDz_TILE_12:
ADD_BIAS_FLOAT v20, v21, v22, v23, v1 ADD_BIAS_FLOAT v20, v21, v22, v23, v1
ADD_BIAS_FLOAT v24, v25, v26, v27, v1 ADD_BIAS_FLOAT v24, v25, v26, v27, v1
ADD_BIAS_FLOAT v28, v29, v30, v31, v1 ADD_BIAS_FLOAT v28, v29, v30, v31, v1
cbnz x0, TILE12_POST
Float32ToHalf v8, v20, v9, v21, v0, v1 b TILE12_L8_ACCUM_BUFFER
Float32ToHalf v10, v22, v11, v23, v2, v3
Float32ToHalf v12, v24, v13, v25, v4, v5
Float32ToHalf v14, v26, v15, v27, v6, v7
Float32ToHalf v16, v28, v17, v29, v8, v9
Float32ToHalf v18, v30, v19, v31, v10, v11
b TILE12_POST
TILE12_ADD_DSTV: TILE12_ADD_DSTV:
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3
ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ADD_FLOAT v16, v17, v18, v19, v0, v1, v2, v3
ADD_FLOAT v20, v21, v22, v23, v4, v5, v6, v7
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ADD_FLOAT v24, v25, v26, v27, v0, v1, v2, v3
ADD_FLOAT v28, v29, v30, v31, v4, v5, v6, v7
cbnz x0, TILE12_POST
TILE12_L8_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x15], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x15], #64
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x15], #64
st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x15], #64
b L8Tile12LoopCheck
TILE12_POST:
Float32ToHalf v8, v20, v9, v21, v0, v1 Float32ToHalf v8, v20, v9, v21, v0, v1
Float32ToHalf v10, v22, v11, v23, v2, v3 Float32ToHalf v10, v22, v11, v23, v2, v3
Float32ToHalf v12, v24, v13, v25, v4, v5 Float32ToHalf v12, v24, v13, v25, v4, v5
Float32ToHalf v14, v26, v15, v27, v6, v7 Float32ToHalf v14, v26, v15, v27, v6, v7
Float32ToHalf v16, v28, v17, v29, v8, v9 Float32ToHalf v16, v28, v17, v29, v8, v9
Float32ToHalf v18, v30, v19, v31, v10, v11 Float32ToHalf v18, v30, v19, v31, v10, v11
ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64
ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64
ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0]
fadd v0.8h, v0.8h, v20.8h
fadd v1.8h, v1.8h, v21.8h
fadd v2.8h, v2.8h, v22.8h
fadd v3.8h, v3.8h, v23.8h
fadd v4.8h, v4.8h, v12.8h
fadd v5.8h, v5.8h, v13.8h
fadd v6.8h, v6.8h, v14.8h
fadd v7.8h, v7.8h, v15.8h
fadd v8.8h, v8.8h, v16.8h
fadd v9.8h, v9.8h, v17.8h
fadd v10.8h, v10.8h, v18.8h
fadd v11.8h, v11.8h, v19.8h
sub x0, x0, #128
TILE12_POST:
cbz x23, TILE12_STORE cbz x23, TILE12_STORE
ld1r {v24.8h}, [x23], #2 // f32 min ld1r {v24.8h}, [x23], #2 // f32 min
ld1r {v25.8h}, [x23] // f32 max ld1r {v25.8h}, [x23] // f32 max
@ -288,7 +287,6 @@ L8LoopDz_TILE_12:
st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64
st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64
st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], x4 st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], x4
add x4, x4, #128
L8Tile12LoopCheck: L8Tile12LoopCheck:
cmp x5, #1 cmp x5, #1
bge L8LoopDz_TILE_12 bge L8LoopDz_TILE_12
@ -297,12 +295,11 @@ L8LoopDz_TILE_12:
TILE_8: TILE_8:
cmp x7, #8 cmp x7, #8
blt TILE_4 blt TILE_4
mov x10, x0 sub x19, x4, #64
mov x6, x0
mov x12, x2 mov x12, x2
mov x14, x5 mov x14, x5
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x26 // weightQuantBias
L8LoopDz_TILE_8: L8LoopDz_TILE_8:
mov x11, x1 mov x11, x1
@ -313,15 +310,13 @@ L8LoopDz_TILE_8:
SET_BIAS v12, v13, v14, v15 SET_BIAS v12, v13, v14, v15
SET_BIAS v16, v17, v18, v19 SET_BIAS v16, v17, v18, v19
SET_BIAS v20, v21, v22, v23 SET_BIAS v20, v21, v22, v23
mov x28, x12
L8LoopSz_TILE_8: L8LoopSz_TILE_8:
ld1 {v5.16b}, [x12], #16 // weight ld1 {v5.16b}, [x12], #16 // weight
ld1 {v0.16b, v1.16b}, [x11], x22 // src ld1 {v0.16b, v1.16b}, [x11], x22 // src
// int4->int8 // int4->int8
ushr v3.16b, v5.16b, #4 ushr v3.16b, v5.16b, #4
and v4.16b, v5.16b, v7.16b and v4.16b, v5.16b, v7.16b
//zip1 v3.16b, v5.16b, v6.16b
//zip2 v4.16b, v5.16b, v6.16b
.inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0]
.inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1]
@ -345,13 +340,12 @@ L8LoopDz_TILE_8:
bne L8LoopSz_TILE_8 bne L8LoopSz_TILE_8
L8LoopSzEnd_TILE_8: L8LoopSzEnd_TILE_8:
add x12, x28, x15
sub x14, x14, #1 sub x14, x14, #1
L8Tile8Quan: L8Tile8Quan:
ld1 {v0.4s, v1.4s}, [x19], #32 // scale ld1 {v0.4s, v1.4s}, [x12], #32 // scale
ld1 {v2.4s, v3.4s}, [x25] // x kernel sum ld1 {v2.4s, v3.4s}, [x8] // x kernel sum
ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint ld1 {v24.4s, v25.4s}, [x12], #32 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
Int32ToFloat v16, v17, v18, v19 Int32ToFloat v16, v17, v18, v19
@ -361,8 +355,8 @@ L8LoopDz_TILE_8:
MUL_SCALE v1, v16, v17, v18, v19 MUL_SCALE v1, v16, v17, v18, v19
MUL_SCALE v1, v20, v21, v22, v23 MUL_SCALE v1, v20, v21, v22, v23
cbz x27, TILE8_L8_MLA_TERM cbz x25, TILE8_L8_MLA_TERM
ld1 {v4.4s, v5.4s}, [x27] ld1 {v4.4s, v5.4s}, [x24]
MUL_EXTRA_SCALE v4, v8, v9, v10, v11 MUL_EXTRA_SCALE v4, v8, v9, v10, v11
MUL_EXTRA_SCALE v5, v12, v13, v14, v15 MUL_EXTRA_SCALE v5, v12, v13, v14, v15
MUL_EXTRA_SCALE v4, v16, v17, v18, v19 MUL_EXTRA_SCALE v4, v16, v17, v18, v19
@ -386,8 +380,6 @@ L8LoopDz_TILE_8:
MLA_WEIGHTZERO v22, v3, v25, 2 // tile:6, oc:4-7 MLA_WEIGHTZERO v22, v3, v25, 2 // tile:6, oc:4-7
MLA_WEIGHTZERO v23, v3, v25, 3 // tile:7, oc:4-7 MLA_WEIGHTZERO v23, v3, v25, 3 // tile:7, oc:4-7
sub x4, x4, #64
cbz x9, TILE8_ADD_DSTV cbz x9, TILE8_ADD_DSTV
TILE8_ADD_BIAS: TILE8_ADD_BIAS:
ld1 {v0.4s, v1.4s}, [x20], #32 ld1 {v0.4s, v1.4s}, [x20], #32
@ -395,31 +387,33 @@ L8LoopDz_TILE_8:
ADD_BIAS_FLOAT v12, v13, v14, v15, v0 ADD_BIAS_FLOAT v12, v13, v14, v15, v0
ADD_BIAS_FLOAT v16, v17, v18, v19, v1 ADD_BIAS_FLOAT v16, v17, v18, v19, v1
ADD_BIAS_FLOAT v20, v21, v22, v23, v1 ADD_BIAS_FLOAT v20, v21, v22, v23, v1
cbnz x0, TILE8_POST
Float32ToHalf v8, v16, v9, v17, v0, v1 b TILE8_L8_ACCUM_BUFFER
Float32ToHalf v10, v18, v11, v19, v2, v3
Float32ToHalf v12, v20, v13, v21, v4, v5
Float32ToHalf v14, v22, v15, v23, v6, v7
b TILE8_POST
TILE8_ADD_DSTV: TILE8_ADD_DSTV:
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], #64
ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3
ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7
ADD_FLOAT v16, v17, v18, v19, v24, v25, v26, v27
ADD_FLOAT v20, v21, v22, v23, v28, v29, v30, v31
cbnz x0, TILE8_POST
TILE8_L8_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x15], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x15], #64
b L8Tile8LoopCheck
TILE8_POST:
Float32ToHalf v8, v16, v9, v17, v0, v1 Float32ToHalf v8, v16, v9, v17, v0, v1
Float32ToHalf v10, v18, v11, v19, v2, v3 Float32ToHalf v10, v18, v11, v19, v2, v3
Float32ToHalf v12, v20, v13, v21, v4, v5 Float32ToHalf v12, v20, v13, v21, v4, v5
Float32ToHalf v14, v22, v15, v23, v6, v7 Float32ToHalf v14, v22, v15, v23, v6, v7
ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x10], #64
ld1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x10]
fadd v0.8h, v0.8h, v24.8h
fadd v1.8h, v1.8h, v25.8h
fadd v2.8h, v2.8h, v26.8h
fadd v3.8h, v3.8h, v27.8h
fadd v4.8h, v4.8h, v28.8h
fadd v5.8h, v5.8h, v29.8h
fadd v6.8h, v6.8h, v30.8h
fadd v7.8h, v7.8h, v31.8h
sub x10, x10, #64
TILE8_POST:
cbz x23, TILE8_STORE cbz x23, TILE8_STORE
ld1r {v24.8h}, [x23], #2 // f16 min ld1r {v24.8h}, [x23], #2 // f16 min
ld1r {v25.8h}, [x23] // f16 max ld1r {v25.8h}, [x23] // f16 max
@ -428,49 +422,42 @@ L8LoopDz_TILE_8:
sub x23, x23, #2 sub x23, x23, #2
TILE8_STORE: TILE8_STORE:
st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x6], #64
st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], x4 st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x6], x19
add x4, x4, #64
L8Tile8LoopCheck: L8Tile8LoopCheck:
cmp x14, #1 cmp x14, #1
bge L8LoopDz_TILE_8 bge L8LoopDz_TILE_8
cbz x27, Tile8End
add x27, x27, #32
Tile8End: Tile8End:
sub x7, x7, #8 cbz x0, Tile8_End_Offset
add x0, x0, x21, LSL #3 add x0, x0, x21, LSL #3
Tile8_End_Offset:
sub x7, x7, #8
add x1, x1, #32 add x1, x1, #32
add x25, x25, #32 add x8, x8, #32
add x24, x24, #32
cbz x7, End
TILE_4: TILE_4:
movi v7.16b, #15 movi v7.16b, #15
cmp x7, #4 cmp x7, #4
blt TILE_1 blt TILE_1
mov x10, x0 mov x6, x0
mov x12, x2 mov x12, x2
mov x14, x5 mov x14, x5
mov x19, x8
mov x20, x9 mov x20, x9
mov x6, x26 // weightQuantBias
L8LoopDz_TILE_4: L8LoopDz_TILE_4:
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
SET_BIAS v8, v9, v10, v11 SET_BIAS v8, v9, v10, v11
SET_BIAS v12, v13, v14, v15 SET_BIAS v12, v13, v14, v15
mov x28, x12
L8LoopSz_TILE_4: L8LoopSz_TILE_4:
ld1 {v5.16b}, [x12], #16 // weight ld1 {v5.16b}, [x12], #16 // weight
ld1 {v0.16b}, [x11], x22 // src ld1 {v0.16b}, [x11], x22 // src
// int4->int8 // int4->int8
ushr v3.16b, v5.16b, #4 ushr v3.16b, v5.16b, #4
and v4.16b, v5.16b, v7.16b and v4.16b, v5.16b, v7.16b
//zip1 v3.16b, v5.16b, v6.16b
//zip2 v4.16b, v5.16b, v6.16b
.inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0]
.inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1]
@ -485,20 +472,19 @@ L8LoopDz_TILE_4:
bne L8LoopSz_TILE_4 bne L8LoopSz_TILE_4
L8LoopSzEnd_TILE_4: L8LoopSzEnd_TILE_4:
add x12, x28, x15
sub x14, x14, #1 sub x14, x14, #1
L8Tile4Quan: L8Tile4Quan:
ld1 {v0.4s, v1.4s}, [x19], #32 // scale ld1 {v0.4s, v1.4s}, [x12], #32 // scale
ld1 {v2.4s}, [x25] // x kernel sum ld1 {v2.4s}, [x8] // x kernel sum
ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint ld1 {v24.4s, v25.4s}, [x12], #32 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
MUL_SCALE v0, v8, v9, v10, v11 MUL_SCALE v0, v8, v9, v10, v11
MUL_SCALE v1, v12, v13, v14, v15 MUL_SCALE v1, v12, v13, v14, v15
cbz x27, TILE4_L8_MLA_TERM cbz x25, TILE4_L8_MLA_TERM
ld1 {v4.4s}, [x27] ld1 {v4.4s}, [x24]
MUL_EXTRA_SCALE v4, v8, v9, v10, v11 MUL_EXTRA_SCALE v4, v8, v9, v10, v11
MUL_EXTRA_SCALE v4, v12, v13, v14, v15 MUL_EXTRA_SCALE v4, v12, v13, v14, v15
@ -514,64 +500,63 @@ L8LoopDz_TILE_4:
cbz x9, TILE4_ADD_DSTV cbz x9, TILE4_ADD_DSTV
TILE4_ADD_BIAS: TILE4_ADD_BIAS:
ld1 {v0.4s, v1.4s}, [x20], #32 ld1 {v4.4s, v5.4s}, [x20], #32
ADD_BIAS_FLOAT v8, v9, v10, v11, v0 ADD_BIAS_FLOAT v8, v9, v10, v11, v4
ADD_BIAS_FLOAT v12, v13, v14, v15, v1 ADD_BIAS_FLOAT v12, v13, v14, v15, v5
Float32ToHalf v8, v12, v9, v13, v0, v1 cbnz x0, TILE4_POST
Float32ToHalf v10, v14, v11, v15, v2, v3 b TILE4_L8_ACCUM_BUFFER
b TILE4_POST
TILE4_ADD_DSTV: TILE4_ADD_DSTV:
Float32ToHalf v8, v12, v9, v13, v0, v1 ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64
Float32ToHalf v10, v14, v11, v15, v2, v3 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x10] ADD_FLOAT v8, v9, v10, v11, v16, v17, v18, v19
fadd v0.8h, v0.8h, v20.8h ADD_FLOAT v12, v13, v14, v15, v20, v21, v22, v23
fadd v1.8h, v1.8h, v21.8h cbnz x0, TILE4_POST
fadd v2.8h, v2.8h, v22.8h
fadd v3.8h, v3.8h, v23.8h TILE4_L8_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
b L8Tile4LoopCheck
TILE4_POST: TILE4_POST:
Float32ToHalf v8, v12, v9, v13, v0, v1
Float32ToHalf v10, v14, v11, v15, v2, v3
cbz x23, TILE4_STORE cbz x23, TILE4_STORE
ld1r {v24.8h}, [x23], #2 // f16 min ld1r {v24.8h}, [x23], #2 // f16 min
ld1r {v25.8h}, [x23] // f16 max ld1r {v25.8h}, [x23] // f16 max
sub x23, x23, #2 sub x23, x23, #2
ReLU_FP16 v0, v1, v2, v3, v24, v25 ReLU_FP16 v0, v1, v2, v3, v24, v25
//st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], x4
//st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4
TILE4_STORE: TILE4_STORE:
st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], x4 st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x6], x4
L8Tile4LoopCheck: L8Tile4LoopCheck:
cmp x14, #1 cmp x14, #1
bge L8LoopDz_TILE_4 bge L8LoopDz_TILE_4
cbz x27, Tile4End
add x27, x27, #16
Tile4End: Tile4End:
sub x7, x7, #4 cbz x0, Tile4_End_Offset
add x0, x0, x21, LSL #2 add x0, x0, x21, LSL #2
Tile4_End_Offset:
sub x7, x7, #4
add x1, x1, #16 add x1, x1, #16
add x25, x25, #16 add x8, x8, #16
add x24, x24, #16
TILE_1:
// Already execute: [movi v7.16b, #15] in TILE_4
cbz x7, End cbz x7, End
mov x10, x0
TILE_1:
// Already execute: [movi v7.16b, #15] in TILE_4
mov x6, x0
mov x12, x2 mov x12, x2
mov x14, x5 mov x14, x5
mov x19, x8
mov x20, x9 mov x20, x9
mov x6, x26 // weightQuantBias
L8LoopDz_TILE_1: L8LoopDz_TILE_1:
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
movi v8.16b, #0 movi v8.16b, #0
movi v9.16b, #0 movi v9.16b, #0
mov x28, x12
cmp x22, #4 cmp x22, #4
bne L8LoopSz_TILE_1_lu1 bne L8LoopSz_TILE_1_lu1
cmp x13, #4 cmp x13, #4
@ -610,7 +595,6 @@ L8LoopDz_TILE_1:
and v29.16b, v13.16b, v7.16b and v29.16b, v13.16b, v7.16b
cmp x13, #8 cmp x13, #8
//sub x12, x12, x15
.inst 0x4f80e1c8 // sdot v8.4s, v14.16b, v0.4b[0] .inst 0x4f80e1c8 // sdot v8.4s, v14.16b, v0.4b[0]
.inst 0x4f80e2c9 // sdot v9.4s, v22.16b, v0.4b[0] .inst 0x4f80e2c9 // sdot v9.4s, v22.16b, v0.4b[0]
.inst 0x4fa0e1e8 // sdot v8.4s, v15.16b, v0.4b[1] .inst 0x4fa0e1e8 // sdot v8.4s, v15.16b, v0.4b[1]
@ -653,7 +637,6 @@ L8LoopDz_TILE_1:
and v25.16b, v6.16b, v7.16b and v25.16b, v6.16b, v7.16b
cmp x13, #4 cmp x13, #4
//sub x12, x12, x15
.inst 0x4f80e188 // sdot v8.4s, v12.16b, v0.4b[0] .inst 0x4f80e188 // sdot v8.4s, v12.16b, v0.4b[0]
.inst 0x4f80e2c9 // sdot v9.4s, v22.16b, v0.4b[0] .inst 0x4f80e2c9 // sdot v9.4s, v22.16b, v0.4b[0]
.inst 0x4fa0e1e8 // sdot v8.4s, v15.16b, v0.4b[1] .inst 0x4fa0e1e8 // sdot v8.4s, v15.16b, v0.4b[1]
@ -669,7 +652,6 @@ L8LoopDz_TILE_1:
L8LoopSz_TILE_1_lu1: L8LoopSz_TILE_1_lu1:
ld1 {v4.16b}, [x12], #16 // weight ld1 {v4.16b}, [x12], #16 // weight
ld1 {v0.s}[0], [x11], x22 // src ld1 {v0.s}[0], [x11], x22 // src
//ld1 {v4.d}[0], [x12], #8 // weight
subs x13, x13, #1 subs x13, x13, #1
// int4->int8 // int4->int8
ushr v3.16b, v4.16b, #4 ushr v3.16b, v4.16b, #4
@ -680,20 +662,19 @@ L8LoopDz_TILE_1:
bne L8LoopSz_TILE_1_lu1 bne L8LoopSz_TILE_1_lu1
L8LoopSzEnd_TILE_1: L8LoopSzEnd_TILE_1:
add x12, x28, x15
sub x14, x14, #1 sub x14, x14, #1
L8Tile1Quan: L8Tile1Quan:
ld1 {v0.4s, v1.4s}, [x19], #32 // scale ld1 {v0.4s, v1.4s}, [x12], #32 // scale
ld1 {v2.s}[0], [x25] // x kernel sum ld1 {v2.s}[0], [x8] // x kernel sum
ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint ld1 {v24.4s, v25.4s}, [x12], #32 // weight quan zeropoint
scvtf v8.4s, v8.4s scvtf v8.4s, v8.4s
scvtf v9.4s, v9.4s scvtf v9.4s, v9.4s
fmul v8.4s, v8.4s, v0.4s fmul v8.4s, v8.4s, v0.4s
fmul v9.4s, v9.4s, v1.4s fmul v9.4s, v9.4s, v1.4s
cbz x27, TILE1_L8_MLA_TERM cbz x25, TILE1_L8_MLA_TERM
ld1 {v4.s}[0], [x27] ld1 {v4.s}[0], [x24]
fmul v8.4s, v8.4s, v4.s[0] fmul v8.4s, v8.4s, v4.s[0]
fmul v9.4s, v9.4s, v4.s[0] fmul v9.4s, v9.4s, v4.s[0]
@ -703,20 +684,25 @@ L8LoopDz_TILE_1:
cbz x9, TILE1_ADD_DSTV cbz x9, TILE1_ADD_DSTV
TILE1_ADD_BIAS: TILE1_ADD_BIAS:
ld1 {v20.4s, v21.4s}, [x20], #32 ld1 {v10.4s, v11.4s}, [x20], #32
fadd v8.4s, v8.4s, v20.4s fadd v8.4s, v8.4s, v10.4s
fadd v9.4s, v9.4s, v21.4s fadd v9.4s, v9.4s, v11.4s
fcvtn v0.4h, v8.4s cbnz x0, TILE1_POST
fcvtn2 v0.8h, v9.4s b TILE1_L8_ACCUM_BUFFER
b TILE1_POST
TILE1_ADD_DSTV: TILE1_ADD_DSTV:
fcvtn v0.4h, v8.4s ld1 {v10.4s, v11.4s}, [x10], #32
fcvtn2 v0.8h, v9.4s fadd v8.4s, v8.4s, v10.4s
ld1 {v3.8h}, [x10] fadd v9.4s, v9.4s, v11.4s
fadd v0.8h, v0.8h, v3.8h cbnz x0, TILE1_POST
TILE1_L8_ACCUM_BUFFER:
st1 {v8.4s, v9.4s}, [x15], #32
b L8Tile1LoopCheck
TILE1_POST: TILE1_POST:
fcvtn v0.4h, v8.4s
fcvtn2 v0.8h, v9.4s
cbz x23, TILE1_STORE cbz x23, TILE1_STORE
ld1r {v24.8h}, [x23], #2 // f16 min ld1r {v24.8h}, [x23], #2 // f16 min
ld1r {v25.8h}, [x23] // f16 max ld1r {v25.8h}, [x23] // f16 max
@ -724,30 +710,30 @@ L8LoopDz_TILE_1:
fmax v0.8h, v24.8h, v0.8h fmax v0.8h, v24.8h, v0.8h
fmin v0.8h, v25.8h, v0.8h fmin v0.8h, v25.8h, v0.8h
TILE1_STORE: TILE1_STORE:
st1 {v0.8h}, [x10], x4 st1 {v0.8h}, [x6], x4
L8Tile1LoopCheck: L8Tile1LoopCheck:
cmp x14, #1 cmp x14, #1
bge L8LoopDz_TILE_1 bge L8LoopDz_TILE_1
cbz x27, Tile1End
add x27, x27, #4
Tile1End: Tile1End:
sub x7, x7, #1 cbz x0, Tile1_End_Offset
add x0, x0, x21 add x0, x0, x21
Tile1_End_Offset:
add x24, x24, #4
subs x7, x7, #1
add x1, x1, #4 add x1, x1, #4
add x25, x25, #4 add x8, x8, #4
b TILE_1 bne TILE_1
End: End:
ldp x23, x24, [sp, #(16 * 8)]
ldp x25, x26, [sp, #(16 * 7)] ldp x25, x26, [sp, #(16 * 7)]
ldp x27, x28, [sp, #(16 * 6)] ldp x23, x24, [sp, #(16 * 6)]
ldp x19, x20, [sp, #(16 * 5)] ldp x19, x20, [sp, #(16 * 5)]
ldp x21, x22, [sp, #(16 * 4)] ldp x21, x22, [sp, #(16 * 4)]
ldp d8, d9, [sp, #(16 * 3)] ldp d8, d9, [sp, #(16 * 3)]
ldp d10, d11, [sp, #(16 * 2)] ldp d10, d11, [sp, #(16 * 2)]
ldp d12, d13, [sp, #(16 * 1)] ldp d12, d13, [sp, #(16 * 1)]
ldp d14, d15, [sp], #(16 * 10) ldp d14, d15, [sp], #(16 * 8)
ret ret
#endif // __aarch64__ #endif // __aarch64__

View File

@ -56,24 +56,6 @@
fmax \s0\().8h, \s0\().8h, \z0\().8h fmax \s0\().8h, \s0\().8h, \z0\().8h
fmax \s1\().8h, \s1\().8h, \z0\().8h fmax \s1\().8h, \s1\().8h, \z0\().8h
.endm .endm
.macro SET_BIAS s, d0, d1, d2, d3, d4, idx
dup \d0\().2d, \s\().d[\idx]
dup \d1\().2d, \s\().d[\idx]
dup \d2\().2d, \s\().d[\idx]
dup \d3\().2d, \s\().d[\idx]
dup \d4\().2d, \s\().d[\idx]
.endm
.macro SET_BIAS_4 s, d0, d1, d2, d3, idx
dup \d0\().2d, \s\().d[\idx]
dup \d1\().2d, \s\().d[\idx]
dup \d2\().2d, \s\().d[\idx]
dup \d3\().2d, \s\().d[\idx]
.endm
.macro SET_BIAS_2 s, d0, d1, idx
dup \d0\().2d, \s\().d[\idx]
dup \d1\().2d, \s\().d[\idx]
.endm
.macro Int32ToFloat z0, z1, z2, z3 .macro Int32ToFloat z0, z1, z2, z3
scvtf \z0\().4s, \z0\().4s scvtf \z0\().4s, \z0\().4s
scvtf \z1\().4s, \z1\().4s scvtf \z1\().4s, \z1\().4s
@ -129,13 +111,13 @@ struct QuanPostTreatParameters {
//Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step //Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step
//x5:dst_depth_quad, x6: parameters, x7: realDstCount //x5:dst_depth_quad, x6: parameters, x7: realDstCount
//Load from x7: x8: scale, x9: biasFloat, x27: srcKernelSum, x28: weightQuanBias, x14: fp32minmax //Load from x7: x9: biasFloat, x27: srcKernelSum, x14: fp32minmax
// x12, x15, x8,x28
/* For FP16 /* For FP16
UNIT = 8; UNIT = 8;
SRC_UNIT = 8; SRC_UNIT = 8;
DST_XUNIT = 10; DST_XUNIT = 10;
*/ */
ldr x8, [x6, #0]
ldr x9, [x6, #8] ldr x9, [x6, #8]
stp d14, d15, [sp, #(-16 * 10)]! stp d14, d15, [sp, #(-16 * 10)]!
@ -145,29 +127,25 @@ stp d8, d9, [sp, #(16 * 3)]
stp x21, x22, [sp, #(16 * 4)] stp x21, x22, [sp, #(16 * 4)]
stp x19, x20, [sp, #(16 * 5)] stp x19, x20, [sp, #(16 * 5)]
stp x23, x24, [sp, #(16 * 6)] stp x23, x24, [sp, #(16 * 6)]
stp x25, x26, [sp, #(16 * 7)] stp x27, x28, [sp, #(16 * 7)]
stp x27, x28, [sp, #(16 * 8)] stp x25, x26, [sp, #(16 * 8)]
// ldr w23, [x6, #24]
ldr x27, [x6, #40] // srcKernelSum ldr x27, [x6, #40] // srcKernelSum
ldr x28, [x6, #48] // weightQuanBias
ldr x14, [x6, #56] // fp32minmax ldr x14, [x6, #56] // fp32minmax
lsl x22, x7, #3 // eDest * GEMM_INT8_SRC_UNIT lsl x22, x7, #3 // eDest * GEMM_INT8_SRC_UNIT
mov x21, #16 // sizeof(float16_t) * UNIT mov x21, #16 // sizeof(float16_t) * UNIT
Start:
lsl x15, x3, #6 // x15 = src_depth_quad * UNIT * UNIT_SRC * sizeof(int8_t) = src_depth_quad * 64 = src_depth_quad << 6
ldr x23, [x6, #80] // extra scale ldr x23, [x6, #80] // extra scale
ldr x15, [x6, #96]
mov x10, x15 // tag dst address
mov x25, x23
TILE_10: TILE_10:
cmp x7, #10 cmp x7, #10
blt TILE_8 blt TILE_8
sub x4, x4, #128 sub x4, x4, #128
LoopDz_TILE_10: LoopDz_TILE_10:
//ld1 {v0.4s, v1.4s}, [x9], #32 // bias
mov x11, x1 // src mov x11, x1 // src
mov x12, x2 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
mov x10, x0 // tag dst address
SET_0_5 v12, v16, v20, v24, v28 // oc:0,1,0,1 SET_0_5 v12, v16, v20, v24, v28 // oc:0,1,0,1
SET_0_5 v13, v17, v21, v25, v29 // oc:2,3,2,3 SET_0_5 v13, v17, v21, v25, v29 // oc:2,3,2,3
@ -175,7 +153,7 @@ LoopDz_TILE_10:
SET_0_5 v15, v19, v23, v27, v31 // oc:6,7,6,7 SET_0_5 v15, v19, v23, v27, v31 // oc:6,7,6,7
LoopSz_TILE_10: LoopSz_TILE_10:
ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x12], #64 // weight ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x2], #64 // weight
ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9 ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9
ld1 {v7.16b}, [x11], #16 ld1 {v7.16b}, [x11], #16
subs x13, x13, #1 subs x13, x13, #1
@ -205,8 +183,6 @@ LoopSz_TILE_10:
.inst 0x4e8ba4ff // smmla v31.4s, v7.16b, v11.16b // tile8-oc6, tile8-oc7, tile9-oc6, tile9-oc7 .inst 0x4e8ba4ff // smmla v31.4s, v7.16b, v11.16b // tile8-oc6, tile8-oc7, tile9-oc6, tile9-oc7
bne LoopSz_TILE_10 bne LoopSz_TILE_10
LoopSzEnd_TILE_10: LoopSzEnd_TILE_10:
add x2, x2, x15 // weight += dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) * sizeof(int8_t);
sub x5, x5, #1 // dz--
// transpose // transpose
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
@ -239,14 +215,11 @@ LoopSzEnd_TILE_10:
Int32ToFloat v16, v17, v18, v19 Int32ToFloat v16, v17, v18, v19
Tile10Quan: Tile10Quan:
ld1 {v20.4s, v21.4s}, [x8], #32 // scale ld1 {v20.4s, v21.4s}, [x2], #32 // scale
ld1 {v22.4s, v23.4s}, [x27], #32 // x kernel sum ld1 {v22.4s, v23.4s}, [x27], #32 // x kernel sum
ld1 {v24.d}[0], [x27] ld1 {v24.d}[0], [x27]
sub x27, x27, #32 sub x27, x27, #32
ld1 {v25.4s, v26.4s}, [x28], #32 // weight quan zeropoint ld1 {v25.4s, v26.4s}, [x2], #32 // weight quan zeropoint
//ld1r {v27.4s}, [x6], #4 // f32 min
//ld1r {v28.4s}, [x6] // f32 max
//sub x6, x6, #4
MUL_SCALE v20, v0, v1, v4, v5 MUL_SCALE v20, v0, v1, v4, v5
MUL_SCALE v21, v2, v3, v6, v7 MUL_SCALE v21, v2, v3, v6, v7
MUL_SCALE v20, v8, v9, v12, v13 MUL_SCALE v20, v8, v9, v12, v13
@ -256,7 +229,7 @@ Tile10Quan:
fmul v18.4s, v18.4s, v21.4s fmul v18.4s, v18.4s, v21.4s
fmul v19.4s, v19.4s, v21.4s fmul v19.4s, v19.4s, v21.4s
cbz x23, TILE10_MLA cbz x25, TILE10_MLA
ld1 {v27.4s, v28.4s}, [x23], #32 ld1 {v27.4s, v28.4s}, [x23], #32
ld1 {v29.d}[0], [x23] ld1 {v29.d}[0], [x23]
MUL_EXTRA_SCALE v27, v0, v1, v4, v5 MUL_EXTRA_SCALE v27, v0, v1, v4, v5
@ -307,39 +280,52 @@ Tile10Quan:
fadd v17.4s, v17.4s, v20.4s fadd v17.4s, v17.4s, v20.4s
fadd v18.4s, v18.4s, v21.4s fadd v18.4s, v18.4s, v21.4s
fadd v19.4s, v19.4s, v21.4s fadd v19.4s, v19.4s, v21.4s
cbnz x0, TILE10_POST // to Relu post
// float32->float16 b TILE10_TEMP_BUFFER
Float32ToHalf v0, v2, v1, v3, v20, v21
Float32ToHalf v4, v6, v5, v7, v22, v23
Float32ToHalf v8, v10, v9, v11, v24, v25
Float32ToHalf v12, v14, v13, v15, v26, v27
Float32ToHalf v16, v18, v17, v19, v30, v31
b TILE10_POST // to Relu post
TILE10_ADD_DSTV: TILE10_ADD_DSTV:
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], #64
fadd v0.4s, v0.4s, v20.4s
fadd v1.4s, v1.4s, v21.4s
fadd v2.4s, v2.4s, v22.4s
fadd v3.4s, v3.4s, v23.4s
fadd v4.4s, v4.4s, v24.4s
fadd v5.4s, v5.4s, v25.4s
fadd v6.4s, v6.4s, v26.4s
fadd v7.4s, v7.4s, v27.4s
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
fadd v8.4s, v8.4s, v28.4s
fadd v9.4s, v9.4s, v29.4s
fadd v10.4s, v10.4s, v30.4s
fadd v11.4s, v11.4s, v31.4s
fadd v12.4s, v12.4s, v20.4s
fadd v13.4s, v13.4s, v21.4s
fadd v14.4s, v14.4s, v22.4s
fadd v15.4s, v15.4s, v23.4s
fadd v16.4s, v16.4s, v24.4s
fadd v17.4s, v17.4s, v25.4s
fadd v18.4s, v18.4s, v26.4s
fadd v19.4s, v19.4s, v27.4s
cbnz x0, TILE10_POST
TILE10_TEMP_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x15], #64
b Tile10LoopCheck
TILE10_POST:
// float32->float16 // float32->float16
Float32ToHalf v0, v2, v1, v3, v20, v21 Float32ToHalf v0, v2, v1, v3, v20, v21
Float32ToHalf v4, v6, v5, v7, v22, v23 Float32ToHalf v4, v6, v5, v7, v22, v23
Float32ToHalf v8, v10, v9, v11, v24, v25 Float32ToHalf v8, v10, v9, v11, v24, v25
Float32ToHalf v12, v14, v13, v15, v26, v27 Float32ToHalf v12, v14, v13, v15, v26, v27
Float32ToHalf v16, v18, v17, v19, v30, v31 Float32ToHalf v16, v18, v17, v19, v30, v31
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64
ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64
ld1 {v8.8h, v9.8h}, [x10]
fadd v20.8h, v20.8h, v0.8h
fadd v21.8h, v21.8h, v1.8h
fadd v22.8h, v22.8h, v2.8h
fadd v23.8h, v23.8h, v3.8h
fadd v24.8h, v24.8h, v4.8h
fadd v25.8h, v25.8h, v5.8h
fadd v26.8h, v26.8h, v6.8h
fadd v27.8h, v27.8h, v7.8h
fadd v30.8h, v30.8h, v8.8h
fadd v31.8h, v31.8h, v9.8h
TILE10_POST:
cbz x14, TILE10_STORE cbz x14, TILE10_STORE
ld1r {v29.8h}, [x14], #2 // f32 min ld1r {v29.8h}, [x14], #2 // f32 min
ld1r {v28.8h}, [x14] // f32 max ld1r {v28.8h}, [x14] // f32 max
@ -356,26 +342,21 @@ Tile10Quan:
st1 {v30.8h, v31.8h}, [x0], x4 st1 {v30.8h, v31.8h}, [x0], x4
Tile10LoopCheck: Tile10LoopCheck:
cmp x5, #1 subs x5, x5, #1 // dz--
bge LoopDz_TILE_10 bne LoopDz_TILE_10
b End b End
TILE_8: TILE_8:
cmp x7, #8 cmp x7, #8
blt TILE_4 blt TILE_4
sub x8, x4, #64 // just for Tile8, revert it when Tile8end
mov x24, x5 // dst_depth_quad mov x24, x5 // dst_depth_quad
mov x26, x0 // dst mov x6, x0 // dst
mov x25, x2 // weight mov x12, x2 // weight
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x28 // weightQuanBias
sub x4, x4, #64 // For Tile8, revert it when Tile8 end
LoopDz_TILE_8: LoopDz_TILE_8:
//ld1 {v0.4s, v1.4s}, [x20], #32 // bias
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
mov x10, x26 // tag dst
SET_0_4 v12, v16, v20, v24 // oc:0,1,0,1 SET_0_4 v12, v16, v20, v24 // oc:0,1,0,1
SET_0_4 v13, v17, v21, v25 // oc:2,3,2,3 SET_0_4 v13, v17, v21, v25 // oc:2,3,2,3
@ -407,7 +388,6 @@ LoopSz_TILE_8:
bne LoopSz_TILE_8 bne LoopSz_TILE_8
LoopSzEnd_TILE_8: LoopSzEnd_TILE_8:
add x25, x25, x15
sub x24, x24, #1 sub x24, x24, #1
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
@ -434,15 +414,15 @@ LoopSzEnd_TILE_8:
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
Tile8Quan: Tile8Quan:
ld1 {v20.4s, v21.4s}, [x19], #32 // scale ld1 {v20.4s, v21.4s}, [x12], #32 // scale
ld1 {v22.4s, v23.4s}, [x27] // x kernel sum ld1 {v22.4s, v23.4s}, [x27] // x kernel sum
ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint ld1 {v25.4s, v26.4s}, [x12], #32 // weight quan zeropoint
MUL_SCALE v20, v0, v1, v4, v5 MUL_SCALE v20, v0, v1, v4, v5
MUL_SCALE v21, v2, v3, v6, v7 MUL_SCALE v21, v2, v3, v6, v7
MUL_SCALE v20, v8, v9, v12, v13 MUL_SCALE v20, v8, v9, v12, v13
MUL_SCALE v21, v10, v11, v14, v15 MUL_SCALE v21, v10, v11, v14, v15
cbz x23, TILE8_MLA cbz x25, TILE8_MLA
ld1 {v27.4s, v28.4s}, [x23] ld1 {v27.4s, v28.4s}, [x23]
MUL_EXTRA_SCALE v27, v0, v1, v4, v5 MUL_EXTRA_SCALE v27, v0, v1, v4, v5
MUL_EXTRA_SCALE v28, v8, v9, v12, v13 MUL_EXTRA_SCALE v28, v8, v9, v12, v13
@ -477,31 +457,45 @@ Tile8Quan:
ADD_BIAS_FLOAT v2, v3, v6, v7, v17 ADD_BIAS_FLOAT v2, v3, v6, v7, v17
ADD_BIAS_FLOAT v8, v9, v12, v13, v16 ADD_BIAS_FLOAT v8, v9, v12, v13, v16
ADD_BIAS_FLOAT v10, v11, v14, v15, v17 ADD_BIAS_FLOAT v10, v11, v14, v15, v17
// float32->float16 cbnz x0, TILE8_POST
Float32ToHalf v0, v2, v1, v3, v20, v21 b TILE8_TEMP_BUFFER
Float32ToHalf v4, v6, v5, v7, v22, v23
Float32ToHalf v8, v10, v9, v11, v24, v25
Float32ToHalf v12, v14, v13, v15, v26, v27
b TILE8_POST
TILE8_ADD_DSTV: TILE8_ADD_DSTV:
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
fadd v0.4s, v0.4s, v16.4s
fadd v1.4s, v1.4s, v17.4s
fadd v2.4s, v2.4s, v18.4s
fadd v3.4s, v3.4s, v19.4s
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64
fadd v4.4s, v4.4s, v20.4s
fadd v5.4s, v5.4s, v21.4s
fadd v6.4s, v6.4s, v22.4s
fadd v7.4s, v7.4s, v23.4s
fadd v8.4s, v8.4s, v24.4s
fadd v9.4s, v9.4s, v25.4s
fadd v10.4s, v10.4s, v26.4s
fadd v11.4s, v11.4s, v27.4s
fadd v12.4s, v12.4s, v16.4s
fadd v13.4s, v13.4s, v17.4s
fadd v14.4s, v14.4s, v18.4s
fadd v15.4s, v15.4s, v19.4s
cbnz x0, TILE8_POST
TILE8_TEMP_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
b Tile8LoopCheck
TILE8_POST:
// float32->float16 // float32->float16
Float32ToHalf v0, v2, v1, v3, v20, v21 Float32ToHalf v0, v2, v1, v3, v20, v21
Float32ToHalf v4, v6, v5, v7, v22, v23 Float32ToHalf v4, v6, v5, v7, v22, v23
Float32ToHalf v8, v10, v9, v11, v24, v25 Float32ToHalf v8, v10, v9, v11, v24, v25
Float32ToHalf v12, v14, v13, v15, v26, v27 Float32ToHalf v12, v14, v13, v15, v26, v27
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64
ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10]
fadd v20.8h, v20.8h, v0.8h
fadd v21.8h, v21.8h, v1.8h
fadd v22.8h, v22.8h, v2.8h
fadd v23.8h, v23.8h, v3.8h
fadd v24.8h, v24.8h, v4.8h
fadd v25.8h, v25.8h, v5.8h
fadd v26.8h, v26.8h, v6.8h
fadd v27.8h, v27.8h, v7.8h
TILE8_POST:
cbz x14, TILE8_STORE cbz x14, TILE8_STORE
ld1r {v29.8h}, [x14], #2 // f32 min ld1r {v29.8h}, [x14], #2 // f32 min
ld1r {v28.8h}, [x14] // f32 max ld1r {v28.8h}, [x14] // f32 max
@ -510,37 +504,31 @@ Tile8Quan:
ReLU_FP16 v24, v25, v26, v27, v29, v28 ReLU_FP16 v24, v25, v26, v27, v29, v28
TILE8_STORE: TILE8_STORE:
st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x6], #64
st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x26], #64 st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x6], x8
st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x26], x4
Tile8LoopCheck: Tile8LoopCheck:
cmp x24, #1 cmp x24, #1
bge LoopDz_TILE_8 bge LoopDz_TILE_8
cbz x23, Tile8End cbz x0, Tile8End
add x23, x23, #32 add x0, x0, x21, LSL #3
Tile8End: Tile8End:
sub x7, x7, #8 sub x7, x7, #8
add x0, x0, x21, LSL #3
add x1, x1, #64 add x1, x1, #64
add x27, x27, #32 add x27, x27, #32
add x4, x4, #64 // Revert it add x23, x23, #32
TILE_4: TILE_4:
cmp x7, #4 cmp x7, #4
blt TILE_2 blt TILE_2
mov x24, x5 // dst_depth_quad mov x24, x5 // dst_depth_quad
mov x26, x0 // dst mov x6, x0 // dst
mov x25, x2 // weight mov x12, x2 // weight
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x28 // weightQuanBias
LoopDz_TILE_4: LoopDz_TILE_4:
//ld1 {v0.4s, v1.4s}, [x20], #32 // bias
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
mov x10, x26 // tag dst
SET_0_2 v12, v16 // oc:0,1,0,1 SET_0_2 v12, v16 // oc:0,1,0,1
SET_0_2 v13, v17 // oc:2,3,2,3 SET_0_2 v13, v17 // oc:2,3,2,3
@ -562,7 +550,6 @@ LoopSz_TILE_4:
bne LoopSz_TILE_4 bne LoopSz_TILE_4
LoopSzEnd_TILE_4: LoopSzEnd_TILE_4:
add x25, x25, x15
sub x24, x24, #1 sub x24, x24, #1
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
@ -577,13 +564,13 @@ LoopSzEnd_TILE_4:
Int32ToFloat v4, v5, v6, v7 Int32ToFloat v4, v5, v6, v7
Tile4Quan: Tile4Quan:
ld1 {v20.4s, v21.4s}, [x19], #32 // scale ld1 {v20.4s, v21.4s}, [x12], #32 // scale
ld1 {v22.4s}, [x27] // x kernel sum ld1 {v22.4s}, [x27] // x kernel sum
ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint ld1 {v25.4s, v26.4s}, [x12], #32 // weight quan zeropoint
MUL_SCALE v20, v0, v1, v4, v5 MUL_SCALE v20, v0, v1, v4, v5
MUL_SCALE v21, v2, v3, v6, v7 MUL_SCALE v21, v2, v3, v6, v7
cbz x23, TILE4_MLA cbz x25, TILE4_MLA
ld1 {v27.4s}, [x23] ld1 {v27.4s}, [x23]
MUL_EXTRA_SCALE v27, v0, v1, v4, v5 MUL_EXTRA_SCALE v27, v0, v1, v4, v5
MUL_EXTRA_SCALE v27, v2, v3, v6, v7 MUL_EXTRA_SCALE v27, v2, v3, v6, v7
@ -604,22 +591,31 @@ Tile4Quan:
ld1 {v16.4s, v17.4s}, [x20], #32 // bias ld1 {v16.4s, v17.4s}, [x20], #32 // bias
ADD_BIAS_FLOAT v0, v1, v4, v5, v16 ADD_BIAS_FLOAT v0, v1, v4, v5, v16
ADD_BIAS_FLOAT v2, v3, v6, v7, v17 ADD_BIAS_FLOAT v2, v3, v6, v7, v17
// float32->float16 cbnz x0, TILE4_POST
Float32ToHalf v0, v2, v1, v3, v20, v21 b TILE4_TEMP_BUFFER
Float32ToHalf v4, v6, v5, v7, v22, v23
b TILE4_POST
TILE4_ADD_DSTV: TILE4_ADD_DSTV:
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
fadd v0.4s, v0.4s, v20.4s
fadd v1.4s, v1.4s, v21.4s
fadd v2.4s, v2.4s, v22.4s
fadd v3.4s, v3.4s, v23.4s
fadd v4.4s, v4.4s, v24.4s
fadd v5.4s, v5.4s, v25.4s
fadd v6.4s, v6.4s, v26.4s
fadd v7.4s, v7.4s, v27.4s
cbnz x0, TILE4_POST
TILE4_TEMP_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
b Tile4LoopCheck
TILE4_POST:
// float32->float16 // float32->float16
Float32ToHalf v0, v2, v1, v3, v20, v21 Float32ToHalf v0, v2, v1, v3, v20, v21
Float32ToHalf v4, v6, v5, v7, v22, v23 Float32ToHalf v4, v6, v5, v7, v22, v23
ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x10]
fadd v20.8h, v20.8h, v24.8h
fadd v21.8h, v21.8h, v25.8h
fadd v22.8h, v22.8h, v26.8h
fadd v23.8h, v23.8h, v27.8h
TILE4_POST:
cbz x14, TILE4_STORE cbz x14, TILE4_STORE
ld1r {v29.8h}, [x14], #2 // f32 min ld1r {v29.8h}, [x14], #2 // f32 min
ld1r {v28.8h}, [x14] // f32 max ld1r {v28.8h}, [x14] // f32 max
@ -627,40 +623,30 @@ Tile4Quan:
ReLU_FP16 v20, v21, v22, v23, v29, v28 ReLU_FP16 v20, v21, v22, v23, v29, v28
TILE4_STORE: TILE4_STORE:
st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x26], x4 st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x6], x4
Tile4LoopCheck: Tile4LoopCheck:
cmp x24, #1 cmp x24, #1
bge LoopDz_TILE_4 bge LoopDz_TILE_4
cbz x23, Tile4End cbz x0, Tile4End
add x23, x23, #16 add x0, x0, x21, LSL #2
Tile4End: Tile4End:
sub x7, x7, #4 sub x7, x7, #4
add x0, x0, x21, LSL #2
add x1, x1, #32 add x1, x1, #32
add x27, x27, #16 add x27, x27, #16
//b TILE_4 add x23, x23, #16
TILE_2: TILE_2:
cmp x7, #2 cmp x7, #2
blt TILE_1 blt TILE_1
mov x24, x5 // dst_depth_quad mov x24, x5 // dst_depth_quad
mov x26, x0 // dst mov x6, x0 // dst
mov x25, x2 // weight mov x12, x2 // weight
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x28 // weightQuanBias
LoopDz_TILE_2: LoopDz_TILE_2:
//ld1 {v0.4s, v1.4s}, [x20], #32 // bias
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
mov x10, x26 // tag dst
// v12 oc:0,1,0,1
// v13 oc:2,3,2,3
// v14 oc:4,5,4,5
// v15 oc:6,7,6,7
SET_0_4 v12, v13, v14, v15 SET_0_4 v12, v13, v14, v15
LoopSz_TILE_2: LoopSz_TILE_2:
ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x12], #64 ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x12], #64
@ -672,7 +658,6 @@ LoopSz_TILE_2:
subs x13, x13, #1 subs x13, x13, #1
bne LoopSz_TILE_2 bne LoopSz_TILE_2
LoopSzEnd_TILE_2: LoopSzEnd_TILE_2:
add x25, x25, x15
sub x24, x24, #1 sub x24, x24, #1
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
@ -681,15 +666,15 @@ LoopSzEnd_TILE_2:
Int32ToFloat v0, v1, v2, v3 Int32ToFloat v0, v1, v2, v3
Tile2Quan: Tile2Quan:
ld1 {v20.4s, v21.4s}, [x19], #32 // scale ld1 {v20.4s, v21.4s}, [x12], #32 // scale
ld1 {v22.d}[0], [x27] // x kernel sum ld1 {v22.d}[0], [x27] // x kernel sum
ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint ld1 {v25.4s, v26.4s}, [x12], #32 // weight quan zeropoint
fmul v0.4s, v0.4s, v20.4s fmul v0.4s, v0.4s, v20.4s
fmul v1.4s, v1.4s, v20.4s fmul v1.4s, v1.4s, v20.4s
fmul v2.4s, v2.4s, v21.4s fmul v2.4s, v2.4s, v21.4s
fmul v3.4s, v3.4s, v21.4s fmul v3.4s, v3.4s, v21.4s
cbz x23, TILE2_MLA cbz x25, TILE2_MLA
ld1 {v27.d}[0], [x23] ld1 {v27.d}[0], [x23]
fmul v0.4s, v0.4s, v27.s[0] fmul v0.4s, v0.4s, v27.s[0]
fmul v1.4s, v1.4s, v27.s[1] fmul v1.4s, v1.4s, v27.s[1]
@ -709,17 +694,24 @@ Tile2Quan:
fadd v1.4s, v1.4s, v16.4s fadd v1.4s, v1.4s, v16.4s
fadd v2.4s, v2.4s, v17.4s fadd v2.4s, v2.4s, v17.4s
fadd v3.4s, v3.4s, v17.4s fadd v3.4s, v3.4s, v17.4s
// float32->float16 cbnz x0, TILE2_POST
Float32ToHalf v0, v2, v1, v3, v20, v21 b TILE2_TEMP_BUFFER
b TILE2_POST
TILE2_ADD_DSTV: TILE2_ADD_DSTV:
Float32ToHalf v0, v2, v1, v3, v20, v21 ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ld1 {v24.8h, v25.8h}, [x10] fadd v0.4s, v0.4s, v4.4s
fadd v20.8h, v20.8h, v24.8h fadd v1.4s, v1.4s, v5.4s
fadd v21.8h, v21.8h, v25.8h fadd v2.4s, v2.4s, v6.4s
fadd v3.4s, v3.4s, v7.4s
cbnz x0, TILE2_POST
TILE2_TEMP_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
b Tile2LoopCheck
TILE2_POST: TILE2_POST:
// float32->float16
Float32ToHalf v0, v2, v1, v3, v20, v21
cbz x14, TILE2_STORE cbz x14, TILE2_STORE
ld1r {v29.8h}, [x14], #2 // f32 min ld1r {v29.8h}, [x14], #2 // f32 min
ld1r {v28.8h}, [x14] // f32 max ld1r {v28.8h}, [x14] // f32 max
@ -730,53 +722,36 @@ Tile2Quan:
fmin v21.8h, v21.8h, v28.8h fmin v21.8h, v21.8h, v28.8h
TILE2_STORE: TILE2_STORE:
st1 {v20.8h, v21.8h}, [x26], x4 st1 {v20.8h, v21.8h}, [x6], x4
Tile2LoopCheck: Tile2LoopCheck:
cmp x24, #1 cmp x24, #1
bge LoopDz_TILE_2 bge LoopDz_TILE_2
cbz x23, Tile2End cbz x0, Tile2End
add x23, x23, #8 add x0, x0, x21, LSL #1
Tile2End: Tile2End:
sub x7, x7, #2 sub x7, x7, #2
add x0, x0, x21, LSL #1
add x1, x1, #16 add x1, x1, #16
add x27, x27, #8 add x27, x27, #8
add x23, x23, #8
TILE_1: TILE_1:
cmp x7, #1 cmp x7, #1
blt End blt End
mov x24, x5 // dst_depth_quad mov x24, x5 // dst_depth_quad
mov x26, x0 // dst mov x6, x0 // dst
mov x25, x2 // weight mov x12, x2 // weight
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x28 // weightQuanBias
LoopDz_TILE_1: LoopDz_TILE_1:
//ld1 {v7.4s, v8.4s}, [x20], #32 // bias
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
mov x10, x26
//dup v16.2d, v7.d[0] // oc:0,1,0,1
//dup v17.2d, v7.d[1] // oc:2,3,2,3
//dup v18.2d, v8.d[0] // oc:4,5,4,5
//dup v19.2d, v8.d[1] // oc:6,7,6,7
movi v16.4s, #0 // oc:0,1,0,1 movi v16.4s, #0 // oc:0,1,0,1
movi v17.4s, #0 // oc:2,3,2,3 movi v17.4s, #0 // oc:2,3,2,3
movi v18.4s, #0 // oc:4,5,4,5 movi v18.4s, #0 // oc:4,5,4,5
movi v19.4s, #0 // oc:6,7,6,7 movi v19.4s, #0 // oc:6,7,6,7
//movi v22.4s, #0 // oc:0,1,0,1
//movi v23.4s, #0 // oc:2,3,2,3
//movi v24.4s, #0 // oc:4,5,4,5
//movi v25.4s, #0 // oc:6,7,6,7
LoopSz_TILE_1: LoopSz_TILE_1:
// src : 1 x [1 x 8] : v2
// weight : 2 x [2 x 8] : v0-1
// dst : 1 x 2 x [2] : v30-v31
ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x12], #64 // weight ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x12], #64 // weight
ld1 {v2.8b}, [x11], x22 // src ld1 {v2.8b}, [x11], x22 // src
subs x13, x13, #1 subs x13, x13, #1
@ -787,44 +762,48 @@ LoopSz_TILE_1:
bne LoopSz_TILE_1 bne LoopSz_TILE_1
LoopSzEnd_TILE_1: LoopSzEnd_TILE_1:
add x25, x25, x15
sub x24, x24, #1 sub x24, x24, #1
uzp1 v27.2d, v16.2d, v17.2d uzp1 v25.2d, v16.2d, v17.2d
uzp1 v26.2d, v18.2d, v19.2d uzp1 v26.2d, v18.2d, v19.2d
scvtf v27.4s, v27.4s scvtf v25.4s, v25.4s
scvtf v26.4s, v26.4s scvtf v26.4s, v26.4s
Tile1Quan: Tile1Quan:
ld1 {v0.4s, v1.4s}, [x19], #32 // scale ld1 {v0.4s, v1.4s}, [x12], #32 // scale
ld1 {v6.s}[0], [x27] // x kernel sum ld1 {v6.s}[0], [x27] // x kernel sum
ld1 {v8.4s, v9.4s}, [x6], #32 // weight quan zeropoint ld1 {v8.4s, v9.4s}, [x12], #32 // weight quan zeropoint
fmul v27.4s, v27.4s, v0.4s fmul v25.4s, v25.4s, v0.4s
fmul v26.4s, v26.4s, v1.4s fmul v26.4s, v26.4s, v1.4s
cbz x23, TILE1_MLA cbz x25, TILE1_MLA
ld1 {v4.s}[0], [x23] ld1 {v4.s}[0], [x23]
fmul v27.4s, v27.4s, v4.s[0] fmul v25.4s, v25.4s, v4.s[0]
fmul v26.4s, v26.4s, v4.s[0] fmul v26.4s, v26.4s, v4.s[0]
TILE1_MLA: TILE1_MLA:
MLA_WEIGHTZERO v27, v6, v8, 0 // tile:0, oc:0-3 MLA_WEIGHTZERO v25, v6, v8, 0 // tile:0, oc:0-3
MLA_WEIGHTZERO v26, v6, v9, 0 // tile:0, oc:4-7 MLA_WEIGHTZERO v26, v6, v9, 0 // tile:0, oc:4-7
cbz x9, TILE1_ADD_DSTV cbz x9, TILE1_ADD_DSTV
TILE1_ADD_BIAS: TILE1_ADD_BIAS:
ld1 {v16.4s, v17.4s}, [x20], #32 // bias ld1 {v16.4s, v17.4s}, [x20], #32 // bias
fadd v27.4s, v27.4s, v16.4s fadd v25.4s, v25.4s, v16.4s
fadd v26.4s, v26.4s, v17.4s fadd v26.4s, v26.4s, v17.4s
fcvtn v0.4h, v27.4s cbnz x0, TILE1_POST
fcvtn2 v0.8h, v26.4s b TILE1_TEMP_BUFFER
b TILE1_POST
TILE1_ADD_DSTV: TILE1_ADD_DSTV:
fcvtn v0.4h, v27.4s ld1 {v16.4s, v17.4s}, [x10], #32
fcvtn2 v0.8h, v26.4s fadd v25.4s, v25.4s, v16.4s
ld1 {v24.8h}, [x10] fadd v26.4s, v26.4s, v17.4s
fadd v0.8h, v0.8h, v24.8h cbnz x0, TILE1_POST
TILE1_TEMP_BUFFER:
st1 {v25.4s, v26.4s}, [x15], #32
b Tile1LoopEnd
TILE1_POST: TILE1_POST:
fcvtn v0.4h, v25.4s
fcvtn2 v0.8h, v26.4s
cbz x14, TILE1_STORE cbz x14, TILE1_STORE
ld1r {v29.8h}, [x14], #2 // f32 min ld1r {v29.8h}, [x14], #2 // f32 min
ld1r {v28.8h}, [x14] // f32 max ld1r {v28.8h}, [x14] // f32 max
@ -832,15 +811,15 @@ Tile1Quan:
fmax v0.8h, v0.8h, v29.8h fmax v0.8h, v0.8h, v29.8h
fmin v0.8h, v0.8h, v28.8h fmin v0.8h, v0.8h, v28.8h
TILE1_STORE: TILE1_STORE:
st1 {v0.8h}, [x26], x4 st1 {v0.8h}, [x6], x4
Tile1LoopEnd: Tile1LoopEnd:
cmp x24, #1 cmp x24, #1
bge LoopDz_TILE_1 bge LoopDz_TILE_1
End: End:
ldp x27, x28, [sp, #(16 * 8)] ldp x25, x26, [sp, #(16 * 8)]
ldp x25, x26, [sp, #(16 * 7)] ldp x27, x28, [sp, #(16 * 7)]
ldp x23, x24, [sp, #(16 * 6)] ldp x23, x24, [sp, #(16 * 6)]
ldp x19, x20, [sp, #(16 * 5)] ldp x19, x20, [sp, #(16 * 5)]
ldp x21, x22, [sp, #(16 * 4)] ldp x21, x22, [sp, #(16 * 4)]

View File

@ -115,7 +115,6 @@ UNIT = 8;
SRC_UNIT = 8; SRC_UNIT = 8;
DST_XUNIT = 10; DST_XUNIT = 10;
*/ */
ldr x8, [x6, #0]
ldr x9, [x6, #8] ldr x9, [x6, #8]
stp d14, d15, [sp, #(-16 * 10)]! stp d14, d15, [sp, #(-16 * 10)]!
@ -127,27 +126,24 @@ stp x19, x20, [sp, #(16 * 5)]
stp x23, x24, [sp, #(16 * 6)] stp x23, x24, [sp, #(16 * 6)]
stp x25, x26, [sp, #(16 * 7)] stp x25, x26, [sp, #(16 * 7)]
stp x27, x28, [sp, #(16 * 8)] stp x27, x28, [sp, #(16 * 8)]
// ldr w23, [x6, #24]
ldr x27, [x6, #40] // srcKernelSum ldr x27, [x6, #40] // srcKernelSum
ldr x28, [x6, #48] // weightQuanBias
ldr x14, [x6, #56] // fp32minmax ldr x14, [x6, #56] // fp32minmax
lsl x22, x7, #3 // eDest * GEMM_INT8_SRC_UNIT lsl x22, x7, #3 // eDest * GEMM_INT8_SRC_UNIT
mov x21, #16 // sizeof(float16_t) * UNIT mov x21, #16 // sizeof(float16_t) * UNIT
Start:
lsl x15, x3, #5 // x15 = src_depth_quad * UNIT * UNIT_SRC * sizeof(int4_t) = src_depth_quad * 8 * 8 * 0.5 = src_depth_quad << 5
ldr x23, [x6, #80] // extra scale ldr x23, [x6, #80] // extra scale
ldr x15, [x6, #96]
mov x10, x15 // tag dst address
mov x25, x23
TILE_10: TILE_10:
cmp x7, #10 cmp x7, #10
blt TILE_8 blt TILE_8
sub x4, x4, #128 // For Tile10 sub x4, x4, #128 // For Tile10
LoopDz_TILE_10: LoopDz_TILE_10:
//ld1 {v0.4s, v1.4s}, [x9], #32 // bias
mov x11, x1 // src mov x11, x1 // src
mov x12, x2 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
mov x10, x0 // tag dst address
movi v2.16b, #15 movi v2.16b, #15
SET_0_5 v12, v16, v20, v24, v28 // oc:0,1,0,1 SET_0_5 v12, v16, v20, v24, v28 // oc:0,1,0,1
@ -156,11 +152,10 @@ LoopDz_TILE_10:
SET_0_5 v15, v19, v23, v27, v31 // oc:6,7,6,7 SET_0_5 v15, v19, v23, v27, v31 // oc:6,7,6,7
LoopSz_TILE_10: LoopSz_TILE_10:
ld1 {v0.16b, v1.16b}, [x12], #32 // weight ld1 {v0.16b, v1.16b}, [x2], #32 // weight
ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9 ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9
ld1 {v7.16b}, [x11], #16 ld1 {v7.16b}, [x11], #16
// int4->int8 // int4->int8
ushr v8.16b, v0.16b, #4 // oc:0-1 ushr v8.16b, v0.16b, #4 // oc:0-1
ushr v9.16b, v1.16b, #4 // oc:2-3 ushr v9.16b, v1.16b, #4 // oc:2-3
and v10.16b, v0.16b, v2.16b // oc:4-5 and v10.16b, v0.16b, v2.16b // oc:4-5
@ -193,8 +188,6 @@ LoopSz_TILE_10:
.inst 0x4e8ba4ff // smmla v31.4s, v7.16b, v11.16b // tile8-oc6, tile8-oc7, tile9-oc6, tile9-oc7 .inst 0x4e8ba4ff // smmla v31.4s, v7.16b, v11.16b // tile8-oc6, tile8-oc7, tile9-oc6, tile9-oc7
bne LoopSz_TILE_10 bne LoopSz_TILE_10
LoopSzEnd_TILE_10: LoopSzEnd_TILE_10:
add x2, x2, x15 // weight += dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT * 0.5);
sub x5, x5, #1 // dz--
// transpose // transpose
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
@ -227,11 +220,11 @@ LoopSzEnd_TILE_10:
Int32ToFloat v16, v17, v18, v19 Int32ToFloat v16, v17, v18, v19
Tile10Quan: Tile10Quan:
ld1 {v20.4s, v21.4s}, [x8], #32 // scale ld1 {v20.4s, v21.4s}, [x2], #32 // scale
ld1 {v22.4s, v23.4s}, [x27], #32 // x kernel sum ld1 {v22.4s, v23.4s}, [x27], #32 // x kernel sum
ld1 {v24.d}[0], [x27] ld1 {v24.d}[0], [x27]
sub x27, x27, #32 sub x27, x27, #32
ld1 {v25.4s, v26.4s}, [x28], #32 // weight quan zeropoint ld1 {v25.4s, v26.4s}, [x2], #32 // weight quan zeropoint
MUL_SCALE v20, v0, v1, v4, v5 MUL_SCALE v20, v0, v1, v4, v5
MUL_SCALE v21, v2, v3, v6, v7 MUL_SCALE v21, v2, v3, v6, v7
MUL_SCALE v20, v8, v9, v12, v13 MUL_SCALE v20, v8, v9, v12, v13
@ -241,7 +234,7 @@ Tile10Quan:
fmul v18.4s, v18.4s, v21.4s fmul v18.4s, v18.4s, v21.4s
fmul v19.4s, v19.4s, v21.4s fmul v19.4s, v19.4s, v21.4s
cbz x23, TILE10_MLA cbz x25, TILE10_MLA
ld1 {v27.4s, v28.4s}, [x23], #32 ld1 {v27.4s, v28.4s}, [x23], #32
ld1 {v29.d}[0], [x23] ld1 {v29.d}[0], [x23]
MUL_EXTRA_SCALE v27, v0, v1, v4, v5 MUL_EXTRA_SCALE v27, v0, v1, v4, v5
@ -292,16 +285,46 @@ Tile10Quan:
fadd v17.4s, v17.4s, v20.4s fadd v17.4s, v17.4s, v20.4s
fadd v18.4s, v18.4s, v21.4s fadd v18.4s, v18.4s, v21.4s
fadd v19.4s, v19.4s, v21.4s fadd v19.4s, v19.4s, v21.4s
cbnz x0, TILE10_POST // to Relu post
// float32->float16 b TILE10_TEMP_BUFFER
Float32ToHalf v0, v2, v1, v3, v20, v21
Float32ToHalf v4, v6, v5, v7, v22, v23
Float32ToHalf v8, v10, v9, v11, v24, v25
Float32ToHalf v12, v14, v13, v15, v26, v27
Float32ToHalf v16, v18, v17, v19, v30, v31
b TILE10_POST // to Relu post
TILE10_ADD_DSTV: TILE10_ADD_DSTV:
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], #64
fadd v0.4s, v0.4s, v20.4s
fadd v1.4s, v1.4s, v21.4s
fadd v2.4s, v2.4s, v22.4s
fadd v3.4s, v3.4s, v23.4s
fadd v4.4s, v4.4s, v24.4s
fadd v5.4s, v5.4s, v25.4s
fadd v6.4s, v6.4s, v26.4s
fadd v7.4s, v7.4s, v27.4s
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
fadd v8.4s, v8.4s, v28.4s
fadd v9.4s, v9.4s, v29.4s
fadd v10.4s, v10.4s, v30.4s
fadd v11.4s, v11.4s, v31.4s
fadd v12.4s, v12.4s, v20.4s
fadd v13.4s, v13.4s, v21.4s
fadd v14.4s, v14.4s, v22.4s
fadd v15.4s, v15.4s, v23.4s
fadd v16.4s, v16.4s, v24.4s
fadd v17.4s, v17.4s, v25.4s
fadd v18.4s, v18.4s, v26.4s
fadd v19.4s, v19.4s, v27.4s
cbnz x0, TILE10_POST
TILE10_TEMP_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x15], #64
b Tile10LoopCheck
TILE10_POST:
// float32->float16 // float32->float16
Float32ToHalf v0, v2, v1, v3, v20, v21 Float32ToHalf v0, v2, v1, v3, v20, v21
Float32ToHalf v4, v6, v5, v7, v22, v23 Float32ToHalf v4, v6, v5, v7, v22, v23
@ -309,22 +332,6 @@ Tile10Quan:
Float32ToHalf v12, v14, v13, v15, v26, v27 Float32ToHalf v12, v14, v13, v15, v26, v27
Float32ToHalf v16, v18, v17, v19, v30, v31 Float32ToHalf v16, v18, v17, v19, v30, v31
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64
ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64
ld1 {v8.8h, v9.8h}, [x10]
fadd v20.8h, v20.8h, v0.8h
fadd v21.8h, v21.8h, v1.8h
fadd v22.8h, v22.8h, v2.8h
fadd v23.8h, v23.8h, v3.8h
fadd v24.8h, v24.8h, v4.8h
fadd v25.8h, v25.8h, v5.8h
fadd v26.8h, v26.8h, v6.8h
fadd v27.8h, v27.8h, v7.8h
fadd v30.8h, v30.8h, v8.8h
fadd v31.8h, v31.8h, v9.8h
TILE10_POST:
cbz x14, TILE10_STORE cbz x14, TILE10_STORE
ld1r {v29.8h}, [x14], #2 // f32 min ld1r {v29.8h}, [x14], #2 // f32 min
ld1r {v28.8h}, [x14] // f32 max ld1r {v28.8h}, [x14] // f32 max
@ -341,29 +348,22 @@ Tile10Quan:
st1 {v30.8h, v31.8h}, [x0], x4 st1 {v30.8h, v31.8h}, [x0], x4
Tile10LoopCheck: Tile10LoopCheck:
cmp x5, #1 subs x5, x5, #1 // dz--
bge LoopDz_TILE_10 bne LoopDz_TILE_10
b End b End
TILE_8: TILE_8:
//ld1r {v28.4s}, [x6], #4 // f32 min
//ld1r {v29.4s}, [x6] // f32 max
movi v30.16b, #15 movi v30.16b, #15
cmp x7, #8 cmp x7, #8
blt TILE_4 blt TILE_4
sub x4, x4, #64 // just for Tile8, revert it when Tile8end sub x8, x4, #64 // just for Tile8, revert it when Tile8end
mov x24, x5 // dst_depth_quad mov x24, x5 // dst_depth_quad
mov x26, x0 // dst mov x26, x0 // dst
mov x25, x2 // weight mov x12, x2 // weight
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x28 // weightQuanBias
LoopDz_TILE_8: LoopDz_TILE_8:
//ld1 {v0.4s, v1.4s}, [x20], #32 // bias
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
mov x10, x26 // tag dst
SET_0_4 v12, v16, v20, v24 // oc:0,1,0,1 SET_0_4 v12, v16, v20, v24 // oc:0,1,0,1
SET_0_4 v13, v17, v21, v25 // oc:2,3,2,3 SET_0_4 v13, v17, v21, v25 // oc:2,3,2,3
@ -371,7 +371,6 @@ LoopDz_TILE_8:
SET_0_4 v15, v19, v23, v27 // oc:6,7,6,7 SET_0_4 v15, v19, v23, v27 // oc:6,7,6,7
LoopSz_TILE_8: LoopSz_TILE_8:
ld1 {v0.16b, v1.16b}, [x12], #32 // weight ld1 {v0.16b, v1.16b}, [x12], #32 // weight
//movi v2.16b, #15
ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], x22 // src: E0-E7 ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], x22 // src: E0-E7
// int4->int8 // int4->int8
@ -403,7 +402,6 @@ LoopSz_TILE_8:
bne LoopSz_TILE_8 bne LoopSz_TILE_8
LoopSzEnd_TILE_8: LoopSzEnd_TILE_8:
add x25, x25, x15
sub x24, x24, #1 sub x24, x24, #1
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
@ -430,15 +428,15 @@ LoopSzEnd_TILE_8:
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
Tile8Quan: Tile8Quan:
ld1 {v20.4s, v21.4s}, [x19], #32 // scale ld1 {v20.4s, v21.4s}, [x12], #32 // scale
ld1 {v22.4s, v23.4s}, [x27] // x kernel sum ld1 {v22.4s, v23.4s}, [x27] // x kernel sum
ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint ld1 {v25.4s, v26.4s}, [x12], #32 // weight quan zeropoint
MUL_SCALE v20, v0, v1, v4, v5 MUL_SCALE v20, v0, v1, v4, v5
MUL_SCALE v21, v2, v3, v6, v7 MUL_SCALE v21, v2, v3, v6, v7
MUL_SCALE v20, v8, v9, v12, v13 MUL_SCALE v20, v8, v9, v12, v13
MUL_SCALE v21, v10, v11, v14, v15 MUL_SCALE v21, v10, v11, v14, v15
cbz x23, TILE8_MLA cbz x25, TILE8_MLA
ld1 {v27.4s, v28.4s}, [x23] ld1 {v27.4s, v28.4s}, [x23]
MUL_EXTRA_SCALE v27, v0, v1, v4, v5 MUL_EXTRA_SCALE v27, v0, v1, v4, v5
MUL_EXTRA_SCALE v28, v8, v9, v12, v13 MUL_EXTRA_SCALE v28, v8, v9, v12, v13
@ -473,31 +471,45 @@ Tile8Quan:
ADD_BIAS_FLOAT v2, v3, v6, v7, v17 ADD_BIAS_FLOAT v2, v3, v6, v7, v17
ADD_BIAS_FLOAT v8, v9, v12, v13, v16 ADD_BIAS_FLOAT v8, v9, v12, v13, v16
ADD_BIAS_FLOAT v10, v11, v14, v15, v17 ADD_BIAS_FLOAT v10, v11, v14, v15, v17
// float32->float16 cbnz x0, TILE8_POST
Float32ToHalf v0, v2, v1, v3, v20, v21 b TILE8_TEMP_BUFFER
Float32ToHalf v4, v6, v5, v7, v22, v23
Float32ToHalf v8, v10, v9, v11, v24, v25
Float32ToHalf v12, v14, v13, v15, v26, v27
b TILE8_POST
TILE8_ADD_DSTV: TILE8_ADD_DSTV:
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
fadd v0.4s, v0.4s, v16.4s
fadd v1.4s, v1.4s, v17.4s
fadd v2.4s, v2.4s, v18.4s
fadd v3.4s, v3.4s, v19.4s
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64
fadd v4.4s, v4.4s, v20.4s
fadd v5.4s, v5.4s, v21.4s
fadd v6.4s, v6.4s, v22.4s
fadd v7.4s, v7.4s, v23.4s
fadd v8.4s, v8.4s, v24.4s
fadd v9.4s, v9.4s, v25.4s
fadd v10.4s, v10.4s, v26.4s
fadd v11.4s, v11.4s, v27.4s
fadd v12.4s, v12.4s, v16.4s
fadd v13.4s, v13.4s, v17.4s
fadd v14.4s, v14.4s, v18.4s
fadd v15.4s, v15.4s, v19.4s
cbnz x0, TILE8_POST
TILE8_TEMP_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
b Tile8LoopCheck
TILE8_POST:
// float32->float16 // float32->float16
Float32ToHalf v0, v2, v1, v3, v20, v21 Float32ToHalf v0, v2, v1, v3, v20, v21
Float32ToHalf v4, v6, v5, v7, v22, v23 Float32ToHalf v4, v6, v5, v7, v22, v23
Float32ToHalf v8, v10, v9, v11, v24, v25 Float32ToHalf v8, v10, v9, v11, v24, v25
Float32ToHalf v12, v14, v13, v15, v26, v27 Float32ToHalf v12, v14, v13, v15, v26, v27
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64
ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10]
fadd v20.8h, v20.8h, v0.8h
fadd v21.8h, v21.8h, v1.8h
fadd v22.8h, v22.8h, v2.8h
fadd v23.8h, v23.8h, v3.8h
fadd v24.8h, v24.8h, v4.8h
fadd v25.8h, v25.8h, v5.8h
fadd v26.8h, v26.8h, v6.8h
fadd v27.8h, v27.8h, v7.8h
TILE8_POST:
cbz x14, TILE8_STORE cbz x14, TILE8_STORE
ld1r {v29.8h}, [x14], #2 // f32 min ld1r {v29.8h}, [x14], #2 // f32 min
ld1r {v28.8h}, [x14] // f32 max ld1r {v28.8h}, [x14] // f32 max
@ -507,35 +519,29 @@ Tile8Quan:
TILE8_STORE: TILE8_STORE:
st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x26], #64 st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x26], #64
st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x26], x4 st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x26], x8
Tile8LoopCheck: Tile8LoopCheck:
cmp x24, #1 cmp x24, #1
bge LoopDz_TILE_8 bge LoopDz_TILE_8
cbz x23, Tile8End cbz x0, Tile8End
add x23, x23, #32 add x0, x0, x21, LSL #3
Tile8End: Tile8End:
sub x7, x7, #8 sub x7, x7, #8
add x0, x0, x21, LSL #3 add x23, x23, #32
add x1, x1, #64 add x1, x1, #64
add x27, x27, #32 add x27, x27, #32
add x4, x4, #64 // Revert x4 for following tiles
TILE_4: TILE_4:
cmp x7, #4 cmp x7, #4
blt TILE_2 blt TILE_2
mov x24, x5 // dst_depth_quad mov x24, x5 // dst_depth_quad
mov x26, x0 // dst mov x26, x0 // dst
mov x25, x2 // weight mov x12, x2 // weight
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x28 // weightQuanBias
LoopDz_TILE_4: LoopDz_TILE_4:
//ld1 {v0.4s, v1.4s}, [x20], #32 // bias
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
mov x10, x26 // tag dst
SET_0_2 v12, v16 // oc:0,1,0,1 SET_0_2 v12, v16 // oc:0,1,0,1
SET_0_2 v13, v17 // oc:2,3,2,3 SET_0_2 v13, v17 // oc:2,3,2,3
@ -562,7 +568,6 @@ LoopSz_TILE_4:
bne LoopSz_TILE_4 bne LoopSz_TILE_4
LoopSzEnd_TILE_4: LoopSzEnd_TILE_4:
add x25, x25, x15
sub x24, x24, #1 sub x24, x24, #1
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
@ -577,13 +582,13 @@ LoopSzEnd_TILE_4:
Int32ToFloat v4, v5, v6, v7 Int32ToFloat v4, v5, v6, v7
Tile4Quan: Tile4Quan:
ld1 {v20.4s, v21.4s}, [x19], #32 // scale ld1 {v20.4s, v21.4s}, [x12], #32 // scale
ld1 {v22.4s}, [x27] // x kernel sum ld1 {v22.4s}, [x27] // x kernel sum
ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint ld1 {v25.4s, v26.4s}, [x12], #32 // weight quan zeropoint
MUL_SCALE v20, v0, v1, v4, v5 MUL_SCALE v20, v0, v1, v4, v5
MUL_SCALE v21, v2, v3, v6, v7 MUL_SCALE v21, v2, v3, v6, v7
cbz x23, TILE4_MLA cbz x25, TILE4_MLA
ld1 {v27.4s}, [x23] ld1 {v27.4s}, [x23]
MUL_EXTRA_SCALE v27, v0, v1, v4, v5 MUL_EXTRA_SCALE v27, v0, v1, v4, v5
MUL_EXTRA_SCALE v27, v2, v3, v6, v7 MUL_EXTRA_SCALE v27, v2, v3, v6, v7
@ -604,22 +609,31 @@ Tile4Quan:
ld1 {v16.4s, v17.4s}, [x20], #32 // bias ld1 {v16.4s, v17.4s}, [x20], #32 // bias
ADD_BIAS_FLOAT v0, v1, v4, v5, v16 ADD_BIAS_FLOAT v0, v1, v4, v5, v16
ADD_BIAS_FLOAT v2, v3, v6, v7, v17 ADD_BIAS_FLOAT v2, v3, v6, v7, v17
// float32->float16 cbnz x0, TILE4_POST
Float32ToHalf v0, v2, v1, v3, v20, v21 b TILE4_TEMP_BUFFER
Float32ToHalf v4, v6, v5, v7, v22, v23
b TILE4_POST
TILE4_ADD_DSTV: TILE4_ADD_DSTV:
// float32->float16 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
Float32ToHalf v0, v2, v1, v3, v20, v21 ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
Float32ToHalf v4, v6, v5, v7, v22, v23 fadd v0.4s, v0.4s, v20.4s
ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x10] fadd v1.4s, v1.4s, v21.4s
fadd v20.8h, v20.8h, v24.8h fadd v2.4s, v2.4s, v22.4s
fadd v21.8h, v21.8h, v25.8h fadd v3.4s, v3.4s, v23.4s
fadd v22.8h, v22.8h, v26.8h fadd v4.4s, v4.4s, v24.4s
fadd v23.8h, v23.8h, v27.8h fadd v5.4s, v5.4s, v25.4s
fadd v6.4s, v6.4s, v26.4s
fadd v7.4s, v7.4s, v27.4s
cbnz x0, TILE4_POST
TILE4_TEMP_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
b Tile4LoopCheck
TILE4_POST: TILE4_POST:
// float32->float16
Float32ToHalf v0, v2, v1, v3, v20, v21
Float32ToHalf v4, v6, v5, v7, v22, v23
cbz x14, TILE4_STORE cbz x14, TILE4_STORE
ld1r {v29.8h}, [x14], #2 // f32 min ld1r {v29.8h}, [x14], #2 // f32 min
ld1r {v28.8h}, [x14] // f32 max ld1r {v28.8h}, [x14] // f32 max
@ -632,35 +646,25 @@ Tile4Quan:
Tile4LoopCheck: Tile4LoopCheck:
cmp x24, #1 cmp x24, #1
bge LoopDz_TILE_4 bge LoopDz_TILE_4
cbz x23, Tile4End cbz x0, Tile4End
add x23, x23, #16 add x0, x0, x21, LSL #2
Tile4End: Tile4End:
sub x7, x7, #4 sub x7, x7, #4
add x0, x0, x21, LSL #2
add x1, x1, #32 add x1, x1, #32
add x27, x27, #16 add x27, x27, #16
//b TILE_4 add x23, x23, #16
TILE_2: TILE_2:
cmp x7, #2 cmp x7, #2
blt TILE_1 blt TILE_1
mov x24, x5 // dst_depth_quad mov x24, x5 // dst_depth_quad
mov x26, x0 // dst mov x26, x0 // dst
mov x25, x2 // weight mov x12, x2 // weight
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x28 // weightQuanBias
LoopDz_TILE_2: LoopDz_TILE_2:
//ld1 {v0.4s, v1.4s}, [x20], #32 // bias
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
mov x10, x26 // tag dst
// v12 oc:0,1,0,1
// v13 oc:2,3,2,3
// v14 oc:4,5,4,5
// v15 oc:6,7,6,7
SET_0_4 v12, v13, v14, v15 SET_0_4 v12, v13, v14, v15
LoopSz_TILE_2: LoopSz_TILE_2:
ld1 {v2.16b, v3.16b}, [x12], #32 // weight ld1 {v2.16b, v3.16b}, [x12], #32 // weight
@ -678,7 +682,6 @@ LoopSz_TILE_2:
subs x13, x13, #1 subs x13, x13, #1
bne LoopSz_TILE_2 bne LoopSz_TILE_2
LoopSzEnd_TILE_2: LoopSzEnd_TILE_2:
add x25, x25, x15
sub x24, x24, #1 sub x24, x24, #1
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
@ -687,15 +690,15 @@ LoopSzEnd_TILE_2:
Int32ToFloat v0, v1, v2, v3 Int32ToFloat v0, v1, v2, v3
Tile2Quan: Tile2Quan:
ld1 {v20.4s, v21.4s}, [x19], #32 // scale ld1 {v20.4s, v21.4s}, [x12], #32 // scale
ld1 {v22.d}[0], [x27] // x kernel sum ld1 {v22.d}[0], [x27] // x kernel sum
ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint ld1 {v25.4s, v26.4s}, [x12], #32 // weight quan zeropoint
fmul v0.4s, v0.4s, v20.4s fmul v0.4s, v0.4s, v20.4s
fmul v1.4s, v1.4s, v20.4s fmul v1.4s, v1.4s, v20.4s
fmul v2.4s, v2.4s, v21.4s fmul v2.4s, v2.4s, v21.4s
fmul v3.4s, v3.4s, v21.4s fmul v3.4s, v3.4s, v21.4s
cbz x23, TILE2_MLA cbz x25, TILE2_MLA
ld1 {v27.d}[0], [x23] ld1 {v27.d}[0], [x23]
fmul v0.4s, v0.4s, v27.s[0] fmul v0.4s, v0.4s, v27.s[0]
fmul v1.4s, v1.4s, v27.s[1] fmul v1.4s, v1.4s, v27.s[1]
@ -715,17 +718,24 @@ Tile2Quan:
fadd v1.4s, v1.4s, v16.4s fadd v1.4s, v1.4s, v16.4s
fadd v2.4s, v2.4s, v17.4s fadd v2.4s, v2.4s, v17.4s
fadd v3.4s, v3.4s, v17.4s fadd v3.4s, v3.4s, v17.4s
// float32->float16 cbnz x0, TILE2_POST
Float32ToHalf v0, v2, v1, v3, v20, v21 b TILE2_TEMP_BUFFER
b TILE2_POST
TILE2_ADD_DSTV: TILE2_ADD_DSTV:
Float32ToHalf v0, v2, v1, v3, v20, v21 ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ld1 {v24.8h, v25.8h}, [x10] fadd v0.4s, v0.4s, v4.4s
fadd v20.8h, v20.8h, v24.8h fadd v1.4s, v1.4s, v5.4s
fadd v21.8h, v21.8h, v25.8h fadd v2.4s, v2.4s, v6.4s
fadd v3.4s, v3.4s, v7.4s
cbnz x0, TILE2_POST
TILE2_TEMP_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
b Tile2LoopCheck
TILE2_POST: TILE2_POST:
// float32->float16
Float32ToHalf v0, v2, v1, v3, v20, v21
cbz x14, TILE2_STORE cbz x14, TILE2_STORE
ld1r {v29.8h}, [x14], #2 // f32 min ld1r {v29.8h}, [x14], #2 // f32 min
ld1r {v28.8h}, [x14] // f32 max ld1r {v28.8h}, [x14] // f32 max
@ -741,14 +751,13 @@ Tile2Quan:
Tile2LoopCheck: Tile2LoopCheck:
cmp x24, #1 cmp x24, #1
bge LoopDz_TILE_2 bge LoopDz_TILE_2
cbz x23, Tile2End cbz x0, Tile2End
add x23, x23, #8 add x0, x0, x21, LSL #1
Tile2End: Tile2End:
sub x7, x7, #2 sub x7, x7, #2
add x0, x0, x21, LSL #1
add x1, x1, #16 add x1, x1, #16
add x27, x27, #8 add x27, x27, #8
//b TILE_2 add x23, x23, #8
TILE_1: TILE_1:
@ -756,15 +765,11 @@ TILE_1:
blt End blt End
mov x24, x5 // dst_depth_quad mov x24, x5 // dst_depth_quad
mov x26, x0 // dst mov x26, x0 // dst
mov x25, x2 // weight mov x12, x2 // weight
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x28 // weightQuanBias
LoopDz_TILE_1: LoopDz_TILE_1:
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
mov x10, x26
movi v16.4s, #0 // oc:0,1,0,1 movi v16.4s, #0 // oc:0,1,0,1
movi v17.4s, #0 // oc:2,3,2,3 movi v17.4s, #0 // oc:2,3,2,3
@ -845,44 +850,48 @@ LoopSz1_TILE_1_lu1:
bne LoopSz1_TILE_1_lu1 bne LoopSz1_TILE_1_lu1
LoopSzEnd_TILE_1: LoopSzEnd_TILE_1:
add x25, x25, x15
sub x24, x24, #1 sub x24, x24, #1
uzp1 v27.2d, v16.2d, v17.2d uzp1 v25.2d, v16.2d, v17.2d
uzp1 v26.2d, v18.2d, v19.2d uzp1 v26.2d, v18.2d, v19.2d
scvtf v27.4s, v27.4s scvtf v25.4s, v25.4s
scvtf v26.4s, v26.4s scvtf v26.4s, v26.4s
Tile1Quan: Tile1Quan:
ld1 {v0.4s, v1.4s}, [x19], #32 // scale ld1 {v0.4s, v1.4s}, [x12], #32 // scale
ld1 {v6.s}[0], [x27] // x kernel sum ld1 {v6.s}[0], [x27] // x kernel sum
ld1 {v8.4s, v9.4s}, [x6], #32 // weight quan zeropoint ld1 {v8.4s, v9.4s}, [x12], #32 // weight quan zeropoint
fmul v27.4s, v27.4s, v0.4s fmul v25.4s, v25.4s, v0.4s
fmul v26.4s, v26.4s, v1.4s fmul v26.4s, v26.4s, v1.4s
cbz x23, TILE1_MLA cbz x25, TILE1_MLA
ld1 {v4.s}[0], [x23] ld1 {v4.s}[0], [x23]
fmul v27.4s, v27.4s, v4.s[0] fmul v25.4s, v25.4s, v4.s[0]
fmul v26.4s, v26.4s, v4.s[0] fmul v26.4s, v26.4s, v4.s[0]
TILE1_MLA: TILE1_MLA:
MLA_WEIGHTZERO v27, v6, v8, 0 // tile:0, oc:0-3 MLA_WEIGHTZERO v25, v6, v8, 0 // tile:0, oc:0-3
MLA_WEIGHTZERO v26, v6, v9, 0 // tile:0, oc:4-7 MLA_WEIGHTZERO v26, v6, v9, 0 // tile:0, oc:4-7
cbz x9, TILE1_ADD_DSTV cbz x9, TILE1_ADD_DSTV
TILE1_ADD_BIAS: TILE1_ADD_BIAS:
ld1 {v16.4s, v17.4s}, [x20], #32 // bias ld1 {v16.4s, v17.4s}, [x20], #32 // bias
fadd v27.4s, v27.4s, v16.4s fadd v25.4s, v25.4s, v16.4s
fadd v26.4s, v26.4s, v17.4s fadd v26.4s, v26.4s, v17.4s
fcvtn v0.4h, v27.4s cbnz x0, TILE1_POST
fcvtn2 v0.8h, v26.4s b TILE1_TEMP_BUFFER
b TILE1_POST
TILE1_ADD_DSTV: TILE1_ADD_DSTV:
fcvtn v0.4h, v27.4s ld1 {v16.4s, v17.4s}, [x10], #32
fcvtn2 v0.8h, v26.4s fadd v25.4s, v25.4s, v16.4s
ld1 {v24.8h}, [x10] fadd v26.4s, v26.4s, v17.4s
fadd v0.8h, v0.8h, v24.8h cbnz x0, TILE1_POST
TILE1_TEMP_BUFFER:
st1 {v25.4s, v26.4s}, [x15], #32
b Tile1LoopEnd
TILE1_POST: TILE1_POST:
fcvtn v0.4h, v25.4s
fcvtn2 v0.8h, v26.4s
cbz x14, TILE1_STORE cbz x14, TILE1_STORE
ld1r {v29.8h}, [x14], #2 // f32 min ld1r {v29.8h}, [x14], #2 // f32 min
ld1r {v28.8h}, [x14] // f32 max ld1r {v28.8h}, [x14] // f32 max

View File

@ -776,6 +776,10 @@ void registerCPURuntimeCreator() {
#endif #endif
#ifdef MNN_USE_ARMV82 #ifdef MNN_USE_ARMV82
registerArm82RuntimeCreator(); registerArm82RuntimeCreator();
#endif
#ifdef MNN_KLEIDIAI_ENABLED
// Init kleidiAI
KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo(), false, false);
#endif #endif
// TODO: Merge _initCoreFunction MNNFunctionInit and cpuinfo_arm_init // TODO: Merge _initCoreFunction MNNFunctionInit and cpuinfo_arm_init
MNNInsertExtraRuntimeCreator(MNN_FORWARD_CPU, new CPURuntimeCreator); MNNInsertExtraRuntimeCreator(MNN_FORWARD_CPU, new CPURuntimeCreator);

View File

@ -237,9 +237,6 @@ private:
CPUBackend::addCreator(opType, &_temp); \ CPUBackend::addCreator(opType, &_temp); \
} }
#define REGISTER_CPU_OP_CREATOR_AUDIO(name, opType) \
REGISTER_CPU_OP_CREATOR(name, opType)
} // namespace MNN } // namespace MNN
#endif /* CPUBackend_hpp */ #endif /* CPUBackend_hpp */

View File

@ -25,6 +25,21 @@
#endif #endif
namespace MNN { namespace MNN {
void CPUConvolution::Resource::copyBias(float* dst, const float* bias, int outputCount, Backend* backend) {
auto core = static_cast<CPUBackend*>(backend)->functions();
int bytes = core->bytes;
int unit = core->pack;
auto alignOutput = UP_DIV(outputCount, unit) * unit;
int remain = alignOutput - outputCount;
if (bytes < 4) {
core->MNNFp32ToLowp(bias, (int16_t*)dst, outputCount);
} else {
::memcpy(dst, bias, outputCount * bytes);
}
if (remain > 0) {
::memset((uint8_t*)dst + outputCount * bytes, 0, remain * bytes);
}
}
bool CPUConvolution::Resource::copyBiasAlign(const float* bias, int outputCount) { bool CPUConvolution::Resource::copyBiasAlign(const float* bias, int outputCount) {
auto core = static_cast<CPUBackend*>(backend)->functions(); auto core = static_cast<CPUBackend*>(backend)->functions();
@ -38,14 +53,7 @@ bool CPUConvolution::Resource::copyBiasAlign(const float* bias, int outputCount)
MNN_ERROR("Error for alloc memory for Alloc Bias\n"); MNN_ERROR("Error for alloc memory for Alloc Bias\n");
return false;; return false;;
} }
if (bytes < 4) { copyBias(mBias->host<float>(), bias, outputCount, backend);
core->MNNFp32ToLowp(bias, mBias->host<int16_t>(), outputCount);
} else {
::memcpy(mBias->host<float>(), bias, outputCount * bytes);
}
if (remain > 0) {
::memset(mBias->host<uint8_t>() + outputCount * bytes, 0, remain * bytes);
}
return true; return true;
} }
CPUConvolution::MutableResourceInt8::MutableResourceInt8(std::shared_ptr<ResourceInt8> res, Backend* backend) : mResource(res) { CPUConvolution::MutableResourceInt8::MutableResourceInt8(std::shared_ptr<ResourceInt8> res, Backend* backend) : mResource(res) {
@ -104,9 +112,6 @@ void CPUConvolution::MutableResourceInt8::updateInputOutputScale(std::vector<flo
mOutputScale = mResource->mOutputScale; mOutputScale = mResource->mOutputScale;
mInputZeroPoint = mResource->mInputZeroPoint; mInputZeroPoint = mResource->mInputZeroPoint;
mOutputZeroPoint = mResource->mOutputZeroPoint; mOutputZeroPoint = mResource->mOutputZeroPoint;
// if (mInputScale == inputScale && mOutputScale == outputScale) {
// return;
// }
if (inputScale != 0 && outputScale != 0) { if (inputScale != 0 && outputScale != 0) {
mInputScale = inputScale; mInputScale = inputScale;
mOutputScale = outputScale; mOutputScale = outputScale;
@ -138,7 +143,6 @@ void CPUConvolution::MutableResourceInt8::updateInputOutputScale(std::vector<flo
// compute outputZeroPointFused in asymmetric quant // compute outputZeroPointFused in asymmetric quant
int outputZeroPointFused = static_cast<int32_t>(mOutputZeroPoint / scale[i]); int outputZeroPointFused = static_cast<int32_t>(mOutputZeroPoint / scale[i]);
bias[i] = static_cast<int32_t>(biasData[i] / (mInputScale * alphaValue)) - mResource->mInt8WeightKernelSum[i] * (mInputZeroPoint + offset) + outputZeroPointFused; bias[i] = static_cast<int32_t>(biasData[i] / (mInputScale * alphaValue)) - mResource->mInt8WeightKernelSum[i] * (mInputZeroPoint + offset) + outputZeroPointFused;
// biasfloat[i] = biasData[i] / mOutputScale - mResource->mInt8WeightKernelSum[i] * (mInputZeroPoint + offset) * scale[i] + mOutputZeroPoint;
biasfloat[i] = bias[i] * scale[i]; biasfloat[i] = bias[i] * scale[i];
} }
} }

View File

@ -45,11 +45,11 @@ public:
std::shared_ptr<Tensor> mScaleBias; std::shared_ptr<Tensor> mScaleBias;
}; };
struct Resource { struct Resource {
std::shared_ptr<Tensor> mWeightKernelSum;
std::shared_ptr<Tensor> mWeight; std::shared_ptr<Tensor> mWeight;
std::shared_ptr<Tensor> mBias; std::shared_ptr<Tensor> mBias;
ResourceDequantizeInfo mDequantize; ResourceDequantizeInfo mDequantize;
Backend* backend; Backend* backend;
static void copyBias(float* dst, const float* bias, int outputCount, Backend* backend);
bool copyBiasAlign(const float* bias, int outputCount); bool copyBiasAlign(const float* bias, int outputCount);
int hU; int hU;
int lU; int lU;
@ -79,6 +79,7 @@ public:
int8_t mClampMin; int8_t mClampMin;
int8_t mClampMax; int8_t mClampMax;
bool mDynamicQuant = false; bool mDynamicQuant = false;
int32_t mBlockNum = 1;
}; };
struct MutableResourceInt8 { struct MutableResourceInt8 {
MutableResourceInt8(std::shared_ptr<ResourceInt8> res, Backend* backend); MutableResourceInt8(std::shared_ptr<ResourceInt8> res, Backend* backend);

View File

@ -23,9 +23,9 @@
namespace MNN { namespace MNN {
CPUDeconvolutionBasic::CPUDeconvolutionBasic(const Tensor* input, const Op* convOp, Backend* b) CPUDeconvolutionBasic::CPUDeconvolutionBasic(int inputChannel, const Op* convOp, Backend* b)
: CPUConvolution(convOp->main_as_Convolution2D()->common(), b) { : CPUConvolution(convOp->main_as_Convolution2D()->common(), b) {
mSrcCount = input->channel(); mSrcCount = inputChannel;
mPostParameters = getPostParameters(); mPostParameters = getPostParameters();
} }
@ -38,33 +38,6 @@ ErrorCode CPUDeconvolutionBasic::onResize(const std::vector<Tensor*>& inputs, co
return NO_ERROR; return NO_ERROR;
} }
CPUDeconvolutionCommon::CPUDeconvolutionCommon(const Tensor* input, const Op* convOp, Backend* b, bool dynamicWeight)
: CPUDeconvolutionBasic(input, convOp, b) {
auto conv2D = convOp->main_as_Convolution2D();
int outputCount = mCommon->outputCount();
auto core = static_cast<CPUBackend*>(b)->functions();
mDynamicWeight = dynamicWeight;
mBias.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputCount, core->pack) * core->pack}));
if (dynamicWeight) {
return;
}
bool success = b->onAcquireBuffer(mBias.get(), Backend::STATIC);
if (!success) {
mValid = false;
return;
}
::memset(mBias->host<float>(), 0, mBias->length(0) * core->bytes);
if (core->bytes == 4) {
::memcpy(mBias->host<float>(), conv2D->bias()->data(), conv2D->bias()->size() * sizeof(float));
} else {
core->MNNFp32ToLowp(conv2D->bias()->data(), mBias->host<int16_t>(), conv2D->bias()->size());
}
}
CPUDeconvolutionCommon::~CPUDeconvolutionCommon() {
// Do nothing
}
// Float Weight. // Float Weight.
static void _transformWeight(const uint8_t* tempWeight, uint8_t* dest, int outputCount, int srcCount, int fh, int fw, static void _transformWeight(const uint8_t* tempWeight, uint8_t* dest, int outputCount, int srcCount, int fh, int fw,
uint8_t* cache, const CoreFunctions* core) { uint8_t* cache, const CoreFunctions* core) {
@ -82,66 +55,90 @@ static void _transformWeight(const uint8_t* tempWeight, uint8_t* dest, int outpu
//printf("%d - %d - %d - %d\n", outputCount, srcCount, fh, fw); //printf("%d - %d - %d - %d\n", outputCount, srcCount, fh, fw);
core->MNNPackForMatMul_B((float*)dest, (const float*)cache, outputC4 * fw * fh * core->pack, srcCount, false); core->MNNPackForMatMul_B((float*)dest, (const float*)cache, outputC4 * fw * fh * core->pack, srcCount, false);
} }
std::shared_ptr<DeconvolutionResource> CPUDeconvolution::makeResource(int srcCount, const Op *convOp, Backend* backend, bool dynamic) {
CPUDeconvolution::CPUDeconvolution(const Tensor* input, const Op* convOp, Backend* backend, bool dynamicWeight) auto core = static_cast<CPUBackend*>(backend)->functions();
: MNN::CPUDeconvolutionCommon(input, convOp, backend, dynamicWeight) { auto coreInt8 = static_cast<CPUBackend*>(backend)->int8Functions();
auto core = static_cast<CPUBackend*>(backend)->functions();
auto coreInt8 = static_cast<CPUBackend*>(backend)->int8Functions();
int eP, lP, hP; int eP, lP, hP;
core->MNNGetMatMulPackMode(&eP, &lP, &hP); core->MNNGetMatMulPackMode(&eP, &lP, &hP);
auto conv2d = convOp->main_as_Convolution2D(); auto conv2d = convOp->main_as_Convolution2D();
auto layer = conv2d->common(); auto layer = conv2d->common();
int outputCount = layer->outputCount(); int outputCount = layer->outputCount();
const auto outputChannleUp4 = UP_DIV(outputCount, hP) * hP; const auto outputChannleUp4 = UP_DIV(outputCount, hP) * hP;
int fw = layer->kernelX(); int fw = layer->kernelX();
int fh = layer->kernelY(); int fh = layer->kernelY();
int srcCount = mSrcCount; std::shared_ptr<DeconvolutionResource> res(new DeconvolutionResource);
mParam.fh = fh; res->mParam.fh = fh;
mParam.fw = fw; res->mParam.fw = fw;
mParam.srcCount = srcCount; res->mParam.srcCount = srcCount;
mParam.outputCount = outputCount; res->mParam.outputCount = outputCount;
if (dynamic) {
return res;
}
auto outputAlign = UP_DIV(layer->outputCount(), core->pack) * core->pack * fw * fh; auto outputAlign = UP_DIV(layer->outputCount(), core->pack) * core->pack * fw * fh;
mWeight.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputAlign, hP), UP_DIV(srcCount, lP) * lP, hP})); const float* tempWeight = nullptr;
std::shared_ptr<Tensor> cache(Tensor::createDevice<float>({outputAlign * srcCount})); int tempWeightSize = 0;
if (dynamicWeight) {
mOrigin.reset(new CPUDeconvolutionOrigin(input, mWeight.get(), convOp, backend, false));
mWeightTransformCache = cache;
return;
}
const float* tempWeight = nullptr;
int tempWeightSize = 0;
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon; std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
ConvolutionCommon::getConvParameters(&quanCommon, backend, convOp, &tempWeight, &tempWeightSize); ConvolutionCommon::getConvParameters(&quanCommon, backend, convOp, &tempWeight, &tempWeightSize);
bool success = backend->onAcquireBuffer(mWeight.get(), Backend::STATIC) &&
backend->onAcquireBuffer(cache.get(), Backend::STATIC);
if (!success) {
mValid = false;
return;
}
AutoStorage<uint8_t> lowpWeight; AutoStorage<uint8_t> lowpWeight;
if (core->bytes < 4) { if (core->bytes < 4) {
lowpWeight.reset(outputCount * srcCount * fh * fw * core->bytes); lowpWeight.reset(outputCount * srcCount * fh * fw * core->bytes);
if (lowpWeight.get() == nullptr) { if (lowpWeight.get() == nullptr) {
mValid = false; return nullptr;
return;
} }
core->MNNFp32ToLowp(tempWeight, (int16_t*)lowpWeight.get(), outputCount * srcCount * fh * fw); core->MNNFp32ToLowp(tempWeight, (int16_t*)lowpWeight.get(), outputCount * srcCount * fh * fw);
tempWeight = (float*)lowpWeight.get(); tempWeight = (float*)lowpWeight.get();
quanCommon.reset();
} }
mWeight.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputAlign, hP), UP_DIV(srcCount, lP) * lP, hP})); res->mWeight.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputAlign, hP), UP_DIV(srcCount, lP) * lP, hP}));
success = backend->onAcquireBuffer(mWeight.get(), Backend::STATIC); res->mBias.reset(Tensor::createDevice<float>({UP_DIV(outputCount, core->pack) * core->pack}));
if (!success) { bool success = backend->onAcquireBuffer(res->mWeight.get(), Backend::STATIC) && backend->onAcquireBuffer(res->mBias.get(), Backend::STATIC);
mValid = false; AutoStorage<float> cache(outputAlign * srcCount);
if (!success || cache.get() == nullptr) {
MNN_ERROR("Alloc memory error for deconvolution\n");
return nullptr;
}
CPUConvolution::Resource::copyBias(res->mBias->host<float>(), convOp->main_as_Convolution2D()->bias()->data(), outputCount, backend);
_transformWeight((uint8_t*)tempWeight, res->mWeight->host<uint8_t>(), outputCount, srcCount, fh, fw, (uint8_t*)cache.get(), core);
return res;
}
bool CPUDeconvolution::onClone(Backend* bn, const Op* op, Execution** dst) {
if (mDynamicWeight) {
return false;
}
if (nullptr == dst) {
return true;
}
auto exe = new CPUDeconvolution(mSrcCount, op, bn, mDynamicWeight, mResource);
*dst = exe;
return true;
}
CPUDeconvolution::CPUDeconvolution(int srcCount, const Op* convOp, Backend* backend, bool dynamicWeight, std::shared_ptr<DeconvolutionResource> resource) : MNN::CPUDeconvolutionBasic(srcCount, convOp, backend) {
mDynamicWeight = dynamicWeight;
mResource = resource;
if (dynamicWeight) {
auto core = static_cast<CPUBackend*>(backend)->functions();
auto coreInt8 = static_cast<CPUBackend*>(backend)->int8Functions();
int eP, lP, hP;
core->MNNGetMatMulPackMode(&eP, &lP, &hP);
auto conv2d = convOp->main_as_Convolution2D();
auto layer = conv2d->common();
int outputCount = layer->outputCount();
const auto outputChannleUp4 = UP_DIV(outputCount, hP) * hP;
int fw = layer->kernelX();
int fh = layer->kernelY();
auto outputAlign = UP_DIV(layer->outputCount(), core->pack) * core->pack * fw * fh;
mWeight.reset(Tensor::createDevice<float>(std::vector<int>{UP_DIV(outputAlign, hP), UP_DIV(srcCount, lP) * lP, hP}));
mBias.reset(Tensor::createDevice<float>({UP_DIV(outputCount, core->pack) * core->pack}));
mOrigin.reset(new CPUDeconvolutionOrigin(srcCount, convOp, backend));
mWeightTransformCache.reset(Tensor::createDevice<float>({outputAlign * srcCount}));
return; return;
} else {
mWeight = mResource->mWeight;
mBias = mResource->mBias;
} }
auto dest = mWeight->host<uint8_t>(); mOrigin.reset(new CPUDeconvolutionOrigin(srcCount, convOp, backend));
_transformWeight((uint8_t*)tempWeight, dest, outputCount, srcCount, fh, fw, cache->host<uint8_t>(), core);
backend->onReleaseBuffer(cache.get(), Backend::STATIC);
mOrigin.reset(new CPUDeconvolutionOrigin(input, mWeight.get(), convOp, backend, false));
} }
CPUDeconvolution::~CPUDeconvolution() { CPUDeconvolution::~CPUDeconvolution() {
@ -150,7 +147,7 @@ CPUDeconvolution::~CPUDeconvolution() {
ErrorCode CPUDeconvolution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { ErrorCode CPUDeconvolution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
if (mDynamicWeight) { if (mDynamicWeight) {
auto core = static_cast<CPUBackend*>(backend())->functions(); auto core = static_cast<CPUBackend*>(backend())->functions();
_transformWeight(inputs[1]->host<uint8_t>(), mWeight->host<uint8_t>(), mParam.outputCount, mParam.srcCount, mParam.fh, mParam.fw, mWeightTransformCache->host<uint8_t>(), core); _transformWeight(inputs[1]->host<uint8_t>(), mWeight->host<uint8_t>(), mResource->mParam.outputCount, mResource->mParam.srcCount, mResource->mParam.fh, mResource->mParam.fw, mWeightTransformCache->host<uint8_t>(), core);
::memset(mBias->host<uint8_t>(), 0, mBias->length(0) * core->bytes); ::memset(mBias->host<uint8_t>(), 0, mBias->length(0) * core->bytes);
if (inputs.size() >= 3) { if (inputs.size() >= 3) {
::memcpy(mBias->host<uint8_t>(), inputs[2]->host<uint8_t>(), TensorUtils::getRawSize(inputs[2]) * core->bytes); ::memcpy(mBias->host<uint8_t>(), inputs[2]->host<uint8_t>(), TensorUtils::getRawSize(inputs[2]) * core->bytes);
@ -186,7 +183,7 @@ ErrorCode CPUDeconvolution::onResize(const std::vector<Tensor *> &inputs, const
return NO_ERROR; return NO_ERROR;
} }
CPUDeconvolutionOrigin::CPUDeconvolutionOrigin(const Tensor *input, Tensor *weight, const Op *convOp, Backend *b, bool ModeInt8) : CPUDeconvolutionBasic(input, convOp, b) { CPUDeconvolutionOrigin::CPUDeconvolutionOrigin(int inputChannel, const Op *convOp, Backend *b) : CPUDeconvolutionBasic(inputChannel, convOp, b) {
// Do nothing // Do nothing
} }
@ -362,7 +359,12 @@ public:
const MNN::Op* op, Backend* backend) const { const MNN::Op* op, Backend* backend) const {
auto convOp = op->main_as_Convolution2D(); auto convOp = op->main_as_Convolution2D();
auto common = convOp->common(); auto common = convOp->common();
return new CPUDeconvolution(inputs[0], op, backend, inputs.size() > 1); auto res = CPUDeconvolution::makeResource(inputs[0]->channel(), op, backend, inputs.size() > 1);
if (nullptr == res) {
MNN_ERROR("CPUDeconvolution makeResource error\n");
return nullptr;
}
return new CPUDeconvolution(inputs[0]->channel(), op, backend, inputs.size() > 1, res);
} }
}; };

View File

@ -14,9 +14,20 @@
#include "compute/StrassenMatmulComputor.hpp" #include "compute/StrassenMatmulComputor.hpp"
#include "core/TensorUtils.hpp" #include "core/TensorUtils.hpp"
namespace MNN { namespace MNN {
struct DeconvolutionResource {
struct Param {
int outputCount;
int srcCount;
int fh;
int fw;
};
Param mParam;
std::shared_ptr<Tensor> mBias;
std::shared_ptr<Tensor> mWeight;
};
class CPUDeconvolutionBasic : public CPUConvolution { class CPUDeconvolutionBasic : public CPUConvolution {
public: public:
CPUDeconvolutionBasic(const Tensor *input, const Op *convOp, Backend *b); CPUDeconvolutionBasic(int inputChannel, const Op *convOp, Backend *b);
virtual ~CPUDeconvolutionBasic() = default; virtual ~CPUDeconvolutionBasic() = default;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
@ -25,19 +36,9 @@ protected:
std::vector<float> mPostParameters; std::vector<float> mPostParameters;
}; };
class CPUDeconvolutionCommon : public CPUDeconvolutionBasic {
public:
CPUDeconvolutionCommon(const Tensor *input, const Op *convOp, Backend *b, bool dynamicWeight);
virtual ~CPUDeconvolutionCommon();
protected:
std::shared_ptr<Tensor> mBias;
bool mDynamicWeight;
};
class CPUDeconvolutionOrigin : public CPUDeconvolutionBasic { class CPUDeconvolutionOrigin : public CPUDeconvolutionBasic {
public: public:
CPUDeconvolutionOrigin(const Tensor *input, Tensor *weight, const Op *convOp, Backend *b, bool ModeInt8); CPUDeconvolutionOrigin(int inputChannel, const Op *convOp, Backend *b);
virtual ~CPUDeconvolutionOrigin() = default; virtual ~CPUDeconvolutionOrigin() = default;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
@ -50,22 +51,20 @@ private:
std::vector<std::pair<std::function<void(uint8_t*, int)>, int>> mExecuteFuntion; std::vector<std::pair<std::function<void(uint8_t*, int)>, int>> mExecuteFuntion;
}; };
class CPUDeconvolution : public CPUDeconvolutionCommon { class CPUDeconvolution : public CPUDeconvolutionBasic {
public: public:
CPUDeconvolution(const Tensor *input, const Op *convOp, Backend *b, bool dynamicWeight); static std::shared_ptr<DeconvolutionResource> makeResource(int inputChannel, const Op *convOp, Backend *b, bool dynamic);
CPUDeconvolution(int inputChannel, const Op *convOp, Backend *b, bool dynamicWeight, std::shared_ptr<DeconvolutionResource> res);
virtual ~CPUDeconvolution(); virtual ~CPUDeconvolution();
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
struct Param {
int outputCount;
int srcCount;
int fh;
int fw;
};
private: private:
Param mParam; bool mDynamicWeight;
std::shared_ptr<DeconvolutionResource> mResource;
std::shared_ptr<Tensor> mWeight; std::shared_ptr<Tensor> mWeight;
std::shared_ptr<Tensor> mBias;
std::shared_ptr<Tensor> mWeightTransformCache; std::shared_ptr<Tensor> mWeightTransformCache;
std::vector<Tensor *> mTempInputs; std::vector<Tensor *> mTempInputs;
std::shared_ptr<CPUDeconvolutionOrigin> mOrigin; std::shared_ptr<CPUDeconvolutionOrigin> mOrigin;

View File

@ -15,49 +15,66 @@
namespace MNN { namespace MNN {
CPUDeconvolutionDepthwise::CPUDeconvolutionDepthwise(const Tensor* input, const Op* convOp, Backend* b) std::shared_ptr<DeconvolutionResource> CPUDeconvolutionDepthwise::makeResource(int inputChannel, const Op *convOp, Backend* backend) {
: MNN::CPUDeconvolutionCommon(input, convOp, b, false) { std::shared_ptr<DeconvolutionResource> res(new DeconvolutionResource);
auto conv = convOp->main_as_Convolution2D(); auto conv = convOp->main_as_Convolution2D();
auto layer = convOp->main_as_Convolution2D()->common(); auto layer = convOp->main_as_Convolution2D()->common();
int kw = layer->kernelX(); int kw = layer->kernelX();
int kh = layer->kernelY(); int kh = layer->kernelY();
int outputCount = layer->outputCount(); int outputCount = layer->outputCount();
auto core = static_cast<CPUBackend*>(backend())->functions(); auto core = static_cast<CPUBackend*>(backend)->functions();
int depthQuad = UP_DIV(outputCount, core->pack); int depthQuad = UP_DIV(outputCount, core->pack);
const float* tempWeight = nullptr; const float* tempWeight = nullptr;
int tempWeightSize = 0; int tempWeightSize = 0;
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon; std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
ConvolutionCommon::getConvParameters(&quanCommon, b, convOp, &tempWeight, &tempWeightSize); ConvolutionCommon::getConvParameters(&quanCommon, backend, convOp, &tempWeight, &tempWeightSize);
if (nullptr == tempWeight) {
return nullptr;
}
// Reorder weight from whc -> pwhc4 // Reorder weight from whc -> pwhc4
int kernelSize = depthQuad * core->pack * kw * kh; int kernelSize = depthQuad * core->pack * kw * kh;
mWeight.reset(Tensor::createDevice<float>(std::vector<int>{kernelSize})); res->mWeight.reset(Tensor::createDevice<float>(std::vector<int>{kernelSize}));
auto sucess = backend()->onAcquireBuffer(mWeight.get(), Backend::STATIC); res->mBias.reset(Tensor::createDevice<float>(std::vector<int>{depthQuad * core->pack}));
auto sucess = backend->onAcquireBuffer(res->mWeight.get(), Backend::STATIC) && backend->onAcquireBuffer(res->mBias.get(), Backend::STATIC);
if (!sucess) { if (!sucess) {
mValid = false; return nullptr;
return;
} }
CPUConvolution::Resource::copyBias(res->mBias->host<float>(), convOp->main_as_Convolution2D()->bias()->data(), outputCount, backend);
AutoStorage<uint8_t> weightTempStorage; AutoStorage<uint8_t> weightTempStorage;
if (core->bytes < 4) { if (core->bytes < 4) {
weightTempStorage.reset(kernelSize * core->bytes); weightTempStorage.reset(kernelSize * core->bytes);
if (weightTempStorage.get() == nullptr) { if (weightTempStorage.get() == nullptr) {
mValid = false; return nullptr;
return;
} }
core->MNNFp32ToLowp(tempWeight, (int16_t*)weightTempStorage.get(), kernelSize); core->MNNFp32ToLowp(tempWeight, (int16_t*)weightTempStorage.get(), kernelSize);
tempWeight = (const float*)weightTempStorage.get(); tempWeight = (const float*)weightTempStorage.get();
} }
auto weight = mWeight->host<float>(); auto weight = res->mWeight->host<float>();
int offset[] = { int offset[] = {
kw * kh, kw * kh,
kw * kh kw * kh
}; };
core->MNNPackCUnit(weight, tempWeight, kw * kh, outputCount, offset); core->MNNPackCUnit(weight, tempWeight, kw * kh, outputCount, offset);
mOrigin.reset(new CPUDeconvolutionDepthwiseBasic(input, convOp, b)); return res;
}
CPUDeconvolutionDepthwise::CPUDeconvolutionDepthwise(int inputChannel, const Op* convOp, Backend* b, std::shared_ptr<DeconvolutionResource> res)
: MNN::CPUDeconvolutionBasic(inputChannel, convOp, b) {
mResource = res;
mOrigin.reset(new CPUDeconvolutionDepthwiseBasic(inputChannel, convOp, b));
}
bool CPUDeconvolutionDepthwise::onClone(Backend* bn, const Op* op, Execution** dst) {
if (nullptr == dst) {
return true;
}
*dst = new CPUDeconvolutionDepthwise(mSrcCount, op, bn, mResource);
return true;
} }
CPUDeconvolutionDepthwise::~CPUDeconvolutionDepthwise() { CPUDeconvolutionDepthwise::~CPUDeconvolutionDepthwise() {
backend()->onReleaseBuffer(mWeight.get(), Backend::STATIC); // Do nothing
} }
ErrorCode CPUDeconvolutionDepthwiseMultiInput::onResize(const std::vector<Tensor*>& inputs, ErrorCode CPUDeconvolutionDepthwiseMultiInput::onResize(const std::vector<Tensor*>& inputs,
@ -214,9 +231,14 @@ public:
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
const MNN::Op* op, Backend* backend) const { const MNN::Op* op, Backend* backend) const {
if (1 < inputs.size()) { if (1 < inputs.size()) {
return new CPUDeconvolutionDepthwiseMultiInput(inputs[0], op, backend); return new CPUDeconvolutionDepthwiseMultiInput(inputs[0]->channel(), op, backend);
} }
return new CPUDeconvolutionDepthwise(inputs[0], op, backend); auto res = CPUDeconvolutionDepthwise::makeResource(inputs[0]->channel(), op, backend);
if (nullptr == res.get()) {
MNN_ERROR("Create Resource error for DeconvolutionDepthwise\n");
return nullptr;
}
return new CPUDeconvolutionDepthwise(inputs[0]->channel(), op, backend, res);
} }
}; };

View File

@ -14,8 +14,8 @@
namespace MNN { namespace MNN {
class CPUDeconvolutionDepthwiseBasic : public CPUDeconvolutionBasic { class CPUDeconvolutionDepthwiseBasic : public CPUDeconvolutionBasic {
public: public:
CPUDeconvolutionDepthwiseBasic(const Tensor *input, const Op *convOp, Backend *b) CPUDeconvolutionDepthwiseBasic(int inputChannel, const Op *convOp, Backend *b)
: CPUDeconvolutionBasic(input, convOp, b) { : CPUDeconvolutionBasic(inputChannel, convOp, b) {
} }
virtual ~CPUDeconvolutionDepthwiseBasic() = default; virtual ~CPUDeconvolutionDepthwiseBasic() = default;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
@ -27,8 +27,8 @@ private:
class CPUDeconvolutionDepthwiseMultiInput : public CPUDeconvolutionDepthwiseBasic { class CPUDeconvolutionDepthwiseMultiInput : public CPUDeconvolutionDepthwiseBasic {
public: public:
CPUDeconvolutionDepthwiseMultiInput(const Tensor *input, const Op *convOp, Backend *b) CPUDeconvolutionDepthwiseMultiInput(int inputChannel, const Op *convOp, Backend *b)
: CPUDeconvolutionDepthwiseBasic(input, convOp, b) { : CPUDeconvolutionDepthwiseBasic(inputChannel, convOp, b) {
} }
virtual ~CPUDeconvolutionDepthwiseMultiInput() = default; virtual ~CPUDeconvolutionDepthwiseMultiInput() = default;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
@ -40,20 +40,23 @@ private:
std::vector<Tensor *> mInputs; std::vector<Tensor *> mInputs;
}; };
class CPUDeconvolutionDepthwise : public CPUDeconvolutionCommon { class CPUDeconvolutionDepthwise : public CPUDeconvolutionBasic {
public: public:
CPUDeconvolutionDepthwise(const Tensor *input, const Op *convOp, Backend *b); static std::shared_ptr<DeconvolutionResource> makeResource(int inputChannel, const Op *convOp, Backend *b);
CPUDeconvolutionDepthwise(int inputChannel, const Op *convOp, Backend *b, std::shared_ptr<DeconvolutionResource> res);
virtual ~CPUDeconvolutionDepthwise(); virtual ~CPUDeconvolutionDepthwise();
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override { virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override {
mInputs = {inputs[0], mWeight.get(), mBias.get()}; mInputs = {inputs[0], mResource->mWeight.get(), mResource->mBias.get()};
return mOrigin->onResize(mInputs, outputs); return mOrigin->onResize(mInputs, outputs);
} }
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override { virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override {
return mOrigin->onExecute(mInputs, outputs); return mOrigin->onExecute(mInputs, outputs);
} }
private: private:
std::shared_ptr<Tensor> mWeight; std::shared_ptr<DeconvolutionResource> mResource;
std::vector<Tensor *> mInputs; std::vector<Tensor *> mInputs;
std::unique_ptr<CPUDeconvolutionDepthwiseBasic> mOrigin; std::unique_ptr<CPUDeconvolutionDepthwiseBasic> mOrigin;
}; };

View File

@ -34,6 +34,7 @@ extern void ___CPUROIAlignCreator__OpType_ROIAlign__();
extern void ___CPUROIPoolingCreator__OpType_ROIPooling__(); extern void ___CPUROIPoolingCreator__OpType_ROIPooling__();
extern void ___CPUTopKV2Creator__OpType_TopKV2__(); extern void ___CPUTopKV2Creator__OpType_TopKV2__();
extern void ___CPUUnaryCreator__OpType_UnaryOp__(); extern void ___CPUUnaryCreator__OpType_UnaryOp__();
extern void ___CPUStftCreator__OpType_Stft__();
extern void ___CPUReductionCreator__OpType_Reduction__(); extern void ___CPUReductionCreator__OpType_Reduction__();
extern void ___CPUReluCreator__OpType_ReLU__(); extern void ___CPUReluCreator__OpType_ReLU__();
extern void ___CPUReluCreator__OpType_PReLU__(); extern void ___CPUReluCreator__OpType_PReLU__();
@ -78,9 +79,6 @@ extern void ___CPUTextureCreator__OpType_Texture__();
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE #ifdef MNN_SUPPORT_TRANSFORMER_FUSE
extern void ___CPUAttentionCreator__OpType_Attention__(); extern void ___CPUAttentionCreator__OpType_Attention__();
#endif #endif
#ifdef MNN_BUILD_AUDIO
extern void ___CPUStftCreator__OpType_Stft__();
#endif
void registerCPUOps() { void registerCPUOps() {
___CPUCropAndResizeCreator__OpType_CropAndResize__(); ___CPUCropAndResizeCreator__OpType_CropAndResize__();
___CPUArgMaxCreator__OpType_ArgMax__(); ___CPUArgMaxCreator__OpType_ArgMax__();
@ -116,6 +114,7 @@ ___CPUROIAlignCreator__OpType_ROIAlign__();
___CPUROIPoolingCreator__OpType_ROIPooling__(); ___CPUROIPoolingCreator__OpType_ROIPooling__();
___CPUTopKV2Creator__OpType_TopKV2__(); ___CPUTopKV2Creator__OpType_TopKV2__();
___CPUUnaryCreator__OpType_UnaryOp__(); ___CPUUnaryCreator__OpType_UnaryOp__();
___CPUStftCreator__OpType_Stft__();
___CPUReductionCreator__OpType_Reduction__(); ___CPUReductionCreator__OpType_Reduction__();
___CPUReluCreator__OpType_ReLU__(); ___CPUReluCreator__OpType_ReLU__();
___CPUReluCreator__OpType_PReLU__(); ___CPUReluCreator__OpType_PReLU__();
@ -159,8 +158,5 @@ ___CPUTextureCreator__OpType_Texture__();
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE #ifdef MNN_SUPPORT_TRANSFORMER_FUSE
___CPUAttentionCreator__OpType_Attention__(); ___CPUAttentionCreator__OpType_Attention__();
#endif #endif
#ifdef MNN_BUILD_AUDIO
___CPUStftCreator__OpType_Stft__();
#endif
} }
} }

View File

@ -104,7 +104,7 @@ ErrorCode CPURaster::onResize(const std::vector<Tensor *> &____inputs, const std
if (des->regions.size() == 1) { if (des->regions.size() == 1) {
OpCommonUtils::turnRegion2Convert(des->regions[0], output, mSingleConvert); OpCommonUtils::turnRegion2Convert(des->regions[0], output, mSingleConvert);
if (mSingleConvert.type > 0) { if (mSingleConvert.type > 0) {
mUseThreads = (mSingleConvert.batch * mSingleConvert.area > LAUNCH_MULTI_THREADS_WORKLOAD) ? true : false; mUseThreads = (mSingleConvert.batch * mSingleConvert.channel * mSingleConvert.area > LAUNCH_MULTI_THREADS_WORKLOAD) ? true : false;
return NO_ERROR; return NO_ERROR;
} }
} }

View File

@ -6,12 +6,16 @@
// Copyright © 2018, Alibaba Group Holding Limited // Copyright © 2018, Alibaba Group Holding Limited
// //
#ifdef MNN_BUILD_AUDIO /**
Ref from onnxruntime
*/
#ifndef M_PI #ifndef M_PI
#define M_PI 3.141592654 #define M_PI 3.141592654
#endif #endif
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <complex>
#include "backend/cpu/CPUStft.hpp" #include "backend/cpu/CPUStft.hpp"
#include "backend/cpu/CPUBackend.hpp" #include "backend/cpu/CPUBackend.hpp"
#include "core/Concurrency.h" #include "core/Concurrency.h"
@ -20,74 +24,474 @@
#include "compute/CommonOptFunction.h" #include "compute/CommonOptFunction.h"
namespace MNN { namespace MNN {
std::vector<float> CPUStft::gSinTable;
std::vector<float> CPUStft::gCosTable;
static void MNNDftAbs(const float* input, const float* window, float* output, float* buffer, int nfft) { #define ___RETURN_IF_ERROR(x) {auto code = (x); if (NO_ERROR != code) {return code;}}
for (int i = 0; i < nfft; ++i) { #define ___RETURN_IF(x, y) {if (x) {return NOT_SUPPORT;}}
buffer[i] = input[i] * window[i];
static bool is_real_valued_signal(const Tensor* shape) {
return shape->dimensions() == 2 || shape->length(shape->dimensions() -1) == 1;
}
static bool is_complex_valued_signal(const Tensor* shape) {
return shape->dimensions() > 2 && shape->length(shape->dimensions() -1) == 2;
}
static bool is_power_of_2(size_t size) {
size_t n_bits = 0;
while (size != 0) {
n_bits += size & 1;
size = size >> 1;
} }
for (int k = 0; k < nfft / 2 + 1; ++k) { return n_bits == 1;
float real_sum = 0.f, imag_sum = 0.f; }
for (int n = 0; n < nfft; ++n) {
int index = (n * k) % nfft; static const unsigned char BitReverseTable256[] = {
real_sum += buffer[n] * CPUStft::gCosTable[index]; 0x00, 0x80, 0x40, 0xC0, 0x20, 0xA0, 0x60, 0xE0, 0x10, 0x90, 0x50, 0xD0, 0x30, 0xB0, 0x70, 0xF0, 0x08, 0x88, 0x48,
imag_sum -= buffer[n] * CPUStft::gSinTable[index]; 0xC8, 0x28, 0xA8, 0x68, 0xE8, 0x18, 0x98, 0x58, 0xD8, 0x38, 0xB8, 0x78, 0xF8, 0x04, 0x84, 0x44, 0xC4, 0x24, 0xA4,
0x64, 0xE4, 0x14, 0x94, 0x54, 0xD4, 0x34, 0xB4, 0x74, 0xF4, 0x0C, 0x8C, 0x4C, 0xCC, 0x2C, 0xAC, 0x6C, 0xEC, 0x1C,
0x9C, 0x5C, 0xDC, 0x3C, 0xBC, 0x7C, 0xFC, 0x02, 0x82, 0x42, 0xC2, 0x22, 0xA2, 0x62, 0xE2, 0x12, 0x92, 0x52, 0xD2,
0x32, 0xB2, 0x72, 0xF2, 0x0A, 0x8A, 0x4A, 0xCA, 0x2A, 0xAA, 0x6A, 0xEA, 0x1A, 0x9A, 0x5A, 0xDA, 0x3A, 0xBA, 0x7A,
0xFA, 0x06, 0x86, 0x46, 0xC6, 0x26, 0xA6, 0x66, 0xE6, 0x16, 0x96, 0x56, 0xD6, 0x36, 0xB6, 0x76, 0xF6, 0x0E, 0x8E,
0x4E, 0xCE, 0x2E, 0xAE, 0x6E, 0xEE, 0x1E, 0x9E, 0x5E, 0xDE, 0x3E, 0xBE, 0x7E, 0xFE, 0x01, 0x81, 0x41, 0xC1, 0x21,
0xA1, 0x61, 0xE1, 0x11, 0x91, 0x51, 0xD1, 0x31, 0xB1, 0x71, 0xF1, 0x09, 0x89, 0x49, 0xC9, 0x29, 0xA9, 0x69, 0xE9,
0x19, 0x99, 0x59, 0xD9, 0x39, 0xB9, 0x79, 0xF9, 0x05, 0x85, 0x45, 0xC5, 0x25, 0xA5, 0x65, 0xE5, 0x15, 0x95, 0x55,
0xD5, 0x35, 0xB5, 0x75, 0xF5, 0x0D, 0x8D, 0x4D, 0xCD, 0x2D, 0xAD, 0x6D, 0xED, 0x1D, 0x9D, 0x5D, 0xDD, 0x3D, 0xBD,
0x7D, 0xFD, 0x03, 0x83, 0x43, 0xC3, 0x23, 0xA3, 0x63, 0xE3, 0x13, 0x93, 0x53, 0xD3, 0x33, 0xB3, 0x73, 0xF3, 0x0B,
0x8B, 0x4B, 0xCB, 0x2B, 0xAB, 0x6B, 0xEB, 0x1B, 0x9B, 0x5B, 0xDB, 0x3B, 0xBB, 0x7B, 0xFB, 0x07, 0x87, 0x47, 0xC7,
0x27, 0xA7, 0x67, 0xE7, 0x17, 0x97, 0x57, 0xD7, 0x37, 0xB7, 0x77, 0xF7, 0x0F, 0x8F, 0x4F, 0xCF, 0x2F, 0xAF, 0x6F,
0xEF, 0x1F, 0x9F, 0x5F, 0xDF, 0x3F, 0xBF, 0x7F, 0xFF};
template <typename T>
static inline T bit_reverse(T num, unsigned significant_bits) {
if (significant_bits > 32) {
MNN_ERROR("Unsupported bit size.");
}
uint32_t num_32 = static_cast<uint32_t>(num);
uint32_t rev = (BitReverseTable256[num_32 & 0xff] << 24) | (BitReverseTable256[(num_32 >> 8) & 0xff] << 16) |
(BitReverseTable256[(num_32 >> 16) & 0xff] << 8) | (BitReverseTable256[(num_32 >> 24) & 0xff]);
return static_cast<T>(((uint64_t)rev) >> (32 - significant_bits));
}
template <typename T>
static T compute_angular_velocity(size_t number_of_samples, bool inverse) {
// Calculate fundamental angular velocity
static const T pi = static_cast<T>(M_PI);
static const T tau = 2 * pi;
T inverse_switch = inverse ? 1.f : -1.f;
T angular_velocity = inverse_switch * tau / number_of_samples;
return angular_velocity;
}
template <typename T>
static std::complex<T> compute_exponential(size_t index, const T angular_velocity) {
const T angle = static_cast<T>(index) * angular_velocity;
return std::complex<T>(cos(angle), sin(angle));
}
template <typename T, typename U>
static ErrorCode fft_radix2(Backend* backend, const Tensor* X, Tensor* Y, size_t X_offset, size_t X_stride,
size_t Y_offset, size_t Y_stride, int64_t axis, size_t dft_length, const Tensor* window,
bool is_onesided, bool inverse, std::vector<std::complex<T>>& V,
std::vector<std::complex<T>>& temp_output) {
// Get shape and significant bits
const auto X_shape = X->shape();
size_t number_of_samples = static_cast<size_t>(X_shape[axis]);
unsigned significant_bits = static_cast<unsigned>(log2(dft_length));
// Get data
auto* X_data = const_cast<U*>(reinterpret_cast<const U*>(X->host<void>())) + X_offset;
// Get window
float* window_data = nullptr;
if (window) {
window_data = const_cast<float*>(reinterpret_cast<const float*>(window->host<void>()));
}
size_t Y_data_stride = 1;
std::complex<T>* Y_data;
if (is_onesided) {
if (temp_output.size() != dft_length) {
temp_output.resize(dft_length);
} }
output[k] = sqrtf(real_sum * real_sum + imag_sum * imag_sum); Y_data = temp_output.data();
} else {
Y_data = reinterpret_cast<std::complex<T>*>(Y->host<void>()) + Y_offset;
Y_data_stride = Y_stride;
} }
auto angular_velocity = compute_angular_velocity<T>(dft_length, inverse);
// Create vandermonde matrix V ordered with the bit-reversed permutation
if (V.size() != dft_length) {
V.resize(dft_length);
for (size_t i = 0; i < dft_length; i++) {
size_t bit_reversed_index = bit_reverse(i, significant_bits);
V[bit_reversed_index] = compute_exponential(i, angular_velocity);
}
}
for (size_t i = 0; i < dft_length; i++) {
size_t bit_reversed_index = bit_reverse(i, significant_bits);
auto x = (bit_reversed_index < number_of_samples) ? *(X_data + bit_reversed_index * X_stride) : 0;
auto window_element = window_data ? *(window_data + bit_reversed_index) : 1;
*(Y_data + i * Y_data_stride) = std::complex<T>(1, 0) * x * window_element;
}
// Run fft_radix2
unsigned current_significant_bits = 0;
for (size_t i = 2; i <= dft_length; i <<= 1) {
size_t midpoint = i >> 1;
current_significant_bits++;
for (size_t k = 0; k < midpoint; k++) {
auto first_idx = bit_reverse(k, current_significant_bits);
auto second_idx = bit_reverse(midpoint + k, current_significant_bits);
for (size_t j = 0; j < dft_length; j += i) {
auto even_index = k + j;
auto odd_index = k + j + midpoint;
std::complex<T>* even = (Y_data + even_index * Y_data_stride);
std::complex<T>* odd = (Y_data + odd_index * Y_data_stride);
std::complex<T> first = *even + (V[first_idx] * *odd);
std::complex<T> second = *even + (V[second_idx] * *odd);
*even = first;
*odd = second;
}
}
}
// Scale the output if inverse
if (inverse) {
for (size_t i = 0; i < dft_length; i++) {
std::complex<T>& val = *(Y_data + i * Y_data_stride);
val /= static_cast<T>(dft_length);
}
}
if (is_onesided) {
const size_t output_size = (dft_length >> 1) + 1;
auto destination = reinterpret_cast<std::complex<T>*>(Y->host<void>()) + Y_offset;
for (size_t i = 0; i < output_size; i++) {
*(destination + Y_stride * i) = *(Y_data + i * Y_data_stride);
}
}
return NO_ERROR;
}
template <typename T>
T next_power_of_2(T in) {
in--;
T out = 1;
while (out <= in) {
out <<= 1;
}
return out;
}
template <typename T, typename U>
static ErrorCode dft_bluestein_z_chirp(Backend* bn, const Tensor* X, Tensor* Y, std::shared_ptr<Tensor>& b_fft_p, std::shared_ptr<Tensor>& chirp_p, size_t X_offset, size_t X_stride, size_t Y_offset, size_t Y_stride,
int64_t axis, size_t dft_length, const Tensor* window, bool inverse, std::vector<std::complex<T>>& V,
std::vector<std::complex<T>>& temp_output) {
static const T pi = static_cast<T>(M_PI);
size_t N = static_cast<size_t>(dft_length);
size_t M = next_power_of_2(2 * N - 1);
auto dft_input_shape = std::vector<int>({1, (int)M, 2});
T scale = inverse ? 1.f / N : 1.f;
T direction = inverse ? 1.f : -1.f;
bool should_recreate_b_fft = b_fft_p->elementSize() != M * 2;
bool should_recreate_chirp = chirp_p->elementSize() != M * 2;
bool should_recreate = should_recreate_b_fft || should_recreate_chirp;
if (should_recreate) {
std::shared_ptr<Tensor> b_p(Tensor::create(dft_input_shape, X->getType()));
auto& b = *b_p;
b_fft_p.reset(Tensor::create(dft_input_shape, Y->getType()));
auto& b_fft = *b_fft_p;
chirp_p.reset(Tensor::create(dft_input_shape, X->getType()));
auto& chirp = *chirp_p;
std::complex<T>* b_data = reinterpret_cast<std::complex<T>*>(b.host<void>());
std::complex<T>* b_fft_data = reinterpret_cast<std::complex<T>*>(b_fft.host<void>());
std::complex<T>* chirp_data = reinterpret_cast<std::complex<T>*>(chirp.host<void>());
memset(reinterpret_cast<void*>(b_data), 0, b.usize());
memset(reinterpret_cast<void*>(b_fft_data), 0, b_fft.usize());
memset(reinterpret_cast<void*>(chirp_data), 0, chirp.usize());
for (size_t n = 0; n < N; n++) {
std::complex<T>& chirp_n = *(chirp_data + n);
// chirp
auto exponent = direction * pi * n * n / N;
chirp_n = std::complex<T>(cos(exponent), sin(exponent));
// b
std::complex<T>& b_n = *(b_data + n);
b_n = std::conj(chirp_n);
}
for (size_t n = M - N + 1; n < M; n++) {
std::complex<T>& b_n = *(b_data + n);
std::complex<T>& b_m_minus_n = *(b_data + M - n);
b_n = b_m_minus_n;
}
// Forward FFT radix2 for the "b" signal
// This will be cached and reused!
auto code = ((fft_radix2<T, std::complex<T>>(bn, &b, &b_fft, 0, 1, 0, 1, 1, M, nullptr,
false, false, V, temp_output)));
if (NO_ERROR != code) {
FUNC_PRINT(1);
return code;
}
}
// Get data
auto* X_data = const_cast<U*>(reinterpret_cast<const U*>(X->host<void>())) + X_offset;
auto* Y_data = reinterpret_cast<std::complex<T>*>(Y->host<void>()) + Y_offset;
float* window_data = nullptr;
if (window) {
window_data = const_cast<float*>(reinterpret_cast<const float*>(window->host<void>()));
}
std::shared_ptr<Tensor> a_p(Tensor::create(dft_input_shape, X->getType()));
auto& a = *a_p;
std::shared_ptr<Tensor> a_fft_p(Tensor::create(dft_input_shape, Y->getType()));
auto& a_fft = *a_fft_p;
std::complex<T>* a_data = reinterpret_cast<std::complex<T>*>(a.host<void>());
std::complex<T>* a_fft_data = reinterpret_cast<std::complex<T>*>(a_fft.host<void>());
std::complex<T>* b_fft_data = reinterpret_cast<std::complex<T>*>(b_fft_p->host<void>());
std::complex<T>* chirp_data = reinterpret_cast<std::complex<T>*>(chirp_p->host<void>());
memset(reinterpret_cast<void*>(a_data), 0, a.usize());
const auto& X_shape = X->shape();
size_t number_of_samples = static_cast<size_t>(X_shape[axis]);
// Prepare "a" signal
for (size_t n = 0; n < number_of_samples; n++) {
std::complex<T>& a_n = *(a_data + n);
std::complex<T>& chirp_n = *(chirp_data + n);
auto window_n = window_data ? *(window_data + n) : 1;
a_n = *(X_data + n * X_stride); // input
a_n *= window_n;
a_n *= chirp_n;
}
// Forward FFT radix2 for the "a" signal
{
auto code = ((fft_radix2<T, std::complex<T>>(bn, &a, &a_fft, 0, 1, 0, 1, 1, M, nullptr,
false, false, V, temp_output)));
if (NO_ERROR != code) {
return code;
}
}
for (size_t i = 0; i < M; i++) {
std::complex<T>& a_i = *(a_fft_data + i);
std::complex<T>& b_i = *(b_fft_data + i);
a_i *= b_i;
}
// Inverse FFT radix2 for the "a" signal
{
auto code = ((fft_radix2<T, std::complex<T>>(bn, &a_fft, &a, 0, 1, 0, 1, 1, M, nullptr,
false, true, V, temp_output)));
if (NO_ERROR != code) {
return code;
}
}
const auto& Y_shape = Y->shape();
size_t dft_output_size = static_cast<size_t>(Y_shape[(axis)]);
for (size_t i = 0; i < dft_output_size; i++) {
std::complex<T>& chirp_i = *(chirp_data + i);
std::complex<T>& out = *(Y_data + i * Y_stride);
std::complex<T>& c_i = *(a_data + i);
if (i > 0) {
// The inverse fft is computed using the same cached vandermonde matrix (V) created by the
// forward fft. This reversal causes the output to be reversed as well.
// Therefore we undo the reversal when writing the output back out.
c_i = *(a_data + M - i);
}
out = c_i * chirp_i * scale;
}
return NO_ERROR;
}
template <typename T, typename U>
static ErrorCode discrete_fourier_transform(Backend* ctx, const Tensor* X, Tensor* Y, std::shared_ptr<Tensor>& b_fft, std::shared_ptr<Tensor>& chirp,
int64_t axis, int64_t dft_length, const Tensor* window, bool is_onesided, bool inverse,
std::vector<std::complex<T>>& V,
std::vector<std::complex<T>>& temp_output) {
// Get shape
const auto& X_shape = X->shape();
const auto& Y_shape = Y->shape();
auto batch_and_signal_rank = X->dimensions();
auto total_dfts = static_cast<size_t>(X->elementSize() / X->length(axis));
auto is_input_real = X->dimensions() == 2 || X->length(X->dimensions() - 1) == 1;
auto complex_input_factor = is_input_real ? 1 : 2;
if (X->dimensions() > 2) {
total_dfts /= (X->length(X->dimensions() - 1));
batch_and_signal_rank -= 1;
}
// Calculate x/y offsets/strides
for (size_t i = 0; i < total_dfts; i++) {
size_t X_offset = 0;
size_t X_stride = X->stride(axis) / complex_input_factor;
size_t cumulative_packed_stride = total_dfts;
size_t temp = i;
for (size_t r = 0; r < batch_and_signal_rank; r++) {
if (r == static_cast<size_t>(axis)) {
continue;
}
cumulative_packed_stride /= (X_shape[r]);
auto index = temp / cumulative_packed_stride;
temp -= (index * cumulative_packed_stride);
X_offset += index * X->stride(r) / complex_input_factor;
}
size_t Y_offset = 0;
size_t Y_stride = Y->stride(axis) / 2;
cumulative_packed_stride = total_dfts;
temp = i;
for (size_t r = 0; r < batch_and_signal_rank; r++) {
if (r == static_cast<size_t>(axis)) {
continue;
}
cumulative_packed_stride /= (X_shape[r]);
auto index = temp / cumulative_packed_stride;
temp -= (index * cumulative_packed_stride);
Y_offset += index * (size_t)(Y->stride(r) / 2);
}
if (is_power_of_2((dft_length))) {
___RETURN_IF_ERROR((fft_radix2<T, U>(ctx, X, Y, X_offset, X_stride, Y_offset, Y_stride, axis, (dft_length), window,
is_onesided, inverse, V, temp_output)));
} else {
___RETURN_IF_ERROR(
(dft_bluestein_z_chirp<T, U>(ctx, X, Y, b_fft, chirp, X_offset, X_stride, Y_offset, Y_stride, axis, (dft_length), window, inverse, V, temp_output)));
}
}
return NO_ERROR;
}
static ErrorCode discrete_fourier_transform(Backend* ctx, int64_t axis, bool is_onesided, bool inverse, Tensor* X, Tensor* dft_length, Tensor* Y) {
// Get input shape
const auto is_real_valued = is_real_valued_signal(X);
const auto is_complex_valued = is_complex_valued_signal(X);
if (axis < 0) {
axis = axis + X->dimensions();
}
int64_t number_of_samples = static_cast<int64_t>(X->length(axis));
if (dft_length) {
const auto& dft_length_shape = dft_length->shape();
number_of_samples = dft_length->host<int>()[0];
}
// Get the DFT output size. Onesided will return only the unique values!
// note: x >> 1 === std::floor(x / 2.f)
auto dft_output_size = is_onesided ? ((number_of_samples >> 1) + 1) : number_of_samples;
std::shared_ptr<Tensor> b_fft(new Tensor), chirp(new Tensor);
std::vector<std::complex<float>> V;
std::vector<std::complex<float>> temp_output;
if (is_real_valued) {
___RETURN_IF_ERROR((discrete_fourier_transform<float, float>(ctx, X, Y, b_fft, chirp, axis, number_of_samples, nullptr,
is_onesided, inverse, V, temp_output)));
} else if (is_complex_valued) {
___RETURN_IF_ERROR((discrete_fourier_transform<float, std::complex<float>>(
ctx, X, Y, b_fft, chirp, axis, number_of_samples, nullptr, is_onesided, inverse, V, temp_output)));
}
return NO_ERROR;
}
template <typename T, typename U>
static ErrorCode short_time_fourier_transform(Backend* ctx, Tensor* signal, Tensor* Y, int frame_step, Tensor* window, bool is_onesided, bool /*inverse*/) {
// Attr("onesided"): default = 1
// Input(0, "signal") type = T1
// Input(1, "frame_length") type = T2
// Input(2, "window") type = T1, optional
// Input(3, "frame_step") type = T2
// Output(0, "output") type = T1
// Get input signal shape
const auto& signal_shape = signal->shape();
const auto batch_size = signal_shape[0];
const auto signal_size = signal_shape[1];
const auto signal_components = signal_shape.size() == 2 ? 1 : signal_shape[2];
// Get the frame length
int frame_length = window->length(0);
// Get window length
// Calculate the window size with preference to the window input.
const auto window_size = frame_length;
MNN_ASSERT(window_size <= signal_size);
// Calculate the number of dfts to run
const auto n_dfts =
static_cast<int64_t>(std::floor((signal_size - window_size) / static_cast<float>(frame_step))) + 1;
// Calculate the output spectra length (onesided will return only the unique values)
// note: x >> 1 === std::floor(x / 2.f)
const auto dft_output_size = is_onesided ? (window_size >> 1) + 1 : window_size;
auto Y_data = reinterpret_cast<float*>(Y->host<void>());
// Get/create the signal mutable data
auto* signal_data = const_cast<float*>(reinterpret_cast<const float*>(signal->host<void>()));
// Define tensor shapes for each dft run
const int64_t output_components = 2;
auto dft_input_shape = std::vector<int>{1, window_size, signal_components};
auto dft_output_shape = std::vector<int>{1, dft_output_size, output_components};
std::shared_ptr<Tensor> b_fft(new Tensor), chirp(new Tensor);
std::vector<std::complex<T>> V;
std::vector<std::complex<T>> temp_output;
// Tensors do not own the backing memory, so no worries on destruction
std::shared_ptr<Tensor> input(Tensor::createDevice(dft_input_shape, signal->getType()));
std::shared_ptr<Tensor> output(Tensor::createDevice(dft_output_shape, Y->getType()));
// Run each dft of each batch as if it was a real-valued batch size 1 dft operation
for (int64_t batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int64_t i = 0; i < n_dfts; i++) {
auto input_frame_begin =
signal_data + (batch_idx * signal_size * signal_components) + (i * frame_step * signal_components);
auto output_frame_begin = Y_data + (batch_idx * n_dfts * dft_output_size * output_components) + (i * dft_output_size * output_components);
input->buffer().host = (uint8_t*)input_frame_begin;
output->buffer().host = (uint8_t*)output_frame_begin;
// Run individual dft
___RETURN_IF_ERROR((discrete_fourier_transform<T, U>(ctx, input.get(), output.get(), b_fft, chirp, 1, window_size, window, is_onesided, false, V, temp_output)));
}
}
return NO_ERROR;
} }
CPUStft::CPUStft(Backend* backend, int nfft, int hop_length, bool abs) CPUStft::CPUStft(Backend* backend, bool abs)
: Execution(backend), mNfft(nfft), mHopLength(hop_length), mAbs(abs) { : Execution(backend), mAbs(abs) {
if (gSinTable.empty() || gCosTable.empty()) { // nothing to do
gSinTable.resize(nfft);
gCosTable.resize(nfft);
for (int i = 0; i < nfft; i++) {
float angle = 2 * M_PI * i / nfft;
gSinTable[i] = sinf(angle);
gCosTable[i] = cosf(angle);
}
}
} }
ErrorCode CPUStft::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { ErrorCode CPUStft::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto cpuBn = static_cast<CPUBackend*>(backend());
mTmpFrames.buffer().dim[0].extent = cpuBn->threadNumber();
mTmpFrames.buffer().dim[1].extent = mNfft;
TensorUtils::getDescribe(&mTmpFrames)->dimensionFormat = MNN_DATA_FORMAT_NHWC;
mTmpFrames.buffer().dimensions = 2;
mTmpFrames.buffer().type = inputs[0]->getType();
backend()->onAcquireBuffer(&mTmpFrames, Backend::DYNAMIC);
backend()->onReleaseBuffer(&mTmpFrames, Backend::DYNAMIC);
return NO_ERROR; return NO_ERROR;
} }
ErrorCode CPUStft::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) { ErrorCode CPUStft::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
const float* sample = inputs[0]->host<float>(); auto signal = inputs[0];
const float* window = inputs[1]->host<float>(); const auto is_real_valued = is_real_valued_signal(signal);
float* buffer = mTmpFrames.host<float>(); const auto is_complex_valued = is_complex_valued_signal(signal);
float* output = outputs[0]->host<float>(); int frameStep = inputs[1]->host<int>()[0];
auto outputShape = outputs[0]->shape(); if (is_real_valued) {
int frames = outputShape[0]; ___RETURN_IF_ERROR((short_time_fourier_transform<float, float>(backend(), inputs[0], outputs[0], frameStep, inputs[2], mAbs, false)));
int col = outputShape[1]; } else if (is_complex_valued) {
auto cpuBn = static_cast<CPUBackend*>(backend()); ___RETURN_IF_ERROR((short_time_fourier_transform<float, std::complex<float>>(backend(), inputs[0], outputs[0], frameStep, inputs[2], mAbs, false)));
int threadNum = cpuBn->threadNumber(); } else {
// div frames to threadNum MNN_ASSERT(false);
int threadNumber = std::min(threadNum, frames); }
int sizeDivide = frames / threadNumber;
MNN_CONCURRENCY_BEGIN(tId, threadNumber) {
int number = sizeDivide;
if (tId == threadNumber - 1) {
number = frames - tId * sizeDivide;
}
for (int i = tId * sizeDivide; i < tId * sizeDivide + number; ++i) {
MNNDftAbs(sample + i * mHopLength, window, output + i * col, buffer + tId * mNfft, mNfft);
}
};
MNN_CONCURRENCY_END();
return NO_ERROR; return NO_ERROR;
} }
@ -96,10 +500,9 @@ public:
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
const MNN::Op* op, Backend* backend) const { const MNN::Op* op, Backend* backend) const {
auto stft = op->main_as_StftParam(); auto stft = op->main_as_StftParam();
return new CPUStft(backend, stft->n_fft(), stft->hop_length(), stft->abs()); return new CPUStft(backend, stft->abs());
} }
}; };
REGISTER_CPU_OP_CREATOR_AUDIO(CPUStftCreator, OpType_Stft); REGISTER_CPU_OP_CREATOR(CPUStftCreator, OpType_Stft);
} // namespace MNN } // namespace MNN
#endif // MNN_BUILD_AUDIO

View File

@ -6,7 +6,6 @@
// Copyright © 2018, Alibaba Group Holding Limited // Copyright © 2018, Alibaba Group Holding Limited
// //
#ifdef MNN_BUILD_AUDIO
#ifndef CPUStft_hpp #ifndef CPUStft_hpp
#define CPUStft_hpp #define CPUStft_hpp
@ -15,19 +14,14 @@
namespace MNN { namespace MNN {
class CPUStft : public Execution { class CPUStft : public Execution {
public: public:
CPUStft(Backend *backend, int nfft, int hop_length, bool abs); CPUStft(Backend *backend, bool abs);
virtual ~CPUStft() = default; virtual ~CPUStft() = default;
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
static std::vector<float> gSinTable;
static std::vector<float> gCosTable;
private: private:
int mNfft, mHopLength;
bool mAbs; bool mAbs;
Tensor mTmpFrames;
}; };
} // namespace MNN } // namespace MNN
#endif /* CPUStft.hpp */ #endif /* CPUStft.hpp */
#endif // MNN_BUILD_AUDIO

View File

@ -39,7 +39,7 @@ struct QuanPostTreatParameters {
//Auto: r0: dst*, r1: src*, r2:weight*, r3: src_depth_quad //Auto: r0: dst*, r1: src*, r2:weight*, r3: src_depth_quad
// Load from sp: r4: dst_step, r5: dst_depth_quad, r6: post, r10: real // Load from sp: r4: dst_step, r5: dst_depth_quad, r6: post, r10: real
// Load from post: r8: scale, lr: bias, r7: maxValue, r6: minValue // Load from post: lr: bias, r7: maxValue, r6: minValue
push {r4-r8, r10, lr} // avoid to touch platform-register r-9 push {r4-r8, r10, lr} // avoid to touch platform-register r-9
@ -47,28 +47,25 @@ ldr r4, [sp, #28]
ldr r5, [sp, #32] ldr r5, [sp, #32]
ldr r6, [sp, #36] ldr r6, [sp, #36]
ldr r10, [sp, #40] ldr r10, [sp, #40]
ldr r8, [r6, #0]
ldr lr, [r6, #4] ldr lr, [r6, #4]
vpush {q4-q7} vpush {q4-q7}
sub sp, sp, #36 sub sp, sp, #40
ldr r7, [r6, #16] // r7: useInt8 ldr r7, [r6, #16] // r7: useInt8
ldr r12, [r6, #28] // srcKernelSum ldr r8, [r6, #28] // srcKernelSum
str r12, [sp, #4]
ldr r12, [r6, #32] // weightBias
str r12, [sp, #8]
ldr r12, [r6, #36] // f32minmax ldr r12, [r6, #36] // f32minmax
str r12, [sp, #12] str r12, [sp, #12]
ldr r12, [r6, #8] // int8 max ldr r12, [r6, #8] // int8 max
str r12, [sp, #16] str r12, [sp, #16]
ldr r12, [r6, #12] // int8 min ldr r12, [r6, #12] // int8 min
str r12, [sp, #20] str r12, [sp, #20]
lsl r12, r3, #6 // weight_stride = src_depth_quad*LP*HP
str r12, [sp, #24]
ldr r12, [r6, #48] // extraScale ldr r12, [r6, #48] // extraScale
str r12, [sp, #28] str r12, [sp, #28]
ldr r12, [r6, #56] // accumBuffer
str r12, [sp, #32]
str r12, [sp, #36]
Start: Start:
cmp r10, #2 cmp r10, #2
@ -76,7 +73,6 @@ blt L1LoopDz
L2LoopDz: L2LoopDz:
mov r10, r1 mov r10, r1
str r2, [sp, #32] // store weight ptr
subs r12, r3, #1 subs r12, r3, #1
// first four output // first four output
vld1.8 {q2}, [r1]! vld1.8 {q2}, [r1]!
@ -153,7 +149,7 @@ L2LoopDz:
L2LoopSzEnd: L2LoopSzEnd:
L2Quan: L2Quan:
vld1.f32 {q5}, [r8]! // scale vld1.f32 {q5}, [r2]! // scale
vpadd.s32 d16, d16, d17 vpadd.s32 d16, d16, d17
vpadd.s32 d20, d20, d21 vpadd.s32 d20, d20, d21
@ -188,12 +184,10 @@ L2LoopDz:
vmulq.f32 q1, q1, d10[1] vmulq.f32 q1, q1, d10[1]
L2_MLA: L2_MLA:
ldr r6, [sp, #4] // srcKernelSum vld1.f32 {d12[0]}, [r8]! // tile 0
vld1.f32 {d12[0]}, [r6]! // tile 0 vld1.f32 {d12[1]}, [r8] // tile 1
vld1.f32 {d12[1]}, [r6] // tile 1 sub r8, r8, #4
ldr r6, [sp, #8] // weightBias vld1.f32 {q7}, [r2]!
vld1.f32 {q7}, [r6]!
str r6, [sp, #8] // update next 4 weightBias
vmla.f32 q0, q7, d12[0] vmla.f32 q0, q7, d12[0]
vmla.f32 q1, q7, d12[1] vmla.f32 q1, q7, d12[1]
@ -207,12 +201,24 @@ L2LoopDz:
vld1.f32 {q4}, [lr]! // bias vld1.f32 {q4}, [lr]! // bias
vadd.f32 q0, q0, q4 // bias vadd.f32 q0, q0, q4 // bias
vadd.f32 q1, q1, q4 vadd.f32 q1, q1, q4
b L2_POST cmp r0, #0
bne L2_POST
b L2_BUFFER
L2_ADD_DSTV: L2_ADD_DSTV:
vld1.f32 {q4, q5}, [r0] ldr r6, [sp, #32] // accumBuffer
vld1.f32 {q4, q5}, [r6]!
vadd.f32 q0, q0, q4 vadd.f32 q0, q0, q4
vadd.f32 q1, q1, q5 vadd.f32 q1, q1, q5
str r6, [sp, #32]
cmp r0, #0
bne L2_POST
L2_BUFFER:
ldr r6, [sp, #36] // accumBuffer
vst1.f32 {q0, q1}, [r6]!
str r6, [sp, #36]
b L2LoopCheck
L2_POST: L2_POST:
ldr r6, [sp, #12] // fp32 minmax ldr r6, [sp, #12] // fp32 minmax
@ -265,16 +271,12 @@ L2LoopDz:
L2LoopCheck: L2LoopCheck:
subs r5, r5, #1 subs r5, r5, #1
mov r1, r10 mov r1, r10
ldr r2, [sp, #32] // origin weight ptr
ldr r6, [sp, #24] // weight stride
add r2, r2, r6 // next oc4 weight ptr
bne L2LoopDz bne L2LoopDz
b End b End
L1LoopDz: L1LoopDz:
mov r10, r1 mov r10, r1
str r2, [sp, #32] // store weight ptr
subs r12, r3, #1 subs r12, r3, #1
// first four output // first four output
vld1.8 {q2}, [r1]! vld1.8 {q2}, [r1]!
@ -321,8 +323,7 @@ L1LoopDz:
L1LoopSzEnd: L1LoopSzEnd:
L1Quan: L1Quan:
//vld1.f32 {q4}, [lr]! // bias vld1.f32 {q5}, [r2]! // scale
vld1.f32 {q5}, [r8]! // scale
vpadd.s32 d16, d16, d17 vpadd.s32 d16, d16, d17
vpadd.s32 d20, d20, d21 vpadd.s32 d20, d20, d21
@ -344,13 +345,9 @@ L1LoopDz:
vmulq.f32 q0, q0, d10[0] vmulq.f32 q0, q0, d10[0]
L1_MLA: L1_MLA:
ldr r6, [sp, #4] // srcKernelSum vld1.f32 {d12[0]}, [r8] // tile 0
vld1.f32 {d12[0]}, [r6] // tile 0 vld1.f32 {q7}, [r2]!
ldr r6, [sp, #8] // weightBias
vld1.f32 {q7}, [r6]!
str r6, [sp, #8] // update next 4 weightBias
vmla.f32 q0, q7, d12[0] vmla.f32 q0, q7, d12[0]
//vadd.f32 q0, q0, q4
cmp r7, #0 cmp r7, #0
bne L1QuanUseInt8 bne L1QuanUseInt8
@ -359,11 +356,23 @@ L1LoopDz:
beq L1_ADD_DSTV beq L1_ADD_DSTV
vld1.f32 {q4}, [lr]! // bias vld1.f32 {q4}, [lr]! // bias
vadd.f32 q0, q0, q4 vadd.f32 q0, q0, q4
b L1_POST cmp r0, #0
bne L1_POST
b L1_BUFFER
L1_ADD_DSTV: L1_ADD_DSTV:
vld1.f32 {q4}, [r0] ldr r6, [sp, #32] // accumBuffer
vld1.f32 {q4}, [r6]!
vadd.f32 q0, q0, q4 vadd.f32 q0, q0, q4
str r6, [sp, #32]
cmp r0, #0
bne L1_POST
L1_BUFFER:
ldr r6, [sp, #36] // accumBuffer
vst1.f32 {q0}, [r6]!
str r6, [sp, #36]
b L1LoopCheck
L1_POST: L1_POST:
ldr r6, [sp, #12] // fp32 minmax ldr r6, [sp, #12] // fp32 minmax
@ -406,13 +415,10 @@ L1LoopDz:
L1LoopCheck: L1LoopCheck:
subs r5, r5, #1 subs r5, r5, #1
mov r1, r10 mov r1, r10
ldr r2, [sp, #32] // origin weight ptr
ldr r6, [sp, #24] // weight stride
add r2, r2, r6 // next oc4 weight ptr
bne L1LoopDz bne L1LoopDz
End: End:
add sp, sp, #36 add sp, sp, #40
vpop {q4-q7} vpop {q4-q7}
pop {r4-r8, r10, pc} pop {r4-r8, r10, pc}

View File

@ -29,7 +29,7 @@ asm_function MNNGemmInt8AddBiasScale_16x4_Unit_FAST
//Auto: r0: dst*, r1: src*, r2:weight*, r3: src_depth_quad //Auto: r0: dst*, r1: src*, r2:weight*, r3: src_depth_quad
// Load from sp: r4: dst_step, r5: dst_depth_quad, r6: post, r10: real // Load from sp: r4: dst_step, r5: dst_depth_quad, r6: post, r10: real
// Load from post: r8: scale, lr: bias, r7: maxValue, r6: minValue // Load from post: lr: bias, r7: maxValue, r6: minValue
push {r4-r8, r10, lr} // avoid to touch platform-register r-9 push {r4-r8, r10, lr} // avoid to touch platform-register r-9
@ -38,18 +38,14 @@ ldr r5, [sp, #32]
ldr r6, [sp, #36] ldr r6, [sp, #36]
ldr r10, [sp, #40] ldr r10, [sp, #40]
ldr r8, [r6, #0]
ldr lr, [r6, #4] ldr lr, [r6, #4]
vpush {q4-q7} vpush {q4-q7}
sub sp, sp, #36 sub sp, sp, #24
// Only int8 output use this kernel. // Only int8 output use this kernel.
ldr r12, [r6, #28] // srcKernelSum ldr r8, [r6, #28] // srcKernelSum
str r12, [sp, #4]
ldr r12, [r6, #32] // weightBias
str r12, [sp, #8]
ldr r12, [r6, #36] // f32minmax ldr r12, [r6, #36] // f32minmax
str r12, [sp, #12] str r12, [sp, #12]
ldr r12, [r6, #8] // int8 max ldr r12, [r6, #8] // int8 max
@ -130,7 +126,7 @@ L2LoopDz:
L2Quan: L2Quan:
vld1.f32 {q14}, [lr]! // bias vld1.f32 {q14}, [lr]! // bias
vld1.f32 {q15}, [r8]! // scale vld1.f32 {q15}, [r2]! // scale
vpadd.s32 d20, d0, d1 vpadd.s32 d20, d0, d1
vpadd.s32 d21, d2, d3 vpadd.s32 d21, d2, d3
@ -152,12 +148,10 @@ L2LoopDz:
vmulq.f32 q0, q0, q15 // mul scale vmulq.f32 q0, q0, q15 // mul scale
vmulq.f32 q1, q1, q15 vmulq.f32 q1, q1, q15
ldr r6, [sp, #4] // srcKernelSum vld1.f32 {d12[0]}, [r8]! // tile 0
vld1.f32 {d12[0]}, [r6]! // tile 0 vld1.f32 {d12[1]}, [r8] // tile 1
vld1.f32 {d12[1]}, [r6] // tile 1 vld1.f32 {q7}, [r2]!
ldr r6, [sp, #8] // weightBias sub r8, r8, #4
vld1.f32 {q7}, [r6]!
str r6, [sp, #8] // update next 4 weightBias
vmla.f32 q0, q7, d12[0] // add srcKernelSum x weightBias vmla.f32 q0, q7, d12[0] // add srcKernelSum x weightBias
vmla.f32 q1, q7, d12[1] vmla.f32 q1, q7, d12[1]
@ -250,7 +244,7 @@ L1LoopDz:
vld1.f32 {q14}, [lr]! vld1.f32 {q14}, [lr]!
vpadd.s32 d20, d0, d1 vpadd.s32 d20, d0, d1
vpadd.s32 d21, d2, d3 vpadd.s32 d21, d2, d3
vld1.f32 {q15}, [r8]! vld1.f32 {q15}, [r2]!
vpadd.s32 d22, d4, d5 vpadd.s32 d22, d4, d5
vpadd.s32 d23, d6, d7 vpadd.s32 d23, d6, d7
@ -261,11 +255,8 @@ L1LoopDz:
vcvt.f32.s32 q0, q8 vcvt.f32.s32 q0, q8
vmulq.f32 q0, q0, q15 vmulq.f32 q0, q0, q15
ldr r6, [sp, #4] // srcKernelSum vld1.f32 {d12[0]}, [r8] // tile 0
vld1.f32 {d12[0]}, [r6] // tile 0 vld1.f32 {q7}, [r2]!
ldr r6, [sp, #8] // weightBias
vld1.f32 {q7}, [r6]!
str r6, [sp, #8] // update next 4 weightBias
vmla.f32 q0, q7, d12[0] vmla.f32 q0, q7, d12[0]
vadd.f32 q0, q0, q14 // add bias vadd.f32 q0, q0, q14 // add bias
@ -296,7 +287,7 @@ L1LoopCheck:
bne L1LoopDz bne L1LoopDz
End: End:
add sp, sp, #36 add sp, sp, #24
vpop {q4-q7} vpop {q4-q7}
pop {r4-r8, r10, pc} pop {r4-r8, r10, pc}

View File

@ -28,9 +28,10 @@ struct QuanPostTreatParameters {
float* weightQuanBias; float* weightQuanBias;
float* fp32minmax; float* fp32minmax;
ssize_t blockNum = 1; ssize_t blockNum = 1;
const int32_t* bias; const int32_t* bias = nullptr;
const float* extraScale = nullptr; const float* extraScale = nullptr;
const float* extraBias = nullptr; const float* extraBias = nullptr;
float* accumBuffer = nullptr;
}; };
*/ */
@ -47,11 +48,10 @@ ldr r4, [sp, #28]
ldr r5, [sp, #32] ldr r5, [sp, #32]
ldr r6, [sp, #36] ldr r6, [sp, #36]
ldr r10, [sp, #40] ldr r10, [sp, #40]
ldr r8, [r6, #0]
ldr lr, [r6, #4] ldr lr, [r6, #4]
vpush {q4-q7} vpush {q4-q7}
sub sp, sp, #36 sub sp, sp, #32
// Branch1: input is int8_t, output is float32, DO NOT USE "scale". // Branch1: input is int8_t, output is float32, DO NOT USE "scale".
// Branch2: input is int8_t, output is float32. USE "scale", DO NOT USE "minValue" and "maxValue". // Branch2: input is int8_t, output is float32. USE "scale", DO NOT USE "minValue" and "maxValue".
// Branch3: input is int8_t, output is int8_t. USE "scale", "minValue" and "maxValue". // Branch3: input is int8_t, output is int8_t. USE "scale", "minValue" and "maxValue".
@ -59,16 +59,14 @@ sub sp, sp, #36
ldr r7, [r6, #16] // r7: useInt8 ldr r7, [r6, #16] // r7: useInt8
ldr r12, [r6, #28] // srcKernelSum ldr r8, [r6, #28] // srcKernelSum
str r12, [sp, #4]
ldr r12, [r6, #32] // weightBias
str r12, [sp, #8]
ldr r12, [r6, #36] // f32minmax ldr r12, [r6, #36] // f32minmax
str r12, [sp, #12] str r12, [sp, #12]
lsl r12, r3, #5 // weight_stride = src_depth_quad*LP*HP
str r12, [sp, #16]
ldr r12, [r6, #48] // extraScale ldr r12, [r6, #48] // extraScale
str r12, [sp, #20] str r12, [sp, #20]
ldr r12, [r6, #56] // accumBuffer
str r12, [sp, #24]
str r12, [sp, #28]
Start: Start:
cmp r10, #2 cmp r10, #2
@ -76,7 +74,6 @@ blt L1LoopDz
L2LoopDz: L2LoopDz:
mov r10, r1 mov r10, r1
str r2, [sp, #24] // store weight ptr
subs r12, r3, #1 subs r12, r3, #1
// first four output // first four output
vld1.8 {q2}, [r1]! vld1.8 {q2}, [r1]!
@ -166,7 +163,7 @@ L2LoopDz:
L2LoopSzEnd: L2LoopSzEnd:
L2Quan: L2Quan:
vld1.f32 {q5}, [r8]! // scale vld1.f32 {q5}, [r2]! // scale
vpadd.s32 d16, d16, d17 vpadd.s32 d16, d16, d17
vpadd.s32 d20, d20, d21 vpadd.s32 d20, d20, d21
@ -185,9 +182,6 @@ L2LoopDz:
vpadd.s32 d18, d24, d26 vpadd.s32 d18, d24, d26
vpadd.s32 d19, d28, d30 vpadd.s32 d19, d28, d30
// vaddq.s32 q0, q8, q4 // add bias
// vaddq.s32 q1, q9, q4
vcvt.f32.s32 q0, q8 vcvt.f32.s32 q0, q8
vcvt.f32.s32 q1, q9 vcvt.f32.s32 q1, q9
@ -204,12 +198,10 @@ L2LoopDz:
vmulq.f32 q1, q1, d10[1] vmulq.f32 q1, q1, d10[1]
L2_MLA: L2_MLA:
ldr r6, [sp, #4] // srcKernelSum vld1.f32 {d12[0]}, [r8]! // tile 0
vld1.f32 {d12[0]}, [r6]! // tile 0 vld1.f32 {d12[1]}, [r8] // tile 1
vld1.f32 {d12[1]}, [r6] // tile 1 sub r8, r8, #4
ldr r6, [sp, #8] // weightBias vld1.f32 {q7}, [r2]!
vld1.f32 {q7}, [r6]!
str r6, [sp, #8] // update next 4 weightBias
vmla.f32 q0, q7, d12[0] vmla.f32 q0, q7, d12[0]
vmla.f32 q1, q7, d12[1] vmla.f32 q1, q7, d12[1]
@ -220,12 +212,24 @@ L2LoopDz:
vld1.f32 {q4}, [lr]! // bias vld1.f32 {q4}, [lr]! // bias
vadd.f32 q0, q0, q4 // bias vadd.f32 q0, q0, q4 // bias
vadd.f32 q1, q1, q4 vadd.f32 q1, q1, q4
b L2_POST cmp r0, #0
bne L2_POST
b L2_BUFFER
L2_ADD_DSTV: L2_ADD_DSTV:
vld1.f32 {q4, q5}, [r0] ldr r6, [sp, #24] // accumBuffer
vld1.f32 {q4, q5}, [r6]!
vadd.f32 q0, q0, q4 vadd.f32 q0, q0, q4
vadd.f32 q1, q1, q5 vadd.f32 q1, q1, q5
str r6, [sp, #24]
cmp r0, #0
bne L2_POST
L2_BUFFER:
ldr r6, [sp, #28] // accumBuffer
vst1.f32 {q0, q1}, [r6]!
str r6, [sp, #28]
b L2LoopCheck
L2_POST: L2_POST:
ldr r6, [sp, #12] // fp32 minmax ldr r6, [sp, #12] // fp32 minmax
@ -246,16 +250,12 @@ L2LoopDz:
L2LoopCheck: L2LoopCheck:
subs r5, r5, #1 subs r5, r5, #1
mov r1, r10 mov r1, r10
ldr r2, [sp, #24] // origin weight ptr
ldr r6, [sp, #16] // weight stride
add r2, r2, r6 // next oc4 weight ptr
bne L2LoopDz bne L2LoopDz
b End b End
L1LoopDz: L1LoopDz:
mov r10, r1 mov r10, r1
str r2, [sp, #24] // store weight ptr
subs r12, r3, #1 subs r12, r3, #1
// first four output // first four output
vld1.8 {q2}, [r1]! vld1.8 {q2}, [r1]!
@ -316,8 +316,7 @@ L1LoopDz:
L1LoopSzEnd: L1LoopSzEnd:
L1Quan: L1Quan:
//vld1.f32 {q4}, [lr]! // bias vld1.f32 {q5}, [r2]! // scale
vld1.f32 {q5}, [r8]! // scale
vpadd.s32 d16, d16, d17 vpadd.s32 d16, d16, d17
vpadd.s32 d20, d20, d21 vpadd.s32 d20, d20, d21
@ -339,23 +338,31 @@ L1LoopDz:
vmulq.f32 q0, q0, d10[0] vmulq.f32 q0, q0, d10[0]
L1_MLA: L1_MLA:
ldr r6, [sp, #4] // srcKernelSum vld1.f32 {d12[0]}, [r8] // tile 0
vld1.f32 {d12[0]}, [r6] // tile 0 vld1.f32 {q7}, [r2]!
ldr r6, [sp, #8] // weightBias
vld1.f32 {q7}, [r6]!
str r6, [sp, #8] // update next 4 weightBias
vmla.f32 q0, q7, d12[0] vmla.f32 q0, q7, d12[0]
//vadd.f32 q0, q0, q4
cmp lr, #0 cmp lr, #0
beq L1_ADD_DSTV beq L1_ADD_DSTV
vld1.f32 {q4}, [lr]! // bias vld1.f32 {q4}, [lr]! // bias
vadd.f32 q0, q0, q4 vadd.f32 q0, q0, q4
b L1_POST cmp r0, #0
bne L1_POST
b L1_BUFFER
L1_ADD_DSTV: L1_ADD_DSTV:
vld1.f32 {q4}, [r0] ldr r6, [sp, #24] // accumBuffer
vld1.f32 {q4}, [r6]!
vadd.f32 q0, q0, q4 vadd.f32 q0, q0, q4
str r6, [sp, #24]
cmp r0, #0
bne L1_POST
L1_BUFFER:
ldr r6, [sp, #28] // accumBuffer
vst1.f32 {q0}, [r6]!
str r6, [sp, #28]
b L1LoopCheck
L1_POST: L1_POST:
ldr r6, [sp, #12] // fp32 minmax ldr r6, [sp, #12] // fp32 minmax
@ -374,13 +381,10 @@ L1LoopDz:
L1LoopCheck: L1LoopCheck:
subs r5, r5, #1 subs r5, r5, #1
mov r1, r10 mov r1, r10
ldr r2, [sp, #24] // origin weight ptr
ldr r6, [sp, #16] // weight stride
add r2, r2, r6 // next oc4 weight ptr
bne L1LoopDz bne L1LoopDz
End: End:
add sp, sp, #36 add sp, sp, #32
vpop {q4-q7} vpop {q4-q7}
pop {r4-r8, r10, pc} pop {r4-r8, r10, pc}

View File

@ -96,47 +96,39 @@ struct QuanPostTreatParameters {
// x5: dst_depth_quad, x6: post, x7: realSize // x5: dst_depth_quad, x6: post, x7: realSize
//Load from post: //Load from post:
// x7: scale, x10: bias, w11: maxValue, w6: minValue, w13: UseInt8, x14: srcKernelSum, x12: weightQuantBias // x10: bias, w11: maxValue, w6: minValue, w13: UseInt8, x14: srcKernelSum
mov x8, x7
mov x15, x6
ldr x7, [x15, #0]
ldr x10, [x15, #8]
ldr w11, [x15, #16]
ldr w6, [x15, #20]
ldr w13, [x15, #24]
ldr x14, [x15, #40] // srcKernelSum
ldr x12, [x15, #48] // weightQuantBias
stp d14, d15, [sp, #(-16 * 8)]! ldr x10, [x6, #8]
ldr w11, [x6, #16]
ldr w13, [x6, #24]
ldr x14, [x6, #40] // srcKernelSum
stp d14, d15, [sp, #(-16 * 6)]!
stp d12, d13, [sp, #(16 * 1)] stp d12, d13, [sp, #(16 * 1)]
stp d10, d11, [sp, #(16 * 2)] stp d10, d11, [sp, #(16 * 2)]
stp d8, d9, [sp, #(16 * 3)] stp d8, d9, [sp, #(16 * 3)]
stp x19, x20, [sp, #(16 * 4)] stp x19, x20, [sp, #(16 * 4)]
stp x21, x22, [sp, #(16 * 5)] stp x23, x24, [sp, #(16 * 5)]
stp x23, x24, [sp, #(16 * 6)]
ldr x19, [x15, #56] // fp32 min max ldr x19, [x6, #56] // fp32 min max
ldr x21, [x15, #64] // blockNum ldr x23, [x6, #80] // extraScale
ldr x23, [x15, #80] // extraScale ldr x15, [x6, #96] // accumBuffer
lsl x21, x3, #6 // src_depth_quad* SRC_UNIT * UNIT * sizeof(int8_t) ldr w6, [x6, #20] // minValue
add x20, x19, #4 add x20, x19, #4
lsl x24, x8, #4 // eDest * SRC_UNIT lsl x24, x7, #4 // eDest * SRC_UNIT
Start: Start:
cmp x8, #3 cmp x7, #3
beq L3Dz beq L3Dz
cmp x8, #2 cmp x7, #2
beq L2Dz beq L2Dz
cmp x8, #1 cmp x7, #1
beq L1Dz beq L1Dz
cmp w13, #1 mov x7, x15
bne L4LoopDz
L4LoopDz: L4LoopDz:
mov x8, x1 mov x8, x1
mov x22, x2
ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x2], #64 ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x2], #64
ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64 ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64
@ -276,9 +268,9 @@ L4LoopDz:
addp v15.4s, v10.4s, v11.4s addp v15.4s, v10.4s, v11.4s
L4Quan: L4Quan:
ld1 {v1.4s}, [x7], #16 // scalefuse ld1 {v1.4s}, [x2], #16 // scalefuse
ld1 {v20.4s}, [x14] // srcKernelSum ld1 {v20.4s}, [x14] // srcKernelSum
ld1 {v21.4s}, [x12], #16 // weightQuanZero ld1 {v21.4s}, [x2], #16 // weightQuanZero
scvtf v4.4s, v12.4s scvtf v4.4s, v12.4s
scvtf v5.4s, v13.4s scvtf v5.4s, v13.4s
@ -307,14 +299,20 @@ L4LoopDz:
fadd v5.4s, v5.4s, v0.4s fadd v5.4s, v5.4s, v0.4s
fadd v6.4s, v6.4s, v0.4s fadd v6.4s, v6.4s, v0.4s
fadd v7.4s, v7.4s, v0.4s fadd v7.4s, v7.4s, v0.4s
b L4_POST cbnz x0, L4_POST
b L4_BUFFER
L4_ADD_DSTV: L4_ADD_DSTV:
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0] ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x7], #64
fadd v4.4s, v4.4s, v8.4s fadd v4.4s, v4.4s, v8.4s
fadd v5.4s, v5.4s, v9.4s fadd v5.4s, v5.4s, v9.4s
fadd v6.4s, v6.4s, v10.4s fadd v6.4s, v6.4s, v10.4s
fadd v7.4s, v7.4s, v11.4s fadd v7.4s, v7.4s, v11.4s
cbnz x0, L4_POST
L4_BUFFER:
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
b L4LoopCheck
L4_POST: L4_POST:
cbz x19, L4_STORE cbz x19, L4_STORE
@ -354,18 +352,17 @@ L4LoopDz:
L4LoopCheck: L4LoopCheck:
subs x5, x5, #1 subs x5, x5, #1
mov x1, x8 mov x1, x8
add x2, x22, x21
bne L4LoopDz bne L4LoopDz
b End b End
L3Dz: L3Dz:
mov x7, x15
cmp w13, #1 cmp w13, #1
bne L3LoopDz bne L3LoopDz
sub x4, x4, #8 sub x4, x4, #8
L3LoopDz: L3LoopDz:
mov x8, x1 mov x8, x1
mov x22, x2
ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x2], #64 ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x2], #64
ld1 {v4.16b, v5.16b, v6.16b}, [x1], x24 ld1 {v4.16b, v5.16b, v6.16b}, [x1], x24
@ -479,10 +476,10 @@ L3LoopDz:
addp v14.4s, v8.4s, v9.4s addp v14.4s, v8.4s, v9.4s
L3Quan: L3Quan:
ld1 {v1.4s}, [x7], #16 ld1 {v1.4s}, [x2], #16
ld1 {v20.d}[0], [x14], #8 // srcKernelSum ld1 {v20.d}[0], [x14], #8 // srcKernelSum
ld1 {v20.s}[2], [x14] ld1 {v20.s}[2], [x14]
ld1 {v21.4s}, [x12], #16 // weightQuanZero ld1 {v21.4s}, [x2], #16 // weightQuanZero
scvtf v4.4s, v12.4s scvtf v4.4s, v12.4s
scvtf v5.4s, v13.4s scvtf v5.4s, v13.4s
@ -512,13 +509,19 @@ L3LoopDz:
fadd v4.4s, v4.4s, v0.4s fadd v4.4s, v4.4s, v0.4s
fadd v5.4s, v5.4s, v0.4s fadd v5.4s, v5.4s, v0.4s
fadd v6.4s, v6.4s, v0.4s fadd v6.4s, v6.4s, v0.4s
b L3_POST cbnz x0, L3_POST
b L3_BUFFER
L3_ADD_DSTV: L3_ADD_DSTV:
ld1 {v0.4s, v1.4s, v2.4s}, [x0] ld1 {v8.4s, v9.4s, v10.4s}, [x7], #48
fadd v4.4s, v4.4s, v0.4s fadd v4.4s, v4.4s, v8.4s
fadd v5.4s, v5.4s, v1.4s fadd v5.4s, v5.4s, v9.4s
fadd v6.4s, v6.4s, v2.4s fadd v6.4s, v6.4s, v10.4s
cbnz x0, L3_POST
L3_BUFFER:
st1 {v4.4s, v5.4s, v6.4s}, [x15], #48
b L3LoopCheck
L3_POST: L3_POST:
cbz x19, L3_STORE cbz x19, L3_STORE
@ -559,15 +562,14 @@ L3LoopDz:
L3LoopCheck: L3LoopCheck:
subs x5, x5, #1 subs x5, x5, #1
mov x1, x8 mov x1, x8
add x2, x22, x21
bne L3LoopDz bne L3LoopDz
b End b End
L2Dz: L2Dz:
mov x7, x15
L2LoopDz: L2LoopDz:
mov x8, x1 mov x8, x1
mov x22, x2
ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x2], #64 ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x2], #64
ld1 {v4.16b, v5.16b}, [x1], x24 ld1 {v4.16b, v5.16b}, [x1], x24
@ -649,9 +651,9 @@ L2LoopDz:
addp v13.4s, v6.4s, v7.4s addp v13.4s, v6.4s, v7.4s
L2Quan: L2Quan:
ld1 {v1.4s}, [x7], #16 ld1 {v1.4s}, [x2], #16
ld1 {v20.d}[0], [x14] // srcKernelSum ld1 {v20.d}[0], [x14] // srcKernelSum
ld1 {v21.4s}, [x12], #16 // weightQuanZero ld1 {v21.4s}, [x2], #16 // weightQuanZero
scvtf v4.4s, v12.4s scvtf v4.4s, v12.4s
scvtf v5.4s, v13.4s scvtf v5.4s, v13.4s
@ -674,12 +676,18 @@ L2LoopDz:
ld1 {v0.4s}, [x10], #16 ld1 {v0.4s}, [x10], #16
fadd v4.4s, v4.4s, v0.4s fadd v4.4s, v4.4s, v0.4s
fadd v5.4s, v5.4s, v0.4s fadd v5.4s, v5.4s, v0.4s
b L2_POST cbnz x0, L2_POST
b L2_BUFFER
L2_ADD_DSTV: L2_ADD_DSTV:
ld1 {v0.4s, v1.4s}, [x0] ld1 {v8.4s, v9.4s}, [x7], #32
fadd v4.4s, v4.4s, v0.4s fadd v4.4s, v4.4s, v8.4s
fadd v5.4s, v5.4s, v1.4s fadd v5.4s, v5.4s, v9.4s
cbnz x0, L2_POST
L2_BUFFER:
st1 {v4.4s, v5.4s}, [x15], #32
b L2LoopCheck
L2_POST: L2_POST:
cbz x19, L2_STORE cbz x19, L2_STORE
@ -713,19 +721,18 @@ L2LoopDz:
L2LoopCheck: L2LoopCheck:
subs x5, x5, #1 subs x5, x5, #1
mov x1, x8 mov x1, x8
add x2, x22, x21
bne L2LoopDz bne L2LoopDz
b End b End
L1Dz: L1Dz:
mov x7, x15
L1LoopDz: L1LoopDz:
mov x8, x1 mov x8, x1
mov x22, x2
ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x2], #64 ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x2], #64
dup v16.4s, wzr dup v16.4s, wzr
dup v17.4s, wzr dup v17.4s, wzr
ld1 {v4.16b}, [x1], x24 ld1 {v4.16b}, [x1], #16
smull v8.8h, v0.8b, v4.8b smull v8.8h, v0.8b, v4.8b
dup v18.4s, wzr dup v18.4s, wzr
@ -742,7 +749,7 @@ L1LoopDz:
L1LoopSz: L1LoopSz:
sadalp v16.4s, v8.8h sadalp v16.4s, v8.8h
ld1 {v4.16b}, [x1], x24 ld1 {v4.16b}, [x1], #16
sadalp v17.4s, v9.8h sadalp v17.4s, v9.8h
sadalp v18.4s, v10.8h sadalp v18.4s, v10.8h
sadalp v19.4s, v11.8h sadalp v19.4s, v11.8h
@ -778,9 +785,9 @@ L1LoopDz:
addp v12.4s, v4.4s, v5.4s addp v12.4s, v4.4s, v5.4s
L1Quan: L1Quan:
ld1 {v1.4s}, [x7], #16 ld1 {v1.4s}, [x2], #16
ld1 {v20.s}[0], [x14] // srcKernelSum ld1 {v20.s}[0], [x14] // srcKernelSum
ld1 {v21.4s}, [x12], #16 // weightQuanZero ld1 {v21.4s}, [x2], #16 // weightQuanZero
scvtf v4.4s, v12.4s scvtf v4.4s, v12.4s
MUL_SCALE1 v1, v4 MUL_SCALE1 v1, v4
@ -799,11 +806,17 @@ L1LoopDz:
cbz x10, L1_ADD_DSTV cbz x10, L1_ADD_DSTV
ld1 {v0.4s}, [x10], #16 ld1 {v0.4s}, [x10], #16
fadd v4.4s, v4.4s, v0.4s fadd v4.4s, v4.4s, v0.4s
b L1_POST cbnz x0, L1_POST
b L1_BUFFER
L1_ADD_DSTV: L1_ADD_DSTV:
ld1 {v0.4s}, [x0] ld1 {v8.4s}, [x7], #16
fadd v4.4s, v4.4s, v0.4s fadd v4.4s, v4.4s, v8.4s
cbnz x0, L1_POST
L1_BUFFER:
st1 {v4.4s}, [x15], #16
b L1LoopCheck
L1_POST: L1_POST:
cbz x19, L1_STORE cbz x19, L1_STORE
@ -834,17 +847,15 @@ L1LoopDz:
L1LoopCheck: L1LoopCheck:
subs x5, x5, #1 subs x5, x5, #1
mov x1, x8 mov x1, x8
add x2, x22, x21
bne L1LoopDz bne L1LoopDz
End: End:
ldp x23, x24, [sp, #(16 * 6)] ldp x23, x24, [sp, #(16 * 5)]
ldp x21, x22, [sp, #(16 * 5)]
ldp x19, x20, [sp, #(16 * 4)] ldp x19, x20, [sp, #(16 * 4)]
ldp d8, d9, [sp, #(16 * 3)] ldp d8, d9, [sp, #(16 * 3)]
ldp d10, d11, [sp, #(16 * 2)] ldp d10, d11, [sp, #(16 * 2)]
ldp d12, d13, [sp, #(16 * 1)] ldp d12, d13, [sp, #(16 * 1)]
ldp d14, d15, [sp], #(16 * 8) ldp d14, d15, [sp], #(16 * 6)
ret ret
#endif #endif

View File

@ -69,10 +69,9 @@ struct QuanPostTreatParameters {
// x5: dst_depth_quad, x6: post, x7: remain // x5: dst_depth_quad, x6: post, x7: remain
//Load from post: //Load from post:
// x7: scale, x10: bias, w11: maxValue, w13: minValue, w12: useInt8 // x10: bias, w11: maxValue, w13: minValue, w12: useInt8
// x19: srcKernelSum, x20: weightQuanBias // x7: srcKernelSum
mov x8, x7 mov x8, x7
ldr x7, [x6, #0]
ldr x10, [x6, #8] ldr x10, [x6, #8]
ldr w11, [x6, #16] ldr w11, [x6, #16]
ldr w13, [x6, #20] ldr w13, [x6, #20]
@ -83,9 +82,8 @@ stp d12, d13, [sp, #(16 * 1)]
stp d10, d11, [sp, #(16 * 2)] stp d10, d11, [sp, #(16 * 2)]
stp d8, d9, [sp, #(16 * 3)] stp d8, d9, [sp, #(16 * 3)]
stp x21, x22, [sp, #(16 * 4)] stp x21, x22, [sp, #(16 * 4)]
stp x19, x20, [sp, #(16 * 5)] ldr x7, [x6, #40]
ldr x19, [x6, #40] ldr x15, [x6, #96] // extraScale
ldr x20, [x6, #48]
cmp x8, #3 cmp x8, #3
beq L3Dz beq L3Dz
@ -229,16 +227,24 @@ L4LoopDz:
addp v15.4s, v22.4s, v23.4s addp v15.4s, v22.4s, v23.4s
L4Quan: L4Quan:
ld1 {v1.4s}, [x7], #16 // scale ld1 {v1.4s}, [x2], #16 // scale
ld1 {v2.4s}, [x19] // x kernel sum ld1 {v2.4s}, [x7] // x kernel sum
ld1 {v24.4s}, [x20], #16 // weight quan zeropoint ld1 {v24.4s}, [x2], #16 // weight quan zeropoint
TILE4_INT2FLOAT: TILE4_INT2FLOAT:
scvtf v4.4s, v12.4s scvtf v4.4s, v12.4s
scvtf v5.4s, v13.4s scvtf v5.4s, v13.4s
scvtf v6.4s, v14.4s scvtf v6.4s, v14.4s
scvtf v7.4s, v15.4s scvtf v7.4s, v15.4s
cbz x15, TILE4_SCALE
ld1 {v12.4s}, [x15]
fmul v4.4s, v4.4s, v12.s[0]
fmul v5.4s, v5.4s, v12.s[1]
fmul v6.4s, v6.4s, v12.s[2]
fmul v7.4s, v7.4s, v12.s[3]
TILE4_SCALE:
fmul v12.4s, v4.4s, v1.4s fmul v12.4s, v4.4s, v1.4s
fmul v13.4s, v5.4s, v1.4s fmul v13.4s, v5.4s, v1.4s
fmul v14.4s, v6.4s, v1.4s fmul v14.4s, v6.4s, v1.4s
@ -297,7 +303,7 @@ L4LoopCheck:
b End b End
L3Dz: L3Dz:
add x3, x19, #8 add x3, x7, #8
cmp w12, #1 cmp w12, #1
bne L3LoopDz bne L3LoopDz
sub x4, x4, #8 sub x4, x4, #8
@ -408,16 +414,24 @@ L3LoopDz:
ld1 {v0.4s}, [x10], #16 ld1 {v0.4s}, [x10], #16
L3Quan: L3Quan:
ld1 {v1.4s}, [x7], #16 ld1 {v1.4s}, [x2], #16
ld1 {v2.d}[0], [x19] // x kernel sum ld1 {v2.d}[0], [x7] // x kernel sum
ld1 {v2.s}[2], [x6] ld1 {v2.s}[2], [x6]
ld1 {v24.4s}, [x20], #16 // weight quan zeropoint ld1 {v24.4s}, [x2], #16 // weight quan zeropoint
TILE3_INT2FLOAT: TILE3_INT2FLOAT:
scvtf v4.4s, v12.4s scvtf v4.4s, v12.4s
scvtf v5.4s, v13.4s scvtf v5.4s, v13.4s
scvtf v6.4s, v14.4s scvtf v6.4s, v14.4s
cbz x15, TILE3_SCALE
ld1 {v12.d}[0], [x15], #8
ld1 {v12.s}[2], [x15]
sub x15, x15, #8
fmul v4.4s, v4.4s, v12.s[0]
fmul v5.4s, v5.4s, v12.s[1]
fmul v6.4s, v6.4s, v12.s[2]
TILE3_SCALE:
fmul v12.4s, v4.4s, v1.4s fmul v12.4s, v4.4s, v1.4s
fmul v13.4s, v5.4s, v1.4s fmul v13.4s, v5.4s, v1.4s
fmul v14.4s, v6.4s, v1.4s fmul v14.4s, v6.4s, v1.4s
@ -544,15 +558,20 @@ L2LoopDz:
addp v13.4s, v18.4s, v19.4s addp v13.4s, v18.4s, v19.4s
L2Quan: L2Quan:
ld1 {v1.4s}, [x7], #16 ld1 {v1.4s}, [x2], #16
ld1 {v2.d}[0], [x19] // x kernel sum ld1 {v2.d}[0], [x7] // x kernel sum
ld1 {v24.4s}, [x20], #16 // weight quan zeropoint ld1 {v24.4s}, [x2], #16 // weight quan zeropoint
ld1 {v0.4s}, [x10], #16 ld1 {v0.4s}, [x10], #16
TILE2_INT2FLOAT: TILE2_INT2FLOAT:
scvtf v4.4s, v12.4s scvtf v4.4s, v12.4s
scvtf v5.4s, v13.4s scvtf v5.4s, v13.4s
cbz x15, TILE2_SCALE
ld1 {v12.d}[0], [x15]
fmul v4.4s, v4.4s, v12.s[0]
fmul v5.4s, v5.4s, v12.s[1]
TILE2_SCALE:
fmul v12.4s, v4.4s, v1.4s fmul v12.4s, v4.4s, v1.4s
fmul v13.4s, v5.4s, v1.4s fmul v13.4s, v5.4s, v1.4s
MLA_WEIGHTZERO v12, v2, v24, 0 // tile:0, oc:0-3 MLA_WEIGHTZERO v12, v2, v24, 0 // tile:0, oc:0-3
@ -647,12 +666,17 @@ L1LoopDz:
L1Quan: L1Quan:
ld1 {v1.4s}, [x7], #16 ld1 {v1.4s}, [x2], #16
ld1 {v2.s}[0], [x19] // x kernel sum ld1 {v2.s}[0], [x7] // x kernel sum
ld1 {v24.4s}, [x20], #16 // weight quan zeropoint ld1 {v24.4s}, [x2], #16 // weight quan zeropoint
TILE1_INT2FLOAT: TILE1_INT2FLOAT:
scvtf v4.4s, v12.4s scvtf v4.4s, v12.4s
cbz x15, TILE1_SCALE
ld1 {v12.s}[0], [x15]
fmul v4.4s, v4.4s, v12.s[0]
TILE1_SCALE:
fmul v12.4s, v4.4s, v1.4s fmul v12.4s, v4.4s, v1.4s
MLA_WEIGHTZERO v12, v2, v24, 0 // tile:0, oc:0-3 MLA_WEIGHTZERO v12, v2, v24, 0 // tile:0, oc:0-3
fadd v12.4s, v12.4s, v0.4s fadd v12.4s, v12.4s, v0.4s
@ -682,7 +706,6 @@ L1LoopCheck:
bne L1LoopDz bne L1LoopDz
End: End:
ldp x19, x20, [sp, #80]
ldp x21, x22, [sp, #64] ldp x21, x22, [sp, #64]
ldp d8, d9, [sp, #48] ldp d8, d9, [sp, #48]
ldp d10, d11, [sp, #32] ldp d10, d11, [sp, #32]

View File

@ -110,46 +110,43 @@ struct QuanPostTreatParameters {
//Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step //Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step
//x5:dst_depth_quad, x6: parameters, x7: realDstCount //x5:dst_depth_quad, x6: parameters, x7: realDstCount
//Load from x6: x8: scale, x9: bias, w28: useInt8, x25: xKernelSum, x26: weightQuantBias, x23: fp32minmax //Load from x6: x9: bias, w19: useInt8, x8: xKernelSum, x23: fp32minmax
// x24: extraScale // x24: extraScale
ldr x8, [x6, #0]
ldr x9, [x6, #8] ldr x9, [x6, #8]
stp d14, d15, [sp, #(-16 * 9)]! stp d14, d15, [sp, #(-16 * 8)]!
stp d12, d13, [sp, #(16 * 1)] stp d12, d13, [sp, #(16 * 1)]
stp d10, d11, [sp, #(16 * 2)] stp d10, d11, [sp, #(16 * 2)]
stp d8, d9, [sp, #(16 * 3)] stp d8, d9, [sp, #(16 * 3)]
stp x21, x22, [sp, #(16 * 4)] stp x21, x22, [sp, #(16 * 4)]
stp x19, x20, [sp, #(16 * 5)] stp x19, x20, [sp, #(16 * 5)]
stp x27, x28, [sp, #(16 * 6)] stp x23, x24, [sp, #(16 * 6)]
stp x25, x26, [sp, #(16 * 7)] stp x25, x26, [sp, #(16 * 7)]
stp x23, x24, [sp, #(16 * 8)]
lsl x15, x3, #5 // x15 = src_depth_quad * UNIT * SRC_UNIT ldr w19, [x6, #24] // useInt8
ldr x8, [x6, #40] // xKernelSum
ldr w28, [x6, #24] // useInt8
ldr x25, [x6, #40] // xKernelSum
ldr x26, [x6, #48] // weightQuantBias
ldr x24, [x6, #80] // extraScale ldr x24, [x6, #80] // extraScale
ldr x15, [x6, #96] // accumBuffer
mov x10, x15
mov x26, x24
lsl x22, x7, #2 // eDest * SRC_UNIT
add x23, x6, #16 // int8 max ptr add x23, x6, #16 // int8 max ptr
mov x21, #4 // sizeof(int8_t) * pack mov x21, #4 // sizeof(int8_t) * pack
cbnz w28, Start cbnz w19, TILE_12
mov x21, #16 // sizeof(float) * pack mov x21, #16 // sizeof(float) * pack
ldr x23, [x6, #56] // fp32minmax ldr x23, [x6, #56] // fp32minmax
Start:
lsl x22, x7, #2 // eDest * SRC_UNIT
TILE_12: TILE_12:
cmp x7, #12 cmp x7, #12
blt TILE_8 blt TILE_8
sub x25, x4, #128
cmp x5, #2 cmp x5, #2
blt L4LoopDz_TILE_12 blt L4LoopDz_TILE_12
L8LoopDz_TILE_12: L8LoopDz_TILE_12:
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
mov x20, x0 // tag dst address
mov x27, x2
SET_BIAS v8, v9, v10, v11 SET_BIAS v8, v9, v10, v11
SET_BIAS v12, v13, v14, v15 SET_BIAS v12, v13, v14, v15
@ -191,13 +188,12 @@ L8LoopDz_TILE_12:
bne L8LoopSz_TILE_12 bne L8LoopSz_TILE_12
L8LoopSzEnd_TILE_12: L8LoopSzEnd_TILE_12:
add x2, x27, x15
sub x5, x5, #2 sub x5, x5, #2
L8Tile12Quan: L8Tile12Quan:
ld1 {v0.4s, v1.4s}, [x8], #32 // scale ld1 {v0.4s, v1.4s}, [x2], #32 // scale
ld1 {v2.4s, v3.4s, v4.4s}, [x25] // x kernel sum ld1 {v2.4s, v3.4s, v4.4s}, [x8] // x kernel sum
ld1 {v5.4s, v6.4s}, [x26], #32 // weight quan zeropoint ld1 {v5.4s, v6.4s}, [x2], #32 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
Int32ToFloat v16, v17, v18, v19 Int32ToFloat v16, v17, v18, v19
@ -212,7 +208,7 @@ L8LoopDz_TILE_12:
MUL_SCALE v1, v24, v25, v26, v27 MUL_SCALE v1, v24, v25, v26, v27
MUL_SCALE v1, v28, v29, v30, v31 MUL_SCALE v1, v28, v29, v30, v31
cbz x24, TILE12_L8_MLA cbz x26, TILE12_L8_MLA
ld1 {v0.4s, v1.4s}, [x24], #32 ld1 {v0.4s, v1.4s}, [x24], #32
ld1 {v7.4s}, [x24] ld1 {v7.4s}, [x24]
MUL_EXTRA_SCALE v0, v8, v9, v10, v11 MUL_EXTRA_SCALE v0, v8, v9, v10, v11
@ -250,9 +246,8 @@ L8LoopDz_TILE_12:
MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7 MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7
MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7 MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7
cmp w28, #1 cmp w19, #1
beq L8Tile12QuanUseInt8 beq L8Tile12QuanUseInt8
sub x4, x4, #128
cbz x9, TILE12_ADD_DSTV cbz x9, TILE12_ADD_DSTV
TILE12_ADD_BIAS: TILE12_ADD_BIAS:
@ -263,21 +258,32 @@ L8LoopDz_TILE_12:
ADD_BIAS_FLOAT v20, v21, v22, v23, v1 ADD_BIAS_FLOAT v20, v21, v22, v23, v1
ADD_BIAS_FLOAT v24, v25, v26, v27, v1 ADD_BIAS_FLOAT v24, v25, v26, v27, v1
ADD_BIAS_FLOAT v28, v29, v30, v31, v1 ADD_BIAS_FLOAT v28, v29, v30, v31, v1
b TILE12_POST cbnz x0, TILE12_POST
b TILE12_L8_ACCUM_BUFFER
TILE12_ADD_DSTV: TILE12_ADD_DSTV:
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x20], #64 ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20], #64 ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3 ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3
ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7 ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x20], x4 ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20], #64 ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ADD_FLOAT v16, v17, v18, v19, v0, v1, v2, v3 ADD_FLOAT v16, v17, v18, v19, v0, v1, v2, v3
ADD_FLOAT v20, v21, v22, v23, v4, v5, v6, v7 ADD_FLOAT v20, v21, v22, v23, v4, v5, v6, v7
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x20], #64 ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20] ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ADD_FLOAT v24, v25, v26, v27, v0, v1, v2, v3 ADD_FLOAT v24, v25, v26, v27, v0, v1, v2, v3
ADD_FLOAT v28, v29, v30, v31, v4, v5, v6, v7 ADD_FLOAT v28, v29, v30, v31, v4, v5, v6, v7
cbnz x0, TILE12_POST
TILE12_L8_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x15], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x15], #64
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x15], #64
st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x15], #64
b L8Tile12LoopCheck
TILE12_POST: TILE12_POST:
cbz x23, TILE12_STORE cbz x23, TILE12_STORE
@ -294,11 +300,10 @@ L8LoopDz_TILE_12:
TILE12_STORE: TILE12_STORE:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x4 st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x25
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64
st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], x4 st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], x25
add x4, x4, #128
b L8Tile12LoopCheck b L8Tile12LoopCheck
L8Tile12QuanUseInt8: L8Tile12QuanUseInt8:
@ -377,9 +382,9 @@ L4LoopDz_TILE_12:
L4LoopSzEnd_TILE_12: L4LoopSzEnd_TILE_12:
L4Tile12Quan: L4Tile12Quan:
ld1 {v0.4s}, [x8] // scale ld1 {v0.4s}, [x2], #16 // scale
ld1 {v2.4s, v3.4s, v4.4s}, [x25]// x kernel sum ld1 {v2.4s, v3.4s, v4.4s}, [x8]// x kernel sum
ld1 {v5.4s}, [x26], #16 // weight quan zeropoint ld1 {v5.4s}, [x2], #16 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
Int32ToFloat v16, v17, v18, v19 Int32ToFloat v16, v17, v18, v19
@ -387,7 +392,7 @@ L4LoopDz_TILE_12:
MUL_SCALE v0, v12, v13, v14, v15 MUL_SCALE v0, v12, v13, v14, v15
MUL_SCALE v0, v16, v17, v18, v19 MUL_SCALE v0, v16, v17, v18, v19
cbz x24, TILE12_L4_MLA cbz x26, TILE12_L4_MLA
ld1 {v0.4s, v1.4s}, [x24], #32 ld1 {v0.4s, v1.4s}, [x24], #32
ld1 {v7.4s}, [x24] ld1 {v7.4s}, [x24]
MUL_EXTRA_SCALE v0, v8, v9, v10, v11 MUL_EXTRA_SCALE v0, v8, v9, v10, v11
@ -408,9 +413,8 @@ L4LoopDz_TILE_12:
MLA_WEIGHTZERO v17, v4, v5, 1 // tile:9, oc:0-3 MLA_WEIGHTZERO v17, v4, v5, 1 // tile:9, oc:0-3
MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3 MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3
MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3 MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3
cmp w28, #1 cmp w19, #1
beq L4Tile12QuanUseInt8 beq L4Tile12QuanUseInt8
sub x4, x4, #128
TILE12_L4_ADD_BIAS: TILE12_L4_ADD_BIAS:
cbz x9, TILE12_L4_ADD_DSTV cbz x9, TILE12_L4_ADD_DSTV
@ -418,16 +422,23 @@ L4LoopDz_TILE_12:
ADD_BIAS_FLOAT v8, v9, v10, v11, v0 ADD_BIAS_FLOAT v8, v9, v10, v11, v0
ADD_BIAS_FLOAT v12, v13, v14, v15, v0 ADD_BIAS_FLOAT v12, v13, v14, v15, v0
ADD_BIAS_FLOAT v16, v17, v18, v19, v0 ADD_BIAS_FLOAT v16, v17, v18, v19, v0
b TILE12_L4_POST cbnz x0, TILE12_L4_POST
b TILE12_L4_ACCUM_BUFFER
TILE12_L4_ADD_DSTV: TILE12_L4_ADD_DSTV:
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0] ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], #64
sub x0, x0, #128
ADD_FLOAT v8, v9, v10, v11, v20, v21, v22, v23 ADD_FLOAT v8, v9, v10, v11, v20, v21, v22, v23
ADD_FLOAT v12, v13, v14, v15, v24, v25, v26, v27 ADD_FLOAT v12, v13, v14, v15, v24, v25, v26, v27
ADD_FLOAT v16, v17, v18, v19, v28, v29, v30, v31 ADD_FLOAT v16, v17, v18, v19, v28, v29, v30, v31
cbnz x0, TILE12_L4_POST
TILE12_L4_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x15], #64
b End
TILE12_L4_POST: TILE12_L4_POST:
cbz x23, TILE12_L4_STORE cbz x23, TILE12_L4_STORE
@ -440,8 +451,7 @@ L4LoopDz_TILE_12:
TILE12_L4_STORE: TILE12_L4_STORE:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x4 st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x25
add x4, x4, #128
b End b End
L4Tile12QuanUseInt8: L4Tile12QuanUseInt8:
@ -474,18 +484,16 @@ L4LoopDz_TILE_12:
TILE_8: TILE_8:
cmp x7, #8 cmp x7, #8
blt TILE_4 blt TILE_4
mov x10, x0 sub x25, x4, #64
mov x6, x0
mov x12, x2 mov x12, x2
mov x14, x5 mov x14, x5
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x26 // weightQuantBias
cmp x5, #2 cmp x5, #2
blt L4LoopDz_TILE_8 blt L4LoopDz_TILE_8
L8LoopDz_TILE_8: L8LoopDz_TILE_8:
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
mov x27, x12
SET_BIAS v8, v9, v10, v11 SET_BIAS v8, v9, v10, v11
SET_BIAS v12, v13, v14, v15 SET_BIAS v12, v13, v14, v15
@ -517,13 +525,12 @@ L8LoopDz_TILE_8:
bne L8LoopSz_TILE_8 bne L8LoopSz_TILE_8
L8LoopSzEnd_TILE_8: L8LoopSzEnd_TILE_8:
add x12, x27, x15
sub x14, x14, #2 sub x14, x14, #2
L8Tile8Quan: L8Tile8Quan:
ld1 {v0.4s, v1.4s}, [x19], #32 // scale ld1 {v0.4s, v1.4s}, [x12], #32 // scale
ld1 {v2.4s, v3.4s}, [x25] // x kernel sum ld1 {v2.4s, v3.4s}, [x8] // x kernel sum
ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint ld1 {v24.4s, v25.4s}, [x12], #32 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
Int32ToFloat v16, v17, v18, v19 Int32ToFloat v16, v17, v18, v19
@ -533,7 +540,7 @@ L8LoopDz_TILE_8:
MUL_SCALE v1, v16, v17, v18, v19 MUL_SCALE v1, v16, v17, v18, v19
MUL_SCALE v1, v20, v21, v22, v23 MUL_SCALE v1, v20, v21, v22, v23
cbz x24, TILE8_L8_MLA cbz x26, TILE8_L8_MLA
ld1 {v0.4s, v1.4s}, [x24] ld1 {v0.4s, v1.4s}, [x24]
MUL_EXTRA_SCALE v0, v8, v9, v10, v11 MUL_EXTRA_SCALE v0, v8, v9, v10, v11
MUL_EXTRA_SCALE v1, v12, v13, v14, v15 MUL_EXTRA_SCALE v1, v12, v13, v14, v15
@ -558,9 +565,8 @@ L8LoopDz_TILE_8:
MLA_WEIGHTZERO v22, v3, v25, 2 // tile:6, oc:4-7 MLA_WEIGHTZERO v22, v3, v25, 2 // tile:6, oc:4-7
MLA_WEIGHTZERO v23, v3, v25, 3 // tile:7, oc:4-7 MLA_WEIGHTZERO v23, v3, v25, 3 // tile:7, oc:4-7
cmp w28, #1 cmp w19, #1
beq L8Tile8QuanUseInt8 beq L8Tile8QuanUseInt8
sub x4, x4, #64
cbz x9, TILE8_ADD_DSTV cbz x9, TILE8_ADD_DSTV
TILE8_ADD_BIAS: TILE8_ADD_BIAS:
@ -569,19 +575,27 @@ L8LoopDz_TILE_8:
ADD_BIAS_FLOAT v12, v13, v14, v15, v0 ADD_BIAS_FLOAT v12, v13, v14, v15, v0
ADD_BIAS_FLOAT v16, v17, v18, v19, v1 ADD_BIAS_FLOAT v16, v17, v18, v19, v1
ADD_BIAS_FLOAT v20, v21, v22, v23, v1 ADD_BIAS_FLOAT v20, v21, v22, v23, v1
b TILE8_POST cbnz x0, TILE8_POST
b TILE8_L8_ACCUM_BUFFER
TILE8_ADD_DSTV: TILE8_ADD_DSTV:
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64 ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], x4 ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64 ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10] ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], #64
ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3 ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3
ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7 ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7
ADD_FLOAT v16, v17, v18, v19, v24, v25, v26, v27 ADD_FLOAT v16, v17, v18, v19, v24, v25, v26, v27
ADD_FLOAT v20, v21, v22, v23, v28, v29, v30, v31 ADD_FLOAT v20, v21, v22, v23, v28, v29, v30, v31
sub x10, x10, #128 cbnz x0, TILE8_POST
sub x10, x10, x4
TILE8_L8_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x15], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x15], #64
b L8Tile8LoopCheck
TILE8_POST: TILE8_POST:
cbz x23, TILE8_STORE cbz x23, TILE8_STORE
@ -594,11 +608,10 @@ L8LoopDz_TILE_8:
sub x23, x23, #4 sub x23, x23, #4
TILE8_STORE: TILE8_STORE:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], #64 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], x25
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64 st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x6], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], x4 st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x6], x25
add x4, x4, #64
b L8Tile8LoopCheck b L8Tile8LoopCheck
L8Tile8QuanUseInt8: L8Tile8QuanUseInt8:
@ -630,8 +643,8 @@ L8LoopDz_TILE_8:
smin v17.16b, v7.16b, v17.16b smin v17.16b, v7.16b, v17.16b
smin v18.16b, v7.16b, v18.16b smin v18.16b, v7.16b, v18.16b
smin v19.16b, v7.16b, v19.16b smin v19.16b, v7.16b, v19.16b
st1 {v16.16b, v17.16b}, [x10], x4 st1 {v16.16b, v17.16b}, [x6], x4
st1 {v18.16b, v19.16b}, [x10], x4 st1 {v18.16b, v19.16b}, [x6], x4
L8Tile8LoopCheck: L8Tile8LoopCheck:
cmp x14, #1 cmp x14, #1
@ -663,15 +676,15 @@ L4LoopDz_TILE_8:
L4LoopSzEnd_TILE_8: L4LoopSzEnd_TILE_8:
L4Tile8Quan: L4Tile8Quan:
ld1 {v0.4s}, [x19], #16 // scale ld1 {v0.4s}, [x12], #16 // scale
ld1 {v2.4s, v3.4s}, [x25] // x kernel sum ld1 {v2.4s, v3.4s}, [x8] // x kernel sum
ld1 {v24.4s}, [x6], #16 // weight quan zeropoint ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
MUL_SCALE v0, v8, v9, v10, v11 MUL_SCALE v0, v8, v9, v10, v11
MUL_SCALE v0, v12, v13, v14, v15 MUL_SCALE v0, v12, v13, v14, v15
cbz x24, TILE8_L4_MLA cbz x26, TILE8_L4_MLA
ld1 {v0.4s, v1.4s}, [x24] ld1 {v0.4s, v1.4s}, [x24]
MUL_EXTRA_SCALE v0, v8, v9, v10, v11 MUL_EXTRA_SCALE v0, v8, v9, v10, v11
MUL_EXTRA_SCALE v1, v12, v13, v14, v15 MUL_EXTRA_SCALE v1, v12, v13, v14, v15
@ -685,23 +698,34 @@ L4LoopDz_TILE_8:
MLA_WEIGHTZERO v13, v3, v24, 1 // tile:5, oc:0-3 MLA_WEIGHTZERO v13, v3, v24, 1 // tile:5, oc:0-3
MLA_WEIGHTZERO v14, v3, v24, 2 // tile:6, oc:0-3 MLA_WEIGHTZERO v14, v3, v24, 2 // tile:6, oc:0-3
MLA_WEIGHTZERO v15, v3, v24, 3 // tile:7, oc:0-3 MLA_WEIGHTZERO v15, v3, v24, 3 // tile:7, oc:0-3
cmp w28, #1 cmp w19, #1
beq L4Tile8QuanUseInt8 beq L4Tile8QuanUseInt8
sub x4, x4, #64
cbz x9, TILE8_L4_ADD_DSTV cbz x9, TILE8_L4_ADD_DSTV
TILE8_L4_ADD_BIAS: TILE8_L4_ADD_BIAS:
ld1 {v4.4s}, [x20], #16 ld1 {v4.4s}, [x20], #16
ADD_BIAS_FLOAT v8, v9, v10, v11, v4 ADD_BIAS_FLOAT v8, v9, v10, v11, v4
ADD_BIAS_FLOAT v12, v13, v14, v15, v4 ADD_BIAS_FLOAT v12, v13, v14, v15, v4
b TILE8_L4_POST cbnz x0, TILE8_L4_POST
b TILE8_L4_ACCUM_BUFFER
TILE8_L4_ADD_DSTV: TILE8_L4_ADD_DSTV:
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64 ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10] ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
sub x10, x10, #64 ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], #64
ADD_FLOAT v8, v9, v10, v11, v4, v5, v6, v7 ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3
ADD_FLOAT v12, v13, v14, v15, v16, v17, v18, v19 ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7
ADD_FLOAT v16, v17, v18, v19, v24, v25, v26, v27
ADD_FLOAT v20, v21, v22, v23, v28, v29, v30, v31
cbnz x0, TILE8_L4_POST
TILE8_L4_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x15], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x15], #64
b Tile8End
TILE8_L4_POST: TILE8_L4_POST:
cbz x23, TILE8_L4_STORE cbz x23, TILE8_L4_STORE
@ -712,9 +736,8 @@ L4LoopDz_TILE_8:
sub x23, x23, #4 sub x23, x23, #4
TILE8_L4_STORE: TILE8_L4_STORE:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], #64 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], x25
add x4, x4, #64
b Tile8End b Tile8End
L4Tile8QuanUseInt8: L4Tile8QuanUseInt8:
@ -735,33 +758,28 @@ L4LoopDz_TILE_8:
smax v17.16b, v6.16b, v17.16b smax v17.16b, v6.16b, v17.16b
smin v16.16b, v7.16b, v16.16b smin v16.16b, v7.16b, v16.16b
smin v17.16b, v7.16b, v17.16b smin v17.16b, v7.16b, v17.16b
st1 {v16.16b, v17.16b}, [x10], x4 st1 {v16.16b, v17.16b}, [x6], x4
Tile8End: Tile8End:
cbz x24, Tile8_End_Offset cbz x0, Tile8_End_Offset
add x24, x24, #32 add x0, x0, x21, LSL #3
Tile8_End_Offset: Tile8_End_Offset:
sub x7, x7, #8 sub x7, x7, #8
add x0, x0, x21, LSL #3
add x1, x1, #32 add x1, x1, #32
add x25, x25, #32 add x8, x8, #32
add x24, x24, #32
TILE_4: TILE_4:
cmp x7, #4 cmp x7, #4
blt TILE_1 blt TILE_1
mov x10, x0 mov x6, x0
mov x12, x2 mov x12, x2
mov x14, x5 mov x14, x5
mov x19, x8
mov x20, x9 mov x20, x9
mov x6, x26 // weightQuantBias
cmp x5, #2 cmp x5, #2
blt L4LoopDz_TILE_4 blt L4LoopDz_TILE_4
L8LoopDz_TILE_4: L8LoopDz_TILE_4:
//ld1 {v0.4s, v1.4s}, [x20], #32 // bias
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
mov x27, x12
SET_BIAS v8, v9, v10, v11 SET_BIAS v8, v9, v10, v11
SET_BIAS v12, v13, v14, v15 SET_BIAS v12, v13, v14, v15
@ -782,19 +800,18 @@ L8LoopDz_TILE_4:
bne L8LoopSz_TILE_4 bne L8LoopSz_TILE_4
L8LoopSzEnd_TILE_4: L8LoopSzEnd_TILE_4:
add x12, x27, x15
sub x14, x14, #2 sub x14, x14, #2
L8Tile4Quan: L8Tile4Quan:
ld1 {v0.4s, v1.4s}, [x19], #32 // scale ld1 {v0.4s, v1.4s}, [x12], #32 // scale
ld1 {v2.4s}, [x25] // x kernel sum ld1 {v2.4s}, [x8] // x kernel sum
ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint ld1 {v24.4s, v25.4s}, [x12], #32 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
MUL_SCALE v0, v8, v9, v10, v11 MUL_SCALE v0, v8, v9, v10, v11
MUL_SCALE v1, v12, v13, v14, v15 MUL_SCALE v1, v12, v13, v14, v15
cbz x24, TILE4_L8_MLA cbz x26, TILE4_L8_MLA
ld1 {v0.4s}, [x24] ld1 {v0.4s}, [x24]
MUL_EXTRA_SCALE v0, v8, v9, v10, v11 MUL_EXTRA_SCALE v0, v8, v9, v10, v11
MUL_EXTRA_SCALE v0, v12, v13, v14, v15 MUL_EXTRA_SCALE v0, v12, v13, v14, v15
@ -809,7 +826,7 @@ L8LoopDz_TILE_4:
MLA_WEIGHTZERO v14, v2, v25, 2 // tile:2, oc:4-7 MLA_WEIGHTZERO v14, v2, v25, 2 // tile:2, oc:4-7
MLA_WEIGHTZERO v15, v2, v25, 3 // tile:3, oc:4-7 MLA_WEIGHTZERO v15, v2, v25, 3 // tile:3, oc:4-7
cmp w28, #1 cmp w19, #1
beq L8Tile4QuanUseInt8 beq L8Tile4QuanUseInt8
cbz x9, TILE4_ADD_DSTV cbz x9, TILE4_ADD_DSTV
@ -817,14 +834,20 @@ L8LoopDz_TILE_4:
ld1 {v4.4s, v5.4s}, [x20], #32 ld1 {v4.4s, v5.4s}, [x20], #32
ADD_BIAS_FLOAT v8, v9, v10, v11, v4 ADD_BIAS_FLOAT v8, v9, v10, v11, v4
ADD_BIAS_FLOAT v12, v13, v14, v15, v5 ADD_BIAS_FLOAT v12, v13, v14, v15, v5
b TILE4_POST cbnz x0, TILE4_POST
b TILE4_L8_ACCUM_BUFFER
TILE4_ADD_DSTV: TILE4_ADD_DSTV:
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], x4 ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10] ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
sub x10, x10, x4 ADD_FLOAT v8, v9, v10, v11, v16, v17, v18, v19
ADD_FLOAT v8, v9, v10, v11, v4, v5, v6, v7 ADD_FLOAT v12, v13, v14, v15, v20, v21, v22, v23
ADD_FLOAT v12, v13, v14, v15, v16, v17, v18, v19 cbnz x0, TILE4_POST
TILE4_L8_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
b L8Tile4LoopCheck
TILE4_POST: TILE4_POST:
cbz x23, TILE4_STORE cbz x23, TILE4_STORE
@ -835,8 +858,8 @@ L8LoopDz_TILE_4:
sub x23, x23, #4 sub x23, x23, #4
TILE4_STORE: TILE4_STORE:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], x4 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], x4
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], x4
b L8Tile4LoopCheck b L8Tile4LoopCheck
L8Tile4QuanUseInt8: L8Tile4QuanUseInt8:
@ -857,8 +880,8 @@ L8LoopDz_TILE_4:
smax v17.16b, v6.16b, v17.16b smax v17.16b, v6.16b, v17.16b
smin v16.16b, v7.16b, v16.16b smin v16.16b, v7.16b, v16.16b
smin v17.16b, v7.16b, v17.16b smin v17.16b, v7.16b, v17.16b
st1 {v16.16b}, [x10], x4 st1 {v16.16b}, [x6], x4
st1 {v17.16b}, [x10], x4 st1 {v17.16b}, [x6], x4
L8Tile4LoopCheck: L8Tile4LoopCheck:
cmp x14, #1 cmp x14, #1
@ -884,13 +907,13 @@ L4LoopDz_TILE_4:
L4LoopSzEnd_TILE_4: L4LoopSzEnd_TILE_4:
L4Tile4Quan: L4Tile4Quan:
ld1 {v0.4s}, [x19], #16 // scale ld1 {v0.4s}, [x12], #16 // scale
ld1 {v2.4s}, [x25] // x kernel sum ld1 {v2.4s}, [x8] // x kernel sum
ld1 {v24.4s}, [x6], #16 // weight quan zeropoint ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
MUL_SCALE v0, v8, v9, v10, v11 MUL_SCALE v0, v8, v9, v10, v11
cbz x24, TILE4_L4_MLA cbz x26, TILE4_L4_MLA
ld1 {v0.4s}, [x24] ld1 {v0.4s}, [x24]
MUL_EXTRA_SCALE v0, v8, v9, v10, v11 MUL_EXTRA_SCALE v0, v8, v9, v10, v11
@ -900,18 +923,24 @@ L4LoopDz_TILE_4:
MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3 MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3
MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3 MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3
cmp w28, #1 cmp w19, #1
beq L4Tile4QuanUseInt8 beq L4Tile4QuanUseInt8
cbz x9, TILE4_L4_ADD_DSTV cbz x9, TILE4_L4_ADD_DSTV
TILE4_L4_ADD_BIAS: TILE4_L4_ADD_BIAS:
ld1 {v3.4s}, [x20], #16 ld1 {v3.4s}, [x20], #16
ADD_BIAS_FLOAT v8, v9, v10, v11, v3 ADD_BIAS_FLOAT v8, v9, v10, v11, v3
b TILE4_L4_POST cbnz x0, TILE4_L4_POST
b TILE4_L4_ACCUM_BUFFER
TILE4_L4_ADD_DSTV: TILE4_L4_ADD_DSTV:
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10] ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64
ADD_FLOAT v8, v9, v10, v11, v12, v13, v14, v15 ADD_FLOAT v8, v9, v10, v11, v16, v17, v18, v19
cbnz x0, TILE4_L4_POST
TILE4_L4_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
b Tile4End
TILE4_L4_POST: TILE4_L4_POST:
cbz x23, TILE4_L4_STORE cbz x23, TILE4_L4_STORE
@ -921,7 +950,7 @@ L4LoopDz_TILE_4:
sub x23, x23, #4 sub x23, x23, #4
TILE4_L4_STORE: TILE4_L4_STORE:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], x4 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], x4
b Tile4End b Tile4End
L4Tile4QuanUseInt8: L4Tile4QuanUseInt8:
@ -937,31 +966,27 @@ L4LoopDz_TILE_4:
Int16ToInt8_ONE v0, v1, v16 Int16ToInt8_ONE v0, v1, v16
smax v16.16b, v6.16b, v16.16b smax v16.16b, v6.16b, v16.16b
smin v16.16b, v7.16b, v16.16b smin v16.16b, v7.16b, v16.16b
st1 {v16.16b}, [x10], x4 st1 {v16.16b}, [x6], x4
Tile4End: Tile4End:
cbz x24, Tile4_End_Offset cbz x0, Tile4_End_Offset
add x24, x24, #16 add x0, x0, x21, LSL #2
Tile4_End_Offset: Tile4_End_Offset:
sub x7, x7, #4 sub x7, x7, #4
add x0, x0, x21, LSL #2
add x1, x1, #16 add x1, x1, #16
add x25, x25, #16 add x8, x8, #16
add x24, x24, #16
TILE_1: TILE_1:
cbz x7, End cbz x7, End
mov x10, x0
mov x12, x2 mov x12, x2
mov x14, x5 mov x14, x5
mov x19, x8
mov x20, x9 mov x20, x9
mov x6, x26 // weightQuantBias mov x6, x0
cmp x5, #2 cmp x5, #2
blt L4LoopDz_TILE_1 blt L4LoopDz_TILE_1
L8LoopDz_TILE_1: L8LoopDz_TILE_1:
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
mov x27, x12
movi v8.16b, #0 movi v8.16b, #0
movi v9.16b, #0 movi v9.16b, #0
@ -974,19 +999,18 @@ L8LoopDz_TILE_1:
bne L8LoopSz_TILE_1 bne L8LoopSz_TILE_1
L8LoopSzEnd_TILE_1: L8LoopSzEnd_TILE_1:
add x12, x27, x15
sub x14, x14, #2 sub x14, x14, #2
L8Tile1Quan: L8Tile1Quan:
ld1 {v0.4s, v1.4s}, [x19], #32 // scale ld1 {v0.4s, v1.4s}, [x12], #32 // scale
ld1 {v2.s}[0], [x25] // x kernel sum ld1 {v2.s}[0], [x8] // x kernel sum
ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint ld1 {v24.4s, v25.4s}, [x12], #32 // weight quan zeropoint
scvtf v8.4s, v8.4s scvtf v8.4s, v8.4s
scvtf v9.4s, v9.4s scvtf v9.4s, v9.4s
fmul v8.4s, v8.4s, v0.4s fmul v8.4s, v8.4s, v0.4s
fmul v9.4s, v9.4s, v1.4s fmul v9.4s, v9.4s, v1.4s
cbz x24, TILE1_L8_MLA cbz x26, TILE1_L8_MLA
ld1 {v0.s}[0], [x24] ld1 {v0.s}[0], [x24]
fmul v8.4s, v8.4s, v0.s[0] fmul v8.4s, v8.4s, v0.s[0]
fmul v9.4s, v9.4s, v0.s[0] fmul v9.4s, v9.4s, v0.s[0]
@ -995,7 +1019,7 @@ L8LoopDz_TILE_1:
MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3
MLA_WEIGHTZERO v9, v2, v25, 0 // tile:0, oc:4-7 MLA_WEIGHTZERO v9, v2, v25, 0 // tile:0, oc:4-7
cmp w28, #1 cmp w19, #1
beq L8Tile1QuanUseInt8 beq L8Tile1QuanUseInt8
cbz x9, TILE1_ADD_DSTV cbz x9, TILE1_ADD_DSTV
@ -1003,14 +1027,18 @@ L8LoopDz_TILE_1:
ld1 {v10.4s, v11.4s}, [x20], #32 ld1 {v10.4s, v11.4s}, [x20], #32
fadd v8.4s, v8.4s, v10.4s fadd v8.4s, v8.4s, v10.4s
fadd v9.4s, v9.4s, v11.4s fadd v9.4s, v9.4s, v11.4s
b TILE1_POST cbnz x0, TILE1_POST
b TILE1_L8_ACCUM_BUFFER
TILE1_ADD_DSTV: TILE1_ADD_DSTV:
ld1 {v10.4s}, [x10], x4 ld1 {v10.4s, v11.4s}, [x10], #32
ld1 {v11.4s}, [x10]
sub x10, x10, x4
fadd v8.4s, v8.4s, v10.4s fadd v8.4s, v8.4s, v10.4s
fadd v9.4s, v9.4s, v11.4s fadd v9.4s, v9.4s, v11.4s
cbnz x0, TILE1_POST
TILE1_L8_ACCUM_BUFFER:
st1 {v8.4s, v9.4s}, [x15], #32
b L8Tile1LoopCheck
TILE1_POST: TILE1_POST:
cbz x23, TILE1_STORE cbz x23, TILE1_STORE
@ -1023,8 +1051,8 @@ L8LoopDz_TILE_1:
fmax v9.4s, v9.4s, v26.4s fmax v9.4s, v9.4s, v26.4s
TILE1_STORE: TILE1_STORE:
st1 {v8.4s}, [x10], x4 st1 {v8.4s}, [x6], x4
st1 {v9.4s}, [x10], x4 st1 {v9.4s}, [x6], x4
b L8Tile1LoopCheck b L8Tile1LoopCheck
L8Tile1QuanUseInt8: L8Tile1QuanUseInt8:
@ -1043,8 +1071,8 @@ L8LoopDz_TILE_1:
sqxtn v16.8b, v0.8h sqxtn v16.8b, v0.8h
smax v16.16b, v6.16b, v16.16b smax v16.16b, v6.16b, v16.16b
smin v16.16b, v7.16b, v16.16b smin v16.16b, v7.16b, v16.16b
st1 {v16.s}[0], [x10], x4 st1 {v16.s}[0], [x6], x4
st1 {v16.s}[1], [x10], x4 st1 {v16.s}[1], [x6], x4
L8Tile1LoopCheck: L8Tile1LoopCheck:
cmp x14, #1 cmp x14, #1
@ -1066,30 +1094,36 @@ L4LoopDz_TILE_1:
L4LoopSzEnd_TILE_1: L4LoopSzEnd_TILE_1:
L4Tile1Quan: L4Tile1Quan:
ld1 {v0.4s}, [x19], #16 // scale ld1 {v0.4s}, [x12], #16 // scale
ld1 {v2.s}[0], [x25] // x kernel sum ld1 {v2.s}[0], [x8] // x kernel sum
ld1 {v24.4s}, [x6], #16 // weight quan zeropoint ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
scvtf v8.4s, v8.4s scvtf v8.4s, v8.4s
fmul v8.4s, v8.4s, v0.4s fmul v8.4s, v8.4s, v0.4s
cbz x24, TILE1_L4_MLA cbz x26, TILE1_L4_MLA
ld1 {v0.s}[0], [x24] ld1 {v0.s}[0], [x24]
fmul v8.4s, v8.4s, v0.s[0] fmul v8.4s, v8.4s, v0.s[0]
TILE1_L4_MLA: TILE1_L4_MLA:
MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3
cmp w28, #1 cmp w19, #1
beq L4Tile1QuanUseInt8 beq L4Tile1QuanUseInt8
cbz x9, TILE1_L4_ADD_DSTV cbz x9, TILE1_L4_ADD_DSTV
TILE1_L4_ADD_BIAS: TILE1_L4_ADD_BIAS:
ld1 {v4.4s}, [x20], #16 ld1 {v4.4s}, [x20], #16
fadd v8.4s, v8.4s, v4.4s fadd v8.4s, v8.4s, v4.4s
b TILE1_L4_POST cbnz x0, TILE1_L4_POST
b TILE1_L4_ACCUM_BUFFER
TILE1_L4_ADD_DSTV: TILE1_L4_ADD_DSTV:
ld1 {v4.4s}, [x10] ld1 {v10.4s}, [x10], #16
fadd v8.4s, v8.4s, v4.4s fadd v8.4s, v8.4s, v10.4s
cbnz x0, TILE1_L4_POST
TILE1_L4_ACCUM_BUFFER:
st1 {v8.4s}, [x15], #16
b Tile1End
TILE1_L4_POST: TILE1_L4_POST:
cbz x23, TILE1_L4_STORE cbz x23, TILE1_L4_STORE
@ -1099,7 +1133,7 @@ L4LoopDz_TILE_1:
fmax v8.4s, v8.4s, v26.4s fmax v8.4s, v8.4s, v26.4s
fmin v8.4s, v8.4s, v27.4s fmin v8.4s, v8.4s, v27.4s
TILE1_L4_STORE: TILE1_L4_STORE:
st1 {v8.4s}, [x10], x4 st1 {v8.4s}, [x6], x4
b Tile1End b Tile1End
L4Tile1QuanUseInt8: L4Tile1QuanUseInt8:
@ -1115,29 +1149,27 @@ L4LoopDz_TILE_1:
sqxtn v16.8b, v0.8h sqxtn v16.8b, v0.8h
smax v16.8b, v6.8b, v16.8b smax v16.8b, v6.8b, v16.8b
smin v16.8b, v7.8b, v16.8b smin v16.8b, v7.8b, v16.8b
st1 {v16.s}[0], [x10], x4 st1 {v16.s}[0], [x6], x4
Tile1End: Tile1End:
cbz x24, Tile1_End_Offset cbz x0, Tile1_End_Offset
add x24, x24, #4
Tile1_End_Offset:
subs x7, x7, #1
add x0, x0, x21 add x0, x0, x21
Tile1_End_Offset:
add x24, x24, #4
subs x7, x7, #1
add x1, x1, #4 add x1, x1, #4
add x25, x25, #4 add x8, x8, #4
bne TILE_1 bne TILE_1
End: End:
ldp x23, x24, [sp, #(16 * 8)]
ldp x25, x26, [sp, #(16 * 7)] ldp x25, x26, [sp, #(16 * 7)]
ldp x27, x28, [sp, #(16 * 6)] ldp x23, x24, [sp, #(16 * 6)]
ldp x19, x20, [sp, #(16 * 5)] ldp x19, x20, [sp, #(16 * 5)]
ldp x21, x22, [sp, #(16 * 4)] ldp x21, x22, [sp, #(16 * 4)]
ldp d8, d9, [sp, #(16 * 3)] ldp d8, d9, [sp, #(16 * 3)]
ldp d10, d11, [sp, #(16 * 2)] ldp d10, d11, [sp, #(16 * 2)]
ldp d12, d13, [sp, #(16 * 1)] ldp d12, d13, [sp, #(16 * 1)]
ldp d14, d15, [sp], #(16 * 9) ldp d14, d15, [sp], #(16 * 8)
ret ret
#endif // __aarch64__ #endif // __aarch64__

View File

@ -119,13 +119,12 @@ struct QuanPostTreatParameters {
//Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step //Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step
//x5:dst_depth_quad, x6: parameters, x7: realDstCount //x5:dst_depth_quad, x6: parameters, x7: realDstCount
//Load from x6: x8: scale, x9: bias, w23: useInt8, x27: srcKernelSum, x28: weightQuanBias, //Load from x6: x9: bias, w23: useInt8, x8: srcKernelSum
// EP=10,LP=8,HP=8 // EP=10,LP=8,HP=8
ldr x8, [x6, #0]
ldr x9, [x6, #8] ldr x9, [x6, #8]
stp d14, d15, [sp, #(-16 * 10)]! stp d14, d15, [sp, #(-16 * 8)]!
stp d12, d13, [sp, #(16 * 1)] stp d12, d13, [sp, #(16 * 1)]
stp d10, d11, [sp, #(16 * 2)] stp d10, d11, [sp, #(16 * 2)]
stp d8, d9, [sp, #(16 * 3)] stp d8, d9, [sp, #(16 * 3)]
@ -133,14 +132,14 @@ stp x21, x22, [sp, #(16 * 4)]
stp x19, x20, [sp, #(16 * 5)] stp x19, x20, [sp, #(16 * 5)]
stp x23, x24, [sp, #(16 * 6)] stp x23, x24, [sp, #(16 * 6)]
stp x25, x26, [sp, #(16 * 7)] stp x25, x26, [sp, #(16 * 7)]
stp x27, x28, [sp, #(16 * 8)]
ldr w23, [x6, #24] ldr w23, [x6, #24]
ldr x27, [x6, #40] // srcKernelSum ldr x8, [x6, #40] // srcKernelSum
ldr x28, [x6, #48] // weightQuanBias
lsl x15, x3, #6 // x15 = src_depth_quad * UNIT * UNIT_SRC = src_depth_quad * 64 = src_depth_quad << 6
ldr x10, [x6, #80] // extra scale ldr x10, [x6, #80] // extra scale
ldr x15, [x6, #96] // accumBuffer
mov x19, x15
mov x26, x10
mov x21, #4 // sizeof(int8_t) * pack mov x21, #4 // sizeof(int8_t) * pack
add x14, x6, #16 // int8 max ptr add x14, x6, #16 // int8 max ptr
cbnz w23, Start cbnz w23, Start
@ -153,9 +152,8 @@ lsl x22, x7, #3 // eDest * GEMM_INT8_SRC_UNIT
TILE_10: TILE_10:
cmp x7, #10 cmp x7, #10
blt TILE_8 blt TILE_8
sub x25, x4, #128
sub x4, x4, #32 // For int8 output, x4-64 sub x4, x4, #32 // For int8 output, x4-64
cbnz w23, TILE10_DZ
sub x4, x4, #96 // For float32 output, x4-32-96=x4-128
TILE10_DZ: TILE10_DZ:
cmp x5, #2 cmp x5, #2
@ -163,7 +161,6 @@ blt LoopDz4_TILE_10
LoopDz8_TILE_10: LoopDz8_TILE_10:
mov x11, x1 // src mov x11, x1 // src
mov x12, x2 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
SET_0_5 v12, v16, v20, v24, v28 // oc:0,1,0,1 SET_0_5 v12, v16, v20, v24, v28 // oc:0,1,0,1
@ -172,7 +169,7 @@ LoopDz8_TILE_10:
SET_0_5 v15, v19, v23, v27, v31 // oc:6,7,6,7 SET_0_5 v15, v19, v23, v27, v31 // oc:6,7,6,7
LoopSz_TILE_10: LoopSz_TILE_10:
ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x12], #64 // weight ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x2], #64 // weight
ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9 ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9
ld1 {v7.16b}, [x11], #16 ld1 {v7.16b}, [x11], #16
subs x13, x13, #1 subs x13, x13, #1
@ -202,7 +199,6 @@ LoopSz_TILE_10:
.inst 0x4e8ba4ff // smmla v31.4s, v7.16b, v11.16b // tile8-oc6, tile8-oc7, tile9-oc6, tile9-oc7 .inst 0x4e8ba4ff // smmla v31.4s, v7.16b, v11.16b // tile8-oc6, tile8-oc7, tile9-oc6, tile9-oc7
bne LoopSz_TILE_10 bne LoopSz_TILE_10
LoopSzEnd_TILE_10: LoopSzEnd_TILE_10:
add x2, x2, x15 // weight += dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
sub x5, x5, #2 // dz-2 sub x5, x5, #2 // dz-2
// transpose // transpose
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
@ -234,11 +230,11 @@ LoopSzEnd_TILE_10:
Int32ToFloat v16, v17, v18, v19 Int32ToFloat v16, v17, v18, v19
Tile10Quan: Tile10Quan:
ld1 {v20.4s, v21.4s}, [x8], #32 // scale ld1 {v20.4s, v21.4s}, [x2], #32 // scale
ld1 {v22.4s, v23.4s}, [x27], #32 // x kernel sum ld1 {v22.4s, v23.4s}, [x8], #32 // x kernel sum
ld1 {v24.d}[0], [x27] ld1 {v24.d}[0], [x8]
ld1 {v25.4s, v26.4s}, [x28], #32 // weight quan zeropoint ld1 {v25.4s, v26.4s}, [x2], #32 // weight quan zeropoint
sub x27, x27, #32 sub x8, x8, #32
MUL_SCALE v20, v0, v1, v2, v3 MUL_SCALE v20, v0, v1, v2, v3
MUL_SCALE v20, v4, v5, v6, v7 MUL_SCALE v20, v4, v5, v6, v7
MUL_SCALE v21, v10, v11, v12, v13 MUL_SCALE v21, v10, v11, v12, v13
@ -248,7 +244,7 @@ Tile10Quan:
fmul v18.4s, v18.4s, v21.4s fmul v18.4s, v18.4s, v21.4s
fmul v19.4s, v19.4s, v21.4s fmul v19.4s, v19.4s, v21.4s
cbz x10, TILE10_MLA cbz x26, TILE10_MLA
ld1 {v27.4s, v28.4s}, [x10], #32 ld1 {v27.4s, v28.4s}, [x10], #32
ld1 {v29.d}[0], [x10] ld1 {v29.d}[0], [x10]
MUL_EXTRA_SCALE v27, v0, v1, v2, v3 MUL_EXTRA_SCALE v27, v0, v1, v2, v3
@ -300,28 +296,31 @@ Tile10Quan:
fadd v9.4s, v9.4s, v20.4s fadd v9.4s, v9.4s, v20.4s
fadd v18.4s, v18.4s, v21.4s fadd v18.4s, v18.4s, v21.4s
fadd v19.4s, v19.4s, v21.4s fadd v19.4s, v19.4s, v21.4s
b TILE10_POST cbnz x0, TILE10_POST
b TILE10_L8_ACCUM_BUFFER
TILE10_ADD_DSTV: TILE10_ADD_DSTV:
// first batch10 // first batch10
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x19], #64
ld1 {v28.4s, v29.4s}, [x0], x4 ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x19], #64
ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23 ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23
ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27 ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27
fadd v8.4s, v8.4s, v28.4s ADD_FLOAT v8, v9, v10, v11, v28, v29, v30, v31
fadd v9.4s, v9.4s, v29.4s
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64
ld1 {v28.4s, v29.4s}, [x0]
ADD_FLOAT v10, v11, v12, v13, v20, v21, v22, v23
ADD_FLOAT v14, v15, v16, v17, v24, v25, v26, v27
fadd v18.4s, v18.4s, v28.4s
fadd v19.4s, v19.4s, v29.4s
sub x0, x0, #256 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64
sub x0, x0, x4 ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x19], #64
ADD_FLOAT v12, v13, v14, v15, v20, v21, v22, v23
ADD_FLOAT v16, v17, v18, v19, v24, v25, v26, v27
cbnz x0, TILE10_POST
TILE10_L8_ACCUM_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x15], #64
b Tile10LoopCheck
TILE10_POST: TILE10_POST:
cbz x14, TILE10_STORE cbz x14, TILE10_STORE
@ -337,10 +336,10 @@ Tile10Quan:
TILE10_STORE: TILE10_STORE:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
st1 {v8.4s, v9.4s}, [x0], x4 st1 {v8.4s, v9.4s}, [x0], x25
st1 {v10.4s, v11.4s, v12.4s, v13.4s}, [x0], #64 st1 {v10.4s, v11.4s, v12.4s, v13.4s}, [x0], #64
st1 {v14.4s, v15.4s, v16.4s, v17.4s}, [x0], #64 st1 {v14.4s, v15.4s, v16.4s, v17.4s}, [x0], #64
st1 {v18.4s, v19.4s}, [x0], x4 st1 {v18.4s, v19.4s}, [x0], x25
b Tile10LoopCheck b Tile10LoopCheck
Tile10QuanUseInt8: Tile10QuanUseInt8:
@ -406,18 +405,17 @@ Tile10LoopCheck:
LoopDz4_TILE_10: LoopDz4_TILE_10:
mov x11, x1 // src mov x11, x1 // src
mov x12, x2 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
SET_0_5 v12, v13, v16, v17, v20 SET_0_5 v12, v13, v16, v17, v20
SET_0_5 v21, v24, v25, v28, v29 SET_0_5 v21, v24, v25, v28, v29
LoopSz4_TILE_10: LoopSz4_TILE_10:
ld1 {v8.16b, v9.16b}, [x12] // weight ld1 {v8.16b, v9.16b}, [x2] // weight
ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9 ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9
ld1 {v7.16b}, [x11], #16 ld1 {v7.16b}, [x11], #16
subs x13, x13, #1 subs x13, x13, #1
add x12, x12, #64 // x12+lp*hp add x2, x2, #64 // x2+lp*hp
.inst 0x4e88a46c // smmla v12.4s, v3.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 .inst 0x4e88a46c // smmla v12.4s, v3.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1
.inst 0x4e89a46d // smmla v13.4s, v3.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 .inst 0x4e89a46d // smmla v13.4s, v3.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3
@ -452,16 +450,15 @@ LoopSz4End_TILE_10:
scvtf v9.4s, v9.4s scvtf v9.4s, v9.4s
Tile10Quan_L4: Tile10Quan_L4:
ld1 {v20.4s}, [x8] // scale ld1 {v20.4s}, [x2], #16 // scale
ld1 {v22.4s, v23.4s}, [x27], #32 // x kernel sum ld1 {v22.4s, v23.4s}, [x8], #32 // x kernel sum
ld1 {v24.d}[0], [x27] ld1 {v24.d}[0], [x8]
ld1 {v25.4s}, [x28] // weight quan zeropoint ld1 {v25.4s}, [x2] // weight quan zeropoint
MUL_SCALE v20, v0, v1, v2, v3 MUL_SCALE v20, v0, v1, v2, v3
MUL_SCALE v20, v4, v5, v6, v7 MUL_SCALE v20, v4, v5, v6, v7
fmul v8.4s, v8.4s, v20.4s fmul v8.4s, v8.4s, v20.4s
fmul v9.4s, v9.4s, v20.4s fmul v9.4s, v9.4s, v20.4s
cbz x26, TILE10_MLA_L4
cbz x10, TILE10_MLA_L4
ld1 {v27.4s, v28.4s}, [x10], #32 ld1 {v27.4s, v28.4s}, [x10], #32
ld1 {v29.d}[0], [x10] ld1 {v29.d}[0], [x10]
MUL_EXTRA_SCALE v27, v0, v1, v2, v3 MUL_EXTRA_SCALE v27, v0, v1, v2, v3
@ -480,7 +477,6 @@ Tile10Quan_L4:
MLA_WEIGHTZERO v7, v23, v25, 3 // tile:7, oc:0-3 MLA_WEIGHTZERO v7, v23, v25, 3 // tile:7, oc:0-3
MLA_WEIGHTZERO v8, v24, v25, 0 // tile:8, oc:0-3 MLA_WEIGHTZERO v8, v24, v25, 0 // tile:8, oc:0-3
MLA_WEIGHTZERO v9, v24, v25, 1 // tile:9, oc:0-3 MLA_WEIGHTZERO v9, v24, v25, 1 // tile:9, oc:0-3
//sub x4, x4, #128
cbnz w23, Tile10QuanUseInt8_L4 cbnz w23, Tile10QuanUseInt8_L4
@ -491,19 +487,25 @@ Tile10Quan_L4:
ADD_BIAS_FLOAT v4, v5, v6, v7, v20 ADD_BIAS_FLOAT v4, v5, v6, v7, v20
fadd v8.4s, v8.4s, v20.4s fadd v8.4s, v8.4s, v20.4s
fadd v9.4s, v9.4s, v20.4s fadd v9.4s, v9.4s, v20.4s
cbz x0, TILE10_L4_ACCUM_BUFFER
b TILE10_POST_L4 b TILE10_POST_L4
TILE10_ADD_DSTV_L4: TILE10_ADD_DSTV_L4:
// first batch10 // first batch10
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x19], #64
ld1 {v28.4s, v29.4s}, [x0] ld1 {v28.4s, v29.4s}, [x19], #32
ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23 ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23
ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27 ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27
fadd v8.4s, v8.4s, v28.4s fadd v8.4s, v8.4s, v28.4s
fadd v9.4s, v9.4s, v29.4s fadd v9.4s, v9.4s, v29.4s
cbnz x0, TILE10_POST_L4
sub x0, x0, #128 TILE10_L4_ACCUM_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
st1 {v8.4s, v9.4s}, [x15], #32
b End
TILE10_POST_L4: TILE10_POST_L4:
cbz x14, TILE10_STORE_L4 cbz x14, TILE10_STORE_L4
@ -520,7 +522,7 @@ Tile10Quan_L4:
TILE10_STORE_L4: TILE10_STORE_L4:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
st1 {v8.4s, v9.4s}, [x0], x4 st1 {v8.4s, v9.4s}, [x0], x25
b End b End
Tile10QuanUseInt8_L4: Tile10QuanUseInt8_L4:
@ -578,21 +580,17 @@ TILE_8:
TILE_Remain: TILE_Remain:
cmp x7, #8 cmp x7, #8
blt TILE_4 blt TILE_4
cbnz w23, TILE8_START sub x25, x4, #64 // For float32 output, add #64 when tile8 end.
sub x4, x4, #64 // For float32 output, add #64 when tile8 end.
TILE8_START: TILE8_START:
mov x24, x5 // dst_depth_quad mov x24, x5 // dst_depth_quad
mov x26, x0 // dst mov x12, x2 // weight
mov x25, x2 // weight mov x6, x0
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x28 // weightQuanBias
cmp x5, #2 cmp x5, #2
blt LoopDz4_TILE_8 blt LoopDz4_TILE_8
LoopDz_TILE_8: LoopDz_TILE_8:
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
SET_0_4 v12, v16, v20, v24 SET_0_4 v12, v16, v20, v24
SET_0_4 v13, v17, v21, v25 SET_0_4 v13, v17, v21, v25
@ -624,7 +622,6 @@ LoopSz_TILE_8:
bne LoopSz_TILE_8 bne LoopSz_TILE_8
LoopSzEnd_TILE_8: LoopSzEnd_TILE_8:
add x25, x25, x15
sub x24, x24, #2 // dz-2 sub x24, x24, #2 // dz-2
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
@ -651,15 +648,14 @@ LoopSzEnd_TILE_8:
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
Tile8Quan: Tile8Quan:
ld1 {v20.4s, v21.4s}, [x19], #32 // scale ld1 {v20.4s, v21.4s}, [x12], #32 // scale
ld1 {v22.4s, v23.4s}, [x27] // x kernel sum ld1 {v22.4s, v23.4s}, [x8] // x kernel sum
ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint ld1 {v25.4s, v26.4s}, [x12], #32 // weight quan zeropoint
MUL_SCALE v20, v0, v1, v2, v3 MUL_SCALE v20, v0, v1, v2, v3
MUL_SCALE v20, v4, v5, v6, v7 MUL_SCALE v20, v4, v5, v6, v7
MUL_SCALE v21, v8, v9, v10, v11 MUL_SCALE v21, v8, v9, v10, v11
MUL_SCALE v21, v12, v13, v14, v15 MUL_SCALE v21, v12, v13, v14, v15
cbz x26, TILE8_MLA
cbz x10, TILE8_MLA
ld1 {v27.4s, v28.4s}, [x10] ld1 {v27.4s, v28.4s}, [x10]
MUL_EXTRA_SCALE v27, v0, v1, v2, v3 MUL_EXTRA_SCALE v27, v0, v1, v2, v3
MUL_EXTRA_SCALE v28, v4, v5, v6, v7 MUL_EXTRA_SCALE v28, v4, v5, v6, v7
@ -694,19 +690,26 @@ Tile8Quan:
ADD_BIAS_FLOAT v4, v5, v6, v7, v16 ADD_BIAS_FLOAT v4, v5, v6, v7, v16
ADD_BIAS_FLOAT v8, v9, v10, v11, v17 ADD_BIAS_FLOAT v8, v9, v10, v11, v17
ADD_BIAS_FLOAT v12, v13, v14, v15, v17 ADD_BIAS_FLOAT v12, v13, v14, v15, v17
b TILE8_POST cbnz x0, TILE8_POST
b TILE8_L8_ACCUM_BUFFER
TILE8_ADD_DSTV: TILE8_ADD_DSTV:
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x26], #64 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x26], x4 ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x19], #64
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x26], #64 ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x19], #64
ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23 ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23
ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27 ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x26] ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64
ADD_FLOAT v8, v9, v10, v11, v16, v17, v18, v19 ADD_FLOAT v8, v9, v10, v11, v16, v17, v18, v19
ADD_FLOAT v12, v13, v14, v15, v20, v21, v22, v23 ADD_FLOAT v12, v13, v14, v15, v20, v21, v22, v23
sub x26, x26, x4 cbnz x0, TILE8_POST
sub x26, x26, #128
TILE8_L8_ACCUM_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
b Tile8LoopCheck
TILE8_POST: TILE8_POST:
cbz x14, TILE8_STORE cbz x14, TILE8_STORE
@ -716,10 +719,10 @@ Tile8Quan:
ReLU_FP32 v12, v13, v14, v15, v30, v31 ReLU_FP32 v12, v13, v14, v15, v30, v31
TILE8_STORE: TILE8_STORE:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x26], #64 st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x6], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x26], x4 st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x6], x25
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x26], #64 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x26], x4 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], x25
b Tile8LoopCheck b Tile8LoopCheck
Tile8QuanUseInt8: Tile8QuanUseInt8:
@ -749,8 +752,8 @@ Tile8Quan:
smin v29.16b, v31.16b, v29.16b smin v29.16b, v31.16b, v29.16b
smin v18.16b, v31.16b, v18.16b smin v18.16b, v31.16b, v18.16b
smin v19.16b, v31.16b, v19.16b smin v19.16b, v31.16b, v19.16b
st1 {v28.16b, v29.16b}, [x26], x4 st1 {v28.16b, v29.16b}, [x6], x4
st1 {v18.16b, v19.16b}, [x26], x4 // dst += dz * dst_step st1 {v18.16b, v19.16b}, [x6], x4 // dst += dz * dst_step
Tile8LoopCheck: Tile8LoopCheck:
cmp x24, #2 cmp x24, #2
bge LoopDz_TILE_8 bge LoopDz_TILE_8
@ -758,7 +761,6 @@ Tile8LoopCheck:
LoopDz4_TILE_8: LoopDz4_TILE_8:
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
SET_0_4 v12, v13, v16, v17 SET_0_4 v12, v13, v16, v17
SET_0_4 v20, v21, v24, v25 SET_0_4 v20, v21, v24, v25
@ -781,7 +783,6 @@ LoopSz4_TILE_8:
bne LoopSz4_TILE_8 bne LoopSz4_TILE_8
LoopSz4End_TILE_8: LoopSz4End_TILE_8:
add x25, x25, x15
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
uzp1 v2.2d, v16.2d, v17.2d // E2: oc:0-3 uzp1 v2.2d, v16.2d, v17.2d // E2: oc:0-3
@ -794,13 +795,12 @@ LoopSz4End_TILE_8:
Int32ToFloat v4, v5, v6, v7 Int32ToFloat v4, v5, v6, v7
Tile8Quan_L4: Tile8Quan_L4:
ld1 {v20.4s}, [x19] // scale ld1 {v20.4s}, [x12], #16 // scale
ld1 {v22.4s, v23.4s}, [x27] // x kernel sum ld1 {v22.4s, v23.4s}, [x8] // x kernel sum
ld1 {v25.4s}, [x6] // weight quan zeropoint ld1 {v25.4s}, [x12] // weight quan zeropoint
MUL_SCALE v20, v0, v1, v2, v3 MUL_SCALE v20, v0, v1, v2, v3
MUL_SCALE v20, v4, v5, v6, v7 MUL_SCALE v20, v4, v5, v6, v7
cbz x26, TILE8_MLA_L4
cbz x10, TILE8_MLA_L4
ld1 {v27.4s, v28.4s}, [x10] ld1 {v27.4s, v28.4s}, [x10]
MUL_EXTRA_SCALE v27, v0, v1, v2, v3 MUL_EXTRA_SCALE v27, v0, v1, v2, v3
MUL_EXTRA_SCALE v28, v4, v5, v6, v7 MUL_EXTRA_SCALE v28, v4, v5, v6, v7
@ -822,14 +822,20 @@ Tile8Quan_L4:
ld1 {v16.4s}, [x20] ld1 {v16.4s}, [x20]
ADD_BIAS_FLOAT v0, v1, v2, v3, v16 ADD_BIAS_FLOAT v0, v1, v2, v3, v16
ADD_BIAS_FLOAT v4, v5, v6, v7, v16 ADD_BIAS_FLOAT v4, v5, v6, v7, v16
cbz x0, TILE8_L4_ACCUM_BUFFER
b TILE8_POST_L4 b TILE8_POST_L4
TILE8_ADD_DSTV_L4: TILE8_ADD_DSTV_L4:
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x26], #64 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x19], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x26] ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x19], #64
ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23 ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23
ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27 ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27
sub x26, x26, #64 cbnz x0, TILE8_POST_L4
TILE8_L4_ACCUM_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
b Tile8Check
TILE8_POST_L4: TILE8_POST_L4:
cbz x14, TILE8_STORE_L4 cbz x14, TILE8_STORE_L4
@ -837,8 +843,8 @@ Tile8Quan_L4:
ReLU_FP32 v4, v5, v6, v7, v30, v31 ReLU_FP32 v4, v5, v6, v7, v30, v31
TILE8_STORE_L4: TILE8_STORE_L4:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x26], #64 st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x6], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x26], x4 st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x6], x25
b Tile8Check b Tile8Check
Tile8QuanUseInt8_L4: Tile8QuanUseInt8_L4:
@ -857,34 +863,30 @@ Tile8Quan_L4:
smax v17.16b, v30.16b, v17.16b smax v17.16b, v30.16b, v17.16b
smin v16.16b, v31.16b, v16.16b smin v16.16b, v31.16b, v16.16b
smin v17.16b, v31.16b, v17.16b smin v17.16b, v31.16b, v17.16b
st1 {v16.16b, v17.16b}, [x26], x4 st1 {v16.16b, v17.16b}, [x6], x4
Tile8Check: Tile8Check:
cbz x10, Tile8End cbz x0, Tile8End
add x10, x10, #32 add x0, x0, x21, LSL #3
Tile8End: Tile8End:
sub x7, x7, #8 sub x7, x7, #8
add x0, x0, x21, LSL #3
add x1, x1, #64 add x1, x1, #64
add x27, x27, #32 add x8, x8, #32
add x10, x10, #32
cbnz w23, TILE_4 cbnz w23, TILE_4
add x4, x4, #64 // Revert x4 for following tile.
TILE_4: TILE_4:
cmp x7, #4 cmp x7, #4
blt TILE_2 blt TILE_2
mov x24, x5 // dst_depth_quad mov x24, x5 // dst_depth_quad
mov x26, x0 // dst mov x6, x0 // dst
mov x25, x2 // weight mov x12, x2 // weight
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x28 // weightQuanBias
cmp x5, #2 cmp x5, #2
blt LoopDz4_TILE_4 blt LoopDz4_TILE_4
LoopDz_TILE_4: LoopDz_TILE_4:
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
SET_0_4 v12, v13, v14, v15 SET_0_4 v12, v13, v14, v15
SET_0_4 v16, v17, v18, v19 SET_0_4 v16, v17, v18, v19
@ -904,7 +906,6 @@ LoopSz_TILE_4:
.inst 0x4e8ba4b3 // smmla v19.4s, v5.16b, v11.16b // tile2-oc6, tile2-oc7, tile3-oc6, tile3-oc7 .inst 0x4e8ba4b3 // smmla v19.4s, v5.16b, v11.16b // tile2-oc6, tile2-oc7, tile3-oc6, tile3-oc7
bne LoopSz_TILE_4 bne LoopSz_TILE_4
LoopSzEnd_TILE_4: LoopSzEnd_TILE_4:
add x25, x25, x15
sub x24, x24, #2 sub x24, x24, #2
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
@ -919,13 +920,12 @@ LoopSzEnd_TILE_4:
Int32ToFloat v4, v5, v6, v7 Int32ToFloat v4, v5, v6, v7
Tile4Quan: Tile4Quan:
ld1 {v20.4s, v21.4s}, [x19], #32 // scale ld1 {v20.4s, v21.4s}, [x12], #32 // scale
ld1 {v22.4s}, [x27] // x kernel sum ld1 {v22.4s}, [x8] // x kernel sum
ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint ld1 {v25.4s, v26.4s}, [x12], #32 // weight quan zeropoint
MUL_SCALE v20, v0, v1, v2, v3 MUL_SCALE v20, v0, v1, v2, v3
MUL_SCALE v21, v4, v5, v6, v7 MUL_SCALE v21, v4, v5, v6, v7
cbz x26, TILE4_MLA
cbz x10, TILE4_MLA
ld1 {v27.4s}, [x10] ld1 {v27.4s}, [x10]
MUL_EXTRA_SCALE v27, v0, v1, v2, v3 MUL_EXTRA_SCALE v27, v0, v1, v2, v3
MUL_EXTRA_SCALE v27, v4, v5, v6, v7 MUL_EXTRA_SCALE v27, v4, v5, v6, v7
@ -947,14 +947,20 @@ Tile4Quan:
ld1 {v16.4s, v17.4s}, [x20], #32 // bias ld1 {v16.4s, v17.4s}, [x20], #32 // bias
ADD_BIAS_FLOAT v0, v1, v2, v3, v16 ADD_BIAS_FLOAT v0, v1, v2, v3, v16
ADD_BIAS_FLOAT v4, v5, v6, v7, v17 ADD_BIAS_FLOAT v4, v5, v6, v7, v17
b TILE4_POST cbnz x0, TILE4_POST
b TILE4_L8_ACCUM_BUFFER
TILE4_ADD_DSTV: TILE4_ADD_DSTV:
ld1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x26], x4 ld1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x19], #64
ld1 {v19.4s, v20.4s, v21.4s, v22.4s}, [x26] ld1 {v19.4s, v20.4s, v21.4s, v22.4s}, [x19], #64
ADD_FLOAT v0, v1, v2, v3, v15, v16, v17, v18 ADD_FLOAT v0, v1, v2, v3, v15, v16, v17, v18
ADD_FLOAT v4, v5, v6, v7, v19, v20, v21, v22 ADD_FLOAT v4, v5, v6, v7, v19, v20, v21, v22
sub x26, x26, x4 cbnz x0, TILE4_POST
TILE4_L8_ACCUM_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
b Tile4LoopCheck
TILE4_POST: TILE4_POST:
cbz x14, TILE4_STORE cbz x14, TILE4_STORE
@ -962,8 +968,8 @@ Tile4Quan:
ReLU_FP32 v4, v5, v6, v7, v30, v31 ReLU_FP32 v4, v5, v6, v7, v30, v31
TILE4_STORE: TILE4_STORE:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x26], x4 st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x6], x4
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x26], x4 st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x6], x4
b Tile4LoopCheck b Tile4LoopCheck
Tile4QuanUseInt8: Tile4QuanUseInt8:
@ -980,8 +986,8 @@ Tile4Quan:
smin v19.16b, v31.16b, v19.16b smin v19.16b, v31.16b, v19.16b
smax v20.16b, v30.16b, v20.16b smax v20.16b, v30.16b, v20.16b
smin v20.16b, v31.16b, v20.16b smin v20.16b, v31.16b, v20.16b
st1 {v19.16b}, [x26], x4 // dst += dz * dst_step st1 {v19.16b}, [x6], x4 // dst += dz * dst_step
st1 {v20.16b}, [x26], x4 st1 {v20.16b}, [x6], x4
Tile4LoopCheck: Tile4LoopCheck:
cmp x24, #2 cmp x24, #2
bge LoopDz_TILE_4 bge LoopDz_TILE_4
@ -989,7 +995,6 @@ Tile4LoopCheck:
LoopDz4_TILE_4: LoopDz4_TILE_4:
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
SET_0_4 v12, v13, v16, v17 SET_0_4 v12, v13, v16, v17
LoopSz4_TILE_4: LoopSz4_TILE_4:
@ -1004,7 +1009,6 @@ LoopSz4_TILE_4:
.inst 0x4e89a4b1 // smmla v17.4s, v5.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 .inst 0x4e89a4b1 // smmla v17.4s, v5.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3
bne LoopSz4_TILE_4 bne LoopSz4_TILE_4
LoopSz4End_TILE_4: LoopSz4End_TILE_4:
add x25, x25, x15
sub x24, x24, #1 sub x24, x24, #1
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
@ -1013,12 +1017,11 @@ LoopSz4End_TILE_4:
Int32ToFloat v0, v1, v2, v3 Int32ToFloat v0, v1, v2, v3
Tile4Quan_L4: Tile4Quan_L4:
ld1 {v20.4s}, [x19] // scale ld1 {v20.4s}, [x12], #16 // scale
ld1 {v22.4s}, [x27] // x kernel sum ld1 {v22.4s}, [x8] // x kernel sum
ld1 {v25.4s}, [x6] // weight quan zeropoint ld1 {v25.4s}, [x12], #16 // weight quan zeropoint
MUL_SCALE v20, v0, v1, v2, v3 MUL_SCALE v20, v0, v1, v2, v3
cbz x26, TILE4_MLA_L4
cbz x10, TILE4_MLA_L4
ld1 {v27.4s}, [x10] ld1 {v27.4s}, [x10]
MUL_EXTRA_SCALE v27, v0, v1, v2, v3 MUL_EXTRA_SCALE v27, v0, v1, v2, v3
@ -1034,18 +1037,24 @@ Tile4Quan_L4:
cbz x9, TILE4_ADD_DSTV_L4 cbz x9, TILE4_ADD_DSTV_L4
ld1 {v16.4s}, [x20] // bias ld1 {v16.4s}, [x20] // bias
ADD_BIAS_FLOAT v0, v1, v2, v3, v16 ADD_BIAS_FLOAT v0, v1, v2, v3, v16
b TILE4_POST_L4 cbnz x0, TILE4_POST_L4
b TILE4_L4_ACCUM_BUFFER
TILE4_ADD_DSTV_L4: TILE4_ADD_DSTV_L4:
ld1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x26] ld1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x19], #64
ADD_FLOAT v0, v1, v2, v3, v15, v16, v17, v18 ADD_FLOAT v0, v1, v2, v3, v15, v16, v17, v18
cbnz x0, TILE4_POST_L4
TILE4_L4_ACCUM_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
b Tile4Check
TILE4_POST_L4: TILE4_POST_L4:
cbz x14, TILE4_STORE_L4 cbz x14, TILE4_STORE_L4
ReLU_FP32 v0, v1, v2, v3, v30, v31 ReLU_FP32 v0, v1, v2, v3, v30, v31
TILE4_STORE_L4: TILE4_STORE_L4:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x26], x4 st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x6], x4
b Tile4Check b Tile4Check
Tile4QuanUseInt8_L4: Tile4QuanUseInt8_L4:
@ -1056,31 +1065,28 @@ Tile4Quan_L4:
Int16ToInt8_ONE v8, v9, v19 Int16ToInt8_ONE v8, v9, v19
smax v19.16b, v30.16b, v19.16b smax v19.16b, v30.16b, v19.16b
smin v19.16b, v31.16b, v19.16b smin v19.16b, v31.16b, v19.16b
st1 {v19.16b}, [x26], x4 // dst += dz * dst_step st1 {v19.16b}, [x6], x4 // dst += dz * dst_step
Tile4Check: Tile4Check:
cbz x10, Tile4End cbz x0, Tile4End
add x10, x10, #16 add x0, x0, x21, LSL #2
Tile4End: Tile4End:
sub x7, x7, #4 sub x7, x7, #4
add x0, x0, x21, LSL #2
add x1, x1, #32 add x1, x1, #32
add x27, x27, #16 add x8, x8, #16
add x10, x10, #16
TILE_2: TILE_2:
cmp x7, #2 cmp x7, #2
blt TILE_1 blt TILE_1
mov x24, x5 // dst_depth_quad mov x24, x5 // dst_depth_quad
mov x26, x0 // dst mov x6, x0 // dst
mov x25, x2 // weight mov x12, x2 // weight
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x28 // weightQuanBias
cmp x5, #2 cmp x5, #2
blt LoopDz4_TILE_2 blt LoopDz4_TILE_2
LoopDz_TILE_2: LoopDz_TILE_2:
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
SET_0_4 v12, v13, v14, v15 SET_0_4 v12, v13, v14, v15
LoopSz_TILE_2: LoopSz_TILE_2:
@ -1093,7 +1099,6 @@ LoopSz_TILE_2:
subs x13, x13, #1 subs x13, x13, #1
bne LoopSz_TILE_2 bne LoopSz_TILE_2
LoopSzEnd_TILE_2: LoopSzEnd_TILE_2:
add x25, x25, x15
sub x24, x24, #2 sub x24, x24, #2
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
@ -1102,15 +1107,14 @@ LoopSzEnd_TILE_2:
Int32ToFloat v0, v1, v2, v3 Int32ToFloat v0, v1, v2, v3
Tile2Quan: Tile2Quan:
ld1 {v20.4s, v21.4s}, [x19], #32 // scale ld1 {v20.4s, v21.4s}, [x12], #32 // scale
ld1 {v22.d}[0], [x27] // x kernel sum ld1 {v22.d}[0], [x8] // x kernel sum
ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint ld1 {v25.4s, v26.4s}, [x12], #32 // weight quan zeropoint
fmul v0.4s, v0.4s, v20.4s fmul v0.4s, v0.4s, v20.4s
fmul v1.4s, v1.4s, v20.4s fmul v1.4s, v1.4s, v20.4s
fmul v2.4s, v2.4s, v21.4s fmul v2.4s, v2.4s, v21.4s
fmul v3.4s, v3.4s, v21.4s fmul v3.4s, v3.4s, v21.4s
cbz x26, TILE2_MLA
cbz x10, TILE2_MLA
ld1 {v27.d}[0], [x10] ld1 {v27.d}[0], [x10]
fmul v0.4s, v0.4s, v27.s[0] fmul v0.4s, v0.4s, v27.s[0]
fmul v1.4s, v1.4s, v27.s[1] fmul v1.4s, v1.4s, v27.s[1]
@ -1132,23 +1136,24 @@ Tile2Quan:
fadd v1.4s, v1.4s, v16.4s fadd v1.4s, v1.4s, v16.4s
fadd v2.4s, v2.4s, v17.4s fadd v2.4s, v2.4s, v17.4s
fadd v3.4s, v3.4s, v17.4s fadd v3.4s, v3.4s, v17.4s
b TILE2_POST cbnz x0, TILE2_POST
b TILE2_L8_ACCUM_BUFFER
TILE2_ADD_DSTV: TILE2_ADD_DSTV:
ld1 {v18.4s, v19.4s}, [x26], x4 ld1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x19], #64
ld1 {v20.4s, v21.4s}, [x26] ADD_FLOAT v0, v1, v2, v3, v15, v16, v17, v18
fadd v0.4s, v0.4s, v18.4s cbnz x0, TILE2_POST
fadd v1.4s, v1.4s, v19.4s
fadd v2.4s, v2.4s, v20.4s TILE2_L8_ACCUM_BUFFER:
fadd v3.4s, v3.4s, v21.4s st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
sub x26, x26, x4 b Tile2LoopCheck
TILE2_POST: TILE2_POST:
cbz x14, TILE2_STORE cbz x14, TILE2_STORE
ReLU_FP32 v0, v1, v2, v3, v30, v31 ReLU_FP32 v0, v1, v2, v3, v30, v31
TILE2_STORE: TILE2_STORE:
st1 {v0.4s, v1.4s}, [x26], x4 st1 {v0.4s, v1.4s}, [x6], x4
st1 {v2.4s, v3.4s}, [x26], x4 st1 {v2.4s, v3.4s}, [x6], x4
b Tile2LoopCheck b Tile2LoopCheck
Tile2QuanUseInt8: Tile2QuanUseInt8:
@ -1171,8 +1176,8 @@ Tile2Quan:
smin v19.8b, v31.8b, v19.8b smin v19.8b, v31.8b, v19.8b
smax v20.8b, v30.8b, v20.8b smax v20.8b, v30.8b, v20.8b
smin v20.8b, v31.8b, v20.8b smin v20.8b, v31.8b, v20.8b
st1 {v19.8b}, [x26], x4 // dst += dz * dst_step st1 {v19.8b}, [x6], x4 // dst += dz * dst_step
st1 {v20.8b}, [x26], x4 st1 {v20.8b}, [x6], x4
Tile2LoopCheck: Tile2LoopCheck:
cmp x24, #2 cmp x24, #2
@ -1180,7 +1185,6 @@ Tile2LoopCheck:
cbz x24, Tile2Check cbz x24, Tile2Check
LoopDz4_TILE_2: LoopDz4_TILE_2:
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
movi v12.4s, #0 movi v12.4s, #0
movi v13.4s, #0 movi v13.4s, #0
@ -1194,20 +1198,18 @@ LoopSz4_TILE_2:
add x12, x12, #64 add x12, x12, #64
bne LoopSz4_TILE_2 bne LoopSz4_TILE_2
LoopSz4End_TILE_2: LoopSz4End_TILE_2:
add x25, x25, x15
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
scvtf v0.4s, v0.4s scvtf v0.4s, v0.4s
scvtf v1.4s, v1.4s scvtf v1.4s, v1.4s
Tile2Quan_L4: Tile2Quan_L4:
ld1 {v20.4s}, [x19] ld1 {v20.4s}, [x12], #16
ld1 {v22.d}[0], [x27] // x kernel sum ld1 {v22.d}[0], [x8] // x kernel sum
ld1 {v25.4s}, [x6] // weight quan zeropoint ld1 {v25.4s}, [x12] // weight quan zeropoint
fmul v0.4s, v0.4s, v20.4s fmul v0.4s, v0.4s, v20.4s
fmul v1.4s, v1.4s, v20.4s fmul v1.4s, v1.4s, v20.4s
cbz x26, TILE2_MLA_L4
cbz x10, TILE2_MLA_L4
ld1 {v27.d}[0], [x10] ld1 {v27.d}[0], [x10]
fmul v0.4s, v0.4s, v27.s[0] fmul v0.4s, v0.4s, v27.s[0]
fmul v1.4s, v1.4s, v27.s[1] fmul v1.4s, v1.4s, v27.s[1]
@ -1223,18 +1225,24 @@ Tile2Quan_L4:
ld1 {v16.4s}, [x20] // bias ld1 {v16.4s}, [x20] // bias
fadd v0.4s, v0.4s, v16.4s fadd v0.4s, v0.4s, v16.4s
fadd v1.4s, v1.4s, v16.4s fadd v1.4s, v1.4s, v16.4s
b TILE2_POST_L4 cbnz x0, TILE2_POST_L4
b TILE2_L4_ACCUM_BUFFER
TILE2_ADD_DSTV_L4: TILE2_ADD_DSTV_L4:
ld1 {v18.4s, v19.4s}, [x26] ld1 {v15.4s, v16.4s}, [x19], #32
fadd v0.4s, v0.4s, v18.4s fadd v0.4s, v0.4s, v15.4s
fadd v1.4s, v1.4s, v19.4s fadd v1.4s, v1.4s, v16.4s
cbnz x0, TILE2_POST_L4
TILE2_L4_ACCUM_BUFFER:
st1 {v0.4s, v1.4s}, [x15], #32
b Tile2Check
TILE2_POST_L4: TILE2_POST_L4:
cbz x14, TILE2_STORE_L4 cbz x14, TILE2_STORE_L4
ReLU_FP32_2 v0, v1, v30, v31 ReLU_FP32_2 v0, v1, v30, v31
TILE2_STORE_L4: TILE2_STORE_L4:
st1 {v0.4s, v1.4s}, [x26], x4 st1 {v0.4s, v1.4s}, [x6], x4
b Tile2Check b Tile2Check
Tile2QuanUseInt8_L4: Tile2QuanUseInt8_L4:
@ -1248,32 +1256,29 @@ Tile2Quan_L4:
sqxtn v19.8b, v6.8h sqxtn v19.8b, v6.8h
smax v19.8b, v30.8b, v19.8b smax v19.8b, v30.8b, v19.8b
smin v19.8b, v31.8b, v19.8b smin v19.8b, v31.8b, v19.8b
st1 {v19.8b}, [x26], x4 // dst += dz * dst_step st1 {v19.8b}, [x6], x4 // dst += dz * dst_step
Tile2Check: Tile2Check:
cbz x10, Tile2End cbz x0, Tile2End
add x10, x10, #8 add x0, x0, x21, LSL #1
Tile2End: Tile2End:
sub x7, x7, #2 sub x7, x7, #2
add x0, x0, x21, LSL #1
add x1, x1, #16 add x1, x1, #16
add x27, x27, #8 add x8, x8, #8
add x10, x10, #8
TILE_1: TILE_1:
cmp x7, #1 cmp x7, #1
blt End blt End
mov x24, x5 // dst_depth_quad mov x24, x5 // dst_depth_quad
mov x26, x0 // dst mov x6, x0 // dst
mov x25, x2 // weight mov x12, x2 // weight
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x28 // weightQuanBias
cmp x5, #2 cmp x5, #2
blt LoopDz4_TILE_1 blt LoopDz4_TILE_1
LoopDz_TILE_1: LoopDz_TILE_1:
//ld1 {v0.4s}, [x20], #16 // bias //ld1 {v0.4s}, [x20], #16 // bias
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
movi v16.4s, #0 movi v16.4s, #0
@ -1291,27 +1296,25 @@ LoopSz_TILE_1:
.inst 0x4e8ba453 // smmla v19.4s, v2.16b, v11.16b .inst 0x4e8ba453 // smmla v19.4s, v2.16b, v11.16b
bne LoopSz_TILE_1 bne LoopSz_TILE_1
LoopSzEnd_TILE_1: LoopSzEnd_TILE_1:
add x25, x25, x15
sub x24, x24, #2 sub x24, x24, #2
uzp1 v27.2d, v16.2d, v17.2d uzp1 v25.2d, v16.2d, v17.2d
uzp1 v26.2d, v18.2d, v19.2d uzp1 v26.2d, v18.2d, v19.2d
scvtf v27.4s, v27.4s scvtf v25.4s, v25.4s
scvtf v26.4s, v26.4s scvtf v26.4s, v26.4s
Tile1Quan: Tile1Quan:
ld1 {v0.4s, v1.4s}, [x19], #32 // scale ld1 {v0.4s, v1.4s}, [x12], #32 // scale
ld1 {v6.s}[0], [x27] // x kernel sum ld1 {v6.s}[0], [x8] // x kernel sum
ld1 {v8.4s, v9.4s}, [x6], #32 // weight quan zeropoint ld1 {v8.4s, v9.4s}, [x12], #32 // weight quan zeropoint
fmul v27.4s, v27.4s, v0.4s fmul v25.4s, v25.4s, v0.4s
fmul v26.4s, v26.4s, v1.4s fmul v26.4s, v26.4s, v1.4s
cbz x26, TILE1_MLA
cbz x10, TILE1_MLA
ld1 {v10.s}[0], [x10] ld1 {v10.s}[0], [x10]
fmul v27.4s, v27.4s, v10.s[0] fmul v25.4s, v25.4s, v10.s[0]
fmul v26.4s, v26.4s, v10.s[0] fmul v26.4s, v26.4s, v10.s[0]
TILE1_MLA: TILE1_MLA:
MLA_WEIGHTZERO v27, v6, v8, 0 // tile:0, oc:0-3 MLA_WEIGHTZERO v25, v6, v8, 0 // tile:0, oc:0-3
MLA_WEIGHTZERO v26, v6, v9, 0 // tile:0, oc:4-7 MLA_WEIGHTZERO v26, v6, v9, 0 // tile:0, oc:4-7
cbnz w23, Tile1QuanUseInt8 cbnz w23, Tile1QuanUseInt8
@ -1319,36 +1322,40 @@ Tile1Quan:
TILE1_ADD_BIAS: TILE1_ADD_BIAS:
cbz x9, TILE1_ADD_DSTV cbz x9, TILE1_ADD_DSTV
ld1 {v16.4s, v17.4s}, [x20], #32 // bias ld1 {v16.4s, v17.4s}, [x20], #32 // bias
fadd v27.4s, v27.4s, v16.4s fadd v25.4s, v25.4s, v16.4s
fadd v26.4s, v26.4s, v17.4s fadd v26.4s, v26.4s, v17.4s
b TILE1_POST cbnz x0, TILE1_POST
b TILE1_L8_ACCUM_BUFFER
TILE1_ADD_DSTV: TILE1_ADD_DSTV:
ld1 {v16.4s}, [x26], x4 ld1 {v15.4s, v16.4s}, [x19], #32
ld1 {v17.4s}, [x26] fadd v25.4s, v25.4s, v15.4s
fadd v27.4s, v27.4s, v16.4s fadd v26.4s, v26.4s, v16.4s
fadd v26.4s, v26.4s, v17.4s cbnz x0, TILE1_POST
sub x26, x26, x4
TILE1_L8_ACCUM_BUFFER:
st1 {v25.4s, v26.4s}, [x15], #32
b Tile1LoopEnd
TILE1_POST: TILE1_POST:
cbz x14, TILE1_STORE cbz x14, TILE1_STORE
fmin v27.4s, v27.4s, v31.4s fmin v25.4s, v25.4s, v31.4s
fmax v27.4s, v27.4s, v30.4s fmax v25.4s, v25.4s, v30.4s
fmin v26.4s, v26.4s, v31.4s fmin v26.4s, v26.4s, v31.4s
fmax v26.4s, v26.4s, v30.4s fmax v26.4s, v26.4s, v30.4s
TILE1_STORE: TILE1_STORE:
st1 {v27.4s}, [x26], x4 st1 {v25.4s}, [x6], x4
st1 {v26.4s}, [x26], x4 st1 {v26.4s}, [x6], x4
b Tile1LoopEnd b Tile1LoopEnd
Tile1QuanUseInt8: Tile1QuanUseInt8:
ld1 {v16.4s, v17.4s}, [x20], #32 // bias ld1 {v16.4s, v17.4s}, [x20], #32 // bias
fadd v27.4s, v27.4s, v16.4s fadd v25.4s, v25.4s, v16.4s
fadd v26.4s, v26.4s, v17.4s fadd v26.4s, v26.4s, v17.4s
fcvtas v27.4s, v27.4s fcvtas v25.4s, v25.4s
fcvtas v26.4s, v26.4s fcvtas v26.4s, v26.4s
sqxtn v6.4h, v27.4s sqxtn v6.4h, v25.4s
sqxtn v7.4h, v26.4s sqxtn v7.4h, v26.4s
sqxtn v6.8b, v6.8h sqxtn v6.8b, v6.8h
sqxtn v7.8b, v7.8h sqxtn v7.8b, v7.8h
@ -1356,8 +1363,8 @@ Tile1Quan:
smin v6.16b, v31.16b, v6.16b smin v6.16b, v31.16b, v6.16b
smax v7.16b, v30.16b, v7.16b smax v7.16b, v30.16b, v7.16b
smin v7.16b, v31.16b, v7.16b smin v7.16b, v31.16b, v7.16b
st1 {v6.s}[0], [x26], x4 // dst += dz * dst_step st1 {v6.s}[0], [x6], x4 // dst += dz * dst_step
st1 {v7.s}[0], [x26], x4 st1 {v7.s}[0], [x6], x4
Tile1LoopEnd: Tile1LoopEnd:
cmp x24, #2 cmp x24, #2
@ -1366,7 +1373,6 @@ Tile1LoopEnd:
LoopDz4_TILE_1: LoopDz4_TILE_1:
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
movi v16.4s, #0 movi v16.4s, #0
@ -1380,16 +1386,15 @@ LoopSz4_TILE_1:
.inst 0x4e89a451 // smmla v17.4s, v2.16b, v9.16b .inst 0x4e89a451 // smmla v17.4s, v2.16b, v9.16b
bne LoopSz4_TILE_1 bne LoopSz4_TILE_1
LoopSz4End_TILE_1: LoopSz4End_TILE_1:
add x25, x25, x15
uzp1 v27.2d, v16.2d, v17.2d uzp1 v27.2d, v16.2d, v17.2d
scvtf v27.4s, v27.4s scvtf v27.4s, v27.4s
Tile1Quan_L4: Tile1Quan_L4:
ld1 {v0.4s}, [x19] // scale ld1 {v0.4s}, [x12], #16 // scale
ld1 {v6.s}[0], [x27] // x kernel sum ld1 {v6.s}[0], [x8] // x kernel sum
ld1 {v8.4s}, [x6] // weight quan zeropoint ld1 {v8.4s}, [x12] // weight quan zeropoint
fmul v27.4s, v27.4s, v0.4s fmul v27.4s, v27.4s, v0.4s
cbz x10, TILE1_MLA_L4 cbz x26, TILE1_MLA_L4
ld1 {v10.s}[0], [x10] ld1 {v10.s}[0], [x10]
fmul v27.4s, v27.4s, v10.s[0] fmul v27.4s, v27.4s, v10.s[0]
@ -1402,11 +1407,17 @@ Tile1Quan_L4:
cbz x9, TILE1_ADD_DSTV_L4 cbz x9, TILE1_ADD_DSTV_L4
ld1 {v16.4s}, [x20] // bias ld1 {v16.4s}, [x20] // bias
fadd v27.4s, v27.4s, v16.4s fadd v27.4s, v27.4s, v16.4s
b TILE1_POST_L4 cbnz x0, TILE1_POST_L4
b TILE1_L4_ACCUM_BUFFER
TILE1_ADD_DSTV_L4: TILE1_ADD_DSTV_L4:
ld1 {v16.4s}, [x26] ld1 {v15.4s}, [x19], #16
fadd v27.4s, v27.4s, v16.4s fadd v27.4s, v27.4s, v15.4s
cbnz x0, TILE1_POST_L4
TILE1_L4_ACCUM_BUFFER:
st1 {v27.4s}, [x15], #16
b End
TILE1_POST_L4: TILE1_POST_L4:
cbz x14, TILE1_STORE_L4 cbz x14, TILE1_STORE_L4
@ -1414,7 +1425,7 @@ Tile1Quan_L4:
fmax v27.4s, v27.4s, v30.4s fmax v27.4s, v27.4s, v30.4s
TILE1_STORE_L4: TILE1_STORE_L4:
st1 {v27.4s}, [x26], x4 st1 {v27.4s}, [x6], x4
b End b End
Tile1QuanUseInt8_L4: Tile1QuanUseInt8_L4:
@ -1425,10 +1436,9 @@ Tile1Quan_L4:
sqxtn v6.8b, v6.8h sqxtn v6.8b, v6.8h
smax v6.8b, v30.8b, v6.8b smax v6.8b, v30.8b, v6.8b
smin v6.8b, v31.8b, v6.8b smin v6.8b, v31.8b, v6.8b
st1 {v6.s}[0], [x26], x4 // dst += dz * dst_step st1 {v6.s}[0], [x6], x4 // dst += dz * dst_step
End: End:
ldp x27, x28, [sp, #(16 * 8)]
ldp x25, x26, [sp, #(16 * 7)] ldp x25, x26, [sp, #(16 * 7)]
ldp x23, x24, [sp, #(16 * 6)] ldp x23, x24, [sp, #(16 * 6)]
ldp x19, x20, [sp, #(16 * 5)] ldp x19, x20, [sp, #(16 * 5)]
@ -1436,7 +1446,7 @@ ldp x21, x22, [sp, #(16 * 4)]
ldp d8, d9, [sp, #(16 * 3)] ldp d8, d9, [sp, #(16 * 3)]
ldp d10, d11, [sp, #(16 * 2)] ldp d10, d11, [sp, #(16 * 2)]
ldp d12, d13, [sp, #(16 * 1)] ldp d12, d13, [sp, #(16 * 1)]
ldp d14, d15, [sp], #(16 * 10) ldp d14, d15, [sp], #(16 * 8)
ret ret
#endif // __aarch64__ #endif // __aarch64__

View File

@ -0,0 +1,233 @@
//
// MNNAbsMaxFP32_Pack8.S
//
// Created by MNN on 2023/10/31.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifdef __aarch64__
#include "MNNAsmGlobal.h"
.text
.align 5
.macro Abs z0, z1, z2, z3
fabs \z0\().4s, \z0\().4s
fabs \z1\().4s, \z1\().4s
fabs \z2\().4s, \z2\().4s
fabs \z3\().4s, \z3\().4s
.endm
.macro Max d0, d1, d2, d3, z0, z1, z2, z3
fmax \d0\().4s, \d0\().4s, \z0\().4s
fmax \d1\().4s, \d1\().4s, \z1\().4s
fmax \d2\().4s, \d2\().4s, \z2\().4s
fmax \d3\().4s, \d3\().4s, \z3\().4s
.endm
.macro ReduceMax8 s0, s1, s2, s3, s4, s5, s6, s7, z0
fmaxp \s0\().4s, \s0\().4s, \s1\().4s // 0 0 0 0
fmaxp \s2\().4s, \s2\().4s, \s3\().4s // 1 1 1 1
fmaxp \s4\().4s, \s4\().4s, \s5\().4s // 2 2 2 2
fmaxp \s6\().4s, \s6\().4s, \s7\().4s // 3 3 3 3
fmaxp \s0\().4s, \s0\().4s, \s2\().4s // 0 0 1 1
fmaxp \s4\().4s, \s4\().4s, \s6\().4s // 2 2 3 3
fmaxp \z0\().4s, \s0\().4s, \s4\().4s // 0 1 2 3
.endm
.macro ReduceMax4 s0, s1, s2, s3, z0, z1, z2
fmaxp \z1\().4s, \s0\().4s, \s1\().4s // 0 0 0 0
fmaxp \z2\().4s, \s2\().4s, \s3\().4s // 1 1 1 1
fmaxp \s0\().4s, \z1\().4s, \z2\().4s // 0 0 1 1
fmaxp \z0\().4s, \s0\().4s, \s0\().4s // 0 1
.endm
//void MNNAbsMaxFP32_Pack8(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack)
asm_function MNNAbsMaxFP32_Pack8
// x0: source, x1:absmax, x2:src_depth_quad, x3:realSize
stp d14, d15, [sp, #(-16 * 4)]!
stp d12, d13, [sp, #(16 * 1)]
stp d10, d11, [sp, #(16 * 2)]
stp d8, d9, [sp, #(16 * 3)]
Start:
lsl x6, x3, #5 // src_step = batch * 8 * sizeof(float32_t) = batch << 5
TILE_10:
cmp x3, #10
blt TILE_8
mov x5, x2 // src_depth_quad
mov x7, x0 // src
sub x8, x6, #256 // src_step
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x7], #64 // E0, E0, E1, E1
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x7], #64 // E2, E2, E3, E3
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x7], #64 // E4, E4, E5, E5
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x7], #64 // E6, E6, E7, E7
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x7], x8 // E8, E8, E9, E9
Abs v0, v1, v2, v3
Abs v4, v5, v6, v7
Abs v8, v9, v10, v11
Abs v12, v13, v14, v15
Abs v16, v17, v18, v19
subs x5, x5, #1
beq Tile10End
LoopSz_10:
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x7], #64 // E0, E0, E1, E1
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x7], #64 // E2, E2, E3, E3
ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x7], #64 // E4, E4, E5, E5
Abs v20, v21, v22, v23
Abs v24, v25, v26, v27
Abs v28, v29, v30, v31
Max v0, v1, v2, v3, v20, v21, v22, v23
Max v4, v5, v6, v7, v24, v25, v26, v27
Max v8, v9, v10, v11, v28, v29, v30, v31
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x7], #64 // E6, E6, E7, E7
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x7], x8 // E8, E8, E9, E9
Abs v20, v21, v22, v23
Abs v24, v25, v26, v27
Max v12, v13, v14, v15, v20, v21, v22, v23
Max v16, v17, v18, v19, v24, v25, v26, v27
subs x5, x5, #1
bne LoopSz_10
Tile10End:
ReduceMax8 v0, v1, v2, v3, v4, v5, v6, v7, v20
ReduceMax8 v8, v9, v10, v11, v12, v13, v14, v15, v21
ReduceMax4 v16, v17, v18, v19, v22, v23, v24
st1 {v20.4s, v21.4s}, [x1], #32
st1 {v22.2s}, [x1], #8
sub x3, x3, #10
add x0, x0, #320 // src += 10 * pack * sizeof(float)
cbz x3, End
b TILE_10
TILE_8:
cmp x3, #8
blt TILE_4
mov x5, x2 // src_depth_quad
mov x7, x0 // src
sub x8, x6, #192 // src_step
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x7], #64 // E0, E0, E1, E1
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x7], #64 // E2, E2, E3, E3
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x7], #64 // E4, E4, E5, E5
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x7], x8 // E6, E6, E7, E7
Abs v0, v1, v2, v3
Abs v4, v5, v6, v7
Abs v8, v9, v10, v11
Abs v12, v13, v14, v15
subs x5, x5, #1
beq Tile8End
LoopSz_8:
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x7], #64 // E0, E0, E1, E1
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x7], #64 // E2, E2, E3, E3
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x7], #64 // E4, E4, E5, E5
ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x7], x8 // E6, E6, E7, E7
Abs v16, v17, v18, v19
Abs v20, v21, v22, v23
Abs v24, v25, v26, v27
Abs v28, v29, v30, v31
Max v0, v1, v2, v3, v16, v17, v18, v19
Max v4, v5, v6, v7, v20, v21, v22, v23
Max v8, v9, v10, v11, v24, v25, v26, v27
Max v12, v13, v14, v15, v28, v29, v30, v31
subs x5, x5, #1
bne LoopSz_8
Tile8End:
ReduceMax8 v0, v1, v2, v3, v4, v5, v6, v7, v16
ReduceMax8 v8, v9, v10, v11, v12, v13, v14, v15, v17
st1 {v16.4s, v17.4s}, [x1], #32
sub x3, x3, #8
add x0, x0, #256 // src += 8 * pack * sizeof(float)
b TILE_8
TILE_4:
cmp x3, #4
blt TILE_1
mov x5, x2 // src_depth_quad
mov x7, x0 // src
sub x8, x6, #64 // src_step
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x7], #64 // E0, E0, E1, E1
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x7], x8 // E2, E2, E3, E3
Abs v0, v1, v2, v3
Abs v4, v5, v6, v7
subs x5, x5, #1
beq Tile4End
LoopSz_4:
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x7], #64 // E0, E0, E1, E1
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x7], x8 // E2, E2, E3, E3
Abs v16, v17, v18, v19
Abs v20, v21, v22, v23
Max v0, v1, v2, v3, v16, v17, v18, v19
Max v4, v5, v6, v7, v20, v21, v22, v23
subs x5, x5, #1
bne LoopSz_4
Tile4End:
ReduceMax8 v0, v1, v2, v3, v4, v5, v6, v7, v16
st1 {v16.4s}, [x1], #16
sub x3, x3, #4
add x0, x0, #128 // src += 4 * pack * sizeof(float)
b TILE_4
TILE_1:
cmp x3, #1
blt End
mov x5, x2 // src_depth_quad
mov x7, x0 // src
// sum: v0
// absmax: v8
ld1 {v0.4s, v1.4s}, [x7], x6
fabs v0.4s, v0.4s
fabs v1.4s, v1.4s
subs x5, x5, #1
beq Tile1End
LoopSz_1:
ld1 {v16.4s, v17.4s}, [x7], x6
// absmax = fmax(absmax, abs(x))
fabs v16.4s, v16.4s
fabs v17.4s, v17.4s
fmax v0.4s, v0.4s, v16.4s
fmax v1.4s, v1.4s, v17.4s
subs x5, x5, #1
bne LoopSz_1
Tile1End:
// reduce max
fmaxp v2.4s, v0.4s, v1.4s // 0 0 0 0
fmaxp v3.4s, v2.4s, v2.4s // 0 0
fmaxp v4.4s, v3.4s, v3.4s
st1 {v4.s}[0], [x1], #4
subs x3, x3, #1
add x0, x0, #32 // src += 1 * 8(pack) * 4(sizeof(float32_t))
bne TILE_1
End:
ldp d8, d9, [sp, #(16 * 3)]
ldp d10, d11, [sp, #(16 * 2)]
ldp d12, d13, [sp, #(16 * 1)]
ldp d14, d15, [sp], #(16 * 4)
ret
#endif

View File

@ -0,0 +1,328 @@
//
// MNNDynamicQuantFP32_Pack8.S
// MNN
//
// Created by MNN on 2023/10/31.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifdef __aarch64__
#include "MNNAsmGlobal.h"
.text
.align 5
.macro Round z0, z1, z2, z3
fcvtas \z0\().4s, \z0\().4s
fcvtas \z1\().4s, \z1\().4s
fcvtas \z2\().4s, \z2\().4s
fcvtas \z3\().4s, \z3\().4s
.endm
.macro Transpose z0, z1, z2, z3, t0, t1, t2, t3
trn1 \t0\().4s, \z0\().4s, \z1\().4s
trn1 \t1\().4s, \z2\().4s, \z3\().4s
trn2 \t2\().4s, \z0\().4s, \z1\().4s
trn2 \t3\().4s, \z2\().4s, \z3\().4s
trn1 \z0\().2d, \t0\().2d, \t1\().2d
trn1 \z1\().2d, \t2\().2d, \t3\().2d
trn2 \z2\().2d, \t0\().2d, \t1\().2d
trn2 \z3\().2d, \t2\().2d, \t3\().2d
.endm
.macro Add_4x4 d0, d1, d2, d3
add \d0\().4s, \d1\().4s, \d0\().4s
add \d2\().4s, \d3\().4s, \d2\().4s
add \d0\().4s, \d0\().4s, \d2\().4s
.endm
//void MNNDynamicQuantFP32_Pack8(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack)
asm_function MNNDynamicQuantFP32_Pack8
// x0: src, x1:dst, x2:scale, x3:src_depth_quad, x4:realSize
stp d14, d15, [sp, #(-16 * 4)]!
stp d12, d13, [sp, #(16 * 1)]
stp d10, d11, [sp, #(16 * 2)]
stp d8, d9, [sp, #(16 * 3)]
Start:
lsl x6, x4, #3 // dst_step = batch * unit * sizeof(int8_t) = batch * 8 = batch << 3
lsl x7, x6, #2 // src_step = dst_step * 4 (sizeof(float32_t)) = dst_step << 2
TILE_10:
cmp x4, #8
blt TILE_8
sub x8, x7, #256 // src_step - 256
sub x11, x6, #64 // dst_step-64
mov x9, x0 // src
mov x10, x1 // dst
mov x12, x3 // src_depth_quad
// quant_scale: v8, 8(batch)*sizeof(float32_t)
ld1 {v16.4s, v17.4s}, [x2], #32
ld1 {v30.8b}, [x2], #8
LoopSz_10:
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x9], #64 // E0, E0, E1, E1
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x9], #64 // E2, E2, E3, E3
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x9], #64 // E4, E4, E5, E5
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9], #64 // E6, E6, E7, E7
ld1 {v18.4s, v19.4s, v20.4s, v21.4s}, [x9], x8 // E8, E8, E9, E9
// float32_t x = x * quant_scale
fmul v0.4s, v0.4s, v16.s[0]
fmul v1.4s, v1.4s, v16.s[0]
fmul v2.4s, v2.4s, v16.s[1]
fmul v3.4s, v3.4s, v16.s[1]
fmul v4.4s, v4.4s, v16.s[2]
fmul v5.4s, v5.4s, v16.s[2]
fmul v6.4s, v6.4s, v16.s[3]
fmul v7.4s, v7.4s, v16.s[3]
fmul v8.4s, v8.4s, v17.s[0]
fmul v9.4s, v9.4s, v17.s[0]
fmul v10.4s, v10.4s, v17.s[1]
fmul v11.4s, v11.4s, v17.s[1]
fmul v12.4s, v12.4s, v17.s[2]
fmul v13.4s, v13.4s, v17.s[2]
fmul v14.4s, v14.4s, v17.s[3]
fmul v15.4s, v15.4s, v17.s[3]
fmul v18.4s, v18.4s, v30.s[0]
fmul v19.4s, v19.4s, v30.s[0]
fmul v20.4s, v20.4s, v30.s[1]
fmul v21.4s, v21.4s, v30.s[1]
// int32_t x = round(x)
Round v0, v1, v2, v3
Round v4, v5, v6, v7
Round v8, v9, v10, v11
Round v12, v13, v14, v15
Round v18, v19, v20, v21
// y = (int8_t)x
sqxtn v22.4h, v0.4s
sqxtn2 v22.8h, v1.4s
sqxtn v23.4h, v2.4s
sqxtn2 v23.8h, v3.4s
sqxtn v24.4h, v4.4s
sqxtn2 v24.8h, v5.4s
sqxtn v25.4h, v6.4s
sqxtn2 v25.8h, v7.4s
sqxtn v26.4h, v8.4s
sqxtn2 v26.8h, v9.4s
sqxtn v27.4h, v10.4s
sqxtn2 v27.8h, v11.4s
sqxtn v28.4h, v12.4s
sqxtn2 v28.8h, v13.4s
sqxtn v29.4h, v14.4s
sqxtn2 v29.8h, v15.4s
sqxtn v0.4h, v18.4s
sqxtn2 v0.8h, v19.4s
sqxtn v1.4h, v20.4s
sqxtn2 v1.8h, v21.4s
sqxtn v2.8b, v22.8h
sqxtn2 v2.16b, v23.8h
sqxtn v3.8b, v24.8h
sqxtn2 v3.16b, v25.8h
sqxtn v4.8b, v26.8h
sqxtn2 v4.16b, v27.8h
sqxtn v5.8b, v28.8h
sqxtn2 v5.16b, v29.8h
sqxtn v6.8b, v0.8h
sqxtn2 v6.16b, v1.8h
st1 {v2.16b, v3.16b, v4.16b, v5.16b}, [x10], #64
st1 {v6.16b}, [x10], x11
subs x12, x12, #1
bne LoopSz_10
Tile10End:
sub x4, x4, #10 // batch -= 10
add x0, x0, #320 // src += 10 * 8 * sizeof(float32_t)
add x1, x1, #80 // dst += 10 * 8 * sizeof(int8_t)
b TILE_10
TILE_8:
cmp x4, #8
blt TILE_4
sub x8, x7, #192 // src_step - 192
mov x9, x0 // src
mov x10, x1 // dst
mov x12, x3 // src_depth_quad
// quant_scale: v8, 8(batch)*sizeof(float32_t)
ld1 {v16.4s, v17.4s}, [x2], #32
LoopSz_8:
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x9], #64 // E0, E0, E1, E1
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x9], #64 // E2, E2, E3, E3
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x9], #64 // E4, E4, E5, E5
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x9], x8 // E6, E6, E7, E7
// float32_t x = x * quant_scale
fmul v0.4s, v0.4s, v16.s[0]
fmul v1.4s, v1.4s, v16.s[0]
fmul v2.4s, v2.4s, v16.s[1]
fmul v3.4s, v3.4s, v16.s[1]
fmul v4.4s, v4.4s, v16.s[2]
fmul v5.4s, v5.4s, v16.s[2]
fmul v6.4s, v6.4s, v16.s[3]
fmul v7.4s, v7.4s, v16.s[3]
fmul v8.4s, v8.4s, v17.s[0]
fmul v9.4s, v9.4s, v17.s[0]
fmul v10.4s, v10.4s, v17.s[1]
fmul v11.4s, v11.4s, v17.s[1]
fmul v12.4s, v12.4s, v17.s[2]
fmul v13.4s, v13.4s, v17.s[2]
fmul v14.4s, v14.4s, v17.s[3]
fmul v15.4s, v15.4s, v17.s[3]
// int32_t x = round(x)
Round v0, v1, v2, v3
Round v4, v5, v6, v7
Round v8, v9, v10, v11
Round v12, v13, v14, v15
// y = (int8_t)x
sqxtn v18.4h, v0.4s
sqxtn2 v18.8h, v1.4s
sqxtn v19.4h, v2.4s
sqxtn2 v19.8h, v3.4s
sqxtn v20.4h, v4.4s
sqxtn2 v20.8h, v5.4s
sqxtn v21.4h, v6.4s
sqxtn2 v21.8h, v7.4s
sqxtn v22.4h, v8.4s
sqxtn2 v22.8h, v9.4s
sqxtn v23.4h, v10.4s
sqxtn2 v23.8h, v11.4s
sqxtn v24.4h, v12.4s
sqxtn2 v24.8h, v13.4s
sqxtn v25.4h, v14.4s
sqxtn2 v25.8h, v15.4s
sqxtn v26.8b, v18.8h
sqxtn2 v26.16b, v19.8h
sqxtn v27.8b, v20.8h
sqxtn2 v27.16b, v21.8h
sqxtn v28.8b, v22.8h
sqxtn2 v28.16b, v23.8h
sqxtn v29.8b, v24.8h
sqxtn2 v29.16b, v25.8h
st1 {v26.16b, v27.16b, v28.16b, v29.16b}, [x10], x6
subs x12, x12, #1
bne LoopSz_8
Tile8End:
sub x4, x4, #8 // batch -= 8
add x0, x0, #256 // src += 8 * 8 * sizeof(float32_t)
add x1, x1, #64 // dst += 8 * 8 * sizeof(int8_t)
b TILE_8
TILE_4:
cmp x4, #4
blt TILE_1
sub x8, x7, #64
mov x9, x0 // src
mov x10, x1 // dst
mov x12, x3 // src_depth_quad
ld1 {v16.4s}, [x2], #16
LoopSz_4:
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x9], #64 // E0, E0, E1, E1
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x9], x8 // E2, E2, E3, E3
// float32_t x = x * quant_scale
fmul v0.4s, v0.4s, v16.s[0]
fmul v1.4s, v1.4s, v16.s[0]
fmul v2.4s, v2.4s, v16.s[1]
fmul v3.4s, v3.4s, v16.s[1]
fmul v4.4s, v4.4s, v16.s[2]
fmul v5.4s, v5.4s, v16.s[2]
fmul v6.4s, v6.4s, v16.s[3]
fmul v7.4s, v7.4s, v16.s[3]
// int32_t x = round(x)
Round v0, v1, v2, v3
Round v4, v5, v6, v7
// y = (int8_t)x
sqxtn v18.4h, v0.4s
sqxtn2 v18.8h, v1.4s
sqxtn v19.4h, v2.4s
sqxtn2 v19.8h, v3.4s
sqxtn v20.4h, v4.4s
sqxtn2 v20.8h, v5.4s
sqxtn v21.4h, v6.4s
sqxtn2 v21.8h, v7.4s
sqxtn v26.8b, v18.8h
sqxtn2 v26.16b, v19.8h
sqxtn v27.8b, v20.8h
sqxtn2 v27.16b, v21.8h
st1 {v26.16b, v27.16b}, [x10], x6
subs x12, x12, #1
bne LoopSz_4
Tile4End:
sub x4, x4, #4 // batch -= 4
add x0, x0, #128 // src += 4 * 8 * sizeof(float32_t)
add x1, x1, #32 // dst += 4 * 8 * sizeof(int8_t)
b TILE_4
TILE_1:
cmp x4, #1
blt End
mov x9, x0 // src
mov x10, x1 // dst
mov x12, x3 // src_depth_quad
// quant_scale: v8
ld1 {v8.s}[0], [x2], #4
LoopSz_1:
ld1 {v0.4s, v1.4s}, [x9], x7
fmul v0.4s, v0.4s, v8.s[0]
fmul v1.4s, v1.4s, v8.s[0]
// int16_t x = round(x)
fcvtas v0.4s, v0.4s
fcvtas v1.4s, v1.4s
// y = (int8_t)x
sqxtn v7.4h, v0.4s
sqxtn2 v7.8h, v1.4s
sqxtn v7.8b, v7.8h
st1 {v7.8b}, [x10], x6
subs x12, x12, #1
bne LoopSz_1
Tile1End:
subs x4, x4, #1 // batch -= 1
add x0, x0, #32 // src += 1 * 8 * sizeof(float32_t)
add x1, x1, #8 // dst += 1 * 8 * sizeof(int8_t)
bne TILE_1
End:
ldp d8, d9, [sp, #(16 * 3)]
ldp d10, d11, [sp, #(16 * 2)]
ldp d12, d13, [sp, #(16 * 1)]
ldp d14, d15, [sp], #(16 * 4)
ret
#endif

View File

@ -52,114 +52,29 @@
Note: Only used in dynamic quant,so do not need compare min max! Note: Only used in dynamic quant,so do not need compare min max!
*/ */
asm_function MNNDynamicUpdateConvBiasScale asm_function MNNDynamicUpdateConvBiasScale
//MNNDynamicUpdateConvBiasScale(biasFloat.data(), scaleFloat.data(), biasfp32, weightDequantScale, //MNNDynamicUpdateConvBiasScale(float* newbias, float* oldbias, float* weightKernelSum, float* inputZero, size_t ocQuad);
//inputScale, weightKernelSum, inputZero, UP_DIV(output->channel(), 4), alphaSize) //x0:newbias, x1:oldbias, x2:weightKernelSum, x3:inputZero, x4:ocQuad
//x0:biasFloat, x1:scaleFloat, x2:biasfp32, x3:weightDequantScale, x4:inputScale, x5:weightKernelSum, x6:inputZero, x7:ocQuad
//Load from sp: x9: scaleSize
ldr x9, [sp, #0]
stp d14, d15, [sp, #-64]! stp d14, d15, [sp, #-64]!
stp d12, d13, [sp, #16] stp d12, d13, [sp, #16]
stp d10, d11, [sp, #32] stp d10, d11, [sp, #32]
stp d8, d9, [sp, #48] stp d8, d9, [sp, #48]
ld1r {v31.4s}, [x4] // input dequant scale ld1r {v30.4s}, [x3] // input dequant zero:fp32 zero
ld1r {v30.4s}, [x6] // input dequant zero:fp32 zero
lsr x9, x9, #2
// fuse scale
SCALE_L24:
cmp x9, #24
blt SCALE_L16
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3], #64
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x3], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x3], #64
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x3], #64
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x3], #64
MUL_CONSTANT v0, v1, v2, v3, v31 // w_scale * x_scale
MUL_CONSTANT v4, v5, v6, v7, v31
MUL_CONSTANT v8, v9, v10, v11, v31
MUL_CONSTANT v12, v13, v14, v15, v31
MUL_CONSTANT v16, v17, v18, v19, v31
MUL_CONSTANT v20, v21, v22, v23, v31
sub x9, x9, #24
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64
b SCALE_L24
SCALE_L16:
cmp x9, #16
blt SCALE_L8
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3], #64
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x3], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x3], #64
MUL_CONSTANT v0, v1, v2, v3, v31 // w_scale * x_scale
MUL_CONSTANT v4, v5, v6, v7, v31
MUL_CONSTANT v8, v9, v10, v11, v31
MUL_CONSTANT v12, v13, v14, v15, v31
sub x9, x9, #16
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64
b SCALE_L16
SCALE_L8:
cmp x9, #8
blt SCALE_L4
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3], #64
MUL_CONSTANT v0, v1, v2, v3, v31 // w_scale * x_scale
MUL_CONSTANT v4, v5, v6, v7, v31
sub x9, x9, #8
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
b SCALE_L8
SCALE_L4:
cmp x9, #4
blt SCALE_L1
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3], #64
MUL_CONSTANT v0, v1, v2, v3, v31 // w_scale * x_scale
sub x9, x9, #4
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64
b SCALE_L4
SCALE_L1:
cmp x9, #1
blt BIAS_L8
ld1 {v0.4s}, [x3], #16
fmul v0.4s, v0.4s, v31.4s
sub x9, x9, #1
st1 {v0.4s}, [x1], #16
b SCALE_L1
// Bias: // Bias:
BIAS_L16: BIAS_L16:
cmp x7, #16 cmp x4, #16
blt BIAS_L8 blt BIAS_L8
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64 // oldbias ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 // oldbias
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x2], #64 ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64 // weightKernelSum ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 // weightKernelSum
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x5], #64 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x5], #64 ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x2], #64
sub x7, x7, #16 sub x4, x4, #16
MUL_CONSTANT v16, v17, v18, v19, v30 // w_sum * x_zero MUL_CONSTANT v16, v17, v18, v19, v30 // w_sum * x_zero
MUL_CONSTANT v20, v21, v22, v23, v30 // w_sum * x_zero MUL_CONSTANT v20, v21, v22, v23, v30 // w_sum * x_zero
@ -167,7 +82,7 @@ MUL_CONSTANT v24, v25, v26, v27, v30 // w_sum * x_zero
SUB4 v0, v1, v2, v3, v16, v17, v18, v19 SUB4 v0, v1, v2, v3, v16, v17, v18, v19
SUB4 v4, v5, v6, v7, v20, v21, v22, v23 SUB4 v4, v5, v6, v7, v20, v21, v22, v23
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64 ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64
SUB4 v8, v9, v10, v11, v24, v25, v26, v27 SUB4 v8, v9, v10, v11, v24, v25, v26, v27
MUL_CONSTANT v16, v17, v18, v19, v30 // w_sum * x_zero MUL_CONSTANT v16, v17, v18, v19, v30 // w_sum * x_zero
SUB4 v12, v13, v14, v15, v16, v17, v18, v19 SUB4 v12, v13, v14, v15, v16, v17, v18, v19
@ -179,14 +94,14 @@ st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64
b BIAS_L16 b BIAS_L16
BIAS_L8: BIAS_L8:
cmp x7, #8 cmp x4, #8
blt BIAS_L4 blt BIAS_L4
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64 // oldbias ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 // oldbias
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x2], #64 ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64 // weightKernelSum ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x2], #64 // weightKernelSum
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x5], #64 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x2], #64
sub x7, x7, #8 sub x4, x4, #8
MUL_CONSTANT v16, v17, v18, v19, v30 // w_sum * x_zero MUL_CONSTANT v16, v17, v18, v19, v30 // w_sum * x_zero
MUL_CONSTANT v20, v21, v22, v23, v30 // w_sum * x_zero MUL_CONSTANT v20, v21, v22, v23, v30 // w_sum * x_zero
@ -197,12 +112,12 @@ st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64
b BIAS_L8 b BIAS_L8
BIAS_L4: BIAS_L4:
cmp x7, #4 cmp x4, #4
blt BIAS_L1 blt BIAS_L1
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64 // oldbias ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 // oldbias
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x5], #64 // weightKernelSum ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 // weightKernelSum
sub x7, x7, #4 sub x4, x4, #4
MUL_CONSTANT v8, v9, v10, v11, v30 // w_sum * x_zero MUL_CONSTANT v8, v9, v10, v11, v30 // w_sum * x_zero
SUB4 v0, v1, v2, v3, v8, v9, v10, v11 SUB4 v0, v1, v2, v3, v8, v9, v10, v11
@ -210,11 +125,11 @@ st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
b BIAS_L4 b BIAS_L4
BIAS_L1: BIAS_L1:
cmp x7, #1 cmp x4, #1
blt End blt End
ld1 {v0.4s}, [x2], #16 // oldbias ld1 {v0.4s}, [x1], #16 // oldbias
ld1 {v4.4s}, [x5], #16 // weightKernelSum ld1 {v4.4s}, [x2], #16 // weightKernelSum
sub x7, x7, #1 sub x4, x4, #1
fmul v4.4s, v4.4s, v30.4s // w_sum * x_zero fmul v4.4s, v4.4s, v30.4s // w_sum * x_zero
fsub v0.4s, v0.4s, v4.4s // oldbias - w_sum * x_zero fsub v0.4s, v0.4s, v4.4s // oldbias - w_sum * x_zero
st1 {v0.4s}, [x0], #16 st1 {v0.4s}, [x0], #16

View File

@ -93,43 +93,38 @@ struct QuanPostTreatParameters {
// x5: dst_depth_quad, x6: post, x7: realSize // x5: dst_depth_quad, x6: post, x7: realSize
//Load from post: //Load from post:
// x7: scale, x10: bias, w11: maxValue, w6: minValue, w13: UseInt8, x14: srcKernelSum, x12: weightQuantBias // x10: bias, w11: maxValue, w6: minValue, w13: UseInt8, x14: srcKernelSum
mov x8, x7
mov x15, x6
ldr x7, [x15, #0]
ldr x10, [x15, #8]
ldr w11, [x15, #16]
ldr w6, [x15, #20]
ldr w13, [x15, #24]
ldr x14, [x15, #40] // srcKernelSum
ldr x12, [x15, #48] // weightQuantBias
stp d14, d15, [sp, #(-16 * 8)]! ldr x10, [x6, #8]
ldr w11, [x6, #16]
ldr w13, [x6, #24]
ldr x14, [x6, #40] // srcKernelSum
stp d14, d15, [sp, #(-16 * 6)]!
stp d12, d13, [sp, #(16 * 1)] stp d12, d13, [sp, #(16 * 1)]
stp d10, d11, [sp, #(16 * 2)] stp d10, d11, [sp, #(16 * 2)]
stp d8, d9, [sp, #(16 * 3)] stp d8, d9, [sp, #(16 * 3)]
stp x19, x20, [sp, #(16 * 4)] stp x19, x20, [sp, #(16 * 4)]
stp x21, x22, [sp, #(16 * 5)] stp x23, x24, [sp, #(16 * 5)]
stp x23, x24, [sp, #(16 * 6)]
ldr x19, [x15, #56] // fp32 min max ldr x19, [x6, #56] // fp32 min max
ldr x23, [x15, #80] // extraScale ldr x23, [x6, #80] // extraScale
lsl x21, x3, #5 // src_depth_quad* SRC_UNIT * UNIT * sizeof(int4_t) ldr x15, [x6, #96] // accumBuffer
ldr w6, [x6, #20] // minValue
add x20, x19, #4 add x20, x19, #4
Start: Start:
cmp x8, #3 cmp x7, #3
beq L3Dz beq L3Dz
cmp x8, #2 cmp x7, #2
beq L2Dz beq L2Dz
cmp x8, #1 cmp x7, #1
beq L1Dz beq L1Dz
mov x7, x15
L4LoopDz: L4LoopDz:
mov x8, x1 mov x8, x1
mov x22, x2
ld1 {v10.16b, v11.16b}, [x2], #32 // weight ld1 {v10.16b, v11.16b}, [x2], #32 // weight
ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64 // src ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64 // src
// int4->int8 // int4->int8
@ -281,9 +276,9 @@ L4LoopDz:
addp v15.4s, v10.4s, v11.4s addp v15.4s, v10.4s, v11.4s
L4Quan: L4Quan:
ld1 {v1.4s}, [x7], #16 // scalefuse ld1 {v1.4s}, [x2], #16 // scalefuse
ld1 {v20.4s}, [x14] // srcKernelSum ld1 {v20.4s}, [x14] // srcKernelSum
ld1 {v21.4s}, [x12], #16 // weightQuanZero ld1 {v21.4s}, [x2], #16 // weightQuanZero
scvtf v4.4s, v12.4s scvtf v4.4s, v12.4s
scvtf v5.4s, v13.4s scvtf v5.4s, v13.4s
@ -309,14 +304,20 @@ L4LoopDz:
fadd v5.4s, v5.4s, v0.4s fadd v5.4s, v5.4s, v0.4s
fadd v6.4s, v6.4s, v0.4s fadd v6.4s, v6.4s, v0.4s
fadd v7.4s, v7.4s, v0.4s fadd v7.4s, v7.4s, v0.4s
b L4_POST cbnz x0, L4_POST
b L4_BUFFER
L4_ADD_DSTV: L4_ADD_DSTV:
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0] ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x7], #64
fadd v4.4s, v4.4s, v8.4s fadd v4.4s, v4.4s, v8.4s
fadd v5.4s, v5.4s, v9.4s fadd v5.4s, v5.4s, v9.4s
fadd v6.4s, v6.4s, v10.4s fadd v6.4s, v6.4s, v10.4s
fadd v7.4s, v7.4s, v11.4s fadd v7.4s, v7.4s, v11.4s
cbnz x0, L4_POST
L4_BUFFER:
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
b L4LoopCheck
L4_POST: L4_POST:
cbz x19, L4_STORE cbz x19, L4_STORE
@ -330,18 +331,17 @@ L4LoopDz:
L4LoopCheck: L4LoopCheck:
subs x5, x5, #1 subs x5, x5, #1
mov x1, x8 mov x1, x8
add x2, x22, x21
bne L4LoopDz bne L4LoopDz
b End b End
L3Dz: L3Dz:
mov x7, x15
cmp w13, #1 cmp w13, #1
bne L3LoopDz bne L3LoopDz
sub x4, x4, #8 sub x4, x4, #8
L3LoopDz: L3LoopDz:
mov x8, x1 mov x8, x1
mov x22, x2
ld1 {v10.16b, v11.16b}, [x2], #32 ld1 {v10.16b, v11.16b}, [x2], #32
ld1 {v4.16b, v5.16b, v6.16b}, [x1], #48 ld1 {v4.16b, v5.16b, v6.16b}, [x1], #48
// int4->int8 // int4->int8
@ -467,10 +467,10 @@ L3LoopDz:
addp v14.4s, v8.4s, v9.4s addp v14.4s, v8.4s, v9.4s
L3Quan: L3Quan:
ld1 {v1.4s}, [x7], #16 ld1 {v1.4s}, [x2], #16
ld1 {v20.d}[0], [x14], #8 // srcKernelSum ld1 {v20.d}[0], [x14], #8 // srcKernelSum
ld1 {v20.s}[2], [x14] ld1 {v20.s}[2], [x14]
ld1 {v21.4s}, [x12], #16 // weightQuanZero ld1 {v21.4s}, [x2], #16 // weightQuanZero
scvtf v4.4s, v12.4s scvtf v4.4s, v12.4s
scvtf v5.4s, v13.4s scvtf v5.4s, v13.4s
@ -497,13 +497,19 @@ L3LoopDz:
fadd v4.4s, v4.4s, v0.4s fadd v4.4s, v4.4s, v0.4s
fadd v5.4s, v5.4s, v0.4s fadd v5.4s, v5.4s, v0.4s
fadd v6.4s, v6.4s, v0.4s fadd v6.4s, v6.4s, v0.4s
b L3_POST cbnz x0, L3_POST
b L3_BUFFER
L3_ADD_DSTV: L3_ADD_DSTV:
ld1 {v0.4s, v1.4s, v2.4s}, [x0] ld1 {v8.4s, v9.4s, v10.4s}, [x7], #48
fadd v4.4s, v4.4s, v0.4s fadd v4.4s, v4.4s, v8.4s
fadd v5.4s, v5.4s, v1.4s fadd v5.4s, v5.4s, v9.4s
fadd v6.4s, v6.4s, v2.4s fadd v6.4s, v6.4s, v10.4s
cbnz x0, L3_POST
L3_BUFFER:
st1 {v4.4s, v5.4s, v6.4s}, [x15], #48
b L3LoopCheck
L3_POST: L3_POST:
cbz x19, L3_STORE cbz x19, L3_STORE
@ -516,15 +522,14 @@ L3LoopDz:
L3LoopCheck: L3LoopCheck:
subs x5, x5, #1 subs x5, x5, #1
mov x1, x8 mov x1, x8
add x2, x22, x21
bne L3LoopDz bne L3LoopDz
b End b End
L2Dz: L2Dz:
mov x7, x15
L2LoopDz: L2LoopDz:
mov x8, x1 mov x8, x1
mov x22, x2
ld1 {v10.16b, v11.16b}, [x2], #32 ld1 {v10.16b, v11.16b}, [x2], #32
ld1 {v4.16b, v5.16b}, [x1], #32 ld1 {v4.16b, v5.16b}, [x1], #32
// int4->int8 // int4->int8
@ -617,9 +622,9 @@ L2LoopDz:
addp v13.4s, v6.4s, v7.4s addp v13.4s, v6.4s, v7.4s
L2Quan: L2Quan:
ld1 {v1.4s}, [x7], #16 ld1 {v1.4s}, [x2], #16
ld1 {v20.d}[0], [x14] // srcKernelSum ld1 {v20.d}[0], [x14] // srcKernelSum
ld1 {v21.4s}, [x12], #16 // weightQuanZero ld1 {v21.4s}, [x2], #16 // weightQuanZero
scvtf v4.4s, v12.4s scvtf v4.4s, v12.4s
scvtf v5.4s, v13.4s scvtf v5.4s, v13.4s
@ -639,12 +644,18 @@ L2LoopDz:
ld1 {v0.4s}, [x10], #16 ld1 {v0.4s}, [x10], #16
fadd v4.4s, v4.4s, v0.4s fadd v4.4s, v4.4s, v0.4s
fadd v5.4s, v5.4s, v0.4s fadd v5.4s, v5.4s, v0.4s
b L2_POST cbnz x0, L2_POST
b L2_BUFFER
L2_ADD_DSTV: L2_ADD_DSTV:
ld1 {v0.4s, v1.4s}, [x0] ld1 {v8.4s, v9.4s}, [x7], #32
fadd v4.4s, v4.4s, v0.4s fadd v4.4s, v4.4s, v8.4s
fadd v5.4s, v5.4s, v1.4s fadd v5.4s, v5.4s, v9.4s
cbnz x0, L2_POST
L2_BUFFER:
st1 {v4.4s, v5.4s}, [x15], #32
b L2LoopCheck
L2_POST: L2_POST:
cbz x19, L2_STORE cbz x19, L2_STORE
@ -658,15 +669,14 @@ L2LoopDz:
L2LoopCheck: L2LoopCheck:
subs x5, x5, #1 subs x5, x5, #1
mov x1, x8 mov x1, x8
add x2, x22, x21
bne L2LoopDz bne L2LoopDz
b End b End
L1Dz: L1Dz:
mov x7, x15
L1LoopDz: L1LoopDz:
mov x8, x1 mov x8, x1
mov x22, x2
ld1 {v10.16b, v11.16b}, [x2], #32 ld1 {v10.16b, v11.16b}, [x2], #32
// int4->int8 // int4->int8
movi v8.16b, #15 movi v8.16b, #15
@ -729,16 +739,15 @@ L1LoopDz:
sadalp v18.4s, v10.8h sadalp v18.4s, v10.8h
sadalp v19.4s, v11.8h sadalp v19.4s, v11.8h
//ld1 {v0.4s}, [x10], #16
addp v4.4s, v16.4s, v17.4s addp v4.4s, v16.4s, v17.4s
addp v5.4s, v18.4s, v19.4s addp v5.4s, v18.4s, v19.4s
addp v12.4s, v4.4s, v5.4s addp v12.4s, v4.4s, v5.4s
L1Quan: L1Quan:
ld1 {v1.4s}, [x7], #16 ld1 {v1.4s}, [x2], #16
ld1 {v20.s}[0], [x14] // srcKernelSum ld1 {v20.s}[0], [x14] // srcKernelSum
ld1 {v21.4s}, [x12], #16 // weightQuanZero ld1 {v21.4s}, [x2], #16 // weightQuanZero
scvtf v4.4s, v12.4s scvtf v4.4s, v12.4s
MUL_SCALE1 v1, v4 MUL_SCALE1 v1, v4
@ -754,11 +763,17 @@ L1LoopDz:
cbz x10, L1_ADD_DSTV cbz x10, L1_ADD_DSTV
ld1 {v0.4s}, [x10], #16 ld1 {v0.4s}, [x10], #16
fadd v4.4s, v4.4s, v0.4s fadd v4.4s, v4.4s, v0.4s
b L1_POST cbnz x0, L1_POST
b L1_BUFFER
L1_ADD_DSTV: L1_ADD_DSTV:
ld1 {v0.4s}, [x0] ld1 {v8.4s}, [x7], #16
fadd v4.4s, v4.4s, v0.4s fadd v4.4s, v4.4s, v8.4s
cbnz x0, L1_POST
L1_BUFFER:
st1 {v4.4s}, [x15], #16
b L1LoopCheck
L1_POST: L1_POST:
cbz x19, L1_STORE cbz x19, L1_STORE
@ -772,17 +787,15 @@ L1LoopDz:
L1LoopCheck: L1LoopCheck:
subs x5, x5, #1 subs x5, x5, #1
mov x1, x8 mov x1, x8
add x2, x22, x21
bne L1LoopDz bne L1LoopDz
End: End:
ldp x23, x24, [sp, #(16 * 6)] ldp x23, x24, [sp, #(16 * 5)]
ldp x21, x22, [sp, #(16 * 5)]
ldp x19, x20, [sp, #(16 * 4)] ldp x19, x20, [sp, #(16 * 4)]
ldp d8, d9, [sp, #(16 * 3)] ldp d8, d9, [sp, #(16 * 3)]
ldp d10, d11, [sp, #(16 * 2)] ldp d10, d11, [sp, #(16 * 2)]
ldp d12, d13, [sp, #(16 * 1)] ldp d12, d13, [sp, #(16 * 1)]
ldp d14, d15, [sp], #(16 * 8) ldp d14, d15, [sp], #(16 * 6)
ret ret
#endif #endif

View File

@ -110,42 +110,35 @@ struct QuanPostTreatParameters {
//Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step //Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step
//x5:dst_depth_quad, x6: parameters, x7: realDstCount //x5:dst_depth_quad, x6: parameters, x7: realDstCount
//Load from x6: x8: scale, x9: bias, x25: xKernelSum, x26: weightQuantBias, x23: fp32minmax //Load from x6: x9: bias, x8: xKernelSum, x23: fp32minmax
ldr x8, [x6, #0]
ldr x9, [x6, #8] ldr x9, [x6, #8]
stp d14, d15, [sp, #(-16 * 10)]! stp d14, d15, [sp, #(-16 * 8)]!
stp d12, d13, [sp, #(16 * 1)] stp d12, d13, [sp, #(16 * 1)]
stp d10, d11, [sp, #(16 * 2)] stp d10, d11, [sp, #(16 * 2)]
stp d8, d9, [sp, #(16 * 3)] stp d8, d9, [sp, #(16 * 3)]
stp x21, x22, [sp, #(16 * 4)] stp x21, x22, [sp, #(16 * 4)]
stp x19, x20, [sp, #(16 * 5)] stp x19, x20, [sp, #(16 * 5)]
stp x27, x28, [sp, #(16 * 6)] stp x23, x24, [sp, #(16 * 6)]
stp x25, x26, [sp, #(16 * 7)]
stp x23, x24, [sp, #(16 * 8)]
lsl x15, x3, #4 // x15 = src_depth_quad * UNIT * SRC_UNIT * sizeof(int4_t) ldr x8, [x6, #40] // srcKernelSum
ldr x25, [x6, #40] // xKernelSum
ldr x26, [x6, #48] // weightQuantBias
ldr x24, [x6, #80] // extraScale ldr x24, [x6, #80] // extraScale
ldr x15, [x6, #96] // accumBuffer
mov x10, x15
mov x21, #16 // sizeof(float) * pack mov x21, #16 // sizeof(float) * pack
ldr x23, [x6, #56] // fp32minmax ldr x23, [x6, #56] // fp32minmax
Start:
lsl x22, x7, #2 // eDest * SRC_UNIT lsl x22, x7, #2 // eDest * SRC_UNIT
TILE_12: TILE_12:
cmp x7, #12 cmp x7, #12
blt TILE_8 blt TILE_8
sub x4, x4, #128
cmp x5, #2 cmp x5, #2
blt L4LoopDz_TILE_12 blt L4LoopDz_TILE_12
L8LoopDz_TILE_12: L8LoopDz_TILE_12:
//ld1 {v0.4s, v1.4s}, [x9], #32 // bias
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
mov x20, x0 // tag dst address
mov x27, x2
movi v7.16b, #15 movi v7.16b, #15
SET_BIAS v8, v9, v10, v11 SET_BIAS v8, v9, v10, v11
@ -193,13 +186,12 @@ L8LoopDz_TILE_12:
bne L8LoopSz_TILE_12 bne L8LoopSz_TILE_12
L8LoopSzEnd_TILE_12: L8LoopSzEnd_TILE_12:
add x2, x27, x15
sub x5, x5, #2 sub x5, x5, #2
L8Tile12Quan: L8Tile12Quan:
ld1 {v0.4s, v1.4s}, [x8], #32 // scale ld1 {v0.4s, v1.4s}, [x2], #32 // scale
ld1 {v2.4s, v3.4s, v4.4s}, [x25] // x kernel sum ld1 {v2.4s, v3.4s, v4.4s}, [x8] // x kernel sum
ld1 {v5.4s, v6.4s}, [x26], #32 // weight quan zeropoint ld1 {v5.4s, v6.4s}, [x2], #32 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
Int32ToFloat v16, v17, v18, v19 Int32ToFloat v16, v17, v18, v19
@ -252,8 +244,6 @@ L8LoopDz_TILE_12:
MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7 MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7
MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7 MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7
sub x4, x4, #128
cbz x9, TILE12_ADD_DSTV cbz x9, TILE12_ADD_DSTV
TILE12_ADD_BIAS: TILE12_ADD_BIAS:
ld1 {v0.4s, v1.4s}, [x9], #32 ld1 {v0.4s, v1.4s}, [x9], #32
@ -263,21 +253,32 @@ L8LoopDz_TILE_12:
ADD_BIAS_FLOAT v20, v21, v22, v23, v1 ADD_BIAS_FLOAT v20, v21, v22, v23, v1
ADD_BIAS_FLOAT v24, v25, v26, v27, v1 ADD_BIAS_FLOAT v24, v25, v26, v27, v1
ADD_BIAS_FLOAT v28, v29, v30, v31, v1 ADD_BIAS_FLOAT v28, v29, v30, v31, v1
b TILE12_POST cbnz x0, TILE12_POST
b TILE12_L8_ACCUM_BUFFER
TILE12_ADD_DSTV: TILE12_ADD_DSTV:
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x20], #64 ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20], #64 ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3 ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3
ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7 ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x20], x4 ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20], #64 ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ADD_FLOAT v16, v17, v18, v19, v0, v1, v2, v3 ADD_FLOAT v16, v17, v18, v19, v0, v1, v2, v3
ADD_FLOAT v20, v21, v22, v23, v4, v5, v6, v7 ADD_FLOAT v20, v21, v22, v23, v4, v5, v6, v7
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x20], #64 ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20] ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ADD_FLOAT v24, v25, v26, v27, v0, v1, v2, v3 ADD_FLOAT v24, v25, v26, v27, v0, v1, v2, v3
ADD_FLOAT v28, v29, v30, v31, v4, v5, v6, v7 ADD_FLOAT v28, v29, v30, v31, v4, v5, v6, v7
cbnz x0, TILE12_POST
TILE12_L8_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x15], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x15], #64
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x15], #64
st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x15], #64
b L8Tile12LoopCheck
TILE12_POST: TILE12_POST:
cbz x23, TILE12_STORE cbz x23, TILE12_STORE
@ -298,7 +299,6 @@ L8LoopDz_TILE_12:
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64
st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64
st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], x4 st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], x4
add x4, x4, #128
L8Tile12LoopCheck: L8Tile12LoopCheck:
cmp x5, #1 cmp x5, #1
@ -335,9 +335,9 @@ L4LoopDz_TILE_12:
L4LoopSzEnd_TILE_12: L4LoopSzEnd_TILE_12:
L4Tile12Quan: L4Tile12Quan:
ld1 {v0.4s}, [x8] // scale ld1 {v0.4s}, [x2], #16 // scale
ld1 {v2.4s, v3.4s, v4.4s}, [x25]// x kernel sum ld1 {v2.4s, v3.4s, v4.4s}, [x8]// x kernel sum
ld1 {v5.4s}, [x26], #16 // weight quan zeropoint ld1 {v5.4s}, [x2], #16 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
Int32ToFloat v16, v17, v18, v19 Int32ToFloat v16, v17, v18, v19
@ -367,24 +367,29 @@ L4LoopDz_TILE_12:
MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3 MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3
MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3 MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3
sub x4, x4, #128
TILE12_L4_ADD_BIAS: TILE12_L4_ADD_BIAS:
cbz x9, TILE12_L4_ADD_DSTV cbz x9, TILE12_L4_ADD_DSTV
ld1 {v0.4s}, [x9] // bias ld1 {v0.4s}, [x9] // bias
ADD_BIAS_FLOAT v8, v9, v10, v11, v0 ADD_BIAS_FLOAT v8, v9, v10, v11, v0
ADD_BIAS_FLOAT v12, v13, v14, v15, v0 ADD_BIAS_FLOAT v12, v13, v14, v15, v0
ADD_BIAS_FLOAT v16, v17, v18, v19, v0 ADD_BIAS_FLOAT v16, v17, v18, v19, v0
b TILE12_L4_POST cbnz x0, TILE12_L4_POST
b TILE12_L4_ACCUM_BUFFER
TILE12_L4_ADD_DSTV: TILE12_L4_ADD_DSTV:
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0] ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], #64
sub x0, x0, #128
ADD_FLOAT v8, v9, v10, v11, v20, v21, v22, v23 ADD_FLOAT v8, v9, v10, v11, v20, v21, v22, v23
ADD_FLOAT v12, v13, v14, v15, v24, v25, v26, v27 ADD_FLOAT v12, v13, v14, v15, v24, v25, v26, v27
ADD_FLOAT v16, v17, v18, v19, v28, v29, v30, v31 ADD_FLOAT v16, v17, v18, v19, v28, v29, v30, v31
cbnz x0, TILE12_L4_POST
TILE12_L4_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x15], #64
b End
TILE12_L4_POST: TILE12_L4_POST:
cbz x23, TILE12_L4_STORE cbz x23, TILE12_L4_STORE
@ -398,25 +403,21 @@ L4LoopDz_TILE_12:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x4 st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x4
add x4, x4, #128
b End b End
TILE_8: TILE_8:
cmp x7, #8 cmp x7, #8
blt TILE_4 blt TILE_4
mov x10, x0 sub x19, x4, #64
mov x6, x0
mov x12, x2 mov x12, x2
mov x14, x5 mov x14, x5
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x26 // weightQuantBias
cmp x5, #2 cmp x5, #2
blt L4LoopDz_TILE_8 blt L4LoopDz_TILE_8
L8LoopDz_TILE_8: L8LoopDz_TILE_8:
//ld1 {v0.4s, v1.4s}, [x20], #32 // bias
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
mov x27, x12
movi v7.16b, #15 movi v7.16b, #15
SET_BIAS v8, v9, v10, v11 SET_BIAS v8, v9, v10, v11
@ -453,14 +454,12 @@ L8LoopDz_TILE_8:
bne L8LoopSz_TILE_8 bne L8LoopSz_TILE_8
L8LoopSzEnd_TILE_8: L8LoopSzEnd_TILE_8:
//add x12, x12, x15 sub x14, x14, #2
add x12, x27, x15
sub x14, x14, #2
L8Tile8Quan: L8Tile8Quan:
ld1 {v0.4s, v1.4s}, [x19], #32 // scale ld1 {v0.4s, v1.4s}, [x12], #32 // scale
ld1 {v2.4s, v3.4s}, [x25] // x kernel sum ld1 {v2.4s, v3.4s}, [x8] // x kernel sum
ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint ld1 {v24.4s, v25.4s}, [x12], #32 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
Int32ToFloat v16, v17, v18, v19 Int32ToFloat v16, v17, v18, v19
@ -495,8 +494,6 @@ L8LoopDz_TILE_8:
MLA_WEIGHTZERO v22, v3, v25, 2 // tile:6, oc:4-7 MLA_WEIGHTZERO v22, v3, v25, 2 // tile:6, oc:4-7
MLA_WEIGHTZERO v23, v3, v25, 3 // tile:7, oc:4-7 MLA_WEIGHTZERO v23, v3, v25, 3 // tile:7, oc:4-7
sub x4, x4, #64
cbz x9, TILE8_ADD_DSTV cbz x9, TILE8_ADD_DSTV
TILE8_ADD_BIAS: TILE8_ADD_BIAS:
ld1 {v0.4s, v1.4s}, [x20], #32 ld1 {v0.4s, v1.4s}, [x20], #32
@ -504,19 +501,27 @@ L8LoopDz_TILE_8:
ADD_BIAS_FLOAT v12, v13, v14, v15, v0 ADD_BIAS_FLOAT v12, v13, v14, v15, v0
ADD_BIAS_FLOAT v16, v17, v18, v19, v1 ADD_BIAS_FLOAT v16, v17, v18, v19, v1
ADD_BIAS_FLOAT v20, v21, v22, v23, v1 ADD_BIAS_FLOAT v20, v21, v22, v23, v1
b TILE8_POST cbnz x0, TILE8_POST
b TILE8_L8_ACCUM_BUFFER
TILE8_ADD_DSTV: TILE8_ADD_DSTV:
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64 ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], x4 ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64 ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10] ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], #64
ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3 ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3
ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7 ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7
ADD_FLOAT v16, v17, v18, v19, v24, v25, v26, v27 ADD_FLOAT v16, v17, v18, v19, v24, v25, v26, v27
ADD_FLOAT v20, v21, v22, v23, v28, v29, v30, v31 ADD_FLOAT v20, v21, v22, v23, v28, v29, v30, v31
sub x10, x10, #128 cbnz x0, TILE8_POST
sub x10, x10, x4
TILE8_L8_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x15], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x15], #64
b L8Tile8LoopCheck
TILE8_POST: TILE8_POST:
cbz x23, TILE8_STORE cbz x23, TILE8_STORE
@ -529,11 +534,10 @@ L8LoopDz_TILE_8:
sub x23, x23, #4 sub x23, x23, #4
TILE8_STORE: TILE8_STORE:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], #64 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], x19
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64 st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x6], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], x4 st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x6], x19
add x4, x4, #64
L8Tile8LoopCheck: L8Tile8LoopCheck:
cmp x14, #1 cmp x14, #1
@ -541,7 +545,6 @@ L8LoopDz_TILE_8:
cbz x14, Tile8End cbz x14, Tile8End
L4LoopDz_TILE_8: L4LoopDz_TILE_8:
//ld1 {v0.4s}, [x20], #16 // bias
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
movi v7.16b, #15 movi v7.16b, #15
@ -569,9 +572,9 @@ L4LoopDz_TILE_8:
L4LoopSzEnd_TILE_8: L4LoopSzEnd_TILE_8:
L4Tile8Quan: L4Tile8Quan:
ld1 {v0.4s}, [x19], #16 // scale ld1 {v0.4s}, [x12], #16 // scale
ld1 {v2.4s, v3.4s}, [x25] // x kernel sum ld1 {v2.4s, v3.4s}, [x8] // x kernel sum
ld1 {v24.4s}, [x6], #16 // weight quan zeropoint ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
MUL_SCALE v0, v8, v9, v10, v11 MUL_SCALE v0, v8, v9, v10, v11
@ -592,21 +595,31 @@ L4LoopDz_TILE_8:
MLA_WEIGHTZERO v14, v3, v24, 2 // tile:6, oc:0-3 MLA_WEIGHTZERO v14, v3, v24, 2 // tile:6, oc:0-3
MLA_WEIGHTZERO v15, v3, v24, 3 // tile:7, oc:0-3 MLA_WEIGHTZERO v15, v3, v24, 3 // tile:7, oc:0-3
sub x4, x4, #64
cbz x9, TILE8_L4_ADD_DSTV cbz x9, TILE8_L4_ADD_DSTV
TILE8_L4_ADD_BIAS: TILE8_L4_ADD_BIAS:
ld1 {v4.4s}, [x20], #16 ld1 {v4.4s}, [x20], #16
ADD_BIAS_FLOAT v8, v9, v10, v11, v4 ADD_BIAS_FLOAT v8, v9, v10, v11, v4
ADD_BIAS_FLOAT v12, v13, v14, v15, v4 ADD_BIAS_FLOAT v12, v13, v14, v15, v4
b TILE8_L4_POST cbnz x0, TILE8_L4_POST
b TILE8_L4_ACCUM_BUFFER
TILE8_L4_ADD_DSTV: TILE8_L4_ADD_DSTV:
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64 ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10] ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
sub x10, x10, #64 ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], #64
ADD_FLOAT v8, v9, v10, v11, v4, v5, v6, v7 ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3
ADD_FLOAT v12, v13, v14, v15, v16, v17, v18, v19 ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7
ADD_FLOAT v16, v17, v18, v19, v24, v25, v26, v27
ADD_FLOAT v20, v21, v22, v23, v28, v29, v30, v31
cbnz x0, TILE8_L4_POST
TILE8_L4_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x15], #64
st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x15], #64
b Tile8End
TILE8_L4_POST: TILE8_L4_POST:
cbz x23, TILE8_L4_STORE cbz x23, TILE8_L4_STORE
@ -617,36 +630,30 @@ L4LoopDz_TILE_8:
sub x23, x23, #4 sub x23, x23, #4
TILE8_L4_STORE: TILE8_L4_STORE:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], #64 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], x19
add x4, x4, #64
Tile8End: Tile8End:
cbz x24, Tile8_End_Offset cbz x0, Tile8_End_Offset
add x24, x24, #32 add x0, x0, x21, LSL #3
Tile8_End_Offset: Tile8_End_Offset:
sub x7, x7, #8 sub x7, x7, #8
add x0, x0, x21, LSL #3
add x1, x1, #32 add x1, x1, #32
add x25, x25, #32 add x8, x8, #32
add x24, x24, #32
TILE_4: TILE_4:
cmp x7, #4 cmp x7, #4
blt TILE_1_Init blt TILE_1_Init
mov x10, x0 mov x6, x0
mov x12, x2 mov x12, x2
mov x14, x5 mov x14, x5
mov x19, x8
mov x20, x9 mov x20, x9
mov x6, x26 // weightQuantBias
cmp x5, #2 cmp x5, #2
blt L4LoopDz_TILE_4 blt L4LoopDz_TILE_4
L8LoopDz_TILE_4: L8LoopDz_TILE_4:
//ld1 {v0.4s, v1.4s}, [x20], #32 // bias
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
mov x27, x12
movi v7.16b, #15 movi v7.16b, #15
SET_BIAS v8, v9, v10, v11 SET_BIAS v8, v9, v10, v11
@ -672,14 +679,12 @@ L8LoopDz_TILE_4:
bne L8LoopSz_TILE_4 bne L8LoopSz_TILE_4
L8LoopSzEnd_TILE_4: L8LoopSzEnd_TILE_4:
//add x12, x12, x15
add x12, x27, x15
sub x14, x14, #2 sub x14, x14, #2
L8Tile4Quan: L8Tile4Quan:
ld1 {v0.4s, v1.4s}, [x19], #32 // scale ld1 {v0.4s, v1.4s}, [x12], #32 // scale
ld1 {v2.4s}, [x25] // x kernel sum ld1 {v2.4s}, [x8] // x kernel sum
ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint ld1 {v24.4s, v25.4s}, [x12], #32 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
MUL_SCALE v0, v8, v9, v10, v11 MUL_SCALE v0, v8, v9, v10, v11
@ -705,14 +710,20 @@ L8LoopDz_TILE_4:
ld1 {v4.4s, v5.4s}, [x20], #32 ld1 {v4.4s, v5.4s}, [x20], #32
ADD_BIAS_FLOAT v8, v9, v10, v11, v4 ADD_BIAS_FLOAT v8, v9, v10, v11, v4
ADD_BIAS_FLOAT v12, v13, v14, v15, v5 ADD_BIAS_FLOAT v12, v13, v14, v15, v5
b TILE4_POST cbnz x0, TILE4_POST
b TILE4_L8_ACCUM_BUFFER
TILE4_ADD_DSTV: TILE4_ADD_DSTV:
ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], x4 ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10] ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
sub x10, x10, x4 ADD_FLOAT v8, v9, v10, v11, v16, v17, v18, v19
ADD_FLOAT v8, v9, v10, v11, v4, v5, v6, v7 ADD_FLOAT v12, v13, v14, v15, v20, v21, v22, v23
ADD_FLOAT v12, v13, v14, v15, v16, v17, v18, v19 cbnz x0, TILE4_POST
TILE4_L8_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
b L8Tile4LoopCheck
TILE4_POST: TILE4_POST:
cbz x23, TILE4_STORE cbz x23, TILE4_STORE
@ -723,8 +734,8 @@ L8LoopDz_TILE_4:
sub x23, x23, #4 sub x23, x23, #4
TILE4_STORE: TILE4_STORE:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], x4 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], x4
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], x4
L8Tile4LoopCheck: L8Tile4LoopCheck:
cmp x14, #1 cmp x14, #1
@ -732,7 +743,6 @@ L8LoopDz_TILE_4:
cbz x14, Tile4End cbz x14, Tile4End
L4LoopDz_TILE_4: L4LoopDz_TILE_4:
//ld1 {v0.4s}, [x20], #16 // bias
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
movi v7.16b, #15 movi v7.16b, #15
@ -753,9 +763,9 @@ L4LoopDz_TILE_4:
L4LoopSzEnd_TILE_4: L4LoopSzEnd_TILE_4:
L4Tile4Quan: L4Tile4Quan:
ld1 {v0.4s}, [x19], #16 // scale ld1 {v0.4s}, [x12], #16 // scale
ld1 {v2.4s}, [x25] // x kernel sum ld1 {v2.4s}, [x8] // x kernel sum
ld1 {v24.4s}, [x6], #16 // weight quan zeropoint ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
Int32ToFloat v8, v9, v10, v11 Int32ToFloat v8, v9, v10, v11
MUL_SCALE v0, v8, v9, v10, v11 MUL_SCALE v0, v8, v9, v10, v11
@ -773,11 +783,17 @@ L4LoopDz_TILE_4:
TILE4_L4_ADD_BIAS: TILE4_L4_ADD_BIAS:
ld1 {v3.4s}, [x20], #16 ld1 {v3.4s}, [x20], #16
ADD_BIAS_FLOAT v8, v9, v10, v11, v3 ADD_BIAS_FLOAT v8, v9, v10, v11, v3
b TILE4_L4_POST cbnz x0, TILE4_L4_POST
b TILE4_L4_ACCUM_BUFFER
TILE4_L4_ADD_DSTV: TILE4_L4_ADD_DSTV:
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10] ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64
ADD_FLOAT v8, v9, v10, v11, v12, v13, v14, v15 ADD_FLOAT v8, v9, v10, v11, v16, v17, v18, v19
cbnz x0, TILE4_L4_POST
TILE4_L4_ACCUM_BUFFER:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
b Tile4End
TILE4_L4_POST: TILE4_L4_POST:
cbz x23, TILE4_L4_STORE cbz x23, TILE4_L4_STORE
@ -787,17 +803,16 @@ L4LoopDz_TILE_4:
sub x23, x23, #4 sub x23, x23, #4
TILE4_L4_STORE: TILE4_L4_STORE:
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], x4 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], x4
Tile4End: Tile4End:
cbz x24, Tile4_End_Offset cbz x0, Tile4_End_Offset
add x24, x24, #16 add x0, x0, x21, LSL #2
Tile4_End_Offset: Tile4_End_Offset:
sub x7, x7, #4 sub x7, x7, #4
add x0, x0, x21, LSL #2
add x1, x1, #16 add x1, x1, #16
add x25, x25, #16 add x8, x8, #16
add x24, x24, #16
TILE_1_Init: TILE_1_Init:
cbz x7, End cbz x7, End
@ -807,25 +822,19 @@ TILE_1_Init:
ld1r {v27.4s}, [x23] // f32 max ld1r {v27.4s}, [x23] // f32 max
sub x23, x23, #4 sub x23, x23, #4
TILE_1: TILE_1:
mov x10, x0 mov x6, x0
mov x12, x2 mov x12, x2
mov x14, x5 mov x14, x5
mov x19, x8
mov x20, x9 mov x20, x9
mov x6, x26 // weightQuantBias
cmp x5, #2 cmp x5, #2
blt L4LoopDz_TILE_1 blt L4LoopDz_TILE_1
L8LoopDz_TILE_1: L8LoopDz_TILE_1:
//ld1 {v0.4s, v1.4s}, [x20], #32 // bias
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
mov x27, x12
movi v8.16b, #0 movi v8.16b, #0
movi v9.16b, #0 movi v9.16b, #0
//cmp x13, #4
b L8LoopSz_TILE_1_lu1 b L8LoopSz_TILE_1_lu1
//lsl x22, x22, #2
L8LoopSz_TILE_1_lu4: L8LoopSz_TILE_1_lu4:
ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x12], #64 // weight: hu=0,1,2,3,pack=0~7 ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x12], #64 // weight: hu=0,1,2,3,pack=0~7
@ -864,29 +873,21 @@ L8LoopDz_TILE_1:
L8LoopSz_TILE_1_lu1: L8LoopSz_TILE_1_lu1:
ld1 {v5.16b}, [x12], #16 // weight ld1 {v5.16b}, [x12], #16 // weight
ld1 {v0.s}[0], [x11], x22 // src ld1 {v0.s}[0], [x11], x22 // src
//ld1 {v4.d}[0], [x12], #8 // weight
subs x13, x13, #1 subs x13, x13, #1
// int4->int8 // int4->int8
ushr v3.16b, v5.16b, #4 ushr v3.16b, v5.16b, #4
and v12.16b, v5.16b, v7.16b and v12.16b, v5.16b, v7.16b
//ushr v10.16b, v4.16b, #4
//and v11.16b, v4.16b, v7.16b
//zip1 v12.16b, v10.16b, v11.16b
//sub x12, x12, x15
.inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0]
.inst 0x4f80e189 // sdot v9.4s, v12.16b, v0.4b[0] .inst 0x4f80e189 // sdot v9.4s, v12.16b, v0.4b[0]
bne L8LoopSz_TILE_1_lu1 bne L8LoopSz_TILE_1_lu1
L8LoopSzEnd_TILE_1: L8LoopSzEnd_TILE_1:
add x12, x27, x15
sub x14, x14, #2 sub x14, x14, #2
L8Tile1Quan: L8Tile1Quan:
ld1 {v0.4s, v1.4s}, [x19], #32 // scale ld1 {v0.4s, v1.4s}, [x12], #32 // scale
ld1 {v2.s}[0], [x25] // x kernel sum ld1 {v2.s}[0], [x8] // x kernel sum
ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint ld1 {v24.4s, v25.4s}, [x12], #32 // weight quan zeropoint
scvtf v8.4s, v8.4s scvtf v8.4s, v8.4s
scvtf v9.4s, v9.4s scvtf v9.4s, v9.4s
fmul v8.4s, v8.4s, v0.4s fmul v8.4s, v8.4s, v0.4s
@ -906,14 +907,18 @@ L8LoopDz_TILE_1:
ld1 {v10.4s, v11.4s}, [x20], #32 ld1 {v10.4s, v11.4s}, [x20], #32
fadd v8.4s, v8.4s, v10.4s fadd v8.4s, v8.4s, v10.4s
fadd v9.4s, v9.4s, v11.4s fadd v9.4s, v9.4s, v11.4s
b TILE1_POST cbnz x0, TILE1_POST
b TILE1_L8_ACCUM_BUFFER
TILE1_ADD_DSTV: TILE1_ADD_DSTV:
ld1 {v10.4s}, [x10], x4 ld1 {v10.4s, v11.4s}, [x10], #32
ld1 {v11.4s}, [x10]
sub x10, x10, x4
fadd v8.4s, v8.4s, v10.4s fadd v8.4s, v8.4s, v10.4s
fadd v9.4s, v9.4s, v11.4s fadd v9.4s, v9.4s, v11.4s
cbnz x0, TILE1_POST
TILE1_L8_ACCUM_BUFFER:
st1 {v8.4s, v9.4s}, [x15], #32
b L8Tile1LoopCheck
TILE1_POST: TILE1_POST:
cbz x23, TILE1_STORE cbz x23, TILE1_STORE
@ -923,8 +928,8 @@ L8LoopDz_TILE_1:
fmax v9.4s, v9.4s, v26.4s fmax v9.4s, v9.4s, v26.4s
TILE1_STORE: TILE1_STORE:
st1 {v8.4s}, [x10], x4 st1 {v8.4s}, [x6], x4
st1 {v9.4s}, [x10], x4 st1 {v9.4s}, [x6], x4
L8Tile1LoopCheck: L8Tile1LoopCheck:
cmp x14, #1 cmp x14, #1
@ -932,7 +937,6 @@ L8LoopDz_TILE_1:
cbz x14, Tile1End cbz x14, Tile1End
L4LoopDz_TILE_1: L4LoopDz_TILE_1:
//ld1 {v0.4s}, [x20], #16 // bias
mov x11, x1 mov x11, x1
mov x13, x3 mov x13, x3
movi v8.16b, #0 movi v8.16b, #0
@ -949,9 +953,9 @@ L4LoopDz_TILE_1:
L4LoopSzEnd_TILE_1: L4LoopSzEnd_TILE_1:
L4Tile1Quan: L4Tile1Quan:
ld1 {v0.4s}, [x19], #16 // scale ld1 {v0.4s}, [x12], #16 // scale
ld1 {v2.s}[0], [x25] // x kernel sum ld1 {v2.s}[0], [x8] // x kernel sum
ld1 {v24.4s}, [x6], #16 // weight quan zeropoint ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
scvtf v8.4s, v8.4s scvtf v8.4s, v8.4s
fmul v8.4s, v8.4s, v0.4s fmul v8.4s, v8.4s, v0.4s
@ -966,40 +970,43 @@ L4LoopDz_TILE_1:
TILE1_L4_ADD_BIAS: TILE1_L4_ADD_BIAS:
ld1 {v4.4s}, [x20], #16 ld1 {v4.4s}, [x20], #16
fadd v8.4s, v8.4s, v4.4s fadd v8.4s, v8.4s, v4.4s
b TILE1_L4_POST cbnz x0, TILE1_L4_POST
b TILE1_L4_ACCUM_BUFFER
TILE1_L4_ADD_DSTV: TILE1_L4_ADD_DSTV:
ld1 {v4.4s}, [x10] ld1 {v10.4s}, [x10], #16
fadd v8.4s, v8.4s, v4.4s fadd v8.4s, v8.4s, v10.4s
cbnz x0, TILE1_L4_POST
TILE1_L4_ACCUM_BUFFER:
st1 {v8.4s}, [x15], #16
b Tile1End
TILE1_L4_POST: TILE1_L4_POST:
cbz x23, TILE1_L4_STORE cbz x23, TILE1_L4_STORE
fmax v8.4s, v8.4s, v26.4s fmax v8.4s, v8.4s, v26.4s
fmin v8.4s, v8.4s, v27.4s fmin v8.4s, v8.4s, v27.4s
TILE1_L4_STORE: TILE1_L4_STORE:
st1 {v8.4s}, [x10], x4 st1 {v8.4s}, [x6], x4
Tile1End: Tile1End:
cbz x24, Tile1_End_Offset cbz x0, Tile1_End_Offset
add x24, x24, #4
Tile1_End_Offset:
subs x7, x7, #1
add x0, x0, x21 add x0, x0, x21
Tile1_End_Offset:
add x24, x24, #4
subs x7, x7, #1
add x1, x1, #4 add x1, x1, #4
add x25, x25, #4 add x8, x8, #4
bne TILE_1 bne TILE_1
End: End:
ldp x23, x24, [sp, #(16 * 8)] ldp x23, x24, [sp, #(16 * 6)]
ldp x25, x26, [sp, #(16 * 7)]
ldp x27, x28, [sp, #(16 * 6)]
ldp x19, x20, [sp, #(16 * 5)] ldp x19, x20, [sp, #(16 * 5)]
ldp x21, x22, [sp, #(16 * 4)] ldp x21, x22, [sp, #(16 * 4)]
ldp d8, d9, [sp, #(16 * 3)] ldp d8, d9, [sp, #(16 * 3)]
ldp d10, d11, [sp, #(16 * 2)] ldp d10, d11, [sp, #(16 * 2)]
ldp d12, d13, [sp, #(16 * 1)] ldp d12, d13, [sp, #(16 * 1)]
ldp d14, d15, [sp], #(16 * 10) ldp d14, d15, [sp], #(16 * 8)
ret ret
#endif // __aarch64__ #endif // __aarch64__

View File

@ -100,27 +100,24 @@ struct QuanPostTreatParameters {
//Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step //Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step
//x5:dst_depth_quad, x6: parameters, x7: realDstCount //x5:dst_depth_quad, x6: parameters, x7: realDstCount
//Load from x6: x8: scale, x9: bias, x27: srcKernelSum, x28: weightQuanBias, //Load from x6: x9: bias, x8: srcKernelSum
ldr x8, [x6, #0]
ldr x9, [x6, #8] ldr x9, [x6, #8]
stp d14, d15, [sp, #(-16 * 10)]! stp d14, d15, [sp, #(-16 * 8)]!
stp d12, d13, [sp, #(16 * 1)] stp d12, d13, [sp, #(16 * 1)]
stp d10, d11, [sp, #(16 * 2)] stp d10, d11, [sp, #(16 * 2)]
stp d8, d9, [sp, #(16 * 3)] stp d8, d9, [sp, #(16 * 3)]
stp x21, x22, [sp, #(16 * 4)] stp x21, x22, [sp, #(16 * 4)]
stp x19, x20, [sp, #(16 * 5)] stp x19, x20, [sp, #(16 * 5)]
stp x23, x24, [sp, #(16 * 6)] stp x23, x24, [sp, #(16 * 6)]
stp x25, x26, [sp, #(16 * 7)] ldr x8, [x6, #40] // srcKernelSum
stp x27, x28, [sp, #(16 * 8)]
ldr x27, [x6, #40] // srcKernelSum
ldr x28, [x6, #48] // weightQuanBias
lsl x15, x3, #5 // x15 = src_depth_quad * UNIT * UNIT_SRC = src_depth_quad * 64 * (sizeof(int4)) = src_depth_quad << 4
mov x21, #16 // sizeof(float) * pack mov x21, #16 // sizeof(float) * pack
ldr x14, [x6, #56] // float32 maxmin ptr ldr x14, [x6, #56] // float32 maxmin ptr
ldr x23, [x6, #80] // extra scale ldr x23, [x6, #80] // extra scale
ldr x15, [x6, #96] // accumBuffer
mov x10, x15
mov x19, x23
Start: Start:
lsl x22, x7, #3// eDest * GEMM_INT8_SRC_UNIT lsl x22, x7, #3// eDest * GEMM_INT8_SRC_UNIT
@ -134,7 +131,6 @@ blt LoopDz4_TILE_10
LoopDz8_TILE_10: LoopDz8_TILE_10:
mov x11, x1 // src mov x11, x1 // src
mov x12, x2 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
movi v2.16b, #15 movi v2.16b, #15
@ -144,7 +140,7 @@ LoopDz8_TILE_10:
SET_0_5 v15, v19, v23, v27, v31 // oc:6,7,6,7 SET_0_5 v15, v19, v23, v27, v31 // oc:6,7,6,7
LoopSz_TILE_10: LoopSz_TILE_10:
ld1 {v0.16b, v1.16b}, [x12], #32 // weight ld1 {v0.16b, v1.16b}, [x2], #32 // weight
ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9 ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9
ld1 {v7.16b}, [x11], #16 ld1 {v7.16b}, [x11], #16
// int4->int8 // int4->int8
@ -181,7 +177,6 @@ LoopSz_TILE_10:
.inst 0x4e8ba4ff // smmla v31.4s, v7.16b, v11.16b // tile8-oc6, tile8-oc7, tile9-oc6, tile9-oc7 .inst 0x4e8ba4ff // smmla v31.4s, v7.16b, v11.16b // tile8-oc6, tile8-oc7, tile9-oc6, tile9-oc7
bne LoopSz_TILE_10 bne LoopSz_TILE_10
LoopSzEnd_TILE_10: LoopSzEnd_TILE_10:
add x2, x2, x15 // weight += dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT * 0.5);
sub x5, x5, #2 // dz-2 sub x5, x5, #2 // dz-2
// transpose // transpose
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
@ -213,11 +208,11 @@ LoopSzEnd_TILE_10:
Int32ToFloat v16, v17, v18, v19 Int32ToFloat v16, v17, v18, v19
Tile10Quan: Tile10Quan:
ld1 {v20.4s, v21.4s}, [x8], #32 // scale ld1 {v20.4s, v21.4s}, [x2], #32 // scale
ld1 {v22.4s, v23.4s}, [x27], #32 // x kernel sum ld1 {v22.4s, v23.4s}, [x8], #32 // x kernel sum
ld1 {v24.d}[0], [x27] ld1 {v24.d}[0], [x8]
ld1 {v25.4s, v26.4s}, [x28], #32 // weight quan zeropoint ld1 {v25.4s, v26.4s}, [x2], #32 // weight quan zeropoint
sub x27, x27, #32 sub x8, x8, #32
MUL_SCALE v20, v0, v1, v2, v3 MUL_SCALE v20, v0, v1, v2, v3
MUL_SCALE v20, v4, v5, v6, v7 MUL_SCALE v20, v4, v5, v6, v7
MUL_SCALE v21, v10, v11, v12, v13 MUL_SCALE v21, v10, v11, v12, v13
@ -227,7 +222,7 @@ Tile10Quan:
fmul v18.4s, v18.4s, v21.4s fmul v18.4s, v18.4s, v21.4s
fmul v19.4s, v19.4s, v21.4s fmul v19.4s, v19.4s, v21.4s
cbz x23, TILE10_MLA cbz x19, TILE10_MLA
ld1 {v27.4s, v28.4s}, [x23], #32 ld1 {v27.4s, v28.4s}, [x23], #32
ld1 {v29.d}[0], [x23] ld1 {v29.d}[0], [x23]
MUL_EXTRA_SCALE v27, v0, v1, v2, v3 MUL_EXTRA_SCALE v27, v0, v1, v2, v3
@ -277,28 +272,31 @@ Tile10Quan:
fadd v9.4s, v9.4s, v20.4s fadd v9.4s, v9.4s, v20.4s
fadd v18.4s, v18.4s, v21.4s fadd v18.4s, v18.4s, v21.4s
fadd v19.4s, v19.4s, v21.4s fadd v19.4s, v19.4s, v21.4s
b TILE10_POST cbnz x0, TILE10_POST
b TILE10_L8_ACCUM_BUFFER
TILE10_ADD_DSTV: TILE10_ADD_DSTV:
// first batch10 // first batch10
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
ld1 {v28.4s, v29.4s}, [x0], x4 ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10], #64
ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23 ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23
ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27 ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27
fadd v8.4s, v8.4s, v28.4s ADD_FLOAT v8, v9, v10, v11, v28, v29, v30, v31
fadd v9.4s, v9.4s, v29.4s
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64
ld1 {v28.4s, v29.4s}, [x0]
ADD_FLOAT v10, v11, v12, v13, v20, v21, v22, v23
ADD_FLOAT v14, v15, v16, v17, v24, v25, v26, v27
fadd v18.4s, v18.4s, v28.4s
fadd v19.4s, v19.4s, v29.4s
sub x0, x0, #256 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
sub x0, x0, x4 ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
ADD_FLOAT v12, v13, v14, v15, v20, v21, v22, v23
ADD_FLOAT v16, v17, v18, v19, v24, v25, v26, v27
cbnz x0, TILE10_POST
TILE10_L8_ACCUM_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x15], #64
b Tile10LoopCheck
TILE10_POST: TILE10_POST:
cbz x14, TILE10_STORE cbz x14, TILE10_STORE
@ -326,14 +324,13 @@ Tile10LoopCheck:
LoopDz4_TILE_10: LoopDz4_TILE_10:
mov x11, x1 // src mov x11, x1 // src
mov x12, x2 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
SET_0_5 v12, v13, v16, v17, v20 SET_0_5 v12, v13, v16, v17, v20
SET_0_5 v21, v24, v25, v28, v29 SET_0_5 v21, v24, v25, v28, v29
LoopSz4_TILE_10: LoopSz4_TILE_10:
ld1 {v0.16b, v1.16b}, [x12], #32 // weight ld1 {v0.16b, v1.16b}, [x2], #32 // weight
ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9 ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9
ld1 {v7.16b}, [x11], #16 ld1 {v7.16b}, [x11], #16
subs x13, x13, #1 subs x13, x13, #1
@ -376,16 +373,16 @@ LoopSz4End_TILE_10:
scvtf v9.4s, v9.4s scvtf v9.4s, v9.4s
Tile10Quan_L4: Tile10Quan_L4:
ld1 {v20.4s}, [x8] // scale ld1 {v20.4s}, [x2], #16 // scale
ld1 {v22.4s, v23.4s}, [x27], #32 // x kernel sum ld1 {v22.4s, v23.4s}, [x8], #32 // x kernel sum
ld1 {v24.d}[0], [x27] ld1 {v24.d}[0], [x8]
ld1 {v25.4s}, [x28] // weight quan zeropoint ld1 {v25.4s}, [x2] // weight quan zeropoint
MUL_SCALE v20, v0, v1, v2, v3 MUL_SCALE v20, v0, v1, v2, v3
MUL_SCALE v20, v4, v5, v6, v7 MUL_SCALE v20, v4, v5, v6, v7
fmul v8.4s, v8.4s, v20.4s fmul v8.4s, v8.4s, v20.4s
fmul v9.4s, v9.4s, v20.4s fmul v9.4s, v9.4s, v20.4s
cbz x23, TILE10_MLA_L4 cbz x19, TILE10_MLA_L4
ld1 {v27.4s, v28.4s}, [x23], #32 ld1 {v27.4s, v28.4s}, [x23], #32
ld1 {v29.d}[0], [x23] ld1 {v29.d}[0], [x23]
MUL_EXTRA_SCALE v27, v0, v1, v2, v3 MUL_EXTRA_SCALE v27, v0, v1, v2, v3
@ -404,7 +401,6 @@ Tile10Quan_L4:
MLA_WEIGHTZERO v7, v23, v25, 3 // tile:7, oc:0-3 MLA_WEIGHTZERO v7, v23, v25, 3 // tile:7, oc:0-3
MLA_WEIGHTZERO v8, v24, v25, 0 // tile:8, oc:0-3 MLA_WEIGHTZERO v8, v24, v25, 0 // tile:8, oc:0-3
MLA_WEIGHTZERO v9, v24, v25, 1 // tile:9, oc:0-3 MLA_WEIGHTZERO v9, v24, v25, 1 // tile:9, oc:0-3
//sub x4, x4, #128
TILE10_ADD_BIAS_L4: TILE10_ADD_BIAS_L4:
cbz x9, TILE10_ADD_DSTV_L4 cbz x9, TILE10_ADD_DSTV_L4
@ -413,19 +409,25 @@ Tile10Quan_L4:
ADD_BIAS_FLOAT v4, v5, v6, v7, v20 ADD_BIAS_FLOAT v4, v5, v6, v7, v20
fadd v8.4s, v8.4s, v20.4s fadd v8.4s, v8.4s, v20.4s
fadd v9.4s, v9.4s, v20.4s fadd v9.4s, v9.4s, v20.4s
cbz x0, TILE10_L4_ACCUM_BUFFER
b TILE10_POST_L4 b TILE10_POST_L4
TILE10_ADD_DSTV_L4: TILE10_ADD_DSTV_L4:
// first batch10 // first batch10
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
ld1 {v28.4s, v29.4s}, [x0] ld1 {v28.4s, v29.4s}, [x10], #32
ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23 ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23
ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27 ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27
fadd v8.4s, v8.4s, v28.4s fadd v8.4s, v8.4s, v28.4s
fadd v9.4s, v9.4s, v29.4s fadd v9.4s, v9.4s, v29.4s
cbnz x0, TILE10_POST_L4
sub x0, x0, #128 TILE10_L4_ACCUM_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
st1 {v8.4s, v9.4s}, [x15], #32
b End
TILE10_POST_L4: TILE10_POST_L4:
cbz x14, TILE10_STORE_L4 cbz x14, TILE10_STORE_L4
@ -459,17 +461,14 @@ TILE_8:
TILE8_START: TILE8_START:
mov x24, x5 // dst_depth_quad mov x24, x5 // dst_depth_quad
mov x26, x0 // dst mov x6, x0 // dst
mov x25, x2 // weight mov x12, x2 // weight
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x28 // weightQuanBias
cmp x5, #2 cmp x5, #2
blt LoopDz4_TILE_8 blt LoopDz4_TILE_8
LoopDz_TILE_8: LoopDz_TILE_8:
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
SET_0_4 v12, v16, v20, v24 SET_0_4 v12, v16, v20, v24
SET_0_4 v13, v17, v21, v25 SET_0_4 v13, v17, v21, v25
@ -508,7 +507,6 @@ LoopSz_TILE_8:
bne LoopSz_TILE_8 bne LoopSz_TILE_8
LoopSzEnd_TILE_8: LoopSzEnd_TILE_8:
add x25, x25, x15
sub x24, x24, #2 // dz-2 sub x24, x24, #2 // dz-2
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
@ -535,15 +533,15 @@ LoopSzEnd_TILE_8:
Int32ToFloat v12, v13, v14, v15 Int32ToFloat v12, v13, v14, v15
Tile8Quan: Tile8Quan:
ld1 {v20.4s, v21.4s}, [x19], #32 // scale ld1 {v20.4s, v21.4s}, [x12], #32 // scale
ld1 {v22.4s, v23.4s}, [x27] // x kernel sum ld1 {v22.4s, v23.4s}, [x8] // x kernel sum
ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint ld1 {v25.4s, v26.4s}, [x12], #32 // weight quan zeropoint
MUL_SCALE v20, v0, v1, v2, v3 MUL_SCALE v20, v0, v1, v2, v3
MUL_SCALE v20, v4, v5, v6, v7 MUL_SCALE v20, v4, v5, v6, v7
MUL_SCALE v21, v8, v9, v10, v11 MUL_SCALE v21, v8, v9, v10, v11
MUL_SCALE v21, v12, v13, v14, v15 MUL_SCALE v21, v12, v13, v14, v15
cbz x23, TILE8_MLA cbz x19, TILE8_MLA
ld1 {v18.4s, v19.4s}, [x23] ld1 {v18.4s, v19.4s}, [x23]
MUL_EXTRA_SCALE v18, v0, v1, v2, v3 MUL_EXTRA_SCALE v18, v0, v1, v2, v3
MUL_EXTRA_SCALE v19, v4, v5, v6, v7 MUL_EXTRA_SCALE v19, v4, v5, v6, v7
@ -576,19 +574,26 @@ Tile8Quan:
ADD_BIAS_FLOAT v4, v5, v6, v7, v16 ADD_BIAS_FLOAT v4, v5, v6, v7, v16
ADD_BIAS_FLOAT v8, v9, v10, v11, v17 ADD_BIAS_FLOAT v8, v9, v10, v11, v17
ADD_BIAS_FLOAT v12, v13, v14, v15, v17 ADD_BIAS_FLOAT v12, v13, v14, v15, v17
b TILE8_POST cbnz x0, TILE8_POST
b TILE8_L8_ACCUM_BUFFER
TILE8_ADD_DSTV: TILE8_ADD_DSTV:
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x26], #64 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x26], x4 ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x26], #64 ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64
ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23 ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23
ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27 ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x26] ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
ADD_FLOAT v8, v9, v10, v11, v16, v17, v18, v19 ADD_FLOAT v8, v9, v10, v11, v16, v17, v18, v19
ADD_FLOAT v12, v13, v14, v15, v20, v21, v22, v23 ADD_FLOAT v12, v13, v14, v15, v20, v21, v22, v23
sub x26, x26, x4 cbnz x0, TILE8_POST
sub x26, x26, #128
TILE8_L8_ACCUM_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x15], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x15], #64
b Tile8LoopCheck
TILE8_POST: TILE8_POST:
cbz x14, TILE8_STORE cbz x14, TILE8_STORE
@ -598,10 +603,10 @@ Tile8Quan:
ReLU_FP32 v12, v13, v14, v15, v30, v31 ReLU_FP32 v12, v13, v14, v15, v30, v31
TILE8_STORE: TILE8_STORE:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x26], #64 st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x6], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x26], x4 st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x6], x4
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x26], #64 st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x6], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x26], x4 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], x4
b Tile8LoopCheck b Tile8LoopCheck
Tile8LoopCheck: Tile8LoopCheck:
@ -611,7 +616,6 @@ Tile8LoopCheck:
LoopDz4_TILE_8: LoopDz4_TILE_8:
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
SET_0_4 v12, v13, v16, v17 SET_0_4 v12, v13, v16, v17
SET_0_4 v20, v21, v24, v25 SET_0_4 v20, v21, v24, v25
@ -637,7 +641,6 @@ LoopSz4_TILE_8:
bne LoopSz4_TILE_8 bne LoopSz4_TILE_8
LoopSz4End_TILE_8: LoopSz4End_TILE_8:
add x25, x25, x15
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
uzp1 v2.2d, v16.2d, v17.2d // E2: oc:0-3 uzp1 v2.2d, v16.2d, v17.2d // E2: oc:0-3
@ -650,13 +653,13 @@ LoopSz4End_TILE_8:
Int32ToFloat v4, v5, v6, v7 Int32ToFloat v4, v5, v6, v7
Tile8Quan_L4: Tile8Quan_L4:
ld1 {v20.4s}, [x19] // scale ld1 {v20.4s}, [x12], #16 // scale
ld1 {v22.4s, v23.4s}, [x27] // x kernel sum ld1 {v22.4s, v23.4s}, [x8] // x kernel sum
ld1 {v25.4s}, [x6] // weight quan zeropoint ld1 {v25.4s}, [x12] // weight quan zeropoint
MUL_SCALE v20, v0, v1, v2, v3 MUL_SCALE v20, v0, v1, v2, v3
MUL_SCALE v20, v4, v5, v6, v7 MUL_SCALE v20, v4, v5, v6, v7
cbz x23, TILE8_MLA_L4 cbz x19, TILE8_MLA_L4
ld1 {v18.4s, v19.4s}, [x23] ld1 {v18.4s, v19.4s}, [x23]
MUL_EXTRA_SCALE v18, v0, v1, v2, v3 MUL_EXTRA_SCALE v18, v0, v1, v2, v3
MUL_EXTRA_SCALE v19, v4, v5, v6, v7 MUL_EXTRA_SCALE v19, v4, v5, v6, v7
@ -676,14 +679,20 @@ Tile8Quan_L4:
ld1 {v16.4s}, [x20] ld1 {v16.4s}, [x20]
ADD_BIAS_FLOAT v0, v1, v2, v3, v16 ADD_BIAS_FLOAT v0, v1, v2, v3, v16
ADD_BIAS_FLOAT v4, v5, v6, v7, v16 ADD_BIAS_FLOAT v4, v5, v6, v7, v16
b TILE8_POST_L4 cbnz x0, TILE8_POST_L4
b TILE8_L4_ACCUM_BUFFER
TILE8_ADD_DSTV_L4: TILE8_ADD_DSTV_L4:
ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x26], #64 ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], #64
ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x26] ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64
ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23 ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23
ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27 ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27
sub x26, x26, #64 cbnz x0, TILE8_POST_L4
TILE8_L4_ACCUM_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
b Tile8Check
TILE8_POST_L4: TILE8_POST_L4:
cbz x14, TILE8_STORE_L4 cbz x14, TILE8_STORE_L4
@ -691,36 +700,32 @@ Tile8Quan_L4:
ReLU_FP32 v4, v5, v6, v7, v30, v31 ReLU_FP32 v4, v5, v6, v7, v30, v31
TILE8_STORE_L4: TILE8_STORE_L4:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x26], #64 st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x6], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x26], x4 st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x6], x4
b Tile8Check
Tile8Check: Tile8Check:
cbz x23, Tile8End cbz x0, Tile8End
add x23, x23, #32 add x0, x0, x21, LSL #3
Tile8End: Tile8End:
sub x7, x7, #8 sub x7, x7, #8
add x0, x0, x21, LSL #3
add x1, x1, #64 add x1, x1, #64
add x27, x27, #32 add x8, x8, #32
add x4, x4, #64 // Revert x4 for following tile. add x4, x4, #64 // Revert x4 for following tile.
add x23, x23, #32
TILE_4: TILE_4:
cmp x7, #4 cmp x7, #4
blt TILE_2 blt TILE_2
mov x24, x5 // dst_depth_quad mov x24, x5 // dst_depth_quad
mov x26, x0 // dst mov x6, x0 // dst
mov x25, x2 // weight mov x12, x2 // weight
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x28 // weightQuanBias
cmp x5, #2 cmp x5, #2
blt LoopDz4_TILE_4 blt LoopDz4_TILE_4
LoopDz_TILE_4: LoopDz_TILE_4:
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
SET_0_4 v12, v13, v14, v15 SET_0_4 v12, v13, v14, v15
SET_0_4 v16, v17, v18, v19 SET_0_4 v16, v17, v18, v19
@ -746,7 +751,6 @@ LoopSz_TILE_4:
.inst 0x4e8ba4b3 // smmla v19.4s, v5.16b, v11.16b // tile2-oc6, tile2-oc7, tile3-oc6, tile3-oc7 .inst 0x4e8ba4b3 // smmla v19.4s, v5.16b, v11.16b // tile2-oc6, tile2-oc7, tile3-oc6, tile3-oc7
bne LoopSz_TILE_4 bne LoopSz_TILE_4
LoopSzEnd_TILE_4: LoopSzEnd_TILE_4:
add x25, x25, x15
sub x24, x24, #2 sub x24, x24, #2
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
@ -761,13 +765,13 @@ LoopSzEnd_TILE_4:
Int32ToFloat v4, v5, v6, v7 Int32ToFloat v4, v5, v6, v7
Tile4Quan: Tile4Quan:
ld1 {v20.4s, v21.4s}, [x19], #32 // scale ld1 {v20.4s, v21.4s}, [x12], #32 // scale
ld1 {v22.4s}, [x27] // x kernel sum ld1 {v22.4s}, [x8] // x kernel sum
ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint ld1 {v25.4s, v26.4s}, [x12], #32 // weight quan zeropoint
MUL_SCALE v20, v0, v1, v2, v3 MUL_SCALE v20, v0, v1, v2, v3
MUL_SCALE v21, v4, v5, v6, v7 MUL_SCALE v21, v4, v5, v6, v7
cbz x23, TILE4_MLA cbz x19, TILE4_MLA
ld1 {v27.4s}, [x23] ld1 {v27.4s}, [x23]
MUL_EXTRA_SCALE v27, v0, v1, v2, v3 MUL_EXTRA_SCALE v27, v0, v1, v2, v3
MUL_EXTRA_SCALE v27, v4, v5, v6, v7 MUL_EXTRA_SCALE v27, v4, v5, v6, v7
@ -787,14 +791,20 @@ Tile4Quan:
ld1 {v16.4s, v17.4s}, [x20], #32 // bias ld1 {v16.4s, v17.4s}, [x20], #32 // bias
ADD_BIAS_FLOAT v0, v1, v2, v3, v16 ADD_BIAS_FLOAT v0, v1, v2, v3, v16
ADD_BIAS_FLOAT v4, v5, v6, v7, v17 ADD_BIAS_FLOAT v4, v5, v6, v7, v17
b TILE4_POST cbnz x0, TILE4_POST
b TILE4_L8_ACCUM_BUFFER
TILE4_ADD_DSTV: TILE4_ADD_DSTV:
ld1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x26], x4 ld1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x10], #64
ld1 {v19.4s, v20.4s, v21.4s, v22.4s}, [x26] ld1 {v19.4s, v20.4s, v21.4s, v22.4s}, [x10], #64
ADD_FLOAT v0, v1, v2, v3, v15, v16, v17, v18 ADD_FLOAT v0, v1, v2, v3, v15, v16, v17, v18
ADD_FLOAT v4, v5, v6, v7, v19, v20, v21, v22 ADD_FLOAT v4, v5, v6, v7, v19, v20, v21, v22
sub x26, x26, x4 cbnz x0, TILE4_POST
TILE4_L8_ACCUM_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x15], #64
b Tile4LoopCheck
TILE4_POST: TILE4_POST:
cbz x14, TILE4_STORE cbz x14, TILE4_STORE
@ -802,9 +812,8 @@ Tile4Quan:
ReLU_FP32 v4, v5, v6, v7, v30, v31 ReLU_FP32 v4, v5, v6, v7, v30, v31
TILE4_STORE: TILE4_STORE:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x26], x4 st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x6], x4
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x26], x4 st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x6], x4
b Tile4LoopCheck
Tile4LoopCheck: Tile4LoopCheck:
cmp x24, #2 cmp x24, #2
@ -813,7 +822,6 @@ Tile4LoopCheck:
LoopDz4_TILE_4: LoopDz4_TILE_4:
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
SET_0_4 v12, v13, v16, v17 SET_0_4 v12, v13, v16, v17
LoopSz4_TILE_4: LoopSz4_TILE_4:
@ -831,7 +839,6 @@ LoopSz4_TILE_4:
.inst 0x4e89a4b1 // smmla v17.4s, v5.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 .inst 0x4e89a4b1 // smmla v17.4s, v5.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3
bne LoopSz4_TILE_4 bne LoopSz4_TILE_4
LoopSz4End_TILE_4: LoopSz4End_TILE_4:
add x25, x25, x15
sub x24, x24, #1 sub x24, x24, #1
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
@ -840,12 +847,12 @@ LoopSz4End_TILE_4:
Int32ToFloat v0, v1, v2, v3 Int32ToFloat v0, v1, v2, v3
Tile4Quan_L4: Tile4Quan_L4:
ld1 {v20.4s}, [x19] // scale ld1 {v20.4s}, [x12], #16 // scale
ld1 {v22.4s}, [x27] // x kernel sum ld1 {v22.4s}, [x8] // x kernel sum
ld1 {v25.4s}, [x6] // weight quan zeropoint ld1 {v25.4s}, [x12] // weight quan zeropoint
MUL_SCALE v20, v0, v1, v2, v3 MUL_SCALE v20, v0, v1, v2, v3
cbz x23, TILE4_MLA_L4 cbz x19, TILE4_MLA_L4
ld1 {v27.4s}, [x23] ld1 {v27.4s}, [x23]
MUL_EXTRA_SCALE v27, v0, v1, v2, v3 MUL_EXTRA_SCALE v27, v0, v1, v2, v3
@ -859,44 +866,48 @@ Tile4Quan_L4:
cbz x9, TILE4_ADD_DSTV_L4 cbz x9, TILE4_ADD_DSTV_L4
ld1 {v16.4s}, [x20] // bias ld1 {v16.4s}, [x20] // bias
ADD_BIAS_FLOAT v0, v1, v2, v3, v16 ADD_BIAS_FLOAT v0, v1, v2, v3, v16
b TILE4_POST_L4 cbnz x0, TILE4_POST_L4
b TILE4_L4_ACCUM_BUFFER
TILE4_ADD_DSTV_L4: TILE4_ADD_DSTV_L4:
ld1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x26] ld1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x10], #64
ADD_FLOAT v0, v1, v2, v3, v15, v16, v17, v18 ADD_FLOAT v0, v1, v2, v3, v15, v16, v17, v18
cbnz x0, TILE4_POST_L4
TILE4_L4_ACCUM_BUFFER:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
b Tile4Check
TILE4_POST_L4: TILE4_POST_L4:
cbz x14, TILE4_STORE_L4 cbz x14, TILE4_STORE_L4
ReLU_FP32 v0, v1, v2, v3, v30, v31 ReLU_FP32 v0, v1, v2, v3, v30, v31
TILE4_STORE_L4: TILE4_STORE_L4:
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x26], x4 st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x6], x4
b Tile4Check b Tile4Check
Tile4Check: Tile4Check:
cbz x23, Tile4End cbz x0, Tile4End
add x23, x23, #16 add x0, x0, x21, LSL #2
Tile4End: Tile4End:
sub x7, x7, #4 sub x7, x7, #4
add x0, x0, x21, LSL #2
add x1, x1, #32 add x1, x1, #32
add x27, x27, #16 add x8, x8, #16
add x23, x23, #16
TILE_2: TILE_2:
cmp x7, #2 cmp x7, #2
blt TILE_1 blt TILE_1
mov x24, x5 // dst_depth_quad mov x24, x5 // dst_depth_quad
mov x26, x0 // dst mov x6, x0 // dst
mov x25, x2 // weight mov x12, x2 // weight
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x28 // weightQuanBias
cmp x5, #2 cmp x5, #2
blt LoopDz4_TILE_2 blt LoopDz4_TILE_2
LoopDz_TILE_2: LoopDz_TILE_2:
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
SET_0_4 v12, v13, v14, v15 SET_0_4 v12, v13, v14, v15
LoopSz_TILE_2: LoopSz_TILE_2:
@ -915,7 +926,6 @@ LoopSz_TILE_2:
subs x13, x13, #1 subs x13, x13, #1
bne LoopSz_TILE_2 bne LoopSz_TILE_2
LoopSzEnd_TILE_2: LoopSzEnd_TILE_2:
add x25, x25, x15
sub x24, x24, #2 sub x24, x24, #2
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
@ -924,15 +934,15 @@ LoopSzEnd_TILE_2:
Int32ToFloat v0, v1, v2, v3 Int32ToFloat v0, v1, v2, v3
Tile2Quan: Tile2Quan:
ld1 {v20.4s, v21.4s}, [x19], #32 // scale ld1 {v20.4s, v21.4s}, [x12], #32 // scale
ld1 {v22.d}[0], [x27] // x kernel sum ld1 {v22.d}[0], [x8] // x kernel sum
ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint ld1 {v25.4s, v26.4s}, [x12], #32 // weight quan zeropoint
fmul v0.4s, v0.4s, v20.4s fmul v0.4s, v0.4s, v20.4s
fmul v1.4s, v1.4s, v20.4s fmul v1.4s, v1.4s, v20.4s
fmul v2.4s, v2.4s, v21.4s fmul v2.4s, v2.4s, v21.4s
fmul v3.4s, v3.4s, v21.4s fmul v3.4s, v3.4s, v21.4s
cbz x23, TILE2_MLA cbz x19, TILE2_MLA
ld1 {v27.d}[0], [x23] ld1 {v27.d}[0], [x23]
fmul v0.4s, v0.4s, v27.s[0] fmul v0.4s, v0.4s, v27.s[0]
fmul v1.4s, v1.4s, v27.s[1] fmul v1.4s, v1.4s, v27.s[1]
@ -952,23 +962,24 @@ Tile2Quan:
fadd v1.4s, v1.4s, v16.4s fadd v1.4s, v1.4s, v16.4s
fadd v2.4s, v2.4s, v17.4s fadd v2.4s, v2.4s, v17.4s
fadd v3.4s, v3.4s, v17.4s fadd v3.4s, v3.4s, v17.4s
b TILE2_POST cbnz x0, TILE2_POST
b TILE2_L8_ACCUM_BUFFER
TILE2_ADD_DSTV: TILE2_ADD_DSTV:
ld1 {v18.4s, v19.4s}, [x26], x4 ld1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x10], #64
ld1 {v20.4s, v21.4s}, [x26] ADD_FLOAT v0, v1, v2, v3, v15, v16, v17, v18
fadd v0.4s, v0.4s, v18.4s cbnz x0, TILE2_POST
fadd v1.4s, v1.4s, v19.4s
fadd v2.4s, v2.4s, v20.4s TILE2_L8_ACCUM_BUFFER:
fadd v3.4s, v3.4s, v21.4s st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], #64
sub x26, x26, x4 b Tile2LoopCheck
TILE2_POST: TILE2_POST:
cbz x14, TILE2_STORE cbz x14, TILE2_STORE
ReLU_FP32 v0, v1, v2, v3, v30, v31 ReLU_FP32 v0, v1, v2, v3, v30, v31
TILE2_STORE: TILE2_STORE:
st1 {v0.4s, v1.4s}, [x26], x4 st1 {v0.4s, v1.4s}, [x6], x4
st1 {v2.4s, v3.4s}, [x26], x4 st1 {v2.4s, v3.4s}, [x6], x4
b Tile2LoopCheck b Tile2LoopCheck
Tile2LoopCheck: Tile2LoopCheck:
@ -977,7 +988,6 @@ Tile2LoopCheck:
cbz x24, Tile2Check cbz x24, Tile2Check
LoopDz4_TILE_2: LoopDz4_TILE_2:
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
movi v12.4s, #0 movi v12.4s, #0
movi v13.4s, #0 movi v13.4s, #0
@ -993,20 +1003,19 @@ LoopSz4_TILE_2:
subs x13, x13, #1 subs x13, x13, #1
bne LoopSz4_TILE_2 bne LoopSz4_TILE_2
LoopSz4End_TILE_2: LoopSz4End_TILE_2:
add x25, x25, x15
uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3
uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3
scvtf v0.4s, v0.4s scvtf v0.4s, v0.4s
scvtf v1.4s, v1.4s scvtf v1.4s, v1.4s
Tile2Quan_L4: Tile2Quan_L4:
ld1 {v20.4s}, [x19] ld1 {v20.4s}, [x12], #16
ld1 {v22.d}[0], [x27] // x kernel sum ld1 {v22.d}[0], [x8] // x kernel sum
ld1 {v25.4s}, [x6] // weight quan zeropoint ld1 {v25.4s}, [x12] // weight quan zeropoint
fmul v0.4s, v0.4s, v20.4s fmul v0.4s, v0.4s, v20.4s
fmul v1.4s, v1.4s, v20.4s fmul v1.4s, v1.4s, v20.4s
cbz x23, TILE2_MLA_L4 cbz x19, TILE2_MLA_L4
ld1 {v27.d}[0], [x23] ld1 {v27.d}[0], [x23]
fmul v0.4s, v0.4s, v27.s[0] fmul v0.4s, v0.4s, v27.s[0]
fmul v1.4s, v1.4s, v27.s[1] fmul v1.4s, v1.4s, v27.s[1]
@ -1020,44 +1029,46 @@ Tile2Quan_L4:
ld1 {v16.4s}, [x20] // bias ld1 {v16.4s}, [x20] // bias
fadd v0.4s, v0.4s, v16.4s fadd v0.4s, v0.4s, v16.4s
fadd v1.4s, v1.4s, v16.4s fadd v1.4s, v1.4s, v16.4s
b TILE2_POST_L4 cbnz x0, TILE2_POST_L4
b TILE2_L4_ACCUM_BUFFER
TILE2_ADD_DSTV_L4: TILE2_ADD_DSTV_L4:
ld1 {v18.4s, v19.4s}, [x26] ld1 {v15.4s, v16.4s}, [x10], #32
fadd v0.4s, v0.4s, v18.4s fadd v0.4s, v0.4s, v15.4s
fadd v1.4s, v1.4s, v19.4s fadd v1.4s, v1.4s, v16.4s
cbnz x0, TILE2_POST_L4
TILE2_L4_ACCUM_BUFFER:
st1 {v0.4s, v1.4s}, [x15], #32
b Tile2Check
TILE2_POST_L4: TILE2_POST_L4:
cbz x14, TILE2_STORE_L4 cbz x14, TILE2_STORE_L4
ReLU_FP32_2 v0, v1, v30, v31 ReLU_FP32_2 v0, v1, v30, v31
TILE2_STORE_L4: TILE2_STORE_L4:
st1 {v0.4s, v1.4s}, [x26], x4 st1 {v0.4s, v1.4s}, [x6], x4
b Tile2Check b Tile2Check
Tile2Check: Tile2Check:
cbz x23, Tile2End cbz x0, Tile2End
add x23, x23, #8 add x0, x0, x21, LSL #1
Tile2End: Tile2End:
sub x7, x7, #2 sub x7, x7, #2
add x0, x0, x21, LSL #1
add x1, x1, #16 add x1, x1, #16
add x27, x27, #8 add x8, x8, #8
add x23, x23, #8
TILE_1: TILE_1:
cmp x7, #1 cmp x7, #1
blt End blt End
mov x24, x5 // dst_depth_quad mov x24, x5 // dst_depth_quad
mov x26, x0 // dst mov x6, x0 // dst
mov x25, x2 // weight mov x12, x2 // weight
mov x19, x8 // scale
mov x20, x9 // bias mov x20, x9 // bias
mov x6, x28 // weightQuanBias
cmp x5, #2 cmp x5, #2
blt LoopDz4_TILE_1 blt LoopDz4_TILE_1
LoopDz_TILE_1: LoopDz_TILE_1:
//ld1 {v0.4s}, [x20], #16 // bias
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
movi v16.4s, #0 movi v16.4s, #0
@ -1137,53 +1148,56 @@ LoopSz_TILE_1_lu1:
.inst 0x4e8ba453 // smmla v19.4s, v2.16b, v11.16b .inst 0x4e8ba453 // smmla v19.4s, v2.16b, v11.16b
bne LoopSz_TILE_1_lu1 bne LoopSz_TILE_1_lu1
LoopSzEnd_TILE_1: LoopSzEnd_TILE_1:
add x25, x25, x15
sub x24, x24, #2 sub x24, x24, #2
uzp1 v27.2d, v16.2d, v17.2d uzp1 v25.2d, v16.2d, v17.2d
uzp1 v26.2d, v18.2d, v19.2d uzp1 v26.2d, v18.2d, v19.2d
scvtf v27.4s, v27.4s scvtf v25.4s, v25.4s
scvtf v26.4s, v26.4s scvtf v26.4s, v26.4s
Tile1Quan: Tile1Quan:
ld1 {v0.4s, v1.4s}, [x19], #32 // scale ld1 {v0.4s, v1.4s}, [x12], #32 // scale
ld1 {v6.s}[0], [x27] // x kernel sum ld1 {v6.s}[0], [x8] // x kernel sum
ld1 {v8.4s, v9.4s}, [x6], #32 // weight quan zeropoint ld1 {v8.4s, v9.4s}, [x12], #32 // weight quan zeropoint
fmul v27.4s, v27.4s, v0.4s fmul v25.4s, v25.4s, v0.4s
fmul v26.4s, v26.4s, v1.4s fmul v26.4s, v26.4s, v1.4s
cbz x23, TILE1_MLA cbz x19, TILE1_MLA
ld1 {v10.s}[0], [x23] ld1 {v10.s}[0], [x23]
fmul v27.4s, v27.4s, v10.s[0] fmul v25.4s, v25.4s, v10.s[0]
fmul v26.4s, v26.4s, v10.s[0] fmul v26.4s, v26.4s, v10.s[0]
TILE1_MLA: TILE1_MLA:
MLA_WEIGHTZERO v27, v6, v8, 0 // tile:0, oc:0-3 MLA_WEIGHTZERO v25, v6, v8, 0 // tile:0, oc:0-3
MLA_WEIGHTZERO v26, v6, v9, 0 // tile:0, oc:4-7 MLA_WEIGHTZERO v26, v6, v9, 0 // tile:0, oc:4-7
TILE1_ADD_BIAS: TILE1_ADD_BIAS:
cbz x9, TILE1_ADD_DSTV cbz x9, TILE1_ADD_DSTV
ld1 {v16.4s, v17.4s}, [x20], #32 // bias ld1 {v16.4s, v17.4s}, [x20], #32 // bias
fadd v27.4s, v27.4s, v16.4s fadd v25.4s, v25.4s, v16.4s
fadd v26.4s, v26.4s, v17.4s fadd v26.4s, v26.4s, v17.4s
b TILE1_POST cbnz x0, TILE1_POST
b TILE1_L8_ACCUM_BUFFER
TILE1_ADD_DSTV: TILE1_ADD_DSTV:
ld1 {v16.4s}, [x26], x4 ld1 {v15.4s, v16.4s}, [x10], #32
ld1 {v17.4s}, [x26] fadd v25.4s, v25.4s, v15.4s
fadd v27.4s, v27.4s, v16.4s fadd v26.4s, v26.4s, v16.4s
fadd v26.4s, v26.4s, v17.4s cbnz x0, TILE1_POST
sub x26, x26, x4
TILE1_L8_ACCUM_BUFFER:
st1 {v25.4s, v26.4s}, [x15], #32
b Tile1LoopEnd
TILE1_POST: TILE1_POST:
cbz x14, TILE1_STORE cbz x14, TILE1_STORE
fmin v27.4s, v27.4s, v31.4s fmin v25.4s, v25.4s, v31.4s
fmax v27.4s, v27.4s, v30.4s fmax v25.4s, v25.4s, v30.4s
fmin v26.4s, v26.4s, v31.4s fmin v26.4s, v26.4s, v31.4s
fmax v26.4s, v26.4s, v30.4s fmax v26.4s, v26.4s, v30.4s
TILE1_STORE: TILE1_STORE:
st1 {v27.4s}, [x26], x4 st1 {v25.4s}, [x6], x4
st1 {v26.4s}, [x26], x4 st1 {v26.4s}, [x6], x4
b Tile1LoopEnd b Tile1LoopEnd
Tile1LoopEnd: Tile1LoopEnd:
@ -1193,7 +1207,6 @@ Tile1LoopEnd:
LoopDz4_TILE_1: LoopDz4_TILE_1:
mov x11, x1 // src mov x11, x1 // src
mov x12, x25 // weight
mov x13, x3 // src_depth_quad mov x13, x3 // src_depth_quad
movi v16.4s, #0 movi v16.4s, #0
@ -1210,16 +1223,16 @@ LoopSz4_TILE_1:
.inst 0x4e89a451 // smmla v17.4s, v2.16b, v9.16b .inst 0x4e89a451 // smmla v17.4s, v2.16b, v9.16b
bne LoopSz4_TILE_1 bne LoopSz4_TILE_1
LoopSz4End_TILE_1: LoopSz4End_TILE_1:
add x25, x25, x15
uzp1 v27.2d, v16.2d, v17.2d uzp1 v27.2d, v16.2d, v17.2d
scvtf v27.4s, v27.4s scvtf v27.4s, v27.4s
Tile1Quan_L4: Tile1Quan_L4:
ld1 {v0.4s}, [x19] // scale ld1 {v0.4s}, [x12], #16 // scale
ld1 {v6.s}[0], [x27] // x kernel sum ld1 {v6.s}[0], [x8] // x kernel sum
ld1 {v8.4s}, [x6] // weight quan zeropoint ld1 {v8.4s}, [x12] // weight quan zeropoint
fmul v27.4s, v27.4s, v0.4s fmul v27.4s, v27.4s, v0.4s
cbz x23, TILE1_MLA_L4
cbz x19, TILE1_MLA_L4
ld1 {v10.s}[0], [x23] ld1 {v10.s}[0], [x23]
fmul v27.4s, v27.4s, v10.s[0] fmul v27.4s, v27.4s, v10.s[0]
@ -1230,11 +1243,17 @@ Tile1Quan_L4:
cbz x9, TILE1_ADD_DSTV_L4 cbz x9, TILE1_ADD_DSTV_L4
ld1 {v16.4s}, [x20] // bias ld1 {v16.4s}, [x20] // bias
fadd v27.4s, v27.4s, v16.4s fadd v27.4s, v27.4s, v16.4s
b TILE1_POST_L4 cbnz x0, TILE1_POST_L4
b TILE1_L4_ACCUM_BUFFER
TILE1_ADD_DSTV_L4: TILE1_ADD_DSTV_L4:
ld1 {v16.4s}, [x26] ld1 {v15.4s}, [x10], #16
fadd v27.4s, v27.4s, v16.4s fadd v27.4s, v27.4s, v15.4s
cbnz x0, TILE1_POST_L4
TILE1_L4_ACCUM_BUFFER:
st1 {v27.4s}, [x15], #16
b End
TILE1_POST_L4: TILE1_POST_L4:
cbz x14, TILE1_STORE_L4 cbz x14, TILE1_STORE_L4
@ -1242,19 +1261,17 @@ Tile1Quan_L4:
fmax v27.4s, v27.4s, v30.4s fmax v27.4s, v27.4s, v30.4s
TILE1_STORE_L4: TILE1_STORE_L4:
st1 {v27.4s}, [x26], x4 st1 {v27.4s}, [x6], x4
b End b End
End: End:
ldp x27, x28, [sp, #(16 * 8)]
ldp x25, x26, [sp, #(16 * 7)]
ldp x23, x24, [sp, #(16 * 6)] ldp x23, x24, [sp, #(16 * 6)]
ldp x19, x20, [sp, #(16 * 5)] ldp x19, x20, [sp, #(16 * 5)]
ldp x21, x22, [sp, #(16 * 4)] ldp x21, x22, [sp, #(16 * 4)]
ldp d8, d9, [sp, #(16 * 3)] ldp d8, d9, [sp, #(16 * 3)]
ldp d10, d11, [sp, #(16 * 2)] ldp d10, d11, [sp, #(16 * 2)]
ldp d12, d13, [sp, #(16 * 1)] ldp d12, d13, [sp, #(16 * 1)]
ldp d14, d15, [sp], #(16 * 10) ldp d14, d15, [sp], #(16 * 8)
ret ret
#endif // __aarch64__ #endif // __aarch64__

View File

@ -221,18 +221,12 @@ void MNNQuantSumFP32(float* sum, const float* dequant_scale, size_t thread, size
} }
} }
void MNNDynamicUpdateConvBiasScale(float* newbias, float* newscale, float* oldbias, float* weightScale, float* inputScale, float* weightKernelSum, float* inputZero, size_t ocQuad, size_t scaleSize) { void MNNDynamicUpdateConvBiasScale(float* newbias, float* oldbias, float* weightKernelSum, float* inputZero, size_t ocQuad) {
int ocUp4 = 4 * ocQuad; int ocUp4 = 4 * ocQuad;
int pack = 4; int pack = 4;
int blockNum = scaleSize / ocUp4;
for (int i = 0; i < ocUp4; ++i) { for (int i = 0; i < ocUp4; ++i) {
newbias[i] = oldbias[i] - weightKernelSum[i] * inputZero[0]; newbias[i] = oldbias[i] - weightKernelSum[i] * inputZero[0];
} }
for (int k = 0; k < blockNum; ++k) {
for (int i = 0; i < ocUp4; ++i) {
newscale[i + k * ocUp4] = weightScale[i + k * ocUp4] * inputScale[0];
}
}
} }
#endif // LOW_MEMORY #endif // LOW_MEMORY
@ -245,6 +239,10 @@ static void MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_q
MNNAbsMaxFP32_Pack4(source, absmax, src_depth_quad, realSize, pack); MNNAbsMaxFP32_Pack4(source, absmax, src_depth_quad, realSize, pack);
return; return;
} }
if (pack == 8) {
MNNAbsMaxFP32_Pack8(source, absmax, src_depth_quad, realSize, pack);
return;
}
#endif #endif
// source: (ic/4, N, 4) // source: (ic/4, N, 4)
auto srcStep = pack * realSize; auto srcStep = pack * realSize;
@ -266,6 +264,10 @@ void MNNDynamicQuantFP32(const float* src, int8_t* dst, const float* scale, size
MNNDynamicQuantFP32_Pack4(src, dst, scale, src_depth_quad, realSize, pack); MNNDynamicQuantFP32_Pack4(src, dst, scale, src_depth_quad, realSize, pack);
return; return;
} }
if (pack == 8) {
MNNDynamicQuantFP32_Pack8(src, dst, scale, src_depth_quad, realSize, pack);
return;
}
#endif #endif
#ifdef MNN_USE_SSE #ifdef MNN_USE_SSE
uint8_t* dstPtr = reinterpret_cast<uint8_t*>(dst); uint8_t* dstPtr = reinterpret_cast<uint8_t*>(dst);
@ -3254,7 +3256,7 @@ static void generalIm2col(float* destOrigin, float const** sourceGroup, const in
} }
} }
} }
#endif #endif // MNN_LOW_MEMORY
namespace MNN { namespace MNN {

View File

@ -137,10 +137,12 @@ void MNNPackedMatMulRemain_int4(float* C, const float* A, const float* B, size_t
void MNNPackedMatMul_int8(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void MNNPackedMatMul_int8(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
void MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
void MNNAbsMaxFP32_Pack4(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack); void MNNAbsMaxFP32_Pack4(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack);
void MNNAbsMaxFP32_Pack8(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack);
void MNNQuantScaleFP32(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch); void MNNQuantScaleFP32(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch);
void MNNDynamicQuantFP32_Pack4(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack); void MNNDynamicQuantFP32_Pack4(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack);
void MNNDynamicQuantFP32_Pack8(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack);
void MNNQuantSumFP32(float* sum, const float* dequant_scale, size_t thread, size_t batch); void MNNQuantSumFP32(float* sum, const float* dequant_scale, size_t thread, size_t batch);
void MNNDynamicUpdateConvBiasScale(float* newbias, float* newscale, float* oldbias, float* weightScale, float* inputScale, float* weightKernelSum, float* inputZero, size_t ocQuad, size_t scaleSize); void MNNDynamicUpdateConvBiasScale(float* newbias, float* oldbias, float* weightKernelSum, float* inputZero, size_t ocQuad);
void MNNPackForSparseMatMul_B(float* dest, unsigned int* NNZMap, int* dataOffsetMap, int sparseBlockOC, const float* source, size_t h, size_t l, const int eP, bool transpose); void MNNPackForSparseMatMul_B(float* dest, unsigned int* NNZMap, int* dataOffsetMap, int sparseBlockOC, const float* source, size_t h, size_t l, const int eP, bool transpose);
struct SparseMatMulParas struct SparseMatMulParas
@ -230,7 +232,7 @@ struct CoreFunctions {
void(*MNNComputeMatMulForH_1)(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId); void(*MNNComputeMatMulForH_1)(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId);
void(*MNNComputeMatMulForE_1)(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId); void(*MNNComputeMatMulForE_1)(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId);
void(*MNNCountMaxMinValue)(float* source, float* minVal, float* maxVal, size_t size); void(*MNNCountMaxMinValue)(float* source, float* minVal, float* maxVal, size_t size);
void(*MNNDynamicUpdateConvBiasScale)(float* newbias, float* newscale, float* oldbias, float* weightScale, float* inputScale, float* weightKernelSum, float* inputZero, size_t ocQuad, size_t scaleSize); void(*MNNDynamicUpdateConvBiasScale)(float* newbias, float* oldbias, float* weightKernelSum, float* inputZero, size_t ocQuad);
typedef void(*MNNPackedMatMulKernel)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias); typedef void(*MNNPackedMatMulKernel)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias);

File diff suppressed because it is too large Load Diff

View File

@ -24,7 +24,8 @@ public:
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
virtual void getPackParameter(int* Unit, int* SrcUnit, int* DestUnit, const CoreInt8Functions* core) = 0; virtual void getPackParameter(int* Unit, int* SrcUnit, int* DestUnit, const CoreInt8Functions* core) = 0;
static void reorderWeight(Tensor* weight, const uint8_t* weightSrc, int SRC_UNIT, int UNIT, int ic, int oc, int kernelCount, int pack, int blockNum = 1, int32_t initval = 0); static void packWeightAndQuantInfo(int8_t* dstbuffer, const int8_t* weight, const int8_t* quantInfo, int32_t* info, int infoBytes = 4);
static void reorderWeight(uint8_t* dst, const uint8_t* src, int32_t* info, int32_t initval = 0);
protected: protected:
ConvolutionCommon::Im2ColParameter mIm2ColParamter; ConvolutionCommon::Im2ColParameter mIm2ColParamter;
@ -67,9 +68,8 @@ private:
std::function<void(float* dest, int8_t* source, const float* scale, ssize_t realDstCount, SumByAxisParams sumParams)> mSumByAxisLFunc; std::function<void(float* dest, int8_t* source, const float* scale, ssize_t realDstCount, SumByAxisParams sumParams)> mSumByAxisLFunc;
std::shared_ptr<Tensor> mQuantInput; std::shared_ptr<Tensor> mQuantInput;
std::shared_ptr<Tensor> mDynamicBias; std::shared_ptr<Tensor> mDynamicBias;
std::shared_ptr<Tensor> mScaleFuse; std::shared_ptr<Tensor> mAccumBuffer;
std::shared_ptr<Tensor> mBatchQuantInfo; std::shared_ptr<Tensor> mBatchQuantInfo;
std::shared_ptr<Tensor> mInputDeqScales;
std::shared_ptr<Tensor> mTempMaxMinValueBuffer; std::shared_ptr<Tensor> mTempMaxMinValueBuffer;
std::vector<uint8_t> mTempSrcSum; std::vector<uint8_t> mTempSrcSum;
std::vector<int32_t> mDivides; std::vector<int32_t> mDivides;

View File

@ -32,11 +32,12 @@ std::shared_ptr<ConvInt8Winograd::WinoResource> ConvInt8Winograd::makeWinoResour
int kySize = attr[2], kxSize = attr[3], unitY = attr[4], unitX = attr[5]; attr += 6; int kySize = attr[2], kxSize = attr[3], unitY = attr[4], unitX = attr[5]; attr += 6;
int alphaY = kySize + unitY - 1, alphaX = kxSize + unitX - 1, alpha2 = alphaY * alphaX; int alphaY = kySize + unitY - 1, alphaX = kxSize + unitX - 1, alpha2 = alphaY * alphaX;
std::shared_ptr<Tensor> weight, offsets, scales, inputScales; std::shared_ptr<Tensor> weight, offsets, scales, inputScales, mergeInfo;
weight.reset(Tensor::createDevice<int8_t>({alpha2, ocDivUnit, ic4, UNIT, SRC_UNIT})); weight.reset(Tensor::createDevice<int8_t>({1, ocDivUnit, ic4, UNIT, SRC_UNIT}));
offsets.reset(Tensor::createDevice<float>({alpha2, oc4, pack})); offsets.reset(Tensor::createDevice<float>({alpha2, oc4, pack}));
scales.reset(Tensor::createDevice<float>({alpha2, oc4 * pack})); scales.reset(Tensor::createDevice<float>({1, 2 * oc4 * pack}));
inputScales.reset(Tensor::createDevice<float>({alpha2, pack})); inputScales.reset(Tensor::createDevice<float>({alpha2, pack}));
mergeInfo.reset(Tensor::createDevice<int8_t>({alpha2, weight->stride(0) + scales->size()}));
auto allocTensors = [=](std::vector<std::shared_ptr<Tensor>> tensors) -> bool { auto allocTensors = [=](std::vector<std::shared_ptr<Tensor>> tensors) -> bool {
bool success = true; bool success = true;
@ -45,7 +46,15 @@ std::shared_ptr<ConvInt8Winograd::WinoResource> ConvInt8Winograd::makeWinoResour
} }
return success; return success;
}; };
if (!allocTensors({weight, offsets, scales, inputScales})) {
if (!allocTensors({offsets, scales, inputScales, mergeInfo})) {
MNN_ERROR("Memory not enough\n");
return nullptr;
}
std::shared_ptr<Tensor> originWeightFloat, weightFloat;
originWeightFloat.reset(Tensor::createDevice<float>({oc, ic, kySize, kxSize}));
weightFloat.reset(Tensor::createDevice<float>({alpha2, oc, ic, 1, 1}));
if (!allocTensors({weight, originWeightFloat, weightFloat})) {
MNN_ERROR("Memory not enough\n"); MNN_ERROR("Memory not enough\n");
return nullptr; return nullptr;
} }
@ -61,14 +70,6 @@ std::shared_ptr<ConvInt8Winograd::WinoResource> ConvInt8Winograd::makeWinoResour
inputScales->host<float>()[i * pack + u] = scale; inputScales->host<float>()[i * pack + u] = scale;
} }
} }
std::shared_ptr<Tensor> originWeightFloat, weightFloat;
originWeightFloat.reset(Tensor::createDevice<float>({oc, ic, kySize, kxSize}));
weightFloat.reset(Tensor::createDevice<float>({alpha2, oc, ic, 1, 1}));
if (!allocTensors({originWeightFloat, weightFloat})) {
MNN_ERROR("Memory not enough\n");
return nullptr;
}
for (int c = 0; c < oc * ic; ++c) { for (int c = 0; c < oc * ic; ++c) {
for (int h = 0; h < kySize; ++h) { for (int h = 0; h < kySize; ++h) {
for (int w = 0; w < kxSize; ++w) { for (int w = 0; w < kxSize; ++w) {
@ -80,7 +81,7 @@ std::shared_ptr<ConvInt8Winograd::WinoResource> ConvInt8Winograd::makeWinoResour
} }
Math::WinogradGenerater generator({unitY, unitX}, {kySize, kxSize}, 1, true); Math::WinogradGenerater generator({unitY, unitX}, {kySize, kxSize}, 1, true);
generator.transformWeight(weightFloat.get(), originWeightFloat.get(), true); generator.transformWeight(weightFloat.get(), originWeightFloat.get(), true);
auto scalePtr = scales->host<float>();
for (int a = 0; a < alpha2; ++a) { for (int a = 0; a < alpha2; ++a) {
for (int oz = 0; oz < oc; ++oz) { for (int oz = 0; oz < oc; ++oz) {
int oz4 = oz / UNIT, ozRemain = oz % UNIT; int oz4 = oz / UNIT, ozRemain = oz % UNIT;
@ -89,7 +90,7 @@ std::shared_ptr<ConvInt8Winograd::WinoResource> ConvInt8Winograd::makeWinoResour
float scale = weightScaleData[a * oc + oz]; float scale = weightScaleData[a * oc + oz];
for (int sz = 0; sz < ic; ++sz) { for (int sz = 0; sz < ic; ++sz) {
int sz4 = sz / SRC_UNIT, szRemain = sz % SRC_UNIT; int sz4 = sz / SRC_UNIT, szRemain = sz % SRC_UNIT;
int index = (((a * ocDivUnit + oz4) * ic4 + sz4) * UNIT + ozRemain) * SRC_UNIT + szRemain; int index = ((oz4 * ic4 + sz4) * UNIT + ozRemain) * SRC_UNIT + szRemain;
float srcData = weightFloat->host<float>()[(a * oc + oz) * ic + sz]; float srcData = weightFloat->host<float>()[(a * oc + oz) * ic + sz];
// -ffast-math may cause inexact input then wrong rounded result, add eps to avoid this // -ffast-math may cause inexact input then wrong rounded result, add eps to avoid this
float eps = ((srcData/scale) > 0 ? 1 : -1) * 1e-6; float eps = ((srcData/scale) > 0 ? 1 : -1) * 1e-6;
@ -102,20 +103,24 @@ std::shared_ptr<ConvInt8Winograd::WinoResource> ConvInt8Winograd::makeWinoResour
} }
offsets->host<float>()[a * oc4 * pack + oz] = offset * scale * inputScaleData[a]; offsets->host<float>()[a * oc4 * pack + oz] = offset * scale * inputScaleData[a];
scales->host<float>()[a * oc4 * pack + oz] = scale * inputScaleData[a]; scalePtr[oz] = scale * inputScaleData[a];
} }
int32_t params[6] = {1, ocDivUnit, ic4, UNIT, SRC_UNIT, oc4 * pack};
ConvInt8TiledExecutor::packWeightAndQuantInfo(mergeInfo->host<int8_t>() + a * mergeInfo->stride(0), weight->host<int8_t>(), scales->host<int8_t>(), params);
} }
backend->onReleaseBuffer(originWeightFloat.get(), Backend::STATIC);
backend->onReleaseBuffer(weightFloat.get(), Backend::STATIC);
std::shared_ptr<WinoResource> resource(new WinoResource); std::shared_ptr<WinoResource> resource(new WinoResource);
resource->weight = weight; resource->weight = mergeInfo;
resource->offsets = offsets; resource->offsets = offsets;
resource->scales = scales; resource->scales = scales;
resource->transInputScales = inputScales; resource->transInputScales = inputScales;
std::vector<int32_t> inputZeroPoints(inputPointData, inputPointData + alpha2); std::vector<int32_t> inputZeroPoints(inputPointData, inputPointData + alpha2);
resource->transInputZeroPoints = inputZeroPoints; resource->transInputZeroPoints = inputZeroPoints;
resource->backend = backend; resource->backend = backend;
backend->onReleaseBuffer(weight.get(), Backend::STATIC);
backend->onReleaseBuffer(originWeightFloat.get(), Backend::STATIC);
backend->onReleaseBuffer(weightFloat.get(), Backend::STATIC);
return resource; return resource;
} }
@ -551,7 +556,7 @@ ErrorCode ConvInt8Winograd::WinoExecution::onExecute(const std::vector<Tensor *>
auto _dstFloatPtr = _dstOrigin + i * dc_4 * xC * pack; auto _dstFloatPtr = _dstOrigin + i * dc_4 * xC * pack;
auto _weightInt8Ptr = weight + i * mWinoResource->weight->stride(0); auto _weightInt8Ptr = weight + i * mWinoResource->weight->stride(0);
quanParam.biasFloat = (mWinoResource->offsets->host<float>() + i * mWinoResource->offsets->stride(0)); quanParam.biasFloat = (mWinoResource->offsets->host<float>() + i * mWinoResource->offsets->stride(0));
quanParam.scale = mWinoResource->scales->host<float>() + i * dc_4 * pack; quanParam.scale = mWinoResource->scales->host<float>() + i * dc_4 * pack;
quanParam.extraScale = nullptr; quanParam.extraScale = nullptr;

View File

@ -11,6 +11,7 @@
#include "backend/cpu/CPUConvolution.hpp" #include "backend/cpu/CPUConvolution.hpp"
#include "backend/cpu/compute/Int8FunctionsOpt.h" #include "backend/cpu/compute/Int8FunctionsOpt.h"
#include "ConvInt8TiledExecutor.hpp"
namespace MNN { namespace MNN {
class ConvInt8Winograd : public CPUConvolution { class ConvInt8Winograd : public CPUConvolution {

View File

@ -189,7 +189,7 @@ std::pair<int, bool> ConvolutionTiledExecutor::turnIm2ColToBlitInfo(float const
auto srcKx = srcKy + ((oxBegin + sta) * p.strideX + p.dilateX * kx - p.padX) * bytes * unit; auto srcKx = srcKy + ((oxBegin + sta) * p.strideX + p.dilateX * kx - p.padX) * bytes * unit;
srcPtr[number] = (const float*)srcKx; srcPtr[number] = (const float*)srcKx;
el[4 * number + 0] = end - sta; el[4 * number + 0] = end - sta;
el[4 * number + 1] = p.icup4; el[4 * number + 1] = p.icup4;
el[4 * number + 2] = eStart + sta; el[4 * number + 2] = eStart + sta;
el[4 * number + 3] = lOffset; el[4 * number + 3] = lOffset;
number++; number++;

View File

@ -29,7 +29,7 @@ extern "C" {
void MNNInt8ToUInt8(void* ptr, int count); void MNNInt8ToUInt8(void* ptr, int count);
} }
#endif #endif
#define QUANT_INFO_BYTES 4
namespace MNN { namespace MNN {
IdstConvolutionInt8::IdstConvolutionInt8(const Convolution2DCommon* convOp, Backend* b, IdstConvolutionInt8::IdstConvolutionInt8(const Convolution2DCommon* convOp, Backend* b,
@ -39,8 +39,8 @@ IdstConvolutionInt8::IdstConvolutionInt8(const Convolution2DCommon* convOp, Back
int UNIT, SRC_UNIT, DST_XUNIT; int UNIT, SRC_UNIT, DST_XUNIT;
core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
int PackUnit = static_cast<CPUBackend*>(b)->functions()->pack; int PackUnit = static_cast<CPUBackend*>(b)->functions()->pack;
int ocUp4 = ROUND_UP(biasSize, PackUnit);
mBias.reset(ROUND_UP(biasSize, PackUnit)); mBias.reset(ocUp4);
mBias.clear(); mBias.clear();
auto biasDest = mBias.get(); auto biasDest = mBias.get();
mAMin = common->quan->aMin(); mAMin = common->quan->aMin();
@ -65,21 +65,31 @@ IdstConvolutionInt8::IdstConvolutionInt8(const Convolution2DCommon* convOp, Back
auto kernelCount = kx * ky; auto kernelCount = kx * ky;
auto srcCount = mSrcCount; auto srcCount = mSrcCount;
std::vector<int> shape; std::vector<int> shape;
shape = {UP_DIV(outputCount, UNIT), UP_DIV(srcCount, SRC_UNIT) * kernelCount, UNIT, SRC_UNIT}; shape = {1, UP_DIV(outputCount, UNIT), UP_DIV(srcCount, SRC_UNIT) * kernelCount, UNIT, SRC_UNIT};
mFakeBias.reset(Tensor::createDevice<float>({ocUp4}));
mWeight.reset(Tensor::createDevice<int8_t>(shape)); int weightlen = shape[0] * shape[1] * shape[2] * shape[3] * shape[4];
mFakeBias.reset(Tensor::createDevice<float>({(int)ROUND_UP(biasSize, PackUnit)})); int quantlen = 2 * ocUp4 * QUANT_INFO_BYTES;
mFakeWeightBias.reset(Tensor::createDevice<float>({(int)ROUND_UP(biasSize, PackUnit)})); mWeight.reset(Tensor::createDevice<int8_t>({weightlen + quantlen}));
mValid = b->onAcquireBuffer(mWeight.get(), Backend::STATIC); mValid = b->onAcquireBuffer(mWeight.get(), Backend::STATIC);
mValid &= b->onAcquireBuffer(mFakeBias.get(), Backend::STATIC); mValid &= b->onAcquireBuffer(mFakeBias.get(), Backend::STATIC);
mValid &= b->onAcquireBuffer(mFakeWeightBias.get(), Backend::STATIC);
if (!mValid) { if (!mValid) {
MNN_ERROR("Memory not enough\n"); MNN_ERROR("Memory not enough\n");
return; return;
} }
ConvInt8TiledExecutor::reorderWeight(mWeight.get(), (uint8_t*)common->weight.get(), SRC_UNIT, UNIT, srcCount, outputCount, kernelCount, PackUnit); AutoStorage<uint8_t> weightReordered(weightlen);
AutoStorage<float> fakeWeightScaleBias(2 * ocUp4);
if (weightReordered.get() == nullptr || fakeWeightScaleBias.get() == nullptr) {
MNN_ERROR("Memory not enough\n");
return;
}
int32_t info[6] = {1, outputCount, srcCount, kernelCount, UNIT, SRC_UNIT};
ConvInt8TiledExecutor::reorderWeight(weightReordered.get(), (uint8_t*)common->weight.get(), info);
::memset(mFakeBias->host<float>(), 0, mFakeBias->size()); ::memset(mFakeBias->host<float>(), 0, mFakeBias->size());
::memset(mFakeWeightBias->host<float>(), 0, mFakeWeightBias->size()); auto ptr = (float*)fakeWeightScaleBias.get();
::memset(ptr, 0, 2 * ocUp4 * 4);
for (int i = 0; i < ocUp4; ++i) {
ptr[i] = 1.f;
}
#ifdef MNN_USE_SSE #ifdef MNN_USE_SSE
for (int oz = 0; oz < outputCount; ++oz) { for (int oz = 0; oz < outputCount; ++oz) {
auto srcZ = common->weight.get() + oz * kernelCount * srcCount; auto srcZ = common->weight.get() + oz * kernelCount * srcCount;
@ -90,6 +100,8 @@ IdstConvolutionInt8::IdstConvolutionInt8(const Convolution2DCommon* convOp, Back
mFakeBias->host<float>()[oz] = static_cast<float>(offset) * 1.f; mFakeBias->host<float>()[oz] = static_cast<float>(offset) * 1.f;
} }
#endif #endif
int32_t params[6] = {shape[0], shape[1], shape[2], shape[3], shape[4], ocUp4};
ConvInt8TiledExecutor::packWeightAndQuantInfo(mWeight->host<int8_t>(), (int8_t*)weightReordered.get(), (int8_t*)fakeWeightScaleBias.get(), params, QUANT_INFO_BYTES);
} }
IdstConvolutionInt8::~IdstConvolutionInt8() { IdstConvolutionInt8::~IdstConvolutionInt8() {
@ -118,7 +130,7 @@ ErrorCode IdstConvolutionInt8::onResize(const std::vector<Tensor*>& inputs, cons
mTempBuffer.buffer().dimensions = 3; mTempBuffer.buffer().dimensions = 3;
mTempBuffer.buffer().dim[0].extent = number; mTempBuffer.buffer().dim[0].extent = number;
mTempBuffer.buffer().dim[1].extent = DST_XUNIT; mTempBuffer.buffer().dim[1].extent = DST_XUNIT;
mTempBuffer.buffer().dim[2].extent = mWeight->length(1) * SRC_UNIT; mTempBuffer.buffer().dim[2].extent = mIm2ColParamter.kernelCountUnit * SRC_UNIT;
TensorUtils::setLinearLayout(&mTempBuffer); TensorUtils::setLinearLayout(&mTempBuffer);
bool success = backend()->onAcquireBuffer(&mSrcCopyBuffer, Backend::DYNAMIC); bool success = backend()->onAcquireBuffer(&mSrcCopyBuffer, Backend::DYNAMIC);
@ -183,7 +195,6 @@ ErrorCode IdstConvolutionInt8::onExecute(const std::vector<Tensor*>& inputs, con
quanParam.useInt8 = 0; quanParam.useInt8 = 0;
float fp32minmax[2] = {-std::numeric_limits<float>().max(), std::numeric_limits<float>().max()}; float fp32minmax[2] = {-std::numeric_limits<float>().max(), std::numeric_limits<float>().max()};
quanParam.fp32minmax = fp32minmax; quanParam.fp32minmax = fp32minmax;
quanParam.weightQuanBias = mFakeWeightBias->host<float>();
std::vector<float> fakeSrcKernleSum(DST_XUNIT, 0.f); std::vector<float> fakeSrcKernleSum(DST_XUNIT, 0.f);
quanParam.srcKernelSum = fakeSrcKernleSum.data(); quanParam.srcKernelSum = fakeSrcKernleSum.data();

View File

@ -40,7 +40,6 @@ private:
std::vector<float> mPostParameters; std::vector<float> mPostParameters;
// mFakeBias used by GemmKernel // mFakeBias used by GemmKernel
std::shared_ptr<Tensor> mFakeBias; std::shared_ptr<Tensor> mFakeBias;
std::shared_ptr<Tensor> mFakeWeightBias;
MemChunk mBlitInfo; MemChunk mBlitInfo;
std::pair<size_t, size_t> mBlitInfoStride; std::pair<size_t, size_t> mBlitInfoStride;
}; };

View File

@ -1416,7 +1416,7 @@ static void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, co
size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) { size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) {
const int bytes = ((post->useInt8 == 1) ? 1 : 4); const int bytes = ((post->useInt8 == 1) ? 1 : 4);
float fp32min = 0, fp32max = 0; float fp32min = 0, fp32max = 0;
int weight_step_Z = src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); int weight_step_Z = src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) + 4 * 2 * GEMM_INT8_UNIT;
int weight_step_Y = (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); int weight_step_Y = (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
const auto srcSumPtr = post->srcKernelSum; const auto srcSumPtr = post->srcKernelSum;
if (0 == post->useInt8 && post->fp32minmax) { if (0 == post->useInt8 && post->fp32minmax) {
@ -1425,17 +1425,19 @@ static void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, co
} }
float* biasPtr = (float*)post->biasFloat; float* biasPtr = (float*)post->biasFloat;
auto accumbuff = post->accumBuffer;
for (int dz = 0; dz < dst_depth_quad; ++dz) { for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + weight_step_Z * dz; const auto weight_dz = weight + weight_step_Z * dz;
const auto bias_dz = biasPtr + dz * GEMM_INT8_UNIT; const auto bias_dz = biasPtr + dz * GEMM_INT8_UNIT;
const auto weight_zero = post->weightQuanBias + (dz * GEMM_INT8_UNIT); const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const float* scale_dz = nullptr; const auto weight_zero = scale_dz + GEMM_INT8_UNIT;
scale_dz = post->scale + (dz * GEMM_INT8_UNIT); auto dst_z = dst + dz * dst_step;
auto dst_z = dst + dz * dst_step; auto accum_z = accumbuff + dz * realCount * GEMM_INT8_UNIT;
for (int w = 0; w < realCount; ++w) { for (int w = 0; w < realCount; ++w) {
const auto src_x = src + w * GEMM_INT8_SRC_UNIT; const auto src_x = src + w * GEMM_INT8_SRC_UNIT;
auto dst_x = dst_z + w * GEMM_INT8_UNIT * bytes; auto dst_x = dst_z + w * GEMM_INT8_UNIT * bytes;
auto accum_x = accum_z + w * GEMM_INT8_UNIT;
int32_t dstTemp[4] = {0, 0, 0, 0}; int32_t dstTemp[4] = {0, 0, 0, 0};
for (int sz = 0; sz < src_depth_quad; ++sz) { for (int sz = 0; sz < src_depth_quad; ++sz) {
@ -1459,13 +1461,18 @@ static void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, co
if (biasPtr) { if (biasPtr) {
value += bias_dz[j]; value += bias_dz[j];
} else { } else {
float dstv = ((float*)dst_x)[j]; float dstv = ((float*)accum_x)[j];
value += dstv; value += dstv;
} }
if (post->fp32minmax) { if (dst) {
value = std::min(std::max(fp32min, value), fp32max); if (post->fp32minmax) {
value = std::min(std::max(fp32min, value), fp32max);
}
((float*)dst_x)[j] = value;
} else {
((float*)accum_x)[j] = value;
} }
((float*)dst_x)[j] = value;
} else { } else {
value += bias_dz[j]; value += bias_dz[j];
value = ALIMAX(value, post->minValue); value = ALIMAX(value, post->minValue);
@ -1481,8 +1488,8 @@ static void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src,
uint32_t c = 0xf; uint32_t c = 0xf;
const int bytes = 4; const int bytes = 4;
float fp32min = 0, fp32max = 0; float fp32min = 0, fp32max = 0;
int weight_step_Z = 0.5 * (src_depth_quad) * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
int weight_step_Y = 0.5 * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); int weight_step_Y = 0.5 * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
int weight_step_Z = weight_step_Y * src_depth_quad + 4 * 2 * GEMM_INT8_UNIT;
MNN_ASSERT(post->useInt8==0); MNN_ASSERT(post->useInt8==0);
if (post->fp32minmax) { if (post->fp32minmax) {
fp32min = (post->fp32minmax)[0]; fp32min = (post->fp32minmax)[0];
@ -1490,18 +1497,20 @@ static void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src,
} }
float* biasPtr = (float*)post->biasFloat; float* biasPtr = (float*)post->biasFloat;
auto accumbuff = post->accumBuffer;
const auto srcSumPtr = post->srcKernelSum; const auto srcSumPtr = post->srcKernelSum;
for (int dz = 0; dz < dst_depth_quad; ++dz) { for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + weight_step_Z * dz; const auto weight_dz = weight + weight_step_Z * dz;
const auto bias_dz = biasPtr + dz * GEMM_INT8_UNIT; const auto bias_dz = biasPtr + dz * GEMM_INT8_UNIT;
const auto weight_zero = post->weightQuanBias + (dz * GEMM_INT8_UNIT); const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const float* scale_dz = nullptr; const auto weight_zero = scale_dz + GEMM_INT8_UNIT;
scale_dz = post->scale + (dz * GEMM_INT8_UNIT);
auto dst_z = dst + dz * dst_step; auto dst_z = dst + dz * dst_step;
auto accum_z = accumbuff + dz * realCount * GEMM_INT8_UNIT;
for (int w = 0; w < realCount; ++w) { for (int w = 0; w < realCount; ++w) {
const auto src_x = src + w * GEMM_INT8_SRC_UNIT; const auto src_x = src + w * GEMM_INT8_SRC_UNIT;
auto dst_x = dst_z + w * GEMM_INT8_UNIT * bytes; auto dst_x = dst_z + w * GEMM_INT8_UNIT * bytes;
auto accum_x = accum_z + w * GEMM_INT8_UNIT;
int32_t dstTemp[4] = {0, 0, 0, 0}; int32_t dstTemp[4] = {0, 0, 0, 0};
for (int sz = 0; sz < src_depth_quad; ++sz) { for (int sz = 0; sz < src_depth_quad; ++sz) {
@ -1531,13 +1540,17 @@ static void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src,
if (biasPtr) { if (biasPtr) {
value += bias_dz[j]; value += bias_dz[j];
} else { } else {
float dstv = ((float*)dst_x)[j]; float dstv = ((float*)accum_x)[j];
value += dstv; value += dstv;
} }
if (post->fp32minmax) { if (dst) {
value = std::min(std::max(fp32min, value), fp32max); if (post->fp32minmax) {
value = std::min(std::max(fp32min, value), fp32max);
}
((float*)dst_x)[j] = value;
} else {
((float*)accum_x)[j] = value;
} }
((float*)dst_x)[j] = value;
} }
} }
} }

View File

@ -51,6 +51,7 @@ struct QuanPostTreatParameters {
const int32_t* bias = nullptr; const int32_t* bias = nullptr;
const float* extraScale = nullptr; const float* extraScale = nullptr;
const float* extraBias = nullptr; const float* extraBias = nullptr;
float* accumBuffer = nullptr;
}; };
struct QuanPrePostParameters{ struct QuanPrePostParameters{
float* inputScale; float* inputScale;

View File

@ -52,6 +52,7 @@ bool AVX2Functions::init(int cpuFlags) {
#ifdef MNN_LOW_MEMORY #ifdef MNN_LOW_MEMORY
coreFunction->MNNAbsMax = _AVX_MNNAbsMaxFP32; coreFunction->MNNAbsMax = _AVX_MNNAbsMaxFP32;
coreFunction->MNNDynamicQuant = _AVX_MNNDynamicQuant;
#endif #endif
coreFunction->MNNPackC4ForMatMul_A = _AVX_MNNPackC4ForMatMul_A; coreFunction->MNNPackC4ForMatMul_A = _AVX_MNNPackC4ForMatMul_A;
coreFunction->MNNPackForMatMul_B = _AVX_MNNPackForMatMul_B; coreFunction->MNNPackForMatMul_B = _AVX_MNNPackForMatMul_B;

View File

@ -57,6 +57,7 @@ void MNNFunctionInit() {
#ifdef MNN_LOW_MEMORY #ifdef MNN_LOW_MEMORY
coreFunction->MNNAbsMax = _SSE_MNNAbsMaxFP32; coreFunction->MNNAbsMax = _SSE_MNNAbsMaxFP32;
coreFunction->MNNDynamicQuant = _SSE_MNNDynamicQuant;
#endif #endif
coreFunction->MNNPackC4ForMatMul_A = _SSE_MNNPackC4ForMatMul_A; coreFunction->MNNPackC4ForMatMul_A = _SSE_MNNPackC4ForMatMul_A;
coreFunction->MNNPackForMatMul_B = _SSE_MNNPackForMatMul_B; coreFunction->MNNPackForMatMul_B = _SSE_MNNPackForMatMul_B;

View File

@ -46,6 +46,7 @@ void _AVX_MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, s
#ifdef MNN_LOW_MEMORY #ifdef MNN_LOW_MEMORY
void _AVX_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack); void _AVX_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack);
void _AVX_MNNDynamicQuant(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack);
#endif #endif
void _AVX_MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); void _AVX_MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);

View File

@ -66,26 +66,189 @@ void _AVX_MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, s
void _AVX_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) { void _AVX_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) {
// source: (ic/8, N, 8) // source: (ic/8, N, 8)
auto srcStep = pack * realSize; auto srcStep = pack * realSize;
auto constant = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)); if (pack == 8) {
float temp[8]; float temp[8];
for (int i = 0; i < realSize; ++i) { auto constant = _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF));
__m256 res = _mm256_setzero_ps(); for (int i = 0; i < realSize; ++i) {
for (int c = 0; c < src_depth_quad; ++c) { __m256 res = _mm256_setzero_ps();
auto src0 = source + c * srcStep + i * pack; for (int c = 0; c < src_depth_quad; ++c) {
__m256 vecA = _mm256_loadu_ps(src0); auto src0 = source + c * srcStep + i * pack;
__m256 absVecA = _mm256_and_ps(vecA, constant); __m256 vecA = _mm256_loadu_ps(src0);
__m256 mask = _mm256_cmp_ps(absVecA, res, 1); __m256 absVecA = _mm256_and_ps(vecA, constant);
res = _mm256_blendv_ps(absVecA, res, mask); __m256 mask = _mm256_cmp_ps(absVecA, res, 1);
res = _mm256_blendv_ps(absVecA, res, mask);
}
_mm256_storeu_ps(temp, res);
float absmaxVal = temp[0];
for (int k = 1; k < pack; ++k) {
if (absmaxVal < temp[k]) {
absmaxVal = temp[k];
}
}
absmax[i] = absmaxVal;
} }
_mm256_storeu_ps(temp, res); return;
float absmaxVal = temp[0]; }
for (int k = 1; k < pack; ++k) { if (pack == 4) {
if (absmaxVal < temp[k]) { float tmp[4];
absmaxVal = temp[k]; __m128 mask = _mm_set1_ps(-0.0f);
for (int i = 0; i < realSize; ++i) {
__m128 absmax_ = _mm_loadu_ps(source + i * pack);
absmax_ = _mm_andnot_ps(mask, absmax_);
auto src0 = source + i * pack;
for (int j = 1; j < src_depth_quad; ++j) {
__m128 vec = _mm_loadu_ps(src0 + j * srcStep);
vec = _mm_andnot_ps(mask, vec);
absmax_ = _mm_max_ps(absmax_, vec);
}
_mm_storeu_ps(tmp, absmax_);
float res = tmp[0];
for (int j = 1; j < pack; ++j) {
res = ALIMAX(res, tmp[j]);
}
absmax[i] = res;
}
return;
}
MNN_ERROR("absmax error: x86_x64 avx2 don't suppport pack=%d yet\n", pack);
return;
}
void _AVX_MNNDynamicQuant(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack) {
auto srcStep = realSize * pack;
if (pack == 8) { // core->pack
auto offset = _mm256_set1_epi32(128);
int32_t* dstPtr = reinterpret_cast<int32_t*>(dst);
int32_t tmp[8];
for (int i = 0; i < src_depth_quad; ++i) {
int xcount = realSize;
auto srcPtr = src + i * srcStep;
auto scalePtr = scale;
while (xcount > 3) {
auto scale0 = _mm256_set1_ps(scalePtr[0]);
auto scale1 = _mm256_set1_ps(scalePtr[1]);
auto scale2 = _mm256_set1_ps(scalePtr[2]);
auto scale3 = _mm256_set1_ps(scalePtr[3]);
auto data0 = _mm256_loadu_ps(srcPtr);
auto data1 = _mm256_loadu_ps(srcPtr + pack);
auto data2 = _mm256_loadu_ps(srcPtr + 2 * pack);
auto data3 = _mm256_loadu_ps(srcPtr + 3 * pack);
data0 = _mm256_mul_ps(data0, scale0);
data1 = _mm256_mul_ps(data1, scale1);
data2 = _mm256_mul_ps(data2, scale2);
data3 = _mm256_mul_ps(data3, scale3);
data0 = _mm256_round_ps(data0, 0);
data1 = _mm256_round_ps(data1, 0);
data2 = _mm256_round_ps(data2, 0);
data3 = _mm256_round_ps(data3, 0);
auto r0 = _mm256_cvtps_epi32(data0);
auto r1 = _mm256_cvtps_epi32(data1);
auto r2 = _mm256_cvtps_epi32(data2);
auto r3 = _mm256_cvtps_epi32(data3);
r0 = _mm256_add_epi32(r0, offset);
r1 = _mm256_add_epi32(r1, offset);
r2 = _mm256_add_epi32(r2, offset);
r3 = _mm256_add_epi32(r3, offset);
auto r0_16 = _mm256_packs_epi32(r0, r1); // 0000111100001111
auto r1_16 = _mm256_packs_epi32(r2, r3); // 2222333322223333
auto r0_8 = _mm256_packus_epi16(r0_16, r1_16); // 0000111122223333 0000111122223333
_mm256_storeu_si256((__m256i *)tmp, r0_8);
for (int k = 0; k < 4; ++k) {
dstPtr[2 * k] = tmp[k];
dstPtr[2 * k + 1] = tmp[k + 4];
}
// next round
xcount -= 4;
scalePtr += 4;
srcPtr += (4 * pack);
dstPtr += 8;
}
while (xcount) {
auto scale0 = _mm256_set1_ps(scalePtr[0]);
auto data0 = _mm256_loadu_ps(srcPtr);
data0 = _mm256_mul_ps(data0, scale0);
data0 = _mm256_round_ps(data0, 0);
auto r0 = _mm256_cvtps_epi32(data0);
r0 = _mm256_add_epi32(r0, offset);
auto r0_16 = _mm256_packs_epi32(r0, r0); // 0000111100001111
auto r0_8 = _mm256_packus_epi16(r0_16, r0_16); // 0000111122223333 0000111122223333
_mm256_storeu_si256((__m256i *)tmp, r0_8);
dstPtr[0] = tmp[0];
dstPtr[1] = tmp[4];
// next round
xcount--;
scalePtr += 1;
srcPtr += pack;
dstPtr += 2;
} }
} }
absmax[i] = absmaxVal; return;
} }
if (pack == 4) { // LP=4;
auto offset = _mm_set1_epi32(128);
int32_t tmp[4];
int32_t* dstPtr = reinterpret_cast<int32_t*>(dst);
for (int i = 0; i < src_depth_quad; ++i) {
int xcount = realSize;
auto srcPtr = src + i * srcStep;
auto scalePtr = scale;
while (xcount > 3) {
auto scale0 = _mm_set1_ps(scalePtr[0]);
auto scale1 = _mm_set1_ps(scalePtr[1]);
auto scale2 = _mm_set1_ps(scalePtr[2]);
auto scale3 = _mm_set1_ps(scalePtr[3]);
auto data0 = _mm_loadu_ps(srcPtr);
auto data1 = _mm_loadu_ps(srcPtr + pack);
auto data2 = _mm_loadu_ps(srcPtr + 2 * pack);
auto data3 = _mm_loadu_ps(srcPtr + 3 * pack);
data0 = _mm_mul_ps(data0, scale0);
data1 = _mm_mul_ps(data1, scale1);
data2 = _mm_mul_ps(data2, scale2);
data3 = _mm_mul_ps(data3, scale3);
data0 = _mm_round_ps(data0, 0);
data1 = _mm_round_ps(data1, 0);
data2 = _mm_round_ps(data2, 0);
data3 = _mm_round_ps(data3, 0);
auto r0 = _mm_cvtps_epi32(data0);
auto r1 = _mm_cvtps_epi32(data1);
auto r2 = _mm_cvtps_epi32(data2);
auto r3 = _mm_cvtps_epi32(data3);
r0 = _mm_add_epi32(r0, offset);
r1 = _mm_add_epi32(r1, offset);
r2 = _mm_add_epi32(r2, offset);
r3 = _mm_add_epi32(r3, offset);
auto r0_16 = _mm_packs_epi32(r0, r1); // 00001111
auto r1_16 = _mm_packs_epi32(r2, r3); // 22223333
auto r0_8 = _mm_packus_epi16(r0_16, r1_16); // 0000111122223333
_mm_storeu_si128((__m128i *)dstPtr, r0_8);
// next round
xcount -= 4;
scalePtr += 4;
srcPtr += (4 * pack);
dstPtr += 4;
}
while (xcount) {
auto scale0 = _mm_set1_ps(scalePtr[0]);
auto data0 = _mm_loadu_ps(srcPtr);
data0 = _mm_mul_ps(data0, scale0);
auto r0 = _mm_cvtps_epi32(_mm_round_ps(data0, 0));
r0 = _mm_add_epi32(r0, offset);
auto r0_16 = _mm_packs_epi32(r0, r0); // 00001111
auto r0_8 = _mm_packus_epi16(r0_16, r0_16); // 0000111122223333
_mm_storeu_si128((__m128i *)tmp, r0_8);
dstPtr[0] = tmp[0];
// next round
xcount--;
scalePtr += 1;
srcPtr += pack;
dstPtr += 1;
}
}
return;
}
MNN_ERROR("dynamic quant error: x86_x64 avx2 don't suppport pack=%d yet\n", pack);
return;
} }
#endif #endif

View File

@ -72,8 +72,9 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const
if (post->biasFloat) { if (post->biasFloat) {
biasPtr = post->biasFloat; biasPtr = post->biasFloat;
} }
auto accumbuff = post->accumBuffer;
int weight_step_Z = 0.5 * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); int weight_step_Z = 0.5 * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H) + 4 * 2 * GEMMINT8_AVX2_H;
int weight_step_Y = 0.5 * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); int weight_step_Y = 0.5 * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
const __m128i mask = _mm_set1_epi8(0xf); const __m128i mask = _mm_set1_epi8(0xf);
@ -122,11 +123,13 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const
if (GEMMINT8_AVX2_E == realDst) { if (GEMMINT8_AVX2_E == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) { for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * weight_step_Z; const auto weight_dz = weight + dz * weight_step_Z;
const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const float* scale_dz = post->scale + dz * AVX2_PACKINT8; const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
auto dst_z = dst + dz * dst_step_tmp; auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src; const auto src_x = src;
auto dst_x = dst_z; auto dst_x = dst_z;
auto accum_z = accumbuff + dz * realDst * AVX2_PACKINT8;
auto accum_x = accum_z;
__m256i D00 = _mm256_set1_epi32(0); __m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0); __m256i D01 = _mm256_set1_epi32(0);
__m256i D02 = _mm256_set1_epi32(0); __m256i D02 = _mm256_set1_epi32(0);
@ -199,41 +202,49 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const
f2 = _mm256_add_ps(f2, biasValue); f2 = _mm256_add_ps(f2, biasValue);
f3 = _mm256_add_ps(f3, biasValue); f3 = _mm256_add_ps(f3, biasValue);
} else { } else {
auto dstv0 = _mm256_loadu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8); auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8);
auto dstv1 = _mm256_loadu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8); auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8);
auto dstv2 = _mm256_loadu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8); auto dstv2 = _mm256_loadu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8);
auto dstv3 = _mm256_loadu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8); auto dstv3 = _mm256_loadu_ps(((float*)accum_x) + 3 * AVX2_PACKINT8);
f0 = _mm256_add_ps(f0, dstv0); f0 = _mm256_add_ps(f0, dstv0);
f1 = _mm256_add_ps(f1, dstv1); f1 = _mm256_add_ps(f1, dstv1);
f2 = _mm256_add_ps(f2, dstv2); f2 = _mm256_add_ps(f2, dstv2);
f3 = _mm256_add_ps(f3, dstv3); f3 = _mm256_add_ps(f3, dstv3);
} }
if (post->fp32minmax) { if (dst) {
f0 = _mm256_min_ps(f0, fp32max); if (post->fp32minmax) {
f1 = _mm256_min_ps(f1, fp32max); f0 = _mm256_min_ps(f0, fp32max);
f2 = _mm256_min_ps(f2, fp32max); f1 = _mm256_min_ps(f1, fp32max);
f3 = _mm256_min_ps(f3, fp32max); f2 = _mm256_min_ps(f2, fp32max);
f0 = _mm256_max_ps(f0, fp32min); f3 = _mm256_min_ps(f3, fp32max);
f1 = _mm256_max_ps(f1, fp32min); f0 = _mm256_max_ps(f0, fp32min);
f2 = _mm256_max_ps(f2, fp32min); f1 = _mm256_max_ps(f1, fp32min);
f3 = _mm256_max_ps(f3, fp32min); f2 = _mm256_max_ps(f2, fp32min);
f3 = _mm256_max_ps(f3, fp32min);
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
_mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3);
}
} else {
_mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8, f2);
_mm256_storeu_ps(((float*)accum_x) + 3 * AVX2_PACKINT8, f3);
} }
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
_mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3);
} }
return; return;
} }
if (3 == realDst) { if (3 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) { for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * weight_step_Z; const auto weight_dz = weight + dz * weight_step_Z;
const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const float* scale_dz = post->scale + dz * AVX2_PACKINT8; const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
auto dst_z = dst + dz * dst_step_tmp; auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src; const auto src_x = src;
auto dst_x = dst_z; auto dst_x = dst_z;
auto accum_z = accumbuff + dz * realDst * AVX2_PACKINT8;
auto accum_x = accum_z;
__m256i D00 = _mm256_set1_epi32(0); __m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0); __m256i D01 = _mm256_set1_epi32(0);
__m256i D02 = _mm256_set1_epi32(0); __m256i D02 = _mm256_set1_epi32(0);
@ -295,36 +306,43 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const
f1 = _mm256_add_ps(f1, biasValue); f1 = _mm256_add_ps(f1, biasValue);
f2 = _mm256_add_ps(f2, biasValue); f2 = _mm256_add_ps(f2, biasValue);
} else { } else {
auto dstv0 = _mm256_loadu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8); auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8);
auto dstv1 = _mm256_loadu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8); auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8);
auto dstv2 = _mm256_loadu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8); auto dstv2 = _mm256_loadu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8);
f0 = _mm256_add_ps(f0, dstv0); f0 = _mm256_add_ps(f0, dstv0);
f1 = _mm256_add_ps(f1, dstv1); f1 = _mm256_add_ps(f1, dstv1);
f2 = _mm256_add_ps(f2, dstv2); f2 = _mm256_add_ps(f2, dstv2);
} }
if (post->fp32minmax) { if (dst) {
f0 = _mm256_min_ps(f0, fp32max); if (post->fp32minmax) {
f1 = _mm256_min_ps(f1, fp32max); f0 = _mm256_min_ps(f0, fp32max);
f2 = _mm256_min_ps(f2, fp32max); f1 = _mm256_min_ps(f1, fp32max);
f0 = _mm256_max_ps(f0, fp32min); f2 = _mm256_min_ps(f2, fp32max);
f1 = _mm256_max_ps(f1, fp32min); f0 = _mm256_max_ps(f0, fp32min);
f2 = _mm256_max_ps(f2, fp32min); f1 = _mm256_max_ps(f1, fp32min);
f2 = _mm256_max_ps(f2, fp32min);
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
}
} else {
_mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8, f2);
} }
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
} }
return; return;
} }
if (2 == realDst) { if (2 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) { for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * weight_step_Z; const auto weight_dz = weight + dz * weight_step_Z;
const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const float* scale_dz = post->scale + dz * AVX2_PACKINT8; const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
auto dst_z = dst + dz * dst_step_tmp; auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src; const auto src_x = src;
auto dst_x = dst_z; auto dst_x = dst_z;
auto accum_z = accumbuff + dz * realDst * AVX2_PACKINT8;
auto accum_x = accum_z;
__m256i D00 = _mm256_set1_epi32(0); __m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0); __m256i D01 = _mm256_set1_epi32(0);
@ -377,31 +395,37 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const
f0 = _mm256_add_ps(f0, biasValue); f0 = _mm256_add_ps(f0, biasValue);
f1 = _mm256_add_ps(f1, biasValue); f1 = _mm256_add_ps(f1, biasValue);
} else { } else {
auto dstv0 = _mm256_loadu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8); auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8);
auto dstv1 = _mm256_loadu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8); auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8);
f0 = _mm256_add_ps(f0, dstv0); f0 = _mm256_add_ps(f0, dstv0);
f1 = _mm256_add_ps(f1, dstv1); f1 = _mm256_add_ps(f1, dstv1);
} }
if (post->fp32minmax) { if (dst) {
f0 = _mm256_min_ps(f0, fp32max); if (post->fp32minmax) {
f1 = _mm256_min_ps(f1, fp32max); f0 = _mm256_min_ps(f0, fp32max);
f0 = _mm256_max_ps(f0, fp32min); f1 = _mm256_min_ps(f1, fp32max);
f1 = _mm256_max_ps(f1, fp32min); f0 = _mm256_max_ps(f0, fp32min);
f1 = _mm256_max_ps(f1, fp32min);
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
}
} else {
_mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1);
} }
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
} }
return; return;
} }
if (1 == realDst) { if (1 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) { for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * weight_step_Z; const auto weight_dz = weight + dz * weight_step_Z;
const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const float* scale_dz = post->scale + dz * AVX2_PACKINT8; const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
auto dst_z = dst + dz * dst_step_tmp; auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src; const auto src_x = src;
auto dst_x = dst_z; auto dst_x = dst_z;
auto accum_z = accumbuff + dz * realDst * AVX2_PACKINT8;
auto accum_x = accum_z;
__m256i D00 = _mm256_set1_epi32(0); __m256i D00 = _mm256_set1_epi32(0);
__m256i D10 = _mm256_set1_epi32(0); __m256i D10 = _mm256_set1_epi32(0);
@ -441,16 +465,19 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const
auto biasValue = _mm256_loadu_ps(bias_dz); auto biasValue = _mm256_loadu_ps(bias_dz);
f0 = _mm256_add_ps(f0, biasValue); f0 = _mm256_add_ps(f0, biasValue);
} else { } else {
auto dstv = _mm256_loadu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8); auto dstv = _mm256_loadu_ps(((float*)accum_x));
f0 = _mm256_add_ps(f0, dstv); f0 = _mm256_add_ps(f0, dstv);
} }
if (post->fp32minmax) { if (dst) {
f0 = _mm256_min_ps(f0, fp32max); if (post->fp32minmax) {
f0 = _mm256_max_ps(f0, fp32min); f0 = _mm256_min_ps(f0, fp32max);
f0 = _mm256_max_ps(f0, fp32min);
_mm256_storeu_ps(((float*)dst_x), f0);
}
} else {
_mm256_storeu_ps(((float*)accum_x) , f0);
} }
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
} }
return; return;
} }
@ -474,6 +501,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons
if (post->biasFloat) { if (post->biasFloat) {
biasPtr = post->biasFloat; biasPtr = post->biasFloat;
} }
auto accumbuff = post->accumBuffer;
auto srcKernelSumPtr = post->srcKernelSum; auto srcKernelSumPtr = post->srcKernelSum;
__m256 kernelSum0 = _mm256_setzero_ps(); __m256 kernelSum0 = _mm256_setzero_ps();
__m256 kernelSum1 = _mm256_setzero_ps(); __m256 kernelSum1 = _mm256_setzero_ps();
@ -514,15 +542,18 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons
} }
} }
} }
//printf("e=%d, sz=%d, dz=%d\n", realDst, src_depth_quad, dst_depth_quad); int weight_step_Z = src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H) + 4 * 2 * GEMMINT8_AVX2_H;
int weight_step_Y = (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
if (GEMMINT8_AVX2_E == realDst) { if (GEMMINT8_AVX2_E == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) { for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); const auto weight_dz = weight + dz * weight_step_Z;
const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const float* scale_dz = post->scale + dz * AVX2_PACKINT8; const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
auto dst_z = dst + dz * dst_step_tmp; auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src; const auto src_x = src;
auto dst_x = dst_z; auto dst_x = dst_z;
auto accum_z = accumbuff + dz * realDst * AVX2_PACKINT8;
auto accum_x = accum_z;
__m256i D00 = _mm256_set1_epi32(0); __m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0); __m256i D01 = _mm256_set1_epi32(0);
__m256i D02 = _mm256_set1_epi32(0); __m256i D02 = _mm256_set1_epi32(0);
@ -616,42 +647,51 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons
POSTTREAT(3); POSTTREAT(3);
} else { } else {
if (nullptr == biasPtr) { if (nullptr == biasPtr) {
auto dstv0 = _mm256_loadu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8); auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8);
auto dstv1 = _mm256_loadu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8); auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8);
auto dstv2 = _mm256_loadu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8); auto dstv2 = _mm256_loadu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8);
auto dstv3 = _mm256_loadu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8); auto dstv3 = _mm256_loadu_ps(((float*)accum_x) + 3 * AVX2_PACKINT8);
f0 = _mm256_add_ps(f0, dstv0); f0 = _mm256_add_ps(f0, dstv0);
f1 = _mm256_add_ps(f1, dstv1); f1 = _mm256_add_ps(f1, dstv1);
f2 = _mm256_add_ps(f2, dstv2); f2 = _mm256_add_ps(f2, dstv2);
f3 = _mm256_add_ps(f3, dstv3); f3 = _mm256_add_ps(f3, dstv3);
} }
if (post->fp32minmax) { if (dst) {
f0 = _mm256_min_ps(f0, fp32max); if (post->fp32minmax) {
f1 = _mm256_min_ps(f1, fp32max); f0 = _mm256_min_ps(f0, fp32max);
f2 = _mm256_min_ps(f2, fp32max); f1 = _mm256_min_ps(f1, fp32max);
f3 = _mm256_min_ps(f3, fp32max); f2 = _mm256_min_ps(f2, fp32max);
f0 = _mm256_max_ps(f0, fp32min); f3 = _mm256_min_ps(f3, fp32max);
f1 = _mm256_max_ps(f1, fp32min); f0 = _mm256_max_ps(f0, fp32min);
f2 = _mm256_max_ps(f2, fp32min); f1 = _mm256_max_ps(f1, fp32min);
f3 = _mm256_max_ps(f3, fp32min); f2 = _mm256_max_ps(f2, fp32min);
f3 = _mm256_max_ps(f3, fp32min);
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
_mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3);
}
} else {
_mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8, f2);
_mm256_storeu_ps(((float*)accum_x) + 3 * AVX2_PACKINT8, f3);
} }
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
_mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3);
} }
} }
return; return;
} }
if (3 == realDst) { if (3 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) { for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); const auto weight_dz = weight + dz * weight_step_Z;
const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const float* scale_dz = post->scale + dz * AVX2_PACKINT8; const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
auto dst_z = dst + dz * dst_step_tmp; auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src; const auto src_x = src;
auto dst_x = dst_z; auto dst_x = dst_z;
auto accum_z = accumbuff + dz * realDst * AVX2_PACKINT8;
auto accum_x = accum_z;
__m256i D00 = _mm256_set1_epi32(0); __m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0); __m256i D01 = _mm256_set1_epi32(0);
__m256i D02 = _mm256_set1_epi32(0); __m256i D02 = _mm256_set1_epi32(0);
@ -730,36 +770,45 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons
POSTTREAT(2); POSTTREAT(2);
} else { } else {
if (nullptr == biasPtr) { if (nullptr == biasPtr) {
auto dstv0 = _mm256_loadu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8); auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8);
auto dstv1 = _mm256_loadu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8); auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8);
auto dstv2 = _mm256_loadu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8); auto dstv2 = _mm256_loadu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8);
f0 = _mm256_add_ps(f0, dstv0); f0 = _mm256_add_ps(f0, dstv0);
f1 = _mm256_add_ps(f1, dstv1); f1 = _mm256_add_ps(f1, dstv1);
f2 = _mm256_add_ps(f2, dstv2); f2 = _mm256_add_ps(f2, dstv2);
} }
if (post->fp32minmax) { if (dst){
f0 = _mm256_min_ps(f0, fp32max); if (post->fp32minmax) {
f1 = _mm256_min_ps(f1, fp32max); f0 = _mm256_min_ps(f0, fp32max);
f2 = _mm256_min_ps(f2, fp32max); f1 = _mm256_min_ps(f1, fp32max);
f0 = _mm256_max_ps(f0, fp32min); f2 = _mm256_min_ps(f2, fp32max);
f1 = _mm256_max_ps(f1, fp32min); f0 = _mm256_max_ps(f0, fp32min);
f2 = _mm256_max_ps(f2, fp32min); f1 = _mm256_max_ps(f1, fp32min);
f2 = _mm256_max_ps(f2, fp32min);
}
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
} else {
_mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)accum_x) + 2 * AVX2_PACKINT8, f2);
} }
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
_mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2);
} }
} }
return; return;
} }
if (2 == realDst) { if (2 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) { for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); const auto weight_dz = weight + dz * weight_step_Z;
const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const float* scale_dz = post->scale + dz * AVX2_PACKINT8; const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
auto dst_z = dst + dz * dst_step_tmp; auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src; const auto src_x = src;
auto dst_x = dst_z; auto dst_x = dst_z;
auto accum_z = accumbuff + dz * realDst * AVX2_PACKINT8;
auto accum_x = accum_z;
__m256i D00 = _mm256_set1_epi32(0); __m256i D00 = _mm256_set1_epi32(0);
__m256i D01 = _mm256_set1_epi32(0); __m256i D01 = _mm256_set1_epi32(0);
@ -823,31 +872,38 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons
POSTTREAT(1); POSTTREAT(1);
} else { } else {
if (nullptr == biasPtr) { if (nullptr == biasPtr) {
auto dstv0 = _mm256_loadu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8); auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8);
auto dstv1 = _mm256_loadu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8); auto dstv1 = _mm256_loadu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8);
f0 = _mm256_add_ps(f0, dstv0); f0 = _mm256_add_ps(f0, dstv0);
f1 = _mm256_add_ps(f1, dstv1); f1 = _mm256_add_ps(f1, dstv1);
} }
if (post->fp32minmax) { if (dst) {
f0 = _mm256_min_ps(f0, fp32max); if (post->fp32minmax) {
f1 = _mm256_min_ps(f1, fp32max); f0 = _mm256_min_ps(f0, fp32max);
f0 = _mm256_max_ps(f0, fp32min); f1 = _mm256_min_ps(f1, fp32max);
f1 = _mm256_max_ps(f1, fp32min); f0 = _mm256_max_ps(f0, fp32min);
f1 = _mm256_max_ps(f1, fp32min);
}
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
} else {
_mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)accum_x) + 1 * AVX2_PACKINT8, f1);
} }
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
_mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1);
} }
} }
return; return;
} }
if (1 == realDst) { if (1 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) { for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); const auto weight_dz = weight + dz * weight_step_Z;
const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const float* scale_dz = post->scale + dz * AVX2_PACKINT8; const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
auto dst_z = dst + dz * dst_step_tmp; auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src; const auto src_x = src;
auto dst_x = dst_z; auto dst_x = dst_z;
auto accum_z = accumbuff + dz * realDst * AVX2_PACKINT8;
auto accum_x = accum_z;
__m256i D00 = _mm256_set1_epi32(0); __m256i D00 = _mm256_set1_epi32(0);
__m256i D10 = _mm256_set1_epi32(0); __m256i D10 = _mm256_set1_epi32(0);
@ -895,14 +951,18 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons
POSTTREAT(0); POSTTREAT(0);
} else { } else {
if (nullptr == biasPtr) { if (nullptr == biasPtr) {
auto dstv0 = _mm256_loadu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8); auto dstv0 = _mm256_loadu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8);
f0 = _mm256_add_ps(f0, dstv0); f0 = _mm256_add_ps(f0, dstv0);
} }
if (post->fp32minmax) { if (dst) {
f0 = _mm256_min_ps(f0, fp32max); if (post->fp32minmax) {
f0 = _mm256_max_ps(f0, fp32min); f0 = _mm256_min_ps(f0, fp32max);
f0 = _mm256_max_ps(f0, fp32min);
}
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
} else {
_mm256_storeu_ps(((float*)accum_x) + 0 * AVX2_PACKINT8, f0);
} }
_mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0);
} }
} }
return; return;
@ -942,13 +1002,13 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src,
kernelSum2 = _mm256_set1_ps(post->srcKernelSum[2]); kernelSum2 = _mm256_set1_ps(post->srcKernelSum[2]);
} }
} }
//printf("e=%d, sz=%d, dz=%d\n", realDst, src_depth_quad, dst_depth_quad); int weight_step_Z = src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H) + 4 * 2 * GEMMINT8_AVX2_H;
int weight_step_Y = (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H);
if (GEMMINT8_AVX2_E == realDst) { if (GEMMINT8_AVX2_E == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) { for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); const auto weight_dz = weight + dz * weight_step_Z;
const auto bias_dz = post->biasFloat + dz * AVX2_PACKINT8; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
const float* scale_dz = post->scale + dz * AVX2_PACKINT8;
auto dst_z = dst + dz * dst_step_tmp; auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src; const auto src_x = src;
auto dst_x = dst_z; auto dst_x = dst_z;
@ -978,12 +1038,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src,
auto D2 = D02; auto D2 = D02;
auto D3 = D03; auto D3 = D03;
// auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz));
auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz);
// D0 = _mm256_add_epi32(D0, biasValue0);
// D1 = _mm256_add_epi32(D1, biasValue0);
// D2 = _mm256_add_epi32(D2, biasValue0);
// D3 = _mm256_add_epi32(D3, biasValue0);
auto scaleValue = _mm256_loadu_ps(scale_dz); auto scaleValue = _mm256_loadu_ps(scale_dz);
auto f0 = _mm256_cvtepi32_ps(D0); auto f0 = _mm256_cvtepi32_ps(D0);
@ -1003,7 +1058,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src,
f1 = _mm256_add_ps(f1, xy0_1); f1 = _mm256_add_ps(f1, xy0_1);
f2 = _mm256_add_ps(f2, xy0_2); f2 = _mm256_add_ps(f2, xy0_2);
f3 = _mm256_add_ps(f3, xy0_3); f3 = _mm256_add_ps(f3, xy0_3);
auto biasValue = _mm256_loadu_ps(bias_dz); auto biasValue = _mm256_loadu_ps(weightBias_dz);
f0 = _mm256_add_ps(f0, biasValue); f0 = _mm256_add_ps(f0, biasValue);
f1 = _mm256_add_ps(f1, biasValue); f1 = _mm256_add_ps(f1, biasValue);
f2 = _mm256_add_ps(f2, biasValue); f2 = _mm256_add_ps(f2, biasValue);
@ -1032,10 +1087,9 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src,
} }
if (3 == realDst) { if (3 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) { for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); const auto weight_dz = weight + dz * weight_step_Z;
const auto bias_dz = post->biasFloat + dz * AVX2_PACKINT8; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
const float* scale_dz = post->scale + dz * AVX2_PACKINT8;
auto dst_z = dst + dz * dst_step_tmp; auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src; const auto src_x = src;
auto dst_x = dst_z; auto dst_x = dst_z;
@ -1081,7 +1135,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src,
f0 = _mm256_add_ps(f0, xy0_0); f0 = _mm256_add_ps(f0, xy0_0);
f1 = _mm256_add_ps(f1, xy0_1); f1 = _mm256_add_ps(f1, xy0_1);
f2 = _mm256_add_ps(f2, xy0_2); f2 = _mm256_add_ps(f2, xy0_2);
auto biasValue = _mm256_loadu_ps(bias_dz); auto biasValue = _mm256_loadu_ps(weightBias_dz);
f0 = _mm256_add_ps(f0, biasValue); f0 = _mm256_add_ps(f0, biasValue);
f1 = _mm256_add_ps(f1, biasValue); f1 = _mm256_add_ps(f1, biasValue);
f2 = _mm256_add_ps(f2, biasValue); f2 = _mm256_add_ps(f2, biasValue);
@ -1105,10 +1159,9 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src,
} }
if (2 == realDst) { if (2 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) { for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); const auto weight_dz = weight + dz * weight_step_Z;
const auto bias_dz = post->biasFloat + dz * AVX2_PACKINT8; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
const float* scale_dz = post->scale + dz * AVX2_PACKINT8;
auto dst_z = dst + dz * dst_step_tmp; auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src; const auto src_x = src;
auto dst_x = dst_z; auto dst_x = dst_z;
@ -1141,7 +1194,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src,
f1 = _mm256_mul_ps(f1, scaleValue); f1 = _mm256_mul_ps(f1, scaleValue);
f0 = _mm256_add_ps(f0, xy0_0); f0 = _mm256_add_ps(f0, xy0_0);
f1 = _mm256_add_ps(f1, xy0_1); f1 = _mm256_add_ps(f1, xy0_1);
auto biasValue = _mm256_loadu_ps(bias_dz); auto biasValue = _mm256_loadu_ps(weightBias_dz);
f0 = _mm256_add_ps(f0, biasValue); f0 = _mm256_add_ps(f0, biasValue);
f1 = _mm256_add_ps(f1, biasValue); f1 = _mm256_add_ps(f1, biasValue);
if (post->useInt8 == 0) { if (post->useInt8 == 0) {
@ -1160,10 +1213,9 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src,
} }
if (1 == realDst) { if (1 == realDst) {
for (int dz = 0; dz < dst_depth_quad; ++dz) { for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); const auto weight_dz = weight + dz * weight_step_Z;
const auto bias_dz = post->biasFloat + dz * AVX2_PACKINT8; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const auto weightBias_dz = scale_dz + GEMMINT8_AVX2_H;
const float* scale_dz = post->scale + dz * AVX2_PACKINT8;
auto dst_z = dst + dz * dst_step_tmp; auto dst_z = dst + dz * dst_step_tmp;
const auto src_x = src; const auto src_x = src;
auto dst_x = dst_z; auto dst_x = dst_z;
@ -1187,7 +1239,7 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src,
auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first
f0 = _mm256_mul_ps(f0, scaleValue); f0 = _mm256_mul_ps(f0, scaleValue);
f0 = _mm256_add_ps(f0, xy0_0); f0 = _mm256_add_ps(f0, xy0_0);
auto biasValue = _mm256_loadu_ps(bias_dz); auto biasValue = _mm256_loadu_ps(weightBias_dz);
f0 = _mm256_add_ps(f0, biasValue); f0 = _mm256_add_ps(f0, biasValue);
if (post->useInt8 == 0) { if (post->useInt8 == 0) {
f0 = _mm256_min_ps(f0, fp32max); f0 = _mm256_min_ps(f0, fp32max);

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -70,19 +70,182 @@ void _AVX512_MNNComputeScaleZeroScalar(float* source, float* min, float* max, si
max[0] = max_; max[0] = max_;
} }
void _AVX512_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) { static void _AVX512_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) {
// source: (ic/4, N, 4) auto srcStep = realSize * pack;
auto srcStep = pack * realSize; if (pack == 4) {
for (int i = 0; i < realSize; ++i) { __m128 mask = _mm_set1_ps(-0.0f);
float absmaxVal = 0.f; // absmaxVal>=0 float tmp[4];
for (int c = 0; c < src_depth_quad; ++c) { for (int i = 0; i < realSize; ++i) {
auto src = source + c * srcStep + i * pack; __m128 absmax_ = _mm_loadu_ps(source + i * pack);
for (int k = 0; k < pack; ++k) { absmax_ = _mm_andnot_ps(mask, absmax_);
absmaxVal = std::max(absmaxVal, std::abs(src[k])); auto src0 = source + i * pack;
for (int j = 1; j < src_depth_quad; ++j) {
__m128 vec = _mm_loadu_ps(src0 + j * srcStep);
vec = _mm_andnot_ps(mask, vec);
absmax_ = _mm_max_ps(absmax_, vec);
}
_mm_storeu_ps(tmp, absmax_);
float res = tmp[0];
for (int j = 1; j < pack; ++j) {
res = ALIMAX(res, tmp[j]);
}
absmax[i] = res;
}
return;
}
if (pack == 16) {
float tmp[16];
for (int i = 0; i < realSize; ++i) {
auto absmax_ = _mm512_loadu_ps(source + i * pack);
absmax_ = _mm512_abs_ps(absmax_);
auto src0 = source + i * pack;
for (int j = 1; j < src_depth_quad; ++j) {
auto vec = _mm512_loadu_ps(src0 + j * srcStep);
vec = _mm512_abs_ps(vec);
absmax_ = _mm512_max_ps(absmax_, vec);
}
auto maxval = _mm512_reduce_max_ps(absmax_);
absmax[i] = maxval;
}
return;
}
MNN_ERROR("absMax error: x86_x64 avx512 don't suppport pack=%d yet\n", pack);
}
static void _AVX512_DynamicQuant(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack) {
auto srcStep = realSize * pack;
if (pack == 16) { // core->pack=16
auto offset = _mm512_set1_epi32(128);
int32_t tmp[16];
int32_t* dstPtr = reinterpret_cast<int32_t*>(dst);
for (int i = 0; i < src_depth_quad; ++i) {
int xcount = realSize;
auto srcPtr = src + i * srcStep;
auto scalePtr = scale;
while (xcount > 3) {
auto scale0 = _mm512_set1_ps(scalePtr[0]);
auto scale1 = _mm512_set1_ps(scalePtr[1]);
auto scale2 = _mm512_set1_ps(scalePtr[2]);
auto scale3 = _mm512_set1_ps(scalePtr[3]);
auto data0 = _mm512_loadu_ps(srcPtr);
auto data1 = _mm512_loadu_ps(srcPtr + pack);
auto data2 = _mm512_loadu_ps(srcPtr + 2 * pack);
auto data3 = _mm512_loadu_ps(srcPtr + 3 * pack);
data0 = _mm512_mul_ps(data0, scale0);
data1 = _mm512_mul_ps(data1, scale1);
data2 = _mm512_mul_ps(data2, scale2);
data3 = _mm512_mul_ps(data3, scale3);
auto r0 = _mm512_cvt_roundps_epi32(data0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
auto r1 = _mm512_cvt_roundps_epi32(data1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
auto r2 = _mm512_cvt_roundps_epi32(data2, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
auto r3 = _mm512_cvt_roundps_epi32(data3, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
r0 = _mm512_add_epi32(r0, offset); // int32x16
r1 = _mm512_add_epi32(r1, offset); // int32x16
r2 = _mm512_add_epi32(r2, offset);
r3 = _mm512_add_epi32(r3, offset);
auto r0_16 = _mm512_packs_epi32(r0, r1); // 00001111 00001111 00001111 00001111
auto r1_16 = _mm512_packs_epi32(r2, r3); // 22223333 22223333 22223333 22223333
auto r0_8 = _mm512_packus_epi16(r0_16, r1_16); // 0000111122223333 0000111122223333 0000111122223333 0000111122223333
_mm512_storeu_si512(tmp, r0_8);
for (int k = 0; k < 4; ++k) {
dstPtr[k * 4 + 0] = tmp[k + 4 * 0];
dstPtr[k * 4 + 1] = tmp[k + 4 * 1];
dstPtr[k * 4 + 2] = tmp[k + 4 * 2];
dstPtr[k * 4 + 3] = tmp[k + 4 * 3];
}
// next round
xcount -= 4;
scalePtr += 4;
srcPtr += (4 * pack);
dstPtr += 16;
}
while (xcount) {
auto scale0 = _mm512_set1_ps(scalePtr[0]);
auto data0 = _mm512_loadu_ps(srcPtr);
data0 = _mm512_mul_ps(data0, scale0);
auto r0 = _mm512_cvt_roundps_epi32(data0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
r0 = _mm512_add_epi32(r0, offset); // int32x16
auto r0_16 = _mm512_packs_epi32(r0, r0); // 00001111 00001111 00001111 00001111
auto r0_8 = _mm512_packus_epi16(r0_16, r0_16); // 0000111122223333 0000111122223333 0000111122223333 0000111122223333
_mm512_storeu_si512(tmp, r0_8);
dstPtr[0] = tmp[4 * 0];
dstPtr[1] = tmp[4 * 1];
dstPtr[2] = tmp[4 * 2];
dstPtr[3] = tmp[4 * 3];
// next round
xcount--;
scalePtr += 1;
srcPtr += pack;
dstPtr += 4;
} }
} }
absmax[i] = absmaxVal; return;
} }
if (pack == 4) { // LP=4;
auto offset = _mm_set1_epi32(128);
int32_t tmp[4];
int32_t* dstPtr = reinterpret_cast<int32_t*>(dst);
for (int i = 0; i < src_depth_quad; ++i) {
int xcount = realSize;
auto srcPtr = src + i * srcStep;
auto scalePtr = scale;
while (xcount > 3) {
auto scale0 = _mm_set1_ps(scalePtr[0]);
auto scale1 = _mm_set1_ps(scalePtr[1]);
auto scale2 = _mm_set1_ps(scalePtr[2]);
auto scale3 = _mm_set1_ps(scalePtr[3]);
auto data0 = _mm_loadu_ps(srcPtr);
auto data1 = _mm_loadu_ps(srcPtr + pack);
auto data2 = _mm_loadu_ps(srcPtr + 2 * pack);
auto data3 = _mm_loadu_ps(srcPtr + 3 * pack);
data0 = _mm_mul_ps(data0, scale0);
data1 = _mm_mul_ps(data1, scale1);
data2 = _mm_mul_ps(data2, scale2);
data3 = _mm_mul_ps(data3, scale3);
data0 = _mm_round_ps(data0, 0);
data1 = _mm_round_ps(data1, 0);
data2 = _mm_round_ps(data2, 0);
data3 = _mm_round_ps(data3, 0);
auto r0 = _mm_cvtps_epi32(data0);
auto r1 = _mm_cvtps_epi32(data1);
auto r2 = _mm_cvtps_epi32(data2);
auto r3 = _mm_cvtps_epi32(data3);
r0 = _mm_add_epi32(r0, offset);
r1 = _mm_add_epi32(r1, offset);
r2 = _mm_add_epi32(r2, offset);
r3 = _mm_add_epi32(r3, offset);
auto r0_16 = _mm_packs_epi32(r0, r1); // 00001111
auto r1_16 = _mm_packs_epi32(r2, r3); // 22223333
auto r0_8 = _mm_packus_epi16(r0_16, r1_16); // 0000111122223333
_mm_storeu_si128((__m128i *)dstPtr, r0_8);
// next round
xcount -= 4;
scalePtr += 4;
srcPtr += (4 * pack);
dstPtr += 4;
}
while (xcount) {
auto scale0 = _mm_set1_ps(scalePtr[0]);
auto data0 = _mm_loadu_ps(srcPtr);
data0 = _mm_mul_ps(data0, scale0);
auto r0 = _mm_cvtps_epi32(_mm_round_ps(data0, 0));
r0 = _mm_add_epi32(r0, offset);
auto r0_16 = _mm_packs_epi32(r0, r0); // 00001111
auto r0_8 = _mm_packus_epi16(r0_16, r0_16); // 0000111122223333
_mm_storeu_si128((__m128i *)tmp, r0_8);
dstPtr[0] = tmp[0];
// next round
xcount--;
scalePtr += 1;
srcPtr += pack;
dstPtr += 1;
}
}
return;
}
MNN_ERROR("dynamic quant error: x86_x64 avx512 don't suppport pack=%d yet\n", pack);
return;
} }
void _AVX512_MNNReluWithSlopeChannel(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad) { void _AVX512_MNNReluWithSlopeChannel(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad) {
@ -701,5 +864,5 @@ void _AVX512_ExtraInit(void* functions) {
coreFunction->MNNGetSparseMatMulPackMode = _AVX512_MNNGetSparseMatMulPackMode; coreFunction->MNNGetSparseMatMulPackMode = _AVX512_MNNGetSparseMatMulPackMode;
coreFunction->MNNAdjustOptimalSparseKernel = _AVX512_MNNAdjustOptimalSparseKernel; coreFunction->MNNAdjustOptimalSparseKernel = _AVX512_MNNAdjustOptimalSparseKernel;
coreFunction->MNNDynamicQuant = _AVX512_DynamicQuant;
} }

View File

@ -61,6 +61,7 @@ void _SSE_MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, s
void _SSE_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack); void _SSE_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack);
void _SSE_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, void _SSE_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step,
size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst); size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst);
void _SSE_MNNDynamicQuant(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack);
#endif #endif
void _SSE_MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); void _SSE_MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);
void _SSE_MNNConvRunForLineDepthwise(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup, void _SSE_MNNConvRunForLineDepthwise(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup,

View File

@ -75,14 +75,19 @@ void _SSE_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons
if (post->biasFloat) { if (post->biasFloat) {
biasPtr = post->biasFloat; biasPtr = post->biasFloat;
} }
auto accumbuff = post->accumBuffer;
int weight_step_Y = GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT;
int weight_step_Z = src_depth_quad * weight_step_Y + 4 * 2 * GEMM_INT8_UNIT;
for (int dz = 0; dz < dst_depth_quad; ++dz) { for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); const auto weight_dz = weight + dz * weight_step_Z;
const auto weightBias_dz = post->weightQuanBias + dz * GEMM_INT8_UNIT; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const float* scale_dz = nullptr; const auto weightBias_dz = scale_dz + GEMM_INT8_UNIT;
scale_dz = post->scale + dz * GEMM_INT8_UNIT;
auto dst_z = dst + dz * dst_step_tmp; auto dst_z = dst + dz * dst_step_tmp;
auto accum_z = accumbuff + dz * realDst * GEMM_INT8_UNIT;
const auto src_x = src; const auto src_x = src;
auto dst_x = dst_z; auto dst_x = dst_z;
auto accum_x = accum_z;
__m128i d0 = _mm_set1_epi32(0); __m128i d0 = _mm_set1_epi32(0);
__m128i d1 = _mm_set1_epi32(0); __m128i d1 = _mm_set1_epi32(0);
__m128i d2 = _mm_set1_epi32(0); __m128i d2 = _mm_set1_epi32(0);
@ -177,12 +182,6 @@ d##i##j = _mm_add_epi32(_mm_madd_epi16(S##i##j##0, W##i##j##0), _mm_madd_epi16(S
E1 = _mm_hadd_epi32(E2, E3); E1 = _mm_hadd_epi32(E2, E3);
d3 = _mm_hadd_epi32(E0, E1); d3 = _mm_hadd_epi32(E0, E1);
auto scaleValue = _mm_loadu_ps(scale_dz); auto scaleValue = _mm_loadu_ps(scale_dz);
// auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz));
// d0 = _mm_add_epi32(d0, biasValue);
// d1 = _mm_add_epi32(d1, biasValue);
// d2 = _mm_add_epi32(d2, biasValue);
// d3 = _mm_add_epi32(d3, biasValue);
//auto biasValue = _mm_loadu_ps((float*)(bias_dz));
auto weightBiasValue = _mm_loadu_ps((float*)weightBias_dz); auto weightBiasValue = _mm_loadu_ps((float*)weightBias_dz);
__m128 f0 = _mm_cvtepi32_ps(d0); __m128 f0 = _mm_cvtepi32_ps(d0);
__m128 f1 = _mm_cvtepi32_ps(d1); __m128 f1 = _mm_cvtepi32_ps(d1);
@ -278,22 +277,28 @@ d##i##j = _mm_add_epi32(_mm_madd_epi16(S##i##j##0, W##i##j##0), _mm_madd_epi16(S
__m128 f[4] = {f0, f1, f2, f3}; __m128 f[4] = {f0, f1, f2, f3};
if (nullptr == biasPtr) { if (nullptr == biasPtr) {
for (int j = 0; j < realDst; ++j) { for (int j = 0; j < realDst; ++j) {
auto dstv = _mm_loadu_ps(((float*)dst_x) + j * 4); auto dstv = _mm_loadu_ps(((float*)accum_x) + j * 4);
f[j] = _mm_add_ps(dstv, f[j]); f[j] = _mm_add_ps(dstv, f[j]);
} }
} }
if (post->fp32minmax) { if (dst) {
f[0] = _mm_min_ps(f[0], fp32max); if (post->fp32minmax) {
f[1] = _mm_min_ps(f[1], fp32max); f[0] = _mm_min_ps(f[0], fp32max);
f[2] = _mm_min_ps(f[2], fp32max); f[1] = _mm_min_ps(f[1], fp32max);
f[3] = _mm_min_ps(f[3], fp32max); f[2] = _mm_min_ps(f[2], fp32max);
f[0] = _mm_max_ps(f[0], fp32min); f[3] = _mm_min_ps(f[3], fp32max);
f[1] = _mm_max_ps(f[1], fp32min); f[0] = _mm_max_ps(f[0], fp32min);
f[2] = _mm_max_ps(f[2], fp32min); f[1] = _mm_max_ps(f[1], fp32min);
f[3] = _mm_max_ps(f[3], fp32min); f[2] = _mm_max_ps(f[2], fp32min);
} f[3] = _mm_max_ps(f[3], fp32min);
for (int j = 0; j < realDst; ++j) { }
_mm_storeu_ps(((float*)dst_x) + j * 4, f[j]); for (int j = 0; j < realDst; ++j) {
_mm_storeu_ps(((float*)dst_x) + j * 4, f[j]);
}
} else {
for (int j = 0; j < realDst; ++j) {
_mm_storeu_ps(((float*)accum_x) + j * 4, f[j]);
}
} }
} }
} }
@ -322,7 +327,8 @@ void _SSE_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const
if (post->biasFloat) { if (post->biasFloat) {
biasPtr = post->biasFloat; biasPtr = post->biasFloat;
} }
int weight_step_Z = 0.5 * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); auto accumbuff = post->accumBuffer;
int weight_step_Z = 0.5 * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) + 4 * 2 * GEMM_INT8_UNIT;
int weight_step_Y = 0.5 * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); int weight_step_Y = 0.5 * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT);
auto oneValue = _mm_set1_epi16(1); auto oneValue = _mm_set1_epi16(1);
@ -370,12 +376,13 @@ void _SSE_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const
} }
for (int dz = 0; dz < dst_depth_quad; ++dz) { for (int dz = 0; dz < dst_depth_quad; ++dz) {
const auto weight_dz = weight + dz * weight_step_Z; const auto weight_dz = weight + dz * weight_step_Z;
const auto weightBias_dz = post->weightQuanBias + dz * GEMM_INT8_UNIT; const float* scale_dz = reinterpret_cast<const float*>(weight_dz + src_depth_quad * weight_step_Y);
const float* scale_dz = nullptr; const auto weightBias_dz = scale_dz + GEMM_INT8_UNIT;
scale_dz = post->scale + dz * GEMM_INT8_UNIT;
auto dst_z = dst + dz * dst_step_tmp; auto dst_z = dst + dz * dst_step_tmp;
auto accum_z = accumbuff + dz * realDst * GEMM_INT8_UNIT;
const auto src_x = src; const auto src_x = src;
auto dst_x = dst_z; auto dst_x = dst_z;
auto accum_x = accum_z;
__m128i d0 = _mm_set1_epi32(0); __m128i d0 = _mm_set1_epi32(0);
__m128i d1 = _mm_set1_epi32(0); __m128i d1 = _mm_set1_epi32(0);
__m128i d2 = _mm_set1_epi32(0); __m128i d2 = _mm_set1_epi32(0);
@ -521,22 +528,28 @@ auto d##i##j = _mm_add_epi32(_mm_madd_epi16(S##i##j##0, W##i##j##0), _mm_madd_ep
__m128 f[4] = {f0, f1, f2, f3}; __m128 f[4] = {f0, f1, f2, f3};
if (nullptr == biasPtr) { if (nullptr == biasPtr) {
for (int j = 0; j < realDst; ++j) { for (int j = 0; j < realDst; ++j) {
auto dstv = _mm_loadu_ps(((float*)dst_x) + j * 4); auto dstv = _mm_loadu_ps(((float*)accum_x) + j * 4);
f[j] = _mm_add_ps(dstv, f[j]); f[j] = _mm_add_ps(dstv, f[j]);
} }
} }
if (post->fp32minmax) { if (dst) {
f[0] = _mm_min_ps(f[0], fp32max); if (post->fp32minmax) {
f[1] = _mm_min_ps(f[1], fp32max); f[0] = _mm_min_ps(f[0], fp32max);
f[2] = _mm_min_ps(f[2], fp32max); f[1] = _mm_min_ps(f[1], fp32max);
f[3] = _mm_min_ps(f[3], fp32max); f[2] = _mm_min_ps(f[2], fp32max);
f[0] = _mm_max_ps(f[0], fp32min); f[3] = _mm_min_ps(f[3], fp32max);
f[1] = _mm_max_ps(f[1], fp32min); f[0] = _mm_max_ps(f[0], fp32min);
f[2] = _mm_max_ps(f[2], fp32min); f[1] = _mm_max_ps(f[1], fp32min);
f[3] = _mm_max_ps(f[3], fp32min); f[2] = _mm_max_ps(f[2], fp32min);
} f[3] = _mm_max_ps(f[3], fp32min);
for (int j = 0; j < realDst; ++j) { }
_mm_storeu_ps(((float*)dst_x) + j * 4, f[j]); for (int j = 0; j < realDst; ++j) {
_mm_storeu_ps(((float*)dst_x) + j * 4, f[j]);
}
} else {
for (int j = 0; j < realDst; ++j) {
_mm_storeu_ps(((float*)accum_x) + j * 4, f[j]);
}
} }
} }
} }

View File

@ -71,28 +71,298 @@ void _SSE_MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, s
#ifdef MNN_LOW_MEMORY #ifdef MNN_LOW_MEMORY
// Dynamic quant // Dynamic quant
void _SSE_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) { void _SSE_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) {
// source: (ic/4, N, 4) size_t srcStep = realSize * pack;
auto srcStep = pack * realSize; __m128 mask = _mm_set1_ps(-0.0f);
auto constant = _mm_castsi128_ps(_mm_set1_epi32(0x7FFFFFFF)); if (pack == 4) { // input c4
float temp[4]; float tmp[4];
for (int i = 0; i < realSize; ++i) { for (int i = 0; i < realSize; ++i) {
__m128 res = _mm_setzero_ps(); __m128 absmax_ = _mm_loadu_ps(source + i * pack);
for (int c = 0; c < src_depth_quad; ++c) { absmax_ = _mm_andnot_ps(mask, absmax_);
auto src0 = source + c * srcStep + i * pack; auto src0 = source + i * pack;
__m128 vecA = _mm_loadu_ps(src0); for (int j = 1; j < src_depth_quad; ++j) {
__m128 absVecA = _mm_and_ps(vecA, constant); __m128 vec = _mm_loadu_ps(src0 + j * srcStep);
__m128 mask = _mm_cmpgt_ps(res, absVecA); vec = _mm_andnot_ps(mask, vec);
res = _mm_blendv_ps(absVecA, res, mask); absmax_ = _mm_max_ps(absmax_, vec);
}
_mm_storeu_ps(tmp, absmax_);
float res = tmp[0];
for (int j = 1; j < pack; ++j) {
res = ALIMAX(res, tmp[j]);
}
absmax[i] = res;
} }
_mm_storeu_ps(temp, res); return;
float absmaxVal = temp[0]; }
for (int k = 1; k < pack; ++k) { if (pack == 16) { // (lu,ep,lp)
if (absmaxVal < temp[k]) { float tmp[16];
absmaxVal = temp[k]; for (int i = 0; i < realSize; ++i) {
__m128 absmax0 = _mm_loadu_ps(source + i * pack);
__m128 absmax1 = _mm_loadu_ps(source + i * pack + 4);
__m128 absmax2 = _mm_loadu_ps(source + i * pack + 8);
__m128 absmax3 = _mm_loadu_ps(source + i * pack + 12);
absmax0 = _mm_andnot_ps(mask, absmax0);
absmax1 = _mm_andnot_ps(mask, absmax1);
absmax2 = _mm_andnot_ps(mask, absmax2);
absmax3 = _mm_andnot_ps(mask, absmax3);
auto src0 = source + i * pack;
for (int j = 1; j < src_depth_quad; ++j) {
__m128 vec0 = _mm_loadu_ps(src0 + j * srcStep);
__m128 vec1 = _mm_loadu_ps(src0 + j * srcStep + 4);
__m128 vec2 = _mm_loadu_ps(src0 + j * srcStep + 8);
__m128 vec3 = _mm_loadu_ps(src0 + j * srcStep + 12);
vec0 = _mm_andnot_ps(mask, vec0);
vec1 = _mm_andnot_ps(mask, vec1);
vec2 = _mm_andnot_ps(mask, vec2);
vec3 = _mm_andnot_ps(mask, vec3);
absmax0 = _mm_max_ps(absmax0, vec0);
absmax1 = _mm_max_ps(absmax1, vec1);
absmax2 = _mm_max_ps(absmax2, vec2);
absmax3 = _mm_max_ps(absmax3, vec3);
}
absmax0 = _mm_max_ps(absmax0, absmax1);
absmax2 = _mm_max_ps(absmax2, absmax3);
absmax0 = _mm_max_ps(absmax0, absmax2);
_mm_storeu_ps(tmp, absmax0);
float res = tmp[0];
for (int j = 1; j < 4; ++j) {
res = ALIMAX(res, tmp[j]);
}
absmax[i] = res;
}
return;
}
MNN_ERROR("absMax error: x86_x64 sse don't suppport pack=%d yet\n", pack);
return;
}
void _SSE_MNNDynamicQuant(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack) {
auto srcStep = realSize * pack;
if (pack == 4) { // core->pack
auto offset = _mm_set1_epi32(128);
int32_t tmp[4];
int32_t* dstPtr = reinterpret_cast<int32_t*>(dst);
for (int i = 0; i < src_depth_quad; ++i) {
int xcount = realSize;
auto srcPtr = src + i * srcStep;
auto scalePtr = scale;
while (xcount > 3) {
auto scale0 = _mm_set1_ps(scalePtr[0]);
auto scale1 = _mm_set1_ps(scalePtr[1]);
auto scale2 = _mm_set1_ps(scalePtr[2]);
auto scale3 = _mm_set1_ps(scalePtr[3]);
auto data0 = _mm_loadu_ps(srcPtr);
auto data1 = _mm_loadu_ps(srcPtr + pack);
auto data2 = _mm_loadu_ps(srcPtr + 2 * pack);
auto data3 = _mm_loadu_ps(srcPtr + 3 * pack);
data0 = _mm_mul_ps(data0, scale0);
data1 = _mm_mul_ps(data1, scale1);
data2 = _mm_mul_ps(data2, scale2);
data3 = _mm_mul_ps(data3, scale3);
data0 = _mm_round_ps(data0, 0);
data1 = _mm_round_ps(data1, 0);
data2 = _mm_round_ps(data2, 0);
data3 = _mm_round_ps(data3, 0);
auto r0 = _mm_cvtps_epi32(data0);
auto r1 = _mm_cvtps_epi32(data1);
auto r2 = _mm_cvtps_epi32(data2);
auto r3 = _mm_cvtps_epi32(data3);
r0 = _mm_add_epi32(r0, offset);
r1 = _mm_add_epi32(r1, offset);
r2 = _mm_add_epi32(r2, offset);
r3 = _mm_add_epi32(r3, offset);
auto r0_16 = _mm_packs_epi32(r0, r1); // 00001111
auto r1_16 = _mm_packs_epi32(r2, r3); // 22223333
auto r0_8 = _mm_packus_epi16(r0_16, r1_16); // 0000111122223333
_mm_storeu_si128((__m128i *)dstPtr, r0_8);
// next round
xcount -= 4;
scalePtr += 4;
srcPtr += (4 * pack);
dstPtr += 4;
}
while (xcount) {
auto scale0 = _mm_set1_ps(scalePtr[0]);
auto data0 = _mm_loadu_ps(srcPtr);
data0 = _mm_mul_ps(data0, scale0);
auto r0 = _mm_cvtps_epi32(_mm_round_ps(data0, 0));
r0 = _mm_add_epi32(r0, offset);
auto r0_16 = _mm_packs_epi32(r0, r0); // 00001111
auto r0_8 = _mm_packus_epi16(r0_16, r0_16); // 0000111122223333
_mm_storeu_si128((__m128i *)tmp, r0_8);
dstPtr[0] = tmp[0];
// next round
xcount--;
scalePtr += 1;
srcPtr += pack;
dstPtr += 1;
} }
} }
absmax[i] = absmaxVal; return;
} }
if (pack == 16) {
auto offset = _mm_set1_epi32(128);
int32_t tmp[4];
int32_t* dstPtr = reinterpret_cast<int32_t*>(dst);
for (int i = 0; i < src_depth_quad; ++i) {
int xcount = realSize;
auto srcPtr = src + i * srcStep;
auto scalePtr = scale;
while (xcount > 3) {
auto scale0 = _mm_set1_ps(scalePtr[0]);
auto scale1 = _mm_set1_ps(scalePtr[1]);
auto scale2 = _mm_set1_ps(scalePtr[2]);
auto scale3 = _mm_set1_ps(scalePtr[3]);
auto data00 = _mm_loadu_ps(srcPtr);
auto data01 = _mm_loadu_ps(srcPtr + 4);
auto data02 = _mm_loadu_ps(srcPtr + 8);
auto data03 = _mm_loadu_ps(srcPtr + 12);
auto data10 = _mm_loadu_ps(srcPtr + pack);
auto data11 = _mm_loadu_ps(srcPtr + pack + 4);
auto data12 = _mm_loadu_ps(srcPtr + pack + 8);
auto data13 = _mm_loadu_ps(srcPtr + pack + 12);
auto data20 = _mm_loadu_ps(srcPtr + 2 * pack);
auto data21 = _mm_loadu_ps(srcPtr + 2 * pack + 4);
auto data22 = _mm_loadu_ps(srcPtr + 2 * pack + 8);
auto data23 = _mm_loadu_ps(srcPtr + 2 * pack + 12);
auto data30 = _mm_loadu_ps(srcPtr + 3 * pack);
auto data31 = _mm_loadu_ps(srcPtr + 3 * pack + 4);
auto data32 = _mm_loadu_ps(srcPtr + 3 * pack + 8);
auto data33 = _mm_loadu_ps(srcPtr + 3 * pack + 12);
data00 = _mm_mul_ps(data00, scale0);
data01 = _mm_mul_ps(data01, scale0);
data02 = _mm_mul_ps(data02, scale0);
data03 = _mm_mul_ps(data03, scale0);
data10 = _mm_mul_ps(data10, scale1);
data11 = _mm_mul_ps(data11, scale1);
data12 = _mm_mul_ps(data12, scale1);
data13 = _mm_mul_ps(data13, scale1);
data20 = _mm_mul_ps(data20, scale2);
data21 = _mm_mul_ps(data21, scale2);
data22 = _mm_mul_ps(data22, scale2);
data23 = _mm_mul_ps(data23, scale2);
data30 = _mm_mul_ps(data30, scale3);
data31 = _mm_mul_ps(data31, scale3);
data32 = _mm_mul_ps(data32, scale3);
data33 = _mm_mul_ps(data33, scale3);
data00 = _mm_round_ps(data00, 0);
data01 = _mm_round_ps(data01, 0);
data02 = _mm_round_ps(data02, 0);
data03 = _mm_round_ps(data03, 0);
data10 = _mm_round_ps(data10, 0);
data11 = _mm_round_ps(data11, 0);
data12 = _mm_round_ps(data12, 0);
data13 = _mm_round_ps(data13, 0);
data20 = _mm_round_ps(data20, 0);
data21 = _mm_round_ps(data21, 0);
data22 = _mm_round_ps(data22, 0);
data23 = _mm_round_ps(data23, 0);
data30 = _mm_round_ps(data30, 0);
data31 = _mm_round_ps(data31, 0);
data32 = _mm_round_ps(data32, 0);
data33 = _mm_round_ps(data33, 0);
auto r00 = _mm_cvtps_epi32(data00);
auto r01 = _mm_cvtps_epi32(data01);
auto r02 = _mm_cvtps_epi32(data02);
auto r03 = _mm_cvtps_epi32(data03);
auto r10 = _mm_cvtps_epi32(data10);
auto r11 = _mm_cvtps_epi32(data11);
auto r12 = _mm_cvtps_epi32(data12);
auto r13 = _mm_cvtps_epi32(data13);
auto r20 = _mm_cvtps_epi32(data20);
auto r21 = _mm_cvtps_epi32(data21);
auto r22 = _mm_cvtps_epi32(data22);
auto r23 = _mm_cvtps_epi32(data23);
auto r30 = _mm_cvtps_epi32(data30);
auto r31 = _mm_cvtps_epi32(data31);
auto r32 = _mm_cvtps_epi32(data32);
auto r33 = _mm_cvtps_epi32(data33);
r00 = _mm_add_epi32(r00, offset);
r01 = _mm_add_epi32(r01, offset);
r02 = _mm_add_epi32(r02, offset);
r03 = _mm_add_epi32(r03, offset);
r10 = _mm_add_epi32(r10, offset);
r11 = _mm_add_epi32(r11, offset);
r12 = _mm_add_epi32(r12, offset);
r13 = _mm_add_epi32(r13, offset);
r20 = _mm_add_epi32(r20, offset);
r21 = _mm_add_epi32(r21, offset);
r22 = _mm_add_epi32(r22, offset);
r23 = _mm_add_epi32(r23, offset);
r30 = _mm_add_epi32(r30, offset);
r31 = _mm_add_epi32(r31, offset);
r32 = _mm_add_epi32(r32, offset);
r33 = _mm_add_epi32(r33, offset);
auto r00_16 = _mm_packs_epi32(r00, r01); // 00000000
auto r01_16 = _mm_packs_epi32(r02, r03); // 00000000
auto r0_8 = _mm_packus_epi16(r00_16, r01_16); // 0000000000000000
auto r10_16 = _mm_packs_epi32(r10, r11);
auto r11_16 = _mm_packs_epi32(r12, r13);
auto r1_8 = _mm_packus_epi16(r10_16, r11_16);
auto r20_16 = _mm_packs_epi32(r20, r21);
auto r21_16 = _mm_packs_epi32(r22, r23);
auto r2_8 = _mm_packus_epi16(r20_16, r21_16);
auto r30_16 = _mm_packs_epi32(r30, r31);
auto r31_16 = _mm_packs_epi32(r32, r33);
auto r3_8 = _mm_packus_epi16(r30_16, r31_16);
_mm_storeu_si128((__m128i *)dstPtr, r0_8);
_mm_storeu_si128((__m128i *)(dstPtr + 4), r1_8);
_mm_storeu_si128((__m128i *)(dstPtr + 8), r2_8);
_mm_storeu_si128((__m128i *)(dstPtr + 12), r3_8);
// next round
xcount -= 4;
scalePtr += 4;
srcPtr += (4 * pack);
dstPtr += pack;
}
while (xcount) {
auto scale0 = _mm_set1_ps(scalePtr[0]);
auto data00 = _mm_loadu_ps(srcPtr);
auto data01 = _mm_loadu_ps(srcPtr + 4);
auto data02 = _mm_loadu_ps(srcPtr + 8);
auto data03 = _mm_loadu_ps(srcPtr + 12);
data00 = _mm_mul_ps(data00, scale0);
data01 = _mm_mul_ps(data01, scale0);
data02 = _mm_mul_ps(data02, scale0);
data03 = _mm_mul_ps(data03, scale0);
data00 = _mm_round_ps(data00, 0);
data01 = _mm_round_ps(data01, 0);
data02 = _mm_round_ps(data02, 0);
data03 = _mm_round_ps(data03, 0);
auto r00 = _mm_cvtps_epi32(data00);
auto r01 = _mm_cvtps_epi32(data01);
auto r02 = _mm_cvtps_epi32(data02);
auto r03 = _mm_cvtps_epi32(data03);
r00 = _mm_add_epi32(r00, offset);
r01 = _mm_add_epi32(r01, offset);
r02 = _mm_add_epi32(r02, offset);
r03 = _mm_add_epi32(r03, offset);
auto r00_16 = _mm_packs_epi32(r00, r01); // 00000000
auto r01_16 = _mm_packs_epi32(r02, r03); // 00000000
auto r0_8 = _mm_packus_epi16(r00_16, r01_16); // 0000000000000000
_mm_storeu_si128((__m128i *)dstPtr, r0_8);
// next round
xcount--;
scalePtr += 1;
srcPtr += pack;
dstPtr += 4;
}
}
return;
}
MNN_ERROR("dynamic quant error: x86_x64 sse don't suppport pack=%d yet\n", pack);
return;
} }
#endif #endif

View File

@ -179,19 +179,58 @@ static void createLibrary(id<MTLDevice> device, NSMutableDictionary<NSString *,
return cmdBuffer; return cmdBuffer;
} }
bool getCloseThreadgroup(const std::map<std::string, std::vector<std::pair<std::vector<uint32_t>, std::tuple<std::vector<uint32_t>, std::vector<uint32_t>, uint32_t>>>> &tuneMap, const std::vector<uint32_t> &gws, const std::string &kernelName, std::tuple<std::vector<uint32_t>, std::vector<uint32_t>, uint32_t>& res){
float minScale = 0.1;
auto iter = tuneMap.find(kernelName);
if(iter == tuneMap.end()){
return false;
}
auto gwsAndLws = iter->second;
int size = gws.size();
uint32_t minPoint = UINT_MAX;
int index = -1;
for(int i = 0; i < gwsAndLws.size(); ++i){
uint32_t point = 0;
for(int j = 0; j < size; ++j){
point += std::abs(static_cast<int>(gws[j]) - static_cast<int>(gwsAndLws[i].first[j]));
}
if(point < minPoint){
index = i;
minPoint = point;
}
}
if(index != -1){
res = gwsAndLws[index].second;
return true;
}
return false;
}
- (std::tuple<MTLSize, MTLSize, NSUInteger>) getGridAndThreadgroup: (id<MTLComputePipelineState>)pipeline gid:(MTLSize)threads loop:(NSUInteger)count buffer:(NSArray *)buffers runtime:(MetalRuntime *) rt shaderName:(std::string) kernelName offsets:(int *) offset_arr queue:(id<MTLCommandQueue>) cmdqueue { - (std::tuple<MTLSize, MTLSize, NSUInteger>) getGridAndThreadgroup: (id<MTLComputePipelineState>)pipeline gid:(MTLSize)threads loop:(NSUInteger)count buffer:(NSArray *)buffers runtime:(MetalRuntime *) rt shaderName:(std::string) kernelName offsets:(int *) offset_arr queue:(id<MTLCommandQueue>) cmdqueue {
NSUInteger gid_x = threads.width; NSUInteger gid_x = threads.width;
NSUInteger gid_y = threads.height; NSUInteger gid_y = threads.height;
NSUInteger gid_z = threads.depth; NSUInteger gid_z = threads.depth;
auto& tunedThreadGroup = rt->getTunedThreadGroup(); auto& tunedThreadGroup = rt->getTunedThreadGroup();
std::vector<uint32_t> gws = {(uint32_t)gid_x, (uint32_t)gid_y, (uint32_t)gid_z}; std::vector<uint32_t> gws = {(uint32_t)gid_x, (uint32_t)gid_y, (uint32_t)gid_z};
std::pair<std::string, std::vector<uint32_t>> info = std::make_pair(kernelName, gws); std::pair<std::string, std::vector<uint32_t>> info = std::make_pair(kernelName, gws);
if (tunedThreadGroup.find(info) != tunedThreadGroup.end()) { bool exactRes = tunedThreadGroup.find(info) != tunedThreadGroup.end();
std::tuple<std::vector<uint32_t>, std::vector<uint32_t>, uint32_t> tuneLwsRes;
bool closeRes = false;
if(!exactRes) {
auto& tunedThreadGroupVec = rt->getTunedThreadGroupVec();
if(getCloseThreadgroup(tunedThreadGroupVec, gws, kernelName, tuneLwsRes)){
closeRes = true;
}
} else {
tuneLwsRes = tunedThreadGroup[info];
}
if (exactRes || closeRes) {
//printf("conv2d1x1LocalWSOpt Found! gws:%d %d lws:%d %d\n", gws[0], gws[1], tunedLws[info][0], tunedLws[info][1]); //printf("conv2d1x1LocalWSOpt Found! gws:%d %d lws:%d %d\n", gws[0], gws[1], tunedLws[info][0], tunedLws[info][1]);
auto groupNum = std::get<0>(tunedThreadGroup[info]); auto groupNum = std::get<0>(tuneLwsRes);
auto groupSize = std::get<1>(tunedThreadGroup[info]); auto groupSize = std::get<1>(tuneLwsRes);
auto timeCost = std::get<2>(tunedThreadGroup[info]); auto timeCost = std::get<2>(tuneLwsRes);
MTLSize _groupNum = {(NSUInteger)groupNum[0], (NSUInteger)groupNum[1], (NSUInteger)groupNum[2]}; MTLSize _groupNum = {(NSUInteger)groupNum[0], (NSUInteger)groupNum[1], (NSUInteger)groupNum[2]};
MTLSize _groupSize = {(NSUInteger)groupSize[0], (NSUInteger)groupSize[1], (NSUInteger)groupSize[2]}; MTLSize _groupSize = {(NSUInteger)groupSize[0], (NSUInteger)groupSize[1], (NSUInteger)groupSize[2]};

View File

@ -61,6 +61,9 @@ public:
std::map<std::pair<std::string, std::vector<uint32_t>>, std::tuple<std::vector<uint32_t>, std::vector<uint32_t>, uint32_t>>& getTunedThreadGroup() { std::map<std::pair<std::string, std::vector<uint32_t>>, std::tuple<std::vector<uint32_t>, std::vector<uint32_t>, uint32_t>>& getTunedThreadGroup() {
return mTunedThreadGroup; return mTunedThreadGroup;
}; };
std::map<std::string, std::vector<std::pair<std::vector<uint32_t>, std::tuple<std::vector<uint32_t>, std::vector<uint32_t>, uint32_t>>>>& getTunedThreadGroupVec() {
return mTunedThreadGroupVec;
}
virtual Backend *onCreate(const BackendConfig* config, Backend* origin) const override; virtual Backend *onCreate(const BackendConfig* config, Backend* origin) const override;
virtual void onGabageCollect(int level) override; virtual void onGabageCollect(int level) override;
virtual CompilerType onGetCompilerType() const override { virtual CompilerType onGetCompilerType() const override {
@ -88,6 +91,7 @@ private:
mutable std::vector<SingleBufferWithAllocator> mDynamic; mutable std::vector<SingleBufferWithAllocator> mDynamic;
MetalTuneLevel mTuneLevel = Wide; MetalTuneLevel mTuneLevel = Wide;
std::map<std::pair<std::string, std::vector<uint32_t>>, std::tuple<std::vector<uint32_t>, std::vector<uint32_t>, uint32_t>> mTunedThreadGroup; std::map<std::pair<std::string, std::vector<uint32_t>>, std::tuple<std::vector<uint32_t>, std::vector<uint32_t>, uint32_t>> mTunedThreadGroup;
std::map<std::string, std::vector<std::pair<std::vector<uint32_t>, std::tuple<std::vector<uint32_t>, std::vector<uint32_t>, uint32_t>>>> mTunedThreadGroupVec;
private: private:
id<MTLCommandQueue> mQueue = nil; id<MTLCommandQueue> mQueue = nil;

View File

@ -1060,6 +1060,7 @@ bool MetalRuntime::setCache(std::pair<const void*, size_t> cache) {//Get Cache
} }
uint32_t cost = tun->timeCost(); uint32_t cost = tun->timeCost();
mTunedThreadGroup.insert(std::make_pair(std::make_pair(tun->key()->str(), glo), std::make_tuple(grop, loc, cost))); mTunedThreadGroup.insert(std::make_pair(std::make_pair(tun->key()->str(), glo), std::make_tuple(grop, loc, cost)));
mTunedThreadGroupVec[tun->key()->str()].emplace_back(std::make_pair(glo, std::make_tuple(grop, loc, cost)));
} }
} }
return true; return true;

View File

@ -9,9 +9,8 @@
file(GLOB_RECURSE MNN_OpenCL_SRC ${CMAKE_CURRENT_LIST_DIR}/*) file(GLOB_RECURSE MNN_OpenCL_SRC ${CMAKE_CURRENT_LIST_DIR}/*)
option(MNN_OPENCL_SIZE_CUT "Disable MNN OpenCL Buffer Opt" OFF) option(MNN_OPENCL_SIZE_CUT "Disable MNN OpenCL Buffer Opt" OFF)
option(MNN_OPENCL_PROFILE "Enable MNN OpenCL Kernel Profile" OFF)
IF (MNN_OPENCL_PROFILE) IF (MNN_GPU_TIME_PROFILE)
add_definitions(-DENABLE_OPENCL_TIME_PROFILER) add_definitions(-DENABLE_OPENCL_TIME_PROFILER)
ENDIF() ENDIF()

View File

@ -40,7 +40,7 @@ CLRuntime::CLRuntime(const Backend::Info& info, int platformSize, int platformId
} }
// Shader precision // Shader precision
mOpenCLRuntime.reset(new OpenCLRuntime(precision, mInfo.gpuMode, platformSize, platformId, deviceId, contextPtr, glshared)); mOpenCLRuntime.reset(new OpenCLRuntime(precision, mInfo.gpuMode, platformSize, platformId, deviceId, contextPtr, glshared, hint()));
//Whether runtimeError //Whether runtimeError
mCLRuntimeError = mOpenCLRuntime->isCreateError(); mCLRuntimeError = mOpenCLRuntime->isCreateError();
mPrecision = precision; mPrecision = precision;
@ -588,6 +588,8 @@ void OpenCLBackend::onResizeBegin() {
#ifndef ENABLE_OPENCL_TIME_PROFILER #ifndef ENABLE_OPENCL_TIME_PROFILER
mOpenCLRuntime->setCommandQueueProfileEnable(); mOpenCLRuntime->setCommandQueueProfileEnable();
#endif #endif
// update mUseRecordableQueueSize if hint has changed
mUseRecordableQueueSize = mCLRuntime->hint().encorderNumForCommit <= mUseRecordableQueueSize ? mCLRuntime->hint().encorderNumForCommit : mUseRecordableQueueSize;
releaseRecord(); releaseRecord();
} }
@ -1352,15 +1354,16 @@ void OpenCLBackend::recordKernel2d(const std::shared_ptr<KernelWrap> &kernelW, c
cl_int res = CL_SUCCESS; cl_int res = CL_SUCCESS;
if(!mDevideOpRecord){ if(!mDevideOpRecord){
RecordInfo info; RecordInfo info;
int recordNum = mRecordNums == mUseRecordableQueueSize ? 0 : mRecordNums;
if(updateInfo != nullptr){ if(updateInfo != nullptr){
for(int i = 0; i < updateInfo->update_kernel_args.size(); ++i){ for(int i = 0; i < updateInfo->update_kernel_args.size(); ++i){
updateInfo->update_kernel_args[i].dispatch_index = mRecordNums; updateInfo->update_kernel_args[i].dispatch_index = recordNum;
} }
for(int i = 0; i < updateInfo->update_global_size.size(); ++i){ for(int i = 0; i < updateInfo->update_global_size.size(); ++i){
updateInfo->update_global_size[i].dispatch_index = mRecordNums; updateInfo->update_global_size[i].dispatch_index = recordNum;
} }
for(int i = 0; i < updateInfo->update_local_size.size(); ++i){ for(int i = 0; i < updateInfo->update_local_size.size(); ++i){
updateInfo->update_local_size[i].dispatch_index = mRecordNums; updateInfo->update_local_size[i].dispatch_index = recordNum;
} }
info.updateInfo.emplace_back(updateInfo); info.updateInfo.emplace_back(updateInfo);
} }
@ -1421,15 +1424,16 @@ void OpenCLBackend::recordKernel3d(const std::shared_ptr<KernelWrap> &kernelW, c
} }
if(!mDevideOpRecord){ if(!mDevideOpRecord){
RecordInfo info; RecordInfo info;
int recordNum = mRecordNums == mUseRecordableQueueSize ? 0 : mRecordNums;
if(updateInfo != nullptr){ if(updateInfo != nullptr){
for(int i = 0; i < updateInfo->update_kernel_args.size(); ++i){ for(int i = 0; i < updateInfo->update_kernel_args.size(); ++i){
updateInfo->update_kernel_args[i].dispatch_index = mRecordNums; updateInfo->update_kernel_args[i].dispatch_index = recordNum;
} }
for(int i = 0; i < updateInfo->update_global_size.size(); ++i){ for(int i = 0; i < updateInfo->update_global_size.size(); ++i){
updateInfo->update_global_size[i].dispatch_index = mRecordNums; updateInfo->update_global_size[i].dispatch_index = recordNum;
} }
for(int i = 0; i < updateInfo->update_local_size.size(); ++i){ for(int i = 0; i < updateInfo->update_local_size.size(); ++i){
updateInfo->update_local_size[i].dispatch_index = mRecordNums; updateInfo->update_local_size[i].dispatch_index = recordNum;
} }
info.updateInfo.emplace_back(updateInfo); info.updateInfo.emplace_back(updateInfo);
} }

View File

@ -38,7 +38,7 @@ static void callback(const char *buffer, size_t length, size_t final, void *user
} }
#endif #endif
OpenCLRuntime::OpenCLRuntime(const BackendConfig::PrecisionMode precision, const int cl_mode, int platformSize, int platformId, int deviceId, void *contextPtr, void *glShared) { OpenCLRuntime::OpenCLRuntime(const BackendConfig::PrecisionMode precision, const int cl_mode, int platformSize, int platformId, int deviceId, void *contextPtr, void *glShared, const RuntimeHint& hint) {
#ifdef LOG_VERBOSE #ifdef LOG_VERBOSE
MNN_PRINT("start OpenCLRuntime !\n"); MNN_PRINT("start OpenCLRuntime !\n");
#endif #endif
@ -271,8 +271,7 @@ OpenCLRuntime::OpenCLRuntime(const BackendConfig::PrecisionMode precision, const
uint32_t MaxRecordableQueueSize = mFirstGPUDevicePtr->getInfo<CL_DEVICE_RECORDABLE_QUEUE_MAX_SIZE>(); uint32_t MaxRecordableQueueSize = mFirstGPUDevicePtr->getInfo<CL_DEVICE_RECORDABLE_QUEUE_MAX_SIZE>();
cl_int err; cl_int err;
if(MaxRecordableQueueSize > 0){ if(MaxRecordableQueueSize > 0){
// TODO: Use setSessionHint to set the number of mUseRecordableQueueSize mUseRecordableQueueSize = hint.encorderNumForCommit;
mUseRecordableQueueSize = MaxRecordableQueueSize;
mUseRecordableQueueSize = MaxRecordableQueueSize < mUseRecordableQueueSize ? MaxRecordableQueueSize : mUseRecordableQueueSize; mUseRecordableQueueSize = MaxRecordableQueueSize < mUseRecordableQueueSize ? MaxRecordableQueueSize : mUseRecordableQueueSize;
mUseRecordQueue = true; mUseRecordQueue = true;
mRecordableQueuePtr = std::make_shared<cl::CommandQueue>(*mContext, *mFirstGPUDevicePtr, CL_QUEUE_RECORDABLE_QCOM, &err); mRecordableQueuePtr = std::make_shared<cl::CommandQueue>(*mContext, *mFirstGPUDevicePtr, CL_QUEUE_RECORDABLE_QCOM, &err);

View File

@ -69,7 +69,7 @@ private:
}; };
class OpenCLRuntime { class OpenCLRuntime {
public: public:
OpenCLRuntime(const BackendConfig::PrecisionMode precision, const int cl_mode, int platformSize, int platformId, int deviceId, void *contextPtr, void *glShared); OpenCLRuntime(const BackendConfig::PrecisionMode precision, const int cl_mode, int platformSize, int platformId, int deviceId, void *contextPtr, void *glShared, const RuntimeHint& hint);
~OpenCLRuntime(); ~OpenCLRuntime();
OpenCLRuntime(const OpenCLRuntime &) = delete; OpenCLRuntime(const OpenCLRuntime &) = delete;
OpenCLRuntime &operator=(const OpenCLRuntime &) = delete; OpenCLRuntime &operator=(const OpenCLRuntime &) = delete;

View File

@ -526,6 +526,9 @@ ErrorCode AttentionBufExecution::onResize(const std::vector<Tensor *> &inputs, c
mOpenCLBackend->startRecord(mRecording); mOpenCLBackend->startRecord(mRecording);
//clear update arg vector, if prefill and decode use the same one //clear update arg vector, if prefill and decode use the same one
mOpRecordUpdateInfo.clear(); mOpRecordUpdateInfo.clear();
mRgQUpdateInfo.update_kernel_args.clear();
mRgQUpdateInfo.update_global_size.clear();
mRgQUpdateInfo.update_local_size.clear();
mRgUpdateInfo.update_kernel_args.clear(); mRgUpdateInfo.update_kernel_args.clear();
mRgUpdateInfo.update_global_size.clear(); mRgUpdateInfo.update_global_size.clear();
mRgUpdateInfo.update_local_size.clear(); mRgUpdateInfo.update_local_size.clear();
@ -615,6 +618,7 @@ ErrorCode AttentionBufExecution::onResize(const std::vector<Tensor *> &inputs, c
mGlobalWorkSizeRearrgQ[0] = ROUND_UP(mGlobalWorkSizeRearrgQ[0], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[0])); mGlobalWorkSizeRearrgQ[0] = ROUND_UP(mGlobalWorkSizeRearrgQ[0], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[0]));
mGlobalWorkSizeRearrgQ[1] = ROUND_UP(mGlobalWorkSizeRearrgQ[1], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[1])); mGlobalWorkSizeRearrgQ[1] = ROUND_UP(mGlobalWorkSizeRearrgQ[1], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[1]));
mGlobalWorkSizeRearrgQ[2] = ROUND_UP(mGlobalWorkSizeRearrgQ[2], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[2])); mGlobalWorkSizeRearrgQ[2] = ROUND_UP(mGlobalWorkSizeRearrgQ[2], std::max((uint32_t)1, mLocalWorkSizeRearrgQ[2]));
mOpRecordUpdateInfo.emplace_back(&mRgQUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ); mOpenCLBackend->recordKernel3d(mKernel_rearrangeQ, mGlobalWorkSizeRearrgQ, mLocalWorkSizeRearrgQ);
} }
{ {
@ -724,6 +728,7 @@ ErrorCode AttentionBufExecution::onResize(const std::vector<Tensor *> &inputs, c
mGlobalWorkSizeSoftMax[0] = ROUND_UP(mGlobalWorkSizeSoftMax[0], std::max((uint32_t)1, mLocalWorkSizeSoftMax[0])); mGlobalWorkSizeSoftMax[0] = ROUND_UP(mGlobalWorkSizeSoftMax[0], std::max((uint32_t)1, mLocalWorkSizeSoftMax[0]));
mGlobalWorkSizeSoftMax[1] = ROUND_UP(mGlobalWorkSizeSoftMax[1], std::max((uint32_t)1, mLocalWorkSizeSoftMax[1])); mGlobalWorkSizeSoftMax[1] = ROUND_UP(mGlobalWorkSizeSoftMax[1], std::max((uint32_t)1, mLocalWorkSizeSoftMax[1]));
mGlobalWorkSizeSoftMax[2] = ROUND_UP(mGlobalWorkSizeSoftMax[2], std::max((uint32_t)1, mLocalWorkSizeSoftMax[2])); mGlobalWorkSizeSoftMax[2] = ROUND_UP(mGlobalWorkSizeSoftMax[2], std::max((uint32_t)1, mLocalWorkSizeSoftMax[2]));
mOpRecordUpdateInfo.emplace_back(&mSoftMaxUpdateInfo);
mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax); mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax);
} }
{ {
@ -1066,6 +1071,9 @@ ErrorCode AttentionBufExecution::UpdateArgs(const std::vector<Tensor *> &inputs,
mQkvUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->value()))(); mQkvUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->value()))();
}else{ }else{
mRgUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->key()))(); mRgUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->key()))();
mQkUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->key()))();
mRgVUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->value()))();
mQkvUpdateInfo.update_kernel_args[0].arg_value = &(*(mKVCacheCLManager->value()))();
} }
} else { } else {
#endif #endif

View File

@ -89,6 +89,7 @@ private:
uint32_t mMaxWorkGroupSize; uint32_t mMaxWorkGroupSize;
OpenCLBackend *mOpenCLBackend; OpenCLBackend *mOpenCLBackend;
RecordUpdateInfo mRgUpdateInfo; RecordUpdateInfo mRgUpdateInfo;
RecordUpdateInfo mRgQUpdateInfo;
RecordUpdateInfo mQkUpdateInfo; RecordUpdateInfo mQkUpdateInfo;
RecordUpdateInfo mSoftMaxUpdateInfo; RecordUpdateInfo mSoftMaxUpdateInfo;
RecordUpdateInfo mRgVUpdateInfo; RecordUpdateInfo mRgVUpdateInfo;

View File

@ -779,31 +779,7 @@ public:
std::vector<int> outputShape = tensorShapeFormat(output); std::vector<int> outputShape = tensorShapeFormat(output);
const int outputChannel = outputShape.at(3); const int outputChannel = outputShape.at(3);
const int inputChannels = inputShape.at(3); const int inputChannels = inputShape.at(3);
#ifdef MNN_LOW_MEMORY
if (static_cast<OpenCLBackend *>(backend)->getMemory() == BackendConfig::Memory_Low){
auto conv2dParams = op->main_as_Convolution2D();
if (conv2dParams->quanParameter() != nullptr) {
if (((conv2dParams->quanParameter()->type() == 4) ||
(conv2dParams->quanParameter()->type() == 1) ||
(conv2dParams->quanParameter()->type() == 2))) {
if ((1 == conv2dParams->quanParameter()->type() || 2 == conv2dParams->quanParameter()->type()) && conv2dParams->quanParameter()->has_scaleInt()) {
// Don't support IDST-int8 because of error
return nullptr;
}
for (int i = 0; i < inputs.size(); ++i) {
TensorUtils::setTensorSupportPack(inputs[i], false);
}
for (int i = 0; i < outputs.size(); ++i) {
TensorUtils::setTensorSupportPack(outputs[i], false);
}
return new ConvBufLowMemoryExecution(inputs, outputs, op, backend);
} else {
//MNN_ERROR("OpenCL Conv buf low memory init error. For Opencl Backend, only support low memory mode of int8 or int4 dequantization currently.\n");
return nullptr;
}
}
}
#endif
if (nullptr != op->main_as_Convolution2D()->quanParameter()) { if (nullptr != op->main_as_Convolution2D()->quanParameter()) {
auto quan = op->main_as_Convolution2D()->quanParameter(); auto quan = op->main_as_Convolution2D()->quanParameter();
if (1 == quan->type() || 2 == quan->type()) { if (1 == quan->type() || 2 == quan->type()) {
@ -854,6 +830,33 @@ public:
return new ConvSubgroupBuf(inputs, outputs, op, backend); return new ConvSubgroupBuf(inputs, outputs, op, backend);
} }
#endif /* MNN_SUPPORT_INTEL_SUBGROUP */ #endif /* MNN_SUPPORT_INTEL_SUBGROUP */
#ifdef MNN_LOW_MEMORY
if (static_cast<OpenCLBackend *>(backend)->getMemory() == BackendConfig::Memory_Low){
auto conv2dParams = op->main_as_Convolution2D();
if (conv2dParams->quanParameter() != nullptr) {
if (((conv2dParams->quanParameter()->type() == 4) ||
(conv2dParams->quanParameter()->type() == 1) ||
(conv2dParams->quanParameter()->type() == 2))) {
if ((1 == conv2dParams->quanParameter()->type() || 2 == conv2dParams->quanParameter()->type()) && conv2dParams->quanParameter()->has_scaleInt()) {
// Don't support IDST-int8 because of error
return nullptr;
}
for (int i = 0; i < inputs.size(); ++i) {
TensorUtils::setTensorSupportPack(inputs[i], false);
}
for (int i = 0; i < outputs.size(); ++i) {
TensorUtils::setTensorSupportPack(outputs[i], false);
}
return new ConvBufLowMemoryExecution(inputs, outputs, op, backend);
} else {
MNN_ERROR("OpenCL Conv buf low memory init error. For Opencl Backend, only support low memory mode of int8 or int4 dequantization currently.\n");
return nullptr;
}
}
}
#endif
for (int i = 0; i < inputs.size(); ++i) { for (int i = 0; i < inputs.size(); ++i) {
TensorUtils::setTensorSupportPack(inputs[i], false); TensorUtils::setTensorSupportPack(inputs[i], false);
} }

View File

@ -832,7 +832,8 @@ ErrorCode ConvBufLowMemoryExecution::onResize(const std::vector<Tensor *> &input
if(batch == 1){ if(batch == 1){
tuneGemvLowMemory(input, output); tuneGemvLowMemory(input, output);
} else { } else {
if(batch > 512){ // when batch is big, convert to float weight and do gemm computation in floating field
if(batch > 128){
useFPWeightGemmLowMemory(input, output); useFPWeightGemmLowMemory(input, output);
mUseFPWeight = true; mUseFPWeight = true;
} else { } else {

View File

@ -5,6 +5,15 @@ else()
FILE(GLOB_RECURSE MNN_Vulkan_SRC ${CMAKE_CURRENT_LIST_DIR}/buffer/* ${CMAKE_CURRENT_LIST_DIR}/component/* ${CMAKE_CURRENT_LIST_DIR}/runtime/* ${CMAKE_CURRENT_LIST_DIR}/vulkan/*) FILE(GLOB_RECURSE MNN_Vulkan_SRC ${CMAKE_CURRENT_LIST_DIR}/buffer/* ${CMAKE_CURRENT_LIST_DIR}/component/* ${CMAKE_CURRENT_LIST_DIR}/runtime/* ${CMAKE_CURRENT_LIST_DIR}/vulkan/*)
endif() endif()
if(MNN_GPU_TIME_PROFILE)
if(APPLE)
message(STATUS "Timeprofile feature for Vulkan backend is currently not supported on Apple systems.")
else()
add_definitions(-DENABLE_VULKAN_TIME_PROFILE)
endif()
ENDIF()
include_directories("./") include_directories("./")
if(MNN_USE_SYSTEM_LIB) if(MNN_USE_SYSTEM_LIB)
find_package(Vulkan REQUIRED) find_package(Vulkan REQUIRED)

View File

@ -18,6 +18,12 @@
//#define MNN_OPEN_TIME_TRACE //#define MNN_OPEN_TIME_TRACE
#include <MNN/AutoTime.hpp> #include <MNN/AutoTime.hpp>
// #define MNN_OP_SUPPORT_LOG // #define MNN_OP_SUPPORT_LOG
#ifdef ENABLE_VULKAN_TIME_PROFILE
#include <chrono>
#include <unordered_map>
#include <algorithm>
#endif
//#define MNN_VULKAN_DUMP_MEMORY_USAGE //#define MNN_VULKAN_DUMP_MEMORY_USAGE
namespace MNN { namespace MNN {
@ -199,6 +205,9 @@ Execution* VulkanBackend::onCreate(const std::vector<Tensor*>& inputs, const std
return nullptr; return nullptr;
} }
std::shared_ptr<VulkanBasicExecution> originExecution ((VulkanBasicExecution*)iter->second->onCreate(inputs, outputs, op, this)); std::shared_ptr<VulkanBasicExecution> originExecution ((VulkanBasicExecution*)iter->second->onCreate(inputs, outputs, op, this));
#ifdef ENABLE_VULKAN_TIME_PROFILE
originExecution->setName(EnumNameOpType(op->type()));
#endif
if (nullptr == originExecution) { if (nullptr == originExecution) {
#ifdef MNN_OP_SUPPORT_LOG #ifdef MNN_OP_SUPPORT_LOG
MNN_ERROR("Vulkan don't support for %s, type=%s, Special case\n", name.c_str(), EnumNameOpType(op->type())); MNN_ERROR("Vulkan don't support for %s, type=%s, Special case\n", name.c_str(), EnumNameOpType(op->type()));
@ -217,9 +226,26 @@ void VulkanBackend::onExecuteBegin() const {
} }
// FUNC_PRINT_ALL(mDynamicMemoryPool->computeSize(), f); // FUNC_PRINT_ALL(mDynamicMemoryPool->computeSize(), f);
} }
void VulkanBackend::onExecuteEnd() const { void VulkanBackend::onExecuteEnd() const {
#ifdef ENABLE_VULKAN_TIME_PROFILE
auto startTime = std::chrono::high_resolution_clock::now();
_finish();
auto endTime = std::chrono::high_resolution_clock::now();
float totalTime = std::chrono::duration_cast<std::chrono::nanoseconds>(endTime - startTime).count() / (1e6f);
printTimeProfile();
MNN_PRINT("Total time calculated by CPU is %6.2f ms.\n", totalTime);
mQueryPools.clear();
mExecutionNames.clear();
#else
_finish();
#endif
}
void VulkanBackend::finish() {
_finish(); _finish();
} }
void VulkanBackend::_finish() const { void VulkanBackend::_finish() const {
if (mCmdBuffers.empty()) { if (mCmdBuffers.empty()) {
return; return;
@ -450,4 +476,32 @@ std::vector<uint32_t> VulkanBackend::autoTunePipeline(const VulkanPipeline* pipe
return lws_prefer; return lws_prefer;
} }
#ifdef ENABLE_VULKAN_TIME_PROFILE
void VulkanBackend::printTimeProfile() const {
MNN_ASSERT(mQueryPools.size() == mExecutionNames.size());
float timeTotal = 0.0f;
std::unordered_map<std::string, float> timeTable;
MNN_PRINT("Vulkan Time Profiling:\n");
for (int i = 0; i < mQueryPools.size(); i++) {
float timeCurr = mQueryPools[i]->VulkanGetQueryPoolResults();
timeTable[mExecutionNames[i]] += timeCurr;
timeTotal += timeCurr;
MNN_PRINT("%-30s time is %4.2f ms.\n", mExecutionNames[i].c_str(), timeCurr);
}
std::vector<std::pair<std::string, float>> timeVectorForSort(timeTable.begin(), timeTable.end());
std::sort(timeVectorForSort.begin(), timeVectorForSort.end(), [](const std::pair<std::string, float>& a, const std::pair<std::string, float>& b) {
return a.second > b.second;
});
MNN_PRINT("\nSummary:\n");
for (int i = 0; i < timeVectorForSort.size(); i++) {
MNN_PRINT("%-30s time is %4.2f ms.\n", timeVectorForSort[i].first.c_str(), timeVectorForSort[i].second);
}
MNN_PRINT("\nTotal time summed up by commandBuffers is %6.2f ms\n", timeTotal);
}
#endif
} // namespace MNN } // namespace MNN

View File

@ -27,6 +27,7 @@ public:
const MNN::Op* op) override; const MNN::Op* op) override;
virtual void onExecuteBegin() const override; virtual void onExecuteBegin() const override;
virtual void onExecuteEnd() const override; virtual void onExecuteEnd() const override;
void finish();
virtual bool onSelectDynamicAllocator(int index, int maxIndex) override; virtual bool onSelectDynamicAllocator(int index, int maxIndex) override;
virtual void onResizeBegin() override; virtual void onResizeBegin() override;
virtual ErrorCode onResizeEnd() override; virtual ErrorCode onResizeEnd() override;
@ -94,9 +95,21 @@ public:
void copyToGPUBuffer(const void* src, VkBuffer buffer, VkDeviceSize size, VkDeviceSize offset) const; void copyToGPUBuffer(const void* src, VkBuffer buffer, VkDeviceSize size, VkDeviceSize offset) const;
const VulkanDevice& device() const; const VulkanDevice& device() const;
#ifdef ENABLE_VULKAN_TIME_PROFILE
void pushQueryPool(std::shared_ptr<VulkanQueryPool> queryPool) {
mQueryPools.push_back(queryPool);
}
void pushExecutionName(std::string executionName) {
mExecutionNames.push_back(executionName);
}
#endif
private: private:
void _finish() const; void _finish() const;
void _requireHostBuffer(size_t size) const; void _requireHostBuffer(size_t size) const;
#ifdef ENABLE_VULKAN_TIME_PROFILE
void printTimeProfile() const;
#endif
mutable std::shared_ptr<VulkanBuffer> mHostBuffer; mutable std::shared_ptr<VulkanBuffer> mHostBuffer;
std::shared_ptr<VulkanCommandPool::Buffer> mCmdBuffer; std::shared_ptr<VulkanCommandPool::Buffer> mCmdBuffer;
@ -111,6 +124,11 @@ private:
bool mDirect; bool mDirect;
const VulkanRuntime* mRuntime; const VulkanRuntime* mRuntime;
bool mUseAutoTune = true; bool mUseAutoTune = true;
#ifdef ENABLE_VULKAN_TIME_PROFILE
mutable std::vector<std::shared_ptr<VulkanQueryPool>> mQueryPools;
mutable std::vector<std::string> mExecutionNames;
#endif
}; };

View File

@ -18,12 +18,18 @@ VulkanBasicExecutionDirect::VulkanBasicExecutionDirect(std::shared_ptr<VulkanBas
ErrorCode VulkanBasicExecutionDirect::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { ErrorCode VulkanBasicExecutionDirect::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto extra = static_cast<VulkanBackend *>(backend()); auto extra = static_cast<VulkanBackend *>(backend());
#ifdef ENABLE_VULKAN_TIME_PROFILE
extra->pushExecutionName(mEncoder->getName());
extra->pushQueryPool(mQueryPool);
#endif
extra->pushCommand(mCmdBuffer->get()); extra->pushCommand(mCmdBuffer->get());
return NO_ERROR; return NO_ERROR;
} }
ErrorCode VulkanBasicExecutionDirect::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { ErrorCode VulkanBasicExecutionDirect::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
mCmdBuffer->begin(0); mCmdBuffer->begin(0);
auto vkBn = static_cast<VulkanBackend*>(backend()); auto vkBn = static_cast<VulkanBackend*>(backend());
for (auto input : inputs) { for (auto input : inputs) {
auto des = TensorUtils::getDescribe(input); auto des = TensorUtils::getDescribe(input);
@ -38,11 +44,22 @@ ErrorCode VulkanBasicExecutionDirect::onResize(const std::vector<Tensor *> &inpu
auto offset = des->extra.offset; auto offset = des->extra.offset;
mCmdBuffer->barrierSource(vkTensor->buffer(), offset, vkBn->getTensorSize(input)); mCmdBuffer->barrierSource(vkTensor->buffer(), offset, vkBn->getTensorSize(input));
} }
#ifdef ENABLE_VULKAN_TIME_PROFILE
mQueryPool.reset(new VulkanQueryPool(vkBn->device()));
mQueryPool->VulkanCmdResetQueryPool(mCmdBuffer.get()->get());
mQueryPool->VulkanCmdWriteTimestamp(mCmdBuffer.get()->get(), 0);
auto code = mEncoder->onEncode(inputs, outputs, mCmdBuffer.get()); auto code = mEncoder->onEncode(inputs, outputs, mCmdBuffer.get());
mQueryPool->VulkanCmdWriteTimestamp(mCmdBuffer.get()->get(), 1);
#else
auto code = mEncoder->onEncode(inputs, outputs, mCmdBuffer.get());
#endif
mCmdBuffer->end(); mCmdBuffer->end();
return code; return code;
} }
VulkanBasicExecutionInDirect::VulkanBasicExecutionInDirect(std::shared_ptr<VulkanBasicExecution> encoder) : Execution(encoder->backend()) { VulkanBasicExecutionInDirect::VulkanBasicExecutionInDirect(std::shared_ptr<VulkanBasicExecution> encoder) : Execution(encoder->backend()) {
mEncoder = encoder; mEncoder = encoder;
} }

View File

@ -29,6 +29,16 @@ public:
virtual bool onClone(Backend* bn, const Op* op, VulkanBasicExecution** dst) { virtual bool onClone(Backend* bn, const Op* op, VulkanBasicExecution** dst) {
return false; return false;
} }
#ifdef ENABLE_VULKAN_TIME_PROFILE
void setName(const char * name) {
mName = name;
}
std::string getName() {
return mName;
}
protected:
std::string mName = "General_Execution";
#endif
private: private:
Backend* mBackend; Backend* mBackend;
}; };
@ -56,6 +66,9 @@ public:
private: private:
std::shared_ptr<VulkanBasicExecution> mEncoder; std::shared_ptr<VulkanBasicExecution> mEncoder;
std::shared_ptr<VulkanCommandPool::Buffer> mCmdBuffer; std::shared_ptr<VulkanCommandPool::Buffer> mCmdBuffer;
#ifdef ENABLE_VULKAN_TIME_PROFILE
std::shared_ptr<VulkanQueryPool> mQueryPool;
#endif
}; };
class VulkanBasicExecutionInDirect : public Execution { class VulkanBasicExecutionInDirect : public Execution {
public: public:

View File

@ -109,7 +109,7 @@ public:
mKernelReorder.exe->onEncode({}, {mKernel.get()}, prearrangeCmd.get()); mKernelReorder.exe->onEncode({}, {mKernel.get()}, prearrangeCmd.get());
prearrangeCmd->end(); prearrangeCmd->end();
vkBn->pushCommand(prearrangeCmd->get()); vkBn->pushCommand(prearrangeCmd->get());
vkBn->onExecuteEnd(); vkBn->finish();
mKernelReorder.exe = nullptr; mKernelReorder.exe = nullptr;
} }
} }

View File

@ -133,7 +133,7 @@ VulkanDeconvolution* VulkanDeconvolution::create(Backend* bn, const Op* op, OpTy
exeRes->mKernelReorder.exe->onEncode({}, {exeRes->mKernel.get()}, prearrangeCmd.get()); exeRes->mKernelReorder.exe->onEncode({}, {exeRes->mKernel.get()}, prearrangeCmd.get());
prearrangeCmd->end(); prearrangeCmd->end();
vkBn->pushCommand(prearrangeCmd->get()); vkBn->pushCommand(prearrangeCmd->get());
vkBn->onExecuteEnd(); vkBn->finish();
exeRes->mKernelReorder.exe = nullptr; exeRes->mKernelReorder.exe = nullptr;
} }
std::vector<VkDescriptorType> types{ std::vector<VkDescriptorType> types{

View File

@ -25,6 +25,13 @@
// #define MNN_OP_SUPPORT_LOG // #define MNN_OP_SUPPORT_LOG
//#define MNN_VULKAN_DUMP_MEMORY_USAGE //#define MNN_VULKAN_DUMP_MEMORY_USAGE
#define MNN_VULKAN_MAX_CACHE_CONVSIZE 50 #define MNN_VULKAN_MAX_CACHE_CONVSIZE 50
#ifdef ENABLE_VULKAN_TIME_PROFILE
#include <chrono>
#include <unordered_map>
#include <algorithm>
#endif
namespace MNN { namespace MNN {
static std::map<OpType, VulkanBackend::Creator*>* gCreator = nullptr; static std::map<OpType, VulkanBackend::Creator*>* gCreator = nullptr;
@ -224,6 +231,9 @@ Execution* VulkanBackend::onCreate(const std::vector<Tensor*>& inputs, const std
return nullptr; return nullptr;
} }
auto originExecution = (VulkanBasicExecution*)iter->second->onCreate(inputs, outputs, op, this); auto originExecution = (VulkanBasicExecution*)iter->second->onCreate(inputs, outputs, op, this);
#ifdef ENABLE_VULKAN_TIME_PROFILE
originExecution->setName(EnumNameOpType(op->type()));
#endif
if (nullptr == originExecution) { if (nullptr == originExecution) {
#ifdef MNN_OP_SUPPORT_LOG #ifdef MNN_OP_SUPPORT_LOG
MNN_ERROR("Vulkan don't support for %s, type=%s, Special case\n", name.c_str(), EnumNameOpType(op->type())); MNN_ERROR("Vulkan don't support for %s, type=%s, Special case\n", name.c_str(), EnumNameOpType(op->type()));
@ -242,7 +252,23 @@ void VulkanBackend::onExecuteBegin() const {
} }
// FUNC_PRINT_ALL(mDynamicMemoryPool->computeSize(), f); // FUNC_PRINT_ALL(mDynamicMemoryPool->computeSize(), f);
} }
void VulkanBackend::onExecuteEnd() const { void VulkanBackend::onExecuteEnd() const {
#ifdef ENABLE_VULKAN_TIME_PROFILE
auto startTime = std::chrono::high_resolution_clock::now();
_finish();
auto endTime = std::chrono::high_resolution_clock::now();
float totalTime = std::chrono::duration_cast<std::chrono::nanoseconds>(endTime - startTime).count() / (1e6f);
printTimeProfile();
MNN_PRINT("Total time calculated by CPU is %6.2f ms.\n", totalTime);
mQueryPools.clear();
mExecutionNames.clear();
#else
_finish();
#endif
}
void VulkanBackend::finish() {
_finish(); _finish();
} }
void VulkanBackend::_finish() const { void VulkanBackend::_finish() const {
@ -610,4 +636,33 @@ std::vector<uint32_t> VulkanBackend::autoTunePipeline(SharedPtr<VulkanPipeline>
#ifdef ENABLE_VULKAN_TIME_PROFILE
void VulkanBackend::printTimeProfile() const {
MNN_ASSERT(mQueryPools.size() == mExecutionNames.size());
float timeTotal = 0.0f;
std::unordered_map<std::string, float> timeTable;
MNN_PRINT("Vulkan Time Profiling:\n");
for (int i = 0; i < mQueryPools.size(); i++) {
float timeCurr = mQueryPools[i]->VulkanGetQueryPoolResults();
timeTable[mExecutionNames[i]] += timeCurr;
timeTotal += timeCurr;
MNN_PRINT("%-30s time is %4.2f ms.\n", mExecutionNames[i].c_str(), timeCurr);
}
std::vector<std::pair<std::string, float>> timeVectorForSort(timeTable.begin(), timeTable.end());
std::sort(timeVectorForSort.begin(), timeVectorForSort.end(), [](const std::pair<std::string, float>& a, const std::pair<std::string, float>& b) {
return a.second > b.second;
});
MNN_PRINT("\nSummary:\n");
for (int i = 0; i < timeVectorForSort.size(); i++) {
MNN_PRINT("%-30s time is %4.2f ms.\n", timeVectorForSort[i].first.c_str(), timeVectorForSort[i].second);
}
MNN_PRINT("\nTotal time summed up by commandBuffers is %6.2f ms\n", timeTotal);
}
#endif
} // namespace MNN } // namespace MNN

View File

@ -33,6 +33,7 @@ public:
const MNN::Op* op) override; const MNN::Op* op) override;
virtual void onExecuteBegin() const override; virtual void onExecuteBegin() const override;
virtual void onExecuteEnd() const override; virtual void onExecuteEnd() const override;
void finish();
virtual void onResizeBegin() override; virtual void onResizeBegin() override;
virtual ErrorCode onResizeEnd() override; virtual ErrorCode onResizeEnd() override;
virtual void onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor) const override; virtual void onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor) const override;
@ -87,11 +88,24 @@ public:
float getPipelineTime(const VulkanPipeline* pipeline, std::shared_ptr<VulkanLayout::DescriptorSet> des, std::vector<uint32_t> groupSize); float getPipelineTime(const VulkanPipeline* pipeline, std::shared_ptr<VulkanLayout::DescriptorSet> des, std::vector<uint32_t> groupSize);
const VulkanDevice& device() const;
#ifdef ENABLE_VULKAN_TIME_PROFILE
void pushQueryPool(std::shared_ptr<VulkanQueryPool> queryPool) {
mQueryPools.push_back(queryPool);
}
void pushExecutionName(std::string executionName) {
mExecutionNames.push_back(executionName);
}
#endif
private: private:
bool _supportImageSize(const Tensor* tensor); bool _supportImageSize(const Tensor* tensor);
const VulkanDevice& device() const;
void _finish() const; void _finish() const;
void _allocHostBuffer(size_t size) const; void _allocHostBuffer(size_t size) const;
#ifdef ENABLE_VULKAN_TIME_PROFILE
void printTimeProfile() const;
#endif
std::shared_ptr<VulkanCommandPool::Buffer> mCmdBuffer; std::shared_ptr<VulkanCommandPool::Buffer> mCmdBuffer;
std::shared_ptr<VulkanCommandPool::Buffer> mInitBuffer; std::shared_ptr<VulkanCommandPool::Buffer> mInitBuffer;
@ -108,6 +122,11 @@ private:
bool mDirect; bool mDirect;
const VulkanRuntime* mRuntime; const VulkanRuntime* mRuntime;
std::shared_ptr<VulkanMemoryPool> mDynamicMemoryPool; std::shared_ptr<VulkanMemoryPool> mDynamicMemoryPool;
#ifdef ENABLE_VULKAN_TIME_PROFILE
mutable std::vector<std::shared_ptr<VulkanQueryPool>> mQueryPools;
mutable std::vector<std::string> mExecutionNames;
#endif
}; };

View File

@ -2125,321 +2125,6 @@ const unsigned char glsl_softmaxImage_AXIS_C_comp[] = {
}; };
unsigned int glsl_softmaxImage_AXIS_C_comp_len = 5204; unsigned int glsl_softmaxImage_AXIS_C_comp_len = 5204;
const unsigned char glsl_convolutionDepthwiseMali_comp[] = {
0x03, 0x02, 0x23, 0x07, 0x00, 0x00, 0x01, 0x00, 0x0b, 0x00, 0x08, 0x00,
0xec, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x00, 0x02, 0x00,
0x01, 0x00, 0x00, 0x00, 0x11, 0x00, 0x02, 0x00, 0x38, 0x00, 0x00, 0x00,
0x0b, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x00, 0x47, 0x4c, 0x53, 0x4c,
0x2e, 0x73, 0x74, 0x64, 0x2e, 0x34, 0x35, 0x30, 0x00, 0x00, 0x00, 0x00,
0x0e, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x0f, 0x00, 0x06, 0x00, 0x05, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00,
0x10, 0x00, 0x06, 0x00, 0x04, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00,
0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x03, 0x00, 0x03, 0x00, 0x02, 0x00, 0x00, 0x00, 0xb8, 0x01, 0x00, 0x00,
0x05, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e,
0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x08, 0x00, 0x0d, 0x00, 0x00, 0x00,
0x67, 0x6c, 0x5f, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x49, 0x6e, 0x76,
0x6f, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x44, 0x00, 0x00, 0x00,
0x05, 0x00, 0x05, 0x00, 0x13, 0x00, 0x00, 0x00, 0x63, 0x6f, 0x6e, 0x73,
0x74, 0x42, 0x75, 0x66, 0x66, 0x65, 0x72, 0x00, 0x06, 0x00, 0x04, 0x00,
0x13, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x70, 0x61, 0x64, 0x00,
0x06, 0x00, 0x06, 0x00, 0x13, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x53, 0x69, 0x7a, 0x65, 0x00, 0x00,
0x06, 0x00, 0x05, 0x00, 0x13, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x73, 0x74, 0x72, 0x69, 0x64, 0x65, 0x00, 0x00, 0x06, 0x00, 0x05, 0x00,
0x13, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x64, 0x69, 0x6c, 0x61,
0x74, 0x65, 0x00, 0x00, 0x06, 0x00, 0x06, 0x00, 0x13, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x53, 0x69, 0x7a,
0x65, 0x00, 0x00, 0x00, 0x06, 0x00, 0x06, 0x00, 0x13, 0x00, 0x00, 0x00,
0x05, 0x00, 0x00, 0x00, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x53, 0x69,
0x7a, 0x65, 0x00, 0x00, 0x06, 0x00, 0x05, 0x00, 0x13, 0x00, 0x00, 0x00,
0x06, 0x00, 0x00, 0x00, 0x6f, 0x66, 0x66, 0x73, 0x65, 0x74, 0x00, 0x00,
0x05, 0x00, 0x05, 0x00, 0x15, 0x00, 0x00, 0x00, 0x75, 0x43, 0x6f, 0x6e,
0x73, 0x74, 0x61, 0x6e, 0x74, 0x00, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00,
0x6d, 0x00, 0x00, 0x00, 0x75, 0x42, 0x69, 0x61, 0x73, 0x00, 0x00, 0x00,
0x05, 0x00, 0x04, 0x00, 0x9e, 0x00, 0x00, 0x00, 0x75, 0x49, 0x6e, 0x70,
0x75, 0x74, 0x00, 0x00, 0x05, 0x00, 0x04, 0x00, 0xb0, 0x00, 0x00, 0x00,
0x75, 0x4b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x00, 0x05, 0x00, 0x04, 0x00,
0xc7, 0x00, 0x00, 0x00, 0x75, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x00,
0x47, 0x00, 0x04, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00,
0x1c, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, 0x13, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x48, 0x00, 0x05, 0x00, 0x13, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x23, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00,
0x13, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00,
0x10, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, 0x13, 0x00, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00,
0x48, 0x00, 0x05, 0x00, 0x13, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x23, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00,
0x13, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00,
0x30, 0x00, 0x00, 0x00, 0x48, 0x00, 0x05, 0x00, 0x13, 0x00, 0x00, 0x00,
0x06, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00,
0x47, 0x00, 0x03, 0x00, 0x13, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x47, 0x00, 0x04, 0x00, 0x15, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x15, 0x00, 0x00, 0x00,
0x21, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00,
0x6d, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x47, 0x00, 0x04, 0x00, 0x6d, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0x9e, 0x00, 0x00, 0x00,
0x22, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00,
0x9e, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x47, 0x00, 0x04, 0x00, 0xb0, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0xb0, 0x00, 0x00, 0x00,
0x21, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00,
0xc7, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x47, 0x00, 0x04, 0x00, 0xc7, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x47, 0x00, 0x03, 0x00, 0xc7, 0x00, 0x00, 0x00,
0x19, 0x00, 0x00, 0x00, 0x47, 0x00, 0x04, 0x00, 0xda, 0x00, 0x00, 0x00,
0x0b, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, 0x13, 0x00, 0x02, 0x00,
0x02, 0x00, 0x00, 0x00, 0x21, 0x00, 0x03, 0x00, 0x03, 0x00, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00, 0x15, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00,
0x20, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x17, 0x00, 0x04, 0x00,
0x07, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
0x15, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x17, 0x00, 0x04, 0x00, 0x0b, 0x00, 0x00, 0x00,
0x0a, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00,
0x0c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00,
0x3b, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x17, 0x00, 0x04, 0x00, 0x11, 0x00, 0x00, 0x00,
0x06, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x17, 0x00, 0x04, 0x00,
0x12, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x1e, 0x00, 0x09, 0x00, 0x13, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00,
0x11, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00,
0x12, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00,
0x20, 0x00, 0x04, 0x00, 0x14, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x13, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x14, 0x00, 0x00, 0x00,
0x15, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00,
0x06, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00,
0x20, 0x00, 0x04, 0x00, 0x17, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x12, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00,
0x1d, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00,
0x20, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00,
0x14, 0x00, 0x02, 0x00, 0x2e, 0x00, 0x00, 0x00, 0x17, 0x00, 0x04, 0x00,
0x2f, 0x00, 0x00, 0x00, 0x2e, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x2b, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x35, 0x00, 0x00, 0x00,
0x04, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00,
0x3d, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x20, 0x00, 0x04, 0x00,
0x3e, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00,
0x2b, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x42, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x05, 0x00, 0x11, 0x00, 0x00, 0x00,
0x47, 0x00, 0x00, 0x00, 0x42, 0x00, 0x00, 0x00, 0x42, 0x00, 0x00, 0x00,
0x2b, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x4a, 0x00, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00,
0x4e, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x16, 0x00, 0x03, 0x00,
0x66, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x17, 0x00, 0x04, 0x00,
0x67, 0x00, 0x00, 0x00, 0x66, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
0x19, 0x00, 0x09, 0x00, 0x6a, 0x00, 0x00, 0x00, 0x66, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x1b, 0x00, 0x03, 0x00, 0x6b, 0x00, 0x00, 0x00, 0x6a, 0x00, 0x00, 0x00,
0x20, 0x00, 0x04, 0x00, 0x6c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x6b, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x6c, 0x00, 0x00, 0x00,
0x6d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00,
0x0a, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x2b, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x89, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0x6c, 0x00, 0x00, 0x00,
0x9e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00,
0x6c, 0x00, 0x00, 0x00, 0xb0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x19, 0x00, 0x09, 0x00, 0xc5, 0x00, 0x00, 0x00, 0x66, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x20, 0x00, 0x04, 0x00, 0xc6, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0xc5, 0x00, 0x00, 0x00, 0x3b, 0x00, 0x04, 0x00, 0xc6, 0x00, 0x00, 0x00,
0xc7, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x04, 0x00,
0x0a, 0x00, 0x00, 0x00, 0xd9, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00,
0x2c, 0x00, 0x06, 0x00, 0x0b, 0x00, 0x00, 0x00, 0xda, 0x00, 0x00, 0x00,
0xd9, 0x00, 0x00, 0x00, 0xd9, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00,
0x2e, 0x00, 0x03, 0x00, 0x07, 0x00, 0x00, 0x00, 0xe5, 0x00, 0x00, 0x00,
0x2c, 0x00, 0x05, 0x00, 0x11, 0x00, 0x00, 0x00, 0xe6, 0x00, 0x00, 0x00,
0x4e, 0x00, 0x00, 0x00, 0x4e, 0x00, 0x00, 0x00, 0x36, 0x00, 0x05, 0x00,
0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x03, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, 0x05, 0x00, 0x00, 0x00,
0x3d, 0x00, 0x04, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x00, 0x00,
0x0d, 0x00, 0x00, 0x00, 0x7c, 0x00, 0x04, 0x00, 0x07, 0x00, 0x00, 0x00,
0x0f, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00,
0x17, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00,
0x16, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x12, 0x00, 0x00, 0x00,
0x19, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x51, 0x00, 0x05, 0x00,
0x06, 0x00, 0x00, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, 0x20, 0x00, 0x00, 0x00,
0x21, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00,
0x1d, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00,
0x22, 0x00, 0x00, 0x00, 0x21, 0x00, 0x00, 0x00, 0x8b, 0x00, 0x05, 0x00,
0x06, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x1f, 0x00, 0x00, 0x00,
0x22, 0x00, 0x00, 0x00, 0x87, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00,
0x29, 0x00, 0x00, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x22, 0x00, 0x00, 0x00,
0x4f, 0x00, 0x07, 0x00, 0x11, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x00, 0x00,
0x0f, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x4f, 0x00, 0x07, 0x00, 0x11, 0x00, 0x00, 0x00,
0x2d, 0x00, 0x00, 0x00, 0x19, 0x00, 0x00, 0x00, 0xe5, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xb1, 0x00, 0x05, 0x00,
0x2f, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x00, 0x00,
0x2d, 0x00, 0x00, 0x00, 0x9b, 0x00, 0x04, 0x00, 0x2e, 0x00, 0x00, 0x00,
0x31, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, 0xf7, 0x00, 0x03, 0x00,
0x33, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfa, 0x00, 0x04, 0x00,
0x31, 0x00, 0x00, 0x00, 0x32, 0x00, 0x00, 0x00, 0x33, 0x00, 0x00, 0x00,
0xf8, 0x00, 0x02, 0x00, 0x32, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00,
0x3e, 0x00, 0x00, 0x00, 0x3f, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00,
0x3d, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x11, 0x00, 0x00, 0x00,
0x40, 0x00, 0x00, 0x00, 0x3f, 0x00, 0x00, 0x00, 0x84, 0x00, 0x05, 0x00,
0x11, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x00, 0x00,
0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, 0x3e, 0x00, 0x00, 0x00,
0x43, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x42, 0x00, 0x00, 0x00,
0x3d, 0x00, 0x04, 0x00, 0x11, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00,
0x43, 0x00, 0x00, 0x00, 0x82, 0x00, 0x05, 0x00, 0x11, 0x00, 0x00, 0x00,
0x45, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00,
0x7e, 0x00, 0x04, 0x00, 0x11, 0x00, 0x00, 0x00, 0x49, 0x00, 0x00, 0x00,
0x45, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00, 0x3e, 0x00, 0x00, 0x00,
0x4b, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x4a, 0x00, 0x00, 0x00,
0x3d, 0x00, 0x04, 0x00, 0x11, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00,
0x4b, 0x00, 0x00, 0x00, 0x80, 0x00, 0x05, 0x00, 0x11, 0x00, 0x00, 0x00,
0x4d, 0x00, 0x00, 0x00, 0x49, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00,
0x82, 0x00, 0x05, 0x00, 0x11, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00,
0x4d, 0x00, 0x00, 0x00, 0xe6, 0x00, 0x00, 0x00, 0x87, 0x00, 0x05, 0x00,
0x11, 0x00, 0x00, 0x00, 0x53, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00,
0x4c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x11, 0x00, 0x00, 0x00,
0x54, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x2a, 0x00, 0x00, 0x00,
0x47, 0x00, 0x00, 0x00, 0x53, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00,
0x3e, 0x00, 0x00, 0x00, 0x56, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00,
0x4e, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x11, 0x00, 0x00, 0x00,
0x57, 0x00, 0x00, 0x00, 0x56, 0x00, 0x00, 0x00, 0x41, 0x00, 0x05, 0x00,
0x17, 0x00, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00,
0x35, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x12, 0x00, 0x00, 0x00,
0x59, 0x00, 0x00, 0x00, 0x58, 0x00, 0x00, 0x00, 0x4f, 0x00, 0x07, 0x00,
0x11, 0x00, 0x00, 0x00, 0x5a, 0x00, 0x00, 0x00, 0x59, 0x00, 0x00, 0x00,
0x59, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x82, 0x00, 0x05, 0x00, 0x11, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00,
0x5a, 0x00, 0x00, 0x00, 0x45, 0x00, 0x00, 0x00, 0x80, 0x00, 0x05, 0x00,
0x11, 0x00, 0x00, 0x00, 0x5f, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00,
0x4c, 0x00, 0x00, 0x00, 0x82, 0x00, 0x05, 0x00, 0x11, 0x00, 0x00, 0x00,
0x61, 0x00, 0x00, 0x00, 0x5f, 0x00, 0x00, 0x00, 0xe6, 0x00, 0x00, 0x00,
0x87, 0x00, 0x05, 0x00, 0x11, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00,
0x61, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x07, 0x00,
0x11, 0x00, 0x00, 0x00, 0x65, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x27, 0x00, 0x00, 0x00, 0x57, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00,
0x3d, 0x00, 0x04, 0x00, 0x6b, 0x00, 0x00, 0x00, 0x6e, 0x00, 0x00, 0x00,
0x6d, 0x00, 0x00, 0x00, 0x50, 0x00, 0x05, 0x00, 0x11, 0x00, 0x00, 0x00,
0x70, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x42, 0x00, 0x00, 0x00,
0x64, 0x00, 0x04, 0x00, 0x6a, 0x00, 0x00, 0x00, 0x71, 0x00, 0x00, 0x00,
0x6e, 0x00, 0x00, 0x00, 0x5f, 0x00, 0x07, 0x00, 0x67, 0x00, 0x00, 0x00,
0x72, 0x00, 0x00, 0x00, 0x71, 0x00, 0x00, 0x00, 0x70, 0x00, 0x00, 0x00,
0x02, 0x00, 0x00, 0x00, 0x42, 0x00, 0x00, 0x00, 0x51, 0x00, 0x05, 0x00,
0x06, 0x00, 0x00, 0x00, 0x76, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0xf9, 0x00, 0x02, 0x00, 0x77, 0x00, 0x00, 0x00,
0xf8, 0x00, 0x02, 0x00, 0x77, 0x00, 0x00, 0x00, 0xf5, 0x00, 0x07, 0x00,
0x67, 0x00, 0x00, 0x00, 0xe8, 0x00, 0x00, 0x00, 0x72, 0x00, 0x00, 0x00,
0x32, 0x00, 0x00, 0x00, 0xeb, 0x00, 0x00, 0x00, 0x7a, 0x00, 0x00, 0x00,
0xf5, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0xe7, 0x00, 0x00, 0x00,
0x76, 0x00, 0x00, 0x00, 0x32, 0x00, 0x00, 0x00, 0xc4, 0x00, 0x00, 0x00,
0x7a, 0x00, 0x00, 0x00, 0x51, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00,
0x7e, 0x00, 0x00, 0x00, 0x65, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0xb1, 0x00, 0x05, 0x00, 0x2e, 0x00, 0x00, 0x00, 0x7f, 0x00, 0x00, 0x00,
0xe7, 0x00, 0x00, 0x00, 0x7e, 0x00, 0x00, 0x00, 0xf6, 0x00, 0x04, 0x00,
0x79, 0x00, 0x00, 0x00, 0x7a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0xfa, 0x00, 0x04, 0x00, 0x7f, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00,
0x79, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, 0x78, 0x00, 0x00, 0x00,
0x41, 0x00, 0x06, 0x00, 0x20, 0x00, 0x00, 0x00, 0x82, 0x00, 0x00, 0x00,
0x15, 0x00, 0x00, 0x00, 0x4a, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00,
0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x83, 0x00, 0x00, 0x00,
0x82, 0x00, 0x00, 0x00, 0x84, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00,
0x84, 0x00, 0x00, 0x00, 0xe7, 0x00, 0x00, 0x00, 0x83, 0x00, 0x00, 0x00,
0x51, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x86, 0x00, 0x00, 0x00,
0x45, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x80, 0x00, 0x05, 0x00,
0x06, 0x00, 0x00, 0x00, 0x87, 0x00, 0x00, 0x00, 0x84, 0x00, 0x00, 0x00,
0x86, 0x00, 0x00, 0x00, 0x51, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00,
0x8b, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0xf9, 0x00, 0x02, 0x00, 0x8c, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00,
0x8c, 0x00, 0x00, 0x00, 0xf5, 0x00, 0x07, 0x00, 0x67, 0x00, 0x00, 0x00,
0xeb, 0x00, 0x00, 0x00, 0xe8, 0x00, 0x00, 0x00, 0x78, 0x00, 0x00, 0x00,
0xc0, 0x00, 0x00, 0x00, 0x8d, 0x00, 0x00, 0x00, 0xf5, 0x00, 0x07, 0x00,
0x06, 0x00, 0x00, 0x00, 0xe9, 0x00, 0x00, 0x00, 0x8b, 0x00, 0x00, 0x00,
0x78, 0x00, 0x00, 0x00, 0xc2, 0x00, 0x00, 0x00, 0x8d, 0x00, 0x00, 0x00,
0x51, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x93, 0x00, 0x00, 0x00,
0x65, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xb1, 0x00, 0x05, 0x00,
0x2e, 0x00, 0x00, 0x00, 0x94, 0x00, 0x00, 0x00, 0xe9, 0x00, 0x00, 0x00,
0x93, 0x00, 0x00, 0x00, 0xf6, 0x00, 0x04, 0x00, 0x8e, 0x00, 0x00, 0x00,
0x8d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfa, 0x00, 0x04, 0x00,
0x94, 0x00, 0x00, 0x00, 0x8d, 0x00, 0x00, 0x00, 0x8e, 0x00, 0x00, 0x00,
0xf8, 0x00, 0x02, 0x00, 0x8d, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00,
0x20, 0x00, 0x00, 0x00, 0x97, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00,
0x4a, 0x00, 0x00, 0x00, 0x89, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00,
0x06, 0x00, 0x00, 0x00, 0x98, 0x00, 0x00, 0x00, 0x97, 0x00, 0x00, 0x00,
0x84, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x99, 0x00, 0x00, 0x00,
0xe9, 0x00, 0x00, 0x00, 0x98, 0x00, 0x00, 0x00, 0x51, 0x00, 0x05, 0x00,
0x06, 0x00, 0x00, 0x00, 0x9b, 0x00, 0x00, 0x00, 0x45, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00,
0x9c, 0x00, 0x00, 0x00, 0x99, 0x00, 0x00, 0x00, 0x9b, 0x00, 0x00, 0x00,
0x3d, 0x00, 0x04, 0x00, 0x6b, 0x00, 0x00, 0x00, 0x9f, 0x00, 0x00, 0x00,
0x9e, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, 0x20, 0x00, 0x00, 0x00,
0xa2, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x35, 0x00, 0x00, 0x00,
0x89, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00,
0xa3, 0x00, 0x00, 0x00, 0xa2, 0x00, 0x00, 0x00, 0x84, 0x00, 0x05, 0x00,
0x06, 0x00, 0x00, 0x00, 0xa4, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00,
0xa3, 0x00, 0x00, 0x00, 0x80, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00,
0xa5, 0x00, 0x00, 0x00, 0x9c, 0x00, 0x00, 0x00, 0xa4, 0x00, 0x00, 0x00,
0x41, 0x00, 0x06, 0x00, 0x20, 0x00, 0x00, 0x00, 0xa8, 0x00, 0x00, 0x00,
0x15, 0x00, 0x00, 0x00, 0x35, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00,
0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0xa9, 0x00, 0x00, 0x00,
0xa8, 0x00, 0x00, 0x00, 0x84, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00,
0xaa, 0x00, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, 0xa9, 0x00, 0x00, 0x00,
0x80, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0xab, 0x00, 0x00, 0x00,
0x87, 0x00, 0x00, 0x00, 0xaa, 0x00, 0x00, 0x00, 0x50, 0x00, 0x05, 0x00,
0x11, 0x00, 0x00, 0x00, 0xac, 0x00, 0x00, 0x00, 0xa5, 0x00, 0x00, 0x00,
0xab, 0x00, 0x00, 0x00, 0x64, 0x00, 0x04, 0x00, 0x6a, 0x00, 0x00, 0x00,
0xad, 0x00, 0x00, 0x00, 0x9f, 0x00, 0x00, 0x00, 0x5f, 0x00, 0x07, 0x00,
0x67, 0x00, 0x00, 0x00, 0xae, 0x00, 0x00, 0x00, 0xad, 0x00, 0x00, 0x00,
0xac, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x42, 0x00, 0x00, 0x00,
0x3d, 0x00, 0x04, 0x00, 0x6b, 0x00, 0x00, 0x00, 0xb1, 0x00, 0x00, 0x00,
0xb0, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, 0x20, 0x00, 0x00, 0x00,
0xb4, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x4e, 0x00, 0x00, 0x00,
0x89, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00,
0xb5, 0x00, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, 0x84, 0x00, 0x05, 0x00,
0x06, 0x00, 0x00, 0x00, 0xb6, 0x00, 0x00, 0x00, 0xe7, 0x00, 0x00, 0x00,
0xb5, 0x00, 0x00, 0x00, 0x80, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00,
0xb7, 0x00, 0x00, 0x00, 0xe9, 0x00, 0x00, 0x00, 0xb6, 0x00, 0x00, 0x00,
0x50, 0x00, 0x05, 0x00, 0x11, 0x00, 0x00, 0x00, 0xb9, 0x00, 0x00, 0x00,
0xb7, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x64, 0x00, 0x04, 0x00,
0x6a, 0x00, 0x00, 0x00, 0xba, 0x00, 0x00, 0x00, 0xb1, 0x00, 0x00, 0x00,
0x5f, 0x00, 0x07, 0x00, 0x67, 0x00, 0x00, 0x00, 0xbb, 0x00, 0x00, 0x00,
0xba, 0x00, 0x00, 0x00, 0xb9, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x42, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x67, 0x00, 0x00, 0x00,
0xc0, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x32, 0x00, 0x00, 0x00,
0xbb, 0x00, 0x00, 0x00, 0xae, 0x00, 0x00, 0x00, 0xeb, 0x00, 0x00, 0x00,
0x80, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0xc2, 0x00, 0x00, 0x00,
0xe9, 0x00, 0x00, 0x00, 0x4e, 0x00, 0x00, 0x00, 0xf9, 0x00, 0x02, 0x00,
0x8c, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00, 0x8e, 0x00, 0x00, 0x00,
0xf9, 0x00, 0x02, 0x00, 0x7a, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00,
0x7a, 0x00, 0x00, 0x00, 0x80, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00,
0xc4, 0x00, 0x00, 0x00, 0xe7, 0x00, 0x00, 0x00, 0x4e, 0x00, 0x00, 0x00,
0xf9, 0x00, 0x02, 0x00, 0x77, 0x00, 0x00, 0x00, 0xf8, 0x00, 0x02, 0x00,
0x79, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0xc5, 0x00, 0x00, 0x00,
0xc8, 0x00, 0x00, 0x00, 0xc7, 0x00, 0x00, 0x00, 0x51, 0x00, 0x05, 0x00,
0x06, 0x00, 0x00, 0x00, 0xca, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00, 0x20, 0x00, 0x00, 0x00,
0xcc, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00,
0x89, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00,
0xcd, 0x00, 0x00, 0x00, 0xcc, 0x00, 0x00, 0x00, 0x84, 0x00, 0x05, 0x00,
0x06, 0x00, 0x00, 0x00, 0xce, 0x00, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00,
0xcd, 0x00, 0x00, 0x00, 0x80, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00,
0xcf, 0x00, 0x00, 0x00, 0xca, 0x00, 0x00, 0x00, 0xce, 0x00, 0x00, 0x00,
0x51, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0xd1, 0x00, 0x00, 0x00,
0x0f, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x41, 0x00, 0x06, 0x00,
0x20, 0x00, 0x00, 0x00, 0xd3, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00,
0x16, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x3d, 0x00, 0x04, 0x00,
0x06, 0x00, 0x00, 0x00, 0xd4, 0x00, 0x00, 0x00, 0xd3, 0x00, 0x00, 0x00,
0x84, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0xd5, 0x00, 0x00, 0x00,
0x29, 0x00, 0x00, 0x00, 0xd4, 0x00, 0x00, 0x00, 0x80, 0x00, 0x05, 0x00,
0x06, 0x00, 0x00, 0x00, 0xd6, 0x00, 0x00, 0x00, 0xd1, 0x00, 0x00, 0x00,
0xd5, 0x00, 0x00, 0x00, 0x50, 0x00, 0x05, 0x00, 0x11, 0x00, 0x00, 0x00,
0xd7, 0x00, 0x00, 0x00, 0xcf, 0x00, 0x00, 0x00, 0xd6, 0x00, 0x00, 0x00,
0x63, 0x00, 0x04, 0x00, 0xc8, 0x00, 0x00, 0x00, 0xd7, 0x00, 0x00, 0x00,
0xe8, 0x00, 0x00, 0x00, 0xf9, 0x00, 0x02, 0x00, 0x33, 0x00, 0x00, 0x00,
0xf8, 0x00, 0x02, 0x00, 0x33, 0x00, 0x00, 0x00, 0xfd, 0x00, 0x01, 0x00,
0x38, 0x00, 0x01, 0x00
};
unsigned int glsl_convolutionDepthwiseMali_comp_len = 3724;
const unsigned char glsl_relu_comp[] = { const unsigned char glsl_relu_comp[] = {
0x03, 0x02, 0x23, 0x07, 0x00, 0x00, 0x01, 0x00, 0x0b, 0x00, 0x08, 0x00, 0x03, 0x02, 0x23, 0x07, 0x00, 0x00, 0x01, 0x00, 0x0b, 0x00, 0x08, 0x00,
0x5a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x00, 0x02, 0x00, 0x5a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x00, 0x02, 0x00,

View File

@ -10,7 +10,6 @@ mMaps.insert(std::make_pair("glsl_softmaxImage_AXIS_N_comp", std::make_pair(glsl
mMaps.insert(std::make_pair("glsl_softmaxImage_AXIS_H_comp", std::make_pair(glsl_softmaxImage_AXIS_H_comp,glsl_softmaxImage_AXIS_H_comp_len))); mMaps.insert(std::make_pair("glsl_softmaxImage_AXIS_H_comp", std::make_pair(glsl_softmaxImage_AXIS_H_comp,glsl_softmaxImage_AXIS_H_comp_len)));
mMaps.insert(std::make_pair("glsl_softmaxImage_AXIS_W_comp", std::make_pair(glsl_softmaxImage_AXIS_W_comp,glsl_softmaxImage_AXIS_W_comp_len))); mMaps.insert(std::make_pair("glsl_softmaxImage_AXIS_W_comp", std::make_pair(glsl_softmaxImage_AXIS_W_comp,glsl_softmaxImage_AXIS_W_comp_len)));
mMaps.insert(std::make_pair("glsl_softmaxImage_AXIS_C_comp", std::make_pair(glsl_softmaxImage_AXIS_C_comp,glsl_softmaxImage_AXIS_C_comp_len))); mMaps.insert(std::make_pair("glsl_softmaxImage_AXIS_C_comp", std::make_pair(glsl_softmaxImage_AXIS_C_comp,glsl_softmaxImage_AXIS_C_comp_len)));
mMaps.insert(std::make_pair("glsl_convolutionDepthwiseMali_comp", std::make_pair(glsl_convolutionDepthwiseMali_comp,glsl_convolutionDepthwiseMali_comp_len)));
mMaps.insert(std::make_pair("glsl_relu_comp", std::make_pair(glsl_relu_comp,glsl_relu_comp_len))); mMaps.insert(std::make_pair("glsl_relu_comp", std::make_pair(glsl_relu_comp,glsl_relu_comp_len)));
mMaps.insert(std::make_pair("glsl_unaryImage_comp", std::make_pair(glsl_unaryImage_comp,glsl_unaryImage_comp_len))); mMaps.insert(std::make_pair("glsl_unaryImage_comp", std::make_pair(glsl_unaryImage_comp,glsl_unaryImage_comp_len)));
mMaps.insert(std::make_pair("glsl_unaryImage_SIGMOID_comp", std::make_pair(glsl_unaryImage_SIGMOID_comp,glsl_unaryImage_SIGMOID_comp_len))); mMaps.insert(std::make_pair("glsl_unaryImage_SIGMOID_comp", std::make_pair(glsl_unaryImage_SIGMOID_comp,glsl_unaryImage_SIGMOID_comp_len)));

View File

@ -20,6 +20,10 @@ VulkanBasicExecutionDirect::VulkanBasicExecutionDirect(std::shared_ptr<VulkanBas
ErrorCode VulkanBasicExecutionDirect::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { ErrorCode VulkanBasicExecutionDirect::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto extra = static_cast<VulkanBackend *>(backend()); auto extra = static_cast<VulkanBackend *>(backend());
#ifdef ENABLE_VULKAN_TIME_PROFILE
extra->pushExecutionName(mEncoder->getName());
extra->pushQueryPool(mQueryPool);
#endif
extra->pushCommand(mCmdBuffer->get()); extra->pushCommand(mCmdBuffer->get());
return NO_ERROR; return NO_ERROR;
} }
@ -58,7 +62,18 @@ ErrorCode VulkanBasicExecutionDirect::onResize(const std::vector<Tensor *> &inpu
auto initCmdBuffer = static_cast<VulkanBackend*>(backend())->getInitCommandBuffer(); auto initCmdBuffer = static_cast<VulkanBackend*>(backend())->getInitCommandBuffer();
_initLayout(inputs, outputs, initCmdBuffer); _initLayout(inputs, outputs, initCmdBuffer);
mCmdBuffer->begin(0); mCmdBuffer->begin(0);
#ifdef ENABLE_VULKAN_TIME_PROFILE
auto vkBn = static_cast<VulkanBackend*>(backend());
mQueryPool.reset(new VulkanQueryPool(vkBn->device()));
mQueryPool->VulkanCmdResetQueryPool(mCmdBuffer.get()->get());
mQueryPool->VulkanCmdWriteTimestamp(mCmdBuffer.get()->get(), 0);
auto code = mEncoder->onEncode(inputs, outputs, mCmdBuffer.get()); auto code = mEncoder->onEncode(inputs, outputs, mCmdBuffer.get());
mQueryPool->VulkanCmdWriteTimestamp(mCmdBuffer.get()->get(), 1);
#else
auto code = mEncoder->onEncode(inputs, outputs, mCmdBuffer.get());
#endif
for (auto output : outputs) { for (auto output : outputs) {
auto vkTensor = reinterpret_cast<VulkanTensor*>(output->deviceId()); auto vkTensor = reinterpret_cast<VulkanTensor*>(output->deviceId());
for (int i=0; i<vkTensor->imageSize(); ++i) { for (int i=0; i<vkTensor->imageSize(); ++i) {
@ -67,6 +82,7 @@ ErrorCode VulkanBasicExecutionDirect::onResize(const std::vector<Tensor *> &inpu
} }
} }
_postTreat(outputs, mCmdBuffer.get()); _postTreat(outputs, mCmdBuffer.get());
mCmdBuffer->end(); mCmdBuffer->end();
#ifdef MNN_VULKAN_DEBUG #ifdef MNN_VULKAN_DEBUG
#ifdef MNN_VULKAN_DEBUG_EAGER #ifdef MNN_VULKAN_DEBUG_EAGER

View File

@ -26,6 +26,16 @@ public:
Backend* backend() { Backend* backend() {
return mBackend; return mBackend;
} }
#ifdef ENABLE_VULKAN_TIME_PROFILE
void setName(const char * name) {
mName = name;
}
std::string getName() {
return mName;
}
protected:
std::string mName;
#endif
private: private:
Backend* mBackend; Backend* mBackend;
}; };
@ -40,6 +50,9 @@ public:
private: private:
std::shared_ptr<VulkanBasicExecution> mEncoder; std::shared_ptr<VulkanBasicExecution> mEncoder;
std::shared_ptr<VulkanCommandPool::Buffer> mCmdBuffer; std::shared_ptr<VulkanCommandPool::Buffer> mCmdBuffer;
#ifdef ENABLE_VULKAN_TIME_PROFILE
std::shared_ptr<VulkanQueryPool> mQueryPool;
#endif
}; };
class VulkanBasicExecutionInDirect : public Execution { class VulkanBasicExecutionInDirect : public Execution {
public: public:

View File

@ -1,5 +1,5 @@
// //
// VulkanConvolutionImpl.hpp // VulkanConvolution1x1.hpp
// MNN // MNN
// //
// Created by MNN on 2025/01/10. // Created by MNN on 2025/01/10.

View File

@ -1,5 +1,5 @@
// //
// VulkanConvolutionImpl.hpp // VulkanConvolution1x1.hpp
// MNN // MNN
// //
// Created by MNN on 2025/01/10. // Created by MNN on 2025/01/10.

View File

@ -14,8 +14,6 @@ extern const unsigned char glsl_softmaxImage_AXIS_W_comp[];
extern unsigned int glsl_softmaxImage_AXIS_W_comp_len; extern unsigned int glsl_softmaxImage_AXIS_W_comp_len;
extern const unsigned char glsl_softmaxImage_AXIS_C_comp[]; extern const unsigned char glsl_softmaxImage_AXIS_C_comp[];
extern unsigned int glsl_softmaxImage_AXIS_C_comp_len; extern unsigned int glsl_softmaxImage_AXIS_C_comp_len;
extern const unsigned char glsl_convolutionDepthwiseMali_comp[];
extern unsigned int glsl_convolutionDepthwiseMali_comp_len;
extern const unsigned char glsl_relu_comp[]; extern const unsigned char glsl_relu_comp[];
extern unsigned int glsl_relu_comp_len; extern unsigned int glsl_relu_comp_len;
extern const unsigned char glsl_unaryImage_comp[]; extern const unsigned char glsl_unaryImage_comp[];

View File

@ -494,6 +494,14 @@ std::shared_ptr<ConvolutionCommon::Int8Common> ConvolutionCommon::load(const Op*
if (backend && backend->getRuntime()) { if (backend && backend->getRuntime()) {
useCachedMmap = backend->getRuntime()->hint().useCachedMmap > 1; useCachedMmap = backend->getRuntime()->hint().useCachedMmap > 1;
} }
if (USE_EXTERNAL_DATA(conv) && op->externalPath() && quan->type() == 8) {
std::unique_ptr<FileLoader> external(new FileLoader(op->externalPath()->c_str()));
auto param = op->main_as_Convolution2D();
external->offset(param->external()->data()[0]);
result->weightFloat.reset(param->external()->data()[1] / sizeof(float));
external->read((char*)(result->weightFloat.get()), param->external()->data()[1]);
return result;
}
if (USE_EXTERNAL_DATA(conv) && op->externalPath() && quan->buffer() == nullptr) { if (USE_EXTERNAL_DATA(conv) && op->externalPath() && quan->buffer() == nullptr) {
auto external_info = conv->external()->data(); auto external_info = conv->external()->data();
buffer_size = external_info[1]; buffer_size = external_info[1];

View File

@ -673,10 +673,12 @@ static bool _RebuildExternalOp(FileLoader* external, const MNN::Op* origin, flat
external->read((char*)param->quanParameter->index.data(), param->external[4]); external->read((char*)param->quanParameter->index.data(), param->external[4]);
} }
} else { } else {
external->offset(param->external[0]); // Create quanParameter, will load external weight in ConvolutionCommon::load
param->weight.resize(param->external[1] / sizeof(float)); param->quanParameter.reset(new IDSTQuanT);
external->read((char*)param->weight.data(), param->external[1]); param->quanParameter->type = 8;
op->externalPath = external->path();
param->bias.resize(param->external[2] / sizeof(float)); param->bias.resize(param->external[2] / sizeof(float));
external->offset(param->external[0] + param->external[1]);
external->read((char*)param->bias.data(), param->external[2]); external->read((char*)param->bias.data(), param->external[2]);
} }
break; break;
@ -724,6 +726,7 @@ Execution* OpCommonUtils::createExecutionWithExternal(Backend* backend, const st
return execution; return execution;
} }
if (op->main_type() == OpParameter_Convolution2D) { if (op->main_type() == OpParameter_Convolution2D) {
// For convolution / deconvolution, execution will need op info after created, try clone it and remove newOp
Execution* copyExe = nullptr; Execution* copyExe = nullptr;
execution->onClone(backend, op, &copyExe); execution->onClone(backend, op, &copyExe);
if (nullptr != copyExe) { if (nullptr != copyExe) {

View File

@ -231,15 +231,25 @@ public:
auto sh = common->strideY(); auto sh = common->strideY();
auto dw = common->dilateX(); auto dw = common->dilateX();
auto dh = common->dilateY(); auto dh = common->dilateY();
auto pw = common->padX(); int pl,pt,pr,pb;
auto ph = common->padY(); if (common->pads() == nullptr) {
pl = common->padX();
pr = common->padX();
pt = common->padY();
pb = common->padY();
} else {
pl = common->pads()->data()[1];
pr = common->pads()->data()[3];
pt = common->pads()->data()[0];
pb = common->pads()->data()[2];
}
auto batch = output->batch(); auto batch = output->batch();
auto ic = output->channel(); auto ic = output->channel();
auto iw = output->width(); auto iw = output->width();
auto ih = output->height(); auto ih = output->height();
auto pads = std::make_pair(pw, ph); auto pads = std::make_pair(pl, pt);
auto ow = (iw + pw * 2 - kw) / sw + 1; auto ow = (iw + pl + pr - kw) / sw + 1;
auto oh = (ih + ph * 2 - kh) / sh + 1; auto oh = (ih + pt + pb - kh) / sh + 1;
auto shape = output->shape(); auto shape = output->shape();
auto ishape = input->shape(); auto ishape = input->shape();
int n = ishape[0]; int n = ishape[0];

View File

@ -451,8 +451,16 @@ public:
if (newAxis == 0) { if (newAxis == 0) {
concatLen = attr->elemShape[shapeIndex][concatAxis]; concatLen = attr->elemShape[shapeIndex][concatAxis];
} }
if (1 == outside && outDes->regions.size() > 0) {
// If outside is 1, fuse to one region
outDes->regions[outDes->regions.size() - 1].size[2] += inside * concatLen;
concatSum += concatLen;
continue;
}
if (concatLast == concatLen) { if (concatLast == concatLen) {
// Fuse to last region
outDes->regions[outDes->regions.size() - 1].size[0] += 1; outDes->regions[outDes->regions.size() - 1].size[0] += 1;
concatSum += concatLen;
continue; continue;
} }
Tensor::InsideDescribe::Region reg; Tensor::InsideDescribe::Region reg;

View File

@ -102,6 +102,7 @@ extern void ___Conv2DBackpropFilterSizeComputer__OpType_Conv2DBackPropFilter__()
extern void ___Im2ColSizeComputer__OpType_Im2Col__(); extern void ___Im2ColSizeComputer__OpType_Im2Col__();
extern void ___Col2ImSizeComputer__OpType_Col2Im__(); extern void ___Col2ImSizeComputer__OpType_Col2Im__();
extern void ___ShapeScatterNd__OpType_ScatterNd__(); extern void ___ShapeScatterNd__OpType_ScatterNd__();
extern void ___StftOpComputer__OpType_Stft__();
extern void ___LSTMComputer__OpType_LSTM__(); extern void ___LSTMComputer__OpType_LSTM__();
extern void ___LSTMBlockCellComputer__OpType_LSTMBlockCell__(); extern void ___LSTMBlockCellComputer__OpType_LSTMBlockCell__();
extern void ___RNNComputer__OpType_RNN__(); extern void ___RNNComputer__OpType_RNN__();
@ -122,9 +123,6 @@ extern void ___FmhaV2SizeComputer__OpType_FmhaV2__();
extern void ___FmhcaSizeComputer__OpType_Fmhca__(); extern void ___FmhcaSizeComputer__OpType_Fmhca__();
extern void ___AttentionSizeComputer__OpType_Attention__(); extern void ___AttentionSizeComputer__OpType_Attention__();
#endif #endif
#ifdef MNN_BUILD_AUDIO
extern void ___StftOpComputer__OpType_Stft__();
#endif
void registerShapeOps() { void registerShapeOps() {
___ShapeSizeComputer__OpType_Shape__(); ___ShapeSizeComputer__OpType_Shape__();
___ShapeRasterComputer__OpType_Raster__(); ___ShapeRasterComputer__OpType_Raster__();
@ -228,6 +226,7 @@ ___Conv2DBackpropFilterSizeComputer__OpType_Conv2DBackPropFilter__();
___Im2ColSizeComputer__OpType_Im2Col__(); ___Im2ColSizeComputer__OpType_Im2Col__();
___Col2ImSizeComputer__OpType_Col2Im__(); ___Col2ImSizeComputer__OpType_Col2Im__();
___ShapeScatterNd__OpType_ScatterNd__(); ___ShapeScatterNd__OpType_ScatterNd__();
___StftOpComputer__OpType_Stft__();
___LSTMComputer__OpType_LSTM__(); ___LSTMComputer__OpType_LSTM__();
___LSTMBlockCellComputer__OpType_LSTMBlockCell__(); ___LSTMBlockCellComputer__OpType_LSTMBlockCell__();
___RNNComputer__OpType_RNN__(); ___RNNComputer__OpType_RNN__();
@ -247,8 +246,5 @@ ___FmhaV2SizeComputer__OpType_FmhaV2__();
___FmhcaSizeComputer__OpType_Fmhca__(); ___FmhcaSizeComputer__OpType_Fmhca__();
___AttentionSizeComputer__OpType_Attention__(); ___AttentionSizeComputer__OpType_Attention__();
#endif #endif
#ifdef MNN_BUILD_AUDIO
___StftOpComputer__OpType_Stft__();
#endif
} }
} }

View File

@ -6,8 +6,6 @@
// Copyright © 2018, Alibaba Group Holding Limited // Copyright © 2018, Alibaba Group Holding Limited
// //
#ifdef MNN_BUILD_AUDIO
#include "shape/SizeComputer.hpp" #include "shape/SizeComputer.hpp"
#include "core/Macro.h" #include "core/Macro.h"
#include "core/TensorUtils.hpp" #include "core/TensorUtils.hpp"
@ -17,22 +15,27 @@ namespace MNN {
class StftOpComputer : public SizeComputer { class StftOpComputer : public SizeComputer {
virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs,
const std::vector<Tensor*>& outputs) const override { const std::vector<Tensor*>& outputs) const override {
int sample_length = inputs[0]->elementSize(); int batch_size = inputs[0]->length(0);
auto stft = op->main_as_StftParam(); int signal_length = inputs[0]->length(1);
bool abs = stft->abs(); outputs[0]->buffer().dimensions = 4;
int n_fft = stft->n_fft(); outputs[0]->setLength(3, 2);
int hop_length = stft->hop_length(); outputs[0]->setLength(0, batch_size);
int frames = (sample_length - n_fft) / hop_length + 1; int frame_length = inputs[2]->length(0);
// Scalar int nstfts = ((signal_length - frame_length) / inputs[1]->host<int>()[0]) + 1;
outputs[0]->buffer().dimensions = 2; outputs[0]->setLength(1, nstfts);
outputs[0]->setLength(0, frames);
outputs[0]->setLength(1, n_fft / 2 + 1); int dft_unique_bins;
outputs[0]->buffer().type = inputs[0]->getType(); if (op->main_as_StftParam()->abs()) {
dft_unique_bins = frame_length / 2 + 1;
} else {
dft_unique_bins = frame_length;
}
outputs[0]->setLength(2, dft_unique_bins);
TensorUtils::getDescribe(outputs[0])->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat; TensorUtils::getDescribe(outputs[0])->dimensionFormat = TensorUtils::getDescribe(inputs[0])->dimensionFormat;
return true; return true;
} }
}; };
REGISTER_SHAPE_INPUTS(StftOpComputer, OpType_Stft, std::vector<int>{1});
REGISTER_SHAPE_AUDIO(StftOpComputer, OpType_Stft);
} // namespace MNN } // namespace MNN
#endif // MNN_BUILD_AUDIO

View File

@ -186,13 +186,4 @@ public:
#endif #endif
#ifdef MNN_BUILD_AUDIO
#define REGISTER_SHAPE_AUDIO(name, op) \
void ___##name##__##op##__() { \
name* _temp = new name; \
SizeComputerSuite* ts = SizeComputerSuite::get(); \
ts->insert(_temp, op); \
}
#endif
#endif #endif

View File

@ -18,9 +18,9 @@ using namespace MNN;
class HybridConvSpeedTestCommon : public MNNTestCase { class HybridConvSpeedTestCommon : public MNNTestCase {
protected: protected:
static bool testKernel(std::string title, INTS inputShape, INTS kernel, INTS channel, INTS pad, INTS strides, INTS dilate, int batch = 1, int nbit = 8, int precision = 1, bool testSpeed = false, int block = 0) { static bool testKernel(std::string title, INTS inputShape, INTS kernel, INTS channel, INTS pad, INTS strides, INTS dilate, int batch = 1, int nbit = 8, int precision = 1, bool testSpeed = false, int block = 0) {
float fac = 1.23; float fac = 0.23;
int res = 10; int res = 10;
float tail = 0.2; float tail = 0.05;
int ic = channel[0], oc = channel[1]; int ic = channel[0], oc = channel[1];
int iw = inputShape[0], ih = inputShape[1]; int iw = inputShape[0], ih = inputShape[1];
std::vector<float> bias(oc), biastest(oc), biasdup(oc); std::vector<float> bias(oc), biastest(oc), biasdup(oc);
@ -39,7 +39,7 @@ protected:
auto xPtr = x->writeMap<float>(); auto xPtr = x->writeMap<float>();
int8_t xMin = -(1<<(nbit-1))+1, xMax = (1<<(nbit-1))-1; int8_t xMin = -(1<<(nbit-1))+1, xMax = (1<<(nbit-1))-1;
for (int i=0; i<xInfo->size; ++i) { for (int i=0; i<xInfo->size; ++i) {
xPtr[i] = (i % (xMax - xMin + 1) - (xMax / 2)) * 0.27; xPtr[i] = (i % (xMax - xMin + 1) - (xMax / 2)) * 0.17;
} }
x = _Convert(x, NC4HW4); x = _Convert(x, NC4HW4);
for (int i = 0; i < oc; ++i) { for (int i = 0; i < oc; ++i) {
@ -101,8 +101,12 @@ protected:
} }
if (testSpeed) { if (testSpeed) {
x.fix(VARP::INPUT); x.fix(VARP::INPUT);
MNN::Timer _t;
const int LOOP = 20; const int LOOP = 20;
{
x->writeMap<FLOAT_T>();
y->readMap<FLOAT_T>();
}
MNN::Timer _t;
for (int i = 0; i < LOOP; ++i) { for (int i = 0; i < LOOP; ++i) {
x->writeMap<FLOAT_T>(); x->writeMap<FLOAT_T>();
y->readMap<FLOAT_T>(); y->readMap<FLOAT_T>();
@ -119,31 +123,26 @@ class HybridConvSpeedInt8Test : public HybridConvSpeedTestCommon {
public: public:
virtual bool run(int precision) { virtual bool run(int precision) {
INTS strides = {1, 1}, dilate = {1, 1}, pad = {0, 0}, inputShape = {1, 1}; // {w, h} INTS strides = {1, 1}, dilate = {1, 1}, pad = {0, 0}, inputShape = {1, 1}; // {w, h}
INTS channel0 = {4096, 4096}; // {ic, co} int batch[] = {1, 256, 512, 1024};
INTS channel1 = {1496, 256};
int batch[3] = {23, 13, 1};
std::vector<int> blocks = {32, 128, 0}; std::vector<int> blocks = {32, 128, 0};
std::vector<std::vector<int>> channels = {{1536, 2048}, {1536, 8960}};
std::vector<int> kernels = {1, 1}; std::vector<int> kernels = {1, 1};
std::vector<int> weightBits = {4, 8}; std::vector<int> weightBits = {4, 8};
bool lowmemory = true; bool lowmemory = true;
int batchNum = sizeof(batch) / sizeof(int); int batchNum = sizeof(batch) / sizeof(int);
bool correct = true; bool correct = true;
for (auto block : blocks) { for (auto& bits : weightBits) {
for (auto& bits : weightBits) { for (auto &channel: channels) {
MNN_PRINT("Test for %d bits, block=%d\n", bits, block); for (auto block : blocks) {
for (int n = 0; n < batchNum; ++n) { MNN_PRINT("Test for %d bits, block=%d\n", bits, block);
auto res = testKernel("Low memory HybridConv test:", inputShape, kernels, channel0, pad, strides, dilate, batch[n], bits, precision, true, block); for (int n = 0; n < batchNum; ++n) {
if (!res) { auto res = testKernel("Low memory HybridConv test:", inputShape, kernels, channel, pad, strides, dilate, batch[n], bits, precision, true, block);
MNN_ERROR("Error: low memory hybridConv when n=%d, ic=%d, oc=%d\n", batch[n], channel0[0], channel0[1]); if (!res) {
correct = false; MNN_ERROR("Error: low memory hybridConv when bits=%d, n=%d, ic=%d, oc=%d\n", bits, batch[n], channel[0], channel[1]);
} correct = false;
} return false;
for (int n = 0; n < batchNum; ++n) { }
auto res = testKernel("Low memory HybridConv test:", inputShape, kernels, channel1, pad, strides, dilate, batch[n], bits, precision, true, block);
if (!res) {
MNN_ERROR("Error: low memory hybridConv when n=%d, ic=%d, oc=%d\n", batch[n], channel1[0], channel1[1]);
correct = false;
} }
} }
} }
@ -155,48 +154,40 @@ public:
class HybridConvInt8Test : public HybridConvSpeedTestCommon { class HybridConvInt8Test : public HybridConvSpeedTestCommon {
public: public:
virtual bool run(int precision) { virtual bool run(int precision) {
INTS strides = {1, 1}, dilate = {1, 1}, pad = {0, 0}, inputShape = {1, 1}; // {w, h} INTS strides = {1, 1}, dilate = {1, 1}, pad = {0, 0}, inputShape = {1, 1}; // {w, h}
int testBatchCount = 5; int batch[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 21, 22, 23, 1909};
// std::vector<int> batch(testBatchCount); std::vector<int> blocks = {32, 128, 0};
std::vector<int> batch = {1, 23, 149, 38, 29}; std::vector<std::vector<int>> channels = {{3, 7}, {4, 7}, {5, 7}, {12, 16}, {2048, 54}, {8, 8}, {8, 9}, {8, 16}, {7, 9}, {9, 9}, {2048, 54}, {1, 10}, {20, 153}, {9, 18}, {64, 12}, {1496, 11}, {10, 9}};
std::vector<int> kernels = {1, 1}; std::vector<int> kernels = {1, 1};
std::vector<int> weightBits = {4, 8};
bool lowmemory = true; bool lowmemory = true;
{ int batchNum = sizeof(batch) / sizeof(int);
std::vector< std::vector<int>> channels = {{7, 9}, {9, 9}, {2048, 54}, {1, 10}, {20, 153}, {9, 18}}; bool correct = true;
for (int i = 0; i < channels.size(); ++i) { for (auto block : blocks) {
for (int n = 0; n < batch.size(); ++n) { for (auto& bits : weightBits) {
auto res = testKernel("Low memory HybridConv test:", inputShape, kernels, channels[i], pad, strides, dilate, batch[n], 8, precision); for (auto &channel: channels) {
res &= testKernel("Low memory HybridConv test:", inputShape, kernels, channels[i], pad, strides, dilate, batch[n], 4, precision); for (int n = 0; n < batchNum; ++n) {
if (!res) { auto res = testKernel("Low memory HybridConv test:", inputShape, kernels, channel, pad, strides, dilate, batch[n], bits, precision, false, block);
MNN_ERROR("Error: low memory hybridConv when bits=8, n=%d, ic=%d, oc=%d\n", batch[n], channels[i][0], channels[i][1]); if (!res) {
return false; MNN_ERROR("Error: low memory hybridConv when bits=%d, n=%d, ic=%d, oc=%d\n", bits, batch[n], channel[0], channel[1]);
} correct = false;
} return false;
} }
}
{
std::vector< std::vector<int>> channels = {{12, 16}, {2048, 54}, {8, 8}, {8, 9}, {8, 16}};
for (int i = 0; i < channels.size(); ++i) {
for (int n = 0; n < batch.size(); ++n) {
auto res = testKernel("Low memory HybridConv test:", inputShape, kernels, channels[i], pad, strides, dilate, batch[n], 4, precision);
if (!res) {
MNN_ERROR("Error: low memory hybridConv when bits=4, n=%d, ic=%d, oc=%d\n", batch[n], channels[i][0], channels[i][1]);
return false;
} }
} }
} }
} }
return true; return correct;
} }
}; };
class DenseConvInt8Test : public HybridConvSpeedTestCommon { class DenseConvInt8Test : public HybridConvSpeedTestCommon {
public: public:
virtual bool run(int precision) { virtual bool run(int precision) {
std::vector< std::vector<int>> channels = {{4, 256}, {512, 128}, {1, 8}, {7, 9}}; std::vector< std::vector<int>> channels = {{4, 17}, {8, 256}, {5, 8}, {3, 17}, {7, 26}, {9, 26}, {1, 8}, {7, 9}, {256, 256}, {1024, 2048}};
INTS strides = {1, 1}, dilate = {1, 3}, pad = {0, 3}, inputShape = {1, 131}; // {w, h} INTS strides = {1, 1}, dilate = {1, 3}, pad = {0, 3}, inputShape = {1, 131}; // {w, h}
std::vector<int> batch = {1, 13}; std::vector<int> batch = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 21, 22, 25, 28};
std::vector<std::vector<int>> kernels = {{1, 1}, {1, 3}}; std::vector<std::vector<int>> kernels = {{1, 1}, {1, 3}};
std::vector<int> weightBits = {4, 8}; std::vector<int> weightBits = {4, 8};
bool lowmemory = true; bool lowmemory = true;
@ -207,12 +198,12 @@ public:
for (auto kernel : kernels) { for (auto kernel : kernels) {
std::vector<int> blocks = {0}; std::vector<int> blocks = {0};
if (kernel[0] == 1 && kernel[1] == 1) { if (kernel[0] == 1 && kernel[1] == 1) {
blocks = {0, 32}; blocks = {0, 32, 128};
} }
for (auto block : blocks) { for (auto block : blocks) {
auto res = testKernel("Low memory ConvInt8 with kernel test:", inputShape, kernel, channels[i], pad, strides, dilate, batch[n], bits, precision, false, block); auto res = testKernel("Low memory ConvInt8 with kernel test:", inputShape, kernel, channels[i], pad, strides, dilate, batch[n], bits, precision, false, block);
if (!res) { if (!res) {
MNN_ERROR("Error: low memory ConvInt8 with %dx%d kernel when bits=%d n=%d, ic=%d, oc=%d, block=%d\n", kernel[0], kernel[1], bits, batch[n], channels[i][0], channels[i][1], block); MNN_ERROR("Error: low memory ConvInt8 with %dx%d kernel when bits=%d, n=%d, ic=%d, oc=%d, block=%d\n", kernel[0], kernel[1], bits, batch[n], channels[i][0], channels[i][1], block);
return false; return false;
} }
} }

View File

@ -9,6 +9,7 @@
#include "audio/audio.hpp" #include "audio/audio.hpp"
#include <MNN/expr/MathOp.hpp> #include <MNN/expr/MathOp.hpp>
#include <MNN/expr/NeuralNetWorkOp.hpp> #include <MNN/expr/NeuralNetWorkOp.hpp>
#include "MNN_generated.h"
#include <cmath> #include <cmath>
#include <algorithm> #include <algorithm>
#include <complex> #include <complex>
@ -399,11 +400,12 @@ VARP spectrogram(VARP waveform, const SpectrogramParams *params) {
power = params->power; power = params->power;
} }
if (pad_left > 1 || pad_right > 1) { if (pad_left > 1 || pad_right > 1) {
waveform = _Pad(waveform, _var<int>({pad_left, pad_right}, {2}), CONSTANT); waveform = MNN::Express::_Pad(waveform, _var<int>({pad_left, pad_right}, {2}), MNN::Express::CONSTANT);
} }
if (center) { if (center) {
waveform = _Pad(waveform, _var<int>({n_fft / 2, n_fft / 2}, {2}), static_cast<PadValueMode>(pad_mode)); waveform = MNN::Express::_Pad(waveform, _var<int>({n_fft / 2, n_fft / 2}, {2}), static_cast<MNN::Express::PadValueMode>(pad_mode));
} }
waveform = _Reshape(waveform, {1, -1, 1});
hop_length = hop_length ? hop_length : n_fft / 2; hop_length = hop_length ? hop_length : n_fft / 2;
win_length = win_length ? win_length : n_fft; win_length = win_length ? win_length : n_fft;
VARP window; VARP window;
@ -418,15 +420,35 @@ VARP spectrogram(VARP waveform, const SpectrogramParams *params) {
window = hann_window(win_length); window = hann_window(win_length);
break; break;
} }
auto specgram = _Stft(waveform, window, n_fft, hop_length); std::unique_ptr<OpT> op(new OpT);
op->type = OpType_Stft;
op->main.type = OpParameter_StftParam;
auto param = new StftParamT;
param->abs = true;
op->main.value = param;
EXPRP stftexpr = Expr::create(std::move(op), {waveform, _Scalar<int>(hop_length), window});
int nstfts = ((waveform->getInfo()->dim[1] - n_fft) / hop_length) + 1;
int dft_unique_bins = n_fft / 2 + 1;
auto specgram = MNN::Express::Variable::create(stftexpr);
specgram = _Square(specgram);
auto startsDims = std::vector<int>{0, 0, 0, 0};
auto starts1Dims = std::vector<int>{0, 0, 0, 1};
auto sizeDims = std::vector<int>{1, nstfts, dft_unique_bins, 1};
auto startVar = _Const(startsDims.data(), {4}, NCHW, halide_type_of<int>());
auto start1Var = _Const(starts1Dims.data(), {4}, NCHW, halide_type_of<int>());
auto sizeVar = _Const(sizeDims.data(), {4}, NCHW, halide_type_of<int>());
auto specgramReal = _Slice(specgram, startVar, sizeVar);
auto specgramVirt = _Slice(specgram, start1Var, sizeVar);
specgram = specgramReal + specgramVirt;
specgram = _Reshape(specgram, {nstfts, dft_unique_bins});
if (normalized) { if (normalized) {
float window_norm = std::sqrt(_ReduceSum(_Square(window))->readMap<float>()[0]); float window_norm = 1.0f / _ReduceSum(_Square(window))->readMap<float>()[0];
specgram = specgram / _Scalar<float>(window_norm); specgram = specgram * _Scalar<float>(window_norm);
} }
if (power == 2.0) { if (power == 1.0f) {
specgram = _Square(specgram); specgram = _Sqrt(specgram);
} else if (power > 2.0) { } else if (power != 2.0f) {
specgram = _Pow(specgram, _Scalar<float>(power)); specgram = _Pow(specgram, _Scalar<float>(power / 2.0f));
} }
return specgram; return specgram;
} }

Some files were not shown because too many files have changed in this diff Show More