From 318a3de860de6b19f013087828540130e24069ac Mon Sep 17 00:00:00 2001 From: xiaying Date: Fri, 22 Aug 2025 18:04:08 +0800 Subject: [PATCH] MNN:Sync: Sync Internal 3.2.3 --- .gitignore | 4 +- README.md | 2 - docs/inference/module.md | 15 + docs/inference/npu.md | 4 + docs/tools/convert.md | 9 + docs/transformers/llm.md | 13 + express/Executor.cpp | 11 + include/MNN/expr/Executor.hpp | 1 + project/android/CMakeExports.txt | 6 + .../gradle/wrapper/gradle-wrapper.properties | 2 +- project/android/nativepub.gradle | 37 +- project/android/qnnprepare.gradle | 174 ++++ project/ios/MNN.xcodeproj/project.pbxproj | 8 - pymnn/CMakeLists.txt | 8 + pymnn/src/llm.h | 104 +- pymnn/src/util.h | 3 + source/backend/arm82/Arm82Functions.cpp | 966 +++++++++++++++--- source/backend/cpu/CPUAttention.cpp | 400 ++++---- source/backend/cpu/CPUAttention.hpp | 11 +- source/backend/cpu/CPUBackend.cpp | 14 +- source/backend/cpu/CPUBackend.hpp | 4 +- source/backend/cpu/CPURuntime.cpp | 8 +- source/backend/cpu/CPUSoftmax.cpp | 2 +- source/backend/cpu/KVCacheManager.cpp | 66 +- source/backend/cpu/KVCacheManager.hpp | 32 +- .../backend/cpu/arm/CommonOptFunctionNeon.cpp | 380 +++++++ source/backend/cpu/arm/arm32/MNNExpC8.S | 115 --- source/backend/cpu/arm/arm32/MNNSoftmax.S | 270 ----- source/backend/cpu/arm/arm64/MNNExpC8.S | 109 -- source/backend/cpu/arm/arm64/MNNSoftmax.S | 204 ---- ...NNGemmInt8AddBiasScaleHp128_SME2_w4_Fp32.S | 530 +++++----- ...NNGemmInt8AddBiasScaleHp128_SME2_w8_Fp32.S | 529 +++++----- .../backend/cpu/compute/CommonOptFunction.cpp | 256 +++-- .../backend/cpu/compute/CommonOptFunction.h | 21 +- .../cpu/compute/ConvInt8TiledExecutor.cpp | 7 +- source/backend/cpu/x86_x64/AVX2Functions.cpp | 3 +- .../cpu/x86_x64/FunctionDispatcher.cpp | 8 +- .../cpu/x86_x64/avx/FunctionSummary.hpp | 2 +- .../backend/cpu/x86_x64/avx/MathFunctions.cpp | 156 ++- .../cpu/x86_x64/avx/PackedFunction.cpp | 42 + .../cpu/x86_x64/avx512/PackedFunction.cpp | 37 + .../cpu/x86_x64/sse/FunctionSummary.hpp | 2 +- .../backend/cpu/x86_x64/sse/MathFunctions.cpp | 148 ++- .../backend/opencl/core/BufferConvertor.cpp | 39 +- .../backend/opencl/core/BufferConvertor.hpp | 2 +- source/backend/opencl/core/OpenCLBackend.cpp | 17 +- source/backend/qnn/backend/QNNBackend.cpp | 125 ++- source/backend/qnn/backend/QNNBackend.hpp | 6 + source/backend/qnn/backend/QNNUtils.cpp | 2 + source/backend/qnn/backend/QNNUtils.hpp | 2 + source/backend/qnn/backend/QNNWrapper.cpp | 14 +- source/backend/qnn/convertor/QNNConvertor.cpp | 154 ++- source/backend/qnn/convertor/QNNConvertor.hpp | 3 +- .../qnn/execution/QNNConvDepthwise.cpp | 335 +++++- .../qnn/execution/QNNConvDepthwise.hpp | 24 +- .../backend/qnn/execution/QNNConvolution.cpp | 659 ++++++++++-- .../backend/qnn/execution/QNNConvolution.hpp | 33 +- source/backend/qnn/execution/QNNQuant.cpp | 81 ++ source/backend/qnn/execution/QNNQuant.hpp | 32 + source/backend/qnn/execution/QNNScale.cpp | 133 ++- source/backend/qnn/execution/QNNScale.hpp | 1 + source/core/Backend.hpp | 10 +- tools/converter/source/common/cli.cpp | 7 + .../source/optimizer/PostConverter.cpp | 19 +- .../source/optimizer/onnxextra/OnnxClip.cpp | 2 +- .../optimizer/postconvert/RemoveCopy.cpp | 68 +- .../postconvert/RemoveTestNoUseOps.cpp | 149 +++ .../postconvert/RemoveTestNoUseOps.hpp | 132 +-- .../postconvert/RemoveUnusefulOp.cpp | 5 + tools/cpp/MNN2QNNModel.cpp | 35 +- tools/script/convertOnnxTest.py | 2 +- tools/script/genQNNModelsFromMNN.py | 168 +++ transformers/llm/engine/CMakeLists.txt | 15 + transformers/llm/engine/include/llm/llm.hpp | 23 + transformers/llm/engine/src/diskembedding.cpp | 1 + transformers/llm/engine/src/llm.cpp | 68 +- transformers/llm/engine/src/llmconfig.hpp | 4 - transformers/llm/engine/src/omni.cpp | 102 +- transformers/llm/engine/src/omni.hpp | 5 + transformers/llm/export/llmexport.py | 27 +- .../llm/export/utils/mnn_converter.py | 12 +- transformers/llm/export/utils/transformers.py | 4 +- 82 files changed, 5003 insertions(+), 2240 deletions(-) create mode 100644 project/android/qnnprepare.gradle delete mode 100644 source/backend/cpu/arm/arm32/MNNExpC8.S delete mode 100644 source/backend/cpu/arm/arm32/MNNSoftmax.S delete mode 100644 source/backend/cpu/arm/arm64/MNNExpC8.S delete mode 100644 source/backend/cpu/arm/arm64/MNNSoftmax.S create mode 100644 source/backend/qnn/execution/QNNQuant.cpp create mode 100644 source/backend/qnn/execution/QNNQuant.hpp create mode 100644 tools/converter/source/optimizer/postconvert/RemoveTestNoUseOps.cpp create mode 100644 tools/script/genQNNModelsFromMNN.py diff --git a/.gitignore b/.gitignore index eea1b1fa..4cfae4d0 100644 --- a/.gitignore +++ b/.gitignore @@ -376,4 +376,6 @@ datasets/* # qnn 3rdParty source/backend/qnn/3rdParty/include -apps/Android/MnnLlmChat/release_outputs +project/android/.cxx +pymnn/android/.cxx/ +pymnn/android/.cxx/abi_configuration_5u53tc49.json diff --git a/README.md b/README.md index b783063f..576f631f 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,6 @@ ## News 🔥 -- [2025/08/08] Now we support [gpt-oss-20b](./apps/Android/MnnLlmChat/README.md#releases). -- [2025/08/05] MNN Chat Android is availabe in [GooglePlay](https://play.google.com/store/apps/details?id=com.alibaba.mnnllm.android.release) ! - [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md)

Icon diff --git a/docs/inference/module.md b/docs/inference/module.md index e9dcbfa3..cddd1c4d 100644 --- a/docs/inference/module.md +++ b/docs/inference/module.md @@ -183,6 +183,21 @@ struct Info { const Info* getInfo() const; ``` +### 获取设备信息 +调用`getDeviceInfo`函数可获取`Device`信息,可以参考代码: +```cpp +std::string soc_id, dsp_arch; +bool success = MNN::Express::Executor::RuntimeManager::getDeviceInfo("dsp_arch", MNN_FORWARD_NN, dsp_arch); +if(success) { + MNN_PRINT("Device dsp_arch: %s\n", dsp_arch.c_str()); +} + +success = MNN::Express::Executor::RuntimeManager::getDeviceInfo("soc_id", MNN_FORWARD_NN, soc_id); +if(success) { + MNN_PRINT("Device soc_id: %s\n", soc_id.c_str()); +} +``` + ### 执行推理 调用`onForward`执行推理。 diff --git a/docs/inference/npu.md b/docs/inference/npu.md index 0b54ad2b..aabae719 100644 --- a/docs/inference/npu.md +++ b/docs/inference/npu.md @@ -58,6 +58,10 @@ adb push ${MNN_ROOT}/source/backend/qnn/3rdParty/lib/hexagon-v${HEXAGON_ARCH}/un adb shell "cd /data/local/tmp && LD_LIBRARY_PATH=/data/local/tmp ADSP_LIBRARY_PATH=/data/local/tmp ./MyExe.out" ``` +### QNN量化功能说明 +- 仅权重量化(激活是浮点):只支持Linear权重int8、channel-wise的对称量化。 +- 激活&权重都量化:支持激活per-tensor对称量化,权重是int8/int4、channel-wise的对称量化。 + ## CoreML 适用于 Mac / iOS / iPad diff --git a/docs/tools/convert.md b/docs/tools/convert.md index 531390fa..c1459474 100644 --- a/docs/tools/convert.md +++ b/docs/tools/convert.md @@ -388,3 +388,12 @@ npu model path:./qnn_smolvlm_model.bin ./ModuleBasic.out qnn_smolvlm_model.mnn dir 0 0 10 ``` +### 生成多种QNN设备模型脚本 +tools/script/genQNNModelsFromMNN.py中提供了8Gen1 ~ 8Elite设备的QNN模型生成脚本 +``` +// 使用示例 +cd mnn_path +cd build +python3 ../tools/script/genQNNModelsFromMNN.py --config_path ../source/backend/qnn/convertor/config_example/ --graph_name visual_qnn --qnn_sdk_root_path /mnt/2Tpartition/tianbu/QNN/qairt/2.37.0.250724/ --src_model visual.mnn --executable_path ./MNN2QNNModel +``` +后续将在qnn_models文件夹下生成8Gen1 ~ 8Elite设备的QNN模型产物。 diff --git a/docs/transformers/llm.md b/docs/transformers/llm.md index c1edf595..9cfe76c8 100644 --- a/docs/transformers/llm.md +++ b/docs/transformers/llm.md @@ -106,6 +106,10 @@ optional arguments: mnn quant bit, 4 or 8, default is 4. --quant_block QUANT_BLOCK mnn quant block, 0 mean channle-wise, default is 128. + --visual_quant_bit VISUAL_QUANT_BIT + mnn visual model quant bit, 4 or 8, default is setting in utils/vision.py by different vit model. + --visual_quant_block VISUAL_QUANT_BLOCK + mnn visual model quant block, 0 mean channle-wise, default is setting in utils/vision.py by different vit model. --lm_quant_bit LM_QUANT_BIT mnn lm_head quant bit, 4 or 8, default is `quant_bit`. --mnnconvert MNNCONVERT @@ -113,10 +117,12 @@ optional arguments: --ppl Whether or not to get all logits of input tokens. --awq Whether or not to use awq quant. --sym Whether or not to using symmetric quant (without zeropoint), defualt is False. + --visual_sym Whether or not to using symmetric quant (without zeropoint) for visual model, defualt is False. --seperate_embed For lm and embed shared model, whether or not to sepearte embed to avoid quant, defualt is False, if True, embed weight will be seperate to embeddingbf16.bin. --lora_split Whether or not export lora split, defualt is False. ``` + ### 权重读取 llmexport.py 同时支持 LLM 的验证功能,有较多的依赖。在没有相应环境的情况下,MNN-LLM也提供由 safetensors 或 gguf 文件读取权重的工具,可以降低内存需求,提高转换速度。使用方法如下: @@ -166,6 +172,7 @@ python3 gguf2mnn.py --gguf ~/third/llama.cpp/build/ggml-model-Q4_K.gguf --mnn_di ``` -DLLM_SUPPORT_VISION=true -DMNN_BUILD_OPENCV=true -DMNN_IMGCODECS=true ``` + - 需要开启音频功能时,增加相关编译宏 ``` -DLLM_SUPPORT_AUDIO=true -DMNN_BUILD_AUDIO=true @@ -195,6 +202,12 @@ cd project/android mkdir build_64 ../build_64.sh -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_ARM82=true -DMNN_OPENCL=true -DMNN_USE_LOGCAT=true ``` +高通设备部分视觉模型支持NPU功能,可增加`MNN_QNN` 和`MNN_WITH_PLUGIN`的宏启用QNN功能。 +``` +cd project/android +mkdir build_64 +../build_64.sh -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_ARM82=true -DMNN_OPENCL=true -DMNN_QNN=true -DMNN_WITH_PLUGIN=true -DMNN_USE_LOGCAT=true +``` #### iOS: 参考 transformers/llm/engine/ios/README.md ``` diff --git a/express/Executor.cpp b/express/Executor.cpp index 815ba8cb..d906ea7b 100644 --- a/express/Executor.cpp +++ b/express/Executor.cpp @@ -281,6 +281,17 @@ bool Executor::RuntimeManager::getInfo(Interpreter::SessionInfoCode code, void* return false; } +bool Executor::RuntimeManager::getDeviceInfo(const std::string& deviceKey, const MNNForwardType type, std::string& deviceValue) { + auto creator = MNNGetExtraRuntimeCreator(type); + if (creator != nullptr) { + auto res = creator->onGetDeviceInfo(deviceKey, deviceValue); + if(res) { + return true; + } + } + return false; +} + Executor::RuntimeManager::RuntimeManager() { mInside = new RuntimeAttr; mInside->mContent.reset(new RuntimeAttr::Immutable); diff --git a/include/MNN/expr/Executor.hpp b/include/MNN/expr/Executor.hpp index 010df348..744bc595 100644 --- a/include/MNN/expr/Executor.hpp +++ b/include/MNN/expr/Executor.hpp @@ -130,6 +130,7 @@ public: void setHint(Interpreter::HintMode mode, int* value, size_t size); void setHintPtr(Interpreter::HintMode mode, void* value); bool getInfo(Interpreter::SessionInfoCode code, void* ptr); + static bool getDeviceInfo(const std::string& deviceKey, const MNNForwardType type, std::string& deviceValue); BackendConfig* getBnConfig(); const RuntimeAttr* getInside() const { return mInside; diff --git a/project/android/CMakeExports.txt b/project/android/CMakeExports.txt index f64736cf..8c750fa8 100644 --- a/project/android/CMakeExports.txt +++ b/project/android/CMakeExports.txt @@ -25,3 +25,9 @@ set_target_properties( MNNOpenCV PROPERTIES IMPORTED_LOCATION ${CMAKE_CURRENT_LIST_DIR}/libs/${ANDROID_ABI}/libMNNOpenCV.so ) + +add_library(MNN_LLM SHARED IMPORTED GLOBAL ) +set_target_properties(MNN_LLM + PROPERTIES IMPORTED_LOCATION + ${CMAKE_CURRENT_LIST_DIR}/libs/${ANDROID_ABI}/libllm.so +) \ No newline at end of file diff --git a/project/android/gradle/wrapper/gradle-wrapper.properties b/project/android/gradle/wrapper/gradle-wrapper.properties index 43ea6285..9bbe87d1 100644 --- a/project/android/gradle/wrapper/gradle-wrapper.properties +++ b/project/android/gradle/wrapper/gradle-wrapper.properties @@ -3,4 +3,4 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=http\://mtl-gradle-mirror.oss-cn-hangzhou.aliyuncs.com/gradle-4.6-all.zip \ No newline at end of file +distributionUrl=http://mtl-gradle-mirror.oss-cn-hangzhou.aliyuncs.com/gradle-6.7.1-all.zip \ No newline at end of file diff --git a/project/android/nativepub.gradle b/project/android/nativepub.gradle index 6ef61921..2bf5e7b8 100644 --- a/project/android/nativepub.gradle +++ b/project/android/nativepub.gradle @@ -20,23 +20,48 @@ afterEvaluate { android.libraryVariants.all {variant -> def name = variant.buildType.name def zipTask = tasks.create(name: "zipNative${name.capitalize()}", type: Zip) { - from project.zipTree("${variant.packageLibrary.destinationDir.path}/${project.name}-${name}.aar") + from project.zipTree("${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}.aar") from "$buildDir/native-packed" - archiveName "${project.name}-${name}-tmp.aar" + archiveName "${artifactName}-${name}-tmp.aar" destinationDir(variant.packageLibrary.destinationDir) + inputs.dir("$buildDir/native-packed") + inputs.file("${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}.aar") + outputs.file("${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}.aar") } + + zipTask.dependsOn("externalNativeBuild${name.capitalize()}") + + // Ensure QNN dependencies are prepared when BUILD_QNN is enabled + if (project.ext.has('BUILD_QNN') && project.ext.BUILD_QNN) { + zipTask.dependsOn('prepareQnnDeps') + } + zipTask.doLast { copy { - from "${variant.packageLibrary.destinationDir.path}/${project.name}-${name}-tmp.aar" + from "${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}-tmp.aar" into "${variant.packageLibrary.destinationDir.path}" rename{ String fileName-> fileName.replace('-tmp','') } } - delete "${variant.packageLibrary.destinationDir.path}/${project.name}-${name}-tmp.aar" + delete "${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}-tmp.aar" + + println "Generated final AAR: ${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}.aar" } - tasks.findByName("assemble${name.capitalize()}").doLast { - zipTask.execute() + + def bundleTask = tasks.findByName("bundle${name.capitalize()}Aar") + if (bundleTask != null) { + zipTask.dependsOn(bundleTask) + } + + def assembleTask = tasks.findByName("assemble${name.capitalize()}") + if (assembleTask != null) { + assembleTask.dependsOn(zipTask) + zipTask.mustRunAfter(bundleTask) + } + def packageTask = tasks.findByName("package${name.capitalize()}Aar") + if (packageTask != null) { + packageTask.finalizedBy(zipTask) } } \ No newline at end of file diff --git a/project/android/qnnprepare.gradle b/project/android/qnnprepare.gradle new file mode 100644 index 00000000..da0c6538 --- /dev/null +++ b/project/android/qnnprepare.gradle @@ -0,0 +1,174 @@ +// QNN Dependencies Preparation +// This gradle file handles downloading and preparing QNN dependencies + +// QNN configuration +ext { + // QNN download settings + QNN_LIBS_URL = 'http://meta.alicdn.com/data/mnn/libs/qnn_inc_libs.zip' +} + +def qnnZipName = 'qnn_inc_libs.zip' +def qnnZipFile = new File(buildDir, qnnZipName) +def qnnTmpDir = new File(buildDir, 'qnn_tmp') + +task prepareQnnDeps { + group = 'build setup' + description = 'Download and extract QNN include/lib into source/backend/qnn/3rdParty when BUILD_QNN is enabled.' + onlyIf { + project.ext.has('BUILD_QNN') && project.ext.BUILD_QNN + } + + // Define inputs and outputs for incremental build + inputs.property("qnn_libs_url", { QNN_LIBS_URL }) + inputs.property("build_qnn", { project.ext.has('BUILD_QNN') && project.ext.BUILD_QNN }) + outputs.dir(new File(project.rootDir, '../../source/backend/qnn/3rdParty')) + outputs.dir(new File(buildDir, 'native-packed/native/libs')) + + doLast { + println "BUILD_QNN is ON. Preparing QNN dependencies..." + + def url = QNN_LIBS_URL + println "Downloading QNN dependencies from ${url}" + + if (!qnnZipFile.exists()) { + println "Downloading ${qnnZipName} ..." + qnnZipFile.parentFile.mkdirs() + + try { + new java.net.URL(url).withInputStream { inputStream -> + qnnZipFile.withOutputStream { outputStream -> + outputStream << inputStream + } + } + println "Downloaded to: ${qnnZipFile.absolutePath}" + } catch (Exception e) { + throw new RuntimeException("Failed to download QNN dependencies from ${url}. Error: ${e.message}") + } + } else { + println "Using cached zip: ${qnnZipFile.absolutePath}" + } + + // Clean temp dir and unpack + project.delete(qnnTmpDir) + qnnTmpDir.mkdirs() + copy { + from zipTree(qnnZipFile) + into qnnTmpDir + } + + // Find the extracted QNN directory + def extractedQnnDir = findQnnDirectory(qnnTmpDir) + if (extractedQnnDir == null) { + throw new RuntimeException("Failed to find QNN directory structure in ${qnnZipFile.name}") + } + + copyQnnFiles(extractedQnnDir) + } +} + +def findQnnDirectory(File searchDir) { + // Look for directory containing both include and jniLibs (or lib) directories + def candidates = [] + + searchDir.eachDirRecurse { dir -> + def hasInclude = new File(dir, 'include').exists() + def hasLibs = new File(dir, 'jniLibs').exists() || new File(dir, 'lib').exists() + + if (hasInclude && hasLibs) { + candidates.add(dir) + } + } + + if (candidates.isEmpty()) { + // Fallback: look for directory that contains include + searchDir.eachDirRecurse { dir -> + if (new File(dir, 'include').exists()) { + candidates.add(dir) + } + } + } + + return candidates.isEmpty() ? null : candidates[0] +} + +def copyQnnFiles(File sourceDir) { + // Resolve destination directories + def qnnRoot = new File(project.rootDir, '../../source/backend/qnn/3rdParty') + def includeDir = new File(qnnRoot, 'include') + def libDir = new File(qnnRoot, 'lib') + + includeDir.mkdirs() + libDir.mkdirs() + + // Copy include files + def sourceInclude = new File(sourceDir, 'include') + if (sourceInclude.exists()) { + copy { + from sourceInclude + into includeDir + } + println "QNN includes copied to: ${includeDir.absolutePath}" + } else { + throw new RuntimeException("Include directory not found in ${sourceDir.absolutePath}") + } + + // Copy library files - try both jniLibs and lib directories + def sourceLibs = new File(sourceDir, 'jniLibs') + if (!sourceLibs.exists()) { + sourceLibs = new File(sourceDir, 'lib') + } + + if (sourceLibs.exists()) { + copy { + from sourceLibs + into libDir + } + println "QNN libs copied to: ${libDir.absolutePath}" + + // Also copy QNN .so files to native-packed for AAR packaging + copyQnnLibsToNativePacked(sourceLibs) + } else { + println "Warning: No lib/jniLibs directory found in ${sourceDir.absolutePath}" + } + + println "QNN dependencies preparation completed successfully." +} + +def copyQnnLibsToNativePacked(File sourceLibsDir) { + // Create native-packed directory structure for QNN .so files + def nativePackedLibsDir = new File(buildDir, 'native-packed/native/libs') + nativePackedLibsDir.mkdirs() + + println "Copying QNN .so files to native-packed for AAR packaging..." + + // Copy all .so files from QNN libs to native-packed + sourceLibsDir.eachFileRecurse { file -> + if (file.name.endsWith('.so')) { + // Determine the ABI directory (arm64-v8a, armeabi-v7a, etc.) + def relativePath = sourceLibsDir.toPath().relativize(file.toPath()) + def targetFile = new File(nativePackedLibsDir, relativePath.toString()) + + // Create parent directories if they don't exist + targetFile.parentFile.mkdirs() + + // Copy the .so file + copy { + from file + into targetFile.parentFile + } + + println "Copied QNN lib: ${file.name} -> ${targetFile.absolutePath}" + } + } + + println "QNN .so files copied to native-packed directory" +} + +// Ensure preparation runs before compilation when enabled +if (project.ext.has('BUILD_QNN') && project.ext.BUILD_QNN) { + afterEvaluate { + if (tasks.findByName('preBuild')) { + preBuild.dependsOn prepareQnnDeps + } + } +} diff --git a/project/ios/MNN.xcodeproj/project.pbxproj b/project/ios/MNN.xcodeproj/project.pbxproj index 52c0ace4..1b56c2ef 100644 --- a/project/ios/MNN.xcodeproj/project.pbxproj +++ b/project/ios/MNN.xcodeproj/project.pbxproj @@ -480,7 +480,6 @@ 92FF02E023AA0B5A00AC97F6 /* MNNMatrixSub.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016023AA0B4E00AC97F6 /* MNNMatrixSub.S */; }; 92FF02E123AA0B5A00AC97F6 /* MNNPowC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016123AA0B4E00AC97F6 /* MNNPowC8.S */; }; 92FF02E223AA0B5A00AC97F6 /* MNNMatrixAdd.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */; }; - 92FF02E323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */; }; 92FF02E523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */; }; 92FF02E723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */; }; 92FF02E823AA0B5A00AC97F6 /* MNNSamplerC1BilinearOpt.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */; }; @@ -520,7 +519,6 @@ 92FF032023AA0B5A00AC97F6 /* MNNMatrixSub.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A123AA0B4E00AC97F6 /* MNNMatrixSub.S */; }; 92FF032123AA0B5A00AC97F6 /* MNNPowC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A223AA0B4E00AC97F6 /* MNNPowC8.S */; }; 92FF032223AA0B5A00AC97F6 /* MNNMatrixAdd.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */; }; - 92FF032323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */; }; 92FF032523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */; }; 92FF032723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */; }; 92FF032823AA0B5A00AC97F6 /* MNNSamplerC1BilinearOpt.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */; }; @@ -1315,7 +1313,6 @@ 92FF016023AA0B4E00AC97F6 /* MNNMatrixSub.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixSub.S; sourceTree = ""; }; 92FF016123AA0B4E00AC97F6 /* MNNPowC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPowC8.S; sourceTree = ""; }; 92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixAdd.S; sourceTree = ""; }; - 92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNExpC8.S; sourceTree = ""; }; 92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNConvDwF23SourceTransUnit.S; sourceTree = ""; }; 92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNDeconvRunForUnitDepthWise.S; sourceTree = ""; }; 92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNSamplerC1BilinearOpt.S; sourceTree = ""; }; @@ -1355,7 +1352,6 @@ 92FF01A123AA0B4E00AC97F6 /* MNNMatrixSub.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixSub.S; sourceTree = ""; }; 92FF01A223AA0B4E00AC97F6 /* MNNPowC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPowC8.S; sourceTree = ""; }; 92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixAdd.S; sourceTree = ""; }; - 92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNExpC8.S; sourceTree = ""; }; 92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNConvDwF23SourceTransUnit.S; sourceTree = ""; }; 92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNDeconvRunForUnitDepthWise.S; sourceTree = ""; }; 92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNSamplerC1BilinearOpt.S; sourceTree = ""; }; @@ -2591,7 +2587,6 @@ 92FF016023AA0B4E00AC97F6 /* MNNMatrixSub.S */, 92FF016123AA0B4E00AC97F6 /* MNNPowC8.S */, 92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */, - 92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */, 92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */, 92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */, 92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */, @@ -2677,7 +2672,6 @@ 92FF01A123AA0B4E00AC97F6 /* MNNMatrixSub.S */, 92FF01A223AA0B4E00AC97F6 /* MNNPowC8.S */, 92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */, - 92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */, 92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */, 92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */, 92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */, @@ -3279,7 +3273,6 @@ 92FF02C223AA0B5A00AC97F6 /* MNNLoadU8AndSum.S in Sources */, 4819FB2E24C1396A0050BD09 /* GeometryLSTM.cpp in Sources */, 9558334B29B09A7B00488807 /* MNNGeluFP16.S in Sources */, - 92FF02E323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */, 92FF030223AA0B5A00AC97F6 /* MNNQuanToDestUint8.S in Sources */, 489D7AA12550FDC900AD896A /* MetalUnary.mm in Sources */, 92FF037323AA0B5A00AC97F6 /* CPUEltwiseInt8.cpp in Sources */, @@ -3438,7 +3431,6 @@ 92FF030723AA0B5A00AC97F6 /* MNNBlitC1ToFloatRGBA.S in Sources */, 92FF03A423AA0B5A00AC97F6 /* OptimizedComputer.cpp in Sources */, 92FF032E23AA0B5A00AC97F6 /* MNNReluWithSlopeChannel.S in Sources */, - 92FF032323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */, 95278CEA2B9F09C0009E9B29 /* ShapeDynamicQuant.cpp in Sources */, 92FF044C23AA0B7100AC97F6 /* ShapePool3D.cpp in Sources */, 92FF029823AA0B5A00AC97F6 /* CPUTFQuantizedConv2D.cpp in Sources */, diff --git a/pymnn/CMakeLists.txt b/pymnn/CMakeLists.txt index 9eb89412..660a10f9 100644 --- a/pymnn/CMakeLists.txt +++ b/pymnn/CMakeLists.txt @@ -17,6 +17,7 @@ option(PYMNN_INTERNAL_SERVING "Internal use only." OFF) option(PYMNN_OPENCV_API "MNN OpenCV API be exposed" ON) option(PYMNN_IMGCODECS "MNN IMGCODECS API be exposed" OFF) option(PYMNN_AUDIO_API "MNN Audio API be exposed" OFF) +option(PYMNN_LLM_API "MNN LLM API be exposed" OFF) option(PYMNN_OHOS_INTERNAL "compile for harmony internal." OFF) option(PYMNN_LLM_COLLECTION "offline llm data collection." OFF) @@ -97,6 +98,10 @@ if(PYMNN_AUDIO_API) target_compile_definitions(mnnpybridge PRIVATE PYMNN_AUDIO_API) endif() +if(PYMNN_LLM_API) + target_compile_definitions(mnnpybridge PRIVATE PYMNN_LLM_API) +endif() + if(PYMNN_INTERNAL_SERVING) message(STATUS "mnnpybridge define PYMNN_INTERNAL_SERVING") target_compile_definitions(mnnpybridge PRIVATE PYMNN_INTERNAL_SERVING) @@ -214,6 +219,9 @@ else() if(PYMNN_NUMPY_USABLE) target_link_libraries(mnnpybridge PRIVATE numpy_python) endif() + if(PYMNN_LLM_API) + target_link_libraries(mnnpybridge PRIVATE MNN_LLM) + endif() endif() endif() diff --git a/pymnn/src/llm.h b/pymnn/src/llm.h index 3ddd20bc..2fe4c2b4 100644 --- a/pymnn/src/llm.h +++ b/pymnn/src/llm.h @@ -1,6 +1,10 @@ #include +#include +#ifdef BUILD_FOR_IOS +#include "MNN/llm/llm.hpp" +#else #include "llm/llm.hpp" - +#endif #ifdef PYMNN_LLM_COLLECTION #include "cpp/getLinearInput.hpp" #endif @@ -85,18 +89,87 @@ static PyObject* PyMNNLLM_getCurrentHistory(LLM *self, PyObject *args) { } static PyObject* PyMNNLLM_response(LLM *self, PyObject *args) { if (self->is_embedding) { + MNN_PRINT("[MNNLLM] response: is_embedding\n"); Py_RETURN_NONE; } - const char* query = NULL; + + PyObject* content = nullptr; int stream = 0; - if (!PyArg_ParseTuple(args, "s|p", &query, &stream)) { + int max_new_tokens = 2048; + + if (!PyArg_ParseTuple(args, "O|ii", &content, &stream, &max_new_tokens)) { + MNN_PRINT("[MNNLLM] response: PyArg_ParseTuple failed\n"); Py_RETURN_NONE; } + std::ostringstream null_os; - self->llm->response(query, stream ? &std::cout : &null_os); - return string2Object(null_os.str()); + std::ostream* output_stream = stream ? &std::cout : &null_os; + + if (isString(content)) { + std::string text = object2String(content); + MNN_PRINT("[MNNLLM] response: text=%s, stream=%d, max_new_tokens=%d\n", text.c_str(), stream, max_new_tokens); + self->llm->response(text, output_stream, nullptr, max_new_tokens); + } else if (isPyDict(content)) { + MNN::Transformer::MultimodalPrompt multimodal_input; + PyObject* text_obj = PyDict_GetItemString(content, "text"); + if (text_obj && isString(text_obj)) { + multimodal_input.prompt_template = object2String(text_obj); + } + PyObject* images_obj = PyDict_GetItemString(content, "images"); + if (images_obj && PyList_Check(images_obj)) { + Py_ssize_t img_count = PyList_Size(images_obj); + for (Py_ssize_t i = 0; i < img_count; i++) { + PyObject* img_dict = PyList_GetItem(images_obj, i); + if (isPyDict(img_dict)) { + PyObject* data_obj = PyDict_GetItemString(img_dict, "data"); + PyObject* width_obj = PyDict_GetItemString(img_dict, "width"); + PyObject* height_obj = PyDict_GetItemString(img_dict, "height"); + + if (data_obj && width_obj && height_obj) { + MNN::Transformer::PromptImagePart image_part; + image_part.image_data = toVar(data_obj); + image_part.width = PyLong_AsLong(width_obj); + image_part.height = PyLong_AsLong(height_obj); + + std::string key = "image_" + std::to_string(i); + multimodal_input.images[key] = image_part; + } + } + } + } + + PyObject* audios_obj = PyDict_GetItemString(content, "audios"); + if (audios_obj && PyList_Check(audios_obj)) { + Py_ssize_t audio_count = PyList_Size(audios_obj); + for (Py_ssize_t i = 0; i < audio_count; i++) { + PyObject* audio_dict = PyList_GetItem(audios_obj, i); + if (isPyDict(audio_dict)) { + MNN::Transformer::PromptAudioPart audio_part; + + PyObject* file_path_obj = PyDict_GetItemString(audio_dict, "file_path"); + if (file_path_obj && isString(file_path_obj)) { + audio_part.file_path = object2String(file_path_obj); + } + + PyObject* waveform_obj = PyDict_GetItemString(audio_dict, "waveform"); + if (waveform_obj) { + audio_part.waveform = toVar(waveform_obj); + } + + std::string key = "audio_" + std::to_string(i); + multimodal_input.audios[key] = audio_part; + } + } + } + MNN_PRINT("[MNNLLM] response: multimodal, stream=%d, max_new_tokens=%d\n", stream, max_new_tokens); + self->llm->response(multimodal_input, output_stream, nullptr, max_new_tokens); + } else { + PyMNN_ERROR("content must be str or dict"); + } + std::string response_str = null_os.str(); + MNN_PRINT("[MNNLLM] response: %s\n", response_str.c_str()); + return string2Object(response_str); } - static PyObject* PyMNNLLM_tokenizer_encode(LLM *self, PyObject *args) { if (self->is_embedding) { Py_RETURN_NONE; @@ -149,6 +222,14 @@ static PyObject* PyMNNLLM_reset(LLM *self, PyObject *args) { Py_RETURN_NONE; } +static PyObject* PyMNNLLM_get_statistics(LLM *self, PyObject *args) { + if (self->is_embedding) { + Py_RETURN_NONE; + } + auto statistics = self->llm->get_statistics(); + return string2Object(statistics); +} + #ifdef PYMNN_LLM_COLLECTION static PyObject* PyMNNLLM_enable_collection_mode(LLM *self, PyObject *args) { if (self->is_embedding) { @@ -205,12 +286,11 @@ static PyObject* PyMNNLLM_enable_collection_mode(LLM *self, PyObject *args) { return toPyObj(true); } #endif - static PyMethodDef PyMNNLLM_methods[] = { {"load", (PyCFunction)PyMNNLLM_load, METH_VARARGS, "load model."}, {"forward", (PyCFunction)PyMNNLLM_forward, METH_VARARGS, "forward `logits` by `input_ids`."}, {"generate", (PyCFunction)PyMNNLLM_generate, METH_VARARGS, "generate `output_ids` by `input_ids`."}, - {"response", (PyCFunction)PyMNNLLM_response, METH_VARARGS, "response `query` without hsitory."}, + {"response", (PyCFunction)PyMNNLLM_response, METH_VARARGS, "response `query` - supports both text and multimodal input."}, {"get_current_history", (PyCFunction)PyMNNLLM_getCurrentHistory, METH_VARARGS, "Get Current History."}, {"erase_history", (PyCFunction)PyMNNLLM_eraseHistory, METH_VARARGS, "Erase History."}, {"tokenizer_encode", (PyCFunction)PyMNNLLM_tokenizer_encode, METH_VARARGS, "tokenizer encode."}, @@ -219,6 +299,7 @@ static PyMethodDef PyMNNLLM_methods[] = { {"create_lora", (PyCFunction)PyMNNLLM_create_lora, METH_VARARGS, "create_lora."}, {"set_config", (PyCFunction)PyMNNLLM_set_config, METH_VARARGS, "set_config."}, {"reset", (PyCFunction)PyMNNLLM_reset, METH_VARARGS, "reset."}, + {"get_statistics", (PyCFunction)PyMNNLLM_get_statistics, METH_VARARGS, "get performance statistics."}, #ifdef PYMNN_LLM_COLLECTION {"enable_collection_mode", (PyCFunction)PyMNNLLM_enable_collection_mode, METH_VARARGS, "Enable data collection mode."}, #endif @@ -274,7 +355,7 @@ static PyObject* PyMNNLLM_create_lora(LLM *self, PyObject *args) { Py_RETURN_NONE; } auto lora = self->llm->create_lora(path); - LLM *llm = (LLM *)PyObject_Call((PyObject*)&PyMNNLLM, PyTuple_New(0), NULL); + LLM *llm = (LLM *)PyObject_Call((PyObject*)PyType_FindTLSType(&PyMNNLLM), PyTuple_New(0), NULL); if (!llm) { return NULL; } @@ -288,10 +369,11 @@ static PyObject* PyMNNLLM_create(PyObject *self, PyObject *args) { } const char* path = NULL; int embedding_model = 0; - if (!PyArg_ParseTuple(args, "s|p", &path, &embedding_model)) { + if (!PyArg_ParseTuple(args, "s|i", &path, &embedding_model)) { + PyMNN_ERROR_LOG("Invalid arguments. Usage: create(path, embedding_model=False)"); return NULL; } - LLM *llm = (LLM *)PyObject_Call((PyObject*)&PyMNNLLM, PyTuple_New(0), NULL); + LLM *llm = (LLM *)PyObject_Call((PyObject*)PyType_FindTLSType(&PyMNNLLM), PyTuple_New(0), NULL); if (!llm) { return NULL; } diff --git a/pymnn/src/util.h b/pymnn/src/util.h index 4ed57f8a..6a657457 100644 --- a/pymnn/src/util.h +++ b/pymnn/src/util.h @@ -398,6 +398,9 @@ static inline bool isPySequence(PyObject* obj) { // use isPySequence replace PySequence_Check return PyTuple_Check(obj) || PyList_Check(obj) || PyBytes_Check(obj); } +static inline bool isPyDict(PyObject* obj) { + return PyDict_Check(obj); +} static inline int PySequenceSize(PyObject* obj) { if (PyTuple_Check(obj)) return PyTuple_Size(obj); if (PyList_Check(obj)) return PyList_Size(obj); diff --git a/source/backend/arm82/Arm82Functions.cpp b/source/backend/arm82/Arm82Functions.cpp index 4082a1f2..b736e315 100644 --- a/source/backend/arm82/Arm82Functions.cpp +++ b/source/backend/arm82/Arm82Functions.cpp @@ -197,7 +197,7 @@ static void Sme2MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h, auto stride0 = ROUND_UP(ic, LP) * kernelsize * HP; auto stride1 = HP * ROUND_UP(ic, LP); auto stride2 = HP * LP; - + size_t srcStride0 = l; // [h,k2,ic]->[hu,k2,ic/lp,hp,lp] size_t srcStride1 = 1; if (!transpose) { // [k2,ic,h]->[hu,k2,ic/lp,hp,lp] @@ -218,7 +218,7 @@ static void Sme2MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h, } static void MNNScaleAndAddBiasFP16(FLOAT16* dst, const FLOAT16* src, const FLOAT16* bias, const FLOAT16* alpha, size_t planeNumber, - size_t biasNumber) { + size_t biasNumber) { for (int z = 0; z < biasNumber; ++z) { FLOAT16* dstZ = dst + planeNumber * 8 * z; const FLOAT16* srcZ = src + planeNumber * 8 * z; @@ -240,7 +240,7 @@ static void MNNGridSampleComputeCordFP16(FLOAT16* dst, const FLOAT16* src, size_ float16x8_t b = alignCorners ? zero : one; float16x8_t inW_sub_a = vsubq_f16(vdupq_n_f16(inW), a); float16x8_t inH_sub_a = vsubq_f16(vdupq_n_f16(inH), a); - + int area = outH * outW; int areaC8 = area / 8; int areaRemain = area - areaC8 * 8; @@ -251,14 +251,14 @@ static void MNNGridSampleComputeCordFP16(FLOAT16* dst, const FLOAT16* src, size_ cordH.val[0] = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, cordH.val[0]), inW_sub_a), b)); cordH.val[1] = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, cordH.val[1]), inH_sub_a), b)); vst2q_f16(dst, cordH); - + src += 16; dst += 16; } if (areaRemain == 0) { return; } - + // areaRemain FLOAT16 tempDst[16]; ::memcpy(tempDst, src, areaRemain * 2 * sizeof(int16_t)); @@ -281,7 +281,7 @@ static void MNNGridSampleComputeCord3DFp16(FLOAT* dst, const FLOAT* src, size_t size_t area = outH * outW * outD; size_t areaC8 = area / 8; size_t areaRemain = area - areaC8 * 8; - + for (int i = 0; i < areaC8; ++i) { auto cordH = vld3q_f16(src); // float16x8_t x = cordH.val[0]; @@ -296,7 +296,7 @@ static void MNNGridSampleComputeCord3DFp16(FLOAT* dst, const FLOAT* src, size_t if (areaRemain == 0) { return; } - + // areaRemain FLOAT16 tempDst[24]; ::memcpy(tempDst, src, areaRemain * 3 * sizeof(int16_t)); @@ -326,7 +326,7 @@ static void MNNRoiAlignMaxFP16(FLOAT16* dst, const FLOAT16* src, const std::vect for (int i = 0; i < samplingRatioArea; ++i) { const std::vector& pos = vecPos[preCalcIdx]; const std::vector& area = vecArea[preCalcIdx]; - + Vec val0 = Vec::load(src + pos[0] * 8); Vec val1 = Vec::load(src + pos[1] * 8); Vec val2 = Vec::load(src + pos[2] * 8); @@ -352,7 +352,7 @@ static void MNNRoiAlignAvgFP16(FLOAT16* dst, const FLOAT16* src, const std::vect for (int i = 0; i < samplingRatioArea; ++i) { const std::vector& pos = vecPos[preCalcIdx]; const std::vector& area = vecArea[preCalcIdx]; - + Vec val0 = Vec::load(src + pos[0] * 8); Vec val1 = Vec::load(src + pos[1] * 8); Vec val2 = Vec::load(src + pos[2] * 8); @@ -412,7 +412,7 @@ static void MNNAxByClampBroadcastC8FP16(float* CF, const float* AF, const float* } void ARM82StrassenMerge(FLOAT16* c11, FLOAT16* c12, FLOAT16* c21, FLOAT16* c22, FLOAT16* xAddr, - size_t cStride, size_t eSub, size_t hSub) { + size_t cStride, size_t eSub, size_t hSub) { const int pack = 8; for (int y = 0; y < hSub; ++y) { auto c11Y = c11 + y * cStride; @@ -445,7 +445,7 @@ void MNNUnpackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, si int cAlign = cDiv4 * 8; int areaDiv4 = area / 4; int areaAlign = areaDiv4 * 4; - + if (areaAlign > 0) { for (int ci = 0; ci < cDiv4; ++ci) { auto srcH = src + ci * 8 * srcAreaOffset; @@ -481,15 +481,15 @@ void MNNUnpackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, si if (c == cAlign) { return; } - + int cReamin = c - cAlign; auto srcAlign = src + srcAreaOffset * cAlign; auto dstAlign = dst + cAlign; - + for (int hi = 0; hi < area; ++hi) { auto srcHeight = srcAlign + hi * 8; auto dstHeight = dstAlign + hi * c; - + for (int ci = 0; ci < cReamin; ++ci) { dstHeight[ci] = srcHeight[ci]; } @@ -538,15 +538,15 @@ void MNNPackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, size } } } - + if (cAlign == c) { return; } - + int cReamin = c - cAlign; auto srcAlign = src + cAlign; auto dstAlign = dst + dstAreaOffset * cAlign; - + for (int hi = 0; hi < area; ++hi) { auto srcHeight = srcAlign + hi * c; auto dstHeight = dstAlign + hi * 8; @@ -560,7 +560,7 @@ void MNNPackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, size } static void _MNNDeconvRunForUnitDepthWise(const FLOAT16* dst, FLOAT16* src, const FLOAT16* weight, size_t fw, size_t fh, - size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) { + size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) { int fx, fy; auto src_z = src; auto weight_z = weight; @@ -576,7 +576,7 @@ static void _MNNDeconvRunForUnitDepthWise(const FLOAT16* dst, FLOAT16* src, cons } } static void _MNNDeconvRunForLineDepthwise(const FLOAT16* dst, FLOAT16* src, const FLOAT16* weight, size_t width, size_t src_w_setup, - size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step) { + size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step) { int dx; for (dx = 0; dx < width; ++dx) { auto dst_x = dst + dx * 8; @@ -703,7 +703,7 @@ static void _MNNComputeMatMulForE_1_FP16(const float* AF, const float* BF, float Vec::save(C + 4 * 8 * y + 8 * 2, s2); Vec::save(C + 4 * 8 * y + 8 * 3, s3); } - + for (int y=hC16*4+tId; y 11) { auto s0 = vld1q_f32((float*)(source)); // 00112233 auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233 auto s2 = vld1q_f32((float*)(source + 2 * srcStride0)); auto s3 = vld1q_f32((float*)(source + 3 * srcStride0)); - + auto s4 = vld1q_f32((float*)(source + 4 * srcStride0)); auto s5 = vld1q_f32((float*)(source + 5 * srcStride0)); auto s6 = vld1q_f32((float*)(source + 6 * srcStride0)); auto s7 = vld1q_f32((float*)(source + 7 * srcStride0)); - + auto s8 = vld1q_f32((float*)(source + 8 * srcStride0)); auto s9 = vld1q_f32((float*)(source + 9 * srcStride0)); auto s10 = vld1q_f32((float*)(source + 10 * srcStride0)); auto s11 = vld1q_f32((float*)(source + 11 * srcStride0)); - + auto zip1s01 = vzip1q_f32(s0, s1); // 00001111 auto zip1s23 = vzip1q_f32(s2, s3); // 00001111 auto zip1s45 = vzip1q_f32(s4, s5); // 00001111 auto zip1s67 = vzip1q_f32(s6, s7); // 00001111 auto zip1s89 = vzip1q_f32(s8, s9); // 00001111 auto zip1s1011 = vzip1q_f32(s10, s11); // 00001111 - + auto zip2s01 = vzip2q_f32(s0, s1); // 22223333 auto zip2s23 = vzip2q_f32(s2, s3); // 22223333 auto zip2s45 = vzip2q_f32(s4, s5); // 22223333 auto zip2s67 = vzip2q_f32(s6, s7); // 22223333 auto zip2s89 = vzip2q_f32(s8, s9); // 22223333 auto zip2s1011 = vzip2q_f32(s10, s11); // 22223333 - + auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000 auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67); auto zip1s891011_01 = vzip1q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011); - + auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111 auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67); auto zip2s891011_01 = vzip2q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011); - + auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222 auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67); auto zip1s891011_23 = vzip1q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011); - + auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333 auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67); auto zip2s891011_23 = vzip2q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011); - + vst1q_f64((float64_t*)dest, zip1s0123_01); vst1q_f64((float64_t*)(dest + 8), zip1s4567_01); vst1q_f64((float64_t*)(dest + 16), zip1s891011_01); - + vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01); vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01); vst1q_f64((float64_t*)(dest + dstStride0 + 16), zip2s891011_01); - + vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23); vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23); vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 16), zip1s891011_23); - + vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23); vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23); vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 16), zip2s891011_23); - + dest += 24; e -= 12; source += (12 * srcStride0); } - + if (e > 7) { auto s0 = vld1q_f32((float*)(source)); // 00112233 auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233 auto s2 = vld1q_f32((float*)(source + 2 * srcStride0)); auto s3 = vld1q_f32((float*)(source + 3 * srcStride0)); - + auto s4 = vld1q_f32((float*)(source + 4 * srcStride0)); auto s5 = vld1q_f32((float*)(source + 5 * srcStride0)); auto s6 = vld1q_f32((float*)(source + 6 * srcStride0)); auto s7 = vld1q_f32((float*)(source + 7 * srcStride0)); - + auto zip1s01 = vzip1q_f32(s0, s1); // 00001111 auto zip1s23 = vzip1q_f32(s2, s3); // 00001111 auto zip1s45 = vzip1q_f32(s4, s5); // 00001111 auto zip1s67 = vzip1q_f32(s6, s7); // 00001111 - + auto zip2s01 = vzip2q_f32(s0, s1); // 22223333 auto zip2s23 = vzip2q_f32(s2, s3); // 22223333 auto zip2s45 = vzip2q_f32(s4, s5); // 22223333 auto zip2s67 = vzip2q_f32(s6, s7); // 22223333 - + auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000 auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67); - + auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111 auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67); - + auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222 auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67); - + auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333 auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67); - + vst1q_f64((float64_t*)dest, zip1s0123_01); vst1q_f64((float64_t*)(dest + 8), zip1s4567_01); - + vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01); vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01); - + vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23); vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23); - + vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23); vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23); - + dest += 16; e -= 8; source += (8 * srcStride0); } - + if (e > 3) { auto s0 = vld1q_f32((float*)(source)); // 00112233 auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233 auto s2 = vld1q_f32((float*)(source + 2 * srcStride0)); auto s3 = vld1q_f32((float*)(source + 3 * srcStride0)); - + auto zip1s01 = vzip1q_f32(s0, s1); // 00001111 auto zip1s23 = vzip1q_f32(s2, s3); // 00001111 - + auto zip2s01 = vzip2q_f32(s0, s1); // 22223333 auto zip2s23 = vzip2q_f32(s2, s3); // 22223333 - + auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000 - + auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111 - + auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222 - + auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333 - + vst1q_f64((float64_t*)dest, zip1s0123_01); vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01); vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23); vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23); - + dest += 8; e -= 4; source += (4 * srcStride0); } while (e > 0) { auto s0 = vld1q_f32((float*)(source)); // 00112233 - + ((float*)dest)[0] = s0[0]; ((float*)(dest + dstStride0))[0] = s0[1]; ((float*)(dest + 2 * dstStride0))[0] = s0[2]; ((float*)(dest + 3 * dstStride0))[0] = s0[3]; - + dest += 2; e -= 1; source += srcStride0; @@ -1201,7 +1201,7 @@ static void Sme2MNNPackForMatMul_A_FP16(float* destOrigin, float const** sourceG sourceN += (eReal * pack); destN += (4 * dstStride0); } // l>7 - + if (l > 3) { auto source = sourceN; auto dest = destN; @@ -1212,22 +1212,22 @@ static void Sme2MNNPackForMatMul_A_FP16(float* destOrigin, float const** sourceG auto s1 = vld1_f32((float*)(source + srcStride0));// 0011 auto s2 = vld1_f32((float*)(source + 2 * srcStride0)); auto s3 = vld1_f32((float*)(source + 3 * srcStride0)); - + auto s4 = vld1_f32((float*)(source + 4 * srcStride0)); auto s5 = vld1_f32((float*)(source + 5 * srcStride0)); auto s6 = vld1_f32((float*)(source + 6 * srcStride0)); auto s7 = vld1_f32((float*)(source + 7 * srcStride0)); - + auto s8 = vld1_f32((float*)(source + 8 * srcStride0)); auto s9 = vld1_f32((float*)(source + 9 * srcStride0)); auto s10 = vld1_f32((float*)(source + 10 * srcStride0)); auto s11 = vld1_f32((float*)(source + 11 * srcStride0)); - + auto s12 = vld1_f32((float*)(source + 12 * srcStride0)); auto s13 = vld1_f32((float*)(source + 13 * srcStride0)); auto s14 = vld1_f32((float*)(source + 14 * srcStride0)); auto s15 = vld1_f32((float*)(source + 15 * srcStride0)); - + auto zip1s01 = vzip1_f32(s0, s1); // 0000 auto zip1s23 = vzip1_f32(s2, s3); // 0000 auto zip1s45 = vzip1_f32(s4, s5); // 0000 @@ -1236,7 +1236,7 @@ static void Sme2MNNPackForMatMul_A_FP16(float* destOrigin, float const** sourceG auto zip1s1011 = vzip1_f32(s10, s11); // 0000 auto zip1s1213 = vzip1_f32(s12, s13); // 0000 auto zip1s1415 = vzip1_f32(s14, s15); // 0000 - + auto zip2s01 = vzip2_f32(s0, s1); // 1111 auto zip2s23 = vzip2_f32(s2, s3); // 1111 auto zip2s45 = vzip2_f32(s4, s5); // 1111 @@ -1245,7 +1245,7 @@ static void Sme2MNNPackForMatMul_A_FP16(float* destOrigin, float const** sourceG auto zip2s1011 = vzip2_f32(s10, s11); // 1111 auto zip2s1213 = vzip2_f32(s12, s13); // 1111 auto zip2s1415 = vzip2_f32(s14, s15); // 1111 - + vst1_f32((float32_t*)dest, zip1s01); vst1_f32((float32_t*)(dest + 4), zip1s23); vst1_f32((float32_t*)(dest + 8), zip1s45); @@ -1254,7 +1254,7 @@ static void Sme2MNNPackForMatMul_A_FP16(float* destOrigin, float const** sourceG vst1_f32((float32_t*)(dest + 20), zip1s1011); vst1_f32((float32_t*)(dest + 24), zip1s1213); vst1_f32((float32_t*)(dest + 28), zip1s1415); - + vst1_f32((float32_t*)(dest + dstStride0), zip2s01); vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23); vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45); @@ -1263,115 +1263,115 @@ static void Sme2MNNPackForMatMul_A_FP16(float* destOrigin, float const** sourceG vst1_f32((float32_t*)(dest + dstStride0 + 20), zip2s1011); vst1_f32((float32_t*)(dest + dstStride0 + 24), zip2s1213); vst1_f32((float32_t*)(dest + dstStride0 + 28), zip2s1415); - - + + dest += 32; e -= eDest; } - + if (e > 11) { auto s0 = vld1_f32((float*)(source)); // 0011 auto s1 = vld1_f32((float*)(source + srcStride0));// 0011 auto s2 = vld1_f32((float*)(source + 2 * srcStride0)); auto s3 = vld1_f32((float*)(source + 3 * srcStride0)); - + auto s4 = vld1_f32((float*)(source + 4 * srcStride0)); auto s5 = vld1_f32((float*)(source + 5 * srcStride0)); auto s6 = vld1_f32((float*)(source + 6 * srcStride0)); auto s7 = vld1_f32((float*)(source + 7 * srcStride0)); - + auto s8 = vld1_f32((float*)(source + 8 * srcStride0)); auto s9 = vld1_f32((float*)(source + 9 * srcStride0)); auto s10 = vld1_f32((float*)(source + 10 * srcStride0)); auto s11 = vld1_f32((float*)(source + 11 * srcStride0)); - + auto zip1s01 = vzip1_f32(s0, s1); // 0000 auto zip1s23 = vzip1_f32(s2, s3); // 0000 auto zip1s45 = vzip1_f32(s4, s5); // 0000 auto zip1s67 = vzip1_f32(s6, s7); // 0000 auto zip1s89 = vzip1_f32(s8, s9); // 0000 auto zip1s1011 = vzip1_f32(s10, s11); // 0000 - + auto zip2s01 = vzip2_f32(s0, s1); // 1111 auto zip2s23 = vzip2_f32(s2, s3); // 1111 auto zip2s45 = vzip2_f32(s4, s5); // 1111 auto zip2s67 = vzip2_f32(s6, s7); // 1111 auto zip2s89 = vzip2_f32(s8, s9); // 1111 auto zip2s1011 = vzip2_f32(s10, s11); // 1111 - + vst1_f32((float32_t*)dest, zip1s01); vst1_f32((float32_t*)(dest + 4), zip1s23); vst1_f32((float32_t*)(dest + 8), zip1s45); vst1_f32((float32_t*)(dest + 12), zip1s67); vst1_f32((float32_t*)(dest + 16), zip1s89); vst1_f32((float32_t*)(dest + 20), zip1s1011); - + vst1_f32((float32_t*)(dest + dstStride0), zip2s01); vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23); vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45); vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67); vst1_f32((float32_t*)(dest + dstStride0 + 16), zip2s89); vst1_f32((float32_t*)(dest + dstStride0 + 20), zip2s1011); - + dest += 24; e -= 12; source += (12 * srcStride0); } - + if (e > 7) { auto s0 = vld1_f32((float*)(source)); // 0011 auto s1 = vld1_f32((float*)(source + srcStride0));// 0011 auto s2 = vld1_f32((float*)(source + 2 * srcStride0)); auto s3 = vld1_f32((float*)(source + 3 * srcStride0)); - + auto s4 = vld1_f32((float*)(source + 4 * srcStride0)); auto s5 = vld1_f32((float*)(source + 5 * srcStride0)); auto s6 = vld1_f32((float*)(source + 6 * srcStride0)); auto s7 = vld1_f32((float*)(source + 7 * srcStride0)); - + auto zip1s01 = vzip1_f32(s0, s1); // 0000 auto zip1s23 = vzip1_f32(s2, s3); // 0000 auto zip1s45 = vzip1_f32(s4, s5); // 0000 auto zip1s67 = vzip1_f32(s6, s7); // 0000 - + auto zip2s01 = vzip2_f32(s0, s1); // 1111 auto zip2s23 = vzip2_f32(s2, s3); // 1111 auto zip2s45 = vzip2_f32(s4, s5); // 1111 auto zip2s67 = vzip2_f32(s6, s7); // 1111 - + vst1_f32((float32_t*)dest, zip1s01); vst1_f32((float32_t*)(dest + 4), zip1s23); vst1_f32((float32_t*)(dest + 8), zip1s45); vst1_f32((float32_t*)(dest + 12), zip1s67); - + vst1_f32((float32_t*)(dest + dstStride0), zip2s01); vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23); vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45); vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67); - + dest += 16; e -= 8; source += (8 * srcStride0); } - + if (e > 3) { auto s0 = vld1_f32((float*)(source)); // 0011 auto s1 = vld1_f32((float*)(source + srcStride0));// 0011 auto s2 = vld1_f32((float*)(source + 2 * srcStride0)); auto s3 = vld1_f32((float*)(source + 3 * srcStride0)); - + auto zip1s01 = vzip1_f32(s0, s1); // 0000 auto zip1s23 = vzip1_f32(s2, s3); // 0000 - + auto zip2s01 = vzip2_f32(s0, s1); // 1111 auto zip2s23 = vzip2_f32(s2, s3); // 1111 - + vst1_f32((float32_t*)dest, zip1s01); vst1_f32((float32_t*)(dest + 4), zip1s23); - + vst1_f32((float32_t*)(dest + dstStride0), zip2s01); vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23); - + dest += 8; e -= 4; source += (4 * srcStride0); @@ -1379,22 +1379,22 @@ static void Sme2MNNPackForMatMul_A_FP16(float* destOrigin, float const** sourceG if (e > 1) { auto s0 = vld1_f32((float*)(source)); // 0011 auto s1 = vld1_f32((float*)(source + srcStride0));// 0011 - + auto zip1s01 = vzip1_f32(s0, s1); // 0000 - + auto zip2s01 = vzip2_f32(s0, s1); // 1111 - + vst1_f32((float32_t*)dest, zip1s01); - + vst1_f32((float32_t*)(dest + dstStride0), zip2s01); - + dest += 4; e -= 2; source += (2 * srcStride0); } if (e > 0) { auto s0 = vld1_f32((float*)(source)); // 0011 - + ((float*)dest)[0] = s0[0]; ((float*)(dest + dstStride0))[0] = s0[1]; } @@ -1423,6 +1423,730 @@ static void Sme2MNNPackForMatMul_A_FP16(float* destOrigin, float const** sourceG } } +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE +static void MNNAttenPackAndScaleSingleHead(float* dst, const float* srcHeadBase, size_t srcRowStride, const float* scale, const int32_t* units, size_t seqLen, size_t headDim) { + const int32_t eP = units[0]; + const int32_t lP = units[1]; + + if (lP != 1 && lP != 2) { + MNN_ERROR("This function only supports lP=1 or 2\n"); + return; + } + + const float scaleVal = scale[0]; + const float16x8_t vScale = vdupq_n_f16(scaleVal); + + const size_t packedHeadDim = UP_DIV(headDim, lP); + const size_t dstStrideDOuter = (size_t)eP * lP; + const size_t dstStrideSOuter = packedHeadDim * dstStrideDOuter; + + for (int s = 0; s < seqLen; ++s) { + const int sOuter = s / eP; + const int sInner = s % eP; + const FLOAT16* srcRowPtr = (FLOAT16*)srcHeadBase + s * srcRowStride; + FLOAT16* dstBasePtr = (FLOAT16*)dst + sOuter * dstStrideSOuter + sInner * lP; + + if (lP == 1) { + size_t d = 0; + for (; d + 7 < headDim; d += 8) { + float16x8_t sVec = vld1q_f16(srcRowPtr + d); + sVec = vmulq_f16(sVec, vScale); + + dstBasePtr[(d + 0) * dstStrideDOuter] = sVec[0]; + dstBasePtr[(d + 1) * dstStrideDOuter] = sVec[1]; + dstBasePtr[(d + 2) * dstStrideDOuter] = sVec[2]; + dstBasePtr[(d + 3) * dstStrideDOuter] = sVec[3]; + dstBasePtr[(d + 4) * dstStrideDOuter] = sVec[4]; + dstBasePtr[(d + 5) * dstStrideDOuter] = sVec[5]; + dstBasePtr[(d + 6) * dstStrideDOuter] = sVec[6]; + dstBasePtr[(d + 7) * dstStrideDOuter] = sVec[7]; + } + for (; d < headDim; ++d) { + dstBasePtr[d * dstStrideDOuter] = srcRowPtr[d] * scaleVal; + } + } else { // lP == 2 + const FLOAT16* srcDPtr = srcRowPtr; + FLOAT16* dstDPtr = dstBasePtr; + size_t dRealSize = headDim; + + while (dRealSize >= 16) { + float16x8_t s0 = vld1q_f16(srcDPtr); + float16x8_t s1 = vld1q_f16(srcDPtr + 8); + s0 = vmulq_f16(s0, vScale); + s1 = vmulq_f16(s1, vScale); + + float16x4_t lowS0_f16 = vget_low_f16(s0); // {s0, s1, s2, s3} + float16x4_t highS0_f16 = vget_high_f16(s0); // {s4, s5, s6, s7} + uint32x2_t lowS0_u32 = vreinterpret_u32_f16(lowS0_f16); + uint32x2_t highS0_u32 = vreinterpret_u32_f16(highS0_f16); + + *((uint32_t*)(dstDPtr + 0 * dstStrideDOuter)) = vget_lane_u32(lowS0_u32, 0); // Store pair {s0, s1} + *((uint32_t*)(dstDPtr + 1 * dstStrideDOuter)) = vget_lane_u32(lowS0_u32, 1); // Store pair {s2, s3} + *((uint32_t*)(dstDPtr + 2 * dstStrideDOuter)) = vget_lane_u32(highS0_u32, 0); // Store pair {s4, s5} + *((uint32_t*)(dstDPtr + 3 * dstStrideDOuter)) = vget_lane_u32(highS0_u32, 1); // Store pair {s6, s7} + + float16x4_t lowS1_f16 = vget_low_f16(s1); // {s8, s9, s10, s11} + float16x4_t highS1_f16 = vget_high_f16(s1); // {s12, s13, s14, s15} + uint32x2_t lowS1_u32 = vreinterpret_u32_f16(lowS1_f16); + uint32x2_t highS1_u32 = vreinterpret_u32_f16(highS1_f16); + + *((uint32_t*)(dstDPtr + 4 * dstStrideDOuter)) = vget_lane_u32(lowS1_u32, 0); + *((uint32_t*)(dstDPtr + 5 * dstStrideDOuter)) = vget_lane_u32(lowS1_u32, 1); + *((uint32_t*)(dstDPtr + 6 * dstStrideDOuter)) = vget_lane_u32(highS1_u32, 0); + *((uint32_t*)(dstDPtr + 7 * dstStrideDOuter)) = vget_lane_u32(highS1_u32, 1); + + dRealSize -= 16; + srcDPtr += 16; + dstDPtr += 8 * dstStrideDOuter; + } + // Remainder loop with padding + while (dRealSize > 0) { + if (dRealSize >= 2) { + dstDPtr[0] = srcDPtr[0] * scaleVal; + dstDPtr[1] = srcDPtr[1] * scaleVal; + + dRealSize -= 2; + srcDPtr += 2; + dstDPtr += dstStrideDOuter; + } else { // dRealSize == 1 + dstDPtr[0] = srcDPtr[0] * scaleVal; + dstDPtr[1] = (FLOAT16)0.0f; // Pad with zero + dRealSize = 0; + } + } + } + } +} + +static void MNNFlashAttentionUpdateBlockOutput( float* dst, float* src, const float* scale, const float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) { + auto dstPtr = (float16_t*)dst; + auto srcPtr = (float16_t*)src; + const auto stride0 = plane * pack; + + if (idx == 0) { + memcpy(dst, src, size * bytes); + } else { + for (int j = 0; j < depthQuad; ++j) { + const auto baseOffset = j * stride0; + int i = 0; + const int plane4 = plane - (plane % 4); + for (; i < plane4; i += 4) { + + auto pdst0 = dstPtr + baseOffset + (i + 0) * pack; + auto psrc0 = srcPtr + baseOffset + (i + 0) * pack; + auto pdst1 = dstPtr + baseOffset + (i + 1) * pack; + auto psrc1 = srcPtr + baseOffset + (i + 1) * pack; + auto pdst2 = dstPtr + baseOffset + (i + 2) * pack; + auto psrc2 = srcPtr + baseOffset + (i + 2) * pack; + auto pdst3 = dstPtr + baseOffset + (i + 3) * pack; + auto psrc3 = srcPtr + baseOffset + (i + 3) * pack; + + float16x8_t src0 = vld1q_f16(psrc0); + float16x8_t dst0 = vld1q_f16(pdst0); + float16x8_t src1 = vld1q_f16(psrc1); + float16x8_t dst1 = vld1q_f16(pdst1); + float16x8_t src2 = vld1q_f16(psrc2); + float16x8_t dst2 = vld1q_f16(pdst2); + float16x8_t src3 = vld1q_f16(psrc3); + float16x8_t dst3 = vld1q_f16(pdst3); + + float32x4_t svec0 = vdupq_n_f32(scale[i + 0]); + float32x4_t svec1 = vdupq_n_f32(scale[i + 1]); + float32x4_t svec2 = vdupq_n_f32(scale[i + 2]); + float32x4_t svec3 = vdupq_n_f32(scale[i + 3]); + + + float32x4_t res00 = vfmaq_f32(vcvt_f32_f16(vget_low_f16(src0)), vcvt_f32_f16(vget_low_f16(dst0)), svec0); + float32x4_t res10 = vfmaq_f32(vcvt_f32_f16(vget_high_f16(src0)), vcvt_f32_f16(vget_high_f16(dst0)), svec0); + + float32x4_t res01 = vfmaq_f32(vcvt_f32_f16(vget_low_f16(src1)), vcvt_f32_f16(vget_low_f16(dst1)), svec1); + float32x4_t res11 = vfmaq_f32(vcvt_f32_f16(vget_high_f16(src1)), vcvt_f32_f16(vget_high_f16(dst1)), svec1); + + float32x4_t res02 = vfmaq_f32(vcvt_f32_f16(vget_low_f16(src2)), vcvt_f32_f16(vget_low_f16(dst2)), svec2); + float32x4_t res12 = vfmaq_f32(vcvt_f32_f16(vget_high_f16(src2)), vcvt_f32_f16(vget_high_f16(dst2)), svec2); + + float32x4_t res03 = vfmaq_f32(vcvt_f32_f16(vget_low_f16(src3)), vcvt_f32_f16(vget_low_f16(dst3)), svec3); + float32x4_t res13 = vfmaq_f32(vcvt_f32_f16(vget_high_f16(src3)), vcvt_f32_f16(vget_high_f16(dst3)), svec3); + + vst1q_f16(pdst0, vcombine_f16(vcvt_f16_f32(res00), vcvt_f16_f32(res10))); + vst1q_f16(pdst1, vcombine_f16(vcvt_f16_f32(res01), vcvt_f16_f32(res11))); + vst1q_f16(pdst2, vcombine_f16(vcvt_f16_f32(res02), vcvt_f16_f32(res12))); + vst1q_f16(pdst3, vcombine_f16(vcvt_f16_f32(res03), vcvt_f16_f32(res13))); + } + + for (; i < plane; ++i) { + auto pdst = dstPtr + baseOffset + i * pack; + auto psrc = srcPtr + baseOffset + i * pack; + + float16x8_t srcF16 = vld1q_f16(psrc); + float16x8_t dstF16 = vld1q_f16(pdst); + float32x4_t svec = vdupq_n_f32(scale[i]); + + float32x4_t s0 = vcvt_f32_f16(vget_low_f16(srcF16)); + float32x4_t s1 = vcvt_f32_f16(vget_high_f16(srcF16)); + float32x4_t d0 = vcvt_f32_f16(vget_low_f16(dstF16)); + float32x4_t d1 = vcvt_f32_f16(vget_high_f16(dstF16)); + + float32x4_t res0 = vfmaq_f32(s0, d0, svec); + float32x4_t res1 = vfmaq_f32(s1, d1, svec); + + vst1q_f16(pdst, vcombine_f16(vcvt_f16_f32(res0), vcvt_f16_f32(res1))); + } + } + } + + if (idx == kvBlocks - 1) { + for (int j = 0; j < depthQuad; ++j) { + const auto baseOffset = j * stride0; + int i = 0; + const int plane4 = plane - (plane % 4); + for (; i < plane4; i += 4) { + auto pdst0 = dstPtr + baseOffset + (i + 0) * pack; + auto pdst1 = dstPtr + baseOffset + (i + 1) * pack; + auto pdst2 = dstPtr + baseOffset + (i + 2) * pack; + auto pdst3 = dstPtr + baseOffset + (i + 3) * pack; + + float16x8_t dst0 = vld1q_f16(pdst0); + float16x8_t dst1 = vld1q_f16(pdst1); + float16x8_t dst2 = vld1q_f16(pdst2); + float16x8_t dst3 = vld1q_f16(pdst3); + + float32x4_t ns0 = vdupq_n_f32(1.0f / normalizeScale[i + 0]); + float32x4_t ns1 = vdupq_n_f32(1.0f / normalizeScale[i + 1]); + float32x4_t ns2 = vdupq_n_f32(1.0f / normalizeScale[i + 2]); + float32x4_t ns3 = vdupq_n_f32(1.0f / normalizeScale[i + 3]); + + float32x4_t d00 = vmulq_f32(vcvt_f32_f16(vget_low_f16(dst0)), ns0); + float32x4_t d10 = vmulq_f32(vcvt_f32_f16(vget_high_f16(dst0)), ns0); + float32x4_t d01 = vmulq_f32(vcvt_f32_f16(vget_low_f16(dst1)), ns1); + float32x4_t d11 = vmulq_f32(vcvt_f32_f16(vget_high_f16(dst1)), ns1); + float32x4_t d02 = vmulq_f32(vcvt_f32_f16(vget_low_f16(dst2)), ns2); + float32x4_t d12 = vmulq_f32(vcvt_f32_f16(vget_high_f16(dst2)), ns2); + float32x4_t d03 = vmulq_f32(vcvt_f32_f16(vget_low_f16(dst3)), ns3); + float32x4_t d13 = vmulq_f32(vcvt_f32_f16(vget_high_f16(dst3)), ns3); + + vst1q_f16(pdst0, vcombine_f16(vcvt_f16_f32(d00), vcvt_f16_f32(d10))); + vst1q_f16(pdst1, vcombine_f16(vcvt_f16_f32(d01), vcvt_f16_f32(d11))); + vst1q_f16(pdst2, vcombine_f16(vcvt_f16_f32(d02), vcvt_f16_f32(d12))); + vst1q_f16(pdst3, vcombine_f16(vcvt_f16_f32(d03), vcvt_f16_f32(d13))); + } + + for (; i < plane; ++i) { + auto pdst = dstPtr + baseOffset + i * pack; + float32x4_t nsvec = vdupq_n_f32(1.0f / normalizeScale[i]); + + float16x8_t dstF16 = vld1q_f16(pdst); + float32x4_t d0 = vcvt_f32_f16(vget_low_f16(dstF16)); + float32x4_t d1 = vcvt_f32_f16(vget_high_f16(dstF16)); + + d0 = vmulq_f32(d0, nsvec); + d1 = vmulq_f32(d1, nsvec); + + vst1q_f16(pdst, vcombine_f16(vcvt_f16_f32(d0), vcvt_f16_f32(d1))); + } + } + } +} + +static void MNNAttenUnpackAndConvertFp16(float* dst, float* src, size_t depth, size_t planesize, int pack) { + // src: (UP_DIV(depth, pack), planesize, pack), float16 + // dst: (planesize, depth), float32 + // pack=8 + + if (planesize == 1) { + MNNDequantizeFP16((int16_t*)src, dst, depth); + return; // no need to convert + } + const auto depthDiv8 = UP_DIV(depth, pack); + const auto srcStep = pack * planesize; + const auto dstStep = depth; + + auto remainDepth = depth % pack; + auto depthQuad = depthDiv8; + if (remainDepth > 0) { + depthQuad -= 1; // last quad is not full + } + + for (int i = 0; i < depthQuad; ++i) { + auto realsize = planesize; + auto srcPtr = (FLOAT16*)src + i * srcStep; + auto dstPtr = (float*)dst + i * pack; + while (realsize >= 8) { + float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack); + float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack); + float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack); + float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack); + float16x8_t s4_f16 = vld1q_f16(srcPtr + 4 * pack); + float16x8_t s5_f16 = vld1q_f16(srcPtr + 5 * pack); + float16x8_t s6_f16 = vld1q_f16(srcPtr + 6 * pack); + float16x8_t s7_f16 = vld1q_f16(srcPtr + 7 * pack); + + float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16)); + float32x4_t d01_f32 = vcvt_f32_f16(vget_high_f16(s0_f16)); + float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16)); + float32x4_t d11_f32 = vcvt_f32_f16(vget_high_f16(s1_f16)); + float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16)); + float32x4_t d21_f32 = vcvt_f32_f16(vget_high_f16(s2_f16)); + float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16)); + float32x4_t d31_f32 = vcvt_f32_f16(vget_high_f16(s3_f16)); + float32x4_t d40_f32 = vcvt_f32_f16(vget_low_f16(s4_f16)); + float32x4_t d41_f32 = vcvt_f32_f16(vget_high_f16(s4_f16)); + float32x4_t d50_f32 = vcvt_f32_f16(vget_low_f16(s5_f16)); + float32x4_t d51_f32 = vcvt_f32_f16(vget_high_f16(s5_f16)); + float32x4_t d60_f32 = vcvt_f32_f16(vget_low_f16(s6_f16)); + float32x4_t d61_f32 = vcvt_f32_f16(vget_high_f16(s6_f16)); + float32x4_t d70_f32 = vcvt_f32_f16(vget_low_f16(s7_f16)); + float32x4_t d71_f32 = vcvt_f32_f16(vget_high_f16(s7_f16)); + + vst1q_f32(dstPtr + 0 * dstStep, d00_f32); vst1q_f32(dstPtr + 0 * dstStep + 4, d01_f32); + vst1q_f32(dstPtr + 1 * dstStep, d10_f32); vst1q_f32(dstPtr + 1 * dstStep + 4, d11_f32); + vst1q_f32(dstPtr + 2 * dstStep, d20_f32); vst1q_f32(dstPtr + 2 * dstStep + 4, d21_f32); + vst1q_f32(dstPtr + 3 * dstStep, d30_f32); vst1q_f32(dstPtr + 3 * dstStep + 4, d31_f32); + vst1q_f32(dstPtr + 4 * dstStep, d40_f32); vst1q_f32(dstPtr + 4 * dstStep + 4, d41_f32); + vst1q_f32(dstPtr + 5 * dstStep, d50_f32); vst1q_f32(dstPtr + 5 * dstStep + 4, d51_f32); + vst1q_f32(dstPtr + 6 * dstStep, d60_f32); vst1q_f32(dstPtr + 6 * dstStep + 4, d61_f32); + vst1q_f32(dstPtr + 7 * dstStep, d70_f32); vst1q_f32(dstPtr + 7 * dstStep + 4, d71_f32); + + srcPtr += 8 * pack; + dstPtr += 8 * dstStep; + realsize -= 8; + } + if (realsize >= 4) { + float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack); + float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack); + float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack); + float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack); + + float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16)); + float32x4_t d01_f32 = vcvt_f32_f16(vget_high_f16(s0_f16)); + float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16)); + float32x4_t d11_f32 = vcvt_f32_f16(vget_high_f16(s1_f16)); + float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16)); + float32x4_t d21_f32 = vcvt_f32_f16(vget_high_f16(s2_f16)); + float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16)); + float32x4_t d31_f32 = vcvt_f32_f16(vget_high_f16(s3_f16)); + + vst1q_f32(dstPtr + 0 * dstStep, d00_f32); vst1q_f32(dstPtr + 0 * dstStep + 4, d01_f32); + vst1q_f32(dstPtr + 1 * dstStep, d10_f32); vst1q_f32(dstPtr + 1 * dstStep + 4, d11_f32); + vst1q_f32(dstPtr + 2 * dstStep, d20_f32); vst1q_f32(dstPtr + 2 * dstStep + 4, d21_f32); + vst1q_f32(dstPtr + 3 * dstStep, d30_f32); vst1q_f32(dstPtr + 3 * dstStep + 4, d31_f32); + + srcPtr += 4 * pack; + dstPtr += 4 * dstStep; + realsize -= 4; + } + while (realsize > 0) { + auto s0_fp16 = vld1q_f16(srcPtr); + auto s00_fp32 = vcvt_f32_f16(vget_low_f16(s0_fp16)); + auto s01_fp32 = vcvt_f32_f16(vget_high_f16(s0_fp16)); + vst1q_f32(dstPtr, s00_fp32); + vst1q_f32(dstPtr + 4, s01_fp32); + srcPtr += pack; + dstPtr += dstStep; + realsize--; + } + } + + // process remain depth < 8 + if (remainDepth >= 4) { + auto realsize = planesize; + auto srcPtr = (FLOAT16*)src + (depthDiv8 - 1) * srcStep; + auto dstPtr = (float*)dst + (depthDiv8 - 1) * pack; + auto extraDepth = remainDepth - 4; + + float tmp0[4]; + float tmp1[4]; + float tmp2[4]; + float tmp3[4]; + float tmp4[4]; + float tmp5[4]; + float tmp6[4]; + float tmp7[4]; + + while (realsize >= 8) { + float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack); + float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack); + float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack); + float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack); + float16x8_t s4_f16 = vld1q_f16(srcPtr + 4 * pack); + float16x8_t s5_f16 = vld1q_f16(srcPtr + 5 * pack); + float16x8_t s6_f16 = vld1q_f16(srcPtr + 6 * pack); + float16x8_t s7_f16 = vld1q_f16(srcPtr + 7 * pack); + + + float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16)); + float32x4_t d01_f32 = vcvt_f32_f16(vget_high_f16(s0_f16)); + float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16)); + float32x4_t d11_f32 = vcvt_f32_f16(vget_high_f16(s1_f16)); + float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16)); + float32x4_t d21_f32 = vcvt_f32_f16(vget_high_f16(s2_f16)); + float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16)); + float32x4_t d31_f32 = vcvt_f32_f16(vget_high_f16(s3_f16)); + float32x4_t d40_f32 = vcvt_f32_f16(vget_low_f16(s4_f16)); + float32x4_t d41_f32 = vcvt_f32_f16(vget_high_f16(s4_f16)); + float32x4_t d50_f32 = vcvt_f32_f16(vget_low_f16(s5_f16)); + float32x4_t d51_f32 = vcvt_f32_f16(vget_high_f16(s5_f16)); + float32x4_t d60_f32 = vcvt_f32_f16(vget_low_f16(s6_f16)); + float32x4_t d61_f32 = vcvt_f32_f16(vget_high_f16(s6_f16)); + float32x4_t d70_f32 = vcvt_f32_f16(vget_low_f16(s7_f16)); + float32x4_t d71_f32 = vcvt_f32_f16(vget_high_f16(s7_f16)); + + vst1q_f32(dstPtr + 0 * dstStep, d00_f32); vst1q_f32(tmp0, d01_f32); + vst1q_f32(dstPtr + 1 * dstStep, d10_f32); vst1q_f32(tmp1, d11_f32); + vst1q_f32(dstPtr + 2 * dstStep, d20_f32); vst1q_f32(tmp2, d21_f32); + vst1q_f32(dstPtr + 3 * dstStep, d30_f32); vst1q_f32(tmp3, d31_f32); + vst1q_f32(dstPtr + 4 * dstStep, d40_f32); vst1q_f32(tmp4, d41_f32); + vst1q_f32(dstPtr + 5 * dstStep, d50_f32); vst1q_f32(tmp5, d51_f32); + vst1q_f32(dstPtr + 6 * dstStep, d60_f32); vst1q_f32(tmp6, d61_f32); + vst1q_f32(dstPtr + 7 * dstStep, d70_f32); vst1q_f32(tmp7, d71_f32); + + memcpy(dstPtr + 0 * dstStep + 4, tmp0, sizeof(float) * extraDepth); + memcpy(dstPtr + 1 * dstStep + 4, tmp1, sizeof(float) * extraDepth); + memcpy(dstPtr + 2 * dstStep + 4, tmp2, sizeof(float) * extraDepth); + memcpy(dstPtr + 3 * dstStep + 4, tmp3, sizeof(float) * extraDepth); + memcpy(dstPtr + 4 * dstStep + 4, tmp4, sizeof(float) * extraDepth); + memcpy(dstPtr + 5 * dstStep + 4, tmp5, sizeof(float) * extraDepth); + memcpy(dstPtr + 6 * dstStep + 4, tmp6, sizeof(float) * extraDepth); + memcpy(dstPtr + 7 * dstStep + 4, tmp7, sizeof(float) * extraDepth); + + srcPtr += 8 * pack; + dstPtr += 8 * dstStep; + realsize -= 8; + } + if (realsize >= 4) { + float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack); + float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack); + float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack); + float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack); + + float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16)); + float32x4_t d01_f32 = vcvt_f32_f16(vget_high_f16(s0_f16)); + float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16)); + float32x4_t d11_f32 = vcvt_f32_f16(vget_high_f16(s1_f16)); + float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16)); + float32x4_t d21_f32 = vcvt_f32_f16(vget_high_f16(s2_f16)); + float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16)); + float32x4_t d31_f32 = vcvt_f32_f16(vget_high_f16(s3_f16)); + + vst1q_f32(dstPtr + 0 * dstStep, d00_f32); vst1q_f32(tmp0, d01_f32); + vst1q_f32(dstPtr + 1 * dstStep, d10_f32); vst1q_f32(tmp1, d11_f32); + vst1q_f32(dstPtr + 2 * dstStep, d20_f32); vst1q_f32(tmp2, d21_f32); + vst1q_f32(dstPtr + 3 * dstStep, d30_f32); vst1q_f32(tmp3, d31_f32); + + memcpy(dstPtr + 0 * dstStep + 4, tmp0, sizeof(float) * extraDepth); + memcpy(dstPtr + 1 * dstStep + 4, tmp1, sizeof(float) * extraDepth); + memcpy(dstPtr + 2 * dstStep + 4, tmp2, sizeof(float) * extraDepth); + memcpy(dstPtr + 3 * dstStep + 4, tmp3, sizeof(float) * extraDepth); + + srcPtr += 4 * pack; + dstPtr += 4 * dstStep; + realsize -= 4; + } + while (realsize > 0) { + auto s0_fp16 = vld1q_f16(srcPtr); + auto d00_fp32 = vcvt_f32_f16(vget_low_f16(s0_fp16)); + auto d01_fp32 = vcvt_f32_f16(vget_high_f16(s0_fp16)); + vst1q_f32(dstPtr, d00_fp32); + vst1q_f32(tmp0, d01_fp32); + memcpy(dstPtr + 4, tmp0, sizeof(float) * extraDepth); + srcPtr += pack; + dstPtr += dstStep; + realsize--; + } + } + + if (remainDepth > 0 && remainDepth < 4) { + auto realsize = planesize; + auto srcPtr = (FLOAT16*)src + (depthDiv8 - 1) * srcStep; + auto dstPtr = (float*)dst + (depthDiv8 - 1) * pack; + + float tmp0[4]; + float tmp1[4]; + float tmp2[4]; + float tmp3[4]; + float tmp4[4]; + float tmp5[4]; + float tmp6[4]; + float tmp7[4]; + + while (realsize >= 8) { + float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack); + float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack); + float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack); + float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack); + float16x8_t s4_f16 = vld1q_f16(srcPtr + 4 * pack); + float16x8_t s5_f16 = vld1q_f16(srcPtr + 5 * pack); + float16x8_t s6_f16 = vld1q_f16(srcPtr + 6 * pack); + float16x8_t s7_f16 = vld1q_f16(srcPtr + 7 * pack); + + float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16)); + float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16)); + float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16)); + float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16)); + float32x4_t d40_f32 = vcvt_f32_f16(vget_low_f16(s4_f16)); + float32x4_t d50_f32 = vcvt_f32_f16(vget_low_f16(s5_f16)); + float32x4_t d60_f32 = vcvt_f32_f16(vget_low_f16(s6_f16)); + float32x4_t d70_f32 = vcvt_f32_f16(vget_low_f16(s7_f16)); + + vst1q_f32(tmp0, d00_f32); + vst1q_f32(tmp1, d10_f32); + vst1q_f32(tmp2, d20_f32); + vst1q_f32(tmp3, d30_f32); + vst1q_f32(tmp4, d40_f32); + vst1q_f32(tmp5, d50_f32); + vst1q_f32(tmp6, d60_f32); + vst1q_f32(tmp7, d70_f32); + + memcpy(dstPtr + 0 * dstStep, tmp0, sizeof(float) * remainDepth); + memcpy(dstPtr + 1 * dstStep, tmp1, sizeof(float) * remainDepth); + memcpy(dstPtr + 2 * dstStep, tmp2, sizeof(float) * remainDepth); + memcpy(dstPtr + 3 * dstStep, tmp3, sizeof(float) * remainDepth); + memcpy(dstPtr + 4 * dstStep, tmp4, sizeof(float) * remainDepth); + memcpy(dstPtr + 5 * dstStep, tmp5, sizeof(float) * remainDepth); + memcpy(dstPtr + 6 * dstStep, tmp6, sizeof(float) * remainDepth); + memcpy(dstPtr + 7 * dstStep, tmp7, sizeof(float) * remainDepth); + + srcPtr += 8 * pack; + dstPtr += 8 * dstStep; + realsize -= 8; + } + if (realsize >= 4) { + float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack); + float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack); + float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack); + float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack); + + float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16)); + float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16)); + float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16)); + float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16)); + + vst1q_f32(tmp0, d00_f32); + vst1q_f32(tmp1, d10_f32); + vst1q_f32(tmp2, d20_f32); + vst1q_f32(tmp3, d30_f32); + + memcpy(dstPtr + 0 * dstStep, tmp0, sizeof(float) * remainDepth); + memcpy(dstPtr + 1 * dstStep, tmp1, sizeof(float) * remainDepth); + memcpy(dstPtr + 2 * dstStep, tmp2, sizeof(float) * remainDepth); + memcpy(dstPtr + 3 * dstStep, tmp3, sizeof(float) * remainDepth); + + srcPtr += 4 * pack; + dstPtr += 4 * dstStep; + realsize -= 4; + } + while (realsize > 0) { + auto s0_f16 = vld1q_f16(srcPtr); + float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16)); + vst1q_f32(tmp0, d00_f32); + memcpy(dstPtr + 0 * dstStep, tmp0, sizeof(float) * remainDepth); + srcPtr += pack; + dstPtr += dstStep; + realsize--; + } + } +} + +static void MNNAttenPackAndConvertFp32LP1(float* dst, const float* src, const int32_t* units, size_t depth, size_t planesize) { + int32_t eP = units[0]; + int32_t lP = units[1]; + + if (lP != 1) { + MNN_ERROR("This function only supports lP=1\n"); + return; + } + + auto dstStride1 = eP; + auto dstStride0 = planesize * dstStride1; + + for (int i = 0; i < depth; ++i) { + size_t realsize = planesize; + const float* srcPtr = src + i * planesize; + FLOAT16* dstPtr = (FLOAT16*)dst + (i % eP) + (i / eP) * dstStride0; + + while (realsize >= 16) { + float32x4_t s0_f32 = vld1q_f32(srcPtr); + float32x4_t s1_f32 = vld1q_f32(srcPtr + 4); + float32x4_t s2_f32 = vld1q_f32(srcPtr + 8); + float32x4_t s3_f32 = vld1q_f32(srcPtr + 12); + + float16x4_t d0_f16 = vcvt_f16_f32(s0_f32); + float16x4_t d1_f16 = vcvt_f16_f32(s1_f32); + float16x4_t d2_f16 = vcvt_f16_f32(s2_f32); + float16x4_t d3_f16 = vcvt_f16_f32(s3_f32); + + vst1_lane_f16(dstPtr, d0_f16, 0); + vst1_lane_f16(dstPtr + dstStride1, d0_f16, 1); + vst1_lane_f16(dstPtr + 2 * dstStride1, d0_f16, 2); + vst1_lane_f16(dstPtr + 3 * dstStride1, d0_f16, 3); + + vst1_lane_f16(dstPtr + 4 * dstStride1, d1_f16, 0); + vst1_lane_f16(dstPtr + 5 * dstStride1, d1_f16, 1); + vst1_lane_f16(dstPtr + 6 * dstStride1, d1_f16, 2); + vst1_lane_f16(dstPtr + 7 * dstStride1, d1_f16, 3); + + vst1_lane_f16(dstPtr + 8 * dstStride1, d2_f16, 0); + vst1_lane_f16(dstPtr + 9 * dstStride1, d2_f16, 1); + vst1_lane_f16(dstPtr + 10 * dstStride1, d2_f16, 2); + vst1_lane_f16(dstPtr + 11 * dstStride1, d2_f16, 3); + + vst1_lane_f16(dstPtr + 12 * dstStride1, d3_f16, 0); + vst1_lane_f16(dstPtr + 13 * dstStride1, d3_f16, 1); + vst1_lane_f16(dstPtr + 14 * dstStride1, d3_f16, 2); + vst1_lane_f16(dstPtr + 15 * dstStride1, d3_f16, 3); + + srcPtr += 16; + dstPtr += 16 * dstStride1; + realsize -= 16; + } + + if (realsize >= 8) { + float32x4_t s0_f32 = vld1q_f32(srcPtr); + float32x4_t s1_f32 = vld1q_f32(srcPtr + 4); + + float16x4_t d0_f16 = vcvt_f16_f32(s0_f32); + float16x4_t d1_f16 = vcvt_f16_f32(s1_f32); + + vst1_lane_f16(dstPtr, d0_f16, 0); + vst1_lane_f16(dstPtr + dstStride1, d0_f16, 1); + vst1_lane_f16(dstPtr + 2 * dstStride1, d0_f16, 2); + vst1_lane_f16(dstPtr + 3 * dstStride1, d0_f16, 3); + + vst1_lane_f16(dstPtr + 4 * dstStride1, d1_f16, 0); + vst1_lane_f16(dstPtr + 5 * dstStride1, d1_f16, 1); + vst1_lane_f16(dstPtr + 6 * dstStride1, d1_f16, 2); + vst1_lane_f16(dstPtr + 7 * dstStride1, d1_f16, 3); + + srcPtr += 8; + dstPtr += 8 * dstStride1; + realsize -= 8; + } + + if (realsize >= 4) { + float32x4_t s0_f32 = vld1q_f32(srcPtr); + float16x4_t d0_f16 = vcvt_f16_f32(s0_f32); + + vst1_lane_f16(dstPtr, d0_f16, 0); + vst1_lane_f16(dstPtr + dstStride1, d0_f16, 1); + vst1_lane_f16(dstPtr + 2 * dstStride1, d0_f16, 2); + vst1_lane_f16(dstPtr + 3 * dstStride1, d0_f16, 3); + + srcPtr += 4; + dstPtr += 4 * dstStride1; + realsize -= 4; + } + + for (; realsize > 0; --realsize) { + *dstPtr = (FLOAT16)(*srcPtr); + srcPtr++; + dstPtr += dstStride1; + } + } +} + +static void MNNAttenPackAndConvertFp32(float* dst, float* src, const int32_t* units, size_t depth, size_t planesize) { + int32_t eP = units[0]; + int32_t lP = units[1]; // Now lP=1 or 2 + + if (lP != 1 && lP != 2) { + MNN_ERROR("This function only supports lP=1 or 2\n"); + return; + } + + // src [depth, planesize] (float32) + // dst [depth/eP, planesize/lP, eP, lP] (float16) + + if (lP == 1) { + MNNAttenPackAndConvertFp32LP1(dst, src, units, depth, planesize); + return; + } + + auto dstStride1 = eP * lP; + auto dstStride0 = UP_DIV(planesize, lP) * dstStride1; + + for (int i = 0; i < depth; ++i) { + size_t realsize = planesize; + const float* srcPtr = src + i * planesize; + FLOAT16* dstPtr = (FLOAT16*)dst + (i % eP) * lP + (i / eP) * dstStride0; + + while (realsize >= 16) { + float32x4_t s0 = vld1q_f32(srcPtr); + float32x4_t s1 = vld1q_f32(srcPtr + 4); + float32x4_t s2 = vld1q_f32(srcPtr + 8); + float32x4_t s3 = vld1q_f32(srcPtr + 12); + + float16x4_t h0 = vcvt_f16_f32(s0); + float16x4_t h1 = vcvt_f16_f32(s1); + float16x4_t h2 = vcvt_f16_f32(s2); + float16x4_t h3 = vcvt_f16_f32(s3); + + vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0); + vst1_lane_u32((uint32_t*)(dstPtr + dstStride1), vreinterpret_u32_f16(h0), 1); + + vst1_lane_u32((uint32_t*)(dstPtr + 2 * dstStride1), vreinterpret_u32_f16(h1), 0); + vst1_lane_u32((uint32_t*)(dstPtr + 3 * dstStride1), vreinterpret_u32_f16(h1), 1); + + vst1_lane_u32((uint32_t*)(dstPtr + 4 * dstStride1), vreinterpret_u32_f16(h2), 0); + vst1_lane_u32((uint32_t*)(dstPtr + 5 * dstStride1), vreinterpret_u32_f16(h2), 1); + + vst1_lane_u32((uint32_t*)(dstPtr + 6 * dstStride1), vreinterpret_u32_f16(h3), 0); + vst1_lane_u32((uint32_t*)(dstPtr + 7 * dstStride1), vreinterpret_u32_f16(h3), 1); + + realsize -= 16; + srcPtr += 16; + dstPtr += 8 * dstStride1; + } + + if (realsize >= 8) { + float32x4_t s0 = vld1q_f32(srcPtr); + float32x4_t s1 = vld1q_f32(srcPtr + 4); + + float16x4_t h0 = vcvt_f16_f32(s0); + float16x4_t h1 = vcvt_f16_f32(s1); + + vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0); + vst1_lane_u32((uint32_t*)(dstPtr + dstStride1), vreinterpret_u32_f16(h0), 1); + + vst1_lane_u32((uint32_t*)(dstPtr + 2 * dstStride1), vreinterpret_u32_f16(h1), 0); + vst1_lane_u32((uint32_t*)(dstPtr + 3 * dstStride1), vreinterpret_u32_f16(h1), 1); + + realsize -= 8; + srcPtr += 8; + dstPtr += 4 * dstStride1; + } + + if (realsize >= 4) { + float32x4_t s0 = vld1q_f32(srcPtr); + float16x4_t h0 = vcvt_f16_f32(s0); + + vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0); + vst1_lane_u32((uint32_t*)(dstPtr + dstStride1), vreinterpret_u32_f16(h0), 1); + + realsize -= 4; + srcPtr += 4; + dstPtr += 2 * dstStride1; + } + + if (realsize >= 2) { + float32x2_t s0 = vld1_f32(srcPtr); + float16x4_t h0 = vcvt_f16_f32(vcombine_f32(s0, s0)); + + vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0); + + realsize -= 2; + srcPtr += 2; + dstPtr += dstStride1; + } + + if (realsize > 0) { + dstPtr[0] = (FLOAT16)srcPtr[0]; + dstPtr[1] = (FLOAT16)0.0f; + } + } +} + +#endif // MNN_SUPPORT_TRANSFORMER_FUSE + #ifdef MNN_LOW_MEMORY void MNNAbsMaxFP16(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) { if (pack == 4) { @@ -1463,7 +2187,7 @@ static void MNNDynamicQuantFP16(const float* src, int8_t* dst, const float* scal } int8_t* dstPtr = dst; auto srcPtr = (FLOAT16*)src; - + for (int i = 0; i < realSize; ++i) { auto scaleVal = static_cast(scale[i]); for (int c = 0; c < src_depth_quad; ++c) { @@ -1534,7 +2258,7 @@ static void MNNAsyQuantInfo_FP16(float* scale, float* bias, float* qscale, float auto stride0 = blockNum * blockLU * plane * innerSide; auto stride1 = blockLU * plane * innerSide; auto srcPtr = (FLOAT16*)src; - + // input shape: [kernelsize,blocknum,blocklu,DST_XUNIT,SRC_UNIT] or [ic/core->pack, plane, core->pack] // dequant scale/bias : [EU, blockNum, step] // quant scale/bias: [blockNum, plane] @@ -1563,7 +2287,7 @@ static void MNNAsyQuantInfo_FP16(float* scale, float* bias, float* qscale, float } return; } - + #ifdef __aarch64__ if (DST_XUNIT == 12 || DST_XUNIT == 16) { // Arm82/SME2, fp16: core->pack=8, SRC_UNIT=4 // max,min shape: [blockNum, EP] @@ -1653,12 +2377,13 @@ static void MNNAsyQuantInfo_FP16(float* scale, float* bias, float* qscale, float qbias[qind] = -min_ * 255.f / range - 128.0f; scale[sind] = range / 255.f; bias[sind] = min_ + (128.f / 255.f) * range; - + } } } #endif } + #endif // MNN_LOW_MEMORY static CoreFunctions* gInstance = nullptr; @@ -1726,6 +2451,8 @@ bool Arm82Functions::init() { FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain, MNNPackedMatMulRemainFP16); FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Arm82MNNPackForMatMul_A); FUNC_PTR_ASSIGN(gInstance->MNNPackForMatMul_B, Arm82MNNPackForMatMul_B); + + FUNC_PTR_ASSIGN(gInstance->MNNSoftmax, origin->MNNSoftmax); #if defined(__aarch64__) gInstance->supportFp16arith = origin->supportFp16arith; gInstance->supportSDot = origin->supportSDot; @@ -1755,7 +2482,16 @@ bool Arm82Functions::init() { FUNC_PTR_ASSIGN(gInstance->MNNCountMaxMinValue, ARM82CountMinMaxValue); // return one min&max FUNC_PTR_ASSIGN(gInstance->MNNSumByAxisLForMatmul_A, origin->MNNSumByAxisLForMatmul_A); FUNC_PTR_ASSIGN(gInstance->MNNDepthwiseConvFastKernel, MNNDepthwiseConvFastKernelFP16); -#endif +#endif // __aarch64__ + +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE + // Attention + FUNC_PTR_ASSIGN(gInstance->MNNAttenUnpackAndConvertFp16, MNNAttenUnpackAndConvertFp16); + FUNC_PTR_ASSIGN(gInstance->MNNAttenPackAndConvertFp32, MNNAttenPackAndConvertFp32); + FUNC_PTR_ASSIGN(gInstance->MNNAttenPackAndScaleSingleHead, MNNAttenPackAndScaleSingleHead); + FUNC_PTR_ASSIGN(gInstance->MNNFlashAttentionUpdateBlockOutput, MNNFlashAttentionUpdateBlockOutput); +#endif // MNN_SUPPORT_TRANSFORMER_FUSE + gInstance->MNNComputeMatMulForH_1 = _MNNComputeMatMulForH_1_FP16; gInstance->MNNComputeMatMulForE_1 = _MNNComputeMatMulForE_1_FP16; diff --git a/source/backend/cpu/CPUAttention.cpp b/source/backend/cpu/CPUAttention.cpp index 01ff0690..aca373f0 100644 --- a/source/backend/cpu/CPUAttention.cpp +++ b/source/backend/cpu/CPUAttention.cpp @@ -24,10 +24,12 @@ #define FLOAT16_T float #endif +#define MNN_FLASH_ATTENTION_BLOCK_SIZE 64 + namespace MNN { template -void CPUAttention::pack_query(Tensor* query, char* pack_q, char* sum_q, int seq_len, int h, float q_scale) { +void CPUAttention::pack_query(Tensor* query, int8_t* pack_q, int8_t* sum_q, int seq_len, int h, float q_scale) { if (mUseGemmInt8) { // Shape of Query: numhead, [seqlen/eP8, headdim/lP8, eP8, lP8] mMinQ[h] = query->host()[h * mHeadDim]; mMaxQ[h] = query->host()[h * mHeadDim]; @@ -75,21 +77,21 @@ void CPUAttention::pack_query(Tensor* query, char* pack_q, char* sum_q, int seq_ } template -void CPUAttention::unpack_QK(float * unpack_qk_dst, char * pack_qk_src, int seq_len, int kv_seq_len) { +void CPUAttention::unpack_QK(float * unpack_qk_dst, int8_t * pack_qk_src, int seq_len, int kv_seq_len) { float * dst = unpack_qk_dst; T * src = (T *)(pack_qk_src); - // [kv_seq_len/unit, seq_len, unit] -> [seq_len, kv_seq_len] + // [kv_seq_len/mPack, seq_len, mPack] -> [seq_len, kv_seq_len] for (int i = 0; i < seq_len; i++) { for (int j = 0; j < kv_seq_len; j++) { - int out_index = j / unit; - int in_index = j % unit; - dst[i * kv_seq_len + j] = src[out_index * seq_len * unit + i * unit + in_index]; + int out_index = j / mPack; + int in_index = j % mPack; + dst[i * kv_seq_len + j] = src[out_index * seq_len * mPack + i * mPack + in_index]; } } } template -static void pack_QK(char * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_len, int eP, int lP, int bytes) { +static void pack_QK(int8_t * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_len, int eP, int lP, int bytes) { T * dst = reinterpret_cast(pack_qk_dst); float * src = reinterpret_cast(qk_src); // [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len/lP, eP, lP] @@ -108,38 +110,50 @@ static void pack_QK(char * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_ } template -static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale, float min_val, const Tensor* mask) { - if (mask == nullptr) { - for (int i = 0; i < kv_seq_len; i++) { +static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale, float min_val, const Tensor* maskTensor, int offset, int startIndx, int processedKvLen) { + + int endIndx = startIndx + processedKvLen; + if (maskTensor == nullptr) { + for (int i = 0; i < processedKvLen; i++) { unpack_qk[i] = unpack_qk[i] * mScale; } - } else if (mask->getType() == halide_type_of()) { + return; + } + const int8_t* mask = maskTensor->host(); + halide_type_t htype = maskTensor->getType(); + int maskSize = maskTensor->elementSize(); + + if (htype == halide_type_of()) { // float mask - T* fpmask_ptr = mask->host(); - if (mask->elementSize() == seq_len * kv_seq_len) { - // normal mask for all token - for (int i = 0; i < seq_len * kv_seq_len; i++) { - unpack_qk[i] = unpack_qk[i] * mScale + fpmask_ptr[i]; - } - } else { - // square mask just for new generation token - int offset = kv_seq_len - seq_len; + T* fpmask_ptr = (T*)mask; + if (maskSize == seq_len * kv_seq_len) { // sliding attention, mask shape: [seq_len, kv_seq_len] for (int i = 0; i < seq_len; ++i) { - auto unpack_qki = unpack_qk + i * kv_seq_len; - auto fpmask_ptri = fpmask_ptr + i * seq_len; - for (int j = 0; j < offset; ++j) { - unpack_qki[j] = unpack_qki[j] * mScale; + auto unpack_qki = unpack_qk + i * processedKvLen; + auto fpmask_ptri = fpmask_ptr + i * kv_seq_len; + for (int j = startIndx; j < endIndx; ++j) { + unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale + fpmask_ptri[j]; } - for (int j = 0; j < seq_len; ++j) { - unpack_qki[offset + j] = unpack_qki[offset + j] * mScale + fpmask_ptri[j]; + } + } else { // mask shape: [seq_len, seq_len] + for (int i = 0; i < seq_len; ++i) { + auto unpack_qki = unpack_qk + i * processedKvLen; + auto fpmask_ptri = fpmask_ptr + i * seq_len; + + auto notMaskIndx = ALIMIN(endIndx, offset); + auto stMaskIndx = ALIMAX(startIndx, offset); + for (int j = startIndx; j < notMaskIndx; ++j) { + unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale; + } + for (int j = stMaskIndx; j < endIndx; ++j) { + unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale + fpmask_ptri[j - offset]; } } } } else { // int mask - int* mask_ptr = mask->host(); - for (int i = 0; i < seq_len * kv_seq_len; i++) { - if (mask_ptr[i]) { + int* mask_ptr = (int*)mask; + for (int i = 0; i < seq_len * processedKvLen; i++) { + if (mask_ptr[i / processedKvLen * kv_seq_len + i % processedKvLen]) { unpack_qk[i] = unpack_qk[i] * mScale; } else { unpack_qk[i] = min_val; @@ -148,41 +162,47 @@ static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale } } -static void softmax_QK(float* softmax_qk_addr, float* unpack_qk_addr, int seq_len, int kv_seq_len) { - for (int i = 0; i < seq_len; i++) { // softmax each row - MNNSoftmax(softmax_qk_addr + i * kv_seq_len, unpack_qk_addr + i * kv_seq_len, kv_seq_len); +typedef void(softmaxFunc)(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize); +template +static void softmaxQK(float* softmax_qk_addr, float* unpack_qk_addr, float* runningMax, float* runningSum, float* diffScale, const float* sinkPtr, softmaxFunc* sffunc, int seq_len, int kv_seq_len, int headIdx, bool isLastKvBlock) { + + // not sliding attention + if (sinkPtr == nullptr) { + sffunc(softmax_qk_addr, unpack_qk_addr, runningMax, runningSum, diffScale, seq_len, kv_seq_len); + return; + } + + float sink = ((T*)sinkPtr)[headIdx]; + if (!runningMax && !runningSum) { // Do not use flash attention + + for (int i = 0; i < seq_len; ++i) { + float exprOffset[4] = {1, 0, -sink, 1.f}; + MNNExp(softmax_qk_addr + i * kv_seq_len, unpack_qk_addr + i * kv_seq_len, exprOffset, kv_seq_len); + for (int j = 0; j < kv_seq_len; ++j) { + softmax_qk_addr[i * kv_seq_len + j] /= exprOffset[3]; + } + } + return; } -} -static void sink_softmax_QK(float* softmax_qk_addr, float* unpack_qk_addr, int seq_len, int kv_seq_len, float sink) { - // TODO: opt - std::vector buffer(2 * (kv_seq_len + 1)); - float* sinkSrc = buffer.data(); - float* sinkDst = buffer.data() + kv_seq_len + 1; - for (int i = 0; i < seq_len; i++) { // softmax each row - ::memcpy(sinkSrc, unpack_qk_addr + i * kv_seq_len, kv_seq_len * sizeof(float)); - sinkSrc[kv_seq_len] = sink; - float rowMax = sink; - for (int j = 0; j < kv_seq_len; j++) { - rowMax = ALIMAX(rowMax, sinkSrc[j]); + // Use flash attention + if (isLastKvBlock) { + for (int i = 0; i < seq_len; ++i) { + runningSum[i] += expf(sink - runningMax[i]); } - for (int j = 0; j < kv_seq_len + 1; j++) { - sinkSrc[j] = sinkSrc[j] - rowMax; - } - MNNSoftmax(sinkDst, sinkSrc, kv_seq_len + 1); - ::memcpy(softmax_qk_addr + i * kv_seq_len, sinkDst, kv_seq_len * sizeof(float)); } + MNNSoftmax(softmax_qk_addr, unpack_qk_addr, runningMax, runningSum, diffScale, seq_len, kv_seq_len); } template -static void unpack_QKV(char* pack_qkv, char* unpack_qkv, int mNumHead, int mHeadDim, int unit, int seq_len) { +static void unpack_QKV(int8_t* pack_qkv, int8_t* unpack_qkv, int mNumHead, int mHeadDim, int mPack, int seq_len) { auto src_ptr = reinterpret_cast(pack_qkv); auto dst_ptr = reinterpret_cast(unpack_qkv); for (int i = 0; i < seq_len; i++) { for (int j = 0; j < mHeadDim; j++) { - int a = j / unit; - int b = j % unit; - dst_ptr[i * mNumHead * mHeadDim + j] = src_ptr[a * seq_len * unit + i * unit + b]; + int a = j / mPack; + int b = j % mPack; + dst_ptr[i * mNumHead * mHeadDim + j] = src_ptr[a * seq_len * mPack + i * mPack + b]; } } } @@ -191,10 +211,10 @@ ErrorCode CPUAttention::onResize(const std::vector& inputs, const std:: auto core = static_cast(backend())->functions(); core->MNNGetMatMulPackMode(&eP, &lP, &hP); mThreadNum = ((CPUBackend *)backend())->threadNumber(); - unit = core->pack; + mPack = core->pack; bytes = core->bytes; int qkvQuantOptions = static_cast(backend())->getRuntime()->hint().qkvQuantOption; - mUseGemmInt8 = (qkvQuantOptions == 4); + mUseGemmInt8 = (qkvQuantOptions % 8 == 4); if (mUseGemmInt8) { static_cast(backend())->int8Functions()->MNNGetGemmUnit(&hP8, &lP8, &eP8); } @@ -208,7 +228,7 @@ ErrorCode CPUAttention::onResize(const std::vector& inputs, const std:: if (mUseGemmInt8) { mPackQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP8), UP_DIV(mHeadDim, lP8), eP8 * lP8})); mSumQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP8), eP8})); - mPackQKV.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, unit), seq_len, unit})); + mPackQKV.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack})); backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC); backend()->onAcquireBuffer(mSumQ.get(), Backend::DYNAMIC); backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC); @@ -220,18 +240,40 @@ ErrorCode CPUAttention::onResize(const std::vector& inputs, const std:: mQueryScale.resize(mNumHead); mQueryZeroPoint.resize(mNumHead); } else { - mPackQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mHeadDim, lP), eP})); - mPackQKV.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, unit), seq_len, unit})); + mPackQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mHeadDim, lP), eP * bytes})); + mPackQKV.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack * bytes})); backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC); backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC); + + // flash attention + if (qkvQuantOptions / 8 == 1) { + mRunningMax.reset(Tensor::createDevice({mThreadNum, seq_len * 4})); + mRunningSum.reset(Tensor::createDevice({mThreadNum, seq_len * 4})); + mExpfDiffMax.reset(Tensor::createDevice({mThreadNum, seq_len * 4})); + mTempOut.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack * bytes})); + + backend()->onAcquireBuffer(mRunningMax.get(), Backend::DYNAMIC); + backend()->onAcquireBuffer(mRunningSum.get(), Backend::DYNAMIC); + backend()->onAcquireBuffer(mExpfDiffMax.get(), Backend::DYNAMIC); + backend()->onAcquireBuffer(mTempOut.get(), Backend::DYNAMIC); + } + backend()->onReleaseBuffer(mPackQ.get(), Backend::DYNAMIC); backend()->onReleaseBuffer(mPackQKV.get(), Backend::DYNAMIC); + + if (qkvQuantOptions / 8 == 1) { + backend()->onReleaseBuffer(mRunningMax.get(), Backend::DYNAMIC); + backend()->onReleaseBuffer(mRunningSum.get(), Backend::DYNAMIC); + backend()->onReleaseBuffer(mExpfDiffMax.get(), Backend::DYNAMIC); + backend()->onReleaseBuffer(mTempOut.get(), Backend::DYNAMIC); + } } return NO_ERROR; } ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std::vector& outputs) { auto core = static_cast(backend())->functions(); + auto qkvQuantOptions = static_cast(backend())->getRuntime()->hint().qkvQuantOption; auto query = inputs[0]; auto key = inputs[1]; auto value = inputs[2]; @@ -283,146 +325,146 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: int max_len = mKVCacheManager->maxLength(); bool quant_key = mKVCacheManager->config()->mQuantKey; bool quant_value = mKVCacheManager->config()->mQuantValue; + + mBlockKV = (qkvQuantOptions / 8 == 1) ? ALIMIN(MNN_FLASH_ATTENTION_BLOCK_SIZE, kv_seq_len) : kv_seq_len; + int32_t units[2] = {eP, lP}; + // Temporary tensors for intermediate results - std::shared_ptr packQK(Tensor::createDevice({mThreadNum, UP_DIV(kv_seq_len, unit), seq_len, unit})); - std::shared_ptr unpackQK(Tensor::createDevice({mThreadNum, seq_len, kv_seq_len})); - std::shared_ptr softmMaxQ(Tensor::createDevice({mThreadNum, seq_len, kv_seq_len})); - std::shared_ptr newPackQK(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(kv_seq_len, lP), eP})); - std::shared_ptr dequantV(Tensor::createDevice({mKvNumHead, UP_DIV(mHeadDim, hP), kv_seq_len, hP})); - backend()->onAcquireBuffer(packQK.get(), Backend::STATIC); + std::shared_ptr unpackQK(Tensor::createDevice({mThreadNum, seq_len, mBlockKV})); + std::shared_ptr softmMaxQ(Tensor::createDevice({mThreadNum, seq_len, mBlockKV})); + std::shared_ptr newPackQK(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mBlockKV, lP), eP * bytes})); + std::shared_ptr dequantV(Tensor::createDevice({mKvNumHead, UP_DIV(mHeadDim, hP), kv_seq_len, hP * bytes})); + // mTempQKBlock.reset(Tensor::createDevice({mThreadNum, UP_DIV(mBlockKV, mPack), seq_len, mPack * bytes})); + std::shared_ptr tempQKBlock(Tensor::createDevice({mThreadNum, UP_DIV(mBlockKV, mPack), seq_len, mPack * bytes})); backend()->onAcquireBuffer(unpackQK.get(), Backend::STATIC); backend()->onAcquireBuffer(softmMaxQ.get(), Backend::STATIC); backend()->onAcquireBuffer(newPackQK.get(), Backend::STATIC); + backend()->onAcquireBuffer(tempQKBlock.get(), Backend::STATIC); if (quant_value) { backend()->onAcquireBuffer(dequantV.get(), Backend::STATIC); mKVCacheManager->onDequantValue(dequantV.get()); } const float* sinksPtr = sinks ? sinks->host() : nullptr; std::function mCompute = [=](int tId) { - auto pack_q = mPackQ->host() + tId * UP_DIV(seq_len, eP) * ROUND_UP(mHeadDim, lP) * eP * bytes; - auto pack_qk = packQK->host() + tId * UP_DIV(kv_seq_len, unit) * seq_len * unit * bytes; - char * sum_q = nullptr; - auto unpack_qk = unpackQK->host() + tId * seq_len * kv_seq_len; - auto softmax_qk = softmMaxQ->host() + tId * seq_len * kv_seq_len; - auto new_pack_qk = newPackQK->host() + tId * UP_DIV(seq_len, eP) * ROUND_UP(kv_seq_len, lP) * eP * bytes; - auto pack_qkv = mPackQKV->host() + tId * UP_DIV(mHeadDim, unit) * seq_len * unit * bytes; + auto qReordered = mPackQ->host() + tId * mPackQ->stride(0); + auto qkPacked = tempQKBlock->host() + tId * tempQKBlock->stride(0); + int8_t * sum_q = nullptr; + auto qkFlatten = unpackQK->host() + tId * unpackQK->stride(0); + auto qkSoftmax = softmMaxQ->host() + tId * softmMaxQ->stride(0); + auto qkReordered = newPackQK->host() + tId * newPackQK->stride(0); + auto qkvPacked = mPackQKV->host() + tId * mPackQKV->stride(0); auto QxK = quant_key ? core->MNNPackedMatMul_int8 : core->MNNPackedMatMul; auto QxK_remain = quant_key ? core->MNNPackedMatMulRemain_int8 : core->MNNPackedMatMulRemain; + + // Flash Attention + auto runningMax = mRunningMax ? (float*)(mRunningMax->host() + tId * mRunningMax->stride(0)) : nullptr; + auto runningSum = mRunningSum ? (float*)(mRunningSum->host() + tId * mRunningSum->stride(0)) : nullptr; + auto diffScale = mExpfDiffMax ? (float*)(mExpfDiffMax->host() + tId * mExpfDiffMax->stride(0)) : nullptr; + auto outputPacked = mTempOut ? mTempOut->host() + tId * mTempOut->stride(0) : qkvPacked; int head_index = tId * tileCount; + int kvBlocks = UP_DIV(kv_seq_len, mBlockKV); + if (mUseGemmInt8) { - pack_q = mPackQ->host() + tId * UP_DIV(seq_len, eP8) * UP_DIV(mHeadDim, lP8) * eP8 * lP8; - sum_q = mSumQ->host() + tId * UP_DIV(seq_len, eP8) * eP8 * 4; + qReordered = mPackQ->host() + tId * UP_DIV(seq_len, eP8) * UP_DIV(mHeadDim, lP8) * eP8 * lP8; + sum_q = mSumQ->host() + tId * UP_DIV(seq_len, eP8) * eP8 * 4; } for (int h = head_index; h < head_index + tileCount && h < mNumHead; h++) { - int kv_h = h / group_size; - char * key_addr = mKVCacheManager->addrOfKey(kv_h); - char * scale_addr = mKVCacheManager->addrOfScale(kv_h); - char * zero_point_addr = mKVCacheManager->addrOfZeroPoint(kv_h); - char * key_sum_addr = mKVCacheManager->addrOfKeySum(kv_h); - char * value_addr = quant_value ? (dequantV->host() + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(kv_seq_len, lP) * hP * bytes) : mKVCacheManager->addrOfValue(kv_h); - if (bytes == 2) { - pack_query(query, pack_q, sum_q, seq_len, h, q_scale); - } else { - pack_query(query, pack_q, sum_q, seq_len, h, q_scale); + if (runningSum && runningMax) { + memset(runningSum, 0, mRunningSum->stride(0)); + if (sinksPtr == nullptr) { + for (int k = 0; k < seq_len; ++k) { + runningMax[k] = -std::numeric_limits::infinity(); + } + } else { + float sinkVal; + if (bytes == 2) { + sinkVal = ((FLOAT16_T*)sinksPtr)[h]; + } else { + sinkVal =sinksPtr[h]; + } + for (int k = 0; k < seq_len; ++k) { + runningMax[k] = sinkVal; + } + } } - // query @ key + int kv_h = h / group_size; + int8_t * key_addr = mKVCacheManager->addrOfKey(kv_h); + int8_t * scale_addr = mKVCacheManager->addrOfScale(kv_h); + int8_t * zero_point_addr = mKVCacheManager->addrOfZeroPoint(kv_h); + int8_t * key_sum_addr = mKVCacheManager->addrOfKeySum(kv_h); + int8_t * value_addr = quant_value ? (dequantV->host() + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(kv_seq_len, lP) * hP * bytes) : mKVCacheManager->addrOfValue(kv_h); if (mUseGemmInt8) { - auto GemmInt8Kernel = static_cast(backend())->int8Functions()->Int8GemmKernel; - if (bytes == 2 && unit == 8) { - GemmInt8Kernel = static_cast(backend())->int8Functions()->MNNGemmInt8AddBiasScale_Unit_FP16; - } - std::vector postScale(ROUND_UP(kv_seq_len, hP8), 0.0f); - for (int i = 0; i < kv_seq_len; i++) { - postScale[i] = ((float*)scale_addr)[i] * mQueryScale[h] * q_scale; - } - std::vector weightQuantBias(ROUND_UP(kv_seq_len, hP8), 0.0f); - for (int i = 0; i < kv_seq_len; i++) { - weightQuantBias[i] = -((float*)scale_addr)[i] * ((float*)zero_point_addr)[i] * q_scale; - } - std::vector biasFloat(ROUND_UP(kv_seq_len, hP8), 0.0f); - for (int i = 0; i < kv_seq_len; i++) { - biasFloat[i] = -mQueryScale[h] * mQueryZeroPoint[h] * ((float*)key_sum_addr)[i] * q_scale; - } - QuanPostTreatParameters post; - post.bias = nullptr; - post.biasFloat = biasFloat.data(); - post.blockNum = 1; - post.inputBias = nullptr; - post.inputScale = nullptr; - post.fp32minmax = nullptr; - post.scale = postScale.data(); - post.useInt8 = false; - post.weightKernelSum = weightQuantBias.data(); - int N = UP_DIV(seq_len, eP8); - for (int i = 0; i < N; i++) { - int realcount = ALIMIN(eP8, seq_len - i * eP8); - post.srcKernelSum = (float*)((char*)sum_q + i * eP8 * 4); - GemmInt8Kernel( - (int8_t*)pack_qk + i * eP8 * unit * bytes, - (int8_t*)pack_q + i * ROUND_UP(mHeadDim, lP8) * eP8, - (int8_t*)key_addr, - UP_DIV(mHeadDim, lP8), - seq_len * unit * bytes, - UP_DIV(kv_seq_len, unit), - &post, - realcount - ); + if (bytes == 2) { + pack_query(query, qReordered, sum_q, seq_len, h, q_scale); + } else { + pack_query(query, qReordered, sum_q, seq_len, h, q_scale); } + } else { + core->MNNAttenPackAndScaleSingleHead((float*)qReordered, (float*)(query->host() + h * mHeadDim * bytes), mHeadDim * mNumHead, &q_scale, units, seq_len, mHeadDim); } - else { + for (int i = 0; i < kvBlocks; ++i) { + int subKvSeqLen = ALIMIN(mBlockKV, kv_seq_len - i * mBlockKV); + auto keyPtr = key_addr + i * UP_DIV(mBlockKV, hP) * ROUND_UP(mHeadDim, lP) * hP * bytes; + auto valuePtr = value_addr + i * UP_DIV(mBlockKV, lP) * hP * lP * bytes; + // query @ key + { + int loop_e = seq_len / eP; + int remain = seq_len % eP; + auto qStride0 = ROUND_UP(mHeadDim, lP) * eP * bytes; + size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)mHeadDim, lP), (size_t)subKvSeqLen, (size_t)seq_len * mPack * bytes, 0, 0, 0}; + for (int ei = 0 ; ei < loop_e; ei++) { + QxK((float*)(qkPacked + (ei * eP * mPack) * bytes), (float*)(qReordered + ei * qStride0), (float*)keyPtr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr); + } + QxK_remain((float*)(qkPacked + (loop_e * eP * mPack) * bytes), (float*)(qReordered + loop_e * qStride0), (float*)keyPtr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr); + } + // qk: [kv_seq_len/mPack, seq_len, mPack] -> [seq_len/eP, kv_seq_len, eP] + { + if(bytes == 2) { + if (seq_len == 1) { + core->MNNLowpToFp32((int16_t*)qkPacked, qkFlatten, seq_len * subKvSeqLen); + } else { + core->MNNAttenUnpackAndConvertFp16(qkFlatten, (float*)qkPacked, subKvSeqLen, seq_len, mPack); + } + mask_QK(qkFlatten, seq_len, kv_seq_len, mScale, std::numeric_limits::lowest(), mask, kv_seq_len - seq_len, i * mBlockKV, subKvSeqLen); + softmaxQK(qkSoftmax, qkFlatten, runningMax, runningSum, diffScale, sinksPtr, core->MNNSoftmax, seq_len, subKvSeqLen, h, i == kvBlocks - 1); + core->MNNAttenPackAndConvertFp32((float*)qkReordered, qkSoftmax, units, seq_len, subKvSeqLen); + } else { + if (seq_len > 1) { + int32_t areaOffset[2] = {seq_len, seq_len}; + core->MNNUnpackCUnitTranspose(qkFlatten, (float*)qkPacked, seq_len, subKvSeqLen, areaOffset); + } else { + memcpy(qkFlatten, qkPacked, subKvSeqLen * sizeof(float)); + } + mask_QK(qkFlatten, seq_len, kv_seq_len, mScale, std::numeric_limits::lowest(), mask, kv_seq_len - seq_len, i * mBlockKV, subKvSeqLen); + softmaxQK(qkSoftmax, qkFlatten, runningMax, runningSum, diffScale, sinksPtr, core->MNNSoftmax, seq_len, subKvSeqLen, h, i == kvBlocks - 1); + packKvCache((float*)qkReordered, qkSoftmax, seq_len, subKvSeqLen, eP); + } + } + // qk @ v + // TODO: update qkvPacked using diffScale + size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)subKvSeqLen, lP), (size_t)mHeadDim, (size_t)seq_len * mPack * bytes, 0, 0, 0}; + size_t bExtraStride = (UP_DIV(max_len, lP) - UP_DIV(subKvSeqLen + i * mBlockKV, lP) + UP_DIV(i * mBlockKV, lP)) * hP * lP * bytes; + shapeParameters[5] = quant_value ? 0 : bExtraStride; int loop_e = seq_len / eP; int remain = seq_len % eP; - auto qStride0 = ROUND_UP(mHeadDim, lP) * eP * bytes; - size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)mHeadDim, lP), (size_t)kv_seq_len, (size_t)seq_len * unit * bytes, 0, 0, 0}; - for (int i = 0 ; i < loop_e; i++) { - QxK((float*)(pack_qk + (i * eP * unit) * bytes), (float*)(pack_q + i * qStride0), (float*)key_addr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr); + auto qkStride0 = ROUND_UP(subKvSeqLen, lP) * eP * bytes; + for (int ei = 0 ; ei < loop_e; ei++) { + core->MNNPackedMatMul((float*)(qkvPacked + (ei * eP * mPack) * bytes), (float*)(qkReordered + ei * qkStride0), (float*)valuePtr, shapeParameters, nullptr, nullptr, nullptr, nullptr); } - QxK_remain((float*)(pack_qk + (loop_e * eP * unit) * bytes), (float*)(pack_q + loop_e * qStride0), (float*)key_addr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr); - } - // qk: [kv_seq_len/unit, seq_len, unit] -> [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len, eP] - if (sinksPtr != nullptr) { - if(bytes == 2) { - unpack_QK(unpack_qk, pack_qk, seq_len, kv_seq_len); - mask_QK(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits::lowest(), mask); - sink_softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len, sinksPtr[h]); - pack_QK(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes); - } else { - unpack_QK(unpack_qk, pack_qk, seq_len, kv_seq_len); - mask_QK(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits::lowest(), mask); - sink_softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len, sinksPtr[h]); - pack_QK(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes); - } - } else { - if(bytes == 2) { - unpack_QK(unpack_qk, pack_qk, seq_len, kv_seq_len); - mask_QK(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits::lowest(), mask); - softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len); - pack_QK(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes); - } else { - unpack_QK(unpack_qk, pack_qk, seq_len, kv_seq_len); - mask_QK(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits::lowest(), mask); - softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len); - pack_QK(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes); + core->MNNPackedMatMulRemain((float*)(qkvPacked + (loop_e * eP * mPack) * bytes), (float*)(qkReordered + loop_e * qkStride0), (float*)valuePtr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr); + + if (runningMax != nullptr && runningSum != nullptr && diffScale != nullptr) { + core->MNNFlashAttentionUpdateBlockOutput((float*)outputPacked, (float*)qkvPacked, diffScale, runningSum, UP_DIV(mHeadDim, mPack), seq_len, mPack, i, kvBlocks, mPackQKV->stride(0) / bytes, bytes); } } - // qk @ v - size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)kv_seq_len, lP), (size_t)mHeadDim, (size_t)seq_len * unit * bytes, 0, 0, 0}; - size_t bExtraStride = (UP_DIV(max_len, lP) - UP_DIV(kv_seq_len, lP)) * hP * lP * bytes; - shapeParameters[5] = quant_value ? 0 : bExtraStride; - int loop_e = seq_len / eP; - int remain = seq_len % eP; - auto qkStride0 = ROUND_UP(kv_seq_len, lP) * eP * bytes; - for (int i = 0 ; i < loop_e; i++) { - core->MNNPackedMatMul((float*)(pack_qkv + (i * eP * unit) * bytes), (float*)(new_pack_qk + i * qkStride0), (float*)value_addr, shapeParameters, nullptr, nullptr, nullptr, nullptr); - } - core->MNNPackedMatMulRemain((float*)(pack_qkv + (loop_e * eP * unit) * bytes), (float*)(new_pack_qk + loop_e * qkStride0), (float*)value_addr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr); - // unpack: [head_dim/unit, seq_len, unit] -> [seq_len, num_head, head_dim] - auto dst_ptr = outputs[0]->host() + h * mHeadDim * bytes; + // unpack: [head_dim/mPack, seq_len, mPack] -> [seq_len, num_head, head_dim] + auto dst_ptr = outputs[0]->host() + h * mHeadDim * bytes; if (bytes == 2) { - unpack_QKV(pack_qkv, dst_ptr, mNumHead, mHeadDim, unit, seq_len); + unpack_QKV((int8_t*)outputPacked, dst_ptr, mNumHead, mHeadDim, mPack, seq_len); } else { - unpack_QKV(pack_qkv, dst_ptr, mNumHead, mHeadDim, unit, seq_len); + unpack_QKV((int8_t*)outputPacked, dst_ptr, mNumHead, mHeadDim, mPack, seq_len); } + } }; @@ -431,10 +473,10 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: } MNN_CONCURRENCY_END(); - backend()->onReleaseBuffer(packQK.get(), Backend::STATIC); backend()->onReleaseBuffer(unpackQK.get(), Backend::STATIC); backend()->onReleaseBuffer(softmMaxQ.get(), Backend::STATIC); backend()->onReleaseBuffer(newPackQK.get(), Backend::STATIC); + backend()->onReleaseBuffer(tempQKBlock.get(), Backend::STATIC); if (quant_value){ backend()->onReleaseBuffer(dequantV.get(), Backend::STATIC); } @@ -460,9 +502,19 @@ CPUAttention::CPUAttention(Backend *backend, bool kv_cache) : Execution(backend) mPackQKV.reset(Tensor::createDevice({1, 1, 1, 1})); MNN::KVCacheManager::KVCacheConfig kvconfig; int qkvQuantOptions = static_cast(backend)->getRuntime()->hint().qkvQuantOption; - kvconfig.mUseInt8Kernel = (qkvQuantOptions == 4); - kvconfig.mQuantKey = (qkvQuantOptions == 4) || (qkvQuantOptions & 1); - kvconfig.mQuantValue = (qkvQuantOptions == 4) || ((qkvQuantOptions >> 1) & 1); + kvconfig.mUseInt8Kernel = (qkvQuantOptions % 8 == 4); + + // qkvQuantOption % 8: + // 0: Do not quantize + // 1: Only quantize key, use int8 asymmetric quantization + // 2: Only quantize value, use fp8 quantization + // 3: quantize both key and value + // 4: quantize query, key and value, and use gemm int8 kernel to compute K*V + + // qkvQuantOption / 8: + // 1: use flash attention + kvconfig.mQuantKey = (qkvQuantOptions % 8 == 4) || (qkvQuantOptions % 8 == 1) || (qkvQuantOptions % 8 == 3); + kvconfig.mQuantValue = (qkvQuantOptions % 8 == 4) || (qkvQuantOptions % 8 == 2); kvconfig.mKVCacheDir = static_cast(backend)->getRuntime()->hint().kvcacheDirPath; kvconfig.mKVCacheSizeLimit = static_cast(backend)->getRuntime()->hint().kvcacheSizeLimit; kvconfig.mExpandChunk = 64; diff --git a/source/backend/cpu/CPUAttention.hpp b/source/backend/cpu/CPUAttention.hpp index 066d925c..8739fb71 100644 --- a/source/backend/cpu/CPUAttention.hpp +++ b/source/backend/cpu/CPUAttention.hpp @@ -30,15 +30,16 @@ private: bool mKVCache = true; bool mUseGemmInt8 = false; int bytes = 4; - int mThreadNum = 1;; - int eP, lP, hP, unit; // float matmul packing + int mThreadNum = 1; + int mBlockKV = 512; + int eP, lP, hP, mPack; // float matmul packing int eP8, lP8, hP8; // GemmInt8 packing int mNumHead, mKvNumHead, mHeadDim; - std::shared_ptr mPackQ, mPackQKV, mSumQ; + std::shared_ptr mPackQ, mPackQKV, mSumQ, mRunningMax, mRunningSum, mTempQKBlock, mTempOut, mExpfDiffMax; std::shared_ptr mKVCacheManager = nullptr; std::vector mMinQ, mMaxQ, mQueryScale, mQueryZeroPoint; - template void pack_query(Tensor* query, char* pack_q, char* sum_q, int seq_len, int h, float q_scale); - template void unpack_QK(float * unpack_qk_dst, char * pack_qk_src, int seq_len, int kv_seq_len); + template void pack_query(Tensor* query, int8_t* pack_q, int8_t* sum_q, int seq_len, int h, float q_scale); + template void unpack_QK(float * unpack_qk_dst, int8_t * pack_qk_src, int seq_len, int kv_seq_len); KVMeta* mMeta; }; diff --git a/source/backend/cpu/CPUBackend.cpp b/source/backend/cpu/CPUBackend.cpp index c31d0aa1..b0142276 100644 --- a/source/backend/cpu/CPUBackend.cpp +++ b/source/backend/cpu/CPUBackend.cpp @@ -509,21 +509,21 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p mDmaInfo.reset(new CPURuntime::DynamicAllocator); mDmaInfo->mDynamicAllocator.reset(mRuntime->createDynamicBufferAlloctor(0)); mDmaInfo->mCurrentDynamicAllocator = mDmaInfo->mDynamicAllocator.get(); + mDmaInfo->mCacheGroup.resize(MNN_CPU_MAX_BUFFER_INDEX); + for (int i=0; imCacheGroup.size(); ++i) { + mDmaInfo->mCacheGroup[i].reset(new CPUResizeCache); + } } else { mDmaInfo = dynamicAlloc; } mPrecisionMode = precision; mCoreFunctions = MNNGetCoreFunctions(); mInt8CoreFunctions = MNNGetInt8CoreFunctions(); - mCacheGroup.resize(MNN_CPU_MAX_BUFFER_INDEX); - for (int i=0; imCacheGroup[0].get(); } CPUBackend::~CPUBackend() { - mCacheGroup.clear(); + // Do nothing } void CPUBackend::_resetDynamicMemory() const { mRuntime->pCurrentStatus = mDmaInfo->mDynamicAllocator->apply(); @@ -560,7 +560,7 @@ bool CPUBackend::onSelectDynamicAllocator(int index, int maxIndex) { mRuntime->buffer(0)->release(); mDmaInfo->mCurrentDynamicAllocator = mDmaInfo->mDynamicAllocator.get(); } - mCache = mCacheGroup[index].get(); + mCache = mDmaInfo->mCacheGroup[index].get(); return true; } diff --git a/source/backend/cpu/CPUBackend.hpp b/source/backend/cpu/CPUBackend.hpp index 587a48bd..7c699408 100644 --- a/source/backend/cpu/CPUBackend.hpp +++ b/source/backend/cpu/CPUBackend.hpp @@ -23,12 +23,14 @@ namespace MNN { class WorkerThread; +class CPUResizeCache; class CPURuntime : public Runtime { public: struct DynamicAllocator { std::shared_ptr mDynamicAllocator; std::shared_ptr mDynamicAllocatorBackup; BufferAllocator* mCurrentDynamicAllocator = nullptr; + std::vector> mCacheGroup; }; friend class CPUBackend; CPURuntime(const Backend::Info& info); @@ -82,7 +84,6 @@ struct CoreFunctions; struct CoreInt8Functions; struct MatmulRelatedFunctions; -class CPUResizeCache; class CPUMemObj : public Backend::MemObj { public: CPUMemObj(BufferAllocator* allocator, MemChunk chunk, int size) : mAllocator(allocator), mChunk(chunk), mSize(size) {} @@ -199,7 +200,6 @@ private: BackendConfig::MemoryMode mMemory; static std::map* gCreator; CPUResizeCache* mCache; - std::vector> mCacheGroup; }; /** execution cast wrapper. insert tensor cast dynamic. */ class CastWrapExecution : public Execution { diff --git a/source/backend/cpu/CPURuntime.cpp b/source/backend/cpu/CPURuntime.cpp index 423888e8..86f56194 100644 --- a/source/backend/cpu/CPURuntime.cpp +++ b/source/backend/cpu/CPURuntime.cpp @@ -170,7 +170,7 @@ cpu_mask_t MNNGetCPUMask(const std::vector& cpuIds) { // cpuinfo // Reference from: https://github.com/pytorch/cpuinfo -#if defined(ENABLE_ARMV82) && defined(__arm__) +#if (defined(ENABLE_ARMV82) && defined(__arm__)) || (defined(__ANDROID__) && defined(__aarch64__)) /* As per include/sys/system_properties.h in Android NDK */ #define CPUINFO_HARDWARE_VALUE_MAX 64 @@ -1180,7 +1180,7 @@ struct cpuinfo_arm_chipset cpuinfo_arm_android_decode_chipset(const struct cpuin // MNN_PRINT("chipset vendor, series, model is: %d, %d, %d\n", chipset.vendor, chipset.series, chipset.model); return chipset; } -static void _getInfoARMv7(MNNCPUInfo* cpuinfo_isa) { +static void _getInfoArm(MNNCPUInfo* cpuinfo_isa) { // Get White List And Black List struct cpuinfo_arm_linux_processor* arm_linux_processors = NULL; if (0 == cpuinfo_isa->groups.size()) { @@ -1500,8 +1500,8 @@ static void _fillInfo(MNNCPUInfo* cpuinfo_isa) { #if defined(__aarch64__) _getInfoAux(cpuinfo_isa); #endif -#if defined(ENABLE_ARMV82) && defined(__arm__) - _getInfoARMv7(cpuinfo_isa); +#if (defined(ENABLE_ARMV82) && defined(__arm__)) || (defined(__ANDROID__) && defined(__aarch64__)) + _getInfoArm(cpuinfo_isa); #endif // #ifdef arm / arm64 #endif // #ifdef __linux__ diff --git a/source/backend/cpu/CPUSoftmax.cpp b/source/backend/cpu/CPUSoftmax.cpp index d4811899..98c23444 100644 --- a/source/backend/cpu/CPUSoftmax.cpp +++ b/source/backend/cpu/CPUSoftmax.cpp @@ -200,7 +200,7 @@ int CPUSoftmax::_softmaxCommon(const uint8_t *srcData, uint8_t *dstData) { for (int v=0; vonAcquireBuffer(new_key, Backend::STATIC); for (int h = 0; h < mKvNumHead; h++) { memcpy( - new_key->host() + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, - mPastKey->host() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, + new_key->host() + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, + mPastKey->host() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8 ); } @@ -123,8 +123,8 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) { mBackend->onAcquireBuffer(new_key, Backend::STATIC); for (int h = 0; h < mKvNumHead; h++) { memcpy( - new_key->host() + h * new_key->stride(0), - mPastKey->host() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP), + new_key->host() + h * new_key->stride(0), + mPastKey->host() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP), ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) ); } @@ -135,12 +135,12 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) { mBackend->onAcquireBuffer(new_key, Backend::STATIC); for (int h = 0; h < mKvNumHead; h++) { memcpy( - new_key->host() + h * new_key->stride(0) * mBytes, - mPastKey->host() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes, + new_key->host() + h * new_key->stride(0) * mBytes, + mPastKey->host() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes, ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes ); if ((new_key->stride(0) - mPastKey->stride(0)) > 0) { - memset(new_key->host() + h * new_key->stride(0) * mBytes + mPastKey->stride(0) * mBytes, 0, (new_key->stride(0) - mPastKey->stride(0)) * mBytes); + memset(new_key->host() + h * new_key->stride(0) * mBytes + mPastKey->stride(0) * mBytes, 0, (new_key->stride(0) - mPastKey->stride(0)) * mBytes); } } mPastKey.reset(new_key); @@ -152,8 +152,8 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) { for (int h = 0; h < mKvNumHead; h++) { for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { memcpy( - new_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP, - mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP, + new_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP, + mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP, ROUND_UP(oldMaxLength, lP) * hP ); } @@ -166,12 +166,12 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) { for (int h = 0; h < mKvNumHead; h++) { for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { memcpy( - new_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes, - mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes, + new_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes, + mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes, ROUND_UP(oldMaxLength, lP) * hP * mBytes ); if ((new_value->stride(1) - mPastValue->stride(1)) > 0) { - memset(new_value->host() + (h * new_value->stride(0) + i * new_value->stride(1)) * mBytes + mPastValue->stride(1) * mBytes, 0, (new_value->stride(1) - mPastValue->stride(1)) * mBytes); + memset(new_value->host() + (h * new_value->stride(0) + i * new_value->stride(1)) * mBytes + mPastValue->stride(1) * mBytes, 0, (new_value->stride(1) - mPastValue->stride(1)) * mBytes); } } } @@ -189,7 +189,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) { for (int h = 0; h < mKvNumHead; h++) { memcpy( mMapKeyAddr + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, - mPastKey->host() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, + mPastKey->host() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8 ); } @@ -200,7 +200,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) { for (int h = 0; h < mKvNumHead; h++) { memcpy( mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP, - mPastKey->host() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP, + mPastKey->host() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP, UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP ); } @@ -214,7 +214,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) { for (int h = 0; h < mKvNumHead; h++) { memcpy( mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes, - mPastKey->host() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes, + mPastKey->host() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes ); } @@ -227,7 +227,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) { for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { memcpy( mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP, - mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP, + mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP, ROUND_UP(oldMaxLength, lP) * hP ); } @@ -243,7 +243,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) { for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { memcpy( mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes, - mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes, + mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes, ROUND_UP(oldMaxLength, lP) * hP * mBytes ); } @@ -282,8 +282,8 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o memset(old_value->host(), 0, old_value->length(0) * old_value->stride(0) * mBytes); } mmapKVCache(oldKeySize, oldValueSize); - memcpy(old_key->host(), mMapKeyAddr, oldKeySize); - memcpy(old_value->host(), mMapValueAddr, oldValueSize); + memcpy(old_key->host(), mMapKeyAddr, oldKeySize); + memcpy(old_value->host(), mMapValueAddr, oldValueSize); // Step 2: Resize the kvcache files and remap them unmapKVCache(oldKeySize, oldValueSize); resetKVCacheFileSize(keySize, valueSize); @@ -293,7 +293,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o for (int h = 0; h < mKvNumHead; h++) { memcpy( mMapKeyAddr + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, - old_key->host() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, + old_key->host() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8 ); } @@ -301,7 +301,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o for (int h = 0; h < mKvNumHead; h++) { memcpy( mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP, - old_key->host() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP, + old_key->host() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP, UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP ); } @@ -309,7 +309,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o for (int h = 0; h < mKvNumHead; h++) { memcpy( mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes, - old_key->host() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes, + old_key->host() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes ); } @@ -319,7 +319,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { memcpy( mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP, - old_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP, + old_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP, ROUND_UP(oldMaxLength, lP) * hP ); } @@ -329,7 +329,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { memcpy( mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes, - old_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes, + old_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes, ROUND_UP(oldMaxLength, lP) * hP * mBytes ); } @@ -460,9 +460,9 @@ void KVCacheManager::onRealloc(const KVMeta* meta) { mBackend->onAcquireBuffer(new_zeroPoint, Backend::STATIC); mBackend->onAcquireBuffer(new_sum, Backend::STATIC); for (int h = 0; h < mKvNumHead; h++) { - memcpy(new_scale->host() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyScale->host() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); - memcpy(new_zeroPoint->host() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyZeroPoint->host() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); - memcpy(new_sum->host() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeySum->host() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); + memcpy(new_scale->host() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyScale->host() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); + memcpy(new_zeroPoint->host() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyZeroPoint->host() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); + memcpy(new_sum->host() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeySum->host() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); } mKeyScale.reset(new_scale); mKeyZeroPoint.reset(new_zeroPoint); @@ -473,8 +473,8 @@ void KVCacheManager::onRealloc(const KVMeta* meta) { mBackend->onAcquireBuffer(new_scale, Backend::STATIC); mBackend->onAcquireBuffer(new_zeroPoint, Backend::STATIC); for (int h = 0; h < mKvNumHead; h++) { - memcpy(new_scale->host() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyScale->host() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes); - memcpy(new_zeroPoint->host() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyZeroPoint->host() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes); + memcpy(new_scale->host() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyScale->host() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes); + memcpy(new_zeroPoint->host() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyZeroPoint->host() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes); } mKeyScale.reset(new_scale); mKeyZeroPoint.reset(new_zeroPoint); @@ -690,8 +690,8 @@ void KVCacheManager::onDequantValue(Tensor * dequantedValues) { int tileCount = UP_DIV(mKvNumHead, mThreadNum); std::function dequant = [=](int tid) { for (int kv_h = tid * tileCount; kv_h < (tid+1) * tileCount && kv_h < mKvNumHead; kv_h++) { - char * dst = dequantedValues->host() + kv_h * UP_DIV(mHeadDim, hP) * mPastLength * hP * mBytes; - char * src = addrOfValue(kv_h); + int8_t * dst = dequantedValues->host() + kv_h * UP_DIV(mHeadDim, hP) * mPastLength * hP * mBytes; + int8_t * src = addrOfValue(kv_h); for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { if (mBytes == 2) { core->MNNFp8ToFp16((uint16_t*)dst, (uint8_t*)src, mPastLength * hP); diff --git a/source/backend/cpu/KVCacheManager.hpp b/source/backend/cpu/KVCacheManager.hpp index e3fb0e12..7083c346 100644 --- a/source/backend/cpu/KVCacheManager.hpp +++ b/source/backend/cpu/KVCacheManager.hpp @@ -46,8 +46,8 @@ private: std::shared_ptr mKeySum; // numhead, [maxlen/hP8, hP8] file_t mKeyCacheFD = INVALID_FILE; // The file descriptor of keys file_t mValueCacheFD = INVALID_FILE; // The file descriptor of values - char * mMapKeyAddr = nullptr; // Memory-mapped address of keys - char * mMapValueAddr = nullptr; // Memory-mapped address of values + int8_t * mMapKeyAddr = nullptr; // Memory-mapped address of keys + int8_t * mMapValueAddr = nullptr; // Memory-mapped address of values bool mKVCacheInDisk = false; // Whether the kvcache is in disk or in memory now int mPastLength = 0; // Length of past kvcache int mMaxLength = 0; // Capacity of current kvcache buffer (how many kv items can be stored at most) @@ -104,15 +104,15 @@ public: return mMaxLength; } uint8_t* keyAddr() { - char * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host(); + int8_t * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host(); return (uint8_t*)baseAddr; } uint8_t* valudAddr() { - char * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host(); + int8_t * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host(); return (uint8_t*)baseAddr; } - char * addrOfKey(int kv_h) { - char * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host(); + int8_t * addrOfKey(int kv_h) { + int8_t * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host(); if (mConfig.mUseInt8Kernel) { return baseAddr + kv_h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8; } else if (mConfig.mQuantKey) { @@ -121,35 +121,35 @@ public: return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes; } } - char * addrOfValue(int kv_h) { - char * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host(); + int8_t * addrOfValue(int kv_h) { + int8_t * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host(); if (mConfig.mQuantValue) { return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP; } else { return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP * mBytes; } } - char * addrOfScale(int kv_h) { + int8_t * addrOfScale(int kv_h) { if (mConfig.mUseInt8Kernel) { - return mKeyScale->host() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; + return mKeyScale->host() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; } else if (mConfig.mQuantKey) { - return mKeyScale->host() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes; + return mKeyScale->host() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes; } else { return nullptr; } } - char * addrOfZeroPoint(int kv_h) { + int8_t * addrOfZeroPoint(int kv_h) { if (mConfig.mUseInt8Kernel) { - return mKeyZeroPoint->host() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; + return mKeyZeroPoint->host() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; } else if (mConfig.mQuantKey) { - return mKeyZeroPoint->host() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes; + return mKeyZeroPoint->host() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes; } else { return nullptr; } } - char * addrOfKeySum(int kv_h) { + int8_t * addrOfKeySum(int kv_h) { if (mConfig.mUseInt8Kernel) { - return mKeySum->host() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; + return mKeySum->host() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; }else { return nullptr; } diff --git a/source/backend/cpu/arm/CommonOptFunctionNeon.cpp b/source/backend/cpu/arm/CommonOptFunctionNeon.cpp index e876b1b9..a5c14fd7 100644 --- a/source/backend/cpu/arm/CommonOptFunctionNeon.cpp +++ b/source/backend/cpu/arm/CommonOptFunctionNeon.cpp @@ -73,6 +73,386 @@ void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int32_t* dim) { } } + +#define EXP_APPROX_MIN_INPUT vdupq_n_f32(-88.0f) +#define EXP_APPROX_MAX_INPUT vdupq_n_f32(88.0f) +#define EXP_APPROX_LN2 vdupq_n_f32(0.69314718056f) // ln(2) +#define EXP_APPROX_LN2_INV vdupq_n_f32(1.44269504089f) // 1/ln(2) +// Fourth-order polynomial approximation coefficients of exp(r): +// P(x) = c4*x^4 + c3*x^3 + c2*x^2 + c1*x + c0 +#define EXP_APPROX_C4 vdupq_n_f32(0.0416624f) +#define EXP_APPROX_C3 vdupq_n_f32(0.166665f) +#define EXP_APPROX_C2 vdupq_n_f32(0.500000f) +#define EXP_APPROX_C1 vdupq_n_f32(1.0f) +#define EXP_APPROX_C0 vdupq_n_f32(1.0f) + +#ifndef __aarch64__ +static inline float32x4_t vrndaq_f32_compat(float32x4_t val) { + const float32x4_t v_zero = vdupq_n_f32(0.0f); + + float32x4_t v_truncated = vcvtq_f32_s32(vcvtq_s32_f32(val)); + + uint32x4_t v_is_positive_frac = vcgtq_f32(val, v_truncated); + uint32x4_t v_is_negative_frac = vcltq_f32(val, v_truncated); + + float32x4_t v_offset = vbslq_f32(v_is_positive_frac, vdupq_n_f32(1.0f), v_zero); + v_offset = vbslq_f32(v_is_negative_frac, vdupq_n_f32(-1.0f), v_offset); + + return vaddq_f32(v_truncated, v_offset); +} + +#endif + +static inline float32x4_t expApprox(float32x4_t x) { + x = vminq_f32(vmaxq_f32(x, EXP_APPROX_MIN_INPUT), EXP_APPROX_MAX_INPUT); + + float32x4_t k_float; + float32x4_t r; + float32x4_t exp_r; + +#if defined(__aarch64__) + // 1. x = k * ln(2) + r + k_float = vrndaq_f32(vmulq_f32(x, EXP_APPROX_LN2_INV)); + + // r = x - k * ln(2) + r = vfmsq_f32(x, k_float, EXP_APPROX_LN2); + + // 2. c0 + r*(c1 + r*(c2 + r*(c3 + r*c4))) (Horner's method) + exp_r = vfmaq_f32(EXP_APPROX_C3, EXP_APPROX_C4, r); // c3 + c4*r + exp_r = vfmaq_f32(EXP_APPROX_C2, exp_r, r); // c2 + r*(...) + exp_r = vfmaq_f32(EXP_APPROX_C1, exp_r, r); // c1 + r*(...) + exp_r = vfmaq_f32(EXP_APPROX_C0, exp_r, r); // c0 + r*(...) + +#else + + k_float = vrndaq_f32_compat(vmulq_f32(x, EXP_APPROX_LN2_INV)); + + + r = vsubq_f32(x, vmulq_f32(k_float, EXP_APPROX_LN2)); + + // 2. c0 + r*(c1 + r*(c2 + r*(c3 + r*c4))) + exp_r = vmlaq_f32(EXP_APPROX_C3, EXP_APPROX_C4, r); // c3 + c4*r + exp_r = vmlaq_f32(EXP_APPROX_C2, exp_r, r); // c2 + r*(...) + exp_r = vmlaq_f32(EXP_APPROX_C1, exp_r, r); // c1 + r*(...) + exp_r = vmlaq_f32(EXP_APPROX_C0, exp_r, r); // c0 + r*(...) + +#endif + + int32x4_t k_int = vcvtq_s32_f32(k_float); + int32x4_t k_shifted = vshlq_n_s32(k_int, 23); + float32x4_t result = vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(exp_r), k_shifted)); + return result; +} + +void MNNExpC8(float* dst, const float* src, float* offset, const float* parameters, size_t countC8) { + float32x4_t maxVec = vdupq_n_f32(offset[2]); + float32x4_t sumVec0 = vdupq_n_f32(0); + float32x4_t sumVec1 = vdupq_n_f32(0); + + float32x4_t c0 = vdupq_n_f32(offset[0]); + float32x4_t c1 = vdupq_n_f32(offset[1]); + + + for (int i = 0; i < countC8; ++i) { + float32x4_t srcVec0 = vld1q_f32(src); + float32x4_t srcVec1 = vld1q_f32(src + 4); + auto subVec0 = vaddq_f32(vmulq_f32(srcVec0, c0), maxVec); + auto subVec1 = vaddq_f32(vmulq_f32(srcVec1, c0), maxVec); + auto expVec0 = vaddq_f32(expApprox(subVec0), c1); + auto expVec1 = vaddq_f32(expApprox(subVec1), c1); + vst1q_f32(dst, expVec0); + vst1q_f32(dst + 4, expVec1); + sumVec0 = vaddq_f32(sumVec0, expVec0); + sumVec1 = vaddq_f32(sumVec1, expVec1); + + src += 8; + dst += 8; + } + + sumVec0 = vaddq_f32(sumVec0, sumVec1); + float32x2_t sumP = vpadd_f32(vget_low_f32(sumVec0), vget_high_f32(sumVec0)); + sumP = vpadd_f32(sumP, sumP); + offset[3] += vget_lane_f32(sumP, 0); +} + + +void MNNExp(float* destPtr, const float* srcPtr, float* offset, size_t size) { + float32x4_t maxVec = vdupq_n_f32(-offset[2]); + float32x4_t sumVec0 = vdupq_n_f32(0); + float32x4_t sumVec1 = vdupq_n_f32(0); + if (offset[0] == 1.f && offset[1] == 0.f) { + while (size >= 8) { + float32x4_t srcVec0 = vld1q_f32(srcPtr); + float32x4_t srcVec1 = vld1q_f32(srcPtr + 4); + auto subVec0 = vsubq_f32(srcVec0, maxVec); + auto subVec1 = vsubq_f32(srcVec1, maxVec); + auto expVec0 = expApprox(subVec0); + auto expVec1 = expApprox(subVec1); + vst1q_f32(destPtr, expVec0); + vst1q_f32(destPtr + 4, expVec1); + sumVec0 = vaddq_f32(sumVec0, expVec0); + sumVec1 = vaddq_f32(sumVec1, expVec1); + srcPtr += 8; + destPtr += 8; + size -= 8; + + } + while (size >= 4) { + float32x4_t srcVec0 = vld1q_f32(srcPtr); + auto subVec0 = vsubq_f32(srcVec0, maxVec); + auto expVec0 = expApprox(subVec0); + sumVec0 = vaddq_f32(sumVec0, expVec0); + vst1q_f32(destPtr, expVec0); + srcPtr += 4; + destPtr += 4; + size -= 4; + } + //merge + sumVec0 = vaddq_f32(sumVec0, sumVec1); + float32x2_t sumP = vpadd_f32(vget_low_f32(sumVec0), vget_high_f32(sumVec0)); + sumP = vpadd_f32(sumP, sumP); + auto newSum = vget_lane_f32(sumP, 0); + if (size > 0) { + float tmp[4]; + memcpy(tmp, srcPtr, size * sizeof(float)); + float32x4_t srcVec0 = vld1q_f32(tmp); + auto subVec0 = vsubq_f32(srcVec0, maxVec); + auto expVec0 = expApprox(subVec0); + vst1q_f32(tmp, expVec0); + for (int i = 0; i < size; ++i) { + newSum += tmp[i]; + destPtr[i] = tmp[i]; + } + } + offset[3] += newSum; + } else { + float32x4_t c0 = vdupq_n_f32(offset[0]); + float32x4_t c1 = vdupq_n_f32(offset[1]); + while (size >= 8) { + float32x4_t srcVec0 = vld1q_f32(srcPtr); + float32x4_t srcVec1 = vld1q_f32(srcPtr + 4); + auto subVec0 = vsubq_f32(vmulq_f32(srcVec0, c0), maxVec); + auto subVec1 = vsubq_f32(vmulq_f32(srcVec1, c0), maxVec); + auto expVec0 = vaddq_f32(expApprox(subVec0), c1); + auto expVec1 = vaddq_f32(expApprox(subVec1), c1); + vst1q_f32(destPtr, expVec0); + vst1q_f32(destPtr + 4, expVec1); + sumVec0 = vaddq_f32(sumVec0, expVec0); + sumVec1 = vaddq_f32(sumVec1, expVec1); + srcPtr += 8; + destPtr += 8; + size -= 8; + + } + while (size >= 4) { + float32x4_t srcVec0 = vld1q_f32(srcPtr); + auto subVec0 = vsubq_f32(vmulq_f32(srcVec0, c0), maxVec); + auto expVec0 = vaddq_f32(expApprox(subVec0), c1); + sumVec0 = vaddq_f32(sumVec0, expVec0); + vst1q_f32(destPtr, expVec0); + srcPtr += 4; + destPtr += 4; + size -= 4; + } + //merge + sumVec0 = vaddq_f32(sumVec0, sumVec1); + float32x2_t sumP = vpadd_f32(vget_low_f32(sumVec0), vget_high_f32(sumVec0)); + sumP = vpadd_f32(sumP, sumP); + auto newSum = vget_lane_f32(sumP, 0); + if (size > 0) { + float tmp[4]; + memcpy(tmp, srcPtr, size * sizeof(float)); + float32x4_t srcVec0 = vld1q_f32(tmp); + auto subVec0 = vsubq_f32(vmulq_f32(srcVec0, c0), maxVec); + auto expVec0 = vaddq_f32(expApprox(subVec0), c1); + vst1q_f32(tmp, expVec0); + for (int i = 0; i < size; ++i) { + newSum += tmp[i]; + destPtr[i] = tmp[i]; + } + } + offset[3] += newSum; + } +} + + +static inline void transposeAndStore4x4(const float* srcRowPtrs[4], float* dstColBase, size_t dstColStride) { + float32x4_t row0 = vld1q_f32(srcRowPtrs[0]); + float32x4_t row1 = vld1q_f32(srcRowPtrs[1]); + float32x4_t row2 = vld1q_f32(srcRowPtrs[2]); + float32x4_t row3 = vld1q_f32(srcRowPtrs[3]); + + // Step 1: Transpose 2x2 blocks of 2-element vectors + float32x4x2_t t01 = vtrnq_f32(row0, row1); + float32x4x2_t t23 = vtrnq_f32(row2, row3); + + // Step 2: Combine the results to get the full transpose + float32x4_t col0 = vcombine_f32(vget_low_f32(t01.val[0]), vget_low_f32(t23.val[0])); + float32x4_t col1 = vcombine_f32(vget_low_f32(t01.val[1]), vget_low_f32(t23.val[1])); + float32x4_t col2 = vcombine_f32(vget_high_f32(t01.val[0]), vget_high_f32(t23.val[0])); + float32x4_t col3 = vcombine_f32(vget_high_f32(t01.val[1]), vget_high_f32(t23.val[1])); + + vst1q_f32(dstColBase, col0); + vst1q_f32(dstColBase + dstColStride, col1); + vst1q_f32(dstColBase + 2 * dstColStride, col2); + vst1q_f32(dstColBase + 3 * dstColStride, col3); +} + + +void packKvCache(float* dst, const float* src, size_t seqLen, size_t kvSeqLen, size_t eP) { + if (seqLen == 0 || kvSeqLen == 0) { + return; + } + + // source [seqLen, kvSeqLen] + // dest [seqLen/eP, kvSeqLen, eP] + + const int kTileS = 4; // Tiling size for seqLen dimension + const int kTileK = 4; // Tiling size for kvSeqLen dimension + const size_t dstSOuterStride = kvSeqLen * eP; + + int s = 0; + for (; s + kTileS <= seqLen; s += kTileS) { + const int sOuter = s / eP; + const int sInner = s % eP; + if (sInner + kTileS > eP) { + break; + } + + float* dstSBase = dst + sOuter * dstSOuterStride + sInner; + const float* srcRowPtrs[kTileS]; + srcRowPtrs[0] = src + (s + 0) * kvSeqLen; + srcRowPtrs[1] = src + (s + 1) * kvSeqLen; + srcRowPtrs[2] = src + (s + 2) * kvSeqLen; + srcRowPtrs[3] = src + (s + 3) * kvSeqLen; + + int k = 0; + for (; k + kTileK <= kvSeqLen; k += kTileK) { + const float* currentSrcPtrs[kTileS]; + currentSrcPtrs[0] = srcRowPtrs[0] + k; + currentSrcPtrs[1] = srcRowPtrs[1] + k; + currentSrcPtrs[2] = srcRowPtrs[2] + k; + currentSrcPtrs[3] = srcRowPtrs[3] + k; + float* dstKBase = dstSBase + k * eP; + transposeAndStore4x4(currentSrcPtrs, dstKBase, eP); + } + + for (; k < kvSeqLen; ++k) { + float buffer[kTileS] = { + srcRowPtrs[0][k], + srcRowPtrs[1][k], + srcRowPtrs[2][k], + srcRowPtrs[3][k] + }; + vst1q_f32(dstSBase + k * eP, vld1q_f32(buffer)); + } + } + + for (; s < seqLen; ++s) { + const int sOuter = s / eP; + const int sInner = s % eP; + const float* srcRow = src + s * kvSeqLen; + float* dstSBase = dst + sOuter * dstSOuterStride + sInner; + for (int k = 0; k < kvSeqLen; ++k) { + dstSBase[k * eP] = srcRow[k]; + } + } +} + +void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) { + for (int k = 0; k < outside; ++k) { + auto source = input + k * reduceSize; + auto dest = softmaxDst + k * reduceSize; + + // new max + auto srcPtr = source; + auto size = reduceSize; + float32x4_t maxVec0 = vdupq_n_f32(source[0]); + auto maxVec1 = maxVec0; + + float oldMax = source[0]; + if (runningMax) { + oldMax = runningMax[k]; + } + + while (size >= 8) { + float32x4_t srcVec0 = vld1q_f32(srcPtr); + float32x4_t srcVec1 = vld1q_f32(srcPtr + 4); + + maxVec0 = vmaxq_f32(maxVec0, srcVec0); + maxVec1 = vmaxq_f32(maxVec1, srcVec1); + + srcPtr += 8; + size -= 8; + } + + while (size >= 4) { + float32x4_t srcVec0 = vld1q_f32(srcPtr); + maxVec0 = vmaxq_f32(maxVec0, srcVec0); + srcPtr += 4; + size -= 4; + } + + maxVec0 = vmaxq_f32(maxVec0, maxVec1); + float32x2_t maxP = vpmax_f32(vget_low_f32(maxVec0), vget_high_f32(maxVec0)); + maxP = vpmax_f32(maxP, maxP); + auto newMax = vget_lane_f32(maxP, 0); + + while (size > 0) { + newMax = ALIMAX(newMax, srcPtr[0]); + srcPtr += 1; + size -= 1; + } + + newMax = ALIMAX(oldMax, newMax); + srcPtr = source; + auto destPtr = dest; + size = reduceSize; + + float exprOffset[4] = { + 1.0f, + 0.0f, + 0.0f, + 0.0f + }; + exprOffset[2] = -newMax; + + // expf(xi-newmax) & new sum + MNNExp(destPtr, srcPtr, exprOffset, size); + + if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) { + // update runningSum, runningMax, scale=expf(oldMax-newMax) + float newSum = exprOffset[3]; + runningSum[k] = runningSum[k] * expf(oldMax - newMax) + newSum; + runningMax[k] = newMax; + updateScale[k] = expf(oldMax - newMax); + } else { + // Normalize + float sum = exprOffset[3]; + float scale = 1.0f / (sum + 1e-20f); + int count = reduceSize; + auto pDest = dest; + + float32x4_t scaleVec = vdupq_n_f32(scale); + while (count >= 4) { + float32x4_t data = vld1q_f32(pDest); + data = vmulq_f32(data, scaleVec); + vst1q_f32(pDest, data); + + pDest += 4; + count -= 4; + } + + while (count > 0) { + *pDest *= scale; + pDest++; + count--; + } + } + } +} + + #ifndef MNN_USE_NEON void MNNPackedSparseMatMulEpx1(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, unsigned int* NNZMap, int* dataOffsetMap) { diff --git a/source/backend/cpu/arm/arm32/MNNExpC8.S b/source/backend/cpu/arm/arm32/MNNExpC8.S deleted file mode 100644 index eac02b10..00000000 --- a/source/backend/cpu/arm/arm32/MNNExpC8.S +++ /dev/null @@ -1,115 +0,0 @@ -// -// MNNExpC8.S -// MNN -// -// Created by MNN on 2019/01/18. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __arm__ -#ifndef __aarch64__ - -#include "MNNAsmGlobal.h" -.text -.align 5 -//void MNNExpC8(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8) -asm_function MNNExpC8 - -//r0: dest, r1:source, r2: offset, r3:parameters, r4:countC8 -push {r4, r5, lr} -ldr r4, [sp, #12] -vpush {q4-q7} -vmov.i32 q7, #0 -ldr r5, [r2, #0] -vdup.32 q4, r5 // Alpha -ldr r5, [r2, #4] -vdup.32 q5, r5 // Beta -ldr r5, [r2, #8] -vdup.32 q6, r5 // Bias - - -vld1.32 {q0, q1}, [r3] - -vmov.i32 q2, #87 -vcvt.f32.s32 q2, q2 -vneg.f32 q3, q2 - -Loop: - -vld1.32 {q8, q9}, [r1]! -vmul.f32 q8, q8, q4 -vmul.f32 q9, q9, q4 -vadd.f32 q8, q8, q6 -vadd.f32 q9, q9, q6 - -vmin.f32 q8, q8, q2 -vmin.f32 q9, q9, q2 -vmax.f32 q10, q8, q3 -vmax.f32 q11, q9, q3 - -vmul.f32 q8, q10, d0[1] -vmul.f32 q9, q11, d0[1] -vcvt.s32.f32 q8, q8 -vcvt.s32.f32 q9, q9 - -vcvt.f32.s32 q12, q8 -vcvt.f32.s32 q13, q9 - -//q10, q11: t -vmls.f32 q10, q12, d0[0] -vmls.f32 q11, q13, d0[0] - -vmul.f32 q10, q10, d1[0] -vmul.f32 q11, q11, d1[0] - -.macro MLA_TWO z0 z1 z2 z3 -vdup.32 \z1, \z0 -vmla.f32 \z1, \z2, \z3 -.endm - -MLA_TWO d3[0], q12, q10, d3[1] -MLA_TWO d3[0], q13, q11, d3[1] -MLA_TWO d2[1], q14, q10, q12 -MLA_TWO d2[1], q15, q11, q13 -MLA_TWO d2[0], q12, q10, q14 -MLA_TWO d2[0], q13, q11, q15 -MLA_TWO d1[1], q14, q10, q12 -MLA_TWO d1[1], q15, q11, q13 -MLA_TWO d1[1], q12, q10, q14 -MLA_TWO d1[1], q13, q11, q15 - -//q12, q13 is expRemain - -vmul.f32 q12, q12, q12 -vmul.f32 q13, q13, q13 -vmul.f32 q12, q12, q12 -vmul.f32 q13, q13, q13 - -vshl.i32 q8, q8, #23 -vshl.i32 q9, q9, #23 -vadd.i32 q12, q12, q8 -vadd.i32 q13, q13, q9 - -vadd.f32 q12, q12, q5 -vadd.f32 q13, q13, q5 -vadd.f32 q7, q12, q7 - -vst1.32 {q12, q13}, [r0]! -vadd.f32 q7, q13, q7 - -subs r4, r4, #1 -bne Loop -add r5, r2, #12 -vld1.32 {d0[0]}, [r5] -vadd.f32 d14, d14, d15 -vtrn.32 d14, d15 -vadd.f32 d14, d14, d15 -vadd.f32 d0, d14, d0 -vst1.32 {d0[0]}, [r5] - -vpop {q4-q7} -pop {r4, r5, pc} - - -#endif -#endif diff --git a/source/backend/cpu/arm/arm32/MNNSoftmax.S b/source/backend/cpu/arm/arm32/MNNSoftmax.S deleted file mode 100644 index 3e272583..00000000 --- a/source/backend/cpu/arm/arm32/MNNSoftmax.S +++ /dev/null @@ -1,270 +0,0 @@ -// -// MNNSoftmax.S -// MNN -// -// Created by MNN on 2021/07/05. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __arm__ -#ifndef __aarch64__ - -#include "MNNAsmGlobal.h" -.text -.align 5 -//void MNNSoftmax(float* dest, const float* source, size_t size) -asm_function MNNSoftmax - push {r4, r5, r6, r7, r8, lr} - bic r3, r2, #3 - lsrs r4, r2, #2 - vpush.64 {d8, d9, d10, d11, d12, d13, d14, d15} - sub sp, sp, #8 - beq Loop_2 - vld1.32 {d16-d17}, [r1] - cmp r4, #1 - beq Loop_3 - add lr, r1, #16 - mov ip, #1 -Loop_4: - vld1.32 {d18-d19}, [lr]! - add ip, ip, #1 - cmp r4, ip - vmax.f32 q8, q8, q9 - bne Loop_4 -Loop_3: - vmov.32 ip, d16[0] - vmov s15, ip - vmov.32 ip, d16[1] - vmov s12, ip - vmov.32 ip, d17[0] - vcmpe.f32 s15, s12 - vmov s13, ip - vmov.32 ip, d17[1] - vmrs APSR_nzcv, FPSCR - vmov s14, ip - vmovle.f32 s15, s12 - vcmpe.f32 s13, s15 - vmrs APSR_nzcv, FPSCR - vmovpl.f32 s15, s13 - vcmpe.f32 s14, s15 - vmrs APSR_nzcv, FPSCR - vmovpl.f32 s15, s14 - cmp r2, r3 - bls Loop_25 -Loop_24: - add lr, r1, r3, lsl #2 - mov ip, r3 -Loop_11: - vldmia.32 lr!, {s14} - add ip, ip, #1 - vcmpe.f32 s14, s15 - vmrs APSR_nzcv, FPSCR - vmovpl.f32 s15, s14 - cmp r2, ip - bhi Loop_11 - cmp r4, #0 - beq Loop_8 -Loop_25: - vmov.f32 q11, #0.0 @ v4sf - mov r5, r1 - vldr d14, Loop_54 - vldr d15, Loop_54+8 - mov lr, r0 - vldr d12, Loop_54+16 - vldr d13, Loop_54+24 - mov ip, #0 - vmov.f32 q12, #1.0e+0 @ v4sf - vstr.32 s15, [sp, #4] - vmov.f32 q5, #5.0e-1 @ v4sf - vldr d8, Loop_54+32 - vldr d9, Loop_54+40 - vldr d0, Loop_54+48 - vldr d1, Loop_54+56 - vldr d2, Loop_54+64 - vldr d3, Loop_54+72 - vldr d4, Loop_54+80 - vldr d5, Loop_54+88 - vldr d30, Loop_54+96 - vldr d31, Loop_54+104 - vmov.i32 q14, #8388608 @ v4si - vdup.32 q13, d7[1] -Loop_12: - vld1.32 {d20-d21}, [r5]! - add ip, ip, #1 - vmov.i32 q3, #127 @ v4si - cmp r4, ip - vsub.f32 q10, q10, q13 - vmax.f32 q10, q10, q15 - vmin.f32 q10, q10, q2 - vmul.f32 q9, q10, q6 - vcvt.s32.f32 q9, q9 - vcvt.f32.s32 q8, q9 - vadd.i32 q9, q9, q3 - vmul.f32 q8, q8, q7 - vmul.i32 q9, q9, q14 - vsub.f32 q10, q10, q8 - vmul.f32 q8, q1, q10 - vadd.f32 q8, q8, q0 - vmul.f32 q8, q8, q10 - vadd.f32 q8, q8, q4 - vmul.f32 q8, q8, q10 - vadd.f32 q8, q8, q5 - vmul.f32 q8, q8, q10 - vadd.f32 q8, q8, q12 - vmul.f32 q8, q8, q10 - vadd.f32 q8, q8, q12 - vmul.f32 q9, q9, q8 - vadd.f32 q11, q9, q11 - vst1.32 {d18-d19}, [lr]! - bgt Loop_12 - vmov.32 ip, d22[0] - vldr.32 s15, [sp, #4] - cmp r2, r3 - vmov s12, ip - vmov.32 ip, d22[1] - vmov s11, ip - vmov.32 ip, d23[0] - vadd.f32 s12, s12, s11 - vmov s13, ip - vmov.32 ip, d23[1] - vadd.f32 s12, s12, s13 - vmov s14, ip - vadd.f32 s12, s12, s14 - bls Loop_52 -Loop_26: - lsl ip, r3, #2 - add r5, r1, r2, lsl #2 - add lr, r1, ip - movw r7, #13877 - movt r7, 179 - movw r6, #55317 - movt r6, 32310 - vldr.32 s11, Loop_54+112 - vldr.32 s10, Loop_54+116 - add r1, r0, ip - vldr.32 s9, Loop_54+120 - vmov.f32 s8, #5.0e-1 - vldr.32 s5, Loop_54+124 - vldr.32 s6, Loop_54+128 - vldr.32 s7, Loop_54+132 -Loop_17: - vldmia.32 lr!, {s14} - vmov s13, r6 - vsub.f32 s14, s14, s15 - vcmpe.f32 s14, s11 - vmrs APSR_nzcv, FPSCR - vmovle s13, r7 - ble Loop_14 - vcmpe.f32 s14, s10 - vmrs APSR_nzcv, FPSCR - bmi Loop_53 -Loop_14: - vadd.f32 s12, s12, s13 - cmp r5, lr - vstmia.32 r1!, {s13} - bne Loop_17 - vmov.f32 s15, #1.0e+0 - cmp r4, #0 - vdiv.f32 s14, s15, s12 - beq Loop_20 -Loop_23: - vdup.32 q9, d7[0] - mov ip, r0 - mov r1, #0 -Loop_19: - vld1.32 {d16-d17}, [ip] - add r1, r1, #1 - cmp r4, r1 - vmul.f32 q8, q8, q9 - vst1.32 {d16-d17}, [ip]! - bgt Loop_19 - cmp r2, r3 - bls Loop_1 - lsl ip, r3, #2 -Loop_20: - add r0, r0, ip -Loop_21: - vldr.32 s15, [r0] - add r3, r3, #1 - cmp r2, r3 - vmul.f32 s15, s15, s14 - vstmia.32 r0!, {s15} - bhi Loop_21 -Loop_1: - add sp, sp, #8 - vldm sp!, {d8-d15} - pop {r4, r5, r6, r7, r8, pc} -Loop_2: - vldr.32 s15, Loop_54+136 - cmp r2, r3 - bhi Loop_24 -Loop_8: - cmp r2, r3 - vldrhi.32 s12, Loop_54+136 - bhi Loop_26 - b Loop_1 -Loop_52: - vmov.f32 s15, #1.0e+0 - cmp r4, #0 - vdiv.f32 s14, s15, s12 - bne Loop_23 - b Loop_1 -Loop_53: - vdiv.f32 s4, s14, s9 - vmov.f32 s13, #1.0e+0 - vcvt.s32.f32 s4, s4 - vcvt.f32.s32 s3, s4 - vmov r8, s4 @ int - vmov.f32 s4, s7 - vmls.f32 s14, s3, s9 - vmov.f32 s3, s6 - add r8, r8, #127 - lsl r8, r8, #23 - vmla.f32 s3, s14, s5 - vmla.f32 s4, s3, s14 - vmov.f32 s3, s8 - vmla.f32 s3, s4, s14 - vmov.f32 s4, s13 - vmla.f32 s4, s3, s14 - vmla.f32 s13, s4, s14 - vmov s14, r8 - vmul.f32 s13, s13, s14 - b Loop_14 -Loop_54: - .word 1060205080 - .word 1060205080 - .word 1060205080 - .word 1060205080 - .word 1069066811 - .word 1069066811 - .word 1069066811 - .word 1069066811 - .word 1042983595 - .word 1042983595 - .word 1042983595 - .word 1042983595 - .word 1026206379 - .word 1026206379 - .word 1026206379 - .word 1026206379 - .word 1007192201 - .word 1007192201 - .word 1007192201 - .word 1007192201 - .word 1118699520 - .word 1118699520 - .word 1118699520 - .word 1118699520 - .word -1028784128 - .word -1028784128 - .word -1028784128 - .word -1028784128 - .word -1028784128 - .word 1118699520 - .word 1060205080 - .word 1007192201 - .word 1026206379 - .word 1042983595 - .word 0 -#endif -#endif diff --git a/source/backend/cpu/arm/arm64/MNNExpC8.S b/source/backend/cpu/arm/arm64/MNNExpC8.S deleted file mode 100644 index 3edb91f1..00000000 --- a/source/backend/cpu/arm/arm64/MNNExpC8.S +++ /dev/null @@ -1,109 +0,0 @@ -// -// MNNExpC8.S -// MNN -// -// Created by MNN on 2019/01/18. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __aarch64__ - -#include "MNNAsmGlobal.h" -.text -.align 5 - -//void MNNExpC8(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8) -asm_function MNNExpC8 - -//x0: dest, x1:source, x2: offset, x3:parameters, x4:countC8 -ldr w5, [x2, #0] -ldr w6, [x2, #4] -ldr w7, [x2, #8] - -ld1 {v0.4s, v1.4s}, [x3] -movi v2.4s, #23 -movi v3.4s, #87 -scvtf v3.4s, v3.4s -fneg v4.4s, v3.4s -dup v30.4s, w5 -dup v31.4s, w6 -dup v29.4s, w7 - -// Summer -movi v28.4s, #0 - -Loop: - -ld1 {v16.4s, v17.4s}, [x1], #32 -fmul v16.4s, v16.4s, v30.4s -fmul v17.4s, v17.4s, v30.4s -fadd v16.4s, v16.4s, v29.4s -fadd v17.4s, v17.4s, v29.4s -fmin v16.4s, v16.4s, v3.4s -fmin v17.4s, v17.4s, v3.4s -fmax v18.4s, v16.4s, v4.4s -fmax v19.4s, v17.4s, v4.4s - -fmul v16.4s, v18.4s, v0.s[1] -fmul v17.4s, v19.4s, v0.s[1] -fcvtzs v16.4s, v16.4s -fcvtzs v17.4s, v17.4s -scvtf v20.4s, v16.4s -scvtf v21.4s, v17.4s - -//v18.4s, v19.4s: t -fmls v18.4s, v20.4s, v0.s[0] -fmls v19.4s, v21.4s, v0.s[0] - -fmul v18.4s, v18.4s, v0.s[2] -fmul v19.4s, v19.4s, v0.s[2] - -.macro MLA_TWO z0 z1 z2 z3 -dup \z1, \z0 -fmla \z1, \z2, \z3 -.endm - -MLA_TWO v1.s[2], v20.4s, v18.4s, v1.s[3] -MLA_TWO v1.s[2], v21.4s, v19.4s, v1.s[3] -MLA_TWO v1.s[1], v22.4s, v18.4s, v20.4s -MLA_TWO v1.s[1], v23.4s, v19.4s, v21.4s -MLA_TWO v1.s[0], v20.4s, v18.4s, v22.4s -MLA_TWO v1.s[0], v21.4s, v19.4s, v23.4s -MLA_TWO v0.s[3], v22.4s, v18.4s, v20.4s -MLA_TWO v0.s[3], v23.4s, v19.4s, v21.4s -MLA_TWO v0.s[3], v20.4s, v18.4s, v22.4s -MLA_TWO v0.s[3], v21.4s, v19.4s, v23.4s - -//v20.4s, v21.4s is expRemain -fmul v20.4s, v20.4s, v20.4s -fmul v21.4s, v21.4s, v21.4s -fmul v20.4s, v20.4s, v20.4s -fmul v21.4s, v21.4s, v21.4s - -ushl v16.4s, v16.4s, v2.4s -ushl v17.4s, v17.4s, v2.4s -add v20.4s, v20.4s, v16.4s -add v21.4s, v21.4s, v17.4s - -fadd v20.4s, v20.4s, v31.4s -fadd v21.4s, v21.4s, v31.4s - -st1 {v20.4s, v21.4s}, [x0], #32 -fadd v28.4s, v28.4s, v20.4s -fadd v28.4s, v28.4s, v21.4s - -subs x4, x4, #1 -bne Loop - -// Bias -add x7, x2, #12 -ld1 {v27.s}[0], [x7] -faddp v28.4s, v28.4s, v28.4s -faddp v28.2s, v28.2s, v28.2s -fadd v27.2s, v28.2s, v27.2s -st1 {v27.s}[0], [x7] - -ret - -#endif - diff --git a/source/backend/cpu/arm/arm64/MNNSoftmax.S b/source/backend/cpu/arm/arm64/MNNSoftmax.S deleted file mode 100644 index 03532335..00000000 --- a/source/backend/cpu/arm/arm64/MNNSoftmax.S +++ /dev/null @@ -1,204 +0,0 @@ -// -// MNNSoftmax.S -// MNN -// -// Created by MNN on 2021/07/05. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __aarch64__ - -#include "MNNAsmGlobal.h" -.text -.align 5 - -//void MNNSoftmax(float* dest, const float* source, size_t countC8) -asm_function MNNSoftmax - stp x21, x22, [sp, #-32]! - stp x19, x20, [sp, #16] - sxtw x8, w2 - lsr w9, w2, #2 - and x8, x8, #-4 - cbz w9, Loop_5 - ldr q0, [x1] - cmp w9, #1 - b.eq Loop_4 - add x10, x1, #16 - sub x11, x9, #1 -Loop_3: - ldr q1, [x10], #16 - subs x11, x11, #1 - fmax v0.4s, v0.4s, v1.4s - b.ne Loop_3 -Loop_4: - mov s1, v0.s[1] - fcmp s0, s1 - mov s2, v0.s[2] - fcsel s1, s0, s1, gt - fcmp s1, s2 - fcsel s1, s1, s2, gt - mov s0, v0.s[3] - fcmp s1, s0 - fcsel s0, s1, s0, gt - cmp w8, w2 - b.lo Loop_6 - b Loop_8 -Loop_5: - fmov s0, wzr - cmp w8, w2 - b.hs Loop_8 -Loop_6: - add x10, x1, w8, sxtw #2 - mov w11, w8 -Loop_7: - ldr s1, [x10], #4 - add w11, w11, #1 - fcmp s0, s1 - fcsel s0, s0, s1, gt - cmp w11, w2 - b.lo Loop_7 -Loop_8: - cbz w9, Loop_12 - mov w10, #-1028784128 - mov w12, #43579 - dup v4.4s, w10 - mov w10, #29208 - mov w11, #1118699520 - movk w12, #16312, lsl #16 - movk w10, #48945, lsl #16 - dup v5.4s, w11 - mov w11, #34953 - dup v6.4s, w12 - mov w12, #43691 - dup v7.4s, w10 - mov w10, #43691 - movk w11, #15368, lsl #16 - movk w12, #15658, lsl #16 - movk w10, #15914, lsl #16 - dup v2.4s, v0.s[0] - movi v1.16b, #0 - fmov v3.4s, #1.0 - dup v16.4s, w11 - dup v17.4s, w12 - dup v18.4s, w10 - movi v19.4s, #63, lsl #24 - mov x10, x9 - mov x11, x1 - mov x12, x0 -Loop_10: - ldr q20, [x11], #16 - subs x10, x10, #1 - fsub v20.4s, v20.4s, v2.4s - fmax v20.4s, v20.4s, v4.4s - fmin v20.4s, v20.4s, v5.4s - fmul v21.4s, v20.4s, v6.4s - fcvtzs v21.4s, v21.4s - scvtf v22.4s, v21.4s - fmla v20.4s, v22.4s, v7.4s - fmul v22.4s, v20.4s, v16.4s - fadd v22.4s, v22.4s, v17.4s - fmul v22.4s, v20.4s, v22.4s - fadd v22.4s, v22.4s, v18.4s - fmul v22.4s, v20.4s, v22.4s - fadd v22.4s, v22.4s, v19.4s - fmul v22.4s, v20.4s, v22.4s - fadd v22.4s, v22.4s, v3.4s - shl v21.4s, v21.4s, #23 - fmul v20.4s, v20.4s, v22.4s - add v21.4s, v21.4s, v3.4s - fadd v20.4s, v20.4s, v3.4s - fmul v20.4s, v20.4s, v21.4s - fadd v1.4s, v1.4s, v20.4s - str q20, [x12], #16 - b.ne Loop_10 - dup v2.4s, v1.s[1] - dup v3.4s, v1.s[2] - fadd v2.4s, v1.4s, v2.4s - fadd v2.4s, v3.4s, v2.4s - dup v1.4s, v1.s[3] - fadd v1.4s, v1.4s, v2.4s - b Loop_13 -Loop_12: - fmov s1, wzr -Loop_13: - cmp w8, w2 - fmov s2, #1.0 - b.hs Loop_16 - lsl x21, x8, #2 - mov w12, #29208 - mov w14, #34953 - mov w15, #43691 - mov w19, #43691 - mov w10, #-1028784128 - mov w11, #1118699520 - movk w12, #16177, lsl #16 - mov w13, #1065353216 - movk w14, #15368, lsl #16 - movk w15, #15658, lsl #16 - movk w19, #15914, lsl #16 - add x20, x1, x21 - add x21, x0, x21 - fmov s3, #0.5 - mov w1, w8 -Loop_15: - ldr s4, [x20], #4 - fmov s5, w10 - fmov s6, w11 - fmov s7, w12 - fsub s4, s4, s0 - fmaxnm s4, s4, s5 - fminnm s4, s4, s6 - fdiv s6, s4, s7 - fcvtzs w3, s6 - scvtf s6, w3 - fmul s6, s6, s7 - fmov s5, w14 - fsub s4, s4, s6 - fmov s7, w15 - fmul s5, s4, s5 - fadd s5, s5, s7 - fmov s6, w19 - fmul s5, s4, s5 - fadd s5, s5, s6 - fmul s5, s4, s5 - fadd s5, s5, s3 - fmul s5, s4, s5 - fadd s5, s5, s2 - add w3, w13, w3, lsl #23 - fmul s4, s4, s5 - fmov s7, w3 - fadd s4, s4, s2 - add w1, w1, #1 - fmul s4, s4, s7 - cmp w1, w2 - str s4, [x21], #4 - fadd s1, s1, s4 - b.lo Loop_15 -Loop_16: - fdiv s0, s2, s1 - cbz w9, Loop_19 - dup v1.4s, v0.s[0] - mov x10, x0 -Loop_18: - ldr q2, [x10] - subs x9, x9, #1 - fmul v2.4s, v1.4s, v2.4s - str q2, [x10], #16 - b.ne Loop_18 -Loop_19: - cmp w8, w2 - b.hs Loop_22 - add x9, x0, w8, sxtw #2 - Loop_21: - ldr s1, [x9] - add w8, w8, #1 - cmp w8, w2 - fmul s1, s0, s1 - str s1, [x9], #4 - b.lo Loop_21 -Loop_22: - ldp x19, x20, [sp, #16] - ldp x21, x22, [sp], #32 - ret -#endif - diff --git a/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScaleHp128_SME2_w4_Fp32.S b/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScaleHp128_SME2_w4_Fp32.S index 0f4a0140..6881d7c7 100644 --- a/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScaleHp128_SME2_w4_Fp32.S +++ b/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScaleHp128_SME2_w4_Fp32.S @@ -379,10 +379,10 @@ LoopSzEnd_TILE_1: .inst 0x05b02324 // dup z4.q, z25.q[2] .inst 0x05f02325 // dup z5.q, z25.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -404,10 +404,10 @@ LoopSzEnd_TILE_1: .inst 0x05f02325 // dup z5.q, z25.q[3] .inst 0x05702346 // dup z6.q, z26.q[1] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -431,11 +431,11 @@ LoopSzEnd_TILE_1: .inst 0x05702346 // dup z6.q, z26.q[1] .inst 0x05b02347 // dup z7.q, z26.q[2] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -461,11 +461,11 @@ TILE1_STORE48: .inst 0x05b02347 // dup z7.q, z26.q[2] .inst 0x05f02348 // dup z8.q, z26.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -492,12 +492,12 @@ TILE1_STORE52: .inst 0x05b02347 // dup z7.q, z26.q[2] .inst 0x05f02348 // dup z8.q, z26.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -526,12 +526,12 @@ TILE1_STORE56: .inst 0x05f02348 // dup z8.q, z26.q[3] .inst 0x05702369 // dup z9.q, z27.q[1] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -562,13 +562,13 @@ TILE1_STORE60: .inst 0x05702369 // dup z9.q, z27.q[1] .inst 0x05b0236a // dup z10.q, z27.q[2] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -601,13 +601,13 @@ TILE1_STORE64: .inst 0x05b0236a // dup z10.q, z27.q[2] .inst 0x05f0236b // dup z11.q, z27.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -640,15 +640,14 @@ TILE1_STORE68: .inst 0x05702369 // dup z9.q, z27.q[1] .inst 0x05b0236a // dup z10.q, z27.q[2] .inst 0x05f0236b // dup z11.q, z27.q[3] - .inst 0x0570238c // dup z12.q, z28.q[1] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -667,7 +666,7 @@ TILE1_STORE68: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 + add x9, x6, x4, lsl #4 // +16 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] b TILE1_Dz_End @@ -685,14 +684,15 @@ TILE1_STORE72: .inst 0x05702369 // dup z9.q, z27.q[1] .inst 0x05b0236a // dup z10.q, z27.q[2] .inst 0x05f0236b // dup z11.q, z27.q[3] + .inst 0x0570238c // dup z12.q, z28.q[1] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -711,7 +711,7 @@ TILE1_STORE72: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 + add x9, x6, x4, lsl #4 // +16 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] @@ -731,14 +731,15 @@ TILE1_STORE76: .inst 0x05b0236a // dup z10.q, z27.q[2] .inst 0x05f0236b // dup z11.q, z27.q[3] .inst 0x0570238c // dup z12.q, z28.q[1] + .inst 0x05b0238d // dup z13.q, z28.q[2] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -757,8 +758,8 @@ TILE1_STORE76: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] @@ -779,14 +780,16 @@ TILE1_STORE80: .inst 0x05b0236a // dup z10.q, z27.q[2] .inst 0x05f0236b // dup z11.q, z27.q[3] .inst 0x0570238c // dup z12.q, z28.q[1] + .inst 0x05b0238d // dup z13.q, z28.q[2] + .inst 0x05f0238e // dup z14.q, z28.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -805,12 +808,13 @@ TILE1_STORE80: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] .inst 0xe400f5ed // st1b {z13.b}, p5, [x15] + .inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4] b TILE1_Dz_End TILE1_STORE84: @@ -827,14 +831,17 @@ TILE1_STORE84: .inst 0x05b0236a // dup z10.q, z27.q[2] .inst 0x05f0236b // dup z11.q, z27.q[3] .inst 0x0570238c // dup z12.q, z28.q[1] + .inst 0x05b0238d // dup z13.q, z28.q[2] + .inst 0x05f0238e // dup z14.q, z28.q[3] + .inst 0x057023af // dup z15.q, z29.q[1] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -853,12 +860,15 @@ TILE1_STORE84: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] .inst 0xe400f5ed // st1b {z13.b}, p5, [x15] + .inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4] + .inst 0xe400f51d // st1b {z29.b}, p5, [x8] b TILE1_Dz_End TILE1_STORE88: @@ -876,14 +886,16 @@ TILE1_STORE88: .inst 0x05f0236b // dup z11.q, z27.q[3] .inst 0x0570238c // dup z12.q, z28.q[1] .inst 0x05b0238d // dup z13.q, z28.q[2] + .inst 0x05f0238e // dup z14.q, z28.q[3] + .inst 0x057023af // dup z15.q, z29.q[1] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -902,12 +914,16 @@ TILE1_STORE88: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] .inst 0xe400f5ed // st1b {z13.b}, p5, [x15] + .inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4] + .inst 0xe400f51d // st1b {z29.b}, p5, [x8] + .inst 0xe404550f // st1b {z15.b}, p5, [x8, x4] b TILE1_Dz_End TILE1_STORE92: @@ -926,14 +942,17 @@ TILE1_STORE92: .inst 0x0570238c // dup z12.q, z28.q[1] .inst 0x05b0238d // dup z13.q, z28.q[2] .inst 0x05f0238e // dup z14.q, z28.q[3] + .inst 0x057023af // dup z15.q, z29.q[1] + .inst 0x05b023b0 // dup z16.q, z29.q[2] + .inst 0x05f023b1 // dup z17.q, z29.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -952,15 +971,18 @@ TILE1_STORE92: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 - add x8, x23, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 + add x11, x5, x4, lsl #3 // +22 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] .inst 0xe400f5ed // st1b {z13.b}, p5, [x15] .inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4] .inst 0xe400f51d // st1b {z29.b}, p5, [x8] + .inst 0xe404550f // st1b {z15.b}, p5, [x8, x4] + .inst 0xe400f570 // st1b {z16.b}, p5, [x11] b TILE1_Dz_End TILE1_STORE96: @@ -980,14 +1002,16 @@ TILE1_STORE96: .inst 0x05b0238d // dup z13.q, z28.q[2] .inst 0x05f0238e // dup z14.q, z28.q[3] .inst 0x057023af // dup z15.q, z29.q[1] + .inst 0x05b023b0 // dup z16.q, z29.q[2] + .inst 0x05f023b1 // dup z17.q, z29.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -1006,9 +1030,10 @@ TILE1_STORE96: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 - add x8, x23, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 + add x11, x5, x4, lsl #3 // +22 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] @@ -1016,6 +1041,8 @@ TILE1_STORE96: .inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4] .inst 0xe400f51d // st1b {z29.b}, p5, [x8] .inst 0xe404550f // st1b {z15.b}, p5, [x8, x4] + .inst 0xe400f570 // st1b {z16.b}, p5, [x11] + .inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4] b TILE1_Dz_End TILE1_STORE100: @@ -1036,14 +1063,15 @@ TILE1_STORE100: .inst 0x05f0238e // dup z14.q, z28.q[3] .inst 0x057023af // dup z15.q, z29.q[1] .inst 0x05b023b0 // dup z16.q, z29.q[2] + .inst 0x05f023b1 // dup z17.q, z29.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -1062,10 +1090,11 @@ TILE1_STORE100: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 - add x8, x23, x4, lsl #3 - add x11, x5, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 + add x11, x5, x4, lsl #3 // +22 + add x13, x9, x4, lsl #3 // +24 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] @@ -1074,6 +1103,8 @@ TILE1_STORE100: .inst 0xe400f51d // st1b {z29.b}, p5, [x8] .inst 0xe404550f // st1b {z15.b}, p5, [x8, x4] .inst 0xe400f570 // st1b {z16.b}, p5, [x11] + .inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4] + .inst 0xe400f5be // st1b {z30.b}, p5, [x13] b TILE1_Dz_End TILE1_STORE104: @@ -1095,75 +1126,15 @@ TILE1_STORE104: .inst 0x057023af // dup z15.q, z29.q[1] .inst 0x05b023b0 // dup z16.q, z29.q[2] .inst 0x05f023b1 // dup z17.q, z29.q[3] - - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 - - .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] - .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] - .inst 0xe400f521 // st1b {z1.b}, p5, [x9] - .inst 0xe4045522 // st1b {z2.b}, p5, [x9, x4] - .inst 0xe400f5f9 // st1b {z25.b}, p5, [x15] - .inst 0xe40455e3 // st1b {z3.b}, p5, [x15, x4] - .inst 0xe400f504 // st1b {z4.b}, p5, [x8] - .inst 0xe4045505 // st1b {z5.b}, p5, [x8, x4] - .inst 0xe400f57a // st1b {z26.b}, p5, [x11] - .inst 0xe4045566 // st1b {z6.b}, p5, [x11, x4] - .inst 0xe400f5a7 // st1b {z7.b}, p5, [x13] - .inst 0xe40455a8 // st1b {z8.b}, p5, [x13, x4] - .inst 0xe400f6fb // st1b {z27.b}, p5, [x23] - .inst 0xe40456e9 // st1b {z9.b}, p5, [x23, x4] - .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] - .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 - add x8, x23, x4, lsl #3 - add x11, x5, x4, lsl #3 - - .inst 0xe400f53c // st1b {z28.b}, p5, [x9] - .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] - .inst 0xe400f5ed // st1b {z13.b}, p5, [x15] - .inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4] - .inst 0xe400f51d // st1b {z29.b}, p5, [x8] - .inst 0xe404550f // st1b {z15.b}, p5, [x8, x4] - .inst 0xe400f570 // st1b {z16.b}, p5, [x11] - .inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4] - b TILE1_Dz_End - -TILE1_STORE108: - .inst 0x05702300 // dup z0.q, z24.q[1] - .inst 0x05b02301 // dup z1.q, z24.q[2] - .inst 0x05f02302 // dup z2.q, z24.q[3] - .inst 0x05702323 // dup z3.q, z25.q[1] - .inst 0x05b02324 // dup z4.q, z25.q[2] - .inst 0x05f02325 // dup z5.q, z25.q[3] - .inst 0x05702346 // dup z6.q, z26.q[1] - .inst 0x05b02347 // dup z7.q, z26.q[2] - .inst 0x05f02348 // dup z8.q, z26.q[3] - .inst 0x05702369 // dup z9.q, z27.q[1] - .inst 0x05b0236a // dup z10.q, z27.q[2] - .inst 0x05f0236b // dup z11.q, z27.q[3] - .inst 0x0570238c // dup z12.q, z28.q[1] - .inst 0x05b0238d // dup z13.q, z28.q[2] - .inst 0x05f0238e // dup z14.q, z28.q[3] - .inst 0x057023af // dup z15.q, z29.q[1] - .inst 0x05b023b0 // dup z16.q, z29.q[2] - .inst 0x05f023b1 // dup z17.q, z29.q[3] .inst 0x057023d2 // dup z18.q, z30.q[1] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -1182,11 +1153,11 @@ TILE1_STORE108: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 - add x8, x23, x4, lsl #3 - add x11, x5, x4, lsl #3 - add x13, x9, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 + add x11, x5, x4, lsl #3 // +22 + add x13, x9, x4, lsl #3 // +24 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] @@ -1200,7 +1171,8 @@ TILE1_STORE108: .inst 0xe40455b2 // st1b {z18.b}, p5, [x13, x4] b TILE1_Dz_End - TILE1_STORE112: + /* oc=108 */ + TILE1_STORE108: .inst 0x05702300 // dup z0.q, z24.q[1] .inst 0x05b02301 // dup z1.q, z24.q[2] .inst 0x05f02302 // dup z2.q, z24.q[3] @@ -1222,13 +1194,13 @@ TILE1_STORE108: .inst 0x057023d2 // dup z18.q, z30.q[1] .inst 0x05b023d3 // dup z19.q, z30.q[2] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -1247,12 +1219,12 @@ TILE1_STORE108: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 - add x8, x23, x4, lsl #3 - add x11, x5, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 + add x11, x5, x4, lsl #3 // +22 + add x13, x9, x4, lsl #3 // +24 + add x23, x15, x4, lsl #3 // +26 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] @@ -1267,6 +1239,77 @@ TILE1_STORE108: .inst 0xe400f6f3 // st1b {z19.b}, p5, [x23] b TILE1_Dz_End + /* oc=112 */ + TILE1_STORE112: + .inst 0x05702300 // dup z0.q, z24.q[1] + .inst 0x05b02301 // dup z1.q, z24.q[2] + .inst 0x05f02302 // dup z2.q, z24.q[3] + .inst 0x05702323 // dup z3.q, z25.q[1] + .inst 0x05b02324 // dup z4.q, z25.q[2] + .inst 0x05f02325 // dup z5.q, z25.q[3] + .inst 0x05702346 // dup z6.q, z26.q[1] + .inst 0x05b02347 // dup z7.q, z26.q[2] + .inst 0x05f02348 // dup z8.q, z26.q[3] + .inst 0x05702369 // dup z9.q, z27.q[1] + .inst 0x05b0236a // dup z10.q, z27.q[2] + .inst 0x05f0236b // dup z11.q, z27.q[3] + .inst 0x0570238c // dup z12.q, z28.q[1] + .inst 0x05b0238d // dup z13.q, z28.q[2] + .inst 0x05f0238e // dup z14.q, z28.q[3] + .inst 0x057023af // dup z15.q, z29.q[1] + .inst 0x05b023b0 // dup z16.q, z29.q[2] + .inst 0x05f023b1 // dup z17.q, z29.q[3] + .inst 0x057023d2 // dup z18.q, z30.q[1] + .inst 0x05b023d3 // dup z19.q, z30.q[2] + .inst 0x05f023d4 // dup z20.q, z30.q[3] + + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 + + .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] + .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] + .inst 0xe400f521 // st1b {z1.b}, p5, [x9] + .inst 0xe4045522 // st1b {z2.b}, p5, [x9, x4] + .inst 0xe400f5f9 // st1b {z25.b}, p5, [x15] + .inst 0xe40455e3 // st1b {z3.b}, p5, [x15, x4] + .inst 0xe400f504 // st1b {z4.b}, p5, [x8] + .inst 0xe4045505 // st1b {z5.b}, p5, [x8, x4] + .inst 0xe400f57a // st1b {z26.b}, p5, [x11] + .inst 0xe4045566 // st1b {z6.b}, p5, [x11, x4] + .inst 0xe400f5a7 // st1b {z7.b}, p5, [x13] + .inst 0xe40455a8 // st1b {z8.b}, p5, [x13, x4] + .inst 0xe400f6fb // st1b {z27.b}, p5, [x23] + .inst 0xe40456e9 // st1b {z9.b}, p5, [x23, x4] + .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] + .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] + + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 + add x11, x5, x4, lsl #3 // +22 + add x13, x9, x4, lsl #3 // +24 + add x23, x15, x4, lsl #3 // +26 + + .inst 0xe400f53c // st1b {z28.b}, p5, [x9] + .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] + .inst 0xe400f5ed // st1b {z13.b}, p5, [x15] + .inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4] + .inst 0xe400f51d // st1b {z29.b}, p5, [x8] + .inst 0xe404550f // st1b {z15.b}, p5, [x8, x4] + .inst 0xe400f570 // st1b {z16.b}, p5, [x11] + .inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4] + .inst 0xe400f5be // st1b {z30.b}, p5, [x13] + .inst 0xe40455b2 // st1b {z18.b}, p5, [x13, x4] + .inst 0xe400f6f3 // st1b {z19.b}, p5, [x23] + .inst 0xe40456f4 // st1b {z20.b}, p5, [x23, x4] + b TILE1_Dz_End + + /* oc=116 */ TILE1_STORE116: .inst 0x05702300 // dup z0.q, z24.q[1] .inst 0x05b02301 // dup z1.q, z24.q[2] @@ -1290,13 +1333,13 @@ TILE1_STORE108: .inst 0x05b023d3 // dup z19.q, z30.q[2] .inst 0x05f023d4 // dup z20.q, z30.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -1315,13 +1358,13 @@ TILE1_STORE108: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 - add x8, x23, x4, lsl #3 - add x11, x5, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 + add x11, x5, x4, lsl #3 // +22 + add x13, x9, x4, lsl #3 // +24 + add x23, x15, x4, lsl #3 // +26 + add x5, x8, x4, lsl #3 // +28 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] @@ -1338,6 +1381,7 @@ TILE1_STORE108: .inst 0xe400f4bf // st1b {z31.b}, p5, [x5] b TILE1_Dz_End + /* oc=120 */ TILE1_STORE120: .inst 0x05702300 // dup z0.q, z24.q[1] .inst 0x05b02301 // dup z1.q, z24.q[2] @@ -1362,13 +1406,13 @@ TILE1_STORE108: .inst 0x05f023d4 // dup z20.q, z30.q[3] .inst 0x057023f5 // dup z21.q, z31.q[1] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -1387,13 +1431,13 @@ TILE1_STORE108: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 - add x8, x23, x4, lsl #3 - add x11, x5, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 + add x11, x5, x4, lsl #3 // +22 + add x13, x9, x4, lsl #3 // +24 + add x23, x15, x4, lsl #3 // +26 + add x5, x8, x4, lsl #3 // +28 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] @@ -1411,6 +1455,7 @@ TILE1_STORE108: .inst 0xe40454b5 // st1b {z21.b}, p5, [x5, x4] b TILE1_Dz_End + /* oc=124 */ TILE1_STORE124: .inst 0x05702300 // dup z0.q, z24.q[1] .inst 0x05b02301 // dup z1.q, z24.q[2] @@ -1487,6 +1532,7 @@ TILE1_STORE108: .inst 0xe400f596 // st1b {z22.b}, p5, [x12] b TILE1_Dz_End + /* oc=128 */ TILE1_STORE128: .inst 0x05702300 // dup z0.q, z24.q[1] .inst 0x05b02301 // dup z1.q, z24.q[2] diff --git a/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScaleHp128_SME2_w8_Fp32.S b/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScaleHp128_SME2_w8_Fp32.S index c4a21a6c..dc8f5b57 100644 --- a/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScaleHp128_SME2_w8_Fp32.S +++ b/source/backend/cpu/arm/arm64/sme2_asm/MNNGemmInt8AddBiasScaleHp128_SME2_w8_Fp32.S @@ -372,10 +372,10 @@ LoopSzEnd_TILE_1: .inst 0x05b02324 // dup z4.q, z25.q[2] .inst 0x05f02325 // dup z5.q, z25.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -397,10 +397,10 @@ LoopSzEnd_TILE_1: .inst 0x05f02325 // dup z5.q, z25.q[3] .inst 0x05702346 // dup z6.q, z26.q[1] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -424,11 +424,11 @@ LoopSzEnd_TILE_1: .inst 0x05702346 // dup z6.q, z26.q[1] .inst 0x05b02347 // dup z7.q, z26.q[2] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -454,11 +454,11 @@ TILE1_STORE48: .inst 0x05b02347 // dup z7.q, z26.q[2] .inst 0x05f02348 // dup z8.q, z26.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -485,12 +485,12 @@ TILE1_STORE52: .inst 0x05b02347 // dup z7.q, z26.q[2] .inst 0x05f02348 // dup z8.q, z26.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -519,12 +519,12 @@ TILE1_STORE56: .inst 0x05f02348 // dup z8.q, z26.q[3] .inst 0x05702369 // dup z9.q, z27.q[1] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -555,13 +555,13 @@ TILE1_STORE60: .inst 0x05702369 // dup z9.q, z27.q[1] .inst 0x05b0236a // dup z10.q, z27.q[2] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -594,13 +594,13 @@ TILE1_STORE64: .inst 0x05b0236a // dup z10.q, z27.q[2] .inst 0x05f0236b // dup z11.q, z27.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -633,15 +633,14 @@ TILE1_STORE68: .inst 0x05702369 // dup z9.q, z27.q[1] .inst 0x05b0236a // dup z10.q, z27.q[2] .inst 0x05f0236b // dup z11.q, z27.q[3] - .inst 0x0570238c // dup z12.q, z28.q[1] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -660,7 +659,7 @@ TILE1_STORE68: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 + add x9, x6, x4, lsl #4 // +16 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] b TILE1_Dz_End @@ -678,14 +677,15 @@ TILE1_STORE72: .inst 0x05702369 // dup z9.q, z27.q[1] .inst 0x05b0236a // dup z10.q, z27.q[2] .inst 0x05f0236b // dup z11.q, z27.q[3] + .inst 0x0570238c // dup z12.q, z28.q[1] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -704,7 +704,7 @@ TILE1_STORE72: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 + add x9, x6, x4, lsl #4 // +16 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] @@ -724,14 +724,15 @@ TILE1_STORE76: .inst 0x05b0236a // dup z10.q, z27.q[2] .inst 0x05f0236b // dup z11.q, z27.q[3] .inst 0x0570238c // dup z12.q, z28.q[1] + .inst 0x05b0238d // dup z13.q, z28.q[2] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -750,8 +751,8 @@ TILE1_STORE76: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] @@ -772,14 +773,16 @@ TILE1_STORE80: .inst 0x05b0236a // dup z10.q, z27.q[2] .inst 0x05f0236b // dup z11.q, z27.q[3] .inst 0x0570238c // dup z12.q, z28.q[1] + .inst 0x05b0238d // dup z13.q, z28.q[2] + .inst 0x05f0238e // dup z14.q, z28.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -798,12 +801,13 @@ TILE1_STORE80: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] .inst 0xe400f5ed // st1b {z13.b}, p5, [x15] + .inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4] b TILE1_Dz_End TILE1_STORE84: @@ -820,14 +824,16 @@ TILE1_STORE84: .inst 0x05b0236a // dup z10.q, z27.q[2] .inst 0x05f0236b // dup z11.q, z27.q[3] .inst 0x0570238c // dup z12.q, z28.q[1] + .inst 0x05b0238d // dup z13.q, z28.q[2] + .inst 0x05f0238e // dup z14.q, z28.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -846,12 +852,15 @@ TILE1_STORE84: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] .inst 0xe400f5ed // st1b {z13.b}, p5, [x15] + .inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4] + .inst 0xe400f51d // st1b {z29.b}, p5, [x8] b TILE1_Dz_End TILE1_STORE88: @@ -869,14 +878,16 @@ TILE1_STORE88: .inst 0x05f0236b // dup z11.q, z27.q[3] .inst 0x0570238c // dup z12.q, z28.q[1] .inst 0x05b0238d // dup z13.q, z28.q[2] + .inst 0x05f0238e // dup z14.q, z28.q[3] + .inst 0x057023af // dup z15.q, z29.q[1] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -895,12 +906,16 @@ TILE1_STORE88: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] .inst 0xe400f5ed // st1b {z13.b}, p5, [x15] + .inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4] + .inst 0xe400f51d // st1b {z29.b}, p5, [x8] + .inst 0xe404550f // st1b {z15.b}, p5, [x8, x4] b TILE1_Dz_End TILE1_STORE92: @@ -919,14 +934,17 @@ TILE1_STORE92: .inst 0x0570238c // dup z12.q, z28.q[1] .inst 0x05b0238d // dup z13.q, z28.q[2] .inst 0x05f0238e // dup z14.q, z28.q[3] + .inst 0x057023af // dup z15.q, z29.q[1] + .inst 0x05b023b0 // dup z16.q, z29.q[2] + .inst 0x05f023b1 // dup z17.q, z29.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -945,15 +963,18 @@ TILE1_STORE92: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 - add x8, x23, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 + add x11, x5, x4, lsl #3 // +22 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] .inst 0xe400f5ed // st1b {z13.b}, p5, [x15] .inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4] .inst 0xe400f51d // st1b {z29.b}, p5, [x8] + .inst 0xe404550f // st1b {z15.b}, p5, [x8, x4] + .inst 0xe400f570 // st1b {z16.b}, p5, [x11] b TILE1_Dz_End TILE1_STORE96: @@ -973,14 +994,16 @@ TILE1_STORE96: .inst 0x05b0238d // dup z13.q, z28.q[2] .inst 0x05f0238e // dup z14.q, z28.q[3] .inst 0x057023af // dup z15.q, z29.q[1] + .inst 0x05b023b0 // dup z16.q, z29.q[2] + .inst 0x05f023b1 // dup z17.q, z29.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -999,9 +1022,10 @@ TILE1_STORE96: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 - add x8, x23, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 + add x11, x5, x4, lsl #3 // +22 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] @@ -1009,6 +1033,8 @@ TILE1_STORE96: .inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4] .inst 0xe400f51d // st1b {z29.b}, p5, [x8] .inst 0xe404550f // st1b {z15.b}, p5, [x8, x4] + .inst 0xe400f570 // st1b {z16.b}, p5, [x11] + .inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4] b TILE1_Dz_End TILE1_STORE100: @@ -1029,14 +1055,15 @@ TILE1_STORE100: .inst 0x05f0238e // dup z14.q, z28.q[3] .inst 0x057023af // dup z15.q, z29.q[1] .inst 0x05b023b0 // dup z16.q, z29.q[2] + .inst 0x05f023b1 // dup z17.q, z29.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -1055,10 +1082,11 @@ TILE1_STORE100: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 - add x8, x23, x4, lsl #3 - add x11, x5, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 + add x11, x5, x4, lsl #3 // +22 + add x13, x9, x4, lsl #3 // +24 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] @@ -1067,6 +1095,8 @@ TILE1_STORE100: .inst 0xe400f51d // st1b {z29.b}, p5, [x8] .inst 0xe404550f // st1b {z15.b}, p5, [x8, x4] .inst 0xe400f570 // st1b {z16.b}, p5, [x11] + .inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4] + .inst 0xe400f5be // st1b {z30.b}, p5, [x13] b TILE1_Dz_End TILE1_STORE104: @@ -1088,75 +1118,15 @@ TILE1_STORE104: .inst 0x057023af // dup z15.q, z29.q[1] .inst 0x05b023b0 // dup z16.q, z29.q[2] .inst 0x05f023b1 // dup z17.q, z29.q[3] - - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 - - .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] - .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] - .inst 0xe400f521 // st1b {z1.b}, p5, [x9] - .inst 0xe4045522 // st1b {z2.b}, p5, [x9, x4] - .inst 0xe400f5f9 // st1b {z25.b}, p5, [x15] - .inst 0xe40455e3 // st1b {z3.b}, p5, [x15, x4] - .inst 0xe400f504 // st1b {z4.b}, p5, [x8] - .inst 0xe4045505 // st1b {z5.b}, p5, [x8, x4] - .inst 0xe400f57a // st1b {z26.b}, p5, [x11] - .inst 0xe4045566 // st1b {z6.b}, p5, [x11, x4] - .inst 0xe400f5a7 // st1b {z7.b}, p5, [x13] - .inst 0xe40455a8 // st1b {z8.b}, p5, [x13, x4] - .inst 0xe400f6fb // st1b {z27.b}, p5, [x23] - .inst 0xe40456e9 // st1b {z9.b}, p5, [x23, x4] - .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] - .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 - add x8, x23, x4, lsl #3 - add x11, x5, x4, lsl #3 - - .inst 0xe400f53c // st1b {z28.b}, p5, [x9] - .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] - .inst 0xe400f5ed // st1b {z13.b}, p5, [x15] - .inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4] - .inst 0xe400f51d // st1b {z29.b}, p5, [x8] - .inst 0xe404550f // st1b {z15.b}, p5, [x8, x4] - .inst 0xe400f570 // st1b {z16.b}, p5, [x11] - .inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4] - b TILE1_Dz_End - -TILE1_STORE108: - .inst 0x05702300 // dup z0.q, z24.q[1] - .inst 0x05b02301 // dup z1.q, z24.q[2] - .inst 0x05f02302 // dup z2.q, z24.q[3] - .inst 0x05702323 // dup z3.q, z25.q[1] - .inst 0x05b02324 // dup z4.q, z25.q[2] - .inst 0x05f02325 // dup z5.q, z25.q[3] - .inst 0x05702346 // dup z6.q, z26.q[1] - .inst 0x05b02347 // dup z7.q, z26.q[2] - .inst 0x05f02348 // dup z8.q, z26.q[3] - .inst 0x05702369 // dup z9.q, z27.q[1] - .inst 0x05b0236a // dup z10.q, z27.q[2] - .inst 0x05f0236b // dup z11.q, z27.q[3] - .inst 0x0570238c // dup z12.q, z28.q[1] - .inst 0x05b0238d // dup z13.q, z28.q[2] - .inst 0x05f0238e // dup z14.q, z28.q[3] - .inst 0x057023af // dup z15.q, z29.q[1] - .inst 0x05b023b0 // dup z16.q, z29.q[2] - .inst 0x05f023b1 // dup z17.q, z29.q[3] .inst 0x057023d2 // dup z18.q, z30.q[1] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -1175,11 +1145,11 @@ TILE1_STORE108: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 - add x8, x23, x4, lsl #3 - add x11, x5, x4, lsl #3 - add x13, x9, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 + add x11, x5, x4, lsl #3 // +22 + add x13, x9, x4, lsl #3 // +24 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] @@ -1193,7 +1163,8 @@ TILE1_STORE108: .inst 0xe40455b2 // st1b {z18.b}, p5, [x13, x4] b TILE1_Dz_End - TILE1_STORE112: + /* oc=108 */ + TILE1_STORE108: .inst 0x05702300 // dup z0.q, z24.q[1] .inst 0x05b02301 // dup z1.q, z24.q[2] .inst 0x05f02302 // dup z2.q, z24.q[3] @@ -1215,13 +1186,13 @@ TILE1_STORE108: .inst 0x057023d2 // dup z18.q, z30.q[1] .inst 0x05b023d3 // dup z19.q, z30.q[2] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -1240,12 +1211,12 @@ TILE1_STORE108: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 - add x8, x23, x4, lsl #3 - add x11, x5, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 + add x11, x5, x4, lsl #3 // +22 + add x13, x9, x4, lsl #3 // +24 + add x23, x15, x4, lsl #3 // +26 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] @@ -1260,6 +1231,77 @@ TILE1_STORE108: .inst 0xe400f6f3 // st1b {z19.b}, p5, [x23] b TILE1_Dz_End + /* oc=112 */ + TILE1_STORE112: + .inst 0x05702300 // dup z0.q, z24.q[1] + .inst 0x05b02301 // dup z1.q, z24.q[2] + .inst 0x05f02302 // dup z2.q, z24.q[3] + .inst 0x05702323 // dup z3.q, z25.q[1] + .inst 0x05b02324 // dup z4.q, z25.q[2] + .inst 0x05f02325 // dup z5.q, z25.q[3] + .inst 0x05702346 // dup z6.q, z26.q[1] + .inst 0x05b02347 // dup z7.q, z26.q[2] + .inst 0x05f02348 // dup z8.q, z26.q[3] + .inst 0x05702369 // dup z9.q, z27.q[1] + .inst 0x05b0236a // dup z10.q, z27.q[2] + .inst 0x05f0236b // dup z11.q, z27.q[3] + .inst 0x0570238c // dup z12.q, z28.q[1] + .inst 0x05b0238d // dup z13.q, z28.q[2] + .inst 0x05f0238e // dup z14.q, z28.q[3] + .inst 0x057023af // dup z15.q, z29.q[1] + .inst 0x05b023b0 // dup z16.q, z29.q[2] + .inst 0x05f023b1 // dup z17.q, z29.q[3] + .inst 0x057023d2 // dup z18.q, z30.q[1] + .inst 0x05b023d3 // dup z19.q, z30.q[2] + .inst 0x05f023d4 // dup z20.q, z30.q[3] + + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 + + .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] + .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] + .inst 0xe400f521 // st1b {z1.b}, p5, [x9] + .inst 0xe4045522 // st1b {z2.b}, p5, [x9, x4] + .inst 0xe400f5f9 // st1b {z25.b}, p5, [x15] + .inst 0xe40455e3 // st1b {z3.b}, p5, [x15, x4] + .inst 0xe400f504 // st1b {z4.b}, p5, [x8] + .inst 0xe4045505 // st1b {z5.b}, p5, [x8, x4] + .inst 0xe400f57a // st1b {z26.b}, p5, [x11] + .inst 0xe4045566 // st1b {z6.b}, p5, [x11, x4] + .inst 0xe400f5a7 // st1b {z7.b}, p5, [x13] + .inst 0xe40455a8 // st1b {z8.b}, p5, [x13, x4] + .inst 0xe400f6fb // st1b {z27.b}, p5, [x23] + .inst 0xe40456e9 // st1b {z9.b}, p5, [x23, x4] + .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] + .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] + + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 + add x11, x5, x4, lsl #3 // +22 + add x13, x9, x4, lsl #3 // +24 + add x23, x15, x4, lsl #3 // +26 + + .inst 0xe400f53c // st1b {z28.b}, p5, [x9] + .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] + .inst 0xe400f5ed // st1b {z13.b}, p5, [x15] + .inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4] + .inst 0xe400f51d // st1b {z29.b}, p5, [x8] + .inst 0xe404550f // st1b {z15.b}, p5, [x8, x4] + .inst 0xe400f570 // st1b {z16.b}, p5, [x11] + .inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4] + .inst 0xe400f5be // st1b {z30.b}, p5, [x13] + .inst 0xe40455b2 // st1b {z18.b}, p5, [x13, x4] + .inst 0xe400f6f3 // st1b {z19.b}, p5, [x23] + .inst 0xe40456f4 // st1b {z20.b}, p5, [x23, x4] + b TILE1_Dz_End + + /* oc=116 */ TILE1_STORE116: .inst 0x05702300 // dup z0.q, z24.q[1] .inst 0x05b02301 // dup z1.q, z24.q[2] @@ -1283,13 +1325,13 @@ TILE1_STORE108: .inst 0x05b023d3 // dup z19.q, z30.q[2] .inst 0x05f023d4 // dup z20.q, z30.q[3] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -1308,13 +1350,13 @@ TILE1_STORE108: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 - add x8, x23, x4, lsl #3 - add x11, x5, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 + add x11, x5, x4, lsl #3 // +22 + add x13, x9, x4, lsl #3 // +24 + add x23, x15, x4, lsl #3 // +26 + add x5, x8, x4, lsl #3 // +28 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] @@ -1331,6 +1373,7 @@ TILE1_STORE108: .inst 0xe400f4bf // st1b {z31.b}, p5, [x5] b TILE1_Dz_End + /* oc=120 */ TILE1_STORE120: .inst 0x05702300 // dup z0.q, z24.q[1] .inst 0x05b02301 // dup z1.q, z24.q[2] @@ -1355,13 +1398,13 @@ TILE1_STORE108: .inst 0x05f023d4 // dup z20.q, z30.q[3] .inst 0x057023f5 // dup z21.q, z31.q[1] - add x9, x6, x4, lsl #1 - add x15, x6, x4, lsl #2 - add x8, x9, x4, lsl #2 - add x11, x6, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #1 // +2 + add x15, x6, x4, lsl #2 // +4 + add x8, x9, x4, lsl #2 // +6 + add x11, x6, x4, lsl #3 // +8 + add x13, x9, x4, lsl #3 // +10 + add x23, x15, x4, lsl #3 // +12 + add x5, x8, x4, lsl #3 // +14 .inst 0xe400f4d8 // st1b {z24.b}, p5, [x6] .inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4] @@ -1380,13 +1423,13 @@ TILE1_STORE108: .inst 0xe400f4aa // st1b {z10.b}, p5, [x5] .inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4] - add x9, x6, x4, lsl #4 - add x15, x13, x4, lsl #3 - add x8, x23, x4, lsl #3 - add x11, x5, x4, lsl #3 - add x13, x9, x4, lsl #3 - add x23, x15, x4, lsl #3 - add x5, x8, x4, lsl #3 + add x9, x6, x4, lsl #4 // +16 + add x15, x13, x4, lsl #3 // +18 + add x8, x23, x4, lsl #3 // +20 + add x11, x5, x4, lsl #3 // +22 + add x13, x9, x4, lsl #3 // +24 + add x23, x15, x4, lsl #3 // +26 + add x5, x8, x4, lsl #3 // +28 .inst 0xe400f53c // st1b {z28.b}, p5, [x9] .inst 0xe404552c // st1b {z12.b}, p5, [x9, x4] @@ -1404,6 +1447,7 @@ TILE1_STORE108: .inst 0xe40454b5 // st1b {z21.b}, p5, [x5, x4] b TILE1_Dz_End + /* oc=124 */ TILE1_STORE124: .inst 0x05702300 // dup z0.q, z24.q[1] .inst 0x05b02301 // dup z1.q, z24.q[2] @@ -1480,6 +1524,7 @@ TILE1_STORE108: .inst 0xe400f596 // st1b {z22.b}, p5, [x12] b TILE1_Dz_End + /* oc=128 */ TILE1_STORE128: .inst 0x05702300 // dup z0.q, z24.q[1] .inst 0x05b02301 // dup z1.q, z24.q[2] diff --git a/source/backend/cpu/compute/CommonOptFunction.cpp b/source/backend/cpu/compute/CommonOptFunction.cpp index d5b29342..f61d04fe 100644 --- a/source/backend/cpu/compute/CommonOptFunction.cpp +++ b/source/backend/cpu/compute/CommonOptFunction.cpp @@ -1390,6 +1390,89 @@ void MNNAccumulateSequenceNumber (float* dst, const float* src, int size) { *dst = sum; } +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE + +static void MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) { + // source shape: [headDim/pack, seqLen, pack] + // scale & normalizeScale shape: [seqLen] + // dest shape: [headDim/pack, seqLen, pack] + auto stride0 = plane * pack; + + if (idx > 0) { + for (int j = 0; j < depthQuad; ++j) { + for (int i = 0; i < plane; ++i) { + auto dataNew = Vec::load(src + j * stride0 + i * pack); + auto dataOld = Vec::load(dst + j * stride0 + i * pack); + auto s = Vec(scale[i]); + dataNew = Vec::fma(dataNew, dataOld, s); + Vec::save(dst + j * stride0 + i * pack, dataNew); + } + } + } else { + memcpy(dst, src, size * bytes); + } + if (idx == kvBlocks - 1) { // if last subBlock, exp(xi)/sum(exp(xi)) + for (int j = 0; j < depthQuad; ++j) { + for (int i = 0; i < plane; ++i) { + auto dataNew = Vec::load(dst + j * stride0 + i * pack); + auto ns = Vec(1.0f / normalizeScale[i]); + dataNew = dataNew * ns; + Vec::save(dst + j * stride0 + i * pack, dataNew); + } + } + } +} + +static void MNNAttenPackAndScaleSingleHead(float* dst, const float* srcHeadBase, size_t srcRowStride, const float* scale, const int32_t* units, size_t seqLen, size_t headDim) { + const int32_t eP = units[0]; + const int32_t lP = units[1]; + + if (lP != 1) { + MNN_ERROR("This function only supports lP=1 or 2\n"); + return; + } + + const float scaleVal = scale[0]; +#ifdef MNN_USE_NEON + const float32x4_t vScale = vdupq_n_f32(scaleVal); +#endif + const size_t packedHeadDim = UP_DIV(headDim, lP); + const size_t dstStrideDOuter = (size_t)eP * lP; + const size_t dstStrideSOuter = packedHeadDim * dstStrideDOuter; + + for (int s = 0; s < seqLen; ++s) { + const int sOuter = s / eP; + const int sInner = s % eP; + const float* srcRowPtr = srcHeadBase + s * srcRowStride; + float* dstBasePtr = dst + sOuter * dstStrideSOuter + sInner * lP; + + size_t d = 0; +#ifdef MNN_USE_NEON + for (; d + 7 < headDim; d += 8) { + float32x4_t sVec0 = vld1q_f32(srcRowPtr + d); + float32x4_t sVec1 = vld1q_f32(srcRowPtr + d + 4); + sVec0 = vmulq_f32(sVec0, vScale); + sVec1 = vmulq_f32(sVec1, vScale); + + dstBasePtr[(d + 0) * dstStrideDOuter] = sVec0[0]; + dstBasePtr[(d + 1) * dstStrideDOuter] = sVec0[1]; + dstBasePtr[(d + 2) * dstStrideDOuter] = sVec0[2]; + dstBasePtr[(d + 3) * dstStrideDOuter] = sVec0[3]; + dstBasePtr[(d + 4) * dstStrideDOuter] = sVec1[0]; + dstBasePtr[(d + 5) * dstStrideDOuter] = sVec1[1]; + dstBasePtr[(d + 6) * dstStrideDOuter] = sVec1[2]; + dstBasePtr[(d + 7) * dstStrideDOuter] = sVec1[3]; + } +#else + for (; d < headDim; ++d) { + dstBasePtr[d * dstStrideDOuter] = srcRowPtr[d] * scaleVal; + } +#endif + } +} +#endif // MNN_SUPPORT_TRANSFORMER_FUSE + + #ifndef MNN_USE_NEON void MNNGetMatMulPackMode(int* eP, int *lP, int* hP) { *eP = 16; @@ -2619,30 +2702,54 @@ void MNNExpC8(float* dest, const float* source, float* offset, const float* para offset[3] = summer; } -void MNNSoftmax(float* dest, const float* source, size_t size) { - float maxValue = ALIMAX(source[0], source[1]); - for (int i = 2; i < size; ++i) { - maxValue = ALIMAX(maxValue, source[i]); - } - float xLimit = 87, param = 0.6931471805599453, sumValue = 0.f; - for (int i = 0; i < size; ++i) { - auto x = source[i] - maxValue; - x = x > -xLimit ? x : -xLimit; - x = x < xLimit ? x : xLimit; +void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) { + for (int k = 0; k < outside; ++k) { + auto source = input + k * reduceSize; + auto dest = softmaxDst + k * reduceSize; - int div = (x / param); - int div2 = (div + 127) << 23; - auto xReamin = x - div * param; - float expBasic = *(float*)(&div2); + float oldMax = source[0]; + if (runningMax) { + oldMax = runningMax[k]; + } - auto t = xReamin; - auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f; - dest[i] = expBasic * expRemain; - sumValue += dest[i]; - } - sumValue = 1.f / sumValue; - for (int i = 0; i < size; ++i) { - dest[i] *= sumValue; + // find max value of current block + float blockMax =source[0]; + for (int i = 1; i < reduceSize; ++i) { + blockMax = ALIMAX(blockMax, source[i]); + } + float newMax = ALIMAX(oldMax, blockMax); + + // caculate block's expr(xi-newmax) and update runningMax + float xLimit = 87, param = 0.6931471805599453; + float blockSum = 0.f; + for (int i = 0; i < reduceSize; ++i) { + auto x = source[i] - newMax; + x = x > -xLimit ? x : -xLimit; + x = x < xLimit ? x : xLimit; + + int div = (x / param); + int div2 = (div + 127) << 23; + auto xReamin = x - div * param; + float expBasic = *(float*)(&div2); + + auto t = xReamin; + auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f; + dest[i] = expBasic * expRemain; + blockSum += dest[i]; + } + + if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) { + // update runningSum, runningMax, scale=expf(oldMax-newMax) + runningSum[k] = runningSum[k] * expf(oldMax - newMax) + blockSum; + runningMax[k] = newMax; + updateScale[k] = expf(oldMax - newMax); + } else { + // Normalize + auto scale = 1.f / blockSum; + for (int i = 0; i < reduceSize; ++i) { + dest[i] *= scale; + } + } } } @@ -2657,6 +2764,68 @@ void MNNReluInt8(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint) } #endif // no MNN_USE_SSE +void MNNExp(float* dst, const float* src, float* offset, size_t dataSize) { + int countC8 = static_cast(dataSize) / 8; + int remain = static_cast(dataSize) % 8; + static const float parameters[] = { + (float)logf(2.0f), 1.0f / (float)logf(2.0f), 0.25f, 1.0f, 0.5f, 1.0f / 6.0f, 1.0f / 24.0f, 1.0f / 120.0f}; + if (countC8 > 0) { + // Align to eight so asm is easier to write + MNNExpC8(dst, src, offset, parameters, countC8); + } + if (remain > 0) { + auto param = parameters[0]; + float xLimit = 87; + float summer = offset[3]; + auto source = src + countC8 * 8; + auto dest = dst + countC8 * 8; + for (int i = 0; i < remain; ++i) { + auto x = source[i] * offset[0] + offset[2]; + x = ALIMAX(x, -xLimit); + x = ALIMIN(x, xLimit); + int div = (x * parameters[1]); + int div2 = (div + 127) << 23; + auto xReamin = x - div * param; + float expBasic = *(float*)(&div2); + auto t = xReamin * 0.25f; + auto expRemain = + ((((parameters[7] * t + parameters[6]) * t + parameters[5]) * t + parameters[4]) * t + 1.0f) * t + + 1.0f; + expRemain = expRemain * expRemain; + expRemain = expRemain * expRemain; + dest[i] = expBasic * expRemain + offset[1]; + summer+= dest[i]; + } + offset[3] = summer; + } +} + +void packKvCache(float* dst, const float* src, size_t seqLen, size_t kvSeqLen, size_t eP) { + if (seqLen == 0 || kvSeqLen == 0) { + return; + } + + const size_t dstSOuterStride = kvSeqLen * eP; + + for (size_t sBase = 0; sBase < seqLen; sBase += eP) { + const size_t numRowsInBlock = std::min(eP, seqLen - sBase); + const size_t sOuter = sBase / eP; + + float* dstSOuterBase = dst + sOuter * dstSOuterStride; + + for (size_t k = 0; k < kvSeqLen; ++k) { + float* dstColStart = dstSOuterBase + k * eP; + + for (size_t sInner = 0; sInner < numRowsInBlock; ++sInner) { + const size_t s = sBase + sInner; + + const float value = src[s * kvSeqLen + k]; + dstColStart[sInner] = value; + } + } + } +} + void MNNMaxFloat(float* input, float* maxBuffer, int32_t inputCountUnit) { for (int i = 0; i < inputCountUnit; i++) { for (int j = 0; j < UNIT; j++) { @@ -3173,42 +3342,6 @@ void MNNPackTranspose(float* dst, const float* src, size_t area, size_t depth, i } } -void MNNExp(float* dst, const float* src, float* offset, size_t dataSize) { - int countC8 = static_cast(dataSize) / 8; - int remain = static_cast(dataSize) % 8; - static const float parameters[] = { - (float)logf(2.0f), 1.0f / (float)logf(2.0f), 0.25f, 1.0f, 0.5f, 1.0f / 6.0f, 1.0f / 24.0f, 1.0f / 120.0f}; - if (countC8 > 0) { - // Align to eight so asm is easier to write - MNNExpC8(dst, src, offset, parameters, countC8); - } - if (remain > 0) { - auto param = parameters[0]; - float xLimit = 87; - float summer = offset[3]; - auto source = src + countC8 * 8; - auto dest = dst + countC8 * 8; - for (int i = 0; i < remain; ++i) { - auto x = source[i] * offset[0] + offset[2]; - x = ALIMAX(x, -xLimit); - x = ALIMIN(x, xLimit); - int div = (x * parameters[1]); - int div2 = (div + 127) << 23; - auto xReamin = x - div * param; - float expBasic = *(float*)(&div2); - auto t = xReamin * 0.25f; - auto expRemain = - ((((parameters[7] * t + parameters[6]) * t + parameters[5]) * t + parameters[4]) * t + 1.0f) * t + - 1.0f; - expRemain = expRemain * expRemain; - expRemain = expRemain * expRemain; - dest[i] = expBasic * expRemain + offset[1]; - summer+= dest[i]; - } - offset[3] = summer; - } -} - // Lambert's series with 7 divisions // reference from // https://varietyofsound.wordpress.com/2011/02/14/efficient-tanh-computation-using-lamberts-continued-fraction/ @@ -4011,6 +4144,7 @@ void MNNCoreFunctionInit() { gCoreFunction->chooseWinoDestUnrollTransform = WinogradFunction::chooseWinoDestUnrollTransform; gCoreFunction->MNNDeconvRunForLineDepthwise = MNNDeconvRunForLineDepthwise; gCoreFunction->MNNDeconvRunForUnitDepthWise = MNNDeconvRunForUnitDepthWise; + gCoreFunction->MNNSoftmax = MNNSoftmax; #ifdef MNN_USE_NEON gCoreFunction->MNNDepthwiseConvFastKernel = MNNDepthwiseConvFastKernel; #endif @@ -4019,6 +4153,12 @@ void MNNCoreFunctionInit() { #ifdef MNN_SUPPORT_QUANT_EXTEND gCoreFunction->MNNSelectUnaryFunctionForInt8 = CPUUnary::selectForInt8; #endif + +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE + gCoreFunction->MNNAttenPackAndScaleSingleHead = MNNAttenPackAndScaleSingleHead; + gCoreFunction->MNNFlashAttentionUpdateBlockOutput = MNNFlashAttentionUpdateBlockOutput; +#endif + gCoreFunction->MNNReluWithSlopeChannel = MNNReluWithSlopeChannel; gCoreFunction->MNNPoolingAvg = (decltype(gCoreFunction->MNNPoolingAvg))(poolingAvg); // Set min value as 1 << 24 diff --git a/source/backend/cpu/compute/CommonOptFunction.h b/source/backend/cpu/compute/CommonOptFunction.h index eaf676b2..5a7a92e6 100644 --- a/source/backend/cpu/compute/CommonOptFunction.h +++ b/source/backend/cpu/compute/CommonOptFunction.h @@ -32,12 +32,16 @@ void MNNAbsMaxFP32_Pack4(const float* source, float* absmax, size_t src_depth_qu void MNNAbsMaxFP32_Pack8(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack); void MNNQuantScaleFP32(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch); void MNNDynamicUpdateConvBiasScale(float* newbias, float* oldbias, float* weightKernelSum, float* inputZero, size_t ocQuad); -#endif +#endif // MNN_LOW_MEMORY + #ifdef MNN_SME2 void MNNPackedMatMulRemainFP32_SME2(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +#endif + +#endif // __aarch64__ + + -#endif -#endif void MNNFp32ToFp8(uint8_t* dst, const float* src, size_t size); void MNNFp8ToFp32(float* dst, const uint8_t* src, size_t size); void MNNFp16ToFp8(uint8_t* dst, const uint16_t* src, size_t size); @@ -117,7 +121,7 @@ void MNNReluWithSlopeCommon(float* dst, const float* src, size_t size, float slo void MNNHardSwishCommon(float* dst, const float* src, size_t size); void MNNGeluCommon(float* dst, const float* src, size_t size); void MNNGeluStandardCommon(float* dst, const float* src, size_t size); -void MNNSoftmax(float* dest, const float* source, size_t size); +void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize); void MNNNorm(float* dest, const float* source, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm = false); // Get Pack for MatMul's e , l , h , the pack number must be 1 or 4 * n @@ -187,6 +191,8 @@ void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const floa void MNNCopyC4Int16WithStride(const float* sourceF, float* destF, size_t srcStride, size_t dstStride, size_t count); void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count); +void packKvCache(float* dst, const float* src, size_t seqLen, size_t kvSeqLen, size_t eP); + struct SumByAxisParams { ssize_t kernelCountUnitDouble; ssize_t unitColBufferSize; @@ -401,6 +407,13 @@ struct CoreFunctions { void(*MNNSumWeightInt8)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP); void(*MNNSumWeightInt8SmeHp64)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP); + // Attention + void(*MNNAttenUnpackAndConvertFp16)(float* dst, float* src, size_t depth, size_t planesize, int pack); + void(*MNNAttenPackAndConvertFp32)(float* dst, float* src, const int32_t* units, size_t depth, size_t planesize); + void(*MNNAttenPackAndScaleSingleHead)(float* dst, const float* srcHeadBase, size_t srcRowStride, const float* scale, const int32_t* units, size_t seqLen, size_t headDim); + void(*MNNFlashAttentionUpdateBlockOutput)(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes); + void(*MNNSoftmax)(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize); + MatmulRelatedFunctions int8MatmulRelatedFunctions; MatmulRelatedFunctions sme2Int8MatmulRelatedFuncionsHp32; }; diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp index 0cc7624a..22a18fcf 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp @@ -805,17 +805,16 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input return OUT_OF_MEMORY; } if (mOnlineReorderWeightSme && planeSize > 1) { // only prefill need - int dstHp = 32; // if mOnlineReorderWeightSme==true, UNIT is 64 when model loaded. - int weightlenNew = ROUND_UP(outC, dstHp) * mBlockNum * ROUND_UP(ic / mBlockNum, SRC_UNIT) * kernelCount * SRC_UNIT * dstHp; + int weightlenNew = ROUND_UP(outC, SME_DECODE_MAXHP) * mBlockNum * ROUND_UP(ic / mBlockNum, SRC_UNIT) * kernelCount; if (mResourceInt8->mActBits == 4) { weightlenNew /= 2; } - mWeight4Prefill = bufferAlloc->alloc(weightlenNew + 2 * ROUND_UP(outC, dstHp) * QUANT_INFO_BYTES); + mWeight4Prefill = bufferAlloc->alloc(weightlenNew + 2 * mBlockNum * ROUND_UP(outC, SME_DECODE_MAXHP) * QUANT_INFO_BYTES); if (mWeight4Prefill.invalid()) { return OUT_OF_MEMORY; } if (mInputBlockNum > 1) { // only in this case, need to use weight_kernel_sum - mWeightKernelSum4Prefill = bufferAlloc->alloc(ROUND_UP(outC, 32) * mBlockNum * sizeof(float)); + mWeightKernelSum4Prefill = bufferAlloc->alloc(ROUND_UP(outC, SME_DECODE_MAXHP) * mBlockNum * sizeof(float)); if (mWeightKernelSum4Prefill.invalid()) { return OUT_OF_MEMORY; } diff --git a/source/backend/cpu/x86_x64/AVX2Functions.cpp b/source/backend/cpu/x86_x64/AVX2Functions.cpp index 6b5db53e..ce6ea6cb 100644 --- a/source/backend/cpu/x86_x64/AVX2Functions.cpp +++ b/source/backend/cpu/x86_x64/AVX2Functions.cpp @@ -62,7 +62,8 @@ bool AVX2Functions::init(int cpuFlags) { coreFunction->MNNComputeMatMulForH_1 = _AVX_MNNComputeMatMulForH_1; // Dynamic Quant coreFunction->MNNCountMaxMinValue = _AVX_MNNCountMinMaxValue; - + + coreFunction->MNNSoftmax = _AVX_MNNSoftmax; // For Packed Functions coreFunction->pack = 8; diff --git a/source/backend/cpu/x86_x64/FunctionDispatcher.cpp b/source/backend/cpu/x86_x64/FunctionDispatcher.cpp index a77f3ac0..0bdc3e42 100644 --- a/source/backend/cpu/x86_x64/FunctionDispatcher.cpp +++ b/source/backend/cpu/x86_x64/FunctionDispatcher.cpp @@ -24,7 +24,7 @@ struct FunctionGroup { int lP = 1; int hP = 4; void (*MNNExpC8)(float* dest, const float* source, float* offset, const float* parameters, size_t countC8) = _SSE_MNNExpC8; - void (*MNNSoftmax)(float* dest, const float* source, size_t size) = _SSE_MNNSoftmax; + void (*MNNSoftmax)(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) = _SSE_MNNSoftmax; void (*MNNReluInt8)(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint) = _SSE_MNNReluInt8; void (*MNNHardSwish)(float* dst, const float* src, size_t size) = _SSE_MNNHardSwish; void (*MNNGelu)(float* dst, const float* src, size_t size, float* parameters) = _SSE_MNNGelu; @@ -65,6 +65,8 @@ void MNNFunctionInit() { coreFunction->MNNPackForMatMul_B = _SSE_MNNPackForMatMul_B; // Dynamic Quant coreFunction->MNNCountMaxMinValue = _SSE_MNNCountMinMaxValue; + + coreFunction->MNNSoftmax = _SSE_MNNSoftmax; } #ifdef MNN_USE_AVX if (cpuFlags & libyuv::kCpuHasAVX2) { @@ -205,8 +207,8 @@ void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count) { _SSE_MNNInt8ToInt16(dest, source, count); } -void MNNSoftmax(float* dest, const float* source, size_t size) { - gFunc.MNNSoftmax(dest, source, size); +void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) { + gFunc.MNNSoftmax(softmaxDst, input, runningMax, runningSum, updateScale, outside, reduceSize); } void MNNNorm(float* dest, const float* source, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm) { diff --git a/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp b/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp index 87bb7a07..a2d35834 100644 --- a/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp +++ b/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp @@ -53,7 +53,7 @@ void _AVX_MNNAsyQuantInfo(float* scale, float* bias, float* qscale, float* qbias void _AVX_MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); void _AVX_MNNExpC8(float* dest, const float* source, float* offset, const float* parameters, size_t countC8); -void _AVX_MNNSoftmax(float* dest, const float* source, size_t size); +void _AVX_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize); void _AVX_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, const float* zeroPoint, ssize_t quanParamVec); void _AVX_MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t sizeQuad, const float* zeroPoint, ssize_t quanParamVec); void _AVX_MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dstO, const int8_t* srcO, const int8_t* weightO, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder); diff --git a/source/backend/cpu/x86_x64/avx/MathFunctions.cpp b/source/backend/cpu/x86_x64/avx/MathFunctions.cpp index feb2d6f3..fde08bbd 100644 --- a/source/backend/cpu/x86_x64/avx/MathFunctions.cpp +++ b/source/backend/cpu/x86_x64/avx/MathFunctions.cpp @@ -117,103 +117,71 @@ void _AVX_MNNExpC8(float* dest, const float* source, float* offset, const float* } -void _AVX_MNNSoftmax(float* dest, const float* source, size_t size) { - float tmpfloat8[8]; - int count = size / 8; - int remain = count * 8; - // step 1: get maxValue - float maxValue = source[0]; - if (count > 0) { - auto maxVal = _mm256_loadu_ps(source); - for (int i = 1; i < count; i++) { - maxVal = _mm256_max_ps(maxVal, _mm256_loadu_ps(source + i * 8)); - } - _mm256_storeu_ps(tmpfloat8, maxVal); - maxValue = tmpfloat8[0] > tmpfloat8[1] ? tmpfloat8[0] : tmpfloat8[1]; - for (int i = 2; i < 8; i++) { - maxValue = maxValue > tmpfloat8[i] ? maxValue : tmpfloat8[i]; - } - } - for (int i = remain; i < size; i++) { - maxValue = maxValue > source[i] ? maxValue : source[i]; - } +void _AVX_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) { + const float xLimit = 87.0f; + const float param = 0.6931471805599453f; // ln(2) + const float inv_param = 1.0f / param; + const int32_t exp_offset = 127; + const float exp_scale = 8388608.0f; // 2^23 - // step 2: get exp(x - maxValue) and sum(exp(x - maxValue)) - float sumValue = 0.f; - if (count > 0) { - auto sumVal = _mm256_set1_ps(0.f); - auto p0 = _mm256_set1_ps(0.6931471805599453); - auto p1 = _mm256_set1_ps(1.4426950408889634); - auto p2 = _mm256_set1_ps(1.f); - auto p3 = _mm256_set1_ps(1.f); - auto p4 = _mm256_set1_ps(0.5); - auto p5 = _mm256_set1_ps(0.1666666666666666); - auto p6 = _mm256_set1_ps(0.041666666666666664); - auto p7 = _mm256_set1_ps(0.008333333333333333); - auto xMax = _mm256_set1_ps(87); - auto xMin = _mm256_set1_ps(-87); - auto basic = _mm256_set1_epi32(1 << 23); - auto temp127 = _mm256_set1_epi32(127); - for (int i = 0; i < count; ++i) { - auto x = _mm256_sub_ps(_mm256_loadu_ps(source + i * 8), _mm256_set1_ps(maxValue)); - x = _mm256_max_ps(x, xMin); - x = _mm256_min_ps(x, xMax); - auto div = _mm256_mul_ps(x, p1); - auto divInt = _mm256_cvtps_epi32(div); - div = _mm256_cvtepi32_ps(divInt); - auto div2 = _mm256_add_epi32(divInt, temp127); - div2 = _mm256_mullo_epi32(div2, basic); - auto expBasic = _mm256_castsi256_ps(div2); - auto xReamin = _mm256_sub_ps(x, _mm256_mul_ps(div, p0)); - auto t = xReamin; - auto c0 = _mm256_mul_ps(p7, t); - auto c1 = _mm256_add_ps(c0, p6); - auto c2 = _mm256_mul_ps(c1, t); - auto c3 = _mm256_add_ps(c2, p5); - auto c4 = _mm256_mul_ps(c3, t); - auto c5 = _mm256_add_ps(c4, p4); - auto c6 = _mm256_mul_ps(c5, t); - auto c7 = _mm256_add_ps(c6, p3); - auto c8 = _mm256_mul_ps(c7, t); - auto c9 = _mm256_add_ps(c8, p2); - auto expRemain = c9; - auto expRes = _mm256_mul_ps(expBasic, expRemain); - sumVal = _mm256_add_ps(expRes, sumVal); - _mm256_storeu_ps(dest + 8 * i, expRes); - } - _mm256_storeu_ps(tmpfloat8, sumVal); - for (int i = 0; i < 8; i++) { - sumValue += tmpfloat8[i]; - } - } - auto param = 0.6931471805599453; - float xLimit = 87; - for (int i = remain; i < size; i++) { - auto x = source[i] - maxValue; - x = x > -xLimit ? x : -xLimit; - x = x < xLimit ? x : xLimit; + for (int k = 0; k < outside; ++k) { + float* source = input + k * reduceSize; + float* dest = softmaxDst + k * reduceSize; - int div = (x / param); - int div2 = (div + 127) << 23; - auto xReamin = x - div * param; - float expBasic = *(float*)(&div2); + float tmpfloat8[8]; + int count = reduceSize/ 8; + int remain = count * 8; + // step 1: get maxValue + float maxValue = source[0]; - auto t = xReamin; - auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f; - dest[i] = expBasic * expRemain; - sumValue += dest[i]; - } - // step 3: get x / sum and store - for (int i = 0; i < count; ++i) { - // using 1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu - auto x = _mm256_rcp_ps(_mm256_loadu_ps(dest + 8 * i)); - auto y = _mm256_set1_ps(sumValue); - auto z = _mm256_rcp_ps(_mm256_mul_ps(x, y)); - _mm256_storeu_ps(dest + 8 * i, z); - } - sumValue = 1.f / sumValue; - for (int i = remain; i < size; i++) { - dest[i] *= sumValue; + float oldMax = maxValue; + if (runningMax) { + oldMax = runningMax[k]; + } + + if (count > 0) { + auto maxVal = _mm256_loadu_ps(source); + for (int i = 1; i < count; i++) { + maxVal = _mm256_max_ps(maxVal, _mm256_loadu_ps(source + i * 8)); + } + _mm256_storeu_ps(tmpfloat8, maxVal); + maxValue = tmpfloat8[0] > tmpfloat8[1] ? tmpfloat8[0] : tmpfloat8[1]; + for (int i = 2; i < 8; i++) { + maxValue = maxValue > tmpfloat8[i] ? maxValue : tmpfloat8[i]; + } + } + for (int i = remain; i < reduceSize; i++) { + maxValue = maxValue > source[i] ? maxValue : source[i]; + } + + float newMax = ALIMAX(oldMax, maxValue); + + // step 2: get exp(x - newMax) and sum(exp(x - newMax)) + float exprOffset[4] = {1.0f, 0.0f, 0.0f, 0.0f }; + exprOffset[2] = -newMax; + MNNExp(dest, source, exprOffset, reduceSize); + float sumValue = exprOffset[3]; + + if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) { + // === Step 3: Update running variables === + float scale = expf(oldMax - newMax); + runningSum[k] = runningSum[k] * scale + sumValue; + runningMax[k] = newMax; + updateScale[k] = scale; + } else { + // step 3: get x / sum and store + for (int i = 0; i < count; ++i) { + // using 1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu + auto x = _mm256_rcp_ps(_mm256_loadu_ps(dest + 8 * i)); + auto y = _mm256_set1_ps(sumValue); + auto z = _mm256_rcp_ps(_mm256_mul_ps(x, y)); + _mm256_storeu_ps(dest + 8 * i, z); + } + auto scale = 1.f / sumValue; + for (int i = remain; i < reduceSize; i++) { + dest[i] *= scale; + } + } } } diff --git a/source/backend/cpu/x86_x64/avx/PackedFunction.cpp b/source/backend/cpu/x86_x64/avx/PackedFunction.cpp index 999806ce..33a464d1 100644 --- a/source/backend/cpu/x86_x64/avx/PackedFunction.cpp +++ b/source/backend/cpu/x86_x64/avx/PackedFunction.cpp @@ -48,8 +48,45 @@ void _AVX_MNNConvRunForLineDepthwise(float* dst, const float* src, const float* size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height, size_t srcHStep, size_t dstHStep, const float* bias, const float* parameters); void _AVX_MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters); + +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE +void _AVX_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes); +#endif } +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE +void _AVX_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) { + // source shape: [headDim/pack, seqLen, pack] + // scale & normalizeScale shape: [seqLen] + // dest shape: [headDim/pack, seqLen, pack] + auto stride0 = plane * pack; + + if (idx > 0) { + for (int j = 0; j < depthQuad; ++j) { + for (int i = 0; i < plane; ++i) { + auto dataNew = Vec::load(src + j * stride0 + i * pack); + auto dataOld = Vec::load(dst + j * stride0 + i * pack); + auto s = Vec(scale[i]); + dataNew = Vec::fma(dataNew, dataOld, s); + Vec::save(dst + j * stride0 + i * pack, dataNew); + } + } + } else { + memcpy(dst, src, size * bytes); + } + if (idx == kvBlocks - 1) { // if last subBlock, exp(xi)/sum(exp(xi)) + for (int j = 0; j < depthQuad; ++j) { + for (int i = 0; i < plane; ++i) { + auto dataNew = Vec::load(dst + j * stride0 + i * pack); + auto ns = Vec(1.0f / normalizeScale[i]); + dataNew = dataNew * ns; + Vec::save(dst + j * stride0 + i * pack, dataNew); + } + } + } +} +#endif + void _AVX_MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { for (int i = 0; i < count; ++i) { @@ -728,4 +765,9 @@ void _AVX_ExtraInit(void* functions) { // sparse conv funcs coreFunction->MNNGetSparseMatMulPackMode = _AVX_MNNGetSparseMatMulPackMode; coreFunction->MNNAdjustOptimalSparseKernel = _AVX_MNNAdjustOptimalSparseKernel; + + // attention +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE + coreFunction->MNNFlashAttentionUpdateBlockOutput = _AVX_MNNFlashAttentionUpdateBlockOutput; +#endif } diff --git a/source/backend/cpu/x86_x64/avx512/PackedFunction.cpp b/source/backend/cpu/x86_x64/avx512/PackedFunction.cpp index e636490d..a8a56d89 100644 --- a/source/backend/cpu/x86_x64/avx512/PackedFunction.cpp +++ b/source/backend/cpu/x86_x64/avx512/PackedFunction.cpp @@ -1064,6 +1064,39 @@ static void _AVX512_MNNAdjustOptimalSparseKernel(int& sparseBlockOC, MNN::CoreFu } } +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE +void _AVX512_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) { + // source shape: [headDim/pack, seqLen, pack] + // scale & normalizeScale shape: [seqLen] + // dest shape: [headDim/pack, seqLen, pack] + auto stride0 = plane * pack; + + if (idx > 0) { + for (int j = 0; j < depthQuad; ++j) { + for (int i = 0; i < plane; ++i) { + auto dataNew = Vec::load(src + j * stride0 + i * pack); + auto dataOld = Vec::load(dst + j * stride0 + i * pack); + auto s = Vec(scale[i]); + dataNew = Vec::fma(dataNew, dataOld, s); + Vec::save(dst + j * stride0 + i * pack, dataNew); + } + } + } else { + memcpy(dst, src, size * bytes); + } + if (idx == kvBlocks - 1) { // if last subBlock, exp(xi)/sum(exp(xi)) + for (int j = 0; j < depthQuad; ++j) { + for (int i = 0; i < plane; ++i) { + auto dataNew = Vec::load(dst + j * stride0 + i * pack); + auto ns = Vec(1.0f / normalizeScale[i]); + dataNew = dataNew * ns; + Vec::save(dst + j * stride0 + i * pack, dataNew); + } + } + } +} +#endif + void _AVX512_ExtraInit(void* functions) { auto coreFunction = static_cast(functions); @@ -1100,4 +1133,8 @@ void _AVX512_ExtraInit(void* functions) { coreFunction->MNNGetSparseMatMulPackMode = _AVX512_MNNGetSparseMatMulPackMode; coreFunction->MNNAdjustOptimalSparseKernel = _AVX512_MNNAdjustOptimalSparseKernel; + +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE + coreFunction->MNNFlashAttentionUpdateBlockOutput = _AVX512_MNNFlashAttentionUpdateBlockOutput; +#endif } diff --git a/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp b/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp index e7752676..5c389ee7 100644 --- a/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp +++ b/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp @@ -81,7 +81,7 @@ void _SSE_MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count); void _SSE_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose); void _SSE_MNNReluInt8(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint); -void _SSE_MNNSoftmax(float* dest, const float* source, size_t size); +void _SSE_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize); void _SSE_ExtraInit(void* functions); void _SSE_MNNNorm(float *dst, const float *src, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm); void _SSE_ImageProcessInit(void* functions, int cpuFlags); diff --git a/source/backend/cpu/x86_x64/sse/MathFunctions.cpp b/source/backend/cpu/x86_x64/sse/MathFunctions.cpp index 061d7000..87b903fa 100644 --- a/source/backend/cpu/x86_x64/sse/MathFunctions.cpp +++ b/source/backend/cpu/x86_x64/sse/MathFunctions.cpp @@ -69,100 +69,68 @@ void _SSE_MNNExpC8(float* dest, const float* source, float* offset, const float* offset[3] = total; } -void _SSE_MNNSoftmax(float* dest, const float* source, size_t size) { - float tmpfloat4[4]; - int count = static_cast(size / 4); - int remain = count * 4; - // step 1: get maxValue - float maxValue = source[0]; - if (count > 0) { - auto maxVal = _mm_loadu_ps(source); - for (int i = 1; i < count; i++) { - maxVal = _mm_max_ps(maxVal, _mm_loadu_ps(source + i * 4)); +void _SSE_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) { + const float xLimit = 87.0f; + const float param = 0.6931471805599453f; // ln(2) + const float inv_param = 1.0f / param; + const int32_t exp_offset = 127; + const float exp_scale = 8388608.0f; // 2^23 + + for (int k = 0; k < outside; ++k) { + float* source = input + k * reduceSize; + float* dest = softmaxDst + k * reduceSize; + + float tmpfloat4[4]; + int count = static_cast(reduceSize / 4); + int remain = count * 4; + // step 1: get maxValue + float maxValue = source[0]; + float oldMax = maxValue; + if (runningMax) { + oldMax = runningMax[k]; } - _mm_storeu_ps(tmpfloat4, maxVal); - maxValue = tmpfloat4[0] > tmpfloat4[1] ? tmpfloat4[0] : tmpfloat4[1]; - maxValue = maxValue > tmpfloat4[2] ? maxValue : tmpfloat4[2]; - maxValue = maxValue > tmpfloat4[3] ? maxValue : tmpfloat4[3]; - } - for (int i = remain; i < size; i++) { - maxValue = maxValue > source[i] ? maxValue : source[i]; - } - - // step 2: get exp(x - maxValue) and sum(exp(x - maxValue)) - float sumValue = 0.f; - if (count > 0) { - auto sumVal = _mm_set1_ps(0.f); - auto p0 = _mm_set1_ps(0.6931471805599453); - auto p1 = _mm_set1_ps(1.4426950408889634); - auto p2 = _mm_set1_ps(1.f); - auto p3 = _mm_set1_ps(1.f); - auto p4 = _mm_set1_ps(0.5); - auto p5 = _mm_set1_ps(0.1666666666666666); - auto p6 = _mm_set1_ps(0.041666666666666664); - auto p7 = _mm_set1_ps(0.008333333333333333); - auto xMax = _mm_set1_ps(87); - auto xMin = _mm_set1_ps(-87); - // auto basic = _mm_set1_epi32(1 << 23); - for (int i = 0; i < count; ++i) { - auto x = _mm_sub_ps(_mm_loadu_ps(source + i * 4), _mm_set1_ps(maxValue)); - x = _mm_max_ps(x, xMin); - x = _mm_min_ps(x, xMax); - auto div = _mm_mul_ps(x, p1); - auto divInt = _mm_cvtps_epi32(div); - div = _mm_cvtepi32_ps(divInt); - auto div2 = _mm_add_epi32(divInt, _mm_set1_epi32(127)); - // div2 = _mm_mullo_epi32(div2, basic); - div2 = _mm_slli_epi32(div2, 23); - auto expBasic = _mm_castsi128_ps(div2); - auto xReamin = _mm_sub_ps(x, _mm_mul_ps(div, p0)); - auto t = xReamin; - auto c0 = _mm_mul_ps(p7, t); - auto c1 = _mm_add_ps(c0, p6); - auto c2 = _mm_mul_ps(c1, t); - auto c3 = _mm_add_ps(c2, p5); - auto c4 = _mm_mul_ps(c3, t); - auto c5 = _mm_add_ps(c4, p4); - auto c6 = _mm_mul_ps(c5, t); - auto c7 = _mm_add_ps(c6, p3); - auto c8 = _mm_mul_ps(c7, t); - auto c9 = _mm_add_ps(c8, p2); - auto expRemain = c9; - auto expRes = _mm_mul_ps(expBasic, expRemain); - sumVal = _mm_add_ps(expRes, sumVal); - _mm_storeu_ps(dest + 4 * i, expRes); + if (count > 0) { + auto maxVal = _mm_loadu_ps(source); + for (int i = 1; i < count; i++) { + maxVal = _mm_max_ps(maxVal, _mm_loadu_ps(source + i * 4)); + } + _mm_storeu_ps(tmpfloat4, maxVal); + maxValue = tmpfloat4[0] > tmpfloat4[1] ? tmpfloat4[0] : tmpfloat4[1]; + maxValue = maxValue > tmpfloat4[2] ? maxValue : tmpfloat4[2]; + maxValue = maxValue > tmpfloat4[3] ? maxValue : tmpfloat4[3]; + } + for (int i = remain; i < reduceSize; i++) { + maxValue = maxValue > source[i] ? maxValue : source[i]; } - _mm_storeu_ps(tmpfloat4, sumVal); - sumValue = tmpfloat4[0] + tmpfloat4[1] + tmpfloat4[2] + tmpfloat4[3]; - } - auto param = 0.6931471805599453; - float xLimit = 87; - for (int i = remain; i < size; i++) { - auto x = source[i] - maxValue; - x = x > -xLimit ? x : -xLimit; - x = x < xLimit ? x : xLimit; - int div = (x / param); - int div2 = (div + 127) << 23; - auto xReamin = x - div * param; - float expBasic = *(float*)(&div2); + float newMax = ALIMAX(oldMax, maxValue); - auto t = xReamin; - auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f; - dest[i] = expBasic * expRemain; - sumValue += dest[i]; - } - // step 3: get x / sum and store - for (int i = 0; i < count; ++i) { - // using 1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu - auto x = _mm_rcp_ps(_mm_loadu_ps(dest + 4 * i)); - auto y = _mm_set1_ps(sumValue); - auto z = _mm_rcp_ps(_mm_mul_ps(x, y)); - _mm_storeu_ps(dest + 4 * i, z); - } - sumValue = 1.f / sumValue; - for (int i = remain; i < size; i++) { - dest[i] *= sumValue; + // step 2: get exp(x - newMax) and sum(exp(x - newMax)) + float exprOffset[4] = {1.0f, 0.0f, 0.0f, 0.0f }; + exprOffset[2] = -newMax; + MNNExp(dest, source, exprOffset, reduceSize); + float sumValue = exprOffset[3]; + + if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) { + // === Step 3: Update running variables === + float scale = expf(oldMax - newMax); + runningSum[k] = runningSum[k] * scale + sumValue; + runningMax[k] = newMax; + updateScale[k] = scale; + } else { + // step 3: get x / sum and store + for (int i = 0; i < count; ++i) { + // using 1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu + auto x = _mm_rcp_ps(_mm_loadu_ps(dest + 4 * i)); + auto y = _mm_set1_ps(sumValue); + auto z = _mm_rcp_ps(_mm_mul_ps(x, y)); + _mm_storeu_ps(dest + 4 * i, z); + } + auto scale = 1.f / sumValue; + for (int i = remain; i < reduceSize; i++) { + dest[i] *= scale; + } + } } } diff --git a/source/backend/opencl/core/BufferConvertor.cpp b/source/backend/opencl/core/BufferConvertor.cpp index 327b0ead..9b89efff 100644 --- a/source/backend/opencl/core/BufferConvertor.cpp +++ b/source/backend/opencl/core/BufferConvertor.cpp @@ -324,9 +324,9 @@ bool BufferConvertor::convertToNC4HW4Buffer(const Tensor *buffer, const OpenCLBu auto formattedBufferShape = tensorShapeFormat(buffer);//NHWC std::vector imageShape; getImageShape(formattedBufferShape, type, &imageShape); - + uint32_t gws[2] = {static_cast(imageShape[0]), static_cast(imageShape[1])}; - + auto runtime = mOpenCLRuntime; std::string kernelName; std::string kernelFile = "buffer_convert_buf"; @@ -360,26 +360,23 @@ bool BufferConvertor::convertToNC4HW4Buffer(const Tensor *buffer, const OpenCLBu default: break; } - if (mBufferToImageKernel.get() == nullptr || mBufferToImageKernelName != kernelName) { - mBufferToImageKernelName = kernelName; - std::set buildOptions; - if(needTrans) { - //buildOptions.emplace("-DBUFFER_FORMAT_INP_TRANS"); - kernelName += "_floatin"; - } -#ifdef MNN_LOW_MEMORY - if (lowMemory) { - if (quantBit == 8) { - // int8 case - buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT8"); - } else if (quantBit == 4){ - // int4 case - buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT4"); - } else {/* More types to be supported. */} - } -#endif - mBufferToImageKernel = runtime->buildKernelWithCache(kernelFile, kernelName, buildOptions, precision, buffer, image); + std::set buildOptions; + if(needTrans) { + //buildOptions.emplace("-DBUFFER_FORMAT_INP_TRANS"); + kernelName += "_floatin"; } +#ifdef MNN_LOW_MEMORY + if (lowMemory) { + if (quantBit == 8) { + // int8 case + buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT8"); + } else if (quantBit == 4){ + // int4 case + buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT4"); + } else {/* More types to be supported. */} + } +#endif + mBufferToImageKernel = runtime->buildKernelWithCache(kernelFile, kernelName, buildOptions, precision, buffer, image); auto kernel = mBufferToImageKernel->get(); uint32_t idx = 0; diff --git a/source/backend/opencl/core/BufferConvertor.hpp b/source/backend/opencl/core/BufferConvertor.hpp index 8bec0ec1..624df776 100644 --- a/source/backend/opencl/core/BufferConvertor.hpp +++ b/source/backend/opencl/core/BufferConvertor.hpp @@ -43,7 +43,7 @@ public: explicit BufferConvertor(OpenCLRuntime *opencl_runtime) : mOpenCLRuntime(opencl_runtime) { } bool convertToNC4HW4Buffer(const Tensor *input, const OpenCLBufferFormat type, Tensor *output, int precision, - bool needTrans, bool needWait = false, bool lowMemory = false, int quantBit = 0); + bool needTrans, bool needWait = true, bool lowMemory = false, int quantBit = 0); private: OpenCLRuntime *mOpenCLRuntime; diff --git a/source/backend/opencl/core/OpenCLBackend.cpp b/source/backend/opencl/core/OpenCLBackend.cpp index c377ba7f..a1cd0435 100644 --- a/source/backend/opencl/core/OpenCLBackend.cpp +++ b/source/backend/opencl/core/OpenCLBackend.cpp @@ -47,8 +47,13 @@ CLRuntime::CLRuntime(const Backend::Info& info){ mPrecision = mInfo.user->precision; mMemory = mInfo.user->memory; } - - mOpenCLRuntime.reset(new OpenCLRuntime(platform_size, platform_id, device_id, context_ptr, hint())); + + // protect + if(mPrecision > 2 || mPrecision < 0){ + mPrecision = BackendConfig::Precision_High; + } + + mOpenCLRuntime.reset(new OpenCLRuntime(platform_size, platform_id, device_id, context_ptr, hint())); //Whether runtimeError mCLRuntimeError = mOpenCLRuntime->isCreateError(); @@ -206,6 +211,10 @@ Backend* CLRuntime::onCreate(const BackendConfig* config, Backend* origin) const precision = config->precision; memory = config->memory; } + // protect + if(precision > 2 || precision < 0){ + precision = BackendConfig::Precision_High; + } auto backend = new OpenCLBackend(precision, memory, mInfo.gpuMode, mImagePool, mBufferPool, this); backend->setMetaPtr(pMeta); return backend; @@ -246,6 +255,10 @@ OpenCLBackend::OpenCLBackend(BackendConfig::PrecisionMode precision, BackendConf } else{ mPrecision = BackendConfig::Precision_High; } + // protect + if(mPrecision > 2 || mPrecision < 0){ + mPrecision = BackendConfig::Precision_High; + } mMemory = memory; // set tuneLevel, memtype, record mode setGpuMode(gpuMode); diff --git a/source/backend/qnn/backend/QNNBackend.cpp b/source/backend/qnn/backend/QNNBackend.cpp index 96cca4f1..31437782 100644 --- a/source/backend/qnn/backend/QNNBackend.cpp +++ b/source/backend/qnn/backend/QNNBackend.cpp @@ -18,6 +18,8 @@ struct QnnContext { Qnn_LogHandle_t logHandle = nullptr; Qnn_BackendHandle_t backendHandle = nullptr; Qnn_DeviceHandle_t deviceHandle = nullptr; + int soc_id; + int dsp_arch; }; static QnnContext gContext; @@ -94,6 +96,7 @@ bool PluginShapeRaw::compute(InferShapeContext* ctx) { auto dst = ctx->output(i); std::string key = prefix + std::to_string(i); auto attr = ctx->getAttr(key.c_str()); + if (nullptr == attr || nullptr == attr->tensor()) { MNN_ERROR("MNN_QNN: Failed to find raw shape %s.\n", key.c_str()); return false; @@ -886,7 +889,12 @@ ErrorCode QnnBackend::onResizeEnd() { #ifdef QNN_VERBOSE MNN_PRINT("start finalize\n"); #endif + buildOutputDequant(); finalizeGraph(); + for(auto func : mReleaseFunc){ + func(); + } + mReleaseFunc.clear(); #ifdef QNN_VERBOSE MNN_PRINT("end finalize\n"); #endif @@ -915,7 +923,14 @@ Backend::MemObj* QnnBackend::onAcquire(const Tensor* tensor, StorageType storage } Qnn_DataType_t tDataType; - MNN_ASSERT((tensor->getType().code == halide_type_float) || (tensor->getType().code == halide_type_int && tensor->getType().bits == 32)); + Qnn_QuantizeParams_t tQuantizeParams{}; + tQuantizeParams.encodingDefinition = QNN_DEFINITION_UNDEFINED; + tQuantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED; + Qnn_ScaleOffset_t tScaleOffsetEncoding; + tScaleOffsetEncoding.scale = 0.0f; + tScaleOffsetEncoding.offset = 0; + auto quant = TensorUtils::getDescribe(tensor)->quantAttr.get(); + //MNN_ASSERT((tensor->getType().code == halide_type_float) || (tensor->getType().code == halide_type_int && tensor->getType().bits == 32)); if (mUseFP16 && tensor->getType().code == halide_type_float) { tDataType = QNN_DATATYPE_FLOAT_16; } else if (tensor->getType().code == halide_type_float) { @@ -926,13 +941,19 @@ Backend::MemObj* QnnBackend::onAcquire(const Tensor* tensor, StorageType storage MNN_PRINT("MNN_QNN: Not supported data type in .\n"); return nullptr; } - - Qnn_QuantizeParams_t tQuantizeParams{}; - tQuantizeParams.encodingDefinition = QNN_DEFINITION_UNDEFINED; - tQuantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED; - Qnn_ScaleOffset_t tScaleOffsetEncoding; - tScaleOffsetEncoding.scale = 0.0f; - tScaleOffsetEncoding.offset = 0; + if(quant != nullptr && TensorUtils::getDescribe(tensor)->type == DataType_DT_INT8) { + tQuantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED; + tQuantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + tScaleOffsetEncoding.scale = quant->scale; + if(quant->zero != 0){ + MNN_PRINT("MNN_QNN: Not supported asymmetric quant in .\n"); + return nullptr; + } + tDataType = QNN_DATATYPE_SFIXED_POINT_8; + if (isOutput) { + tType = QNN_TENSOR_TYPE_NATIVE; + } + } tQuantizeParams.scaleOffsetEncoding = tScaleOffsetEncoding; std::unique_ptr tempTensor(new Tensor(tensor, gQnnTensorDimType, false)); @@ -947,17 +968,41 @@ Backend::MemObj* QnnBackend::onAcquire(const Tensor* tensor, StorageType storage Qnn_Tensor_t * qnnTensor = qnnTensorWrapper->getNativeTensor(); CALL_QNN(mRuntime->mQnnInterface.tensorCreateGraphTensor(mQnnGraphHandle, qnnTensor)); + mQNNTensorWrappers.push_back(qnnTensorWrapper); + mTensorMap.insert({TensorUtils::getDescribe(tensor), mTensorCounter}); if (isInput) { mInputTensorIndexes.push_back(mTensorCounter); qnnTensorWrapper->alloc(); } if (isOutput) { - mOutputTensorIndexes.push_back(mTensorCounter); - qnnTensorWrapper->alloc(); + if(quant != nullptr && TensorUtils::getDescribe(tensor)->type == DataType_DT_INT8){ + mTensorCounter += 1; + std::shared_ptr stageTensor; + stageTensor.reset(Tensor::create(tensor->shape(), nullptr, gQnnTensorDimType)); + tName = "QnnTensor_" + std::to_string(mTensorCounter); + tType = QNN_TENSOR_TYPE_APP_READ; + if (mUseFP16 && tensor->getType().code == halide_type_float) { + tDataType = QNN_DATATYPE_FLOAT_16; + } else if (tensor->getType().code == halide_type_float) { + tDataType = QNN_DATATYPE_FLOAT_32; + } else { + MNN_PRINT("MNN_QNN: Not supported data type in .\n"); + return nullptr; + } + Qnn_QuantizeParams_t tQuantizeParamstmp = QNN_QUANTIZE_PARAMS_INIT; + std::shared_ptr qnnOutputTensorWrapper = QNNTensorWrapper::create(tName, tType, tDataType, tDims, tQuantizeParamstmp); + CALL_QNN(mRuntime->mQnnInterface.tensorCreateGraphTensor(mQnnGraphHandle, qnnOutputTensorWrapper->getNativeTensor())); + mDeQuantOutputTensorMap.insert({TensorUtils::getDescribe(tensor), {tensor, stageTensor}}); + mQNNTensorWrappers.push_back(qnnOutputTensorWrapper); + mTensorMap.insert({TensorUtils::getDescribe(const_cast(stageTensor.get())), mTensorCounter}); + mOutputTensorIndexes.push_back(mTensorCounter); + qnnOutputTensorWrapper->alloc(); + } else{ + mOutputTensorIndexes.push_back(mTensorCounter); + qnnTensorWrapper->alloc(); + } } - mQNNTensorWrappers.push_back(qnnTensorWrapper); - mTensorMap.insert({TensorUtils::getDescribe(tensor), mTensorCounter}); mTensorCounter += 1; #ifdef QNN_VERBOSE @@ -1029,7 +1074,13 @@ void QnnBackend::inputIO(const Tensor* srcTensor, const Tensor* dstTensor) const } void QnnBackend::outputIO(const Tensor* srcTensor, const Tensor* dstTensor) const { - int srcIndex = getTensorIdx(srcTensor); + auto iter = mDeQuantOutputTensorMap.find(TensorUtils::getDescribe(srcTensor)); + int srcIndex = -1; + if(iter != mDeQuantOutputTensorMap.end()){ + srcIndex = getTensorIdx(iter->second.second.get()); + } else{ + srcIndex = getTensorIdx(srcTensor); + } std::shared_ptr srcQnnTensorWrapper = mQNNTensorWrappers[srcIndex]; std::shared_ptr srcDataContainer = srcQnnTensorWrapper->getDataContainer(); @@ -1091,7 +1142,7 @@ void QnnBackend::finalizeGraph() { // set QNN_PROFILE_LEVEL_DETAILED QnnProfile_Level_t profileLevel = QNN_PROFILE_LEVEL_DETAILED; MNN_PRINT("[QNN Profile] Creating QNN Profile Handle with DETAILED level.\n"); - auto profile_err = mRuntime->mQnnInterface.profileCreate(mQnnContextHandle, profileLevel, &mQnnProfileHandle); + auto profile_err = mRuntime->mQnnInterface.profileCreate(mRuntime->mQnnContextHandle, profileLevel, &mQnnProfileHandle); if (profile_err != QNN_SUCCESS || mQnnProfileHandle == nullptr) { MNN_ERROR("[QNN Profile] Failed to create QNN Profile Handle, error: %d\n", (int)profile_err); mQnnProfileHandle = nullptr; @@ -1216,6 +1267,27 @@ void QnnBackend::clean() { mTensorMap.clear(); mInputTensorIndexes.clear(); mOutputTensorIndexes.clear(); + mDeQuantOutputTensorMap.clear(); +} +void QnnBackend::buildOutputDequant(){ + Qnn_OpConfigVersion_t mOpConfigVersion = QNN_OPCONFIG_VERSION_1; + std::string mNodeName; + std::string mPackageName = "qti.aisw"; + std::string mNodeType; + std::vector mParams; + std::vector mInputs; + std::vector mOutputs; + for(auto iter : mDeQuantOutputTensorMap){ + mNodeType.clear(); + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Dequantize"; + std::string name = "Dequantize_I_" + std::to_string(getTensorIdx(iter.second.first)) + "_O_" + std::to_string(getTensorIdx(iter.second.second.get()));; + mInputs.push_back(*(getNativeTensor(iter.second.first))); // input + mOutputs.push_back(*(getNativeTensor(iter.second.second.get()))); // output + addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } } QnnRuntime::QnnRuntime(const Backend::Info& info, QNN_INTERFACE_VER_TYPE qnnInterface, Qnn_LogHandle_t qnnLogHandle, Qnn_BackendHandle_t qnnBackendHandle, Qnn_DeviceHandle_t qnnDeviceHandle) { @@ -1324,12 +1396,23 @@ bool QnnRuntime::registerCustomOpPackage(QNN_INTERFACE_VER_TYPE qnnInterface, Qn class QnnRuntimeCreator : public RuntimeCreator { public: - virtual Runtime* onCreate(const Backend::Info& info) const { + virtual Runtime* onCreate(const Backend::Info& info) const override { return QnnRuntime::create(info); } - virtual bool onValid(Backend::Info& info) const { + virtual bool onValid(Backend::Info& info) const override { return true; } + virtual bool onGetDeviceInfo(const std::string& deviceKey, std::string& deviceValue) const override { + if(deviceKey == "soc_id" && gContext.soc_id != 0) { + deviceValue = std::to_string(gContext.soc_id); + return true; + } + if(deviceKey == "dsp_arch" && gContext.dsp_arch != 0) { + deviceValue = "v" + std::to_string(gContext.dsp_arch); + return true; + } + return false; + } }; } // end namespace QNN @@ -1401,6 +1484,8 @@ void registerQNNRuntimeCreator() { // Create Device. Qnn_DeviceHandle_t deviceHandle = nullptr; + QnnHtpDevice_Arch_t dspArch = QNN_HTP_DEVICE_ARCH_NONE; + uint32_t socId = 0; { // Check whether the device API is supported. bool supportDevice = QNN::checkCapability(qnnInterface, QNN_PROPERTY_GROUP_DEVICE); @@ -1422,10 +1507,8 @@ void registerQNNRuntimeCreator() { MNN_PRINT("[Warning]: deviceGetPlatformInfo Failed to query platform info"); } else { QnnDevice_HardwareDeviceInfo_t* hwDeviceInfo = backendPlatformInfoPtr->v1.hwDevices; - QnnHtpDevice_Arch_t arch = hwDeviceInfo->v1.deviceInfoExtension->onChipDevice.arch; - uint32_t socModel = hwDeviceInfo->v1.deviceInfoExtension->onChipDevice.socModel; - MNN_PRINT("Qnn Device soc_id: %d\n", socModel); - MNN_PRINT("Qnn Device dsp_arch: v%d\n", arch); + dspArch = hwDeviceInfo->v1.deviceInfoExtension->onChipDevice.arch; + socId = hwDeviceInfo->v1.deviceInfoExtension->onChipDevice.socModel; } } } else { @@ -1478,6 +1561,8 @@ void registerQNNRuntimeCreator() { QNN::gContext.backendHandle = backendHandle; QNN::gContext.deviceHandle = deviceHandle; QNN::gContext.logHandle = logHandle; + QNN::gContext.soc_id = socId; + QNN::gContext.dsp_arch = dspArch; QNN::registerQNNOps(); MNNInsertExtraRuntimeCreator(MNN_FORWARD_NN, new QNN::QnnRuntimeCreator, false); diff --git a/source/backend/qnn/backend/QNNBackend.hpp b/source/backend/qnn/backend/QNNBackend.hpp index 1c951ea1..a75acf06 100644 --- a/source/backend/qnn/backend/QNNBackend.hpp +++ b/source/backend/qnn/backend/QNNBackend.hpp @@ -78,6 +78,10 @@ public: std::shared_ptr getTensorWrapper(const Tensor * tensor); bool useCache() const; bool getUseFP16() const; + void buildOutputDequant(); + void pushReleaseFunc(std::function func){ + mReleaseFunc.push_back(func); + } private: void clean(); @@ -109,8 +113,10 @@ private: mutable int mTensorCounter = 0; mutable std::vector> mQNNTensorWrappers; mutable std::map mTensorMap; + mutable std::map>> mDeQuantOutputTensorMap; std::vector mInputTensorIndexes; std::vector mOutputTensorIndexes; + std::vector> mReleaseFunc; }; diff --git a/source/backend/qnn/backend/QNNUtils.cpp b/source/backend/qnn/backend/QNNUtils.cpp index 9d73b043..aeb27637 100644 --- a/source/backend/qnn/backend/QNNUtils.cpp +++ b/source/backend/qnn/backend/QNNUtils.cpp @@ -134,6 +134,8 @@ void registerQNNOps() { #ifdef MNN_SUPPORT_TRANSFORMER_FUSE ___QNNAttentionCreator__OpType_Attention__(); #endif + ___QNNQuantCreator__OpType_FloatToInt8__(); + ___QNNDeQuantCreator__OpType_Int8ToFloat__(); } Tensor::DimensionType gQnnTensorDimType = Tensor::TENSORFLOW; diff --git a/source/backend/qnn/backend/QNNUtils.hpp b/source/backend/qnn/backend/QNNUtils.hpp index 3b41f872..92128fdb 100644 --- a/source/backend/qnn/backend/QNNUtils.hpp +++ b/source/backend/qnn/backend/QNNUtils.hpp @@ -104,6 +104,8 @@ extern void ___QNNMatMulCreator__OpType_MatMul__(); #ifdef MNN_SUPPORT_TRANSFORMER_FUSE extern void ___QNNAttentionCreator__OpType_Attention__(); #endif +extern void ___QNNQuantCreator__OpType_FloatToInt8__(); +extern void ___QNNDeQuantCreator__OpType_Int8ToFloat__(); void registerQNNOps(); extern Tensor::DimensionType gQnnTensorDimType; diff --git a/source/backend/qnn/backend/QNNWrapper.cpp b/source/backend/qnn/backend/QNNWrapper.cpp index f3e6cd8f..3731225f 100644 --- a/source/backend/qnn/backend/QNNWrapper.cpp +++ b/source/backend/qnn/backend/QNNWrapper.cpp @@ -25,7 +25,7 @@ std::shared_ptr QNNTensorWrapper::create(const std::string & n std::shared_ptr QNNTensorWrapper::createStaticTensor(const std::string & name, Qnn_DataType_t dataType, const std::vector & dimensions, const void * buffer, Qnn_QuantizeParams_t quantizeParam) { MNN_ASSERT(!name.empty() && !dimensions.empty() && buffer); - MNN_ASSERT(dataType == QNN_DATATYPE_SFIXED_POINT_8 || dataType == QNN_DATATYPE_INT_32); + MNN_ASSERT(dataType == QNN_DATATYPE_SFIXED_POINT_8 || dataType == QNN_DATATYPE_INT_32 || dataType == QNN_DATATYPE_SFIXED_POINT_32 || dataType == QNN_DATATYPE_UFIXED_POINT_8); std::shared_ptr tensorWrapper = QNNTensorWrapper::create(name, QNN_TENSOR_TYPE_STATIC, dataType, dimensions, quantizeParam); uint32_t numElement = 1; @@ -114,7 +114,8 @@ void * QNNTensorWrapper::alloc() { dims[i] = (int)mDimensions[i]; } - MNN_ASSERT(mQnnTensor.v1.dataType == QNN_DATATYPE_FLOAT_32 || mQnnTensor.v1.dataType == QNN_DATATYPE_FLOAT_16 || mQnnTensor.v1.dataType == QNN_DATATYPE_INT_32 || mQnnTensor.v1.dataType == QNN_DATATYPE_SFIXED_POINT_8); + MNN_ASSERT(mQnnTensor.v1.dataType == QNN_DATATYPE_FLOAT_32 || mQnnTensor.v1.dataType == QNN_DATATYPE_FLOAT_16 || mQnnTensor.v1.dataType == QNN_DATATYPE_INT_32 || mQnnTensor.v1.dataType == QNN_DATATYPE_SFIXED_POINT_8 || mQnnTensor.v1.dataType == QNN_DATATYPE_SFIXED_POINT_32 + || mQnnTensor.v1.dataType == QNN_DATATYPE_UFIXED_POINT_8); halide_type_t halideType; halideType.lanes = 1; @@ -134,6 +135,15 @@ void * QNNTensorWrapper::alloc() { case QNN_DATATYPE_SFIXED_POINT_8: halideType.code = halide_type_int; halideType.bits = 8; + break; + case QNN_DATATYPE_SFIXED_POINT_32: + halideType.code = halide_type_int; + halideType.bits = 32; + break; + case QNN_DATATYPE_UFIXED_POINT_8: + halideType.code = halide_type_int; + halideType.bits = 8; + break; default: break; } diff --git a/source/backend/qnn/convertor/QNNConvertor.cpp b/source/backend/qnn/convertor/QNNConvertor.cpp index e81ad28e..a575ee01 100644 --- a/source/backend/qnn/convertor/QNNConvertor.cpp +++ b/source/backend/qnn/convertor/QNNConvertor.cpp @@ -269,6 +269,10 @@ std::vector QNNTranslator::TranslateTensor(const QNNCommandTensor& if (isParam) { result.push_back(QNNTranslator::TranslateParamDataArray(dataNameSymbol, cmdT.dataType, cmdT.clientBuf)); } + if(hasQuant){ + std::vector linesQuantScaleOffset = TranslateQuantizeScaleOffsetDataArray(tensorNameSymbol, cmdT.quantizeParams, cmdT.rank, cmdT.dimensions); + APPEND_VECTOR(result, linesQuantScaleOffset); + } result.push_back(" Qnn_Tensor_t " + tensorNameSymbol + " = QNN_TENSOR_INIT;"); result.push_back(" {"); result.push_back(" " + tensorNameSymbol + ".version = QNN_TENSOR_VERSION_1;"); @@ -456,11 +460,122 @@ std::string QNNTranslator::TranslateParamDataArray(const std::string & dataNameS return result; } +std::vector QNNTranslator::TranslateQuantizeScaleOffsetDataArray(const std::string & tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParams, uint32_t rank, const uint32_t * dimensions){ + std::vector result; + if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET){ + result.push_back(" Qnn_ScaleOffset_t " + tensorNameSymbol + "_axis_scale_offset[] = {"); + int totalnum = (quantizeParams.axisScaleOffsetEncoding.numScaleOffsets + 3) / 4; + for(int i = 0; i < totalnum; ++i){ + std::string line = " "; + for(int j = 0; j < 4; ++j){ + int index = i * 4 + j; + if(index >= quantizeParams.axisScaleOffsetEncoding.numScaleOffsets) + break; + line += "{.scale= " + std::to_string(quantizeParams.axisScaleOffsetEncoding.scaleOffset[index].scale) + ", .offset= " + std::to_string(quantizeParams.axisScaleOffsetEncoding.scaleOffset[index].offset) + "}, "; + } + result.push_back(line); + } + result.push_back(" };"); + } + + if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET){ + result.push_back(" float " + tensorNameSymbol + "_bwaxis_scale[] = {"); + int totalnum = (quantizeParams.bwAxisScaleOffsetEncoding.numElements + 3) / 4; + for(int i = 0; i < totalnum; ++i){ + std::string line = " "; + for(int j = 0; j < 4; ++j){ + int index = i * 4 + j; + if(index >= quantizeParams.bwAxisScaleOffsetEncoding.numElements) + break; + line += std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.scales[index]) + ", "; + } + result.push_back(line); + } + result.push_back(" };"); + if(quantizeParams.bwAxisScaleOffsetEncoding.offsets != nullptr){ + result.push_back(" int32_t " + tensorNameSymbol + "_bwaxis_offset[] = {"); + for(int i = 0; i < totalnum; ++i){ + std::string line = " "; + for(int j = 0; j < 4; ++j){ + int index = i * 4 + j; + if(index >= quantizeParams.bwAxisScaleOffsetEncoding.numElements) + break; + line += std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.offsets[index]) + ", "; + } + result.push_back(line); + } + result.push_back(" };"); + } + } + + if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION){ + int axis = quantizeParams.blockwiseExpansion->axis; + int oc = dimensions[axis]; + int blockSize = quantizeParams.blockwiseExpansion->numBlocksPerAxis; + result.push_back(" Qnn_BlockwiseExpansion_t " + tensorNameSymbol + "_blockwiseExpansion = QNN_BLOCKWISE_EXPANSION_INIT;"); + + result.push_back(" Qnn_ScaleOffset_t " + tensorNameSymbol + "_blockwiseExpansionScaleOffset[] = {"); + int totalnum = (oc + 3) / 4; + for(int i = 0; i < totalnum; ++i){ + std::string line = " "; + for(int j = 0; j < 4; ++j){ + int index = i * 4 + j; + if(index >= oc) + break; + line += "{.scale= " + std::to_string(quantizeParams.blockwiseExpansion->scaleOffsets[index].scale) + ", .offset= " + std::to_string(quantizeParams.blockwiseExpansion->scaleOffsets[index].offset) + "}, "; + } + result.push_back(line); + } + result.push_back(" };"); + if(quantizeParams.blockwiseExpansion->blockScaleStorageType == QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8){ + result.push_back(" uint8_t " + tensorNameSymbol + "_blockwiseExpansionBlockScale[] = {"); + totalnum = (oc * blockSize + 3) / 4; + for(int i = 0; i < totalnum; ++i){ + std::string line = " "; + for(int j = 0; j < 4; ++j){ + int index = i * 4 + j; + if(index >= oc * blockSize) + break; + line += std::to_string(quantizeParams.blockwiseExpansion->blocksScale8[index]) + ", "; + } + result.push_back(line); + } + result.push_back(" };"); + }else{ + result.push_back(" uint16_t " + tensorNameSymbol + "_blockwiseExpansionBlockScale[] = {"); + totalnum = (oc * blockSize + 3) / 4; + for(int i = 0; i < totalnum; ++i){ + std::string line = " "; + for(int j = 0; j < 4; ++j){ + int index = i * 4 + j; + if(index >= oc * blockSize) + break; + line += std::to_string(quantizeParams.blockwiseExpansion->blocksScale16[index]) + ", "; + } + result.push_back(line); + } + result.push_back(" };"); + } + result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.axis = " + std::to_string(quantizeParams.blockwiseExpansion->axis) + ";"); + result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.scaleOffsets = " + tensorNameSymbol + "_blockwiseExpansionScaleOffset;"); + result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.numBlocksPerAxis = " + std::to_string(quantizeParams.blockwiseExpansion->numBlocksPerAxis) + ";"); + result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.blockScaleBitwidth = " + std::to_string(quantizeParams.blockwiseExpansion->blockScaleBitwidth) + ";"); + if(quantizeParams.blockwiseExpansion->blockScaleStorageType == QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8){ + result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.blockScaleStorageType = QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8;"); + result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.blocksScale8 = " + tensorNameSymbol + "_blockwiseExpansionBlockScale;"); + }else{ + result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.blockScaleStorageType = QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_16;"); + result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.blocksScale16 = " + tensorNameSymbol + "_blockwiseExpansionBlockScale;"); + } + } + return result; +} + // Currently, only support QNN_QUANTIZATION_ENCODING_UNDEFINED, QNN_QUANTIZATION_ENCODING_SCALE_OFFSET. -std::vector QNNTranslator::TranslateTensorQuantizeParams(const std::string tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParmas) { +std::vector QNNTranslator::TranslateTensorQuantizeParams(const std::string tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParams) { std::vector result; - if (quantizeParmas.encodingDefinition == QNN_DEFINITION_UNDEFINED) { + if (quantizeParams.encodingDefinition == QNN_DEFINITION_UNDEFINED) { result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_UNDEFINED;"); result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED;"); result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.scale = 0.0f;"); @@ -468,13 +583,42 @@ std::vector QNNTranslator::TranslateTensorQuantizeParams(const std: return result; } - if (quantizeParmas.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParmas.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { + if (quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;"); result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;"); - result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.scale = " + std::to_string(quantizeParmas.scaleOffsetEncoding.scale) + ";"); - result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.offset = " + std::to_string(quantizeParmas.scaleOffsetEncoding.offset) + ";"); + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.scale = " + std::to_string(quantizeParams.scaleOffsetEncoding.scale) + ";"); + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.offset = " + std::to_string(quantizeParams.scaleOffsetEncoding.offset) + ";"); return result; } + + if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET){ + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;"); + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET;"); + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.axisScaleOffsetEncoding.axis = " + std::to_string(quantizeParams.axisScaleOffsetEncoding.axis) + ";"); + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets = " + std::to_string(quantizeParams.axisScaleOffsetEncoding.numScaleOffsets) + ";"); + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.axisScaleOffsetEncoding.scaleOffset = " + tensorNameSymbol + "_axis_scale_offset;"); + return result; + } + + if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET){ + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;"); + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET;"); + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.axis = " + std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.axis) + ";"); + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.bitwidth = " + std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.bitwidth) + ";"); + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.numElements = " + std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.numElements) + ";"); + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.scales = " + tensorNameSymbol + "_bwaxis_scale;"); + if(quantizeParams.bwAxisScaleOffsetEncoding.offsets != nullptr) + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.offset = " + tensorNameSymbol + "_bwaxis_offset;"); + return result; + } + + if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION){ + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;"); + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION;"); + result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.blockwiseExpansion = &" + tensorNameSymbol + "_blockwiseExpansion;"); + return result; + } + MNN_ERROR("MNN_QNN: Unknown QuantizeParams.\n"); diff --git a/source/backend/qnn/convertor/QNNConvertor.hpp b/source/backend/qnn/convertor/QNNConvertor.hpp index 6864e233..afd2c875 100644 --- a/source/backend/qnn/convertor/QNNConvertor.hpp +++ b/source/backend/qnn/convertor/QNNConvertor.hpp @@ -69,7 +69,8 @@ private: static std::string MapDataType(Qnn_DataType_t dataType); static std::string TranslateDimensionsArray(const std::string & dimensionsNameSymbol, uint32_t rank, const uint32_t * dimensions); static std::string TranslateParamDataArray(const std::string & dataNameSymbol, Qnn_DataType_t dataType, const Qnn_ClientBuffer_t & clientBuf); - static std::vector TranslateTensorQuantizeParams(const std::string tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParmas); + static std::vector TranslateQuantizeScaleOffsetDataArray(const std::string & tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParams, uint32_t rank, const uint32_t * dimensions); + static std::vector TranslateTensorQuantizeParams(const std::string tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParams); static std::vector TranslateTensorClientBuf(const std::string & tensorNameSymbol, const std::string & dataNameSymbol, const std::string & sname, const Qnn_ClientBuffer_t & clientBuf, bool hasClientBuf, bool isParam); // Utility functions used by TranslateNode. static std::vector TranslateNodeParamArray(const std::string & nodeName,const std::string & paramArraySymbol, uint32_t numOfParams, const Qnn_Param_t * params); diff --git a/source/backend/qnn/execution/QNNConvDepthwise.cpp b/source/backend/qnn/execution/QNNConvDepthwise.cpp index ac5fb71b..c273e111 100644 --- a/source/backend/qnn/execution/QNNConvDepthwise.cpp +++ b/source/backend/qnn/execution/QNNConvDepthwise.cpp @@ -11,6 +11,157 @@ namespace MNN { namespace QNN { +void QNNConvDepthwise::isWeightQuantSupported(const Tensor *input, const int oc){ + Qnn_DataType_t dataType = mBackend->getNativeTensor(input)->v1.dataType; + if(mOp->main_as_Convolution2D()->quanParameter() == nullptr){ + mWeightQuant = false; + return; + }else{ + bool hasBais = false; + auto bias = mOp->main_as_Convolution2D()->bias(); + auto biasPtr = (float*)bias->data(); + for(int i = 0; i < oc; ++i){ + if(biasPtr[i] != 0.0f){ + hasBais = true; + break; + } + } + + std::shared_ptr quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true); + if(quanCommon->asymmetric || dataType == QNN_DATATYPE_FLOAT_16 || dataType == QNN_DATATYPE_FLOAT_32){ + // not support asymmetric and mBlockSize > 1 results incorrect now + mWeightQuant = false; + return; + } + + float inputScale = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.scale; + int inputOffset = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.offset; + if(inputOffset == 0){ + mWeightQuant = true; + }else{ + if(hasBais){ + mWeightQuant = false; + }else{ + mWeightQuant = true; + } + } + } +} + +ErrorCode QNNConvDepthwise::onEncodeQuantDequantDepthConv(Tensor *input, Tensor *output, const int n, const int ic, const int oc) { + auto conv2D = mOp->main_as_Convolution2D(); + auto common = conv2D->common(); + Qnn_DataType_t dataType = QNN_DATATYPE_FLOAT_32; + if(mBackend->getUseFP16()){ + dataType = QNN_DATATYPE_FLOAT_16; + } + + // create dequant input stage tensor + this->createStageTensor("DequantInput", dataType, getNHWCShape(input)); // mTempTensorWrappers[2] + this->createStageTensor("QuantOutput", dataType, getNHWCShape(output)); // mTempTensorWrappers[3] + + // add nodes + { + // dequant input + { + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Dequantize"; + std::string name = mNodeName + "_dequant_input"; + + mInputs.push_back(*(mBackend->getNativeTensor(input))); // input + mOutputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + + if (common->relu() || common->relu6()) { + this->createStageTensor("ReluTensor", dataType, getNHWCShape(output)); // mTempTensorWrappers[4] + // Stage one + { + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "DepthWiseConv2d"; + std::string name = mNodeName + "_convDepthwise"; + mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride + mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount + mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation + + mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput + mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight + mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias + + mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + + // Stage two + { + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = common->relu6() ? "ReluMinMax" : "Relu"; + std::string name = mNodeName + "_relu"; + if (common->relu6()) { + mParams.push_back(*(mParamScalarWrappers[1]->getNativeParam())); // min_value + mParams.push_back(*(mParamScalarWrappers[2]->getNativeParam())); // max_value + } + mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor + mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + + } else { + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "DepthWiseConv2d"; + mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride + mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount + mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation + + mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput + mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight + mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias + + mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput + mBackend->addNodeToGraph(mOpConfigVersion, mNodeName.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + + // Quant output + { + auto QuantOutputTensor = mTempTensorWrappers[3]->getNativeTensor(); + if(mBackend->getUseFP16()){ + this->createStageTensor("CastOutput", QNN_DATATYPE_FLOAT_32, getNHWCShape(output)); + + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Cast"; + std::string name = mNodeName + "_Cast_Output"; + + mInputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput + mOutputs.push_back(*(mTempTensorWrappers.back()->getNativeTensor())); // CastOutput + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + QuantOutputTensor = mTempTensorWrappers.back()->getNativeTensor(); + } + { + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Quantize"; + std::string name = mNodeName + "_Quant_Output"; + + mInputs.push_back(*(QuantOutputTensor)); // stage tensor + mOutputs.push_back(*(mBackend->getNativeTensor(output))); // output + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + } + } + return NO_ERROR; +} + ErrorCode QNNConvDepthwise::onEncode(const std::vector &inputs, const std::vector &outputs) { auto conv2D = mOp->main_as_Convolution2D(); auto common = conv2D->common(); @@ -35,7 +186,8 @@ ErrorCode QNNConvDepthwise::onEncode(const std::vector &inputs, const padTop = std::get<1>(pads); padBottom = std::get<3>(pads); padLeft = std::get<0>(pads); padRight = std::get<2>(pads); dilationH = common->dilateY(); dilationW = common->dilateX(); } - + + isWeightQuantSupported(inputs[0], oc); // create all tensors and params { std::vector strideData = {(uint32_t)strideH, (uint32_t)strideW}; @@ -49,10 +201,24 @@ ErrorCode QNNConvDepthwise::onEncode(const std::vector &inputs, const this->createParamScalar("max_value", 6.0f); } - this->createWeight(dataType, oc, kernelH, kernelW); - this->createBias(dataType, oc); + this->createWeightAndBias(dataType, inputs[0], oc, kernelH, kernelW); + // dequant input and quant output + if(mWeightQuant == false && dataType != QNN_DATATYPE_FLOAT_16 && dataType != QNN_DATATYPE_FLOAT_32){ + return this->onEncodeQuantDequantDepthConv(inputs[0], outputs[0], n, ic, oc); + } + if (common->relu() || common->relu6()) { - this->createStageTensor("ReluTensor", dataType, getNHWCShape(outputs[0])); + Qnn_QuantizeParams_t quantize = DEFAULT_QUANTIZE_PARAMS; + Qnn_ScaleOffset_t tScaleOffsetEncoding; + auto quant = TensorUtils::getDescribe(outputs[0])->quantAttr.get(); + if(quant != nullptr && TensorUtils::getDescribe(outputs[0])->type == DataType_DT_INT8){ + quantize.encodingDefinition = QNN_DEFINITION_DEFINED; + quantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + tScaleOffsetEncoding.scale = quant->scale; + tScaleOffsetEncoding.offset = quant->zero; + quantize.scaleOffsetEncoding = tScaleOffsetEncoding; + } + this->createStageTensor("ReluTensor", dataType, getNHWCShape(outputs[0]), quantize); } } @@ -112,43 +278,140 @@ ErrorCode QNNConvDepthwise::onEncode(const std::vector &inputs, const -void QNNConvDepthwise::createWeight(Qnn_DataType_t dataType, int oc, int kernelH, int kernelW) { - std::vector weightData; - const float* source = nullptr; - int weightElementNum = 0; - std::shared_ptr quanWeight; - ConvolutionCommon::getConvParameters(&quanWeight, mBackend, mOp, &source, &weightElementNum); - // oc ic h w ---> h w ic oc - weightData.resize(weightElementNum); - convertWeight(source, (float *) weightData.data(), oc, kernelH, kernelW); - this->createStaticFloatTensor("weight", dataType, {(uint32_t)kernelH, (uint32_t)kernelW, 1, (uint32_t)oc}, weightData.data()); -} +void QNNConvDepthwise::createWeightAndBias(Qnn_DataType_t dataType, const Tensor *input, int oc, int kernelH, int kernelW) { + if(mWeightQuant){ + Qnn_QuantizeParams_t weightQuantize{}; + std::shared_ptr quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true); + // [TODO] Support asymmetric and other quantBits. + MNN_ASSERT(!quanCommon->asymmetric); -void QNNConvDepthwise::createBias(Qnn_DataType_t dataType, int oc) { - int biasElementNum = oc; - std::vector biasData; - biasData.resize(biasElementNum, 0); - auto bias = mOp->main_as_Convolution2D()->bias(); - if (nullptr != bias) { - ::memcpy((void *)biasData.data(), (void *)bias->data(), biasElementNum * sizeof(float)); - } - this->createStaticFloatTensor("bias", dataType, {(uint32_t)oc}, biasData.data()); -} - - -// oc, h, w ---> h, w, oc -void QNNConvDepthwise::convertWeight(const float * src, float * dst, int oc, int kernelH, int kernelW) { - for (int c = 0; c < oc; c++) { - for (int h = 0; h < kernelH; h++) { - for (int w = 0; w < kernelW; w++) { - int srcOffset = w + kernelW * (h + kernelH * c); - int dstOffset = c + oc * (w + kernelW * h); - dst[dstOffset] = src[srcOffset]; + // create weight + const int8_t * source = quanCommon->weight.get(); + std::vector quantWeightData(oc * kernelH * kernelW, 0); + if(quanCommon->canUseInt4){ + for (int c = 0; c < oc; c++) { + for (int h = 0; h < kernelH; h++) { + for (int w = 0; w < kernelW; w++) { + int srcOffset = w + kernelW * (h + kernelH * c); + int dstOffset = c + oc * (w + kernelW * h); + if(srcOffset % 2 == 0){ + quantWeightData[dstOffset] = ((source[srcOffset / 2] >> 4) & 0x0f) - 8; + }else{ + quantWeightData[dstOffset] = (source[srcOffset / 2] & 0x0f) - 8; + } + } + } } + }else{ + convertWeight(source, (int8_t *) quantWeightData.data(), oc, kernelH, kernelW); } + + mDequantAlpha = quanCommon->alpha.get(); + if(quanCommon->canUseInt4){ + weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; + weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET; + Qnn_BwAxisScaleOffset_t weightBWAxisScaleOffsetEncoding{}; + weightBWAxisScaleOffsetEncoding.bitwidth = 4; + weightBWAxisScaleOffsetEncoding.axis = 3; + weightBWAxisScaleOffsetEncoding.numElements = oc; + mScale.resize(oc); + std::vector OffsetData(oc); + for (int i = 0; i < oc; i++) { + mScale[i] = mDequantAlpha[i]; + } + weightBWAxisScaleOffsetEncoding.scales = mScale.data(); + weightQuantize.bwAxisScaleOffsetEncoding = weightBWAxisScaleOffsetEncoding; + + this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, 1, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize); + std::function mReleaseWeightScaleOffset = [&](){ + std::vector().swap(mScale); + }; + mBackend->pushReleaseFunc(mReleaseWeightScaleOffset); + }else{ + weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; + weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; + Qnn_AxisScaleOffset_t weightAxisScaleOffsetEncoding{}; + weightAxisScaleOffsetEncoding.axis = 3; + weightAxisScaleOffsetEncoding.numScaleOffsets = oc; + mScaleOffsetData.resize(oc); + for (int i = 0; i < oc; i++) { + mScaleOffsetData[i].scale = mDequantAlpha[i]; + mScaleOffsetData[i].offset = 0; + } + weightAxisScaleOffsetEncoding.scaleOffset = mScaleOffsetData.data(); + weightQuantize.axisScaleOffsetEncoding = weightAxisScaleOffsetEncoding; + + this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, 1, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize); + + std::function mReleaseWeightScaleOffset = [&](){ + std::vector().swap(mScaleOffsetData); + }; + mBackend->pushReleaseFunc(mReleaseWeightScaleOffset); + } + // create bias + { + float inputScale = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.scale; + int inputOffset = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.offset; + std::vector biasData; + biasData.resize(oc, 0); + + Qnn_QuantizeParams_t biasQuantize{}; + biasQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; + biasQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; + Qnn_AxisScaleOffset_t biasAxisScaleOffsetEncoding{}; + biasAxisScaleOffsetEncoding.axis = 0; + biasAxisScaleOffsetEncoding.numScaleOffsets = oc; + mBiasScaleOffsetData.resize(oc); + + auto bias = mOp->main_as_Convolution2D()->bias(); + auto biasPtr = (float*)bias->data(); + if (nullptr != bias) { + for(int i = 0; i < oc; ++i){ + float biasScale = inputScale * mDequantAlpha[i]; + mBiasScaleOffsetData[i].scale = 0.f; + mBiasScaleOffsetData[i].offset = 0; + if(fabs(biasPtr[i]) < 0.000001 || fabs(biasScale) < 0.000001){ + biasData[i] = 0; + } else{ + biasData[i] = (int)(biasPtr[i] / (biasScale)); + } + } + } + biasAxisScaleOffsetEncoding.scaleOffset = mBiasScaleOffsetData.data(); + biasQuantize.axisScaleOffsetEncoding = biasAxisScaleOffsetEncoding; + + this->createStaticTensor("bias", QNN_DATATYPE_SFIXED_POINT_32, {(uint32_t)oc}, biasData.data(), biasQuantize); + std::function mReleaseBiasScaleOffset = [&](){ + std::vector().swap(mBiasScaleOffsetData); + }; + mBackend->pushReleaseFunc(mReleaseBiasScaleOffset); + } + }else{ + Qnn_DataType_t floatDatatype = QNN_DATATYPE_FLOAT_32; + if(mBackend->getUseFP16()){ + floatDatatype = QNN_DATATYPE_FLOAT_16; + } + std::vector weightData; + const float* source = nullptr; + int weightElementNum = 0; + std::shared_ptr quanWeight; + ConvolutionCommon::getConvParameters(&quanWeight, mBackend, mOp, &source, &weightElementNum); + // oc ic h w ---> h w ic oc + weightData.resize(weightElementNum); + convertWeight(source, (float *) weightData.data(), oc, kernelH, kernelW); + this->createStaticFloatTensor("weight", floatDatatype, {(uint32_t)kernelH, (uint32_t)kernelW, 1, (uint32_t)oc}, weightData.data()); + + // create bias + std::vector biasData; + biasData.resize(oc, 0); + auto bias = mOp->main_as_Convolution2D()->bias(); + if (nullptr != bias) { + ::memcpy((void *)biasData.data(), (void *)bias->data(), oc * sizeof(float)); + } + this->createStaticFloatTensor("bias", floatDatatype, {(uint32_t)oc}, biasData.data()); } -} +} class QNNConvDepthwiseCreator : public QnnBackend::Creator { public: diff --git a/source/backend/qnn/execution/QNNConvDepthwise.hpp b/source/backend/qnn/execution/QNNConvDepthwise.hpp index 1d27c87c..f85f0c84 100644 --- a/source/backend/qnn/execution/QNNConvDepthwise.hpp +++ b/source/backend/qnn/execution/QNNConvDepthwise.hpp @@ -19,10 +19,28 @@ class QNNConvDepthwise : public QNNCommonExecution { public: QNNConvDepthwise(Backend *backend, const Op *op) : QNNCommonExecution(backend, op) {} virtual ErrorCode onEncode(const std::vector &inputs, const std::vector &outputs) override; + ErrorCode onEncodeQuantDequantDepthConv(Tensor *input, Tensor *output, const int n, const int ic, const int oc); private: - void createWeight(Qnn_DataType_t dataType, int oc, int kernelH, int kernelW); - void createBias(Qnn_DataType_t dataType, int oc); - void convertWeight(const float * src, float * dst, int oc, int kernelH, int kernelW); +template + void convertWeight(const T * src, T * dst, int oc, int kernelH, int kernelW) { + for (int c = 0; c < oc; c++) { + for (int h = 0; h < kernelH; h++) { + for (int w = 0; w < kernelW; w++) { + int srcOffset = w + kernelW * (h + kernelH * c); + int dstOffset = c + oc * (w + kernelW * h); + dst[dstOffset] = src[srcOffset]; + } + } + } + } + void isWeightQuantSupported(const Tensor *input, const int oc); + void createWeightAndBias(Qnn_DataType_t dataType, const Tensor *input, int oc, int kernelH, int kernelW); + std::vector mScale; + std::vector mScaleOffsetData; + std::vector mBiasScaleOffsetData; + std::vector mBlockScale; + float *mDequantAlpha = nullptr; + bool mWeightQuant = false; }; } // end namespace QNN diff --git a/source/backend/qnn/execution/QNNConvolution.cpp b/source/backend/qnn/execution/QNNConvolution.cpp index 68a61c86..18dae908 100644 --- a/source/backend/qnn/execution/QNNConvolution.cpp +++ b/source/backend/qnn/execution/QNNConvolution.cpp @@ -21,6 +21,63 @@ static std::pair closest_factors(int n) { } return {1, n}; } + +void QNNConvolution::isWeightQuantSupported(const Tensor *input, const int ic, const int oc){ + Qnn_DataType_t dataType = mBackend->getNativeTensor(input)->v1.dataType; + if(mOp->main_as_Convolution2D()->quanParameter() == nullptr){ + mWeightQuant = false; + return; + }else{ + bool hasBias = false; + auto bias = mOp->main_as_Convolution2D()->bias(); + auto biasPtr = (float*)bias->data(); + for(int i = 0; i < oc; ++i){ + if(biasPtr[i] != 0.0f){ + hasBias = true; + break; + } + } + + std::shared_ptr quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true); + int totalCount = quanCommon->alpha.size(); + mBlockSize = totalCount / oc; + if(quanCommon->asymmetric){ + // not support asymmetric and mBlockSize > 1 results incorrect now + mWeightQuant = false; + return; + } + + if(dataType == QNN_DATATYPE_FLOAT_16 || dataType == QNN_DATATYPE_FLOAT_32){ + if(mIsMatMul && mBlockSize == 1){ + mWeightQuant = true; + }else{ + mWeightQuant = false; + } + return; + } + + float inputScale = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.scale; + int inputOffset = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.offset; + if(inputOffset == 0){ + mWeightQuant = true; + }else{ + if(hasBias){ + mWeightQuant = false; + }else{ + mWeightQuant = true; + } + } + + if(mBlockSize > 1 && mWeightQuant){ + if(mIs1x1Conv && hasBias == false && (ic / mBlockSize) >= 16){ + mWeightQuant = true; + }else{ + mWeightQuant = false; + } + } + } +} + ErrorCode QNNConvolution::onEncode(const std::vector &inputs, const std::vector &outputs) { auto conv2D = mOp->main_as_Convolution2D(); auto common = conv2D->common(); @@ -46,32 +103,16 @@ ErrorCode QNNConvolution::onEncode(const std::vector &inputs, const st dilationH = common->dilateY(); dilationW = common->dilateX(); group = common->group(); } - - const float * weightSource = nullptr; - std::shared_ptr quanCommon; - if (mOp->main_as_Convolution2D()->quanParameter()) { - bool forceFloat = (common->kernelX() == 1 && common->kernelY() == 1) ? false : true; - - quanCommon = ConvolutionCommon::load(mOp, this->backend(), forceFloat); - if (quanCommon->weightFloat.get() == nullptr) { - // [TODO] Support asymmetric and other quantBits. - // isQuantWeight && symmetric quantization && int8 quantization && 1x1 conv - if (quanCommon->asymmetric || quanCommon->canUseInt4) { - return NOT_SUPPORT; - } - return this->onEncodeQuant(inputs[0], outputs[0], n, ih, iw, ic, oc, quanCommon); - } else { - weightSource = quanCommon->weightFloat.get(); - } - } else { - int weightElementNum; - ConvolutionCommon::getConvParameters(&quanCommon, mBackend, mOp, &weightSource, &weightElementNum); + mIs1x1Conv = kernelW==1 && strideH==1 && \ + strideW==1 && dilationH==1 && dilationW==1 && group==1 && \ + padTop==0 && padBottom==0 && padLeft==0 && padRight==0; + mIsMatMul = ih==1 && iw==1 && oh==1 && ow==1 && mIs1x1Conv; + isWeightQuantSupported(inputs[0], ic, oc); + + if(mIsMatMul && mWeightQuant && (dataType == QNN_DATATYPE_FLOAT_16 || dataType == QNN_DATATYPE_FLOAT_32)){ + return onEncodeFpAIntBMatMul(inputs[0], outputs[0], n, ih, iw, ic, oc); } - - #ifdef QNN_VERBOSE - MNN_PRINT("n:%d, ih:%d, iw:%d, ic:%d, oh:%d, ow:%d, oc:%d, kernelH:%d, kernelW:%d, dilationH:%d, dilationW:%d, strideH:%d, strideW:%d, group:%d, pad:%d %d %d %d\n", n, ih, iw, ic, oh, ow, oc, kernelH, kernelW, dilationH, \ - dilationW, strideH, strideW, group, padTop, padBottom, padLeft, padRight); - #endif + // create all tensors and params { std::vector strideData = {(uint32_t)strideH, (uint32_t)strideW}; @@ -85,12 +126,26 @@ ErrorCode QNNConvolution::onEncode(const std::vector &inputs, const st this->createParamScalar("min_value", 0.0f); this->createParamScalar("max_value", 6.0f); } + } - this->createWeight(dataType, oc, ic, kernelH, kernelW, group, weightSource); - this->createBias(dataType, oc); - if (common->relu() || common->relu6()) { - this->createStageTensor("ReluTensor", dataType, getNHWCShape(outputs[0])); + this->createWeightAndBias(dataType, inputs[0], oc, ic, kernelH, kernelW, group); + // dequant input and quant output + if(mWeightQuant == false && dataType != QNN_DATATYPE_FLOAT_16 && dataType != QNN_DATATYPE_FLOAT_32){ + return this->onEncodeQuantDequantConv(inputs[0], outputs[0], n, ic, oc); + } + + if (common->relu() || common->relu6()) { + Qnn_QuantizeParams_t quantize = DEFAULT_QUANTIZE_PARAMS; + Qnn_ScaleOffset_t tScaleOffsetEncoding; + auto quant = TensorUtils::getDescribe(outputs[0])->quantAttr.get(); + if(quant != nullptr && TensorUtils::getDescribe(outputs[0])->type == DataType_DT_INT8){ + quantize.encodingDefinition = QNN_DEFINITION_DEFINED; + quantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + tScaleOffsetEncoding.scale = quant->scale; + tScaleOffsetEncoding.offset = quant->zero; + quantize.scaleOffsetEncoding = tScaleOffsetEncoding; } + this->createStageTensor("ReluTensor", dataType, getNHWCShape(outputs[0]), quantize); } // add nodes @@ -131,13 +186,34 @@ ErrorCode QNNConvolution::onEncode(const std::vector &inputs, const st } } else { - bool isMatmul = ih==1 && iw==1 && oh==1 && ow==1 && kernelH==1 && kernelW==1 && strideH==1 && \ - strideW==1 && dilationH==1 && dilationW==1 && group==1 && \ - padTop==0 && padBottom==0 && padLeft==0 && padRight==0; - if(isMatmul && n > 1) { + if(mIsMatMul && n > 1) { auto num = closest_factors(n); - this->createStageTensor("InputReshapeTensor", dataType, std::vector({1, num.first, num.second, ic})); - this->createStageTensor("OutputReshapeTensor", dataType, std::vector({1, num.first, num.second, oc})); + { + Qnn_QuantizeParams_t quantize = DEFAULT_QUANTIZE_PARAMS; + Qnn_ScaleOffset_t tScaleOffsetEncoding; + auto quant = TensorUtils::getDescribe(inputs[0])->quantAttr.get(); + if(quant != nullptr && TensorUtils::getDescribe(inputs[0])->type == DataType_DT_INT8){ + quantize.encodingDefinition = QNN_DEFINITION_DEFINED; + quantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + tScaleOffsetEncoding.scale = quant->scale; + tScaleOffsetEncoding.offset = quant->zero; + quantize.scaleOffsetEncoding = tScaleOffsetEncoding; + } + this->createStageTensor("InputReshapeTensor", dataType, std::vector({1, num.first, num.second, ic}), quantize); + } + { + Qnn_QuantizeParams_t quantize = DEFAULT_QUANTIZE_PARAMS; + Qnn_ScaleOffset_t tScaleOffsetEncoding; + auto quant = TensorUtils::getDescribe(outputs[0])->quantAttr.get(); + if(quant != nullptr && TensorUtils::getDescribe(outputs[0])->type == DataType_DT_INT8){ + quantize.encodingDefinition = QNN_DEFINITION_DEFINED; + quantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; + tScaleOffsetEncoding.scale = quant->scale; + tScaleOffsetEncoding.offset = quant->zero; + quantize.scaleOffsetEncoding = tScaleOffsetEncoding; + } + this->createStageTensor("OutputReshapeTensor", dataType, std::vector({1, num.first, num.second, oc}), quantize); + } #ifdef QNN_VERBOSE MNN_PRINT("Matmul2Conv, start reshape batch:%d -> %dx%d\n", n, num.first, num.second); #endif @@ -208,52 +284,273 @@ ErrorCode QNNConvolution::onEncode(const std::vector &inputs, const st return NO_ERROR; } +ErrorCode QNNConvolution::onEncodeQuantDequantConv(Tensor *input, Tensor *output, const int n, const int ic, const int oc) { + auto conv2D = mOp->main_as_Convolution2D(); + auto common = conv2D->common(); + Qnn_DataType_t dataType = QNN_DATATYPE_FLOAT_32; + if(mBackend->getUseFP16()){ + dataType = QNN_DATATYPE_FLOAT_16; + } + + // create dequant input stage tensor + this->createStageTensor("DequantInput", dataType, getNHWCShape(input)); // mTempTensorWrappers[2] + this->createStageTensor("QuantOutput", dataType, getNHWCShape(output)); // mTempTensorWrappers[3] + + // add nodes + { + // dequant input + { + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Dequantize"; + std::string name = mNodeName + "_dequant_input"; + + mInputs.push_back(*(mBackend->getNativeTensor(input))); // input + mOutputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + + if (common->relu() || common->relu6()) { + this->createStageTensor("ReluTensor", dataType, getNHWCShape(output)); // mTempTensorWrappers[4] + // Stage one + { + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Conv2d"; + std::string name = mNodeName + "_conv"; + mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride + mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount + mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation + mParams.push_back(*(mParamScalarWrappers[0]->getNativeParam())); // group + + mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput + mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight + mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias + + mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } -ErrorCode QNNConvolution::onEncodeQuant(Tensor * input, Tensor * output, int n, int h, int w, int ic, int oc, std::shared_ptr quanCommon) { + // Stage two + { + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = common->relu6() ? "ReluMinMax" : "Relu"; + std::string name = mNodeName + "_relu"; + if (common->relu6()) { + mParams.push_back(*(mParamScalarWrappers[1]->getNativeParam())); // min_value + mParams.push_back(*(mParamScalarWrappers[2]->getNativeParam())); // max_value + } + mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor + mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + + } else { + if(mIsMatMul && n > 1) { + auto num = closest_factors(n); + this->createStageTensor("InputReshapeTensor", dataType, std::vector({1, num.first, num.second, ic})); // mTempTensorWrappers[4] + this->createStageTensor("OutputReshapeTensor", dataType, std::vector({1, num.first, num.second, oc})); // mTempTensorWrappers[5] + #ifdef QNN_VERBOSE + MNN_PRINT("Matmul2Conv, start reshape batch:%d -> %dx%d\n", n, num.first, num.second); + #endif + // reshape input + { + std::string name = mNodeName + "_input_reshape"; + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Reshape"; + + mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput + mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // InputReshapeTensor + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + // conv2d + { + std::string name = mNodeName; + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Conv2d"; + + mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride + mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount + mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation + mParams.push_back(*(mParamScalarWrappers[0]->getNativeParam())); // group + + mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // InputReshapeTensor + mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight + mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias + + mOutputs.push_back(*(mTempTensorWrappers[5]->getNativeTensor())); // OutputReshapeTensor + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + + // reshape output + { + std::string name = mNodeName + "_output_reshape"; + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Reshape"; + + mInputs.push_back(*(mTempTensorWrappers[5]->getNativeTensor())); // OutputReshapeTensor + mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + } else{ + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Conv2d"; + mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride + mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount + mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation + mParams.push_back(*(mParamScalarWrappers[0]->getNativeParam())); // group + + mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput + mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight + mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias + + mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput + mBackend->addNodeToGraph(mOpConfigVersion, mNodeName.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + } + + // Quant output + { + auto QuantOutputTensor = mTempTensorWrappers[3]->getNativeTensor(); + if(mBackend->getUseFP16()){ + this->createStageTensor("CastOutput", QNN_DATATYPE_FLOAT_32, getNHWCShape(output)); + + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Cast"; + std::string name = mNodeName + "_Cast_Output"; + + mInputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput + mOutputs.push_back(*(mTempTensorWrappers.back()->getNativeTensor())); // CastOutput + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + QuantOutputTensor = mTempTensorWrappers.back()->getNativeTensor(); + } + { + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Quantize"; + std::string name = mNodeName + "_Quant_Output"; + + mInputs.push_back(*(QuantOutputTensor)); // stage tensor + mOutputs.push_back(*(mBackend->getNativeTensor(output))); // output + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + } + } + return NO_ERROR; +} + +ErrorCode QNNConvolution::onEncodeFpAIntBMatMul(Tensor * input, Tensor * output, int n, int h, int w, int ic, int oc) { // create parameters and stage tensors + auto conv2D = mOp->main_as_Convolution2D(); + auto common = conv2D->common(); + Qnn_DataType_t dataType = mBackend->getNativeTensor(input)->v1.dataType; { bool transposeWeightFlag = true; this->createParamScalar("transpose_in1", transposeWeightFlag); - + std::vector tempInputShape = {(uint32_t) n * h * w , (uint32_t) ic}; std::vector tempOutputShape = {(uint32_t) n * h * w , (uint32_t) oc}; - this->createStageTensor("tempInput", QNN_DATATYPE_FLOAT_16, tempInputShape); - this->createStageTensor("tempOutput", QNN_DATATYPE_FLOAT_16, tempOutputShape); + this->createStageTensor("tempInput", dataType, tempInputShape); + this->createStageTensor("tempOutput", dataType, tempOutputShape); - // create weight - const int8_t * source = quanCommon->weight.get(); - std::vector quantWeightData(oc * ic, 0); - ::memcpy(quantWeightData.data(), source, oc * ic * sizeof(int8_t)); - - float * dequantAlpha = quanCommon->alpha.get(); - - Qnn_QuantizeParams_t weightQuantize{}; - weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; - weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; - Qnn_AxisScaleOffset_t weightAxisScaleOffsetEncoding{}; - weightAxisScaleOffsetEncoding.axis = 0; - weightAxisScaleOffsetEncoding.numScaleOffsets = oc; - std::vector scaleOffsetData(oc); - for (int i = 0; i < oc; i++) { - if (quanCommon->asymmetric) { - scaleOffsetData[i].scale = dequantAlpha[2 * i + 1]; - // scaleOffsetData[i].offset = (int) dequantAlpha[2 * i + 0]; - scaleOffsetData[i].offset = 0; - } else { - scaleOffsetData[i].scale = dequantAlpha[i]; - scaleOffsetData[i].offset = 0; + // create weight and bias + { + Qnn_QuantizeParams_t weightQuantize{}; + std::shared_ptr quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true); + MNN_ASSERT(!quanCommon->asymmetric); + const int8_t * source = quanCommon->weight.get(); + std::vector quantWeightData(oc * ic, 0); + if(quanCommon->canUseInt4){ + for (int o = 0; o < oc; o++) { + for (int i = 0; i < ic; i++) { + uint32_t srcOffset = o * ic + i; + uint32_t dstOffset = srcOffset; + if(srcOffset % 2 == 0){ + quantWeightData[dstOffset] = ((source[srcOffset / 2] >> 4) & 0x0f) - 8; + }else{ + quantWeightData[dstOffset] = (source[srcOffset / 2] & 0x0f) - 8; + } + } + } + }else{ + ::memcpy(quantWeightData.data(), source, oc * ic * sizeof(int8_t)); } + mDequantAlpha = quanCommon->alpha.get(); + int totalCount = quanCommon->alpha.size(); + mBlockSize = totalCount / oc; + int blockNum = ic / mBlockSize; + if(quanCommon->canUseInt4){ + weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; + weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET; + Qnn_BwAxisScaleOffset_t weightBWAxisScaleOffsetEncoding{}; + weightBWAxisScaleOffsetEncoding.bitwidth = 4; + weightBWAxisScaleOffsetEncoding.axis = 0; + weightBWAxisScaleOffsetEncoding.numElements = oc; + mScale.resize(oc); + std::vector OffsetData(oc); + for (int i = 0; i < oc; i++) { + mScale[i] = mDequantAlpha[i]; + } + weightBWAxisScaleOffsetEncoding.scales = mScale.data(); + weightQuantize.bwAxisScaleOffsetEncoding = weightBWAxisScaleOffsetEncoding; + + this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)oc, (uint32_t)ic}, (void *) quantWeightData.data(), weightQuantize); + std::function mReleaseWeightScaleOffset = [&](){ + std::vector().swap(mScale); + }; + mBackend->pushReleaseFunc(mReleaseWeightScaleOffset); + }else{ + weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; + weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; + Qnn_AxisScaleOffset_t weightAxisScaleOffsetEncoding{}; + weightAxisScaleOffsetEncoding.axis = 0; + weightAxisScaleOffsetEncoding.numScaleOffsets = oc; + mScaleOffsetData.resize(oc); + for (int i = 0; i < oc; i++) { + mScaleOffsetData[i].scale = mDequantAlpha[i]; + mScaleOffsetData[i].offset = 0; + } + weightAxisScaleOffsetEncoding.scaleOffset = mScaleOffsetData.data(); + weightQuantize.axisScaleOffsetEncoding = weightAxisScaleOffsetEncoding; + + this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)oc, (uint32_t)ic}, (void *) quantWeightData.data(), weightQuantize); + std::function mReleaseWeightScaleOffset = [&](){ + std::vector().swap(mScaleOffsetData); + }; + mBackend->pushReleaseFunc(mReleaseWeightScaleOffset); + } + //create bias + this->createBias(dataType, oc, input, quanCommon); + } + + if (common->relu6()) { + this->createParamScalar("min_value", 0.0f); + this->createParamScalar("max_value", 6.0f); + } + if (common->relu() || common->relu6()) { + this->createStageTensor("ReluTensor", dataType, getNHWCShape(output)); } - weightAxisScaleOffsetEncoding.scaleOffset = scaleOffsetData.data(); - weightQuantize.axisScaleOffsetEncoding = weightAxisScaleOffsetEncoding; - - this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t) oc, (uint32_t) ic}, (void *) quantWeightData.data(), weightQuantize); } // Stage One: reshape input { mNodeType = "Reshape"; - std::string name = mNodeName + "_reshapeOutput"; + std::string name = mNodeName + "_reshapeInput"; mParams.clear(); mInputs.clear(); mOutputs.clear(); @@ -273,6 +570,7 @@ ErrorCode QNNConvolution::onEncodeQuant(Tensor * input, Tensor * output, int n, mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // tempInput // mInputs.push_back(*(mBackend->getNativeTensor(input))); mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // weight + mInputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // bias mOutputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // tempOutput // mOutputs.push_back(*(mBackend->getNativeTensor(output))); mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); @@ -286,48 +584,217 @@ ErrorCode QNNConvolution::onEncodeQuant(Tensor * input, Tensor * output, int n, mInputs.clear(); mOutputs.clear(); mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); - mOutputs.push_back(*(mBackend->getNativeTensor(output))); + if (common->relu() || common->relu6()){ + mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); //ReluTensor + }else{ + mOutputs.push_back(*(mBackend->getNativeTensor(output))); + } mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); } + + // Stage Four: relu or relu6 + if (common->relu() || common->relu6()){ + mNodeType.clear(); + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = common->relu6() ? "ReluMinMax" : "Relu"; + std::string name = mNodeName + "_relu"; + if (common->relu6()) { + mParams.push_back(*(mParamScalarWrappers[1]->getNativeParam())); // min_value + mParams.push_back(*(mParamScalarWrappers[2]->getNativeParam())); // max_value + } + mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor + mOutputs.push_back(*(mBackend->getNativeTensor(output))); // output + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + return NO_ERROR; +} +bool QNNConvolution::createWeightAndBias(Qnn_DataType_t dataType, const Tensor *input, int oc, int ic, int kernelH, int kernelW, int group) { + if(mWeightQuant){ + Qnn_QuantizeParams_t weightQuantize{}; + std::shared_ptr quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true); + if(quanCommon->asymmetric) { + MNN_ERROR("[Error]: Qnn weight quant only support symmetric currently\n"); + return false; + } + const int8_t * source = quanCommon->weight.get(); + std::vector quantWeightData(oc * (ic / group) * kernelH * kernelW, 0); + if(quanCommon->canUseInt4){ + for (int o = 0; o < oc; o++) { + for (int i = 0; i < ic/group; i++) { + for (int h = 0; h < kernelH; h++) { + for (int w = 0; w < kernelW; w++) { + uint32_t srcOffset = w + kernelW * (h + kernelH * (i + ic/group * o)); + uint32_t dstOffset = o + oc * (i + ic/group * (w + kernelW * h)); + if(srcOffset % 2 == 0){ + quantWeightData[dstOffset] = ((source[srcOffset / 2] >> 4) & 0x0f) - 8; + }else{ + quantWeightData[dstOffset] = (source[srcOffset / 2] & 0x0f) - 8; + } + } + } + } + } + }else{ + convertWeight(source, (int8_t *) quantWeightData.data(), oc, ic/group, kernelH, kernelW); + } + mDequantAlpha = quanCommon->alpha.get(); + int totalCount = quanCommon->alpha.size(); + mBlockSize = totalCount / oc; + // Todo: result is wrong, need to verify + if(mBlockSize > 1){ + Qnn_QuantizeParams_t weightQuantize{}; + weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; + weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION; + + weightBlockwiseExpansionEncoding.axis = 3; + weightBlockwiseExpansionEncoding.numBlocksPerAxis = mBlockSize; + weightBlockwiseExpansionEncoding.blockScaleBitwidth = 4; + weightBlockwiseExpansionEncoding.blockScaleStorageType = QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8; + mBlockScale.resize(oc * mBlockSize); + mScaleOffsetData.resize(oc); + for (int i = 0; i < oc; i++) { + float maxscale = -MAXFLOAT; + for(int j = 0; j < mBlockSize; ++j){ + if(mDequantAlpha[i * mBlockSize + j] > maxscale){ + maxscale = mDequantAlpha[i * mBlockSize + j]; + } + } + float blockScale = maxscale / 16.0f; + for(int j = 0; j < mBlockSize; ++j){ + int quantBlock = round(mDequantAlpha[i * mBlockSize + j] / blockScale); + mBlockScale[i * mBlockSize + j] = (uint8_t)std::min(std::max(quantBlock, 1), 16); + } + mScaleOffsetData[i].scale = blockScale; + mScaleOffsetData[i].offset = 0; + } + weightBlockwiseExpansionEncoding.scaleOffsets = mScaleOffsetData.data(); + weightBlockwiseExpansionEncoding.blocksScale8 = mBlockScale.data(); + weightQuantize.blockwiseExpansion = &weightBlockwiseExpansionEncoding; + this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize); + std::function mReleaseWeightScaleOffset = [&](){ + std::vector().swap(mScaleOffsetData); + }; + mBackend->pushReleaseFunc(mReleaseWeightScaleOffset); + std::function mReleaseBlockScale = [&](){ + std::vector().swap(mBlockScale); + }; + mBackend->pushReleaseFunc(mReleaseBlockScale); + }else if(quanCommon->canUseInt4){ + weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; + weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET; + Qnn_BwAxisScaleOffset_t weightBWAxisScaleOffsetEncoding{}; + weightBWAxisScaleOffsetEncoding.bitwidth = 4; + weightBWAxisScaleOffsetEncoding.axis = 3; + weightBWAxisScaleOffsetEncoding.numElements = oc; + mScale.resize(oc); + std::vector OffsetData(oc); + for (int i = 0; i < oc; i++) { + mScale[i] = mDequantAlpha[i]; + } + weightBWAxisScaleOffsetEncoding.scales = mScale.data(); + weightQuantize.bwAxisScaleOffsetEncoding = weightBWAxisScaleOffsetEncoding; + + this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize); + std::function mReleaseWeightScaleOffset = [&](){ + std::vector().swap(mScale); + }; + mBackend->pushReleaseFunc(mReleaseWeightScaleOffset); + }else{ + weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; + weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; + Qnn_AxisScaleOffset_t weightAxisScaleOffsetEncoding{}; + weightAxisScaleOffsetEncoding.axis = 3; + weightAxisScaleOffsetEncoding.numScaleOffsets = oc; + mScaleOffsetData.resize(oc); + for (int i = 0; i < oc; i++) { + mScaleOffsetData[i].scale = mDequantAlpha[i]; + mScaleOffsetData[i].offset = 0; + } + weightAxisScaleOffsetEncoding.scaleOffset = mScaleOffsetData.data(); + weightQuantize.axisScaleOffsetEncoding = weightAxisScaleOffsetEncoding; + + this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize); + std::function mReleaseWeightScaleOffset = [&](){ + std::vector().swap(mScaleOffsetData); + }; + mBackend->pushReleaseFunc(mReleaseWeightScaleOffset); + } + this->createBias(dataType, oc, input, quanCommon); + } else { + std::vector weightData; + const float* source = nullptr; + int weightElementNum = 0; + std::shared_ptr quanWeight; + ConvolutionCommon::getConvParameters(&quanWeight, mBackend, mOp, &source, &weightElementNum); + // oc ic h w ---> h w ic oc + weightData.resize(weightElementNum); + convertWeight(source, (float *) weightData.data(), oc, ic/group, kernelH, kernelW); + Qnn_DataType_t floatDatatype = QNN_DATATYPE_FLOAT_32; + if(mBackend->getUseFP16()){ + floatDatatype = QNN_DATATYPE_FLOAT_16; + } + this->createStaticFloatTensor("weight", floatDatatype, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, weightData.data()); + this->createBias(dataType, oc, input, nullptr); + } return NO_ERROR; } -void QNNConvolution::createWeight(Qnn_DataType_t dataType, int oc, int ic, int kernelH, int kernelW, int group, const float * source) { - std::vector weightData; - int weightElementNum = oc * ic / group * kernelH * kernelW; - // oc ic/group h w ---> h w ic/group oc - weightData.resize(weightElementNum); - - convertWeight(source, (float *) weightData.data(), oc, ic/group, kernelH, kernelW); - this->createStaticFloatTensor("weight", dataType, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, weightData.data()); -} - - -void QNNConvolution::createBias(Qnn_DataType_t dataType, int oc) { +void QNNConvolution::createBias(Qnn_DataType_t dataType, int oc, const Tensor *input, std::shared_ptr quanCommon) { int biasElementNum = oc; - std::vector biasData; - biasData.resize(biasElementNum, 0); - auto bias = mOp->main_as_Convolution2D()->bias(); - if (nullptr != bias) { - ::memcpy((void *)biasData.data(), (void *)bias->data(), biasElementNum * sizeof(float)); - } - this->createStaticFloatTensor("bias", dataType, {(uint32_t)oc}, biasData.data()); -} + if(dataType != QNN_DATATYPE_FLOAT_16 && dataType != QNN_DATATYPE_FLOAT_32 && mWeightQuant){ + mDequantAlpha = quanCommon->alpha.get(); + float inputScale = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.scale; + int inputOffset = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.offset; + std::vector biasData; + biasData.resize(biasElementNum, 0); -// oc ic h w ---> h w ic oc -void QNNConvolution::convertWeight(const float * src, float * dst, int oc, int ic, int kernelH, int kernelW) { - for (int o = 0; o < oc; o++) { - for (int i = 0; i < ic; i++) { - for (int h = 0; h < kernelH; h++) { - for (int w = 0; w < kernelW; w++) { - uint32_t srcOffset = w + kernelW * (h + kernelH * (i + ic * o)); - uint32_t dstOffset = o + oc * (i + ic * (w + kernelW * h)); - dst[dstOffset] = src[srcOffset]; + Qnn_QuantizeParams_t biasQuantize{}; + biasQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; + biasQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; + Qnn_AxisScaleOffset_t biasAxisScaleOffsetEncoding{}; + biasAxisScaleOffsetEncoding.axis = 0; + biasAxisScaleOffsetEncoding.numScaleOffsets = biasElementNum; + mBiasScaleOffsetData.resize(biasElementNum); + + auto bias = mOp->main_as_Convolution2D()->bias(); + auto biasPtr = (float*)bias->data(); + if (nullptr != bias) { + for(int i = 0; i < biasElementNum; ++i){ + float biasScale = inputScale * mDequantAlpha[i]; + mBiasScaleOffsetData[i].scale = biasScale; + mBiasScaleOffsetData[i].offset = 0; + if(fabs(biasPtr[i]) < 0.000001 || fabs(biasScale) < 0.000001){ + biasData[i] = 0; + } else{ + biasData[i] = (int)(biasPtr[i] / biasScale); } } } + + biasAxisScaleOffsetEncoding.scaleOffset = mBiasScaleOffsetData.data(); + biasQuantize.axisScaleOffsetEncoding = biasAxisScaleOffsetEncoding; + + this->createStaticTensor("bias", QNN_DATATYPE_SFIXED_POINT_32, {(uint32_t)biasElementNum}, biasData.data(), biasQuantize); + std::function mReleaseBiasScaleOffset = [&](){ + std::vector().swap(mBiasScaleOffsetData); + }; + mBackend->pushReleaseFunc(mReleaseBiasScaleOffset); + }else{ + std::vector biasData; + biasData.resize(biasElementNum, 0); + auto bias = mOp->main_as_Convolution2D()->bias(); + if (nullptr != bias) { + ::memcpy((void *)biasData.data(), (void *)bias->data(), biasElementNum * sizeof(float)); + } + Qnn_DataType_t floatDatatype = QNN_DATATYPE_FLOAT_32; + if(mBackend->getUseFP16()){ + floatDatatype = QNN_DATATYPE_FLOAT_16; + } + this->createStaticFloatTensor("bias", floatDatatype, {(uint32_t)oc}, biasData.data()); } } diff --git a/source/backend/qnn/execution/QNNConvolution.hpp b/source/backend/qnn/execution/QNNConvolution.hpp index bc2dcc7f..7f3807d3 100644 --- a/source/backend/qnn/execution/QNNConvolution.hpp +++ b/source/backend/qnn/execution/QNNConvolution.hpp @@ -19,12 +19,37 @@ class QNNConvolution : public QNNCommonExecution { public: QNNConvolution(Backend *backend, const Op *op) : QNNCommonExecution(backend, op) {} virtual ErrorCode onEncode(const std::vector &inputs, const std::vector &outputs) override; - ErrorCode onEncodeQuant(Tensor * input, Tensor * output, int n, int h, int w, int ic, int oc, std::shared_ptr quanCommon); + ErrorCode onEncodeFpAIntBMatMul(Tensor * input, Tensor * output, int n, int h, int w, int ic, int oc); + ErrorCode onEncodeQuantDequantConv(Tensor *input, Tensor *output, const int n, const int ic, const int oc); private: - void createWeight(Qnn_DataType_t dataType, int oc, int ic, int kernelH, int kernelW, int group, const float * source); - void createBias(Qnn_DataType_t dataType, int oc); - void convertWeight(const float * src, float * dst, int oc, int ic, int kernelH, int kernelW); + template + void convertWeight(const T * src, T * dst, int oc, int ic, int kernelH, int kernelW) { + for (int o = 0; o < oc; o++) { + for (int i = 0; i < ic; i++) { + for (int h = 0; h < kernelH; h++) { + for (int w = 0; w < kernelW; w++) { + uint32_t srcOffset = w + kernelW * (h + kernelH * (i + ic * o)); + uint32_t dstOffset = o + oc * (i + ic * (w + kernelW * h)); + dst[dstOffset] = src[srcOffset]; + } + } + } + } + } + void isWeightQuantSupported(const Tensor *input, const int ic, const int oc); + bool createWeightAndBias(Qnn_DataType_t dataType, const Tensor *input, int oc, int ic, int kernelH, int kernelW, int group); + void createBias(Qnn_DataType_t dataType, int oc, const Tensor *input, std::shared_ptr quanCommon); + std::vector mScale; + std::vector mScaleOffsetData; + std::vector mBiasScaleOffsetData; + std::vector mBlockScale; + Qnn_BlockwiseExpansion_t weightBlockwiseExpansionEncoding = QNN_BLOCKWISE_EXPANSION_INIT; + float *mDequantAlpha = nullptr; + int mBlockSize = 1; + bool mWeightQuant = false; + bool mIsMatMul = false; + bool mIs1x1Conv = false; }; } // end namespace QNN diff --git a/source/backend/qnn/execution/QNNQuant.cpp b/source/backend/qnn/execution/QNNQuant.cpp new file mode 100644 index 00000000..e7e84243 --- /dev/null +++ b/source/backend/qnn/execution/QNNQuant.cpp @@ -0,0 +1,81 @@ +// +// QNNQuant.cpp +// MNN +// +// Created by MNN on b'2025/05/29'. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#include "QNNQuant.hpp" + +namespace MNN { +namespace QNN { + +ErrorCode QNNQuant::onEncode(const std::vector &inputs, const std::vector &outputs) { + this->createStageTensor("Cast", QNN_DATATYPE_FLOAT_32, getNHWCShape(outputs[0])); + // Stage one fp16 -> fp32 + { + mNodeType = "Cast"; + std::string name = mNodeName + "_Cast"; + + mInputs.push_back(*(mBackend->getNativeTensor(inputs[0]))); // input + mOutputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // stage tensor + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + + // Stage two fp32 -> int8 + { + mNodeType.clear(); + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Quantize"; + std::string name = mNodeName; + + mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // stage tensor + mOutputs.push_back(*(mBackend->getNativeTensor(outputs[0]))); // output + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + return NO_ERROR; +} + + +class QNNQuantCreator : public QnnBackend::Creator { +public: + virtual QNNCommonExecution * onCreate(const std::vector& inputs, const std::vector& outputs, const MNN::Op* op, + Backend* backend) const override { + return new QNNQuant(backend, op); + } +}; + +ErrorCode QNNDeQuant::onEncode(const std::vector &inputs, const std::vector &outputs) { + // Stage one int8 -> fp16 + { + mNodeType.clear(); + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Dequantize"; + std::string name = mNodeName; + + mInputs.push_back(*(mBackend->getNativeTensor(inputs[0]))); // input + mOutputs.push_back(*(mBackend->getNativeTensor(outputs[0]))); // output + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + return NO_ERROR; +} + + +class QNNDeQuantCreator : public QnnBackend::Creator { +public: + virtual QNNCommonExecution * onCreate(const std::vector& inputs, const std::vector& outputs, const MNN::Op* op, + Backend* backend) const override { + return new QNNDeQuant(backend, op); + } +}; + +REGISTER_QNN_OP_CREATOR(QNNQuantCreator, OpType_FloatToInt8) +REGISTER_QNN_OP_CREATOR(QNNDeQuantCreator, OpType_Int8ToFloat) + +} // end namespace QNN +} // end namespace MNN diff --git a/source/backend/qnn/execution/QNNQuant.hpp b/source/backend/qnn/execution/QNNQuant.hpp new file mode 100644 index 00000000..b38460bf --- /dev/null +++ b/source/backend/qnn/execution/QNNQuant.hpp @@ -0,0 +1,32 @@ +// +// QNNQuant.hpp +// MNN +// +// Created by MNN on b'2025/05/29'. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifndef MNN_QNNQUANT_HPP +#define MNN_QNNQUANT_HPP + +#include "QNNCommonExecution.hpp" + +namespace MNN { +namespace QNN { + +class QNNQuant : public QNNCommonExecution { +public: + QNNQuant(Backend *backend, const Op *op) : QNNCommonExecution(backend, op) {}; + virtual ErrorCode onEncode(const std::vector &inputs, const std::vector &outputs) override; +}; + +class QNNDeQuant : public QNNCommonExecution { +public: + QNNDeQuant(Backend *backend, const Op *op) : QNNCommonExecution(backend, op) {}; + virtual ErrorCode onEncode(const std::vector &inputs, const std::vector &outputs) override; +}; + +} // end namespace QNN +} // end namespace MNN + +#endif // end MNN_QNNQUANT_HPP diff --git a/source/backend/qnn/execution/QNNScale.cpp b/source/backend/qnn/execution/QNNScale.cpp index 22b147db..3b748b05 100644 --- a/source/backend/qnn/execution/QNNScale.cpp +++ b/source/backend/qnn/execution/QNNScale.cpp @@ -28,10 +28,25 @@ ErrorCode QNNScale::onEncode(const std::vector &inputs, const std::vec MNN_ASSERT(channel == mWeightData.size()); Qnn_DataType_t dataType = mBackend->getNativeTensor(inputs[0])->v1.dataType; - - this->createStaticFloatTensor("weight", dataType, {(uint32_t)channel}, mWeightData.data()); - this->createStaticFloatTensor("bias", dataType, {(uint32_t)channel}, mBiasData.data()); - this->createStageTensor("Stage", dataType, getNHWCShape(inputs[0])); + mNeedQuantDequant = dataType != QNN_DATATYPE_FLOAT_16 && dataType != QNN_DATATYPE_FLOAT_32; + if(mNeedQuantDequant){ + Qnn_DataType_t tempDataType = QNN_DATATYPE_FLOAT_32; + if(mBackend->getUseFP16()){ + tempDataType = QNN_DATATYPE_FLOAT_16; + } + this->createStaticFloatTensor("weight", tempDataType, {(uint32_t)channel}, mWeightData.data()); + this->createStaticFloatTensor("bias", tempDataType, {(uint32_t)channel}, mBiasData.data()); + this->createStageTensor("Stage", tempDataType, getNHWCShape(inputs[0])); + this->createStageTensor("Stage_dequantize_input", tempDataType, getNHWCShape(inputs[0])); + this->createStageTensor("Stage_add_output", tempDataType, getNHWCShape(outputs[0])); + if(mBackend->getUseFP16()){ + this->createStageTensor("Stage_cast_output", QNN_DATATYPE_FLOAT_32, getNHWCShape(outputs[0])); + } + }else{ + this->createStaticFloatTensor("weight", dataType, {(uint32_t)channel}, mWeightData.data()); + this->createStaticFloatTensor("bias", dataType, {(uint32_t)channel}, mBiasData.data()); + this->createStageTensor("Stage", dataType, getNHWCShape(inputs[0])); + } } // add nodes @@ -42,34 +57,98 @@ ErrorCode QNNScale::onEncode(const std::vector &inputs, const std::vec } void QNNScale::mulWeight(Tensor * input) { - mNodeType = "ElementWiseMultiply"; - std::string name = mNodeName + "_mul"; - mParams.clear(); - mInputs.clear(); - mOutputs.clear(); - - mInputs.push_back(*(mBackend->getNativeTensor(input))); - mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); - - mOutputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); - - mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + Qnn_DataType_t dataType = mBackend->getNativeTensor(input)->v1.dataType; + // need dequantize to float16 + if(mNeedQuantDequant){ + mNodeType.clear(); + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Dequantize"; + std::string name = mNodeName + "_Dequantize"; + + mInputs.push_back(*(mBackend->getNativeTensor(input))); // input + mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); //Stage_dequantize_input + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + { + mNodeType.clear(); + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "ElementWiseMultiply"; + std::string name = mNodeName + "_mul"; + + if(mNeedQuantDequant){ + mInputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); //Stage_dequantize_input + }else{ + mInputs.push_back(*(mBackend->getNativeTensor(input))); + } + mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); + + mOutputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); + + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } } void QNNScale::addBias(Tensor * output) { - mNodeType = "ElementWiseAdd"; - std::string name = mNodeName + "_add"; - mParams.clear(); - mInputs.clear(); - mOutputs.clear(); + Qnn_DataType_t dataType = mBackend->getNativeTensor(output)->v1.dataType; + { + mNodeType.clear(); + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "ElementWiseAdd"; + std::string name = mNodeName + "_add"; + + mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); + mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); + + if(mNeedQuantDequant){ + mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // Stage_add_output + }else{ + mOutputs.push_back(*(mBackend->getNativeTensor(output))); + } + + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + + // need quantize output + if(mNeedQuantDequant){ + // Stage one fp16 -> fp32 + if(mBackend->getUseFP16()){ + mNodeType.clear(); + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Cast"; + std::string name = mNodeName + "_Cast"; + + mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // Stage_add_output + mOutputs.push_back(*(mTempTensorWrappers[5]->getNativeTensor())); // Stage_cast_output + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } - mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); - mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); - - mOutputs.push_back(*(mBackend->getNativeTensor(output))); - - mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + // Stage two fp32 -> int8 + { + mNodeType.clear(); + mParams.clear(); + mInputs.clear(); + mOutputs.clear(); + mNodeType = "Quantize"; + std::string name = mNodeName + "_Quantize"; + + if(mBackend->getUseFP16()){ + mInputs.push_back(*(mTempTensorWrappers[5]->getNativeTensor())); // Stage_cast_output + }else{ + mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // Stage_add_output + } + mOutputs.push_back(*(mBackend->getNativeTensor(output))); // output + mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); + } + } } ErrorCode QNNScale::onResize(const std::vector &inputs, const std::vector &outputs) { diff --git a/source/backend/qnn/execution/QNNScale.hpp b/source/backend/qnn/execution/QNNScale.hpp index 63b1a787..897c307f 100644 --- a/source/backend/qnn/execution/QNNScale.hpp +++ b/source/backend/qnn/execution/QNNScale.hpp @@ -25,6 +25,7 @@ private: private: std::vector mWeightData; std::vector mBiasData; + bool mNeedQuantDequant = false; }; } // end namespace QNN diff --git a/source/core/Backend.hpp b/source/core/Backend.hpp index fa245823..98f08157 100644 --- a/source/core/Backend.hpp +++ b/source/core/Backend.hpp @@ -34,12 +34,17 @@ struct RuntimeHint { int cpuDecreaseRate = 50; int dynamicQuantOption = 0; + // qkvQuantOption % 8: // 0: Do not quantize // 1: Only quantize key, use int8 asymmetric quantization // 2: Only quantize value, use fp8 quantization // 3: quantize both key and value // 4: quantize query, key and value, and use gemm int8 kernel to compute K*V - int qkvQuantOption = 0; + + // qkvQuantOption / 8: + // 1: use flash attention + + int qkvQuantOption = 8; // the kvcache size limit of each layer // if the size of kvcache in memory exceeds the limit @@ -403,6 +408,9 @@ public: info.mode = Backend::Info::DIRECT; return true; } + virtual bool onGetDeviceInfo(const std::string& deviceKey, std::string& deviceValue) const { + return false; + } protected: /** @brief deinitializer. diff --git a/tools/converter/source/common/cli.cpp b/tools/converter/source/common/cli.cpp index fcb25771..f7062edd 100644 --- a/tools/converter/source/common/cli.cpp +++ b/tools/converter/source/common/cli.cpp @@ -129,6 +129,13 @@ static int dumpModelInfo(const char* modelName) { } else { MNN_PRINT("Model Version: %s \n", info->version.c_str()); } + if (!info->metaData.empty()) { + MNN_PRINT("MetaData: Begin \n"); + for (auto& iter : info->metaData) { + MNN_PRINT("[Meta] %s : %s\n", iter.first.c_str(), iter.second.c_str()); + } + MNN_PRINT("MetaData: End \n"); + } return 0; } diff --git a/tools/converter/source/optimizer/PostConverter.cpp b/tools/converter/source/optimizer/PostConverter.cpp index 4215d553..2d60b1a8 100644 --- a/tools/converter/source/optimizer/PostConverter.cpp +++ b/tools/converter/source/optimizer/PostConverter.cpp @@ -130,7 +130,17 @@ bool CompleteSubGraph(const std::unordered_map& inputs, const return true; } - +static bool _hasDupName(std::unique_ptr& originNet) { + std::set names; + for (auto& tensorName : originNet->tensorName) { + if (names.find(tensorName) != names.end()) { + MNN_ERROR("Repeat name %s\n", tensorName.c_str()); + return true; + } + names.insert(tensorName); + } + return false; +} void RunNetPass(const std::vector& passes, std::unique_ptr& originNet) { for (auto pass : passes) { auto convert = PostConverter::get(pass); @@ -138,7 +148,14 @@ void RunNetPass(const std::vector& passes, std::unique_ptroplists.size(); bool valid = convert->onExecute(originNet); +#ifdef DEBUG + auto hasDup = _hasDupName(originNet); + if (originSize != originNet->oplists.size() || hasDup) { + MNN_PRINT("%s: %d -> %d, dup: %d\n", pass.c_str(), originSize, originNet->oplists.size(), hasDup); + } +#endif if (!valid) { LOG(INFO) << "Run " << pass << "Error\n"; } diff --git a/tools/converter/source/optimizer/onnxextra/OnnxClip.cpp b/tools/converter/source/optimizer/onnxextra/OnnxClip.cpp index 54fff31a..3c508598 100644 --- a/tools/converter/source/optimizer/onnxextra/OnnxClip.cpp +++ b/tools/converter/source/optimizer/onnxextra/OnnxClip.cpp @@ -19,7 +19,7 @@ static EXPRP clipConvert(EXPRP expr, bool supportRelu6) { auto extraParam = op->main_as_Extra(); // auto dataType = expr->outputInfo(0)->type.code; auto maxValue = std::numeric_limits().max(); - auto minValue = std::numeric_limits().min(); + auto minValue = std::numeric_limits().lowest(); if (nullptr != extraParam->attr()) { const int attrSize = extraParam->attr()->size(); for (int i = 0; i < attrSize; ++i) { diff --git a/tools/converter/source/optimizer/postconvert/RemoveCopy.cpp b/tools/converter/source/optimizer/postconvert/RemoveCopy.cpp index a588b66b..4a7c7ca1 100644 --- a/tools/converter/source/optimizer/postconvert/RemoveCopy.cpp +++ b/tools/converter/source/optimizer/postconvert/RemoveCopy.cpp @@ -12,20 +12,70 @@ class RemoveCopy : public PostConverter { public: virtual bool onExecute(std::unique_ptr& net) const override { - auto config = Global::Get(); - if (config->optimizeLevel < 1 || config->inSubGraph) { - return true; + std::set netOutputNames; + for (auto& t : net->outputName) { + netOutputNames.insert(t); } + for (auto iter = net->oplists.begin(); iter != net->oplists.end(); iter++) { + auto& op = *iter; + if (op->type == MNN::OpType_Input) { + for (auto o : op->outputIndexes) { + netOutputNames.insert(net->tensorName[o]); + } + } + } + auto config = Global::Get(); for (auto iter = net->oplists.begin(); iter != net->oplists.end();) { auto& op = *iter; - if (op->type != MNN::OpType_Identity) { + if (op->type != MNN::OpType_Identity || op->inputIndexes.size() != op->outputIndexes.size()) { iter++; continue; } + + bool hasOutputName = false; + for (auto o : op->outputIndexes) { + if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { + hasOutputName = true; + break; + } + } + bool hasOutputFromInput = false; + for (auto o : op->inputIndexes) { + if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { + hasOutputFromInput = true; + break; + } + } + if (hasOutputFromInput && hasOutputName) { + iter++; + continue; + } + auto originInput = op->inputIndexes; + auto originOutputs = op->outputIndexes; + MNN_ASSERT(originInput.size() == originOutputs.size()); + if (hasOutputName) { + bool valid = true; + for (int i=0; iinputIndexes.size(); ++i) { + auto o = op->outputIndexes[i]; + auto originInput = op->inputIndexes[i]; + if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { + if (netOutputNames.find(net->tensorName[originInput]) != netOutputNames.end()) { + valid = false; + break; + } + auto originName = net->tensorName[originInput]; + net->tensorName[originInput] = net->tensorName[o]; + net->tensorName[o] = originName; + } + } + if (!valid) { + continue; + } + } + std::map replaceIndexes; for (int i=0; iinputIndexes.size();++i) { replaceIndexes.insert(std::make_pair(op->outputIndexes[i], op->inputIndexes[i])); - net->tensorName[op->inputIndexes[i]] = net->tensorName[op->outputIndexes[i]]; } for (auto subIter = net->oplists.begin(); subIter != net->oplists.end(); subIter++) { auto& subOp = *subIter; @@ -35,14 +85,6 @@ public: } } } - for (int v=0; vinputIndexes.size(); ++v) { - for (auto& o : net->outputName) { - if (o == net->tensorName[op->inputIndexes[v]]) { - o = net->tensorName[op->outputIndexes[v]]; - break; - } - } - } iter = net->oplists.erase(iter); } return true; diff --git a/tools/converter/source/optimizer/postconvert/RemoveTestNoUseOps.cpp b/tools/converter/source/optimizer/postconvert/RemoveTestNoUseOps.cpp new file mode 100644 index 00000000..873b5869 --- /dev/null +++ b/tools/converter/source/optimizer/postconvert/RemoveTestNoUseOps.cpp @@ -0,0 +1,149 @@ +#include "RemoveTestNoUseOps.hpp" +bool RemoveTestNoUseOps::onExecute(std::unique_ptr& net) const { + const MNN::NetT* const netPtr = net.get(); + std::set netOutputNames; + for (auto& t : net->outputName) { + netOutputNames.insert(t); + } + for (auto iter = net->oplists.begin(); iter != net->oplists.end(); iter++) { + auto& op = *iter; + if (op->type == OpType_Input) { + for (auto o : op->outputIndexes) { + netOutputNames.insert(net->tensorName[o]); + } + } + } + + std::unordered_set removedInputs; + for (auto iter = net->oplists.begin(); iter != net->oplists.end();) { + auto& op = *iter; + bool shouldDelete = shouldDeleteJudge(op.get(), netPtr); + if (!shouldDelete) { + iter++; + continue; + } + bool hasOutputName = false; + for (auto o : op->outputIndexes) { + if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { + hasOutputName = true; + break; + } + } + bool hasOutputFromInput = false; + for (auto o : op->inputIndexes) { + if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { + hasOutputFromInput = true; + break; + } + } + if (hasOutputFromInput && hasOutputName) { + iter++; + continue; + } + bool deleteOutput = shouldDeleteOutput(op.get()); + // Find the next op + if (op->outputIndexes.empty() || op->inputIndexes.empty()) { + iter = net->oplists.erase(iter); + continue; + } + auto originInput = op->inputIndexes[0]; + auto originOutputs = op->outputIndexes; + if ((!deleteOutput) && hasOutputName) { + bool valid = true; + for (auto o : originOutputs) { + if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { + if (netOutputNames.find(net->tensorName[originInput]) != netOutputNames.end()) { + valid = false; + break; + } + net->tensorName[originInput] = net->tensorName[o]; + } + } + if (!valid) { + continue; + } + } + for (auto subIter = net->oplists.begin(); subIter != net->oplists.end(); subIter++) { + auto& subOp = *subIter; + if (deleteOutput) { + for (auto iter=subOp->inputIndexes.begin(); iter != subOp->inputIndexes.end();) { + if (std::find(originOutputs.begin(), originOutputs.end(), *iter) != originOutputs.end()) { + iter = subOp->inputIndexes.erase(iter); + continue; + } + iter++; + } + } else { + for (int v = 0; v < subOp->inputIndexes.size(); ++v) { + if (std::find(originOutputs.begin(), originOutputs.end(), subOp->inputIndexes[v]) != originOutputs.end()) { + subOp->inputIndexes[v] = originInput; + } + } + } + } + bool removeUselessInput = shouldRemoveUnusefulInputs(op.get()); + if (removeUselessInput) { + for (int input : op->inputIndexes) { + removedInputs.emplace(input); + } + } + iter = net->oplists.erase(iter); + } + + // Remove the op only if the reference counts of it's all outputs + // are reduced to be zero. + std::unordered_map uselessIndex; + for (const auto& op : net->oplists) { + for (int input : op->inputIndexes) { + auto it = uselessIndex.find(input); + if (it == uselessIndex.end()) { + uselessIndex.emplace(input, 1); + } else { + ++it->second; + } + } + } + // Set reference count 1 for all net outputs. + for (const auto& op : net->oplists) { + for (int output : op->outputIndexes) { + auto it = uselessIndex.find(output); + if (it == uselessIndex.end()) { + if (removedInputs.count(output)) { + uselessIndex.emplace(output, 0); + } else { + uselessIndex.emplace(output, 1); + } + } + } + } + + bool needIteration = false; + do { + needIteration = false; + for (auto iter = net->oplists.begin(); iter != net->oplists.end();) { + auto& op = *iter; + bool useless = true; + for (auto index : op->outputIndexes) { + if (uselessIndex.at(index) > 0) { + useless = false; + break; + } + } + if (!useless) { + iter++; + continue; + } + if (!op->inputIndexes.empty()) { + for (auto index : op->inputIndexes) { + auto it = uselessIndex.find(index); + MNN_ASSERT(it != uselessIndex.end()); + --it->second; + } + needIteration = true; + } + iter = net->oplists.erase(iter); + } + } while (needIteration); + + return true; +} diff --git a/tools/converter/source/optimizer/postconvert/RemoveTestNoUseOps.hpp b/tools/converter/source/optimizer/postconvert/RemoveTestNoUseOps.hpp index 6919d994..e7285b3b 100644 --- a/tools/converter/source/optimizer/postconvert/RemoveTestNoUseOps.hpp +++ b/tools/converter/source/optimizer/postconvert/RemoveTestNoUseOps.hpp @@ -25,135 +25,5 @@ public: virtual bool shouldDeleteOutput(const MNN::OpT* op) const = 0; - virtual bool onExecute(std::unique_ptr& net) const override { - const MNN::NetT* const netPtr = net.get(); - std::set netOutputNames; - for (auto& t : net->outputName) { - netOutputNames.insert(t); - } - std::unordered_set removedInputs; - for (auto iter = net->oplists.begin(); iter != net->oplists.end();) { - auto& op = *iter; - bool shouldDelete = shouldDeleteJudge(op.get(), netPtr); - if (!shouldDelete) { - iter++; - continue; - } - bool hasOutputName = false; - for (auto o : op->outputIndexes) { - if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { - hasOutputName = true; - break; - } - } - bool hasOutputFromInput = false; - for (auto o : op->inputIndexes) { - if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { - hasOutputFromInput = true; - break; - } - } - if (hasOutputFromInput && hasOutputName) { - iter++; - continue; - } - bool deleteOutput = shouldDeleteOutput(op.get()); - // Find the next op - if (op->outputIndexes.empty() || op->inputIndexes.empty()) { - iter = net->oplists.erase(iter); - continue; - } - auto originInput = op->inputIndexes[0]; - auto originOutputs = op->outputIndexes; - if ((!deleteOutput) && hasOutputName) { - for (auto o : originOutputs) { - if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { - net->tensorName[originInput] = net->tensorName[o]; - } - } - } - for (auto subIter = net->oplists.begin(); subIter != net->oplists.end(); subIter++) { - auto& subOp = *subIter; - if (deleteOutput) { - for (auto iter=subOp->inputIndexes.begin(); iter != subOp->inputIndexes.end();) { - if (std::find(originOutputs.begin(), originOutputs.end(), *iter) != originOutputs.end()) { - iter = subOp->inputIndexes.erase(iter); - continue; - } - iter++; - } - } else { - for (int v = 0; v < subOp->inputIndexes.size(); ++v) { - if (std::find(originOutputs.begin(), originOutputs.end(), subOp->inputIndexes[v]) != originOutputs.end()) { - subOp->inputIndexes[v] = originInput; - } - } - } - } - bool removeUselessInput = shouldRemoveUnusefulInputs(op.get()); - if (removeUselessInput) { - for (int input : op->inputIndexes) { - removedInputs.emplace(input); - } - } - iter = net->oplists.erase(iter); - } - - // Remove the op only if the reference counts of it's all outputs - // are reduced to be zero. - std::unordered_map uselessIndex; - for (const auto& op : net->oplists) { - for (int input : op->inputIndexes) { - auto it = uselessIndex.find(input); - if (it == uselessIndex.end()) { - uselessIndex.emplace(input, 1); - } else { - ++it->second; - } - } - } - // Set reference count 1 for all net outputs. - for (const auto& op : net->oplists) { - for (int output : op->outputIndexes) { - auto it = uselessIndex.find(output); - if (it == uselessIndex.end()) { - if (removedInputs.count(output)) { - uselessIndex.emplace(output, 0); - } else { - uselessIndex.emplace(output, 1); - } - } - } - } - - bool needIteration = false; - do { - needIteration = false; - for (auto iter = net->oplists.begin(); iter != net->oplists.end();) { - auto& op = *iter; - bool useless = true; - for (auto index : op->outputIndexes) { - if (uselessIndex.at(index) > 0) { - useless = false; - break; - } - } - if (!useless) { - iter++; - continue; - } - if (!op->inputIndexes.empty()) { - for (auto index : op->inputIndexes) { - auto it = uselessIndex.find(index); - MNN_ASSERT(it != uselessIndex.end()); - --it->second; - } - needIteration = true; - } - iter = net->oplists.erase(iter); - } - } while (needIteration); - - return true; - } + virtual bool onExecute(std::unique_ptr& net) const override; }; diff --git a/tools/converter/source/optimizer/postconvert/RemoveUnusefulOp.cpp b/tools/converter/source/optimizer/postconvert/RemoveUnusefulOp.cpp index 03235eeb..631ca5d0 100644 --- a/tools/converter/source/optimizer/postconvert/RemoveUnusefulOp.cpp +++ b/tools/converter/source/optimizer/postconvert/RemoveUnusefulOp.cpp @@ -33,6 +33,11 @@ public: return true; } } + if (op->type == OpType_Identity) { + // Support 1->N + return op->inputIndexes.size() == 1 && op->outputIndexes.size() > 1; + } + if (op->type == OpType_Cast) { if (op->main.AsCastParam()->dstT == op->main.AsCastParam()->srcT) { return true; diff --git a/tools/cpp/MNN2QNNModel.cpp b/tools/cpp/MNN2QNNModel.cpp index 707e5a1d..7ea7716e 100644 --- a/tools/cpp/MNN2QNNModel.cpp +++ b/tools/cpp/MNN2QNNModel.cpp @@ -37,11 +37,21 @@ int main(int argc, const char* argv[]) { /** generate qnn .cpp and .bin */ - + std::string dstModelName = dstMNN; + size_t pos = dstModelName.find_last_of("/\\"); + std::string dstModelPath; + if (pos == std::string::npos) { + // current path + dstModelPath = "./"; + } else { + dstModelPath = dstModelName.substr(0, pos); + } + std::string qnnModelPath = dstModelPath + "/" + qnnModelName; + MNN_PRINT("[Temp Product]: Qnn temp product generate at %s\n", qnnModelPath.c_str()); MNN::ScheduleConfig config; config.type = MNN_FORWARD_NN; std::shared_ptr rtmgr(Executor::RuntimeManager::createRuntimeManager(config)); - rtmgr->setCache(qnnModelName.c_str()); + rtmgr->setCache(qnnModelPath.c_str()); MNN::Express::Module::Config mConfig; mConfig.shapeMutable = false; std::shared_ptr m(MNN::Express::Module::load(inputNames, outputNames, srcMNN, rtmgr, &mConfig), MNN::Express::Module::destroy); @@ -64,7 +74,7 @@ int main(int argc, const char* argv[]) { } int ret = 0; - std::string tarBinCmd = "cd " + qnnModelName + \ + std::string tarBinCmd = "cd " + qnnModelPath + \ " && " + \ "tar -cf " + qnnModelName + ".bin *.raw"; ret = system(tarBinCmd.c_str()); @@ -74,10 +84,10 @@ int main(int argc, const char* argv[]) { } std::string modelLibCmd = qnnSdkPath + "/bin/x86_64-linux-clang/qnn-model-lib-generator " + \ - "-c " + qnnModelName + "/" + qnnModelName + ".cpp " + \ - "-b " + qnnModelName + "/" + qnnModelName + ".bin " + \ + "-c " + qnnModelPath + "/" + qnnModelName + ".cpp " + \ + "-b " + qnnModelPath + "/" + qnnModelName + ".bin " + \ "-t x86_64-linux-clang " + \ - "-o " + qnnModelName + "/lib "; + "-o " + qnnModelPath + "/lib "; ret = system(modelLibCmd.c_str()); if(ret) { MNN_ERROR("[Error]: qnn-model-lib-generator error!\n"); @@ -86,12 +96,13 @@ int main(int argc, const char* argv[]) { MNN_PRINT("[Pass]: qnn-model-lib-generator success!\n"); } + std::string qnnBin = dstModelPath + "/" + qnnModelName + ".bin"; std::string binaryGenCmd = qnnSdkPath + "/bin/x86_64-linux-clang/qnn-context-binary-generator " + \ - "--model " + qnnModelName + "/lib/x86_64-linux-clang/lib" + qnnModelName + ".so " + \ + "--model " + qnnModelPath + "/lib/x86_64-linux-clang/lib" + qnnModelName + ".so " + \ "--backend " + qnnSdkPath + "/lib/x86_64-linux-clang/libQnnHtp.so " + \ "--binary_file " + qnnModelName + " " + \ "--config_file " + qnnContextConfig + " " + \ - "--output_dir " + qnnModelName + "/binary"; + "--output_dir " + dstModelPath; ret = system(binaryGenCmd.c_str()); if(ret) { MNN_ERROR("[Error]: qnn-context-binary-generator error!\n"); @@ -100,7 +111,7 @@ int main(int argc, const char* argv[]) { MNN_PRINT("[Pass]: qnn-context-binary-generator success!\n"); } - + std::vector inputInfos(inputs.size()); for (int i=0; igetInfo(); @@ -122,7 +133,8 @@ int main(int argc, const char* argv[]) { dstNet->oplists.emplace_back(std::move(input)); } - std::string npuPath = std::string("/") + qnnModelName + std::string(".bin"); + std::string npuPath = std::string("/") + qnnModelName + std::string(".bin"); + MNN_PRINT("npu model path:%s\n", npuPath.c_str()); /** Fuse to Op*/ std::unique_ptr op(new OpT); @@ -204,7 +216,7 @@ int main(int argc, const char* argv[]) { } for (int i=0; ikey = "o_" + std::to_string(i) + "_0"; + attr->key = "o_0_" + std::to_string(i); attr->tensor.reset(new BlobT); attr->tensor->dataType = OpCommonUtils::convertDataType(outputInfos[i].type); attr->tensor->dims = outputInfos[i].dim; @@ -240,7 +252,6 @@ int main(int argc, const char* argv[]) { outputOs.close(); MNN_PRINT("[All Pass]: npu model generator success!\n"); - std::string qnnBin = qnnModelName + "/binary/" + qnnModelName + ".bin"; MNN_PRINT("[Output Product]:\nNew mnn model path: %s\nNpu model path: %s\n", dstMNN, qnnBin.c_str()); return 0; } diff --git a/tools/script/convertOnnxTest.py b/tools/script/convertOnnxTest.py index 9f55dd9f..fec82704 100755 --- a/tools/script/convertOnnxTest.py +++ b/tools/script/convertOnnxTest.py @@ -43,6 +43,6 @@ for w in gWrong: print(w) print('TEST_NAME_MODULE: 模型测试\nTEST_CASE_AMOUNT_MODULE: {\"blocked\":0,\"failed\":%d,\"passed\":%d,\"skipped\":0}\n'%(len(gWrong), total_num - len(gWrong))) print('TEST_CASE={\"name\":\"Onnx转换测试\",\"failed\":%d,\"passed\":%d}\n'%(len(gWrong), total_num - len(gWrong))) -print('Total Size: ', mnnsize,' MB, convert ', correct_num," model") +print('Total Size: ', total_size,' MB, convert ', correct_num," model") if len(gWrong) > 0: exit(1) diff --git a/tools/script/genQNNModelsFromMNN.py b/tools/script/genQNNModelsFromMNN.py new file mode 100644 index 00000000..cd7f87fc --- /dev/null +++ b/tools/script/genQNNModelsFromMNN.py @@ -0,0 +1,168 @@ +import json +import copy +import argparse +import os +import subprocess +import shutil # 导入 shutil 模块用于删除目录 + +def generate_all_configs(config_path, graph_name, qnn_sdk_root_path, src_model, executable_path, output_dir): + """ + 为每个组合创建子目录,生成配置文件,并调用C++可执行文件进行模型转换。 + """ + # --- 0. 准备工作 --- + # 创建主输出目录 + os.makedirs(output_dir, exist_ok=True) + print(f"所有生成的文件将被保存在主目录: '{output_dir}'") + + # 定义组合 + combinations = [ + [36, 'v69'], + [42, 'v69'], + [43, 'v73'], + [57, 'v75'], + [69, 'v79'] + ] + + # --- 1. 读取模板文件 --- + htp_template_file = os.path.join(config_path, "htp_backend_extensions.json") + context_template_file = os.path.join(config_path, "context_config.json") + + try: + with open(htp_template_file, 'r', encoding='utf-8') as f: + base_htp_data = json.load(f) + print(f"成功读取模板文件 '{htp_template_file}'。") + + with open(context_template_file, 'r', encoding='utf-8') as f: + base_context_data = json.load(f) + print(f"成功读取模板文件 '{context_template_file}'。") + except FileNotFoundError as e: + print(f"错误:模板文件未找到。请确保 '{e.filename}' 存在于指定的路径中。") + return + except json.JSONDecodeError as e: + print(f"错误:文件格式无效。请检查 {e.doc} 是否为有效的JSON。") + return + + # --- 2. 遍历组合,生成文件并执行命令 --- + for soc_id, dsp_arch in combinations: + print(f"\n{'='*15} 处理组合: soc_id={soc_id}, dsp_arch={dsp_arch} {'='*15}") + + # --- 新增步骤: 为当前组合创建专用的子目录 --- + new_graph_name = f"{graph_name}_{soc_id}_{dsp_arch}" + graph_specific_dir = output_dir + + # --- Part A: 生成 htp_backend_extensions 文件 (路径更新) --- + htp_config_data = copy.deepcopy(base_htp_data) + try: + htp_config_data["graphs"][0]["graph_names"] = [new_graph_name] + htp_config_data["devices"][0]["soc_id"] = soc_id + htp_config_data["devices"][0]["dsp_arch"] = dsp_arch + except (IndexError, KeyError) as e: + print(f"处理组合时出错: '{htp_template_file}' 结构不正确。错误: {e}") + continue + + htp_output_filename = f"htp_backend_extensions_{soc_id}_{dsp_arch}.json" + # 更新路径,使其指向新的子目录 + htp_output_filepath = os.path.join(graph_specific_dir, htp_output_filename) + with open(htp_output_filepath, 'w', encoding='utf-8') as f: + json.dump(htp_config_data, f, indent=4, ensure_ascii=False) + print(f"-> 已生成配置文件: '{htp_output_filepath}'") + + # --- Part B: 生成 context_config 文件 (路径更新) --- + context_config_data = copy.deepcopy(base_context_data) + try: + # 这里的 htp_output_filename 是相对路径,这是正确的, + # 因为 context_config 和 htp_backend_extensions 在同一个目录中。 + context_config_data["backend_extensions"]["config_file_path"] = htp_output_filepath + path_template = context_config_data["backend_extensions"]["shared_library_path"] + new_lib_path = path_template.replace("{QNN_SDK_ROOT}", qnn_sdk_root_path) + context_config_data["backend_extensions"]["shared_library_path"] = new_lib_path + except KeyError as e: + print(f"处理组合时出错: '{context_template_file}' 结构不正确,缺少键: {e}") + continue + + context_output_filename = f"context_config_{soc_id}_{dsp_arch}.json" + # 更新路径,使其指向新的子目录 + context_output_filepath = os.path.join(graph_specific_dir, context_output_filename) + with open(context_output_filepath, 'w', encoding='utf-8') as f: + json.dump(context_config_data, f, indent=4, ensure_ascii=False) + print(f"-> 已生成关联文件: '{context_output_filepath}'") + + # --- Part C: 调用C++可执行命令 (路径更新) --- + dst_model_filename = f"{graph_name}_{soc_id}_{dsp_arch}.mnn" + # 更新路径,使其指向新的子目录 + dst_model_filepath = os.path.join(graph_specific_dir, dst_model_filename) + + graph_product_dir = os.path.join(graph_specific_dir, new_graph_name) + os.makedirs(graph_product_dir, exist_ok=True) + print(f"-> 已创建/确认子目录: '{graph_product_dir}'") + + command = [ + executable_path, + src_model, + dst_model_filepath, + qnn_sdk_root_path, + new_graph_name, + context_output_filepath + ] + + print("--> 准备执行命令...") + print(f" $ {' '.join(command)}") + + try: + result = subprocess.run(command, check=True, capture_output=True, text=True) + print("--> 命令执行成功!") + # 即使成功,也打印 C++ 程序的输出,这对于查看警告等信息很有用 + if result.stdout: + print(" --- C++程序输出 (stdout) ---") + print(result.stdout.strip()) + print(" ------------------------------") + + except FileNotFoundError: + print(f"!!! 命令执行失败: 可执行文件未找到 '{executable_path}'。请检查路径。") + break # 如果可执行文件找不到,直接退出循环 + except subprocess.CalledProcessError as e: + # 这是关键的修改部分 + print(f"!!! 命令执行失败 (返回码: {e.returncode})") + + # 检查并打印 C++ 程序在失败前产生的标准输出 + if e.stdout: + print(" --- C++程序输出 (stdout) ---") + print(e.stdout.strip()) + print(" ------------------------------") + + # 检查并打印 C++ 程序在失败前产生的标准错误(错误日志通常在这里) + if e.stderr: + print(" --- C++程序错误 (stderr) ---") + print(e.stderr.strip()) + print(" ------------------------------") + except Exception as e: + print(f"!!! 执行期间发生未知错误: {e}") + + finally: + # --- 步骤 3: 清理 --- + # 检查目录是否存在,然后删除 + if os.path.exists(graph_product_dir): + print(f"--> 清理临时文件和目录: '{graph_product_dir}'") + shutil.rmtree(graph_product_dir) + else: + print("--> 无需清理,临时目录未创建。") + +# --- 脚本执行入口 --- +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="为多个组合创建子目录,生成QNN配置文件并调用模型转换工具。", + formatter_class=argparse.RawTextHelpFormatter + ) + # ... (argparse部分保持完全不变) ... + gen_group = parser.add_argument_group('文件生成参数') + gen_group.add_argument("--config_path", required=True, help="[必需] 包含模板文件的目录路径。") + gen_group.add_argument("--graph_name", required=True, help="[必需] 模型图的名称 (不含soc_id等后缀)。") + gen_group.add_argument("--qnn_sdk_root_path", required=True, help="[必需] QNN SDK 的根路径。") + + exec_group = parser.add_argument_group('模型转换参数') + exec_group.add_argument("--src_model", required=True, help="[必需] 源模型文件路径 (例如: my_model.mnn)。") + exec_group.add_argument("--executable_path", required=True, help="[必需] C++模型转换可执行文件的路径。") + exec_group.add_argument("--output_dir", default="./qnn_models", help="存放所有生成文件的输出目录 (默认: ./qnn_models)。") + + args = parser.parse_args() + generate_all_configs(args.config_path, args.graph_name, args.qnn_sdk_root_path, args.src_model, args.executable_path, args.output_dir) diff --git a/transformers/llm/engine/CMakeLists.txt b/transformers/llm/engine/CMakeLists.txt index 79361828..955a67a8 100644 --- a/transformers/llm/engine/CMakeLists.txt +++ b/transformers/llm/engine/CMakeLists.txt @@ -51,6 +51,21 @@ if (LLM_SUPPORT_AUDIO AND MNN_BUILD_AUDIO) add_definitions(-DLLM_SUPPORT_AUDIO) endif() +IF(CMAKE_SYSTEM_NAME MATCHES "^Android" AND NOT MNN_BUILD_FOR_ANDROID_COMMAND) +IF(NOT NATIVE_INCLUDE_OUTPUT) + set(NATIVE_INCLUDE_OUTPUT ".") +ENDIF() +add_custom_command( + TARGET llm + POST_BUILD + COMMAND ${CMAKE_COMMAND} + ARGS -E copy_directory ${CMAKE_CURRENT_LIST_DIR}/include ${NATIVE_INCLUDE_OUTPUT} +) +ELSE() +INSTALL(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/include/ DESTINATION include FILES_MATCHING PATTERN *.hpp) +ENDIF() + + add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/demo/llm_demo.cpp) add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/demo/embedding_demo.cpp) add_executable(reranker_demo ${CMAKE_CURRENT_LIST_DIR}/demo/reranker_demo.cpp) diff --git a/transformers/llm/engine/include/llm/llm.hpp b/transformers/llm/engine/include/llm/llm.hpp index af25ef18..dc515e94 100644 --- a/transformers/llm/engine/include/llm/llm.hpp +++ b/transformers/llm/engine/include/llm/llm.hpp @@ -37,6 +37,23 @@ struct TimePerformance; using ChatMessage = std::pair; // using ChatMessages = std::vector; +struct MNN_PUBLIC PromptImagePart { + MNN::Express::VARP image_data; + int width; + int height; +}; + +struct MNN_PUBLIC PromptAudioPart { + std::string file_path; + MNN::Express::VARP waveform; +}; + +struct MNN_PUBLIC MultimodalPrompt { + std::string prompt_template; + std::map images; + std::map audios; +}; + enum TuneType { // op encoder number for commit OP_ENCODER_NUMBER = 0, @@ -108,11 +125,17 @@ public: std::string dump_config(); bool set_config(const std::string& content); Llm* create_lora(const std::string& lora_path); + std::string get_statistics(); // tokenier function bool is_stop(int token); std::string tokenizer_decode(int token); virtual std::vector tokenizer_encode(const std::string& query); friend class Pipeline; + virtual std::vector tokenizer_encode(const MultimodalPrompt& multimodal_input); + void response(const MultimodalPrompt& multimodal_input, + std::ostream* os = &std::cout, + const char* end_with = nullptr, + int max_new_tokens = -1); const LlmContext* getContext() const { return mContext.get(); } diff --git a/transformers/llm/engine/src/diskembedding.cpp b/transformers/llm/engine/src/diskembedding.cpp index 47143eca..fb0ed1e3 100644 --- a/transformers/llm/engine/src/diskembedding.cpp +++ b/transformers/llm/engine/src/diskembedding.cpp @@ -48,6 +48,7 @@ DiskEmbedding::DiskEmbedding(const std::shared_ptr& config, std::stri if (mQuantBit != 16) { if (mQuantBlock == 0) { mBlockNum = 1; + mQuantBlock = mHiddenSize; // be used for mDequantFunc. } else { mBlockNum = mHiddenSize / mQuantBlock; } diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index c0bee418..d3480635 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -97,11 +98,44 @@ bool Llm::set_config(const std::string& content) { return res; } +std::string Llm::get_statistics() { + auto context = getContext(); + int prompt_len = context->prompt_len; + int decode_len = context->gen_seq_len; + int64_t vision_time = context->vision_us; + int64_t audio_time = context->audio_us; + int64_t prefill_time = context->prefill_us; + int64_t decode_time = context->decode_us; + int64_t sample_time = context->sample_us; + float vision_s = vision_time / 1e6; + float audio_s = audio_time / 1e6; + float prefill_s = prefill_time / 1e6; + float decode_s = decode_time / 1e6; + float sample_s = sample_time / 1e6; + float prefill_speed = (prefill_s > 0.0f) ? (prompt_len / prefill_s) : 0.0f; + float decode_speed = (decode_s > 0.0f) ? (decode_len / decode_s) : 0.0f; + + std::ostringstream json_stream; + json_stream << "{" + << "\"prompt_tokens\":" << prompt_len << "," + << "\"decode_tokens\":" << decode_len << "," + << "\"vision_time\":" << std::fixed << std::setprecision(2) << vision_s << "," + << "\"audio_time\":" << std::fixed << std::setprecision(2) << audio_s << "," + << "\"prefill_time\":" << std::fixed << std::setprecision(2) << prefill_s << "," + << "\"decode_time\":" << std::fixed << std::setprecision(2) << decode_s << "," + << "\"sample_time\":" << std::fixed << std::setprecision(2) << sample_s << "," + << "\"prefill_speed\":" << std::fixed << std::setprecision(2) << prefill_speed << "," + << "\"decode_speed\":" << std::fixed << std::setprecision(2) << decode_speed + << "}"; + + return json_stream.str(); +} + void Llm::setRuntimeHint(std::shared_ptr &rtg) { rtg->setHint(MNN::Interpreter::INIT_THREAD_NUMBER, 4); rtg->setHint(MNN::Interpreter::MEM_ALLOCATOR_TYPE, 0); - rtg->setHint(MNN::Interpreter::QKV_QUANT_OPTIONS, mConfig->quant_qkv()); + rtg->setHint(MNN::Interpreter::QKV_QUANT_OPTIONS, mConfig->config_.value("quant_qkv", 8)); rtg->setHint(MNN::Interpreter::KVCACHE_SIZE_LIMIT, mConfig->kvcache_limit()); if (mConfig->use_cached_mmap()) { rtg->setHint(MNN::Interpreter::USE_CACHED_MMAP, 1); @@ -113,6 +147,8 @@ void Llm::setRuntimeHint(std::shared_ptr &rtg if (mConfig->use_mmap()) { rtg->setExternalPath(tmpPath, MNN::Interpreter::EXTERNAL_WEIGHT_DIR); } + // set npu model dir + rtg->setExternalPath(mConfig->npu_model_dir(), 3); auto dynamicOption = mConfig->dynamic_option(); if (mConfig->dynamic_option()) { rtg->setHint(MNN::Interpreter::DYNAMIC_QUANT_OPTIONS, mConfig->dynamic_option()); @@ -246,8 +282,7 @@ void Llm::load() { if (needHiddenState) { outputNames.emplace_back("hidden_states"); } - // set npu model dir - mRuntimeManager->setExternalPath(mConfig->npu_model_dir(), 3); + mRuntimeManager->setExternalFile(mConfig->llm_weight()); mModules[0].reset(Module::load(inputNames, outputNames, model_path.c_str(), mRuntimeManager, &module_config)); mRuntimeManager->setExternalFile(""); @@ -594,6 +629,29 @@ std::vector Llm::tokenizer_encode(const std::string& user_content) { return mTokenizer->encode(user_content); } +std::vector Llm::tokenizer_encode(const MultimodalPrompt& multimodal_input) { + return mTokenizer->encode(multimodal_input.prompt_template); +} + +void Llm::response(const MultimodalPrompt& multimodal_input, + std::ostream* os, const char* end_with, int max_new_tokens) { + auto prompt = multimodal_input.prompt_template; + if (mConfig->use_template()) { + prompt = mPrompt->applyTemplate(prompt, true); + } + + int prompt_len = 0; + int decode_len = 0; + int64_t vision_time = 0; + int64_t audio_time = 0; + int64_t prefill_time = 0; + int64_t decode_time = 0; + int64_t sample_time = 0; + + std::vector input_ids = tokenizer_encode(multimodal_input); + response(input_ids, os, end_with, max_new_tokens); +} + std::vector Llm::generate(MNN::Express::VARP input_embeds, int max_tokens) { if (max_tokens < 0) { max_tokens = mConfig->max_new_tokens(); @@ -621,8 +679,8 @@ std::vector Llm::generate(MNN::Express::VARP input_embeds, int max_tokens) } { std::ofstream outFile("logits.txt"); - auto temp = outputs[0]->readMap(); - for (size_t i = 0; i < outputs[0]->getInfo()->size; ++i) { + auto temp = mGenerateParam->outputs[0]->readMap(); + for (size_t i = 0; i < mGenerateParam->outputs[0]->getInfo()->size; ++i) { outFile << temp[i] << " "; // 每个数字后加空格 } outFile.close(); diff --git a/transformers/llm/engine/src/llmconfig.hpp b/transformers/llm/engine/src/llmconfig.hpp index b2d3b929..58386d90 100644 --- a/transformers/llm/engine/src/llmconfig.hpp +++ b/transformers/llm/engine/src/llmconfig.hpp @@ -359,10 +359,6 @@ public: return config_.value("memory", "low"); } - int quant_qkv() const { - return config_.value("quant_qkv", 0); - } - int kvcache_limit() const { return config_.value("kvcache_limit", -1); } diff --git a/transformers/llm/engine/src/omni.cpp b/transformers/llm/engine/src/omni.cpp index 7a26269e..7ad98036 100644 --- a/transformers/llm/engine/src/omni.cpp +++ b/transformers/llm/engine/src/omni.cpp @@ -9,6 +9,7 @@ #define _USE_MATH_DEFINES #endif #include +#include #include #include #include "omni.hpp" @@ -555,8 +556,16 @@ std::vector Omni::minicpmVisionProcess(VARP image) { std::vector Omni::visionProcess(const std::string& file) { #ifdef LLM_SUPPORT_VISION VARP image = MNN::CV::imread(file); + return visionProcess(image); +#else + return std::vector(0); +#endif +} + +std::vector Omni::visionProcess(VARP image) { +#ifdef LLM_SUPPORT_VISION if (image == nullptr) { - MNN_PRINT("Omni Can't open image: %s\n", file.c_str()); + MNN_PRINT("Omni Can't open image\n"); return std::vector(0); } Timer _t; @@ -573,7 +582,7 @@ std::vector Omni::visionProcess(const std::string& file) { } else { imgIds = defaultVisionProcess(image); } - mContext->vision_us = _t.durationInUs(); + mContext->vision_us += _t.durationInUs(); // set vision number for image idx mVisionNum += 1; return imgIds; @@ -591,9 +600,19 @@ std::vector Omni::audioProcess(const std::string& file) { MNN_PRINT("Omni Can't open audio: %s\n", file.c_str()); return std::vector(0); } - // int sample_rate = load_res.second; - int wav_len = waveform->getInfo()->dim[0]; - int hop_length = 160; + return audioProcess(waveform); +#else + return std::vector(0); +#endif +} + +std::vector Omni::audioProcess(MNN::Express::VARP waveform) { +#ifdef LLM_SUPPORT_AUDIO + if (waveform == nullptr) { + MNN_PRINT("Omni Can't process audio: waveform is null\n"); + return std::vector(0); + } + Timer _t; auto input_features = MNN::AUDIO::whisper_fbank(waveform); VARP audio_embedding; @@ -727,21 +746,30 @@ void Omni::addPositionIds(int t, int h, int w) { } } -std::vector Omni::tokenizer_encode(const std::string& prompt) { - // split query +std::vector Omni::tokenizer_encode(const MultimodalPrompt& multimodal_input) { + std::string prompt = multimodal_input.prompt_template; + // MNN_PRINT("tokenizer_encode(MultimodalPrompt) prompt: %s", prompt.c_str()); std::regex multimode_regex("<(img|audio)>(.*?)"); std::string::const_iterator searchStart(prompt.cbegin()); std::smatch match; - std::vector img_infos; std::vector ids{}; - mPositionIds.clear(); + while (std::regex_search(searchStart, prompt.cend(), match, multimode_regex)) { - // std::cout << "img match: " << match[1].str() << std::endl; auto txt_ids = mTokenizer->encode(match.prefix().str()); addPositionIds(txt_ids.size()); ids.insert(ids.end(), txt_ids.begin(), txt_ids.end()); - auto mul_ids = multimodeProcess(match[1].str(), match[2].str()); + std::string mode = match[1].str(); + std::string content = match[2].str(); + std::vector mul_ids; + if (mode == "img") { + mul_ids = processImageContent(content, multimodal_input.images); + // MNN_PRINT("tokenizer_encode(MultimodalPrompt) image mul_ids size: %lu", mul_ids.size()); + } else if (mode == "audio") { + mul_ids = processAudioContent(content, multimodal_input.audios); + // MNN_PRINT("tokenizer_encode(MultimodalPrompt) audio mul_ids size: %lu", mul_ids.size()); + } + ids.insert(ids.end(), mul_ids.begin(), mul_ids.end()); searchStart = match.suffix().first; } @@ -753,6 +781,43 @@ std::vector Omni::tokenizer_encode(const std::string& prompt) { return ids; } +std::vector Omni::tokenizer_encode(const std::string& prompt) { + MultimodalPrompt multimodal_input; + multimodal_input.prompt_template = prompt; + return tokenizer_encode(multimodal_input); +} + +std::vector Omni::processImageContent(const std::string& content, const std::map& images) { + auto it = images.find(content); + if (it != images.end()) { + if (it->second.height > 0 && it->second.width > 0) { + mVisionHeight = it->second.height; + mVisionWidth = it->second.width; + } + // MNN_PRINT("processImageContent: using placeholder '%s' with size %dx%d", content.c_str(), mVisionWidth, mVisionHeight); + return visionProcess(it->second.image_data); + } + // MNN_PRINT("processImageContent: treating '%s' as file path or URL", content.c_str()); + return multimodeProcess("img", content); +} + +std::vector Omni::processAudioContent(const std::string& content, const std::map& audios) { + auto it = audios.find(content); + if (it != audios.end()) { + // MNN_PRINT("processAudioContent: using placeholder '%s'", content.c_str()); + if (it->second.waveform.get() != nullptr) { + return audioProcess(it->second.waveform); + } else if (!it->second.file_path.empty()) { + return audioProcess(it->second.file_path); + } else { + MNN_PRINT("processAudioContent: audio_part has no valid input\n"); + return std::vector(0); + } + } + // MNN_PRINT("processAudioContent: treating '%s' as file path", content.c_str()); + return multimodeProcess("audio", content); +} + VARP Omni::embedding(const std::vector& input_ids) { if (input_ids.size() == 1) { return Llm::embedding(input_ids); @@ -845,21 +910,28 @@ VARP Omni::gen_position_ids(int seq_len) { auto ptr = positionIds->writeMap(); if (mContext->gen_seq_len > 0) { for (int i=0; igen_seq_len + mPositionIds.back() + i; + // auto pos = mContext->gen_seq_len + mPositionIds.back() + i; + auto pos = mContext->all_seq_len + i; ptr[i + 0] = pos; ptr[i + seq_len] = pos; ptr[i + seq_len * 2] = pos; } } else { for (int i = 0; i < seq_len; i++) { - ptr[i] = mPositionIds.mT[i]; - ptr[i + seq_len] = mPositionIds.mH[i]; - ptr[i + seq_len * 2] = mPositionIds.mW[i]; + ptr[i] = mPositionIds.mT[i] + mContext->all_seq_len; + ptr[i + seq_len] = mPositionIds.mH[i] + mContext->all_seq_len; + ptr[i + seq_len * 2] = mPositionIds.mW[i] + mContext->all_seq_len; } if (mTalker) { mTalker->setPostionIds(mPositionIds); } } + // // dump position ids + // printf("position_ids = ["); + // for (int i = 0; i < seq_len; i++) { + // printf("%d ", ptr[i]); + // } + // printf("]\n"); return positionIds; } diff --git a/transformers/llm/engine/src/omni.hpp b/transformers/llm/engine/src/omni.hpp index 217e0ffa..4802e82e 100644 --- a/transformers/llm/engine/src/omni.hpp +++ b/transformers/llm/engine/src/omni.hpp @@ -105,12 +105,14 @@ public: virtual void load() override; virtual std::vector forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos) override; virtual std::vector tokenizer_encode(const std::string& query) override; + virtual std::vector tokenizer_encode(const MultimodalPrompt& multimodal_input) override; virtual Express::VARP embedding(const std::vector& input_ids) override; virtual Express::VARP gen_position_ids(int seq_len) override; virtual void response(const std::vector& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, int max_new_tokens = -1) override; virtual void setWavformCallback(std::function callback) override; virtual void generateWavform() override; // some models preprocess function + std::vector visionProcess(VARP image); std::vector defaultVisionProcess(VARP image); std::vector qwen2VisionProcess(VARP image); std::vector smolvlmVisionProcess(VARP image); @@ -126,6 +128,9 @@ private: std::vector multimodeProcess(const std::string& mode, std::string info); std::vector visionProcess(const std::string& file); std::vector audioProcess(const std::string& file); + std::vector audioProcess(MNN::Express::VARP waveform); + std::vector processImageContent(const std::string& content, const std::map& images); + std::vector processAudioContent(const std::string& content, const std::map& audios); std::shared_ptr mVisionModule, mAudioModule; std::vector mVisionEmbeddings, mAudioEmbeddings; std::shared_ptr mTalker; diff --git a/transformers/llm/export/llmexport.py b/transformers/llm/export/llmexport.py index cb307fd4..a8700354 100644 --- a/transformers/llm/export/llmexport.py +++ b/transformers/llm/export/llmexport.py @@ -225,6 +225,7 @@ class LlmExporter(torch.nn.Module): self.llm_config['jinja'] = prompt_template['jinja'] # load modules ModelMapper.do_map(self, self.model, self.model_map['model']) + # rebuild modules if self.lm_ is None: out_features, in_features = self.embed_.weight.shape @@ -244,6 +245,7 @@ class LlmExporter(torch.nn.Module): self.embed = Embedding(embed_copy, self) else: self.embed = Embedding(self.embed_, self) + # tie word embeddings self.tie_word_embeddings = not self.args.seperate_embed and self.lm_.weight.equal(self.embed_.weight) if self.tie_word_embeddings: @@ -802,14 +804,21 @@ class LlmExporter(torch.nn.Module): if self.mnn_converter: fuse_transformer = self.visual.transformer_fuse native_group_conv = self.visual.group_conv_native + quant_bit_visual = self.visual.quant_bit + quant_block_visual = self.visual.quant_block if self.args.transformer_fuse: fuse_transformer = True if self.args.group_conv_native: native_group_conv = True - self.mnn_converter.export(vision_onnx, self.visual.quant_bit, - self.visual.quant_block, + if self.args.visual_quant_bit is not None: + quant_bit_visual = self.args.visual_quant_bit + if self.args.visual_quant_block is not None: + quant_block_visual = self.args.visual_quant_block + self.mnn_converter.export(vision_onnx, quant_bit_visual, + quant_block_visual, transformer_fuse=fuse_transformer, - group_conv_native=native_group_conv) + group_conv_native=native_group_conv, + weight_sym=self.args.visual_sym) def export_audio(self): if self.audio is None: @@ -1237,6 +1246,9 @@ def export(path, 'onnx_slim': onnx_slim, 'quant_bit': quant_bit, 'quant_block': quant_block, + 'visual_quant_bit': visual_quant_bit, + 'visual_quant_block': visual_quant_block, + 'visual_sym': visual_sym, 'lm_quant_bit': lm_quant_bit, 'mnnconvert': mnnconvert, 'ppl': ppl, @@ -1276,6 +1288,8 @@ def main(): parser.add_argument('--onnx_slim', action='store_true', help='Whether or not to use onnx-slim.') parser.add_argument('--quant_bit', type=int, default=4, help='mnn quant bit, 4 or 8, default is 4.') parser.add_argument('--quant_block', type=int, default=64, help='mnn quant block, 0 mean channle-wise, default is 64.') + parser.add_argument('--visual_quant_bit', type=int, default=None, help='mnn viusal quant bit, 4 or 8, default is setting in utils/vision.py by different vit model.') + parser.add_argument('--visual_quant_block', type=int, default=None, help='mnn quant block, default is setting in utils/vision.py by different vit model.') parser.add_argument('--lm_quant_bit', type=int, default=None, help='mnn lm_head quant bit, 4 or 8, default is `quant_bit`.') parser.add_argument('--mnnconvert', type=str, default='../../../build/MNNConvert', help='local mnnconvert path, if invalid, using pymnn.') parser.add_argument('--ppl', action='store_true', help='Whether or not to get all logits of input tokens.') @@ -1284,9 +1298,10 @@ def main(): parser.add_argument('--transformer_fuse', action='store_true', help='Whether or not to fuse vision transformer op.') parser.add_argument('--group_conv_native', action='store_true', help='Whether or not to keep native group_conv.') parser.add_argument('--smooth', action='store_true', help='Whether or not to use smooth quant.') - parser.add_argument('--sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint), defualt is False.') - parser.add_argument('--seperate_embed', action='store_true', help='For lm and embed shared model, whether or not to sepearte embed to avoid quant, defualt is False, if True, embed weight will be seperate to embeddingbf16.bin.') - parser.add_argument('--lora_split', action='store_true', help='Whether or not export lora split, defualt is False.') + parser.add_argument('--sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint), default is False.') + parser.add_argument('--visual_sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint) for visual model, default is False.') + parser.add_argument('--seperate_embed', action='store_true', help='For lm and embed shared model, whether or not to sepearte embed to avoid quant, default is False, if True, embed weight will be seperate to embeddingbf16.bin.') + parser.add_argument('--lora_split', action='store_true', help='Whether or not export lora split, default is False.') parser.add_argument('--calib_data', type=str, default=None, help='calibration data path, defaut is `None` mean not use calib data.') args = parser.parse_args() diff --git a/transformers/llm/export/utils/mnn_converter.py b/transformers/llm/export/utils/mnn_converter.py index 59aee05f..3fb79d04 100644 --- a/transformers/llm/export/utils/mnn_converter.py +++ b/transformers/llm/export/utils/mnn_converter.py @@ -51,7 +51,7 @@ class MNNConveter: os.close(log_fd) @spinner_run(f'convert onnx model to ') - def onnx2mnn(self, onnx_path, mnn_path, args = [], transformer_fuse = True, group_conv_native = False, save_external_data = True): + def onnx2mnn(self, onnx_path, mnn_path, args = [], transformer_fuse = True, group_conv_native = False, weight_sym = False, save_external_data = True): convert_args = [ '', '-f', @@ -66,6 +66,8 @@ class MNNConveter: convert_args += ['--transformerFuse'] if group_conv_native: convert_args += ['--groupConvNative'] + if weight_sym: + convert_args += ['--weightQuantAsymmetric=0'] if save_external_data: convert_args += ['--saveExternalData'] convert_args += args @@ -112,7 +114,7 @@ class MNNConveter: self.convert(convert_args) return mnn_path - def export(self, onnx_path, quant_bit = None, quant_block = None, transformer_fuse = True, group_conv_native = False): + def export(self, onnx_path, quant_bit = None, quant_block = None, transformer_fuse = True, group_conv_native = False, weight_sym = None): self.onnx_model_path = onnx_path self.mnn_name = os.path.basename(onnx_path).replace('.onnx', '.mnn') self.mnn_model_path = os.path.join(self.config.args.dst_path, self.mnn_name) @@ -133,10 +135,10 @@ class MNNConveter: ] if quant_bit == 32: quant_args = [] - self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, quant_args, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native) + self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, quant_args, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native, weight_sym=weight_sym) else: mnn_json = f'{self.mnn_model_path}.json' - self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native) + self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native, weight_sym=weight_sym) self.mnn2json(self.mnn_model_path, mnn_json) self.rebuild(mnn_json) self.json2mnn(mnn_json, self.mnn_model_path) @@ -511,4 +513,4 @@ class MNNConveter: } if name.startswith('/expert/'): post_reshape['main']['dims'] = [-1, oc] - return [pre_reshape, pre_convert, conv_op, post_convert, post_reshape] \ No newline at end of file + return [pre_reshape, pre_convert, conv_op, post_convert, post_reshape] diff --git a/transformers/llm/export/utils/transformers.py b/transformers/llm/export/utils/transformers.py index f48cb6de..5b761960 100644 --- a/transformers/llm/export/utils/transformers.py +++ b/transformers/llm/export/utils/transformers.py @@ -266,7 +266,6 @@ class Rotary(torch.nn.Module): def get_theta(): return 1.0 / (self.rope_theta ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim)) - # default rope type's theta self.theta = get_theta() # other type @@ -278,7 +277,6 @@ class Rotary(torch.nn.Module): rope_type = config.rope_scaling['type'] elif 'rope_type' in config.rope_scaling: rope_type = config.rope_scaling['rope_type'] - # gen theta for rope_type if rope_type == 'dynamic': # NTK if 'alpha' in config.rope_scaling: # NTKAlpha in Hunyuan @@ -286,7 +284,7 @@ class Rotary(torch.nn.Module): else: # NTKScaling pass self.theta = get_theta() - elif rope_type == 'yarn': # YaRN in gpt-oss + elif rope_type == 'yarn': self.is_scaled = True self.theta, self.attention_scaling = _compute_yarn_parameters( rotary_dim=self.rotary_dim,