mirror of https://github.com/alibaba/MNN.git
				
				
				
			MNN:Sync: Sync Internal 3.2.3
This commit is contained in:
		
							parent
							
								
									8f175e2748
								
							
						
					
					
						commit
						318a3de860
					
				|  | @ -376,4 +376,6 @@ datasets/* | |||
| 
 | ||||
| # qnn 3rdParty | ||||
| source/backend/qnn/3rdParty/include | ||||
| apps/Android/MnnLlmChat/release_outputs | ||||
| project/android/.cxx | ||||
| pymnn/android/.cxx/ | ||||
| pymnn/android/.cxx/abi_configuration_5u53tc49.json | ||||
|  |  | |||
|  | @ -11,8 +11,6 @@ | |||
| 
 | ||||
| 
 | ||||
| ## News 🔥 | ||||
| - [2025/08/08] Now we support [gpt-oss-20b](./apps/Android/MnnLlmChat/README.md#releases). | ||||
| - [2025/08/05] MNN Chat Android is availabe in [GooglePlay](https://play.google.com/store/apps/details?id=com.alibaba.mnnllm.android.release) ! | ||||
| - [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md)  | ||||
| <p align="center"> | ||||
|   <img width="20%" alt="Icon"  src="https://meta.alicdn.com/data/mnn/avatar/avatar_demo.gif" style="margin: 0 10px;"> | ||||
|  |  | |||
|  | @ -183,6 +183,21 @@ struct Info { | |||
| const Info* getInfo() const; | ||||
| ``` | ||||
| 
 | ||||
| ### 获取设备信息 | ||||
| 调用`getDeviceInfo`函数可获取`Device`信息,可以参考代码: | ||||
| ```cpp | ||||
| std::string soc_id, dsp_arch; | ||||
| bool success = MNN::Express::Executor::RuntimeManager::getDeviceInfo("dsp_arch", MNN_FORWARD_NN, dsp_arch); | ||||
| if(success) { | ||||
|     MNN_PRINT("Device dsp_arch: %s\n", dsp_arch.c_str()); | ||||
| } | ||||
| 
 | ||||
| success = MNN::Express::Executor::RuntimeManager::getDeviceInfo("soc_id", MNN_FORWARD_NN, soc_id); | ||||
| if(success) { | ||||
|     MNN_PRINT("Device soc_id: %s\n", soc_id.c_str()); | ||||
| } | ||||
| ``` | ||||
| 
 | ||||
| ### 执行推理 | ||||
| 调用`onForward`执行推理。 | ||||
| 
 | ||||
|  |  | |||
|  | @ -58,6 +58,10 @@ adb push ${MNN_ROOT}/source/backend/qnn/3rdParty/lib/hexagon-v${HEXAGON_ARCH}/un | |||
| adb shell "cd /data/local/tmp && LD_LIBRARY_PATH=/data/local/tmp ADSP_LIBRARY_PATH=/data/local/tmp ./MyExe.out" | ||||
| ``` | ||||
| 
 | ||||
| ### QNN量化功能说明 | ||||
| - 仅权重量化(激活是浮点):只支持Linear权重int8、channel-wise的对称量化。 | ||||
| - 激活&权重都量化:支持激活per-tensor对称量化,权重是int8/int4、channel-wise的对称量化。 | ||||
| 
 | ||||
| ## CoreML | ||||
| 适用于 Mac / iOS / iPad | ||||
| 
 | ||||
|  |  | |||
|  | @ -388,3 +388,12 @@ npu model path:./qnn_smolvlm_model.bin | |||
|   | ||||
| ./ModuleBasic.out qnn_smolvlm_model.mnn dir 0 0 10 | ||||
| ``` | ||||
| ### 生成多种QNN设备模型脚本 | ||||
| tools/script/genQNNModelsFromMNN.py中提供了8Gen1 ~ 8Elite设备的QNN模型生成脚本 | ||||
| ``` | ||||
| // 使用示例 | ||||
| cd mnn_path | ||||
| cd build | ||||
| python3 ../tools/script/genQNNModelsFromMNN.py --config_path ../source/backend/qnn/convertor/config_example/ --graph_name visual_qnn --qnn_sdk_root_path /mnt/2Tpartition/tianbu/QNN/qairt/2.37.0.250724/ --src_model visual.mnn --executable_path ./MNN2QNNModel | ||||
| ``` | ||||
| 后续将在qnn_models文件夹下生成8Gen1 ~ 8Elite设备的QNN模型产物。 | ||||
|  |  | |||
|  | @ -106,6 +106,10 @@ optional arguments: | |||
|                         mnn quant bit, 4 or 8, default is 4. | ||||
|   --quant_block QUANT_BLOCK | ||||
|                         mnn quant block, 0 mean channle-wise, default is 128. | ||||
|   --visual_quant_bit VISUAL_QUANT_BIT | ||||
|                         mnn visual model quant bit, 4 or 8, default is setting in utils/vision.py by different vit model. | ||||
|   --visual_quant_block VISUAL_QUANT_BLOCK | ||||
|                         mnn visual model quant block, 0 mean channle-wise, default is setting in utils/vision.py by different vit model. | ||||
|   --lm_quant_bit LM_QUANT_BIT | ||||
|                         mnn lm_head quant bit, 4 or 8, default is `quant_bit`. | ||||
|   --mnnconvert MNNCONVERT | ||||
|  | @ -113,10 +117,12 @@ optional arguments: | |||
|   --ppl                 Whether or not to get all logits of input tokens. | ||||
|   --awq                 Whether or not to use awq quant. | ||||
|   --sym                 Whether or not to using symmetric quant (without zeropoint), defualt is False. | ||||
|   --visual_sym          Whether or not to using symmetric quant (without zeropoint) for visual model, defualt is False. | ||||
|   --seperate_embed      For lm and embed shared model, whether or not to sepearte embed to avoid quant, defualt is False, if True, embed weight will be seperate to embeddingbf16.bin. | ||||
|   --lora_split          Whether or not export lora split, defualt is False. | ||||
| ``` | ||||
| 
 | ||||
| 
 | ||||
| ### 权重读取 | ||||
| llmexport.py 同时支持 LLM 的验证功能,有较多的依赖。在没有相应环境的情况下,MNN-LLM也提供由 safetensors 或 gguf 文件读取权重的工具,可以降低内存需求,提高转换速度。使用方法如下: | ||||
| 
 | ||||
|  | @ -166,6 +172,7 @@ python3 gguf2mnn.py --gguf ~/third/llama.cpp/build/ggml-model-Q4_K.gguf --mnn_di | |||
| ``` | ||||
| -DLLM_SUPPORT_VISION=true -DMNN_BUILD_OPENCV=true -DMNN_IMGCODECS=true | ||||
| ``` | ||||
| 
 | ||||
| - 需要开启音频功能时,增加相关编译宏 | ||||
| ``` | ||||
| -DLLM_SUPPORT_AUDIO=true -DMNN_BUILD_AUDIO=true | ||||
|  | @ -195,6 +202,12 @@ cd project/android | |||
| mkdir build_64 | ||||
| ../build_64.sh -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_ARM82=true -DMNN_OPENCL=true -DMNN_USE_LOGCAT=true | ||||
| ``` | ||||
| 高通设备部分视觉模型支持NPU功能,可增加`MNN_QNN` 和`MNN_WITH_PLUGIN`的宏启用QNN功能。 | ||||
| ``` | ||||
| cd project/android | ||||
| mkdir build_64 | ||||
| ../build_64.sh -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_ARM82=true -DMNN_OPENCL=true -DMNN_QNN=true -DMNN_WITH_PLUGIN=true -DMNN_USE_LOGCAT=true | ||||
| ``` | ||||
| 
 | ||||
| #### iOS: 参考 transformers/llm/engine/ios/README.md | ||||
| ``` | ||||
|  |  | |||
|  | @ -281,6 +281,17 @@ bool Executor::RuntimeManager::getInfo(Interpreter::SessionInfoCode code, void* | |||
|     return false; | ||||
| } | ||||
| 
 | ||||
| bool Executor::RuntimeManager::getDeviceInfo(const std::string& deviceKey, const MNNForwardType type, std::string& deviceValue) { | ||||
|     auto creator = MNNGetExtraRuntimeCreator(type); | ||||
|     if (creator != nullptr) { | ||||
|         auto res = creator->onGetDeviceInfo(deviceKey, deviceValue); | ||||
|         if(res) { | ||||
|             return true; | ||||
|         } | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| 
 | ||||
| Executor::RuntimeManager::RuntimeManager() { | ||||
|     mInside = new RuntimeAttr; | ||||
|     mInside->mContent.reset(new RuntimeAttr::Immutable); | ||||
|  |  | |||
|  | @ -130,6 +130,7 @@ public: | |||
|         void setHint(Interpreter::HintMode mode, int* value, size_t size); | ||||
|         void setHintPtr(Interpreter::HintMode mode, void* value); | ||||
|         bool getInfo(Interpreter::SessionInfoCode code, void* ptr); | ||||
|         static bool getDeviceInfo(const std::string& deviceKey, const MNNForwardType type, std::string& deviceValue); | ||||
|         BackendConfig* getBnConfig(); | ||||
|         const RuntimeAttr* getInside() const { | ||||
|             return mInside; | ||||
|  |  | |||
|  | @ -25,3 +25,9 @@ set_target_properties( MNNOpenCV | |||
|                 PROPERTIES IMPORTED_LOCATION | ||||
|                 ${CMAKE_CURRENT_LIST_DIR}/libs/${ANDROID_ABI}/libMNNOpenCV.so | ||||
|                 ) | ||||
| 
 | ||||
| add_library(MNN_LLM  SHARED IMPORTED GLOBAL ) | ||||
| set_target_properties(MNN_LLM | ||||
|             PROPERTIES IMPORTED_LOCATION | ||||
|             ${CMAKE_CURRENT_LIST_DIR}/libs/${ANDROID_ABI}/libllm.so | ||||
| ) | ||||
|  | @ -3,4 +3,4 @@ distributionBase=GRADLE_USER_HOME | |||
| distributionPath=wrapper/dists | ||||
| zipStoreBase=GRADLE_USER_HOME | ||||
| zipStorePath=wrapper/dists | ||||
| distributionUrl=http\://mtl-gradle-mirror.oss-cn-hangzhou.aliyuncs.com/gradle-4.6-all.zip | ||||
| distributionUrl=http://mtl-gradle-mirror.oss-cn-hangzhou.aliyuncs.com/gradle-6.7.1-all.zip | ||||
|  | @ -20,23 +20,48 @@ afterEvaluate { | |||
| android.libraryVariants.all {variant -> | ||||
|     def name = variant.buildType.name | ||||
|     def zipTask = tasks.create(name: "zipNative${name.capitalize()}", type: Zip) { | ||||
|         from project.zipTree("${variant.packageLibrary.destinationDir.path}/${project.name}-${name}.aar") | ||||
|         from project.zipTree("${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}.aar") | ||||
|         from "$buildDir/native-packed" | ||||
|         archiveName "${project.name}-${name}-tmp.aar" | ||||
|         archiveName "${artifactName}-${name}-tmp.aar" | ||||
|         destinationDir(variant.packageLibrary.destinationDir) | ||||
|         inputs.dir("$buildDir/native-packed") | ||||
|         inputs.file("${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}.aar") | ||||
|         outputs.file("${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}.aar") | ||||
|     } | ||||
|      | ||||
|     zipTask.dependsOn("externalNativeBuild${name.capitalize()}") | ||||
|      | ||||
|     // Ensure QNN dependencies are prepared when BUILD_QNN is enabled | ||||
|     if (project.ext.has('BUILD_QNN') && project.ext.BUILD_QNN) { | ||||
|         zipTask.dependsOn('prepareQnnDeps') | ||||
|     } | ||||
|      | ||||
|     zipTask.doLast { | ||||
|         copy { | ||||
|             from "${variant.packageLibrary.destinationDir.path}/${project.name}-${name}-tmp.aar" | ||||
|             from "${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}-tmp.aar" | ||||
|             into "${variant.packageLibrary.destinationDir.path}" | ||||
|             rename{ | ||||
|                 String fileName-> | ||||
|                     fileName.replace('-tmp','') | ||||
|             } | ||||
|         } | ||||
|         delete "${variant.packageLibrary.destinationDir.path}/${project.name}-${name}-tmp.aar" | ||||
|         delete "${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}-tmp.aar" | ||||
|          | ||||
|         println "Generated final AAR: ${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}.aar" | ||||
|     } | ||||
|     tasks.findByName("assemble${name.capitalize()}").doLast { | ||||
|         zipTask.execute() | ||||
|      | ||||
|     def bundleTask = tasks.findByName("bundle${name.capitalize()}Aar") | ||||
|     if (bundleTask != null) { | ||||
|         zipTask.dependsOn(bundleTask) | ||||
|     } | ||||
|      | ||||
|     def assembleTask = tasks.findByName("assemble${name.capitalize()}") | ||||
|     if (assembleTask != null) { | ||||
|         assembleTask.dependsOn(zipTask) | ||||
|         zipTask.mustRunAfter(bundleTask) | ||||
|     } | ||||
|     def packageTask = tasks.findByName("package${name.capitalize()}Aar") | ||||
|     if (packageTask != null) { | ||||
|         packageTask.finalizedBy(zipTask) | ||||
|     } | ||||
| } | ||||
|  | @ -0,0 +1,174 @@ | |||
| // QNN Dependencies Preparation | ||||
| // This gradle file handles downloading and preparing QNN dependencies | ||||
| 
 | ||||
| // QNN configuration | ||||
| ext { | ||||
|     // QNN download settings | ||||
|     QNN_LIBS_URL = 'http://meta.alicdn.com/data/mnn/libs/qnn_inc_libs.zip' | ||||
| } | ||||
| 
 | ||||
| def qnnZipName = 'qnn_inc_libs.zip' | ||||
| def qnnZipFile = new File(buildDir, qnnZipName) | ||||
| def qnnTmpDir = new File(buildDir, 'qnn_tmp') | ||||
| 
 | ||||
| task prepareQnnDeps { | ||||
|     group = 'build setup' | ||||
|     description = 'Download and extract QNN include/lib into source/backend/qnn/3rdParty when BUILD_QNN is enabled.' | ||||
|     onlyIf {  | ||||
|         project.ext.has('BUILD_QNN') && project.ext.BUILD_QNN | ||||
|     } | ||||
|      | ||||
|     // Define inputs and outputs for incremental build | ||||
|     inputs.property("qnn_libs_url", { QNN_LIBS_URL }) | ||||
|     inputs.property("build_qnn", { project.ext.has('BUILD_QNN') && project.ext.BUILD_QNN }) | ||||
|     outputs.dir(new File(project.rootDir, '../../source/backend/qnn/3rdParty')) | ||||
|     outputs.dir(new File(buildDir, 'native-packed/native/libs')) | ||||
|      | ||||
|     doLast { | ||||
|         println "BUILD_QNN is ON. Preparing QNN dependencies..." | ||||
|          | ||||
|         def url = QNN_LIBS_URL | ||||
|         println "Downloading QNN dependencies from ${url}" | ||||
|          | ||||
|         if (!qnnZipFile.exists()) { | ||||
|             println "Downloading ${qnnZipName} ..." | ||||
|             qnnZipFile.parentFile.mkdirs() | ||||
|              | ||||
|             try { | ||||
|                 new java.net.URL(url).withInputStream { inputStream -> | ||||
|                     qnnZipFile.withOutputStream { outputStream -> | ||||
|                         outputStream << inputStream | ||||
|                     } | ||||
|                 } | ||||
|                 println "Downloaded to: ${qnnZipFile.absolutePath}" | ||||
|             } catch (Exception e) { | ||||
|                 throw new RuntimeException("Failed to download QNN dependencies from ${url}. Error: ${e.message}") | ||||
|             } | ||||
|         } else { | ||||
|             println "Using cached zip: ${qnnZipFile.absolutePath}" | ||||
|         } | ||||
| 
 | ||||
|         // Clean temp dir and unpack | ||||
|         project.delete(qnnTmpDir) | ||||
|         qnnTmpDir.mkdirs() | ||||
|         copy { | ||||
|             from zipTree(qnnZipFile) | ||||
|             into qnnTmpDir | ||||
|         } | ||||
| 
 | ||||
|         // Find the extracted QNN directory | ||||
|         def extractedQnnDir = findQnnDirectory(qnnTmpDir) | ||||
|         if (extractedQnnDir == null) { | ||||
|             throw new RuntimeException("Failed to find QNN directory structure in ${qnnZipFile.name}") | ||||
|         } | ||||
|          | ||||
|         copyQnnFiles(extractedQnnDir) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| def findQnnDirectory(File searchDir) { | ||||
|     // Look for directory containing both include and jniLibs (or lib) directories | ||||
|     def candidates = [] | ||||
|      | ||||
|     searchDir.eachDirRecurse { dir -> | ||||
|         def hasInclude = new File(dir, 'include').exists() | ||||
|         def hasLibs = new File(dir, 'jniLibs').exists() || new File(dir, 'lib').exists() | ||||
|          | ||||
|         if (hasInclude && hasLibs) { | ||||
|             candidates.add(dir) | ||||
|         } | ||||
|     } | ||||
|      | ||||
|     if (candidates.isEmpty()) { | ||||
|         // Fallback: look for directory that contains include | ||||
|         searchDir.eachDirRecurse { dir -> | ||||
|             if (new File(dir, 'include').exists()) { | ||||
|                 candidates.add(dir) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|      | ||||
|     return candidates.isEmpty() ? null : candidates[0] | ||||
| } | ||||
| 
 | ||||
| def copyQnnFiles(File sourceDir) { | ||||
|     // Resolve destination directories | ||||
|     def qnnRoot = new File(project.rootDir, '../../source/backend/qnn/3rdParty') | ||||
|     def includeDir = new File(qnnRoot, 'include') | ||||
|     def libDir = new File(qnnRoot, 'lib') | ||||
|      | ||||
|     includeDir.mkdirs() | ||||
|     libDir.mkdirs() | ||||
| 
 | ||||
|     // Copy include files | ||||
|     def sourceInclude = new File(sourceDir, 'include') | ||||
|     if (sourceInclude.exists()) { | ||||
|         copy { | ||||
|             from sourceInclude | ||||
|             into includeDir | ||||
|         } | ||||
|         println "QNN includes copied to: ${includeDir.absolutePath}" | ||||
|     } else { | ||||
|         throw new RuntimeException("Include directory not found in ${sourceDir.absolutePath}") | ||||
|     } | ||||
| 
 | ||||
|     // Copy library files - try both jniLibs and lib directories | ||||
|     def sourceLibs = new File(sourceDir, 'jniLibs') | ||||
|     if (!sourceLibs.exists()) { | ||||
|         sourceLibs = new File(sourceDir, 'lib') | ||||
|     } | ||||
|      | ||||
|     if (sourceLibs.exists()) { | ||||
|         copy { | ||||
|             from sourceLibs | ||||
|             into libDir | ||||
|         } | ||||
|         println "QNN libs copied to: ${libDir.absolutePath}" | ||||
|          | ||||
|         // Also copy QNN .so files to native-packed for AAR packaging | ||||
|         copyQnnLibsToNativePacked(sourceLibs) | ||||
|     } else { | ||||
|         println "Warning: No lib/jniLibs directory found in ${sourceDir.absolutePath}" | ||||
|     } | ||||
|      | ||||
|     println "QNN dependencies preparation completed successfully." | ||||
| } | ||||
| 
 | ||||
| def copyQnnLibsToNativePacked(File sourceLibsDir) { | ||||
|     // Create native-packed directory structure for QNN .so files | ||||
|     def nativePackedLibsDir = new File(buildDir, 'native-packed/native/libs') | ||||
|     nativePackedLibsDir.mkdirs() | ||||
|      | ||||
|     println "Copying QNN .so files to native-packed for AAR packaging..." | ||||
|      | ||||
|     // Copy all .so files from QNN libs to native-packed | ||||
|     sourceLibsDir.eachFileRecurse { file -> | ||||
|         if (file.name.endsWith('.so')) { | ||||
|             // Determine the ABI directory (arm64-v8a, armeabi-v7a, etc.) | ||||
|             def relativePath = sourceLibsDir.toPath().relativize(file.toPath()) | ||||
|             def targetFile = new File(nativePackedLibsDir, relativePath.toString()) | ||||
|              | ||||
|             // Create parent directories if they don't exist | ||||
|             targetFile.parentFile.mkdirs() | ||||
|              | ||||
|             // Copy the .so file | ||||
|             copy { | ||||
|                 from file | ||||
|                 into targetFile.parentFile | ||||
|             } | ||||
|              | ||||
|             println "Copied QNN lib: ${file.name} -> ${targetFile.absolutePath}" | ||||
|         } | ||||
|     } | ||||
|      | ||||
|     println "QNN .so files copied to native-packed directory" | ||||
| } | ||||
| 
 | ||||
| // Ensure preparation runs before compilation when enabled | ||||
| if (project.ext.has('BUILD_QNN') && project.ext.BUILD_QNN) { | ||||
|     afterEvaluate { | ||||
|         if (tasks.findByName('preBuild')) { | ||||
|             preBuild.dependsOn prepareQnnDeps | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | @ -480,7 +480,6 @@ | |||
| 		92FF02E023AA0B5A00AC97F6 /* MNNMatrixSub.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016023AA0B4E00AC97F6 /* MNNMatrixSub.S */; }; | ||||
| 		92FF02E123AA0B5A00AC97F6 /* MNNPowC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016123AA0B4E00AC97F6 /* MNNPowC8.S */; }; | ||||
| 		92FF02E223AA0B5A00AC97F6 /* MNNMatrixAdd.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */; }; | ||||
| 		92FF02E323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */; }; | ||||
| 		92FF02E523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */; }; | ||||
| 		92FF02E723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */; }; | ||||
| 		92FF02E823AA0B5A00AC97F6 /* MNNSamplerC1BilinearOpt.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */; }; | ||||
|  | @ -520,7 +519,6 @@ | |||
| 		92FF032023AA0B5A00AC97F6 /* MNNMatrixSub.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A123AA0B4E00AC97F6 /* MNNMatrixSub.S */; }; | ||||
| 		92FF032123AA0B5A00AC97F6 /* MNNPowC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A223AA0B4E00AC97F6 /* MNNPowC8.S */; }; | ||||
| 		92FF032223AA0B5A00AC97F6 /* MNNMatrixAdd.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */; }; | ||||
| 		92FF032323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */; }; | ||||
| 		92FF032523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */; }; | ||||
| 		92FF032723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */; }; | ||||
| 		92FF032823AA0B5A00AC97F6 /* MNNSamplerC1BilinearOpt.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */; }; | ||||
|  | @ -1315,7 +1313,6 @@ | |||
| 		92FF016023AA0B4E00AC97F6 /* MNNMatrixSub.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixSub.S; sourceTree = "<group>"; }; | ||||
| 		92FF016123AA0B4E00AC97F6 /* MNNPowC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPowC8.S; sourceTree = "<group>"; }; | ||||
| 		92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixAdd.S; sourceTree = "<group>"; }; | ||||
| 		92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNExpC8.S; sourceTree = "<group>"; }; | ||||
| 		92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNConvDwF23SourceTransUnit.S; sourceTree = "<group>"; }; | ||||
| 		92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNDeconvRunForUnitDepthWise.S; sourceTree = "<group>"; }; | ||||
| 		92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNSamplerC1BilinearOpt.S; sourceTree = "<group>"; }; | ||||
|  | @ -1355,7 +1352,6 @@ | |||
| 		92FF01A123AA0B4E00AC97F6 /* MNNMatrixSub.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixSub.S; sourceTree = "<group>"; }; | ||||
| 		92FF01A223AA0B4E00AC97F6 /* MNNPowC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPowC8.S; sourceTree = "<group>"; }; | ||||
| 		92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixAdd.S; sourceTree = "<group>"; }; | ||||
| 		92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNExpC8.S; sourceTree = "<group>"; }; | ||||
| 		92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNConvDwF23SourceTransUnit.S; sourceTree = "<group>"; }; | ||||
| 		92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNDeconvRunForUnitDepthWise.S; sourceTree = "<group>"; }; | ||||
| 		92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNSamplerC1BilinearOpt.S; sourceTree = "<group>"; }; | ||||
|  | @ -2591,7 +2587,6 @@ | |||
| 				92FF016023AA0B4E00AC97F6 /* MNNMatrixSub.S */, | ||||
| 				92FF016123AA0B4E00AC97F6 /* MNNPowC8.S */, | ||||
| 				92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */, | ||||
| 				92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */, | ||||
| 				92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */, | ||||
| 				92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */, | ||||
| 				92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */, | ||||
|  | @ -2677,7 +2672,6 @@ | |||
| 				92FF01A123AA0B4E00AC97F6 /* MNNMatrixSub.S */, | ||||
| 				92FF01A223AA0B4E00AC97F6 /* MNNPowC8.S */, | ||||
| 				92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */, | ||||
| 				92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */, | ||||
| 				92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */, | ||||
| 				92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */, | ||||
| 				92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */, | ||||
|  | @ -3279,7 +3273,6 @@ | |||
| 				92FF02C223AA0B5A00AC97F6 /* MNNLoadU8AndSum.S in Sources */, | ||||
| 				4819FB2E24C1396A0050BD09 /* GeometryLSTM.cpp in Sources */, | ||||
| 				9558334B29B09A7B00488807 /* MNNGeluFP16.S in Sources */, | ||||
| 				92FF02E323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */, | ||||
| 				92FF030223AA0B5A00AC97F6 /* MNNQuanToDestUint8.S in Sources */, | ||||
| 				489D7AA12550FDC900AD896A /* MetalUnary.mm in Sources */, | ||||
| 				92FF037323AA0B5A00AC97F6 /* CPUEltwiseInt8.cpp in Sources */, | ||||
|  | @ -3438,7 +3431,6 @@ | |||
| 				92FF030723AA0B5A00AC97F6 /* MNNBlitC1ToFloatRGBA.S in Sources */, | ||||
| 				92FF03A423AA0B5A00AC97F6 /* OptimizedComputer.cpp in Sources */, | ||||
| 				92FF032E23AA0B5A00AC97F6 /* MNNReluWithSlopeChannel.S in Sources */, | ||||
| 				92FF032323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */, | ||||
| 				95278CEA2B9F09C0009E9B29 /* ShapeDynamicQuant.cpp in Sources */, | ||||
| 				92FF044C23AA0B7100AC97F6 /* ShapePool3D.cpp in Sources */, | ||||
| 				92FF029823AA0B5A00AC97F6 /* CPUTFQuantizedConv2D.cpp in Sources */, | ||||
|  |  | |||
|  | @ -17,6 +17,7 @@ option(PYMNN_INTERNAL_SERVING "Internal use only." OFF) | |||
| option(PYMNN_OPENCV_API "MNN OpenCV API be exposed" ON) | ||||
| option(PYMNN_IMGCODECS "MNN IMGCODECS API be exposed" OFF) | ||||
| option(PYMNN_AUDIO_API "MNN Audio API be exposed" OFF) | ||||
| option(PYMNN_LLM_API "MNN LLM API be exposed" OFF) | ||||
| option(PYMNN_OHOS_INTERNAL "compile for harmony internal." OFF) | ||||
| option(PYMNN_LLM_COLLECTION "offline llm data collection." OFF) | ||||
| 
 | ||||
|  | @ -97,6 +98,10 @@ if(PYMNN_AUDIO_API) | |||
|     target_compile_definitions(mnnpybridge PRIVATE PYMNN_AUDIO_API) | ||||
| endif() | ||||
| 
 | ||||
| if(PYMNN_LLM_API) | ||||
|     target_compile_definitions(mnnpybridge PRIVATE PYMNN_LLM_API) | ||||
| endif() | ||||
| 
 | ||||
| if(PYMNN_INTERNAL_SERVING) | ||||
|     message(STATUS "mnnpybridge define PYMNN_INTERNAL_SERVING") | ||||
|     target_compile_definitions(mnnpybridge PRIVATE PYMNN_INTERNAL_SERVING) | ||||
|  | @ -214,6 +219,9 @@ else() | |||
|         if(PYMNN_NUMPY_USABLE) | ||||
|             target_link_libraries(mnnpybridge PRIVATE numpy_python) | ||||
|         endif() | ||||
|         if(PYMNN_LLM_API) | ||||
|             target_link_libraries(mnnpybridge PRIVATE MNN_LLM) | ||||
|         endif() | ||||
|     endif() | ||||
| endif() | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										112
									
								
								pymnn/src/llm.h
								
								
								
								
							
							
						
						
									
										112
									
								
								pymnn/src/llm.h
								
								
								
								
							|  | @ -1,6 +1,10 @@ | |||
| #include <sstream> | ||||
| #include <iostream> | ||||
| #ifdef BUILD_FOR_IOS | ||||
| #include "MNN/llm/llm.hpp" | ||||
| #else | ||||
| #include "llm/llm.hpp" | ||||
| 
 | ||||
| #endif | ||||
| #ifdef PYMNN_LLM_COLLECTION | ||||
| #include "cpp/getLinearInput.hpp" | ||||
| #endif | ||||
|  | @ -85,18 +89,87 @@ static PyObject* PyMNNLLM_getCurrentHistory(LLM *self, PyObject *args) { | |||
| } | ||||
| static PyObject* PyMNNLLM_response(LLM *self, PyObject *args) { | ||||
|     if (self->is_embedding) { | ||||
|         MNN_PRINT("[MNNLLM] response: is_embedding\n"); | ||||
|         Py_RETURN_NONE; | ||||
|     } | ||||
|     const char* query = NULL; | ||||
|     int stream = 0; | ||||
|     if (!PyArg_ParseTuple(args, "s|p", &query, &stream)) { | ||||
|         Py_RETURN_NONE; | ||||
|     } | ||||
|     std::ostringstream null_os; | ||||
|     self->llm->response(query, stream ? &std::cout : &null_os); | ||||
|     return string2Object(null_os.str()); | ||||
| } | ||||
|      | ||||
|     PyObject* content = nullptr; | ||||
|     int stream = 0; | ||||
|     int max_new_tokens = 2048; | ||||
|      | ||||
|     if (!PyArg_ParseTuple(args, "O|ii", &content, &stream, &max_new_tokens)) { | ||||
|         MNN_PRINT("[MNNLLM] response: PyArg_ParseTuple failed\n"); | ||||
|         Py_RETURN_NONE; | ||||
|     } | ||||
|      | ||||
|     std::ostringstream null_os; | ||||
|     std::ostream* output_stream = stream ? &std::cout : &null_os; | ||||
|      | ||||
|     if (isString(content)) { | ||||
|         std::string text = object2String(content); | ||||
|         MNN_PRINT("[MNNLLM] response: text=%s, stream=%d, max_new_tokens=%d\n", text.c_str(), stream, max_new_tokens); | ||||
|         self->llm->response(text, output_stream, nullptr, max_new_tokens); | ||||
|     } else if (isPyDict(content)) { | ||||
|         MNN::Transformer::MultimodalPrompt multimodal_input; | ||||
|         PyObject* text_obj = PyDict_GetItemString(content, "text"); | ||||
|         if (text_obj && isString(text_obj)) { | ||||
|             multimodal_input.prompt_template = object2String(text_obj); | ||||
|         } | ||||
|         PyObject* images_obj = PyDict_GetItemString(content, "images"); | ||||
|         if (images_obj && PyList_Check(images_obj)) { | ||||
|             Py_ssize_t img_count = PyList_Size(images_obj); | ||||
|             for (Py_ssize_t i = 0; i < img_count; i++) { | ||||
|                 PyObject* img_dict = PyList_GetItem(images_obj, i); | ||||
|                 if (isPyDict(img_dict)) { | ||||
|                     PyObject* data_obj = PyDict_GetItemString(img_dict, "data"); | ||||
|                     PyObject* width_obj = PyDict_GetItemString(img_dict, "width"); | ||||
|                     PyObject* height_obj = PyDict_GetItemString(img_dict, "height"); | ||||
|                      | ||||
|                     if (data_obj && width_obj && height_obj) { | ||||
|                         MNN::Transformer::PromptImagePart image_part; | ||||
|                         image_part.image_data = toVar(data_obj); | ||||
|                         image_part.width = PyLong_AsLong(width_obj); | ||||
|                         image_part.height = PyLong_AsLong(height_obj); | ||||
|                          | ||||
|                         std::string key = "image_" + std::to_string(i); | ||||
|                         multimodal_input.images[key] = image_part; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|          | ||||
|         PyObject* audios_obj = PyDict_GetItemString(content, "audios"); | ||||
|         if (audios_obj && PyList_Check(audios_obj)) { | ||||
|             Py_ssize_t audio_count = PyList_Size(audios_obj); | ||||
|             for (Py_ssize_t i = 0; i < audio_count; i++) { | ||||
|                 PyObject* audio_dict = PyList_GetItem(audios_obj, i); | ||||
|                 if (isPyDict(audio_dict)) { | ||||
|                     MNN::Transformer::PromptAudioPart audio_part; | ||||
|                      | ||||
|                     PyObject* file_path_obj = PyDict_GetItemString(audio_dict, "file_path"); | ||||
|                     if (file_path_obj && isString(file_path_obj)) { | ||||
|                         audio_part.file_path = object2String(file_path_obj); | ||||
|                     } | ||||
|                      | ||||
|                     PyObject* waveform_obj = PyDict_GetItemString(audio_dict, "waveform"); | ||||
|                     if (waveform_obj) { | ||||
|                         audio_part.waveform = toVar(waveform_obj); | ||||
|                     } | ||||
|                      | ||||
|                     std::string key = "audio_" + std::to_string(i); | ||||
|                     multimodal_input.audios[key] = audio_part; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         MNN_PRINT("[MNNLLM] response: multimodal, stream=%d, max_new_tokens=%d\n", stream, max_new_tokens); | ||||
|         self->llm->response(multimodal_input, output_stream, nullptr, max_new_tokens); | ||||
|     } else { | ||||
|         PyMNN_ERROR("content must be str or dict"); | ||||
|     } | ||||
|     std::string response_str = null_os.str(); | ||||
|     MNN_PRINT("[MNNLLM] response: %s\n", response_str.c_str()); | ||||
|     return string2Object(response_str); | ||||
| } | ||||
| static PyObject* PyMNNLLM_tokenizer_encode(LLM *self, PyObject *args) { | ||||
|     if (self->is_embedding) { | ||||
|         Py_RETURN_NONE; | ||||
|  | @ -149,6 +222,14 @@ static PyObject* PyMNNLLM_reset(LLM *self, PyObject *args) { | |||
|     Py_RETURN_NONE; | ||||
| } | ||||
| 
 | ||||
| static PyObject* PyMNNLLM_get_statistics(LLM *self, PyObject *args) { | ||||
|     if (self->is_embedding) { | ||||
|         Py_RETURN_NONE; | ||||
|     } | ||||
|     auto statistics = self->llm->get_statistics(); | ||||
|     return string2Object(statistics); | ||||
| } | ||||
| 
 | ||||
| #ifdef PYMNN_LLM_COLLECTION | ||||
| static PyObject* PyMNNLLM_enable_collection_mode(LLM *self, PyObject *args) { | ||||
|     if (self->is_embedding) { | ||||
|  | @ -205,12 +286,11 @@ static PyObject* PyMNNLLM_enable_collection_mode(LLM *self, PyObject *args) { | |||
|     return toPyObj(true);   | ||||
| } | ||||
| #endif | ||||
| 
 | ||||
| static PyMethodDef PyMNNLLM_methods[] = { | ||||
|     {"load", (PyCFunction)PyMNNLLM_load, METH_VARARGS, "load model."}, | ||||
|     {"forward", (PyCFunction)PyMNNLLM_forward, METH_VARARGS, "forward `logits` by `input_ids`."}, | ||||
|     {"generate", (PyCFunction)PyMNNLLM_generate, METH_VARARGS, "generate `output_ids` by `input_ids`."}, | ||||
|     {"response", (PyCFunction)PyMNNLLM_response, METH_VARARGS, "response `query` without hsitory."}, | ||||
|     {"response", (PyCFunction)PyMNNLLM_response, METH_VARARGS, "response `query` - supports both text and multimodal input."}, | ||||
|     {"get_current_history", (PyCFunction)PyMNNLLM_getCurrentHistory, METH_VARARGS, "Get Current History."}, | ||||
|     {"erase_history", (PyCFunction)PyMNNLLM_eraseHistory, METH_VARARGS, "Erase History."}, | ||||
|     {"tokenizer_encode", (PyCFunction)PyMNNLLM_tokenizer_encode, METH_VARARGS, "tokenizer encode."}, | ||||
|  | @ -219,6 +299,7 @@ static PyMethodDef PyMNNLLM_methods[] = { | |||
|     {"create_lora", (PyCFunction)PyMNNLLM_create_lora, METH_VARARGS, "create_lora."}, | ||||
|     {"set_config", (PyCFunction)PyMNNLLM_set_config, METH_VARARGS, "set_config."}, | ||||
|     {"reset", (PyCFunction)PyMNNLLM_reset, METH_VARARGS, "reset."}, | ||||
|     {"get_statistics", (PyCFunction)PyMNNLLM_get_statistics, METH_VARARGS, "get performance statistics."}, | ||||
| #ifdef PYMNN_LLM_COLLECTION | ||||
|     {"enable_collection_mode", (PyCFunction)PyMNNLLM_enable_collection_mode, METH_VARARGS, "Enable data collection mode."}, | ||||
| #endif | ||||
|  | @ -274,7 +355,7 @@ static PyObject* PyMNNLLM_create_lora(LLM *self, PyObject *args) { | |||
|         Py_RETURN_NONE; | ||||
|     } | ||||
|     auto lora = self->llm->create_lora(path); | ||||
|     LLM *llm = (LLM *)PyObject_Call((PyObject*)&PyMNNLLM, PyTuple_New(0), NULL); | ||||
|     LLM *llm = (LLM *)PyObject_Call((PyObject*)PyType_FindTLSType(&PyMNNLLM), PyTuple_New(0), NULL); | ||||
|     if (!llm) { | ||||
|         return NULL; | ||||
|     } | ||||
|  | @ -288,10 +369,11 @@ static PyObject* PyMNNLLM_create(PyObject *self, PyObject *args) { | |||
|     } | ||||
|     const char* path = NULL; | ||||
|     int embedding_model = 0; | ||||
|     if (!PyArg_ParseTuple(args, "s|p", &path, &embedding_model)) { | ||||
|     if (!PyArg_ParseTuple(args, "s|i", &path, &embedding_model)) { | ||||
|         PyMNN_ERROR_LOG("Invalid arguments. Usage: create(path, embedding_model=False)"); | ||||
|         return NULL; | ||||
|     } | ||||
|     LLM *llm = (LLM *)PyObject_Call((PyObject*)&PyMNNLLM, PyTuple_New(0), NULL); | ||||
|     LLM *llm = (LLM *)PyObject_Call((PyObject*)PyType_FindTLSType(&PyMNNLLM), PyTuple_New(0), NULL); | ||||
|     if (!llm) { | ||||
|         return NULL; | ||||
|     } | ||||
|  |  | |||
|  | @ -398,6 +398,9 @@ static inline bool isPySequence(PyObject* obj) { | |||
|     // use isPySequence replace PySequence_Check
 | ||||
|     return PyTuple_Check(obj) || PyList_Check(obj) || PyBytes_Check(obj); | ||||
| } | ||||
| static inline bool isPyDict(PyObject* obj) { | ||||
|     return PyDict_Check(obj); | ||||
| } | ||||
| static inline int PySequenceSize(PyObject* obj) { | ||||
|     if (PyTuple_Check(obj)) return PyTuple_Size(obj); | ||||
|     if (PyList_Check(obj)) return PyList_Size(obj); | ||||
|  |  | |||
|  | @ -218,7 +218,7 @@ static void Sme2MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h, | |||
| } | ||||
| 
 | ||||
| static void MNNScaleAndAddBiasFP16(FLOAT16* dst, const FLOAT16* src, const FLOAT16* bias, const FLOAT16* alpha, size_t planeNumber, | ||||
|                         size_t biasNumber) { | ||||
|                                    size_t biasNumber) { | ||||
|     for (int z = 0; z < biasNumber; ++z) { | ||||
|         FLOAT16* dstZ         = dst + planeNumber * 8 * z; | ||||
|         const FLOAT16* srcZ   = src + planeNumber * 8 * z; | ||||
|  | @ -412,7 +412,7 @@ static void MNNAxByClampBroadcastC8FP16(float* CF, const float* AF, const float* | |||
| } | ||||
| 
 | ||||
| void ARM82StrassenMerge(FLOAT16* c11, FLOAT16* c12, FLOAT16* c21, FLOAT16* c22, FLOAT16* xAddr, | ||||
|                           size_t cStride, size_t eSub, size_t hSub) { | ||||
|                         size_t cStride, size_t eSub, size_t hSub) { | ||||
|     const int pack = 8; | ||||
|     for (int y = 0; y < hSub; ++y) { | ||||
|         auto c11Y = c11 + y * cStride; | ||||
|  | @ -560,7 +560,7 @@ void MNNPackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, size | |||
| } | ||||
| 
 | ||||
| static void _MNNDeconvRunForUnitDepthWise(const FLOAT16* dst, FLOAT16* src, const FLOAT16* weight, size_t fw, size_t fh, | ||||
|                                   size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) { | ||||
|                                           size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) { | ||||
|     int fx, fy; | ||||
|     auto src_z          = src; | ||||
|     auto weight_z = weight; | ||||
|  | @ -576,7 +576,7 @@ static void _MNNDeconvRunForUnitDepthWise(const FLOAT16* dst, FLOAT16* src, cons | |||
|     } | ||||
| } | ||||
| static void _MNNDeconvRunForLineDepthwise(const FLOAT16* dst, FLOAT16* src, const FLOAT16* weight, size_t width, size_t src_w_setup, | ||||
|                                   size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step) { | ||||
|                                           size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step) { | ||||
|     int dx; | ||||
|     for (dx = 0; dx < width; ++dx) { | ||||
|         auto dst_x = dst + dx * 8; | ||||
|  | @ -1423,6 +1423,730 @@ static void Sme2MNNPackForMatMul_A_FP16(float* destOrigin, float const** sourceG | |||
|     } | ||||
| } | ||||
| 
 | ||||
| #ifdef MNN_SUPPORT_TRANSFORMER_FUSE | ||||
| static void MNNAttenPackAndScaleSingleHead(float* dst, const float* srcHeadBase, size_t srcRowStride, const float* scale, const int32_t* units, size_t seqLen, size_t headDim) { | ||||
|     const int32_t eP = units[0]; | ||||
|     const int32_t lP = units[1]; | ||||
| 
 | ||||
|     if (lP != 1 && lP != 2) { | ||||
|         MNN_ERROR("This function only supports lP=1 or 2\n"); | ||||
|         return; | ||||
|     } | ||||
| 
 | ||||
|     const float scaleVal = scale[0]; | ||||
|     const float16x8_t vScale = vdupq_n_f16(scaleVal); | ||||
| 
 | ||||
|     const size_t packedHeadDim = UP_DIV(headDim, lP); | ||||
|     const size_t dstStrideDOuter = (size_t)eP * lP; | ||||
|     const size_t dstStrideSOuter = packedHeadDim * dstStrideDOuter; | ||||
| 
 | ||||
|     for (int s = 0; s < seqLen; ++s) { | ||||
|         const int sOuter = s / eP; | ||||
|         const int sInner = s % eP; | ||||
|         const FLOAT16* srcRowPtr = (FLOAT16*)srcHeadBase + s * srcRowStride; | ||||
|         FLOAT16* dstBasePtr = (FLOAT16*)dst + sOuter * dstStrideSOuter + sInner * lP; | ||||
|          | ||||
|         if (lP == 1) { | ||||
|             size_t d = 0; | ||||
|             for (; d + 7 < headDim; d += 8) { | ||||
|                 float16x8_t sVec = vld1q_f16(srcRowPtr + d); | ||||
|                 sVec = vmulq_f16(sVec, vScale); | ||||
|                  | ||||
|                 dstBasePtr[(d + 0) * dstStrideDOuter] = sVec[0]; | ||||
|                 dstBasePtr[(d + 1) * dstStrideDOuter] = sVec[1]; | ||||
|                 dstBasePtr[(d + 2) * dstStrideDOuter] = sVec[2]; | ||||
|                 dstBasePtr[(d + 3) * dstStrideDOuter] = sVec[3]; | ||||
|                 dstBasePtr[(d + 4) * dstStrideDOuter] = sVec[4]; | ||||
|                 dstBasePtr[(d + 5) * dstStrideDOuter] = sVec[5]; | ||||
|                 dstBasePtr[(d + 6) * dstStrideDOuter] = sVec[6]; | ||||
|                 dstBasePtr[(d + 7) * dstStrideDOuter] = sVec[7]; | ||||
|             } | ||||
|             for (; d < headDim; ++d) { | ||||
|                 dstBasePtr[d * dstStrideDOuter] = srcRowPtr[d] * scaleVal; | ||||
|             } | ||||
|         } else { // lP == 2
 | ||||
|             const FLOAT16* srcDPtr = srcRowPtr; | ||||
|             FLOAT16* dstDPtr = dstBasePtr; | ||||
|             size_t dRealSize = headDim; | ||||
| 
 | ||||
|             while (dRealSize >= 16) { | ||||
|                 float16x8_t s0 = vld1q_f16(srcDPtr); | ||||
|                 float16x8_t s1 = vld1q_f16(srcDPtr + 8); | ||||
|                 s0 = vmulq_f16(s0, vScale); | ||||
|                 s1 = vmulq_f16(s1, vScale); | ||||
| 
 | ||||
|                 float16x4_t lowS0_f16 = vget_low_f16(s0);   // {s0, s1, s2, s3}
 | ||||
|                 float16x4_t highS0_f16 = vget_high_f16(s0); // {s4, s5, s6, s7}
 | ||||
|                 uint32x2_t lowS0_u32 = vreinterpret_u32_f16(lowS0_f16); | ||||
|                 uint32x2_t highS0_u32 = vreinterpret_u32_f16(highS0_f16); | ||||
| 
 | ||||
|                 *((uint32_t*)(dstDPtr + 0 * dstStrideDOuter)) = vget_lane_u32(lowS0_u32, 0); // Store pair {s0, s1}
 | ||||
|                 *((uint32_t*)(dstDPtr + 1 * dstStrideDOuter)) = vget_lane_u32(lowS0_u32, 1); // Store pair {s2, s3}
 | ||||
|                 *((uint32_t*)(dstDPtr + 2 * dstStrideDOuter)) = vget_lane_u32(highS0_u32, 0); // Store pair {s4, s5}
 | ||||
|                 *((uint32_t*)(dstDPtr + 3 * dstStrideDOuter)) = vget_lane_u32(highS0_u32, 1); // Store pair {s6, s7}
 | ||||
| 
 | ||||
|                 float16x4_t lowS1_f16 = vget_low_f16(s1);   // {s8, s9, s10, s11}
 | ||||
|                 float16x4_t highS1_f16 = vget_high_f16(s1); // {s12, s13, s14, s15}
 | ||||
|                 uint32x2_t lowS1_u32 = vreinterpret_u32_f16(lowS1_f16); | ||||
|                 uint32x2_t highS1_u32 = vreinterpret_u32_f16(highS1_f16); | ||||
|                  | ||||
|                 *((uint32_t*)(dstDPtr + 4 * dstStrideDOuter)) = vget_lane_u32(lowS1_u32, 0); | ||||
|                 *((uint32_t*)(dstDPtr + 5 * dstStrideDOuter)) = vget_lane_u32(lowS1_u32, 1); | ||||
|                 *((uint32_t*)(dstDPtr + 6 * dstStrideDOuter)) = vget_lane_u32(highS1_u32, 0); | ||||
|                 *((uint32_t*)(dstDPtr + 7 * dstStrideDOuter)) = vget_lane_u32(highS1_u32, 1); | ||||
| 
 | ||||
|                 dRealSize -= 16; | ||||
|                 srcDPtr += 16; | ||||
|                 dstDPtr += 8 * dstStrideDOuter; | ||||
|             } | ||||
|             // Remainder loop with padding
 | ||||
|             while (dRealSize > 0) { | ||||
|                 if (dRealSize >= 2) { | ||||
|                     dstDPtr[0] = srcDPtr[0] * scaleVal; | ||||
|                     dstDPtr[1] = srcDPtr[1] * scaleVal; | ||||
| 
 | ||||
|                     dRealSize -= 2; | ||||
|                     srcDPtr += 2; | ||||
|                     dstDPtr += dstStrideDOuter; | ||||
|                 } else { // dRealSize == 1
 | ||||
|                     dstDPtr[0] = srcDPtr[0] * scaleVal; | ||||
|                     dstDPtr[1] = (FLOAT16)0.0f; // Pad with zero
 | ||||
|                     dRealSize = 0; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| static void MNNFlashAttentionUpdateBlockOutput( float* dst, float* src, const float* scale, const float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) { | ||||
|     auto dstPtr = (float16_t*)dst; | ||||
|     auto srcPtr = (float16_t*)src; | ||||
|     const auto stride0 = plane * pack; | ||||
| 
 | ||||
|     if (idx == 0) { | ||||
|         memcpy(dst, src, size * bytes); | ||||
|     } else { | ||||
|         for (int j = 0; j < depthQuad; ++j) { | ||||
|             const auto baseOffset = j * stride0; | ||||
|             int i = 0; | ||||
|             const int plane4 = plane - (plane % 4); | ||||
|             for (; i < plane4; i += 4) { | ||||
| 
 | ||||
|                 auto pdst0 = dstPtr + baseOffset + (i + 0) * pack; | ||||
|                 auto psrc0 = srcPtr + baseOffset + (i + 0) * pack; | ||||
|                 auto pdst1 = dstPtr + baseOffset + (i + 1) * pack; | ||||
|                 auto psrc1 = srcPtr + baseOffset + (i + 1) * pack; | ||||
|                 auto pdst2 = dstPtr + baseOffset + (i + 2) * pack; | ||||
|                 auto psrc2 = srcPtr + baseOffset + (i + 2) * pack; | ||||
|                 auto pdst3 = dstPtr + baseOffset + (i + 3) * pack; | ||||
|                 auto psrc3 = srcPtr + baseOffset + (i + 3) * pack; | ||||
| 
 | ||||
|                 float16x8_t src0 = vld1q_f16(psrc0); | ||||
|                 float16x8_t dst0 = vld1q_f16(pdst0); | ||||
|                 float16x8_t src1 = vld1q_f16(psrc1); | ||||
|                 float16x8_t dst1 = vld1q_f16(pdst1); | ||||
|                 float16x8_t src2 = vld1q_f16(psrc2); | ||||
|                 float16x8_t dst2 = vld1q_f16(pdst2); | ||||
|                 float16x8_t src3 = vld1q_f16(psrc3); | ||||
|                 float16x8_t dst3 = vld1q_f16(pdst3); | ||||
| 
 | ||||
|                 float32x4_t svec0 = vdupq_n_f32(scale[i + 0]); | ||||
|                 float32x4_t svec1 = vdupq_n_f32(scale[i + 1]); | ||||
|                 float32x4_t svec2 = vdupq_n_f32(scale[i + 2]); | ||||
|                 float32x4_t svec3 = vdupq_n_f32(scale[i + 3]); | ||||
| 
 | ||||
| 
 | ||||
|                 float32x4_t res00 = vfmaq_f32(vcvt_f32_f16(vget_low_f16(src0)),  vcvt_f32_f16(vget_low_f16(dst0)),  svec0); | ||||
|                 float32x4_t res10 = vfmaq_f32(vcvt_f32_f16(vget_high_f16(src0)), vcvt_f32_f16(vget_high_f16(dst0)), svec0); | ||||
|                  | ||||
|                 float32x4_t res01 = vfmaq_f32(vcvt_f32_f16(vget_low_f16(src1)),  vcvt_f32_f16(vget_low_f16(dst1)),  svec1); | ||||
|                 float32x4_t res11 = vfmaq_f32(vcvt_f32_f16(vget_high_f16(src1)), vcvt_f32_f16(vget_high_f16(dst1)), svec1); | ||||
| 
 | ||||
|                 float32x4_t res02 = vfmaq_f32(vcvt_f32_f16(vget_low_f16(src2)),  vcvt_f32_f16(vget_low_f16(dst2)),  svec2); | ||||
|                 float32x4_t res12 = vfmaq_f32(vcvt_f32_f16(vget_high_f16(src2)), vcvt_f32_f16(vget_high_f16(dst2)), svec2); | ||||
| 
 | ||||
|                 float32x4_t res03 = vfmaq_f32(vcvt_f32_f16(vget_low_f16(src3)),  vcvt_f32_f16(vget_low_f16(dst3)),  svec3); | ||||
|                 float32x4_t res13 = vfmaq_f32(vcvt_f32_f16(vget_high_f16(src3)), vcvt_f32_f16(vget_high_f16(dst3)), svec3); | ||||
| 
 | ||||
|                 vst1q_f16(pdst0, vcombine_f16(vcvt_f16_f32(res00), vcvt_f16_f32(res10))); | ||||
|                 vst1q_f16(pdst1, vcombine_f16(vcvt_f16_f32(res01), vcvt_f16_f32(res11))); | ||||
|                 vst1q_f16(pdst2, vcombine_f16(vcvt_f16_f32(res02), vcvt_f16_f32(res12))); | ||||
|                 vst1q_f16(pdst3, vcombine_f16(vcvt_f16_f32(res03), vcvt_f16_f32(res13))); | ||||
|             } | ||||
| 
 | ||||
|             for (; i < plane; ++i) { | ||||
|                 auto pdst = dstPtr + baseOffset + i * pack; | ||||
|                 auto psrc = srcPtr + baseOffset + i * pack; | ||||
| 
 | ||||
|                 float16x8_t srcF16 = vld1q_f16(psrc); | ||||
|                 float16x8_t dstF16 = vld1q_f16(pdst); | ||||
|                 float32x4_t svec = vdupq_n_f32(scale[i]); | ||||
|                  | ||||
|                 float32x4_t s0 = vcvt_f32_f16(vget_low_f16(srcF16)); | ||||
|                 float32x4_t s1 = vcvt_f32_f16(vget_high_f16(srcF16)); | ||||
|                 float32x4_t d0 = vcvt_f32_f16(vget_low_f16(dstF16)); | ||||
|                 float32x4_t d1 = vcvt_f32_f16(vget_high_f16(dstF16)); | ||||
| 
 | ||||
|                 float32x4_t res0 = vfmaq_f32(s0, d0, svec); | ||||
|                 float32x4_t res1 = vfmaq_f32(s1, d1, svec); | ||||
|                  | ||||
|                 vst1q_f16(pdst, vcombine_f16(vcvt_f16_f32(res0), vcvt_f16_f32(res1))); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     if (idx == kvBlocks - 1) { | ||||
|         for (int j = 0; j < depthQuad; ++j) { | ||||
|             const auto baseOffset = j * stride0; | ||||
|             int i = 0; | ||||
|             const int plane4 = plane - (plane % 4); | ||||
|             for (; i < plane4; i += 4) { | ||||
|                 auto pdst0 = dstPtr + baseOffset + (i + 0) * pack; | ||||
|                 auto pdst1 = dstPtr + baseOffset + (i + 1) * pack; | ||||
|                 auto pdst2 = dstPtr + baseOffset + (i + 2) * pack; | ||||
|                 auto pdst3 = dstPtr + baseOffset + (i + 3) * pack; | ||||
| 
 | ||||
|                 float16x8_t dst0 = vld1q_f16(pdst0); | ||||
|                 float16x8_t dst1 = vld1q_f16(pdst1); | ||||
|                 float16x8_t dst2 = vld1q_f16(pdst2); | ||||
|                 float16x8_t dst3 = vld1q_f16(pdst3); | ||||
| 
 | ||||
|                 float32x4_t ns0 = vdupq_n_f32(1.0f / normalizeScale[i + 0]); | ||||
|                 float32x4_t ns1 = vdupq_n_f32(1.0f / normalizeScale[i + 1]); | ||||
|                 float32x4_t ns2 = vdupq_n_f32(1.0f / normalizeScale[i + 2]); | ||||
|                 float32x4_t ns3 = vdupq_n_f32(1.0f / normalizeScale[i + 3]); | ||||
| 
 | ||||
|                 float32x4_t d00 = vmulq_f32(vcvt_f32_f16(vget_low_f16(dst0)),  ns0); | ||||
|                 float32x4_t d10 = vmulq_f32(vcvt_f32_f16(vget_high_f16(dst0)), ns0); | ||||
|                 float32x4_t d01 = vmulq_f32(vcvt_f32_f16(vget_low_f16(dst1)),  ns1); | ||||
|                 float32x4_t d11 = vmulq_f32(vcvt_f32_f16(vget_high_f16(dst1)), ns1); | ||||
|                 float32x4_t d02 = vmulq_f32(vcvt_f32_f16(vget_low_f16(dst2)),  ns2); | ||||
|                 float32x4_t d12 = vmulq_f32(vcvt_f32_f16(vget_high_f16(dst2)), ns2); | ||||
|                 float32x4_t d03 = vmulq_f32(vcvt_f32_f16(vget_low_f16(dst3)),  ns3); | ||||
|                 float32x4_t d13 = vmulq_f32(vcvt_f32_f16(vget_high_f16(dst3)), ns3); | ||||
| 
 | ||||
|                 vst1q_f16(pdst0, vcombine_f16(vcvt_f16_f32(d00), vcvt_f16_f32(d10))); | ||||
|                 vst1q_f16(pdst1, vcombine_f16(vcvt_f16_f32(d01), vcvt_f16_f32(d11))); | ||||
|                 vst1q_f16(pdst2, vcombine_f16(vcvt_f16_f32(d02), vcvt_f16_f32(d12))); | ||||
|                 vst1q_f16(pdst3, vcombine_f16(vcvt_f16_f32(d03), vcvt_f16_f32(d13))); | ||||
|             } | ||||
|              | ||||
|             for (; i < plane; ++i) { | ||||
|                 auto pdst = dstPtr + baseOffset + i * pack; | ||||
|                 float32x4_t nsvec = vdupq_n_f32(1.0f / normalizeScale[i]); | ||||
| 
 | ||||
|                 float16x8_t dstF16 = vld1q_f16(pdst); | ||||
|                 float32x4_t d0 = vcvt_f32_f16(vget_low_f16(dstF16)); | ||||
|                 float32x4_t d1 = vcvt_f32_f16(vget_high_f16(dstF16)); | ||||
|                  | ||||
|                 d0 = vmulq_f32(d0, nsvec); | ||||
|                 d1 = vmulq_f32(d1, nsvec); | ||||
|                  | ||||
|                 vst1q_f16(pdst, vcombine_f16(vcvt_f16_f32(d0), vcvt_f16_f32(d1))); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| static void MNNAttenUnpackAndConvertFp16(float* dst, float* src, size_t depth, size_t planesize, int pack) { | ||||
|     // src: (UP_DIV(depth, pack), planesize, pack), float16
 | ||||
|     // dst: (planesize, depth), float32
 | ||||
|     // pack=8
 | ||||
|      | ||||
|     if (planesize == 1) { | ||||
|         MNNDequantizeFP16((int16_t*)src, dst, depth); | ||||
|         return; // no need to convert
 | ||||
|     } | ||||
|     const auto depthDiv8 = UP_DIV(depth, pack); | ||||
|     const auto srcStep = pack * planesize; | ||||
|     const auto dstStep = depth; | ||||
|      | ||||
|     auto remainDepth = depth % pack; | ||||
|     auto depthQuad = depthDiv8; | ||||
|     if (remainDepth > 0) { | ||||
|         depthQuad -= 1; // last quad is not full
 | ||||
|     } | ||||
|      | ||||
|     for (int i = 0; i < depthQuad; ++i) { | ||||
|         auto realsize = planesize; | ||||
|         auto srcPtr = (FLOAT16*)src + i * srcStep; | ||||
|         auto dstPtr = (float*)dst + i * pack; | ||||
|         while (realsize >= 8) { | ||||
|             float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack); | ||||
|             float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack); | ||||
|             float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack); | ||||
|             float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack); | ||||
|             float16x8_t s4_f16 = vld1q_f16(srcPtr + 4 * pack); | ||||
|             float16x8_t s5_f16 = vld1q_f16(srcPtr + 5 * pack); | ||||
|             float16x8_t s6_f16 = vld1q_f16(srcPtr + 6 * pack); | ||||
|             float16x8_t s7_f16 = vld1q_f16(srcPtr + 7 * pack); | ||||
|              | ||||
|             float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16)); | ||||
|             float32x4_t d01_f32 = vcvt_f32_f16(vget_high_f16(s0_f16)); | ||||
|             float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16)); | ||||
|             float32x4_t d11_f32 = vcvt_f32_f16(vget_high_f16(s1_f16)); | ||||
|             float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16)); | ||||
|             float32x4_t d21_f32 = vcvt_f32_f16(vget_high_f16(s2_f16)); | ||||
|             float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16)); | ||||
|             float32x4_t d31_f32 = vcvt_f32_f16(vget_high_f16(s3_f16)); | ||||
|             float32x4_t d40_f32 = vcvt_f32_f16(vget_low_f16(s4_f16)); | ||||
|             float32x4_t d41_f32 = vcvt_f32_f16(vget_high_f16(s4_f16)); | ||||
|             float32x4_t d50_f32 = vcvt_f32_f16(vget_low_f16(s5_f16)); | ||||
|             float32x4_t d51_f32 = vcvt_f32_f16(vget_high_f16(s5_f16)); | ||||
|             float32x4_t d60_f32 = vcvt_f32_f16(vget_low_f16(s6_f16)); | ||||
|             float32x4_t d61_f32 = vcvt_f32_f16(vget_high_f16(s6_f16)); | ||||
|             float32x4_t d70_f32 = vcvt_f32_f16(vget_low_f16(s7_f16)); | ||||
|             float32x4_t d71_f32 = vcvt_f32_f16(vget_high_f16(s7_f16)); | ||||
|              | ||||
|             vst1q_f32(dstPtr + 0 * dstStep, d00_f32); vst1q_f32(dstPtr + 0 * dstStep + 4, d01_f32); | ||||
|             vst1q_f32(dstPtr + 1 * dstStep, d10_f32); vst1q_f32(dstPtr + 1 * dstStep + 4, d11_f32); | ||||
|             vst1q_f32(dstPtr + 2 * dstStep, d20_f32); vst1q_f32(dstPtr + 2 * dstStep + 4, d21_f32); | ||||
|             vst1q_f32(dstPtr + 3 * dstStep, d30_f32); vst1q_f32(dstPtr + 3 * dstStep + 4, d31_f32); | ||||
|             vst1q_f32(dstPtr + 4 * dstStep, d40_f32); vst1q_f32(dstPtr + 4 * dstStep + 4, d41_f32); | ||||
|             vst1q_f32(dstPtr + 5 * dstStep, d50_f32); vst1q_f32(dstPtr + 5 * dstStep + 4, d51_f32); | ||||
|             vst1q_f32(dstPtr + 6 * dstStep, d60_f32); vst1q_f32(dstPtr + 6 * dstStep + 4, d61_f32); | ||||
|             vst1q_f32(dstPtr + 7 * dstStep, d70_f32); vst1q_f32(dstPtr + 7 * dstStep + 4, d71_f32); | ||||
|              | ||||
|             srcPtr += 8 * pack; | ||||
|             dstPtr += 8 * dstStep; | ||||
|             realsize -= 8; | ||||
|         } | ||||
|         if (realsize >= 4) { | ||||
|             float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack); | ||||
|             float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack); | ||||
|             float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack); | ||||
|             float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack); | ||||
|              | ||||
|             float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16)); | ||||
|             float32x4_t d01_f32 = vcvt_f32_f16(vget_high_f16(s0_f16)); | ||||
|             float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16)); | ||||
|             float32x4_t d11_f32 = vcvt_f32_f16(vget_high_f16(s1_f16)); | ||||
|             float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16)); | ||||
|             float32x4_t d21_f32 = vcvt_f32_f16(vget_high_f16(s2_f16)); | ||||
|             float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16)); | ||||
|             float32x4_t d31_f32 = vcvt_f32_f16(vget_high_f16(s3_f16)); | ||||
|              | ||||
|             vst1q_f32(dstPtr + 0 * dstStep, d00_f32); vst1q_f32(dstPtr + 0 * dstStep + 4, d01_f32); | ||||
|             vst1q_f32(dstPtr + 1 * dstStep, d10_f32); vst1q_f32(dstPtr + 1 * dstStep + 4, d11_f32); | ||||
|             vst1q_f32(dstPtr + 2 * dstStep, d20_f32); vst1q_f32(dstPtr + 2 * dstStep + 4, d21_f32); | ||||
|             vst1q_f32(dstPtr + 3 * dstStep, d30_f32); vst1q_f32(dstPtr + 3 * dstStep + 4, d31_f32); | ||||
|              | ||||
|             srcPtr += 4 * pack; | ||||
|             dstPtr += 4 * dstStep; | ||||
|             realsize -= 4; | ||||
|         } | ||||
|         while (realsize > 0) { | ||||
|             auto s0_fp16 = vld1q_f16(srcPtr); | ||||
|             auto s00_fp32 = vcvt_f32_f16(vget_low_f16(s0_fp16)); | ||||
|             auto s01_fp32 = vcvt_f32_f16(vget_high_f16(s0_fp16)); | ||||
|             vst1q_f32(dstPtr, s00_fp32); | ||||
|             vst1q_f32(dstPtr + 4, s01_fp32); | ||||
|             srcPtr += pack; | ||||
|             dstPtr += dstStep; | ||||
|             realsize--; | ||||
|         } | ||||
|     } | ||||
|      | ||||
|     // process remain depth < 8
 | ||||
|     if (remainDepth >= 4) { | ||||
|         auto realsize = planesize; | ||||
|         auto srcPtr = (FLOAT16*)src + (depthDiv8 - 1) * srcStep; | ||||
|         auto dstPtr = (float*)dst + (depthDiv8 - 1) * pack; | ||||
|         auto extraDepth = remainDepth - 4; | ||||
|          | ||||
|         float tmp0[4]; | ||||
|         float tmp1[4]; | ||||
|         float tmp2[4]; | ||||
|         float tmp3[4]; | ||||
|         float tmp4[4]; | ||||
|         float tmp5[4]; | ||||
|         float tmp6[4]; | ||||
|         float tmp7[4]; | ||||
|          | ||||
|         while (realsize >= 8) { | ||||
|             float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack); | ||||
|             float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack); | ||||
|             float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack); | ||||
|             float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack); | ||||
|             float16x8_t s4_f16 = vld1q_f16(srcPtr + 4 * pack); | ||||
|             float16x8_t s5_f16 = vld1q_f16(srcPtr + 5 * pack); | ||||
|             float16x8_t s6_f16 = vld1q_f16(srcPtr + 6 * pack); | ||||
|             float16x8_t s7_f16 = vld1q_f16(srcPtr + 7 * pack); | ||||
|              | ||||
|              | ||||
|             float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16)); | ||||
|             float32x4_t d01_f32 = vcvt_f32_f16(vget_high_f16(s0_f16)); | ||||
|             float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16)); | ||||
|             float32x4_t d11_f32 = vcvt_f32_f16(vget_high_f16(s1_f16)); | ||||
|             float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16)); | ||||
|             float32x4_t d21_f32 = vcvt_f32_f16(vget_high_f16(s2_f16)); | ||||
|             float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16)); | ||||
|             float32x4_t d31_f32 = vcvt_f32_f16(vget_high_f16(s3_f16)); | ||||
|             float32x4_t d40_f32 = vcvt_f32_f16(vget_low_f16(s4_f16)); | ||||
|             float32x4_t d41_f32 = vcvt_f32_f16(vget_high_f16(s4_f16)); | ||||
|             float32x4_t d50_f32 = vcvt_f32_f16(vget_low_f16(s5_f16)); | ||||
|             float32x4_t d51_f32 = vcvt_f32_f16(vget_high_f16(s5_f16)); | ||||
|             float32x4_t d60_f32 = vcvt_f32_f16(vget_low_f16(s6_f16)); | ||||
|             float32x4_t d61_f32 = vcvt_f32_f16(vget_high_f16(s6_f16)); | ||||
|             float32x4_t d70_f32 = vcvt_f32_f16(vget_low_f16(s7_f16)); | ||||
|             float32x4_t d71_f32 = vcvt_f32_f16(vget_high_f16(s7_f16)); | ||||
|              | ||||
|             vst1q_f32(dstPtr + 0 * dstStep, d00_f32); vst1q_f32(tmp0, d01_f32); | ||||
|             vst1q_f32(dstPtr + 1 * dstStep, d10_f32); vst1q_f32(tmp1, d11_f32); | ||||
|             vst1q_f32(dstPtr + 2 * dstStep, d20_f32); vst1q_f32(tmp2, d21_f32); | ||||
|             vst1q_f32(dstPtr + 3 * dstStep, d30_f32); vst1q_f32(tmp3, d31_f32); | ||||
|             vst1q_f32(dstPtr + 4 * dstStep, d40_f32); vst1q_f32(tmp4, d41_f32); | ||||
|             vst1q_f32(dstPtr + 5 * dstStep, d50_f32); vst1q_f32(tmp5, d51_f32); | ||||
|             vst1q_f32(dstPtr + 6 * dstStep, d60_f32); vst1q_f32(tmp6, d61_f32); | ||||
|             vst1q_f32(dstPtr + 7 * dstStep, d70_f32); vst1q_f32(tmp7, d71_f32); | ||||
|              | ||||
|             memcpy(dstPtr + 0 * dstStep + 4, tmp0, sizeof(float) * extraDepth); | ||||
|             memcpy(dstPtr + 1 * dstStep + 4, tmp1, sizeof(float) * extraDepth); | ||||
|             memcpy(dstPtr + 2 * dstStep + 4, tmp2, sizeof(float) * extraDepth); | ||||
|             memcpy(dstPtr + 3 * dstStep + 4, tmp3, sizeof(float) * extraDepth); | ||||
|             memcpy(dstPtr + 4 * dstStep + 4, tmp4, sizeof(float) * extraDepth); | ||||
|             memcpy(dstPtr + 5 * dstStep + 4, tmp5, sizeof(float) * extraDepth); | ||||
|             memcpy(dstPtr + 6 * dstStep + 4, tmp6, sizeof(float) * extraDepth); | ||||
|             memcpy(dstPtr + 7 * dstStep + 4, tmp7, sizeof(float) * extraDepth); | ||||
|              | ||||
|             srcPtr += 8 * pack; | ||||
|             dstPtr += 8 * dstStep; | ||||
|             realsize -= 8; | ||||
|         } | ||||
|         if (realsize >= 4) { | ||||
|             float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack); | ||||
|             float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack); | ||||
|             float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack); | ||||
|             float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack); | ||||
|              | ||||
|             float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16)); | ||||
|             float32x4_t d01_f32 = vcvt_f32_f16(vget_high_f16(s0_f16)); | ||||
|             float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16)); | ||||
|             float32x4_t d11_f32 = vcvt_f32_f16(vget_high_f16(s1_f16)); | ||||
|             float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16)); | ||||
|             float32x4_t d21_f32 = vcvt_f32_f16(vget_high_f16(s2_f16)); | ||||
|             float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16)); | ||||
|             float32x4_t d31_f32 = vcvt_f32_f16(vget_high_f16(s3_f16)); | ||||
|              | ||||
|             vst1q_f32(dstPtr + 0 * dstStep, d00_f32); vst1q_f32(tmp0, d01_f32); | ||||
|             vst1q_f32(dstPtr + 1 * dstStep, d10_f32); vst1q_f32(tmp1, d11_f32); | ||||
|             vst1q_f32(dstPtr + 2 * dstStep, d20_f32); vst1q_f32(tmp2, d21_f32); | ||||
|             vst1q_f32(dstPtr + 3 * dstStep, d30_f32); vst1q_f32(tmp3, d31_f32); | ||||
|              | ||||
|             memcpy(dstPtr + 0 * dstStep + 4, tmp0, sizeof(float) * extraDepth); | ||||
|             memcpy(dstPtr + 1 * dstStep + 4, tmp1, sizeof(float) * extraDepth); | ||||
|             memcpy(dstPtr + 2 * dstStep + 4, tmp2, sizeof(float) * extraDepth); | ||||
|             memcpy(dstPtr + 3 * dstStep + 4, tmp3, sizeof(float) * extraDepth); | ||||
|              | ||||
|             srcPtr += 4 * pack; | ||||
|             dstPtr += 4 * dstStep; | ||||
|             realsize -= 4; | ||||
|         } | ||||
|         while (realsize > 0) { | ||||
|             auto s0_fp16 = vld1q_f16(srcPtr); | ||||
|             auto d00_fp32 = vcvt_f32_f16(vget_low_f16(s0_fp16)); | ||||
|             auto d01_fp32 = vcvt_f32_f16(vget_high_f16(s0_fp16)); | ||||
|             vst1q_f32(dstPtr, d00_fp32); | ||||
|             vst1q_f32(tmp0, d01_fp32); | ||||
|             memcpy(dstPtr + 4, tmp0, sizeof(float) * extraDepth); | ||||
|             srcPtr += pack; | ||||
|             dstPtr += dstStep; | ||||
|             realsize--; | ||||
|         } | ||||
|     } | ||||
|      | ||||
|     if (remainDepth > 0 && remainDepth < 4) { | ||||
|         auto realsize = planesize; | ||||
|         auto srcPtr = (FLOAT16*)src + (depthDiv8 - 1) * srcStep; | ||||
|         auto dstPtr = (float*)dst + (depthDiv8 - 1) * pack; | ||||
|          | ||||
|         float tmp0[4]; | ||||
|         float tmp1[4]; | ||||
|         float tmp2[4]; | ||||
|         float tmp3[4]; | ||||
|         float tmp4[4]; | ||||
|         float tmp5[4]; | ||||
|         float tmp6[4]; | ||||
|         float tmp7[4]; | ||||
|          | ||||
|         while (realsize >= 8) { | ||||
|             float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack); | ||||
|             float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack); | ||||
|             float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack); | ||||
|             float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack); | ||||
|             float16x8_t s4_f16 = vld1q_f16(srcPtr + 4 * pack); | ||||
|             float16x8_t s5_f16 = vld1q_f16(srcPtr + 5 * pack); | ||||
|             float16x8_t s6_f16 = vld1q_f16(srcPtr + 6 * pack); | ||||
|             float16x8_t s7_f16 = vld1q_f16(srcPtr + 7 * pack); | ||||
|              | ||||
|             float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16)); | ||||
|             float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16)); | ||||
|             float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16)); | ||||
|             float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16)); | ||||
|             float32x4_t d40_f32 = vcvt_f32_f16(vget_low_f16(s4_f16)); | ||||
|             float32x4_t d50_f32 = vcvt_f32_f16(vget_low_f16(s5_f16)); | ||||
|             float32x4_t d60_f32 = vcvt_f32_f16(vget_low_f16(s6_f16)); | ||||
|             float32x4_t d70_f32 = vcvt_f32_f16(vget_low_f16(s7_f16)); | ||||
|              | ||||
|             vst1q_f32(tmp0, d00_f32); | ||||
|             vst1q_f32(tmp1, d10_f32); | ||||
|             vst1q_f32(tmp2, d20_f32); | ||||
|             vst1q_f32(tmp3, d30_f32); | ||||
|             vst1q_f32(tmp4, d40_f32); | ||||
|             vst1q_f32(tmp5, d50_f32); | ||||
|             vst1q_f32(tmp6, d60_f32); | ||||
|             vst1q_f32(tmp7, d70_f32); | ||||
|              | ||||
|             memcpy(dstPtr + 0 * dstStep, tmp0, sizeof(float) * remainDepth); | ||||
|             memcpy(dstPtr + 1 * dstStep, tmp1, sizeof(float) * remainDepth); | ||||
|             memcpy(dstPtr + 2 * dstStep, tmp2, sizeof(float) * remainDepth); | ||||
|             memcpy(dstPtr + 3 * dstStep, tmp3, sizeof(float) * remainDepth); | ||||
|             memcpy(dstPtr + 4 * dstStep, tmp4, sizeof(float) * remainDepth); | ||||
|             memcpy(dstPtr + 5 * dstStep, tmp5, sizeof(float) * remainDepth); | ||||
|             memcpy(dstPtr + 6 * dstStep, tmp6, sizeof(float) * remainDepth); | ||||
|             memcpy(dstPtr + 7 * dstStep, tmp7, sizeof(float) * remainDepth); | ||||
|              | ||||
|             srcPtr += 8 * pack; | ||||
|             dstPtr += 8 * dstStep; | ||||
|             realsize -= 8; | ||||
|         } | ||||
|         if (realsize >= 4) { | ||||
|             float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack); | ||||
|             float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack); | ||||
|             float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack); | ||||
|             float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack); | ||||
|              | ||||
|             float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16)); | ||||
|             float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16)); | ||||
|             float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16)); | ||||
|             float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16)); | ||||
|              | ||||
|             vst1q_f32(tmp0, d00_f32); | ||||
|             vst1q_f32(tmp1, d10_f32); | ||||
|             vst1q_f32(tmp2, d20_f32); | ||||
|             vst1q_f32(tmp3, d30_f32); | ||||
|              | ||||
|             memcpy(dstPtr + 0 * dstStep, tmp0, sizeof(float) * remainDepth); | ||||
|             memcpy(dstPtr + 1 * dstStep, tmp1, sizeof(float) * remainDepth); | ||||
|             memcpy(dstPtr + 2 * dstStep, tmp2, sizeof(float) * remainDepth); | ||||
|             memcpy(dstPtr + 3 * dstStep, tmp3, sizeof(float) * remainDepth); | ||||
|              | ||||
|             srcPtr += 4 * pack; | ||||
|             dstPtr += 4 * dstStep; | ||||
|             realsize -= 4; | ||||
|         } | ||||
|         while (realsize > 0) { | ||||
|             auto s0_f16 = vld1q_f16(srcPtr); | ||||
|             float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16)); | ||||
|             vst1q_f32(tmp0, d00_f32); | ||||
|             memcpy(dstPtr + 0 * dstStep, tmp0, sizeof(float) * remainDepth); | ||||
|             srcPtr += pack; | ||||
|             dstPtr += dstStep; | ||||
|             realsize--; | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| static void MNNAttenPackAndConvertFp32LP1(float* dst, const float* src, const int32_t* units, size_t depth, size_t planesize) { | ||||
|     int32_t eP = units[0]; | ||||
|     int32_t lP = units[1]; | ||||
|      | ||||
|     if (lP != 1) { | ||||
|         MNN_ERROR("This function only supports lP=1\n"); | ||||
|         return; | ||||
|     } | ||||
|      | ||||
|     auto dstStride1 = eP; | ||||
|     auto dstStride0 = planesize * dstStride1; | ||||
|      | ||||
|     for (int i = 0; i < depth; ++i) { | ||||
|         size_t realsize = planesize; | ||||
|         const float* srcPtr = src + i * planesize; | ||||
|         FLOAT16* dstPtr = (FLOAT16*)dst + (i % eP) + (i / eP) * dstStride0; | ||||
|          | ||||
|         while (realsize >= 16) { | ||||
|             float32x4_t s0_f32 = vld1q_f32(srcPtr); | ||||
|             float32x4_t s1_f32 = vld1q_f32(srcPtr + 4); | ||||
|             float32x4_t s2_f32 = vld1q_f32(srcPtr + 8); | ||||
|             float32x4_t s3_f32 = vld1q_f32(srcPtr + 12); | ||||
|              | ||||
|             float16x4_t d0_f16 = vcvt_f16_f32(s0_f32); | ||||
|             float16x4_t d1_f16 = vcvt_f16_f32(s1_f32); | ||||
|             float16x4_t d2_f16 = vcvt_f16_f32(s2_f32); | ||||
|             float16x4_t d3_f16 = vcvt_f16_f32(s3_f32); | ||||
|              | ||||
|             vst1_lane_f16(dstPtr,                  d0_f16, 0); | ||||
|             vst1_lane_f16(dstPtr + dstStride1,     d0_f16, 1); | ||||
|             vst1_lane_f16(dstPtr + 2 * dstStride1, d0_f16, 2); | ||||
|             vst1_lane_f16(dstPtr + 3 * dstStride1, d0_f16, 3); | ||||
|              | ||||
|             vst1_lane_f16(dstPtr + 4 * dstStride1, d1_f16, 0); | ||||
|             vst1_lane_f16(dstPtr + 5 * dstStride1, d1_f16, 1); | ||||
|             vst1_lane_f16(dstPtr + 6 * dstStride1, d1_f16, 2); | ||||
|             vst1_lane_f16(dstPtr + 7 * dstStride1, d1_f16, 3); | ||||
|              | ||||
|             vst1_lane_f16(dstPtr + 8 * dstStride1,  d2_f16, 0); | ||||
|             vst1_lane_f16(dstPtr + 9 * dstStride1,  d2_f16, 1); | ||||
|             vst1_lane_f16(dstPtr + 10 * dstStride1, d2_f16, 2); | ||||
|             vst1_lane_f16(dstPtr + 11 * dstStride1, d2_f16, 3); | ||||
|              | ||||
|             vst1_lane_f16(dstPtr + 12 * dstStride1, d3_f16, 0); | ||||
|             vst1_lane_f16(dstPtr + 13 * dstStride1, d3_f16, 1); | ||||
|             vst1_lane_f16(dstPtr + 14 * dstStride1, d3_f16, 2); | ||||
|             vst1_lane_f16(dstPtr + 15 * dstStride1, d3_f16, 3); | ||||
|              | ||||
|             srcPtr += 16; | ||||
|             dstPtr += 16 * dstStride1; | ||||
|             realsize -= 16; | ||||
|         } | ||||
|          | ||||
|         if (realsize >= 8) { | ||||
|             float32x4_t s0_f32 = vld1q_f32(srcPtr); | ||||
|             float32x4_t s1_f32 = vld1q_f32(srcPtr + 4); | ||||
|              | ||||
|             float16x4_t d0_f16 = vcvt_f16_f32(s0_f32); | ||||
|             float16x4_t d1_f16 = vcvt_f16_f32(s1_f32); | ||||
|              | ||||
|             vst1_lane_f16(dstPtr,              d0_f16, 0); | ||||
|             vst1_lane_f16(dstPtr + dstStride1, d0_f16, 1); | ||||
|             vst1_lane_f16(dstPtr + 2 * dstStride1, d0_f16, 2); | ||||
|             vst1_lane_f16(dstPtr + 3 * dstStride1, d0_f16, 3); | ||||
|              | ||||
|             vst1_lane_f16(dstPtr + 4 * dstStride1, d1_f16, 0); | ||||
|             vst1_lane_f16(dstPtr + 5 * dstStride1, d1_f16, 1); | ||||
|             vst1_lane_f16(dstPtr + 6 * dstStride1, d1_f16, 2); | ||||
|             vst1_lane_f16(dstPtr + 7 * dstStride1, d1_f16, 3); | ||||
|              | ||||
|             srcPtr += 8; | ||||
|             dstPtr += 8 * dstStride1; | ||||
|             realsize -= 8; | ||||
|         } | ||||
|          | ||||
|         if (realsize >= 4) { | ||||
|             float32x4_t s0_f32 = vld1q_f32(srcPtr); | ||||
|             float16x4_t d0_f16 = vcvt_f16_f32(s0_f32); | ||||
|              | ||||
|             vst1_lane_f16(dstPtr,              d0_f16, 0); | ||||
|             vst1_lane_f16(dstPtr + dstStride1, d0_f16, 1); | ||||
|             vst1_lane_f16(dstPtr + 2 * dstStride1, d0_f16, 2); | ||||
|             vst1_lane_f16(dstPtr + 3 * dstStride1, d0_f16, 3); | ||||
|              | ||||
|             srcPtr += 4; | ||||
|             dstPtr += 4 * dstStride1; | ||||
|             realsize -= 4; | ||||
|         } | ||||
|          | ||||
|         for (; realsize > 0; --realsize) { | ||||
|             *dstPtr = (FLOAT16)(*srcPtr); | ||||
|             srcPtr++; | ||||
|             dstPtr += dstStride1; | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| static void MNNAttenPackAndConvertFp32(float* dst, float* src, const int32_t* units, size_t depth, size_t planesize) { | ||||
|     int32_t eP = units[0]; | ||||
|     int32_t lP = units[1]; // Now lP=1 or 2
 | ||||
|      | ||||
|     if (lP != 1 && lP != 2) { | ||||
|         MNN_ERROR("This function only supports lP=1 or 2\n"); | ||||
|         return; | ||||
|     } | ||||
|      | ||||
|     // src [depth, planesize] (float32)
 | ||||
|     // dst [depth/eP, planesize/lP, eP, lP] (float16)
 | ||||
|      | ||||
|     if (lP == 1) { | ||||
|         MNNAttenPackAndConvertFp32LP1(dst, src, units, depth, planesize); | ||||
|         return; | ||||
|     } | ||||
|      | ||||
|     auto dstStride1 = eP * lP; | ||||
|     auto dstStride0 = UP_DIV(planesize, lP) * dstStride1; | ||||
|      | ||||
|     for (int i = 0; i < depth; ++i) { | ||||
|         size_t realsize = planesize; | ||||
|         const float* srcPtr = src + i * planesize; | ||||
|         FLOAT16* dstPtr = (FLOAT16*)dst + (i % eP) * lP + (i / eP) * dstStride0; | ||||
|          | ||||
|         while (realsize >= 16) { | ||||
|             float32x4_t s0 = vld1q_f32(srcPtr); | ||||
|             float32x4_t s1 = vld1q_f32(srcPtr + 4); | ||||
|             float32x4_t s2 = vld1q_f32(srcPtr + 8); | ||||
|             float32x4_t s3 = vld1q_f32(srcPtr + 12); | ||||
|              | ||||
|             float16x4_t h0 = vcvt_f16_f32(s0); | ||||
|             float16x4_t h1 = vcvt_f16_f32(s1); | ||||
|             float16x4_t h2 = vcvt_f16_f32(s2); | ||||
|             float16x4_t h3 = vcvt_f16_f32(s3); | ||||
|              | ||||
|             vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0); | ||||
|             vst1_lane_u32((uint32_t*)(dstPtr + dstStride1), vreinterpret_u32_f16(h0), 1); | ||||
|              | ||||
|             vst1_lane_u32((uint32_t*)(dstPtr + 2 * dstStride1), vreinterpret_u32_f16(h1), 0); | ||||
|             vst1_lane_u32((uint32_t*)(dstPtr + 3 * dstStride1), vreinterpret_u32_f16(h1), 1); | ||||
|              | ||||
|             vst1_lane_u32((uint32_t*)(dstPtr + 4 * dstStride1), vreinterpret_u32_f16(h2), 0); | ||||
|             vst1_lane_u32((uint32_t*)(dstPtr + 5 * dstStride1), vreinterpret_u32_f16(h2), 1); | ||||
|              | ||||
|             vst1_lane_u32((uint32_t*)(dstPtr + 6 * dstStride1), vreinterpret_u32_f16(h3), 0); | ||||
|             vst1_lane_u32((uint32_t*)(dstPtr + 7 * dstStride1), vreinterpret_u32_f16(h3), 1); | ||||
|              | ||||
|             realsize -= 16; | ||||
|             srcPtr += 16; | ||||
|             dstPtr += 8 * dstStride1; | ||||
|         } | ||||
|          | ||||
|         if (realsize >= 8) { | ||||
|             float32x4_t s0 = vld1q_f32(srcPtr); | ||||
|             float32x4_t s1 = vld1q_f32(srcPtr + 4); | ||||
|              | ||||
|             float16x4_t h0 = vcvt_f16_f32(s0); | ||||
|             float16x4_t h1 = vcvt_f16_f32(s1); | ||||
|              | ||||
|             vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0); | ||||
|             vst1_lane_u32((uint32_t*)(dstPtr + dstStride1), vreinterpret_u32_f16(h0), 1); | ||||
|              | ||||
|             vst1_lane_u32((uint32_t*)(dstPtr + 2 * dstStride1), vreinterpret_u32_f16(h1), 0); | ||||
|             vst1_lane_u32((uint32_t*)(dstPtr + 3 * dstStride1), vreinterpret_u32_f16(h1), 1); | ||||
|              | ||||
|             realsize -= 8; | ||||
|             srcPtr += 8; | ||||
|             dstPtr += 4 * dstStride1; | ||||
|         } | ||||
|          | ||||
|         if (realsize >= 4) { | ||||
|             float32x4_t s0 = vld1q_f32(srcPtr); | ||||
|             float16x4_t h0 = vcvt_f16_f32(s0); | ||||
|              | ||||
|             vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0); | ||||
|             vst1_lane_u32((uint32_t*)(dstPtr + dstStride1), vreinterpret_u32_f16(h0), 1); | ||||
|              | ||||
|             realsize -= 4; | ||||
|             srcPtr += 4; | ||||
|             dstPtr += 2 * dstStride1; | ||||
|         } | ||||
|          | ||||
|         if (realsize >= 2) { | ||||
|             float32x2_t s0 = vld1_f32(srcPtr); | ||||
|             float16x4_t h0 = vcvt_f16_f32(vcombine_f32(s0, s0)); | ||||
|              | ||||
|             vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0); | ||||
|              | ||||
|             realsize -= 2; | ||||
|             srcPtr += 2; | ||||
|             dstPtr += dstStride1; | ||||
|         } | ||||
|          | ||||
|         if (realsize > 0) { | ||||
|             dstPtr[0] = (FLOAT16)srcPtr[0]; | ||||
|             dstPtr[1] = (FLOAT16)0.0f; | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #endif // MNN_SUPPORT_TRANSFORMER_FUSE
 | ||||
| 
 | ||||
| #ifdef MNN_LOW_MEMORY | ||||
| void MNNAbsMaxFP16(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) { | ||||
|     if (pack == 4) { | ||||
|  | @ -1659,6 +2383,7 @@ static void MNNAsyQuantInfo_FP16(float* scale, float* bias, float* qscale, float | |||
|     } | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| #endif // MNN_LOW_MEMORY
 | ||||
| 
 | ||||
| static CoreFunctions* gInstance = nullptr; | ||||
|  | @ -1726,6 +2451,8 @@ bool Arm82Functions::init() { | |||
|     FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain, MNNPackedMatMulRemainFP16); | ||||
|     FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Arm82MNNPackForMatMul_A); | ||||
|     FUNC_PTR_ASSIGN(gInstance->MNNPackForMatMul_B, Arm82MNNPackForMatMul_B); | ||||
| 
 | ||||
|     FUNC_PTR_ASSIGN(gInstance->MNNSoftmax, origin->MNNSoftmax); | ||||
| #if defined(__aarch64__) | ||||
|     gInstance->supportFp16arith = origin->supportFp16arith; | ||||
|     gInstance->supportSDot = origin->supportSDot; | ||||
|  | @ -1755,7 +2482,16 @@ bool Arm82Functions::init() { | |||
|     FUNC_PTR_ASSIGN(gInstance->MNNCountMaxMinValue, ARM82CountMinMaxValue); // return one min&max
 | ||||
|     FUNC_PTR_ASSIGN(gInstance->MNNSumByAxisLForMatmul_A, origin->MNNSumByAxisLForMatmul_A); | ||||
|     FUNC_PTR_ASSIGN(gInstance->MNNDepthwiseConvFastKernel, MNNDepthwiseConvFastKernelFP16); | ||||
| #endif | ||||
| #endif // __aarch64__
 | ||||
| 
 | ||||
| #ifdef MNN_SUPPORT_TRANSFORMER_FUSE | ||||
|     // Attention
 | ||||
|     FUNC_PTR_ASSIGN(gInstance->MNNAttenUnpackAndConvertFp16, MNNAttenUnpackAndConvertFp16); | ||||
|     FUNC_PTR_ASSIGN(gInstance->MNNAttenPackAndConvertFp32, MNNAttenPackAndConvertFp32); | ||||
|     FUNC_PTR_ASSIGN(gInstance->MNNAttenPackAndScaleSingleHead, MNNAttenPackAndScaleSingleHead); | ||||
|     FUNC_PTR_ASSIGN(gInstance->MNNFlashAttentionUpdateBlockOutput, MNNFlashAttentionUpdateBlockOutput); | ||||
| #endif // MNN_SUPPORT_TRANSFORMER_FUSE
 | ||||
| 
 | ||||
|     gInstance->MNNComputeMatMulForH_1 = _MNNComputeMatMulForH_1_FP16; | ||||
|     gInstance->MNNComputeMatMulForE_1 = _MNNComputeMatMulForE_1_FP16; | ||||
| 
 | ||||
|  |  | |||
|  | @ -24,10 +24,12 @@ | |||
| #define FLOAT16_T float | ||||
| #endif | ||||
| 
 | ||||
| #define MNN_FLASH_ATTENTION_BLOCK_SIZE 64 | ||||
| 
 | ||||
| namespace MNN { | ||||
| 
 | ||||
| template <typename T> | ||||
| void CPUAttention::pack_query(Tensor* query, char* pack_q, char* sum_q, int seq_len, int h, float q_scale) { | ||||
| void CPUAttention::pack_query(Tensor* query, int8_t* pack_q, int8_t* sum_q, int seq_len, int h, float q_scale) { | ||||
|     if (mUseGemmInt8) { // Shape of Query: numhead, [seqlen/eP8, headdim/lP8, eP8, lP8]
 | ||||
|         mMinQ[h] = query->host<T>()[h * mHeadDim]; | ||||
|         mMaxQ[h] = query->host<T>()[h * mHeadDim]; | ||||
|  | @ -75,21 +77,21 @@ void CPUAttention::pack_query(Tensor* query, char* pack_q, char* sum_q, int seq_ | |||
| } | ||||
| 
 | ||||
| template <typename T> | ||||
| void CPUAttention::unpack_QK(float * unpack_qk_dst, char * pack_qk_src, int seq_len, int kv_seq_len) { | ||||
| void CPUAttention::unpack_QK(float * unpack_qk_dst, int8_t * pack_qk_src, int seq_len, int kv_seq_len) { | ||||
|     float * dst = unpack_qk_dst; | ||||
|     T * src = (T *)(pack_qk_src); | ||||
|     // [kv_seq_len/unit, seq_len, unit] -> [seq_len, kv_seq_len]
 | ||||
|     // [kv_seq_len/mPack, seq_len, mPack] -> [seq_len, kv_seq_len]
 | ||||
|     for (int i = 0; i < seq_len; i++) { | ||||
|         for (int j = 0; j < kv_seq_len; j++) { | ||||
|             int out_index = j / unit; | ||||
|             int in_index  = j % unit; | ||||
|             dst[i * kv_seq_len + j] = src[out_index * seq_len * unit + i * unit + in_index]; | ||||
|             int out_index = j / mPack; | ||||
|             int in_index  = j % mPack; | ||||
|             dst[i * kv_seq_len + j] = src[out_index * seq_len * mPack + i * mPack + in_index]; | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| template <typename T> | ||||
| static void pack_QK(char * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_len, int eP, int lP, int bytes) { | ||||
| static void pack_QK(int8_t * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_len, int eP, int lP, int bytes) { | ||||
|     T * dst = reinterpret_cast<T*>(pack_qk_dst); | ||||
|     float * src = reinterpret_cast<float*>(qk_src); | ||||
|     // [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len/lP, eP, lP]
 | ||||
|  | @ -108,38 +110,50 @@ static void pack_QK(char * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_ | |||
| } | ||||
| 
 | ||||
| template <typename T> | ||||
| static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale, float min_val, const Tensor* mask) { | ||||
|     if (mask == nullptr) { | ||||
|         for (int i = 0; i < kv_seq_len; i++) { | ||||
| static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale, float min_val, const Tensor* maskTensor, int offset, int startIndx, int processedKvLen) { | ||||
|      | ||||
|     int endIndx = startIndx + processedKvLen; | ||||
|     if (maskTensor == nullptr) { | ||||
|         for (int i = 0; i < processedKvLen; i++) { | ||||
|             unpack_qk[i] = unpack_qk[i] * mScale; | ||||
|         } | ||||
|     } else if (mask->getType() == halide_type_of<float>()) { | ||||
|         return; | ||||
|     } | ||||
|     const int8_t* mask = maskTensor->host<int8_t>(); | ||||
|     halide_type_t htype = maskTensor->getType(); | ||||
|     int maskSize = maskTensor->elementSize(); | ||||
|      | ||||
|     if (htype == halide_type_of<float>()) { | ||||
|         // float mask
 | ||||
|         T* fpmask_ptr = mask->host<T>(); | ||||
|         if (mask->elementSize() == seq_len * kv_seq_len) { | ||||
|             // normal mask for all token
 | ||||
|             for (int i = 0; i < seq_len * kv_seq_len; i++) { | ||||
|                 unpack_qk[i] = unpack_qk[i] * mScale + fpmask_ptr[i]; | ||||
|             } | ||||
|         } else { | ||||
|             // square mask just for new generation token
 | ||||
|             int offset = kv_seq_len - seq_len; | ||||
|         T* fpmask_ptr = (T*)mask; | ||||
|         if (maskSize == seq_len * kv_seq_len) { // sliding attention, mask shape: [seq_len, kv_seq_len]
 | ||||
|             for (int i = 0; i < seq_len; ++i) { | ||||
|                 auto unpack_qki = unpack_qk + i * kv_seq_len; | ||||
|                 auto fpmask_ptri = fpmask_ptr + i * seq_len; | ||||
|                 for (int j = 0; j < offset; ++j) { | ||||
|                     unpack_qki[j] = unpack_qki[j] * mScale; | ||||
|                 auto unpack_qki = unpack_qk + i * processedKvLen; | ||||
|                 auto fpmask_ptri = fpmask_ptr + i * kv_seq_len; | ||||
|                 for (int j = startIndx; j < endIndx; ++j) { | ||||
|                     unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale + fpmask_ptri[j]; | ||||
|                 } | ||||
|                 for (int j = 0; j < seq_len; ++j) { | ||||
|                     unpack_qki[offset + j] = unpack_qki[offset + j] * mScale + fpmask_ptri[j]; | ||||
|             } | ||||
|         } else { // mask shape: [seq_len, seq_len]
 | ||||
|             for (int i = 0; i < seq_len; ++i) { | ||||
|                 auto unpack_qki = unpack_qk + i * processedKvLen; | ||||
|                 auto fpmask_ptri = fpmask_ptr + i * seq_len; | ||||
| 
 | ||||
|                 auto notMaskIndx = ALIMIN(endIndx, offset); | ||||
|                 auto stMaskIndx = ALIMAX(startIndx, offset); | ||||
|                 for (int j = startIndx; j < notMaskIndx; ++j) { | ||||
|                     unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale; | ||||
|                 } | ||||
|                 for (int j = stMaskIndx; j < endIndx; ++j) { | ||||
|                     unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale + fpmask_ptri[j - offset]; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } else { | ||||
|         // int mask
 | ||||
|         int* mask_ptr = mask->host<int>(); | ||||
|         for (int i = 0; i < seq_len * kv_seq_len; i++) { | ||||
|             if (mask_ptr[i]) { | ||||
|         int* mask_ptr = (int*)mask; | ||||
|         for (int i = 0; i < seq_len * processedKvLen; i++) { | ||||
|             if (mask_ptr[i / processedKvLen * kv_seq_len + i % processedKvLen]) { | ||||
|                 unpack_qk[i] = unpack_qk[i] * mScale; | ||||
|             } else { | ||||
|                 unpack_qk[i] = min_val; | ||||
|  | @ -148,41 +162,47 @@ static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale | |||
|     } | ||||
| } | ||||
| 
 | ||||
| static void softmax_QK(float* softmax_qk_addr, float* unpack_qk_addr, int seq_len, int kv_seq_len) { | ||||
|     for (int i = 0; i < seq_len; i++) {  // softmax each row
 | ||||
|         MNNSoftmax(softmax_qk_addr + i * kv_seq_len, unpack_qk_addr + i * kv_seq_len, kv_seq_len); | ||||
|     } | ||||
| } | ||||
| typedef void(softmaxFunc)(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize); | ||||
| template <typename T> | ||||
| static void softmaxQK(float* softmax_qk_addr, float* unpack_qk_addr, float* runningMax, float* runningSum, float* diffScale, const float* sinkPtr, softmaxFunc* sffunc, int seq_len, int kv_seq_len, int headIdx, bool isLastKvBlock) { | ||||
|      | ||||
| static void sink_softmax_QK(float* softmax_qk_addr, float* unpack_qk_addr, int seq_len, int kv_seq_len, float sink) { | ||||
|     // TODO: opt
 | ||||
|     std::vector<float> buffer(2 * (kv_seq_len + 1)); | ||||
|     float* sinkSrc = buffer.data(); | ||||
|     float* sinkDst = buffer.data() + kv_seq_len + 1; | ||||
|     for (int i = 0; i < seq_len; i++) {  // softmax each row
 | ||||
|         ::memcpy(sinkSrc, unpack_qk_addr + i * kv_seq_len, kv_seq_len * sizeof(float)); | ||||
|         sinkSrc[kv_seq_len] = sink; | ||||
|         float rowMax = sink; | ||||
|         for (int j = 0; j < kv_seq_len; j++) { | ||||
|             rowMax = ALIMAX(rowMax, sinkSrc[j]); | ||||
|         } | ||||
|         for (int j = 0; j < kv_seq_len + 1; j++) { | ||||
|             sinkSrc[j] = sinkSrc[j] - rowMax; | ||||
|         } | ||||
|         MNNSoftmax(sinkDst, sinkSrc, kv_seq_len + 1); | ||||
|         ::memcpy(softmax_qk_addr + i * kv_seq_len, sinkDst, kv_seq_len * sizeof(float)); | ||||
|     // not sliding attention
 | ||||
|     if (sinkPtr == nullptr) { | ||||
|         sffunc(softmax_qk_addr, unpack_qk_addr, runningMax, runningSum, diffScale, seq_len, kv_seq_len); | ||||
|         return; | ||||
|     } | ||||
|      | ||||
|     float sink = ((T*)sinkPtr)[headIdx]; | ||||
|     if (!runningMax && !runningSum) { // Do not use flash attention
 | ||||
|          | ||||
|         for (int i = 0; i < seq_len; ++i) { | ||||
|             float exprOffset[4] = {1, 0, -sink, 1.f}; | ||||
|             MNNExp(softmax_qk_addr + i * kv_seq_len, unpack_qk_addr + i * kv_seq_len, exprOffset, kv_seq_len); | ||||
|             for (int j = 0; j < kv_seq_len; ++j) { | ||||
|                 softmax_qk_addr[i * kv_seq_len + j] /= exprOffset[3]; | ||||
|             } | ||||
|         } | ||||
|         return; | ||||
|     } | ||||
| 
 | ||||
|     // Use flash attention    
 | ||||
|     if (isLastKvBlock) { | ||||
|         for (int i = 0; i < seq_len; ++i) { | ||||
|             runningSum[i] += expf(sink - runningMax[i]); | ||||
|         } | ||||
|     } | ||||
|     MNNSoftmax(softmax_qk_addr, unpack_qk_addr, runningMax, runningSum, diffScale, seq_len, kv_seq_len); | ||||
| } | ||||
| 
 | ||||
| template <typename T> | ||||
| static void unpack_QKV(char* pack_qkv, char* unpack_qkv, int mNumHead, int mHeadDim, int unit, int seq_len) { | ||||
| static void unpack_QKV(int8_t* pack_qkv, int8_t* unpack_qkv, int mNumHead, int mHeadDim, int mPack, int seq_len) { | ||||
|     auto src_ptr = reinterpret_cast<T*>(pack_qkv); | ||||
|     auto dst_ptr = reinterpret_cast<T*>(unpack_qkv); | ||||
|     for (int i = 0; i < seq_len; i++) { | ||||
|         for (int j = 0; j < mHeadDim; j++) { | ||||
|             int a = j / unit; | ||||
|             int b = j % unit; | ||||
|             dst_ptr[i * mNumHead * mHeadDim + j] = src_ptr[a * seq_len * unit + i * unit + b]; | ||||
|             int a = j / mPack; | ||||
|             int b = j % mPack; | ||||
|             dst_ptr[i * mNumHead * mHeadDim + j] = src_ptr[a * seq_len * mPack + i * mPack + b]; | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | @ -191,10 +211,10 @@ ErrorCode CPUAttention::onResize(const std::vector<Tensor*>& inputs, const std:: | |||
|     auto core = static_cast<CPUBackend *>(backend())->functions(); | ||||
|     core->MNNGetMatMulPackMode(&eP, &lP, &hP); | ||||
|     mThreadNum = ((CPUBackend *)backend())->threadNumber(); | ||||
|     unit  = core->pack; | ||||
|     mPack  = core->pack; | ||||
|     bytes = core->bytes; | ||||
|     int qkvQuantOptions = static_cast<CPUBackend *>(backend())->getRuntime()->hint().qkvQuantOption; | ||||
|     mUseGemmInt8 = (qkvQuantOptions == 4); | ||||
|     mUseGemmInt8 = (qkvQuantOptions % 8 == 4); | ||||
|     if (mUseGemmInt8) { | ||||
|         static_cast<CPUBackend*>(backend())->int8Functions()->MNNGetGemmUnit(&hP8, &lP8, &eP8); | ||||
|     } | ||||
|  | @ -208,7 +228,7 @@ ErrorCode CPUAttention::onResize(const std::vector<Tensor*>& inputs, const std:: | |||
|     if (mUseGemmInt8) { | ||||
|         mPackQ.reset(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(seq_len, eP8), UP_DIV(mHeadDim, lP8), eP8 * lP8})); | ||||
|         mSumQ.reset(Tensor::createDevice<int32_t>({mThreadNum, UP_DIV(seq_len, eP8), eP8})); | ||||
|         mPackQKV.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(mHeadDim, unit), seq_len, unit})); | ||||
|         mPackQKV.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack})); | ||||
|         backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC); | ||||
|         backend()->onAcquireBuffer(mSumQ.get(), Backend::DYNAMIC); | ||||
|         backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC); | ||||
|  | @ -220,18 +240,40 @@ ErrorCode CPUAttention::onResize(const std::vector<Tensor*>& inputs, const std:: | |||
|         mQueryScale.resize(mNumHead); | ||||
|         mQueryZeroPoint.resize(mNumHead); | ||||
|     } else { | ||||
|         mPackQ.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mHeadDim, lP), eP})); | ||||
|         mPackQKV.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(mHeadDim, unit), seq_len, unit})); | ||||
|         mPackQ.reset(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mHeadDim, lP), eP * bytes})); | ||||
|         mPackQKV.reset(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack * bytes})); | ||||
|         backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC); | ||||
|         backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC); | ||||
|          | ||||
|         // flash attention
 | ||||
|         if (qkvQuantOptions / 8 == 1) { | ||||
|             mRunningMax.reset(Tensor::createDevice<int8_t>({mThreadNum, seq_len * 4})); | ||||
|             mRunningSum.reset(Tensor::createDevice<int8_t>({mThreadNum, seq_len * 4})); | ||||
|             mExpfDiffMax.reset(Tensor::createDevice<int8_t>({mThreadNum, seq_len * 4})); | ||||
|             mTempOut.reset(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack * bytes})); | ||||
| 
 | ||||
|             backend()->onAcquireBuffer(mRunningMax.get(), Backend::DYNAMIC); | ||||
|             backend()->onAcquireBuffer(mRunningSum.get(), Backend::DYNAMIC); | ||||
|             backend()->onAcquireBuffer(mExpfDiffMax.get(), Backend::DYNAMIC); | ||||
|             backend()->onAcquireBuffer(mTempOut.get(), Backend::DYNAMIC); | ||||
|         } | ||||
| 
 | ||||
|         backend()->onReleaseBuffer(mPackQ.get(), Backend::DYNAMIC); | ||||
|         backend()->onReleaseBuffer(mPackQKV.get(), Backend::DYNAMIC); | ||||
| 
 | ||||
|         if (qkvQuantOptions / 8 == 1) { | ||||
|             backend()->onReleaseBuffer(mRunningMax.get(), Backend::DYNAMIC); | ||||
|             backend()->onReleaseBuffer(mRunningSum.get(), Backend::DYNAMIC); | ||||
|             backend()->onReleaseBuffer(mExpfDiffMax.get(), Backend::DYNAMIC); | ||||
|             backend()->onReleaseBuffer(mTempOut.get(), Backend::DYNAMIC); | ||||
|         } | ||||
|     } | ||||
|     return NO_ERROR; | ||||
| } | ||||
| 
 | ||||
| ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) { | ||||
|     auto core  = static_cast<CPUBackend *>(backend())->functions(); | ||||
|     auto qkvQuantOptions = static_cast<CPUBackend *>(backend())->getRuntime()->hint().qkvQuantOption; | ||||
|     auto query = inputs[0]; | ||||
|     auto key   = inputs[1]; | ||||
|     auto value = inputs[2]; | ||||
|  | @ -283,146 +325,146 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std: | |||
|     int max_len = mKVCacheManager->maxLength(); | ||||
|     bool quant_key = mKVCacheManager->config()->mQuantKey; | ||||
|     bool quant_value = mKVCacheManager->config()->mQuantValue; | ||||
| 
 | ||||
|     mBlockKV = (qkvQuantOptions / 8 == 1) ? ALIMIN(MNN_FLASH_ATTENTION_BLOCK_SIZE, kv_seq_len) : kv_seq_len; | ||||
|     int32_t units[2] = {eP, lP}; | ||||
| 
 | ||||
|     // Temporary tensors for intermediate results
 | ||||
|     std::shared_ptr<Tensor> packQK(Tensor::createDevice<float>({mThreadNum, UP_DIV(kv_seq_len, unit), seq_len, unit})); | ||||
|     std::shared_ptr<Tensor> unpackQK(Tensor::createDevice<int32_t>({mThreadNum, seq_len, kv_seq_len})); | ||||
|     std::shared_ptr<Tensor> softmMaxQ(Tensor::createDevice<int32_t>({mThreadNum, seq_len, kv_seq_len})); | ||||
|     std::shared_ptr<Tensor> newPackQK(Tensor::createDevice<float>({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(kv_seq_len, lP), eP})); | ||||
|     std::shared_ptr<Tensor> dequantV(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), kv_seq_len, hP})); | ||||
|     backend()->onAcquireBuffer(packQK.get(), Backend::STATIC); | ||||
|     std::shared_ptr<Tensor> unpackQK(Tensor::createDevice<int32_t>({mThreadNum, seq_len, mBlockKV})); | ||||
|     std::shared_ptr<Tensor> softmMaxQ(Tensor::createDevice<int32_t>({mThreadNum, seq_len, mBlockKV})); | ||||
|     std::shared_ptr<Tensor> newPackQK(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mBlockKV, lP), eP * bytes})); | ||||
|     std::shared_ptr<Tensor> dequantV(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mHeadDim, hP), kv_seq_len, hP * bytes})); | ||||
|     // mTempQKBlock.reset(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(mBlockKV, mPack), seq_len, mPack * bytes}));
 | ||||
|     std::shared_ptr<Tensor> tempQKBlock(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(mBlockKV, mPack), seq_len, mPack * bytes})); | ||||
|     backend()->onAcquireBuffer(unpackQK.get(), Backend::STATIC); | ||||
|     backend()->onAcquireBuffer(softmMaxQ.get(), Backend::STATIC); | ||||
|     backend()->onAcquireBuffer(newPackQK.get(), Backend::STATIC); | ||||
|     backend()->onAcquireBuffer(tempQKBlock.get(), Backend::STATIC); | ||||
|     if (quant_value) { | ||||
|         backend()->onAcquireBuffer(dequantV.get(), Backend::STATIC); | ||||
|         mKVCacheManager->onDequantValue(dequantV.get()); | ||||
|     } | ||||
|     const float* sinksPtr = sinks ? sinks->host<float>() : nullptr; | ||||
|     std::function<void(int)> mCompute = [=](int tId) { | ||||
|         auto pack_q      = mPackQ->host<char>() + tId * UP_DIV(seq_len, eP) * ROUND_UP(mHeadDim, lP) * eP * bytes; | ||||
|         auto pack_qk     = packQK->host<char>() + tId * UP_DIV(kv_seq_len, unit) * seq_len * unit * bytes; | ||||
|         char * sum_q     = nullptr; | ||||
|         auto unpack_qk   = unpackQK->host<float>() + tId * seq_len * kv_seq_len; | ||||
|         auto softmax_qk  = softmMaxQ->host<float>() + tId * seq_len * kv_seq_len; | ||||
|         auto new_pack_qk = newPackQK->host<char>() + tId * UP_DIV(seq_len, eP) * ROUND_UP(kv_seq_len, lP) * eP * bytes; | ||||
|         auto pack_qkv    = mPackQKV->host<char>() + tId * UP_DIV(mHeadDim, unit) * seq_len * unit * bytes; | ||||
|         auto qReordered      = mPackQ->host<int8_t>() + tId * mPackQ->stride(0); | ||||
|         auto qkPacked     = tempQKBlock->host<int8_t>() + tId * tempQKBlock->stride(0); | ||||
|         int8_t * sum_q     = nullptr; | ||||
|         auto qkFlatten   = unpackQK->host<float>() + tId * unpackQK->stride(0); | ||||
|         auto qkSoftmax  = softmMaxQ->host<float>() + tId * softmMaxQ->stride(0); | ||||
|         auto qkReordered = newPackQK->host<int8_t>() + tId * newPackQK->stride(0); | ||||
|         auto qkvPacked    = mPackQKV->host<int8_t>() + tId * mPackQKV->stride(0); | ||||
|         auto QxK         = quant_key ? core->MNNPackedMatMul_int8 : core->MNNPackedMatMul; | ||||
|         auto QxK_remain  = quant_key ? core->MNNPackedMatMulRemain_int8 : core->MNNPackedMatMulRemain; | ||||
| 
 | ||||
|         // Flash Attention
 | ||||
|         auto runningMax = mRunningMax ? (float*)(mRunningMax->host<int8_t>() + tId * mRunningMax->stride(0)) : nullptr; | ||||
|         auto runningSum = mRunningSum ? (float*)(mRunningSum->host<int8_t>() + tId * mRunningSum->stride(0)) : nullptr; | ||||
|         auto diffScale = mExpfDiffMax ? (float*)(mExpfDiffMax->host<int8_t>() + tId * mExpfDiffMax->stride(0)) : nullptr; | ||||
|         auto outputPacked = mTempOut ? mTempOut->host<int8_t>() + tId * mTempOut->stride(0) : qkvPacked; | ||||
|         int  head_index  = tId * tileCount; | ||||
|         int  kvBlocks = UP_DIV(kv_seq_len, mBlockKV); | ||||
| 
 | ||||
|         if (mUseGemmInt8) { | ||||
|             pack_q  = mPackQ->host<char>() + tId * UP_DIV(seq_len, eP8) * UP_DIV(mHeadDim, lP8) * eP8 * lP8; | ||||
|             sum_q   = mSumQ->host<char>() + tId * UP_DIV(seq_len, eP8) * eP8 * 4; | ||||
|             qReordered  = mPackQ->host<int8_t>() + tId * UP_DIV(seq_len, eP8) * UP_DIV(mHeadDim, lP8) * eP8 * lP8; | ||||
|             sum_q   = mSumQ->host<int8_t>() + tId * UP_DIV(seq_len, eP8) * eP8 * 4; | ||||
|         } | ||||
|         for (int h = head_index; h < head_index + tileCount && h < mNumHead; h++) { | ||||
|             int    kv_h            = h / group_size; | ||||
|             char * key_addr        = mKVCacheManager->addrOfKey(kv_h); | ||||
|             char * scale_addr      = mKVCacheManager->addrOfScale(kv_h); | ||||
|             char * zero_point_addr = mKVCacheManager->addrOfZeroPoint(kv_h); | ||||
|             char * key_sum_addr    = mKVCacheManager->addrOfKeySum(kv_h); | ||||
|             char * value_addr      = quant_value ? (dequantV->host<char>() + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(kv_seq_len, lP) * hP * bytes) : mKVCacheManager->addrOfValue(kv_h); | ||||
|             if (bytes == 2) { | ||||
|                 pack_query<FLOAT16_T>(query, pack_q, sum_q, seq_len, h, q_scale); | ||||
|             } else { | ||||
|                 pack_query<float>(query, pack_q, sum_q, seq_len, h, q_scale); | ||||
|             if (runningSum && runningMax) { | ||||
|                 memset(runningSum, 0, mRunningSum->stride(0)); | ||||
|                 if (sinksPtr == nullptr) { | ||||
|                     for (int k = 0; k < seq_len; ++k) { | ||||
|                         runningMax[k] = -std::numeric_limits<float>::infinity(); | ||||
|                     } | ||||
|                 } else { | ||||
|                     float sinkVal; | ||||
|                     if (bytes == 2) { | ||||
|                         sinkVal = ((FLOAT16_T*)sinksPtr)[h]; | ||||
|                     } else { | ||||
|                         sinkVal =sinksPtr[h]; | ||||
|                     } | ||||
|                     for (int k = 0; k < seq_len; ++k) { | ||||
|                         runningMax[k] = sinkVal; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             // query @ key
 | ||||
|             int    kv_h              = h / group_size; | ||||
|             int8_t * key_addr        = mKVCacheManager->addrOfKey(kv_h); | ||||
|             int8_t * scale_addr      = mKVCacheManager->addrOfScale(kv_h); | ||||
|             int8_t * zero_point_addr = mKVCacheManager->addrOfZeroPoint(kv_h); | ||||
|             int8_t * key_sum_addr    = mKVCacheManager->addrOfKeySum(kv_h); | ||||
|             int8_t * value_addr      = quant_value ? (dequantV->host<int8_t>() + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(kv_seq_len, lP) * hP * bytes) : mKVCacheManager->addrOfValue(kv_h); | ||||
|             if (mUseGemmInt8) { | ||||
|                 auto GemmInt8Kernel = static_cast<CPUBackend*>(backend())->int8Functions()->Int8GemmKernel; | ||||
|                 if (bytes == 2 && unit == 8) { | ||||
|                     GemmInt8Kernel = static_cast<CPUBackend*>(backend())->int8Functions()->MNNGemmInt8AddBiasScale_Unit_FP16; | ||||
|                 } | ||||
|                 std::vector<float> postScale(ROUND_UP(kv_seq_len, hP8), 0.0f); | ||||
|                 for (int i = 0; i < kv_seq_len; i++) { | ||||
|                     postScale[i] = ((float*)scale_addr)[i] * mQueryScale[h] * q_scale; | ||||
|                 } | ||||
|                 std::vector<float> weightQuantBias(ROUND_UP(kv_seq_len, hP8), 0.0f); | ||||
|                 for (int i = 0; i < kv_seq_len; i++) { | ||||
|                     weightQuantBias[i] = -((float*)scale_addr)[i] * ((float*)zero_point_addr)[i] * q_scale; | ||||
|                 } | ||||
|                 std::vector<float> biasFloat(ROUND_UP(kv_seq_len, hP8), 0.0f); | ||||
|                 for (int i = 0; i < kv_seq_len; i++) { | ||||
|                     biasFloat[i] = -mQueryScale[h] * mQueryZeroPoint[h] * ((float*)key_sum_addr)[i] * q_scale; | ||||
|                 } | ||||
|                 QuanPostTreatParameters post; | ||||
|                 post.bias = nullptr; | ||||
|                 post.biasFloat = biasFloat.data(); | ||||
|                 post.blockNum = 1; | ||||
|                 post.inputBias = nullptr; | ||||
|                 post.inputScale = nullptr; | ||||
|                 post.fp32minmax = nullptr; | ||||
|                 post.scale = postScale.data(); | ||||
|                 post.useInt8 = false; | ||||
|                 post.weightKernelSum = weightQuantBias.data(); | ||||
|                 int N = UP_DIV(seq_len, eP8); | ||||
|                 for (int i = 0; i < N; i++) { | ||||
|                     int realcount = ALIMIN(eP8, seq_len - i * eP8); | ||||
|                     post.srcKernelSum = (float*)((char*)sum_q + i * eP8 * 4); | ||||
|                     GemmInt8Kernel( | ||||
|                         (int8_t*)pack_qk + i * eP8 * unit * bytes, | ||||
|                         (int8_t*)pack_q + i * ROUND_UP(mHeadDim, lP8) * eP8, | ||||
|                         (int8_t*)key_addr, | ||||
|                         UP_DIV(mHeadDim, lP8), | ||||
|                         seq_len * unit * bytes, | ||||
|                         UP_DIV(kv_seq_len, unit), | ||||
|                         &post, | ||||
|                         realcount | ||||
|                     ); | ||||
|                 if (bytes == 2) { | ||||
|                     pack_query<FLOAT16_T>(query, qReordered, sum_q, seq_len, h, q_scale); | ||||
|                 } else { | ||||
|                     pack_query<float>(query, qReordered, sum_q, seq_len, h, q_scale); | ||||
|                 } | ||||
|             } else { | ||||
|                 core->MNNAttenPackAndScaleSingleHead((float*)qReordered, (float*)(query->host<int8_t>() + h * mHeadDim * bytes), mHeadDim * mNumHead, &q_scale, units, seq_len, mHeadDim); | ||||
|             } | ||||
|             else { | ||||
|             for (int i = 0; i < kvBlocks; ++i) { | ||||
|                 int subKvSeqLen = ALIMIN(mBlockKV, kv_seq_len - i * mBlockKV); | ||||
|                 auto keyPtr = key_addr + i * UP_DIV(mBlockKV, hP) * ROUND_UP(mHeadDim, lP) * hP * bytes; | ||||
|                 auto valuePtr = value_addr + i * UP_DIV(mBlockKV, lP) * hP * lP * bytes; | ||||
|                 // query @ key
 | ||||
|                 { | ||||
|                     int loop_e = seq_len / eP; | ||||
|                     int remain = seq_len % eP; | ||||
|                     auto qStride0 = ROUND_UP(mHeadDim, lP) * eP * bytes; | ||||
|                     size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)mHeadDim, lP), (size_t)subKvSeqLen, (size_t)seq_len * mPack * bytes, 0, 0, 0}; | ||||
|                     for (int ei = 0 ; ei < loop_e; ei++) { | ||||
|                         QxK((float*)(qkPacked + (ei * eP * mPack) * bytes), (float*)(qReordered + ei * qStride0), (float*)keyPtr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr); | ||||
|                     } | ||||
|                     QxK_remain((float*)(qkPacked + (loop_e * eP * mPack) * bytes), (float*)(qReordered + loop_e * qStride0), (float*)keyPtr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr); | ||||
|                 } | ||||
|                 // qk: [kv_seq_len/mPack, seq_len, mPack] -> [seq_len/eP, kv_seq_len, eP]
 | ||||
|                 { | ||||
|                     if(bytes == 2) { | ||||
|                         if (seq_len == 1) { | ||||
|                             core->MNNLowpToFp32((int16_t*)qkPacked, qkFlatten, seq_len * subKvSeqLen); | ||||
|                         } else { | ||||
|                             core->MNNAttenUnpackAndConvertFp16(qkFlatten, (float*)qkPacked, subKvSeqLen, seq_len, mPack); | ||||
|                         } | ||||
|                         mask_QK<FLOAT16_T>(qkFlatten, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask, kv_seq_len - seq_len, i * mBlockKV, subKvSeqLen); | ||||
|                         softmaxQK<FLOAT16_T>(qkSoftmax, qkFlatten, runningMax, runningSum, diffScale, sinksPtr, core->MNNSoftmax, seq_len, subKvSeqLen, h, i == kvBlocks - 1); | ||||
|                         core->MNNAttenPackAndConvertFp32((float*)qkReordered, qkSoftmax, units, seq_len, subKvSeqLen); | ||||
|                     } else { | ||||
|                         if (seq_len > 1) { | ||||
|                             int32_t areaOffset[2] = {seq_len, seq_len}; | ||||
|                             core->MNNUnpackCUnitTranspose(qkFlatten, (float*)qkPacked, seq_len, subKvSeqLen, areaOffset); | ||||
|                         } else { | ||||
|                             memcpy(qkFlatten, qkPacked, subKvSeqLen * sizeof(float)); | ||||
|                         } | ||||
|                         mask_QK<float>(qkFlatten, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask, kv_seq_len - seq_len, i * mBlockKV, subKvSeqLen); | ||||
|                         softmaxQK<float>(qkSoftmax, qkFlatten, runningMax, runningSum, diffScale, sinksPtr, core->MNNSoftmax, seq_len, subKvSeqLen, h, i == kvBlocks - 1); | ||||
|                         packKvCache((float*)qkReordered, qkSoftmax, seq_len, subKvSeqLen, eP); | ||||
|                     } | ||||
|                 } | ||||
|                 // qk @ v
 | ||||
|                 // TODO: update qkvPacked using diffScale
 | ||||
|                 size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)subKvSeqLen, lP), (size_t)mHeadDim, (size_t)seq_len * mPack * bytes, 0, 0, 0}; | ||||
|                 size_t bExtraStride = (UP_DIV(max_len, lP) - UP_DIV(subKvSeqLen + i * mBlockKV, lP) + UP_DIV(i * mBlockKV, lP)) * hP * lP * bytes; | ||||
|                 shapeParameters[5] = quant_value ? 0 : bExtraStride; | ||||
|                 int loop_e = seq_len / eP; | ||||
|                 int remain = seq_len % eP; | ||||
|                 auto qStride0 = ROUND_UP(mHeadDim, lP) * eP * bytes; | ||||
|                 size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)mHeadDim, lP), (size_t)kv_seq_len, (size_t)seq_len * unit * bytes, 0, 0, 0}; | ||||
|                 for (int i = 0 ; i < loop_e; i++) { | ||||
|                     QxK((float*)(pack_qk + (i * eP * unit) * bytes), (float*)(pack_q + i * qStride0), (float*)key_addr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr); | ||||
|                 auto qkStride0 = ROUND_UP(subKvSeqLen, lP) * eP * bytes; | ||||
|                 for (int ei = 0 ; ei < loop_e; ei++) { | ||||
|                     core->MNNPackedMatMul((float*)(qkvPacked + (ei * eP * mPack) * bytes), (float*)(qkReordered + ei * qkStride0), (float*)valuePtr, shapeParameters, nullptr, nullptr, nullptr, nullptr); | ||||
|                 } | ||||
|                 QxK_remain((float*)(pack_qk + (loop_e * eP * unit) * bytes), (float*)(pack_q + loop_e * qStride0), (float*)key_addr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr); | ||||
|             } | ||||
|             // qk: [kv_seq_len/unit, seq_len, unit] -> [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len, eP]
 | ||||
|             if (sinksPtr != nullptr) { | ||||
|                 if(bytes == 2) { | ||||
|                     unpack_QK<FLOAT16_T>(unpack_qk, pack_qk, seq_len, kv_seq_len); | ||||
|                     mask_QK<FLOAT16_T>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask); | ||||
|                     sink_softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len, sinksPtr[h]); | ||||
|                     pack_QK<FLOAT16_T>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes); | ||||
|                 } else { | ||||
|                     unpack_QK<float>(unpack_qk, pack_qk, seq_len, kv_seq_len); | ||||
|                     mask_QK<float>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask); | ||||
|                     sink_softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len, sinksPtr[h]); | ||||
|                     pack_QK<float>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes); | ||||
|                 } | ||||
|             } else { | ||||
|                 if(bytes == 2) { | ||||
|                     unpack_QK<FLOAT16_T>(unpack_qk, pack_qk, seq_len, kv_seq_len); | ||||
|                     mask_QK<FLOAT16_T>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask); | ||||
|                     softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len); | ||||
|                     pack_QK<FLOAT16_T>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes); | ||||
|                 } else { | ||||
|                     unpack_QK<float>(unpack_qk, pack_qk, seq_len, kv_seq_len); | ||||
|                     mask_QK<float>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask); | ||||
|                     softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len); | ||||
|                     pack_QK<float>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes); | ||||
|                 core->MNNPackedMatMulRemain((float*)(qkvPacked + (loop_e * eP * mPack) * bytes), (float*)(qkReordered + loop_e * qkStride0), (float*)valuePtr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr); | ||||
| 
 | ||||
|                 if (runningMax != nullptr && runningSum != nullptr && diffScale != nullptr) { | ||||
|                     core->MNNFlashAttentionUpdateBlockOutput((float*)outputPacked, (float*)qkvPacked, diffScale, runningSum, UP_DIV(mHeadDim, mPack), seq_len, mPack, i, kvBlocks, mPackQKV->stride(0) / bytes, bytes); | ||||
|                 } | ||||
|             } | ||||
|             // qk @ v
 | ||||
|             size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)kv_seq_len, lP), (size_t)mHeadDim, (size_t)seq_len * unit * bytes, 0, 0, 0}; | ||||
|             size_t bExtraStride = (UP_DIV(max_len, lP) - UP_DIV(kv_seq_len, lP)) * hP * lP * bytes; | ||||
|             shapeParameters[5] = quant_value ? 0 : bExtraStride; | ||||
|             int loop_e = seq_len / eP; | ||||
|             int remain = seq_len % eP; | ||||
|             auto qkStride0 = ROUND_UP(kv_seq_len, lP) * eP * bytes; | ||||
|             for (int i = 0 ; i < loop_e; i++) { | ||||
|                 core->MNNPackedMatMul((float*)(pack_qkv + (i * eP * unit) * bytes), (float*)(new_pack_qk + i * qkStride0), (float*)value_addr, shapeParameters, nullptr, nullptr, nullptr, nullptr); | ||||
|             } | ||||
|             core->MNNPackedMatMulRemain((float*)(pack_qkv + (loop_e * eP * unit) * bytes), (float*)(new_pack_qk + loop_e * qkStride0), (float*)value_addr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr); | ||||
|             // unpack: [head_dim/unit, seq_len, unit] -> [seq_len, num_head, head_dim]
 | ||||
|             auto dst_ptr = outputs[0]->host<char>() + h * mHeadDim * bytes; | ||||
|             // unpack: [head_dim/mPack, seq_len, mPack] -> [seq_len, num_head, head_dim]
 | ||||
|             auto dst_ptr = outputs[0]->host<int8_t>() + h * mHeadDim * bytes; | ||||
|             if (bytes == 2) { | ||||
|                 unpack_QKV<int16_t>(pack_qkv, dst_ptr, mNumHead, mHeadDim, unit, seq_len); | ||||
|                 unpack_QKV<int16_t>((int8_t*)outputPacked, dst_ptr, mNumHead, mHeadDim, mPack, seq_len); | ||||
|             } else { | ||||
|                 unpack_QKV<float>(pack_qkv, dst_ptr, mNumHead, mHeadDim, unit, seq_len); | ||||
|                 unpack_QKV<float>((int8_t*)outputPacked, dst_ptr, mNumHead, mHeadDim, mPack, seq_len); | ||||
|             } | ||||
|              | ||||
|         } | ||||
|     }; | ||||
| 
 | ||||
|  | @ -431,10 +473,10 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std: | |||
|     } | ||||
|     MNN_CONCURRENCY_END(); | ||||
| 
 | ||||
|     backend()->onReleaseBuffer(packQK.get(), Backend::STATIC); | ||||
|     backend()->onReleaseBuffer(unpackQK.get(), Backend::STATIC); | ||||
|     backend()->onReleaseBuffer(softmMaxQ.get(), Backend::STATIC); | ||||
|     backend()->onReleaseBuffer(newPackQK.get(), Backend::STATIC); | ||||
|     backend()->onReleaseBuffer(tempQKBlock.get(), Backend::STATIC); | ||||
|     if (quant_value){ | ||||
|         backend()->onReleaseBuffer(dequantV.get(), Backend::STATIC); | ||||
|     } | ||||
|  | @ -460,9 +502,19 @@ CPUAttention::CPUAttention(Backend *backend, bool kv_cache) : Execution(backend) | |||
|     mPackQKV.reset(Tensor::createDevice<float>({1, 1, 1, 1})); | ||||
|     MNN::KVCacheManager::KVCacheConfig kvconfig; | ||||
|     int qkvQuantOptions = static_cast<CPUBackend *>(backend)->getRuntime()->hint().qkvQuantOption; | ||||
|     kvconfig.mUseInt8Kernel = (qkvQuantOptions == 4); | ||||
|     kvconfig.mQuantKey   = (qkvQuantOptions == 4) || (qkvQuantOptions & 1); | ||||
|     kvconfig.mQuantValue = (qkvQuantOptions == 4) || ((qkvQuantOptions >> 1) & 1); | ||||
|     kvconfig.mUseInt8Kernel = (qkvQuantOptions % 8 == 4); | ||||
| 
 | ||||
|     // qkvQuantOption % 8:
 | ||||
|     // 0: Do not quantize
 | ||||
|     // 1: Only quantize key, use int8 asymmetric quantization
 | ||||
|     // 2: Only quantize value, use fp8 quantization
 | ||||
|     // 3: quantize both key and value
 | ||||
|     // 4: quantize query, key and value, and use gemm int8 kernel to compute K*V
 | ||||
| 
 | ||||
|     // qkvQuantOption / 8:
 | ||||
|     // 1: use flash attention
 | ||||
|     kvconfig.mQuantKey   = (qkvQuantOptions % 8 == 4) || (qkvQuantOptions % 8 == 1) || (qkvQuantOptions % 8 == 3); | ||||
|     kvconfig.mQuantValue = (qkvQuantOptions % 8 == 4) || (qkvQuantOptions % 8 == 2); | ||||
|     kvconfig.mKVCacheDir = static_cast<CPUBackend *>(backend)->getRuntime()->hint().kvcacheDirPath; | ||||
|     kvconfig.mKVCacheSizeLimit = static_cast<CPUBackend *>(backend)->getRuntime()->hint().kvcacheSizeLimit; | ||||
|     kvconfig.mExpandChunk = 64; | ||||
|  |  | |||
|  | @ -30,15 +30,16 @@ private: | |||
|     bool mKVCache        = true; | ||||
|     bool mUseGemmInt8    = false; | ||||
|     int bytes = 4; | ||||
|     int mThreadNum = 1;; | ||||
|     int eP, lP, hP, unit; // float matmul packing
 | ||||
|     int mThreadNum = 1; | ||||
|     int mBlockKV = 512; | ||||
|     int eP, lP, hP, mPack; // float matmul packing
 | ||||
|     int eP8, lP8, hP8;    // GemmInt8 packing
 | ||||
|     int mNumHead, mKvNumHead, mHeadDim; | ||||
|     std::shared_ptr<Tensor> mPackQ, mPackQKV, mSumQ; | ||||
|     std::shared_ptr<Tensor> mPackQ, mPackQKV, mSumQ, mRunningMax, mRunningSum, mTempQKBlock, mTempOut, mExpfDiffMax; | ||||
|     std::shared_ptr<KVCacheManager> mKVCacheManager = nullptr; | ||||
|     std::vector<float> mMinQ, mMaxQ, mQueryScale, mQueryZeroPoint; | ||||
|     template <typename T> void pack_query(Tensor* query, char* pack_q, char* sum_q, int seq_len, int h, float q_scale); | ||||
|     template <typename T> void unpack_QK(float * unpack_qk_dst, char * pack_qk_src, int seq_len, int kv_seq_len); | ||||
|     template <typename T> void pack_query(Tensor* query, int8_t* pack_q, int8_t* sum_q, int seq_len, int h, float q_scale); | ||||
|     template <typename T> void unpack_QK(float * unpack_qk_dst, int8_t * pack_qk_src, int seq_len, int kv_seq_len); | ||||
|     KVMeta* mMeta; | ||||
| }; | ||||
| 
 | ||||
|  |  | |||
|  | @ -509,21 +509,21 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p | |||
|         mDmaInfo.reset(new CPURuntime::DynamicAllocator); | ||||
|         mDmaInfo->mDynamicAllocator.reset(mRuntime->createDynamicBufferAlloctor(0)); | ||||
|         mDmaInfo->mCurrentDynamicAllocator = mDmaInfo->mDynamicAllocator.get(); | ||||
|         mDmaInfo->mCacheGroup.resize(MNN_CPU_MAX_BUFFER_INDEX); | ||||
|         for (int i=0; i<mDmaInfo->mCacheGroup.size(); ++i) { | ||||
|             mDmaInfo->mCacheGroup[i].reset(new CPUResizeCache); | ||||
|         } | ||||
|     } else { | ||||
|         mDmaInfo = dynamicAlloc; | ||||
|     } | ||||
|     mPrecisionMode = precision; | ||||
|     mCoreFunctions = MNNGetCoreFunctions(); | ||||
|     mInt8CoreFunctions = MNNGetInt8CoreFunctions(); | ||||
|     mCacheGroup.resize(MNN_CPU_MAX_BUFFER_INDEX); | ||||
|     for (int i=0; i<mCacheGroup.size(); ++i) { | ||||
|         mCacheGroup[i].reset(new CPUResizeCache); | ||||
|     } | ||||
|     mCache = mCacheGroup[0].get(); | ||||
|     mCache = mDmaInfo->mCacheGroup[0].get(); | ||||
| } | ||||
| 
 | ||||
| CPUBackend::~CPUBackend() { | ||||
|     mCacheGroup.clear(); | ||||
|     // Do nothing
 | ||||
| } | ||||
| void CPUBackend::_resetDynamicMemory() const { | ||||
|     mRuntime->pCurrentStatus = mDmaInfo->mDynamicAllocator->apply(); | ||||
|  | @ -560,7 +560,7 @@ bool CPUBackend::onSelectDynamicAllocator(int index, int maxIndex) { | |||
|         mRuntime->buffer(0)->release(); | ||||
|         mDmaInfo->mCurrentDynamicAllocator = mDmaInfo->mDynamicAllocator.get(); | ||||
|     } | ||||
|     mCache = mCacheGroup[index].get(); | ||||
|     mCache = mDmaInfo->mCacheGroup[index].get(); | ||||
|     return true; | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -23,12 +23,14 @@ | |||
| 
 | ||||
| namespace MNN { | ||||
| class WorkerThread; | ||||
| class CPUResizeCache; | ||||
| class CPURuntime : public Runtime { | ||||
| public: | ||||
|     struct DynamicAllocator { | ||||
|         std::shared_ptr<BufferAllocator> mDynamicAllocator; | ||||
|         std::shared_ptr<BufferAllocator> mDynamicAllocatorBackup; | ||||
|         BufferAllocator* mCurrentDynamicAllocator = nullptr; | ||||
|         std::vector<std::shared_ptr<CPUResizeCache>> mCacheGroup; | ||||
|     }; | ||||
|     friend class CPUBackend; | ||||
|     CPURuntime(const Backend::Info& info); | ||||
|  | @ -82,7 +84,6 @@ struct CoreFunctions; | |||
| struct CoreInt8Functions; | ||||
| struct MatmulRelatedFunctions; | ||||
| 
 | ||||
| class CPUResizeCache; | ||||
| class CPUMemObj : public Backend::MemObj { | ||||
| public: | ||||
|     CPUMemObj(BufferAllocator* allocator, MemChunk chunk, int size) : mAllocator(allocator), mChunk(chunk), mSize(size) {} | ||||
|  | @ -199,7 +200,6 @@ private: | |||
|     BackendConfig::MemoryMode mMemory; | ||||
|     static std::map<OpType, CPUBackend::Creator*>* gCreator; | ||||
|     CPUResizeCache* mCache; | ||||
|     std::vector<std::shared_ptr<CPUResizeCache>> mCacheGroup; | ||||
| }; | ||||
| /** execution cast wrapper. insert tensor cast dynamic. */ | ||||
| class CastWrapExecution : public Execution { | ||||
|  |  | |||
|  | @ -170,7 +170,7 @@ cpu_mask_t MNNGetCPUMask(const std::vector<int>& cpuIds) { | |||
| 
 | ||||
| // cpuinfo
 | ||||
| // Reference from: https://github.com/pytorch/cpuinfo
 | ||||
| #if defined(ENABLE_ARMV82) && defined(__arm__) | ||||
| #if (defined(ENABLE_ARMV82) && defined(__arm__)) || (defined(__ANDROID__) && defined(__aarch64__)) | ||||
| 
 | ||||
| /* As per include/sys/system_properties.h in Android NDK */ | ||||
| #define CPUINFO_HARDWARE_VALUE_MAX 64 | ||||
|  | @ -1180,7 +1180,7 @@ struct cpuinfo_arm_chipset cpuinfo_arm_android_decode_chipset(const struct cpuin | |||
|     // MNN_PRINT("chipset vendor, series, model is: %d, %d, %d\n", chipset.vendor, chipset.series, chipset.model);
 | ||||
|     return chipset; | ||||
| } | ||||
| static void _getInfoARMv7(MNNCPUInfo* cpuinfo_isa) { | ||||
| static void _getInfoArm(MNNCPUInfo* cpuinfo_isa) { | ||||
|     // Get White List And Black List
 | ||||
|     struct cpuinfo_arm_linux_processor* arm_linux_processors = NULL; | ||||
|     if (0 == cpuinfo_isa->groups.size()) { | ||||
|  | @ -1500,8 +1500,8 @@ static void _fillInfo(MNNCPUInfo* cpuinfo_isa) { | |||
| #if defined(__aarch64__) | ||||
|     _getInfoAux(cpuinfo_isa); | ||||
| #endif | ||||
| #if defined(ENABLE_ARMV82) && defined(__arm__) | ||||
|     _getInfoARMv7(cpuinfo_isa); | ||||
| #if (defined(ENABLE_ARMV82) && defined(__arm__)) || (defined(__ANDROID__) && defined(__aarch64__)) | ||||
|     _getInfoArm(cpuinfo_isa); | ||||
| #endif // #ifdef arm / arm64
 | ||||
| #endif // #ifdef __linux__
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -200,7 +200,7 @@ int CPUSoftmax::_softmaxCommon(const uint8_t *srcData, uint8_t *dstData) { | |||
|             for (int v=0; v<mInside; ++v) { | ||||
|                 //TODO: Fix x86 compute error and use the same function
 | ||||
| #ifdef MNN_USE_SSE | ||||
|                 MNNSoftmax(workDst+v*mChannel, workSrc+v*mChannel, mChannel); | ||||
|                 MNNSoftmax(workDst+v*mChannel, workSrc+v*mChannel, nullptr, nullptr, nullptr, 1, mChannel); | ||||
| #else | ||||
|                 ___MNNSoftmax(workDst+v*mChannel, workSrc+v*mChannel, mChannel, mulFunction); | ||||
| #endif | ||||
|  |  | |||
|  | @ -75,13 +75,13 @@ void KVCacheManager::resetKVCacheFileSize(size_t keySize, size_t valueSize) { | |||
| void KVCacheManager::mmapKVCache(size_t keySize, size_t valueSize) | ||||
| { | ||||
|     if (mMapKeyAddr == nullptr) { | ||||
|         mMapKeyAddr = (char *)MNNMmapFile(mKeyCacheFD, keySize); | ||||
|         mMapKeyAddr = (int8_t *)MNNMmapFile(mKeyCacheFD, keySize); | ||||
|         if (mMapKeyAddr == nullptr) { | ||||
|             MNN_PRINT("Failed to memory-map the kvcache!\n"); | ||||
|         } | ||||
|     } | ||||
|     if (mMapValueAddr == nullptr) { | ||||
|         mMapValueAddr = (char *)MNNMmapFile(mValueCacheFD, valueSize); | ||||
|         mMapValueAddr = (int8_t *)MNNMmapFile(mValueCacheFD, valueSize); | ||||
|         if (mMapValueAddr == nullptr) { | ||||
|             MNN_PRINT("Failed to memory-map the kvcache!\n"); | ||||
|         } | ||||
|  | @ -111,8 +111,8 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) { | |||
|         mBackend->onAcquireBuffer(new_key, Backend::STATIC); | ||||
|         for (int h = 0; h < mKvNumHead; h++) { | ||||
|             memcpy( | ||||
|                 new_key->host<char>() + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, | ||||
|                 mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, | ||||
|                 new_key->host<int8_t>() + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, | ||||
|                 mPastKey->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, | ||||
|                 UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8 | ||||
|             ); | ||||
|         } | ||||
|  | @ -123,8 +123,8 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) { | |||
|         mBackend->onAcquireBuffer(new_key, Backend::STATIC); | ||||
|         for (int h = 0; h < mKvNumHead; h++) { | ||||
|             memcpy( | ||||
|                 new_key->host<char>() + h * new_key->stride(0), | ||||
|                 mPastKey->host<char>() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP), | ||||
|                 new_key->host<int8_t>() + h * new_key->stride(0), | ||||
|                 mPastKey->host<int8_t>() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP), | ||||
|                 ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) | ||||
|             ); | ||||
|         } | ||||
|  | @ -135,12 +135,12 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) { | |||
|         mBackend->onAcquireBuffer(new_key, Backend::STATIC); | ||||
|         for (int h = 0; h < mKvNumHead; h++) { | ||||
|             memcpy( | ||||
|                 new_key->host<char>() + h * new_key->stride(0) * mBytes, | ||||
|                 mPastKey->host<char>() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes, | ||||
|                 new_key->host<int8_t>() + h * new_key->stride(0) * mBytes, | ||||
|                 mPastKey->host<int8_t>() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes, | ||||
|                 ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes | ||||
|             ); | ||||
|             if ((new_key->stride(0) - mPastKey->stride(0)) > 0) { | ||||
|                 memset(new_key->host<char>() + h * new_key->stride(0) * mBytes + mPastKey->stride(0) * mBytes, 0, (new_key->stride(0) - mPastKey->stride(0)) * mBytes); | ||||
|                 memset(new_key->host<int8_t>() + h * new_key->stride(0) * mBytes + mPastKey->stride(0) * mBytes, 0, (new_key->stride(0) - mPastKey->stride(0)) * mBytes); | ||||
|             } | ||||
|         } | ||||
|         mPastKey.reset(new_key); | ||||
|  | @ -152,8 +152,8 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) { | |||
|         for (int h = 0; h < mKvNumHead; h++) { | ||||
|             for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { | ||||
|                 memcpy( | ||||
|                     new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP, | ||||
|                     mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP, | ||||
|                     new_value->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP, | ||||
|                     mPastValue->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP, | ||||
|                     ROUND_UP(oldMaxLength, lP) * hP | ||||
|                 ); | ||||
|             } | ||||
|  | @ -166,12 +166,12 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) { | |||
|         for (int h = 0; h < mKvNumHead; h++) { | ||||
|             for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { | ||||
|                 memcpy( | ||||
|                     new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes, | ||||
|                     mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes, | ||||
|                     new_value->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes, | ||||
|                     mPastValue->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes, | ||||
|                     ROUND_UP(oldMaxLength, lP) * hP * mBytes | ||||
|                 ); | ||||
|                 if ((new_value->stride(1) - mPastValue->stride(1)) > 0) { | ||||
|                     memset(new_value->host<char>() + (h * new_value->stride(0) + i * new_value->stride(1)) * mBytes + mPastValue->stride(1) * mBytes, 0, (new_value->stride(1) - mPastValue->stride(1)) * mBytes); | ||||
|                     memset(new_value->host<int8_t>() + (h * new_value->stride(0) + i * new_value->stride(1)) * mBytes + mPastValue->stride(1) * mBytes, 0, (new_value->stride(1) - mPastValue->stride(1)) * mBytes); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | @ -189,7 +189,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) { | |||
|         for (int h = 0; h < mKvNumHead; h++) { | ||||
|             memcpy( | ||||
|                 mMapKeyAddr + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, | ||||
|                 mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, | ||||
|                 mPastKey->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, | ||||
|                 UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8 | ||||
|             ); | ||||
|         } | ||||
|  | @ -200,7 +200,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) { | |||
|         for (int h = 0; h < mKvNumHead; h++) { | ||||
|             memcpy( | ||||
|                 mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP, | ||||
|                 mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP, | ||||
|                 mPastKey->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP, | ||||
|                 UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP | ||||
|             ); | ||||
|         } | ||||
|  | @ -214,7 +214,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) { | |||
|         for (int h = 0; h < mKvNumHead; h++) { | ||||
|             memcpy( | ||||
|                 mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes, | ||||
|                 mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes, | ||||
|                 mPastKey->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes, | ||||
|                 UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes | ||||
|             ); | ||||
|         } | ||||
|  | @ -227,7 +227,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) { | |||
|             for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { | ||||
|                 memcpy( | ||||
|                     mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP, | ||||
|                     mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP, | ||||
|                     mPastValue->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP, | ||||
|                     ROUND_UP(oldMaxLength, lP) * hP | ||||
|                 ); | ||||
|             } | ||||
|  | @ -243,7 +243,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) { | |||
|             for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { | ||||
|                 memcpy( | ||||
|                     mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes, | ||||
|                     mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes, | ||||
|                     mPastValue->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes, | ||||
|                     ROUND_UP(oldMaxLength, lP) * hP * mBytes | ||||
|                 ); | ||||
|             } | ||||
|  | @ -282,8 +282,8 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o | |||
|         memset(old_value->host<uint8_t>(), 0, old_value->length(0) * old_value->stride(0) * mBytes); | ||||
|     } | ||||
|     mmapKVCache(oldKeySize, oldValueSize); | ||||
|     memcpy(old_key->host<char>(),   mMapKeyAddr,   oldKeySize); | ||||
|     memcpy(old_value->host<char>(), mMapValueAddr, oldValueSize); | ||||
|     memcpy(old_key->host<int8_t>(),   mMapKeyAddr,   oldKeySize); | ||||
|     memcpy(old_value->host<int8_t>(), mMapValueAddr, oldValueSize); | ||||
|     // Step 2: Resize the kvcache files and remap them
 | ||||
|     unmapKVCache(oldKeySize, oldValueSize); | ||||
|     resetKVCacheFileSize(keySize, valueSize); | ||||
|  | @ -293,7 +293,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o | |||
|         for (int h = 0; h < mKvNumHead; h++) { | ||||
|             memcpy( | ||||
|                 mMapKeyAddr + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, | ||||
|                 old_key->host<char>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, | ||||
|                 old_key->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, | ||||
|                 UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8 | ||||
|             ); | ||||
|         } | ||||
|  | @ -301,7 +301,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o | |||
|         for (int h = 0; h < mKvNumHead; h++) { | ||||
|             memcpy( | ||||
|                 mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP, | ||||
|                 old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP, | ||||
|                 old_key->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP, | ||||
|                 UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP | ||||
|             ); | ||||
|         } | ||||
|  | @ -309,7 +309,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o | |||
|         for (int h = 0; h < mKvNumHead; h++) { | ||||
|             memcpy( | ||||
|                 mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes, | ||||
|                 old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes, | ||||
|                 old_key->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes, | ||||
|                 UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes | ||||
|             ); | ||||
|         } | ||||
|  | @ -319,7 +319,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o | |||
|             for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { | ||||
|                 memcpy( | ||||
|                     mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP, | ||||
|                     old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP, | ||||
|                     old_value->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP, | ||||
|                     ROUND_UP(oldMaxLength, lP) * hP | ||||
|                 ); | ||||
|             } | ||||
|  | @ -329,7 +329,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o | |||
|             for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { | ||||
|                 memcpy( | ||||
|                     mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes, | ||||
|                     old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes, | ||||
|                     old_value->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes, | ||||
|                     ROUND_UP(oldMaxLength, lP) * hP * mBytes | ||||
|                 ); | ||||
|             } | ||||
|  | @ -460,9 +460,9 @@ void KVCacheManager::onRealloc(const KVMeta* meta) { | |||
|             mBackend->onAcquireBuffer(new_zeroPoint, Backend::STATIC); | ||||
|             mBackend->onAcquireBuffer(new_sum, Backend::STATIC); | ||||
|             for (int h = 0; h < mKvNumHead; h++) { | ||||
|                 memcpy(new_scale->host<char>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyScale->host<char>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); | ||||
|                 memcpy(new_zeroPoint->host<char>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyZeroPoint->host<char>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); | ||||
|                 memcpy(new_sum->host<char>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeySum->host<char>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); | ||||
|                 memcpy(new_scale->host<int8_t>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyScale->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); | ||||
|                 memcpy(new_zeroPoint->host<int8_t>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyZeroPoint->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); | ||||
|                 memcpy(new_sum->host<int8_t>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeySum->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); | ||||
|             } | ||||
|             mKeyScale.reset(new_scale); | ||||
|             mKeyZeroPoint.reset(new_zeroPoint); | ||||
|  | @ -473,8 +473,8 @@ void KVCacheManager::onRealloc(const KVMeta* meta) { | |||
|             mBackend->onAcquireBuffer(new_scale, Backend::STATIC); | ||||
|             mBackend->onAcquireBuffer(new_zeroPoint, Backend::STATIC); | ||||
|             for (int h = 0; h < mKvNumHead; h++) { | ||||
|                 memcpy(new_scale->host<char>() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyScale->host<char>() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes); | ||||
|                 memcpy(new_zeroPoint->host<char>() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyZeroPoint->host<char>() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes); | ||||
|                 memcpy(new_scale->host<int8_t>() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyScale->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes); | ||||
|                 memcpy(new_zeroPoint->host<int8_t>() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyZeroPoint->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes); | ||||
|             } | ||||
|             mKeyScale.reset(new_scale); | ||||
|             mKeyZeroPoint.reset(new_zeroPoint); | ||||
|  | @ -690,8 +690,8 @@ void KVCacheManager::onDequantValue(Tensor * dequantedValues) { | |||
|     int tileCount = UP_DIV(mKvNumHead, mThreadNum); | ||||
|     std::function<void(int)> dequant = [=](int tid) { | ||||
|         for (int kv_h = tid * tileCount; kv_h < (tid+1) * tileCount && kv_h < mKvNumHead; kv_h++) { | ||||
|             char * dst = dequantedValues->host<char>() + kv_h * UP_DIV(mHeadDim, hP) * mPastLength * hP * mBytes; | ||||
|             char * src = addrOfValue(kv_h); | ||||
|             int8_t * dst = dequantedValues->host<int8_t>() + kv_h * UP_DIV(mHeadDim, hP) * mPastLength * hP * mBytes; | ||||
|             int8_t * src = addrOfValue(kv_h); | ||||
|             for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { | ||||
|                 if (mBytes == 2) { | ||||
|                     core->MNNFp8ToFp16((uint16_t*)dst, (uint8_t*)src, mPastLength * hP); | ||||
|  |  | |||
|  | @ -46,8 +46,8 @@ private: | |||
|     std::shared_ptr<Tensor> mKeySum;                // numhead, [maxlen/hP8, hP8]
 | ||||
|     file_t mKeyCacheFD   = INVALID_FILE;            // The file descriptor of keys
 | ||||
|     file_t mValueCacheFD = INVALID_FILE;            // The file descriptor of values
 | ||||
|     char * mMapKeyAddr   = nullptr;                 // Memory-mapped address of keys
 | ||||
|     char * mMapValueAddr = nullptr;                 // Memory-mapped address of values
 | ||||
|     int8_t * mMapKeyAddr   = nullptr;                 // Memory-mapped address of keys
 | ||||
|     int8_t * mMapValueAddr = nullptr;                 // Memory-mapped address of values
 | ||||
|     bool mKVCacheInDisk  = false;                   // Whether the kvcache is in disk or in memory now
 | ||||
|     int  mPastLength     = 0;                       // Length of past kvcache
 | ||||
|     int  mMaxLength      = 0;                       // Capacity of current kvcache buffer (how many kv items can be stored at most)
 | ||||
|  | @ -104,15 +104,15 @@ public: | |||
|         return mMaxLength; | ||||
|     } | ||||
|     uint8_t* keyAddr() { | ||||
|         char * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host<char>(); | ||||
|         int8_t * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host<int8_t>(); | ||||
|         return (uint8_t*)baseAddr; | ||||
|     } | ||||
|     uint8_t* valudAddr() { | ||||
|         char * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host<char>(); | ||||
|         int8_t * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host<int8_t>(); | ||||
|         return (uint8_t*)baseAddr; | ||||
|     } | ||||
|     char * addrOfKey(int kv_h) { | ||||
|         char * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host<char>(); | ||||
|     int8_t * addrOfKey(int kv_h) { | ||||
|         int8_t * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host<int8_t>(); | ||||
|         if (mConfig.mUseInt8Kernel) { | ||||
|             return baseAddr + kv_h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8; | ||||
|         } else if (mConfig.mQuantKey) { | ||||
|  | @ -121,35 +121,35 @@ public: | |||
|             return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes; | ||||
|         } | ||||
|     } | ||||
|     char * addrOfValue(int kv_h) { | ||||
|         char * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host<char>(); | ||||
|     int8_t * addrOfValue(int kv_h) { | ||||
|         int8_t * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host<int8_t>(); | ||||
|         if (mConfig.mQuantValue) { | ||||
|             return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP; | ||||
|         } else { | ||||
|             return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP * mBytes; | ||||
|         } | ||||
|     } | ||||
|     char * addrOfScale(int kv_h) { | ||||
|     int8_t * addrOfScale(int kv_h) { | ||||
|         if (mConfig.mUseInt8Kernel) { | ||||
|             return mKeyScale->host<char>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; | ||||
|             return mKeyScale->host<int8_t>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; | ||||
|         } else if (mConfig.mQuantKey) { | ||||
|             return mKeyScale->host<char>() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes; | ||||
|             return mKeyScale->host<int8_t>() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes; | ||||
|         } else { | ||||
|             return nullptr; | ||||
|         } | ||||
|     } | ||||
|     char * addrOfZeroPoint(int kv_h) { | ||||
|     int8_t * addrOfZeroPoint(int kv_h) { | ||||
|         if (mConfig.mUseInt8Kernel) { | ||||
|             return mKeyZeroPoint->host<char>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; | ||||
|             return mKeyZeroPoint->host<int8_t>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; | ||||
|         } else if (mConfig.mQuantKey) { | ||||
|             return mKeyZeroPoint->host<char>() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes; | ||||
|             return mKeyZeroPoint->host<int8_t>() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes; | ||||
|         } else { | ||||
|             return nullptr; | ||||
|         } | ||||
|     } | ||||
|     char * addrOfKeySum(int kv_h) { | ||||
|     int8_t * addrOfKeySum(int kv_h) { | ||||
|         if (mConfig.mUseInt8Kernel) { | ||||
|             return mKeySum->host<char>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; | ||||
|             return mKeySum->host<int8_t>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; | ||||
|         }else { | ||||
|             return nullptr; | ||||
|         } | ||||
|  |  | |||
|  | @ -73,6 +73,386 @@ void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int32_t* dim) { | |||
|     } | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| #define EXP_APPROX_MIN_INPUT vdupq_n_f32(-88.0f) | ||||
| #define EXP_APPROX_MAX_INPUT vdupq_n_f32(88.0f) | ||||
| #define EXP_APPROX_LN2         vdupq_n_f32(0.69314718056f)  // ln(2)
 | ||||
| #define EXP_APPROX_LN2_INV     vdupq_n_f32(1.44269504089f)   // 1/ln(2)
 | ||||
| // Fourth-order polynomial approximation coefficients of exp(r):
 | ||||
| // P(x) = c4*x^4 + c3*x^3 + c2*x^2 + c1*x + c0
 | ||||
| #define EXP_APPROX_C4          vdupq_n_f32(0.0416624f) | ||||
| #define EXP_APPROX_C3          vdupq_n_f32(0.166665f) | ||||
| #define EXP_APPROX_C2          vdupq_n_f32(0.500000f) | ||||
| #define EXP_APPROX_C1          vdupq_n_f32(1.0f) | ||||
| #define EXP_APPROX_C0          vdupq_n_f32(1.0f) | ||||
| 
 | ||||
| #ifndef __aarch64__ | ||||
| static inline float32x4_t vrndaq_f32_compat(float32x4_t val) { | ||||
|     const float32x4_t v_zero = vdupq_n_f32(0.0f); | ||||
| 
 | ||||
|     float32x4_t v_truncated = vcvtq_f32_s32(vcvtq_s32_f32(val)); | ||||
| 
 | ||||
|     uint32x4_t v_is_positive_frac = vcgtq_f32(val, v_truncated); | ||||
|     uint32x4_t v_is_negative_frac = vcltq_f32(val, v_truncated); | ||||
| 
 | ||||
|     float32x4_t v_offset = vbslq_f32(v_is_positive_frac, vdupq_n_f32(1.0f), v_zero); | ||||
|     v_offset = vbslq_f32(v_is_negative_frac, vdupq_n_f32(-1.0f), v_offset); | ||||
| 
 | ||||
|     return vaddq_f32(v_truncated, v_offset); | ||||
| } | ||||
| 
 | ||||
| #endif | ||||
| 
 | ||||
| static inline float32x4_t expApprox(float32x4_t x) { | ||||
|     x = vminq_f32(vmaxq_f32(x, EXP_APPROX_MIN_INPUT), EXP_APPROX_MAX_INPUT); | ||||
|      | ||||
|     float32x4_t k_float; | ||||
|     float32x4_t r; | ||||
|     float32x4_t exp_r; | ||||
| 
 | ||||
| #if defined(__aarch64__) | ||||
|     // 1. x = k * ln(2) + r
 | ||||
|     k_float = vrndaq_f32(vmulq_f32(x, EXP_APPROX_LN2_INV)); | ||||
|      | ||||
|     // r = x - k * ln(2)
 | ||||
|     r = vfmsq_f32(x, k_float, EXP_APPROX_LN2); | ||||
| 
 | ||||
|     // 2. c0 + r*(c1 + r*(c2 + r*(c3 + r*c4)))  (Horner's method)
 | ||||
|     exp_r = vfmaq_f32(EXP_APPROX_C3, EXP_APPROX_C4, r); // c3 + c4*r
 | ||||
|     exp_r = vfmaq_f32(EXP_APPROX_C2, exp_r, r);         // c2 + r*(...)
 | ||||
|     exp_r = vfmaq_f32(EXP_APPROX_C1, exp_r, r);         // c1 + r*(...)
 | ||||
|     exp_r = vfmaq_f32(EXP_APPROX_C0, exp_r, r);         // c0 + r*(...)
 | ||||
| 
 | ||||
| #else | ||||
| 
 | ||||
|     k_float = vrndaq_f32_compat(vmulq_f32(x, EXP_APPROX_LN2_INV)); | ||||
|      | ||||
| 
 | ||||
|     r = vsubq_f32(x, vmulq_f32(k_float, EXP_APPROX_LN2)); | ||||
| 
 | ||||
|     // 2. c0 + r*(c1 + r*(c2 + r*(c3 + r*c4)))
 | ||||
|     exp_r = vmlaq_f32(EXP_APPROX_C3, EXP_APPROX_C4, r); // c3 + c4*r
 | ||||
|     exp_r = vmlaq_f32(EXP_APPROX_C2, exp_r, r);         // c2 + r*(...)
 | ||||
|     exp_r = vmlaq_f32(EXP_APPROX_C1, exp_r, r);         // c1 + r*(...)
 | ||||
|     exp_r = vmlaq_f32(EXP_APPROX_C0, exp_r, r);         // c0 + r*(...)
 | ||||
| 
 | ||||
| #endif | ||||
| 
 | ||||
|     int32x4_t k_int = vcvtq_s32_f32(k_float); | ||||
|     int32x4_t k_shifted = vshlq_n_s32(k_int, 23); | ||||
|     float32x4_t result = vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(exp_r), k_shifted));    | ||||
|     return result; | ||||
| } | ||||
| 
 | ||||
| void MNNExpC8(float* dst, const float* src, float* offset, const float* parameters, size_t countC8) { | ||||
|     float32x4_t maxVec = vdupq_n_f32(offset[2]); | ||||
|     float32x4_t sumVec0 = vdupq_n_f32(0); | ||||
|     float32x4_t sumVec1 = vdupq_n_f32(0); | ||||
| 
 | ||||
|     float32x4_t c0 = vdupq_n_f32(offset[0]); | ||||
|     float32x4_t c1 = vdupq_n_f32(offset[1]); | ||||
| 
 | ||||
| 
 | ||||
|     for (int i = 0; i < countC8; ++i) { | ||||
|         float32x4_t srcVec0 = vld1q_f32(src); | ||||
|         float32x4_t srcVec1 = vld1q_f32(src + 4); | ||||
|         auto subVec0 = vaddq_f32(vmulq_f32(srcVec0, c0), maxVec); | ||||
|         auto subVec1 = vaddq_f32(vmulq_f32(srcVec1, c0), maxVec); | ||||
|         auto expVec0 = vaddq_f32(expApprox(subVec0), c1); | ||||
|         auto expVec1 = vaddq_f32(expApprox(subVec1), c1); | ||||
|         vst1q_f32(dst, expVec0); | ||||
|         vst1q_f32(dst + 4, expVec1); | ||||
|         sumVec0 = vaddq_f32(sumVec0, expVec0); | ||||
|         sumVec1 = vaddq_f32(sumVec1, expVec1); | ||||
| 
 | ||||
|         src += 8; | ||||
|         dst += 8; | ||||
|     } | ||||
| 
 | ||||
|     sumVec0 = vaddq_f32(sumVec0, sumVec1); | ||||
|     float32x2_t sumP = vpadd_f32(vget_low_f32(sumVec0), vget_high_f32(sumVec0)); | ||||
|     sumP = vpadd_f32(sumP, sumP); | ||||
|     offset[3] += vget_lane_f32(sumP, 0); | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| void MNNExp(float* destPtr, const float* srcPtr, float* offset, size_t size) { | ||||
|     float32x4_t maxVec = vdupq_n_f32(-offset[2]); | ||||
|     float32x4_t sumVec0 = vdupq_n_f32(0); | ||||
|     float32x4_t sumVec1 = vdupq_n_f32(0); | ||||
|     if (offset[0] == 1.f && offset[1] == 0.f) { | ||||
|         while (size >= 8) { | ||||
|             float32x4_t srcVec0 = vld1q_f32(srcPtr); | ||||
|             float32x4_t srcVec1 = vld1q_f32(srcPtr + 4); | ||||
|             auto subVec0 = vsubq_f32(srcVec0, maxVec); | ||||
|             auto subVec1 = vsubq_f32(srcVec1, maxVec); | ||||
|             auto expVec0 = expApprox(subVec0); | ||||
|             auto expVec1 = expApprox(subVec1); | ||||
|             vst1q_f32(destPtr, expVec0); | ||||
|             vst1q_f32(destPtr + 4, expVec1); | ||||
|             sumVec0 = vaddq_f32(sumVec0, expVec0); | ||||
|             sumVec1 = vaddq_f32(sumVec1, expVec1); | ||||
|             srcPtr += 8; | ||||
|             destPtr += 8; | ||||
|             size -= 8; | ||||
|              | ||||
|         } | ||||
|         while (size >= 4) { | ||||
|             float32x4_t srcVec0 = vld1q_f32(srcPtr); | ||||
|             auto subVec0 = vsubq_f32(srcVec0, maxVec); | ||||
|             auto expVec0 = expApprox(subVec0); | ||||
|             sumVec0 = vaddq_f32(sumVec0, expVec0); | ||||
|             vst1q_f32(destPtr, expVec0); | ||||
|             srcPtr += 4; | ||||
|             destPtr += 4; | ||||
|             size -= 4; | ||||
|         } | ||||
|         //merge
 | ||||
|         sumVec0 = vaddq_f32(sumVec0, sumVec1); | ||||
|         float32x2_t sumP = vpadd_f32(vget_low_f32(sumVec0), vget_high_f32(sumVec0)); | ||||
|         sumP = vpadd_f32(sumP, sumP); | ||||
|         auto newSum = vget_lane_f32(sumP, 0); | ||||
|         if (size > 0) { | ||||
|             float tmp[4]; | ||||
|             memcpy(tmp, srcPtr, size * sizeof(float)); | ||||
|             float32x4_t srcVec0 = vld1q_f32(tmp); | ||||
|             auto subVec0 = vsubq_f32(srcVec0, maxVec); | ||||
|             auto expVec0 = expApprox(subVec0); | ||||
|             vst1q_f32(tmp, expVec0); | ||||
|             for (int i = 0; i < size; ++i) { | ||||
|                 newSum += tmp[i]; | ||||
|                 destPtr[i] = tmp[i]; | ||||
|             } | ||||
|         } | ||||
|         offset[3] += newSum; | ||||
|     } else { | ||||
|         float32x4_t c0 = vdupq_n_f32(offset[0]); | ||||
|         float32x4_t c1 = vdupq_n_f32(offset[1]); | ||||
|         while (size >= 8) { | ||||
|             float32x4_t srcVec0 = vld1q_f32(srcPtr); | ||||
|             float32x4_t srcVec1 = vld1q_f32(srcPtr + 4); | ||||
|             auto subVec0 = vsubq_f32(vmulq_f32(srcVec0, c0), maxVec); | ||||
|             auto subVec1 = vsubq_f32(vmulq_f32(srcVec1, c0), maxVec); | ||||
|             auto expVec0 = vaddq_f32(expApprox(subVec0), c1); | ||||
|             auto expVec1 = vaddq_f32(expApprox(subVec1), c1); | ||||
|             vst1q_f32(destPtr, expVec0); | ||||
|             vst1q_f32(destPtr + 4, expVec1); | ||||
|             sumVec0 = vaddq_f32(sumVec0, expVec0); | ||||
|             sumVec1 = vaddq_f32(sumVec1, expVec1); | ||||
|             srcPtr += 8; | ||||
|             destPtr += 8; | ||||
|             size -= 8; | ||||
|              | ||||
|         } | ||||
|         while (size >= 4) { | ||||
|             float32x4_t srcVec0 = vld1q_f32(srcPtr); | ||||
|             auto subVec0 = vsubq_f32(vmulq_f32(srcVec0, c0), maxVec); | ||||
|             auto expVec0 = vaddq_f32(expApprox(subVec0), c1); | ||||
|             sumVec0 = vaddq_f32(sumVec0, expVec0); | ||||
|             vst1q_f32(destPtr, expVec0); | ||||
|             srcPtr += 4; | ||||
|             destPtr += 4; | ||||
|             size -= 4; | ||||
|         } | ||||
|         //merge
 | ||||
|         sumVec0 = vaddq_f32(sumVec0, sumVec1); | ||||
|         float32x2_t sumP = vpadd_f32(vget_low_f32(sumVec0), vget_high_f32(sumVec0)); | ||||
|         sumP = vpadd_f32(sumP, sumP); | ||||
|         auto newSum = vget_lane_f32(sumP, 0); | ||||
|         if (size > 0) { | ||||
|             float tmp[4]; | ||||
|             memcpy(tmp, srcPtr, size * sizeof(float)); | ||||
|             float32x4_t srcVec0 = vld1q_f32(tmp); | ||||
|             auto subVec0 = vsubq_f32(vmulq_f32(srcVec0, c0), maxVec); | ||||
|             auto expVec0 = vaddq_f32(expApprox(subVec0), c1); | ||||
|             vst1q_f32(tmp, expVec0); | ||||
|             for (int i = 0; i < size; ++i) { | ||||
|                 newSum += tmp[i]; | ||||
|                 destPtr[i] = tmp[i]; | ||||
|             } | ||||
|         } | ||||
|         offset[3] += newSum; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| static inline void transposeAndStore4x4(const float* srcRowPtrs[4], float* dstColBase, size_t dstColStride) { | ||||
|     float32x4_t row0 = vld1q_f32(srcRowPtrs[0]); | ||||
|     float32x4_t row1 = vld1q_f32(srcRowPtrs[1]); | ||||
|     float32x4_t row2 = vld1q_f32(srcRowPtrs[2]); | ||||
|     float32x4_t row3 = vld1q_f32(srcRowPtrs[3]); | ||||
| 
 | ||||
|     // Step 1: Transpose 2x2 blocks of 2-element vectors
 | ||||
|     float32x4x2_t t01 = vtrnq_f32(row0, row1); | ||||
|     float32x4x2_t t23 = vtrnq_f32(row2, row3); | ||||
| 
 | ||||
|     // Step 2: Combine the results to get the full transpose
 | ||||
|     float32x4_t col0 = vcombine_f32(vget_low_f32(t01.val[0]), vget_low_f32(t23.val[0])); | ||||
|     float32x4_t col1 = vcombine_f32(vget_low_f32(t01.val[1]), vget_low_f32(t23.val[1])); | ||||
|     float32x4_t col2 = vcombine_f32(vget_high_f32(t01.val[0]), vget_high_f32(t23.val[0])); | ||||
|     float32x4_t col3 = vcombine_f32(vget_high_f32(t01.val[1]), vget_high_f32(t23.val[1])); | ||||
| 
 | ||||
|     vst1q_f32(dstColBase, col0); | ||||
|     vst1q_f32(dstColBase + dstColStride, col1); | ||||
|     vst1q_f32(dstColBase + 2 * dstColStride, col2); | ||||
|     vst1q_f32(dstColBase + 3 * dstColStride, col3); | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| void packKvCache(float* dst, const float* src, size_t seqLen, size_t kvSeqLen, size_t eP) { | ||||
|     if (seqLen == 0 || kvSeqLen == 0) { | ||||
|         return; | ||||
|     } | ||||
| 
 | ||||
|     // source [seqLen, kvSeqLen]
 | ||||
|     // dest   [seqLen/eP, kvSeqLen, eP]
 | ||||
| 
 | ||||
|     const int kTileS = 4; // Tiling size for seqLen dimension
 | ||||
|     const int kTileK = 4; // Tiling size for kvSeqLen dimension
 | ||||
|     const size_t dstSOuterStride = kvSeqLen * eP; | ||||
| 
 | ||||
|     int s = 0; | ||||
|     for (; s + kTileS <= seqLen; s += kTileS) { | ||||
|         const int sOuter = s / eP; | ||||
|         const int sInner = s % eP; | ||||
|         if (sInner + kTileS > eP) { | ||||
|             break; | ||||
|         } | ||||
| 
 | ||||
|         float* dstSBase = dst + sOuter * dstSOuterStride + sInner; | ||||
|         const float* srcRowPtrs[kTileS]; | ||||
|         srcRowPtrs[0] = src + (s + 0) * kvSeqLen; | ||||
|         srcRowPtrs[1] = src + (s + 1) * kvSeqLen; | ||||
|         srcRowPtrs[2] = src + (s + 2) * kvSeqLen; | ||||
|         srcRowPtrs[3] = src + (s + 3) * kvSeqLen; | ||||
| 
 | ||||
|         int k = 0; | ||||
|         for (; k + kTileK <= kvSeqLen; k += kTileK) { | ||||
|             const float* currentSrcPtrs[kTileS]; | ||||
|             currentSrcPtrs[0] = srcRowPtrs[0] + k; | ||||
|             currentSrcPtrs[1] = srcRowPtrs[1] + k; | ||||
|             currentSrcPtrs[2] = srcRowPtrs[2] + k; | ||||
|             currentSrcPtrs[3] = srcRowPtrs[3] + k; | ||||
|             float* dstKBase = dstSBase + k * eP; | ||||
|             transposeAndStore4x4(currentSrcPtrs, dstKBase, eP); | ||||
|         } | ||||
| 
 | ||||
|         for (; k < kvSeqLen; ++k) { | ||||
|             float buffer[kTileS] = { | ||||
|                 srcRowPtrs[0][k], | ||||
|                 srcRowPtrs[1][k], | ||||
|                 srcRowPtrs[2][k], | ||||
|                 srcRowPtrs[3][k] | ||||
|             }; | ||||
|             vst1q_f32(dstSBase + k * eP, vld1q_f32(buffer)); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     for (; s < seqLen; ++s) { | ||||
|         const int sOuter = s / eP; | ||||
|         const int sInner = s % eP; | ||||
|         const float* srcRow = src + s * kvSeqLen; | ||||
|         float* dstSBase = dst + sOuter * dstSOuterStride + sInner; | ||||
|         for (int k = 0; k < kvSeqLen; ++k) { | ||||
|             dstSBase[k * eP] = srcRow[k]; | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) { | ||||
|     for (int k = 0; k < outside; ++k) { | ||||
|         auto source = input + k * reduceSize; | ||||
|         auto dest = softmaxDst + k * reduceSize; | ||||
| 
 | ||||
|         // new max
 | ||||
|         auto srcPtr = source; | ||||
|         auto size = reduceSize; | ||||
|         float32x4_t maxVec0 = vdupq_n_f32(source[0]); | ||||
|         auto maxVec1 = maxVec0; | ||||
| 
 | ||||
|         float oldMax = source[0]; | ||||
|         if (runningMax) { | ||||
|             oldMax = runningMax[k]; | ||||
|         } | ||||
| 
 | ||||
|         while (size >= 8) { | ||||
|             float32x4_t srcVec0 = vld1q_f32(srcPtr); | ||||
|             float32x4_t srcVec1 = vld1q_f32(srcPtr + 4); | ||||
| 
 | ||||
|             maxVec0 = vmaxq_f32(maxVec0, srcVec0); | ||||
|             maxVec1 = vmaxq_f32(maxVec1, srcVec1); | ||||
| 
 | ||||
|             srcPtr += 8; | ||||
|             size -= 8; | ||||
|         } | ||||
| 
 | ||||
|         while (size >= 4) { | ||||
|             float32x4_t srcVec0 = vld1q_f32(srcPtr); | ||||
|             maxVec0 = vmaxq_f32(maxVec0, srcVec0); | ||||
|             srcPtr += 4; | ||||
|             size -= 4; | ||||
|         } | ||||
| 
 | ||||
|         maxVec0 = vmaxq_f32(maxVec0, maxVec1); | ||||
|         float32x2_t maxP = vpmax_f32(vget_low_f32(maxVec0), vget_high_f32(maxVec0)); | ||||
|         maxP = vpmax_f32(maxP, maxP); | ||||
|         auto newMax = vget_lane_f32(maxP, 0); | ||||
| 
 | ||||
|         while (size > 0) { | ||||
|             newMax = ALIMAX(newMax, srcPtr[0]); | ||||
|             srcPtr += 1; | ||||
|             size -= 1; | ||||
|         } | ||||
| 
 | ||||
|         newMax = ALIMAX(oldMax, newMax); | ||||
|         srcPtr = source; | ||||
|         auto destPtr = dest; | ||||
|         size = reduceSize; | ||||
| 
 | ||||
|         float exprOffset[4] = { | ||||
|             1.0f, | ||||
|             0.0f, | ||||
|             0.0f, | ||||
|             0.0f | ||||
|         }; | ||||
|         exprOffset[2] = -newMax; | ||||
| 
 | ||||
|         // expf(xi-newmax) & new sum
 | ||||
|         MNNExp(destPtr, srcPtr, exprOffset, size); | ||||
| 
 | ||||
|         if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) { | ||||
|             // update runningSum, runningMax, scale=expf(oldMax-newMax)
 | ||||
|             float newSum = exprOffset[3]; | ||||
|             runningSum[k] = runningSum[k] * expf(oldMax - newMax) + newSum; | ||||
|             runningMax[k] = newMax; | ||||
|             updateScale[k] = expf(oldMax - newMax); | ||||
|         } else { | ||||
|             // Normalize
 | ||||
|             float sum = exprOffset[3]; | ||||
|             float scale = 1.0f / (sum + 1e-20f); | ||||
|             int count = reduceSize; | ||||
|             auto pDest = dest; | ||||
| 
 | ||||
|             float32x4_t scaleVec = vdupq_n_f32(scale); | ||||
|             while (count >= 4) { | ||||
|                 float32x4_t data = vld1q_f32(pDest); | ||||
|                 data = vmulq_f32(data, scaleVec); | ||||
|                 vst1q_f32(pDest, data); | ||||
|                  | ||||
|                 pDest += 4; | ||||
|                 count -= 4; | ||||
|             } | ||||
| 
 | ||||
|             while (count > 0) { | ||||
|                 *pDest *= scale; | ||||
|                 pDest++; | ||||
|                 count--; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| #ifndef MNN_USE_NEON | ||||
| 
 | ||||
| void MNNPackedSparseMatMulEpx1(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, unsigned int* NNZMap, int* dataOffsetMap) { | ||||
|  |  | |||
|  | @ -1,115 +0,0 @@ | |||
| // | ||||
| //  MNNExpC8.S | ||||
| //  MNN | ||||
| // | ||||
| //  Created by MNN on 2019/01/18. | ||||
| //  Copyright © 2018, Alibaba Group Holding Limited | ||||
| // | ||||
| 
 | ||||
| #ifdef __arm__ | ||||
| #ifndef __aarch64__ | ||||
| 
 | ||||
| #include "MNNAsmGlobal.h" | ||||
| .text | ||||
| .align 5
 | ||||
| //void MNNExpC8(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8) | ||||
| asm_function MNNExpC8 | ||||
| 
 | ||||
| //r0: dest, r1:source, r2: offset, r3:parameters, r4:countC8 | ||||
| push {r4, r5, lr} | ||||
| ldr r4, [sp, #12] | ||||
| vpush {q4-q7} | ||||
| vmov.i32 q7, #0 | ||||
| ldr r5, [r2, #0] | ||||
| vdup.32 q4, r5 // Alpha | ||||
| ldr r5, [r2, #4] | ||||
| vdup.32 q5, r5 // Beta | ||||
| ldr r5, [r2, #8] | ||||
| vdup.32 q6, r5 // Bias | ||||
| 
 | ||||
| 
 | ||||
| vld1.32 {q0, q1}, [r3] | ||||
| 
 | ||||
| vmov.i32 q2, #87 | ||||
| vcvt.f32.s32 q2, q2 | ||||
| vneg.f32 q3, q2 | ||||
| 
 | ||||
| Loop: | ||||
| 
 | ||||
| vld1.32 {q8, q9}, [r1]! | ||||
| vmul.f32 q8, q8, q4 | ||||
| vmul.f32 q9, q9, q4 | ||||
| vadd.f32 q8, q8, q6 | ||||
| vadd.f32 q9, q9, q6 | ||||
| 
 | ||||
| vmin.f32 q8, q8, q2 | ||||
| vmin.f32 q9, q9, q2 | ||||
| vmax.f32 q10, q8, q3 | ||||
| vmax.f32 q11, q9, q3 | ||||
| 
 | ||||
| vmul.f32 q8, q10, d0[1] | ||||
| vmul.f32 q9, q11, d0[1] | ||||
| vcvt.s32.f32 q8, q8 | ||||
| vcvt.s32.f32 q9, q9 | ||||
| 
 | ||||
| vcvt.f32.s32 q12, q8 | ||||
| vcvt.f32.s32 q13, q9 | ||||
| 
 | ||||
| //q10, q11: t | ||||
| vmls.f32 q10, q12, d0[0] | ||||
| vmls.f32 q11, q13, d0[0] | ||||
| 
 | ||||
| vmul.f32 q10, q10, d1[0] | ||||
| vmul.f32 q11, q11, d1[0] | ||||
| 
 | ||||
| .macro MLA_TWO z0 z1 z2 z3 | ||||
| vdup.32 \z1, \z0 | ||||
| vmla.f32 \z1, \z2, \z3 | ||||
| .endm | ||||
| 
 | ||||
| MLA_TWO d3[0], q12, q10, d3[1] | ||||
| MLA_TWO d3[0], q13, q11, d3[1] | ||||
| MLA_TWO d2[1], q14, q10, q12 | ||||
| MLA_TWO d2[1], q15, q11, q13 | ||||
| MLA_TWO d2[0], q12, q10, q14 | ||||
| MLA_TWO d2[0], q13, q11, q15 | ||||
| MLA_TWO d1[1], q14, q10, q12 | ||||
| MLA_TWO d1[1], q15, q11, q13 | ||||
| MLA_TWO d1[1], q12, q10, q14 | ||||
| MLA_TWO d1[1], q13, q11, q15 | ||||
| 
 | ||||
| //q12, q13 is expRemain | ||||
| 
 | ||||
| vmul.f32 q12, q12, q12 | ||||
| vmul.f32 q13, q13, q13 | ||||
| vmul.f32 q12, q12, q12 | ||||
| vmul.f32 q13, q13, q13 | ||||
| 
 | ||||
| vshl.i32 q8, q8, #23 | ||||
| vshl.i32 q9, q9, #23 | ||||
| vadd.i32 q12, q12, q8 | ||||
| vadd.i32 q13, q13, q9 | ||||
| 
 | ||||
| vadd.f32 q12, q12, q5 | ||||
| vadd.f32 q13, q13, q5 | ||||
| vadd.f32 q7, q12, q7 | ||||
| 
 | ||||
| vst1.32 {q12, q13}, [r0]! | ||||
| vadd.f32 q7, q13, q7 | ||||
| 
 | ||||
| subs r4, r4, #1 | ||||
| bne Loop | ||||
| add r5, r2, #12 | ||||
| vld1.32 {d0[0]}, [r5] | ||||
| vadd.f32 d14, d14, d15 | ||||
| vtrn.32 d14, d15 | ||||
| vadd.f32 d14, d14, d15 | ||||
| vadd.f32 d0, d14, d0 | ||||
| vst1.32 {d0[0]}, [r5] | ||||
| 
 | ||||
| vpop {q4-q7} | ||||
| pop {r4, r5, pc} | ||||
| 
 | ||||
| 
 | ||||
| #endif | ||||
| #endif | ||||
|  | @ -1,270 +0,0 @@ | |||
| // | ||||
| //  MNNSoftmax.S | ||||
| //  MNN | ||||
| // | ||||
| //  Created by MNN on 2021/07/05. | ||||
| //  Copyright © 2018, Alibaba Group Holding Limited | ||||
| // | ||||
| 
 | ||||
| #ifdef __arm__ | ||||
| #ifndef __aarch64__ | ||||
| 
 | ||||
| #include "MNNAsmGlobal.h" | ||||
| .text | ||||
| .align 5
 | ||||
| //void MNNSoftmax(float* dest, const float* source, size_t size) | ||||
| asm_function MNNSoftmax | ||||
|         push    {r4, r5, r6, r7, r8, lr} | ||||
|         bic     r3, r2, #3 | ||||
|         lsrs    r4, r2, #2 | ||||
|         vpush.64        {d8, d9, d10, d11, d12, d13, d14, d15} | ||||
|         sub     sp, sp, #8 | ||||
|         beq     Loop_2 | ||||
|         vld1.32 {d16-d17}, [r1] | ||||
|         cmp     r4, #1 | ||||
|         beq     Loop_3 | ||||
|         add     lr, r1, #16 | ||||
|         mov     ip, #1 | ||||
| Loop_4: | ||||
|         vld1.32 {d18-d19}, [lr]! | ||||
|         add     ip, ip, #1 | ||||
|         cmp     r4, ip | ||||
|         vmax.f32        q8, q8, q9 | ||||
|         bne     Loop_4 | ||||
| Loop_3: | ||||
|         vmov.32 ip, d16[0] | ||||
|         vmov    s15, ip | ||||
|         vmov.32 ip, d16[1] | ||||
|         vmov    s12, ip | ||||
|         vmov.32 ip, d17[0] | ||||
|         vcmpe.f32       s15, s12 | ||||
|         vmov    s13, ip | ||||
|         vmov.32 ip, d17[1] | ||||
|         vmrs    APSR_nzcv, FPSCR | ||||
|         vmov    s14, ip | ||||
|         vmovle.f32      s15, s12 | ||||
|         vcmpe.f32       s13, s15 | ||||
|         vmrs    APSR_nzcv, FPSCR | ||||
|         vmovpl.f32      s15, s13 | ||||
|         vcmpe.f32       s14, s15 | ||||
|         vmrs    APSR_nzcv, FPSCR | ||||
|         vmovpl.f32      s15, s14 | ||||
|         cmp     r2, r3 | ||||
|         bls     Loop_25 | ||||
| Loop_24: | ||||
|         add     lr, r1, r3, lsl #2 | ||||
|         mov     ip, r3 | ||||
| Loop_11: | ||||
|         vldmia.32       lr!, {s14} | ||||
|         add     ip, ip, #1 | ||||
|         vcmpe.f32       s14, s15 | ||||
|         vmrs    APSR_nzcv, FPSCR | ||||
|         vmovpl.f32      s15, s14 | ||||
|         cmp     r2, ip | ||||
|         bhi     Loop_11 | ||||
|         cmp     r4, #0 | ||||
|         beq     Loop_8 | ||||
| Loop_25: | ||||
|         vmov.f32        q11, #0.0  @ v4sf
 | ||||
|         mov     r5, r1 | ||||
|         vldr    d14, Loop_54 | ||||
|         vldr    d15, Loop_54+8 | ||||
|         mov     lr, r0 | ||||
|         vldr    d12, Loop_54+16 | ||||
|         vldr    d13, Loop_54+24 | ||||
|         mov     ip, #0 | ||||
|         vmov.f32        q12, #1.0e+0  @ v4sf
 | ||||
|         vstr.32 s15, [sp, #4] | ||||
|         vmov.f32        q5, #5.0e-1  @ v4sf
 | ||||
|         vldr    d8, Loop_54+32 | ||||
|         vldr    d9, Loop_54+40 | ||||
|         vldr    d0, Loop_54+48 | ||||
|         vldr    d1, Loop_54+56 | ||||
|         vldr    d2, Loop_54+64 | ||||
|         vldr    d3, Loop_54+72 | ||||
|         vldr    d4, Loop_54+80 | ||||
|         vldr    d5, Loop_54+88 | ||||
|         vldr    d30, Loop_54+96 | ||||
|         vldr    d31, Loop_54+104 | ||||
|         vmov.i32        q14, #8388608  @ v4si
 | ||||
|         vdup.32 q13, d7[1] | ||||
| Loop_12: | ||||
|         vld1.32 {d20-d21}, [r5]! | ||||
|         add     ip, ip, #1 | ||||
|         vmov.i32        q3, #127  @ v4si
 | ||||
|         cmp     r4, ip | ||||
|         vsub.f32        q10, q10, q13 | ||||
|         vmax.f32        q10, q10, q15 | ||||
|         vmin.f32        q10, q10, q2 | ||||
|         vmul.f32        q9, q10, q6 | ||||
|         vcvt.s32.f32    q9, q9 | ||||
|         vcvt.f32.s32    q8, q9 | ||||
|         vadd.i32        q9, q9, q3 | ||||
|         vmul.f32        q8, q8, q7 | ||||
|         vmul.i32        q9, q9, q14 | ||||
|         vsub.f32        q10, q10, q8 | ||||
|         vmul.f32        q8, q1, q10 | ||||
|         vadd.f32        q8, q8, q0 | ||||
|         vmul.f32        q8, q8, q10 | ||||
|         vadd.f32        q8, q8, q4 | ||||
|         vmul.f32        q8, q8, q10 | ||||
|         vadd.f32        q8, q8, q5 | ||||
|         vmul.f32        q8, q8, q10 | ||||
|         vadd.f32        q8, q8, q12 | ||||
|         vmul.f32        q8, q8, q10 | ||||
|         vadd.f32        q8, q8, q12 | ||||
|         vmul.f32        q9, q9, q8 | ||||
|         vadd.f32        q11, q9, q11 | ||||
|         vst1.32 {d18-d19}, [lr]! | ||||
|         bgt     Loop_12 | ||||
|         vmov.32 ip, d22[0] | ||||
|         vldr.32 s15, [sp, #4] | ||||
|         cmp     r2, r3 | ||||
|         vmov    s12, ip | ||||
|         vmov.32 ip, d22[1] | ||||
|         vmov    s11, ip | ||||
|         vmov.32 ip, d23[0] | ||||
|         vadd.f32        s12, s12, s11 | ||||
|         vmov    s13, ip | ||||
|         vmov.32 ip, d23[1] | ||||
|         vadd.f32        s12, s12, s13 | ||||
|         vmov    s14, ip | ||||
|         vadd.f32        s12, s12, s14 | ||||
|         bls     Loop_52 | ||||
| Loop_26: | ||||
|         lsl     ip, r3, #2 | ||||
|         add     r5, r1, r2, lsl #2 | ||||
|         add     lr, r1, ip | ||||
|         movw    r7, #13877 | ||||
|         movt    r7, 179 | ||||
|         movw    r6, #55317 | ||||
|         movt    r6, 32310 | ||||
|         vldr.32 s11, Loop_54+112 | ||||
|         vldr.32 s10, Loop_54+116 | ||||
|         add     r1, r0, ip | ||||
|         vldr.32 s9, Loop_54+120 | ||||
|         vmov.f32        s8, #5.0e-1 | ||||
|         vldr.32 s5, Loop_54+124 | ||||
|         vldr.32 s6, Loop_54+128 | ||||
|         vldr.32 s7, Loop_54+132 | ||||
| Loop_17: | ||||
|         vldmia.32       lr!, {s14} | ||||
|         vmov    s13, r6 | ||||
|         vsub.f32        s14, s14, s15 | ||||
|         vcmpe.f32       s14, s11 | ||||
|         vmrs    APSR_nzcv, FPSCR | ||||
|         vmovle  s13, r7 | ||||
|         ble     Loop_14 | ||||
|         vcmpe.f32       s14, s10 | ||||
|         vmrs    APSR_nzcv, FPSCR | ||||
|         bmi     Loop_53 | ||||
| Loop_14: | ||||
|         vadd.f32        s12, s12, s13 | ||||
|         cmp     r5, lr | ||||
|         vstmia.32       r1!, {s13} | ||||
|         bne     Loop_17 | ||||
|         vmov.f32        s15, #1.0e+0 | ||||
|         cmp     r4, #0 | ||||
|         vdiv.f32        s14, s15, s12 | ||||
|         beq     Loop_20 | ||||
| Loop_23: | ||||
|         vdup.32 q9, d7[0] | ||||
|         mov     ip, r0 | ||||
|         mov     r1, #0 | ||||
| Loop_19: | ||||
|         vld1.32 {d16-d17}, [ip] | ||||
|         add     r1, r1, #1 | ||||
|         cmp     r4, r1 | ||||
|         vmul.f32        q8, q8, q9 | ||||
|         vst1.32 {d16-d17}, [ip]! | ||||
|         bgt     Loop_19 | ||||
|         cmp     r2, r3 | ||||
|         bls     Loop_1 | ||||
|         lsl     ip, r3, #2 | ||||
| Loop_20: | ||||
|         add     r0, r0, ip | ||||
| Loop_21: | ||||
|         vldr.32 s15, [r0] | ||||
|         add     r3, r3, #1 | ||||
|         cmp     r2, r3 | ||||
|         vmul.f32        s15, s15, s14 | ||||
|         vstmia.32       r0!, {s15} | ||||
|         bhi     Loop_21 | ||||
| Loop_1: | ||||
|         add     sp, sp, #8 | ||||
|         vldm    sp!, {d8-d15} | ||||
|         pop     {r4, r5, r6, r7, r8, pc} | ||||
| Loop_2: | ||||
|         vldr.32 s15, Loop_54+136 | ||||
|         cmp     r2, r3 | ||||
|         bhi     Loop_24 | ||||
| Loop_8: | ||||
|         cmp     r2, r3 | ||||
|         vldrhi.32       s12, Loop_54+136 | ||||
|         bhi     Loop_26 | ||||
|         b       Loop_1 | ||||
| Loop_52: | ||||
|         vmov.f32        s15, #1.0e+0 | ||||
|         cmp     r4, #0 | ||||
|         vdiv.f32        s14, s15, s12 | ||||
|         bne     Loop_23 | ||||
|         b       Loop_1 | ||||
| Loop_53: | ||||
|         vdiv.f32        s4, s14, s9 | ||||
|         vmov.f32        s13, #1.0e+0 | ||||
|         vcvt.s32.f32    s4, s4 | ||||
|         vcvt.f32.s32    s3, s4 | ||||
|         vmov    r8, s4  @ int
 | ||||
|         vmov.f32        s4, s7 | ||||
|         vmls.f32        s14, s3, s9 | ||||
|         vmov.f32        s3, s6 | ||||
|         add     r8, r8, #127 | ||||
|         lsl     r8, r8, #23 | ||||
|         vmla.f32        s3, s14, s5 | ||||
|         vmla.f32        s4, s3, s14 | ||||
|         vmov.f32        s3, s8 | ||||
|         vmla.f32        s3, s4, s14 | ||||
|         vmov.f32        s4, s13 | ||||
|         vmla.f32        s4, s3, s14 | ||||
|         vmla.f32        s13, s4, s14 | ||||
|         vmov    s14, r8 | ||||
|         vmul.f32        s13, s13, s14 | ||||
|         b       Loop_14 | ||||
| Loop_54: | ||||
|         .word   1060205080
 | ||||
|         .word   1060205080
 | ||||
|         .word   1060205080
 | ||||
|         .word   1060205080
 | ||||
|         .word   1069066811
 | ||||
|         .word   1069066811
 | ||||
|         .word   1069066811
 | ||||
|         .word   1069066811
 | ||||
|         .word   1042983595
 | ||||
|         .word   1042983595
 | ||||
|         .word   1042983595
 | ||||
|         .word   1042983595
 | ||||
|         .word   1026206379
 | ||||
|         .word   1026206379
 | ||||
|         .word   1026206379
 | ||||
|         .word   1026206379
 | ||||
|         .word   1007192201
 | ||||
|         .word   1007192201
 | ||||
|         .word   1007192201
 | ||||
|         .word   1007192201
 | ||||
|         .word   1118699520
 | ||||
|         .word   1118699520
 | ||||
|         .word   1118699520
 | ||||
|         .word   1118699520
 | ||||
|         .word   -1028784128 | ||||
|         .word   -1028784128 | ||||
|         .word   -1028784128 | ||||
|         .word   -1028784128 | ||||
|         .word   -1028784128 | ||||
|         .word   1118699520
 | ||||
|         .word   1060205080
 | ||||
|         .word   1007192201
 | ||||
|         .word   1026206379
 | ||||
|         .word   1042983595
 | ||||
|         .word   0
 | ||||
| #endif | ||||
| #endif | ||||
|  | @ -1,109 +0,0 @@ | |||
| // | ||||
| //  MNNExpC8.S | ||||
| //  MNN | ||||
| // | ||||
| //  Created by MNN on 2019/01/18. | ||||
| //  Copyright © 2018, Alibaba Group Holding Limited | ||||
| // | ||||
| 
 | ||||
| #ifdef __aarch64__ | ||||
| 
 | ||||
| #include "MNNAsmGlobal.h" | ||||
| .text | ||||
| .align 5
 | ||||
| 
 | ||||
| //void MNNExpC8(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8) | ||||
| asm_function MNNExpC8 | ||||
| 
 | ||||
| //x0: dest, x1:source, x2: offset, x3:parameters, x4:countC8 | ||||
| ldr w5, [x2, #0] | ||||
| ldr w6, [x2, #4] | ||||
| ldr w7, [x2, #8] | ||||
| 
 | ||||
| ld1 {v0.4s, v1.4s}, [x3] | ||||
| movi v2.4s, #23 | ||||
| movi v3.4s, #87 | ||||
| scvtf v3.4s, v3.4s | ||||
| fneg v4.4s, v3.4s | ||||
| dup v30.4s, w5 | ||||
| dup v31.4s, w6 | ||||
| dup v29.4s, w7 | ||||
| 
 | ||||
| // Summer | ||||
| movi v28.4s, #0 | ||||
| 
 | ||||
| Loop: | ||||
| 
 | ||||
| ld1 {v16.4s, v17.4s}, [x1], #32 | ||||
| fmul v16.4s, v16.4s, v30.4s | ||||
| fmul v17.4s, v17.4s, v30.4s | ||||
| fadd v16.4s, v16.4s, v29.4s | ||||
| fadd v17.4s, v17.4s, v29.4s | ||||
| fmin v16.4s, v16.4s, v3.4s | ||||
| fmin v17.4s, v17.4s, v3.4s | ||||
| fmax v18.4s, v16.4s, v4.4s | ||||
| fmax v19.4s, v17.4s, v4.4s | ||||
| 
 | ||||
| fmul v16.4s, v18.4s, v0.s[1] | ||||
| fmul v17.4s, v19.4s, v0.s[1] | ||||
| fcvtzs v16.4s, v16.4s | ||||
| fcvtzs v17.4s, v17.4s | ||||
| scvtf v20.4s, v16.4s | ||||
| scvtf v21.4s, v17.4s | ||||
| 
 | ||||
| //v18.4s, v19.4s: t | ||||
| fmls v18.4s, v20.4s, v0.s[0] | ||||
| fmls v19.4s, v21.4s, v0.s[0] | ||||
| 
 | ||||
| fmul v18.4s, v18.4s, v0.s[2] | ||||
| fmul v19.4s, v19.4s, v0.s[2] | ||||
| 
 | ||||
| .macro MLA_TWO z0 z1 z2 z3 | ||||
| dup \z1, \z0 | ||||
| fmla \z1, \z2, \z3 | ||||
| .endm | ||||
| 
 | ||||
| MLA_TWO v1.s[2], v20.4s, v18.4s, v1.s[3] | ||||
| MLA_TWO v1.s[2], v21.4s, v19.4s, v1.s[3] | ||||
| MLA_TWO v1.s[1], v22.4s, v18.4s, v20.4s | ||||
| MLA_TWO v1.s[1], v23.4s, v19.4s, v21.4s | ||||
| MLA_TWO v1.s[0], v20.4s, v18.4s, v22.4s | ||||
| MLA_TWO v1.s[0], v21.4s, v19.4s, v23.4s | ||||
| MLA_TWO v0.s[3], v22.4s, v18.4s, v20.4s | ||||
| MLA_TWO v0.s[3], v23.4s, v19.4s, v21.4s | ||||
| MLA_TWO v0.s[3], v20.4s, v18.4s, v22.4s | ||||
| MLA_TWO v0.s[3], v21.4s, v19.4s, v23.4s | ||||
| 
 | ||||
| //v20.4s, v21.4s is expRemain | ||||
| fmul v20.4s, v20.4s, v20.4s | ||||
| fmul v21.4s, v21.4s, v21.4s | ||||
| fmul v20.4s, v20.4s, v20.4s | ||||
| fmul v21.4s, v21.4s, v21.4s | ||||
| 
 | ||||
| ushl v16.4s, v16.4s, v2.4s | ||||
| ushl v17.4s, v17.4s, v2.4s | ||||
| add v20.4s, v20.4s, v16.4s | ||||
| add v21.4s, v21.4s, v17.4s | ||||
| 
 | ||||
| fadd v20.4s, v20.4s, v31.4s | ||||
| fadd v21.4s, v21.4s, v31.4s | ||||
| 
 | ||||
| st1 {v20.4s, v21.4s}, [x0], #32 | ||||
| fadd v28.4s, v28.4s, v20.4s | ||||
| fadd v28.4s, v28.4s, v21.4s | ||||
| 
 | ||||
| subs x4, x4, #1 | ||||
| bne Loop | ||||
| 
 | ||||
| // Bias | ||||
| add x7, x2, #12 | ||||
| ld1 {v27.s}[0], [x7] | ||||
| faddp v28.4s, v28.4s, v28.4s | ||||
| faddp v28.2s, v28.2s, v28.2s | ||||
| fadd v27.2s, v28.2s, v27.2s | ||||
| st1 {v27.s}[0], [x7] | ||||
| 
 | ||||
| ret | ||||
| 
 | ||||
| #endif | ||||
| 
 | ||||
|  | @ -1,204 +0,0 @@ | |||
| // | ||||
| //  MNNSoftmax.S | ||||
| //  MNN | ||||
| // | ||||
| //  Created by MNN on 2021/07/05. | ||||
| //  Copyright © 2018, Alibaba Group Holding Limited | ||||
| // | ||||
| 
 | ||||
| #ifdef __aarch64__ | ||||
| 
 | ||||
| #include "MNNAsmGlobal.h" | ||||
| .text | ||||
| .align 5
 | ||||
| 
 | ||||
| //void MNNSoftmax(float* dest, const float* source, size_t countC8) | ||||
| asm_function MNNSoftmax | ||||
|     stp x21, x22, [sp, #-32]! | ||||
|     stp x19, x20, [sp, #16] | ||||
|     sxtw    x8, w2 | ||||
|     lsr     w9, w2, #2 | ||||
|     and     x8, x8, #-4 | ||||
|     cbz     w9, Loop_5 | ||||
|     ldr     q0, [x1] | ||||
|     cmp     w9, #1 | ||||
|     b.eq    Loop_4 | ||||
|     add     x10, x1, #16 | ||||
|     sub     x11, x9, #1 | ||||
| Loop_3: | ||||
|     ldr     q1, [x10], #16 | ||||
|     subs    x11, x11, #1 | ||||
|     fmax    v0.4s, v0.4s, v1.4s | ||||
|     b.ne    Loop_3 | ||||
| Loop_4: | ||||
|     mov     s1, v0.s[1] | ||||
|     fcmp    s0, s1 | ||||
|     mov     s2, v0.s[2] | ||||
|     fcsel   s1, s0, s1, gt | ||||
|     fcmp    s1, s2 | ||||
|     fcsel   s1, s1, s2, gt | ||||
|     mov     s0, v0.s[3] | ||||
|     fcmp    s1, s0 | ||||
|     fcsel   s0, s1, s0, gt | ||||
|     cmp     w8, w2 | ||||
|     b.lo    Loop_6 | ||||
|     b       Loop_8 | ||||
| Loop_5: | ||||
|     fmov    s0, wzr | ||||
|     cmp     w8, w2 | ||||
|     b.hs    Loop_8 | ||||
| Loop_6: | ||||
|     add     x10, x1, w8, sxtw #2 | ||||
|     mov     w11, w8 | ||||
| Loop_7: | ||||
|     ldr     s1, [x10], #4 | ||||
|     add     w11, w11, #1 | ||||
|     fcmp    s0, s1 | ||||
|     fcsel   s0, s0, s1, gt | ||||
|     cmp     w11, w2 | ||||
|     b.lo    Loop_7 | ||||
| Loop_8: | ||||
|     cbz     w9, Loop_12 | ||||
|     mov     w10, #-1028784128 | ||||
|     mov     w12, #43579 | ||||
|     dup     v4.4s, w10 | ||||
|     mov     w10, #29208 | ||||
|     mov     w11, #1118699520 | ||||
|     movk    w12, #16312, lsl #16 | ||||
|     movk    w10, #48945, lsl #16 | ||||
|     dup     v5.4s, w11 | ||||
|     mov     w11, #34953 | ||||
|     dup     v6.4s, w12 | ||||
|     mov     w12, #43691 | ||||
|     dup     v7.4s, w10 | ||||
|     mov     w10, #43691 | ||||
|     movk    w11, #15368, lsl #16 | ||||
|     movk    w12, #15658, lsl #16 | ||||
|     movk    w10, #15914, lsl #16 | ||||
|     dup     v2.4s, v0.s[0] | ||||
|     movi    v1.16b, #0 | ||||
|     fmov    v3.4s, #1.0 | ||||
|     dup     v16.4s, w11 | ||||
|     dup     v17.4s, w12 | ||||
|     dup     v18.4s, w10 | ||||
|     movi    v19.4s, #63, lsl #24 | ||||
|     mov     x10, x9 | ||||
|     mov     x11, x1 | ||||
|     mov     x12, x0 | ||||
| Loop_10: | ||||
|     ldr     q20, [x11], #16 | ||||
|     subs    x10, x10, #1 | ||||
|     fsub    v20.4s, v20.4s, v2.4s | ||||
|     fmax    v20.4s, v20.4s, v4.4s | ||||
|     fmin    v20.4s, v20.4s, v5.4s | ||||
|     fmul    v21.4s, v20.4s, v6.4s | ||||
|     fcvtzs  v21.4s, v21.4s | ||||
|     scvtf   v22.4s, v21.4s | ||||
|     fmla    v20.4s, v22.4s, v7.4s | ||||
|     fmul    v22.4s, v20.4s, v16.4s | ||||
|     fadd    v22.4s, v22.4s, v17.4s | ||||
|     fmul    v22.4s, v20.4s, v22.4s | ||||
|     fadd    v22.4s, v22.4s, v18.4s | ||||
|     fmul    v22.4s, v20.4s, v22.4s | ||||
|     fadd    v22.4s, v22.4s, v19.4s | ||||
|     fmul    v22.4s, v20.4s, v22.4s | ||||
|     fadd    v22.4s, v22.4s, v3.4s | ||||
|     shl     v21.4s, v21.4s, #23 | ||||
|     fmul    v20.4s, v20.4s, v22.4s | ||||
|     add     v21.4s, v21.4s, v3.4s | ||||
|     fadd    v20.4s, v20.4s, v3.4s | ||||
|     fmul    v20.4s, v20.4s, v21.4s | ||||
|     fadd    v1.4s, v1.4s, v20.4s | ||||
|     str     q20, [x12], #16 | ||||
|     b.ne    Loop_10 | ||||
|     dup     v2.4s, v1.s[1] | ||||
|     dup     v3.4s, v1.s[2] | ||||
|     fadd    v2.4s, v1.4s, v2.4s | ||||
|     fadd    v2.4s, v3.4s, v2.4s | ||||
|     dup     v1.4s, v1.s[3] | ||||
|     fadd    v1.4s, v1.4s, v2.4s | ||||
|     b       Loop_13 | ||||
| Loop_12: | ||||
|     fmov    s1, wzr | ||||
| Loop_13: | ||||
|     cmp     w8, w2 | ||||
|     fmov    s2, #1.0 | ||||
|     b.hs    Loop_16 | ||||
|     lsl     x21, x8, #2 | ||||
|     mov     w12, #29208 | ||||
|     mov     w14, #34953 | ||||
|     mov     w15, #43691 | ||||
|     mov     w19, #43691 | ||||
|     mov     w10, #-1028784128 | ||||
|     mov     w11, #1118699520 | ||||
|     movk    w12, #16177, lsl #16 | ||||
|     mov     w13, #1065353216 | ||||
|     movk    w14, #15368, lsl #16 | ||||
|     movk    w15, #15658, lsl #16 | ||||
|     movk    w19, #15914, lsl #16 | ||||
|     add     x20, x1, x21 | ||||
|     add     x21, x0, x21 | ||||
|     fmov    s3, #0.5 | ||||
|     mov     w1, w8 | ||||
| Loop_15: | ||||
|     ldr     s4, [x20], #4 | ||||
|     fmov    s5, w10 | ||||
|     fmov    s6, w11 | ||||
|     fmov    s7, w12 | ||||
|     fsub    s4, s4, s0 | ||||
|     fmaxnm  s4, s4, s5 | ||||
|     fminnm  s4, s4, s6 | ||||
|     fdiv    s6, s4, s7 | ||||
|     fcvtzs  w3, s6 | ||||
|     scvtf   s6, w3 | ||||
|     fmul    s6, s6, s7 | ||||
|     fmov    s5, w14 | ||||
|     fsub    s4, s4, s6 | ||||
|     fmov    s7, w15 | ||||
|     fmul    s5, s4, s5 | ||||
|     fadd    s5, s5, s7 | ||||
|     fmov    s6, w19 | ||||
|     fmul    s5, s4, s5 | ||||
|     fadd    s5, s5, s6 | ||||
|     fmul    s5, s4, s5 | ||||
|     fadd    s5, s5, s3 | ||||
|     fmul    s5, s4, s5 | ||||
|     fadd    s5, s5, s2 | ||||
|     add     w3, w13, w3, lsl #23 | ||||
|     fmul    s4, s4, s5 | ||||
|     fmov    s7, w3 | ||||
|     fadd    s4, s4, s2 | ||||
|     add     w1, w1, #1 | ||||
|     fmul    s4, s4, s7 | ||||
|     cmp     w1, w2 | ||||
|     str     s4, [x21], #4 | ||||
|     fadd    s1, s1, s4 | ||||
|     b.lo    Loop_15 | ||||
| Loop_16: | ||||
|     fdiv    s0, s2, s1 | ||||
|     cbz     w9, Loop_19 | ||||
|     dup     v1.4s, v0.s[0] | ||||
|     mov     x10, x0 | ||||
| Loop_18: | ||||
|     ldr     q2, [x10] | ||||
|     subs    x9, x9, #1 | ||||
|     fmul    v2.4s, v1.4s, v2.4s | ||||
|     str     q2, [x10], #16 | ||||
|     b.ne    Loop_18 | ||||
| Loop_19: | ||||
|     cmp     w8, w2 | ||||
|     b.hs    Loop_22 | ||||
|     add     x9, x0, w8, sxtw #2 | ||||
|     Loop_21: | ||||
|     ldr     s1, [x9] | ||||
|     add     w8, w8, #1 | ||||
|     cmp     w8, w2 | ||||
|     fmul    s1, s0, s1 | ||||
|     str     s1, [x9], #4 | ||||
|     b.lo    Loop_21 | ||||
| Loop_22: | ||||
|     ldp x19, x20, [sp, #16] | ||||
|     ldp x21, x22, [sp], #32 | ||||
|     ret | ||||
| #endif | ||||
| 
 | ||||
|  | @ -379,10 +379,10 @@ LoopSzEnd_TILE_1: | |||
|     .inst 0x05b02324  // dup z4.q, z25.q[2] | ||||
|     .inst 0x05f02325  // dup z5.q, z25.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -404,10 +404,10 @@ LoopSzEnd_TILE_1: | |||
|     .inst 0x05f02325  // dup z5.q, z25.q[3] | ||||
|     .inst 0x05702346  // dup z6.q, z26.q[1] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -431,11 +431,11 @@ LoopSzEnd_TILE_1: | |||
|     .inst 0x05702346  // dup z6.q, z26.q[1] | ||||
|     .inst 0x05b02347  // dup z7.q, z26.q[2] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -461,11 +461,11 @@ TILE1_STORE48: | |||
|     .inst 0x05b02347  // dup z7.q, z26.q[2] | ||||
|     .inst 0x05f02348  // dup z8.q, z26.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -492,12 +492,12 @@ TILE1_STORE52: | |||
|     .inst 0x05b02347  // dup z7.q, z26.q[2] | ||||
|     .inst 0x05f02348  // dup z8.q, z26.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -526,12 +526,12 @@ TILE1_STORE56: | |||
|     .inst 0x05f02348  // dup z8.q, z26.q[3] | ||||
|     .inst 0x05702369  // dup z9.q, z27.q[1] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -562,13 +562,13 @@ TILE1_STORE60: | |||
|     .inst 0x05702369  // dup z9.q, z27.q[1] | ||||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -601,13 +601,13 @@ TILE1_STORE64: | |||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -640,15 +640,14 @@ TILE1_STORE68: | |||
|     .inst 0x05702369  // dup z9.q, z27.q[1] | ||||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -667,7 +666,7 @@ TILE1_STORE68: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     b TILE1_Dz_End | ||||
|  | @ -685,14 +684,15 @@ TILE1_STORE72: | |||
|     .inst 0x05702369  // dup z9.q, z27.q[1] | ||||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -711,7 +711,7 @@ TILE1_STORE72: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|  | @ -731,14 +731,15 @@ TILE1_STORE76: | |||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
|     .inst 0x05b0238d  // dup z13.q, z28.q[2] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -757,8 +758,8 @@ TILE1_STORE76: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|  | @ -779,14 +780,16 @@ TILE1_STORE80: | |||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
|     .inst 0x05b0238d  // dup z13.q, z28.q[2] | ||||
|     .inst 0x05f0238e  // dup z14.q, z28.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -805,12 +808,13 @@ TILE1_STORE80: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|     .inst 0xe400f5ed  // st1b {z13.b}, p5, [x15] | ||||
|     .inst 0xe40455ee  // st1b {z14.b}, p5, [x15, x4] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
| TILE1_STORE84: | ||||
|  | @ -827,14 +831,17 @@ TILE1_STORE84: | |||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
|     .inst 0x05b0238d  // dup z13.q, z28.q[2] | ||||
|     .inst 0x05f0238e  // dup z14.q, z28.q[3] | ||||
|     .inst 0x057023af  // dup z15.q, z29.q[1] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -853,12 +860,15 @@ TILE1_STORE84: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|     .inst 0xe400f5ed  // st1b {z13.b}, p5, [x15] | ||||
|     .inst 0xe40455ee  // st1b {z14.b}, p5, [x15, x4] | ||||
|     .inst 0xe400f51d  // st1b {z29.b}, p5, [x8] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
| TILE1_STORE88: | ||||
|  | @ -876,14 +886,16 @@ TILE1_STORE88: | |||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
|     .inst 0x05b0238d  // dup z13.q, z28.q[2] | ||||
|     .inst 0x05f0238e  // dup z14.q, z28.q[3] | ||||
|     .inst 0x057023af  // dup z15.q, z29.q[1] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -902,12 +914,16 @@ TILE1_STORE88: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|     .inst 0xe400f5ed  // st1b {z13.b}, p5, [x15] | ||||
|     .inst 0xe40455ee  // st1b {z14.b}, p5, [x15, x4] | ||||
|     .inst 0xe400f51d  // st1b {z29.b}, p5, [x8] | ||||
|     .inst 0xe404550f  // st1b {z15.b}, p5, [x8, x4] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
| TILE1_STORE92: | ||||
|  | @ -926,14 +942,17 @@ TILE1_STORE92: | |||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
|     .inst 0x05b0238d  // dup z13.q, z28.q[2] | ||||
|     .inst 0x05f0238e  // dup z14.q, z28.q[3] | ||||
|     .inst 0x057023af  // dup z15.q, z29.q[1] | ||||
|     .inst 0x05b023b0  // dup z16.q, z29.q[2] | ||||
|     .inst 0x05f023b1  // dup z17.q, z29.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -952,15 +971,18 @@ TILE1_STORE92: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x8, x23, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
|     add x11, x5, x4, lsl #3 // +22 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|     .inst 0xe400f5ed  // st1b {z13.b}, p5, [x15] | ||||
|     .inst 0xe40455ee  // st1b {z14.b}, p5, [x15, x4] | ||||
|     .inst 0xe400f51d  // st1b {z29.b}, p5, [x8] | ||||
|     .inst 0xe404550f  // st1b {z15.b}, p5, [x8, x4] | ||||
|     .inst 0xe400f570  // st1b {z16.b}, p5, [x11] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
| TILE1_STORE96: | ||||
|  | @ -980,14 +1002,16 @@ TILE1_STORE96: | |||
|     .inst 0x05b0238d  // dup z13.q, z28.q[2] | ||||
|     .inst 0x05f0238e  // dup z14.q, z28.q[3] | ||||
|     .inst 0x057023af  // dup z15.q, z29.q[1] | ||||
|     .inst 0x05b023b0  // dup z16.q, z29.q[2] | ||||
|     .inst 0x05f023b1  // dup z17.q, z29.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -1006,9 +1030,10 @@ TILE1_STORE96: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x8, x23, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
|     add x11, x5, x4, lsl #3 // +22 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|  | @ -1016,6 +1041,8 @@ TILE1_STORE96: | |||
|     .inst 0xe40455ee  // st1b {z14.b}, p5, [x15, x4] | ||||
|     .inst 0xe400f51d  // st1b {z29.b}, p5, [x8] | ||||
|     .inst 0xe404550f  // st1b {z15.b}, p5, [x8, x4] | ||||
|     .inst 0xe400f570  // st1b {z16.b}, p5, [x11] | ||||
|     .inst 0xe4045571  // st1b {z17.b}, p5, [x11, x4] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
| TILE1_STORE100: | ||||
|  | @ -1036,14 +1063,15 @@ TILE1_STORE100: | |||
|     .inst 0x05f0238e  // dup z14.q, z28.q[3] | ||||
|     .inst 0x057023af  // dup z15.q, z29.q[1] | ||||
|     .inst 0x05b023b0  // dup z16.q, z29.q[2] | ||||
|     .inst 0x05f023b1  // dup z17.q, z29.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -1062,10 +1090,11 @@ TILE1_STORE100: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x8, x23, x4, lsl #3 | ||||
|     add x11, x5, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
|     add x11, x5, x4, lsl #3 // +22 | ||||
|     add x13, x9, x4, lsl #3 // +24 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|  | @ -1074,6 +1103,8 @@ TILE1_STORE100: | |||
|     .inst 0xe400f51d  // st1b {z29.b}, p5, [x8] | ||||
|     .inst 0xe404550f  // st1b {z15.b}, p5, [x8, x4] | ||||
|     .inst 0xe400f570  // st1b {z16.b}, p5, [x11] | ||||
|     .inst 0xe4045571  // st1b {z17.b}, p5, [x11, x4] | ||||
|     .inst 0xe400f5be  // st1b {z30.b}, p5, [x13] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
| TILE1_STORE104: | ||||
|  | @ -1095,75 +1126,15 @@ TILE1_STORE104: | |||
|     .inst 0x057023af  // dup z15.q, z29.q[1] | ||||
|     .inst 0x05b023b0  // dup z16.q, z29.q[2] | ||||
|     .inst 0x05f023b1  // dup z17.q, z29.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|     .inst 0xe400f521  // st1b {z1.b}, p5, [x9] | ||||
|     .inst 0xe4045522  // st1b {z2.b}, p5, [x9, x4] | ||||
|     .inst 0xe400f5f9  // st1b {z25.b}, p5, [x15] | ||||
|     .inst 0xe40455e3  // st1b {z3.b}, p5, [x15, x4] | ||||
|     .inst 0xe400f504  // st1b {z4.b}, p5, [x8] | ||||
|     .inst 0xe4045505  // st1b {z5.b}, p5, [x8, x4] | ||||
|     .inst 0xe400f57a  // st1b {z26.b}, p5, [x11] | ||||
|     .inst 0xe4045566  // st1b {z6.b}, p5, [x11, x4] | ||||
|     .inst 0xe400f5a7  // st1b {z7.b}, p5, [x13] | ||||
|     .inst 0xe40455a8  // st1b {z8.b}, p5, [x13, x4] | ||||
|     .inst 0xe400f6fb  // st1b {z27.b}, p5, [x23] | ||||
|     .inst 0xe40456e9  // st1b {z9.b}, p5, [x23, x4] | ||||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x8, x23, x4, lsl #3 | ||||
|     add x11, x5, x4, lsl #3 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|     .inst 0xe400f5ed  // st1b {z13.b}, p5, [x15] | ||||
|     .inst 0xe40455ee  // st1b {z14.b}, p5, [x15, x4] | ||||
|     .inst 0xe400f51d  // st1b {z29.b}, p5, [x8] | ||||
|     .inst 0xe404550f  // st1b {z15.b}, p5, [x8, x4] | ||||
|     .inst 0xe400f570  // st1b {z16.b}, p5, [x11] | ||||
|     .inst 0xe4045571  // st1b {z17.b}, p5, [x11, x4] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
| TILE1_STORE108: | ||||
|     .inst 0x05702300  // dup z0.q, z24.q[1] | ||||
|     .inst 0x05b02301  // dup z1.q, z24.q[2] | ||||
|     .inst 0x05f02302  // dup z2.q, z24.q[3] | ||||
|     .inst 0x05702323  // dup z3.q, z25.q[1] | ||||
|     .inst 0x05b02324  // dup z4.q, z25.q[2] | ||||
|     .inst 0x05f02325  // dup z5.q, z25.q[3] | ||||
|     .inst 0x05702346  // dup z6.q, z26.q[1] | ||||
|     .inst 0x05b02347  // dup z7.q, z26.q[2] | ||||
|     .inst 0x05f02348  // dup z8.q, z26.q[3] | ||||
|     .inst 0x05702369  // dup z9.q, z27.q[1] | ||||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
|     .inst 0x05b0238d  // dup z13.q, z28.q[2] | ||||
|     .inst 0x05f0238e  // dup z14.q, z28.q[3] | ||||
|     .inst 0x057023af  // dup z15.q, z29.q[1] | ||||
|     .inst 0x05b023b0  // dup z16.q, z29.q[2] | ||||
|     .inst 0x05f023b1  // dup z17.q, z29.q[3] | ||||
|     .inst 0x057023d2  // dup z18.q, z30.q[1] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -1182,11 +1153,11 @@ TILE1_STORE108: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x8, x23, x4, lsl #3 | ||||
|     add x11, x5, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
|     add x11, x5, x4, lsl #3 // +22 | ||||
|     add x13, x9, x4, lsl #3 // +24 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|  | @ -1200,7 +1171,8 @@ TILE1_STORE108: | |||
|     .inst 0xe40455b2  // st1b {z18.b}, p5, [x13, x4] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
|     TILE1_STORE112: | ||||
|     /* oc=108 */ | ||||
|     TILE1_STORE108: | ||||
|     .inst 0x05702300  // dup z0.q, z24.q[1] | ||||
|     .inst 0x05b02301  // dup z1.q, z24.q[2] | ||||
|     .inst 0x05f02302  // dup z2.q, z24.q[3] | ||||
|  | @ -1222,13 +1194,13 @@ TILE1_STORE108: | |||
|     .inst 0x057023d2  // dup z18.q, z30.q[1] | ||||
|     .inst 0x05b023d3  // dup z19.q, z30.q[2] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -1247,12 +1219,12 @@ TILE1_STORE108: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x8, x23, x4, lsl #3 | ||||
|     add x11, x5, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
|     add x11, x5, x4, lsl #3 // +22 | ||||
|     add x13, x9, x4, lsl #3 // +24 | ||||
|     add x23, x15, x4, lsl #3 // +26 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|  | @ -1267,6 +1239,77 @@ TILE1_STORE108: | |||
|     .inst 0xe400f6f3  // st1b {z19.b}, p5, [x23] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
|     /* oc=112 */ | ||||
|     TILE1_STORE112: | ||||
|     .inst 0x05702300  // dup z0.q, z24.q[1] | ||||
|     .inst 0x05b02301  // dup z1.q, z24.q[2] | ||||
|     .inst 0x05f02302  // dup z2.q, z24.q[3] | ||||
|     .inst 0x05702323  // dup z3.q, z25.q[1] | ||||
|     .inst 0x05b02324  // dup z4.q, z25.q[2] | ||||
|     .inst 0x05f02325  // dup z5.q, z25.q[3] | ||||
|     .inst 0x05702346  // dup z6.q, z26.q[1] | ||||
|     .inst 0x05b02347  // dup z7.q, z26.q[2] | ||||
|     .inst 0x05f02348  // dup z8.q, z26.q[3] | ||||
|     .inst 0x05702369  // dup z9.q, z27.q[1] | ||||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
|     .inst 0x05b0238d  // dup z13.q, z28.q[2] | ||||
|     .inst 0x05f0238e  // dup z14.q, z28.q[3] | ||||
|     .inst 0x057023af  // dup z15.q, z29.q[1] | ||||
|     .inst 0x05b023b0  // dup z16.q, z29.q[2] | ||||
|     .inst 0x05f023b1  // dup z17.q, z29.q[3] | ||||
|     .inst 0x057023d2  // dup z18.q, z30.q[1] | ||||
|     .inst 0x05b023d3  // dup z19.q, z30.q[2] | ||||
|     .inst 0x05f023d4  // dup z20.q, z30.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|     .inst 0xe400f521  // st1b {z1.b}, p5, [x9] | ||||
|     .inst 0xe4045522  // st1b {z2.b}, p5, [x9, x4] | ||||
|     .inst 0xe400f5f9  // st1b {z25.b}, p5, [x15] | ||||
|     .inst 0xe40455e3  // st1b {z3.b}, p5, [x15, x4] | ||||
|     .inst 0xe400f504  // st1b {z4.b}, p5, [x8] | ||||
|     .inst 0xe4045505  // st1b {z5.b}, p5, [x8, x4] | ||||
|     .inst 0xe400f57a  // st1b {z26.b}, p5, [x11] | ||||
|     .inst 0xe4045566  // st1b {z6.b}, p5, [x11, x4] | ||||
|     .inst 0xe400f5a7  // st1b {z7.b}, p5, [x13] | ||||
|     .inst 0xe40455a8  // st1b {z8.b}, p5, [x13, x4] | ||||
|     .inst 0xe400f6fb  // st1b {z27.b}, p5, [x23] | ||||
|     .inst 0xe40456e9  // st1b {z9.b}, p5, [x23, x4] | ||||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
|     add x11, x5, x4, lsl #3 // +22 | ||||
|     add x13, x9, x4, lsl #3 // +24 | ||||
|     add x23, x15, x4, lsl #3 // +26 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|     .inst 0xe400f5ed  // st1b {z13.b}, p5, [x15] | ||||
|     .inst 0xe40455ee  // st1b {z14.b}, p5, [x15, x4] | ||||
|     .inst 0xe400f51d  // st1b {z29.b}, p5, [x8] | ||||
|     .inst 0xe404550f  // st1b {z15.b}, p5, [x8, x4] | ||||
|     .inst 0xe400f570  // st1b {z16.b}, p5, [x11] | ||||
|     .inst 0xe4045571  // st1b {z17.b}, p5, [x11, x4] | ||||
|     .inst 0xe400f5be  // st1b {z30.b}, p5, [x13] | ||||
|     .inst 0xe40455b2  // st1b {z18.b}, p5, [x13, x4] | ||||
|     .inst 0xe400f6f3  // st1b {z19.b}, p5, [x23] | ||||
|     .inst 0xe40456f4  // st1b {z20.b}, p5, [x23, x4] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
|     /* oc=116 */ | ||||
|     TILE1_STORE116: | ||||
|     .inst 0x05702300  // dup z0.q, z24.q[1] | ||||
|     .inst 0x05b02301  // dup z1.q, z24.q[2] | ||||
|  | @ -1290,13 +1333,13 @@ TILE1_STORE108: | |||
|     .inst 0x05b023d3  // dup z19.q, z30.q[2] | ||||
|     .inst 0x05f023d4  // dup z20.q, z30.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -1315,13 +1358,13 @@ TILE1_STORE108: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x8, x23, x4, lsl #3 | ||||
|     add x11, x5, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
|     add x11, x5, x4, lsl #3 // +22 | ||||
|     add x13, x9, x4, lsl #3 // +24 | ||||
|     add x23, x15, x4, lsl #3 // +26 | ||||
|     add x5, x8, x4, lsl #3   // +28 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|  | @ -1338,6 +1381,7 @@ TILE1_STORE108: | |||
|     .inst 0xe400f4bf  // st1b {z31.b}, p5, [x5] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
|     /* oc=120 */ | ||||
|     TILE1_STORE120: | ||||
|     .inst 0x05702300  // dup z0.q, z24.q[1] | ||||
|     .inst 0x05b02301  // dup z1.q, z24.q[2] | ||||
|  | @ -1362,13 +1406,13 @@ TILE1_STORE108: | |||
|     .inst 0x05f023d4  // dup z20.q, z30.q[3] | ||||
|     .inst 0x057023f5  // dup z21.q, z31.q[1] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -1387,13 +1431,13 @@ TILE1_STORE108: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x8, x23, x4, lsl #3 | ||||
|     add x11, x5, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
|     add x11, x5, x4, lsl #3 // +22 | ||||
|     add x13, x9, x4, lsl #3 // +24 | ||||
|     add x23, x15, x4, lsl #3 // +26 | ||||
|     add x5, x8, x4, lsl #3   // +28 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|  | @ -1411,6 +1455,7 @@ TILE1_STORE108: | |||
|     .inst 0xe40454b5  // st1b {z21.b}, p5, [x5, x4] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
|     /* oc=124 */ | ||||
|     TILE1_STORE124: | ||||
|     .inst 0x05702300  // dup z0.q, z24.q[1] | ||||
|     .inst 0x05b02301  // dup z1.q, z24.q[2] | ||||
|  | @ -1487,6 +1532,7 @@ TILE1_STORE108: | |||
|     .inst 0xe400f596  // st1b {z22.b}, p5, [x12] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
|     /* oc=128 */ | ||||
|     TILE1_STORE128: | ||||
|     .inst 0x05702300  // dup z0.q, z24.q[1] | ||||
|     .inst 0x05b02301  // dup z1.q, z24.q[2] | ||||
|  |  | |||
|  | @ -372,10 +372,10 @@ LoopSzEnd_TILE_1: | |||
|     .inst 0x05b02324  // dup z4.q, z25.q[2] | ||||
|     .inst 0x05f02325  // dup z5.q, z25.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -397,10 +397,10 @@ LoopSzEnd_TILE_1: | |||
|     .inst 0x05f02325  // dup z5.q, z25.q[3] | ||||
|     .inst 0x05702346  // dup z6.q, z26.q[1] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -424,11 +424,11 @@ LoopSzEnd_TILE_1: | |||
|     .inst 0x05702346  // dup z6.q, z26.q[1] | ||||
|     .inst 0x05b02347  // dup z7.q, z26.q[2] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -454,11 +454,11 @@ TILE1_STORE48: | |||
|     .inst 0x05b02347  // dup z7.q, z26.q[2] | ||||
|     .inst 0x05f02348  // dup z8.q, z26.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -485,12 +485,12 @@ TILE1_STORE52: | |||
|     .inst 0x05b02347  // dup z7.q, z26.q[2] | ||||
|     .inst 0x05f02348  // dup z8.q, z26.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -519,12 +519,12 @@ TILE1_STORE56: | |||
|     .inst 0x05f02348  // dup z8.q, z26.q[3] | ||||
|     .inst 0x05702369  // dup z9.q, z27.q[1] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -555,13 +555,13 @@ TILE1_STORE60: | |||
|     .inst 0x05702369  // dup z9.q, z27.q[1] | ||||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -594,13 +594,13 @@ TILE1_STORE64: | |||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -633,15 +633,14 @@ TILE1_STORE68: | |||
|     .inst 0x05702369  // dup z9.q, z27.q[1] | ||||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -660,7 +659,7 @@ TILE1_STORE68: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     b TILE1_Dz_End | ||||
|  | @ -678,14 +677,15 @@ TILE1_STORE72: | |||
|     .inst 0x05702369  // dup z9.q, z27.q[1] | ||||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -704,7 +704,7 @@ TILE1_STORE72: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|  | @ -724,14 +724,15 @@ TILE1_STORE76: | |||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
|     .inst 0x05b0238d  // dup z13.q, z28.q[2] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -750,8 +751,8 @@ TILE1_STORE76: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|  | @ -772,14 +773,16 @@ TILE1_STORE80: | |||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
|     .inst 0x05b0238d  // dup z13.q, z28.q[2] | ||||
|     .inst 0x05f0238e  // dup z14.q, z28.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -798,12 +801,13 @@ TILE1_STORE80: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|     .inst 0xe400f5ed  // st1b {z13.b}, p5, [x15] | ||||
|     .inst 0xe40455ee  // st1b {z14.b}, p5, [x15, x4] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
| TILE1_STORE84: | ||||
|  | @ -820,14 +824,16 @@ TILE1_STORE84: | |||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
|     .inst 0x05b0238d  // dup z13.q, z28.q[2] | ||||
|     .inst 0x05f0238e  // dup z14.q, z28.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -846,12 +852,15 @@ TILE1_STORE84: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|     .inst 0xe400f5ed  // st1b {z13.b}, p5, [x15] | ||||
|     .inst 0xe40455ee  // st1b {z14.b}, p5, [x15, x4] | ||||
|     .inst 0xe400f51d  // st1b {z29.b}, p5, [x8] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
| TILE1_STORE88: | ||||
|  | @ -869,14 +878,16 @@ TILE1_STORE88: | |||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
|     .inst 0x05b0238d  // dup z13.q, z28.q[2] | ||||
|     .inst 0x05f0238e  // dup z14.q, z28.q[3] | ||||
|     .inst 0x057023af  // dup z15.q, z29.q[1] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -895,12 +906,16 @@ TILE1_STORE88: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|     .inst 0xe400f5ed  // st1b {z13.b}, p5, [x15] | ||||
|     .inst 0xe40455ee  // st1b {z14.b}, p5, [x15, x4] | ||||
|     .inst 0xe400f51d  // st1b {z29.b}, p5, [x8] | ||||
|     .inst 0xe404550f  // st1b {z15.b}, p5, [x8, x4] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
| TILE1_STORE92: | ||||
|  | @ -919,14 +934,17 @@ TILE1_STORE92: | |||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
|     .inst 0x05b0238d  // dup z13.q, z28.q[2] | ||||
|     .inst 0x05f0238e  // dup z14.q, z28.q[3] | ||||
|     .inst 0x057023af  // dup z15.q, z29.q[1] | ||||
|     .inst 0x05b023b0  // dup z16.q, z29.q[2] | ||||
|     .inst 0x05f023b1  // dup z17.q, z29.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -945,15 +963,18 @@ TILE1_STORE92: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x8, x23, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
|     add x11, x5, x4, lsl #3 // +22 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|     .inst 0xe400f5ed  // st1b {z13.b}, p5, [x15] | ||||
|     .inst 0xe40455ee  // st1b {z14.b}, p5, [x15, x4] | ||||
|     .inst 0xe400f51d  // st1b {z29.b}, p5, [x8] | ||||
|     .inst 0xe404550f  // st1b {z15.b}, p5, [x8, x4] | ||||
|     .inst 0xe400f570  // st1b {z16.b}, p5, [x11] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
| TILE1_STORE96: | ||||
|  | @ -973,14 +994,16 @@ TILE1_STORE96: | |||
|     .inst 0x05b0238d  // dup z13.q, z28.q[2] | ||||
|     .inst 0x05f0238e  // dup z14.q, z28.q[3] | ||||
|     .inst 0x057023af  // dup z15.q, z29.q[1] | ||||
|     .inst 0x05b023b0  // dup z16.q, z29.q[2] | ||||
|     .inst 0x05f023b1  // dup z17.q, z29.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -999,9 +1022,10 @@ TILE1_STORE96: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x8, x23, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
|     add x11, x5, x4, lsl #3 // +22 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|  | @ -1009,6 +1033,8 @@ TILE1_STORE96: | |||
|     .inst 0xe40455ee  // st1b {z14.b}, p5, [x15, x4] | ||||
|     .inst 0xe400f51d  // st1b {z29.b}, p5, [x8] | ||||
|     .inst 0xe404550f  // st1b {z15.b}, p5, [x8, x4] | ||||
|     .inst 0xe400f570  // st1b {z16.b}, p5, [x11] | ||||
|     .inst 0xe4045571  // st1b {z17.b}, p5, [x11, x4] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
| TILE1_STORE100: | ||||
|  | @ -1029,14 +1055,15 @@ TILE1_STORE100: | |||
|     .inst 0x05f0238e  // dup z14.q, z28.q[3] | ||||
|     .inst 0x057023af  // dup z15.q, z29.q[1] | ||||
|     .inst 0x05b023b0  // dup z16.q, z29.q[2] | ||||
|     .inst 0x05f023b1  // dup z17.q, z29.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -1055,10 +1082,11 @@ TILE1_STORE100: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x8, x23, x4, lsl #3 | ||||
|     add x11, x5, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
|     add x11, x5, x4, lsl #3 // +22 | ||||
|     add x13, x9, x4, lsl #3 // +24 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|  | @ -1067,6 +1095,8 @@ TILE1_STORE100: | |||
|     .inst 0xe400f51d  // st1b {z29.b}, p5, [x8] | ||||
|     .inst 0xe404550f  // st1b {z15.b}, p5, [x8, x4] | ||||
|     .inst 0xe400f570  // st1b {z16.b}, p5, [x11] | ||||
|     .inst 0xe4045571  // st1b {z17.b}, p5, [x11, x4] | ||||
|     .inst 0xe400f5be  // st1b {z30.b}, p5, [x13] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
| TILE1_STORE104: | ||||
|  | @ -1088,75 +1118,15 @@ TILE1_STORE104: | |||
|     .inst 0x057023af  // dup z15.q, z29.q[1] | ||||
|     .inst 0x05b023b0  // dup z16.q, z29.q[2] | ||||
|     .inst 0x05f023b1  // dup z17.q, z29.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|     .inst 0xe400f521  // st1b {z1.b}, p5, [x9] | ||||
|     .inst 0xe4045522  // st1b {z2.b}, p5, [x9, x4] | ||||
|     .inst 0xe400f5f9  // st1b {z25.b}, p5, [x15] | ||||
|     .inst 0xe40455e3  // st1b {z3.b}, p5, [x15, x4] | ||||
|     .inst 0xe400f504  // st1b {z4.b}, p5, [x8] | ||||
|     .inst 0xe4045505  // st1b {z5.b}, p5, [x8, x4] | ||||
|     .inst 0xe400f57a  // st1b {z26.b}, p5, [x11] | ||||
|     .inst 0xe4045566  // st1b {z6.b}, p5, [x11, x4] | ||||
|     .inst 0xe400f5a7  // st1b {z7.b}, p5, [x13] | ||||
|     .inst 0xe40455a8  // st1b {z8.b}, p5, [x13, x4] | ||||
|     .inst 0xe400f6fb  // st1b {z27.b}, p5, [x23] | ||||
|     .inst 0xe40456e9  // st1b {z9.b}, p5, [x23, x4] | ||||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x8, x23, x4, lsl #3 | ||||
|     add x11, x5, x4, lsl #3 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|     .inst 0xe400f5ed  // st1b {z13.b}, p5, [x15] | ||||
|     .inst 0xe40455ee  // st1b {z14.b}, p5, [x15, x4] | ||||
|     .inst 0xe400f51d  // st1b {z29.b}, p5, [x8] | ||||
|     .inst 0xe404550f  // st1b {z15.b}, p5, [x8, x4] | ||||
|     .inst 0xe400f570  // st1b {z16.b}, p5, [x11] | ||||
|     .inst 0xe4045571  // st1b {z17.b}, p5, [x11, x4] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
| TILE1_STORE108: | ||||
|     .inst 0x05702300  // dup z0.q, z24.q[1] | ||||
|     .inst 0x05b02301  // dup z1.q, z24.q[2] | ||||
|     .inst 0x05f02302  // dup z2.q, z24.q[3] | ||||
|     .inst 0x05702323  // dup z3.q, z25.q[1] | ||||
|     .inst 0x05b02324  // dup z4.q, z25.q[2] | ||||
|     .inst 0x05f02325  // dup z5.q, z25.q[3] | ||||
|     .inst 0x05702346  // dup z6.q, z26.q[1] | ||||
|     .inst 0x05b02347  // dup z7.q, z26.q[2] | ||||
|     .inst 0x05f02348  // dup z8.q, z26.q[3] | ||||
|     .inst 0x05702369  // dup z9.q, z27.q[1] | ||||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
|     .inst 0x05b0238d  // dup z13.q, z28.q[2] | ||||
|     .inst 0x05f0238e  // dup z14.q, z28.q[3] | ||||
|     .inst 0x057023af  // dup z15.q, z29.q[1] | ||||
|     .inst 0x05b023b0  // dup z16.q, z29.q[2] | ||||
|     .inst 0x05f023b1  // dup z17.q, z29.q[3] | ||||
|     .inst 0x057023d2  // dup z18.q, z30.q[1] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -1175,11 +1145,11 @@ TILE1_STORE108: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x8, x23, x4, lsl #3 | ||||
|     add x11, x5, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
|     add x11, x5, x4, lsl #3 // +22 | ||||
|     add x13, x9, x4, lsl #3 // +24 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|  | @ -1193,7 +1163,8 @@ TILE1_STORE108: | |||
|     .inst 0xe40455b2  // st1b {z18.b}, p5, [x13, x4] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
|     TILE1_STORE112: | ||||
|     /* oc=108 */ | ||||
|     TILE1_STORE108: | ||||
|     .inst 0x05702300  // dup z0.q, z24.q[1] | ||||
|     .inst 0x05b02301  // dup z1.q, z24.q[2] | ||||
|     .inst 0x05f02302  // dup z2.q, z24.q[3] | ||||
|  | @ -1215,13 +1186,13 @@ TILE1_STORE108: | |||
|     .inst 0x057023d2  // dup z18.q, z30.q[1] | ||||
|     .inst 0x05b023d3  // dup z19.q, z30.q[2] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -1240,12 +1211,12 @@ TILE1_STORE108: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x8, x23, x4, lsl #3 | ||||
|     add x11, x5, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
|     add x11, x5, x4, lsl #3 // +22 | ||||
|     add x13, x9, x4, lsl #3 // +24 | ||||
|     add x23, x15, x4, lsl #3 // +26 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|  | @ -1260,6 +1231,77 @@ TILE1_STORE108: | |||
|     .inst 0xe400f6f3  // st1b {z19.b}, p5, [x23] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
|     /* oc=112 */ | ||||
|     TILE1_STORE112: | ||||
|     .inst 0x05702300  // dup z0.q, z24.q[1] | ||||
|     .inst 0x05b02301  // dup z1.q, z24.q[2] | ||||
|     .inst 0x05f02302  // dup z2.q, z24.q[3] | ||||
|     .inst 0x05702323  // dup z3.q, z25.q[1] | ||||
|     .inst 0x05b02324  // dup z4.q, z25.q[2] | ||||
|     .inst 0x05f02325  // dup z5.q, z25.q[3] | ||||
|     .inst 0x05702346  // dup z6.q, z26.q[1] | ||||
|     .inst 0x05b02347  // dup z7.q, z26.q[2] | ||||
|     .inst 0x05f02348  // dup z8.q, z26.q[3] | ||||
|     .inst 0x05702369  // dup z9.q, z27.q[1] | ||||
|     .inst 0x05b0236a  // dup z10.q, z27.q[2] | ||||
|     .inst 0x05f0236b  // dup z11.q, z27.q[3] | ||||
|     .inst 0x0570238c  // dup z12.q, z28.q[1] | ||||
|     .inst 0x05b0238d  // dup z13.q, z28.q[2] | ||||
|     .inst 0x05f0238e  // dup z14.q, z28.q[3] | ||||
|     .inst 0x057023af  // dup z15.q, z29.q[1] | ||||
|     .inst 0x05b023b0  // dup z16.q, z29.q[2] | ||||
|     .inst 0x05f023b1  // dup z17.q, z29.q[3] | ||||
|     .inst 0x057023d2  // dup z18.q, z30.q[1] | ||||
|     .inst 0x05b023d3  // dup z19.q, z30.q[2] | ||||
|     .inst 0x05f023d4  // dup z20.q, z30.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|     .inst 0xe400f521  // st1b {z1.b}, p5, [x9] | ||||
|     .inst 0xe4045522  // st1b {z2.b}, p5, [x9, x4] | ||||
|     .inst 0xe400f5f9  // st1b {z25.b}, p5, [x15] | ||||
|     .inst 0xe40455e3  // st1b {z3.b}, p5, [x15, x4] | ||||
|     .inst 0xe400f504  // st1b {z4.b}, p5, [x8] | ||||
|     .inst 0xe4045505  // st1b {z5.b}, p5, [x8, x4] | ||||
|     .inst 0xe400f57a  // st1b {z26.b}, p5, [x11] | ||||
|     .inst 0xe4045566  // st1b {z6.b}, p5, [x11, x4] | ||||
|     .inst 0xe400f5a7  // st1b {z7.b}, p5, [x13] | ||||
|     .inst 0xe40455a8  // st1b {z8.b}, p5, [x13, x4] | ||||
|     .inst 0xe400f6fb  // st1b {z27.b}, p5, [x23] | ||||
|     .inst 0xe40456e9  // st1b {z9.b}, p5, [x23, x4] | ||||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
|     add x11, x5, x4, lsl #3 // +22 | ||||
|     add x13, x9, x4, lsl #3 // +24 | ||||
|     add x23, x15, x4, lsl #3 // +26 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|     .inst 0xe400f5ed  // st1b {z13.b}, p5, [x15] | ||||
|     .inst 0xe40455ee  // st1b {z14.b}, p5, [x15, x4] | ||||
|     .inst 0xe400f51d  // st1b {z29.b}, p5, [x8] | ||||
|     .inst 0xe404550f  // st1b {z15.b}, p5, [x8, x4] | ||||
|     .inst 0xe400f570  // st1b {z16.b}, p5, [x11] | ||||
|     .inst 0xe4045571  // st1b {z17.b}, p5, [x11, x4] | ||||
|     .inst 0xe400f5be  // st1b {z30.b}, p5, [x13] | ||||
|     .inst 0xe40455b2  // st1b {z18.b}, p5, [x13, x4] | ||||
|     .inst 0xe400f6f3  // st1b {z19.b}, p5, [x23] | ||||
|     .inst 0xe40456f4  // st1b {z20.b}, p5, [x23, x4] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
|     /* oc=116 */ | ||||
|     TILE1_STORE116: | ||||
|     .inst 0x05702300  // dup z0.q, z24.q[1] | ||||
|     .inst 0x05b02301  // dup z1.q, z24.q[2] | ||||
|  | @ -1283,13 +1325,13 @@ TILE1_STORE108: | |||
|     .inst 0x05b023d3  // dup z19.q, z30.q[2] | ||||
|     .inst 0x05f023d4  // dup z20.q, z30.q[3] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -1308,13 +1350,13 @@ TILE1_STORE108: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x8, x23, x4, lsl #3 | ||||
|     add x11, x5, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
|     add x11, x5, x4, lsl #3 // +22 | ||||
|     add x13, x9, x4, lsl #3 // +24 | ||||
|     add x23, x15, x4, lsl #3 // +26 | ||||
|     add x5, x8, x4, lsl #3   // +28 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|  | @ -1331,6 +1373,7 @@ TILE1_STORE108: | |||
|     .inst 0xe400f4bf  // st1b {z31.b}, p5, [x5] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
|     /* oc=120 */ | ||||
|     TILE1_STORE120: | ||||
|     .inst 0x05702300  // dup z0.q, z24.q[1] | ||||
|     .inst 0x05b02301  // dup z1.q, z24.q[2] | ||||
|  | @ -1355,13 +1398,13 @@ TILE1_STORE108: | |||
|     .inst 0x05f023d4  // dup z20.q, z30.q[3] | ||||
|     .inst 0x057023f5  // dup z21.q, z31.q[1] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #1 | ||||
|     add x15, x6, x4, lsl #2 | ||||
|     add x8, x9, x4, lsl #2 | ||||
|     add x11, x6, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #1 // +2 | ||||
|     add x15, x6, x4, lsl #2 // +4 | ||||
|     add x8, x9, x4, lsl #2 // +6 | ||||
|     add x11, x6, x4, lsl #3 // +8 | ||||
|     add x13, x9, x4, lsl #3 // +10 | ||||
|     add x23, x15, x4, lsl #3 // +12 | ||||
|     add x5, x8, x4, lsl #3   // +14 | ||||
| 
 | ||||
|     .inst 0xe400f4d8  // st1b {z24.b}, p5, [x6] | ||||
|     .inst 0xe40454c0  // st1b {z0.b}, p5, [x6, x4] | ||||
|  | @ -1380,13 +1423,13 @@ TILE1_STORE108: | |||
|     .inst 0xe400f4aa  // st1b {z10.b}, p5, [x5] | ||||
|     .inst 0xe40454ab  // st1b {z11.b}, p5, [x5, x4] | ||||
| 
 | ||||
|     add x9, x6, x4, lsl #4 | ||||
|     add x15, x13, x4, lsl #3 | ||||
|     add x8, x23, x4, lsl #3 | ||||
|     add x11, x5, x4, lsl #3 | ||||
|     add x13, x9, x4, lsl #3 | ||||
|     add x23, x15, x4, lsl #3 | ||||
|     add x5, x8, x4, lsl #3 | ||||
|     add x9, x6, x4, lsl #4 // +16 | ||||
|     add x15, x13, x4, lsl #3 // +18 | ||||
|     add x8, x23, x4, lsl #3 // +20 | ||||
|     add x11, x5, x4, lsl #3 // +22 | ||||
|     add x13, x9, x4, lsl #3 // +24 | ||||
|     add x23, x15, x4, lsl #3 // +26 | ||||
|     add x5, x8, x4, lsl #3   // +28 | ||||
| 
 | ||||
|     .inst 0xe400f53c  // st1b {z28.b}, p5, [x9] | ||||
|     .inst 0xe404552c  // st1b {z12.b}, p5, [x9, x4] | ||||
|  | @ -1404,6 +1447,7 @@ TILE1_STORE108: | |||
|     .inst 0xe40454b5  // st1b {z21.b}, p5, [x5, x4] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
|     /* oc=124 */ | ||||
|     TILE1_STORE124: | ||||
|     .inst 0x05702300  // dup z0.q, z24.q[1] | ||||
|     .inst 0x05b02301  // dup z1.q, z24.q[2] | ||||
|  | @ -1480,6 +1524,7 @@ TILE1_STORE108: | |||
|     .inst 0xe400f596  // st1b {z22.b}, p5, [x12] | ||||
|     b TILE1_Dz_End | ||||
| 
 | ||||
|     /* oc=128 */ | ||||
|     TILE1_STORE128: | ||||
|     .inst 0x05702300  // dup z0.q, z24.q[1] | ||||
|     .inst 0x05b02301  // dup z1.q, z24.q[2] | ||||
|  |  | |||
|  | @ -1390,6 +1390,89 @@ void MNNAccumulateSequenceNumber (float* dst, const float* src, int size) { | |||
|     *dst = sum; | ||||
| } | ||||
| 
 | ||||
| #ifdef MNN_SUPPORT_TRANSFORMER_FUSE | ||||
| 
 | ||||
| static void MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) { | ||||
|     // source shape:                 [headDim/pack, seqLen, pack]
 | ||||
|     // scale & normalizeScale shape: [seqLen]
 | ||||
|     // dest shape:                   [headDim/pack, seqLen, pack]
 | ||||
|     auto stride0 = plane * pack; | ||||
| 
 | ||||
|     if (idx > 0) { | ||||
|         for (int j = 0; j < depthQuad; ++j) { | ||||
|             for (int i = 0; i < plane; ++i) { | ||||
|                 auto dataNew = Vec::load(src + j * stride0 + i * pack); | ||||
|                 auto dataOld = Vec::load(dst + j * stride0 + i * pack); | ||||
|                 auto s = Vec(scale[i]); | ||||
|                 dataNew = Vec::fma(dataNew, dataOld, s); | ||||
|                 Vec::save(dst + j * stride0 + i * pack, dataNew); | ||||
|             } | ||||
|         } | ||||
|     } else { | ||||
|         memcpy(dst, src, size * bytes); | ||||
|     } | ||||
|     if (idx == kvBlocks - 1) { // if last subBlock, exp(xi)/sum(exp(xi))
 | ||||
|         for (int j = 0; j < depthQuad; ++j) { | ||||
|             for (int i = 0; i < plane; ++i) { | ||||
|                 auto dataNew = Vec::load(dst + j * stride0 + i * pack); | ||||
|                 auto ns = Vec(1.0f / normalizeScale[i]); | ||||
|                 dataNew = dataNew * ns; | ||||
|                 Vec::save(dst + j * stride0 + i * pack, dataNew); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| static void MNNAttenPackAndScaleSingleHead(float* dst, const float* srcHeadBase, size_t srcRowStride, const float* scale, const int32_t* units, size_t seqLen, size_t headDim) { | ||||
|     const int32_t eP = units[0]; | ||||
|     const int32_t lP = units[1]; | ||||
| 
 | ||||
|     if (lP != 1) { | ||||
|         MNN_ERROR("This function only supports lP=1 or 2\n"); | ||||
|         return; | ||||
|     } | ||||
| 
 | ||||
|     const float scaleVal = scale[0]; | ||||
| #ifdef MNN_USE_NEON | ||||
|     const float32x4_t vScale = vdupq_n_f32(scaleVal); | ||||
| #endif | ||||
|     const size_t packedHeadDim = UP_DIV(headDim, lP); | ||||
|     const size_t dstStrideDOuter = (size_t)eP * lP; | ||||
|     const size_t dstStrideSOuter = packedHeadDim * dstStrideDOuter; | ||||
| 
 | ||||
|     for (int s = 0; s < seqLen; ++s) { | ||||
|         const int sOuter = s / eP; | ||||
|         const int sInner = s % eP; | ||||
|         const float* srcRowPtr = srcHeadBase + s * srcRowStride; | ||||
|         float* dstBasePtr = dst + sOuter * dstStrideSOuter + sInner * lP; | ||||
|          | ||||
|         size_t d = 0; | ||||
| #ifdef MNN_USE_NEON | ||||
|         for (; d + 7 < headDim; d += 8) { | ||||
|             float32x4_t sVec0 = vld1q_f32(srcRowPtr + d); | ||||
|             float32x4_t sVec1 = vld1q_f32(srcRowPtr + d + 4); | ||||
|             sVec0 = vmulq_f32(sVec0, vScale); | ||||
|             sVec1 = vmulq_f32(sVec1, vScale); | ||||
|              | ||||
|             dstBasePtr[(d + 0) * dstStrideDOuter] = sVec0[0]; | ||||
|             dstBasePtr[(d + 1) * dstStrideDOuter] = sVec0[1]; | ||||
|             dstBasePtr[(d + 2) * dstStrideDOuter] = sVec0[2]; | ||||
|             dstBasePtr[(d + 3) * dstStrideDOuter] = sVec0[3]; | ||||
|             dstBasePtr[(d + 4) * dstStrideDOuter] = sVec1[0]; | ||||
|             dstBasePtr[(d + 5) * dstStrideDOuter] = sVec1[1]; | ||||
|             dstBasePtr[(d + 6) * dstStrideDOuter] = sVec1[2]; | ||||
|             dstBasePtr[(d + 7) * dstStrideDOuter] = sVec1[3]; | ||||
|         } | ||||
| #else | ||||
|         for (; d < headDim; ++d) { | ||||
|             dstBasePtr[d * dstStrideDOuter] = srcRowPtr[d] * scaleVal; | ||||
|         } | ||||
| #endif | ||||
|     } | ||||
| } | ||||
| #endif // MNN_SUPPORT_TRANSFORMER_FUSE
 | ||||
| 
 | ||||
| 
 | ||||
| #ifndef MNN_USE_NEON | ||||
| void MNNGetMatMulPackMode(int* eP, int *lP, int* hP) { | ||||
|     *eP = 16; | ||||
|  | @ -2619,30 +2702,54 @@ void MNNExpC8(float* dest, const float* source, float* offset, const float* para | |||
|     offset[3] = summer; | ||||
| } | ||||
| 
 | ||||
| void MNNSoftmax(float* dest, const float* source, size_t size) { | ||||
|     float maxValue = ALIMAX(source[0], source[1]); | ||||
|     for (int i = 2; i < size; ++i) { | ||||
|         maxValue = ALIMAX(maxValue, source[i]); | ||||
|     } | ||||
|     float xLimit = 87, param = 0.6931471805599453, sumValue = 0.f; | ||||
|     for (int i = 0; i < size; ++i) { | ||||
|         auto x         = source[i] - maxValue; | ||||
|         x = x > -xLimit ? x : -xLimit; | ||||
|         x = x < xLimit ? x : xLimit; | ||||
| void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) { | ||||
|     for (int k = 0; k < outside; ++k) { | ||||
|         auto source = input + k * reduceSize; | ||||
|         auto dest = softmaxDst + k * reduceSize; | ||||
| 
 | ||||
|         int div        = (x / param); | ||||
|         int div2       = (div + 127) << 23; | ||||
|         auto xReamin   = x - div * param; | ||||
|         float expBasic = *(float*)(&div2); | ||||
|         float oldMax = source[0]; | ||||
|         if (runningMax) { | ||||
|             oldMax = runningMax[k]; | ||||
|         } | ||||
| 
 | ||||
|         auto t         = xReamin; | ||||
|         auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f; | ||||
|         dest[i]  = expBasic * expRemain; | ||||
|         sumValue += dest[i]; | ||||
|     } | ||||
|     sumValue = 1.f / sumValue; | ||||
|     for (int i = 0; i < size; ++i) { | ||||
|         dest[i] *= sumValue; | ||||
|         // find max value of current block
 | ||||
|         float blockMax =source[0]; | ||||
|         for (int i = 1; i < reduceSize; ++i) { | ||||
|             blockMax = ALIMAX(blockMax, source[i]); | ||||
|         } | ||||
|         float newMax = ALIMAX(oldMax, blockMax); | ||||
|          | ||||
|         // caculate block's expr(xi-newmax) and update runningMax
 | ||||
|         float xLimit = 87, param = 0.6931471805599453; | ||||
|         float blockSum = 0.f; | ||||
|         for (int i = 0; i < reduceSize; ++i) { | ||||
|             auto x         = source[i] - newMax; | ||||
|             x = x > -xLimit ? x : -xLimit; | ||||
|             x = x < xLimit ? x : xLimit; | ||||
| 
 | ||||
|             int div        = (x / param); | ||||
|             int div2       = (div + 127) << 23; | ||||
|             auto xReamin   = x - div * param; | ||||
|             float expBasic = *(float*)(&div2); | ||||
| 
 | ||||
|             auto t         = xReamin; | ||||
|             auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f; | ||||
|             dest[i]  = expBasic * expRemain; | ||||
|             blockSum += dest[i]; | ||||
|         } | ||||
| 
 | ||||
|         if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) { | ||||
|             // update runningSum, runningMax, scale=expf(oldMax-newMax)
 | ||||
|             runningSum[k] = runningSum[k] * expf(oldMax - newMax) + blockSum; | ||||
|             runningMax[k] = newMax; | ||||
|             updateScale[k] = expf(oldMax - newMax); | ||||
|         } else { | ||||
|             // Normalize
 | ||||
|             auto scale = 1.f / blockSum; | ||||
|             for (int i = 0; i < reduceSize; ++i) { | ||||
|                 dest[i] *= scale; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  | @ -2657,6 +2764,68 @@ void MNNReluInt8(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint) | |||
| } | ||||
| #endif // no MNN_USE_SSE
 | ||||
| 
 | ||||
| void MNNExp(float* dst, const float* src, float* offset, size_t dataSize) { | ||||
|     int countC8        = static_cast<int32_t>(dataSize) / 8; | ||||
|     int remain = static_cast<int32_t>(dataSize) % 8; | ||||
|     static const float parameters[] = { | ||||
|         (float)logf(2.0f), 1.0f / (float)logf(2.0f), 0.25f, 1.0f, 0.5f, 1.0f / 6.0f, 1.0f / 24.0f, 1.0f / 120.0f}; | ||||
|     if (countC8 > 0) { | ||||
|         // Align to eight so asm is easier to write
 | ||||
|         MNNExpC8(dst, src, offset, parameters, countC8); | ||||
|     } | ||||
|     if (remain > 0) { | ||||
|         auto param = parameters[0]; | ||||
|         float xLimit = 87; | ||||
|         float summer = offset[3]; | ||||
|         auto source = src + countC8 * 8; | ||||
|         auto dest = dst + countC8 * 8; | ||||
|         for (int i = 0; i < remain; ++i) { | ||||
|             auto x         = source[i] * offset[0] + offset[2]; | ||||
|             x = ALIMAX(x, -xLimit); | ||||
|             x = ALIMIN(x, xLimit); | ||||
|             int div        = (x * parameters[1]); | ||||
|             int div2       = (div + 127) << 23; | ||||
|             auto xReamin   = x - div * param; | ||||
|             float expBasic = *(float*)(&div2); | ||||
|             auto t = xReamin * 0.25f; | ||||
|             auto expRemain = | ||||
|             ((((parameters[7] * t + parameters[6]) * t + parameters[5]) * t + parameters[4]) * t + 1.0f) * t + | ||||
|                 1.0f; | ||||
|             expRemain = expRemain * expRemain; | ||||
|             expRemain = expRemain * expRemain; | ||||
|             dest[i] = expBasic * expRemain + offset[1]; | ||||
|             summer+= dest[i]; | ||||
|         } | ||||
|         offset[3] = summer; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| void packKvCache(float* dst, const float* src, size_t seqLen, size_t kvSeqLen, size_t eP) { | ||||
|     if (seqLen == 0 || kvSeqLen == 0) { | ||||
|         return; | ||||
|     } | ||||
| 
 | ||||
|     const size_t dstSOuterStride = kvSeqLen * eP; | ||||
| 
 | ||||
|     for (size_t sBase = 0; sBase < seqLen; sBase += eP) { | ||||
|         const size_t numRowsInBlock = std::min(eP, seqLen - sBase); | ||||
|         const size_t sOuter = sBase / eP; | ||||
| 
 | ||||
|         float* dstSOuterBase = dst + sOuter * dstSOuterStride; | ||||
| 
 | ||||
|         for (size_t k = 0; k < kvSeqLen; ++k) { | ||||
|             float* dstColStart = dstSOuterBase + k * eP; | ||||
| 
 | ||||
|             for (size_t sInner = 0; sInner < numRowsInBlock; ++sInner) { | ||||
|                 const size_t s = sBase + sInner; | ||||
| 
 | ||||
|                 const float value = src[s * kvSeqLen + k]; | ||||
|                 dstColStart[sInner] = value; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| void MNNMaxFloat(float* input, float* maxBuffer, int32_t inputCountUnit) { | ||||
|     for (int i = 0; i < inputCountUnit; i++) { | ||||
|         for (int j = 0; j < UNIT; j++) { | ||||
|  | @ -3173,42 +3342,6 @@ void MNNPackTranspose(float* dst, const float* src, size_t area, size_t depth, i | |||
|     } | ||||
| } | ||||
| 
 | ||||
| void MNNExp(float* dst, const float* src, float* offset, size_t dataSize) { | ||||
|     int countC8        = static_cast<int32_t>(dataSize) / 8; | ||||
|     int remain = static_cast<int32_t>(dataSize) % 8; | ||||
|     static const float parameters[] = { | ||||
|         (float)logf(2.0f), 1.0f / (float)logf(2.0f), 0.25f, 1.0f, 0.5f, 1.0f / 6.0f, 1.0f / 24.0f, 1.0f / 120.0f}; | ||||
|     if (countC8 > 0) { | ||||
|         // Align to eight so asm is easier to write
 | ||||
|         MNNExpC8(dst, src, offset, parameters, countC8); | ||||
|     } | ||||
|     if (remain > 0) { | ||||
|         auto param = parameters[0]; | ||||
|         float xLimit = 87; | ||||
|         float summer = offset[3]; | ||||
|         auto source = src + countC8 * 8; | ||||
|         auto dest = dst + countC8 * 8; | ||||
|         for (int i = 0; i < remain; ++i) { | ||||
|             auto x         = source[i] * offset[0] + offset[2]; | ||||
|             x = ALIMAX(x, -xLimit); | ||||
|             x = ALIMIN(x, xLimit); | ||||
|             int div        = (x * parameters[1]); | ||||
|             int div2       = (div + 127) << 23; | ||||
|             auto xReamin   = x - div * param; | ||||
|             float expBasic = *(float*)(&div2); | ||||
|             auto t = xReamin * 0.25f; | ||||
|             auto expRemain = | ||||
|             ((((parameters[7] * t + parameters[6]) * t + parameters[5]) * t + parameters[4]) * t + 1.0f) * t + | ||||
|                 1.0f; | ||||
|             expRemain = expRemain * expRemain; | ||||
|             expRemain = expRemain * expRemain; | ||||
|             dest[i] = expBasic * expRemain + offset[1]; | ||||
|             summer+= dest[i]; | ||||
|         } | ||||
|         offset[3] = summer; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| // Lambert's series with 7 divisions
 | ||||
| // reference from
 | ||||
| // https://varietyofsound.wordpress.com/2011/02/14/efficient-tanh-computation-using-lamberts-continued-fraction/
 | ||||
|  | @ -4011,6 +4144,7 @@ void MNNCoreFunctionInit() { | |||
|     gCoreFunction->chooseWinoDestUnrollTransform = WinogradFunction::chooseWinoDestUnrollTransform; | ||||
|     gCoreFunction->MNNDeconvRunForLineDepthwise = MNNDeconvRunForLineDepthwise; | ||||
|     gCoreFunction->MNNDeconvRunForUnitDepthWise = MNNDeconvRunForUnitDepthWise; | ||||
|     gCoreFunction->MNNSoftmax = MNNSoftmax; | ||||
| #ifdef MNN_USE_NEON | ||||
|     gCoreFunction->MNNDepthwiseConvFastKernel = MNNDepthwiseConvFastKernel; | ||||
| #endif | ||||
|  | @ -4019,6 +4153,12 @@ void MNNCoreFunctionInit() { | |||
| #ifdef MNN_SUPPORT_QUANT_EXTEND | ||||
|     gCoreFunction->MNNSelectUnaryFunctionForInt8 = CPUUnary::selectForInt8; | ||||
| #endif | ||||
| 
 | ||||
| #ifdef MNN_SUPPORT_TRANSFORMER_FUSE | ||||
|     gCoreFunction->MNNAttenPackAndScaleSingleHead = MNNAttenPackAndScaleSingleHead; | ||||
|     gCoreFunction->MNNFlashAttentionUpdateBlockOutput = MNNFlashAttentionUpdateBlockOutput; | ||||
| #endif | ||||
| 
 | ||||
|     gCoreFunction->MNNReluWithSlopeChannel = MNNReluWithSlopeChannel; | ||||
|     gCoreFunction->MNNPoolingAvg = (decltype(gCoreFunction->MNNPoolingAvg))(poolingAvg<float, Vec4, 4>); | ||||
|     // Set min value as 1 << 24
 | ||||
|  |  | |||
|  | @ -32,12 +32,16 @@ void MNNAbsMaxFP32_Pack4(const float* source, float* absmax, size_t src_depth_qu | |||
| void MNNAbsMaxFP32_Pack8(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack); | ||||
| void MNNQuantScaleFP32(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch); | ||||
| void MNNDynamicUpdateConvBiasScale(float* newbias, float* oldbias, float* weightKernelSum, float* inputZero, size_t ocQuad); | ||||
| #endif | ||||
| #endif // MNN_LOW_MEMORY
 | ||||
| 
 | ||||
| #ifdef MNN_SME2 | ||||
| void MNNPackedMatMulRemainFP32_SME2(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); | ||||
| #endif | ||||
| 
 | ||||
| #endif // __aarch64__
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| #endif | ||||
| #endif | ||||
| void MNNFp32ToFp8(uint8_t* dst, const float* src, size_t size); | ||||
| void MNNFp8ToFp32(float* dst, const uint8_t* src, size_t size); | ||||
| void MNNFp16ToFp8(uint8_t* dst, const uint16_t* src, size_t size); | ||||
|  | @ -117,7 +121,7 @@ void MNNReluWithSlopeCommon(float* dst, const float* src, size_t size, float slo | |||
| void MNNHardSwishCommon(float* dst, const float* src, size_t size); | ||||
| void MNNGeluCommon(float* dst, const float* src, size_t size); | ||||
| void MNNGeluStandardCommon(float* dst, const float* src, size_t size); | ||||
| void MNNSoftmax(float* dest, const float* source, size_t size); | ||||
| void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize); | ||||
| void MNNNorm(float* dest, const float* source, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm = false); | ||||
| 
 | ||||
| // Get Pack for MatMul's e , l , h , the pack number must be 1 or 4 * n
 | ||||
|  | @ -187,6 +191,8 @@ void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const floa | |||
| void MNNCopyC4Int16WithStride(const float* sourceF, float* destF, size_t srcStride, size_t dstStride, size_t count); | ||||
| void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count); | ||||
| 
 | ||||
| void packKvCache(float* dst, const float* src, size_t seqLen, size_t kvSeqLen, size_t eP); | ||||
| 
 | ||||
| struct SumByAxisParams { | ||||
|     ssize_t kernelCountUnitDouble; | ||||
|     ssize_t unitColBufferSize; | ||||
|  | @ -401,6 +407,13 @@ struct CoreFunctions { | |||
|     void(*MNNSumWeightInt8)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP); | ||||
|     void(*MNNSumWeightInt8SmeHp64)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP); | ||||
| 
 | ||||
|     // Attention
 | ||||
|     void(*MNNAttenUnpackAndConvertFp16)(float* dst, float* src, size_t depth, size_t planesize, int pack); | ||||
|     void(*MNNAttenPackAndConvertFp32)(float* dst, float* src, const int32_t* units, size_t depth, size_t planesize); | ||||
|     void(*MNNAttenPackAndScaleSingleHead)(float* dst, const float* srcHeadBase, size_t srcRowStride, const float* scale, const int32_t* units, size_t seqLen, size_t headDim); | ||||
|     void(*MNNFlashAttentionUpdateBlockOutput)(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes); | ||||
|     void(*MNNSoftmax)(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize); | ||||
| 
 | ||||
|     MatmulRelatedFunctions int8MatmulRelatedFunctions; | ||||
|     MatmulRelatedFunctions sme2Int8MatmulRelatedFuncionsHp32; | ||||
| }; | ||||
|  |  | |||
|  | @ -805,17 +805,16 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector<Tensor*>& input | |||
|         return OUT_OF_MEMORY; | ||||
|     } | ||||
|     if (mOnlineReorderWeightSme && planeSize > 1) { // only prefill need
 | ||||
|         int dstHp = 32; // if mOnlineReorderWeightSme==true, UNIT is 64 when model loaded.
 | ||||
|         int weightlenNew = ROUND_UP(outC, dstHp) * mBlockNum * ROUND_UP(ic / mBlockNum, SRC_UNIT) * kernelCount * SRC_UNIT * dstHp; | ||||
|         int weightlenNew = ROUND_UP(outC, SME_DECODE_MAXHP) * mBlockNum * ROUND_UP(ic / mBlockNum, SRC_UNIT) * kernelCount; | ||||
|         if (mResourceInt8->mActBits == 4) { | ||||
|             weightlenNew /= 2; | ||||
|         } | ||||
|         mWeight4Prefill = bufferAlloc->alloc(weightlenNew + 2 * ROUND_UP(outC, dstHp) * QUANT_INFO_BYTES); | ||||
|         mWeight4Prefill = bufferAlloc->alloc(weightlenNew + 2 * mBlockNum * ROUND_UP(outC, SME_DECODE_MAXHP) * QUANT_INFO_BYTES); | ||||
|         if (mWeight4Prefill.invalid()) { | ||||
|             return OUT_OF_MEMORY; | ||||
|         } | ||||
|         if (mInputBlockNum > 1) { // only in this case, need to use weight_kernel_sum
 | ||||
|             mWeightKernelSum4Prefill = bufferAlloc->alloc(ROUND_UP(outC, 32) * mBlockNum * sizeof(float)); | ||||
|             mWeightKernelSum4Prefill = bufferAlloc->alloc(ROUND_UP(outC, SME_DECODE_MAXHP) * mBlockNum * sizeof(float)); | ||||
|             if (mWeightKernelSum4Prefill.invalid()) { | ||||
|                 return OUT_OF_MEMORY; | ||||
|             } | ||||
|  |  | |||
|  | @ -63,6 +63,7 @@ bool AVX2Functions::init(int cpuFlags) { | |||
|     // Dynamic Quant
 | ||||
|     coreFunction->MNNCountMaxMinValue = _AVX_MNNCountMinMaxValue; | ||||
| 
 | ||||
|     coreFunction->MNNSoftmax = _AVX_MNNSoftmax; | ||||
| 
 | ||||
|     // For Packed Functions
 | ||||
|     coreFunction->pack = 8; | ||||
|  |  | |||
|  | @ -24,7 +24,7 @@ struct FunctionGroup { | |||
|     int lP                                                                                       = 1; | ||||
|     int hP                                                                                       = 4; | ||||
|     void (*MNNExpC8)(float* dest, const float* source, float* offset, const float* parameters, size_t countC8) = _SSE_MNNExpC8; | ||||
|     void (*MNNSoftmax)(float* dest, const float* source, size_t size) = _SSE_MNNSoftmax; | ||||
|     void (*MNNSoftmax)(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) = _SSE_MNNSoftmax; | ||||
|     void (*MNNReluInt8)(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint) = _SSE_MNNReluInt8; | ||||
|     void (*MNNHardSwish)(float* dst, const float* src, size_t size) = _SSE_MNNHardSwish; | ||||
|     void (*MNNGelu)(float* dst, const float* src, size_t size, float* parameters) = _SSE_MNNGelu; | ||||
|  | @ -65,6 +65,8 @@ void MNNFunctionInit() { | |||
|         coreFunction->MNNPackForMatMul_B    = _SSE_MNNPackForMatMul_B; | ||||
|         // Dynamic Quant
 | ||||
|         coreFunction->MNNCountMaxMinValue = _SSE_MNNCountMinMaxValue; | ||||
| 
 | ||||
|         coreFunction->MNNSoftmax = _SSE_MNNSoftmax; | ||||
|     } | ||||
| #ifdef MNN_USE_AVX | ||||
|     if (cpuFlags & libyuv::kCpuHasAVX2) { | ||||
|  | @ -205,8 +207,8 @@ void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count) { | |||
|     _SSE_MNNInt8ToInt16(dest, source, count); | ||||
| } | ||||
| 
 | ||||
| void MNNSoftmax(float* dest, const float* source, size_t size) { | ||||
|     gFunc.MNNSoftmax(dest, source, size); | ||||
| void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) { | ||||
|     gFunc.MNNSoftmax(softmaxDst, input, runningMax, runningSum, updateScale, outside, reduceSize); | ||||
| } | ||||
| 
 | ||||
| void MNNNorm(float* dest, const float* source, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm) { | ||||
|  |  | |||
|  | @ -53,7 +53,7 @@ void _AVX_MNNAsyQuantInfo(float* scale, float* bias, float* qscale, float* qbias | |||
| void _AVX_MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); | ||||
| 
 | ||||
| void _AVX_MNNExpC8(float* dest, const float* source, float* offset, const float* parameters, size_t countC8); | ||||
| void _AVX_MNNSoftmax(float* dest, const float* source, size_t size); | ||||
| void _AVX_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize); | ||||
| void _AVX_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, const float* zeroPoint, ssize_t quanParamVec); | ||||
| void _AVX_MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t sizeQuad, const float* zeroPoint, ssize_t quanParamVec); | ||||
| void _AVX_MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dstO, const int8_t* srcO, const int8_t* weightO, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder); | ||||
|  |  | |||
|  | @ -117,103 +117,71 @@ void _AVX_MNNExpC8(float* dest, const float* source, float* offset, const float* | |||
| } | ||||
| 
 | ||||
| 
 | ||||
| void _AVX_MNNSoftmax(float* dest, const float* source, size_t size) { | ||||
|     float tmpfloat8[8]; | ||||
|     int count  = size / 8; | ||||
|     int remain = count * 8; | ||||
|     // step 1: get maxValue
 | ||||
|     float maxValue = source[0]; | ||||
|     if (count > 0) { | ||||
|         auto maxVal = _mm256_loadu_ps(source); | ||||
|         for (int i = 1; i < count; i++) { | ||||
|             maxVal = _mm256_max_ps(maxVal, _mm256_loadu_ps(source + i * 8)); | ||||
|         } | ||||
|         _mm256_storeu_ps(tmpfloat8, maxVal); | ||||
|         maxValue = tmpfloat8[0] > tmpfloat8[1] ? tmpfloat8[0] : tmpfloat8[1]; | ||||
|         for (int i = 2; i < 8; i++) { | ||||
|             maxValue = maxValue > tmpfloat8[i] ? maxValue : tmpfloat8[i]; | ||||
|         } | ||||
|     } | ||||
|     for (int i = remain; i < size; i++) { | ||||
|         maxValue = maxValue > source[i] ? maxValue : source[i]; | ||||
|     } | ||||
| void _AVX_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) { | ||||
|     const float xLimit = 87.0f; | ||||
|     const float param = 0.6931471805599453f; // ln(2)
 | ||||
|     const float inv_param = 1.0f / param; | ||||
|     const int32_t exp_offset = 127; | ||||
|     const float exp_scale = 8388608.0f; // 2^23
 | ||||
| 
 | ||||
|     // step 2: get exp(x - maxValue) and sum(exp(x - maxValue))
 | ||||
|     float sumValue = 0.f; | ||||
|     if (count > 0) { | ||||
|         auto sumVal = _mm256_set1_ps(0.f); | ||||
|         auto p0    = _mm256_set1_ps(0.6931471805599453); | ||||
|         auto p1    = _mm256_set1_ps(1.4426950408889634); | ||||
|         auto p2    = _mm256_set1_ps(1.f); | ||||
|         auto p3    = _mm256_set1_ps(1.f); | ||||
|         auto p4    = _mm256_set1_ps(0.5); | ||||
|         auto p5    = _mm256_set1_ps(0.1666666666666666); | ||||
|         auto p6    = _mm256_set1_ps(0.041666666666666664); | ||||
|         auto p7    = _mm256_set1_ps(0.008333333333333333); | ||||
|         auto xMax  = _mm256_set1_ps(87); | ||||
|         auto xMin  = _mm256_set1_ps(-87); | ||||
|         auto basic = _mm256_set1_epi32(1 << 23); | ||||
|         auto temp127 = _mm256_set1_epi32(127); | ||||
|         for (int i = 0; i < count; ++i) { | ||||
|             auto x            = _mm256_sub_ps(_mm256_loadu_ps(source + i * 8), _mm256_set1_ps(maxValue)); | ||||
|             x                 = _mm256_max_ps(x, xMin); | ||||
|             x                 = _mm256_min_ps(x, xMax); | ||||
|             auto div          = _mm256_mul_ps(x, p1); | ||||
|             auto divInt       = _mm256_cvtps_epi32(div); | ||||
|             div               = _mm256_cvtepi32_ps(divInt); | ||||
|             auto div2         = _mm256_add_epi32(divInt, temp127); | ||||
|             div2 = _mm256_mullo_epi32(div2, basic); | ||||
|             auto expBasic  = _mm256_castsi256_ps(div2); | ||||
|             auto xReamin   = _mm256_sub_ps(x, _mm256_mul_ps(div, p0)); | ||||
|             auto t         = xReamin; | ||||
|             auto c0        = _mm256_mul_ps(p7, t); | ||||
|             auto c1        = _mm256_add_ps(c0, p6); | ||||
|             auto c2        = _mm256_mul_ps(c1, t); | ||||
|             auto c3        = _mm256_add_ps(c2, p5); | ||||
|             auto c4        = _mm256_mul_ps(c3, t); | ||||
|             auto c5        = _mm256_add_ps(c4, p4); | ||||
|             auto c6        = _mm256_mul_ps(c5, t); | ||||
|             auto c7        = _mm256_add_ps(c6, p3); | ||||
|             auto c8        = _mm256_mul_ps(c7, t); | ||||
|             auto c9        = _mm256_add_ps(c8, p2); | ||||
|             auto expRemain = c9; | ||||
|             auto expRes    = _mm256_mul_ps(expBasic, expRemain); | ||||
|             sumVal         = _mm256_add_ps(expRes, sumVal); | ||||
|             _mm256_storeu_ps(dest + 8 * i, expRes); | ||||
|         } | ||||
|         _mm256_storeu_ps(tmpfloat8, sumVal); | ||||
|         for (int i = 0; i < 8; i++) { | ||||
|             sumValue += tmpfloat8[i]; | ||||
|         } | ||||
|     } | ||||
|     auto param = 0.6931471805599453; | ||||
|     float xLimit = 87; | ||||
|     for (int i = remain; i < size; i++) { | ||||
|         auto x         = source[i] - maxValue; | ||||
|         x = x > -xLimit ? x : -xLimit; | ||||
|         x = x < xLimit ? x : xLimit; | ||||
|     for (int k = 0; k < outside; ++k) { | ||||
|         float* source = input + k * reduceSize; | ||||
|         float* dest = softmaxDst + k * reduceSize; | ||||
| 
 | ||||
|         int div        = (x / param); | ||||
|         int div2       = (div + 127) << 23; | ||||
|         auto xReamin   = x - div * param; | ||||
|         float expBasic = *(float*)(&div2); | ||||
|         float tmpfloat8[8]; | ||||
|         int count  = reduceSize/ 8; | ||||
|         int remain = count * 8; | ||||
|         // step 1: get maxValue
 | ||||
|         float maxValue = source[0]; | ||||
| 
 | ||||
|         auto t         = xReamin; | ||||
|         auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f; | ||||
|         dest[i]  = expBasic * expRemain; | ||||
|         sumValue += dest[i]; | ||||
|     } | ||||
|     // step 3: get x / sum and store
 | ||||
|     for (int i = 0; i < count; ++i) { | ||||
|         // using  1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu
 | ||||
|         auto x = _mm256_rcp_ps(_mm256_loadu_ps(dest + 8 * i)); | ||||
|         auto y = _mm256_set1_ps(sumValue); | ||||
|         auto z = _mm256_rcp_ps(_mm256_mul_ps(x, y)); | ||||
|         _mm256_storeu_ps(dest + 8 * i, z); | ||||
|     } | ||||
|     sumValue = 1.f / sumValue; | ||||
|     for (int i = remain; i < size; i++) { | ||||
|         dest[i] *= sumValue; | ||||
|         float oldMax = maxValue; | ||||
|         if (runningMax) { | ||||
|             oldMax = runningMax[k]; | ||||
|         } | ||||
| 
 | ||||
|         if (count > 0) { | ||||
|             auto maxVal = _mm256_loadu_ps(source); | ||||
|             for (int i = 1; i < count; i++) { | ||||
|                 maxVal = _mm256_max_ps(maxVal, _mm256_loadu_ps(source + i * 8)); | ||||
|             } | ||||
|             _mm256_storeu_ps(tmpfloat8, maxVal); | ||||
|             maxValue = tmpfloat8[0] > tmpfloat8[1] ? tmpfloat8[0] : tmpfloat8[1]; | ||||
|             for (int i = 2; i < 8; i++) { | ||||
|                 maxValue = maxValue > tmpfloat8[i] ? maxValue : tmpfloat8[i]; | ||||
|             } | ||||
|         } | ||||
|         for (int i = remain; i < reduceSize; i++) { | ||||
|             maxValue = maxValue > source[i] ? maxValue : source[i]; | ||||
|         } | ||||
| 
 | ||||
|         float newMax = ALIMAX(oldMax, maxValue); | ||||
| 
 | ||||
|         // step 2: get exp(x - newMax) and sum(exp(x - newMax))
 | ||||
|         float exprOffset[4] = {1.0f, 0.0f, 0.0f, 0.0f }; | ||||
|         exprOffset[2] = -newMax; | ||||
|         MNNExp(dest, source, exprOffset, reduceSize); | ||||
|         float sumValue = exprOffset[3]; | ||||
| 
 | ||||
|         if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) { | ||||
|             // === Step 3: Update running variables ===
 | ||||
|             float scale = expf(oldMax - newMax); | ||||
|             runningSum[k] = runningSum[k] * scale + sumValue; | ||||
|             runningMax[k] = newMax; | ||||
|             updateScale[k] = scale; | ||||
|         } else { | ||||
|             // step 3: get x / sum and store
 | ||||
|             for (int i = 0; i < count; ++i) { | ||||
|                 // using  1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu
 | ||||
|                 auto x = _mm256_rcp_ps(_mm256_loadu_ps(dest + 8 * i)); | ||||
|                 auto y = _mm256_set1_ps(sumValue); | ||||
|                 auto z = _mm256_rcp_ps(_mm256_mul_ps(x, y)); | ||||
|                 _mm256_storeu_ps(dest + 8 * i, z); | ||||
|             } | ||||
|             auto scale = 1.f / sumValue; | ||||
|             for (int i = remain; i < reduceSize; i++) { | ||||
|                 dest[i] *= scale; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -48,8 +48,45 @@ void _AVX_MNNConvRunForLineDepthwise(float* dst, const float* src, const float* | |||
|                                 size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height, | ||||
|                                      size_t srcHStep, size_t dstHStep, const float* bias, const float* parameters); | ||||
| void _AVX_MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters); | ||||
| 
 | ||||
| #ifdef MNN_SUPPORT_TRANSFORMER_FUSE | ||||
| void _AVX_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes); | ||||
| #endif  | ||||
| } | ||||
| 
 | ||||
| #ifdef MNN_SUPPORT_TRANSFORMER_FUSE | ||||
| void _AVX_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) { | ||||
|     // source shape:                 [headDim/pack, seqLen, pack]
 | ||||
|     // scale & normalizeScale shape: [seqLen]
 | ||||
|     // dest shape:                   [headDim/pack, seqLen, pack]
 | ||||
|     auto stride0 = plane * pack; | ||||
| 
 | ||||
|     if (idx > 0) { | ||||
|         for (int j = 0; j < depthQuad; ++j) { | ||||
|             for (int i = 0; i < plane; ++i) { | ||||
|                 auto dataNew = Vec::load(src + j * stride0 + i * pack); | ||||
|                 auto dataOld = Vec::load(dst + j * stride0 + i * pack); | ||||
|                 auto s = Vec(scale[i]); | ||||
|                 dataNew = Vec::fma(dataNew, dataOld, s); | ||||
|                 Vec::save(dst + j * stride0 + i * pack, dataNew); | ||||
|             } | ||||
|         } | ||||
|     } else { | ||||
|         memcpy(dst, src, size * bytes); | ||||
|     } | ||||
|     if (idx == kvBlocks - 1) { // if last subBlock, exp(xi)/sum(exp(xi))
 | ||||
|         for (int j = 0; j < depthQuad; ++j) { | ||||
|             for (int i = 0; i < plane; ++i) { | ||||
|                 auto dataNew = Vec::load(dst + j * stride0 + i * pack); | ||||
|                 auto ns = Vec(1.0f / normalizeScale[i]); | ||||
|                 dataNew = dataNew * ns; | ||||
|                 Vec::save(dst + j * stride0 + i * pack, dataNew); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| #endif | ||||
| 
 | ||||
| 
 | ||||
| void _AVX_MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) { | ||||
|     for (int i = 0; i < count; ++i) { | ||||
|  | @ -728,4 +765,9 @@ void _AVX_ExtraInit(void* functions) { | |||
|     // sparse conv funcs
 | ||||
|     coreFunction->MNNGetSparseMatMulPackMode = _AVX_MNNGetSparseMatMulPackMode; | ||||
|     coreFunction->MNNAdjustOptimalSparseKernel = _AVX_MNNAdjustOptimalSparseKernel; | ||||
| 
 | ||||
|     // attention
 | ||||
| #ifdef MNN_SUPPORT_TRANSFORMER_FUSE | ||||
|     coreFunction->MNNFlashAttentionUpdateBlockOutput = _AVX_MNNFlashAttentionUpdateBlockOutput; | ||||
| #endif | ||||
| } | ||||
|  |  | |||
|  | @ -1064,6 +1064,39 @@ static void _AVX512_MNNAdjustOptimalSparseKernel(int& sparseBlockOC, MNN::CoreFu | |||
|     } | ||||
| } | ||||
| 
 | ||||
| #ifdef MNN_SUPPORT_TRANSFORMER_FUSE | ||||
| void _AVX512_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) { | ||||
|     // source shape:                 [headDim/pack, seqLen, pack]
 | ||||
|     // scale & normalizeScale shape: [seqLen]
 | ||||
|     // dest shape:                   [headDim/pack, seqLen, pack]
 | ||||
|     auto stride0 = plane * pack; | ||||
| 
 | ||||
|     if (idx > 0) { | ||||
|         for (int j = 0; j < depthQuad; ++j) { | ||||
|             for (int i = 0; i < plane; ++i) { | ||||
|                 auto dataNew = Vec::load(src + j * stride0 + i * pack); | ||||
|                 auto dataOld = Vec::load(dst + j * stride0 + i * pack); | ||||
|                 auto s = Vec(scale[i]); | ||||
|                 dataNew = Vec::fma(dataNew, dataOld, s); | ||||
|                 Vec::save(dst + j * stride0 + i * pack, dataNew); | ||||
|             } | ||||
|         } | ||||
|     } else { | ||||
|         memcpy(dst, src, size * bytes); | ||||
|     } | ||||
|     if (idx == kvBlocks - 1) { // if last subBlock, exp(xi)/sum(exp(xi))
 | ||||
|         for (int j = 0; j < depthQuad; ++j) { | ||||
|             for (int i = 0; i < plane; ++i) { | ||||
|                 auto dataNew = Vec::load(dst + j * stride0 + i * pack); | ||||
|                 auto ns = Vec(1.0f / normalizeScale[i]); | ||||
|                 dataNew = dataNew * ns; | ||||
|                 Vec::save(dst + j * stride0 + i * pack, dataNew); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| #endif | ||||
| 
 | ||||
| 
 | ||||
| void _AVX512_ExtraInit(void* functions) { | ||||
|     auto coreFunction = static_cast<MNN::CoreFunctions*>(functions); | ||||
|  | @ -1100,4 +1133,8 @@ void _AVX512_ExtraInit(void* functions) { | |||
| 
 | ||||
|     coreFunction->MNNGetSparseMatMulPackMode = _AVX512_MNNGetSparseMatMulPackMode; | ||||
|     coreFunction->MNNAdjustOptimalSparseKernel = _AVX512_MNNAdjustOptimalSparseKernel; | ||||
| 
 | ||||
| #ifdef MNN_SUPPORT_TRANSFORMER_FUSE  | ||||
|     coreFunction->MNNFlashAttentionUpdateBlockOutput = _AVX512_MNNFlashAttentionUpdateBlockOutput; | ||||
| #endif | ||||
| } | ||||
|  |  | |||
|  | @ -81,7 +81,7 @@ void _SSE_MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count); | |||
| 
 | ||||
| void _SSE_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose); | ||||
| void _SSE_MNNReluInt8(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint); | ||||
| void _SSE_MNNSoftmax(float* dest, const float* source, size_t size); | ||||
| void _SSE_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize); | ||||
| void _SSE_ExtraInit(void* functions); | ||||
| void _SSE_MNNNorm(float *dst, const float *src, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm); | ||||
| void _SSE_ImageProcessInit(void* functions, int cpuFlags); | ||||
|  |  | |||
|  | @ -69,100 +69,68 @@ void _SSE_MNNExpC8(float* dest, const float* source, float* offset, const float* | |||
|     offset[3] = total; | ||||
| } | ||||
| 
 | ||||
| void _SSE_MNNSoftmax(float* dest, const float* source, size_t size) { | ||||
|     float tmpfloat4[4]; | ||||
|     int count  = static_cast<int32_t>(size / 4); | ||||
|     int remain = count * 4; | ||||
|     // step 1: get maxValue
 | ||||
|     float maxValue = source[0]; | ||||
|     if (count > 0) { | ||||
|         auto maxVal = _mm_loadu_ps(source); | ||||
|         for (int i = 1; i < count; i++) { | ||||
|             maxVal = _mm_max_ps(maxVal, _mm_loadu_ps(source + i * 4)); | ||||
| void _SSE_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) { | ||||
|     const float xLimit = 87.0f; | ||||
|     const float param = 0.6931471805599453f; // ln(2)
 | ||||
|     const float inv_param = 1.0f / param; | ||||
|     const int32_t exp_offset = 127; | ||||
|     const float exp_scale = 8388608.0f; // 2^23
 | ||||
| 
 | ||||
|     for (int k = 0; k < outside; ++k) { | ||||
|         float* source = input + k * reduceSize; | ||||
|         float* dest = softmaxDst + k * reduceSize; | ||||
| 
 | ||||
|         float tmpfloat4[4]; | ||||
|         int count  = static_cast<int32_t>(reduceSize / 4); | ||||
|         int remain = count * 4; | ||||
|         // step 1: get maxValue
 | ||||
|         float maxValue = source[0]; | ||||
|         float oldMax = maxValue; | ||||
|         if (runningMax) { | ||||
|             oldMax = runningMax[k]; | ||||
|         } | ||||
|         _mm_storeu_ps(tmpfloat4, maxVal); | ||||
|         maxValue = tmpfloat4[0] > tmpfloat4[1] ? tmpfloat4[0] : tmpfloat4[1]; | ||||
|         maxValue = maxValue > tmpfloat4[2] ? maxValue : tmpfloat4[2]; | ||||
|         maxValue = maxValue > tmpfloat4[3] ? maxValue : tmpfloat4[3]; | ||||
|     } | ||||
|     for (int i = remain; i < size; i++) { | ||||
|         maxValue = maxValue > source[i] ? maxValue : source[i]; | ||||
|     } | ||||
| 
 | ||||
|     // step 2: get exp(x - maxValue) and sum(exp(x - maxValue))
 | ||||
|     float sumValue = 0.f; | ||||
|     if (count > 0) { | ||||
|         auto sumVal = _mm_set1_ps(0.f); | ||||
|         auto p0    = _mm_set1_ps(0.6931471805599453); | ||||
|         auto p1    = _mm_set1_ps(1.4426950408889634); | ||||
|         auto p2    = _mm_set1_ps(1.f); | ||||
|         auto p3    = _mm_set1_ps(1.f); | ||||
|         auto p4    = _mm_set1_ps(0.5); | ||||
|         auto p5    = _mm_set1_ps(0.1666666666666666); | ||||
|         auto p6    = _mm_set1_ps(0.041666666666666664); | ||||
|         auto p7    = _mm_set1_ps(0.008333333333333333); | ||||
|         auto xMax  = _mm_set1_ps(87); | ||||
|         auto xMin  = _mm_set1_ps(-87); | ||||
|         // auto basic = _mm_set1_epi32(1 << 23);
 | ||||
|         for (int i = 0; i < count; ++i) { | ||||
|             auto x            = _mm_sub_ps(_mm_loadu_ps(source + i * 4), _mm_set1_ps(maxValue)); | ||||
|             x                 = _mm_max_ps(x, xMin); | ||||
|             x                 = _mm_min_ps(x, xMax); | ||||
|             auto div          = _mm_mul_ps(x, p1); | ||||
|             auto divInt       = _mm_cvtps_epi32(div); | ||||
|             div               = _mm_cvtepi32_ps(divInt); | ||||
|             auto div2         = _mm_add_epi32(divInt, _mm_set1_epi32(127)); | ||||
|             // div2 = _mm_mullo_epi32(div2, basic);
 | ||||
|             div2 = _mm_slli_epi32(div2, 23); | ||||
|             auto expBasic  = _mm_castsi128_ps(div2); | ||||
|             auto xReamin   = _mm_sub_ps(x, _mm_mul_ps(div, p0)); | ||||
|             auto t         = xReamin; | ||||
|             auto c0        = _mm_mul_ps(p7, t); | ||||
|             auto c1        = _mm_add_ps(c0, p6); | ||||
|             auto c2        = _mm_mul_ps(c1, t); | ||||
|             auto c3        = _mm_add_ps(c2, p5); | ||||
|             auto c4        = _mm_mul_ps(c3, t); | ||||
|             auto c5        = _mm_add_ps(c4, p4); | ||||
|             auto c6        = _mm_mul_ps(c5, t); | ||||
|             auto c7        = _mm_add_ps(c6, p3); | ||||
|             auto c8        = _mm_mul_ps(c7, t); | ||||
|             auto c9        = _mm_add_ps(c8, p2); | ||||
|             auto expRemain = c9; | ||||
|             auto expRes    = _mm_mul_ps(expBasic, expRemain); | ||||
|             sumVal         = _mm_add_ps(expRes, sumVal); | ||||
|             _mm_storeu_ps(dest + 4 * i, expRes); | ||||
|         if (count > 0) { | ||||
|             auto maxVal = _mm_loadu_ps(source); | ||||
|             for (int i = 1; i < count; i++) { | ||||
|                 maxVal = _mm_max_ps(maxVal, _mm_loadu_ps(source + i * 4)); | ||||
|             } | ||||
|             _mm_storeu_ps(tmpfloat4, maxVal); | ||||
|             maxValue = tmpfloat4[0] > tmpfloat4[1] ? tmpfloat4[0] : tmpfloat4[1]; | ||||
|             maxValue = maxValue > tmpfloat4[2] ? maxValue : tmpfloat4[2]; | ||||
|             maxValue = maxValue > tmpfloat4[3] ? maxValue : tmpfloat4[3]; | ||||
|         } | ||||
|         for (int i = remain; i < reduceSize; i++) { | ||||
|             maxValue = maxValue > source[i] ? maxValue : source[i]; | ||||
|         } | ||||
|         _mm_storeu_ps(tmpfloat4, sumVal); | ||||
|         sumValue = tmpfloat4[0] + tmpfloat4[1] + tmpfloat4[2] + tmpfloat4[3]; | ||||
|     } | ||||
|     auto param = 0.6931471805599453; | ||||
|     float xLimit = 87; | ||||
|     for (int i = remain; i < size; i++) { | ||||
|         auto x         = source[i] - maxValue; | ||||
|         x = x > -xLimit ? x : -xLimit; | ||||
|         x = x < xLimit ? x : xLimit; | ||||
| 
 | ||||
|         int div        = (x / param); | ||||
|         int div2       = (div + 127) << 23; | ||||
|         auto xReamin   = x - div * param; | ||||
|         float expBasic = *(float*)(&div2); | ||||
|         float newMax = ALIMAX(oldMax, maxValue); | ||||
| 
 | ||||
|         auto t         = xReamin; | ||||
|         auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f; | ||||
|         dest[i]  = expBasic * expRemain; | ||||
|         sumValue += dest[i]; | ||||
|     } | ||||
|     // step 3: get x / sum and store
 | ||||
|     for (int i = 0; i < count; ++i) { | ||||
|         // using  1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu
 | ||||
|         auto x = _mm_rcp_ps(_mm_loadu_ps(dest + 4 * i)); | ||||
|         auto y = _mm_set1_ps(sumValue); | ||||
|         auto z = _mm_rcp_ps(_mm_mul_ps(x, y)); | ||||
|         _mm_storeu_ps(dest + 4 * i, z); | ||||
|     } | ||||
|     sumValue = 1.f / sumValue; | ||||
|     for (int i = remain; i < size; i++) { | ||||
|         dest[i] *= sumValue; | ||||
|         // step 2: get exp(x - newMax) and sum(exp(x - newMax))
 | ||||
|         float exprOffset[4] = {1.0f, 0.0f, 0.0f, 0.0f }; | ||||
|         exprOffset[2] = -newMax; | ||||
|         MNNExp(dest, source, exprOffset, reduceSize); | ||||
|         float sumValue = exprOffset[3]; | ||||
| 
 | ||||
|         if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) { | ||||
|             // === Step 3: Update running variables ===
 | ||||
|             float scale = expf(oldMax - newMax); | ||||
|             runningSum[k] = runningSum[k] * scale + sumValue; | ||||
|             runningMax[k] = newMax; | ||||
|             updateScale[k] = scale; | ||||
|         } else { | ||||
|             // step 3: get x / sum and store
 | ||||
|             for (int i = 0; i < count; ++i) { | ||||
|                 // using  1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu
 | ||||
|                 auto x = _mm_rcp_ps(_mm_loadu_ps(dest + 4 * i)); | ||||
|                 auto y = _mm_set1_ps(sumValue); | ||||
|                 auto z = _mm_rcp_ps(_mm_mul_ps(x, y)); | ||||
|                 _mm_storeu_ps(dest + 4 * i, z); | ||||
|             } | ||||
|             auto scale = 1.f / sumValue; | ||||
|             for (int i = remain; i < reduceSize; i++) { | ||||
|                 dest[i] *= scale; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -360,26 +360,23 @@ bool BufferConvertor::convertToNC4HW4Buffer(const Tensor *buffer, const OpenCLBu | |||
|         default: | ||||
|             break; | ||||
|     } | ||||
|     if (mBufferToImageKernel.get() == nullptr || mBufferToImageKernelName != kernelName) { | ||||
|         mBufferToImageKernelName = kernelName; | ||||
|         std::set<std::string> buildOptions; | ||||
|         if(needTrans) { | ||||
|             //buildOptions.emplace("-DBUFFER_FORMAT_INP_TRANS");
 | ||||
|             kernelName += "_floatin"; | ||||
|         } | ||||
| #ifdef MNN_LOW_MEMORY | ||||
|         if (lowMemory) { | ||||
|             if (quantBit == 8) { | ||||
|                 // int8 case
 | ||||
|                 buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT8"); | ||||
|             } else if (quantBit == 4){ | ||||
|                 // int4 case
 | ||||
|                 buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT4"); | ||||
|             } else {/* More types to be supported. */} | ||||
|         } | ||||
| #endif | ||||
|         mBufferToImageKernel = runtime->buildKernelWithCache(kernelFile, kernelName, buildOptions, precision, buffer, image); | ||||
|     std::set<std::string> buildOptions; | ||||
|     if(needTrans) { | ||||
|         //buildOptions.emplace("-DBUFFER_FORMAT_INP_TRANS");
 | ||||
|         kernelName += "_floatin"; | ||||
|     } | ||||
| #ifdef MNN_LOW_MEMORY | ||||
|     if (lowMemory) { | ||||
|         if (quantBit == 8) { | ||||
|             // int8 case
 | ||||
|             buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT8"); | ||||
|         } else if (quantBit == 4){ | ||||
|             // int4 case
 | ||||
|             buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT4"); | ||||
|         } else {/* More types to be supported. */} | ||||
|     } | ||||
| #endif | ||||
|     mBufferToImageKernel = runtime->buildKernelWithCache(kernelFile, kernelName, buildOptions, precision, buffer, image); | ||||
|     auto kernel = mBufferToImageKernel->get(); | ||||
| 
 | ||||
|     uint32_t idx = 0; | ||||
|  |  | |||
|  | @ -43,7 +43,7 @@ public: | |||
|     explicit BufferConvertor(OpenCLRuntime *opencl_runtime) : mOpenCLRuntime(opencl_runtime) { | ||||
|     } | ||||
|     bool convertToNC4HW4Buffer(const Tensor *input, const OpenCLBufferFormat type, Tensor *output, int precision, | ||||
|                                bool needTrans, bool needWait = false, bool lowMemory = false, int quantBit = 0); | ||||
|                                bool needTrans, bool needWait = true, bool lowMemory = false, int quantBit = 0); | ||||
| 
 | ||||
| private: | ||||
|     OpenCLRuntime *mOpenCLRuntime; | ||||
|  |  | |||
|  | @ -48,7 +48,12 @@ CLRuntime::CLRuntime(const Backend::Info& info){ | |||
|         mMemory    = mInfo.user->memory; | ||||
|     } | ||||
|      | ||||
|         mOpenCLRuntime.reset(new OpenCLRuntime(platform_size, platform_id, device_id, context_ptr, hint())); | ||||
|     // protect
 | ||||
|     if(mPrecision > 2 || mPrecision < 0){ | ||||
|         mPrecision = BackendConfig::Precision_High; | ||||
|     } | ||||
|      | ||||
|     mOpenCLRuntime.reset(new OpenCLRuntime(platform_size, platform_id, device_id, context_ptr, hint())); | ||||
|      | ||||
|     //Whether runtimeError
 | ||||
|     mCLRuntimeError = mOpenCLRuntime->isCreateError(); | ||||
|  | @ -206,6 +211,10 @@ Backend* CLRuntime::onCreate(const BackendConfig* config, Backend* origin) const | |||
|         precision = config->precision; | ||||
|         memory = config->memory; | ||||
|     } | ||||
|     // protect
 | ||||
|     if(precision > 2 || precision < 0){ | ||||
|         precision = BackendConfig::Precision_High; | ||||
|     } | ||||
|     auto backend = new OpenCLBackend(precision, memory, mInfo.gpuMode, mImagePool, mBufferPool, this); | ||||
|     backend->setMetaPtr(pMeta); | ||||
|     return backend; | ||||
|  | @ -246,6 +255,10 @@ OpenCLBackend::OpenCLBackend(BackendConfig::PrecisionMode precision, BackendConf | |||
|     } else{ | ||||
|         mPrecision = BackendConfig::Precision_High; | ||||
|     } | ||||
|     // protect
 | ||||
|     if(mPrecision > 2 || mPrecision < 0){ | ||||
|         mPrecision = BackendConfig::Precision_High; | ||||
|     } | ||||
|     mMemory = memory; | ||||
|     // set tuneLevel, memtype, record mode
 | ||||
|     setGpuMode(gpuMode); | ||||
|  |  | |||
|  | @ -18,6 +18,8 @@ struct QnnContext { | |||
|     Qnn_LogHandle_t logHandle = nullptr; | ||||
|     Qnn_BackendHandle_t backendHandle = nullptr; | ||||
|     Qnn_DeviceHandle_t deviceHandle = nullptr; | ||||
|     int soc_id; | ||||
|     int dsp_arch; | ||||
| }; | ||||
| 
 | ||||
| static QnnContext gContext; | ||||
|  | @ -94,6 +96,7 @@ bool PluginShapeRaw::compute(InferShapeContext* ctx) { | |||
|             auto dst = ctx->output(i); | ||||
|             std::string key = prefix + std::to_string(i); | ||||
|             auto attr = ctx->getAttr(key.c_str()); | ||||
| 
 | ||||
|             if (nullptr == attr || nullptr == attr->tensor()) { | ||||
|                 MNN_ERROR("MNN_QNN: Failed to find raw shape %s.\n", key.c_str()); | ||||
|                 return false; | ||||
|  | @ -886,7 +889,12 @@ ErrorCode QnnBackend::onResizeEnd() { | |||
|     #ifdef QNN_VERBOSE | ||||
|     MNN_PRINT("start finalize\n"); | ||||
|     #endif | ||||
|     buildOutputDequant(); | ||||
|     finalizeGraph(); | ||||
|     for(auto func : mReleaseFunc){ | ||||
|         func(); | ||||
|     } | ||||
|     mReleaseFunc.clear(); | ||||
|     #ifdef QNN_VERBOSE | ||||
|     MNN_PRINT("end finalize\n"); | ||||
|     #endif | ||||
|  | @ -915,7 +923,14 @@ Backend::MemObj* QnnBackend::onAcquire(const Tensor* tensor, StorageType storage | |||
|     } | ||||
| 
 | ||||
|     Qnn_DataType_t tDataType; | ||||
|     MNN_ASSERT((tensor->getType().code == halide_type_float) || (tensor->getType().code == halide_type_int && tensor->getType().bits == 32)); | ||||
|     Qnn_QuantizeParams_t tQuantizeParams{}; | ||||
|     tQuantizeParams.encodingDefinition = QNN_DEFINITION_UNDEFINED; | ||||
|     tQuantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED; | ||||
|     Qnn_ScaleOffset_t tScaleOffsetEncoding; | ||||
|     tScaleOffsetEncoding.scale = 0.0f; | ||||
|     tScaleOffsetEncoding.offset = 0; | ||||
|     auto quant = TensorUtils::getDescribe(tensor)->quantAttr.get(); | ||||
|     //MNN_ASSERT((tensor->getType().code == halide_type_float) || (tensor->getType().code == halide_type_int && tensor->getType().bits == 32));
 | ||||
|     if (mUseFP16 && tensor->getType().code == halide_type_float) { | ||||
|         tDataType = QNN_DATATYPE_FLOAT_16; | ||||
|     } else if (tensor->getType().code == halide_type_float) { | ||||
|  | @ -926,13 +941,19 @@ Backend::MemObj* QnnBackend::onAcquire(const Tensor* tensor, StorageType storage | |||
|         MNN_PRINT("MNN_QNN: Not supported data type in <QnnBackend::onAcquire>.\n"); | ||||
|         return nullptr; | ||||
|     } | ||||
| 
 | ||||
|     Qnn_QuantizeParams_t tQuantizeParams{}; | ||||
|     tQuantizeParams.encodingDefinition = QNN_DEFINITION_UNDEFINED; | ||||
|     tQuantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED; | ||||
|     Qnn_ScaleOffset_t tScaleOffsetEncoding; | ||||
|     tScaleOffsetEncoding.scale = 0.0f; | ||||
|     tScaleOffsetEncoding.offset = 0; | ||||
|     if(quant != nullptr && TensorUtils::getDescribe(tensor)->type == DataType_DT_INT8) { | ||||
|         tQuantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED; | ||||
|         tQuantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; | ||||
|         tScaleOffsetEncoding.scale = quant->scale; | ||||
|         if(quant->zero != 0){ | ||||
|             MNN_PRINT("MNN_QNN: Not supported asymmetric quant in <QnnBackend::onAcquire>.\n"); | ||||
|             return nullptr; | ||||
|         } | ||||
|         tDataType = QNN_DATATYPE_SFIXED_POINT_8; | ||||
|         if (isOutput) { | ||||
|             tType = QNN_TENSOR_TYPE_NATIVE; | ||||
|         } | ||||
|     } | ||||
|     tQuantizeParams.scaleOffsetEncoding = tScaleOffsetEncoding; | ||||
| 
 | ||||
|     std::unique_ptr<Tensor> tempTensor(new Tensor(tensor, gQnnTensorDimType, false)); | ||||
|  | @ -947,17 +968,41 @@ Backend::MemObj* QnnBackend::onAcquire(const Tensor* tensor, StorageType storage | |||
| 
 | ||||
|     Qnn_Tensor_t * qnnTensor = qnnTensorWrapper->getNativeTensor(); | ||||
|     CALL_QNN(mRuntime->mQnnInterface.tensorCreateGraphTensor(mQnnGraphHandle, qnnTensor)); | ||||
|     mQNNTensorWrappers.push_back(qnnTensorWrapper); | ||||
|     mTensorMap.insert({TensorUtils::getDescribe(tensor), mTensorCounter}); | ||||
| 
 | ||||
|     if (isInput) { | ||||
|         mInputTensorIndexes.push_back(mTensorCounter); | ||||
|         qnnTensorWrapper->alloc(); | ||||
|     } | ||||
|     if (isOutput) { | ||||
|         mOutputTensorIndexes.push_back(mTensorCounter); | ||||
|         qnnTensorWrapper->alloc(); | ||||
|         if(quant != nullptr && TensorUtils::getDescribe(tensor)->type == DataType_DT_INT8){ | ||||
|             mTensorCounter += 1; | ||||
|             std::shared_ptr<Tensor> stageTensor; | ||||
|             stageTensor.reset(Tensor::create<float>(tensor->shape(), nullptr, gQnnTensorDimType)); | ||||
|             tName = "QnnTensor_" + std::to_string(mTensorCounter); | ||||
|             tType = QNN_TENSOR_TYPE_APP_READ; | ||||
|             if (mUseFP16 && tensor->getType().code == halide_type_float) { | ||||
|                 tDataType = QNN_DATATYPE_FLOAT_16; | ||||
|             } else if (tensor->getType().code == halide_type_float) { | ||||
|                 tDataType = QNN_DATATYPE_FLOAT_32; | ||||
|             } else { | ||||
|                 MNN_PRINT("MNN_QNN: Not supported data type in <QnnBackend::onAcquire>.\n"); | ||||
|                 return nullptr; | ||||
|             } | ||||
|             Qnn_QuantizeParams_t tQuantizeParamstmp = QNN_QUANTIZE_PARAMS_INIT; | ||||
|             std::shared_ptr<QNNTensorWrapper> qnnOutputTensorWrapper = QNNTensorWrapper::create(tName, tType, tDataType, tDims, tQuantizeParamstmp); | ||||
|             CALL_QNN(mRuntime->mQnnInterface.tensorCreateGraphTensor(mQnnGraphHandle, qnnOutputTensorWrapper->getNativeTensor())); | ||||
|             mDeQuantOutputTensorMap.insert({TensorUtils::getDescribe(tensor), {tensor, stageTensor}}); | ||||
|             mQNNTensorWrappers.push_back(qnnOutputTensorWrapper); | ||||
|             mTensorMap.insert({TensorUtils::getDescribe(const_cast<const Tensor*>(stageTensor.get())), mTensorCounter}); | ||||
|             mOutputTensorIndexes.push_back(mTensorCounter); | ||||
|             qnnOutputTensorWrapper->alloc(); | ||||
|         } else{ | ||||
|             mOutputTensorIndexes.push_back(mTensorCounter); | ||||
|             qnnTensorWrapper->alloc(); | ||||
|         } | ||||
|     } | ||||
|     mQNNTensorWrappers.push_back(qnnTensorWrapper); | ||||
|     mTensorMap.insert({TensorUtils::getDescribe(tensor), mTensorCounter}); | ||||
| 
 | ||||
|     mTensorCounter += 1; | ||||
|     #ifdef QNN_VERBOSE | ||||
|  | @ -1029,7 +1074,13 @@ void QnnBackend::inputIO(const Tensor* srcTensor, const Tensor* dstTensor) const | |||
| } | ||||
| 
 | ||||
| void QnnBackend::outputIO(const Tensor* srcTensor, const Tensor* dstTensor) const { | ||||
|     int srcIndex = getTensorIdx(srcTensor); | ||||
|     auto iter = mDeQuantOutputTensorMap.find(TensorUtils::getDescribe(srcTensor)); | ||||
|     int srcIndex = -1; | ||||
|     if(iter != mDeQuantOutputTensorMap.end()){ | ||||
|         srcIndex = getTensorIdx(iter->second.second.get()); | ||||
|     } else{ | ||||
|         srcIndex = getTensorIdx(srcTensor); | ||||
|     } | ||||
|     std::shared_ptr<QNNTensorWrapper> srcQnnTensorWrapper = mQNNTensorWrappers[srcIndex]; | ||||
|     std::shared_ptr<Tensor> srcDataContainer = srcQnnTensorWrapper->getDataContainer(); | ||||
| 
 | ||||
|  | @ -1091,7 +1142,7 @@ void QnnBackend::finalizeGraph() { | |||
|         // set QNN_PROFILE_LEVEL_DETAILED
 | ||||
|         QnnProfile_Level_t profileLevel = QNN_PROFILE_LEVEL_DETAILED; | ||||
|         MNN_PRINT("[QNN Profile] Creating QNN Profile Handle with DETAILED level.\n"); | ||||
|         auto profile_err = mRuntime->mQnnInterface.profileCreate(mQnnContextHandle, profileLevel, &mQnnProfileHandle); | ||||
|         auto profile_err = mRuntime->mQnnInterface.profileCreate(mRuntime->mQnnContextHandle, profileLevel, &mQnnProfileHandle); | ||||
|         if (profile_err != QNN_SUCCESS || mQnnProfileHandle == nullptr) { | ||||
|             MNN_ERROR("[QNN Profile] Failed to create QNN Profile Handle, error: %d\n", (int)profile_err); | ||||
|             mQnnProfileHandle = nullptr; | ||||
|  | @ -1216,6 +1267,27 @@ void QnnBackend::clean() { | |||
|     mTensorMap.clear(); | ||||
|     mInputTensorIndexes.clear(); | ||||
|     mOutputTensorIndexes.clear(); | ||||
|     mDeQuantOutputTensorMap.clear(); | ||||
| } | ||||
| void QnnBackend::buildOutputDequant(){ | ||||
|     Qnn_OpConfigVersion_t mOpConfigVersion = QNN_OPCONFIG_VERSION_1; | ||||
|     std::string mNodeName; | ||||
|     std::string mPackageName = "qti.aisw"; | ||||
|     std::string mNodeType; | ||||
|     std::vector<Qnn_Param_t> mParams; | ||||
|     std::vector<Qnn_Tensor_t> mInputs; | ||||
|     std::vector<Qnn_Tensor_t> mOutputs; | ||||
|     for(auto iter : mDeQuantOutputTensorMap){ | ||||
|         mNodeType.clear(); | ||||
|         mParams.clear(); | ||||
|         mInputs.clear(); | ||||
|         mOutputs.clear(); | ||||
|         mNodeType = "Dequantize"; | ||||
|         std::string name = "Dequantize_I_" + std::to_string(getTensorIdx(iter.second.first)) + "_O_" + std::to_string(getTensorIdx(iter.second.second.get()));; | ||||
|         mInputs.push_back(*(getNativeTensor(iter.second.first))); // input
 | ||||
|         mOutputs.push_back(*(getNativeTensor(iter.second.second.get()))); // output
 | ||||
|         addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| QnnRuntime::QnnRuntime(const Backend::Info& info, QNN_INTERFACE_VER_TYPE qnnInterface, Qnn_LogHandle_t qnnLogHandle, Qnn_BackendHandle_t qnnBackendHandle, Qnn_DeviceHandle_t qnnDeviceHandle) { | ||||
|  | @ -1324,12 +1396,23 @@ bool QnnRuntime::registerCustomOpPackage(QNN_INTERFACE_VER_TYPE qnnInterface, Qn | |||
| 
 | ||||
| class QnnRuntimeCreator : public RuntimeCreator { | ||||
| public: | ||||
|     virtual Runtime* onCreate(const Backend::Info& info) const { | ||||
|     virtual Runtime* onCreate(const Backend::Info& info) const override { | ||||
|         return QnnRuntime::create(info); | ||||
|     } | ||||
|     virtual bool onValid(Backend::Info& info) const { | ||||
|     virtual bool onValid(Backend::Info& info) const override { | ||||
|         return true; | ||||
|     } | ||||
|     virtual bool onGetDeviceInfo(const std::string& deviceKey, std::string& deviceValue) const override { | ||||
|         if(deviceKey == "soc_id" && gContext.soc_id != 0) { | ||||
|             deviceValue = std::to_string(gContext.soc_id); | ||||
|             return true; | ||||
|         } | ||||
|         if(deviceKey == "dsp_arch" && gContext.dsp_arch != 0) { | ||||
|             deviceValue = "v" + std::to_string(gContext.dsp_arch); | ||||
|             return true; | ||||
|         } | ||||
|         return false; | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| } // end namespace QNN
 | ||||
|  | @ -1401,6 +1484,8 @@ void registerQNNRuntimeCreator() { | |||
| 
 | ||||
|     // Create Device.
 | ||||
|     Qnn_DeviceHandle_t deviceHandle = nullptr; | ||||
|     QnnHtpDevice_Arch_t dspArch = QNN_HTP_DEVICE_ARCH_NONE; | ||||
|     uint32_t socId = 0; | ||||
|     { | ||||
|         // Check whether the device API is supported.
 | ||||
|         bool supportDevice = QNN::checkCapability(qnnInterface, QNN_PROPERTY_GROUP_DEVICE); | ||||
|  | @ -1422,10 +1507,8 @@ void registerQNNRuntimeCreator() { | |||
|                     MNN_PRINT("[Warning]: deviceGetPlatformInfo Failed to query platform info"); | ||||
|                 } else { | ||||
|                     QnnDevice_HardwareDeviceInfo_t* hwDeviceInfo = backendPlatformInfoPtr->v1.hwDevices; | ||||
|                     QnnHtpDevice_Arch_t arch = hwDeviceInfo->v1.deviceInfoExtension->onChipDevice.arch; | ||||
|                     uint32_t socModel = hwDeviceInfo->v1.deviceInfoExtension->onChipDevice.socModel; | ||||
|                     MNN_PRINT("Qnn Device soc_id: %d\n", socModel); | ||||
|                     MNN_PRINT("Qnn Device dsp_arch: v%d\n", arch); | ||||
|                     dspArch = hwDeviceInfo->v1.deviceInfoExtension->onChipDevice.arch; | ||||
|                     socId = hwDeviceInfo->v1.deviceInfoExtension->onChipDevice.socModel; | ||||
|                 } | ||||
|             } | ||||
|         } else { | ||||
|  | @ -1478,6 +1561,8 @@ void registerQNNRuntimeCreator() { | |||
|     QNN::gContext.backendHandle = backendHandle; | ||||
|     QNN::gContext.deviceHandle = deviceHandle; | ||||
|     QNN::gContext.logHandle = logHandle; | ||||
|     QNN::gContext.soc_id = socId; | ||||
|     QNN::gContext.dsp_arch = dspArch; | ||||
| 
 | ||||
|     QNN::registerQNNOps(); | ||||
|     MNNInsertExtraRuntimeCreator(MNN_FORWARD_NN, new QNN::QnnRuntimeCreator, false); | ||||
|  |  | |||
|  | @ -78,6 +78,10 @@ public: | |||
|     std::shared_ptr<QNNTensorWrapper> getTensorWrapper(const Tensor * tensor); | ||||
|     bool useCache() const; | ||||
|     bool getUseFP16() const; | ||||
|     void buildOutputDequant(); | ||||
|     void pushReleaseFunc(std::function<void()> func){ | ||||
|         mReleaseFunc.push_back(func); | ||||
|     } | ||||
| 
 | ||||
| private: | ||||
|     void clean(); | ||||
|  | @ -109,8 +113,10 @@ private: | |||
|     mutable int mTensorCounter = 0; | ||||
|     mutable std::vector<std::shared_ptr<QNNTensorWrapper>> mQNNTensorWrappers; | ||||
|     mutable std::map<const Tensor::InsideDescribe::NativeInsideDescribe *, int> mTensorMap; | ||||
|     mutable std::map<const Tensor::InsideDescribe::NativeInsideDescribe *, std::pair<const Tensor*, std::shared_ptr<Tensor>>> mDeQuantOutputTensorMap; | ||||
|     std::vector<int> mInputTensorIndexes; | ||||
|     std::vector<int> mOutputTensorIndexes; | ||||
|     std::vector<std::function<void()>> mReleaseFunc; | ||||
| }; | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -134,6 +134,8 @@ void registerQNNOps() { | |||
|     #ifdef MNN_SUPPORT_TRANSFORMER_FUSE | ||||
|     ___QNNAttentionCreator__OpType_Attention__(); | ||||
|     #endif | ||||
|     ___QNNQuantCreator__OpType_FloatToInt8__(); | ||||
|     ___QNNDeQuantCreator__OpType_Int8ToFloat__(); | ||||
| } | ||||
| 
 | ||||
| Tensor::DimensionType gQnnTensorDimType = Tensor::TENSORFLOW; | ||||
|  |  | |||
|  | @ -104,6 +104,8 @@ extern void ___QNNMatMulCreator__OpType_MatMul__(); | |||
| #ifdef MNN_SUPPORT_TRANSFORMER_FUSE | ||||
| extern void ___QNNAttentionCreator__OpType_Attention__(); | ||||
| #endif | ||||
| extern void ___QNNQuantCreator__OpType_FloatToInt8__(); | ||||
| extern void ___QNNDeQuantCreator__OpType_Int8ToFloat__(); | ||||
| void registerQNNOps(); | ||||
| 
 | ||||
| extern Tensor::DimensionType gQnnTensorDimType; | ||||
|  |  | |||
|  | @ -25,7 +25,7 @@ std::shared_ptr<QNNTensorWrapper> QNNTensorWrapper::create(const std::string & n | |||
| 
 | ||||
| std::shared_ptr<QNNTensorWrapper> QNNTensorWrapper::createStaticTensor(const std::string & name, Qnn_DataType_t dataType, const std::vector<uint32_t> & dimensions, const void * buffer, Qnn_QuantizeParams_t quantizeParam) { | ||||
|     MNN_ASSERT(!name.empty() && !dimensions.empty() && buffer); | ||||
|     MNN_ASSERT(dataType == QNN_DATATYPE_SFIXED_POINT_8 || dataType == QNN_DATATYPE_INT_32); | ||||
|     MNN_ASSERT(dataType == QNN_DATATYPE_SFIXED_POINT_8 || dataType == QNN_DATATYPE_INT_32 || dataType == QNN_DATATYPE_SFIXED_POINT_32 || dataType == QNN_DATATYPE_UFIXED_POINT_8); | ||||
| 
 | ||||
|     std::shared_ptr<QNNTensorWrapper> tensorWrapper = QNNTensorWrapper::create(name, QNN_TENSOR_TYPE_STATIC, dataType, dimensions, quantizeParam); | ||||
|     uint32_t numElement = 1; | ||||
|  | @ -114,7 +114,8 @@ void * QNNTensorWrapper::alloc() { | |||
|         dims[i] = (int)mDimensions[i]; | ||||
|     } | ||||
| 
 | ||||
|     MNN_ASSERT(mQnnTensor.v1.dataType == QNN_DATATYPE_FLOAT_32 || mQnnTensor.v1.dataType == QNN_DATATYPE_FLOAT_16 || mQnnTensor.v1.dataType == QNN_DATATYPE_INT_32 ||  mQnnTensor.v1.dataType == QNN_DATATYPE_SFIXED_POINT_8); | ||||
|     MNN_ASSERT(mQnnTensor.v1.dataType == QNN_DATATYPE_FLOAT_32 || mQnnTensor.v1.dataType == QNN_DATATYPE_FLOAT_16 || mQnnTensor.v1.dataType == QNN_DATATYPE_INT_32 ||  mQnnTensor.v1.dataType == QNN_DATATYPE_SFIXED_POINT_8 || mQnnTensor.v1.dataType == QNN_DATATYPE_SFIXED_POINT_32 | ||||
|     || mQnnTensor.v1.dataType == QNN_DATATYPE_UFIXED_POINT_8); | ||||
|     halide_type_t halideType; | ||||
| 
 | ||||
|     halideType.lanes = 1; | ||||
|  | @ -134,6 +135,15 @@ void * QNNTensorWrapper::alloc() { | |||
|         case QNN_DATATYPE_SFIXED_POINT_8: | ||||
|             halideType.code = halide_type_int; | ||||
|             halideType.bits = 8; | ||||
|             break; | ||||
|         case QNN_DATATYPE_SFIXED_POINT_32: | ||||
|             halideType.code = halide_type_int; | ||||
|             halideType.bits = 32; | ||||
|             break; | ||||
|         case QNN_DATATYPE_UFIXED_POINT_8: | ||||
|             halideType.code = halide_type_int; | ||||
|             halideType.bits = 8; | ||||
|             break; | ||||
|         default: | ||||
|             break; | ||||
|     } | ||||
|  |  | |||
|  | @ -269,6 +269,10 @@ std::vector<std::string> QNNTranslator::TranslateTensor(const QNNCommandTensor& | |||
|     if (isParam) { | ||||
|         result.push_back(QNNTranslator::TranslateParamDataArray(dataNameSymbol, cmdT.dataType, cmdT.clientBuf)); | ||||
|     } | ||||
|     if(hasQuant){ | ||||
|         std::vector<std::string> linesQuantScaleOffset = TranslateQuantizeScaleOffsetDataArray(tensorNameSymbol, cmdT.quantizeParams, cmdT.rank, cmdT.dimensions); | ||||
|         APPEND_VECTOR(result, linesQuantScaleOffset); | ||||
|     } | ||||
|     result.push_back("  Qnn_Tensor_t " + tensorNameSymbol +   " = QNN_TENSOR_INIT;"); | ||||
|     result.push_back("  {"); | ||||
|     result.push_back("  " + tensorNameSymbol + ".version = QNN_TENSOR_VERSION_1;"); | ||||
|  | @ -456,11 +460,122 @@ std::string QNNTranslator::TranslateParamDataArray(const std::string & dataNameS | |||
|     return result; | ||||
| } | ||||
| 
 | ||||
| std::vector<std::string> QNNTranslator::TranslateQuantizeScaleOffsetDataArray(const std::string & tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParams, uint32_t rank, const uint32_t * dimensions){ | ||||
|     std::vector<std::string> result; | ||||
|     if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET){ | ||||
|         result.push_back("  Qnn_ScaleOffset_t " + tensorNameSymbol + "_axis_scale_offset[] = {"); | ||||
|         int totalnum = (quantizeParams.axisScaleOffsetEncoding.numScaleOffsets + 3) / 4; | ||||
|         for(int i = 0; i < totalnum; ++i){ | ||||
|             std::string line = "    "; | ||||
|             for(int j = 0; j < 4; ++j){ | ||||
|                 int index = i * 4 + j; | ||||
|                 if(index >= quantizeParams.axisScaleOffsetEncoding.numScaleOffsets) | ||||
|                     break; | ||||
|                 line += "{.scale= " + std::to_string(quantizeParams.axisScaleOffsetEncoding.scaleOffset[index].scale) + ", .offset= " + std::to_string(quantizeParams.axisScaleOffsetEncoding.scaleOffset[index].offset) + "}, "; | ||||
|             } | ||||
|             result.push_back(line); | ||||
|         } | ||||
|         result.push_back("  };"); | ||||
|     } | ||||
|      | ||||
|     if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET){ | ||||
|         result.push_back("  float " + tensorNameSymbol + "_bwaxis_scale[] = {"); | ||||
|         int totalnum = (quantizeParams.bwAxisScaleOffsetEncoding.numElements + 3) / 4; | ||||
|         for(int i = 0; i < totalnum; ++i){ | ||||
|             std::string line = "    "; | ||||
|             for(int j = 0; j < 4; ++j){ | ||||
|                 int index = i * 4 + j; | ||||
|                 if(index >= quantizeParams.bwAxisScaleOffsetEncoding.numElements) | ||||
|                     break; | ||||
|                 line += std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.scales[index]) + ", "; | ||||
|             } | ||||
|             result.push_back(line); | ||||
|         } | ||||
|         result.push_back("  };"); | ||||
|         if(quantizeParams.bwAxisScaleOffsetEncoding.offsets != nullptr){ | ||||
|             result.push_back("  int32_t " + tensorNameSymbol + "_bwaxis_offset[] = {"); | ||||
|             for(int i = 0; i < totalnum; ++i){ | ||||
|                 std::string line = "    "; | ||||
|                 for(int j = 0; j < 4; ++j){ | ||||
|                     int index = i * 4 + j; | ||||
|                     if(index >= quantizeParams.bwAxisScaleOffsetEncoding.numElements) | ||||
|                         break; | ||||
|                     line += std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.offsets[index]) + ", "; | ||||
|                 } | ||||
|                 result.push_back(line); | ||||
|             } | ||||
|             result.push_back("  };"); | ||||
|         } | ||||
|     } | ||||
|      | ||||
|     if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION){ | ||||
|         int axis = quantizeParams.blockwiseExpansion->axis; | ||||
|         int oc = dimensions[axis]; | ||||
|         int blockSize = quantizeParams.blockwiseExpansion->numBlocksPerAxis; | ||||
|         result.push_back("  Qnn_BlockwiseExpansion_t " + tensorNameSymbol + "_blockwiseExpansion = QNN_BLOCKWISE_EXPANSION_INIT;"); | ||||
|          | ||||
|         result.push_back("  Qnn_ScaleOffset_t " + tensorNameSymbol + "_blockwiseExpansionScaleOffset[] = {"); | ||||
|         int totalnum = (oc + 3) / 4; | ||||
|         for(int i = 0; i < totalnum; ++i){ | ||||
|             std::string line = "    "; | ||||
|             for(int j = 0; j < 4; ++j){ | ||||
|                 int index = i * 4 + j; | ||||
|                 if(index >= oc) | ||||
|                     break; | ||||
|                 line += "{.scale= " + std::to_string(quantizeParams.blockwiseExpansion->scaleOffsets[index].scale) + ", .offset= " + std::to_string(quantizeParams.blockwiseExpansion->scaleOffsets[index].offset) + "}, "; | ||||
|             } | ||||
|             result.push_back(line); | ||||
|         } | ||||
|         result.push_back("  };"); | ||||
|         if(quantizeParams.blockwiseExpansion->blockScaleStorageType == QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8){ | ||||
|             result.push_back("  uint8_t " + tensorNameSymbol + "_blockwiseExpansionBlockScale[] = {"); | ||||
|             totalnum = (oc * blockSize + 3) / 4; | ||||
|             for(int i = 0; i < totalnum; ++i){ | ||||
|                 std::string line = "    "; | ||||
|                 for(int j = 0; j < 4; ++j){ | ||||
|                     int index = i * 4 + j; | ||||
|                     if(index >= oc * blockSize) | ||||
|                         break; | ||||
|                     line += std::to_string(quantizeParams.blockwiseExpansion->blocksScale8[index]) + ", "; | ||||
|                 } | ||||
|                 result.push_back(line); | ||||
|             } | ||||
|             result.push_back("  };"); | ||||
|         }else{ | ||||
|             result.push_back("  uint16_t " + tensorNameSymbol + "_blockwiseExpansionBlockScale[] = {"); | ||||
|             totalnum = (oc * blockSize + 3) / 4; | ||||
|             for(int i = 0; i < totalnum; ++i){ | ||||
|                 std::string line = "    "; | ||||
|                 for(int j = 0; j < 4; ++j){ | ||||
|                     int index = i * 4 + j; | ||||
|                     if(index >= oc * blockSize) | ||||
|                         break; | ||||
|                     line += std::to_string(quantizeParams.blockwiseExpansion->blocksScale16[index]) + ", "; | ||||
|                 } | ||||
|                 result.push_back(line); | ||||
|             } | ||||
|             result.push_back("  };"); | ||||
|         } | ||||
|         result.push_back("  " + tensorNameSymbol + "_blockwiseExpansion.axis = " + std::to_string(quantizeParams.blockwiseExpansion->axis) + ";"); | ||||
|         result.push_back("  " + tensorNameSymbol + "_blockwiseExpansion.scaleOffsets = " + tensorNameSymbol + "_blockwiseExpansionScaleOffset;"); | ||||
|         result.push_back("  " + tensorNameSymbol + "_blockwiseExpansion.numBlocksPerAxis = " + std::to_string(quantizeParams.blockwiseExpansion->numBlocksPerAxis) + ";"); | ||||
|         result.push_back("  " + tensorNameSymbol + "_blockwiseExpansion.blockScaleBitwidth = " + std::to_string(quantizeParams.blockwiseExpansion->blockScaleBitwidth) + ";"); | ||||
|         if(quantizeParams.blockwiseExpansion->blockScaleStorageType == QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8){ | ||||
|             result.push_back("  " + tensorNameSymbol + "_blockwiseExpansion.blockScaleStorageType = QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8;"); | ||||
|             result.push_back("  " + tensorNameSymbol + "_blockwiseExpansion.blocksScale8 = " + tensorNameSymbol + "_blockwiseExpansionBlockScale;"); | ||||
|         }else{ | ||||
|             result.push_back("  " + tensorNameSymbol + "_blockwiseExpansion.blockScaleStorageType = QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_16;"); | ||||
|             result.push_back("  " + tensorNameSymbol + "_blockwiseExpansion.blocksScale16 = " + tensorNameSymbol + "_blockwiseExpansionBlockScale;"); | ||||
|         } | ||||
|     } | ||||
|     return result; | ||||
| } | ||||
| 
 | ||||
| // Currently, only support QNN_QUANTIZATION_ENCODING_UNDEFINED, QNN_QUANTIZATION_ENCODING_SCALE_OFFSET.
 | ||||
| std::vector<std::string> QNNTranslator::TranslateTensorQuantizeParams(const std::string tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParmas) { | ||||
| std::vector<std::string> QNNTranslator::TranslateTensorQuantizeParams(const std::string tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParams) { | ||||
|     std::vector<std::string> result; | ||||
| 
 | ||||
|     if (quantizeParmas.encodingDefinition == QNN_DEFINITION_UNDEFINED) { | ||||
|     if (quantizeParams.encodingDefinition == QNN_DEFINITION_UNDEFINED) { | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_UNDEFINED;"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED;"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.scale = 0.0f;"); | ||||
|  | @ -468,14 +583,43 @@ std::vector<std::string> QNNTranslator::TranslateTensorQuantizeParams(const std: | |||
|         return result; | ||||
|     } | ||||
| 
 | ||||
|     if (quantizeParmas.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParmas.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { | ||||
|     if (quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.scale = " + std::to_string(quantizeParmas.scaleOffsetEncoding.scale) + ";"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.offset = " + std::to_string(quantizeParmas.scaleOffsetEncoding.offset) + ";"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.scale = " + std::to_string(quantizeParams.scaleOffsetEncoding.scale) + ";"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.offset = " + std::to_string(quantizeParams.scaleOffsetEncoding.offset) + ";"); | ||||
|         return result; | ||||
|     } | ||||
|      | ||||
|     if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET){ | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET;"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.axisScaleOffsetEncoding.axis = " + std::to_string(quantizeParams.axisScaleOffsetEncoding.axis) + ";"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets = " + std::to_string(quantizeParams.axisScaleOffsetEncoding.numScaleOffsets) + ";"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.axisScaleOffsetEncoding.scaleOffset = " + tensorNameSymbol + "_axis_scale_offset;"); | ||||
|         return result; | ||||
|     } | ||||
|      | ||||
|     if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET){ | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET;"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.axis = " + std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.axis) + ";"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.bitwidth = " + std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.bitwidth) + ";"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.numElements = " + std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.numElements) + ";"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.scales = " + tensorNameSymbol + "_bwaxis_scale;"); | ||||
|         if(quantizeParams.bwAxisScaleOffsetEncoding.offsets != nullptr) | ||||
|             result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.offset = " + tensorNameSymbol + "_bwaxis_offset;"); | ||||
|         return result; | ||||
|     } | ||||
|      | ||||
|     if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION){ | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION;"); | ||||
|         result.push_back("  " + tensorNameSymbol + ".v1.quantizeParams.blockwiseExpansion = &" + tensorNameSymbol + "_blockwiseExpansion;"); | ||||
|         return result; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     MNN_ERROR("MNN_QNN: Unknown QuantizeParams.\n"); | ||||
| 
 | ||||
|     return result; | ||||
|  |  | |||
|  | @ -69,7 +69,8 @@ private: | |||
|     static std::string MapDataType(Qnn_DataType_t dataType); | ||||
|     static std::string TranslateDimensionsArray(const std::string & dimensionsNameSymbol, uint32_t rank, const uint32_t * dimensions); | ||||
|     static std::string TranslateParamDataArray(const std::string & dataNameSymbol, Qnn_DataType_t dataType, const Qnn_ClientBuffer_t & clientBuf); | ||||
|     static std::vector<std::string> TranslateTensorQuantizeParams(const std::string tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParmas); | ||||
|     static std::vector<std::string> TranslateQuantizeScaleOffsetDataArray(const std::string & tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParams, uint32_t rank, const uint32_t * dimensions); | ||||
|     static std::vector<std::string> TranslateTensorQuantizeParams(const std::string tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParams); | ||||
|     static std::vector<std::string> TranslateTensorClientBuf(const std::string & tensorNameSymbol, const std::string & dataNameSymbol, const std::string & sname, const Qnn_ClientBuffer_t & clientBuf, bool hasClientBuf, bool isParam); | ||||
|     // Utility functions used by TranslateNode.
 | ||||
|     static std::vector<std::string> TranslateNodeParamArray(const std::string & nodeName,const std::string & paramArraySymbol, uint32_t numOfParams, const Qnn_Param_t * params); | ||||
|  |  | |||
|  | @ -11,6 +11,157 @@ | |||
| namespace MNN { | ||||
| namespace QNN { | ||||
| 
 | ||||
| void QNNConvDepthwise::isWeightQuantSupported(const Tensor *input, const int oc){ | ||||
|     Qnn_DataType_t dataType = mBackend->getNativeTensor(input)->v1.dataType; | ||||
|     if(mOp->main_as_Convolution2D()->quanParameter() == nullptr){ | ||||
|         mWeightQuant = false; | ||||
|         return; | ||||
|     }else{ | ||||
|         bool hasBais = false; | ||||
|         auto bias = mOp->main_as_Convolution2D()->bias(); | ||||
|         auto biasPtr = (float*)bias->data(); | ||||
|         for(int i = 0; i < oc; ++i){ | ||||
|             if(biasPtr[i] != 0.0f){ | ||||
|                 hasBais = true; | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|          | ||||
|         std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true); | ||||
|         if(quanCommon->asymmetric || dataType == QNN_DATATYPE_FLOAT_16 || dataType == QNN_DATATYPE_FLOAT_32){ | ||||
|             // not support asymmetric and mBlockSize > 1 results incorrect now
 | ||||
|             mWeightQuant = false; | ||||
|             return; | ||||
|         } | ||||
|          | ||||
|         float inputScale = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.scale; | ||||
|         int inputOffset = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.offset; | ||||
|         if(inputOffset == 0){ | ||||
|             mWeightQuant = true; | ||||
|         }else{ | ||||
|             if(hasBais){ | ||||
|                 mWeightQuant = false; | ||||
|             }else{ | ||||
|                 mWeightQuant = true; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| ErrorCode QNNConvDepthwise::onEncodeQuantDequantDepthConv(Tensor *input, Tensor *output, const int n, const int ic, const int oc) { | ||||
|     auto conv2D     = mOp->main_as_Convolution2D(); | ||||
|     auto common     = conv2D->common(); | ||||
|     Qnn_DataType_t dataType = QNN_DATATYPE_FLOAT_32; | ||||
|     if(mBackend->getUseFP16()){ | ||||
|         dataType = QNN_DATATYPE_FLOAT_16; | ||||
|     } | ||||
|      | ||||
|     // create dequant input stage tensor
 | ||||
|     this->createStageTensor("DequantInput", dataType, getNHWCShape(input)); // mTempTensorWrappers[2]
 | ||||
|     this->createStageTensor("QuantOutput", dataType, getNHWCShape(output)); // mTempTensorWrappers[3]
 | ||||
|      | ||||
|     // add nodes
 | ||||
|     { | ||||
|         // dequant input
 | ||||
|         { | ||||
|             mParams.clear(); | ||||
|             mInputs.clear(); | ||||
|             mOutputs.clear(); | ||||
|             mNodeType = "Dequantize"; | ||||
|             std::string name = mNodeName + "_dequant_input"; | ||||
|          | ||||
|             mInputs.push_back(*(mBackend->getNativeTensor(input))); // input
 | ||||
|             mOutputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
 | ||||
|             mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|         } | ||||
|          | ||||
|         if (common->relu() || common->relu6()) { | ||||
|             this->createStageTensor("ReluTensor", dataType, getNHWCShape(output)); // mTempTensorWrappers[4]
 | ||||
|             // Stage one
 | ||||
|             { | ||||
|                 mParams.clear(); | ||||
|                 mInputs.clear(); | ||||
|                 mOutputs.clear(); | ||||
|                 mNodeType = "DepthWiseConv2d"; | ||||
|                 std::string name = mNodeName + "_convDepthwise"; | ||||
|                 mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride
 | ||||
|                 mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount
 | ||||
|                 mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation
 | ||||
|          | ||||
|                 mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
 | ||||
|                 mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight
 | ||||
|                 mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias
 | ||||
|          | ||||
|                 mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor
 | ||||
|                 mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|             } | ||||
| 
 | ||||
|             // Stage two
 | ||||
|             { | ||||
|                 mParams.clear(); | ||||
|                 mInputs.clear(); | ||||
|                 mOutputs.clear(); | ||||
|                 mNodeType = common->relu6() ? "ReluMinMax" : "Relu"; | ||||
|                 std::string name = mNodeName + "_relu"; | ||||
|                 if (common->relu6()) { | ||||
|                     mParams.push_back(*(mParamScalarWrappers[1]->getNativeParam())); // min_value
 | ||||
|                     mParams.push_back(*(mParamScalarWrappers[2]->getNativeParam())); // max_value
 | ||||
|                 } | ||||
|                 mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor
 | ||||
|                 mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
 | ||||
|                 mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|             } | ||||
| 
 | ||||
|         } else { | ||||
|             mParams.clear(); | ||||
|             mInputs.clear(); | ||||
|             mOutputs.clear(); | ||||
|             mNodeType = "DepthWiseConv2d"; | ||||
|             mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride
 | ||||
|             mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount
 | ||||
|             mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation
 | ||||
|              | ||||
|             mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
 | ||||
|             mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight
 | ||||
|             mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias
 | ||||
|              | ||||
|             mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
 | ||||
|             mBackend->addNodeToGraph(mOpConfigVersion, mNodeName.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|         } | ||||
|          | ||||
|         // Quant output
 | ||||
|         { | ||||
|             auto QuantOutputTensor = mTempTensorWrappers[3]->getNativeTensor(); | ||||
|             if(mBackend->getUseFP16()){ | ||||
|                 this->createStageTensor("CastOutput", QNN_DATATYPE_FLOAT_32, getNHWCShape(output)); | ||||
|                  | ||||
|                 mParams.clear(); | ||||
|                 mInputs.clear(); | ||||
|                 mOutputs.clear(); | ||||
|                 mNodeType = "Cast"; | ||||
|                 std::string name = mNodeName + "_Cast_Output"; | ||||
|                  | ||||
|                 mInputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
 | ||||
|                 mOutputs.push_back(*(mTempTensorWrappers.back()->getNativeTensor())); // CastOutput
 | ||||
|                 mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|                 QuantOutputTensor = mTempTensorWrappers.back()->getNativeTensor(); | ||||
|             } | ||||
|             { | ||||
|                 mParams.clear(); | ||||
|                 mInputs.clear(); | ||||
|                 mOutputs.clear(); | ||||
|                 mNodeType = "Quantize"; | ||||
|                 std::string name = mNodeName + "_Quant_Output"; | ||||
|                  | ||||
|                 mInputs.push_back(*(QuantOutputTensor)); // stage tensor
 | ||||
|                 mOutputs.push_back(*(mBackend->getNativeTensor(output))); // output
 | ||||
|                 mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     return NO_ERROR; | ||||
| } | ||||
| 
 | ||||
| ErrorCode QNNConvDepthwise::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { | ||||
|     auto conv2D     = mOp->main_as_Convolution2D(); | ||||
|     auto common     = conv2D->common(); | ||||
|  | @ -36,6 +187,7 @@ ErrorCode QNNConvDepthwise::onEncode(const std::vector<Tensor *> &inputs, const | |||
|     dilationH = common->dilateY(); dilationW = common->dilateX(); | ||||
| } | ||||
|      | ||||
|     isWeightQuantSupported(inputs[0], oc); | ||||
|     // create all tensors and params
 | ||||
| { | ||||
|     std::vector<uint32_t> strideData = {(uint32_t)strideH, (uint32_t)strideW}; | ||||
|  | @ -49,10 +201,24 @@ ErrorCode QNNConvDepthwise::onEncode(const std::vector<Tensor *> &inputs, const | |||
|         this->createParamScalar("max_value", 6.0f); | ||||
|     } | ||||
| 
 | ||||
|     this->createWeight(dataType, oc, kernelH, kernelW); | ||||
|     this->createBias(dataType, oc); | ||||
|     this->createWeightAndBias(dataType, inputs[0], oc, kernelH, kernelW); | ||||
|     // dequant input and quant output
 | ||||
|     if(mWeightQuant == false && dataType != QNN_DATATYPE_FLOAT_16 && dataType != QNN_DATATYPE_FLOAT_32){ | ||||
|         return this->onEncodeQuantDequantDepthConv(inputs[0], outputs[0], n, ic, oc); | ||||
|     } | ||||
|      | ||||
|     if (common->relu() || common->relu6()) { | ||||
|         this->createStageTensor("ReluTensor", dataType, getNHWCShape(outputs[0])); | ||||
|         Qnn_QuantizeParams_t quantize = DEFAULT_QUANTIZE_PARAMS; | ||||
|         Qnn_ScaleOffset_t tScaleOffsetEncoding; | ||||
|         auto quant = TensorUtils::getDescribe(outputs[0])->quantAttr.get(); | ||||
|         if(quant != nullptr && TensorUtils::getDescribe(outputs[0])->type == DataType_DT_INT8){ | ||||
|             quantize.encodingDefinition = QNN_DEFINITION_DEFINED; | ||||
|             quantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; | ||||
|             tScaleOffsetEncoding.scale = quant->scale; | ||||
|             tScaleOffsetEncoding.offset = quant->zero; | ||||
|             quantize.scaleOffsetEncoding = tScaleOffsetEncoding; | ||||
|         } | ||||
|         this->createStageTensor("ReluTensor", dataType, getNHWCShape(outputs[0]), quantize); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  | @ -112,41 +278,138 @@ ErrorCode QNNConvDepthwise::onEncode(const std::vector<Tensor *> &inputs, const | |||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| void QNNConvDepthwise::createWeight(Qnn_DataType_t dataType, int oc, int kernelH, int kernelW) { | ||||
|     std::vector<float> weightData; | ||||
|     const float* source = nullptr; | ||||
|     int weightElementNum = 0; | ||||
|     std::shared_ptr<ConvolutionCommon::Int8Common> quanWeight; | ||||
|     ConvolutionCommon::getConvParameters(&quanWeight, mBackend, mOp, &source, &weightElementNum); | ||||
|     // oc ic h w ---> h w ic oc
 | ||||
|     weightData.resize(weightElementNum); | ||||
|     convertWeight(source, (float *) weightData.data(), oc, kernelH, kernelW); | ||||
|     this->createStaticFloatTensor("weight", dataType, {(uint32_t)kernelH, (uint32_t)kernelW, 1, (uint32_t)oc}, weightData.data()); | ||||
| } | ||||
| void QNNConvDepthwise::createWeightAndBias(Qnn_DataType_t dataType, const Tensor *input, int oc, int kernelH, int kernelW) { | ||||
|     if(mWeightQuant){ | ||||
|         Qnn_QuantizeParams_t weightQuantize{}; | ||||
|         std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true); | ||||
| 
 | ||||
|         // [TODO] Support asymmetric and other quantBits.
 | ||||
|         MNN_ASSERT(!quanCommon->asymmetric); | ||||
| 
 | ||||
| void QNNConvDepthwise::createBias(Qnn_DataType_t dataType, int oc) { | ||||
|     int biasElementNum = oc; | ||||
|     std::vector<float> biasData; | ||||
|     biasData.resize(biasElementNum, 0); | ||||
|     auto bias = mOp->main_as_Convolution2D()->bias(); | ||||
|     if (nullptr != bias) { | ||||
|         ::memcpy((void *)biasData.data(), (void *)bias->data(), biasElementNum * sizeof(float)); | ||||
|     } | ||||
|     this->createStaticFloatTensor("bias", dataType, {(uint32_t)oc}, biasData.data()); | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| // oc, h, w ---> h, w, oc
 | ||||
| void QNNConvDepthwise::convertWeight(const float * src, float * dst, int oc, int kernelH, int kernelW) { | ||||
|     for (int c = 0; c < oc; c++) { | ||||
|         for (int h = 0; h < kernelH; h++) { | ||||
|             for (int w = 0; w < kernelW; w++) { | ||||
|                 int srcOffset = w + kernelW * (h + kernelH * c); | ||||
|                 int dstOffset = c + oc * (w + kernelW * h); | ||||
|                 dst[dstOffset] = src[srcOffset]; | ||||
|         // create weight
 | ||||
|         const int8_t * source = quanCommon->weight.get(); | ||||
|         std::vector<int8_t> quantWeightData(oc * kernelH * kernelW, 0); | ||||
|         if(quanCommon->canUseInt4){ | ||||
|             for (int c = 0; c < oc; c++) { | ||||
|                 for (int h = 0; h < kernelH; h++) { | ||||
|                     for (int w = 0; w < kernelW; w++) { | ||||
|                         int srcOffset = w + kernelW * (h + kernelH * c); | ||||
|                         int dstOffset = c + oc * (w + kernelW * h); | ||||
|                         if(srcOffset % 2 == 0){ | ||||
|                             quantWeightData[dstOffset] = ((source[srcOffset / 2] >> 4) & 0x0f) - 8; | ||||
|                         }else{ | ||||
|                             quantWeightData[dstOffset] = (source[srcOffset / 2] & 0x0f) - 8; | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         }else{ | ||||
|             convertWeight(source, (int8_t *) quantWeightData.data(), oc, kernelH, kernelW); | ||||
|         } | ||||
| 
 | ||||
|         mDequantAlpha = quanCommon->alpha.get(); | ||||
|         if(quanCommon->canUseInt4){ | ||||
|             weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; | ||||
|             weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET; | ||||
|             Qnn_BwAxisScaleOffset_t weightBWAxisScaleOffsetEncoding{}; | ||||
|             weightBWAxisScaleOffsetEncoding.bitwidth = 4; | ||||
|             weightBWAxisScaleOffsetEncoding.axis = 3; | ||||
|             weightBWAxisScaleOffsetEncoding.numElements = oc; | ||||
|             mScale.resize(oc); | ||||
|             std::vector<int32_t> OffsetData(oc); | ||||
|             for (int i = 0; i < oc; i++) { | ||||
|                 mScale[i] = mDequantAlpha[i]; | ||||
|             } | ||||
|             weightBWAxisScaleOffsetEncoding.scales = mScale.data(); | ||||
|             weightQuantize.bwAxisScaleOffsetEncoding = weightBWAxisScaleOffsetEncoding; | ||||
|              | ||||
|             this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, 1, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize); | ||||
|             std::function<void()> mReleaseWeightScaleOffset = [&](){ | ||||
|                 std::vector<float>().swap(mScale); | ||||
|             }; | ||||
|             mBackend->pushReleaseFunc(mReleaseWeightScaleOffset); | ||||
|         }else{ | ||||
|             weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; | ||||
|             weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; | ||||
|             Qnn_AxisScaleOffset_t weightAxisScaleOffsetEncoding{}; | ||||
|             weightAxisScaleOffsetEncoding.axis = 3; | ||||
|             weightAxisScaleOffsetEncoding.numScaleOffsets = oc; | ||||
|             mScaleOffsetData.resize(oc); | ||||
|             for (int i = 0; i < oc; i++) { | ||||
|                 mScaleOffsetData[i].scale = mDequantAlpha[i]; | ||||
|                 mScaleOffsetData[i].offset = 0; | ||||
|             } | ||||
|             weightAxisScaleOffsetEncoding.scaleOffset = mScaleOffsetData.data(); | ||||
|             weightQuantize.axisScaleOffsetEncoding = weightAxisScaleOffsetEncoding; | ||||
| 
 | ||||
|             this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, 1, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize); | ||||
|          | ||||
|             std::function<void()> mReleaseWeightScaleOffset = [&](){ | ||||
|                 std::vector<Qnn_ScaleOffset_t>().swap(mScaleOffsetData); | ||||
|             }; | ||||
|             mBackend->pushReleaseFunc(mReleaseWeightScaleOffset); | ||||
|         } | ||||
|         // create bias
 | ||||
|         { | ||||
|             float inputScale = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.scale; | ||||
|             int inputOffset = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.offset; | ||||
|             std::vector<int> biasData; | ||||
|             biasData.resize(oc, 0); | ||||
| 
 | ||||
|             Qnn_QuantizeParams_t biasQuantize{}; | ||||
|             biasQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; | ||||
|             biasQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; | ||||
|             Qnn_AxisScaleOffset_t biasAxisScaleOffsetEncoding{}; | ||||
|             biasAxisScaleOffsetEncoding.axis = 0; | ||||
|             biasAxisScaleOffsetEncoding.numScaleOffsets = oc; | ||||
|             mBiasScaleOffsetData.resize(oc); | ||||
| 
 | ||||
|             auto bias = mOp->main_as_Convolution2D()->bias(); | ||||
|             auto biasPtr = (float*)bias->data(); | ||||
|             if (nullptr != bias) { | ||||
|                 for(int i = 0; i < oc; ++i){ | ||||
|                     float biasScale = inputScale * mDequantAlpha[i]; | ||||
|                     mBiasScaleOffsetData[i].scale = 0.f; | ||||
|                     mBiasScaleOffsetData[i].offset = 0; | ||||
|                     if(fabs(biasPtr[i]) < 0.000001 || fabs(biasScale) < 0.000001){ | ||||
|                         biasData[i] = 0; | ||||
|                     } else{ | ||||
|                         biasData[i] = (int)(biasPtr[i] / (biasScale)); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             biasAxisScaleOffsetEncoding.scaleOffset = mBiasScaleOffsetData.data(); | ||||
|             biasQuantize.axisScaleOffsetEncoding = biasAxisScaleOffsetEncoding; | ||||
| 
 | ||||
|             this->createStaticTensor("bias", QNN_DATATYPE_SFIXED_POINT_32, {(uint32_t)oc}, biasData.data(), biasQuantize); | ||||
|             std::function<void()> mReleaseBiasScaleOffset = [&](){ | ||||
|                 std::vector<Qnn_ScaleOffset_t>().swap(mBiasScaleOffsetData); | ||||
|             }; | ||||
|             mBackend->pushReleaseFunc(mReleaseBiasScaleOffset); | ||||
|         } | ||||
|     }else{ | ||||
|         Qnn_DataType_t floatDatatype = QNN_DATATYPE_FLOAT_32; | ||||
|         if(mBackend->getUseFP16()){ | ||||
|             floatDatatype = QNN_DATATYPE_FLOAT_16; | ||||
|         } | ||||
|         std::vector<float> weightData; | ||||
|         const float* source = nullptr; | ||||
|         int weightElementNum = 0; | ||||
|         std::shared_ptr<ConvolutionCommon::Int8Common> quanWeight; | ||||
|         ConvolutionCommon::getConvParameters(&quanWeight, mBackend, mOp, &source, &weightElementNum); | ||||
|         // oc ic h w ---> h w ic oc
 | ||||
|         weightData.resize(weightElementNum); | ||||
|         convertWeight(source, (float *) weightData.data(), oc, kernelH, kernelW); | ||||
|         this->createStaticFloatTensor("weight", floatDatatype, {(uint32_t)kernelH, (uint32_t)kernelW, 1, (uint32_t)oc}, weightData.data()); | ||||
|          | ||||
|         // create bias
 | ||||
|         std::vector<float> biasData; | ||||
|         biasData.resize(oc, 0); | ||||
|         auto bias = mOp->main_as_Convolution2D()->bias(); | ||||
|         if (nullptr != bias) { | ||||
|             ::memcpy((void *)biasData.data(), (void *)bias->data(), oc * sizeof(float)); | ||||
|         } | ||||
|         this->createStaticFloatTensor("bias", floatDatatype, {(uint32_t)oc}, biasData.data()); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -19,10 +19,28 @@ class QNNConvDepthwise : public QNNCommonExecution { | |||
| public: | ||||
|     QNNConvDepthwise(Backend *backend, const Op *op) : QNNCommonExecution(backend, op) {} | ||||
|     virtual ErrorCode onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | ||||
|     ErrorCode onEncodeQuantDequantDepthConv(Tensor *input, Tensor *output, const int n, const int ic, const int oc); | ||||
| private: | ||||
|     void createWeight(Qnn_DataType_t dataType, int oc, int kernelH, int kernelW); | ||||
|     void createBias(Qnn_DataType_t dataType, int oc); | ||||
|     void convertWeight(const float * src, float * dst, int oc, int kernelH, int kernelW); | ||||
| template <typename T> | ||||
|     void convertWeight(const T * src, T * dst, int oc, int kernelH, int kernelW) { | ||||
|         for (int c = 0; c < oc; c++) { | ||||
|             for (int h = 0; h < kernelH; h++) { | ||||
|                 for (int w = 0; w < kernelW; w++) { | ||||
|                     int srcOffset = w + kernelW * (h + kernelH * c); | ||||
|                     int dstOffset = c + oc * (w + kernelW * h); | ||||
|                     dst[dstOffset] = src[srcOffset]; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     void isWeightQuantSupported(const Tensor *input, const int oc); | ||||
|     void createWeightAndBias(Qnn_DataType_t dataType, const Tensor *input, int oc, int kernelH, int kernelW); | ||||
|     std::vector<float> mScale; | ||||
|     std::vector<Qnn_ScaleOffset_t> mScaleOffsetData; | ||||
|     std::vector<Qnn_ScaleOffset_t> mBiasScaleOffsetData; | ||||
|     std::vector<uint8_t> mBlockScale; | ||||
|     float *mDequantAlpha = nullptr; | ||||
|     bool mWeightQuant = false; | ||||
| }; | ||||
| 
 | ||||
| } // end namespace QNN
 | ||||
|  |  | |||
|  | @ -21,6 +21,63 @@ static std::pair<int, int> closest_factors(int n) { | |||
|     } | ||||
|     return {1, n};   | ||||
| } | ||||
| 
 | ||||
| void QNNConvolution::isWeightQuantSupported(const Tensor *input, const int ic, const int oc){ | ||||
|     Qnn_DataType_t dataType = mBackend->getNativeTensor(input)->v1.dataType; | ||||
|     if(mOp->main_as_Convolution2D()->quanParameter() == nullptr){ | ||||
|         mWeightQuant = false; | ||||
|         return; | ||||
|     }else{ | ||||
|         bool hasBias = false; | ||||
|         auto bias = mOp->main_as_Convolution2D()->bias(); | ||||
|         auto biasPtr = (float*)bias->data(); | ||||
|         for(int i = 0; i < oc; ++i){ | ||||
|             if(biasPtr[i] != 0.0f){ | ||||
|                 hasBias = true; | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|          | ||||
|         std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true); | ||||
|         int totalCount = quanCommon->alpha.size(); | ||||
|         mBlockSize = totalCount / oc; | ||||
|         if(quanCommon->asymmetric){ | ||||
|             // not support asymmetric and mBlockSize > 1 results incorrect now
 | ||||
|             mWeightQuant = false; | ||||
|             return; | ||||
|         } | ||||
|          | ||||
|         if(dataType == QNN_DATATYPE_FLOAT_16 || dataType == QNN_DATATYPE_FLOAT_32){ | ||||
|             if(mIsMatMul && mBlockSize == 1){ | ||||
|                 mWeightQuant = true; | ||||
|             }else{ | ||||
|                 mWeightQuant = false; | ||||
|             } | ||||
|             return; | ||||
|         } | ||||
|          | ||||
|         float inputScale = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.scale; | ||||
|         int inputOffset = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.offset; | ||||
|         if(inputOffset == 0){ | ||||
|             mWeightQuant = true; | ||||
|         }else{ | ||||
|             if(hasBias){ | ||||
|                 mWeightQuant = false; | ||||
|             }else{ | ||||
|                 mWeightQuant = true; | ||||
|             } | ||||
|         } | ||||
|          | ||||
|         if(mBlockSize > 1 && mWeightQuant){ | ||||
|             if(mIs1x1Conv && hasBias == false && (ic / mBlockSize) >= 16){ | ||||
|                 mWeightQuant = true; | ||||
|             }else{ | ||||
|                 mWeightQuant = false; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| ErrorCode QNNConvolution::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { | ||||
|     auto conv2D     = mOp->main_as_Convolution2D(); | ||||
|     auto common     = conv2D->common(); | ||||
|  | @ -46,32 +103,16 @@ ErrorCode QNNConvolution::onEncode(const std::vector<Tensor *> &inputs, const st | |||
|         dilationH = common->dilateY(); dilationW = common->dilateX(); | ||||
|         group = common->group(); | ||||
|     } | ||||
|     mIs1x1Conv = kernelW==1 && strideH==1 && \ | ||||
|                  strideW==1 && dilationH==1 && dilationW==1 && group==1 && \ | ||||
|                  padTop==0 && padBottom==0 && padLeft==0 && padRight==0; | ||||
|     mIsMatMul = ih==1 && iw==1 && oh==1 && ow==1 && mIs1x1Conv; | ||||
|     isWeightQuantSupported(inputs[0], ic, oc); | ||||
|      | ||||
|     const float * weightSource = nullptr; | ||||
|     std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon; | ||||
|     if (mOp->main_as_Convolution2D()->quanParameter()) { | ||||
|         bool forceFloat = (common->kernelX() == 1 && common->kernelY() == 1) ? false : true; | ||||
| 
 | ||||
|         quanCommon = ConvolutionCommon::load(mOp, this->backend(), forceFloat); | ||||
|         if (quanCommon->weightFloat.get() == nullptr) { | ||||
|             // [TODO] Support asymmetric and other quantBits.
 | ||||
|             // isQuantWeight && symmetric quantization && int8 quantization && 1x1 conv
 | ||||
|             if (quanCommon->asymmetric || quanCommon->canUseInt4) { | ||||
|                 return NOT_SUPPORT; | ||||
|             } | ||||
|             return this->onEncodeQuant(inputs[0], outputs[0], n, ih, iw, ic, oc, quanCommon); | ||||
|         } else { | ||||
|             weightSource = quanCommon->weightFloat.get(); | ||||
|         } | ||||
|     } else { | ||||
|         int weightElementNum; | ||||
|         ConvolutionCommon::getConvParameters(&quanCommon, mBackend, mOp, &weightSource, &weightElementNum); | ||||
|     if(mIsMatMul && mWeightQuant && (dataType == QNN_DATATYPE_FLOAT_16 || dataType == QNN_DATATYPE_FLOAT_32)){ | ||||
|         return onEncodeFpAIntBMatMul(inputs[0], outputs[0], n, ih, iw, ic, oc); | ||||
|     } | ||||
|      | ||||
|     #ifdef QNN_VERBOSE | ||||
|     MNN_PRINT("n:%d, ih:%d, iw:%d, ic:%d, oh:%d, ow:%d, oc:%d, kernelH:%d, kernelW:%d, dilationH:%d, dilationW:%d, strideH:%d, strideW:%d, group:%d, pad:%d %d %d %d\n", n, ih, iw, ic, oh, ow, oc, kernelH, kernelW, dilationH, \ | ||||
|         dilationW, strideH, strideW, group, padTop, padBottom, padLeft, padRight); | ||||
|     #endif | ||||
|     // create all tensors and params
 | ||||
|     { | ||||
|         std::vector<uint32_t> strideData = {(uint32_t)strideH, (uint32_t)strideW}; | ||||
|  | @ -85,12 +126,26 @@ ErrorCode QNNConvolution::onEncode(const std::vector<Tensor *> &inputs, const st | |||
|             this->createParamScalar("min_value", 0.0f); | ||||
|             this->createParamScalar("max_value", 6.0f); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|         this->createWeight(dataType, oc, ic, kernelH, kernelW, group, weightSource); | ||||
|         this->createBias(dataType, oc); | ||||
|         if (common->relu() || common->relu6()) { | ||||
|             this->createStageTensor("ReluTensor", dataType, getNHWCShape(outputs[0])); | ||||
|     this->createWeightAndBias(dataType, inputs[0], oc, ic, kernelH, kernelW, group); | ||||
|     // dequant input and quant output
 | ||||
|     if(mWeightQuant == false && dataType != QNN_DATATYPE_FLOAT_16 && dataType != QNN_DATATYPE_FLOAT_32){ | ||||
|         return this->onEncodeQuantDequantConv(inputs[0], outputs[0], n, ic, oc); | ||||
|     } | ||||
|      | ||||
|     if (common->relu() || common->relu6()) { | ||||
|         Qnn_QuantizeParams_t quantize = DEFAULT_QUANTIZE_PARAMS; | ||||
|         Qnn_ScaleOffset_t tScaleOffsetEncoding; | ||||
|         auto quant = TensorUtils::getDescribe(outputs[0])->quantAttr.get(); | ||||
|         if(quant != nullptr && TensorUtils::getDescribe(outputs[0])->type == DataType_DT_INT8){ | ||||
|             quantize.encodingDefinition = QNN_DEFINITION_DEFINED; | ||||
|             quantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; | ||||
|             tScaleOffsetEncoding.scale = quant->scale; | ||||
|             tScaleOffsetEncoding.offset = quant->zero; | ||||
|             quantize.scaleOffsetEncoding = tScaleOffsetEncoding; | ||||
|         } | ||||
|         this->createStageTensor("ReluTensor", dataType, getNHWCShape(outputs[0]), quantize); | ||||
|     } | ||||
| 
 | ||||
|     // add nodes
 | ||||
|  | @ -131,13 +186,34 @@ ErrorCode QNNConvolution::onEncode(const std::vector<Tensor *> &inputs, const st | |||
|             } | ||||
| 
 | ||||
|         } else { | ||||
|             bool isMatmul = ih==1 && iw==1 && oh==1 && ow==1 && kernelH==1 && kernelW==1 && strideH==1 && \ | ||||
|                 strideW==1 && dilationH==1 && dilationW==1 && group==1 && \ | ||||
|                 padTop==0 && padBottom==0 && padLeft==0 && padRight==0; | ||||
|             if(isMatmul && n > 1) { | ||||
|             if(mIsMatMul && n > 1) { | ||||
|                 auto num = closest_factors(n); | ||||
|                 this->createStageTensor("InputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, ic})); | ||||
|                 this->createStageTensor("OutputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, oc})); | ||||
|                 { | ||||
|                     Qnn_QuantizeParams_t quantize = DEFAULT_QUANTIZE_PARAMS; | ||||
|                     Qnn_ScaleOffset_t tScaleOffsetEncoding; | ||||
|                     auto quant = TensorUtils::getDescribe(inputs[0])->quantAttr.get(); | ||||
|                     if(quant != nullptr && TensorUtils::getDescribe(inputs[0])->type == DataType_DT_INT8){ | ||||
|                         quantize.encodingDefinition = QNN_DEFINITION_DEFINED; | ||||
|                         quantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; | ||||
|                         tScaleOffsetEncoding.scale = quant->scale; | ||||
|                         tScaleOffsetEncoding.offset = quant->zero; | ||||
|                         quantize.scaleOffsetEncoding = tScaleOffsetEncoding; | ||||
|                     } | ||||
|                     this->createStageTensor("InputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, ic}), quantize); | ||||
|                 } | ||||
|                 { | ||||
|                     Qnn_QuantizeParams_t quantize = DEFAULT_QUANTIZE_PARAMS; | ||||
|                     Qnn_ScaleOffset_t tScaleOffsetEncoding; | ||||
|                     auto quant = TensorUtils::getDescribe(outputs[0])->quantAttr.get(); | ||||
|                     if(quant != nullptr && TensorUtils::getDescribe(outputs[0])->type == DataType_DT_INT8){ | ||||
|                         quantize.encodingDefinition = QNN_DEFINITION_DEFINED; | ||||
|                         quantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET; | ||||
|                         tScaleOffsetEncoding.scale = quant->scale; | ||||
|                         tScaleOffsetEncoding.offset = quant->zero; | ||||
|                         quantize.scaleOffsetEncoding = tScaleOffsetEncoding; | ||||
|                     } | ||||
|                     this->createStageTensor("OutputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, oc}), quantize); | ||||
|                 } | ||||
|                 #ifdef QNN_VERBOSE | ||||
|                 MNN_PRINT("Matmul2Conv, start reshape batch:%d -> %dx%d\n", n, num.first, num.second); | ||||
|                 #endif | ||||
|  | @ -208,52 +284,273 @@ ErrorCode QNNConvolution::onEncode(const std::vector<Tensor *> &inputs, const st | |||
|     return NO_ERROR; | ||||
| } | ||||
| 
 | ||||
| ErrorCode QNNConvolution::onEncodeQuantDequantConv(Tensor *input, Tensor *output, const int n, const int ic, const int oc) { | ||||
|     auto conv2D     = mOp->main_as_Convolution2D(); | ||||
|     auto common     = conv2D->common(); | ||||
|     Qnn_DataType_t dataType = QNN_DATATYPE_FLOAT_32; | ||||
|     if(mBackend->getUseFP16()){ | ||||
|         dataType = QNN_DATATYPE_FLOAT_16; | ||||
|     } | ||||
|      | ||||
| ErrorCode QNNConvolution::onEncodeQuant(Tensor * input, Tensor * output, int n, int h, int w, int ic, int oc, std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon) { | ||||
|     // create dequant input stage tensor
 | ||||
|     this->createStageTensor("DequantInput", dataType, getNHWCShape(input)); // mTempTensorWrappers[2]
 | ||||
|     this->createStageTensor("QuantOutput", dataType, getNHWCShape(output)); // mTempTensorWrappers[3]
 | ||||
|      | ||||
|     // add nodes
 | ||||
|     { | ||||
|         // dequant input
 | ||||
|         { | ||||
|             mParams.clear(); | ||||
|             mInputs.clear(); | ||||
|             mOutputs.clear(); | ||||
|             mNodeType = "Dequantize"; | ||||
|             std::string name = mNodeName + "_dequant_input"; | ||||
|          | ||||
|             mInputs.push_back(*(mBackend->getNativeTensor(input))); // input
 | ||||
|             mOutputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
 | ||||
|             mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|         } | ||||
|          | ||||
|         if (common->relu() || common->relu6()) { | ||||
|             this->createStageTensor("ReluTensor", dataType, getNHWCShape(output)); // mTempTensorWrappers[4]
 | ||||
|             // Stage one
 | ||||
|             { | ||||
|                 mParams.clear(); | ||||
|                 mInputs.clear(); | ||||
|                 mOutputs.clear(); | ||||
|                 mNodeType = "Conv2d"; | ||||
|                 std::string name = mNodeName + "_conv"; | ||||
|                 mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride
 | ||||
|                 mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount
 | ||||
|                 mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation
 | ||||
|                 mParams.push_back(*(mParamScalarWrappers[0]->getNativeParam())); // group
 | ||||
|          | ||||
|                 mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
 | ||||
|                 mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight
 | ||||
|                 mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias
 | ||||
|          | ||||
|                 mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor
 | ||||
|                 mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|             } | ||||
| 
 | ||||
|             // Stage two
 | ||||
|             { | ||||
|                 mParams.clear(); | ||||
|                 mInputs.clear(); | ||||
|                 mOutputs.clear(); | ||||
|                 mNodeType = common->relu6() ? "ReluMinMax" : "Relu"; | ||||
|                 std::string name = mNodeName + "_relu"; | ||||
|                 if (common->relu6()) { | ||||
|                     mParams.push_back(*(mParamScalarWrappers[1]->getNativeParam())); // min_value
 | ||||
|                     mParams.push_back(*(mParamScalarWrappers[2]->getNativeParam())); // max_value
 | ||||
|                 } | ||||
|                 mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor
 | ||||
|                 mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
 | ||||
|                 mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|             } | ||||
| 
 | ||||
|         } else { | ||||
|             if(mIsMatMul && n > 1) { | ||||
|                 auto num = closest_factors(n); | ||||
|                 this->createStageTensor("InputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, ic})); // mTempTensorWrappers[4]
 | ||||
|                 this->createStageTensor("OutputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, oc})); // mTempTensorWrappers[5]
 | ||||
|                 #ifdef QNN_VERBOSE | ||||
|                 MNN_PRINT("Matmul2Conv, start reshape batch:%d -> %dx%d\n", n, num.first, num.second); | ||||
|                 #endif | ||||
|                 // reshape input
 | ||||
|                 { | ||||
|                     std::string name = mNodeName + "_input_reshape"; | ||||
|                     mParams.clear(); | ||||
|                     mInputs.clear(); | ||||
|                     mOutputs.clear(); | ||||
|                     mNodeType = "Reshape"; | ||||
| 
 | ||||
|                     mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
 | ||||
|                     mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // InputReshapeTensor
 | ||||
|                     mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|                 } | ||||
|                 // conv2d
 | ||||
|                 { | ||||
|                     std::string name = mNodeName; | ||||
|                     mParams.clear(); | ||||
|                     mInputs.clear(); | ||||
|                     mOutputs.clear(); | ||||
|                     mNodeType = "Conv2d"; | ||||
| 
 | ||||
|                     mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride
 | ||||
|                     mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount
 | ||||
|                     mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation
 | ||||
|                     mParams.push_back(*(mParamScalarWrappers[0]->getNativeParam())); // group
 | ||||
| 
 | ||||
|                     mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // InputReshapeTensor
 | ||||
|                     mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight
 | ||||
|                     mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias
 | ||||
| 
 | ||||
|                     mOutputs.push_back(*(mTempTensorWrappers[5]->getNativeTensor())); // OutputReshapeTensor
 | ||||
|                     mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|                 } | ||||
| 
 | ||||
|                 // reshape output
 | ||||
|                 { | ||||
|                     std::string name = mNodeName + "_output_reshape"; | ||||
|                     mParams.clear(); | ||||
|                     mInputs.clear(); | ||||
|                     mOutputs.clear(); | ||||
|                     mNodeType = "Reshape"; | ||||
| 
 | ||||
|                     mInputs.push_back(*(mTempTensorWrappers[5]->getNativeTensor())); // OutputReshapeTensor
 | ||||
|                     mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
 | ||||
|                     mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|                 } | ||||
|             } else{ | ||||
|                 mParams.clear(); | ||||
|                 mInputs.clear(); | ||||
|                 mOutputs.clear(); | ||||
|                 mNodeType = "Conv2d"; | ||||
|                 mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride
 | ||||
|                 mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount
 | ||||
|                 mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation
 | ||||
|                 mParams.push_back(*(mParamScalarWrappers[0]->getNativeParam())); // group
 | ||||
|                  | ||||
|                 mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
 | ||||
|                 mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight
 | ||||
|                 mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias
 | ||||
|                  | ||||
|                 mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
 | ||||
|                 mBackend->addNodeToGraph(mOpConfigVersion, mNodeName.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|             } | ||||
|         } | ||||
|          | ||||
|         // Quant output
 | ||||
|         { | ||||
|             auto QuantOutputTensor = mTempTensorWrappers[3]->getNativeTensor(); | ||||
|             if(mBackend->getUseFP16()){ | ||||
|                 this->createStageTensor("CastOutput", QNN_DATATYPE_FLOAT_32, getNHWCShape(output)); | ||||
|                  | ||||
|                 mParams.clear(); | ||||
|                 mInputs.clear(); | ||||
|                 mOutputs.clear(); | ||||
|                 mNodeType = "Cast"; | ||||
|                 std::string name = mNodeName + "_Cast_Output"; | ||||
|                  | ||||
|                 mInputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
 | ||||
|                 mOutputs.push_back(*(mTempTensorWrappers.back()->getNativeTensor())); // CastOutput
 | ||||
|                 mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|                 QuantOutputTensor = mTempTensorWrappers.back()->getNativeTensor(); | ||||
|             } | ||||
|             { | ||||
|                 mParams.clear(); | ||||
|                 mInputs.clear(); | ||||
|                 mOutputs.clear(); | ||||
|                 mNodeType = "Quantize"; | ||||
|                 std::string name = mNodeName + "_Quant_Output"; | ||||
|                  | ||||
|                 mInputs.push_back(*(QuantOutputTensor)); // stage tensor
 | ||||
|                 mOutputs.push_back(*(mBackend->getNativeTensor(output))); // output
 | ||||
|                 mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     return NO_ERROR; | ||||
| } | ||||
| 
 | ||||
| ErrorCode QNNConvolution::onEncodeFpAIntBMatMul(Tensor * input, Tensor * output, int n, int h, int w, int ic, int oc) { | ||||
|     // create parameters and stage tensors
 | ||||
|     auto conv2D     = mOp->main_as_Convolution2D(); | ||||
|     auto common     = conv2D->common(); | ||||
|     Qnn_DataType_t dataType = mBackend->getNativeTensor(input)->v1.dataType; | ||||
|     { | ||||
|         bool transposeWeightFlag = true; | ||||
|         this->createParamScalar("transpose_in1", transposeWeightFlag); | ||||
|          | ||||
|         std::vector<uint32_t> tempInputShape = {(uint32_t) n * h * w , (uint32_t) ic}; | ||||
|         std::vector<uint32_t> tempOutputShape = {(uint32_t) n * h * w , (uint32_t) oc}; | ||||
|         this->createStageTensor("tempInput", QNN_DATATYPE_FLOAT_16, tempInputShape); | ||||
|         this->createStageTensor("tempOutput", QNN_DATATYPE_FLOAT_16, tempOutputShape); | ||||
|         this->createStageTensor("tempInput", dataType, tempInputShape); | ||||
|         this->createStageTensor("tempOutput", dataType, tempOutputShape); | ||||
| 
 | ||||
|         // create weight
 | ||||
|         const int8_t * source = quanCommon->weight.get(); | ||||
|         std::vector<int8_t> quantWeightData(oc * ic, 0); | ||||
|         ::memcpy(quantWeightData.data(), source, oc * ic * sizeof(int8_t)); | ||||
| 
 | ||||
|         float * dequantAlpha = quanCommon->alpha.get(); | ||||
| 
 | ||||
|         Qnn_QuantizeParams_t weightQuantize{}; | ||||
|         weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; | ||||
|         weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; | ||||
|         Qnn_AxisScaleOffset_t weightAxisScaleOffsetEncoding{}; | ||||
|         weightAxisScaleOffsetEncoding.axis = 0; | ||||
|         weightAxisScaleOffsetEncoding.numScaleOffsets = oc; | ||||
|         std::vector<Qnn_ScaleOffset_t> scaleOffsetData(oc); | ||||
|         for (int i = 0; i < oc; i++) { | ||||
|             if (quanCommon->asymmetric) { | ||||
|                 scaleOffsetData[i].scale = dequantAlpha[2 * i + 1]; | ||||
|                 // scaleOffsetData[i].offset = (int) dequantAlpha[2 * i + 0];
 | ||||
|                 scaleOffsetData[i].offset = 0; | ||||
|             } else { | ||||
|                 scaleOffsetData[i].scale = dequantAlpha[i]; | ||||
|                 scaleOffsetData[i].offset = 0; | ||||
|         // create weight and bias
 | ||||
|         { | ||||
|             Qnn_QuantizeParams_t weightQuantize{}; | ||||
|             std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true); | ||||
|             MNN_ASSERT(!quanCommon->asymmetric); | ||||
|             const int8_t * source = quanCommon->weight.get(); | ||||
|             std::vector<int8_t> quantWeightData(oc * ic, 0); | ||||
|             if(quanCommon->canUseInt4){ | ||||
|                 for (int o = 0; o < oc; o++) { | ||||
|                     for (int i = 0; i < ic; i++) { | ||||
|                         uint32_t srcOffset = o * ic + i; | ||||
|                         uint32_t dstOffset = srcOffset; | ||||
|                         if(srcOffset % 2 == 0){ | ||||
|                             quantWeightData[dstOffset] = ((source[srcOffset / 2] >> 4) & 0x0f) - 8; | ||||
|                         }else{ | ||||
|                             quantWeightData[dstOffset] = (source[srcOffset / 2] & 0x0f) - 8; | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             }else{ | ||||
|                 ::memcpy(quantWeightData.data(), source, oc * ic * sizeof(int8_t)); | ||||
|             } | ||||
|         } | ||||
|         weightAxisScaleOffsetEncoding.scaleOffset = scaleOffsetData.data(); | ||||
|         weightQuantize.axisScaleOffsetEncoding = weightAxisScaleOffsetEncoding; | ||||
|             mDequantAlpha = quanCommon->alpha.get(); | ||||
|             int totalCount = quanCommon->alpha.size(); | ||||
|             mBlockSize = totalCount / oc; | ||||
|             int blockNum = ic / mBlockSize; | ||||
|             if(quanCommon->canUseInt4){ | ||||
|                 weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; | ||||
|                 weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET; | ||||
|                 Qnn_BwAxisScaleOffset_t weightBWAxisScaleOffsetEncoding{}; | ||||
|                 weightBWAxisScaleOffsetEncoding.bitwidth = 4; | ||||
|                 weightBWAxisScaleOffsetEncoding.axis = 0; | ||||
|                 weightBWAxisScaleOffsetEncoding.numElements = oc; | ||||
|                 mScale.resize(oc); | ||||
|                 std::vector<int32_t> OffsetData(oc); | ||||
|                 for (int i = 0; i < oc; i++) { | ||||
|                     mScale[i] = mDequantAlpha[i]; | ||||
|                 } | ||||
|                 weightBWAxisScaleOffsetEncoding.scales = mScale.data(); | ||||
|                 weightQuantize.bwAxisScaleOffsetEncoding = weightBWAxisScaleOffsetEncoding; | ||||
|                  | ||||
|         this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t) oc, (uint32_t) ic}, (void *) quantWeightData.data(), weightQuantize); | ||||
|                 this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)oc, (uint32_t)ic}, (void *) quantWeightData.data(), weightQuantize); | ||||
|                 std::function<void()> mReleaseWeightScaleOffset = [&](){ | ||||
|                     std::vector<float>().swap(mScale); | ||||
|                 }; | ||||
|                 mBackend->pushReleaseFunc(mReleaseWeightScaleOffset); | ||||
|             }else{ | ||||
|                 weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; | ||||
|                 weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; | ||||
|                 Qnn_AxisScaleOffset_t weightAxisScaleOffsetEncoding{}; | ||||
|                 weightAxisScaleOffsetEncoding.axis = 0; | ||||
|                 weightAxisScaleOffsetEncoding.numScaleOffsets = oc; | ||||
|                 mScaleOffsetData.resize(oc); | ||||
|                 for (int i = 0; i < oc; i++) { | ||||
|                     mScaleOffsetData[i].scale = mDequantAlpha[i]; | ||||
|                     mScaleOffsetData[i].offset = 0; | ||||
|                 } | ||||
|                 weightAxisScaleOffsetEncoding.scaleOffset = mScaleOffsetData.data(); | ||||
|                 weightQuantize.axisScaleOffsetEncoding = weightAxisScaleOffsetEncoding; | ||||
|                  | ||||
|                 this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)oc, (uint32_t)ic}, (void *) quantWeightData.data(), weightQuantize); | ||||
|                 std::function<void()> mReleaseWeightScaleOffset = [&](){ | ||||
|                     std::vector<Qnn_ScaleOffset_t>().swap(mScaleOffsetData); | ||||
|                 }; | ||||
|                 mBackend->pushReleaseFunc(mReleaseWeightScaleOffset); | ||||
|             } | ||||
|             //create bias
 | ||||
|             this->createBias(dataType, oc, input, quanCommon); | ||||
|         } | ||||
|          | ||||
|         if (common->relu6()) { | ||||
|             this->createParamScalar("min_value", 0.0f); | ||||
|             this->createParamScalar("max_value", 6.0f); | ||||
|         } | ||||
|         if (common->relu() || common->relu6()) { | ||||
|             this->createStageTensor("ReluTensor", dataType, getNHWCShape(output)); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     // Stage One: reshape input
 | ||||
|     { | ||||
|         mNodeType = "Reshape"; | ||||
|         std::string name = mNodeName + "_reshapeOutput"; | ||||
|         std::string name = mNodeName + "_reshapeInput"; | ||||
|         mParams.clear(); | ||||
|         mInputs.clear(); | ||||
|         mOutputs.clear(); | ||||
|  | @ -273,6 +570,7 @@ ErrorCode QNNConvolution::onEncodeQuant(Tensor * input, Tensor * output, int n, | |||
|         mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // tempInput
 | ||||
|         // mInputs.push_back(*(mBackend->getNativeTensor(input)));
 | ||||
|         mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // weight
 | ||||
|         mInputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // bias
 | ||||
|         mOutputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // tempOutput
 | ||||
|         // mOutputs.push_back(*(mBackend->getNativeTensor(output)));
 | ||||
|         mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|  | @ -286,48 +584,217 @@ ErrorCode QNNConvolution::onEncodeQuant(Tensor * input, Tensor * output, int n, | |||
|         mInputs.clear(); | ||||
|         mOutputs.clear(); | ||||
|         mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); | ||||
|         mOutputs.push_back(*(mBackend->getNativeTensor(output))); | ||||
|         if (common->relu() || common->relu6()){ | ||||
|             mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); //ReluTensor
 | ||||
|         }else{ | ||||
|             mOutputs.push_back(*(mBackend->getNativeTensor(output))); | ||||
|         } | ||||
|         mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|     } | ||||
|      | ||||
|     // Stage Four: relu or relu6
 | ||||
|     if (common->relu() || common->relu6()){ | ||||
|         mNodeType.clear(); | ||||
|         mParams.clear(); | ||||
|         mInputs.clear(); | ||||
|         mOutputs.clear(); | ||||
|         mNodeType = common->relu6() ? "ReluMinMax" : "Relu"; | ||||
|         std::string name = mNodeName + "_relu"; | ||||
|         if (common->relu6()) { | ||||
|             mParams.push_back(*(mParamScalarWrappers[1]->getNativeParam())); // min_value
 | ||||
|             mParams.push_back(*(mParamScalarWrappers[2]->getNativeParam())); // max_value
 | ||||
|         } | ||||
|         mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor
 | ||||
|         mOutputs.push_back(*(mBackend->getNativeTensor(output))); // output
 | ||||
|         mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|     } | ||||
|     return NO_ERROR; | ||||
| } | ||||
| 
 | ||||
| bool QNNConvolution::createWeightAndBias(Qnn_DataType_t dataType, const Tensor *input, int oc, int ic, int kernelH, int kernelW, int group) { | ||||
|     if(mWeightQuant){ | ||||
|         Qnn_QuantizeParams_t weightQuantize{}; | ||||
|         std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true); | ||||
|         if(quanCommon->asymmetric) { | ||||
|             MNN_ERROR("[Error]: Qnn weight quant only support symmetric currently\n"); | ||||
|             return false; | ||||
|         } | ||||
|         const int8_t * source = quanCommon->weight.get(); | ||||
|         std::vector<int8_t> quantWeightData(oc * (ic / group) * kernelH * kernelW, 0); | ||||
|         if(quanCommon->canUseInt4){ | ||||
|             for (int o = 0; o < oc; o++) { | ||||
|                 for (int i = 0; i < ic/group; i++) { | ||||
|                     for (int h = 0; h < kernelH; h++) { | ||||
|                         for (int w = 0; w < kernelW; w++) { | ||||
|                             uint32_t srcOffset = w + kernelW * (h + kernelH * (i + ic/group * o)); | ||||
|                             uint32_t dstOffset = o + oc * (i + ic/group * (w + kernelW * h)); | ||||
|                             if(srcOffset % 2 == 0){ | ||||
|                                 quantWeightData[dstOffset] = ((source[srcOffset / 2] >> 4) & 0x0f) - 8; | ||||
|                             }else{ | ||||
|                                 quantWeightData[dstOffset] = (source[srcOffset / 2] & 0x0f) - 8; | ||||
|                             } | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         }else{ | ||||
|             convertWeight(source, (int8_t *) quantWeightData.data(), oc, ic/group, kernelH, kernelW); | ||||
|         } | ||||
|         mDequantAlpha = quanCommon->alpha.get(); | ||||
|         int totalCount = quanCommon->alpha.size(); | ||||
|         mBlockSize = totalCount / oc; | ||||
|         // Todo: result is wrong, need to verify
 | ||||
|         if(mBlockSize > 1){ | ||||
|             Qnn_QuantizeParams_t weightQuantize{}; | ||||
|             weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; | ||||
|             weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION; | ||||
|              | ||||
|             weightBlockwiseExpansionEncoding.axis = 3; | ||||
|             weightBlockwiseExpansionEncoding.numBlocksPerAxis = mBlockSize; | ||||
|             weightBlockwiseExpansionEncoding.blockScaleBitwidth = 4; | ||||
|             weightBlockwiseExpansionEncoding.blockScaleStorageType = QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8; | ||||
|             mBlockScale.resize(oc * mBlockSize); | ||||
|             mScaleOffsetData.resize(oc); | ||||
|             for (int i = 0; i < oc; i++) { | ||||
|                 float maxscale = -MAXFLOAT; | ||||
|                 for(int j = 0; j < mBlockSize; ++j){ | ||||
|                     if(mDequantAlpha[i * mBlockSize + j] > maxscale){ | ||||
|                         maxscale = mDequantAlpha[i * mBlockSize + j]; | ||||
|                     } | ||||
|                 } | ||||
|                 float blockScale = maxscale / 16.0f; | ||||
|                 for(int j = 0; j < mBlockSize; ++j){ | ||||
|                     int quantBlock = round(mDequantAlpha[i * mBlockSize + j] / blockScale); | ||||
|                     mBlockScale[i * mBlockSize + j] = (uint8_t)std::min(std::max(quantBlock, 1), 16); | ||||
|                 } | ||||
|                 mScaleOffsetData[i].scale = blockScale; | ||||
|                 mScaleOffsetData[i].offset = 0; | ||||
|             } | ||||
|             weightBlockwiseExpansionEncoding.scaleOffsets = mScaleOffsetData.data(); | ||||
|             weightBlockwiseExpansionEncoding.blocksScale8 = mBlockScale.data(); | ||||
|             weightQuantize.blockwiseExpansion = &weightBlockwiseExpansionEncoding; | ||||
|             this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize); | ||||
|             std::function<void()> mReleaseWeightScaleOffset = [&](){ | ||||
|                 std::vector<Qnn_ScaleOffset_t>().swap(mScaleOffsetData); | ||||
|             }; | ||||
|             mBackend->pushReleaseFunc(mReleaseWeightScaleOffset); | ||||
|             std::function<void()> mReleaseBlockScale = [&](){ | ||||
|                 std::vector<uint8_t>().swap(mBlockScale); | ||||
|             }; | ||||
|             mBackend->pushReleaseFunc(mReleaseBlockScale); | ||||
|         }else if(quanCommon->canUseInt4){ | ||||
|             weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; | ||||
|             weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET; | ||||
|             Qnn_BwAxisScaleOffset_t weightBWAxisScaleOffsetEncoding{}; | ||||
|             weightBWAxisScaleOffsetEncoding.bitwidth = 4; | ||||
|             weightBWAxisScaleOffsetEncoding.axis = 3; | ||||
|             weightBWAxisScaleOffsetEncoding.numElements = oc; | ||||
|             mScale.resize(oc); | ||||
|             std::vector<int32_t> OffsetData(oc); | ||||
|             for (int i = 0; i < oc; i++) { | ||||
|                 mScale[i] = mDequantAlpha[i]; | ||||
|             } | ||||
|             weightBWAxisScaleOffsetEncoding.scales = mScale.data(); | ||||
|             weightQuantize.bwAxisScaleOffsetEncoding = weightBWAxisScaleOffsetEncoding; | ||||
|              | ||||
|             this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize); | ||||
|             std::function<void()> mReleaseWeightScaleOffset = [&](){ | ||||
|                 std::vector<float>().swap(mScale); | ||||
|             }; | ||||
|             mBackend->pushReleaseFunc(mReleaseWeightScaleOffset); | ||||
|         }else{ | ||||
|             weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; | ||||
|             weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; | ||||
|             Qnn_AxisScaleOffset_t weightAxisScaleOffsetEncoding{}; | ||||
|             weightAxisScaleOffsetEncoding.axis = 3; | ||||
|             weightAxisScaleOffsetEncoding.numScaleOffsets = oc; | ||||
|             mScaleOffsetData.resize(oc); | ||||
|             for (int i = 0; i < oc; i++) { | ||||
|                 mScaleOffsetData[i].scale = mDequantAlpha[i]; | ||||
|                 mScaleOffsetData[i].offset = 0; | ||||
|             } | ||||
|             weightAxisScaleOffsetEncoding.scaleOffset = mScaleOffsetData.data(); | ||||
|             weightQuantize.axisScaleOffsetEncoding = weightAxisScaleOffsetEncoding; | ||||
| 
 | ||||
|             this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize); | ||||
|             std::function<void()> mReleaseWeightScaleOffset = [&](){ | ||||
|                 std::vector<Qnn_ScaleOffset_t>().swap(mScaleOffsetData); | ||||
|             }; | ||||
|             mBackend->pushReleaseFunc(mReleaseWeightScaleOffset); | ||||
|         } | ||||
|         this->createBias(dataType, oc, input, quanCommon); | ||||
|     } else { | ||||
|         std::vector<float> weightData; | ||||
|         const float* source = nullptr; | ||||
|         int weightElementNum = 0; | ||||
|         std::shared_ptr<ConvolutionCommon::Int8Common> quanWeight; | ||||
|         ConvolutionCommon::getConvParameters(&quanWeight, mBackend, mOp, &source, &weightElementNum); | ||||
|         // oc ic h w ---> h w ic oc
 | ||||
|         weightData.resize(weightElementNum); | ||||
|         convertWeight(source, (float *) weightData.data(), oc, ic/group, kernelH, kernelW); | ||||
|         Qnn_DataType_t floatDatatype = QNN_DATATYPE_FLOAT_32; | ||||
|         if(mBackend->getUseFP16()){ | ||||
|             floatDatatype = QNN_DATATYPE_FLOAT_16; | ||||
|         } | ||||
|         this->createStaticFloatTensor("weight", floatDatatype, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, weightData.data()); | ||||
|         this->createBias(dataType, oc, input, nullptr); | ||||
|     } | ||||
|     return NO_ERROR; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| void QNNConvolution::createWeight(Qnn_DataType_t dataType, int oc, int ic, int kernelH, int kernelW, int group, const float * source) { | ||||
|     std::vector<float> weightData; | ||||
|     int weightElementNum = oc * ic / group * kernelH * kernelW; | ||||
|     // oc ic/group h w ---> h w ic/group oc
 | ||||
|     weightData.resize(weightElementNum); | ||||
| 
 | ||||
|     convertWeight(source, (float *) weightData.data(), oc, ic/group, kernelH, kernelW); | ||||
|     this->createStaticFloatTensor("weight", dataType, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, weightData.data()); | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| void QNNConvolution::createBias(Qnn_DataType_t dataType, int oc) { | ||||
| void QNNConvolution::createBias(Qnn_DataType_t dataType, int oc, const Tensor *input, std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon) { | ||||
|     int biasElementNum = oc; | ||||
|     std::vector<float> biasData; | ||||
|     biasData.resize(biasElementNum, 0); | ||||
|     auto bias = mOp->main_as_Convolution2D()->bias(); | ||||
|     if (nullptr != bias) { | ||||
|         ::memcpy((void *)biasData.data(), (void *)bias->data(), biasElementNum * sizeof(float)); | ||||
|     } | ||||
|     this->createStaticFloatTensor("bias", dataType, {(uint32_t)oc}, biasData.data()); | ||||
| } | ||||
|     if(dataType != QNN_DATATYPE_FLOAT_16 && dataType != QNN_DATATYPE_FLOAT_32 && mWeightQuant){ | ||||
|         mDequantAlpha = quanCommon->alpha.get(); | ||||
|         float inputScale = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.scale; | ||||
|         int inputOffset = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.offset; | ||||
|         std::vector<int> biasData; | ||||
|         biasData.resize(biasElementNum, 0); | ||||
| 
 | ||||
| // oc ic h w ---> h w ic oc
 | ||||
| void QNNConvolution::convertWeight(const float * src, float * dst, int oc, int ic, int kernelH, int kernelW) { | ||||
|     for (int o = 0; o < oc; o++) { | ||||
|         for (int i = 0; i < ic; i++) { | ||||
|             for (int h = 0; h < kernelH; h++) { | ||||
|                 for (int w = 0; w < kernelW; w++) { | ||||
|                     uint32_t srcOffset = w + kernelW * (h + kernelH * (i + ic * o)); | ||||
|                     uint32_t dstOffset = o + oc * (i + ic * (w + kernelW * h)); | ||||
|                     dst[dstOffset] = src[srcOffset]; | ||||
|         Qnn_QuantizeParams_t biasQuantize{}; | ||||
|         biasQuantize.encodingDefinition = QNN_DEFINITION_DEFINED; | ||||
|         biasQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET; | ||||
|         Qnn_AxisScaleOffset_t biasAxisScaleOffsetEncoding{}; | ||||
|         biasAxisScaleOffsetEncoding.axis = 0; | ||||
|         biasAxisScaleOffsetEncoding.numScaleOffsets = biasElementNum; | ||||
|         mBiasScaleOffsetData.resize(biasElementNum); | ||||
| 
 | ||||
|         auto bias = mOp->main_as_Convolution2D()->bias(); | ||||
|         auto biasPtr = (float*)bias->data(); | ||||
|         if (nullptr != bias) { | ||||
|             for(int i = 0; i < biasElementNum; ++i){ | ||||
|                 float biasScale = inputScale * mDequantAlpha[i]; | ||||
|                 mBiasScaleOffsetData[i].scale = biasScale; | ||||
|                 mBiasScaleOffsetData[i].offset = 0; | ||||
|                 if(fabs(biasPtr[i]) < 0.000001 || fabs(biasScale) < 0.000001){ | ||||
|                     biasData[i] = 0; | ||||
|                 } else{ | ||||
|                     biasData[i] = (int)(biasPtr[i] / biasScale); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|          | ||||
|         biasAxisScaleOffsetEncoding.scaleOffset = mBiasScaleOffsetData.data(); | ||||
|         biasQuantize.axisScaleOffsetEncoding = biasAxisScaleOffsetEncoding; | ||||
| 
 | ||||
|         this->createStaticTensor("bias", QNN_DATATYPE_SFIXED_POINT_32, {(uint32_t)biasElementNum}, biasData.data(), biasQuantize); | ||||
|         std::function<void()> mReleaseBiasScaleOffset = [&](){ | ||||
|             std::vector<Qnn_ScaleOffset_t>().swap(mBiasScaleOffsetData); | ||||
|         }; | ||||
|         mBackend->pushReleaseFunc(mReleaseBiasScaleOffset); | ||||
|     }else{ | ||||
|         std::vector<float> biasData; | ||||
|         biasData.resize(biasElementNum, 0); | ||||
|         auto bias = mOp->main_as_Convolution2D()->bias(); | ||||
|         if (nullptr != bias) { | ||||
|             ::memcpy((void *)biasData.data(), (void *)bias->data(), biasElementNum * sizeof(float)); | ||||
|         } | ||||
|         Qnn_DataType_t floatDatatype = QNN_DATATYPE_FLOAT_32; | ||||
|         if(mBackend->getUseFP16()){ | ||||
|             floatDatatype = QNN_DATATYPE_FLOAT_16; | ||||
|         } | ||||
|         this->createStaticFloatTensor("bias", floatDatatype, {(uint32_t)oc}, biasData.data()); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -19,12 +19,37 @@ class QNNConvolution : public QNNCommonExecution { | |||
| public: | ||||
|     QNNConvolution(Backend *backend, const Op *op) : QNNCommonExecution(backend, op) {} | ||||
|     virtual ErrorCode onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | ||||
|     ErrorCode onEncodeQuant(Tensor * input, Tensor * output, int n, int h, int w, int ic, int oc, std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon); | ||||
|     ErrorCode onEncodeFpAIntBMatMul(Tensor * input, Tensor * output, int n, int h, int w, int ic, int oc); | ||||
|     ErrorCode onEncodeQuantDequantConv(Tensor *input, Tensor *output, const int n, const int ic, const int oc); | ||||
| 
 | ||||
| private: | ||||
|     void createWeight(Qnn_DataType_t dataType, int oc, int ic, int kernelH, int kernelW, int group, const float * source); | ||||
|     void createBias(Qnn_DataType_t dataType, int oc); | ||||
|     void convertWeight(const float * src, float * dst, int oc, int ic, int kernelH, int kernelW); | ||||
|     template <typename T> | ||||
|     void convertWeight(const T * src, T * dst, int oc, int ic, int kernelH, int kernelW) { | ||||
|         for (int o = 0; o < oc; o++) { | ||||
|             for (int i = 0; i < ic; i++) { | ||||
|                 for (int h = 0; h < kernelH; h++) { | ||||
|                     for (int w = 0; w < kernelW; w++) { | ||||
|                         uint32_t srcOffset = w + kernelW * (h + kernelH * (i + ic * o)); | ||||
|                         uint32_t dstOffset = o + oc * (i + ic * (w + kernelW * h)); | ||||
|                         dst[dstOffset] = src[srcOffset]; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     void isWeightQuantSupported(const Tensor *input, const int ic, const int oc); | ||||
|     bool createWeightAndBias(Qnn_DataType_t dataType, const Tensor *input, int oc, int ic, int kernelH, int kernelW, int group); | ||||
|     void createBias(Qnn_DataType_t dataType, int oc, const Tensor *input, std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon); | ||||
|     std::vector<float> mScale; | ||||
|     std::vector<Qnn_ScaleOffset_t> mScaleOffsetData; | ||||
|     std::vector<Qnn_ScaleOffset_t> mBiasScaleOffsetData; | ||||
|     std::vector<uint8_t> mBlockScale; | ||||
|     Qnn_BlockwiseExpansion_t weightBlockwiseExpansionEncoding = QNN_BLOCKWISE_EXPANSION_INIT; | ||||
|     float *mDequantAlpha = nullptr; | ||||
|     int mBlockSize = 1; | ||||
|     bool mWeightQuant = false; | ||||
|     bool mIsMatMul = false; | ||||
|     bool mIs1x1Conv = false; | ||||
| }; | ||||
| 
 | ||||
| } // end namespace QNN
 | ||||
|  |  | |||
|  | @ -0,0 +1,81 @@ | |||
| //
 | ||||
| //  QNNQuant.cpp
 | ||||
| //  MNN
 | ||||
| //
 | ||||
| //  Created by MNN on b'2025/05/29'.
 | ||||
| //  Copyright © 2018, Alibaba Group Holding Limited
 | ||||
| //
 | ||||
| 
 | ||||
| #include "QNNQuant.hpp" | ||||
| 
 | ||||
| namespace MNN { | ||||
| namespace QNN { | ||||
| 
 | ||||
| ErrorCode QNNQuant::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { | ||||
|     this->createStageTensor("Cast", QNN_DATATYPE_FLOAT_32, getNHWCShape(outputs[0])); | ||||
|      // Stage one  fp16 -> fp32
 | ||||
|     { | ||||
|         mNodeType = "Cast"; | ||||
|         std::string name = mNodeName + "_Cast"; | ||||
|      | ||||
|         mInputs.push_back(*(mBackend->getNativeTensor(inputs[0]))); // input
 | ||||
|         mOutputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // stage tensor
 | ||||
|         mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|     } | ||||
| 
 | ||||
|     // Stage two  fp32 -> int8
 | ||||
|     { | ||||
|         mNodeType.clear(); | ||||
|         mParams.clear(); | ||||
|         mInputs.clear(); | ||||
|         mOutputs.clear(); | ||||
|         mNodeType = "Quantize"; | ||||
|         std::string name = mNodeName; | ||||
|      | ||||
|         mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // stage tensor
 | ||||
|         mOutputs.push_back(*(mBackend->getNativeTensor(outputs[0]))); // output
 | ||||
|         mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|     } | ||||
|     return NO_ERROR; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| class QNNQuantCreator : public QnnBackend::Creator { | ||||
| public: | ||||
|     virtual QNNCommonExecution * onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const MNN::Op* op, | ||||
|                                 Backend* backend) const override { | ||||
|         return new QNNQuant(backend, op); | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| ErrorCode QNNDeQuant::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { | ||||
|      // Stage one  int8 -> fp16
 | ||||
|     { | ||||
|         mNodeType.clear(); | ||||
|         mParams.clear(); | ||||
|         mInputs.clear(); | ||||
|         mOutputs.clear(); | ||||
|         mNodeType = "Dequantize"; | ||||
|         std::string name = mNodeName; | ||||
|      | ||||
|         mInputs.push_back(*(mBackend->getNativeTensor(inputs[0]))); // input
 | ||||
|         mOutputs.push_back(*(mBackend->getNativeTensor(outputs[0]))); // output
 | ||||
|         mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|     } | ||||
|     return NO_ERROR; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| class QNNDeQuantCreator : public QnnBackend::Creator { | ||||
| public: | ||||
|     virtual QNNCommonExecution * onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const MNN::Op* op, | ||||
|                                 Backend* backend) const override { | ||||
|         return new QNNDeQuant(backend, op); | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
| REGISTER_QNN_OP_CREATOR(QNNQuantCreator, OpType_FloatToInt8) | ||||
| REGISTER_QNN_OP_CREATOR(QNNDeQuantCreator, OpType_Int8ToFloat) | ||||
| 
 | ||||
| } // end namespace QNN
 | ||||
| } // end namespace MNN
 | ||||
|  | @ -0,0 +1,32 @@ | |||
| //
 | ||||
| //  QNNQuant.hpp
 | ||||
| //  MNN
 | ||||
| //
 | ||||
| //  Created by MNN on b'2025/05/29'.
 | ||||
| //  Copyright © 2018, Alibaba Group Holding Limited
 | ||||
| //
 | ||||
| 
 | ||||
| #ifndef MNN_QNNQUANT_HPP | ||||
| #define MNN_QNNQUANT_HPP | ||||
| 
 | ||||
| #include "QNNCommonExecution.hpp" | ||||
| 
 | ||||
| namespace MNN { | ||||
| namespace QNN { | ||||
| 
 | ||||
| class QNNQuant : public QNNCommonExecution { | ||||
| public: | ||||
|     QNNQuant(Backend *backend, const Op *op) : QNNCommonExecution(backend, op) {}; | ||||
|     virtual ErrorCode onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | ||||
| }; | ||||
| 
 | ||||
| class QNNDeQuant : public QNNCommonExecution { | ||||
| public: | ||||
|     QNNDeQuant(Backend *backend, const Op *op) : QNNCommonExecution(backend, op) {}; | ||||
|     virtual ErrorCode onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | ||||
| }; | ||||
| 
 | ||||
| } // end namespace QNN
 | ||||
| } // end namespace MNN
 | ||||
| 
 | ||||
| #endif // end MNN_QNNQUANT_HPP
 | ||||
|  | @ -28,10 +28,25 @@ ErrorCode QNNScale::onEncode(const std::vector<Tensor *> &inputs, const std::vec | |||
|         MNN_ASSERT(channel == mWeightData.size()); | ||||
| 
 | ||||
|         Qnn_DataType_t dataType = mBackend->getNativeTensor(inputs[0])->v1.dataType; | ||||
| 
 | ||||
|         this->createStaticFloatTensor("weight", dataType, {(uint32_t)channel}, mWeightData.data()); | ||||
|         this->createStaticFloatTensor("bias", dataType, {(uint32_t)channel}, mBiasData.data()); | ||||
|         this->createStageTensor("Stage", dataType, getNHWCShape(inputs[0])); | ||||
|         mNeedQuantDequant = dataType != QNN_DATATYPE_FLOAT_16 && dataType != QNN_DATATYPE_FLOAT_32; | ||||
|         if(mNeedQuantDequant){ | ||||
|             Qnn_DataType_t tempDataType = QNN_DATATYPE_FLOAT_32; | ||||
|             if(mBackend->getUseFP16()){ | ||||
|                 tempDataType = QNN_DATATYPE_FLOAT_16; | ||||
|             } | ||||
|             this->createStaticFloatTensor("weight", tempDataType, {(uint32_t)channel}, mWeightData.data()); | ||||
|             this->createStaticFloatTensor("bias", tempDataType, {(uint32_t)channel}, mBiasData.data()); | ||||
|             this->createStageTensor("Stage", tempDataType, getNHWCShape(inputs[0])); | ||||
|             this->createStageTensor("Stage_dequantize_input", tempDataType, getNHWCShape(inputs[0])); | ||||
|             this->createStageTensor("Stage_add_output", tempDataType, getNHWCShape(outputs[0])); | ||||
|             if(mBackend->getUseFP16()){ | ||||
|                 this->createStageTensor("Stage_cast_output", QNN_DATATYPE_FLOAT_32, getNHWCShape(outputs[0])); | ||||
|             } | ||||
|         }else{ | ||||
|             this->createStaticFloatTensor("weight", dataType, {(uint32_t)channel}, mWeightData.data()); | ||||
|             this->createStaticFloatTensor("bias", dataType, {(uint32_t)channel}, mBiasData.data()); | ||||
|             this->createStageTensor("Stage", dataType, getNHWCShape(inputs[0])); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     // add nodes
 | ||||
|  | @ -42,34 +57,98 @@ ErrorCode QNNScale::onEncode(const std::vector<Tensor *> &inputs, const std::vec | |||
| } | ||||
| 
 | ||||
| void QNNScale::mulWeight(Tensor * input) { | ||||
|     mNodeType = "ElementWiseMultiply"; | ||||
|     std::string name = mNodeName + "_mul"; | ||||
|     mParams.clear(); | ||||
|     mInputs.clear(); | ||||
|     mOutputs.clear(); | ||||
|     Qnn_DataType_t dataType = mBackend->getNativeTensor(input)->v1.dataType; | ||||
|     // need dequantize to float16
 | ||||
|     if(mNeedQuantDequant){ | ||||
|         mNodeType.clear(); | ||||
|         mParams.clear(); | ||||
|         mInputs.clear(); | ||||
|         mOutputs.clear(); | ||||
|         mNodeType = "Dequantize"; | ||||
|         std::string name = mNodeName + "_Dequantize"; | ||||
|      | ||||
|     mInputs.push_back(*(mBackend->getNativeTensor(input))); | ||||
|     mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); | ||||
|         mInputs.push_back(*(mBackend->getNativeTensor(input))); // input
 | ||||
|         mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); //Stage_dequantize_input
 | ||||
|         mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|     } | ||||
|     { | ||||
|         mNodeType.clear(); | ||||
|         mParams.clear(); | ||||
|         mInputs.clear(); | ||||
|         mOutputs.clear(); | ||||
|         mNodeType = "ElementWiseMultiply"; | ||||
|         std::string name = mNodeName + "_mul"; | ||||
|          | ||||
|     mOutputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); | ||||
|         if(mNeedQuantDequant){ | ||||
|             mInputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); //Stage_dequantize_input
 | ||||
|         }else{ | ||||
|             mInputs.push_back(*(mBackend->getNativeTensor(input))); | ||||
|         } | ||||
|         mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); | ||||
|          | ||||
|     mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|         mOutputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); | ||||
|          | ||||
|         mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| void QNNScale::addBias(Tensor * output) { | ||||
|     mNodeType = "ElementWiseAdd"; | ||||
|     std::string name = mNodeName + "_add"; | ||||
|     mParams.clear(); | ||||
|     mInputs.clear(); | ||||
|     mOutputs.clear(); | ||||
|     Qnn_DataType_t dataType = mBackend->getNativeTensor(output)->v1.dataType; | ||||
|     { | ||||
|         mNodeType.clear(); | ||||
|         mParams.clear(); | ||||
|         mInputs.clear(); | ||||
|         mOutputs.clear(); | ||||
|         mNodeType = "ElementWiseAdd"; | ||||
|         std::string name = mNodeName + "_add"; | ||||
|          | ||||
|     mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); | ||||
|     mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); | ||||
|         mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); | ||||
|         mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); | ||||
|          | ||||
|     mOutputs.push_back(*(mBackend->getNativeTensor(output))); | ||||
|         if(mNeedQuantDequant){ | ||||
|             mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // Stage_add_output
 | ||||
|         }else{ | ||||
|             mOutputs.push_back(*(mBackend->getNativeTensor(output))); | ||||
|         } | ||||
|          | ||||
|     mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|         mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|     } | ||||
|      | ||||
|     // need quantize output
 | ||||
|     if(mNeedQuantDequant){ | ||||
|         // Stage one  fp16 -> fp32
 | ||||
|         if(mBackend->getUseFP16()){ | ||||
|            mNodeType.clear(); | ||||
|            mParams.clear(); | ||||
|            mInputs.clear(); | ||||
|            mOutputs.clear(); | ||||
|            mNodeType = "Cast"; | ||||
|            std::string name = mNodeName + "_Cast"; | ||||
|         | ||||
|            mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // Stage_add_output
 | ||||
|            mOutputs.push_back(*(mTempTensorWrappers[5]->getNativeTensor())); // Stage_cast_output
 | ||||
|            mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|        } | ||||
| 
 | ||||
|        // Stage two  fp32 -> int8
 | ||||
|        { | ||||
|            mNodeType.clear(); | ||||
|            mParams.clear(); | ||||
|            mInputs.clear(); | ||||
|            mOutputs.clear(); | ||||
|            mNodeType = "Quantize"; | ||||
|            std::string name = mNodeName + "_Quantize"; | ||||
|             | ||||
|            if(mBackend->getUseFP16()){ | ||||
|                mInputs.push_back(*(mTempTensorWrappers[5]->getNativeTensor())); // Stage_cast_output
 | ||||
|            }else{ | ||||
|                mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // Stage_add_output
 | ||||
|            } | ||||
|            mOutputs.push_back(*(mBackend->getNativeTensor(output))); // output
 | ||||
|            mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs); | ||||
|        } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| ErrorCode QNNScale::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) { | ||||
|  |  | |||
|  | @ -25,6 +25,7 @@ private: | |||
| private: | ||||
|     std::vector<float> mWeightData; | ||||
|     std::vector<float> mBiasData; | ||||
|     bool mNeedQuantDequant = false; | ||||
| }; | ||||
| 
 | ||||
| } // end namespace QNN
 | ||||
|  |  | |||
|  | @ -34,12 +34,17 @@ struct RuntimeHint { | |||
|     int cpuDecreaseRate = 50; | ||||
|     int dynamicQuantOption = 0; | ||||
| 
 | ||||
|     // qkvQuantOption % 8:
 | ||||
|     // 0: Do not quantize
 | ||||
|     // 1: Only quantize key, use int8 asymmetric quantization
 | ||||
|     // 2: Only quantize value, use fp8 quantization
 | ||||
|     // 3: quantize both key and value
 | ||||
|     // 4: quantize query, key and value, and use gemm int8 kernel to compute K*V
 | ||||
|     int qkvQuantOption = 0; | ||||
| 
 | ||||
|     // qkvQuantOption / 8:
 | ||||
|     // 1: use flash attention
 | ||||
| 
 | ||||
|     int qkvQuantOption = 8; | ||||
| 
 | ||||
|     // the kvcache size limit of each layer
 | ||||
|     // if the size of kvcache in memory exceeds the limit
 | ||||
|  | @ -403,6 +408,9 @@ public: | |||
|         info.mode = Backend::Info::DIRECT; | ||||
|         return true; | ||||
|     } | ||||
|     virtual bool onGetDeviceInfo(const std::string& deviceKey, std::string& deviceValue) const { | ||||
|         return false; | ||||
|     } | ||||
| protected: | ||||
|     /**
 | ||||
|      @brief deinitializer. | ||||
|  |  | |||
|  | @ -129,6 +129,13 @@ static int dumpModelInfo(const char* modelName) { | |||
|     } else { | ||||
|         MNN_PRINT("Model Version: %s \n", info->version.c_str()); | ||||
|     } | ||||
|     if (!info->metaData.empty()) { | ||||
|         MNN_PRINT("MetaData: Begin \n"); | ||||
|         for (auto& iter : info->metaData) { | ||||
|             MNN_PRINT("[Meta] %s : %s\n", iter.first.c_str(), iter.second.c_str()); | ||||
|         } | ||||
|         MNN_PRINT("MetaData: End \n"); | ||||
|     } | ||||
|     return 0; | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -130,7 +130,17 @@ bool CompleteSubGraph(const std::unordered_map<std::string, VARP>& inputs, const | |||
|     return true; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| static bool _hasDupName(std::unique_ptr<MNN::NetT>& originNet) { | ||||
|     std::set<std::string> names; | ||||
|     for (auto& tensorName : originNet->tensorName) { | ||||
|         if (names.find(tensorName) != names.end()) { | ||||
|             MNN_ERROR("Repeat name %s\n", tensorName.c_str()); | ||||
|             return true; | ||||
|         } | ||||
|         names.insert(tensorName); | ||||
|     } | ||||
|     return false; | ||||
| } | ||||
| void RunNetPass(const std::vector<std::string>& passes, std::unique_ptr<MNN::NetT>& originNet) { | ||||
|     for (auto pass : passes) { | ||||
|         auto convert = PostConverter::get(pass); | ||||
|  | @ -138,7 +148,14 @@ void RunNetPass(const std::vector<std::string>& passes, std::unique_ptr<MNN::Net | |||
|             LOG(INFO) << "Can't find pass of " << pass << "\n"; | ||||
|             continue; | ||||
|         } | ||||
|         auto originSize = originNet->oplists.size(); | ||||
|         bool valid = convert->onExecute(originNet); | ||||
| #ifdef DEBUG | ||||
|         auto hasDup = _hasDupName(originNet); | ||||
|         if (originSize != originNet->oplists.size() || hasDup) { | ||||
|             MNN_PRINT("%s: %d -> %d, dup: %d\n", pass.c_str(), originSize, originNet->oplists.size(), hasDup); | ||||
|         } | ||||
| #endif | ||||
|         if (!valid) { | ||||
|             LOG(INFO) << "Run " << pass << "Error\n"; | ||||
|         } | ||||
|  |  | |||
|  | @ -19,7 +19,7 @@ static EXPRP clipConvert(EXPRP expr, bool supportRelu6) { | |||
|     auto extraParam = op->main_as_Extra(); | ||||
|     // auto dataType = expr->outputInfo(0)->type.code;
 | ||||
|     auto maxValue  = std::numeric_limits<T>().max(); | ||||
|     auto minValue  = std::numeric_limits<T>().min(); | ||||
|     auto minValue  = std::numeric_limits<T>().lowest(); | ||||
|     if (nullptr != extraParam->attr()) { | ||||
|         const int attrSize = extraParam->attr()->size(); | ||||
|         for (int i = 0; i < attrSize; ++i) { | ||||
|  |  | |||
|  | @ -12,20 +12,70 @@ | |||
| class RemoveCopy : public PostConverter { | ||||
| public: | ||||
|     virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override { | ||||
|         auto config = Global<modelConfig>::Get(); | ||||
|         if (config->optimizeLevel < 1 || config->inSubGraph) { | ||||
|             return true; | ||||
|         std::set<std::string> netOutputNames; | ||||
|         for (auto& t : net->outputName) { | ||||
|             netOutputNames.insert(t); | ||||
|         } | ||||
|         for (auto iter = net->oplists.begin(); iter != net->oplists.end(); iter++) { | ||||
|             auto& op          = *iter; | ||||
|             if (op->type == MNN::OpType_Input) { | ||||
|                 for (auto o : op->outputIndexes) { | ||||
|                     netOutputNames.insert(net->tensorName[o]); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         auto config = Global<modelConfig>::Get(); | ||||
|         for (auto iter = net->oplists.begin(); iter != net->oplists.end();) { | ||||
|             auto& op          = *iter; | ||||
|             if (op->type != MNN::OpType_Identity) { | ||||
|             if (op->type != MNN::OpType_Identity || op->inputIndexes.size() != op->outputIndexes.size()) { | ||||
|                 iter++; | ||||
|                 continue; | ||||
|             } | ||||
|              | ||||
|             bool hasOutputName = false; | ||||
|             for (auto o : op->outputIndexes) { | ||||
|                 if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { | ||||
|                     hasOutputName = true; | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|             bool hasOutputFromInput = false; | ||||
|             for (auto o : op->inputIndexes) { | ||||
|                 if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { | ||||
|                     hasOutputFromInput = true; | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|             if (hasOutputFromInput && hasOutputName) { | ||||
|                 iter++; | ||||
|                 continue; | ||||
|             } | ||||
|             auto originInput  = op->inputIndexes; | ||||
|             auto originOutputs = op->outputIndexes; | ||||
|             MNN_ASSERT(originInput.size() == originOutputs.size()); | ||||
|             if (hasOutputName) { | ||||
|                 bool valid = true; | ||||
|                 for (int i=0; i<op->inputIndexes.size(); ++i) { | ||||
|                     auto o = op->outputIndexes[i]; | ||||
|                     auto originInput = op->inputIndexes[i]; | ||||
|                     if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { | ||||
|                         if (netOutputNames.find(net->tensorName[originInput]) != netOutputNames.end()) { | ||||
|                             valid = false; | ||||
|                             break; | ||||
|                         } | ||||
|                         auto originName = net->tensorName[originInput]; | ||||
|                         net->tensorName[originInput] = net->tensorName[o]; | ||||
|                         net->tensorName[o] = originName; | ||||
|                     } | ||||
|                 } | ||||
|                 if (!valid) { | ||||
|                     continue; | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             std::map<int, int> replaceIndexes; | ||||
|             for (int i=0; i<op->inputIndexes.size();++i) { | ||||
|                 replaceIndexes.insert(std::make_pair(op->outputIndexes[i], op->inputIndexes[i])); | ||||
|                 net->tensorName[op->inputIndexes[i]] = net->tensorName[op->outputIndexes[i]]; | ||||
|             } | ||||
|             for (auto subIter = net->oplists.begin(); subIter != net->oplists.end(); subIter++) { | ||||
|                 auto& subOp = *subIter; | ||||
|  | @ -35,14 +85,6 @@ public: | |||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             for (int v=0; v<op->inputIndexes.size(); ++v) { | ||||
|                 for (auto& o : net->outputName) { | ||||
|                     if (o == net->tensorName[op->inputIndexes[v]]) { | ||||
|                         o = net->tensorName[op->outputIndexes[v]]; | ||||
|                         break; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             iter = net->oplists.erase(iter); | ||||
|         } | ||||
|         return true; | ||||
|  |  | |||
|  | @ -0,0 +1,149 @@ | |||
| #include "RemoveTestNoUseOps.hpp" | ||||
| bool RemoveTestNoUseOps::onExecute(std::unique_ptr<MNN::NetT>& net) const { | ||||
|     const MNN::NetT* const netPtr = net.get(); | ||||
|     std::set<std::string> netOutputNames; | ||||
|     for (auto& t : net->outputName) { | ||||
|         netOutputNames.insert(t); | ||||
|     } | ||||
|     for (auto iter = net->oplists.begin(); iter != net->oplists.end(); iter++) { | ||||
|         auto& op          = *iter; | ||||
|         if (op->type == OpType_Input) { | ||||
|             for (auto o : op->outputIndexes) { | ||||
|                 netOutputNames.insert(net->tensorName[o]); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     std::unordered_set<int> removedInputs; | ||||
|     for (auto iter = net->oplists.begin(); iter != net->oplists.end();) { | ||||
|         auto& op          = *iter; | ||||
|         bool shouldDelete = shouldDeleteJudge(op.get(), netPtr); | ||||
|         if (!shouldDelete) { | ||||
|             iter++; | ||||
|             continue; | ||||
|         } | ||||
|         bool hasOutputName = false; | ||||
|         for (auto o : op->outputIndexes) { | ||||
|             if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { | ||||
|                 hasOutputName = true; | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|         bool hasOutputFromInput = false; | ||||
|         for (auto o : op->inputIndexes) { | ||||
|             if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { | ||||
|                 hasOutputFromInput = true; | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
|         if (hasOutputFromInput && hasOutputName) { | ||||
|             iter++; | ||||
|             continue; | ||||
|         } | ||||
|         bool deleteOutput = shouldDeleteOutput(op.get()); | ||||
|         // Find the next op
 | ||||
|         if (op->outputIndexes.empty() || op->inputIndexes.empty()) { | ||||
|             iter = net->oplists.erase(iter); | ||||
|             continue; | ||||
|         } | ||||
|         auto originInput  = op->inputIndexes[0]; | ||||
|         auto originOutputs = op->outputIndexes; | ||||
|         if ((!deleteOutput) && hasOutputName) { | ||||
|             bool valid = true; | ||||
|             for (auto o : originOutputs) { | ||||
|                 if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { | ||||
|                     if (netOutputNames.find(net->tensorName[originInput]) != netOutputNames.end()) { | ||||
|                         valid = false; | ||||
|                         break; | ||||
|                     } | ||||
|                     net->tensorName[originInput] = net->tensorName[o]; | ||||
|                 } | ||||
|             } | ||||
|             if (!valid) { | ||||
|                 continue; | ||||
|             } | ||||
|         } | ||||
|         for (auto subIter = net->oplists.begin(); subIter != net->oplists.end(); subIter++) { | ||||
|             auto& subOp = *subIter; | ||||
|             if (deleteOutput) { | ||||
|                 for (auto iter=subOp->inputIndexes.begin(); iter != subOp->inputIndexes.end();) { | ||||
|                     if (std::find(originOutputs.begin(), originOutputs.end(), *iter) != originOutputs.end()) { | ||||
|                         iter = subOp->inputIndexes.erase(iter); | ||||
|                         continue; | ||||
|                     } | ||||
|                     iter++; | ||||
|                 } | ||||
|             } else { | ||||
|                 for (int v = 0; v < subOp->inputIndexes.size(); ++v) { | ||||
|                     if (std::find(originOutputs.begin(), originOutputs.end(), subOp->inputIndexes[v]) != originOutputs.end()) { | ||||
|                         subOp->inputIndexes[v] = originInput; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         bool removeUselessInput = shouldRemoveUnusefulInputs(op.get()); | ||||
|         if (removeUselessInput) { | ||||
|             for (int input : op->inputIndexes) { | ||||
|                 removedInputs.emplace(input); | ||||
|             } | ||||
|         } | ||||
|         iter = net->oplists.erase(iter); | ||||
|     } | ||||
| 
 | ||||
|     // Remove the op only if the reference counts of it's all outputs
 | ||||
|     // are reduced to be zero.
 | ||||
|     std::unordered_map<int, int/*reference count*/> uselessIndex; | ||||
|     for (const auto& op : net->oplists) { | ||||
|         for (int input : op->inputIndexes) { | ||||
|             auto it = uselessIndex.find(input); | ||||
|             if (it == uselessIndex.end()) { | ||||
|                 uselessIndex.emplace(input, 1); | ||||
|             } else { | ||||
|                 ++it->second; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|     // Set reference count 1 for all net outputs.
 | ||||
|     for (const auto& op : net->oplists) { | ||||
|         for (int output : op->outputIndexes) { | ||||
|             auto it = uselessIndex.find(output); | ||||
|             if (it == uselessIndex.end()) { | ||||
|                 if (removedInputs.count(output)) { | ||||
|                     uselessIndex.emplace(output, 0); | ||||
|                 } else { | ||||
|                     uselessIndex.emplace(output, 1); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     bool needIteration = false; | ||||
|     do { | ||||
|         needIteration = false; | ||||
|         for (auto iter = net->oplists.begin(); iter != net->oplists.end();) { | ||||
|             auto& op     = *iter; | ||||
|             bool useless = true; | ||||
|             for (auto index : op->outputIndexes) { | ||||
|                 if (uselessIndex.at(index) > 0) { | ||||
|                     useless = false; | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|             if (!useless) { | ||||
|                 iter++; | ||||
|                 continue; | ||||
|             } | ||||
|             if (!op->inputIndexes.empty()) { | ||||
|                 for (auto index : op->inputIndexes) { | ||||
|                     auto it = uselessIndex.find(index); | ||||
|                     MNN_ASSERT(it != uselessIndex.end()); | ||||
|                     --it->second; | ||||
|                 } | ||||
|                 needIteration = true; | ||||
|             } | ||||
|             iter = net->oplists.erase(iter); | ||||
|         } | ||||
|     } while (needIteration); | ||||
| 
 | ||||
|     return true; | ||||
| } | ||||
|  | @ -25,135 +25,5 @@ public: | |||
| 
 | ||||
|     virtual bool shouldDeleteOutput(const MNN::OpT* op) const = 0; | ||||
| 
 | ||||
|     virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override { | ||||
|         const MNN::NetT* const netPtr = net.get(); | ||||
|         std::set<std::string> netOutputNames; | ||||
|         for (auto& t : net->outputName) { | ||||
|             netOutputNames.insert(t); | ||||
|         } | ||||
|         std::unordered_set<int> removedInputs; | ||||
|         for (auto iter = net->oplists.begin(); iter != net->oplists.end();) { | ||||
|             auto& op          = *iter; | ||||
|             bool shouldDelete = shouldDeleteJudge(op.get(), netPtr); | ||||
|             if (!shouldDelete) { | ||||
|                 iter++; | ||||
|                 continue; | ||||
|             } | ||||
|             bool hasOutputName = false; | ||||
|             for (auto o : op->outputIndexes) { | ||||
|                 if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { | ||||
|                     hasOutputName = true; | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|             bool hasOutputFromInput = false; | ||||
|             for (auto o : op->inputIndexes) { | ||||
|                 if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { | ||||
|                     hasOutputFromInput = true; | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|             if (hasOutputFromInput && hasOutputName) { | ||||
|                 iter++; | ||||
|                 continue; | ||||
|             } | ||||
|             bool deleteOutput = shouldDeleteOutput(op.get()); | ||||
|             // Find the next op
 | ||||
|             if (op->outputIndexes.empty() || op->inputIndexes.empty()) { | ||||
|                 iter = net->oplists.erase(iter); | ||||
|                 continue; | ||||
|             } | ||||
|             auto originInput  = op->inputIndexes[0]; | ||||
|             auto originOutputs = op->outputIndexes; | ||||
|             if ((!deleteOutput) && hasOutputName) { | ||||
|                 for (auto o : originOutputs) { | ||||
|                     if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) { | ||||
|                         net->tensorName[originInput] = net->tensorName[o]; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             for (auto subIter = net->oplists.begin(); subIter != net->oplists.end(); subIter++) { | ||||
|                 auto& subOp = *subIter; | ||||
|                 if (deleteOutput) { | ||||
|                     for (auto iter=subOp->inputIndexes.begin(); iter != subOp->inputIndexes.end();) { | ||||
|                         if (std::find(originOutputs.begin(), originOutputs.end(), *iter) != originOutputs.end()) { | ||||
|                             iter = subOp->inputIndexes.erase(iter); | ||||
|                             continue; | ||||
|                         } | ||||
|                         iter++; | ||||
|                     } | ||||
|                 } else { | ||||
|                     for (int v = 0; v < subOp->inputIndexes.size(); ++v) { | ||||
|                         if (std::find(originOutputs.begin(), originOutputs.end(), subOp->inputIndexes[v]) != originOutputs.end()) { | ||||
|                             subOp->inputIndexes[v] = originInput; | ||||
|                         } | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             bool removeUselessInput = shouldRemoveUnusefulInputs(op.get()); | ||||
|             if (removeUselessInput) { | ||||
|                 for (int input : op->inputIndexes) { | ||||
|                     removedInputs.emplace(input); | ||||
|                 } | ||||
|             } | ||||
|             iter = net->oplists.erase(iter); | ||||
|         } | ||||
| 
 | ||||
|         // Remove the op only if the reference counts of it's all outputs
 | ||||
|         // are reduced to be zero.
 | ||||
|         std::unordered_map<int, int/*reference count*/> uselessIndex; | ||||
|         for (const auto& op : net->oplists) { | ||||
|             for (int input : op->inputIndexes) { | ||||
|                 auto it = uselessIndex.find(input); | ||||
|                 if (it == uselessIndex.end()) { | ||||
|                     uselessIndex.emplace(input, 1); | ||||
|                 } else { | ||||
|                     ++it->second; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         // Set reference count 1 for all net outputs.
 | ||||
|         for (const auto& op : net->oplists) { | ||||
|             for (int output : op->outputIndexes) { | ||||
|                 auto it = uselessIndex.find(output); | ||||
|                 if (it == uselessIndex.end()) { | ||||
|                     if (removedInputs.count(output)) { | ||||
|                         uselessIndex.emplace(output, 0); | ||||
|                     } else { | ||||
|                         uselessIndex.emplace(output, 1); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         bool needIteration = false; | ||||
|         do { | ||||
|             needIteration = false; | ||||
|             for (auto iter = net->oplists.begin(); iter != net->oplists.end();) { | ||||
|                 auto& op     = *iter; | ||||
|                 bool useless = true; | ||||
|                 for (auto index : op->outputIndexes) { | ||||
|                     if (uselessIndex.at(index) > 0) { | ||||
|                         useless = false; | ||||
|                         break; | ||||
|                     } | ||||
|                 } | ||||
|                 if (!useless) { | ||||
|                     iter++; | ||||
|                     continue; | ||||
|                 } | ||||
|                 if (!op->inputIndexes.empty()) { | ||||
|                     for (auto index : op->inputIndexes) { | ||||
|                         auto it = uselessIndex.find(index); | ||||
|                         MNN_ASSERT(it != uselessIndex.end()); | ||||
|                         --it->second; | ||||
|                     } | ||||
|                     needIteration = true; | ||||
|                 } | ||||
|                 iter = net->oplists.erase(iter); | ||||
|             } | ||||
|         } while (needIteration); | ||||
| 
 | ||||
|         return true; | ||||
|     } | ||||
|     virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override; | ||||
| }; | ||||
|  |  | |||
|  | @ -33,6 +33,11 @@ public: | |||
|                 return true; | ||||
|             } | ||||
|         } | ||||
|         if (op->type == OpType_Identity) { | ||||
|             // Support 1->N
 | ||||
|             return op->inputIndexes.size() == 1 && op->outputIndexes.size() > 1; | ||||
|         } | ||||
| 
 | ||||
|         if (op->type == OpType_Cast) { | ||||
|             if (op->main.AsCastParam()->dstT == op->main.AsCastParam()->srcT) { | ||||
|                 return true; | ||||
|  |  | |||
|  | @ -37,11 +37,21 @@ int main(int argc, const char* argv[]) { | |||
|     /**
 | ||||
|     generate qnn .cpp and .bin | ||||
|     */ | ||||
| 
 | ||||
|     std::string dstModelName = dstMNN; | ||||
|     size_t pos = dstModelName.find_last_of("/\\"); | ||||
|     std::string dstModelPath; | ||||
|     if (pos == std::string::npos) { | ||||
|         // current path
 | ||||
|         dstModelPath = "./"; | ||||
|     } else { | ||||
|         dstModelPath = dstModelName.substr(0, pos); | ||||
|     } | ||||
|     std::string qnnModelPath = dstModelPath + "/" + qnnModelName; | ||||
|     MNN_PRINT("[Temp Product]: Qnn temp product generate at %s\n", qnnModelPath.c_str()); | ||||
|     MNN::ScheduleConfig config; | ||||
|     config.type = MNN_FORWARD_NN; | ||||
|     std::shared_ptr<Executor::RuntimeManager> rtmgr(Executor::RuntimeManager::createRuntimeManager(config)); | ||||
|     rtmgr->setCache(qnnModelName.c_str()); | ||||
|     rtmgr->setCache(qnnModelPath.c_str()); | ||||
|     MNN::Express::Module::Config mConfig; | ||||
|     mConfig.shapeMutable = false; | ||||
|     std::shared_ptr<MNN::Express::Module> m(MNN::Express::Module::load(inputNames, outputNames, srcMNN, rtmgr, &mConfig), MNN::Express::Module::destroy); | ||||
|  | @ -64,7 +74,7 @@ int main(int argc, const char* argv[]) { | |||
|     } | ||||
| 
 | ||||
|     int ret = 0; | ||||
|     std::string tarBinCmd = "cd " + qnnModelName + \ | ||||
|     std::string tarBinCmd = "cd " + qnnModelPath + \ | ||||
|         " && " + \ | ||||
|         "tar -cf " + qnnModelName + ".bin *.raw"; | ||||
|     ret = system(tarBinCmd.c_str()); | ||||
|  | @ -74,10 +84,10 @@ int main(int argc, const char* argv[]) { | |||
|     } | ||||
| 
 | ||||
|     std::string modelLibCmd = qnnSdkPath + "/bin/x86_64-linux-clang/qnn-model-lib-generator " + \ | ||||
|         "-c " + qnnModelName + "/" + qnnModelName + ".cpp " + \ | ||||
|         "-b " + qnnModelName + "/" + qnnModelName + ".bin " + \ | ||||
|         "-c " + qnnModelPath + "/" + qnnModelName + ".cpp " + \ | ||||
|         "-b " + qnnModelPath + "/" + qnnModelName + ".bin " + \ | ||||
|         "-t x86_64-linux-clang " + \ | ||||
|         "-o " + qnnModelName + "/lib "; | ||||
|         "-o " + qnnModelPath + "/lib "; | ||||
|     ret = system(modelLibCmd.c_str()); | ||||
|     if(ret) { | ||||
|         MNN_ERROR("[Error]: qnn-model-lib-generator error!\n"); | ||||
|  | @ -86,12 +96,13 @@ int main(int argc, const char* argv[]) { | |||
|         MNN_PRINT("[Pass]: qnn-model-lib-generator success!\n"); | ||||
|     } | ||||
| 
 | ||||
|     std::string qnnBin = dstModelPath + "/" + qnnModelName + ".bin"; | ||||
|     std::string binaryGenCmd = qnnSdkPath + "/bin/x86_64-linux-clang/qnn-context-binary-generator " + \ | ||||
|         "--model " + qnnModelName + "/lib/x86_64-linux-clang/lib" + qnnModelName + ".so " + \ | ||||
|         "--model " + qnnModelPath + "/lib/x86_64-linux-clang/lib" + qnnModelName + ".so " + \ | ||||
|         "--backend " + qnnSdkPath + "/lib/x86_64-linux-clang/libQnnHtp.so " + \ | ||||
|         "--binary_file " + qnnModelName + " " + \ | ||||
|         "--config_file " + qnnContextConfig + " " + \ | ||||
|         "--output_dir " + qnnModelName + "/binary"; | ||||
|         "--output_dir " + dstModelPath; | ||||
|     ret = system(binaryGenCmd.c_str()); | ||||
|     if(ret) { | ||||
|         MNN_ERROR("[Error]: qnn-context-binary-generator error!\n"); | ||||
|  | @ -123,6 +134,7 @@ int main(int argc, const char* argv[]) { | |||
|     } | ||||
| 
 | ||||
|     std::string npuPath = std::string("/") + qnnModelName + std::string(".bin"); | ||||
|   | ||||
|     MNN_PRINT("npu model path:%s\n", npuPath.c_str()); | ||||
|     /** Fuse to Op*/ | ||||
|     std::unique_ptr<MNN::OpT> op(new OpT); | ||||
|  | @ -204,7 +216,7 @@ int main(int argc, const char* argv[]) { | |||
|     } | ||||
|     for (int i=0; i<outputInfos.size(); ++i) { | ||||
|         attr.reset(new MNN::AttributeT); | ||||
|         attr->key = "o_" + std::to_string(i) + "_0"; | ||||
|         attr->key = "o_0_" + std::to_string(i); | ||||
|         attr->tensor.reset(new BlobT); | ||||
|         attr->tensor->dataType = OpCommonUtils::convertDataType(outputInfos[i].type); | ||||
|         attr->tensor->dims = outputInfos[i].dim; | ||||
|  | @ -240,7 +252,6 @@ int main(int argc, const char* argv[]) { | |||
|     outputOs.close(); | ||||
| 
 | ||||
|     MNN_PRINT("[All Pass]: npu model generator success!\n"); | ||||
|     std::string qnnBin = qnnModelName + "/binary/" + qnnModelName + ".bin"; | ||||
|     MNN_PRINT("[Output Product]:\nNew mnn model path: %s\nNpu model path: %s\n", dstMNN, qnnBin.c_str()); | ||||
|     return 0; | ||||
| } | ||||
|  |  | |||
|  | @ -43,6 +43,6 @@ for w in gWrong: | |||
|     print(w) | ||||
| print('TEST_NAME_MODULE: 模型测试\nTEST_CASE_AMOUNT_MODULE: {\"blocked\":0,\"failed\":%d,\"passed\":%d,\"skipped\":0}\n'%(len(gWrong), total_num - len(gWrong))) | ||||
| print('TEST_CASE={\"name\":\"Onnx转换测试\",\"failed\":%d,\"passed\":%d}\n'%(len(gWrong), total_num - len(gWrong))) | ||||
| print('Total Size: ', mnnsize,' MB, convert ', correct_num," model") | ||||
| print('Total Size: ', total_size,' MB, convert ', correct_num," model") | ||||
| if len(gWrong) > 0: | ||||
|     exit(1) | ||||
|  |  | |||
|  | @ -0,0 +1,168 @@ | |||
| import json | ||||
| import copy | ||||
| import argparse | ||||
| import os | ||||
| import subprocess | ||||
| import shutil # 导入 shutil 模块用于删除目录 | ||||
| 
 | ||||
| def generate_all_configs(config_path, graph_name, qnn_sdk_root_path, src_model, executable_path, output_dir): | ||||
|     """ | ||||
|     为每个组合创建子目录,生成配置文件,并调用C++可执行文件进行模型转换。 | ||||
|     """ | ||||
|     # --- 0. 准备工作 --- | ||||
|     # 创建主输出目录 | ||||
|     os.makedirs(output_dir, exist_ok=True) | ||||
|     print(f"所有生成的文件将被保存在主目录: '{output_dir}'") | ||||
|      | ||||
|     # 定义组合 | ||||
|     combinations = [ | ||||
|         [36, 'v69'], | ||||
|         [42, 'v69'], | ||||
|         [43, 'v73'], | ||||
|         [57, 'v75'], | ||||
|         [69, 'v79'] | ||||
|     ] | ||||
| 
 | ||||
|     # --- 1. 读取模板文件 --- | ||||
|     htp_template_file = os.path.join(config_path, "htp_backend_extensions.json") | ||||
|     context_template_file = os.path.join(config_path, "context_config.json") | ||||
| 
 | ||||
|     try: | ||||
|         with open(htp_template_file, 'r', encoding='utf-8') as f: | ||||
|             base_htp_data = json.load(f) | ||||
|         print(f"成功读取模板文件 '{htp_template_file}'。") | ||||
|          | ||||
|         with open(context_template_file, 'r', encoding='utf-8') as f: | ||||
|             base_context_data = json.load(f) | ||||
|         print(f"成功读取模板文件 '{context_template_file}'。") | ||||
|     except FileNotFoundError as e: | ||||
|         print(f"错误:模板文件未找到。请确保 '{e.filename}' 存在于指定的路径中。") | ||||
|         return | ||||
|     except json.JSONDecodeError as e: | ||||
|         print(f"错误:文件格式无效。请检查 {e.doc} 是否为有效的JSON。") | ||||
|         return | ||||
| 
 | ||||
|     # --- 2. 遍历组合,生成文件并执行命令 --- | ||||
|     for soc_id, dsp_arch in combinations: | ||||
|         print(f"\n{'='*15} 处理组合: soc_id={soc_id}, dsp_arch={dsp_arch} {'='*15}") | ||||
|          | ||||
|         # --- 新增步骤: 为当前组合创建专用的子目录 --- | ||||
|         new_graph_name = f"{graph_name}_{soc_id}_{dsp_arch}" | ||||
|         graph_specific_dir = output_dir | ||||
| 
 | ||||
|         # --- Part A: 生成 htp_backend_extensions 文件 (路径更新) --- | ||||
|         htp_config_data = copy.deepcopy(base_htp_data) | ||||
|         try: | ||||
|             htp_config_data["graphs"][0]["graph_names"] = [new_graph_name] | ||||
|             htp_config_data["devices"][0]["soc_id"] = soc_id | ||||
|             htp_config_data["devices"][0]["dsp_arch"] = dsp_arch | ||||
|         except (IndexError, KeyError) as e: | ||||
|             print(f"处理组合时出错: '{htp_template_file}' 结构不正确。错误: {e}") | ||||
|             continue | ||||
| 
 | ||||
|         htp_output_filename = f"htp_backend_extensions_{soc_id}_{dsp_arch}.json" | ||||
|         # 更新路径,使其指向新的子目录 | ||||
|         htp_output_filepath = os.path.join(graph_specific_dir, htp_output_filename) | ||||
|         with open(htp_output_filepath, 'w', encoding='utf-8') as f: | ||||
|             json.dump(htp_config_data, f, indent=4, ensure_ascii=False) | ||||
|         print(f"-> 已生成配置文件: '{htp_output_filepath}'") | ||||
|          | ||||
|         # --- Part B: 生成 context_config 文件 (路径更新) --- | ||||
|         context_config_data = copy.deepcopy(base_context_data) | ||||
|         try: | ||||
|             # 这里的 htp_output_filename 是相对路径,这是正确的, | ||||
|             # 因为 context_config 和 htp_backend_extensions 在同一个目录中。 | ||||
|             context_config_data["backend_extensions"]["config_file_path"] = htp_output_filepath | ||||
|             path_template = context_config_data["backend_extensions"]["shared_library_path"] | ||||
|             new_lib_path = path_template.replace("{QNN_SDK_ROOT}", qnn_sdk_root_path) | ||||
|             context_config_data["backend_extensions"]["shared_library_path"] = new_lib_path | ||||
|         except KeyError as e: | ||||
|             print(f"处理组合时出错: '{context_template_file}' 结构不正确,缺少键: {e}") | ||||
|             continue | ||||
|              | ||||
|         context_output_filename = f"context_config_{soc_id}_{dsp_arch}.json" | ||||
|         # 更新路径,使其指向新的子目录 | ||||
|         context_output_filepath = os.path.join(graph_specific_dir, context_output_filename) | ||||
|         with open(context_output_filepath, 'w', encoding='utf-8') as f: | ||||
|             json.dump(context_config_data, f, indent=4, ensure_ascii=False) | ||||
|         print(f"-> 已生成关联文件: '{context_output_filepath}'") | ||||
|          | ||||
|         # --- Part C: 调用C++可执行命令 (路径更新) --- | ||||
|         dst_model_filename = f"{graph_name}_{soc_id}_{dsp_arch}.mnn" | ||||
|         # 更新路径,使其指向新的子目录 | ||||
|         dst_model_filepath = os.path.join(graph_specific_dir, dst_model_filename) | ||||
|          | ||||
|         graph_product_dir = os.path.join(graph_specific_dir, new_graph_name) | ||||
|         os.makedirs(graph_product_dir, exist_ok=True) | ||||
|         print(f"-> 已创建/确认子目录: '{graph_product_dir}'") | ||||
| 
 | ||||
|         command = [ | ||||
|             executable_path, | ||||
|             src_model, | ||||
|             dst_model_filepath, | ||||
|             qnn_sdk_root_path, | ||||
|             new_graph_name, | ||||
|             context_output_filepath | ||||
|         ] | ||||
|          | ||||
|         print("--> 准备执行命令...") | ||||
|         print(f"    $ {' '.join(command)}") | ||||
| 
 | ||||
|         try: | ||||
|             result = subprocess.run(command, check=True, capture_output=True, text=True) | ||||
|             print("--> 命令执行成功!") | ||||
|             # 即使成功,也打印 C++ 程序的输出,这对于查看警告等信息很有用 | ||||
|             if result.stdout: | ||||
|                 print("    --- C++程序输出 (stdout) ---") | ||||
|                 print(result.stdout.strip()) | ||||
|                 print("    ------------------------------") | ||||
| 
 | ||||
|         except FileNotFoundError: | ||||
|             print(f"!!! 命令执行失败: 可执行文件未找到 '{executable_path}'。请检查路径。") | ||||
|             break # 如果可执行文件找不到,直接退出循环 | ||||
|         except subprocess.CalledProcessError as e: | ||||
|             # 这是关键的修改部分 | ||||
|             print(f"!!! 命令执行失败 (返回码: {e.returncode})") | ||||
|              | ||||
|             # 检查并打印 C++ 程序在失败前产生的标准输出 | ||||
|             if e.stdout: | ||||
|                 print("    --- C++程序输出 (stdout) ---") | ||||
|                 print(e.stdout.strip()) | ||||
|                 print("    ------------------------------") | ||||
|              | ||||
|             # 检查并打印 C++ 程序在失败前产生的标准错误(错误日志通常在这里) | ||||
|             if e.stderr: | ||||
|                 print("    --- C++程序错误 (stderr) ---") | ||||
|                 print(e.stderr.strip()) | ||||
|                 print("    ------------------------------") | ||||
|         except Exception as e: | ||||
|             print(f"!!! 执行期间发生未知错误: {e}") | ||||
|          | ||||
|         finally: | ||||
|             # --- 步骤 3: 清理 --- | ||||
|             # 检查目录是否存在,然后删除 | ||||
|             if os.path.exists(graph_product_dir): | ||||
|                 print(f"--> 清理临时文件和目录: '{graph_product_dir}'") | ||||
|                 shutil.rmtree(graph_product_dir) | ||||
|             else: | ||||
|                 print("--> 无需清理,临时目录未创建。") | ||||
| 
 | ||||
| # --- 脚本执行入口 --- | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="为多个组合创建子目录,生成QNN配置文件并调用模型转换工具。", | ||||
|         formatter_class=argparse.RawTextHelpFormatter | ||||
|     ) | ||||
|     # ... (argparse部分保持完全不变) ... | ||||
|     gen_group = parser.add_argument_group('文件生成参数') | ||||
|     gen_group.add_argument("--config_path", required=True, help="[必需] 包含模板文件的目录路径。") | ||||
|     gen_group.add_argument("--graph_name", required=True, help="[必需] 模型图的名称 (不含soc_id等后缀)。") | ||||
|     gen_group.add_argument("--qnn_sdk_root_path", required=True, help="[必需] QNN SDK 的根路径。") | ||||
|      | ||||
|     exec_group = parser.add_argument_group('模型转换参数') | ||||
|     exec_group.add_argument("--src_model", required=True, help="[必需] 源模型文件路径 (例如: my_model.mnn)。") | ||||
|     exec_group.add_argument("--executable_path", required=True, help="[必需] C++模型转换可执行文件的路径。") | ||||
|     exec_group.add_argument("--output_dir", default="./qnn_models", help="存放所有生成文件的输出目录 (默认: ./qnn_models)。") | ||||
| 
 | ||||
|     args = parser.parse_args() | ||||
|     generate_all_configs(args.config_path, args.graph_name, args.qnn_sdk_root_path, args.src_model, args.executable_path, args.output_dir) | ||||
|  | @ -51,6 +51,21 @@ if (LLM_SUPPORT_AUDIO AND MNN_BUILD_AUDIO) | |||
|     add_definitions(-DLLM_SUPPORT_AUDIO) | ||||
| endif() | ||||
| 
 | ||||
| IF(CMAKE_SYSTEM_NAME MATCHES "^Android" AND NOT MNN_BUILD_FOR_ANDROID_COMMAND) | ||||
| IF(NOT NATIVE_INCLUDE_OUTPUT) | ||||
|   set(NATIVE_INCLUDE_OUTPUT ".") | ||||
| ENDIF() | ||||
| add_custom_command( | ||||
|   TARGET llm | ||||
|   POST_BUILD | ||||
|   COMMAND ${CMAKE_COMMAND} | ||||
|   ARGS -E copy_directory ${CMAKE_CURRENT_LIST_DIR}/include ${NATIVE_INCLUDE_OUTPUT} | ||||
| ) | ||||
| ELSE() | ||||
| INSTALL(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/include/ DESTINATION include FILES_MATCHING PATTERN *.hpp) | ||||
| ENDIF() | ||||
| 
 | ||||
| 
 | ||||
| add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/demo/llm_demo.cpp) | ||||
| add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/demo/embedding_demo.cpp) | ||||
| add_executable(reranker_demo ${CMAKE_CURRENT_LIST_DIR}/demo/reranker_demo.cpp) | ||||
|  |  | |||
|  | @ -37,6 +37,23 @@ struct TimePerformance; | |||
| using ChatMessage = std::pair<std::string, std::string>; // <role, content>
 | ||||
| using ChatMessages = std::vector<ChatMessage>; | ||||
| 
 | ||||
| struct MNN_PUBLIC PromptImagePart { | ||||
|     MNN::Express::VARP image_data; | ||||
|     int width; | ||||
|     int height; | ||||
| }; | ||||
| 
 | ||||
| struct MNN_PUBLIC PromptAudioPart { | ||||
|     std::string file_path; | ||||
|     MNN::Express::VARP waveform; | ||||
| }; | ||||
| 
 | ||||
| struct MNN_PUBLIC MultimodalPrompt { | ||||
|     std::string prompt_template; | ||||
|     std::map<std::string, PromptImagePart> images; | ||||
|     std::map<std::string, PromptAudioPart> audios; | ||||
| }; | ||||
| 
 | ||||
| enum TuneType { | ||||
|     // op encoder number for commit
 | ||||
|     OP_ENCODER_NUMBER = 0, | ||||
|  | @ -108,11 +125,17 @@ public: | |||
|     std::string dump_config(); | ||||
|     bool set_config(const std::string& content); | ||||
|     Llm* create_lora(const std::string& lora_path); | ||||
|     std::string get_statistics(); | ||||
|     // tokenier function
 | ||||
|     bool is_stop(int token); | ||||
|     std::string tokenizer_decode(int token); | ||||
|     virtual std::vector<int> tokenizer_encode(const std::string& query); | ||||
|     friend class Pipeline; | ||||
|     virtual std::vector<int> tokenizer_encode(const MultimodalPrompt& multimodal_input); | ||||
|     void response(const MultimodalPrompt& multimodal_input,  | ||||
|                   std::ostream* os = &std::cout,  | ||||
|                   const char* end_with = nullptr,  | ||||
|                   int max_new_tokens = -1); | ||||
|     const LlmContext* getContext() const { | ||||
|         return mContext.get(); | ||||
|     } | ||||
|  |  | |||
|  | @ -48,6 +48,7 @@ DiskEmbedding::DiskEmbedding(const std::shared_ptr<LlmConfig>& config, std::stri | |||
|         if (mQuantBit != 16) { | ||||
|             if (mQuantBlock == 0) { | ||||
|                 mBlockNum = 1; | ||||
|                 mQuantBlock = mHiddenSize; // be used for mDequantFunc.
 | ||||
|             } else { | ||||
|                 mBlockNum = mHiddenSize / mQuantBlock; | ||||
|             } | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ | |||
| #include <fstream> | ||||
| #include <iostream> | ||||
| #include <sstream> | ||||
| #include <iomanip> | ||||
| #include <unordered_set> | ||||
| 
 | ||||
| #include <MNN/AutoTime.hpp> | ||||
|  | @ -97,11 +98,44 @@ bool Llm::set_config(const std::string& content) { | |||
|     return res; | ||||
| } | ||||
| 
 | ||||
| std::string Llm::get_statistics() { | ||||
|     auto context = getContext(); | ||||
|     int prompt_len = context->prompt_len; | ||||
|     int decode_len = context->gen_seq_len; | ||||
|     int64_t vision_time = context->vision_us; | ||||
|     int64_t audio_time = context->audio_us; | ||||
|     int64_t prefill_time = context->prefill_us; | ||||
|     int64_t decode_time = context->decode_us; | ||||
|     int64_t sample_time = context->sample_us; | ||||
|     float vision_s = vision_time / 1e6; | ||||
|     float audio_s = audio_time / 1e6; | ||||
|     float prefill_s = prefill_time / 1e6; | ||||
|     float decode_s = decode_time / 1e6; | ||||
|     float sample_s = sample_time / 1e6; | ||||
|     float prefill_speed = (prefill_s > 0.0f) ? (prompt_len / prefill_s) : 0.0f; | ||||
|     float decode_speed = (decode_s > 0.0f) ? (decode_len / decode_s) : 0.0f; | ||||
| 
 | ||||
|     std::ostringstream json_stream; | ||||
|     json_stream << "{" | ||||
|                 << "\"prompt_tokens\":" << prompt_len << "," | ||||
|                 << "\"decode_tokens\":" << decode_len << "," | ||||
|                 << "\"vision_time\":" << std::fixed << std::setprecision(2) << vision_s << "," | ||||
|                 << "\"audio_time\":" << std::fixed << std::setprecision(2) << audio_s << "," | ||||
|                 << "\"prefill_time\":" << std::fixed << std::setprecision(2) << prefill_s << "," | ||||
|                 << "\"decode_time\":" << std::fixed << std::setprecision(2) << decode_s << "," | ||||
|                 << "\"sample_time\":" << std::fixed << std::setprecision(2) << sample_s << "," | ||||
|                 << "\"prefill_speed\":" << std::fixed << std::setprecision(2) << prefill_speed << "," | ||||
|                 << "\"decode_speed\":" << std::fixed << std::setprecision(2) << decode_speed | ||||
|                 << "}"; | ||||
| 
 | ||||
|     return json_stream.str(); | ||||
| } | ||||
| 
 | ||||
| void Llm::setRuntimeHint(std::shared_ptr<Express::Executor::RuntimeManager> &rtg) { | ||||
|     rtg->setHint(MNN::Interpreter::INIT_THREAD_NUMBER, 4); | ||||
| 
 | ||||
|     rtg->setHint(MNN::Interpreter::MEM_ALLOCATOR_TYPE, 0); | ||||
|     rtg->setHint(MNN::Interpreter::QKV_QUANT_OPTIONS, mConfig->quant_qkv()); | ||||
|     rtg->setHint(MNN::Interpreter::QKV_QUANT_OPTIONS, mConfig->config_.value("quant_qkv", 8)); | ||||
|     rtg->setHint(MNN::Interpreter::KVCACHE_SIZE_LIMIT, mConfig->kvcache_limit()); | ||||
|     if (mConfig->use_cached_mmap()) { | ||||
|         rtg->setHint(MNN::Interpreter::USE_CACHED_MMAP, 1); | ||||
|  | @ -113,6 +147,8 @@ void Llm::setRuntimeHint(std::shared_ptr<Express::Executor::RuntimeManager> &rtg | |||
|     if (mConfig->use_mmap()) { | ||||
|         rtg->setExternalPath(tmpPath, MNN::Interpreter::EXTERNAL_WEIGHT_DIR); | ||||
|     } | ||||
|     // set npu model dir
 | ||||
|     rtg->setExternalPath(mConfig->npu_model_dir(), 3); | ||||
|     auto dynamicOption = mConfig->dynamic_option(); | ||||
|     if (mConfig->dynamic_option()) { | ||||
|         rtg->setHint(MNN::Interpreter::DYNAMIC_QUANT_OPTIONS, mConfig->dynamic_option()); | ||||
|  | @ -246,8 +282,7 @@ void Llm::load() { | |||
|     if (needHiddenState) { | ||||
|         outputNames.emplace_back("hidden_states"); | ||||
|     } | ||||
|     // set npu model dir
 | ||||
|     mRuntimeManager->setExternalPath(mConfig->npu_model_dir(), 3); | ||||
| 
 | ||||
|     mRuntimeManager->setExternalFile(mConfig->llm_weight()); | ||||
|     mModules[0].reset(Module::load(inputNames, outputNames, model_path.c_str(), mRuntimeManager, &module_config)); | ||||
|     mRuntimeManager->setExternalFile(""); | ||||
|  | @ -594,6 +629,29 @@ std::vector<int> Llm::tokenizer_encode(const std::string& user_content) { | |||
|     return mTokenizer->encode(user_content); | ||||
| } | ||||
| 
 | ||||
| std::vector<int> Llm::tokenizer_encode(const MultimodalPrompt& multimodal_input) { | ||||
|     return mTokenizer->encode(multimodal_input.prompt_template); | ||||
| } | ||||
| 
 | ||||
| void Llm::response(const MultimodalPrompt& multimodal_input,  | ||||
|                    std::ostream* os, const char* end_with, int max_new_tokens) { | ||||
|     auto prompt = multimodal_input.prompt_template; | ||||
|     if (mConfig->use_template()) { | ||||
|         prompt = mPrompt->applyTemplate(prompt, true); | ||||
|     } | ||||
|      | ||||
|     int prompt_len = 0; | ||||
|     int decode_len = 0; | ||||
|     int64_t vision_time = 0; | ||||
|     int64_t audio_time = 0; | ||||
|     int64_t prefill_time = 0; | ||||
|     int64_t decode_time = 0; | ||||
|     int64_t sample_time = 0; | ||||
|      | ||||
|     std::vector<int> input_ids = tokenizer_encode(multimodal_input); | ||||
|     response(input_ids, os, end_with, max_new_tokens); | ||||
| } | ||||
| 
 | ||||
| std::vector<int> Llm::generate(MNN::Express::VARP input_embeds, int max_tokens) { | ||||
|     if (max_tokens < 0) { | ||||
|         max_tokens = mConfig->max_new_tokens(); | ||||
|  | @ -621,8 +679,8 @@ std::vector<int> Llm::generate(MNN::Express::VARP input_embeds, int max_tokens) | |||
|     } | ||||
|     { | ||||
|         std::ofstream outFile("logits.txt"); | ||||
|         auto temp = outputs[0]->readMap<float>(); | ||||
|         for (size_t i = 0; i < outputs[0]->getInfo()->size; ++i) { | ||||
|         auto temp = mGenerateParam->outputs[0]->readMap<float>(); | ||||
|         for (size_t i = 0; i < mGenerateParam->outputs[0]->getInfo()->size; ++i) { | ||||
|             outFile << temp[i] << " "; // 每个数字后加空格
 | ||||
|         } | ||||
|         outFile.close(); | ||||
|  |  | |||
|  | @ -359,10 +359,6 @@ public: | |||
|         return config_.value("memory", "low"); | ||||
|     } | ||||
| 
 | ||||
|     int quant_qkv() const { | ||||
|         return config_.value("quant_qkv", 0); | ||||
|     } | ||||
| 
 | ||||
|     int kvcache_limit() const { | ||||
|         return config_.value("kvcache_limit", -1); | ||||
|     } | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ | |||
| #define _USE_MATH_DEFINES | ||||
| #endif | ||||
| #include <regex> | ||||
| #include <algorithm> | ||||
| #include <MNN/AutoTime.hpp> | ||||
| #include <MNN/expr/ExecutorScope.hpp> | ||||
| #include "omni.hpp" | ||||
|  | @ -555,8 +556,16 @@ std::vector<int> Omni::minicpmVisionProcess(VARP image) { | |||
| std::vector<int> Omni::visionProcess(const std::string& file) { | ||||
| #ifdef LLM_SUPPORT_VISION | ||||
|     VARP image = MNN::CV::imread(file); | ||||
|     return visionProcess(image); | ||||
| #else | ||||
|     return std::vector<int>(0); | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| std::vector<int> Omni::visionProcess(VARP image) { | ||||
| #ifdef LLM_SUPPORT_VISION | ||||
|     if (image == nullptr) { | ||||
|         MNN_PRINT("Omni Can't open image: %s\n", file.c_str()); | ||||
|         MNN_PRINT("Omni Can't open image\n"); | ||||
|         return std::vector<int>(0); | ||||
|     } | ||||
|     Timer _t; | ||||
|  | @ -573,7 +582,7 @@ std::vector<int> Omni::visionProcess(const std::string& file) { | |||
|     } else { | ||||
|         imgIds = defaultVisionProcess(image); | ||||
|     } | ||||
|     mContext->vision_us = _t.durationInUs(); | ||||
|     mContext->vision_us += _t.durationInUs(); | ||||
|     // set vision number for image idx
 | ||||
|     mVisionNum += 1; | ||||
|     return imgIds; | ||||
|  | @ -591,9 +600,19 @@ std::vector<int> Omni::audioProcess(const std::string& file) { | |||
|         MNN_PRINT("Omni Can't open audio: %s\n", file.c_str()); | ||||
|         return std::vector<int>(0); | ||||
|     } | ||||
|     // int sample_rate      = load_res.second;
 | ||||
|     int wav_len          = waveform->getInfo()->dim[0]; | ||||
|     int hop_length       = 160; | ||||
|     return audioProcess(waveform); | ||||
| #else | ||||
|     return std::vector<int>(0); | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| std::vector<int> Omni::audioProcess(MNN::Express::VARP waveform) { | ||||
| #ifdef LLM_SUPPORT_AUDIO | ||||
|     if (waveform == nullptr) { | ||||
|         MNN_PRINT("Omni Can't process audio: waveform is null\n"); | ||||
|         return std::vector<int>(0); | ||||
|     } | ||||
|      | ||||
|     Timer _t; | ||||
|     auto input_features  = MNN::AUDIO::whisper_fbank(waveform); | ||||
|     VARP audio_embedding; | ||||
|  | @ -727,21 +746,30 @@ void Omni::addPositionIds(int t, int h, int w) { | |||
|     } | ||||
| } | ||||
| 
 | ||||
| std::vector<int> Omni::tokenizer_encode(const std::string& prompt) { | ||||
|     // split query
 | ||||
| std::vector<int> Omni::tokenizer_encode(const MultimodalPrompt& multimodal_input) { | ||||
|     std::string prompt = multimodal_input.prompt_template; | ||||
|     // MNN_PRINT("tokenizer_encode(MultimodalPrompt) prompt: %s", prompt.c_str());
 | ||||
|     std::regex multimode_regex("<(img|audio)>(.*?)</\\1>"); | ||||
|     std::string::const_iterator searchStart(prompt.cbegin()); | ||||
|     std::smatch match; | ||||
|     std::vector<std::string> img_infos; | ||||
|     std::vector<int> ids{}; | ||||
| 
 | ||||
|     mPositionIds.clear(); | ||||
|      | ||||
|     while (std::regex_search(searchStart, prompt.cend(), match, multimode_regex)) { | ||||
|         // std::cout << "img match: " << match[1].str() << std::endl;
 | ||||
|         auto txt_ids = mTokenizer->encode(match.prefix().str()); | ||||
|         addPositionIds(txt_ids.size()); | ||||
|         ids.insert(ids.end(), txt_ids.begin(), txt_ids.end()); | ||||
|         auto mul_ids = multimodeProcess(match[1].str(), match[2].str()); | ||||
|         std::string mode = match[1].str(); | ||||
|         std::string content = match[2].str(); | ||||
|         std::vector<int> mul_ids; | ||||
|         if (mode == "img") { | ||||
|             mul_ids = processImageContent(content, multimodal_input.images); | ||||
|             // MNN_PRINT("tokenizer_encode(MultimodalPrompt) image mul_ids size: %lu", mul_ids.size());
 | ||||
|         } else if (mode == "audio") { | ||||
|             mul_ids = processAudioContent(content, multimodal_input.audios); | ||||
|             // MNN_PRINT("tokenizer_encode(MultimodalPrompt) audio mul_ids size: %lu", mul_ids.size());
 | ||||
|         } | ||||
|          | ||||
|         ids.insert(ids.end(), mul_ids.begin(), mul_ids.end()); | ||||
|         searchStart = match.suffix().first; | ||||
|     } | ||||
|  | @ -753,6 +781,43 @@ std::vector<int> Omni::tokenizer_encode(const std::string& prompt) { | |||
|     return ids; | ||||
| } | ||||
| 
 | ||||
| std::vector<int> Omni::tokenizer_encode(const std::string& prompt) { | ||||
|     MultimodalPrompt multimodal_input; | ||||
|     multimodal_input.prompt_template = prompt; | ||||
|     return tokenizer_encode(multimodal_input); | ||||
| } | ||||
| 
 | ||||
| std::vector<int> Omni::processImageContent(const std::string& content, const std::map<std::string, PromptImagePart>& images) { | ||||
|     auto it = images.find(content); | ||||
|     if (it != images.end()) { | ||||
|         if (it->second.height > 0 && it->second.width > 0) { | ||||
|             mVisionHeight = it->second.height; | ||||
|             mVisionWidth = it->second.width; | ||||
|         } | ||||
|         // MNN_PRINT("processImageContent: using placeholder '%s' with size %dx%d", content.c_str(), mVisionWidth, mVisionHeight);
 | ||||
|         return visionProcess(it->second.image_data); | ||||
|     } | ||||
|     // MNN_PRINT("processImageContent: treating '%s' as file path or URL", content.c_str());
 | ||||
|     return multimodeProcess("img", content); | ||||
| } | ||||
| 
 | ||||
| std::vector<int> Omni::processAudioContent(const std::string& content, const std::map<std::string, PromptAudioPart>& audios) { | ||||
|     auto it = audios.find(content); | ||||
|     if (it != audios.end()) { | ||||
|         // MNN_PRINT("processAudioContent: using placeholder '%s'", content.c_str());
 | ||||
|         if (it->second.waveform.get() != nullptr) { | ||||
|             return audioProcess(it->second.waveform); | ||||
|         } else if (!it->second.file_path.empty()) { | ||||
|             return audioProcess(it->second.file_path); | ||||
|         } else { | ||||
|             MNN_PRINT("processAudioContent: audio_part has no valid input\n"); | ||||
|             return std::vector<int>(0); | ||||
|         } | ||||
|     } | ||||
|     // MNN_PRINT("processAudioContent: treating '%s' as file path", content.c_str());
 | ||||
|     return multimodeProcess("audio", content); | ||||
| } | ||||
| 
 | ||||
| VARP Omni::embedding(const std::vector<int>& input_ids) { | ||||
|     if (input_ids.size() == 1) { | ||||
|         return Llm::embedding(input_ids); | ||||
|  | @ -845,21 +910,28 @@ VARP Omni::gen_position_ids(int seq_len) { | |||
|     auto ptr = positionIds->writeMap<int>(); | ||||
|     if (mContext->gen_seq_len > 0) { | ||||
|         for (int i=0; i<seq_len; ++i) { | ||||
|             auto pos = mContext->gen_seq_len + mPositionIds.back() + i; | ||||
|             // auto pos = mContext->gen_seq_len + mPositionIds.back() + i;
 | ||||
|             auto pos = mContext->all_seq_len + i; | ||||
|             ptr[i + 0] = pos; | ||||
|             ptr[i + seq_len] = pos; | ||||
|             ptr[i + seq_len * 2] = pos; | ||||
|         } | ||||
|     } else { | ||||
|         for (int i = 0; i < seq_len; i++) { | ||||
|             ptr[i] = mPositionIds.mT[i]; | ||||
|             ptr[i + seq_len] = mPositionIds.mH[i]; | ||||
|             ptr[i + seq_len * 2] = mPositionIds.mW[i]; | ||||
|             ptr[i] = mPositionIds.mT[i] + mContext->all_seq_len; | ||||
|             ptr[i + seq_len] = mPositionIds.mH[i] + mContext->all_seq_len; | ||||
|             ptr[i + seq_len * 2] = mPositionIds.mW[i] + mContext->all_seq_len; | ||||
|         } | ||||
|         if (mTalker) { | ||||
|             mTalker->setPostionIds(mPositionIds); | ||||
|         } | ||||
|     } | ||||
|     // // dump position ids
 | ||||
|     // printf("position_ids = [");
 | ||||
|     // for (int i = 0; i < seq_len; i++) {
 | ||||
|     //     printf("%d ", ptr[i]);
 | ||||
|     // }
 | ||||
|     // printf("]\n");
 | ||||
|     return positionIds; | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -105,12 +105,14 @@ public: | |||
|     virtual void load() override; | ||||
|     virtual std::vector<Express::VARP> forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos) override; | ||||
|     virtual std::vector<int> tokenizer_encode(const std::string& query) override; | ||||
|     virtual std::vector<int> tokenizer_encode(const MultimodalPrompt& multimodal_input) override; | ||||
|     virtual Express::VARP embedding(const std::vector<int>& input_ids) override; | ||||
|     virtual Express::VARP gen_position_ids(int seq_len) override; | ||||
|     virtual void response(const std::vector<int>& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, int max_new_tokens = -1) override; | ||||
|     virtual void setWavformCallback(std::function<bool(const float*, size_t, bool)> callback) override; | ||||
|     virtual void generateWavform() override; | ||||
|     // some models preprocess function
 | ||||
|     std::vector<int> visionProcess(VARP image); | ||||
|     std::vector<int> defaultVisionProcess(VARP image); | ||||
|     std::vector<int> qwen2VisionProcess(VARP image); | ||||
|     std::vector<int> smolvlmVisionProcess(VARP image); | ||||
|  | @ -126,6 +128,9 @@ private: | |||
|     std::vector<int> multimodeProcess(const std::string& mode, std::string info); | ||||
|     std::vector<int> visionProcess(const std::string& file); | ||||
|     std::vector<int> audioProcess(const std::string& file); | ||||
|     std::vector<int> audioProcess(MNN::Express::VARP waveform); | ||||
|     std::vector<int> processImageContent(const std::string& content, const std::map<std::string, PromptImagePart>& images); | ||||
|     std::vector<int> processAudioContent(const std::string& content, const std::map<std::string, PromptAudioPart>& audios); | ||||
|     std::shared_ptr<Module> mVisionModule, mAudioModule; | ||||
|     std::vector<VARP> mVisionEmbeddings, mAudioEmbeddings; | ||||
|     std::shared_ptr<Talker> mTalker; | ||||
|  |  | |||
|  | @ -225,6 +225,7 @@ class LlmExporter(torch.nn.Module): | |||
|             self.llm_config['jinja'] = prompt_template['jinja'] | ||||
|         # load modules | ||||
|         ModelMapper.do_map(self, self.model, self.model_map['model']) | ||||
| 
 | ||||
|         # rebuild modules | ||||
|         if self.lm_ is None: | ||||
|             out_features, in_features = self.embed_.weight.shape | ||||
|  | @ -244,6 +245,7 @@ class LlmExporter(torch.nn.Module): | |||
|             self.embed = Embedding(embed_copy, self) | ||||
|         else: | ||||
|             self.embed = Embedding(self.embed_, self) | ||||
| 
 | ||||
|         # tie word embeddings | ||||
|         self.tie_word_embeddings = not self.args.seperate_embed and self.lm_.weight.equal(self.embed_.weight) | ||||
|         if self.tie_word_embeddings: | ||||
|  | @ -802,14 +804,21 @@ class LlmExporter(torch.nn.Module): | |||
|         if self.mnn_converter: | ||||
|             fuse_transformer = self.visual.transformer_fuse | ||||
|             native_group_conv = self.visual.group_conv_native | ||||
|             quant_bit_visual = self.visual.quant_bit | ||||
|             quant_block_visual = self.visual.quant_block | ||||
|             if self.args.transformer_fuse: | ||||
|                 fuse_transformer = True | ||||
|             if self.args.group_conv_native: | ||||
|                 native_group_conv = True | ||||
|             self.mnn_converter.export(vision_onnx, self.visual.quant_bit, | ||||
|                                       self.visual.quant_block, | ||||
|             if self.args.visual_quant_bit is not None: | ||||
|                 quant_bit_visual = self.args.visual_quant_bit | ||||
|             if self.args.visual_quant_block is not None: | ||||
|                 quant_block_visual = self.args.visual_quant_block | ||||
|             self.mnn_converter.export(vision_onnx, quant_bit_visual, | ||||
|                                       quant_block_visual, | ||||
|                                       transformer_fuse=fuse_transformer, | ||||
|                                       group_conv_native=native_group_conv) | ||||
|                                       group_conv_native=native_group_conv, | ||||
|                                       weight_sym=self.args.visual_sym) | ||||
| 
 | ||||
|     def export_audio(self): | ||||
|         if self.audio is None: | ||||
|  | @ -1237,6 +1246,9 @@ def export(path, | |||
|         'onnx_slim': onnx_slim, | ||||
|         'quant_bit': quant_bit, | ||||
|         'quant_block': quant_block, | ||||
|         'visual_quant_bit': visual_quant_bit, | ||||
|         'visual_quant_block': visual_quant_block, | ||||
|         'visual_sym': visual_sym, | ||||
|         'lm_quant_bit': lm_quant_bit, | ||||
|         'mnnconvert': mnnconvert, | ||||
|         'ppl': ppl, | ||||
|  | @ -1276,6 +1288,8 @@ def main(): | |||
|     parser.add_argument('--onnx_slim', action='store_true', help='Whether or not to use onnx-slim.') | ||||
|     parser.add_argument('--quant_bit', type=int, default=4, help='mnn quant bit, 4 or 8, default is 4.') | ||||
|     parser.add_argument('--quant_block', type=int, default=64, help='mnn quant block, 0 mean channle-wise, default is 64.') | ||||
|     parser.add_argument('--visual_quant_bit', type=int, default=None, help='mnn viusal quant bit, 4 or 8, default is setting in utils/vision.py by different vit model.') | ||||
|     parser.add_argument('--visual_quant_block', type=int, default=None, help='mnn quant block, default is setting in utils/vision.py by different vit model.') | ||||
|     parser.add_argument('--lm_quant_bit', type=int, default=None, help='mnn lm_head quant bit, 4 or 8, default is `quant_bit`.') | ||||
|     parser.add_argument('--mnnconvert', type=str, default='../../../build/MNNConvert', help='local mnnconvert path, if invalid, using pymnn.') | ||||
|     parser.add_argument('--ppl', action='store_true', help='Whether or not to get all logits of input tokens.') | ||||
|  | @ -1284,9 +1298,10 @@ def main(): | |||
|     parser.add_argument('--transformer_fuse', action='store_true', help='Whether or not to fuse vision transformer op.') | ||||
|     parser.add_argument('--group_conv_native', action='store_true', help='Whether or not to keep native group_conv.') | ||||
|     parser.add_argument('--smooth', action='store_true', help='Whether or not to use smooth quant.') | ||||
|     parser.add_argument('--sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint), defualt is False.') | ||||
|     parser.add_argument('--seperate_embed', action='store_true', help='For lm and embed shared model, whether or not to sepearte embed to avoid quant, defualt is False, if True, embed weight will be seperate to embeddingbf16.bin.') | ||||
|     parser.add_argument('--lora_split', action='store_true', help='Whether or not export lora split, defualt is False.') | ||||
|     parser.add_argument('--sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint), default is False.') | ||||
|     parser.add_argument('--visual_sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint) for visual model, default is False.') | ||||
|     parser.add_argument('--seperate_embed', action='store_true', help='For lm and embed shared model, whether or not to sepearte embed to avoid quant, default is False, if True, embed weight will be seperate to embeddingbf16.bin.') | ||||
|     parser.add_argument('--lora_split', action='store_true', help='Whether or not export lora split, default is False.') | ||||
|     parser.add_argument('--calib_data', type=str, default=None, help='calibration data path, defaut is `None` mean not use calib data.') | ||||
|     args = parser.parse_args() | ||||
| 
 | ||||
|  |  | |||
|  | @ -51,7 +51,7 @@ class MNNConveter: | |||
|             os.close(log_fd) | ||||
| 
 | ||||
|     @spinner_run(f'convert onnx model to ') | ||||
|     def onnx2mnn(self, onnx_path, mnn_path, args = [], transformer_fuse = True, group_conv_native = False, save_external_data = True): | ||||
|     def onnx2mnn(self, onnx_path, mnn_path, args = [], transformer_fuse = True, group_conv_native = False, weight_sym = False, save_external_data = True): | ||||
|         convert_args = [ | ||||
|             '', | ||||
|             '-f', | ||||
|  | @ -66,6 +66,8 @@ class MNNConveter: | |||
|             convert_args += ['--transformerFuse'] | ||||
|         if group_conv_native: | ||||
|             convert_args += ['--groupConvNative'] | ||||
|         if weight_sym: | ||||
|             convert_args += ['--weightQuantAsymmetric=0'] | ||||
|         if save_external_data: | ||||
|             convert_args += ['--saveExternalData'] | ||||
|         convert_args += args | ||||
|  | @ -112,7 +114,7 @@ class MNNConveter: | |||
|         self.convert(convert_args) | ||||
|         return mnn_path | ||||
| 
 | ||||
|     def export(self, onnx_path, quant_bit = None, quant_block = None, transformer_fuse = True, group_conv_native = False): | ||||
|     def export(self, onnx_path, quant_bit = None, quant_block = None, transformer_fuse = True, group_conv_native = False, weight_sym = None): | ||||
|         self.onnx_model_path = onnx_path | ||||
|         self.mnn_name = os.path.basename(onnx_path).replace('.onnx', '.mnn') | ||||
|         self.mnn_model_path = os.path.join(self.config.args.dst_path, self.mnn_name) | ||||
|  | @ -133,10 +135,10 @@ class MNNConveter: | |||
|                 ] | ||||
|             if quant_bit == 32: | ||||
|                 quant_args = [] | ||||
|             self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, quant_args, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native) | ||||
|             self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, quant_args, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native, weight_sym=weight_sym) | ||||
|         else: | ||||
|             mnn_json = f'{self.mnn_model_path}.json' | ||||
|             self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native) | ||||
|             self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native, weight_sym=weight_sym) | ||||
|             self.mnn2json(self.mnn_model_path, mnn_json) | ||||
|             self.rebuild(mnn_json) | ||||
|             self.json2mnn(mnn_json, self.mnn_model_path) | ||||
|  |  | |||
|  | @ -266,7 +266,6 @@ class Rotary(torch.nn.Module): | |||
| 
 | ||||
|         def get_theta(): | ||||
|             return 1.0 / (self.rope_theta ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim)) | ||||
| 
 | ||||
|         # default rope type's theta | ||||
|         self.theta = get_theta() | ||||
|         # other type | ||||
|  | @ -278,7 +277,6 @@ class Rotary(torch.nn.Module): | |||
|                 rope_type = config.rope_scaling['type'] | ||||
|             elif 'rope_type' in config.rope_scaling: | ||||
|                 rope_type = config.rope_scaling['rope_type'] | ||||
| 
 | ||||
|             # gen theta for rope_type | ||||
|             if rope_type == 'dynamic': # NTK | ||||
|                 if 'alpha' in config.rope_scaling: # NTKAlpha in Hunyuan | ||||
|  | @ -286,7 +284,7 @@ class Rotary(torch.nn.Module): | |||
|                 else: # NTKScaling | ||||
|                     pass | ||||
|                 self.theta = get_theta() | ||||
|             elif rope_type == 'yarn': # YaRN in gpt-oss | ||||
|             elif rope_type == 'yarn': | ||||
|                 self.is_scaled = True | ||||
|                 self.theta, self.attention_scaling = _compute_yarn_parameters( | ||||
|                     rotary_dim=self.rotary_dim, | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue