MNN:Sync: Sync Internal 3.2.3

This commit is contained in:
xiaying 2025-08-22 18:04:08 +08:00
parent 8f175e2748
commit 318a3de860
82 changed files with 5003 additions and 2240 deletions

4
.gitignore vendored
View File

@ -376,4 +376,6 @@ datasets/*
# qnn 3rdParty
source/backend/qnn/3rdParty/include
apps/Android/MnnLlmChat/release_outputs
project/android/.cxx
pymnn/android/.cxx/
pymnn/android/.cxx/abi_configuration_5u53tc49.json

View File

@ -11,8 +11,6 @@
## News 🔥
- [2025/08/08] Now we support [gpt-oss-20b](./apps/Android/MnnLlmChat/README.md#releases).
- [2025/08/05] MNN Chat Android is availabe in [GooglePlay](https://play.google.com/store/apps/details?id=com.alibaba.mnnllm.android.release) !
- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md)
<p align="center">
<img width="20%" alt="Icon" src="https://meta.alicdn.com/data/mnn/avatar/avatar_demo.gif" style="margin: 0 10px;">

View File

@ -183,6 +183,21 @@ struct Info {
const Info* getInfo() const;
```
### 获取设备信息
调用`getDeviceInfo`函数可获取`Device`信息,可以参考代码:
```cpp
std::string soc_id, dsp_arch;
bool success = MNN::Express::Executor::RuntimeManager::getDeviceInfo("dsp_arch", MNN_FORWARD_NN, dsp_arch);
if(success) {
MNN_PRINT("Device dsp_arch: %s\n", dsp_arch.c_str());
}
success = MNN::Express::Executor::RuntimeManager::getDeviceInfo("soc_id", MNN_FORWARD_NN, soc_id);
if(success) {
MNN_PRINT("Device soc_id: %s\n", soc_id.c_str());
}
```
### 执行推理
调用`onForward`执行推理。

View File

@ -58,6 +58,10 @@ adb push ${MNN_ROOT}/source/backend/qnn/3rdParty/lib/hexagon-v${HEXAGON_ARCH}/un
adb shell "cd /data/local/tmp && LD_LIBRARY_PATH=/data/local/tmp ADSP_LIBRARY_PATH=/data/local/tmp ./MyExe.out"
```
### QNN量化功能说明
- 仅权重量化激活是浮点只支持Linear权重int8、channel-wise的对称量化。
- 激活&权重都量化支持激活per-tensor对称量化权重是int8/int4、channel-wise的对称量化。
## CoreML
适用于 Mac / iOS / iPad

View File

@ -388,3 +388,12 @@ npu model path:./qnn_smolvlm_model.bin
./ModuleBasic.out qnn_smolvlm_model.mnn dir 0 0 10
```
### 生成多种QNN设备模型脚本
tools/script/genQNNModelsFromMNN.py中提供了8Gen1 ~ 8Elite设备的QNN模型生成脚本
```
// 使用示例
cd mnn_path
cd build
python3 ../tools/script/genQNNModelsFromMNN.py --config_path ../source/backend/qnn/convertor/config_example/ --graph_name visual_qnn --qnn_sdk_root_path /mnt/2Tpartition/tianbu/QNN/qairt/2.37.0.250724/ --src_model visual.mnn --executable_path ./MNN2QNNModel
```
后续将在qnn_models文件夹下生成8Gen1 ~ 8Elite设备的QNN模型产物。

View File

@ -106,6 +106,10 @@ optional arguments:
mnn quant bit, 4 or 8, default is 4.
--quant_block QUANT_BLOCK
mnn quant block, 0 mean channle-wise, default is 128.
--visual_quant_bit VISUAL_QUANT_BIT
mnn visual model quant bit, 4 or 8, default is setting in utils/vision.py by different vit model.
--visual_quant_block VISUAL_QUANT_BLOCK
mnn visual model quant block, 0 mean channle-wise, default is setting in utils/vision.py by different vit model.
--lm_quant_bit LM_QUANT_BIT
mnn lm_head quant bit, 4 or 8, default is `quant_bit`.
--mnnconvert MNNCONVERT
@ -113,10 +117,12 @@ optional arguments:
--ppl Whether or not to get all logits of input tokens.
--awq Whether or not to use awq quant.
--sym Whether or not to using symmetric quant (without zeropoint), defualt is False.
--visual_sym Whether or not to using symmetric quant (without zeropoint) for visual model, defualt is False.
--seperate_embed For lm and embed shared model, whether or not to sepearte embed to avoid quant, defualt is False, if True, embed weight will be seperate to embeddingbf16.bin.
--lora_split Whether or not export lora split, defualt is False.
```
### 权重读取
llmexport.py 同时支持 LLM 的验证功能有较多的依赖。在没有相应环境的情况下MNN-LLM也提供由 safetensors 或 gguf 文件读取权重的工具,可以降低内存需求,提高转换速度。使用方法如下:
@ -166,6 +172,7 @@ python3 gguf2mnn.py --gguf ~/third/llama.cpp/build/ggml-model-Q4_K.gguf --mnn_di
```
-DLLM_SUPPORT_VISION=true -DMNN_BUILD_OPENCV=true -DMNN_IMGCODECS=true
```
- 需要开启音频功能时,增加相关编译宏
```
-DLLM_SUPPORT_AUDIO=true -DMNN_BUILD_AUDIO=true
@ -195,6 +202,12 @@ cd project/android
mkdir build_64
../build_64.sh -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_ARM82=true -DMNN_OPENCL=true -DMNN_USE_LOGCAT=true
```
高通设备部分视觉模型支持NPU功能可增加`MNN_QNN` 和`MNN_WITH_PLUGIN`的宏启用QNN功能。
```
cd project/android
mkdir build_64
../build_64.sh -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_ARM82=true -DMNN_OPENCL=true -DMNN_QNN=true -DMNN_WITH_PLUGIN=true -DMNN_USE_LOGCAT=true
```
#### iOS: 参考 transformers/llm/engine/ios/README.md
```

View File

@ -281,6 +281,17 @@ bool Executor::RuntimeManager::getInfo(Interpreter::SessionInfoCode code, void*
return false;
}
bool Executor::RuntimeManager::getDeviceInfo(const std::string& deviceKey, const MNNForwardType type, std::string& deviceValue) {
auto creator = MNNGetExtraRuntimeCreator(type);
if (creator != nullptr) {
auto res = creator->onGetDeviceInfo(deviceKey, deviceValue);
if(res) {
return true;
}
}
return false;
}
Executor::RuntimeManager::RuntimeManager() {
mInside = new RuntimeAttr;
mInside->mContent.reset(new RuntimeAttr::Immutable);

View File

@ -130,6 +130,7 @@ public:
void setHint(Interpreter::HintMode mode, int* value, size_t size);
void setHintPtr(Interpreter::HintMode mode, void* value);
bool getInfo(Interpreter::SessionInfoCode code, void* ptr);
static bool getDeviceInfo(const std::string& deviceKey, const MNNForwardType type, std::string& deviceValue);
BackendConfig* getBnConfig();
const RuntimeAttr* getInside() const {
return mInside;

View File

@ -25,3 +25,9 @@ set_target_properties( MNNOpenCV
PROPERTIES IMPORTED_LOCATION
${CMAKE_CURRENT_LIST_DIR}/libs/${ANDROID_ABI}/libMNNOpenCV.so
)
add_library(MNN_LLM SHARED IMPORTED GLOBAL )
set_target_properties(MNN_LLM
PROPERTIES IMPORTED_LOCATION
${CMAKE_CURRENT_LIST_DIR}/libs/${ANDROID_ABI}/libllm.so
)

View File

@ -3,4 +3,4 @@ distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
distributionUrl=http\://mtl-gradle-mirror.oss-cn-hangzhou.aliyuncs.com/gradle-4.6-all.zip
distributionUrl=http://mtl-gradle-mirror.oss-cn-hangzhou.aliyuncs.com/gradle-6.7.1-all.zip

View File

@ -20,23 +20,48 @@ afterEvaluate {
android.libraryVariants.all {variant ->
def name = variant.buildType.name
def zipTask = tasks.create(name: "zipNative${name.capitalize()}", type: Zip) {
from project.zipTree("${variant.packageLibrary.destinationDir.path}/${project.name}-${name}.aar")
from project.zipTree("${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}.aar")
from "$buildDir/native-packed"
archiveName "${project.name}-${name}-tmp.aar"
archiveName "${artifactName}-${name}-tmp.aar"
destinationDir(variant.packageLibrary.destinationDir)
inputs.dir("$buildDir/native-packed")
inputs.file("${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}.aar")
outputs.file("${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}.aar")
}
zipTask.dependsOn("externalNativeBuild${name.capitalize()}")
// Ensure QNN dependencies are prepared when BUILD_QNN is enabled
if (project.ext.has('BUILD_QNN') && project.ext.BUILD_QNN) {
zipTask.dependsOn('prepareQnnDeps')
}
zipTask.doLast {
copy {
from "${variant.packageLibrary.destinationDir.path}/${project.name}-${name}-tmp.aar"
from "${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}-tmp.aar"
into "${variant.packageLibrary.destinationDir.path}"
rename{
String fileName->
fileName.replace('-tmp','')
}
}
delete "${variant.packageLibrary.destinationDir.path}/${project.name}-${name}-tmp.aar"
delete "${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}-tmp.aar"
println "Generated final AAR: ${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}.aar"
}
tasks.findByName("assemble${name.capitalize()}").doLast {
zipTask.execute()
def bundleTask = tasks.findByName("bundle${name.capitalize()}Aar")
if (bundleTask != null) {
zipTask.dependsOn(bundleTask)
}
def assembleTask = tasks.findByName("assemble${name.capitalize()}")
if (assembleTask != null) {
assembleTask.dependsOn(zipTask)
zipTask.mustRunAfter(bundleTask)
}
def packageTask = tasks.findByName("package${name.capitalize()}Aar")
if (packageTask != null) {
packageTask.finalizedBy(zipTask)
}
}

View File

@ -0,0 +1,174 @@
// QNN Dependencies Preparation
// This gradle file handles downloading and preparing QNN dependencies
// QNN configuration
ext {
// QNN download settings
QNN_LIBS_URL = 'http://meta.alicdn.com/data/mnn/libs/qnn_inc_libs.zip'
}
def qnnZipName = 'qnn_inc_libs.zip'
def qnnZipFile = new File(buildDir, qnnZipName)
def qnnTmpDir = new File(buildDir, 'qnn_tmp')
task prepareQnnDeps {
group = 'build setup'
description = 'Download and extract QNN include/lib into source/backend/qnn/3rdParty when BUILD_QNN is enabled.'
onlyIf {
project.ext.has('BUILD_QNN') && project.ext.BUILD_QNN
}
// Define inputs and outputs for incremental build
inputs.property("qnn_libs_url", { QNN_LIBS_URL })
inputs.property("build_qnn", { project.ext.has('BUILD_QNN') && project.ext.BUILD_QNN })
outputs.dir(new File(project.rootDir, '../../source/backend/qnn/3rdParty'))
outputs.dir(new File(buildDir, 'native-packed/native/libs'))
doLast {
println "BUILD_QNN is ON. Preparing QNN dependencies..."
def url = QNN_LIBS_URL
println "Downloading QNN dependencies from ${url}"
if (!qnnZipFile.exists()) {
println "Downloading ${qnnZipName} ..."
qnnZipFile.parentFile.mkdirs()
try {
new java.net.URL(url).withInputStream { inputStream ->
qnnZipFile.withOutputStream { outputStream ->
outputStream << inputStream
}
}
println "Downloaded to: ${qnnZipFile.absolutePath}"
} catch (Exception e) {
throw new RuntimeException("Failed to download QNN dependencies from ${url}. Error: ${e.message}")
}
} else {
println "Using cached zip: ${qnnZipFile.absolutePath}"
}
// Clean temp dir and unpack
project.delete(qnnTmpDir)
qnnTmpDir.mkdirs()
copy {
from zipTree(qnnZipFile)
into qnnTmpDir
}
// Find the extracted QNN directory
def extractedQnnDir = findQnnDirectory(qnnTmpDir)
if (extractedQnnDir == null) {
throw new RuntimeException("Failed to find QNN directory structure in ${qnnZipFile.name}")
}
copyQnnFiles(extractedQnnDir)
}
}
def findQnnDirectory(File searchDir) {
// Look for directory containing both include and jniLibs (or lib) directories
def candidates = []
searchDir.eachDirRecurse { dir ->
def hasInclude = new File(dir, 'include').exists()
def hasLibs = new File(dir, 'jniLibs').exists() || new File(dir, 'lib').exists()
if (hasInclude && hasLibs) {
candidates.add(dir)
}
}
if (candidates.isEmpty()) {
// Fallback: look for directory that contains include
searchDir.eachDirRecurse { dir ->
if (new File(dir, 'include').exists()) {
candidates.add(dir)
}
}
}
return candidates.isEmpty() ? null : candidates[0]
}
def copyQnnFiles(File sourceDir) {
// Resolve destination directories
def qnnRoot = new File(project.rootDir, '../../source/backend/qnn/3rdParty')
def includeDir = new File(qnnRoot, 'include')
def libDir = new File(qnnRoot, 'lib')
includeDir.mkdirs()
libDir.mkdirs()
// Copy include files
def sourceInclude = new File(sourceDir, 'include')
if (sourceInclude.exists()) {
copy {
from sourceInclude
into includeDir
}
println "QNN includes copied to: ${includeDir.absolutePath}"
} else {
throw new RuntimeException("Include directory not found in ${sourceDir.absolutePath}")
}
// Copy library files - try both jniLibs and lib directories
def sourceLibs = new File(sourceDir, 'jniLibs')
if (!sourceLibs.exists()) {
sourceLibs = new File(sourceDir, 'lib')
}
if (sourceLibs.exists()) {
copy {
from sourceLibs
into libDir
}
println "QNN libs copied to: ${libDir.absolutePath}"
// Also copy QNN .so files to native-packed for AAR packaging
copyQnnLibsToNativePacked(sourceLibs)
} else {
println "Warning: No lib/jniLibs directory found in ${sourceDir.absolutePath}"
}
println "QNN dependencies preparation completed successfully."
}
def copyQnnLibsToNativePacked(File sourceLibsDir) {
// Create native-packed directory structure for QNN .so files
def nativePackedLibsDir = new File(buildDir, 'native-packed/native/libs')
nativePackedLibsDir.mkdirs()
println "Copying QNN .so files to native-packed for AAR packaging..."
// Copy all .so files from QNN libs to native-packed
sourceLibsDir.eachFileRecurse { file ->
if (file.name.endsWith('.so')) {
// Determine the ABI directory (arm64-v8a, armeabi-v7a, etc.)
def relativePath = sourceLibsDir.toPath().relativize(file.toPath())
def targetFile = new File(nativePackedLibsDir, relativePath.toString())
// Create parent directories if they don't exist
targetFile.parentFile.mkdirs()
// Copy the .so file
copy {
from file
into targetFile.parentFile
}
println "Copied QNN lib: ${file.name} -> ${targetFile.absolutePath}"
}
}
println "QNN .so files copied to native-packed directory"
}
// Ensure preparation runs before compilation when enabled
if (project.ext.has('BUILD_QNN') && project.ext.BUILD_QNN) {
afterEvaluate {
if (tasks.findByName('preBuild')) {
preBuild.dependsOn prepareQnnDeps
}
}
}

View File

@ -480,7 +480,6 @@
92FF02E023AA0B5A00AC97F6 /* MNNMatrixSub.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016023AA0B4E00AC97F6 /* MNNMatrixSub.S */; };
92FF02E123AA0B5A00AC97F6 /* MNNPowC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016123AA0B4E00AC97F6 /* MNNPowC8.S */; };
92FF02E223AA0B5A00AC97F6 /* MNNMatrixAdd.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */; };
92FF02E323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */; };
92FF02E523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */; };
92FF02E723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */; };
92FF02E823AA0B5A00AC97F6 /* MNNSamplerC1BilinearOpt.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */; };
@ -520,7 +519,6 @@
92FF032023AA0B5A00AC97F6 /* MNNMatrixSub.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A123AA0B4E00AC97F6 /* MNNMatrixSub.S */; };
92FF032123AA0B5A00AC97F6 /* MNNPowC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A223AA0B4E00AC97F6 /* MNNPowC8.S */; };
92FF032223AA0B5A00AC97F6 /* MNNMatrixAdd.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */; };
92FF032323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */; };
92FF032523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */; };
92FF032723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */; };
92FF032823AA0B5A00AC97F6 /* MNNSamplerC1BilinearOpt.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */; };
@ -1315,7 +1313,6 @@
92FF016023AA0B4E00AC97F6 /* MNNMatrixSub.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixSub.S; sourceTree = "<group>"; };
92FF016123AA0B4E00AC97F6 /* MNNPowC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPowC8.S; sourceTree = "<group>"; };
92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixAdd.S; sourceTree = "<group>"; };
92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNExpC8.S; sourceTree = "<group>"; };
92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNConvDwF23SourceTransUnit.S; sourceTree = "<group>"; };
92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNDeconvRunForUnitDepthWise.S; sourceTree = "<group>"; };
92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNSamplerC1BilinearOpt.S; sourceTree = "<group>"; };
@ -1355,7 +1352,6 @@
92FF01A123AA0B4E00AC97F6 /* MNNMatrixSub.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixSub.S; sourceTree = "<group>"; };
92FF01A223AA0B4E00AC97F6 /* MNNPowC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPowC8.S; sourceTree = "<group>"; };
92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixAdd.S; sourceTree = "<group>"; };
92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNExpC8.S; sourceTree = "<group>"; };
92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNConvDwF23SourceTransUnit.S; sourceTree = "<group>"; };
92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNDeconvRunForUnitDepthWise.S; sourceTree = "<group>"; };
92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNSamplerC1BilinearOpt.S; sourceTree = "<group>"; };
@ -2591,7 +2587,6 @@
92FF016023AA0B4E00AC97F6 /* MNNMatrixSub.S */,
92FF016123AA0B4E00AC97F6 /* MNNPowC8.S */,
92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */,
92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */,
92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */,
92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */,
92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */,
@ -2677,7 +2672,6 @@
92FF01A123AA0B4E00AC97F6 /* MNNMatrixSub.S */,
92FF01A223AA0B4E00AC97F6 /* MNNPowC8.S */,
92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */,
92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */,
92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */,
92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */,
92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */,
@ -3279,7 +3273,6 @@
92FF02C223AA0B5A00AC97F6 /* MNNLoadU8AndSum.S in Sources */,
4819FB2E24C1396A0050BD09 /* GeometryLSTM.cpp in Sources */,
9558334B29B09A7B00488807 /* MNNGeluFP16.S in Sources */,
92FF02E323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */,
92FF030223AA0B5A00AC97F6 /* MNNQuanToDestUint8.S in Sources */,
489D7AA12550FDC900AD896A /* MetalUnary.mm in Sources */,
92FF037323AA0B5A00AC97F6 /* CPUEltwiseInt8.cpp in Sources */,
@ -3438,7 +3431,6 @@
92FF030723AA0B5A00AC97F6 /* MNNBlitC1ToFloatRGBA.S in Sources */,
92FF03A423AA0B5A00AC97F6 /* OptimizedComputer.cpp in Sources */,
92FF032E23AA0B5A00AC97F6 /* MNNReluWithSlopeChannel.S in Sources */,
92FF032323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */,
95278CEA2B9F09C0009E9B29 /* ShapeDynamicQuant.cpp in Sources */,
92FF044C23AA0B7100AC97F6 /* ShapePool3D.cpp in Sources */,
92FF029823AA0B5A00AC97F6 /* CPUTFQuantizedConv2D.cpp in Sources */,

View File

@ -17,6 +17,7 @@ option(PYMNN_INTERNAL_SERVING "Internal use only." OFF)
option(PYMNN_OPENCV_API "MNN OpenCV API be exposed" ON)
option(PYMNN_IMGCODECS "MNN IMGCODECS API be exposed" OFF)
option(PYMNN_AUDIO_API "MNN Audio API be exposed" OFF)
option(PYMNN_LLM_API "MNN LLM API be exposed" OFF)
option(PYMNN_OHOS_INTERNAL "compile for harmony internal." OFF)
option(PYMNN_LLM_COLLECTION "offline llm data collection." OFF)
@ -97,6 +98,10 @@ if(PYMNN_AUDIO_API)
target_compile_definitions(mnnpybridge PRIVATE PYMNN_AUDIO_API)
endif()
if(PYMNN_LLM_API)
target_compile_definitions(mnnpybridge PRIVATE PYMNN_LLM_API)
endif()
if(PYMNN_INTERNAL_SERVING)
message(STATUS "mnnpybridge define PYMNN_INTERNAL_SERVING")
target_compile_definitions(mnnpybridge PRIVATE PYMNN_INTERNAL_SERVING)
@ -214,6 +219,9 @@ else()
if(PYMNN_NUMPY_USABLE)
target_link_libraries(mnnpybridge PRIVATE numpy_python)
endif()
if(PYMNN_LLM_API)
target_link_libraries(mnnpybridge PRIVATE MNN_LLM)
endif()
endif()
endif()

View File

@ -1,6 +1,10 @@
#include <sstream>
#include <iostream>
#ifdef BUILD_FOR_IOS
#include "MNN/llm/llm.hpp"
#else
#include "llm/llm.hpp"
#endif
#ifdef PYMNN_LLM_COLLECTION
#include "cpp/getLinearInput.hpp"
#endif
@ -85,18 +89,87 @@ static PyObject* PyMNNLLM_getCurrentHistory(LLM *self, PyObject *args) {
}
static PyObject* PyMNNLLM_response(LLM *self, PyObject *args) {
if (self->is_embedding) {
MNN_PRINT("[MNNLLM] response: is_embedding\n");
Py_RETURN_NONE;
}
const char* query = NULL;
int stream = 0;
if (!PyArg_ParseTuple(args, "s|p", &query, &stream)) {
Py_RETURN_NONE;
}
std::ostringstream null_os;
self->llm->response(query, stream ? &std::cout : &null_os);
return string2Object(null_os.str());
}
PyObject* content = nullptr;
int stream = 0;
int max_new_tokens = 2048;
if (!PyArg_ParseTuple(args, "O|ii", &content, &stream, &max_new_tokens)) {
MNN_PRINT("[MNNLLM] response: PyArg_ParseTuple failed\n");
Py_RETURN_NONE;
}
std::ostringstream null_os;
std::ostream* output_stream = stream ? &std::cout : &null_os;
if (isString(content)) {
std::string text = object2String(content);
MNN_PRINT("[MNNLLM] response: text=%s, stream=%d, max_new_tokens=%d\n", text.c_str(), stream, max_new_tokens);
self->llm->response(text, output_stream, nullptr, max_new_tokens);
} else if (isPyDict(content)) {
MNN::Transformer::MultimodalPrompt multimodal_input;
PyObject* text_obj = PyDict_GetItemString(content, "text");
if (text_obj && isString(text_obj)) {
multimodal_input.prompt_template = object2String(text_obj);
}
PyObject* images_obj = PyDict_GetItemString(content, "images");
if (images_obj && PyList_Check(images_obj)) {
Py_ssize_t img_count = PyList_Size(images_obj);
for (Py_ssize_t i = 0; i < img_count; i++) {
PyObject* img_dict = PyList_GetItem(images_obj, i);
if (isPyDict(img_dict)) {
PyObject* data_obj = PyDict_GetItemString(img_dict, "data");
PyObject* width_obj = PyDict_GetItemString(img_dict, "width");
PyObject* height_obj = PyDict_GetItemString(img_dict, "height");
if (data_obj && width_obj && height_obj) {
MNN::Transformer::PromptImagePart image_part;
image_part.image_data = toVar(data_obj);
image_part.width = PyLong_AsLong(width_obj);
image_part.height = PyLong_AsLong(height_obj);
std::string key = "image_" + std::to_string(i);
multimodal_input.images[key] = image_part;
}
}
}
}
PyObject* audios_obj = PyDict_GetItemString(content, "audios");
if (audios_obj && PyList_Check(audios_obj)) {
Py_ssize_t audio_count = PyList_Size(audios_obj);
for (Py_ssize_t i = 0; i < audio_count; i++) {
PyObject* audio_dict = PyList_GetItem(audios_obj, i);
if (isPyDict(audio_dict)) {
MNN::Transformer::PromptAudioPart audio_part;
PyObject* file_path_obj = PyDict_GetItemString(audio_dict, "file_path");
if (file_path_obj && isString(file_path_obj)) {
audio_part.file_path = object2String(file_path_obj);
}
PyObject* waveform_obj = PyDict_GetItemString(audio_dict, "waveform");
if (waveform_obj) {
audio_part.waveform = toVar(waveform_obj);
}
std::string key = "audio_" + std::to_string(i);
multimodal_input.audios[key] = audio_part;
}
}
}
MNN_PRINT("[MNNLLM] response: multimodal, stream=%d, max_new_tokens=%d\n", stream, max_new_tokens);
self->llm->response(multimodal_input, output_stream, nullptr, max_new_tokens);
} else {
PyMNN_ERROR("content must be str or dict");
}
std::string response_str = null_os.str();
MNN_PRINT("[MNNLLM] response: %s\n", response_str.c_str());
return string2Object(response_str);
}
static PyObject* PyMNNLLM_tokenizer_encode(LLM *self, PyObject *args) {
if (self->is_embedding) {
Py_RETURN_NONE;
@ -149,6 +222,14 @@ static PyObject* PyMNNLLM_reset(LLM *self, PyObject *args) {
Py_RETURN_NONE;
}
static PyObject* PyMNNLLM_get_statistics(LLM *self, PyObject *args) {
if (self->is_embedding) {
Py_RETURN_NONE;
}
auto statistics = self->llm->get_statistics();
return string2Object(statistics);
}
#ifdef PYMNN_LLM_COLLECTION
static PyObject* PyMNNLLM_enable_collection_mode(LLM *self, PyObject *args) {
if (self->is_embedding) {
@ -205,12 +286,11 @@ static PyObject* PyMNNLLM_enable_collection_mode(LLM *self, PyObject *args) {
return toPyObj(true);
}
#endif
static PyMethodDef PyMNNLLM_methods[] = {
{"load", (PyCFunction)PyMNNLLM_load, METH_VARARGS, "load model."},
{"forward", (PyCFunction)PyMNNLLM_forward, METH_VARARGS, "forward `logits` by `input_ids`."},
{"generate", (PyCFunction)PyMNNLLM_generate, METH_VARARGS, "generate `output_ids` by `input_ids`."},
{"response", (PyCFunction)PyMNNLLM_response, METH_VARARGS, "response `query` without hsitory."},
{"response", (PyCFunction)PyMNNLLM_response, METH_VARARGS, "response `query` - supports both text and multimodal input."},
{"get_current_history", (PyCFunction)PyMNNLLM_getCurrentHistory, METH_VARARGS, "Get Current History."},
{"erase_history", (PyCFunction)PyMNNLLM_eraseHistory, METH_VARARGS, "Erase History."},
{"tokenizer_encode", (PyCFunction)PyMNNLLM_tokenizer_encode, METH_VARARGS, "tokenizer encode."},
@ -219,6 +299,7 @@ static PyMethodDef PyMNNLLM_methods[] = {
{"create_lora", (PyCFunction)PyMNNLLM_create_lora, METH_VARARGS, "create_lora."},
{"set_config", (PyCFunction)PyMNNLLM_set_config, METH_VARARGS, "set_config."},
{"reset", (PyCFunction)PyMNNLLM_reset, METH_VARARGS, "reset."},
{"get_statistics", (PyCFunction)PyMNNLLM_get_statistics, METH_VARARGS, "get performance statistics."},
#ifdef PYMNN_LLM_COLLECTION
{"enable_collection_mode", (PyCFunction)PyMNNLLM_enable_collection_mode, METH_VARARGS, "Enable data collection mode."},
#endif
@ -274,7 +355,7 @@ static PyObject* PyMNNLLM_create_lora(LLM *self, PyObject *args) {
Py_RETURN_NONE;
}
auto lora = self->llm->create_lora(path);
LLM *llm = (LLM *)PyObject_Call((PyObject*)&PyMNNLLM, PyTuple_New(0), NULL);
LLM *llm = (LLM *)PyObject_Call((PyObject*)PyType_FindTLSType(&PyMNNLLM), PyTuple_New(0), NULL);
if (!llm) {
return NULL;
}
@ -288,10 +369,11 @@ static PyObject* PyMNNLLM_create(PyObject *self, PyObject *args) {
}
const char* path = NULL;
int embedding_model = 0;
if (!PyArg_ParseTuple(args, "s|p", &path, &embedding_model)) {
if (!PyArg_ParseTuple(args, "s|i", &path, &embedding_model)) {
PyMNN_ERROR_LOG("Invalid arguments. Usage: create(path, embedding_model=False)");
return NULL;
}
LLM *llm = (LLM *)PyObject_Call((PyObject*)&PyMNNLLM, PyTuple_New(0), NULL);
LLM *llm = (LLM *)PyObject_Call((PyObject*)PyType_FindTLSType(&PyMNNLLM), PyTuple_New(0), NULL);
if (!llm) {
return NULL;
}

View File

@ -398,6 +398,9 @@ static inline bool isPySequence(PyObject* obj) {
// use isPySequence replace PySequence_Check
return PyTuple_Check(obj) || PyList_Check(obj) || PyBytes_Check(obj);
}
static inline bool isPyDict(PyObject* obj) {
return PyDict_Check(obj);
}
static inline int PySequenceSize(PyObject* obj) {
if (PyTuple_Check(obj)) return PyTuple_Size(obj);
if (PyList_Check(obj)) return PyList_Size(obj);

View File

@ -1423,6 +1423,730 @@ static void Sme2MNNPackForMatMul_A_FP16(float* destOrigin, float const** sourceG
}
}
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
static void MNNAttenPackAndScaleSingleHead(float* dst, const float* srcHeadBase, size_t srcRowStride, const float* scale, const int32_t* units, size_t seqLen, size_t headDim) {
const int32_t eP = units[0];
const int32_t lP = units[1];
if (lP != 1 && lP != 2) {
MNN_ERROR("This function only supports lP=1 or 2\n");
return;
}
const float scaleVal = scale[0];
const float16x8_t vScale = vdupq_n_f16(scaleVal);
const size_t packedHeadDim = UP_DIV(headDim, lP);
const size_t dstStrideDOuter = (size_t)eP * lP;
const size_t dstStrideSOuter = packedHeadDim * dstStrideDOuter;
for (int s = 0; s < seqLen; ++s) {
const int sOuter = s / eP;
const int sInner = s % eP;
const FLOAT16* srcRowPtr = (FLOAT16*)srcHeadBase + s * srcRowStride;
FLOAT16* dstBasePtr = (FLOAT16*)dst + sOuter * dstStrideSOuter + sInner * lP;
if (lP == 1) {
size_t d = 0;
for (; d + 7 < headDim; d += 8) {
float16x8_t sVec = vld1q_f16(srcRowPtr + d);
sVec = vmulq_f16(sVec, vScale);
dstBasePtr[(d + 0) * dstStrideDOuter] = sVec[0];
dstBasePtr[(d + 1) * dstStrideDOuter] = sVec[1];
dstBasePtr[(d + 2) * dstStrideDOuter] = sVec[2];
dstBasePtr[(d + 3) * dstStrideDOuter] = sVec[3];
dstBasePtr[(d + 4) * dstStrideDOuter] = sVec[4];
dstBasePtr[(d + 5) * dstStrideDOuter] = sVec[5];
dstBasePtr[(d + 6) * dstStrideDOuter] = sVec[6];
dstBasePtr[(d + 7) * dstStrideDOuter] = sVec[7];
}
for (; d < headDim; ++d) {
dstBasePtr[d * dstStrideDOuter] = srcRowPtr[d] * scaleVal;
}
} else { // lP == 2
const FLOAT16* srcDPtr = srcRowPtr;
FLOAT16* dstDPtr = dstBasePtr;
size_t dRealSize = headDim;
while (dRealSize >= 16) {
float16x8_t s0 = vld1q_f16(srcDPtr);
float16x8_t s1 = vld1q_f16(srcDPtr + 8);
s0 = vmulq_f16(s0, vScale);
s1 = vmulq_f16(s1, vScale);
float16x4_t lowS0_f16 = vget_low_f16(s0); // {s0, s1, s2, s3}
float16x4_t highS0_f16 = vget_high_f16(s0); // {s4, s5, s6, s7}
uint32x2_t lowS0_u32 = vreinterpret_u32_f16(lowS0_f16);
uint32x2_t highS0_u32 = vreinterpret_u32_f16(highS0_f16);
*((uint32_t*)(dstDPtr + 0 * dstStrideDOuter)) = vget_lane_u32(lowS0_u32, 0); // Store pair {s0, s1}
*((uint32_t*)(dstDPtr + 1 * dstStrideDOuter)) = vget_lane_u32(lowS0_u32, 1); // Store pair {s2, s3}
*((uint32_t*)(dstDPtr + 2 * dstStrideDOuter)) = vget_lane_u32(highS0_u32, 0); // Store pair {s4, s5}
*((uint32_t*)(dstDPtr + 3 * dstStrideDOuter)) = vget_lane_u32(highS0_u32, 1); // Store pair {s6, s7}
float16x4_t lowS1_f16 = vget_low_f16(s1); // {s8, s9, s10, s11}
float16x4_t highS1_f16 = vget_high_f16(s1); // {s12, s13, s14, s15}
uint32x2_t lowS1_u32 = vreinterpret_u32_f16(lowS1_f16);
uint32x2_t highS1_u32 = vreinterpret_u32_f16(highS1_f16);
*((uint32_t*)(dstDPtr + 4 * dstStrideDOuter)) = vget_lane_u32(lowS1_u32, 0);
*((uint32_t*)(dstDPtr + 5 * dstStrideDOuter)) = vget_lane_u32(lowS1_u32, 1);
*((uint32_t*)(dstDPtr + 6 * dstStrideDOuter)) = vget_lane_u32(highS1_u32, 0);
*((uint32_t*)(dstDPtr + 7 * dstStrideDOuter)) = vget_lane_u32(highS1_u32, 1);
dRealSize -= 16;
srcDPtr += 16;
dstDPtr += 8 * dstStrideDOuter;
}
// Remainder loop with padding
while (dRealSize > 0) {
if (dRealSize >= 2) {
dstDPtr[0] = srcDPtr[0] * scaleVal;
dstDPtr[1] = srcDPtr[1] * scaleVal;
dRealSize -= 2;
srcDPtr += 2;
dstDPtr += dstStrideDOuter;
} else { // dRealSize == 1
dstDPtr[0] = srcDPtr[0] * scaleVal;
dstDPtr[1] = (FLOAT16)0.0f; // Pad with zero
dRealSize = 0;
}
}
}
}
}
static void MNNFlashAttentionUpdateBlockOutput( float* dst, float* src, const float* scale, const float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) {
auto dstPtr = (float16_t*)dst;
auto srcPtr = (float16_t*)src;
const auto stride0 = plane * pack;
if (idx == 0) {
memcpy(dst, src, size * bytes);
} else {
for (int j = 0; j < depthQuad; ++j) {
const auto baseOffset = j * stride0;
int i = 0;
const int plane4 = plane - (plane % 4);
for (; i < plane4; i += 4) {
auto pdst0 = dstPtr + baseOffset + (i + 0) * pack;
auto psrc0 = srcPtr + baseOffset + (i + 0) * pack;
auto pdst1 = dstPtr + baseOffset + (i + 1) * pack;
auto psrc1 = srcPtr + baseOffset + (i + 1) * pack;
auto pdst2 = dstPtr + baseOffset + (i + 2) * pack;
auto psrc2 = srcPtr + baseOffset + (i + 2) * pack;
auto pdst3 = dstPtr + baseOffset + (i + 3) * pack;
auto psrc3 = srcPtr + baseOffset + (i + 3) * pack;
float16x8_t src0 = vld1q_f16(psrc0);
float16x8_t dst0 = vld1q_f16(pdst0);
float16x8_t src1 = vld1q_f16(psrc1);
float16x8_t dst1 = vld1q_f16(pdst1);
float16x8_t src2 = vld1q_f16(psrc2);
float16x8_t dst2 = vld1q_f16(pdst2);
float16x8_t src3 = vld1q_f16(psrc3);
float16x8_t dst3 = vld1q_f16(pdst3);
float32x4_t svec0 = vdupq_n_f32(scale[i + 0]);
float32x4_t svec1 = vdupq_n_f32(scale[i + 1]);
float32x4_t svec2 = vdupq_n_f32(scale[i + 2]);
float32x4_t svec3 = vdupq_n_f32(scale[i + 3]);
float32x4_t res00 = vfmaq_f32(vcvt_f32_f16(vget_low_f16(src0)), vcvt_f32_f16(vget_low_f16(dst0)), svec0);
float32x4_t res10 = vfmaq_f32(vcvt_f32_f16(vget_high_f16(src0)), vcvt_f32_f16(vget_high_f16(dst0)), svec0);
float32x4_t res01 = vfmaq_f32(vcvt_f32_f16(vget_low_f16(src1)), vcvt_f32_f16(vget_low_f16(dst1)), svec1);
float32x4_t res11 = vfmaq_f32(vcvt_f32_f16(vget_high_f16(src1)), vcvt_f32_f16(vget_high_f16(dst1)), svec1);
float32x4_t res02 = vfmaq_f32(vcvt_f32_f16(vget_low_f16(src2)), vcvt_f32_f16(vget_low_f16(dst2)), svec2);
float32x4_t res12 = vfmaq_f32(vcvt_f32_f16(vget_high_f16(src2)), vcvt_f32_f16(vget_high_f16(dst2)), svec2);
float32x4_t res03 = vfmaq_f32(vcvt_f32_f16(vget_low_f16(src3)), vcvt_f32_f16(vget_low_f16(dst3)), svec3);
float32x4_t res13 = vfmaq_f32(vcvt_f32_f16(vget_high_f16(src3)), vcvt_f32_f16(vget_high_f16(dst3)), svec3);
vst1q_f16(pdst0, vcombine_f16(vcvt_f16_f32(res00), vcvt_f16_f32(res10)));
vst1q_f16(pdst1, vcombine_f16(vcvt_f16_f32(res01), vcvt_f16_f32(res11)));
vst1q_f16(pdst2, vcombine_f16(vcvt_f16_f32(res02), vcvt_f16_f32(res12)));
vst1q_f16(pdst3, vcombine_f16(vcvt_f16_f32(res03), vcvt_f16_f32(res13)));
}
for (; i < plane; ++i) {
auto pdst = dstPtr + baseOffset + i * pack;
auto psrc = srcPtr + baseOffset + i * pack;
float16x8_t srcF16 = vld1q_f16(psrc);
float16x8_t dstF16 = vld1q_f16(pdst);
float32x4_t svec = vdupq_n_f32(scale[i]);
float32x4_t s0 = vcvt_f32_f16(vget_low_f16(srcF16));
float32x4_t s1 = vcvt_f32_f16(vget_high_f16(srcF16));
float32x4_t d0 = vcvt_f32_f16(vget_low_f16(dstF16));
float32x4_t d1 = vcvt_f32_f16(vget_high_f16(dstF16));
float32x4_t res0 = vfmaq_f32(s0, d0, svec);
float32x4_t res1 = vfmaq_f32(s1, d1, svec);
vst1q_f16(pdst, vcombine_f16(vcvt_f16_f32(res0), vcvt_f16_f32(res1)));
}
}
}
if (idx == kvBlocks - 1) {
for (int j = 0; j < depthQuad; ++j) {
const auto baseOffset = j * stride0;
int i = 0;
const int plane4 = plane - (plane % 4);
for (; i < plane4; i += 4) {
auto pdst0 = dstPtr + baseOffset + (i + 0) * pack;
auto pdst1 = dstPtr + baseOffset + (i + 1) * pack;
auto pdst2 = dstPtr + baseOffset + (i + 2) * pack;
auto pdst3 = dstPtr + baseOffset + (i + 3) * pack;
float16x8_t dst0 = vld1q_f16(pdst0);
float16x8_t dst1 = vld1q_f16(pdst1);
float16x8_t dst2 = vld1q_f16(pdst2);
float16x8_t dst3 = vld1q_f16(pdst3);
float32x4_t ns0 = vdupq_n_f32(1.0f / normalizeScale[i + 0]);
float32x4_t ns1 = vdupq_n_f32(1.0f / normalizeScale[i + 1]);
float32x4_t ns2 = vdupq_n_f32(1.0f / normalizeScale[i + 2]);
float32x4_t ns3 = vdupq_n_f32(1.0f / normalizeScale[i + 3]);
float32x4_t d00 = vmulq_f32(vcvt_f32_f16(vget_low_f16(dst0)), ns0);
float32x4_t d10 = vmulq_f32(vcvt_f32_f16(vget_high_f16(dst0)), ns0);
float32x4_t d01 = vmulq_f32(vcvt_f32_f16(vget_low_f16(dst1)), ns1);
float32x4_t d11 = vmulq_f32(vcvt_f32_f16(vget_high_f16(dst1)), ns1);
float32x4_t d02 = vmulq_f32(vcvt_f32_f16(vget_low_f16(dst2)), ns2);
float32x4_t d12 = vmulq_f32(vcvt_f32_f16(vget_high_f16(dst2)), ns2);
float32x4_t d03 = vmulq_f32(vcvt_f32_f16(vget_low_f16(dst3)), ns3);
float32x4_t d13 = vmulq_f32(vcvt_f32_f16(vget_high_f16(dst3)), ns3);
vst1q_f16(pdst0, vcombine_f16(vcvt_f16_f32(d00), vcvt_f16_f32(d10)));
vst1q_f16(pdst1, vcombine_f16(vcvt_f16_f32(d01), vcvt_f16_f32(d11)));
vst1q_f16(pdst2, vcombine_f16(vcvt_f16_f32(d02), vcvt_f16_f32(d12)));
vst1q_f16(pdst3, vcombine_f16(vcvt_f16_f32(d03), vcvt_f16_f32(d13)));
}
for (; i < plane; ++i) {
auto pdst = dstPtr + baseOffset + i * pack;
float32x4_t nsvec = vdupq_n_f32(1.0f / normalizeScale[i]);
float16x8_t dstF16 = vld1q_f16(pdst);
float32x4_t d0 = vcvt_f32_f16(vget_low_f16(dstF16));
float32x4_t d1 = vcvt_f32_f16(vget_high_f16(dstF16));
d0 = vmulq_f32(d0, nsvec);
d1 = vmulq_f32(d1, nsvec);
vst1q_f16(pdst, vcombine_f16(vcvt_f16_f32(d0), vcvt_f16_f32(d1)));
}
}
}
}
static void MNNAttenUnpackAndConvertFp16(float* dst, float* src, size_t depth, size_t planesize, int pack) {
// src: (UP_DIV(depth, pack), planesize, pack), float16
// dst: (planesize, depth), float32
// pack=8
if (planesize == 1) {
MNNDequantizeFP16((int16_t*)src, dst, depth);
return; // no need to convert
}
const auto depthDiv8 = UP_DIV(depth, pack);
const auto srcStep = pack * planesize;
const auto dstStep = depth;
auto remainDepth = depth % pack;
auto depthQuad = depthDiv8;
if (remainDepth > 0) {
depthQuad -= 1; // last quad is not full
}
for (int i = 0; i < depthQuad; ++i) {
auto realsize = planesize;
auto srcPtr = (FLOAT16*)src + i * srcStep;
auto dstPtr = (float*)dst + i * pack;
while (realsize >= 8) {
float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack);
float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack);
float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack);
float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack);
float16x8_t s4_f16 = vld1q_f16(srcPtr + 4 * pack);
float16x8_t s5_f16 = vld1q_f16(srcPtr + 5 * pack);
float16x8_t s6_f16 = vld1q_f16(srcPtr + 6 * pack);
float16x8_t s7_f16 = vld1q_f16(srcPtr + 7 * pack);
float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16));
float32x4_t d01_f32 = vcvt_f32_f16(vget_high_f16(s0_f16));
float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16));
float32x4_t d11_f32 = vcvt_f32_f16(vget_high_f16(s1_f16));
float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16));
float32x4_t d21_f32 = vcvt_f32_f16(vget_high_f16(s2_f16));
float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16));
float32x4_t d31_f32 = vcvt_f32_f16(vget_high_f16(s3_f16));
float32x4_t d40_f32 = vcvt_f32_f16(vget_low_f16(s4_f16));
float32x4_t d41_f32 = vcvt_f32_f16(vget_high_f16(s4_f16));
float32x4_t d50_f32 = vcvt_f32_f16(vget_low_f16(s5_f16));
float32x4_t d51_f32 = vcvt_f32_f16(vget_high_f16(s5_f16));
float32x4_t d60_f32 = vcvt_f32_f16(vget_low_f16(s6_f16));
float32x4_t d61_f32 = vcvt_f32_f16(vget_high_f16(s6_f16));
float32x4_t d70_f32 = vcvt_f32_f16(vget_low_f16(s7_f16));
float32x4_t d71_f32 = vcvt_f32_f16(vget_high_f16(s7_f16));
vst1q_f32(dstPtr + 0 * dstStep, d00_f32); vst1q_f32(dstPtr + 0 * dstStep + 4, d01_f32);
vst1q_f32(dstPtr + 1 * dstStep, d10_f32); vst1q_f32(dstPtr + 1 * dstStep + 4, d11_f32);
vst1q_f32(dstPtr + 2 * dstStep, d20_f32); vst1q_f32(dstPtr + 2 * dstStep + 4, d21_f32);
vst1q_f32(dstPtr + 3 * dstStep, d30_f32); vst1q_f32(dstPtr + 3 * dstStep + 4, d31_f32);
vst1q_f32(dstPtr + 4 * dstStep, d40_f32); vst1q_f32(dstPtr + 4 * dstStep + 4, d41_f32);
vst1q_f32(dstPtr + 5 * dstStep, d50_f32); vst1q_f32(dstPtr + 5 * dstStep + 4, d51_f32);
vst1q_f32(dstPtr + 6 * dstStep, d60_f32); vst1q_f32(dstPtr + 6 * dstStep + 4, d61_f32);
vst1q_f32(dstPtr + 7 * dstStep, d70_f32); vst1q_f32(dstPtr + 7 * dstStep + 4, d71_f32);
srcPtr += 8 * pack;
dstPtr += 8 * dstStep;
realsize -= 8;
}
if (realsize >= 4) {
float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack);
float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack);
float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack);
float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack);
float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16));
float32x4_t d01_f32 = vcvt_f32_f16(vget_high_f16(s0_f16));
float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16));
float32x4_t d11_f32 = vcvt_f32_f16(vget_high_f16(s1_f16));
float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16));
float32x4_t d21_f32 = vcvt_f32_f16(vget_high_f16(s2_f16));
float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16));
float32x4_t d31_f32 = vcvt_f32_f16(vget_high_f16(s3_f16));
vst1q_f32(dstPtr + 0 * dstStep, d00_f32); vst1q_f32(dstPtr + 0 * dstStep + 4, d01_f32);
vst1q_f32(dstPtr + 1 * dstStep, d10_f32); vst1q_f32(dstPtr + 1 * dstStep + 4, d11_f32);
vst1q_f32(dstPtr + 2 * dstStep, d20_f32); vst1q_f32(dstPtr + 2 * dstStep + 4, d21_f32);
vst1q_f32(dstPtr + 3 * dstStep, d30_f32); vst1q_f32(dstPtr + 3 * dstStep + 4, d31_f32);
srcPtr += 4 * pack;
dstPtr += 4 * dstStep;
realsize -= 4;
}
while (realsize > 0) {
auto s0_fp16 = vld1q_f16(srcPtr);
auto s00_fp32 = vcvt_f32_f16(vget_low_f16(s0_fp16));
auto s01_fp32 = vcvt_f32_f16(vget_high_f16(s0_fp16));
vst1q_f32(dstPtr, s00_fp32);
vst1q_f32(dstPtr + 4, s01_fp32);
srcPtr += pack;
dstPtr += dstStep;
realsize--;
}
}
// process remain depth < 8
if (remainDepth >= 4) {
auto realsize = planesize;
auto srcPtr = (FLOAT16*)src + (depthDiv8 - 1) * srcStep;
auto dstPtr = (float*)dst + (depthDiv8 - 1) * pack;
auto extraDepth = remainDepth - 4;
float tmp0[4];
float tmp1[4];
float tmp2[4];
float tmp3[4];
float tmp4[4];
float tmp5[4];
float tmp6[4];
float tmp7[4];
while (realsize >= 8) {
float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack);
float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack);
float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack);
float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack);
float16x8_t s4_f16 = vld1q_f16(srcPtr + 4 * pack);
float16x8_t s5_f16 = vld1q_f16(srcPtr + 5 * pack);
float16x8_t s6_f16 = vld1q_f16(srcPtr + 6 * pack);
float16x8_t s7_f16 = vld1q_f16(srcPtr + 7 * pack);
float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16));
float32x4_t d01_f32 = vcvt_f32_f16(vget_high_f16(s0_f16));
float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16));
float32x4_t d11_f32 = vcvt_f32_f16(vget_high_f16(s1_f16));
float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16));
float32x4_t d21_f32 = vcvt_f32_f16(vget_high_f16(s2_f16));
float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16));
float32x4_t d31_f32 = vcvt_f32_f16(vget_high_f16(s3_f16));
float32x4_t d40_f32 = vcvt_f32_f16(vget_low_f16(s4_f16));
float32x4_t d41_f32 = vcvt_f32_f16(vget_high_f16(s4_f16));
float32x4_t d50_f32 = vcvt_f32_f16(vget_low_f16(s5_f16));
float32x4_t d51_f32 = vcvt_f32_f16(vget_high_f16(s5_f16));
float32x4_t d60_f32 = vcvt_f32_f16(vget_low_f16(s6_f16));
float32x4_t d61_f32 = vcvt_f32_f16(vget_high_f16(s6_f16));
float32x4_t d70_f32 = vcvt_f32_f16(vget_low_f16(s7_f16));
float32x4_t d71_f32 = vcvt_f32_f16(vget_high_f16(s7_f16));
vst1q_f32(dstPtr + 0 * dstStep, d00_f32); vst1q_f32(tmp0, d01_f32);
vst1q_f32(dstPtr + 1 * dstStep, d10_f32); vst1q_f32(tmp1, d11_f32);
vst1q_f32(dstPtr + 2 * dstStep, d20_f32); vst1q_f32(tmp2, d21_f32);
vst1q_f32(dstPtr + 3 * dstStep, d30_f32); vst1q_f32(tmp3, d31_f32);
vst1q_f32(dstPtr + 4 * dstStep, d40_f32); vst1q_f32(tmp4, d41_f32);
vst1q_f32(dstPtr + 5 * dstStep, d50_f32); vst1q_f32(tmp5, d51_f32);
vst1q_f32(dstPtr + 6 * dstStep, d60_f32); vst1q_f32(tmp6, d61_f32);
vst1q_f32(dstPtr + 7 * dstStep, d70_f32); vst1q_f32(tmp7, d71_f32);
memcpy(dstPtr + 0 * dstStep + 4, tmp0, sizeof(float) * extraDepth);
memcpy(dstPtr + 1 * dstStep + 4, tmp1, sizeof(float) * extraDepth);
memcpy(dstPtr + 2 * dstStep + 4, tmp2, sizeof(float) * extraDepth);
memcpy(dstPtr + 3 * dstStep + 4, tmp3, sizeof(float) * extraDepth);
memcpy(dstPtr + 4 * dstStep + 4, tmp4, sizeof(float) * extraDepth);
memcpy(dstPtr + 5 * dstStep + 4, tmp5, sizeof(float) * extraDepth);
memcpy(dstPtr + 6 * dstStep + 4, tmp6, sizeof(float) * extraDepth);
memcpy(dstPtr + 7 * dstStep + 4, tmp7, sizeof(float) * extraDepth);
srcPtr += 8 * pack;
dstPtr += 8 * dstStep;
realsize -= 8;
}
if (realsize >= 4) {
float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack);
float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack);
float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack);
float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack);
float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16));
float32x4_t d01_f32 = vcvt_f32_f16(vget_high_f16(s0_f16));
float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16));
float32x4_t d11_f32 = vcvt_f32_f16(vget_high_f16(s1_f16));
float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16));
float32x4_t d21_f32 = vcvt_f32_f16(vget_high_f16(s2_f16));
float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16));
float32x4_t d31_f32 = vcvt_f32_f16(vget_high_f16(s3_f16));
vst1q_f32(dstPtr + 0 * dstStep, d00_f32); vst1q_f32(tmp0, d01_f32);
vst1q_f32(dstPtr + 1 * dstStep, d10_f32); vst1q_f32(tmp1, d11_f32);
vst1q_f32(dstPtr + 2 * dstStep, d20_f32); vst1q_f32(tmp2, d21_f32);
vst1q_f32(dstPtr + 3 * dstStep, d30_f32); vst1q_f32(tmp3, d31_f32);
memcpy(dstPtr + 0 * dstStep + 4, tmp0, sizeof(float) * extraDepth);
memcpy(dstPtr + 1 * dstStep + 4, tmp1, sizeof(float) * extraDepth);
memcpy(dstPtr + 2 * dstStep + 4, tmp2, sizeof(float) * extraDepth);
memcpy(dstPtr + 3 * dstStep + 4, tmp3, sizeof(float) * extraDepth);
srcPtr += 4 * pack;
dstPtr += 4 * dstStep;
realsize -= 4;
}
while (realsize > 0) {
auto s0_fp16 = vld1q_f16(srcPtr);
auto d00_fp32 = vcvt_f32_f16(vget_low_f16(s0_fp16));
auto d01_fp32 = vcvt_f32_f16(vget_high_f16(s0_fp16));
vst1q_f32(dstPtr, d00_fp32);
vst1q_f32(tmp0, d01_fp32);
memcpy(dstPtr + 4, tmp0, sizeof(float) * extraDepth);
srcPtr += pack;
dstPtr += dstStep;
realsize--;
}
}
if (remainDepth > 0 && remainDepth < 4) {
auto realsize = planesize;
auto srcPtr = (FLOAT16*)src + (depthDiv8 - 1) * srcStep;
auto dstPtr = (float*)dst + (depthDiv8 - 1) * pack;
float tmp0[4];
float tmp1[4];
float tmp2[4];
float tmp3[4];
float tmp4[4];
float tmp5[4];
float tmp6[4];
float tmp7[4];
while (realsize >= 8) {
float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack);
float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack);
float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack);
float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack);
float16x8_t s4_f16 = vld1q_f16(srcPtr + 4 * pack);
float16x8_t s5_f16 = vld1q_f16(srcPtr + 5 * pack);
float16x8_t s6_f16 = vld1q_f16(srcPtr + 6 * pack);
float16x8_t s7_f16 = vld1q_f16(srcPtr + 7 * pack);
float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16));
float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16));
float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16));
float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16));
float32x4_t d40_f32 = vcvt_f32_f16(vget_low_f16(s4_f16));
float32x4_t d50_f32 = vcvt_f32_f16(vget_low_f16(s5_f16));
float32x4_t d60_f32 = vcvt_f32_f16(vget_low_f16(s6_f16));
float32x4_t d70_f32 = vcvt_f32_f16(vget_low_f16(s7_f16));
vst1q_f32(tmp0, d00_f32);
vst1q_f32(tmp1, d10_f32);
vst1q_f32(tmp2, d20_f32);
vst1q_f32(tmp3, d30_f32);
vst1q_f32(tmp4, d40_f32);
vst1q_f32(tmp5, d50_f32);
vst1q_f32(tmp6, d60_f32);
vst1q_f32(tmp7, d70_f32);
memcpy(dstPtr + 0 * dstStep, tmp0, sizeof(float) * remainDepth);
memcpy(dstPtr + 1 * dstStep, tmp1, sizeof(float) * remainDepth);
memcpy(dstPtr + 2 * dstStep, tmp2, sizeof(float) * remainDepth);
memcpy(dstPtr + 3 * dstStep, tmp3, sizeof(float) * remainDepth);
memcpy(dstPtr + 4 * dstStep, tmp4, sizeof(float) * remainDepth);
memcpy(dstPtr + 5 * dstStep, tmp5, sizeof(float) * remainDepth);
memcpy(dstPtr + 6 * dstStep, tmp6, sizeof(float) * remainDepth);
memcpy(dstPtr + 7 * dstStep, tmp7, sizeof(float) * remainDepth);
srcPtr += 8 * pack;
dstPtr += 8 * dstStep;
realsize -= 8;
}
if (realsize >= 4) {
float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack);
float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack);
float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack);
float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack);
float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16));
float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16));
float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16));
float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16));
vst1q_f32(tmp0, d00_f32);
vst1q_f32(tmp1, d10_f32);
vst1q_f32(tmp2, d20_f32);
vst1q_f32(tmp3, d30_f32);
memcpy(dstPtr + 0 * dstStep, tmp0, sizeof(float) * remainDepth);
memcpy(dstPtr + 1 * dstStep, tmp1, sizeof(float) * remainDepth);
memcpy(dstPtr + 2 * dstStep, tmp2, sizeof(float) * remainDepth);
memcpy(dstPtr + 3 * dstStep, tmp3, sizeof(float) * remainDepth);
srcPtr += 4 * pack;
dstPtr += 4 * dstStep;
realsize -= 4;
}
while (realsize > 0) {
auto s0_f16 = vld1q_f16(srcPtr);
float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16));
vst1q_f32(tmp0, d00_f32);
memcpy(dstPtr + 0 * dstStep, tmp0, sizeof(float) * remainDepth);
srcPtr += pack;
dstPtr += dstStep;
realsize--;
}
}
}
static void MNNAttenPackAndConvertFp32LP1(float* dst, const float* src, const int32_t* units, size_t depth, size_t planesize) {
int32_t eP = units[0];
int32_t lP = units[1];
if (lP != 1) {
MNN_ERROR("This function only supports lP=1\n");
return;
}
auto dstStride1 = eP;
auto dstStride0 = planesize * dstStride1;
for (int i = 0; i < depth; ++i) {
size_t realsize = planesize;
const float* srcPtr = src + i * planesize;
FLOAT16* dstPtr = (FLOAT16*)dst + (i % eP) + (i / eP) * dstStride0;
while (realsize >= 16) {
float32x4_t s0_f32 = vld1q_f32(srcPtr);
float32x4_t s1_f32 = vld1q_f32(srcPtr + 4);
float32x4_t s2_f32 = vld1q_f32(srcPtr + 8);
float32x4_t s3_f32 = vld1q_f32(srcPtr + 12);
float16x4_t d0_f16 = vcvt_f16_f32(s0_f32);
float16x4_t d1_f16 = vcvt_f16_f32(s1_f32);
float16x4_t d2_f16 = vcvt_f16_f32(s2_f32);
float16x4_t d3_f16 = vcvt_f16_f32(s3_f32);
vst1_lane_f16(dstPtr, d0_f16, 0);
vst1_lane_f16(dstPtr + dstStride1, d0_f16, 1);
vst1_lane_f16(dstPtr + 2 * dstStride1, d0_f16, 2);
vst1_lane_f16(dstPtr + 3 * dstStride1, d0_f16, 3);
vst1_lane_f16(dstPtr + 4 * dstStride1, d1_f16, 0);
vst1_lane_f16(dstPtr + 5 * dstStride1, d1_f16, 1);
vst1_lane_f16(dstPtr + 6 * dstStride1, d1_f16, 2);
vst1_lane_f16(dstPtr + 7 * dstStride1, d1_f16, 3);
vst1_lane_f16(dstPtr + 8 * dstStride1, d2_f16, 0);
vst1_lane_f16(dstPtr + 9 * dstStride1, d2_f16, 1);
vst1_lane_f16(dstPtr + 10 * dstStride1, d2_f16, 2);
vst1_lane_f16(dstPtr + 11 * dstStride1, d2_f16, 3);
vst1_lane_f16(dstPtr + 12 * dstStride1, d3_f16, 0);
vst1_lane_f16(dstPtr + 13 * dstStride1, d3_f16, 1);
vst1_lane_f16(dstPtr + 14 * dstStride1, d3_f16, 2);
vst1_lane_f16(dstPtr + 15 * dstStride1, d3_f16, 3);
srcPtr += 16;
dstPtr += 16 * dstStride1;
realsize -= 16;
}
if (realsize >= 8) {
float32x4_t s0_f32 = vld1q_f32(srcPtr);
float32x4_t s1_f32 = vld1q_f32(srcPtr + 4);
float16x4_t d0_f16 = vcvt_f16_f32(s0_f32);
float16x4_t d1_f16 = vcvt_f16_f32(s1_f32);
vst1_lane_f16(dstPtr, d0_f16, 0);
vst1_lane_f16(dstPtr + dstStride1, d0_f16, 1);
vst1_lane_f16(dstPtr + 2 * dstStride1, d0_f16, 2);
vst1_lane_f16(dstPtr + 3 * dstStride1, d0_f16, 3);
vst1_lane_f16(dstPtr + 4 * dstStride1, d1_f16, 0);
vst1_lane_f16(dstPtr + 5 * dstStride1, d1_f16, 1);
vst1_lane_f16(dstPtr + 6 * dstStride1, d1_f16, 2);
vst1_lane_f16(dstPtr + 7 * dstStride1, d1_f16, 3);
srcPtr += 8;
dstPtr += 8 * dstStride1;
realsize -= 8;
}
if (realsize >= 4) {
float32x4_t s0_f32 = vld1q_f32(srcPtr);
float16x4_t d0_f16 = vcvt_f16_f32(s0_f32);
vst1_lane_f16(dstPtr, d0_f16, 0);
vst1_lane_f16(dstPtr + dstStride1, d0_f16, 1);
vst1_lane_f16(dstPtr + 2 * dstStride1, d0_f16, 2);
vst1_lane_f16(dstPtr + 3 * dstStride1, d0_f16, 3);
srcPtr += 4;
dstPtr += 4 * dstStride1;
realsize -= 4;
}
for (; realsize > 0; --realsize) {
*dstPtr = (FLOAT16)(*srcPtr);
srcPtr++;
dstPtr += dstStride1;
}
}
}
static void MNNAttenPackAndConvertFp32(float* dst, float* src, const int32_t* units, size_t depth, size_t planesize) {
int32_t eP = units[0];
int32_t lP = units[1]; // Now lP=1 or 2
if (lP != 1 && lP != 2) {
MNN_ERROR("This function only supports lP=1 or 2\n");
return;
}
// src [depth, planesize] (float32)
// dst [depth/eP, planesize/lP, eP, lP] (float16)
if (lP == 1) {
MNNAttenPackAndConvertFp32LP1(dst, src, units, depth, planesize);
return;
}
auto dstStride1 = eP * lP;
auto dstStride0 = UP_DIV(planesize, lP) * dstStride1;
for (int i = 0; i < depth; ++i) {
size_t realsize = planesize;
const float* srcPtr = src + i * planesize;
FLOAT16* dstPtr = (FLOAT16*)dst + (i % eP) * lP + (i / eP) * dstStride0;
while (realsize >= 16) {
float32x4_t s0 = vld1q_f32(srcPtr);
float32x4_t s1 = vld1q_f32(srcPtr + 4);
float32x4_t s2 = vld1q_f32(srcPtr + 8);
float32x4_t s3 = vld1q_f32(srcPtr + 12);
float16x4_t h0 = vcvt_f16_f32(s0);
float16x4_t h1 = vcvt_f16_f32(s1);
float16x4_t h2 = vcvt_f16_f32(s2);
float16x4_t h3 = vcvt_f16_f32(s3);
vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0);
vst1_lane_u32((uint32_t*)(dstPtr + dstStride1), vreinterpret_u32_f16(h0), 1);
vst1_lane_u32((uint32_t*)(dstPtr + 2 * dstStride1), vreinterpret_u32_f16(h1), 0);
vst1_lane_u32((uint32_t*)(dstPtr + 3 * dstStride1), vreinterpret_u32_f16(h1), 1);
vst1_lane_u32((uint32_t*)(dstPtr + 4 * dstStride1), vreinterpret_u32_f16(h2), 0);
vst1_lane_u32((uint32_t*)(dstPtr + 5 * dstStride1), vreinterpret_u32_f16(h2), 1);
vst1_lane_u32((uint32_t*)(dstPtr + 6 * dstStride1), vreinterpret_u32_f16(h3), 0);
vst1_lane_u32((uint32_t*)(dstPtr + 7 * dstStride1), vreinterpret_u32_f16(h3), 1);
realsize -= 16;
srcPtr += 16;
dstPtr += 8 * dstStride1;
}
if (realsize >= 8) {
float32x4_t s0 = vld1q_f32(srcPtr);
float32x4_t s1 = vld1q_f32(srcPtr + 4);
float16x4_t h0 = vcvt_f16_f32(s0);
float16x4_t h1 = vcvt_f16_f32(s1);
vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0);
vst1_lane_u32((uint32_t*)(dstPtr + dstStride1), vreinterpret_u32_f16(h0), 1);
vst1_lane_u32((uint32_t*)(dstPtr + 2 * dstStride1), vreinterpret_u32_f16(h1), 0);
vst1_lane_u32((uint32_t*)(dstPtr + 3 * dstStride1), vreinterpret_u32_f16(h1), 1);
realsize -= 8;
srcPtr += 8;
dstPtr += 4 * dstStride1;
}
if (realsize >= 4) {
float32x4_t s0 = vld1q_f32(srcPtr);
float16x4_t h0 = vcvt_f16_f32(s0);
vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0);
vst1_lane_u32((uint32_t*)(dstPtr + dstStride1), vreinterpret_u32_f16(h0), 1);
realsize -= 4;
srcPtr += 4;
dstPtr += 2 * dstStride1;
}
if (realsize >= 2) {
float32x2_t s0 = vld1_f32(srcPtr);
float16x4_t h0 = vcvt_f16_f32(vcombine_f32(s0, s0));
vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0);
realsize -= 2;
srcPtr += 2;
dstPtr += dstStride1;
}
if (realsize > 0) {
dstPtr[0] = (FLOAT16)srcPtr[0];
dstPtr[1] = (FLOAT16)0.0f;
}
}
}
#endif // MNN_SUPPORT_TRANSFORMER_FUSE
#ifdef MNN_LOW_MEMORY
void MNNAbsMaxFP16(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) {
if (pack == 4) {
@ -1659,6 +2383,7 @@ static void MNNAsyQuantInfo_FP16(float* scale, float* bias, float* qscale, float
}
#endif
}
#endif // MNN_LOW_MEMORY
static CoreFunctions* gInstance = nullptr;
@ -1726,6 +2451,8 @@ bool Arm82Functions::init() {
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain, MNNPackedMatMulRemainFP16);
FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Arm82MNNPackForMatMul_A);
FUNC_PTR_ASSIGN(gInstance->MNNPackForMatMul_B, Arm82MNNPackForMatMul_B);
FUNC_PTR_ASSIGN(gInstance->MNNSoftmax, origin->MNNSoftmax);
#if defined(__aarch64__)
gInstance->supportFp16arith = origin->supportFp16arith;
gInstance->supportSDot = origin->supportSDot;
@ -1755,7 +2482,16 @@ bool Arm82Functions::init() {
FUNC_PTR_ASSIGN(gInstance->MNNCountMaxMinValue, ARM82CountMinMaxValue); // return one min&max
FUNC_PTR_ASSIGN(gInstance->MNNSumByAxisLForMatmul_A, origin->MNNSumByAxisLForMatmul_A);
FUNC_PTR_ASSIGN(gInstance->MNNDepthwiseConvFastKernel, MNNDepthwiseConvFastKernelFP16);
#endif
#endif // __aarch64__
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
// Attention
FUNC_PTR_ASSIGN(gInstance->MNNAttenUnpackAndConvertFp16, MNNAttenUnpackAndConvertFp16);
FUNC_PTR_ASSIGN(gInstance->MNNAttenPackAndConvertFp32, MNNAttenPackAndConvertFp32);
FUNC_PTR_ASSIGN(gInstance->MNNAttenPackAndScaleSingleHead, MNNAttenPackAndScaleSingleHead);
FUNC_PTR_ASSIGN(gInstance->MNNFlashAttentionUpdateBlockOutput, MNNFlashAttentionUpdateBlockOutput);
#endif // MNN_SUPPORT_TRANSFORMER_FUSE
gInstance->MNNComputeMatMulForH_1 = _MNNComputeMatMulForH_1_FP16;
gInstance->MNNComputeMatMulForE_1 = _MNNComputeMatMulForE_1_FP16;

View File

@ -24,10 +24,12 @@
#define FLOAT16_T float
#endif
#define MNN_FLASH_ATTENTION_BLOCK_SIZE 64
namespace MNN {
template <typename T>
void CPUAttention::pack_query(Tensor* query, char* pack_q, char* sum_q, int seq_len, int h, float q_scale) {
void CPUAttention::pack_query(Tensor* query, int8_t* pack_q, int8_t* sum_q, int seq_len, int h, float q_scale) {
if (mUseGemmInt8) { // Shape of Query: numhead, [seqlen/eP8, headdim/lP8, eP8, lP8]
mMinQ[h] = query->host<T>()[h * mHeadDim];
mMaxQ[h] = query->host<T>()[h * mHeadDim];
@ -75,21 +77,21 @@ void CPUAttention::pack_query(Tensor* query, char* pack_q, char* sum_q, int seq_
}
template <typename T>
void CPUAttention::unpack_QK(float * unpack_qk_dst, char * pack_qk_src, int seq_len, int kv_seq_len) {
void CPUAttention::unpack_QK(float * unpack_qk_dst, int8_t * pack_qk_src, int seq_len, int kv_seq_len) {
float * dst = unpack_qk_dst;
T * src = (T *)(pack_qk_src);
// [kv_seq_len/unit, seq_len, unit] -> [seq_len, kv_seq_len]
// [kv_seq_len/mPack, seq_len, mPack] -> [seq_len, kv_seq_len]
for (int i = 0; i < seq_len; i++) {
for (int j = 0; j < kv_seq_len; j++) {
int out_index = j / unit;
int in_index = j % unit;
dst[i * kv_seq_len + j] = src[out_index * seq_len * unit + i * unit + in_index];
int out_index = j / mPack;
int in_index = j % mPack;
dst[i * kv_seq_len + j] = src[out_index * seq_len * mPack + i * mPack + in_index];
}
}
}
template <typename T>
static void pack_QK(char * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_len, int eP, int lP, int bytes) {
static void pack_QK(int8_t * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_len, int eP, int lP, int bytes) {
T * dst = reinterpret_cast<T*>(pack_qk_dst);
float * src = reinterpret_cast<float*>(qk_src);
// [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len/lP, eP, lP]
@ -108,38 +110,50 @@ static void pack_QK(char * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_
}
template <typename T>
static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale, float min_val, const Tensor* mask) {
if (mask == nullptr) {
for (int i = 0; i < kv_seq_len; i++) {
static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale, float min_val, const Tensor* maskTensor, int offset, int startIndx, int processedKvLen) {
int endIndx = startIndx + processedKvLen;
if (maskTensor == nullptr) {
for (int i = 0; i < processedKvLen; i++) {
unpack_qk[i] = unpack_qk[i] * mScale;
}
} else if (mask->getType() == halide_type_of<float>()) {
return;
}
const int8_t* mask = maskTensor->host<int8_t>();
halide_type_t htype = maskTensor->getType();
int maskSize = maskTensor->elementSize();
if (htype == halide_type_of<float>()) {
// float mask
T* fpmask_ptr = mask->host<T>();
if (mask->elementSize() == seq_len * kv_seq_len) {
// normal mask for all token
for (int i = 0; i < seq_len * kv_seq_len; i++) {
unpack_qk[i] = unpack_qk[i] * mScale + fpmask_ptr[i];
}
} else {
// square mask just for new generation token
int offset = kv_seq_len - seq_len;
T* fpmask_ptr = (T*)mask;
if (maskSize == seq_len * kv_seq_len) { // sliding attention, mask shape: [seq_len, kv_seq_len]
for (int i = 0; i < seq_len; ++i) {
auto unpack_qki = unpack_qk + i * kv_seq_len;
auto fpmask_ptri = fpmask_ptr + i * seq_len;
for (int j = 0; j < offset; ++j) {
unpack_qki[j] = unpack_qki[j] * mScale;
auto unpack_qki = unpack_qk + i * processedKvLen;
auto fpmask_ptri = fpmask_ptr + i * kv_seq_len;
for (int j = startIndx; j < endIndx; ++j) {
unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale + fpmask_ptri[j];
}
for (int j = 0; j < seq_len; ++j) {
unpack_qki[offset + j] = unpack_qki[offset + j] * mScale + fpmask_ptri[j];
}
} else { // mask shape: [seq_len, seq_len]
for (int i = 0; i < seq_len; ++i) {
auto unpack_qki = unpack_qk + i * processedKvLen;
auto fpmask_ptri = fpmask_ptr + i * seq_len;
auto notMaskIndx = ALIMIN(endIndx, offset);
auto stMaskIndx = ALIMAX(startIndx, offset);
for (int j = startIndx; j < notMaskIndx; ++j) {
unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale;
}
for (int j = stMaskIndx; j < endIndx; ++j) {
unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale + fpmask_ptri[j - offset];
}
}
}
} else {
// int mask
int* mask_ptr = mask->host<int>();
for (int i = 0; i < seq_len * kv_seq_len; i++) {
if (mask_ptr[i]) {
int* mask_ptr = (int*)mask;
for (int i = 0; i < seq_len * processedKvLen; i++) {
if (mask_ptr[i / processedKvLen * kv_seq_len + i % processedKvLen]) {
unpack_qk[i] = unpack_qk[i] * mScale;
} else {
unpack_qk[i] = min_val;
@ -148,41 +162,47 @@ static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale
}
}
static void softmax_QK(float* softmax_qk_addr, float* unpack_qk_addr, int seq_len, int kv_seq_len) {
for (int i = 0; i < seq_len; i++) { // softmax each row
MNNSoftmax(softmax_qk_addr + i * kv_seq_len, unpack_qk_addr + i * kv_seq_len, kv_seq_len);
}
typedef void(softmaxFunc)(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize);
template <typename T>
static void softmaxQK(float* softmax_qk_addr, float* unpack_qk_addr, float* runningMax, float* runningSum, float* diffScale, const float* sinkPtr, softmaxFunc* sffunc, int seq_len, int kv_seq_len, int headIdx, bool isLastKvBlock) {
// not sliding attention
if (sinkPtr == nullptr) {
sffunc(softmax_qk_addr, unpack_qk_addr, runningMax, runningSum, diffScale, seq_len, kv_seq_len);
return;
}
static void sink_softmax_QK(float* softmax_qk_addr, float* unpack_qk_addr, int seq_len, int kv_seq_len, float sink) {
// TODO: opt
std::vector<float> buffer(2 * (kv_seq_len + 1));
float* sinkSrc = buffer.data();
float* sinkDst = buffer.data() + kv_seq_len + 1;
for (int i = 0; i < seq_len; i++) { // softmax each row
::memcpy(sinkSrc, unpack_qk_addr + i * kv_seq_len, kv_seq_len * sizeof(float));
sinkSrc[kv_seq_len] = sink;
float rowMax = sink;
for (int j = 0; j < kv_seq_len; j++) {
rowMax = ALIMAX(rowMax, sinkSrc[j]);
float sink = ((T*)sinkPtr)[headIdx];
if (!runningMax && !runningSum) { // Do not use flash attention
for (int i = 0; i < seq_len; ++i) {
float exprOffset[4] = {1, 0, -sink, 1.f};
MNNExp(softmax_qk_addr + i * kv_seq_len, unpack_qk_addr + i * kv_seq_len, exprOffset, kv_seq_len);
for (int j = 0; j < kv_seq_len; ++j) {
softmax_qk_addr[i * kv_seq_len + j] /= exprOffset[3];
}
for (int j = 0; j < kv_seq_len + 1; j++) {
sinkSrc[j] = sinkSrc[j] - rowMax;
}
MNNSoftmax(sinkDst, sinkSrc, kv_seq_len + 1);
::memcpy(softmax_qk_addr + i * kv_seq_len, sinkDst, kv_seq_len * sizeof(float));
return;
}
// Use flash attention
if (isLastKvBlock) {
for (int i = 0; i < seq_len; ++i) {
runningSum[i] += expf(sink - runningMax[i]);
}
}
MNNSoftmax(softmax_qk_addr, unpack_qk_addr, runningMax, runningSum, diffScale, seq_len, kv_seq_len);
}
template <typename T>
static void unpack_QKV(char* pack_qkv, char* unpack_qkv, int mNumHead, int mHeadDim, int unit, int seq_len) {
static void unpack_QKV(int8_t* pack_qkv, int8_t* unpack_qkv, int mNumHead, int mHeadDim, int mPack, int seq_len) {
auto src_ptr = reinterpret_cast<T*>(pack_qkv);
auto dst_ptr = reinterpret_cast<T*>(unpack_qkv);
for (int i = 0; i < seq_len; i++) {
for (int j = 0; j < mHeadDim; j++) {
int a = j / unit;
int b = j % unit;
dst_ptr[i * mNumHead * mHeadDim + j] = src_ptr[a * seq_len * unit + i * unit + b];
int a = j / mPack;
int b = j % mPack;
dst_ptr[i * mNumHead * mHeadDim + j] = src_ptr[a * seq_len * mPack + i * mPack + b];
}
}
}
@ -191,10 +211,10 @@ ErrorCode CPUAttention::onResize(const std::vector<Tensor*>& inputs, const std::
auto core = static_cast<CPUBackend *>(backend())->functions();
core->MNNGetMatMulPackMode(&eP, &lP, &hP);
mThreadNum = ((CPUBackend *)backend())->threadNumber();
unit = core->pack;
mPack = core->pack;
bytes = core->bytes;
int qkvQuantOptions = static_cast<CPUBackend *>(backend())->getRuntime()->hint().qkvQuantOption;
mUseGemmInt8 = (qkvQuantOptions == 4);
mUseGemmInt8 = (qkvQuantOptions % 8 == 4);
if (mUseGemmInt8) {
static_cast<CPUBackend*>(backend())->int8Functions()->MNNGetGemmUnit(&hP8, &lP8, &eP8);
}
@ -208,7 +228,7 @@ ErrorCode CPUAttention::onResize(const std::vector<Tensor*>& inputs, const std::
if (mUseGemmInt8) {
mPackQ.reset(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(seq_len, eP8), UP_DIV(mHeadDim, lP8), eP8 * lP8}));
mSumQ.reset(Tensor::createDevice<int32_t>({mThreadNum, UP_DIV(seq_len, eP8), eP8}));
mPackQKV.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(mHeadDim, unit), seq_len, unit}));
mPackQKV.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack}));
backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC);
backend()->onAcquireBuffer(mSumQ.get(), Backend::DYNAMIC);
backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC);
@ -220,18 +240,40 @@ ErrorCode CPUAttention::onResize(const std::vector<Tensor*>& inputs, const std::
mQueryScale.resize(mNumHead);
mQueryZeroPoint.resize(mNumHead);
} else {
mPackQ.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mHeadDim, lP), eP}));
mPackQKV.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(mHeadDim, unit), seq_len, unit}));
mPackQ.reset(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mHeadDim, lP), eP * bytes}));
mPackQKV.reset(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack * bytes}));
backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC);
backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC);
// flash attention
if (qkvQuantOptions / 8 == 1) {
mRunningMax.reset(Tensor::createDevice<int8_t>({mThreadNum, seq_len * 4}));
mRunningSum.reset(Tensor::createDevice<int8_t>({mThreadNum, seq_len * 4}));
mExpfDiffMax.reset(Tensor::createDevice<int8_t>({mThreadNum, seq_len * 4}));
mTempOut.reset(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack * bytes}));
backend()->onAcquireBuffer(mRunningMax.get(), Backend::DYNAMIC);
backend()->onAcquireBuffer(mRunningSum.get(), Backend::DYNAMIC);
backend()->onAcquireBuffer(mExpfDiffMax.get(), Backend::DYNAMIC);
backend()->onAcquireBuffer(mTempOut.get(), Backend::DYNAMIC);
}
backend()->onReleaseBuffer(mPackQ.get(), Backend::DYNAMIC);
backend()->onReleaseBuffer(mPackQKV.get(), Backend::DYNAMIC);
if (qkvQuantOptions / 8 == 1) {
backend()->onReleaseBuffer(mRunningMax.get(), Backend::DYNAMIC);
backend()->onReleaseBuffer(mRunningSum.get(), Backend::DYNAMIC);
backend()->onReleaseBuffer(mExpfDiffMax.get(), Backend::DYNAMIC);
backend()->onReleaseBuffer(mTempOut.get(), Backend::DYNAMIC);
}
}
return NO_ERROR;
}
ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
auto core = static_cast<CPUBackend *>(backend())->functions();
auto qkvQuantOptions = static_cast<CPUBackend *>(backend())->getRuntime()->hint().qkvQuantOption;
auto query = inputs[0];
auto key = inputs[1];
auto value = inputs[2];
@ -283,146 +325,146 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std:
int max_len = mKVCacheManager->maxLength();
bool quant_key = mKVCacheManager->config()->mQuantKey;
bool quant_value = mKVCacheManager->config()->mQuantValue;
mBlockKV = (qkvQuantOptions / 8 == 1) ? ALIMIN(MNN_FLASH_ATTENTION_BLOCK_SIZE, kv_seq_len) : kv_seq_len;
int32_t units[2] = {eP, lP};
// Temporary tensors for intermediate results
std::shared_ptr<Tensor> packQK(Tensor::createDevice<float>({mThreadNum, UP_DIV(kv_seq_len, unit), seq_len, unit}));
std::shared_ptr<Tensor> unpackQK(Tensor::createDevice<int32_t>({mThreadNum, seq_len, kv_seq_len}));
std::shared_ptr<Tensor> softmMaxQ(Tensor::createDevice<int32_t>({mThreadNum, seq_len, kv_seq_len}));
std::shared_ptr<Tensor> newPackQK(Tensor::createDevice<float>({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(kv_seq_len, lP), eP}));
std::shared_ptr<Tensor> dequantV(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), kv_seq_len, hP}));
backend()->onAcquireBuffer(packQK.get(), Backend::STATIC);
std::shared_ptr<Tensor> unpackQK(Tensor::createDevice<int32_t>({mThreadNum, seq_len, mBlockKV}));
std::shared_ptr<Tensor> softmMaxQ(Tensor::createDevice<int32_t>({mThreadNum, seq_len, mBlockKV}));
std::shared_ptr<Tensor> newPackQK(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mBlockKV, lP), eP * bytes}));
std::shared_ptr<Tensor> dequantV(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mHeadDim, hP), kv_seq_len, hP * bytes}));
// mTempQKBlock.reset(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(mBlockKV, mPack), seq_len, mPack * bytes}));
std::shared_ptr<Tensor> tempQKBlock(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(mBlockKV, mPack), seq_len, mPack * bytes}));
backend()->onAcquireBuffer(unpackQK.get(), Backend::STATIC);
backend()->onAcquireBuffer(softmMaxQ.get(), Backend::STATIC);
backend()->onAcquireBuffer(newPackQK.get(), Backend::STATIC);
backend()->onAcquireBuffer(tempQKBlock.get(), Backend::STATIC);
if (quant_value) {
backend()->onAcquireBuffer(dequantV.get(), Backend::STATIC);
mKVCacheManager->onDequantValue(dequantV.get());
}
const float* sinksPtr = sinks ? sinks->host<float>() : nullptr;
std::function<void(int)> mCompute = [=](int tId) {
auto pack_q = mPackQ->host<char>() + tId * UP_DIV(seq_len, eP) * ROUND_UP(mHeadDim, lP) * eP * bytes;
auto pack_qk = packQK->host<char>() + tId * UP_DIV(kv_seq_len, unit) * seq_len * unit * bytes;
char * sum_q = nullptr;
auto unpack_qk = unpackQK->host<float>() + tId * seq_len * kv_seq_len;
auto softmax_qk = softmMaxQ->host<float>() + tId * seq_len * kv_seq_len;
auto new_pack_qk = newPackQK->host<char>() + tId * UP_DIV(seq_len, eP) * ROUND_UP(kv_seq_len, lP) * eP * bytes;
auto pack_qkv = mPackQKV->host<char>() + tId * UP_DIV(mHeadDim, unit) * seq_len * unit * bytes;
auto qReordered = mPackQ->host<int8_t>() + tId * mPackQ->stride(0);
auto qkPacked = tempQKBlock->host<int8_t>() + tId * tempQKBlock->stride(0);
int8_t * sum_q = nullptr;
auto qkFlatten = unpackQK->host<float>() + tId * unpackQK->stride(0);
auto qkSoftmax = softmMaxQ->host<float>() + tId * softmMaxQ->stride(0);
auto qkReordered = newPackQK->host<int8_t>() + tId * newPackQK->stride(0);
auto qkvPacked = mPackQKV->host<int8_t>() + tId * mPackQKV->stride(0);
auto QxK = quant_key ? core->MNNPackedMatMul_int8 : core->MNNPackedMatMul;
auto QxK_remain = quant_key ? core->MNNPackedMatMulRemain_int8 : core->MNNPackedMatMulRemain;
// Flash Attention
auto runningMax = mRunningMax ? (float*)(mRunningMax->host<int8_t>() + tId * mRunningMax->stride(0)) : nullptr;
auto runningSum = mRunningSum ? (float*)(mRunningSum->host<int8_t>() + tId * mRunningSum->stride(0)) : nullptr;
auto diffScale = mExpfDiffMax ? (float*)(mExpfDiffMax->host<int8_t>() + tId * mExpfDiffMax->stride(0)) : nullptr;
auto outputPacked = mTempOut ? mTempOut->host<int8_t>() + tId * mTempOut->stride(0) : qkvPacked;
int head_index = tId * tileCount;
int kvBlocks = UP_DIV(kv_seq_len, mBlockKV);
if (mUseGemmInt8) {
pack_q = mPackQ->host<char>() + tId * UP_DIV(seq_len, eP8) * UP_DIV(mHeadDim, lP8) * eP8 * lP8;
sum_q = mSumQ->host<char>() + tId * UP_DIV(seq_len, eP8) * eP8 * 4;
qReordered = mPackQ->host<int8_t>() + tId * UP_DIV(seq_len, eP8) * UP_DIV(mHeadDim, lP8) * eP8 * lP8;
sum_q = mSumQ->host<int8_t>() + tId * UP_DIV(seq_len, eP8) * eP8 * 4;
}
for (int h = head_index; h < head_index + tileCount && h < mNumHead; h++) {
int kv_h = h / group_size;
char * key_addr = mKVCacheManager->addrOfKey(kv_h);
char * scale_addr = mKVCacheManager->addrOfScale(kv_h);
char * zero_point_addr = mKVCacheManager->addrOfZeroPoint(kv_h);
char * key_sum_addr = mKVCacheManager->addrOfKeySum(kv_h);
char * value_addr = quant_value ? (dequantV->host<char>() + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(kv_seq_len, lP) * hP * bytes) : mKVCacheManager->addrOfValue(kv_h);
if (bytes == 2) {
pack_query<FLOAT16_T>(query, pack_q, sum_q, seq_len, h, q_scale);
if (runningSum && runningMax) {
memset(runningSum, 0, mRunningSum->stride(0));
if (sinksPtr == nullptr) {
for (int k = 0; k < seq_len; ++k) {
runningMax[k] = -std::numeric_limits<float>::infinity();
}
} else {
pack_query<float>(query, pack_q, sum_q, seq_len, h, q_scale);
float sinkVal;
if (bytes == 2) {
sinkVal = ((FLOAT16_T*)sinksPtr)[h];
} else {
sinkVal =sinksPtr[h];
}
// query @ key
for (int k = 0; k < seq_len; ++k) {
runningMax[k] = sinkVal;
}
}
}
int kv_h = h / group_size;
int8_t * key_addr = mKVCacheManager->addrOfKey(kv_h);
int8_t * scale_addr = mKVCacheManager->addrOfScale(kv_h);
int8_t * zero_point_addr = mKVCacheManager->addrOfZeroPoint(kv_h);
int8_t * key_sum_addr = mKVCacheManager->addrOfKeySum(kv_h);
int8_t * value_addr = quant_value ? (dequantV->host<int8_t>() + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(kv_seq_len, lP) * hP * bytes) : mKVCacheManager->addrOfValue(kv_h);
if (mUseGemmInt8) {
auto GemmInt8Kernel = static_cast<CPUBackend*>(backend())->int8Functions()->Int8GemmKernel;
if (bytes == 2 && unit == 8) {
GemmInt8Kernel = static_cast<CPUBackend*>(backend())->int8Functions()->MNNGemmInt8AddBiasScale_Unit_FP16;
if (bytes == 2) {
pack_query<FLOAT16_T>(query, qReordered, sum_q, seq_len, h, q_scale);
} else {
pack_query<float>(query, qReordered, sum_q, seq_len, h, q_scale);
}
std::vector<float> postScale(ROUND_UP(kv_seq_len, hP8), 0.0f);
for (int i = 0; i < kv_seq_len; i++) {
postScale[i] = ((float*)scale_addr)[i] * mQueryScale[h] * q_scale;
} else {
core->MNNAttenPackAndScaleSingleHead((float*)qReordered, (float*)(query->host<int8_t>() + h * mHeadDim * bytes), mHeadDim * mNumHead, &q_scale, units, seq_len, mHeadDim);
}
std::vector<float> weightQuantBias(ROUND_UP(kv_seq_len, hP8), 0.0f);
for (int i = 0; i < kv_seq_len; i++) {
weightQuantBias[i] = -((float*)scale_addr)[i] * ((float*)zero_point_addr)[i] * q_scale;
}
std::vector<float> biasFloat(ROUND_UP(kv_seq_len, hP8), 0.0f);
for (int i = 0; i < kv_seq_len; i++) {
biasFloat[i] = -mQueryScale[h] * mQueryZeroPoint[h] * ((float*)key_sum_addr)[i] * q_scale;
}
QuanPostTreatParameters post;
post.bias = nullptr;
post.biasFloat = biasFloat.data();
post.blockNum = 1;
post.inputBias = nullptr;
post.inputScale = nullptr;
post.fp32minmax = nullptr;
post.scale = postScale.data();
post.useInt8 = false;
post.weightKernelSum = weightQuantBias.data();
int N = UP_DIV(seq_len, eP8);
for (int i = 0; i < N; i++) {
int realcount = ALIMIN(eP8, seq_len - i * eP8);
post.srcKernelSum = (float*)((char*)sum_q + i * eP8 * 4);
GemmInt8Kernel(
(int8_t*)pack_qk + i * eP8 * unit * bytes,
(int8_t*)pack_q + i * ROUND_UP(mHeadDim, lP8) * eP8,
(int8_t*)key_addr,
UP_DIV(mHeadDim, lP8),
seq_len * unit * bytes,
UP_DIV(kv_seq_len, unit),
&post,
realcount
);
}
}
else {
for (int i = 0; i < kvBlocks; ++i) {
int subKvSeqLen = ALIMIN(mBlockKV, kv_seq_len - i * mBlockKV);
auto keyPtr = key_addr + i * UP_DIV(mBlockKV, hP) * ROUND_UP(mHeadDim, lP) * hP * bytes;
auto valuePtr = value_addr + i * UP_DIV(mBlockKV, lP) * hP * lP * bytes;
// query @ key
{
int loop_e = seq_len / eP;
int remain = seq_len % eP;
auto qStride0 = ROUND_UP(mHeadDim, lP) * eP * bytes;
size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)mHeadDim, lP), (size_t)kv_seq_len, (size_t)seq_len * unit * bytes, 0, 0, 0};
for (int i = 0 ; i < loop_e; i++) {
QxK((float*)(pack_qk + (i * eP * unit) * bytes), (float*)(pack_q + i * qStride0), (float*)key_addr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)mHeadDim, lP), (size_t)subKvSeqLen, (size_t)seq_len * mPack * bytes, 0, 0, 0};
for (int ei = 0 ; ei < loop_e; ei++) {
QxK((float*)(qkPacked + (ei * eP * mPack) * bytes), (float*)(qReordered + ei * qStride0), (float*)keyPtr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
}
QxK_remain((float*)(pack_qk + (loop_e * eP * unit) * bytes), (float*)(pack_q + loop_e * qStride0), (float*)key_addr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
QxK_remain((float*)(qkPacked + (loop_e * eP * mPack) * bytes), (float*)(qReordered + loop_e * qStride0), (float*)keyPtr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
}
// qk: [kv_seq_len/unit, seq_len, unit] -> [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len, eP]
if (sinksPtr != nullptr) {
// qk: [kv_seq_len/mPack, seq_len, mPack] -> [seq_len/eP, kv_seq_len, eP]
{
if(bytes == 2) {
unpack_QK<FLOAT16_T>(unpack_qk, pack_qk, seq_len, kv_seq_len);
mask_QK<FLOAT16_T>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask);
sink_softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len, sinksPtr[h]);
pack_QK<FLOAT16_T>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes);
if (seq_len == 1) {
core->MNNLowpToFp32((int16_t*)qkPacked, qkFlatten, seq_len * subKvSeqLen);
} else {
unpack_QK<float>(unpack_qk, pack_qk, seq_len, kv_seq_len);
mask_QK<float>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask);
sink_softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len, sinksPtr[h]);
pack_QK<float>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes);
core->MNNAttenUnpackAndConvertFp16(qkFlatten, (float*)qkPacked, subKvSeqLen, seq_len, mPack);
}
mask_QK<FLOAT16_T>(qkFlatten, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask, kv_seq_len - seq_len, i * mBlockKV, subKvSeqLen);
softmaxQK<FLOAT16_T>(qkSoftmax, qkFlatten, runningMax, runningSum, diffScale, sinksPtr, core->MNNSoftmax, seq_len, subKvSeqLen, h, i == kvBlocks - 1);
core->MNNAttenPackAndConvertFp32((float*)qkReordered, qkSoftmax, units, seq_len, subKvSeqLen);
} else {
if(bytes == 2) {
unpack_QK<FLOAT16_T>(unpack_qk, pack_qk, seq_len, kv_seq_len);
mask_QK<FLOAT16_T>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask);
softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len);
pack_QK<FLOAT16_T>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes);
if (seq_len > 1) {
int32_t areaOffset[2] = {seq_len, seq_len};
core->MNNUnpackCUnitTranspose(qkFlatten, (float*)qkPacked, seq_len, subKvSeqLen, areaOffset);
} else {
unpack_QK<float>(unpack_qk, pack_qk, seq_len, kv_seq_len);
mask_QK<float>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask);
softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len);
pack_QK<float>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes);
memcpy(qkFlatten, qkPacked, subKvSeqLen * sizeof(float));
}
mask_QK<float>(qkFlatten, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask, kv_seq_len - seq_len, i * mBlockKV, subKvSeqLen);
softmaxQK<float>(qkSoftmax, qkFlatten, runningMax, runningSum, diffScale, sinksPtr, core->MNNSoftmax, seq_len, subKvSeqLen, h, i == kvBlocks - 1);
packKvCache((float*)qkReordered, qkSoftmax, seq_len, subKvSeqLen, eP);
}
}
// qk @ v
size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)kv_seq_len, lP), (size_t)mHeadDim, (size_t)seq_len * unit * bytes, 0, 0, 0};
size_t bExtraStride = (UP_DIV(max_len, lP) - UP_DIV(kv_seq_len, lP)) * hP * lP * bytes;
// TODO: update qkvPacked using diffScale
size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)subKvSeqLen, lP), (size_t)mHeadDim, (size_t)seq_len * mPack * bytes, 0, 0, 0};
size_t bExtraStride = (UP_DIV(max_len, lP) - UP_DIV(subKvSeqLen + i * mBlockKV, lP) + UP_DIV(i * mBlockKV, lP)) * hP * lP * bytes;
shapeParameters[5] = quant_value ? 0 : bExtraStride;
int loop_e = seq_len / eP;
int remain = seq_len % eP;
auto qkStride0 = ROUND_UP(kv_seq_len, lP) * eP * bytes;
for (int i = 0 ; i < loop_e; i++) {
core->MNNPackedMatMul((float*)(pack_qkv + (i * eP * unit) * bytes), (float*)(new_pack_qk + i * qkStride0), (float*)value_addr, shapeParameters, nullptr, nullptr, nullptr, nullptr);
auto qkStride0 = ROUND_UP(subKvSeqLen, lP) * eP * bytes;
for (int ei = 0 ; ei < loop_e; ei++) {
core->MNNPackedMatMul((float*)(qkvPacked + (ei * eP * mPack) * bytes), (float*)(qkReordered + ei * qkStride0), (float*)valuePtr, shapeParameters, nullptr, nullptr, nullptr, nullptr);
}
core->MNNPackedMatMulRemain((float*)(pack_qkv + (loop_e * eP * unit) * bytes), (float*)(new_pack_qk + loop_e * qkStride0), (float*)value_addr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr);
// unpack: [head_dim/unit, seq_len, unit] -> [seq_len, num_head, head_dim]
auto dst_ptr = outputs[0]->host<char>() + h * mHeadDim * bytes;
core->MNNPackedMatMulRemain((float*)(qkvPacked + (loop_e * eP * mPack) * bytes), (float*)(qkReordered + loop_e * qkStride0), (float*)valuePtr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr);
if (runningMax != nullptr && runningSum != nullptr && diffScale != nullptr) {
core->MNNFlashAttentionUpdateBlockOutput((float*)outputPacked, (float*)qkvPacked, diffScale, runningSum, UP_DIV(mHeadDim, mPack), seq_len, mPack, i, kvBlocks, mPackQKV->stride(0) / bytes, bytes);
}
}
// unpack: [head_dim/mPack, seq_len, mPack] -> [seq_len, num_head, head_dim]
auto dst_ptr = outputs[0]->host<int8_t>() + h * mHeadDim * bytes;
if (bytes == 2) {
unpack_QKV<int16_t>(pack_qkv, dst_ptr, mNumHead, mHeadDim, unit, seq_len);
unpack_QKV<int16_t>((int8_t*)outputPacked, dst_ptr, mNumHead, mHeadDim, mPack, seq_len);
} else {
unpack_QKV<float>(pack_qkv, dst_ptr, mNumHead, mHeadDim, unit, seq_len);
unpack_QKV<float>((int8_t*)outputPacked, dst_ptr, mNumHead, mHeadDim, mPack, seq_len);
}
}
};
@ -431,10 +473,10 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std:
}
MNN_CONCURRENCY_END();
backend()->onReleaseBuffer(packQK.get(), Backend::STATIC);
backend()->onReleaseBuffer(unpackQK.get(), Backend::STATIC);
backend()->onReleaseBuffer(softmMaxQ.get(), Backend::STATIC);
backend()->onReleaseBuffer(newPackQK.get(), Backend::STATIC);
backend()->onReleaseBuffer(tempQKBlock.get(), Backend::STATIC);
if (quant_value){
backend()->onReleaseBuffer(dequantV.get(), Backend::STATIC);
}
@ -460,9 +502,19 @@ CPUAttention::CPUAttention(Backend *backend, bool kv_cache) : Execution(backend)
mPackQKV.reset(Tensor::createDevice<float>({1, 1, 1, 1}));
MNN::KVCacheManager::KVCacheConfig kvconfig;
int qkvQuantOptions = static_cast<CPUBackend *>(backend)->getRuntime()->hint().qkvQuantOption;
kvconfig.mUseInt8Kernel = (qkvQuantOptions == 4);
kvconfig.mQuantKey = (qkvQuantOptions == 4) || (qkvQuantOptions & 1);
kvconfig.mQuantValue = (qkvQuantOptions == 4) || ((qkvQuantOptions >> 1) & 1);
kvconfig.mUseInt8Kernel = (qkvQuantOptions % 8 == 4);
// qkvQuantOption % 8:
// 0: Do not quantize
// 1: Only quantize key, use int8 asymmetric quantization
// 2: Only quantize value, use fp8 quantization
// 3: quantize both key and value
// 4: quantize query, key and value, and use gemm int8 kernel to compute K*V
// qkvQuantOption / 8:
// 1: use flash attention
kvconfig.mQuantKey = (qkvQuantOptions % 8 == 4) || (qkvQuantOptions % 8 == 1) || (qkvQuantOptions % 8 == 3);
kvconfig.mQuantValue = (qkvQuantOptions % 8 == 4) || (qkvQuantOptions % 8 == 2);
kvconfig.mKVCacheDir = static_cast<CPUBackend *>(backend)->getRuntime()->hint().kvcacheDirPath;
kvconfig.mKVCacheSizeLimit = static_cast<CPUBackend *>(backend)->getRuntime()->hint().kvcacheSizeLimit;
kvconfig.mExpandChunk = 64;

View File

@ -30,15 +30,16 @@ private:
bool mKVCache = true;
bool mUseGemmInt8 = false;
int bytes = 4;
int mThreadNum = 1;;
int eP, lP, hP, unit; // float matmul packing
int mThreadNum = 1;
int mBlockKV = 512;
int eP, lP, hP, mPack; // float matmul packing
int eP8, lP8, hP8; // GemmInt8 packing
int mNumHead, mKvNumHead, mHeadDim;
std::shared_ptr<Tensor> mPackQ, mPackQKV, mSumQ;
std::shared_ptr<Tensor> mPackQ, mPackQKV, mSumQ, mRunningMax, mRunningSum, mTempQKBlock, mTempOut, mExpfDiffMax;
std::shared_ptr<KVCacheManager> mKVCacheManager = nullptr;
std::vector<float> mMinQ, mMaxQ, mQueryScale, mQueryZeroPoint;
template <typename T> void pack_query(Tensor* query, char* pack_q, char* sum_q, int seq_len, int h, float q_scale);
template <typename T> void unpack_QK(float * unpack_qk_dst, char * pack_qk_src, int seq_len, int kv_seq_len);
template <typename T> void pack_query(Tensor* query, int8_t* pack_q, int8_t* sum_q, int seq_len, int h, float q_scale);
template <typename T> void unpack_QK(float * unpack_qk_dst, int8_t * pack_qk_src, int seq_len, int kv_seq_len);
KVMeta* mMeta;
};

View File

@ -509,21 +509,21 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p
mDmaInfo.reset(new CPURuntime::DynamicAllocator);
mDmaInfo->mDynamicAllocator.reset(mRuntime->createDynamicBufferAlloctor(0));
mDmaInfo->mCurrentDynamicAllocator = mDmaInfo->mDynamicAllocator.get();
mDmaInfo->mCacheGroup.resize(MNN_CPU_MAX_BUFFER_INDEX);
for (int i=0; i<mDmaInfo->mCacheGroup.size(); ++i) {
mDmaInfo->mCacheGroup[i].reset(new CPUResizeCache);
}
} else {
mDmaInfo = dynamicAlloc;
}
mPrecisionMode = precision;
mCoreFunctions = MNNGetCoreFunctions();
mInt8CoreFunctions = MNNGetInt8CoreFunctions();
mCacheGroup.resize(MNN_CPU_MAX_BUFFER_INDEX);
for (int i=0; i<mCacheGroup.size(); ++i) {
mCacheGroup[i].reset(new CPUResizeCache);
}
mCache = mCacheGroup[0].get();
mCache = mDmaInfo->mCacheGroup[0].get();
}
CPUBackend::~CPUBackend() {
mCacheGroup.clear();
// Do nothing
}
void CPUBackend::_resetDynamicMemory() const {
mRuntime->pCurrentStatus = mDmaInfo->mDynamicAllocator->apply();
@ -560,7 +560,7 @@ bool CPUBackend::onSelectDynamicAllocator(int index, int maxIndex) {
mRuntime->buffer(0)->release();
mDmaInfo->mCurrentDynamicAllocator = mDmaInfo->mDynamicAllocator.get();
}
mCache = mCacheGroup[index].get();
mCache = mDmaInfo->mCacheGroup[index].get();
return true;
}

View File

@ -23,12 +23,14 @@
namespace MNN {
class WorkerThread;
class CPUResizeCache;
class CPURuntime : public Runtime {
public:
struct DynamicAllocator {
std::shared_ptr<BufferAllocator> mDynamicAllocator;
std::shared_ptr<BufferAllocator> mDynamicAllocatorBackup;
BufferAllocator* mCurrentDynamicAllocator = nullptr;
std::vector<std::shared_ptr<CPUResizeCache>> mCacheGroup;
};
friend class CPUBackend;
CPURuntime(const Backend::Info& info);
@ -82,7 +84,6 @@ struct CoreFunctions;
struct CoreInt8Functions;
struct MatmulRelatedFunctions;
class CPUResizeCache;
class CPUMemObj : public Backend::MemObj {
public:
CPUMemObj(BufferAllocator* allocator, MemChunk chunk, int size) : mAllocator(allocator), mChunk(chunk), mSize(size) {}
@ -199,7 +200,6 @@ private:
BackendConfig::MemoryMode mMemory;
static std::map<OpType, CPUBackend::Creator*>* gCreator;
CPUResizeCache* mCache;
std::vector<std::shared_ptr<CPUResizeCache>> mCacheGroup;
};
/** execution cast wrapper. insert tensor cast dynamic. */
class CastWrapExecution : public Execution {

View File

@ -170,7 +170,7 @@ cpu_mask_t MNNGetCPUMask(const std::vector<int>& cpuIds) {
// cpuinfo
// Reference from: https://github.com/pytorch/cpuinfo
#if defined(ENABLE_ARMV82) && defined(__arm__)
#if (defined(ENABLE_ARMV82) && defined(__arm__)) || (defined(__ANDROID__) && defined(__aarch64__))
/* As per include/sys/system_properties.h in Android NDK */
#define CPUINFO_HARDWARE_VALUE_MAX 64
@ -1180,7 +1180,7 @@ struct cpuinfo_arm_chipset cpuinfo_arm_android_decode_chipset(const struct cpuin
// MNN_PRINT("chipset vendor, series, model is: %d, %d, %d\n", chipset.vendor, chipset.series, chipset.model);
return chipset;
}
static void _getInfoARMv7(MNNCPUInfo* cpuinfo_isa) {
static void _getInfoArm(MNNCPUInfo* cpuinfo_isa) {
// Get White List And Black List
struct cpuinfo_arm_linux_processor* arm_linux_processors = NULL;
if (0 == cpuinfo_isa->groups.size()) {
@ -1500,8 +1500,8 @@ static void _fillInfo(MNNCPUInfo* cpuinfo_isa) {
#if defined(__aarch64__)
_getInfoAux(cpuinfo_isa);
#endif
#if defined(ENABLE_ARMV82) && defined(__arm__)
_getInfoARMv7(cpuinfo_isa);
#if (defined(ENABLE_ARMV82) && defined(__arm__)) || (defined(__ANDROID__) && defined(__aarch64__))
_getInfoArm(cpuinfo_isa);
#endif // #ifdef arm / arm64
#endif // #ifdef __linux__

View File

@ -200,7 +200,7 @@ int CPUSoftmax::_softmaxCommon(const uint8_t *srcData, uint8_t *dstData) {
for (int v=0; v<mInside; ++v) {
//TODO: Fix x86 compute error and use the same function
#ifdef MNN_USE_SSE
MNNSoftmax(workDst+v*mChannel, workSrc+v*mChannel, mChannel);
MNNSoftmax(workDst+v*mChannel, workSrc+v*mChannel, nullptr, nullptr, nullptr, 1, mChannel);
#else
___MNNSoftmax(workDst+v*mChannel, workSrc+v*mChannel, mChannel, mulFunction);
#endif

View File

@ -75,13 +75,13 @@ void KVCacheManager::resetKVCacheFileSize(size_t keySize, size_t valueSize) {
void KVCacheManager::mmapKVCache(size_t keySize, size_t valueSize)
{
if (mMapKeyAddr == nullptr) {
mMapKeyAddr = (char *)MNNMmapFile(mKeyCacheFD, keySize);
mMapKeyAddr = (int8_t *)MNNMmapFile(mKeyCacheFD, keySize);
if (mMapKeyAddr == nullptr) {
MNN_PRINT("Failed to memory-map the kvcache!\n");
}
}
if (mMapValueAddr == nullptr) {
mMapValueAddr = (char *)MNNMmapFile(mValueCacheFD, valueSize);
mMapValueAddr = (int8_t *)MNNMmapFile(mValueCacheFD, valueSize);
if (mMapValueAddr == nullptr) {
MNN_PRINT("Failed to memory-map the kvcache!\n");
}
@ -111,8 +111,8 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
mBackend->onAcquireBuffer(new_key, Backend::STATIC);
for (int h = 0; h < mKvNumHead; h++) {
memcpy(
new_key->host<char>() + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
new_key->host<int8_t>() + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
mPastKey->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8
);
}
@ -123,8 +123,8 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
mBackend->onAcquireBuffer(new_key, Backend::STATIC);
for (int h = 0; h < mKvNumHead; h++) {
memcpy(
new_key->host<char>() + h * new_key->stride(0),
mPastKey->host<char>() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP),
new_key->host<int8_t>() + h * new_key->stride(0),
mPastKey->host<int8_t>() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP),
ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP)
);
}
@ -135,12 +135,12 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
mBackend->onAcquireBuffer(new_key, Backend::STATIC);
for (int h = 0; h < mKvNumHead; h++) {
memcpy(
new_key->host<char>() + h * new_key->stride(0) * mBytes,
mPastKey->host<char>() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes,
new_key->host<int8_t>() + h * new_key->stride(0) * mBytes,
mPastKey->host<int8_t>() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes,
ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes
);
if ((new_key->stride(0) - mPastKey->stride(0)) > 0) {
memset(new_key->host<char>() + h * new_key->stride(0) * mBytes + mPastKey->stride(0) * mBytes, 0, (new_key->stride(0) - mPastKey->stride(0)) * mBytes);
memset(new_key->host<int8_t>() + h * new_key->stride(0) * mBytes + mPastKey->stride(0) * mBytes, 0, (new_key->stride(0) - mPastKey->stride(0)) * mBytes);
}
}
mPastKey.reset(new_key);
@ -152,8 +152,8 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
for (int h = 0; h < mKvNumHead; h++) {
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
memcpy(
new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP,
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
new_value->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP,
mPastValue->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
ROUND_UP(oldMaxLength, lP) * hP
);
}
@ -166,12 +166,12 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
for (int h = 0; h < mKvNumHead; h++) {
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
memcpy(
new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes,
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
new_value->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes,
mPastValue->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
ROUND_UP(oldMaxLength, lP) * hP * mBytes
);
if ((new_value->stride(1) - mPastValue->stride(1)) > 0) {
memset(new_value->host<char>() + (h * new_value->stride(0) + i * new_value->stride(1)) * mBytes + mPastValue->stride(1) * mBytes, 0, (new_value->stride(1) - mPastValue->stride(1)) * mBytes);
memset(new_value->host<int8_t>() + (h * new_value->stride(0) + i * new_value->stride(1)) * mBytes + mPastValue->stride(1) * mBytes, 0, (new_value->stride(1) - mPastValue->stride(1)) * mBytes);
}
}
}
@ -189,7 +189,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
for (int h = 0; h < mKvNumHead; h++) {
memcpy(
mMapKeyAddr + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
mPastKey->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8
);
}
@ -200,7 +200,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
for (int h = 0; h < mKvNumHead; h++) {
memcpy(
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
mPastKey->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP
);
}
@ -214,7 +214,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
for (int h = 0; h < mKvNumHead; h++) {
memcpy(
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
mPastKey->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes
);
}
@ -227,7 +227,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
memcpy(
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP,
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
mPastValue->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
ROUND_UP(oldMaxLength, lP) * hP
);
}
@ -243,7 +243,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
memcpy(
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes,
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
mPastValue->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
ROUND_UP(oldMaxLength, lP) * hP * mBytes
);
}
@ -282,8 +282,8 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
memset(old_value->host<uint8_t>(), 0, old_value->length(0) * old_value->stride(0) * mBytes);
}
mmapKVCache(oldKeySize, oldValueSize);
memcpy(old_key->host<char>(), mMapKeyAddr, oldKeySize);
memcpy(old_value->host<char>(), mMapValueAddr, oldValueSize);
memcpy(old_key->host<int8_t>(), mMapKeyAddr, oldKeySize);
memcpy(old_value->host<int8_t>(), mMapValueAddr, oldValueSize);
// Step 2: Resize the kvcache files and remap them
unmapKVCache(oldKeySize, oldValueSize);
resetKVCacheFileSize(keySize, valueSize);
@ -293,7 +293,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
for (int h = 0; h < mKvNumHead; h++) {
memcpy(
mMapKeyAddr + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
old_key->host<char>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
old_key->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8
);
}
@ -301,7 +301,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
for (int h = 0; h < mKvNumHead; h++) {
memcpy(
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
old_key->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP
);
}
@ -309,7 +309,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
for (int h = 0; h < mKvNumHead; h++) {
memcpy(
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
old_key->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes
);
}
@ -319,7 +319,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
memcpy(
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP,
old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
old_value->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
ROUND_UP(oldMaxLength, lP) * hP
);
}
@ -329,7 +329,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
memcpy(
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes,
old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
old_value->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
ROUND_UP(oldMaxLength, lP) * hP * mBytes
);
}
@ -460,9 +460,9 @@ void KVCacheManager::onRealloc(const KVMeta* meta) {
mBackend->onAcquireBuffer(new_zeroPoint, Backend::STATIC);
mBackend->onAcquireBuffer(new_sum, Backend::STATIC);
for (int h = 0; h < mKvNumHead; h++) {
memcpy(new_scale->host<char>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyScale->host<char>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4);
memcpy(new_zeroPoint->host<char>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyZeroPoint->host<char>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4);
memcpy(new_sum->host<char>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeySum->host<char>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4);
memcpy(new_scale->host<int8_t>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyScale->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4);
memcpy(new_zeroPoint->host<int8_t>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyZeroPoint->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4);
memcpy(new_sum->host<int8_t>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeySum->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4);
}
mKeyScale.reset(new_scale);
mKeyZeroPoint.reset(new_zeroPoint);
@ -473,8 +473,8 @@ void KVCacheManager::onRealloc(const KVMeta* meta) {
mBackend->onAcquireBuffer(new_scale, Backend::STATIC);
mBackend->onAcquireBuffer(new_zeroPoint, Backend::STATIC);
for (int h = 0; h < mKvNumHead; h++) {
memcpy(new_scale->host<char>() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyScale->host<char>() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes);
memcpy(new_zeroPoint->host<char>() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyZeroPoint->host<char>() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes);
memcpy(new_scale->host<int8_t>() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyScale->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes);
memcpy(new_zeroPoint->host<int8_t>() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyZeroPoint->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes);
}
mKeyScale.reset(new_scale);
mKeyZeroPoint.reset(new_zeroPoint);
@ -690,8 +690,8 @@ void KVCacheManager::onDequantValue(Tensor * dequantedValues) {
int tileCount = UP_DIV(mKvNumHead, mThreadNum);
std::function<void(int)> dequant = [=](int tid) {
for (int kv_h = tid * tileCount; kv_h < (tid+1) * tileCount && kv_h < mKvNumHead; kv_h++) {
char * dst = dequantedValues->host<char>() + kv_h * UP_DIV(mHeadDim, hP) * mPastLength * hP * mBytes;
char * src = addrOfValue(kv_h);
int8_t * dst = dequantedValues->host<int8_t>() + kv_h * UP_DIV(mHeadDim, hP) * mPastLength * hP * mBytes;
int8_t * src = addrOfValue(kv_h);
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
if (mBytes == 2) {
core->MNNFp8ToFp16((uint16_t*)dst, (uint8_t*)src, mPastLength * hP);

View File

@ -46,8 +46,8 @@ private:
std::shared_ptr<Tensor> mKeySum; // numhead, [maxlen/hP8, hP8]
file_t mKeyCacheFD = INVALID_FILE; // The file descriptor of keys
file_t mValueCacheFD = INVALID_FILE; // The file descriptor of values
char * mMapKeyAddr = nullptr; // Memory-mapped address of keys
char * mMapValueAddr = nullptr; // Memory-mapped address of values
int8_t * mMapKeyAddr = nullptr; // Memory-mapped address of keys
int8_t * mMapValueAddr = nullptr; // Memory-mapped address of values
bool mKVCacheInDisk = false; // Whether the kvcache is in disk or in memory now
int mPastLength = 0; // Length of past kvcache
int mMaxLength = 0; // Capacity of current kvcache buffer (how many kv items can be stored at most)
@ -104,15 +104,15 @@ public:
return mMaxLength;
}
uint8_t* keyAddr() {
char * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host<char>();
int8_t * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host<int8_t>();
return (uint8_t*)baseAddr;
}
uint8_t* valudAddr() {
char * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host<char>();
int8_t * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host<int8_t>();
return (uint8_t*)baseAddr;
}
char * addrOfKey(int kv_h) {
char * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host<char>();
int8_t * addrOfKey(int kv_h) {
int8_t * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host<int8_t>();
if (mConfig.mUseInt8Kernel) {
return baseAddr + kv_h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
} else if (mConfig.mQuantKey) {
@ -121,35 +121,35 @@ public:
return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes;
}
}
char * addrOfValue(int kv_h) {
char * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host<char>();
int8_t * addrOfValue(int kv_h) {
int8_t * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host<int8_t>();
if (mConfig.mQuantValue) {
return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP;
} else {
return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP * mBytes;
}
}
char * addrOfScale(int kv_h) {
int8_t * addrOfScale(int kv_h) {
if (mConfig.mUseInt8Kernel) {
return mKeyScale->host<char>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4;
return mKeyScale->host<int8_t>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4;
} else if (mConfig.mQuantKey) {
return mKeyScale->host<char>() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes;
return mKeyScale->host<int8_t>() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes;
} else {
return nullptr;
}
}
char * addrOfZeroPoint(int kv_h) {
int8_t * addrOfZeroPoint(int kv_h) {
if (mConfig.mUseInt8Kernel) {
return mKeyZeroPoint->host<char>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4;
return mKeyZeroPoint->host<int8_t>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4;
} else if (mConfig.mQuantKey) {
return mKeyZeroPoint->host<char>() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes;
return mKeyZeroPoint->host<int8_t>() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes;
} else {
return nullptr;
}
}
char * addrOfKeySum(int kv_h) {
int8_t * addrOfKeySum(int kv_h) {
if (mConfig.mUseInt8Kernel) {
return mKeySum->host<char>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4;
return mKeySum->host<int8_t>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4;
}else {
return nullptr;
}

View File

@ -73,6 +73,386 @@ void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int32_t* dim) {
}
}
#define EXP_APPROX_MIN_INPUT vdupq_n_f32(-88.0f)
#define EXP_APPROX_MAX_INPUT vdupq_n_f32(88.0f)
#define EXP_APPROX_LN2 vdupq_n_f32(0.69314718056f) // ln(2)
#define EXP_APPROX_LN2_INV vdupq_n_f32(1.44269504089f) // 1/ln(2)
// Fourth-order polynomial approximation coefficients of exp(r):
// P(x) = c4*x^4 + c3*x^3 + c2*x^2 + c1*x + c0
#define EXP_APPROX_C4 vdupq_n_f32(0.0416624f)
#define EXP_APPROX_C3 vdupq_n_f32(0.166665f)
#define EXP_APPROX_C2 vdupq_n_f32(0.500000f)
#define EXP_APPROX_C1 vdupq_n_f32(1.0f)
#define EXP_APPROX_C0 vdupq_n_f32(1.0f)
#ifndef __aarch64__
static inline float32x4_t vrndaq_f32_compat(float32x4_t val) {
const float32x4_t v_zero = vdupq_n_f32(0.0f);
float32x4_t v_truncated = vcvtq_f32_s32(vcvtq_s32_f32(val));
uint32x4_t v_is_positive_frac = vcgtq_f32(val, v_truncated);
uint32x4_t v_is_negative_frac = vcltq_f32(val, v_truncated);
float32x4_t v_offset = vbslq_f32(v_is_positive_frac, vdupq_n_f32(1.0f), v_zero);
v_offset = vbslq_f32(v_is_negative_frac, vdupq_n_f32(-1.0f), v_offset);
return vaddq_f32(v_truncated, v_offset);
}
#endif
static inline float32x4_t expApprox(float32x4_t x) {
x = vminq_f32(vmaxq_f32(x, EXP_APPROX_MIN_INPUT), EXP_APPROX_MAX_INPUT);
float32x4_t k_float;
float32x4_t r;
float32x4_t exp_r;
#if defined(__aarch64__)
// 1. x = k * ln(2) + r
k_float = vrndaq_f32(vmulq_f32(x, EXP_APPROX_LN2_INV));
// r = x - k * ln(2)
r = vfmsq_f32(x, k_float, EXP_APPROX_LN2);
// 2. c0 + r*(c1 + r*(c2 + r*(c3 + r*c4))) (Horner's method)
exp_r = vfmaq_f32(EXP_APPROX_C3, EXP_APPROX_C4, r); // c3 + c4*r
exp_r = vfmaq_f32(EXP_APPROX_C2, exp_r, r); // c2 + r*(...)
exp_r = vfmaq_f32(EXP_APPROX_C1, exp_r, r); // c1 + r*(...)
exp_r = vfmaq_f32(EXP_APPROX_C0, exp_r, r); // c0 + r*(...)
#else
k_float = vrndaq_f32_compat(vmulq_f32(x, EXP_APPROX_LN2_INV));
r = vsubq_f32(x, vmulq_f32(k_float, EXP_APPROX_LN2));
// 2. c0 + r*(c1 + r*(c2 + r*(c3 + r*c4)))
exp_r = vmlaq_f32(EXP_APPROX_C3, EXP_APPROX_C4, r); // c3 + c4*r
exp_r = vmlaq_f32(EXP_APPROX_C2, exp_r, r); // c2 + r*(...)
exp_r = vmlaq_f32(EXP_APPROX_C1, exp_r, r); // c1 + r*(...)
exp_r = vmlaq_f32(EXP_APPROX_C0, exp_r, r); // c0 + r*(...)
#endif
int32x4_t k_int = vcvtq_s32_f32(k_float);
int32x4_t k_shifted = vshlq_n_s32(k_int, 23);
float32x4_t result = vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(exp_r), k_shifted));
return result;
}
void MNNExpC8(float* dst, const float* src, float* offset, const float* parameters, size_t countC8) {
float32x4_t maxVec = vdupq_n_f32(offset[2]);
float32x4_t sumVec0 = vdupq_n_f32(0);
float32x4_t sumVec1 = vdupq_n_f32(0);
float32x4_t c0 = vdupq_n_f32(offset[0]);
float32x4_t c1 = vdupq_n_f32(offset[1]);
for (int i = 0; i < countC8; ++i) {
float32x4_t srcVec0 = vld1q_f32(src);
float32x4_t srcVec1 = vld1q_f32(src + 4);
auto subVec0 = vaddq_f32(vmulq_f32(srcVec0, c0), maxVec);
auto subVec1 = vaddq_f32(vmulq_f32(srcVec1, c0), maxVec);
auto expVec0 = vaddq_f32(expApprox(subVec0), c1);
auto expVec1 = vaddq_f32(expApprox(subVec1), c1);
vst1q_f32(dst, expVec0);
vst1q_f32(dst + 4, expVec1);
sumVec0 = vaddq_f32(sumVec0, expVec0);
sumVec1 = vaddq_f32(sumVec1, expVec1);
src += 8;
dst += 8;
}
sumVec0 = vaddq_f32(sumVec0, sumVec1);
float32x2_t sumP = vpadd_f32(vget_low_f32(sumVec0), vget_high_f32(sumVec0));
sumP = vpadd_f32(sumP, sumP);
offset[3] += vget_lane_f32(sumP, 0);
}
void MNNExp(float* destPtr, const float* srcPtr, float* offset, size_t size) {
float32x4_t maxVec = vdupq_n_f32(-offset[2]);
float32x4_t sumVec0 = vdupq_n_f32(0);
float32x4_t sumVec1 = vdupq_n_f32(0);
if (offset[0] == 1.f && offset[1] == 0.f) {
while (size >= 8) {
float32x4_t srcVec0 = vld1q_f32(srcPtr);
float32x4_t srcVec1 = vld1q_f32(srcPtr + 4);
auto subVec0 = vsubq_f32(srcVec0, maxVec);
auto subVec1 = vsubq_f32(srcVec1, maxVec);
auto expVec0 = expApprox(subVec0);
auto expVec1 = expApprox(subVec1);
vst1q_f32(destPtr, expVec0);
vst1q_f32(destPtr + 4, expVec1);
sumVec0 = vaddq_f32(sumVec0, expVec0);
sumVec1 = vaddq_f32(sumVec1, expVec1);
srcPtr += 8;
destPtr += 8;
size -= 8;
}
while (size >= 4) {
float32x4_t srcVec0 = vld1q_f32(srcPtr);
auto subVec0 = vsubq_f32(srcVec0, maxVec);
auto expVec0 = expApprox(subVec0);
sumVec0 = vaddq_f32(sumVec0, expVec0);
vst1q_f32(destPtr, expVec0);
srcPtr += 4;
destPtr += 4;
size -= 4;
}
//merge
sumVec0 = vaddq_f32(sumVec0, sumVec1);
float32x2_t sumP = vpadd_f32(vget_low_f32(sumVec0), vget_high_f32(sumVec0));
sumP = vpadd_f32(sumP, sumP);
auto newSum = vget_lane_f32(sumP, 0);
if (size > 0) {
float tmp[4];
memcpy(tmp, srcPtr, size * sizeof(float));
float32x4_t srcVec0 = vld1q_f32(tmp);
auto subVec0 = vsubq_f32(srcVec0, maxVec);
auto expVec0 = expApprox(subVec0);
vst1q_f32(tmp, expVec0);
for (int i = 0; i < size; ++i) {
newSum += tmp[i];
destPtr[i] = tmp[i];
}
}
offset[3] += newSum;
} else {
float32x4_t c0 = vdupq_n_f32(offset[0]);
float32x4_t c1 = vdupq_n_f32(offset[1]);
while (size >= 8) {
float32x4_t srcVec0 = vld1q_f32(srcPtr);
float32x4_t srcVec1 = vld1q_f32(srcPtr + 4);
auto subVec0 = vsubq_f32(vmulq_f32(srcVec0, c0), maxVec);
auto subVec1 = vsubq_f32(vmulq_f32(srcVec1, c0), maxVec);
auto expVec0 = vaddq_f32(expApprox(subVec0), c1);
auto expVec1 = vaddq_f32(expApprox(subVec1), c1);
vst1q_f32(destPtr, expVec0);
vst1q_f32(destPtr + 4, expVec1);
sumVec0 = vaddq_f32(sumVec0, expVec0);
sumVec1 = vaddq_f32(sumVec1, expVec1);
srcPtr += 8;
destPtr += 8;
size -= 8;
}
while (size >= 4) {
float32x4_t srcVec0 = vld1q_f32(srcPtr);
auto subVec0 = vsubq_f32(vmulq_f32(srcVec0, c0), maxVec);
auto expVec0 = vaddq_f32(expApprox(subVec0), c1);
sumVec0 = vaddq_f32(sumVec0, expVec0);
vst1q_f32(destPtr, expVec0);
srcPtr += 4;
destPtr += 4;
size -= 4;
}
//merge
sumVec0 = vaddq_f32(sumVec0, sumVec1);
float32x2_t sumP = vpadd_f32(vget_low_f32(sumVec0), vget_high_f32(sumVec0));
sumP = vpadd_f32(sumP, sumP);
auto newSum = vget_lane_f32(sumP, 0);
if (size > 0) {
float tmp[4];
memcpy(tmp, srcPtr, size * sizeof(float));
float32x4_t srcVec0 = vld1q_f32(tmp);
auto subVec0 = vsubq_f32(vmulq_f32(srcVec0, c0), maxVec);
auto expVec0 = vaddq_f32(expApprox(subVec0), c1);
vst1q_f32(tmp, expVec0);
for (int i = 0; i < size; ++i) {
newSum += tmp[i];
destPtr[i] = tmp[i];
}
}
offset[3] += newSum;
}
}
static inline void transposeAndStore4x4(const float* srcRowPtrs[4], float* dstColBase, size_t dstColStride) {
float32x4_t row0 = vld1q_f32(srcRowPtrs[0]);
float32x4_t row1 = vld1q_f32(srcRowPtrs[1]);
float32x4_t row2 = vld1q_f32(srcRowPtrs[2]);
float32x4_t row3 = vld1q_f32(srcRowPtrs[3]);
// Step 1: Transpose 2x2 blocks of 2-element vectors
float32x4x2_t t01 = vtrnq_f32(row0, row1);
float32x4x2_t t23 = vtrnq_f32(row2, row3);
// Step 2: Combine the results to get the full transpose
float32x4_t col0 = vcombine_f32(vget_low_f32(t01.val[0]), vget_low_f32(t23.val[0]));
float32x4_t col1 = vcombine_f32(vget_low_f32(t01.val[1]), vget_low_f32(t23.val[1]));
float32x4_t col2 = vcombine_f32(vget_high_f32(t01.val[0]), vget_high_f32(t23.val[0]));
float32x4_t col3 = vcombine_f32(vget_high_f32(t01.val[1]), vget_high_f32(t23.val[1]));
vst1q_f32(dstColBase, col0);
vst1q_f32(dstColBase + dstColStride, col1);
vst1q_f32(dstColBase + 2 * dstColStride, col2);
vst1q_f32(dstColBase + 3 * dstColStride, col3);
}
void packKvCache(float* dst, const float* src, size_t seqLen, size_t kvSeqLen, size_t eP) {
if (seqLen == 0 || kvSeqLen == 0) {
return;
}
// source [seqLen, kvSeqLen]
// dest [seqLen/eP, kvSeqLen, eP]
const int kTileS = 4; // Tiling size for seqLen dimension
const int kTileK = 4; // Tiling size for kvSeqLen dimension
const size_t dstSOuterStride = kvSeqLen * eP;
int s = 0;
for (; s + kTileS <= seqLen; s += kTileS) {
const int sOuter = s / eP;
const int sInner = s % eP;
if (sInner + kTileS > eP) {
break;
}
float* dstSBase = dst + sOuter * dstSOuterStride + sInner;
const float* srcRowPtrs[kTileS];
srcRowPtrs[0] = src + (s + 0) * kvSeqLen;
srcRowPtrs[1] = src + (s + 1) * kvSeqLen;
srcRowPtrs[2] = src + (s + 2) * kvSeqLen;
srcRowPtrs[3] = src + (s + 3) * kvSeqLen;
int k = 0;
for (; k + kTileK <= kvSeqLen; k += kTileK) {
const float* currentSrcPtrs[kTileS];
currentSrcPtrs[0] = srcRowPtrs[0] + k;
currentSrcPtrs[1] = srcRowPtrs[1] + k;
currentSrcPtrs[2] = srcRowPtrs[2] + k;
currentSrcPtrs[3] = srcRowPtrs[3] + k;
float* dstKBase = dstSBase + k * eP;
transposeAndStore4x4(currentSrcPtrs, dstKBase, eP);
}
for (; k < kvSeqLen; ++k) {
float buffer[kTileS] = {
srcRowPtrs[0][k],
srcRowPtrs[1][k],
srcRowPtrs[2][k],
srcRowPtrs[3][k]
};
vst1q_f32(dstSBase + k * eP, vld1q_f32(buffer));
}
}
for (; s < seqLen; ++s) {
const int sOuter = s / eP;
const int sInner = s % eP;
const float* srcRow = src + s * kvSeqLen;
float* dstSBase = dst + sOuter * dstSOuterStride + sInner;
for (int k = 0; k < kvSeqLen; ++k) {
dstSBase[k * eP] = srcRow[k];
}
}
}
void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) {
for (int k = 0; k < outside; ++k) {
auto source = input + k * reduceSize;
auto dest = softmaxDst + k * reduceSize;
// new max
auto srcPtr = source;
auto size = reduceSize;
float32x4_t maxVec0 = vdupq_n_f32(source[0]);
auto maxVec1 = maxVec0;
float oldMax = source[0];
if (runningMax) {
oldMax = runningMax[k];
}
while (size >= 8) {
float32x4_t srcVec0 = vld1q_f32(srcPtr);
float32x4_t srcVec1 = vld1q_f32(srcPtr + 4);
maxVec0 = vmaxq_f32(maxVec0, srcVec0);
maxVec1 = vmaxq_f32(maxVec1, srcVec1);
srcPtr += 8;
size -= 8;
}
while (size >= 4) {
float32x4_t srcVec0 = vld1q_f32(srcPtr);
maxVec0 = vmaxq_f32(maxVec0, srcVec0);
srcPtr += 4;
size -= 4;
}
maxVec0 = vmaxq_f32(maxVec0, maxVec1);
float32x2_t maxP = vpmax_f32(vget_low_f32(maxVec0), vget_high_f32(maxVec0));
maxP = vpmax_f32(maxP, maxP);
auto newMax = vget_lane_f32(maxP, 0);
while (size > 0) {
newMax = ALIMAX(newMax, srcPtr[0]);
srcPtr += 1;
size -= 1;
}
newMax = ALIMAX(oldMax, newMax);
srcPtr = source;
auto destPtr = dest;
size = reduceSize;
float exprOffset[4] = {
1.0f,
0.0f,
0.0f,
0.0f
};
exprOffset[2] = -newMax;
// expf(xi-newmax) & new sum
MNNExp(destPtr, srcPtr, exprOffset, size);
if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) {
// update runningSum, runningMax, scale=expf(oldMax-newMax)
float newSum = exprOffset[3];
runningSum[k] = runningSum[k] * expf(oldMax - newMax) + newSum;
runningMax[k] = newMax;
updateScale[k] = expf(oldMax - newMax);
} else {
// Normalize
float sum = exprOffset[3];
float scale = 1.0f / (sum + 1e-20f);
int count = reduceSize;
auto pDest = dest;
float32x4_t scaleVec = vdupq_n_f32(scale);
while (count >= 4) {
float32x4_t data = vld1q_f32(pDest);
data = vmulq_f32(data, scaleVec);
vst1q_f32(pDest, data);
pDest += 4;
count -= 4;
}
while (count > 0) {
*pDest *= scale;
pDest++;
count--;
}
}
}
}
#ifndef MNN_USE_NEON
void MNNPackedSparseMatMulEpx1(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, unsigned int* NNZMap, int* dataOffsetMap) {

View File

@ -1,115 +0,0 @@
//
// MNNExpC8.S
// MNN
//
// Created by MNN on 2019/01/18.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifdef __arm__
#ifndef __aarch64__
#include "MNNAsmGlobal.h"
.text
.align 5
//void MNNExpC8(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8)
asm_function MNNExpC8
//r0: dest, r1:source, r2: offset, r3:parameters, r4:countC8
push {r4, r5, lr}
ldr r4, [sp, #12]
vpush {q4-q7}
vmov.i32 q7, #0
ldr r5, [r2, #0]
vdup.32 q4, r5 // Alpha
ldr r5, [r2, #4]
vdup.32 q5, r5 // Beta
ldr r5, [r2, #8]
vdup.32 q6, r5 // Bias
vld1.32 {q0, q1}, [r3]
vmov.i32 q2, #87
vcvt.f32.s32 q2, q2
vneg.f32 q3, q2
Loop:
vld1.32 {q8, q9}, [r1]!
vmul.f32 q8, q8, q4
vmul.f32 q9, q9, q4
vadd.f32 q8, q8, q6
vadd.f32 q9, q9, q6
vmin.f32 q8, q8, q2
vmin.f32 q9, q9, q2
vmax.f32 q10, q8, q3
vmax.f32 q11, q9, q3
vmul.f32 q8, q10, d0[1]
vmul.f32 q9, q11, d0[1]
vcvt.s32.f32 q8, q8
vcvt.s32.f32 q9, q9
vcvt.f32.s32 q12, q8
vcvt.f32.s32 q13, q9
//q10, q11: t
vmls.f32 q10, q12, d0[0]
vmls.f32 q11, q13, d0[0]
vmul.f32 q10, q10, d1[0]
vmul.f32 q11, q11, d1[0]
.macro MLA_TWO z0 z1 z2 z3
vdup.32 \z1, \z0
vmla.f32 \z1, \z2, \z3
.endm
MLA_TWO d3[0], q12, q10, d3[1]
MLA_TWO d3[0], q13, q11, d3[1]
MLA_TWO d2[1], q14, q10, q12
MLA_TWO d2[1], q15, q11, q13
MLA_TWO d2[0], q12, q10, q14
MLA_TWO d2[0], q13, q11, q15
MLA_TWO d1[1], q14, q10, q12
MLA_TWO d1[1], q15, q11, q13
MLA_TWO d1[1], q12, q10, q14
MLA_TWO d1[1], q13, q11, q15
//q12, q13 is expRemain
vmul.f32 q12, q12, q12
vmul.f32 q13, q13, q13
vmul.f32 q12, q12, q12
vmul.f32 q13, q13, q13
vshl.i32 q8, q8, #23
vshl.i32 q9, q9, #23
vadd.i32 q12, q12, q8
vadd.i32 q13, q13, q9
vadd.f32 q12, q12, q5
vadd.f32 q13, q13, q5
vadd.f32 q7, q12, q7
vst1.32 {q12, q13}, [r0]!
vadd.f32 q7, q13, q7
subs r4, r4, #1
bne Loop
add r5, r2, #12
vld1.32 {d0[0]}, [r5]
vadd.f32 d14, d14, d15
vtrn.32 d14, d15
vadd.f32 d14, d14, d15
vadd.f32 d0, d14, d0
vst1.32 {d0[0]}, [r5]
vpop {q4-q7}
pop {r4, r5, pc}
#endif
#endif

View File

@ -1,270 +0,0 @@
//
// MNNSoftmax.S
// MNN
//
// Created by MNN on 2021/07/05.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifdef __arm__
#ifndef __aarch64__
#include "MNNAsmGlobal.h"
.text
.align 5
//void MNNSoftmax(float* dest, const float* source, size_t size)
asm_function MNNSoftmax
push {r4, r5, r6, r7, r8, lr}
bic r3, r2, #3
lsrs r4, r2, #2
vpush.64 {d8, d9, d10, d11, d12, d13, d14, d15}
sub sp, sp, #8
beq Loop_2
vld1.32 {d16-d17}, [r1]
cmp r4, #1
beq Loop_3
add lr, r1, #16
mov ip, #1
Loop_4:
vld1.32 {d18-d19}, [lr]!
add ip, ip, #1
cmp r4, ip
vmax.f32 q8, q8, q9
bne Loop_4
Loop_3:
vmov.32 ip, d16[0]
vmov s15, ip
vmov.32 ip, d16[1]
vmov s12, ip
vmov.32 ip, d17[0]
vcmpe.f32 s15, s12
vmov s13, ip
vmov.32 ip, d17[1]
vmrs APSR_nzcv, FPSCR
vmov s14, ip
vmovle.f32 s15, s12
vcmpe.f32 s13, s15
vmrs APSR_nzcv, FPSCR
vmovpl.f32 s15, s13
vcmpe.f32 s14, s15
vmrs APSR_nzcv, FPSCR
vmovpl.f32 s15, s14
cmp r2, r3
bls Loop_25
Loop_24:
add lr, r1, r3, lsl #2
mov ip, r3
Loop_11:
vldmia.32 lr!, {s14}
add ip, ip, #1
vcmpe.f32 s14, s15
vmrs APSR_nzcv, FPSCR
vmovpl.f32 s15, s14
cmp r2, ip
bhi Loop_11
cmp r4, #0
beq Loop_8
Loop_25:
vmov.f32 q11, #0.0 @ v4sf
mov r5, r1
vldr d14, Loop_54
vldr d15, Loop_54+8
mov lr, r0
vldr d12, Loop_54+16
vldr d13, Loop_54+24
mov ip, #0
vmov.f32 q12, #1.0e+0 @ v4sf
vstr.32 s15, [sp, #4]
vmov.f32 q5, #5.0e-1 @ v4sf
vldr d8, Loop_54+32
vldr d9, Loop_54+40
vldr d0, Loop_54+48
vldr d1, Loop_54+56
vldr d2, Loop_54+64
vldr d3, Loop_54+72
vldr d4, Loop_54+80
vldr d5, Loop_54+88
vldr d30, Loop_54+96
vldr d31, Loop_54+104
vmov.i32 q14, #8388608 @ v4si
vdup.32 q13, d7[1]
Loop_12:
vld1.32 {d20-d21}, [r5]!
add ip, ip, #1
vmov.i32 q3, #127 @ v4si
cmp r4, ip
vsub.f32 q10, q10, q13
vmax.f32 q10, q10, q15
vmin.f32 q10, q10, q2
vmul.f32 q9, q10, q6
vcvt.s32.f32 q9, q9
vcvt.f32.s32 q8, q9
vadd.i32 q9, q9, q3
vmul.f32 q8, q8, q7
vmul.i32 q9, q9, q14
vsub.f32 q10, q10, q8
vmul.f32 q8, q1, q10
vadd.f32 q8, q8, q0
vmul.f32 q8, q8, q10
vadd.f32 q8, q8, q4
vmul.f32 q8, q8, q10
vadd.f32 q8, q8, q5
vmul.f32 q8, q8, q10
vadd.f32 q8, q8, q12
vmul.f32 q8, q8, q10
vadd.f32 q8, q8, q12
vmul.f32 q9, q9, q8
vadd.f32 q11, q9, q11
vst1.32 {d18-d19}, [lr]!
bgt Loop_12
vmov.32 ip, d22[0]
vldr.32 s15, [sp, #4]
cmp r2, r3
vmov s12, ip
vmov.32 ip, d22[1]
vmov s11, ip
vmov.32 ip, d23[0]
vadd.f32 s12, s12, s11
vmov s13, ip
vmov.32 ip, d23[1]
vadd.f32 s12, s12, s13
vmov s14, ip
vadd.f32 s12, s12, s14
bls Loop_52
Loop_26:
lsl ip, r3, #2
add r5, r1, r2, lsl #2
add lr, r1, ip
movw r7, #13877
movt r7, 179
movw r6, #55317
movt r6, 32310
vldr.32 s11, Loop_54+112
vldr.32 s10, Loop_54+116
add r1, r0, ip
vldr.32 s9, Loop_54+120
vmov.f32 s8, #5.0e-1
vldr.32 s5, Loop_54+124
vldr.32 s6, Loop_54+128
vldr.32 s7, Loop_54+132
Loop_17:
vldmia.32 lr!, {s14}
vmov s13, r6
vsub.f32 s14, s14, s15
vcmpe.f32 s14, s11
vmrs APSR_nzcv, FPSCR
vmovle s13, r7
ble Loop_14
vcmpe.f32 s14, s10
vmrs APSR_nzcv, FPSCR
bmi Loop_53
Loop_14:
vadd.f32 s12, s12, s13
cmp r5, lr
vstmia.32 r1!, {s13}
bne Loop_17
vmov.f32 s15, #1.0e+0
cmp r4, #0
vdiv.f32 s14, s15, s12
beq Loop_20
Loop_23:
vdup.32 q9, d7[0]
mov ip, r0
mov r1, #0
Loop_19:
vld1.32 {d16-d17}, [ip]
add r1, r1, #1
cmp r4, r1
vmul.f32 q8, q8, q9
vst1.32 {d16-d17}, [ip]!
bgt Loop_19
cmp r2, r3
bls Loop_1
lsl ip, r3, #2
Loop_20:
add r0, r0, ip
Loop_21:
vldr.32 s15, [r0]
add r3, r3, #1
cmp r2, r3
vmul.f32 s15, s15, s14
vstmia.32 r0!, {s15}
bhi Loop_21
Loop_1:
add sp, sp, #8
vldm sp!, {d8-d15}
pop {r4, r5, r6, r7, r8, pc}
Loop_2:
vldr.32 s15, Loop_54+136
cmp r2, r3
bhi Loop_24
Loop_8:
cmp r2, r3
vldrhi.32 s12, Loop_54+136
bhi Loop_26
b Loop_1
Loop_52:
vmov.f32 s15, #1.0e+0
cmp r4, #0
vdiv.f32 s14, s15, s12
bne Loop_23
b Loop_1
Loop_53:
vdiv.f32 s4, s14, s9
vmov.f32 s13, #1.0e+0
vcvt.s32.f32 s4, s4
vcvt.f32.s32 s3, s4
vmov r8, s4 @ int
vmov.f32 s4, s7
vmls.f32 s14, s3, s9
vmov.f32 s3, s6
add r8, r8, #127
lsl r8, r8, #23
vmla.f32 s3, s14, s5
vmla.f32 s4, s3, s14
vmov.f32 s3, s8
vmla.f32 s3, s4, s14
vmov.f32 s4, s13
vmla.f32 s4, s3, s14
vmla.f32 s13, s4, s14
vmov s14, r8
vmul.f32 s13, s13, s14
b Loop_14
Loop_54:
.word 1060205080
.word 1060205080
.word 1060205080
.word 1060205080
.word 1069066811
.word 1069066811
.word 1069066811
.word 1069066811
.word 1042983595
.word 1042983595
.word 1042983595
.word 1042983595
.word 1026206379
.word 1026206379
.word 1026206379
.word 1026206379
.word 1007192201
.word 1007192201
.word 1007192201
.word 1007192201
.word 1118699520
.word 1118699520
.word 1118699520
.word 1118699520
.word -1028784128
.word -1028784128
.word -1028784128
.word -1028784128
.word -1028784128
.word 1118699520
.word 1060205080
.word 1007192201
.word 1026206379
.word 1042983595
.word 0
#endif
#endif

View File

@ -1,109 +0,0 @@
//
// MNNExpC8.S
// MNN
//
// Created by MNN on 2019/01/18.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifdef __aarch64__
#include "MNNAsmGlobal.h"
.text
.align 5
//void MNNExpC8(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8)
asm_function MNNExpC8
//x0: dest, x1:source, x2: offset, x3:parameters, x4:countC8
ldr w5, [x2, #0]
ldr w6, [x2, #4]
ldr w7, [x2, #8]
ld1 {v0.4s, v1.4s}, [x3]
movi v2.4s, #23
movi v3.4s, #87
scvtf v3.4s, v3.4s
fneg v4.4s, v3.4s
dup v30.4s, w5
dup v31.4s, w6
dup v29.4s, w7
// Summer
movi v28.4s, #0
Loop:
ld1 {v16.4s, v17.4s}, [x1], #32
fmul v16.4s, v16.4s, v30.4s
fmul v17.4s, v17.4s, v30.4s
fadd v16.4s, v16.4s, v29.4s
fadd v17.4s, v17.4s, v29.4s
fmin v16.4s, v16.4s, v3.4s
fmin v17.4s, v17.4s, v3.4s
fmax v18.4s, v16.4s, v4.4s
fmax v19.4s, v17.4s, v4.4s
fmul v16.4s, v18.4s, v0.s[1]
fmul v17.4s, v19.4s, v0.s[1]
fcvtzs v16.4s, v16.4s
fcvtzs v17.4s, v17.4s
scvtf v20.4s, v16.4s
scvtf v21.4s, v17.4s
//v18.4s, v19.4s: t
fmls v18.4s, v20.4s, v0.s[0]
fmls v19.4s, v21.4s, v0.s[0]
fmul v18.4s, v18.4s, v0.s[2]
fmul v19.4s, v19.4s, v0.s[2]
.macro MLA_TWO z0 z1 z2 z3
dup \z1, \z0
fmla \z1, \z2, \z3
.endm
MLA_TWO v1.s[2], v20.4s, v18.4s, v1.s[3]
MLA_TWO v1.s[2], v21.4s, v19.4s, v1.s[3]
MLA_TWO v1.s[1], v22.4s, v18.4s, v20.4s
MLA_TWO v1.s[1], v23.4s, v19.4s, v21.4s
MLA_TWO v1.s[0], v20.4s, v18.4s, v22.4s
MLA_TWO v1.s[0], v21.4s, v19.4s, v23.4s
MLA_TWO v0.s[3], v22.4s, v18.4s, v20.4s
MLA_TWO v0.s[3], v23.4s, v19.4s, v21.4s
MLA_TWO v0.s[3], v20.4s, v18.4s, v22.4s
MLA_TWO v0.s[3], v21.4s, v19.4s, v23.4s
//v20.4s, v21.4s is expRemain
fmul v20.4s, v20.4s, v20.4s
fmul v21.4s, v21.4s, v21.4s
fmul v20.4s, v20.4s, v20.4s
fmul v21.4s, v21.4s, v21.4s
ushl v16.4s, v16.4s, v2.4s
ushl v17.4s, v17.4s, v2.4s
add v20.4s, v20.4s, v16.4s
add v21.4s, v21.4s, v17.4s
fadd v20.4s, v20.4s, v31.4s
fadd v21.4s, v21.4s, v31.4s
st1 {v20.4s, v21.4s}, [x0], #32
fadd v28.4s, v28.4s, v20.4s
fadd v28.4s, v28.4s, v21.4s
subs x4, x4, #1
bne Loop
// Bias
add x7, x2, #12
ld1 {v27.s}[0], [x7]
faddp v28.4s, v28.4s, v28.4s
faddp v28.2s, v28.2s, v28.2s
fadd v27.2s, v28.2s, v27.2s
st1 {v27.s}[0], [x7]
ret
#endif

View File

@ -1,204 +0,0 @@
//
// MNNSoftmax.S
// MNN
//
// Created by MNN on 2021/07/05.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifdef __aarch64__
#include "MNNAsmGlobal.h"
.text
.align 5
//void MNNSoftmax(float* dest, const float* source, size_t countC8)
asm_function MNNSoftmax
stp x21, x22, [sp, #-32]!
stp x19, x20, [sp, #16]
sxtw x8, w2
lsr w9, w2, #2
and x8, x8, #-4
cbz w9, Loop_5
ldr q0, [x1]
cmp w9, #1
b.eq Loop_4
add x10, x1, #16
sub x11, x9, #1
Loop_3:
ldr q1, [x10], #16
subs x11, x11, #1
fmax v0.4s, v0.4s, v1.4s
b.ne Loop_3
Loop_4:
mov s1, v0.s[1]
fcmp s0, s1
mov s2, v0.s[2]
fcsel s1, s0, s1, gt
fcmp s1, s2
fcsel s1, s1, s2, gt
mov s0, v0.s[3]
fcmp s1, s0
fcsel s0, s1, s0, gt
cmp w8, w2
b.lo Loop_6
b Loop_8
Loop_5:
fmov s0, wzr
cmp w8, w2
b.hs Loop_8
Loop_6:
add x10, x1, w8, sxtw #2
mov w11, w8
Loop_7:
ldr s1, [x10], #4
add w11, w11, #1
fcmp s0, s1
fcsel s0, s0, s1, gt
cmp w11, w2
b.lo Loop_7
Loop_8:
cbz w9, Loop_12
mov w10, #-1028784128
mov w12, #43579
dup v4.4s, w10
mov w10, #29208
mov w11, #1118699520
movk w12, #16312, lsl #16
movk w10, #48945, lsl #16
dup v5.4s, w11
mov w11, #34953
dup v6.4s, w12
mov w12, #43691
dup v7.4s, w10
mov w10, #43691
movk w11, #15368, lsl #16
movk w12, #15658, lsl #16
movk w10, #15914, lsl #16
dup v2.4s, v0.s[0]
movi v1.16b, #0
fmov v3.4s, #1.0
dup v16.4s, w11
dup v17.4s, w12
dup v18.4s, w10
movi v19.4s, #63, lsl #24
mov x10, x9
mov x11, x1
mov x12, x0
Loop_10:
ldr q20, [x11], #16
subs x10, x10, #1
fsub v20.4s, v20.4s, v2.4s
fmax v20.4s, v20.4s, v4.4s
fmin v20.4s, v20.4s, v5.4s
fmul v21.4s, v20.4s, v6.4s
fcvtzs v21.4s, v21.4s
scvtf v22.4s, v21.4s
fmla v20.4s, v22.4s, v7.4s
fmul v22.4s, v20.4s, v16.4s
fadd v22.4s, v22.4s, v17.4s
fmul v22.4s, v20.4s, v22.4s
fadd v22.4s, v22.4s, v18.4s
fmul v22.4s, v20.4s, v22.4s
fadd v22.4s, v22.4s, v19.4s
fmul v22.4s, v20.4s, v22.4s
fadd v22.4s, v22.4s, v3.4s
shl v21.4s, v21.4s, #23
fmul v20.4s, v20.4s, v22.4s
add v21.4s, v21.4s, v3.4s
fadd v20.4s, v20.4s, v3.4s
fmul v20.4s, v20.4s, v21.4s
fadd v1.4s, v1.4s, v20.4s
str q20, [x12], #16
b.ne Loop_10
dup v2.4s, v1.s[1]
dup v3.4s, v1.s[2]
fadd v2.4s, v1.4s, v2.4s
fadd v2.4s, v3.4s, v2.4s
dup v1.4s, v1.s[3]
fadd v1.4s, v1.4s, v2.4s
b Loop_13
Loop_12:
fmov s1, wzr
Loop_13:
cmp w8, w2
fmov s2, #1.0
b.hs Loop_16
lsl x21, x8, #2
mov w12, #29208
mov w14, #34953
mov w15, #43691
mov w19, #43691
mov w10, #-1028784128
mov w11, #1118699520
movk w12, #16177, lsl #16
mov w13, #1065353216
movk w14, #15368, lsl #16
movk w15, #15658, lsl #16
movk w19, #15914, lsl #16
add x20, x1, x21
add x21, x0, x21
fmov s3, #0.5
mov w1, w8
Loop_15:
ldr s4, [x20], #4
fmov s5, w10
fmov s6, w11
fmov s7, w12
fsub s4, s4, s0
fmaxnm s4, s4, s5
fminnm s4, s4, s6
fdiv s6, s4, s7
fcvtzs w3, s6
scvtf s6, w3
fmul s6, s6, s7
fmov s5, w14
fsub s4, s4, s6
fmov s7, w15
fmul s5, s4, s5
fadd s5, s5, s7
fmov s6, w19
fmul s5, s4, s5
fadd s5, s5, s6
fmul s5, s4, s5
fadd s5, s5, s3
fmul s5, s4, s5
fadd s5, s5, s2
add w3, w13, w3, lsl #23
fmul s4, s4, s5
fmov s7, w3
fadd s4, s4, s2
add w1, w1, #1
fmul s4, s4, s7
cmp w1, w2
str s4, [x21], #4
fadd s1, s1, s4
b.lo Loop_15
Loop_16:
fdiv s0, s2, s1
cbz w9, Loop_19
dup v1.4s, v0.s[0]
mov x10, x0
Loop_18:
ldr q2, [x10]
subs x9, x9, #1
fmul v2.4s, v1.4s, v2.4s
str q2, [x10], #16
b.ne Loop_18
Loop_19:
cmp w8, w2
b.hs Loop_22
add x9, x0, w8, sxtw #2
Loop_21:
ldr s1, [x9]
add w8, w8, #1
cmp w8, w2
fmul s1, s0, s1
str s1, [x9], #4
b.lo Loop_21
Loop_22:
ldp x19, x20, [sp, #16]
ldp x21, x22, [sp], #32
ret
#endif

View File

@ -379,10 +379,10 @@ LoopSzEnd_TILE_1:
.inst 0x05b02324 // dup z4.q, z25.q[2]
.inst 0x05f02325 // dup z5.q, z25.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -404,10 +404,10 @@ LoopSzEnd_TILE_1:
.inst 0x05f02325 // dup z5.q, z25.q[3]
.inst 0x05702346 // dup z6.q, z26.q[1]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -431,11 +431,11 @@ LoopSzEnd_TILE_1:
.inst 0x05702346 // dup z6.q, z26.q[1]
.inst 0x05b02347 // dup z7.q, z26.q[2]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -461,11 +461,11 @@ TILE1_STORE48:
.inst 0x05b02347 // dup z7.q, z26.q[2]
.inst 0x05f02348 // dup z8.q, z26.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -492,12 +492,12 @@ TILE1_STORE52:
.inst 0x05b02347 // dup z7.q, z26.q[2]
.inst 0x05f02348 // dup z8.q, z26.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -526,12 +526,12 @@ TILE1_STORE56:
.inst 0x05f02348 // dup z8.q, z26.q[3]
.inst 0x05702369 // dup z9.q, z27.q[1]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -562,13 +562,13 @@ TILE1_STORE60:
.inst 0x05702369 // dup z9.q, z27.q[1]
.inst 0x05b0236a // dup z10.q, z27.q[2]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -601,13 +601,13 @@ TILE1_STORE64:
.inst 0x05b0236a // dup z10.q, z27.q[2]
.inst 0x05f0236b // dup z11.q, z27.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -640,15 +640,14 @@ TILE1_STORE68:
.inst 0x05702369 // dup z9.q, z27.q[1]
.inst 0x05b0236a // dup z10.q, z27.q[2]
.inst 0x05f0236b // dup z11.q, z27.q[3]
.inst 0x0570238c // dup z12.q, z28.q[1]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -667,7 +666,7 @@ TILE1_STORE68:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x9, x6, x4, lsl #4 // +16
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
b TILE1_Dz_End
@ -685,14 +684,15 @@ TILE1_STORE72:
.inst 0x05702369 // dup z9.q, z27.q[1]
.inst 0x05b0236a // dup z10.q, z27.q[2]
.inst 0x05f0236b // dup z11.q, z27.q[3]
.inst 0x0570238c // dup z12.q, z28.q[1]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -711,7 +711,7 @@ TILE1_STORE72:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x9, x6, x4, lsl #4 // +16
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
@ -731,14 +731,15 @@ TILE1_STORE76:
.inst 0x05b0236a // dup z10.q, z27.q[2]
.inst 0x05f0236b // dup z11.q, z27.q[3]
.inst 0x0570238c // dup z12.q, z28.q[1]
.inst 0x05b0238d // dup z13.q, z28.q[2]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -757,8 +758,8 @@ TILE1_STORE76:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
@ -779,14 +780,16 @@ TILE1_STORE80:
.inst 0x05b0236a // dup z10.q, z27.q[2]
.inst 0x05f0236b // dup z11.q, z27.q[3]
.inst 0x0570238c // dup z12.q, z28.q[1]
.inst 0x05b0238d // dup z13.q, z28.q[2]
.inst 0x05f0238e // dup z14.q, z28.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -805,12 +808,13 @@ TILE1_STORE80:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
b TILE1_Dz_End
TILE1_STORE84:
@ -827,14 +831,17 @@ TILE1_STORE84:
.inst 0x05b0236a // dup z10.q, z27.q[2]
.inst 0x05f0236b // dup z11.q, z27.q[3]
.inst 0x0570238c // dup z12.q, z28.q[1]
.inst 0x05b0238d // dup z13.q, z28.q[2]
.inst 0x05f0238e // dup z14.q, z28.q[3]
.inst 0x057023af // dup z15.q, z29.q[1]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -853,12 +860,15 @@ TILE1_STORE84:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
b TILE1_Dz_End
TILE1_STORE88:
@ -876,14 +886,16 @@ TILE1_STORE88:
.inst 0x05f0236b // dup z11.q, z27.q[3]
.inst 0x0570238c // dup z12.q, z28.q[1]
.inst 0x05b0238d // dup z13.q, z28.q[2]
.inst 0x05f0238e // dup z14.q, z28.q[3]
.inst 0x057023af // dup z15.q, z29.q[1]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -902,12 +914,16 @@ TILE1_STORE88:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
b TILE1_Dz_End
TILE1_STORE92:
@ -926,14 +942,17 @@ TILE1_STORE92:
.inst 0x0570238c // dup z12.q, z28.q[1]
.inst 0x05b0238d // dup z13.q, z28.q[2]
.inst 0x05f0238e // dup z14.q, z28.q[3]
.inst 0x057023af // dup z15.q, z29.q[1]
.inst 0x05b023b0 // dup z16.q, z29.q[2]
.inst 0x05f023b1 // dup z17.q, z29.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -952,15 +971,18 @@ TILE1_STORE92:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x8, x23, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
add x11, x5, x4, lsl #3 // +22
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
b TILE1_Dz_End
TILE1_STORE96:
@ -980,14 +1002,16 @@ TILE1_STORE96:
.inst 0x05b0238d // dup z13.q, z28.q[2]
.inst 0x05f0238e // dup z14.q, z28.q[3]
.inst 0x057023af // dup z15.q, z29.q[1]
.inst 0x05b023b0 // dup z16.q, z29.q[2]
.inst 0x05f023b1 // dup z17.q, z29.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -1006,9 +1030,10 @@ TILE1_STORE96:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x8, x23, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
add x11, x5, x4, lsl #3 // +22
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
@ -1016,6 +1041,8 @@ TILE1_STORE96:
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
.inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4]
b TILE1_Dz_End
TILE1_STORE100:
@ -1036,14 +1063,15 @@ TILE1_STORE100:
.inst 0x05f0238e // dup z14.q, z28.q[3]
.inst 0x057023af // dup z15.q, z29.q[1]
.inst 0x05b023b0 // dup z16.q, z29.q[2]
.inst 0x05f023b1 // dup z17.q, z29.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -1062,10 +1090,11 @@ TILE1_STORE100:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x8, x23, x4, lsl #3
add x11, x5, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
add x11, x5, x4, lsl #3 // +22
add x13, x9, x4, lsl #3 // +24
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
@ -1074,6 +1103,8 @@ TILE1_STORE100:
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
.inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4]
.inst 0xe400f5be // st1b {z30.b}, p5, [x13]
b TILE1_Dz_End
TILE1_STORE104:
@ -1095,75 +1126,15 @@ TILE1_STORE104:
.inst 0x057023af // dup z15.q, z29.q[1]
.inst 0x05b023b0 // dup z16.q, z29.q[2]
.inst 0x05f023b1 // dup z17.q, z29.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
.inst 0xe400f521 // st1b {z1.b}, p5, [x9]
.inst 0xe4045522 // st1b {z2.b}, p5, [x9, x4]
.inst 0xe400f5f9 // st1b {z25.b}, p5, [x15]
.inst 0xe40455e3 // st1b {z3.b}, p5, [x15, x4]
.inst 0xe400f504 // st1b {z4.b}, p5, [x8]
.inst 0xe4045505 // st1b {z5.b}, p5, [x8, x4]
.inst 0xe400f57a // st1b {z26.b}, p5, [x11]
.inst 0xe4045566 // st1b {z6.b}, p5, [x11, x4]
.inst 0xe400f5a7 // st1b {z7.b}, p5, [x13]
.inst 0xe40455a8 // st1b {z8.b}, p5, [x13, x4]
.inst 0xe400f6fb // st1b {z27.b}, p5, [x23]
.inst 0xe40456e9 // st1b {z9.b}, p5, [x23, x4]
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x8, x23, x4, lsl #3
add x11, x5, x4, lsl #3
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
.inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4]
b TILE1_Dz_End
TILE1_STORE108:
.inst 0x05702300 // dup z0.q, z24.q[1]
.inst 0x05b02301 // dup z1.q, z24.q[2]
.inst 0x05f02302 // dup z2.q, z24.q[3]
.inst 0x05702323 // dup z3.q, z25.q[1]
.inst 0x05b02324 // dup z4.q, z25.q[2]
.inst 0x05f02325 // dup z5.q, z25.q[3]
.inst 0x05702346 // dup z6.q, z26.q[1]
.inst 0x05b02347 // dup z7.q, z26.q[2]
.inst 0x05f02348 // dup z8.q, z26.q[3]
.inst 0x05702369 // dup z9.q, z27.q[1]
.inst 0x05b0236a // dup z10.q, z27.q[2]
.inst 0x05f0236b // dup z11.q, z27.q[3]
.inst 0x0570238c // dup z12.q, z28.q[1]
.inst 0x05b0238d // dup z13.q, z28.q[2]
.inst 0x05f0238e // dup z14.q, z28.q[3]
.inst 0x057023af // dup z15.q, z29.q[1]
.inst 0x05b023b0 // dup z16.q, z29.q[2]
.inst 0x05f023b1 // dup z17.q, z29.q[3]
.inst 0x057023d2 // dup z18.q, z30.q[1]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -1182,11 +1153,11 @@ TILE1_STORE108:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x8, x23, x4, lsl #3
add x11, x5, x4, lsl #3
add x13, x9, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
add x11, x5, x4, lsl #3 // +22
add x13, x9, x4, lsl #3 // +24
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
@ -1200,7 +1171,8 @@ TILE1_STORE108:
.inst 0xe40455b2 // st1b {z18.b}, p5, [x13, x4]
b TILE1_Dz_End
TILE1_STORE112:
/* oc=108 */
TILE1_STORE108:
.inst 0x05702300 // dup z0.q, z24.q[1]
.inst 0x05b02301 // dup z1.q, z24.q[2]
.inst 0x05f02302 // dup z2.q, z24.q[3]
@ -1222,13 +1194,13 @@ TILE1_STORE108:
.inst 0x057023d2 // dup z18.q, z30.q[1]
.inst 0x05b023d3 // dup z19.q, z30.q[2]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -1247,12 +1219,12 @@ TILE1_STORE108:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x8, x23, x4, lsl #3
add x11, x5, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
add x11, x5, x4, lsl #3 // +22
add x13, x9, x4, lsl #3 // +24
add x23, x15, x4, lsl #3 // +26
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
@ -1267,6 +1239,77 @@ TILE1_STORE108:
.inst 0xe400f6f3 // st1b {z19.b}, p5, [x23]
b TILE1_Dz_End
/* oc=112 */
TILE1_STORE112:
.inst 0x05702300 // dup z0.q, z24.q[1]
.inst 0x05b02301 // dup z1.q, z24.q[2]
.inst 0x05f02302 // dup z2.q, z24.q[3]
.inst 0x05702323 // dup z3.q, z25.q[1]
.inst 0x05b02324 // dup z4.q, z25.q[2]
.inst 0x05f02325 // dup z5.q, z25.q[3]
.inst 0x05702346 // dup z6.q, z26.q[1]
.inst 0x05b02347 // dup z7.q, z26.q[2]
.inst 0x05f02348 // dup z8.q, z26.q[3]
.inst 0x05702369 // dup z9.q, z27.q[1]
.inst 0x05b0236a // dup z10.q, z27.q[2]
.inst 0x05f0236b // dup z11.q, z27.q[3]
.inst 0x0570238c // dup z12.q, z28.q[1]
.inst 0x05b0238d // dup z13.q, z28.q[2]
.inst 0x05f0238e // dup z14.q, z28.q[3]
.inst 0x057023af // dup z15.q, z29.q[1]
.inst 0x05b023b0 // dup z16.q, z29.q[2]
.inst 0x05f023b1 // dup z17.q, z29.q[3]
.inst 0x057023d2 // dup z18.q, z30.q[1]
.inst 0x05b023d3 // dup z19.q, z30.q[2]
.inst 0x05f023d4 // dup z20.q, z30.q[3]
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
.inst 0xe400f521 // st1b {z1.b}, p5, [x9]
.inst 0xe4045522 // st1b {z2.b}, p5, [x9, x4]
.inst 0xe400f5f9 // st1b {z25.b}, p5, [x15]
.inst 0xe40455e3 // st1b {z3.b}, p5, [x15, x4]
.inst 0xe400f504 // st1b {z4.b}, p5, [x8]
.inst 0xe4045505 // st1b {z5.b}, p5, [x8, x4]
.inst 0xe400f57a // st1b {z26.b}, p5, [x11]
.inst 0xe4045566 // st1b {z6.b}, p5, [x11, x4]
.inst 0xe400f5a7 // st1b {z7.b}, p5, [x13]
.inst 0xe40455a8 // st1b {z8.b}, p5, [x13, x4]
.inst 0xe400f6fb // st1b {z27.b}, p5, [x23]
.inst 0xe40456e9 // st1b {z9.b}, p5, [x23, x4]
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
add x11, x5, x4, lsl #3 // +22
add x13, x9, x4, lsl #3 // +24
add x23, x15, x4, lsl #3 // +26
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
.inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4]
.inst 0xe400f5be // st1b {z30.b}, p5, [x13]
.inst 0xe40455b2 // st1b {z18.b}, p5, [x13, x4]
.inst 0xe400f6f3 // st1b {z19.b}, p5, [x23]
.inst 0xe40456f4 // st1b {z20.b}, p5, [x23, x4]
b TILE1_Dz_End
/* oc=116 */
TILE1_STORE116:
.inst 0x05702300 // dup z0.q, z24.q[1]
.inst 0x05b02301 // dup z1.q, z24.q[2]
@ -1290,13 +1333,13 @@ TILE1_STORE108:
.inst 0x05b023d3 // dup z19.q, z30.q[2]
.inst 0x05f023d4 // dup z20.q, z30.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -1315,13 +1358,13 @@ TILE1_STORE108:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x8, x23, x4, lsl #3
add x11, x5, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
add x11, x5, x4, lsl #3 // +22
add x13, x9, x4, lsl #3 // +24
add x23, x15, x4, lsl #3 // +26
add x5, x8, x4, lsl #3 // +28
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
@ -1338,6 +1381,7 @@ TILE1_STORE108:
.inst 0xe400f4bf // st1b {z31.b}, p5, [x5]
b TILE1_Dz_End
/* oc=120 */
TILE1_STORE120:
.inst 0x05702300 // dup z0.q, z24.q[1]
.inst 0x05b02301 // dup z1.q, z24.q[2]
@ -1362,13 +1406,13 @@ TILE1_STORE108:
.inst 0x05f023d4 // dup z20.q, z30.q[3]
.inst 0x057023f5 // dup z21.q, z31.q[1]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -1387,13 +1431,13 @@ TILE1_STORE108:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x8, x23, x4, lsl #3
add x11, x5, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
add x11, x5, x4, lsl #3 // +22
add x13, x9, x4, lsl #3 // +24
add x23, x15, x4, lsl #3 // +26
add x5, x8, x4, lsl #3 // +28
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
@ -1411,6 +1455,7 @@ TILE1_STORE108:
.inst 0xe40454b5 // st1b {z21.b}, p5, [x5, x4]
b TILE1_Dz_End
/* oc=124 */
TILE1_STORE124:
.inst 0x05702300 // dup z0.q, z24.q[1]
.inst 0x05b02301 // dup z1.q, z24.q[2]
@ -1487,6 +1532,7 @@ TILE1_STORE108:
.inst 0xe400f596 // st1b {z22.b}, p5, [x12]
b TILE1_Dz_End
/* oc=128 */
TILE1_STORE128:
.inst 0x05702300 // dup z0.q, z24.q[1]
.inst 0x05b02301 // dup z1.q, z24.q[2]

View File

@ -372,10 +372,10 @@ LoopSzEnd_TILE_1:
.inst 0x05b02324 // dup z4.q, z25.q[2]
.inst 0x05f02325 // dup z5.q, z25.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -397,10 +397,10 @@ LoopSzEnd_TILE_1:
.inst 0x05f02325 // dup z5.q, z25.q[3]
.inst 0x05702346 // dup z6.q, z26.q[1]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -424,11 +424,11 @@ LoopSzEnd_TILE_1:
.inst 0x05702346 // dup z6.q, z26.q[1]
.inst 0x05b02347 // dup z7.q, z26.q[2]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -454,11 +454,11 @@ TILE1_STORE48:
.inst 0x05b02347 // dup z7.q, z26.q[2]
.inst 0x05f02348 // dup z8.q, z26.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -485,12 +485,12 @@ TILE1_STORE52:
.inst 0x05b02347 // dup z7.q, z26.q[2]
.inst 0x05f02348 // dup z8.q, z26.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -519,12 +519,12 @@ TILE1_STORE56:
.inst 0x05f02348 // dup z8.q, z26.q[3]
.inst 0x05702369 // dup z9.q, z27.q[1]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -555,13 +555,13 @@ TILE1_STORE60:
.inst 0x05702369 // dup z9.q, z27.q[1]
.inst 0x05b0236a // dup z10.q, z27.q[2]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -594,13 +594,13 @@ TILE1_STORE64:
.inst 0x05b0236a // dup z10.q, z27.q[2]
.inst 0x05f0236b // dup z11.q, z27.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -633,15 +633,14 @@ TILE1_STORE68:
.inst 0x05702369 // dup z9.q, z27.q[1]
.inst 0x05b0236a // dup z10.q, z27.q[2]
.inst 0x05f0236b // dup z11.q, z27.q[3]
.inst 0x0570238c // dup z12.q, z28.q[1]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -660,7 +659,7 @@ TILE1_STORE68:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x9, x6, x4, lsl #4 // +16
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
b TILE1_Dz_End
@ -678,14 +677,15 @@ TILE1_STORE72:
.inst 0x05702369 // dup z9.q, z27.q[1]
.inst 0x05b0236a // dup z10.q, z27.q[2]
.inst 0x05f0236b // dup z11.q, z27.q[3]
.inst 0x0570238c // dup z12.q, z28.q[1]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -704,7 +704,7 @@ TILE1_STORE72:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x9, x6, x4, lsl #4 // +16
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
@ -724,14 +724,15 @@ TILE1_STORE76:
.inst 0x05b0236a // dup z10.q, z27.q[2]
.inst 0x05f0236b // dup z11.q, z27.q[3]
.inst 0x0570238c // dup z12.q, z28.q[1]
.inst 0x05b0238d // dup z13.q, z28.q[2]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -750,8 +751,8 @@ TILE1_STORE76:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
@ -772,14 +773,16 @@ TILE1_STORE80:
.inst 0x05b0236a // dup z10.q, z27.q[2]
.inst 0x05f0236b // dup z11.q, z27.q[3]
.inst 0x0570238c // dup z12.q, z28.q[1]
.inst 0x05b0238d // dup z13.q, z28.q[2]
.inst 0x05f0238e // dup z14.q, z28.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -798,12 +801,13 @@ TILE1_STORE80:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
b TILE1_Dz_End
TILE1_STORE84:
@ -820,14 +824,16 @@ TILE1_STORE84:
.inst 0x05b0236a // dup z10.q, z27.q[2]
.inst 0x05f0236b // dup z11.q, z27.q[3]
.inst 0x0570238c // dup z12.q, z28.q[1]
.inst 0x05b0238d // dup z13.q, z28.q[2]
.inst 0x05f0238e // dup z14.q, z28.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -846,12 +852,15 @@ TILE1_STORE84:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
b TILE1_Dz_End
TILE1_STORE88:
@ -869,14 +878,16 @@ TILE1_STORE88:
.inst 0x05f0236b // dup z11.q, z27.q[3]
.inst 0x0570238c // dup z12.q, z28.q[1]
.inst 0x05b0238d // dup z13.q, z28.q[2]
.inst 0x05f0238e // dup z14.q, z28.q[3]
.inst 0x057023af // dup z15.q, z29.q[1]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -895,12 +906,16 @@ TILE1_STORE88:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
b TILE1_Dz_End
TILE1_STORE92:
@ -919,14 +934,17 @@ TILE1_STORE92:
.inst 0x0570238c // dup z12.q, z28.q[1]
.inst 0x05b0238d // dup z13.q, z28.q[2]
.inst 0x05f0238e // dup z14.q, z28.q[3]
.inst 0x057023af // dup z15.q, z29.q[1]
.inst 0x05b023b0 // dup z16.q, z29.q[2]
.inst 0x05f023b1 // dup z17.q, z29.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -945,15 +963,18 @@ TILE1_STORE92:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x8, x23, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
add x11, x5, x4, lsl #3 // +22
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
b TILE1_Dz_End
TILE1_STORE96:
@ -973,14 +994,16 @@ TILE1_STORE96:
.inst 0x05b0238d // dup z13.q, z28.q[2]
.inst 0x05f0238e // dup z14.q, z28.q[3]
.inst 0x057023af // dup z15.q, z29.q[1]
.inst 0x05b023b0 // dup z16.q, z29.q[2]
.inst 0x05f023b1 // dup z17.q, z29.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -999,9 +1022,10 @@ TILE1_STORE96:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x8, x23, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
add x11, x5, x4, lsl #3 // +22
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
@ -1009,6 +1033,8 @@ TILE1_STORE96:
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
.inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4]
b TILE1_Dz_End
TILE1_STORE100:
@ -1029,14 +1055,15 @@ TILE1_STORE100:
.inst 0x05f0238e // dup z14.q, z28.q[3]
.inst 0x057023af // dup z15.q, z29.q[1]
.inst 0x05b023b0 // dup z16.q, z29.q[2]
.inst 0x05f023b1 // dup z17.q, z29.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -1055,10 +1082,11 @@ TILE1_STORE100:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x8, x23, x4, lsl #3
add x11, x5, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
add x11, x5, x4, lsl #3 // +22
add x13, x9, x4, lsl #3 // +24
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
@ -1067,6 +1095,8 @@ TILE1_STORE100:
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
.inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4]
.inst 0xe400f5be // st1b {z30.b}, p5, [x13]
b TILE1_Dz_End
TILE1_STORE104:
@ -1088,75 +1118,15 @@ TILE1_STORE104:
.inst 0x057023af // dup z15.q, z29.q[1]
.inst 0x05b023b0 // dup z16.q, z29.q[2]
.inst 0x05f023b1 // dup z17.q, z29.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
.inst 0xe400f521 // st1b {z1.b}, p5, [x9]
.inst 0xe4045522 // st1b {z2.b}, p5, [x9, x4]
.inst 0xe400f5f9 // st1b {z25.b}, p5, [x15]
.inst 0xe40455e3 // st1b {z3.b}, p5, [x15, x4]
.inst 0xe400f504 // st1b {z4.b}, p5, [x8]
.inst 0xe4045505 // st1b {z5.b}, p5, [x8, x4]
.inst 0xe400f57a // st1b {z26.b}, p5, [x11]
.inst 0xe4045566 // st1b {z6.b}, p5, [x11, x4]
.inst 0xe400f5a7 // st1b {z7.b}, p5, [x13]
.inst 0xe40455a8 // st1b {z8.b}, p5, [x13, x4]
.inst 0xe400f6fb // st1b {z27.b}, p5, [x23]
.inst 0xe40456e9 // st1b {z9.b}, p5, [x23, x4]
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x8, x23, x4, lsl #3
add x11, x5, x4, lsl #3
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
.inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4]
b TILE1_Dz_End
TILE1_STORE108:
.inst 0x05702300 // dup z0.q, z24.q[1]
.inst 0x05b02301 // dup z1.q, z24.q[2]
.inst 0x05f02302 // dup z2.q, z24.q[3]
.inst 0x05702323 // dup z3.q, z25.q[1]
.inst 0x05b02324 // dup z4.q, z25.q[2]
.inst 0x05f02325 // dup z5.q, z25.q[3]
.inst 0x05702346 // dup z6.q, z26.q[1]
.inst 0x05b02347 // dup z7.q, z26.q[2]
.inst 0x05f02348 // dup z8.q, z26.q[3]
.inst 0x05702369 // dup z9.q, z27.q[1]
.inst 0x05b0236a // dup z10.q, z27.q[2]
.inst 0x05f0236b // dup z11.q, z27.q[3]
.inst 0x0570238c // dup z12.q, z28.q[1]
.inst 0x05b0238d // dup z13.q, z28.q[2]
.inst 0x05f0238e // dup z14.q, z28.q[3]
.inst 0x057023af // dup z15.q, z29.q[1]
.inst 0x05b023b0 // dup z16.q, z29.q[2]
.inst 0x05f023b1 // dup z17.q, z29.q[3]
.inst 0x057023d2 // dup z18.q, z30.q[1]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -1175,11 +1145,11 @@ TILE1_STORE108:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x8, x23, x4, lsl #3
add x11, x5, x4, lsl #3
add x13, x9, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
add x11, x5, x4, lsl #3 // +22
add x13, x9, x4, lsl #3 // +24
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
@ -1193,7 +1163,8 @@ TILE1_STORE108:
.inst 0xe40455b2 // st1b {z18.b}, p5, [x13, x4]
b TILE1_Dz_End
TILE1_STORE112:
/* oc=108 */
TILE1_STORE108:
.inst 0x05702300 // dup z0.q, z24.q[1]
.inst 0x05b02301 // dup z1.q, z24.q[2]
.inst 0x05f02302 // dup z2.q, z24.q[3]
@ -1215,13 +1186,13 @@ TILE1_STORE108:
.inst 0x057023d2 // dup z18.q, z30.q[1]
.inst 0x05b023d3 // dup z19.q, z30.q[2]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -1240,12 +1211,12 @@ TILE1_STORE108:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x8, x23, x4, lsl #3
add x11, x5, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
add x11, x5, x4, lsl #3 // +22
add x13, x9, x4, lsl #3 // +24
add x23, x15, x4, lsl #3 // +26
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
@ -1260,6 +1231,77 @@ TILE1_STORE108:
.inst 0xe400f6f3 // st1b {z19.b}, p5, [x23]
b TILE1_Dz_End
/* oc=112 */
TILE1_STORE112:
.inst 0x05702300 // dup z0.q, z24.q[1]
.inst 0x05b02301 // dup z1.q, z24.q[2]
.inst 0x05f02302 // dup z2.q, z24.q[3]
.inst 0x05702323 // dup z3.q, z25.q[1]
.inst 0x05b02324 // dup z4.q, z25.q[2]
.inst 0x05f02325 // dup z5.q, z25.q[3]
.inst 0x05702346 // dup z6.q, z26.q[1]
.inst 0x05b02347 // dup z7.q, z26.q[2]
.inst 0x05f02348 // dup z8.q, z26.q[3]
.inst 0x05702369 // dup z9.q, z27.q[1]
.inst 0x05b0236a // dup z10.q, z27.q[2]
.inst 0x05f0236b // dup z11.q, z27.q[3]
.inst 0x0570238c // dup z12.q, z28.q[1]
.inst 0x05b0238d // dup z13.q, z28.q[2]
.inst 0x05f0238e // dup z14.q, z28.q[3]
.inst 0x057023af // dup z15.q, z29.q[1]
.inst 0x05b023b0 // dup z16.q, z29.q[2]
.inst 0x05f023b1 // dup z17.q, z29.q[3]
.inst 0x057023d2 // dup z18.q, z30.q[1]
.inst 0x05b023d3 // dup z19.q, z30.q[2]
.inst 0x05f023d4 // dup z20.q, z30.q[3]
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
.inst 0xe400f521 // st1b {z1.b}, p5, [x9]
.inst 0xe4045522 // st1b {z2.b}, p5, [x9, x4]
.inst 0xe400f5f9 // st1b {z25.b}, p5, [x15]
.inst 0xe40455e3 // st1b {z3.b}, p5, [x15, x4]
.inst 0xe400f504 // st1b {z4.b}, p5, [x8]
.inst 0xe4045505 // st1b {z5.b}, p5, [x8, x4]
.inst 0xe400f57a // st1b {z26.b}, p5, [x11]
.inst 0xe4045566 // st1b {z6.b}, p5, [x11, x4]
.inst 0xe400f5a7 // st1b {z7.b}, p5, [x13]
.inst 0xe40455a8 // st1b {z8.b}, p5, [x13, x4]
.inst 0xe400f6fb // st1b {z27.b}, p5, [x23]
.inst 0xe40456e9 // st1b {z9.b}, p5, [x23, x4]
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
add x11, x5, x4, lsl #3 // +22
add x13, x9, x4, lsl #3 // +24
add x23, x15, x4, lsl #3 // +26
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
.inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4]
.inst 0xe400f5be // st1b {z30.b}, p5, [x13]
.inst 0xe40455b2 // st1b {z18.b}, p5, [x13, x4]
.inst 0xe400f6f3 // st1b {z19.b}, p5, [x23]
.inst 0xe40456f4 // st1b {z20.b}, p5, [x23, x4]
b TILE1_Dz_End
/* oc=116 */
TILE1_STORE116:
.inst 0x05702300 // dup z0.q, z24.q[1]
.inst 0x05b02301 // dup z1.q, z24.q[2]
@ -1283,13 +1325,13 @@ TILE1_STORE108:
.inst 0x05b023d3 // dup z19.q, z30.q[2]
.inst 0x05f023d4 // dup z20.q, z30.q[3]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -1308,13 +1350,13 @@ TILE1_STORE108:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x8, x23, x4, lsl #3
add x11, x5, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
add x11, x5, x4, lsl #3 // +22
add x13, x9, x4, lsl #3 // +24
add x23, x15, x4, lsl #3 // +26
add x5, x8, x4, lsl #3 // +28
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
@ -1331,6 +1373,7 @@ TILE1_STORE108:
.inst 0xe400f4bf // st1b {z31.b}, p5, [x5]
b TILE1_Dz_End
/* oc=120 */
TILE1_STORE120:
.inst 0x05702300 // dup z0.q, z24.q[1]
.inst 0x05b02301 // dup z1.q, z24.q[2]
@ -1355,13 +1398,13 @@ TILE1_STORE108:
.inst 0x05f023d4 // dup z20.q, z30.q[3]
.inst 0x057023f5 // dup z21.q, z31.q[1]
add x9, x6, x4, lsl #1
add x15, x6, x4, lsl #2
add x8, x9, x4, lsl #2
add x11, x6, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #1 // +2
add x15, x6, x4, lsl #2 // +4
add x8, x9, x4, lsl #2 // +6
add x11, x6, x4, lsl #3 // +8
add x13, x9, x4, lsl #3 // +10
add x23, x15, x4, lsl #3 // +12
add x5, x8, x4, lsl #3 // +14
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
@ -1380,13 +1423,13 @@ TILE1_STORE108:
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
add x9, x6, x4, lsl #4
add x15, x13, x4, lsl #3
add x8, x23, x4, lsl #3
add x11, x5, x4, lsl #3
add x13, x9, x4, lsl #3
add x23, x15, x4, lsl #3
add x5, x8, x4, lsl #3
add x9, x6, x4, lsl #4 // +16
add x15, x13, x4, lsl #3 // +18
add x8, x23, x4, lsl #3 // +20
add x11, x5, x4, lsl #3 // +22
add x13, x9, x4, lsl #3 // +24
add x23, x15, x4, lsl #3 // +26
add x5, x8, x4, lsl #3 // +28
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
@ -1404,6 +1447,7 @@ TILE1_STORE108:
.inst 0xe40454b5 // st1b {z21.b}, p5, [x5, x4]
b TILE1_Dz_End
/* oc=124 */
TILE1_STORE124:
.inst 0x05702300 // dup z0.q, z24.q[1]
.inst 0x05b02301 // dup z1.q, z24.q[2]
@ -1480,6 +1524,7 @@ TILE1_STORE108:
.inst 0xe400f596 // st1b {z22.b}, p5, [x12]
b TILE1_Dz_End
/* oc=128 */
TILE1_STORE128:
.inst 0x05702300 // dup z0.q, z24.q[1]
.inst 0x05b02301 // dup z1.q, z24.q[2]

View File

@ -1390,6 +1390,89 @@ void MNNAccumulateSequenceNumber (float* dst, const float* src, int size) {
*dst = sum;
}
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
static void MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) {
// source shape: [headDim/pack, seqLen, pack]
// scale & normalizeScale shape: [seqLen]
// dest shape: [headDim/pack, seqLen, pack]
auto stride0 = plane * pack;
if (idx > 0) {
for (int j = 0; j < depthQuad; ++j) {
for (int i = 0; i < plane; ++i) {
auto dataNew = Vec::load(src + j * stride0 + i * pack);
auto dataOld = Vec::load(dst + j * stride0 + i * pack);
auto s = Vec(scale[i]);
dataNew = Vec::fma(dataNew, dataOld, s);
Vec::save(dst + j * stride0 + i * pack, dataNew);
}
}
} else {
memcpy(dst, src, size * bytes);
}
if (idx == kvBlocks - 1) { // if last subBlock, exp(xi)/sum(exp(xi))
for (int j = 0; j < depthQuad; ++j) {
for (int i = 0; i < plane; ++i) {
auto dataNew = Vec::load(dst + j * stride0 + i * pack);
auto ns = Vec(1.0f / normalizeScale[i]);
dataNew = dataNew * ns;
Vec::save(dst + j * stride0 + i * pack, dataNew);
}
}
}
}
static void MNNAttenPackAndScaleSingleHead(float* dst, const float* srcHeadBase, size_t srcRowStride, const float* scale, const int32_t* units, size_t seqLen, size_t headDim) {
const int32_t eP = units[0];
const int32_t lP = units[1];
if (lP != 1) {
MNN_ERROR("This function only supports lP=1 or 2\n");
return;
}
const float scaleVal = scale[0];
#ifdef MNN_USE_NEON
const float32x4_t vScale = vdupq_n_f32(scaleVal);
#endif
const size_t packedHeadDim = UP_DIV(headDim, lP);
const size_t dstStrideDOuter = (size_t)eP * lP;
const size_t dstStrideSOuter = packedHeadDim * dstStrideDOuter;
for (int s = 0; s < seqLen; ++s) {
const int sOuter = s / eP;
const int sInner = s % eP;
const float* srcRowPtr = srcHeadBase + s * srcRowStride;
float* dstBasePtr = dst + sOuter * dstStrideSOuter + sInner * lP;
size_t d = 0;
#ifdef MNN_USE_NEON
for (; d + 7 < headDim; d += 8) {
float32x4_t sVec0 = vld1q_f32(srcRowPtr + d);
float32x4_t sVec1 = vld1q_f32(srcRowPtr + d + 4);
sVec0 = vmulq_f32(sVec0, vScale);
sVec1 = vmulq_f32(sVec1, vScale);
dstBasePtr[(d + 0) * dstStrideDOuter] = sVec0[0];
dstBasePtr[(d + 1) * dstStrideDOuter] = sVec0[1];
dstBasePtr[(d + 2) * dstStrideDOuter] = sVec0[2];
dstBasePtr[(d + 3) * dstStrideDOuter] = sVec0[3];
dstBasePtr[(d + 4) * dstStrideDOuter] = sVec1[0];
dstBasePtr[(d + 5) * dstStrideDOuter] = sVec1[1];
dstBasePtr[(d + 6) * dstStrideDOuter] = sVec1[2];
dstBasePtr[(d + 7) * dstStrideDOuter] = sVec1[3];
}
#else
for (; d < headDim; ++d) {
dstBasePtr[d * dstStrideDOuter] = srcRowPtr[d] * scaleVal;
}
#endif
}
}
#endif // MNN_SUPPORT_TRANSFORMER_FUSE
#ifndef MNN_USE_NEON
void MNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
*eP = 16;
@ -2619,14 +2702,28 @@ void MNNExpC8(float* dest, const float* source, float* offset, const float* para
offset[3] = summer;
}
void MNNSoftmax(float* dest, const float* source, size_t size) {
float maxValue = ALIMAX(source[0], source[1]);
for (int i = 2; i < size; ++i) {
maxValue = ALIMAX(maxValue, source[i]);
void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) {
for (int k = 0; k < outside; ++k) {
auto source = input + k * reduceSize;
auto dest = softmaxDst + k * reduceSize;
float oldMax = source[0];
if (runningMax) {
oldMax = runningMax[k];
}
float xLimit = 87, param = 0.6931471805599453, sumValue = 0.f;
for (int i = 0; i < size; ++i) {
auto x = source[i] - maxValue;
// find max value of current block
float blockMax =source[0];
for (int i = 1; i < reduceSize; ++i) {
blockMax = ALIMAX(blockMax, source[i]);
}
float newMax = ALIMAX(oldMax, blockMax);
// caculate block's expr(xi-newmax) and update runningMax
float xLimit = 87, param = 0.6931471805599453;
float blockSum = 0.f;
for (int i = 0; i < reduceSize; ++i) {
auto x = source[i] - newMax;
x = x > -xLimit ? x : -xLimit;
x = x < xLimit ? x : xLimit;
@ -2638,11 +2735,21 @@ void MNNSoftmax(float* dest, const float* source, size_t size) {
auto t = xReamin;
auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f;
dest[i] = expBasic * expRemain;
sumValue += dest[i];
blockSum += dest[i];
}
if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) {
// update runningSum, runningMax, scale=expf(oldMax-newMax)
runningSum[k] = runningSum[k] * expf(oldMax - newMax) + blockSum;
runningMax[k] = newMax;
updateScale[k] = expf(oldMax - newMax);
} else {
// Normalize
auto scale = 1.f / blockSum;
for (int i = 0; i < reduceSize; ++i) {
dest[i] *= scale;
}
}
sumValue = 1.f / sumValue;
for (int i = 0; i < size; ++i) {
dest[i] *= sumValue;
}
}
@ -2657,6 +2764,68 @@ void MNNReluInt8(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint)
}
#endif // no MNN_USE_SSE
void MNNExp(float* dst, const float* src, float* offset, size_t dataSize) {
int countC8 = static_cast<int32_t>(dataSize) / 8;
int remain = static_cast<int32_t>(dataSize) % 8;
static const float parameters[] = {
(float)logf(2.0f), 1.0f / (float)logf(2.0f), 0.25f, 1.0f, 0.5f, 1.0f / 6.0f, 1.0f / 24.0f, 1.0f / 120.0f};
if (countC8 > 0) {
// Align to eight so asm is easier to write
MNNExpC8(dst, src, offset, parameters, countC8);
}
if (remain > 0) {
auto param = parameters[0];
float xLimit = 87;
float summer = offset[3];
auto source = src + countC8 * 8;
auto dest = dst + countC8 * 8;
for (int i = 0; i < remain; ++i) {
auto x = source[i] * offset[0] + offset[2];
x = ALIMAX(x, -xLimit);
x = ALIMIN(x, xLimit);
int div = (x * parameters[1]);
int div2 = (div + 127) << 23;
auto xReamin = x - div * param;
float expBasic = *(float*)(&div2);
auto t = xReamin * 0.25f;
auto expRemain =
((((parameters[7] * t + parameters[6]) * t + parameters[5]) * t + parameters[4]) * t + 1.0f) * t +
1.0f;
expRemain = expRemain * expRemain;
expRemain = expRemain * expRemain;
dest[i] = expBasic * expRemain + offset[1];
summer+= dest[i];
}
offset[3] = summer;
}
}
void packKvCache(float* dst, const float* src, size_t seqLen, size_t kvSeqLen, size_t eP) {
if (seqLen == 0 || kvSeqLen == 0) {
return;
}
const size_t dstSOuterStride = kvSeqLen * eP;
for (size_t sBase = 0; sBase < seqLen; sBase += eP) {
const size_t numRowsInBlock = std::min(eP, seqLen - sBase);
const size_t sOuter = sBase / eP;
float* dstSOuterBase = dst + sOuter * dstSOuterStride;
for (size_t k = 0; k < kvSeqLen; ++k) {
float* dstColStart = dstSOuterBase + k * eP;
for (size_t sInner = 0; sInner < numRowsInBlock; ++sInner) {
const size_t s = sBase + sInner;
const float value = src[s * kvSeqLen + k];
dstColStart[sInner] = value;
}
}
}
}
void MNNMaxFloat(float* input, float* maxBuffer, int32_t inputCountUnit) {
for (int i = 0; i < inputCountUnit; i++) {
for (int j = 0; j < UNIT; j++) {
@ -3173,42 +3342,6 @@ void MNNPackTranspose(float* dst, const float* src, size_t area, size_t depth, i
}
}
void MNNExp(float* dst, const float* src, float* offset, size_t dataSize) {
int countC8 = static_cast<int32_t>(dataSize) / 8;
int remain = static_cast<int32_t>(dataSize) % 8;
static const float parameters[] = {
(float)logf(2.0f), 1.0f / (float)logf(2.0f), 0.25f, 1.0f, 0.5f, 1.0f / 6.0f, 1.0f / 24.0f, 1.0f / 120.0f};
if (countC8 > 0) {
// Align to eight so asm is easier to write
MNNExpC8(dst, src, offset, parameters, countC8);
}
if (remain > 0) {
auto param = parameters[0];
float xLimit = 87;
float summer = offset[3];
auto source = src + countC8 * 8;
auto dest = dst + countC8 * 8;
for (int i = 0; i < remain; ++i) {
auto x = source[i] * offset[0] + offset[2];
x = ALIMAX(x, -xLimit);
x = ALIMIN(x, xLimit);
int div = (x * parameters[1]);
int div2 = (div + 127) << 23;
auto xReamin = x - div * param;
float expBasic = *(float*)(&div2);
auto t = xReamin * 0.25f;
auto expRemain =
((((parameters[7] * t + parameters[6]) * t + parameters[5]) * t + parameters[4]) * t + 1.0f) * t +
1.0f;
expRemain = expRemain * expRemain;
expRemain = expRemain * expRemain;
dest[i] = expBasic * expRemain + offset[1];
summer+= dest[i];
}
offset[3] = summer;
}
}
// Lambert's series with 7 divisions
// reference from
// https://varietyofsound.wordpress.com/2011/02/14/efficient-tanh-computation-using-lamberts-continued-fraction/
@ -4011,6 +4144,7 @@ void MNNCoreFunctionInit() {
gCoreFunction->chooseWinoDestUnrollTransform = WinogradFunction::chooseWinoDestUnrollTransform;
gCoreFunction->MNNDeconvRunForLineDepthwise = MNNDeconvRunForLineDepthwise;
gCoreFunction->MNNDeconvRunForUnitDepthWise = MNNDeconvRunForUnitDepthWise;
gCoreFunction->MNNSoftmax = MNNSoftmax;
#ifdef MNN_USE_NEON
gCoreFunction->MNNDepthwiseConvFastKernel = MNNDepthwiseConvFastKernel;
#endif
@ -4019,6 +4153,12 @@ void MNNCoreFunctionInit() {
#ifdef MNN_SUPPORT_QUANT_EXTEND
gCoreFunction->MNNSelectUnaryFunctionForInt8 = CPUUnary::selectForInt8;
#endif
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
gCoreFunction->MNNAttenPackAndScaleSingleHead = MNNAttenPackAndScaleSingleHead;
gCoreFunction->MNNFlashAttentionUpdateBlockOutput = MNNFlashAttentionUpdateBlockOutput;
#endif
gCoreFunction->MNNReluWithSlopeChannel = MNNReluWithSlopeChannel;
gCoreFunction->MNNPoolingAvg = (decltype(gCoreFunction->MNNPoolingAvg))(poolingAvg<float, Vec4, 4>);
// Set min value as 1 << 24

View File

@ -32,12 +32,16 @@ void MNNAbsMaxFP32_Pack4(const float* source, float* absmax, size_t src_depth_qu
void MNNAbsMaxFP32_Pack8(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack);
void MNNQuantScaleFP32(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch);
void MNNDynamicUpdateConvBiasScale(float* newbias, float* oldbias, float* weightKernelSum, float* inputZero, size_t ocQuad);
#endif
#endif // MNN_LOW_MEMORY
#ifdef MNN_SME2
void MNNPackedMatMulRemainFP32_SME2(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
#endif
#endif // __aarch64__
#endif
#endif
void MNNFp32ToFp8(uint8_t* dst, const float* src, size_t size);
void MNNFp8ToFp32(float* dst, const uint8_t* src, size_t size);
void MNNFp16ToFp8(uint8_t* dst, const uint16_t* src, size_t size);
@ -117,7 +121,7 @@ void MNNReluWithSlopeCommon(float* dst, const float* src, size_t size, float slo
void MNNHardSwishCommon(float* dst, const float* src, size_t size);
void MNNGeluCommon(float* dst, const float* src, size_t size);
void MNNGeluStandardCommon(float* dst, const float* src, size_t size);
void MNNSoftmax(float* dest, const float* source, size_t size);
void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize);
void MNNNorm(float* dest, const float* source, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm = false);
// Get Pack for MatMul's e , l , h , the pack number must be 1 or 4 * n
@ -187,6 +191,8 @@ void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const floa
void MNNCopyC4Int16WithStride(const float* sourceF, float* destF, size_t srcStride, size_t dstStride, size_t count);
void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count);
void packKvCache(float* dst, const float* src, size_t seqLen, size_t kvSeqLen, size_t eP);
struct SumByAxisParams {
ssize_t kernelCountUnitDouble;
ssize_t unitColBufferSize;
@ -401,6 +407,13 @@ struct CoreFunctions {
void(*MNNSumWeightInt8)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP);
void(*MNNSumWeightInt8SmeHp64)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP);
// Attention
void(*MNNAttenUnpackAndConvertFp16)(float* dst, float* src, size_t depth, size_t planesize, int pack);
void(*MNNAttenPackAndConvertFp32)(float* dst, float* src, const int32_t* units, size_t depth, size_t planesize);
void(*MNNAttenPackAndScaleSingleHead)(float* dst, const float* srcHeadBase, size_t srcRowStride, const float* scale, const int32_t* units, size_t seqLen, size_t headDim);
void(*MNNFlashAttentionUpdateBlockOutput)(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes);
void(*MNNSoftmax)(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize);
MatmulRelatedFunctions int8MatmulRelatedFunctions;
MatmulRelatedFunctions sme2Int8MatmulRelatedFuncionsHp32;
};

View File

@ -805,17 +805,16 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector<Tensor*>& input
return OUT_OF_MEMORY;
}
if (mOnlineReorderWeightSme && planeSize > 1) { // only prefill need
int dstHp = 32; // if mOnlineReorderWeightSme==true, UNIT is 64 when model loaded.
int weightlenNew = ROUND_UP(outC, dstHp) * mBlockNum * ROUND_UP(ic / mBlockNum, SRC_UNIT) * kernelCount * SRC_UNIT * dstHp;
int weightlenNew = ROUND_UP(outC, SME_DECODE_MAXHP) * mBlockNum * ROUND_UP(ic / mBlockNum, SRC_UNIT) * kernelCount;
if (mResourceInt8->mActBits == 4) {
weightlenNew /= 2;
}
mWeight4Prefill = bufferAlloc->alloc(weightlenNew + 2 * ROUND_UP(outC, dstHp) * QUANT_INFO_BYTES);
mWeight4Prefill = bufferAlloc->alloc(weightlenNew + 2 * mBlockNum * ROUND_UP(outC, SME_DECODE_MAXHP) * QUANT_INFO_BYTES);
if (mWeight4Prefill.invalid()) {
return OUT_OF_MEMORY;
}
if (mInputBlockNum > 1) { // only in this case, need to use weight_kernel_sum
mWeightKernelSum4Prefill = bufferAlloc->alloc(ROUND_UP(outC, 32) * mBlockNum * sizeof(float));
mWeightKernelSum4Prefill = bufferAlloc->alloc(ROUND_UP(outC, SME_DECODE_MAXHP) * mBlockNum * sizeof(float));
if (mWeightKernelSum4Prefill.invalid()) {
return OUT_OF_MEMORY;
}

View File

@ -63,6 +63,7 @@ bool AVX2Functions::init(int cpuFlags) {
// Dynamic Quant
coreFunction->MNNCountMaxMinValue = _AVX_MNNCountMinMaxValue;
coreFunction->MNNSoftmax = _AVX_MNNSoftmax;
// For Packed Functions
coreFunction->pack = 8;

View File

@ -24,7 +24,7 @@ struct FunctionGroup {
int lP = 1;
int hP = 4;
void (*MNNExpC8)(float* dest, const float* source, float* offset, const float* parameters, size_t countC8) = _SSE_MNNExpC8;
void (*MNNSoftmax)(float* dest, const float* source, size_t size) = _SSE_MNNSoftmax;
void (*MNNSoftmax)(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) = _SSE_MNNSoftmax;
void (*MNNReluInt8)(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint) = _SSE_MNNReluInt8;
void (*MNNHardSwish)(float* dst, const float* src, size_t size) = _SSE_MNNHardSwish;
void (*MNNGelu)(float* dst, const float* src, size_t size, float* parameters) = _SSE_MNNGelu;
@ -65,6 +65,8 @@ void MNNFunctionInit() {
coreFunction->MNNPackForMatMul_B = _SSE_MNNPackForMatMul_B;
// Dynamic Quant
coreFunction->MNNCountMaxMinValue = _SSE_MNNCountMinMaxValue;
coreFunction->MNNSoftmax = _SSE_MNNSoftmax;
}
#ifdef MNN_USE_AVX
if (cpuFlags & libyuv::kCpuHasAVX2) {
@ -205,8 +207,8 @@ void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count) {
_SSE_MNNInt8ToInt16(dest, source, count);
}
void MNNSoftmax(float* dest, const float* source, size_t size) {
gFunc.MNNSoftmax(dest, source, size);
void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) {
gFunc.MNNSoftmax(softmaxDst, input, runningMax, runningSum, updateScale, outside, reduceSize);
}
void MNNNorm(float* dest, const float* source, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm) {

View File

@ -53,7 +53,7 @@ void _AVX_MNNAsyQuantInfo(float* scale, float* bias, float* qscale, float* qbias
void _AVX_MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);
void _AVX_MNNExpC8(float* dest, const float* source, float* offset, const float* parameters, size_t countC8);
void _AVX_MNNSoftmax(float* dest, const float* source, size_t size);
void _AVX_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize);
void _AVX_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, const float* zeroPoint, ssize_t quanParamVec);
void _AVX_MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t sizeQuad, const float* zeroPoint, ssize_t quanParamVec);
void _AVX_MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dstO, const int8_t* srcO, const int8_t* weightO, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder);

View File

@ -117,12 +117,28 @@ void _AVX_MNNExpC8(float* dest, const float* source, float* offset, const float*
}
void _AVX_MNNSoftmax(float* dest, const float* source, size_t size) {
void _AVX_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) {
const float xLimit = 87.0f;
const float param = 0.6931471805599453f; // ln(2)
const float inv_param = 1.0f / param;
const int32_t exp_offset = 127;
const float exp_scale = 8388608.0f; // 2^23
for (int k = 0; k < outside; ++k) {
float* source = input + k * reduceSize;
float* dest = softmaxDst + k * reduceSize;
float tmpfloat8[8];
int count = size / 8;
int count = reduceSize/ 8;
int remain = count * 8;
// step 1: get maxValue
float maxValue = source[0];
float oldMax = maxValue;
if (runningMax) {
oldMax = runningMax[k];
}
if (count > 0) {
auto maxVal = _mm256_loadu_ps(source);
for (int i = 1; i < count; i++) {
@ -134,75 +150,25 @@ void _AVX_MNNSoftmax(float* dest, const float* source, size_t size) {
maxValue = maxValue > tmpfloat8[i] ? maxValue : tmpfloat8[i];
}
}
for (int i = remain; i < size; i++) {
for (int i = remain; i < reduceSize; i++) {
maxValue = maxValue > source[i] ? maxValue : source[i];
}
// step 2: get exp(x - maxValue) and sum(exp(x - maxValue))
float sumValue = 0.f;
if (count > 0) {
auto sumVal = _mm256_set1_ps(0.f);
auto p0 = _mm256_set1_ps(0.6931471805599453);
auto p1 = _mm256_set1_ps(1.4426950408889634);
auto p2 = _mm256_set1_ps(1.f);
auto p3 = _mm256_set1_ps(1.f);
auto p4 = _mm256_set1_ps(0.5);
auto p5 = _mm256_set1_ps(0.1666666666666666);
auto p6 = _mm256_set1_ps(0.041666666666666664);
auto p7 = _mm256_set1_ps(0.008333333333333333);
auto xMax = _mm256_set1_ps(87);
auto xMin = _mm256_set1_ps(-87);
auto basic = _mm256_set1_epi32(1 << 23);
auto temp127 = _mm256_set1_epi32(127);
for (int i = 0; i < count; ++i) {
auto x = _mm256_sub_ps(_mm256_loadu_ps(source + i * 8), _mm256_set1_ps(maxValue));
x = _mm256_max_ps(x, xMin);
x = _mm256_min_ps(x, xMax);
auto div = _mm256_mul_ps(x, p1);
auto divInt = _mm256_cvtps_epi32(div);
div = _mm256_cvtepi32_ps(divInt);
auto div2 = _mm256_add_epi32(divInt, temp127);
div2 = _mm256_mullo_epi32(div2, basic);
auto expBasic = _mm256_castsi256_ps(div2);
auto xReamin = _mm256_sub_ps(x, _mm256_mul_ps(div, p0));
auto t = xReamin;
auto c0 = _mm256_mul_ps(p7, t);
auto c1 = _mm256_add_ps(c0, p6);
auto c2 = _mm256_mul_ps(c1, t);
auto c3 = _mm256_add_ps(c2, p5);
auto c4 = _mm256_mul_ps(c3, t);
auto c5 = _mm256_add_ps(c4, p4);
auto c6 = _mm256_mul_ps(c5, t);
auto c7 = _mm256_add_ps(c6, p3);
auto c8 = _mm256_mul_ps(c7, t);
auto c9 = _mm256_add_ps(c8, p2);
auto expRemain = c9;
auto expRes = _mm256_mul_ps(expBasic, expRemain);
sumVal = _mm256_add_ps(expRes, sumVal);
_mm256_storeu_ps(dest + 8 * i, expRes);
}
_mm256_storeu_ps(tmpfloat8, sumVal);
for (int i = 0; i < 8; i++) {
sumValue += tmpfloat8[i];
}
}
auto param = 0.6931471805599453;
float xLimit = 87;
for (int i = remain; i < size; i++) {
auto x = source[i] - maxValue;
x = x > -xLimit ? x : -xLimit;
x = x < xLimit ? x : xLimit;
float newMax = ALIMAX(oldMax, maxValue);
int div = (x / param);
int div2 = (div + 127) << 23;
auto xReamin = x - div * param;
float expBasic = *(float*)(&div2);
// step 2: get exp(x - newMax) and sum(exp(x - newMax))
float exprOffset[4] = {1.0f, 0.0f, 0.0f, 0.0f };
exprOffset[2] = -newMax;
MNNExp(dest, source, exprOffset, reduceSize);
float sumValue = exprOffset[3];
auto t = xReamin;
auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f;
dest[i] = expBasic * expRemain;
sumValue += dest[i];
}
if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) {
// === Step 3: Update running variables ===
float scale = expf(oldMax - newMax);
runningSum[k] = runningSum[k] * scale + sumValue;
runningMax[k] = newMax;
updateScale[k] = scale;
} else {
// step 3: get x / sum and store
for (int i = 0; i < count; ++i) {
// using 1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu
@ -211,9 +177,11 @@ void _AVX_MNNSoftmax(float* dest, const float* source, size_t size) {
auto z = _mm256_rcp_ps(_mm256_mul_ps(x, y));
_mm256_storeu_ps(dest + 8 * i, z);
}
sumValue = 1.f / sumValue;
for (int i = remain; i < size; i++) {
dest[i] *= sumValue;
auto scale = 1.f / sumValue;
for (int i = remain; i < reduceSize; i++) {
dest[i] *= scale;
}
}
}
}

View File

@ -48,8 +48,45 @@ void _AVX_MNNConvRunForLineDepthwise(float* dst, const float* src, const float*
size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height,
size_t srcHStep, size_t dstHStep, const float* bias, const float* parameters);
void _AVX_MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters);
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
void _AVX_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes);
#endif
}
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
void _AVX_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) {
// source shape: [headDim/pack, seqLen, pack]
// scale & normalizeScale shape: [seqLen]
// dest shape: [headDim/pack, seqLen, pack]
auto stride0 = plane * pack;
if (idx > 0) {
for (int j = 0; j < depthQuad; ++j) {
for (int i = 0; i < plane; ++i) {
auto dataNew = Vec::load(src + j * stride0 + i * pack);
auto dataOld = Vec::load(dst + j * stride0 + i * pack);
auto s = Vec(scale[i]);
dataNew = Vec::fma(dataNew, dataOld, s);
Vec::save(dst + j * stride0 + i * pack, dataNew);
}
}
} else {
memcpy(dst, src, size * bytes);
}
if (idx == kvBlocks - 1) { // if last subBlock, exp(xi)/sum(exp(xi))
for (int j = 0; j < depthQuad; ++j) {
for (int i = 0; i < plane; ++i) {
auto dataNew = Vec::load(dst + j * stride0 + i * pack);
auto ns = Vec(1.0f / normalizeScale[i]);
dataNew = dataNew * ns;
Vec::save(dst + j * stride0 + i * pack, dataNew);
}
}
}
}
#endif
void _AVX_MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
for (int i = 0; i < count; ++i) {
@ -728,4 +765,9 @@ void _AVX_ExtraInit(void* functions) {
// sparse conv funcs
coreFunction->MNNGetSparseMatMulPackMode = _AVX_MNNGetSparseMatMulPackMode;
coreFunction->MNNAdjustOptimalSparseKernel = _AVX_MNNAdjustOptimalSparseKernel;
// attention
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
coreFunction->MNNFlashAttentionUpdateBlockOutput = _AVX_MNNFlashAttentionUpdateBlockOutput;
#endif
}

View File

@ -1064,6 +1064,39 @@ static void _AVX512_MNNAdjustOptimalSparseKernel(int& sparseBlockOC, MNN::CoreFu
}
}
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
void _AVX512_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) {
// source shape: [headDim/pack, seqLen, pack]
// scale & normalizeScale shape: [seqLen]
// dest shape: [headDim/pack, seqLen, pack]
auto stride0 = plane * pack;
if (idx > 0) {
for (int j = 0; j < depthQuad; ++j) {
for (int i = 0; i < plane; ++i) {
auto dataNew = Vec::load(src + j * stride0 + i * pack);
auto dataOld = Vec::load(dst + j * stride0 + i * pack);
auto s = Vec(scale[i]);
dataNew = Vec::fma(dataNew, dataOld, s);
Vec::save(dst + j * stride0 + i * pack, dataNew);
}
}
} else {
memcpy(dst, src, size * bytes);
}
if (idx == kvBlocks - 1) { // if last subBlock, exp(xi)/sum(exp(xi))
for (int j = 0; j < depthQuad; ++j) {
for (int i = 0; i < plane; ++i) {
auto dataNew = Vec::load(dst + j * stride0 + i * pack);
auto ns = Vec(1.0f / normalizeScale[i]);
dataNew = dataNew * ns;
Vec::save(dst + j * stride0 + i * pack, dataNew);
}
}
}
}
#endif
void _AVX512_ExtraInit(void* functions) {
auto coreFunction = static_cast<MNN::CoreFunctions*>(functions);
@ -1100,4 +1133,8 @@ void _AVX512_ExtraInit(void* functions) {
coreFunction->MNNGetSparseMatMulPackMode = _AVX512_MNNGetSparseMatMulPackMode;
coreFunction->MNNAdjustOptimalSparseKernel = _AVX512_MNNAdjustOptimalSparseKernel;
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
coreFunction->MNNFlashAttentionUpdateBlockOutput = _AVX512_MNNFlashAttentionUpdateBlockOutput;
#endif
}

View File

@ -81,7 +81,7 @@ void _SSE_MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count);
void _SSE_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
void _SSE_MNNReluInt8(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint);
void _SSE_MNNSoftmax(float* dest, const float* source, size_t size);
void _SSE_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize);
void _SSE_ExtraInit(void* functions);
void _SSE_MNNNorm(float *dst, const float *src, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm);
void _SSE_ImageProcessInit(void* functions, int cpuFlags);

View File

@ -69,12 +69,26 @@ void _SSE_MNNExpC8(float* dest, const float* source, float* offset, const float*
offset[3] = total;
}
void _SSE_MNNSoftmax(float* dest, const float* source, size_t size) {
void _SSE_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) {
const float xLimit = 87.0f;
const float param = 0.6931471805599453f; // ln(2)
const float inv_param = 1.0f / param;
const int32_t exp_offset = 127;
const float exp_scale = 8388608.0f; // 2^23
for (int k = 0; k < outside; ++k) {
float* source = input + k * reduceSize;
float* dest = softmaxDst + k * reduceSize;
float tmpfloat4[4];
int count = static_cast<int32_t>(size / 4);
int count = static_cast<int32_t>(reduceSize / 4);
int remain = count * 4;
// step 1: get maxValue
float maxValue = source[0];
float oldMax = maxValue;
if (runningMax) {
oldMax = runningMax[k];
}
if (count > 0) {
auto maxVal = _mm_loadu_ps(source);
for (int i = 1; i < count; i++) {
@ -85,73 +99,25 @@ void _SSE_MNNSoftmax(float* dest, const float* source, size_t size) {
maxValue = maxValue > tmpfloat4[2] ? maxValue : tmpfloat4[2];
maxValue = maxValue > tmpfloat4[3] ? maxValue : tmpfloat4[3];
}
for (int i = remain; i < size; i++) {
for (int i = remain; i < reduceSize; i++) {
maxValue = maxValue > source[i] ? maxValue : source[i];
}
// step 2: get exp(x - maxValue) and sum(exp(x - maxValue))
float sumValue = 0.f;
if (count > 0) {
auto sumVal = _mm_set1_ps(0.f);
auto p0 = _mm_set1_ps(0.6931471805599453);
auto p1 = _mm_set1_ps(1.4426950408889634);
auto p2 = _mm_set1_ps(1.f);
auto p3 = _mm_set1_ps(1.f);
auto p4 = _mm_set1_ps(0.5);
auto p5 = _mm_set1_ps(0.1666666666666666);
auto p6 = _mm_set1_ps(0.041666666666666664);
auto p7 = _mm_set1_ps(0.008333333333333333);
auto xMax = _mm_set1_ps(87);
auto xMin = _mm_set1_ps(-87);
// auto basic = _mm_set1_epi32(1 << 23);
for (int i = 0; i < count; ++i) {
auto x = _mm_sub_ps(_mm_loadu_ps(source + i * 4), _mm_set1_ps(maxValue));
x = _mm_max_ps(x, xMin);
x = _mm_min_ps(x, xMax);
auto div = _mm_mul_ps(x, p1);
auto divInt = _mm_cvtps_epi32(div);
div = _mm_cvtepi32_ps(divInt);
auto div2 = _mm_add_epi32(divInt, _mm_set1_epi32(127));
// div2 = _mm_mullo_epi32(div2, basic);
div2 = _mm_slli_epi32(div2, 23);
auto expBasic = _mm_castsi128_ps(div2);
auto xReamin = _mm_sub_ps(x, _mm_mul_ps(div, p0));
auto t = xReamin;
auto c0 = _mm_mul_ps(p7, t);
auto c1 = _mm_add_ps(c0, p6);
auto c2 = _mm_mul_ps(c1, t);
auto c3 = _mm_add_ps(c2, p5);
auto c4 = _mm_mul_ps(c3, t);
auto c5 = _mm_add_ps(c4, p4);
auto c6 = _mm_mul_ps(c5, t);
auto c7 = _mm_add_ps(c6, p3);
auto c8 = _mm_mul_ps(c7, t);
auto c9 = _mm_add_ps(c8, p2);
auto expRemain = c9;
auto expRes = _mm_mul_ps(expBasic, expRemain);
sumVal = _mm_add_ps(expRes, sumVal);
_mm_storeu_ps(dest + 4 * i, expRes);
}
_mm_storeu_ps(tmpfloat4, sumVal);
sumValue = tmpfloat4[0] + tmpfloat4[1] + tmpfloat4[2] + tmpfloat4[3];
}
auto param = 0.6931471805599453;
float xLimit = 87;
for (int i = remain; i < size; i++) {
auto x = source[i] - maxValue;
x = x > -xLimit ? x : -xLimit;
x = x < xLimit ? x : xLimit;
float newMax = ALIMAX(oldMax, maxValue);
int div = (x / param);
int div2 = (div + 127) << 23;
auto xReamin = x - div * param;
float expBasic = *(float*)(&div2);
// step 2: get exp(x - newMax) and sum(exp(x - newMax))
float exprOffset[4] = {1.0f, 0.0f, 0.0f, 0.0f };
exprOffset[2] = -newMax;
MNNExp(dest, source, exprOffset, reduceSize);
float sumValue = exprOffset[3];
auto t = xReamin;
auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f;
dest[i] = expBasic * expRemain;
sumValue += dest[i];
}
if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) {
// === Step 3: Update running variables ===
float scale = expf(oldMax - newMax);
runningSum[k] = runningSum[k] * scale + sumValue;
runningMax[k] = newMax;
updateScale[k] = scale;
} else {
// step 3: get x / sum and store
for (int i = 0; i < count; ++i) {
// using 1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu
@ -160,9 +126,11 @@ void _SSE_MNNSoftmax(float* dest, const float* source, size_t size) {
auto z = _mm_rcp_ps(_mm_mul_ps(x, y));
_mm_storeu_ps(dest + 4 * i, z);
}
sumValue = 1.f / sumValue;
for (int i = remain; i < size; i++) {
dest[i] *= sumValue;
auto scale = 1.f / sumValue;
for (int i = remain; i < reduceSize; i++) {
dest[i] *= scale;
}
}
}
}

View File

@ -360,8 +360,6 @@ bool BufferConvertor::convertToNC4HW4Buffer(const Tensor *buffer, const OpenCLBu
default:
break;
}
if (mBufferToImageKernel.get() == nullptr || mBufferToImageKernelName != kernelName) {
mBufferToImageKernelName = kernelName;
std::set<std::string> buildOptions;
if(needTrans) {
//buildOptions.emplace("-DBUFFER_FORMAT_INP_TRANS");
@ -379,7 +377,6 @@ bool BufferConvertor::convertToNC4HW4Buffer(const Tensor *buffer, const OpenCLBu
}
#endif
mBufferToImageKernel = runtime->buildKernelWithCache(kernelFile, kernelName, buildOptions, precision, buffer, image);
}
auto kernel = mBufferToImageKernel->get();
uint32_t idx = 0;

View File

@ -43,7 +43,7 @@ public:
explicit BufferConvertor(OpenCLRuntime *opencl_runtime) : mOpenCLRuntime(opencl_runtime) {
}
bool convertToNC4HW4Buffer(const Tensor *input, const OpenCLBufferFormat type, Tensor *output, int precision,
bool needTrans, bool needWait = false, bool lowMemory = false, int quantBit = 0);
bool needTrans, bool needWait = true, bool lowMemory = false, int quantBit = 0);
private:
OpenCLRuntime *mOpenCLRuntime;

View File

@ -48,6 +48,11 @@ CLRuntime::CLRuntime(const Backend::Info& info){
mMemory = mInfo.user->memory;
}
// protect
if(mPrecision > 2 || mPrecision < 0){
mPrecision = BackendConfig::Precision_High;
}
mOpenCLRuntime.reset(new OpenCLRuntime(platform_size, platform_id, device_id, context_ptr, hint()));
//Whether runtimeError
@ -206,6 +211,10 @@ Backend* CLRuntime::onCreate(const BackendConfig* config, Backend* origin) const
precision = config->precision;
memory = config->memory;
}
// protect
if(precision > 2 || precision < 0){
precision = BackendConfig::Precision_High;
}
auto backend = new OpenCLBackend(precision, memory, mInfo.gpuMode, mImagePool, mBufferPool, this);
backend->setMetaPtr(pMeta);
return backend;
@ -246,6 +255,10 @@ OpenCLBackend::OpenCLBackend(BackendConfig::PrecisionMode precision, BackendConf
} else{
mPrecision = BackendConfig::Precision_High;
}
// protect
if(mPrecision > 2 || mPrecision < 0){
mPrecision = BackendConfig::Precision_High;
}
mMemory = memory;
// set tuneLevel, memtype, record mode
setGpuMode(gpuMode);

View File

@ -18,6 +18,8 @@ struct QnnContext {
Qnn_LogHandle_t logHandle = nullptr;
Qnn_BackendHandle_t backendHandle = nullptr;
Qnn_DeviceHandle_t deviceHandle = nullptr;
int soc_id;
int dsp_arch;
};
static QnnContext gContext;
@ -94,6 +96,7 @@ bool PluginShapeRaw::compute(InferShapeContext* ctx) {
auto dst = ctx->output(i);
std::string key = prefix + std::to_string(i);
auto attr = ctx->getAttr(key.c_str());
if (nullptr == attr || nullptr == attr->tensor()) {
MNN_ERROR("MNN_QNN: Failed to find raw shape %s.\n", key.c_str());
return false;
@ -886,7 +889,12 @@ ErrorCode QnnBackend::onResizeEnd() {
#ifdef QNN_VERBOSE
MNN_PRINT("start finalize\n");
#endif
buildOutputDequant();
finalizeGraph();
for(auto func : mReleaseFunc){
func();
}
mReleaseFunc.clear();
#ifdef QNN_VERBOSE
MNN_PRINT("end finalize\n");
#endif
@ -915,7 +923,14 @@ Backend::MemObj* QnnBackend::onAcquire(const Tensor* tensor, StorageType storage
}
Qnn_DataType_t tDataType;
MNN_ASSERT((tensor->getType().code == halide_type_float) || (tensor->getType().code == halide_type_int && tensor->getType().bits == 32));
Qnn_QuantizeParams_t tQuantizeParams{};
tQuantizeParams.encodingDefinition = QNN_DEFINITION_UNDEFINED;
tQuantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED;
Qnn_ScaleOffset_t tScaleOffsetEncoding;
tScaleOffsetEncoding.scale = 0.0f;
tScaleOffsetEncoding.offset = 0;
auto quant = TensorUtils::getDescribe(tensor)->quantAttr.get();
//MNN_ASSERT((tensor->getType().code == halide_type_float) || (tensor->getType().code == halide_type_int && tensor->getType().bits == 32));
if (mUseFP16 && tensor->getType().code == halide_type_float) {
tDataType = QNN_DATATYPE_FLOAT_16;
} else if (tensor->getType().code == halide_type_float) {
@ -926,13 +941,19 @@ Backend::MemObj* QnnBackend::onAcquire(const Tensor* tensor, StorageType storage
MNN_PRINT("MNN_QNN: Not supported data type in <QnnBackend::onAcquire>.\n");
return nullptr;
}
Qnn_QuantizeParams_t tQuantizeParams{};
tQuantizeParams.encodingDefinition = QNN_DEFINITION_UNDEFINED;
tQuantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED;
Qnn_ScaleOffset_t tScaleOffsetEncoding;
tScaleOffsetEncoding.scale = 0.0f;
tScaleOffsetEncoding.offset = 0;
if(quant != nullptr && TensorUtils::getDescribe(tensor)->type == DataType_DT_INT8) {
tQuantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;
tQuantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;
tScaleOffsetEncoding.scale = quant->scale;
if(quant->zero != 0){
MNN_PRINT("MNN_QNN: Not supported asymmetric quant in <QnnBackend::onAcquire>.\n");
return nullptr;
}
tDataType = QNN_DATATYPE_SFIXED_POINT_8;
if (isOutput) {
tType = QNN_TENSOR_TYPE_NATIVE;
}
}
tQuantizeParams.scaleOffsetEncoding = tScaleOffsetEncoding;
std::unique_ptr<Tensor> tempTensor(new Tensor(tensor, gQnnTensorDimType, false));
@ -947,17 +968,41 @@ Backend::MemObj* QnnBackend::onAcquire(const Tensor* tensor, StorageType storage
Qnn_Tensor_t * qnnTensor = qnnTensorWrapper->getNativeTensor();
CALL_QNN(mRuntime->mQnnInterface.tensorCreateGraphTensor(mQnnGraphHandle, qnnTensor));
mQNNTensorWrappers.push_back(qnnTensorWrapper);
mTensorMap.insert({TensorUtils::getDescribe(tensor), mTensorCounter});
if (isInput) {
mInputTensorIndexes.push_back(mTensorCounter);
qnnTensorWrapper->alloc();
}
if (isOutput) {
if(quant != nullptr && TensorUtils::getDescribe(tensor)->type == DataType_DT_INT8){
mTensorCounter += 1;
std::shared_ptr<Tensor> stageTensor;
stageTensor.reset(Tensor::create<float>(tensor->shape(), nullptr, gQnnTensorDimType));
tName = "QnnTensor_" + std::to_string(mTensorCounter);
tType = QNN_TENSOR_TYPE_APP_READ;
if (mUseFP16 && tensor->getType().code == halide_type_float) {
tDataType = QNN_DATATYPE_FLOAT_16;
} else if (tensor->getType().code == halide_type_float) {
tDataType = QNN_DATATYPE_FLOAT_32;
} else {
MNN_PRINT("MNN_QNN: Not supported data type in <QnnBackend::onAcquire>.\n");
return nullptr;
}
Qnn_QuantizeParams_t tQuantizeParamstmp = QNN_QUANTIZE_PARAMS_INIT;
std::shared_ptr<QNNTensorWrapper> qnnOutputTensorWrapper = QNNTensorWrapper::create(tName, tType, tDataType, tDims, tQuantizeParamstmp);
CALL_QNN(mRuntime->mQnnInterface.tensorCreateGraphTensor(mQnnGraphHandle, qnnOutputTensorWrapper->getNativeTensor()));
mDeQuantOutputTensorMap.insert({TensorUtils::getDescribe(tensor), {tensor, stageTensor}});
mQNNTensorWrappers.push_back(qnnOutputTensorWrapper);
mTensorMap.insert({TensorUtils::getDescribe(const_cast<const Tensor*>(stageTensor.get())), mTensorCounter});
mOutputTensorIndexes.push_back(mTensorCounter);
qnnOutputTensorWrapper->alloc();
} else{
mOutputTensorIndexes.push_back(mTensorCounter);
qnnTensorWrapper->alloc();
}
mQNNTensorWrappers.push_back(qnnTensorWrapper);
mTensorMap.insert({TensorUtils::getDescribe(tensor), mTensorCounter});
}
mTensorCounter += 1;
#ifdef QNN_VERBOSE
@ -1029,7 +1074,13 @@ void QnnBackend::inputIO(const Tensor* srcTensor, const Tensor* dstTensor) const
}
void QnnBackend::outputIO(const Tensor* srcTensor, const Tensor* dstTensor) const {
int srcIndex = getTensorIdx(srcTensor);
auto iter = mDeQuantOutputTensorMap.find(TensorUtils::getDescribe(srcTensor));
int srcIndex = -1;
if(iter != mDeQuantOutputTensorMap.end()){
srcIndex = getTensorIdx(iter->second.second.get());
} else{
srcIndex = getTensorIdx(srcTensor);
}
std::shared_ptr<QNNTensorWrapper> srcQnnTensorWrapper = mQNNTensorWrappers[srcIndex];
std::shared_ptr<Tensor> srcDataContainer = srcQnnTensorWrapper->getDataContainer();
@ -1091,7 +1142,7 @@ void QnnBackend::finalizeGraph() {
// set QNN_PROFILE_LEVEL_DETAILED
QnnProfile_Level_t profileLevel = QNN_PROFILE_LEVEL_DETAILED;
MNN_PRINT("[QNN Profile] Creating QNN Profile Handle with DETAILED level.\n");
auto profile_err = mRuntime->mQnnInterface.profileCreate(mQnnContextHandle, profileLevel, &mQnnProfileHandle);
auto profile_err = mRuntime->mQnnInterface.profileCreate(mRuntime->mQnnContextHandle, profileLevel, &mQnnProfileHandle);
if (profile_err != QNN_SUCCESS || mQnnProfileHandle == nullptr) {
MNN_ERROR("[QNN Profile] Failed to create QNN Profile Handle, error: %d\n", (int)profile_err);
mQnnProfileHandle = nullptr;
@ -1216,6 +1267,27 @@ void QnnBackend::clean() {
mTensorMap.clear();
mInputTensorIndexes.clear();
mOutputTensorIndexes.clear();
mDeQuantOutputTensorMap.clear();
}
void QnnBackend::buildOutputDequant(){
Qnn_OpConfigVersion_t mOpConfigVersion = QNN_OPCONFIG_VERSION_1;
std::string mNodeName;
std::string mPackageName = "qti.aisw";
std::string mNodeType;
std::vector<Qnn_Param_t> mParams;
std::vector<Qnn_Tensor_t> mInputs;
std::vector<Qnn_Tensor_t> mOutputs;
for(auto iter : mDeQuantOutputTensorMap){
mNodeType.clear();
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Dequantize";
std::string name = "Dequantize_I_" + std::to_string(getTensorIdx(iter.second.first)) + "_O_" + std::to_string(getTensorIdx(iter.second.second.get()));;
mInputs.push_back(*(getNativeTensor(iter.second.first))); // input
mOutputs.push_back(*(getNativeTensor(iter.second.second.get()))); // output
addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
}
QnnRuntime::QnnRuntime(const Backend::Info& info, QNN_INTERFACE_VER_TYPE qnnInterface, Qnn_LogHandle_t qnnLogHandle, Qnn_BackendHandle_t qnnBackendHandle, Qnn_DeviceHandle_t qnnDeviceHandle) {
@ -1324,12 +1396,23 @@ bool QnnRuntime::registerCustomOpPackage(QNN_INTERFACE_VER_TYPE qnnInterface, Qn
class QnnRuntimeCreator : public RuntimeCreator {
public:
virtual Runtime* onCreate(const Backend::Info& info) const {
virtual Runtime* onCreate(const Backend::Info& info) const override {
return QnnRuntime::create(info);
}
virtual bool onValid(Backend::Info& info) const {
virtual bool onValid(Backend::Info& info) const override {
return true;
}
virtual bool onGetDeviceInfo(const std::string& deviceKey, std::string& deviceValue) const override {
if(deviceKey == "soc_id" && gContext.soc_id != 0) {
deviceValue = std::to_string(gContext.soc_id);
return true;
}
if(deviceKey == "dsp_arch" && gContext.dsp_arch != 0) {
deviceValue = "v" + std::to_string(gContext.dsp_arch);
return true;
}
return false;
}
};
} // end namespace QNN
@ -1401,6 +1484,8 @@ void registerQNNRuntimeCreator() {
// Create Device.
Qnn_DeviceHandle_t deviceHandle = nullptr;
QnnHtpDevice_Arch_t dspArch = QNN_HTP_DEVICE_ARCH_NONE;
uint32_t socId = 0;
{
// Check whether the device API is supported.
bool supportDevice = QNN::checkCapability(qnnInterface, QNN_PROPERTY_GROUP_DEVICE);
@ -1422,10 +1507,8 @@ void registerQNNRuntimeCreator() {
MNN_PRINT("[Warning]: deviceGetPlatformInfo Failed to query platform info");
} else {
QnnDevice_HardwareDeviceInfo_t* hwDeviceInfo = backendPlatformInfoPtr->v1.hwDevices;
QnnHtpDevice_Arch_t arch = hwDeviceInfo->v1.deviceInfoExtension->onChipDevice.arch;
uint32_t socModel = hwDeviceInfo->v1.deviceInfoExtension->onChipDevice.socModel;
MNN_PRINT("Qnn Device soc_id: %d\n", socModel);
MNN_PRINT("Qnn Device dsp_arch: v%d\n", arch);
dspArch = hwDeviceInfo->v1.deviceInfoExtension->onChipDevice.arch;
socId = hwDeviceInfo->v1.deviceInfoExtension->onChipDevice.socModel;
}
}
} else {
@ -1478,6 +1561,8 @@ void registerQNNRuntimeCreator() {
QNN::gContext.backendHandle = backendHandle;
QNN::gContext.deviceHandle = deviceHandle;
QNN::gContext.logHandle = logHandle;
QNN::gContext.soc_id = socId;
QNN::gContext.dsp_arch = dspArch;
QNN::registerQNNOps();
MNNInsertExtraRuntimeCreator(MNN_FORWARD_NN, new QNN::QnnRuntimeCreator, false);

View File

@ -78,6 +78,10 @@ public:
std::shared_ptr<QNNTensorWrapper> getTensorWrapper(const Tensor * tensor);
bool useCache() const;
bool getUseFP16() const;
void buildOutputDequant();
void pushReleaseFunc(std::function<void()> func){
mReleaseFunc.push_back(func);
}
private:
void clean();
@ -109,8 +113,10 @@ private:
mutable int mTensorCounter = 0;
mutable std::vector<std::shared_ptr<QNNTensorWrapper>> mQNNTensorWrappers;
mutable std::map<const Tensor::InsideDescribe::NativeInsideDescribe *, int> mTensorMap;
mutable std::map<const Tensor::InsideDescribe::NativeInsideDescribe *, std::pair<const Tensor*, std::shared_ptr<Tensor>>> mDeQuantOutputTensorMap;
std::vector<int> mInputTensorIndexes;
std::vector<int> mOutputTensorIndexes;
std::vector<std::function<void()>> mReleaseFunc;
};

View File

@ -134,6 +134,8 @@ void registerQNNOps() {
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
___QNNAttentionCreator__OpType_Attention__();
#endif
___QNNQuantCreator__OpType_FloatToInt8__();
___QNNDeQuantCreator__OpType_Int8ToFloat__();
}
Tensor::DimensionType gQnnTensorDimType = Tensor::TENSORFLOW;

View File

@ -104,6 +104,8 @@ extern void ___QNNMatMulCreator__OpType_MatMul__();
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
extern void ___QNNAttentionCreator__OpType_Attention__();
#endif
extern void ___QNNQuantCreator__OpType_FloatToInt8__();
extern void ___QNNDeQuantCreator__OpType_Int8ToFloat__();
void registerQNNOps();
extern Tensor::DimensionType gQnnTensorDimType;

View File

@ -25,7 +25,7 @@ std::shared_ptr<QNNTensorWrapper> QNNTensorWrapper::create(const std::string & n
std::shared_ptr<QNNTensorWrapper> QNNTensorWrapper::createStaticTensor(const std::string & name, Qnn_DataType_t dataType, const std::vector<uint32_t> & dimensions, const void * buffer, Qnn_QuantizeParams_t quantizeParam) {
MNN_ASSERT(!name.empty() && !dimensions.empty() && buffer);
MNN_ASSERT(dataType == QNN_DATATYPE_SFIXED_POINT_8 || dataType == QNN_DATATYPE_INT_32);
MNN_ASSERT(dataType == QNN_DATATYPE_SFIXED_POINT_8 || dataType == QNN_DATATYPE_INT_32 || dataType == QNN_DATATYPE_SFIXED_POINT_32 || dataType == QNN_DATATYPE_UFIXED_POINT_8);
std::shared_ptr<QNNTensorWrapper> tensorWrapper = QNNTensorWrapper::create(name, QNN_TENSOR_TYPE_STATIC, dataType, dimensions, quantizeParam);
uint32_t numElement = 1;
@ -114,7 +114,8 @@ void * QNNTensorWrapper::alloc() {
dims[i] = (int)mDimensions[i];
}
MNN_ASSERT(mQnnTensor.v1.dataType == QNN_DATATYPE_FLOAT_32 || mQnnTensor.v1.dataType == QNN_DATATYPE_FLOAT_16 || mQnnTensor.v1.dataType == QNN_DATATYPE_INT_32 || mQnnTensor.v1.dataType == QNN_DATATYPE_SFIXED_POINT_8);
MNN_ASSERT(mQnnTensor.v1.dataType == QNN_DATATYPE_FLOAT_32 || mQnnTensor.v1.dataType == QNN_DATATYPE_FLOAT_16 || mQnnTensor.v1.dataType == QNN_DATATYPE_INT_32 || mQnnTensor.v1.dataType == QNN_DATATYPE_SFIXED_POINT_8 || mQnnTensor.v1.dataType == QNN_DATATYPE_SFIXED_POINT_32
|| mQnnTensor.v1.dataType == QNN_DATATYPE_UFIXED_POINT_8);
halide_type_t halideType;
halideType.lanes = 1;
@ -134,6 +135,15 @@ void * QNNTensorWrapper::alloc() {
case QNN_DATATYPE_SFIXED_POINT_8:
halideType.code = halide_type_int;
halideType.bits = 8;
break;
case QNN_DATATYPE_SFIXED_POINT_32:
halideType.code = halide_type_int;
halideType.bits = 32;
break;
case QNN_DATATYPE_UFIXED_POINT_8:
halideType.code = halide_type_int;
halideType.bits = 8;
break;
default:
break;
}

View File

@ -269,6 +269,10 @@ std::vector<std::string> QNNTranslator::TranslateTensor(const QNNCommandTensor&
if (isParam) {
result.push_back(QNNTranslator::TranslateParamDataArray(dataNameSymbol, cmdT.dataType, cmdT.clientBuf));
}
if(hasQuant){
std::vector<std::string> linesQuantScaleOffset = TranslateQuantizeScaleOffsetDataArray(tensorNameSymbol, cmdT.quantizeParams, cmdT.rank, cmdT.dimensions);
APPEND_VECTOR(result, linesQuantScaleOffset);
}
result.push_back(" Qnn_Tensor_t " + tensorNameSymbol + " = QNN_TENSOR_INIT;");
result.push_back(" {");
result.push_back(" " + tensorNameSymbol + ".version = QNN_TENSOR_VERSION_1;");
@ -456,11 +460,122 @@ std::string QNNTranslator::TranslateParamDataArray(const std::string & dataNameS
return result;
}
std::vector<std::string> QNNTranslator::TranslateQuantizeScaleOffsetDataArray(const std::string & tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParams, uint32_t rank, const uint32_t * dimensions){
std::vector<std::string> result;
if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET){
result.push_back(" Qnn_ScaleOffset_t " + tensorNameSymbol + "_axis_scale_offset[] = {");
int totalnum = (quantizeParams.axisScaleOffsetEncoding.numScaleOffsets + 3) / 4;
for(int i = 0; i < totalnum; ++i){
std::string line = " ";
for(int j = 0; j < 4; ++j){
int index = i * 4 + j;
if(index >= quantizeParams.axisScaleOffsetEncoding.numScaleOffsets)
break;
line += "{.scale= " + std::to_string(quantizeParams.axisScaleOffsetEncoding.scaleOffset[index].scale) + ", .offset= " + std::to_string(quantizeParams.axisScaleOffsetEncoding.scaleOffset[index].offset) + "}, ";
}
result.push_back(line);
}
result.push_back(" };");
}
if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET){
result.push_back(" float " + tensorNameSymbol + "_bwaxis_scale[] = {");
int totalnum = (quantizeParams.bwAxisScaleOffsetEncoding.numElements + 3) / 4;
for(int i = 0; i < totalnum; ++i){
std::string line = " ";
for(int j = 0; j < 4; ++j){
int index = i * 4 + j;
if(index >= quantizeParams.bwAxisScaleOffsetEncoding.numElements)
break;
line += std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.scales[index]) + ", ";
}
result.push_back(line);
}
result.push_back(" };");
if(quantizeParams.bwAxisScaleOffsetEncoding.offsets != nullptr){
result.push_back(" int32_t " + tensorNameSymbol + "_bwaxis_offset[] = {");
for(int i = 0; i < totalnum; ++i){
std::string line = " ";
for(int j = 0; j < 4; ++j){
int index = i * 4 + j;
if(index >= quantizeParams.bwAxisScaleOffsetEncoding.numElements)
break;
line += std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.offsets[index]) + ", ";
}
result.push_back(line);
}
result.push_back(" };");
}
}
if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION){
int axis = quantizeParams.blockwiseExpansion->axis;
int oc = dimensions[axis];
int blockSize = quantizeParams.blockwiseExpansion->numBlocksPerAxis;
result.push_back(" Qnn_BlockwiseExpansion_t " + tensorNameSymbol + "_blockwiseExpansion = QNN_BLOCKWISE_EXPANSION_INIT;");
result.push_back(" Qnn_ScaleOffset_t " + tensorNameSymbol + "_blockwiseExpansionScaleOffset[] = {");
int totalnum = (oc + 3) / 4;
for(int i = 0; i < totalnum; ++i){
std::string line = " ";
for(int j = 0; j < 4; ++j){
int index = i * 4 + j;
if(index >= oc)
break;
line += "{.scale= " + std::to_string(quantizeParams.blockwiseExpansion->scaleOffsets[index].scale) + ", .offset= " + std::to_string(quantizeParams.blockwiseExpansion->scaleOffsets[index].offset) + "}, ";
}
result.push_back(line);
}
result.push_back(" };");
if(quantizeParams.blockwiseExpansion->blockScaleStorageType == QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8){
result.push_back(" uint8_t " + tensorNameSymbol + "_blockwiseExpansionBlockScale[] = {");
totalnum = (oc * blockSize + 3) / 4;
for(int i = 0; i < totalnum; ++i){
std::string line = " ";
for(int j = 0; j < 4; ++j){
int index = i * 4 + j;
if(index >= oc * blockSize)
break;
line += std::to_string(quantizeParams.blockwiseExpansion->blocksScale8[index]) + ", ";
}
result.push_back(line);
}
result.push_back(" };");
}else{
result.push_back(" uint16_t " + tensorNameSymbol + "_blockwiseExpansionBlockScale[] = {");
totalnum = (oc * blockSize + 3) / 4;
for(int i = 0; i < totalnum; ++i){
std::string line = " ";
for(int j = 0; j < 4; ++j){
int index = i * 4 + j;
if(index >= oc * blockSize)
break;
line += std::to_string(quantizeParams.blockwiseExpansion->blocksScale16[index]) + ", ";
}
result.push_back(line);
}
result.push_back(" };");
}
result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.axis = " + std::to_string(quantizeParams.blockwiseExpansion->axis) + ";");
result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.scaleOffsets = " + tensorNameSymbol + "_blockwiseExpansionScaleOffset;");
result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.numBlocksPerAxis = " + std::to_string(quantizeParams.blockwiseExpansion->numBlocksPerAxis) + ";");
result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.blockScaleBitwidth = " + std::to_string(quantizeParams.blockwiseExpansion->blockScaleBitwidth) + ";");
if(quantizeParams.blockwiseExpansion->blockScaleStorageType == QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8){
result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.blockScaleStorageType = QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8;");
result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.blocksScale8 = " + tensorNameSymbol + "_blockwiseExpansionBlockScale;");
}else{
result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.blockScaleStorageType = QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_16;");
result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.blocksScale16 = " + tensorNameSymbol + "_blockwiseExpansionBlockScale;");
}
}
return result;
}
// Currently, only support QNN_QUANTIZATION_ENCODING_UNDEFINED, QNN_QUANTIZATION_ENCODING_SCALE_OFFSET.
std::vector<std::string> QNNTranslator::TranslateTensorQuantizeParams(const std::string tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParmas) {
std::vector<std::string> QNNTranslator::TranslateTensorQuantizeParams(const std::string tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParams) {
std::vector<std::string> result;
if (quantizeParmas.encodingDefinition == QNN_DEFINITION_UNDEFINED) {
if (quantizeParams.encodingDefinition == QNN_DEFINITION_UNDEFINED) {
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_UNDEFINED;");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED;");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.scale = 0.0f;");
@ -468,14 +583,43 @@ std::vector<std::string> QNNTranslator::TranslateTensorQuantizeParams(const std:
return result;
}
if (quantizeParmas.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParmas.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) {
if (quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) {
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.scale = " + std::to_string(quantizeParmas.scaleOffsetEncoding.scale) + ";");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.offset = " + std::to_string(quantizeParmas.scaleOffsetEncoding.offset) + ";");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.scale = " + std::to_string(quantizeParams.scaleOffsetEncoding.scale) + ";");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.offset = " + std::to_string(quantizeParams.scaleOffsetEncoding.offset) + ";");
return result;
}
if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET){
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET;");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.axisScaleOffsetEncoding.axis = " + std::to_string(quantizeParams.axisScaleOffsetEncoding.axis) + ";");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets = " + std::to_string(quantizeParams.axisScaleOffsetEncoding.numScaleOffsets) + ";");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.axisScaleOffsetEncoding.scaleOffset = " + tensorNameSymbol + "_axis_scale_offset;");
return result;
}
if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET){
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET;");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.axis = " + std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.axis) + ";");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.bitwidth = " + std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.bitwidth) + ";");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.numElements = " + std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.numElements) + ";");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.scales = " + tensorNameSymbol + "_bwaxis_scale;");
if(quantizeParams.bwAxisScaleOffsetEncoding.offsets != nullptr)
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.offset = " + tensorNameSymbol + "_bwaxis_offset;");
return result;
}
if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION){
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION;");
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.blockwiseExpansion = &" + tensorNameSymbol + "_blockwiseExpansion;");
return result;
}
MNN_ERROR("MNN_QNN: Unknown QuantizeParams.\n");
return result;

View File

@ -69,7 +69,8 @@ private:
static std::string MapDataType(Qnn_DataType_t dataType);
static std::string TranslateDimensionsArray(const std::string & dimensionsNameSymbol, uint32_t rank, const uint32_t * dimensions);
static std::string TranslateParamDataArray(const std::string & dataNameSymbol, Qnn_DataType_t dataType, const Qnn_ClientBuffer_t & clientBuf);
static std::vector<std::string> TranslateTensorQuantizeParams(const std::string tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParmas);
static std::vector<std::string> TranslateQuantizeScaleOffsetDataArray(const std::string & tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParams, uint32_t rank, const uint32_t * dimensions);
static std::vector<std::string> TranslateTensorQuantizeParams(const std::string tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParams);
static std::vector<std::string> TranslateTensorClientBuf(const std::string & tensorNameSymbol, const std::string & dataNameSymbol, const std::string & sname, const Qnn_ClientBuffer_t & clientBuf, bool hasClientBuf, bool isParam);
// Utility functions used by TranslateNode.
static std::vector<std::string> TranslateNodeParamArray(const std::string & nodeName,const std::string & paramArraySymbol, uint32_t numOfParams, const Qnn_Param_t * params);

View File

@ -11,6 +11,157 @@
namespace MNN {
namespace QNN {
void QNNConvDepthwise::isWeightQuantSupported(const Tensor *input, const int oc){
Qnn_DataType_t dataType = mBackend->getNativeTensor(input)->v1.dataType;
if(mOp->main_as_Convolution2D()->quanParameter() == nullptr){
mWeightQuant = false;
return;
}else{
bool hasBais = false;
auto bias = mOp->main_as_Convolution2D()->bias();
auto biasPtr = (float*)bias->data();
for(int i = 0; i < oc; ++i){
if(biasPtr[i] != 0.0f){
hasBais = true;
break;
}
}
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true);
if(quanCommon->asymmetric || dataType == QNN_DATATYPE_FLOAT_16 || dataType == QNN_DATATYPE_FLOAT_32){
// not support asymmetric and mBlockSize > 1 results incorrect now
mWeightQuant = false;
return;
}
float inputScale = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.scale;
int inputOffset = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.offset;
if(inputOffset == 0){
mWeightQuant = true;
}else{
if(hasBais){
mWeightQuant = false;
}else{
mWeightQuant = true;
}
}
}
}
ErrorCode QNNConvDepthwise::onEncodeQuantDequantDepthConv(Tensor *input, Tensor *output, const int n, const int ic, const int oc) {
auto conv2D = mOp->main_as_Convolution2D();
auto common = conv2D->common();
Qnn_DataType_t dataType = QNN_DATATYPE_FLOAT_32;
if(mBackend->getUseFP16()){
dataType = QNN_DATATYPE_FLOAT_16;
}
// create dequant input stage tensor
this->createStageTensor("DequantInput", dataType, getNHWCShape(input)); // mTempTensorWrappers[2]
this->createStageTensor("QuantOutput", dataType, getNHWCShape(output)); // mTempTensorWrappers[3]
// add nodes
{
// dequant input
{
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Dequantize";
std::string name = mNodeName + "_dequant_input";
mInputs.push_back(*(mBackend->getNativeTensor(input))); // input
mOutputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
if (common->relu() || common->relu6()) {
this->createStageTensor("ReluTensor", dataType, getNHWCShape(output)); // mTempTensorWrappers[4]
// Stage one
{
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "DepthWiseConv2d";
std::string name = mNodeName + "_convDepthwise";
mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride
mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount
mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation
mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight
mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias
mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
// Stage two
{
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = common->relu6() ? "ReluMinMax" : "Relu";
std::string name = mNodeName + "_relu";
if (common->relu6()) {
mParams.push_back(*(mParamScalarWrappers[1]->getNativeParam())); // min_value
mParams.push_back(*(mParamScalarWrappers[2]->getNativeParam())); // max_value
}
mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor
mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
} else {
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "DepthWiseConv2d";
mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride
mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount
mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation
mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight
mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias
mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
mBackend->addNodeToGraph(mOpConfigVersion, mNodeName.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
// Quant output
{
auto QuantOutputTensor = mTempTensorWrappers[3]->getNativeTensor();
if(mBackend->getUseFP16()){
this->createStageTensor("CastOutput", QNN_DATATYPE_FLOAT_32, getNHWCShape(output));
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Cast";
std::string name = mNodeName + "_Cast_Output";
mInputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
mOutputs.push_back(*(mTempTensorWrappers.back()->getNativeTensor())); // CastOutput
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
QuantOutputTensor = mTempTensorWrappers.back()->getNativeTensor();
}
{
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Quantize";
std::string name = mNodeName + "_Quant_Output";
mInputs.push_back(*(QuantOutputTensor)); // stage tensor
mOutputs.push_back(*(mBackend->getNativeTensor(output))); // output
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
}
}
return NO_ERROR;
}
ErrorCode QNNConvDepthwise::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto conv2D = mOp->main_as_Convolution2D();
auto common = conv2D->common();
@ -36,6 +187,7 @@ ErrorCode QNNConvDepthwise::onEncode(const std::vector<Tensor *> &inputs, const
dilationH = common->dilateY(); dilationW = common->dilateX();
}
isWeightQuantSupported(inputs[0], oc);
// create all tensors and params
{
std::vector<uint32_t> strideData = {(uint32_t)strideH, (uint32_t)strideW};
@ -49,10 +201,24 @@ ErrorCode QNNConvDepthwise::onEncode(const std::vector<Tensor *> &inputs, const
this->createParamScalar("max_value", 6.0f);
}
this->createWeight(dataType, oc, kernelH, kernelW);
this->createBias(dataType, oc);
this->createWeightAndBias(dataType, inputs[0], oc, kernelH, kernelW);
// dequant input and quant output
if(mWeightQuant == false && dataType != QNN_DATATYPE_FLOAT_16 && dataType != QNN_DATATYPE_FLOAT_32){
return this->onEncodeQuantDequantDepthConv(inputs[0], outputs[0], n, ic, oc);
}
if (common->relu() || common->relu6()) {
this->createStageTensor("ReluTensor", dataType, getNHWCShape(outputs[0]));
Qnn_QuantizeParams_t quantize = DEFAULT_QUANTIZE_PARAMS;
Qnn_ScaleOffset_t tScaleOffsetEncoding;
auto quant = TensorUtils::getDescribe(outputs[0])->quantAttr.get();
if(quant != nullptr && TensorUtils::getDescribe(outputs[0])->type == DataType_DT_INT8){
quantize.encodingDefinition = QNN_DEFINITION_DEFINED;
quantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;
tScaleOffsetEncoding.scale = quant->scale;
tScaleOffsetEncoding.offset = quant->zero;
quantize.scaleOffsetEncoding = tScaleOffsetEncoding;
}
this->createStageTensor("ReluTensor", dataType, getNHWCShape(outputs[0]), quantize);
}
}
@ -112,7 +278,120 @@ ErrorCode QNNConvDepthwise::onEncode(const std::vector<Tensor *> &inputs, const
void QNNConvDepthwise::createWeight(Qnn_DataType_t dataType, int oc, int kernelH, int kernelW) {
void QNNConvDepthwise::createWeightAndBias(Qnn_DataType_t dataType, const Tensor *input, int oc, int kernelH, int kernelW) {
if(mWeightQuant){
Qnn_QuantizeParams_t weightQuantize{};
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true);
// [TODO] Support asymmetric and other quantBits.
MNN_ASSERT(!quanCommon->asymmetric);
// create weight
const int8_t * source = quanCommon->weight.get();
std::vector<int8_t> quantWeightData(oc * kernelH * kernelW, 0);
if(quanCommon->canUseInt4){
for (int c = 0; c < oc; c++) {
for (int h = 0; h < kernelH; h++) {
for (int w = 0; w < kernelW; w++) {
int srcOffset = w + kernelW * (h + kernelH * c);
int dstOffset = c + oc * (w + kernelW * h);
if(srcOffset % 2 == 0){
quantWeightData[dstOffset] = ((source[srcOffset / 2] >> 4) & 0x0f) - 8;
}else{
quantWeightData[dstOffset] = (source[srcOffset / 2] & 0x0f) - 8;
}
}
}
}
}else{
convertWeight(source, (int8_t *) quantWeightData.data(), oc, kernelH, kernelW);
}
mDequantAlpha = quanCommon->alpha.get();
if(quanCommon->canUseInt4){
weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET;
Qnn_BwAxisScaleOffset_t weightBWAxisScaleOffsetEncoding{};
weightBWAxisScaleOffsetEncoding.bitwidth = 4;
weightBWAxisScaleOffsetEncoding.axis = 3;
weightBWAxisScaleOffsetEncoding.numElements = oc;
mScale.resize(oc);
std::vector<int32_t> OffsetData(oc);
for (int i = 0; i < oc; i++) {
mScale[i] = mDequantAlpha[i];
}
weightBWAxisScaleOffsetEncoding.scales = mScale.data();
weightQuantize.bwAxisScaleOffsetEncoding = weightBWAxisScaleOffsetEncoding;
this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, 1, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize);
std::function<void()> mReleaseWeightScaleOffset = [&](){
std::vector<float>().swap(mScale);
};
mBackend->pushReleaseFunc(mReleaseWeightScaleOffset);
}else{
weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET;
Qnn_AxisScaleOffset_t weightAxisScaleOffsetEncoding{};
weightAxisScaleOffsetEncoding.axis = 3;
weightAxisScaleOffsetEncoding.numScaleOffsets = oc;
mScaleOffsetData.resize(oc);
for (int i = 0; i < oc; i++) {
mScaleOffsetData[i].scale = mDequantAlpha[i];
mScaleOffsetData[i].offset = 0;
}
weightAxisScaleOffsetEncoding.scaleOffset = mScaleOffsetData.data();
weightQuantize.axisScaleOffsetEncoding = weightAxisScaleOffsetEncoding;
this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, 1, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize);
std::function<void()> mReleaseWeightScaleOffset = [&](){
std::vector<Qnn_ScaleOffset_t>().swap(mScaleOffsetData);
};
mBackend->pushReleaseFunc(mReleaseWeightScaleOffset);
}
// create bias
{
float inputScale = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.scale;
int inputOffset = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.offset;
std::vector<int> biasData;
biasData.resize(oc, 0);
Qnn_QuantizeParams_t biasQuantize{};
biasQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
biasQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET;
Qnn_AxisScaleOffset_t biasAxisScaleOffsetEncoding{};
biasAxisScaleOffsetEncoding.axis = 0;
biasAxisScaleOffsetEncoding.numScaleOffsets = oc;
mBiasScaleOffsetData.resize(oc);
auto bias = mOp->main_as_Convolution2D()->bias();
auto biasPtr = (float*)bias->data();
if (nullptr != bias) {
for(int i = 0; i < oc; ++i){
float biasScale = inputScale * mDequantAlpha[i];
mBiasScaleOffsetData[i].scale = 0.f;
mBiasScaleOffsetData[i].offset = 0;
if(fabs(biasPtr[i]) < 0.000001 || fabs(biasScale) < 0.000001){
biasData[i] = 0;
} else{
biasData[i] = (int)(biasPtr[i] / (biasScale));
}
}
}
biasAxisScaleOffsetEncoding.scaleOffset = mBiasScaleOffsetData.data();
biasQuantize.axisScaleOffsetEncoding = biasAxisScaleOffsetEncoding;
this->createStaticTensor("bias", QNN_DATATYPE_SFIXED_POINT_32, {(uint32_t)oc}, biasData.data(), biasQuantize);
std::function<void()> mReleaseBiasScaleOffset = [&](){
std::vector<Qnn_ScaleOffset_t>().swap(mBiasScaleOffsetData);
};
mBackend->pushReleaseFunc(mReleaseBiasScaleOffset);
}
}else{
Qnn_DataType_t floatDatatype = QNN_DATATYPE_FLOAT_32;
if(mBackend->getUseFP16()){
floatDatatype = QNN_DATATYPE_FLOAT_16;
}
std::vector<float> weightData;
const float* source = nullptr;
int weightElementNum = 0;
@ -121,32 +400,16 @@ void QNNConvDepthwise::createWeight(Qnn_DataType_t dataType, int oc, int kernelH
// oc ic h w ---> h w ic oc
weightData.resize(weightElementNum);
convertWeight(source, (float *) weightData.data(), oc, kernelH, kernelW);
this->createStaticFloatTensor("weight", dataType, {(uint32_t)kernelH, (uint32_t)kernelW, 1, (uint32_t)oc}, weightData.data());
}
this->createStaticFloatTensor("weight", floatDatatype, {(uint32_t)kernelH, (uint32_t)kernelW, 1, (uint32_t)oc}, weightData.data());
void QNNConvDepthwise::createBias(Qnn_DataType_t dataType, int oc) {
int biasElementNum = oc;
// create bias
std::vector<float> biasData;
biasData.resize(biasElementNum, 0);
biasData.resize(oc, 0);
auto bias = mOp->main_as_Convolution2D()->bias();
if (nullptr != bias) {
::memcpy((void *)biasData.data(), (void *)bias->data(), biasElementNum * sizeof(float));
}
this->createStaticFloatTensor("bias", dataType, {(uint32_t)oc}, biasData.data());
}
// oc, h, w ---> h, w, oc
void QNNConvDepthwise::convertWeight(const float * src, float * dst, int oc, int kernelH, int kernelW) {
for (int c = 0; c < oc; c++) {
for (int h = 0; h < kernelH; h++) {
for (int w = 0; w < kernelW; w++) {
int srcOffset = w + kernelW * (h + kernelH * c);
int dstOffset = c + oc * (w + kernelW * h);
dst[dstOffset] = src[srcOffset];
}
::memcpy((void *)biasData.data(), (void *)bias->data(), oc * sizeof(float));
}
this->createStaticFloatTensor("bias", floatDatatype, {(uint32_t)oc}, biasData.data());
}
}

View File

@ -19,10 +19,28 @@ class QNNConvDepthwise : public QNNCommonExecution {
public:
QNNConvDepthwise(Backend *backend, const Op *op) : QNNCommonExecution(backend, op) {}
virtual ErrorCode onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
ErrorCode onEncodeQuantDequantDepthConv(Tensor *input, Tensor *output, const int n, const int ic, const int oc);
private:
void createWeight(Qnn_DataType_t dataType, int oc, int kernelH, int kernelW);
void createBias(Qnn_DataType_t dataType, int oc);
void convertWeight(const float * src, float * dst, int oc, int kernelH, int kernelW);
template <typename T>
void convertWeight(const T * src, T * dst, int oc, int kernelH, int kernelW) {
for (int c = 0; c < oc; c++) {
for (int h = 0; h < kernelH; h++) {
for (int w = 0; w < kernelW; w++) {
int srcOffset = w + kernelW * (h + kernelH * c);
int dstOffset = c + oc * (w + kernelW * h);
dst[dstOffset] = src[srcOffset];
}
}
}
}
void isWeightQuantSupported(const Tensor *input, const int oc);
void createWeightAndBias(Qnn_DataType_t dataType, const Tensor *input, int oc, int kernelH, int kernelW);
std::vector<float> mScale;
std::vector<Qnn_ScaleOffset_t> mScaleOffsetData;
std::vector<Qnn_ScaleOffset_t> mBiasScaleOffsetData;
std::vector<uint8_t> mBlockScale;
float *mDequantAlpha = nullptr;
bool mWeightQuant = false;
};
} // end namespace QNN

View File

@ -21,6 +21,63 @@ static std::pair<int, int> closest_factors(int n) {
}
return {1, n};
}
void QNNConvolution::isWeightQuantSupported(const Tensor *input, const int ic, const int oc){
Qnn_DataType_t dataType = mBackend->getNativeTensor(input)->v1.dataType;
if(mOp->main_as_Convolution2D()->quanParameter() == nullptr){
mWeightQuant = false;
return;
}else{
bool hasBias = false;
auto bias = mOp->main_as_Convolution2D()->bias();
auto biasPtr = (float*)bias->data();
for(int i = 0; i < oc; ++i){
if(biasPtr[i] != 0.0f){
hasBias = true;
break;
}
}
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true);
int totalCount = quanCommon->alpha.size();
mBlockSize = totalCount / oc;
if(quanCommon->asymmetric){
// not support asymmetric and mBlockSize > 1 results incorrect now
mWeightQuant = false;
return;
}
if(dataType == QNN_DATATYPE_FLOAT_16 || dataType == QNN_DATATYPE_FLOAT_32){
if(mIsMatMul && mBlockSize == 1){
mWeightQuant = true;
}else{
mWeightQuant = false;
}
return;
}
float inputScale = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.scale;
int inputOffset = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.offset;
if(inputOffset == 0){
mWeightQuant = true;
}else{
if(hasBias){
mWeightQuant = false;
}else{
mWeightQuant = true;
}
}
if(mBlockSize > 1 && mWeightQuant){
if(mIs1x1Conv && hasBias == false && (ic / mBlockSize) >= 16){
mWeightQuant = true;
}else{
mWeightQuant = false;
}
}
}
}
ErrorCode QNNConvolution::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
auto conv2D = mOp->main_as_Convolution2D();
auto common = conv2D->common();
@ -46,32 +103,16 @@ ErrorCode QNNConvolution::onEncode(const std::vector<Tensor *> &inputs, const st
dilationH = common->dilateY(); dilationW = common->dilateX();
group = common->group();
}
mIs1x1Conv = kernelW==1 && strideH==1 && \
strideW==1 && dilationH==1 && dilationW==1 && group==1 && \
padTop==0 && padBottom==0 && padLeft==0 && padRight==0;
mIsMatMul = ih==1 && iw==1 && oh==1 && ow==1 && mIs1x1Conv;
isWeightQuantSupported(inputs[0], ic, oc);
const float * weightSource = nullptr;
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
if (mOp->main_as_Convolution2D()->quanParameter()) {
bool forceFloat = (common->kernelX() == 1 && common->kernelY() == 1) ? false : true;
quanCommon = ConvolutionCommon::load(mOp, this->backend(), forceFloat);
if (quanCommon->weightFloat.get() == nullptr) {
// [TODO] Support asymmetric and other quantBits.
// isQuantWeight && symmetric quantization && int8 quantization && 1x1 conv
if (quanCommon->asymmetric || quanCommon->canUseInt4) {
return NOT_SUPPORT;
}
return this->onEncodeQuant(inputs[0], outputs[0], n, ih, iw, ic, oc, quanCommon);
} else {
weightSource = quanCommon->weightFloat.get();
}
} else {
int weightElementNum;
ConvolutionCommon::getConvParameters(&quanCommon, mBackend, mOp, &weightSource, &weightElementNum);
if(mIsMatMul && mWeightQuant && (dataType == QNN_DATATYPE_FLOAT_16 || dataType == QNN_DATATYPE_FLOAT_32)){
return onEncodeFpAIntBMatMul(inputs[0], outputs[0], n, ih, iw, ic, oc);
}
#ifdef QNN_VERBOSE
MNN_PRINT("n:%d, ih:%d, iw:%d, ic:%d, oh:%d, ow:%d, oc:%d, kernelH:%d, kernelW:%d, dilationH:%d, dilationW:%d, strideH:%d, strideW:%d, group:%d, pad:%d %d %d %d\n", n, ih, iw, ic, oh, ow, oc, kernelH, kernelW, dilationH, \
dilationW, strideH, strideW, group, padTop, padBottom, padLeft, padRight);
#endif
// create all tensors and params
{
std::vector<uint32_t> strideData = {(uint32_t)strideH, (uint32_t)strideW};
@ -85,12 +126,26 @@ ErrorCode QNNConvolution::onEncode(const std::vector<Tensor *> &inputs, const st
this->createParamScalar("min_value", 0.0f);
this->createParamScalar("max_value", 6.0f);
}
this->createWeight(dataType, oc, ic, kernelH, kernelW, group, weightSource);
this->createBias(dataType, oc);
if (common->relu() || common->relu6()) {
this->createStageTensor("ReluTensor", dataType, getNHWCShape(outputs[0]));
}
this->createWeightAndBias(dataType, inputs[0], oc, ic, kernelH, kernelW, group);
// dequant input and quant output
if(mWeightQuant == false && dataType != QNN_DATATYPE_FLOAT_16 && dataType != QNN_DATATYPE_FLOAT_32){
return this->onEncodeQuantDequantConv(inputs[0], outputs[0], n, ic, oc);
}
if (common->relu() || common->relu6()) {
Qnn_QuantizeParams_t quantize = DEFAULT_QUANTIZE_PARAMS;
Qnn_ScaleOffset_t tScaleOffsetEncoding;
auto quant = TensorUtils::getDescribe(outputs[0])->quantAttr.get();
if(quant != nullptr && TensorUtils::getDescribe(outputs[0])->type == DataType_DT_INT8){
quantize.encodingDefinition = QNN_DEFINITION_DEFINED;
quantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;
tScaleOffsetEncoding.scale = quant->scale;
tScaleOffsetEncoding.offset = quant->zero;
quantize.scaleOffsetEncoding = tScaleOffsetEncoding;
}
this->createStageTensor("ReluTensor", dataType, getNHWCShape(outputs[0]), quantize);
}
// add nodes
@ -131,13 +186,34 @@ ErrorCode QNNConvolution::onEncode(const std::vector<Tensor *> &inputs, const st
}
} else {
bool isMatmul = ih==1 && iw==1 && oh==1 && ow==1 && kernelH==1 && kernelW==1 && strideH==1 && \
strideW==1 && dilationH==1 && dilationW==1 && group==1 && \
padTop==0 && padBottom==0 && padLeft==0 && padRight==0;
if(isMatmul && n > 1) {
if(mIsMatMul && n > 1) {
auto num = closest_factors(n);
this->createStageTensor("InputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, ic}));
this->createStageTensor("OutputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, oc}));
{
Qnn_QuantizeParams_t quantize = DEFAULT_QUANTIZE_PARAMS;
Qnn_ScaleOffset_t tScaleOffsetEncoding;
auto quant = TensorUtils::getDescribe(inputs[0])->quantAttr.get();
if(quant != nullptr && TensorUtils::getDescribe(inputs[0])->type == DataType_DT_INT8){
quantize.encodingDefinition = QNN_DEFINITION_DEFINED;
quantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;
tScaleOffsetEncoding.scale = quant->scale;
tScaleOffsetEncoding.offset = quant->zero;
quantize.scaleOffsetEncoding = tScaleOffsetEncoding;
}
this->createStageTensor("InputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, ic}), quantize);
}
{
Qnn_QuantizeParams_t quantize = DEFAULT_QUANTIZE_PARAMS;
Qnn_ScaleOffset_t tScaleOffsetEncoding;
auto quant = TensorUtils::getDescribe(outputs[0])->quantAttr.get();
if(quant != nullptr && TensorUtils::getDescribe(outputs[0])->type == DataType_DT_INT8){
quantize.encodingDefinition = QNN_DEFINITION_DEFINED;
quantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;
tScaleOffsetEncoding.scale = quant->scale;
tScaleOffsetEncoding.offset = quant->zero;
quantize.scaleOffsetEncoding = tScaleOffsetEncoding;
}
this->createStageTensor("OutputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, oc}), quantize);
}
#ifdef QNN_VERBOSE
MNN_PRINT("Matmul2Conv, start reshape batch:%d -> %dx%d\n", n, num.first, num.second);
#endif
@ -208,52 +284,273 @@ ErrorCode QNNConvolution::onEncode(const std::vector<Tensor *> &inputs, const st
return NO_ERROR;
}
ErrorCode QNNConvolution::onEncodeQuantDequantConv(Tensor *input, Tensor *output, const int n, const int ic, const int oc) {
auto conv2D = mOp->main_as_Convolution2D();
auto common = conv2D->common();
Qnn_DataType_t dataType = QNN_DATATYPE_FLOAT_32;
if(mBackend->getUseFP16()){
dataType = QNN_DATATYPE_FLOAT_16;
}
ErrorCode QNNConvolution::onEncodeQuant(Tensor * input, Tensor * output, int n, int h, int w, int ic, int oc, std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon) {
// create dequant input stage tensor
this->createStageTensor("DequantInput", dataType, getNHWCShape(input)); // mTempTensorWrappers[2]
this->createStageTensor("QuantOutput", dataType, getNHWCShape(output)); // mTempTensorWrappers[3]
// add nodes
{
// dequant input
{
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Dequantize";
std::string name = mNodeName + "_dequant_input";
mInputs.push_back(*(mBackend->getNativeTensor(input))); // input
mOutputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
if (common->relu() || common->relu6()) {
this->createStageTensor("ReluTensor", dataType, getNHWCShape(output)); // mTempTensorWrappers[4]
// Stage one
{
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Conv2d";
std::string name = mNodeName + "_conv";
mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride
mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount
mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation
mParams.push_back(*(mParamScalarWrappers[0]->getNativeParam())); // group
mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight
mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias
mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
// Stage two
{
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = common->relu6() ? "ReluMinMax" : "Relu";
std::string name = mNodeName + "_relu";
if (common->relu6()) {
mParams.push_back(*(mParamScalarWrappers[1]->getNativeParam())); // min_value
mParams.push_back(*(mParamScalarWrappers[2]->getNativeParam())); // max_value
}
mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor
mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
} else {
if(mIsMatMul && n > 1) {
auto num = closest_factors(n);
this->createStageTensor("InputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, ic})); // mTempTensorWrappers[4]
this->createStageTensor("OutputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, oc})); // mTempTensorWrappers[5]
#ifdef QNN_VERBOSE
MNN_PRINT("Matmul2Conv, start reshape batch:%d -> %dx%d\n", n, num.first, num.second);
#endif
// reshape input
{
std::string name = mNodeName + "_input_reshape";
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Reshape";
mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // InputReshapeTensor
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
// conv2d
{
std::string name = mNodeName;
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Conv2d";
mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride
mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount
mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation
mParams.push_back(*(mParamScalarWrappers[0]->getNativeParam())); // group
mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // InputReshapeTensor
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight
mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias
mOutputs.push_back(*(mTempTensorWrappers[5]->getNativeTensor())); // OutputReshapeTensor
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
// reshape output
{
std::string name = mNodeName + "_output_reshape";
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Reshape";
mInputs.push_back(*(mTempTensorWrappers[5]->getNativeTensor())); // OutputReshapeTensor
mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
} else{
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Conv2d";
mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride
mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount
mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation
mParams.push_back(*(mParamScalarWrappers[0]->getNativeParam())); // group
mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight
mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias
mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
mBackend->addNodeToGraph(mOpConfigVersion, mNodeName.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
}
// Quant output
{
auto QuantOutputTensor = mTempTensorWrappers[3]->getNativeTensor();
if(mBackend->getUseFP16()){
this->createStageTensor("CastOutput", QNN_DATATYPE_FLOAT_32, getNHWCShape(output));
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Cast";
std::string name = mNodeName + "_Cast_Output";
mInputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
mOutputs.push_back(*(mTempTensorWrappers.back()->getNativeTensor())); // CastOutput
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
QuantOutputTensor = mTempTensorWrappers.back()->getNativeTensor();
}
{
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Quantize";
std::string name = mNodeName + "_Quant_Output";
mInputs.push_back(*(QuantOutputTensor)); // stage tensor
mOutputs.push_back(*(mBackend->getNativeTensor(output))); // output
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
}
}
return NO_ERROR;
}
ErrorCode QNNConvolution::onEncodeFpAIntBMatMul(Tensor * input, Tensor * output, int n, int h, int w, int ic, int oc) {
// create parameters and stage tensors
auto conv2D = mOp->main_as_Convolution2D();
auto common = conv2D->common();
Qnn_DataType_t dataType = mBackend->getNativeTensor(input)->v1.dataType;
{
bool transposeWeightFlag = true;
this->createParamScalar("transpose_in1", transposeWeightFlag);
std::vector<uint32_t> tempInputShape = {(uint32_t) n * h * w , (uint32_t) ic};
std::vector<uint32_t> tempOutputShape = {(uint32_t) n * h * w , (uint32_t) oc};
this->createStageTensor("tempInput", QNN_DATATYPE_FLOAT_16, tempInputShape);
this->createStageTensor("tempOutput", QNN_DATATYPE_FLOAT_16, tempOutputShape);
this->createStageTensor("tempInput", dataType, tempInputShape);
this->createStageTensor("tempOutput", dataType, tempOutputShape);
// create weight
// create weight and bias
{
Qnn_QuantizeParams_t weightQuantize{};
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true);
MNN_ASSERT(!quanCommon->asymmetric);
const int8_t * source = quanCommon->weight.get();
std::vector<int8_t> quantWeightData(oc * ic, 0);
if(quanCommon->canUseInt4){
for (int o = 0; o < oc; o++) {
for (int i = 0; i < ic; i++) {
uint32_t srcOffset = o * ic + i;
uint32_t dstOffset = srcOffset;
if(srcOffset % 2 == 0){
quantWeightData[dstOffset] = ((source[srcOffset / 2] >> 4) & 0x0f) - 8;
}else{
quantWeightData[dstOffset] = (source[srcOffset / 2] & 0x0f) - 8;
}
}
}
}else{
::memcpy(quantWeightData.data(), source, oc * ic * sizeof(int8_t));
}
mDequantAlpha = quanCommon->alpha.get();
int totalCount = quanCommon->alpha.size();
mBlockSize = totalCount / oc;
int blockNum = ic / mBlockSize;
if(quanCommon->canUseInt4){
weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET;
Qnn_BwAxisScaleOffset_t weightBWAxisScaleOffsetEncoding{};
weightBWAxisScaleOffsetEncoding.bitwidth = 4;
weightBWAxisScaleOffsetEncoding.axis = 0;
weightBWAxisScaleOffsetEncoding.numElements = oc;
mScale.resize(oc);
std::vector<int32_t> OffsetData(oc);
for (int i = 0; i < oc; i++) {
mScale[i] = mDequantAlpha[i];
}
weightBWAxisScaleOffsetEncoding.scales = mScale.data();
weightQuantize.bwAxisScaleOffsetEncoding = weightBWAxisScaleOffsetEncoding;
float * dequantAlpha = quanCommon->alpha.get();
Qnn_QuantizeParams_t weightQuantize{};
this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)oc, (uint32_t)ic}, (void *) quantWeightData.data(), weightQuantize);
std::function<void()> mReleaseWeightScaleOffset = [&](){
std::vector<float>().swap(mScale);
};
mBackend->pushReleaseFunc(mReleaseWeightScaleOffset);
}else{
weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET;
Qnn_AxisScaleOffset_t weightAxisScaleOffsetEncoding{};
weightAxisScaleOffsetEncoding.axis = 0;
weightAxisScaleOffsetEncoding.numScaleOffsets = oc;
std::vector<Qnn_ScaleOffset_t> scaleOffsetData(oc);
mScaleOffsetData.resize(oc);
for (int i = 0; i < oc; i++) {
if (quanCommon->asymmetric) {
scaleOffsetData[i].scale = dequantAlpha[2 * i + 1];
// scaleOffsetData[i].offset = (int) dequantAlpha[2 * i + 0];
scaleOffsetData[i].offset = 0;
} else {
scaleOffsetData[i].scale = dequantAlpha[i];
scaleOffsetData[i].offset = 0;
mScaleOffsetData[i].scale = mDequantAlpha[i];
mScaleOffsetData[i].offset = 0;
}
}
weightAxisScaleOffsetEncoding.scaleOffset = scaleOffsetData.data();
weightAxisScaleOffsetEncoding.scaleOffset = mScaleOffsetData.data();
weightQuantize.axisScaleOffsetEncoding = weightAxisScaleOffsetEncoding;
this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)oc, (uint32_t)ic}, (void *) quantWeightData.data(), weightQuantize);
std::function<void()> mReleaseWeightScaleOffset = [&](){
std::vector<Qnn_ScaleOffset_t>().swap(mScaleOffsetData);
};
mBackend->pushReleaseFunc(mReleaseWeightScaleOffset);
}
//create bias
this->createBias(dataType, oc, input, quanCommon);
}
if (common->relu6()) {
this->createParamScalar("min_value", 0.0f);
this->createParamScalar("max_value", 6.0f);
}
if (common->relu() || common->relu6()) {
this->createStageTensor("ReluTensor", dataType, getNHWCShape(output));
}
}
// Stage One: reshape input
{
mNodeType = "Reshape";
std::string name = mNodeName + "_reshapeOutput";
std::string name = mNodeName + "_reshapeInput";
mParams.clear();
mInputs.clear();
mOutputs.clear();
@ -273,6 +570,7 @@ ErrorCode QNNConvolution::onEncodeQuant(Tensor * input, Tensor * output, int n,
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // tempInput
// mInputs.push_back(*(mBackend->getNativeTensor(input)));
mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // weight
mInputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // bias
mOutputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // tempOutput
// mOutputs.push_back(*(mBackend->getNativeTensor(output)));
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
@ -286,48 +584,217 @@ ErrorCode QNNConvolution::onEncodeQuant(Tensor * input, Tensor * output, int n,
mInputs.clear();
mOutputs.clear();
mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor()));
if (common->relu() || common->relu6()){
mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); //ReluTensor
}else{
mOutputs.push_back(*(mBackend->getNativeTensor(output)));
}
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
// Stage Four: relu or relu6
if (common->relu() || common->relu6()){
mNodeType.clear();
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = common->relu6() ? "ReluMinMax" : "Relu";
std::string name = mNodeName + "_relu";
if (common->relu6()) {
mParams.push_back(*(mParamScalarWrappers[1]->getNativeParam())); // min_value
mParams.push_back(*(mParamScalarWrappers[2]->getNativeParam())); // max_value
}
mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor
mOutputs.push_back(*(mBackend->getNativeTensor(output))); // output
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
return NO_ERROR;
}
bool QNNConvolution::createWeightAndBias(Qnn_DataType_t dataType, const Tensor *input, int oc, int ic, int kernelH, int kernelW, int group) {
if(mWeightQuant){
Qnn_QuantizeParams_t weightQuantize{};
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true);
if(quanCommon->asymmetric) {
MNN_ERROR("[Error]: Qnn weight quant only support symmetric currently\n");
return false;
}
const int8_t * source = quanCommon->weight.get();
std::vector<int8_t> quantWeightData(oc * (ic / group) * kernelH * kernelW, 0);
if(quanCommon->canUseInt4){
for (int o = 0; o < oc; o++) {
for (int i = 0; i < ic/group; i++) {
for (int h = 0; h < kernelH; h++) {
for (int w = 0; w < kernelW; w++) {
uint32_t srcOffset = w + kernelW * (h + kernelH * (i + ic/group * o));
uint32_t dstOffset = o + oc * (i + ic/group * (w + kernelW * h));
if(srcOffset % 2 == 0){
quantWeightData[dstOffset] = ((source[srcOffset / 2] >> 4) & 0x0f) - 8;
}else{
quantWeightData[dstOffset] = (source[srcOffset / 2] & 0x0f) - 8;
}
}
}
}
}
}else{
convertWeight(source, (int8_t *) quantWeightData.data(), oc, ic/group, kernelH, kernelW);
}
mDequantAlpha = quanCommon->alpha.get();
int totalCount = quanCommon->alpha.size();
mBlockSize = totalCount / oc;
// Todo: result is wrong, need to verify
if(mBlockSize > 1){
Qnn_QuantizeParams_t weightQuantize{};
weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION;
weightBlockwiseExpansionEncoding.axis = 3;
weightBlockwiseExpansionEncoding.numBlocksPerAxis = mBlockSize;
weightBlockwiseExpansionEncoding.blockScaleBitwidth = 4;
weightBlockwiseExpansionEncoding.blockScaleStorageType = QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8;
mBlockScale.resize(oc * mBlockSize);
mScaleOffsetData.resize(oc);
for (int i = 0; i < oc; i++) {
float maxscale = -MAXFLOAT;
for(int j = 0; j < mBlockSize; ++j){
if(mDequantAlpha[i * mBlockSize + j] > maxscale){
maxscale = mDequantAlpha[i * mBlockSize + j];
}
}
float blockScale = maxscale / 16.0f;
for(int j = 0; j < mBlockSize; ++j){
int quantBlock = round(mDequantAlpha[i * mBlockSize + j] / blockScale);
mBlockScale[i * mBlockSize + j] = (uint8_t)std::min(std::max(quantBlock, 1), 16);
}
mScaleOffsetData[i].scale = blockScale;
mScaleOffsetData[i].offset = 0;
}
weightBlockwiseExpansionEncoding.scaleOffsets = mScaleOffsetData.data();
weightBlockwiseExpansionEncoding.blocksScale8 = mBlockScale.data();
weightQuantize.blockwiseExpansion = &weightBlockwiseExpansionEncoding;
this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize);
std::function<void()> mReleaseWeightScaleOffset = [&](){
std::vector<Qnn_ScaleOffset_t>().swap(mScaleOffsetData);
};
mBackend->pushReleaseFunc(mReleaseWeightScaleOffset);
std::function<void()> mReleaseBlockScale = [&](){
std::vector<uint8_t>().swap(mBlockScale);
};
mBackend->pushReleaseFunc(mReleaseBlockScale);
}else if(quanCommon->canUseInt4){
weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET;
Qnn_BwAxisScaleOffset_t weightBWAxisScaleOffsetEncoding{};
weightBWAxisScaleOffsetEncoding.bitwidth = 4;
weightBWAxisScaleOffsetEncoding.axis = 3;
weightBWAxisScaleOffsetEncoding.numElements = oc;
mScale.resize(oc);
std::vector<int32_t> OffsetData(oc);
for (int i = 0; i < oc; i++) {
mScale[i] = mDequantAlpha[i];
}
weightBWAxisScaleOffsetEncoding.scales = mScale.data();
weightQuantize.bwAxisScaleOffsetEncoding = weightBWAxisScaleOffsetEncoding;
this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize);
std::function<void()> mReleaseWeightScaleOffset = [&](){
std::vector<float>().swap(mScale);
};
mBackend->pushReleaseFunc(mReleaseWeightScaleOffset);
}else{
weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET;
Qnn_AxisScaleOffset_t weightAxisScaleOffsetEncoding{};
weightAxisScaleOffsetEncoding.axis = 3;
weightAxisScaleOffsetEncoding.numScaleOffsets = oc;
mScaleOffsetData.resize(oc);
for (int i = 0; i < oc; i++) {
mScaleOffsetData[i].scale = mDequantAlpha[i];
mScaleOffsetData[i].offset = 0;
}
weightAxisScaleOffsetEncoding.scaleOffset = mScaleOffsetData.data();
weightQuantize.axisScaleOffsetEncoding = weightAxisScaleOffsetEncoding;
this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize);
std::function<void()> mReleaseWeightScaleOffset = [&](){
std::vector<Qnn_ScaleOffset_t>().swap(mScaleOffsetData);
};
mBackend->pushReleaseFunc(mReleaseWeightScaleOffset);
}
this->createBias(dataType, oc, input, quanCommon);
} else {
std::vector<float> weightData;
const float* source = nullptr;
int weightElementNum = 0;
std::shared_ptr<ConvolutionCommon::Int8Common> quanWeight;
ConvolutionCommon::getConvParameters(&quanWeight, mBackend, mOp, &source, &weightElementNum);
// oc ic h w ---> h w ic oc
weightData.resize(weightElementNum);
convertWeight(source, (float *) weightData.data(), oc, ic/group, kernelH, kernelW);
Qnn_DataType_t floatDatatype = QNN_DATATYPE_FLOAT_32;
if(mBackend->getUseFP16()){
floatDatatype = QNN_DATATYPE_FLOAT_16;
}
this->createStaticFloatTensor("weight", floatDatatype, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, weightData.data());
this->createBias(dataType, oc, input, nullptr);
}
return NO_ERROR;
}
void QNNConvolution::createWeight(Qnn_DataType_t dataType, int oc, int ic, int kernelH, int kernelW, int group, const float * source) {
std::vector<float> weightData;
int weightElementNum = oc * ic / group * kernelH * kernelW;
// oc ic/group h w ---> h w ic/group oc
weightData.resize(weightElementNum);
void QNNConvolution::createBias(Qnn_DataType_t dataType, int oc, const Tensor *input, std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon) {
int biasElementNum = oc;
if(dataType != QNN_DATATYPE_FLOAT_16 && dataType != QNN_DATATYPE_FLOAT_32 && mWeightQuant){
mDequantAlpha = quanCommon->alpha.get();
float inputScale = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.scale;
int inputOffset = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.offset;
std::vector<int> biasData;
biasData.resize(biasElementNum, 0);
convertWeight(source, (float *) weightData.data(), oc, ic/group, kernelH, kernelW);
this->createStaticFloatTensor("weight", dataType, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, weightData.data());
Qnn_QuantizeParams_t biasQuantize{};
biasQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
biasQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET;
Qnn_AxisScaleOffset_t biasAxisScaleOffsetEncoding{};
biasAxisScaleOffsetEncoding.axis = 0;
biasAxisScaleOffsetEncoding.numScaleOffsets = biasElementNum;
mBiasScaleOffsetData.resize(biasElementNum);
auto bias = mOp->main_as_Convolution2D()->bias();
auto biasPtr = (float*)bias->data();
if (nullptr != bias) {
for(int i = 0; i < biasElementNum; ++i){
float biasScale = inputScale * mDequantAlpha[i];
mBiasScaleOffsetData[i].scale = biasScale;
mBiasScaleOffsetData[i].offset = 0;
if(fabs(biasPtr[i]) < 0.000001 || fabs(biasScale) < 0.000001){
biasData[i] = 0;
} else{
biasData[i] = (int)(biasPtr[i] / biasScale);
}
}
}
biasAxisScaleOffsetEncoding.scaleOffset = mBiasScaleOffsetData.data();
biasQuantize.axisScaleOffsetEncoding = biasAxisScaleOffsetEncoding;
void QNNConvolution::createBias(Qnn_DataType_t dataType, int oc) {
int biasElementNum = oc;
this->createStaticTensor("bias", QNN_DATATYPE_SFIXED_POINT_32, {(uint32_t)biasElementNum}, biasData.data(), biasQuantize);
std::function<void()> mReleaseBiasScaleOffset = [&](){
std::vector<Qnn_ScaleOffset_t>().swap(mBiasScaleOffsetData);
};
mBackend->pushReleaseFunc(mReleaseBiasScaleOffset);
}else{
std::vector<float> biasData;
biasData.resize(biasElementNum, 0);
auto bias = mOp->main_as_Convolution2D()->bias();
if (nullptr != bias) {
::memcpy((void *)biasData.data(), (void *)bias->data(), biasElementNum * sizeof(float));
}
this->createStaticFloatTensor("bias", dataType, {(uint32_t)oc}, biasData.data());
}
// oc ic h w ---> h w ic oc
void QNNConvolution::convertWeight(const float * src, float * dst, int oc, int ic, int kernelH, int kernelW) {
for (int o = 0; o < oc; o++) {
for (int i = 0; i < ic; i++) {
for (int h = 0; h < kernelH; h++) {
for (int w = 0; w < kernelW; w++) {
uint32_t srcOffset = w + kernelW * (h + kernelH * (i + ic * o));
uint32_t dstOffset = o + oc * (i + ic * (w + kernelW * h));
dst[dstOffset] = src[srcOffset];
}
}
Qnn_DataType_t floatDatatype = QNN_DATATYPE_FLOAT_32;
if(mBackend->getUseFP16()){
floatDatatype = QNN_DATATYPE_FLOAT_16;
}
this->createStaticFloatTensor("bias", floatDatatype, {(uint32_t)oc}, biasData.data());
}
}

View File

@ -19,12 +19,37 @@ class QNNConvolution : public QNNCommonExecution {
public:
QNNConvolution(Backend *backend, const Op *op) : QNNCommonExecution(backend, op) {}
virtual ErrorCode onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
ErrorCode onEncodeQuant(Tensor * input, Tensor * output, int n, int h, int w, int ic, int oc, std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon);
ErrorCode onEncodeFpAIntBMatMul(Tensor * input, Tensor * output, int n, int h, int w, int ic, int oc);
ErrorCode onEncodeQuantDequantConv(Tensor *input, Tensor *output, const int n, const int ic, const int oc);
private:
void createWeight(Qnn_DataType_t dataType, int oc, int ic, int kernelH, int kernelW, int group, const float * source);
void createBias(Qnn_DataType_t dataType, int oc);
void convertWeight(const float * src, float * dst, int oc, int ic, int kernelH, int kernelW);
template <typename T>
void convertWeight(const T * src, T * dst, int oc, int ic, int kernelH, int kernelW) {
for (int o = 0; o < oc; o++) {
for (int i = 0; i < ic; i++) {
for (int h = 0; h < kernelH; h++) {
for (int w = 0; w < kernelW; w++) {
uint32_t srcOffset = w + kernelW * (h + kernelH * (i + ic * o));
uint32_t dstOffset = o + oc * (i + ic * (w + kernelW * h));
dst[dstOffset] = src[srcOffset];
}
}
}
}
}
void isWeightQuantSupported(const Tensor *input, const int ic, const int oc);
bool createWeightAndBias(Qnn_DataType_t dataType, const Tensor *input, int oc, int ic, int kernelH, int kernelW, int group);
void createBias(Qnn_DataType_t dataType, int oc, const Tensor *input, std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon);
std::vector<float> mScale;
std::vector<Qnn_ScaleOffset_t> mScaleOffsetData;
std::vector<Qnn_ScaleOffset_t> mBiasScaleOffsetData;
std::vector<uint8_t> mBlockScale;
Qnn_BlockwiseExpansion_t weightBlockwiseExpansionEncoding = QNN_BLOCKWISE_EXPANSION_INIT;
float *mDequantAlpha = nullptr;
int mBlockSize = 1;
bool mWeightQuant = false;
bool mIsMatMul = false;
bool mIs1x1Conv = false;
};
} // end namespace QNN

View File

@ -0,0 +1,81 @@
//
// QNNQuant.cpp
// MNN
//
// Created by MNN on b'2025/05/29'.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "QNNQuant.hpp"
namespace MNN {
namespace QNN {
ErrorCode QNNQuant::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
this->createStageTensor("Cast", QNN_DATATYPE_FLOAT_32, getNHWCShape(outputs[0]));
// Stage one fp16 -> fp32
{
mNodeType = "Cast";
std::string name = mNodeName + "_Cast";
mInputs.push_back(*(mBackend->getNativeTensor(inputs[0]))); // input
mOutputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // stage tensor
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
// Stage two fp32 -> int8
{
mNodeType.clear();
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Quantize";
std::string name = mNodeName;
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // stage tensor
mOutputs.push_back(*(mBackend->getNativeTensor(outputs[0]))); // output
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
return NO_ERROR;
}
class QNNQuantCreator : public QnnBackend::Creator {
public:
virtual QNNCommonExecution * onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const MNN::Op* op,
Backend* backend) const override {
return new QNNQuant(backend, op);
}
};
ErrorCode QNNDeQuant::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
// Stage one int8 -> fp16
{
mNodeType.clear();
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Dequantize";
std::string name = mNodeName;
mInputs.push_back(*(mBackend->getNativeTensor(inputs[0]))); // input
mOutputs.push_back(*(mBackend->getNativeTensor(outputs[0]))); // output
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
return NO_ERROR;
}
class QNNDeQuantCreator : public QnnBackend::Creator {
public:
virtual QNNCommonExecution * onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const MNN::Op* op,
Backend* backend) const override {
return new QNNDeQuant(backend, op);
}
};
REGISTER_QNN_OP_CREATOR(QNNQuantCreator, OpType_FloatToInt8)
REGISTER_QNN_OP_CREATOR(QNNDeQuantCreator, OpType_Int8ToFloat)
} // end namespace QNN
} // end namespace MNN

View File

@ -0,0 +1,32 @@
//
// QNNQuant.hpp
// MNN
//
// Created by MNN on b'2025/05/29'.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef MNN_QNNQUANT_HPP
#define MNN_QNNQUANT_HPP
#include "QNNCommonExecution.hpp"
namespace MNN {
namespace QNN {
class QNNQuant : public QNNCommonExecution {
public:
QNNQuant(Backend *backend, const Op *op) : QNNCommonExecution(backend, op) {};
virtual ErrorCode onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
};
class QNNDeQuant : public QNNCommonExecution {
public:
QNNDeQuant(Backend *backend, const Op *op) : QNNCommonExecution(backend, op) {};
virtual ErrorCode onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
};
} // end namespace QNN
} // end namespace MNN
#endif // end MNN_QNNQUANT_HPP

View File

@ -28,11 +28,26 @@ ErrorCode QNNScale::onEncode(const std::vector<Tensor *> &inputs, const std::vec
MNN_ASSERT(channel == mWeightData.size());
Qnn_DataType_t dataType = mBackend->getNativeTensor(inputs[0])->v1.dataType;
mNeedQuantDequant = dataType != QNN_DATATYPE_FLOAT_16 && dataType != QNN_DATATYPE_FLOAT_32;
if(mNeedQuantDequant){
Qnn_DataType_t tempDataType = QNN_DATATYPE_FLOAT_32;
if(mBackend->getUseFP16()){
tempDataType = QNN_DATATYPE_FLOAT_16;
}
this->createStaticFloatTensor("weight", tempDataType, {(uint32_t)channel}, mWeightData.data());
this->createStaticFloatTensor("bias", tempDataType, {(uint32_t)channel}, mBiasData.data());
this->createStageTensor("Stage", tempDataType, getNHWCShape(inputs[0]));
this->createStageTensor("Stage_dequantize_input", tempDataType, getNHWCShape(inputs[0]));
this->createStageTensor("Stage_add_output", tempDataType, getNHWCShape(outputs[0]));
if(mBackend->getUseFP16()){
this->createStageTensor("Stage_cast_output", QNN_DATATYPE_FLOAT_32, getNHWCShape(outputs[0]));
}
}else{
this->createStaticFloatTensor("weight", dataType, {(uint32_t)channel}, mWeightData.data());
this->createStaticFloatTensor("bias", dataType, {(uint32_t)channel}, mBiasData.data());
this->createStageTensor("Stage", dataType, getNHWCShape(inputs[0]));
}
}
// add nodes
this->mulWeight(inputs[0]);
@ -42,36 +57,100 @@ ErrorCode QNNScale::onEncode(const std::vector<Tensor *> &inputs, const std::vec
}
void QNNScale::mulWeight(Tensor * input) {
mNodeType = "ElementWiseMultiply";
std::string name = mNodeName + "_mul";
Qnn_DataType_t dataType = mBackend->getNativeTensor(input)->v1.dataType;
// need dequantize to float16
if(mNeedQuantDequant){
mNodeType.clear();
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Dequantize";
std::string name = mNodeName + "_Dequantize";
mInputs.push_back(*(mBackend->getNativeTensor(input))); // input
mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); //Stage_dequantize_input
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
{
mNodeType.clear();
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "ElementWiseMultiply";
std::string name = mNodeName + "_mul";
if(mNeedQuantDequant){
mInputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); //Stage_dequantize_input
}else{
mInputs.push_back(*(mBackend->getNativeTensor(input)));
}
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor()));
mOutputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor()));
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
}
void QNNScale::addBias(Tensor * output) {
mNodeType = "ElementWiseAdd";
std::string name = mNodeName + "_add";
Qnn_DataType_t dataType = mBackend->getNativeTensor(output)->v1.dataType;
{
mNodeType.clear();
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "ElementWiseAdd";
std::string name = mNodeName + "_add";
mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor()));
mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor()));
if(mNeedQuantDequant){
mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // Stage_add_output
}else{
mOutputs.push_back(*(mBackend->getNativeTensor(output)));
}
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
// need quantize output
if(mNeedQuantDequant){
// Stage one fp16 -> fp32
if(mBackend->getUseFP16()){
mNodeType.clear();
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Cast";
std::string name = mNodeName + "_Cast";
mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // Stage_add_output
mOutputs.push_back(*(mTempTensorWrappers[5]->getNativeTensor())); // Stage_cast_output
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
// Stage two fp32 -> int8
{
mNodeType.clear();
mParams.clear();
mInputs.clear();
mOutputs.clear();
mNodeType = "Quantize";
std::string name = mNodeName + "_Quantize";
if(mBackend->getUseFP16()){
mInputs.push_back(*(mTempTensorWrappers[5]->getNativeTensor())); // Stage_cast_output
}else{
mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // Stage_add_output
}
mOutputs.push_back(*(mBackend->getNativeTensor(output))); // output
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
}
}
}
ErrorCode QNNScale::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
std::string nodeNameBase = "Scale";
nodeNameBase += "_";

View File

@ -25,6 +25,7 @@ private:
private:
std::vector<float> mWeightData;
std::vector<float> mBiasData;
bool mNeedQuantDequant = false;
};
} // end namespace QNN

View File

@ -34,12 +34,17 @@ struct RuntimeHint {
int cpuDecreaseRate = 50;
int dynamicQuantOption = 0;
// qkvQuantOption % 8:
// 0: Do not quantize
// 1: Only quantize key, use int8 asymmetric quantization
// 2: Only quantize value, use fp8 quantization
// 3: quantize both key and value
// 4: quantize query, key and value, and use gemm int8 kernel to compute K*V
int qkvQuantOption = 0;
// qkvQuantOption / 8:
// 1: use flash attention
int qkvQuantOption = 8;
// the kvcache size limit of each layer
// if the size of kvcache in memory exceeds the limit
@ -403,6 +408,9 @@ public:
info.mode = Backend::Info::DIRECT;
return true;
}
virtual bool onGetDeviceInfo(const std::string& deviceKey, std::string& deviceValue) const {
return false;
}
protected:
/**
@brief deinitializer.

View File

@ -129,6 +129,13 @@ static int dumpModelInfo(const char* modelName) {
} else {
MNN_PRINT("Model Version: %s \n", info->version.c_str());
}
if (!info->metaData.empty()) {
MNN_PRINT("MetaData: Begin \n");
for (auto& iter : info->metaData) {
MNN_PRINT("[Meta] %s : %s\n", iter.first.c_str(), iter.second.c_str());
}
MNN_PRINT("MetaData: End \n");
}
return 0;
}

View File

@ -130,7 +130,17 @@ bool CompleteSubGraph(const std::unordered_map<std::string, VARP>& inputs, const
return true;
}
static bool _hasDupName(std::unique_ptr<MNN::NetT>& originNet) {
std::set<std::string> names;
for (auto& tensorName : originNet->tensorName) {
if (names.find(tensorName) != names.end()) {
MNN_ERROR("Repeat name %s\n", tensorName.c_str());
return true;
}
names.insert(tensorName);
}
return false;
}
void RunNetPass(const std::vector<std::string>& passes, std::unique_ptr<MNN::NetT>& originNet) {
for (auto pass : passes) {
auto convert = PostConverter::get(pass);
@ -138,7 +148,14 @@ void RunNetPass(const std::vector<std::string>& passes, std::unique_ptr<MNN::Net
LOG(INFO) << "Can't find pass of " << pass << "\n";
continue;
}
auto originSize = originNet->oplists.size();
bool valid = convert->onExecute(originNet);
#ifdef DEBUG
auto hasDup = _hasDupName(originNet);
if (originSize != originNet->oplists.size() || hasDup) {
MNN_PRINT("%s: %d -> %d, dup: %d\n", pass.c_str(), originSize, originNet->oplists.size(), hasDup);
}
#endif
if (!valid) {
LOG(INFO) << "Run " << pass << "Error\n";
}

View File

@ -19,7 +19,7 @@ static EXPRP clipConvert(EXPRP expr, bool supportRelu6) {
auto extraParam = op->main_as_Extra();
// auto dataType = expr->outputInfo(0)->type.code;
auto maxValue = std::numeric_limits<T>().max();
auto minValue = std::numeric_limits<T>().min();
auto minValue = std::numeric_limits<T>().lowest();
if (nullptr != extraParam->attr()) {
const int attrSize = extraParam->attr()->size();
for (int i = 0; i < attrSize; ++i) {

View File

@ -12,20 +12,70 @@
class RemoveCopy : public PostConverter {
public:
virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override {
auto config = Global<modelConfig>::Get();
if (config->optimizeLevel < 1 || config->inSubGraph) {
return true;
std::set<std::string> netOutputNames;
for (auto& t : net->outputName) {
netOutputNames.insert(t);
}
for (auto iter = net->oplists.begin(); iter != net->oplists.end(); iter++) {
auto& op = *iter;
if (op->type == MNN::OpType_Input) {
for (auto o : op->outputIndexes) {
netOutputNames.insert(net->tensorName[o]);
}
}
}
auto config = Global<modelConfig>::Get();
for (auto iter = net->oplists.begin(); iter != net->oplists.end();) {
auto& op = *iter;
if (op->type != MNN::OpType_Identity) {
if (op->type != MNN::OpType_Identity || op->inputIndexes.size() != op->outputIndexes.size()) {
iter++;
continue;
}
bool hasOutputName = false;
for (auto o : op->outputIndexes) {
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
hasOutputName = true;
break;
}
}
bool hasOutputFromInput = false;
for (auto o : op->inputIndexes) {
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
hasOutputFromInput = true;
break;
}
}
if (hasOutputFromInput && hasOutputName) {
iter++;
continue;
}
auto originInput = op->inputIndexes;
auto originOutputs = op->outputIndexes;
MNN_ASSERT(originInput.size() == originOutputs.size());
if (hasOutputName) {
bool valid = true;
for (int i=0; i<op->inputIndexes.size(); ++i) {
auto o = op->outputIndexes[i];
auto originInput = op->inputIndexes[i];
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
if (netOutputNames.find(net->tensorName[originInput]) != netOutputNames.end()) {
valid = false;
break;
}
auto originName = net->tensorName[originInput];
net->tensorName[originInput] = net->tensorName[o];
net->tensorName[o] = originName;
}
}
if (!valid) {
continue;
}
}
std::map<int, int> replaceIndexes;
for (int i=0; i<op->inputIndexes.size();++i) {
replaceIndexes.insert(std::make_pair(op->outputIndexes[i], op->inputIndexes[i]));
net->tensorName[op->inputIndexes[i]] = net->tensorName[op->outputIndexes[i]];
}
for (auto subIter = net->oplists.begin(); subIter != net->oplists.end(); subIter++) {
auto& subOp = *subIter;
@ -35,14 +85,6 @@ public:
}
}
}
for (int v=0; v<op->inputIndexes.size(); ++v) {
for (auto& o : net->outputName) {
if (o == net->tensorName[op->inputIndexes[v]]) {
o = net->tensorName[op->outputIndexes[v]];
break;
}
}
}
iter = net->oplists.erase(iter);
}
return true;

View File

@ -0,0 +1,149 @@
#include "RemoveTestNoUseOps.hpp"
bool RemoveTestNoUseOps::onExecute(std::unique_ptr<MNN::NetT>& net) const {
const MNN::NetT* const netPtr = net.get();
std::set<std::string> netOutputNames;
for (auto& t : net->outputName) {
netOutputNames.insert(t);
}
for (auto iter = net->oplists.begin(); iter != net->oplists.end(); iter++) {
auto& op = *iter;
if (op->type == OpType_Input) {
for (auto o : op->outputIndexes) {
netOutputNames.insert(net->tensorName[o]);
}
}
}
std::unordered_set<int> removedInputs;
for (auto iter = net->oplists.begin(); iter != net->oplists.end();) {
auto& op = *iter;
bool shouldDelete = shouldDeleteJudge(op.get(), netPtr);
if (!shouldDelete) {
iter++;
continue;
}
bool hasOutputName = false;
for (auto o : op->outputIndexes) {
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
hasOutputName = true;
break;
}
}
bool hasOutputFromInput = false;
for (auto o : op->inputIndexes) {
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
hasOutputFromInput = true;
break;
}
}
if (hasOutputFromInput && hasOutputName) {
iter++;
continue;
}
bool deleteOutput = shouldDeleteOutput(op.get());
// Find the next op
if (op->outputIndexes.empty() || op->inputIndexes.empty()) {
iter = net->oplists.erase(iter);
continue;
}
auto originInput = op->inputIndexes[0];
auto originOutputs = op->outputIndexes;
if ((!deleteOutput) && hasOutputName) {
bool valid = true;
for (auto o : originOutputs) {
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
if (netOutputNames.find(net->tensorName[originInput]) != netOutputNames.end()) {
valid = false;
break;
}
net->tensorName[originInput] = net->tensorName[o];
}
}
if (!valid) {
continue;
}
}
for (auto subIter = net->oplists.begin(); subIter != net->oplists.end(); subIter++) {
auto& subOp = *subIter;
if (deleteOutput) {
for (auto iter=subOp->inputIndexes.begin(); iter != subOp->inputIndexes.end();) {
if (std::find(originOutputs.begin(), originOutputs.end(), *iter) != originOutputs.end()) {
iter = subOp->inputIndexes.erase(iter);
continue;
}
iter++;
}
} else {
for (int v = 0; v < subOp->inputIndexes.size(); ++v) {
if (std::find(originOutputs.begin(), originOutputs.end(), subOp->inputIndexes[v]) != originOutputs.end()) {
subOp->inputIndexes[v] = originInput;
}
}
}
}
bool removeUselessInput = shouldRemoveUnusefulInputs(op.get());
if (removeUselessInput) {
for (int input : op->inputIndexes) {
removedInputs.emplace(input);
}
}
iter = net->oplists.erase(iter);
}
// Remove the op only if the reference counts of it's all outputs
// are reduced to be zero.
std::unordered_map<int, int/*reference count*/> uselessIndex;
for (const auto& op : net->oplists) {
for (int input : op->inputIndexes) {
auto it = uselessIndex.find(input);
if (it == uselessIndex.end()) {
uselessIndex.emplace(input, 1);
} else {
++it->second;
}
}
}
// Set reference count 1 for all net outputs.
for (const auto& op : net->oplists) {
for (int output : op->outputIndexes) {
auto it = uselessIndex.find(output);
if (it == uselessIndex.end()) {
if (removedInputs.count(output)) {
uselessIndex.emplace(output, 0);
} else {
uselessIndex.emplace(output, 1);
}
}
}
}
bool needIteration = false;
do {
needIteration = false;
for (auto iter = net->oplists.begin(); iter != net->oplists.end();) {
auto& op = *iter;
bool useless = true;
for (auto index : op->outputIndexes) {
if (uselessIndex.at(index) > 0) {
useless = false;
break;
}
}
if (!useless) {
iter++;
continue;
}
if (!op->inputIndexes.empty()) {
for (auto index : op->inputIndexes) {
auto it = uselessIndex.find(index);
MNN_ASSERT(it != uselessIndex.end());
--it->second;
}
needIteration = true;
}
iter = net->oplists.erase(iter);
}
} while (needIteration);
return true;
}

View File

@ -25,135 +25,5 @@ public:
virtual bool shouldDeleteOutput(const MNN::OpT* op) const = 0;
virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override {
const MNN::NetT* const netPtr = net.get();
std::set<std::string> netOutputNames;
for (auto& t : net->outputName) {
netOutputNames.insert(t);
}
std::unordered_set<int> removedInputs;
for (auto iter = net->oplists.begin(); iter != net->oplists.end();) {
auto& op = *iter;
bool shouldDelete = shouldDeleteJudge(op.get(), netPtr);
if (!shouldDelete) {
iter++;
continue;
}
bool hasOutputName = false;
for (auto o : op->outputIndexes) {
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
hasOutputName = true;
break;
}
}
bool hasOutputFromInput = false;
for (auto o : op->inputIndexes) {
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
hasOutputFromInput = true;
break;
}
}
if (hasOutputFromInput && hasOutputName) {
iter++;
continue;
}
bool deleteOutput = shouldDeleteOutput(op.get());
// Find the next op
if (op->outputIndexes.empty() || op->inputIndexes.empty()) {
iter = net->oplists.erase(iter);
continue;
}
auto originInput = op->inputIndexes[0];
auto originOutputs = op->outputIndexes;
if ((!deleteOutput) && hasOutputName) {
for (auto o : originOutputs) {
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
net->tensorName[originInput] = net->tensorName[o];
}
}
}
for (auto subIter = net->oplists.begin(); subIter != net->oplists.end(); subIter++) {
auto& subOp = *subIter;
if (deleteOutput) {
for (auto iter=subOp->inputIndexes.begin(); iter != subOp->inputIndexes.end();) {
if (std::find(originOutputs.begin(), originOutputs.end(), *iter) != originOutputs.end()) {
iter = subOp->inputIndexes.erase(iter);
continue;
}
iter++;
}
} else {
for (int v = 0; v < subOp->inputIndexes.size(); ++v) {
if (std::find(originOutputs.begin(), originOutputs.end(), subOp->inputIndexes[v]) != originOutputs.end()) {
subOp->inputIndexes[v] = originInput;
}
}
}
}
bool removeUselessInput = shouldRemoveUnusefulInputs(op.get());
if (removeUselessInput) {
for (int input : op->inputIndexes) {
removedInputs.emplace(input);
}
}
iter = net->oplists.erase(iter);
}
// Remove the op only if the reference counts of it's all outputs
// are reduced to be zero.
std::unordered_map<int, int/*reference count*/> uselessIndex;
for (const auto& op : net->oplists) {
for (int input : op->inputIndexes) {
auto it = uselessIndex.find(input);
if (it == uselessIndex.end()) {
uselessIndex.emplace(input, 1);
} else {
++it->second;
}
}
}
// Set reference count 1 for all net outputs.
for (const auto& op : net->oplists) {
for (int output : op->outputIndexes) {
auto it = uselessIndex.find(output);
if (it == uselessIndex.end()) {
if (removedInputs.count(output)) {
uselessIndex.emplace(output, 0);
} else {
uselessIndex.emplace(output, 1);
}
}
}
}
bool needIteration = false;
do {
needIteration = false;
for (auto iter = net->oplists.begin(); iter != net->oplists.end();) {
auto& op = *iter;
bool useless = true;
for (auto index : op->outputIndexes) {
if (uselessIndex.at(index) > 0) {
useless = false;
break;
}
}
if (!useless) {
iter++;
continue;
}
if (!op->inputIndexes.empty()) {
for (auto index : op->inputIndexes) {
auto it = uselessIndex.find(index);
MNN_ASSERT(it != uselessIndex.end());
--it->second;
}
needIteration = true;
}
iter = net->oplists.erase(iter);
}
} while (needIteration);
return true;
}
virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override;
};

View File

@ -33,6 +33,11 @@ public:
return true;
}
}
if (op->type == OpType_Identity) {
// Support 1->N
return op->inputIndexes.size() == 1 && op->outputIndexes.size() > 1;
}
if (op->type == OpType_Cast) {
if (op->main.AsCastParam()->dstT == op->main.AsCastParam()->srcT) {
return true;

View File

@ -37,11 +37,21 @@ int main(int argc, const char* argv[]) {
/**
generate qnn .cpp and .bin
*/
std::string dstModelName = dstMNN;
size_t pos = dstModelName.find_last_of("/\\");
std::string dstModelPath;
if (pos == std::string::npos) {
// current path
dstModelPath = "./";
} else {
dstModelPath = dstModelName.substr(0, pos);
}
std::string qnnModelPath = dstModelPath + "/" + qnnModelName;
MNN_PRINT("[Temp Product]: Qnn temp product generate at %s\n", qnnModelPath.c_str());
MNN::ScheduleConfig config;
config.type = MNN_FORWARD_NN;
std::shared_ptr<Executor::RuntimeManager> rtmgr(Executor::RuntimeManager::createRuntimeManager(config));
rtmgr->setCache(qnnModelName.c_str());
rtmgr->setCache(qnnModelPath.c_str());
MNN::Express::Module::Config mConfig;
mConfig.shapeMutable = false;
std::shared_ptr<MNN::Express::Module> m(MNN::Express::Module::load(inputNames, outputNames, srcMNN, rtmgr, &mConfig), MNN::Express::Module::destroy);
@ -64,7 +74,7 @@ int main(int argc, const char* argv[]) {
}
int ret = 0;
std::string tarBinCmd = "cd " + qnnModelName + \
std::string tarBinCmd = "cd " + qnnModelPath + \
" && " + \
"tar -cf " + qnnModelName + ".bin *.raw";
ret = system(tarBinCmd.c_str());
@ -74,10 +84,10 @@ int main(int argc, const char* argv[]) {
}
std::string modelLibCmd = qnnSdkPath + "/bin/x86_64-linux-clang/qnn-model-lib-generator " + \
"-c " + qnnModelName + "/" + qnnModelName + ".cpp " + \
"-b " + qnnModelName + "/" + qnnModelName + ".bin " + \
"-c " + qnnModelPath + "/" + qnnModelName + ".cpp " + \
"-b " + qnnModelPath + "/" + qnnModelName + ".bin " + \
"-t x86_64-linux-clang " + \
"-o " + qnnModelName + "/lib ";
"-o " + qnnModelPath + "/lib ";
ret = system(modelLibCmd.c_str());
if(ret) {
MNN_ERROR("[Error]: qnn-model-lib-generator error!\n");
@ -86,12 +96,13 @@ int main(int argc, const char* argv[]) {
MNN_PRINT("[Pass]: qnn-model-lib-generator success!\n");
}
std::string qnnBin = dstModelPath + "/" + qnnModelName + ".bin";
std::string binaryGenCmd = qnnSdkPath + "/bin/x86_64-linux-clang/qnn-context-binary-generator " + \
"--model " + qnnModelName + "/lib/x86_64-linux-clang/lib" + qnnModelName + ".so " + \
"--model " + qnnModelPath + "/lib/x86_64-linux-clang/lib" + qnnModelName + ".so " + \
"--backend " + qnnSdkPath + "/lib/x86_64-linux-clang/libQnnHtp.so " + \
"--binary_file " + qnnModelName + " " + \
"--config_file " + qnnContextConfig + " " + \
"--output_dir " + qnnModelName + "/binary";
"--output_dir " + dstModelPath;
ret = system(binaryGenCmd.c_str());
if(ret) {
MNN_ERROR("[Error]: qnn-context-binary-generator error!\n");
@ -123,6 +134,7 @@ int main(int argc, const char* argv[]) {
}
std::string npuPath = std::string("/") + qnnModelName + std::string(".bin");
MNN_PRINT("npu model path:%s\n", npuPath.c_str());
/** Fuse to Op*/
std::unique_ptr<MNN::OpT> op(new OpT);
@ -204,7 +216,7 @@ int main(int argc, const char* argv[]) {
}
for (int i=0; i<outputInfos.size(); ++i) {
attr.reset(new MNN::AttributeT);
attr->key = "o_" + std::to_string(i) + "_0";
attr->key = "o_0_" + std::to_string(i);
attr->tensor.reset(new BlobT);
attr->tensor->dataType = OpCommonUtils::convertDataType(outputInfos[i].type);
attr->tensor->dims = outputInfos[i].dim;
@ -240,7 +252,6 @@ int main(int argc, const char* argv[]) {
outputOs.close();
MNN_PRINT("[All Pass]: npu model generator success!\n");
std::string qnnBin = qnnModelName + "/binary/" + qnnModelName + ".bin";
MNN_PRINT("[Output Product]:\nNew mnn model path: %s\nNpu model path: %s\n", dstMNN, qnnBin.c_str());
return 0;
}

View File

@ -43,6 +43,6 @@ for w in gWrong:
print(w)
print('TEST_NAME_MODULE: 模型测试\nTEST_CASE_AMOUNT_MODULE: {\"blocked\":0,\"failed\":%d,\"passed\":%d,\"skipped\":0}\n'%(len(gWrong), total_num - len(gWrong)))
print('TEST_CASE={\"name\":\"Onnx转换测试\",\"failed\":%d,\"passed\":%d}\n'%(len(gWrong), total_num - len(gWrong)))
print('Total Size: ', mnnsize,' MB, convert ', correct_num," model")
print('Total Size: ', total_size,' MB, convert ', correct_num," model")
if len(gWrong) > 0:
exit(1)

View File

@ -0,0 +1,168 @@
import json
import copy
import argparse
import os
import subprocess
import shutil # 导入 shutil 模块用于删除目录
def generate_all_configs(config_path, graph_name, qnn_sdk_root_path, src_model, executable_path, output_dir):
"""
为每个组合创建子目录生成配置文件并调用C++可执行文件进行模型转换
"""
# --- 0. 准备工作 ---
# 创建主输出目录
os.makedirs(output_dir, exist_ok=True)
print(f"所有生成的文件将被保存在主目录: '{output_dir}'")
# 定义组合
combinations = [
[36, 'v69'],
[42, 'v69'],
[43, 'v73'],
[57, 'v75'],
[69, 'v79']
]
# --- 1. 读取模板文件 ---
htp_template_file = os.path.join(config_path, "htp_backend_extensions.json")
context_template_file = os.path.join(config_path, "context_config.json")
try:
with open(htp_template_file, 'r', encoding='utf-8') as f:
base_htp_data = json.load(f)
print(f"成功读取模板文件 '{htp_template_file}'")
with open(context_template_file, 'r', encoding='utf-8') as f:
base_context_data = json.load(f)
print(f"成功读取模板文件 '{context_template_file}'")
except FileNotFoundError as e:
print(f"错误:模板文件未找到。请确保 '{e.filename}' 存在于指定的路径中。")
return
except json.JSONDecodeError as e:
print(f"错误:文件格式无效。请检查 {e.doc} 是否为有效的JSON。")
return
# --- 2. 遍历组合,生成文件并执行命令 ---
for soc_id, dsp_arch in combinations:
print(f"\n{'='*15} 处理组合: soc_id={soc_id}, dsp_arch={dsp_arch} {'='*15}")
# --- 新增步骤: 为当前组合创建专用的子目录 ---
new_graph_name = f"{graph_name}_{soc_id}_{dsp_arch}"
graph_specific_dir = output_dir
# --- Part A: 生成 htp_backend_extensions 文件 (路径更新) ---
htp_config_data = copy.deepcopy(base_htp_data)
try:
htp_config_data["graphs"][0]["graph_names"] = [new_graph_name]
htp_config_data["devices"][0]["soc_id"] = soc_id
htp_config_data["devices"][0]["dsp_arch"] = dsp_arch
except (IndexError, KeyError) as e:
print(f"处理组合时出错: '{htp_template_file}' 结构不正确。错误: {e}")
continue
htp_output_filename = f"htp_backend_extensions_{soc_id}_{dsp_arch}.json"
# 更新路径,使其指向新的子目录
htp_output_filepath = os.path.join(graph_specific_dir, htp_output_filename)
with open(htp_output_filepath, 'w', encoding='utf-8') as f:
json.dump(htp_config_data, f, indent=4, ensure_ascii=False)
print(f"-> 已生成配置文件: '{htp_output_filepath}'")
# --- Part B: 生成 context_config 文件 (路径更新) ---
context_config_data = copy.deepcopy(base_context_data)
try:
# 这里的 htp_output_filename 是相对路径,这是正确的,
# 因为 context_config 和 htp_backend_extensions 在同一个目录中。
context_config_data["backend_extensions"]["config_file_path"] = htp_output_filepath
path_template = context_config_data["backend_extensions"]["shared_library_path"]
new_lib_path = path_template.replace("{QNN_SDK_ROOT}", qnn_sdk_root_path)
context_config_data["backend_extensions"]["shared_library_path"] = new_lib_path
except KeyError as e:
print(f"处理组合时出错: '{context_template_file}' 结构不正确,缺少键: {e}")
continue
context_output_filename = f"context_config_{soc_id}_{dsp_arch}.json"
# 更新路径,使其指向新的子目录
context_output_filepath = os.path.join(graph_specific_dir, context_output_filename)
with open(context_output_filepath, 'w', encoding='utf-8') as f:
json.dump(context_config_data, f, indent=4, ensure_ascii=False)
print(f"-> 已生成关联文件: '{context_output_filepath}'")
# --- Part C: 调用C++可执行命令 (路径更新) ---
dst_model_filename = f"{graph_name}_{soc_id}_{dsp_arch}.mnn"
# 更新路径,使其指向新的子目录
dst_model_filepath = os.path.join(graph_specific_dir, dst_model_filename)
graph_product_dir = os.path.join(graph_specific_dir, new_graph_name)
os.makedirs(graph_product_dir, exist_ok=True)
print(f"-> 已创建/确认子目录: '{graph_product_dir}'")
command = [
executable_path,
src_model,
dst_model_filepath,
qnn_sdk_root_path,
new_graph_name,
context_output_filepath
]
print("--> 准备执行命令...")
print(f" $ {' '.join(command)}")
try:
result = subprocess.run(command, check=True, capture_output=True, text=True)
print("--> 命令执行成功!")
# 即使成功,也打印 C++ 程序的输出,这对于查看警告等信息很有用
if result.stdout:
print(" --- C++程序输出 (stdout) ---")
print(result.stdout.strip())
print(" ------------------------------")
except FileNotFoundError:
print(f"!!! 命令执行失败: 可执行文件未找到 '{executable_path}'。请检查路径。")
break # 如果可执行文件找不到,直接退出循环
except subprocess.CalledProcessError as e:
# 这是关键的修改部分
print(f"!!! 命令执行失败 (返回码: {e.returncode})")
# 检查并打印 C++ 程序在失败前产生的标准输出
if e.stdout:
print(" --- C++程序输出 (stdout) ---")
print(e.stdout.strip())
print(" ------------------------------")
# 检查并打印 C++ 程序在失败前产生的标准错误(错误日志通常在这里)
if e.stderr:
print(" --- C++程序错误 (stderr) ---")
print(e.stderr.strip())
print(" ------------------------------")
except Exception as e:
print(f"!!! 执行期间发生未知错误: {e}")
finally:
# --- 步骤 3: 清理 ---
# 检查目录是否存在,然后删除
if os.path.exists(graph_product_dir):
print(f"--> 清理临时文件和目录: '{graph_product_dir}'")
shutil.rmtree(graph_product_dir)
else:
print("--> 无需清理,临时目录未创建。")
# --- 脚本执行入口 ---
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="为多个组合创建子目录生成QNN配置文件并调用模型转换工具。",
formatter_class=argparse.RawTextHelpFormatter
)
# ... (argparse部分保持完全不变) ...
gen_group = parser.add_argument_group('文件生成参数')
gen_group.add_argument("--config_path", required=True, help="[必需] 包含模板文件的目录路径。")
gen_group.add_argument("--graph_name", required=True, help="[必需] 模型图的名称 (不含soc_id等后缀)。")
gen_group.add_argument("--qnn_sdk_root_path", required=True, help="[必需] QNN SDK 的根路径。")
exec_group = parser.add_argument_group('模型转换参数')
exec_group.add_argument("--src_model", required=True, help="[必需] 源模型文件路径 (例如: my_model.mnn)。")
exec_group.add_argument("--executable_path", required=True, help="[必需] C++模型转换可执行文件的路径。")
exec_group.add_argument("--output_dir", default="./qnn_models", help="存放所有生成文件的输出目录 (默认: ./qnn_models)。")
args = parser.parse_args()
generate_all_configs(args.config_path, args.graph_name, args.qnn_sdk_root_path, args.src_model, args.executable_path, args.output_dir)

View File

@ -51,6 +51,21 @@ if (LLM_SUPPORT_AUDIO AND MNN_BUILD_AUDIO)
add_definitions(-DLLM_SUPPORT_AUDIO)
endif()
IF(CMAKE_SYSTEM_NAME MATCHES "^Android" AND NOT MNN_BUILD_FOR_ANDROID_COMMAND)
IF(NOT NATIVE_INCLUDE_OUTPUT)
set(NATIVE_INCLUDE_OUTPUT ".")
ENDIF()
add_custom_command(
TARGET llm
POST_BUILD
COMMAND ${CMAKE_COMMAND}
ARGS -E copy_directory ${CMAKE_CURRENT_LIST_DIR}/include ${NATIVE_INCLUDE_OUTPUT}
)
ELSE()
INSTALL(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/include/ DESTINATION include FILES_MATCHING PATTERN *.hpp)
ENDIF()
add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/demo/llm_demo.cpp)
add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/demo/embedding_demo.cpp)
add_executable(reranker_demo ${CMAKE_CURRENT_LIST_DIR}/demo/reranker_demo.cpp)

View File

@ -37,6 +37,23 @@ struct TimePerformance;
using ChatMessage = std::pair<std::string, std::string>; // <role, content>
using ChatMessages = std::vector<ChatMessage>;
struct MNN_PUBLIC PromptImagePart {
MNN::Express::VARP image_data;
int width;
int height;
};
struct MNN_PUBLIC PromptAudioPart {
std::string file_path;
MNN::Express::VARP waveform;
};
struct MNN_PUBLIC MultimodalPrompt {
std::string prompt_template;
std::map<std::string, PromptImagePart> images;
std::map<std::string, PromptAudioPart> audios;
};
enum TuneType {
// op encoder number for commit
OP_ENCODER_NUMBER = 0,
@ -108,11 +125,17 @@ public:
std::string dump_config();
bool set_config(const std::string& content);
Llm* create_lora(const std::string& lora_path);
std::string get_statistics();
// tokenier function
bool is_stop(int token);
std::string tokenizer_decode(int token);
virtual std::vector<int> tokenizer_encode(const std::string& query);
friend class Pipeline;
virtual std::vector<int> tokenizer_encode(const MultimodalPrompt& multimodal_input);
void response(const MultimodalPrompt& multimodal_input,
std::ostream* os = &std::cout,
const char* end_with = nullptr,
int max_new_tokens = -1);
const LlmContext* getContext() const {
return mContext.get();
}

View File

@ -48,6 +48,7 @@ DiskEmbedding::DiskEmbedding(const std::shared_ptr<LlmConfig>& config, std::stri
if (mQuantBit != 16) {
if (mQuantBlock == 0) {
mBlockNum = 1;
mQuantBlock = mHiddenSize; // be used for mDequantFunc.
} else {
mBlockNum = mHiddenSize / mQuantBlock;
}

View File

@ -9,6 +9,7 @@
#include <fstream>
#include <iostream>
#include <sstream>
#include <iomanip>
#include <unordered_set>
#include <MNN/AutoTime.hpp>
@ -97,11 +98,44 @@ bool Llm::set_config(const std::string& content) {
return res;
}
std::string Llm::get_statistics() {
auto context = getContext();
int prompt_len = context->prompt_len;
int decode_len = context->gen_seq_len;
int64_t vision_time = context->vision_us;
int64_t audio_time = context->audio_us;
int64_t prefill_time = context->prefill_us;
int64_t decode_time = context->decode_us;
int64_t sample_time = context->sample_us;
float vision_s = vision_time / 1e6;
float audio_s = audio_time / 1e6;
float prefill_s = prefill_time / 1e6;
float decode_s = decode_time / 1e6;
float sample_s = sample_time / 1e6;
float prefill_speed = (prefill_s > 0.0f) ? (prompt_len / prefill_s) : 0.0f;
float decode_speed = (decode_s > 0.0f) ? (decode_len / decode_s) : 0.0f;
std::ostringstream json_stream;
json_stream << "{"
<< "\"prompt_tokens\":" << prompt_len << ","
<< "\"decode_tokens\":" << decode_len << ","
<< "\"vision_time\":" << std::fixed << std::setprecision(2) << vision_s << ","
<< "\"audio_time\":" << std::fixed << std::setprecision(2) << audio_s << ","
<< "\"prefill_time\":" << std::fixed << std::setprecision(2) << prefill_s << ","
<< "\"decode_time\":" << std::fixed << std::setprecision(2) << decode_s << ","
<< "\"sample_time\":" << std::fixed << std::setprecision(2) << sample_s << ","
<< "\"prefill_speed\":" << std::fixed << std::setprecision(2) << prefill_speed << ","
<< "\"decode_speed\":" << std::fixed << std::setprecision(2) << decode_speed
<< "}";
return json_stream.str();
}
void Llm::setRuntimeHint(std::shared_ptr<Express::Executor::RuntimeManager> &rtg) {
rtg->setHint(MNN::Interpreter::INIT_THREAD_NUMBER, 4);
rtg->setHint(MNN::Interpreter::MEM_ALLOCATOR_TYPE, 0);
rtg->setHint(MNN::Interpreter::QKV_QUANT_OPTIONS, mConfig->quant_qkv());
rtg->setHint(MNN::Interpreter::QKV_QUANT_OPTIONS, mConfig->config_.value("quant_qkv", 8));
rtg->setHint(MNN::Interpreter::KVCACHE_SIZE_LIMIT, mConfig->kvcache_limit());
if (mConfig->use_cached_mmap()) {
rtg->setHint(MNN::Interpreter::USE_CACHED_MMAP, 1);
@ -113,6 +147,8 @@ void Llm::setRuntimeHint(std::shared_ptr<Express::Executor::RuntimeManager> &rtg
if (mConfig->use_mmap()) {
rtg->setExternalPath(tmpPath, MNN::Interpreter::EXTERNAL_WEIGHT_DIR);
}
// set npu model dir
rtg->setExternalPath(mConfig->npu_model_dir(), 3);
auto dynamicOption = mConfig->dynamic_option();
if (mConfig->dynamic_option()) {
rtg->setHint(MNN::Interpreter::DYNAMIC_QUANT_OPTIONS, mConfig->dynamic_option());
@ -246,8 +282,7 @@ void Llm::load() {
if (needHiddenState) {
outputNames.emplace_back("hidden_states");
}
// set npu model dir
mRuntimeManager->setExternalPath(mConfig->npu_model_dir(), 3);
mRuntimeManager->setExternalFile(mConfig->llm_weight());
mModules[0].reset(Module::load(inputNames, outputNames, model_path.c_str(), mRuntimeManager, &module_config));
mRuntimeManager->setExternalFile("");
@ -594,6 +629,29 @@ std::vector<int> Llm::tokenizer_encode(const std::string& user_content) {
return mTokenizer->encode(user_content);
}
std::vector<int> Llm::tokenizer_encode(const MultimodalPrompt& multimodal_input) {
return mTokenizer->encode(multimodal_input.prompt_template);
}
void Llm::response(const MultimodalPrompt& multimodal_input,
std::ostream* os, const char* end_with, int max_new_tokens) {
auto prompt = multimodal_input.prompt_template;
if (mConfig->use_template()) {
prompt = mPrompt->applyTemplate(prompt, true);
}
int prompt_len = 0;
int decode_len = 0;
int64_t vision_time = 0;
int64_t audio_time = 0;
int64_t prefill_time = 0;
int64_t decode_time = 0;
int64_t sample_time = 0;
std::vector<int> input_ids = tokenizer_encode(multimodal_input);
response(input_ids, os, end_with, max_new_tokens);
}
std::vector<int> Llm::generate(MNN::Express::VARP input_embeds, int max_tokens) {
if (max_tokens < 0) {
max_tokens = mConfig->max_new_tokens();
@ -621,8 +679,8 @@ std::vector<int> Llm::generate(MNN::Express::VARP input_embeds, int max_tokens)
}
{
std::ofstream outFile("logits.txt");
auto temp = outputs[0]->readMap<float>();
for (size_t i = 0; i < outputs[0]->getInfo()->size; ++i) {
auto temp = mGenerateParam->outputs[0]->readMap<float>();
for (size_t i = 0; i < mGenerateParam->outputs[0]->getInfo()->size; ++i) {
outFile << temp[i] << " "; // 每个数字后加空格
}
outFile.close();

View File

@ -359,10 +359,6 @@ public:
return config_.value("memory", "low");
}
int quant_qkv() const {
return config_.value("quant_qkv", 0);
}
int kvcache_limit() const {
return config_.value("kvcache_limit", -1);
}

View File

@ -9,6 +9,7 @@
#define _USE_MATH_DEFINES
#endif
#include <regex>
#include <algorithm>
#include <MNN/AutoTime.hpp>
#include <MNN/expr/ExecutorScope.hpp>
#include "omni.hpp"
@ -555,8 +556,16 @@ std::vector<int> Omni::minicpmVisionProcess(VARP image) {
std::vector<int> Omni::visionProcess(const std::string& file) {
#ifdef LLM_SUPPORT_VISION
VARP image = MNN::CV::imread(file);
return visionProcess(image);
#else
return std::vector<int>(0);
#endif
}
std::vector<int> Omni::visionProcess(VARP image) {
#ifdef LLM_SUPPORT_VISION
if (image == nullptr) {
MNN_PRINT("Omni Can't open image: %s\n", file.c_str());
MNN_PRINT("Omni Can't open image\n");
return std::vector<int>(0);
}
Timer _t;
@ -573,7 +582,7 @@ std::vector<int> Omni::visionProcess(const std::string& file) {
} else {
imgIds = defaultVisionProcess(image);
}
mContext->vision_us = _t.durationInUs();
mContext->vision_us += _t.durationInUs();
// set vision number for image idx
mVisionNum += 1;
return imgIds;
@ -591,9 +600,19 @@ std::vector<int> Omni::audioProcess(const std::string& file) {
MNN_PRINT("Omni Can't open audio: %s\n", file.c_str());
return std::vector<int>(0);
}
// int sample_rate = load_res.second;
int wav_len = waveform->getInfo()->dim[0];
int hop_length = 160;
return audioProcess(waveform);
#else
return std::vector<int>(0);
#endif
}
std::vector<int> Omni::audioProcess(MNN::Express::VARP waveform) {
#ifdef LLM_SUPPORT_AUDIO
if (waveform == nullptr) {
MNN_PRINT("Omni Can't process audio: waveform is null\n");
return std::vector<int>(0);
}
Timer _t;
auto input_features = MNN::AUDIO::whisper_fbank(waveform);
VARP audio_embedding;
@ -727,21 +746,30 @@ void Omni::addPositionIds(int t, int h, int w) {
}
}
std::vector<int> Omni::tokenizer_encode(const std::string& prompt) {
// split query
std::vector<int> Omni::tokenizer_encode(const MultimodalPrompt& multimodal_input) {
std::string prompt = multimodal_input.prompt_template;
// MNN_PRINT("tokenizer_encode(MultimodalPrompt) prompt: %s", prompt.c_str());
std::regex multimode_regex("<(img|audio)>(.*?)</\\1>");
std::string::const_iterator searchStart(prompt.cbegin());
std::smatch match;
std::vector<std::string> img_infos;
std::vector<int> ids{};
mPositionIds.clear();
while (std::regex_search(searchStart, prompt.cend(), match, multimode_regex)) {
// std::cout << "img match: " << match[1].str() << std::endl;
auto txt_ids = mTokenizer->encode(match.prefix().str());
addPositionIds(txt_ids.size());
ids.insert(ids.end(), txt_ids.begin(), txt_ids.end());
auto mul_ids = multimodeProcess(match[1].str(), match[2].str());
std::string mode = match[1].str();
std::string content = match[2].str();
std::vector<int> mul_ids;
if (mode == "img") {
mul_ids = processImageContent(content, multimodal_input.images);
// MNN_PRINT("tokenizer_encode(MultimodalPrompt) image mul_ids size: %lu", mul_ids.size());
} else if (mode == "audio") {
mul_ids = processAudioContent(content, multimodal_input.audios);
// MNN_PRINT("tokenizer_encode(MultimodalPrompt) audio mul_ids size: %lu", mul_ids.size());
}
ids.insert(ids.end(), mul_ids.begin(), mul_ids.end());
searchStart = match.suffix().first;
}
@ -753,6 +781,43 @@ std::vector<int> Omni::tokenizer_encode(const std::string& prompt) {
return ids;
}
std::vector<int> Omni::tokenizer_encode(const std::string& prompt) {
MultimodalPrompt multimodal_input;
multimodal_input.prompt_template = prompt;
return tokenizer_encode(multimodal_input);
}
std::vector<int> Omni::processImageContent(const std::string& content, const std::map<std::string, PromptImagePart>& images) {
auto it = images.find(content);
if (it != images.end()) {
if (it->second.height > 0 && it->second.width > 0) {
mVisionHeight = it->second.height;
mVisionWidth = it->second.width;
}
// MNN_PRINT("processImageContent: using placeholder '%s' with size %dx%d", content.c_str(), mVisionWidth, mVisionHeight);
return visionProcess(it->second.image_data);
}
// MNN_PRINT("processImageContent: treating '%s' as file path or URL", content.c_str());
return multimodeProcess("img", content);
}
std::vector<int> Omni::processAudioContent(const std::string& content, const std::map<std::string, PromptAudioPart>& audios) {
auto it = audios.find(content);
if (it != audios.end()) {
// MNN_PRINT("processAudioContent: using placeholder '%s'", content.c_str());
if (it->second.waveform.get() != nullptr) {
return audioProcess(it->second.waveform);
} else if (!it->second.file_path.empty()) {
return audioProcess(it->second.file_path);
} else {
MNN_PRINT("processAudioContent: audio_part has no valid input\n");
return std::vector<int>(0);
}
}
// MNN_PRINT("processAudioContent: treating '%s' as file path", content.c_str());
return multimodeProcess("audio", content);
}
VARP Omni::embedding(const std::vector<int>& input_ids) {
if (input_ids.size() == 1) {
return Llm::embedding(input_ids);
@ -845,21 +910,28 @@ VARP Omni::gen_position_ids(int seq_len) {
auto ptr = positionIds->writeMap<int>();
if (mContext->gen_seq_len > 0) {
for (int i=0; i<seq_len; ++i) {
auto pos = mContext->gen_seq_len + mPositionIds.back() + i;
// auto pos = mContext->gen_seq_len + mPositionIds.back() + i;
auto pos = mContext->all_seq_len + i;
ptr[i + 0] = pos;
ptr[i + seq_len] = pos;
ptr[i + seq_len * 2] = pos;
}
} else {
for (int i = 0; i < seq_len; i++) {
ptr[i] = mPositionIds.mT[i];
ptr[i + seq_len] = mPositionIds.mH[i];
ptr[i + seq_len * 2] = mPositionIds.mW[i];
ptr[i] = mPositionIds.mT[i] + mContext->all_seq_len;
ptr[i + seq_len] = mPositionIds.mH[i] + mContext->all_seq_len;
ptr[i + seq_len * 2] = mPositionIds.mW[i] + mContext->all_seq_len;
}
if (mTalker) {
mTalker->setPostionIds(mPositionIds);
}
}
// // dump position ids
// printf("position_ids = [");
// for (int i = 0; i < seq_len; i++) {
// printf("%d ", ptr[i]);
// }
// printf("]\n");
return positionIds;
}

View File

@ -105,12 +105,14 @@ public:
virtual void load() override;
virtual std::vector<Express::VARP> forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos) override;
virtual std::vector<int> tokenizer_encode(const std::string& query) override;
virtual std::vector<int> tokenizer_encode(const MultimodalPrompt& multimodal_input) override;
virtual Express::VARP embedding(const std::vector<int>& input_ids) override;
virtual Express::VARP gen_position_ids(int seq_len) override;
virtual void response(const std::vector<int>& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, int max_new_tokens = -1) override;
virtual void setWavformCallback(std::function<bool(const float*, size_t, bool)> callback) override;
virtual void generateWavform() override;
// some models preprocess function
std::vector<int> visionProcess(VARP image);
std::vector<int> defaultVisionProcess(VARP image);
std::vector<int> qwen2VisionProcess(VARP image);
std::vector<int> smolvlmVisionProcess(VARP image);
@ -126,6 +128,9 @@ private:
std::vector<int> multimodeProcess(const std::string& mode, std::string info);
std::vector<int> visionProcess(const std::string& file);
std::vector<int> audioProcess(const std::string& file);
std::vector<int> audioProcess(MNN::Express::VARP waveform);
std::vector<int> processImageContent(const std::string& content, const std::map<std::string, PromptImagePart>& images);
std::vector<int> processAudioContent(const std::string& content, const std::map<std::string, PromptAudioPart>& audios);
std::shared_ptr<Module> mVisionModule, mAudioModule;
std::vector<VARP> mVisionEmbeddings, mAudioEmbeddings;
std::shared_ptr<Talker> mTalker;

View File

@ -225,6 +225,7 @@ class LlmExporter(torch.nn.Module):
self.llm_config['jinja'] = prompt_template['jinja']
# load modules
ModelMapper.do_map(self, self.model, self.model_map['model'])
# rebuild modules
if self.lm_ is None:
out_features, in_features = self.embed_.weight.shape
@ -244,6 +245,7 @@ class LlmExporter(torch.nn.Module):
self.embed = Embedding(embed_copy, self)
else:
self.embed = Embedding(self.embed_, self)
# tie word embeddings
self.tie_word_embeddings = not self.args.seperate_embed and self.lm_.weight.equal(self.embed_.weight)
if self.tie_word_embeddings:
@ -802,14 +804,21 @@ class LlmExporter(torch.nn.Module):
if self.mnn_converter:
fuse_transformer = self.visual.transformer_fuse
native_group_conv = self.visual.group_conv_native
quant_bit_visual = self.visual.quant_bit
quant_block_visual = self.visual.quant_block
if self.args.transformer_fuse:
fuse_transformer = True
if self.args.group_conv_native:
native_group_conv = True
self.mnn_converter.export(vision_onnx, self.visual.quant_bit,
self.visual.quant_block,
if self.args.visual_quant_bit is not None:
quant_bit_visual = self.args.visual_quant_bit
if self.args.visual_quant_block is not None:
quant_block_visual = self.args.visual_quant_block
self.mnn_converter.export(vision_onnx, quant_bit_visual,
quant_block_visual,
transformer_fuse=fuse_transformer,
group_conv_native=native_group_conv)
group_conv_native=native_group_conv,
weight_sym=self.args.visual_sym)
def export_audio(self):
if self.audio is None:
@ -1237,6 +1246,9 @@ def export(path,
'onnx_slim': onnx_slim,
'quant_bit': quant_bit,
'quant_block': quant_block,
'visual_quant_bit': visual_quant_bit,
'visual_quant_block': visual_quant_block,
'visual_sym': visual_sym,
'lm_quant_bit': lm_quant_bit,
'mnnconvert': mnnconvert,
'ppl': ppl,
@ -1276,6 +1288,8 @@ def main():
parser.add_argument('--onnx_slim', action='store_true', help='Whether or not to use onnx-slim.')
parser.add_argument('--quant_bit', type=int, default=4, help='mnn quant bit, 4 or 8, default is 4.')
parser.add_argument('--quant_block', type=int, default=64, help='mnn quant block, 0 mean channle-wise, default is 64.')
parser.add_argument('--visual_quant_bit', type=int, default=None, help='mnn viusal quant bit, 4 or 8, default is setting in utils/vision.py by different vit model.')
parser.add_argument('--visual_quant_block', type=int, default=None, help='mnn quant block, default is setting in utils/vision.py by different vit model.')
parser.add_argument('--lm_quant_bit', type=int, default=None, help='mnn lm_head quant bit, 4 or 8, default is `quant_bit`.')
parser.add_argument('--mnnconvert', type=str, default='../../../build/MNNConvert', help='local mnnconvert path, if invalid, using pymnn.')
parser.add_argument('--ppl', action='store_true', help='Whether or not to get all logits of input tokens.')
@ -1284,9 +1298,10 @@ def main():
parser.add_argument('--transformer_fuse', action='store_true', help='Whether or not to fuse vision transformer op.')
parser.add_argument('--group_conv_native', action='store_true', help='Whether or not to keep native group_conv.')
parser.add_argument('--smooth', action='store_true', help='Whether or not to use smooth quant.')
parser.add_argument('--sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint), defualt is False.')
parser.add_argument('--seperate_embed', action='store_true', help='For lm and embed shared model, whether or not to sepearte embed to avoid quant, defualt is False, if True, embed weight will be seperate to embeddingbf16.bin.')
parser.add_argument('--lora_split', action='store_true', help='Whether or not export lora split, defualt is False.')
parser.add_argument('--sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint), default is False.')
parser.add_argument('--visual_sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint) for visual model, default is False.')
parser.add_argument('--seperate_embed', action='store_true', help='For lm and embed shared model, whether or not to sepearte embed to avoid quant, default is False, if True, embed weight will be seperate to embeddingbf16.bin.')
parser.add_argument('--lora_split', action='store_true', help='Whether or not export lora split, default is False.')
parser.add_argument('--calib_data', type=str, default=None, help='calibration data path, defaut is `None` mean not use calib data.')
args = parser.parse_args()

View File

@ -51,7 +51,7 @@ class MNNConveter:
os.close(log_fd)
@spinner_run(f'convert onnx model to ')
def onnx2mnn(self, onnx_path, mnn_path, args = [], transformer_fuse = True, group_conv_native = False, save_external_data = True):
def onnx2mnn(self, onnx_path, mnn_path, args = [], transformer_fuse = True, group_conv_native = False, weight_sym = False, save_external_data = True):
convert_args = [
'',
'-f',
@ -66,6 +66,8 @@ class MNNConveter:
convert_args += ['--transformerFuse']
if group_conv_native:
convert_args += ['--groupConvNative']
if weight_sym:
convert_args += ['--weightQuantAsymmetric=0']
if save_external_data:
convert_args += ['--saveExternalData']
convert_args += args
@ -112,7 +114,7 @@ class MNNConveter:
self.convert(convert_args)
return mnn_path
def export(self, onnx_path, quant_bit = None, quant_block = None, transformer_fuse = True, group_conv_native = False):
def export(self, onnx_path, quant_bit = None, quant_block = None, transformer_fuse = True, group_conv_native = False, weight_sym = None):
self.onnx_model_path = onnx_path
self.mnn_name = os.path.basename(onnx_path).replace('.onnx', '.mnn')
self.mnn_model_path = os.path.join(self.config.args.dst_path, self.mnn_name)
@ -133,10 +135,10 @@ class MNNConveter:
]
if quant_bit == 32:
quant_args = []
self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, quant_args, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native)
self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, quant_args, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native, weight_sym=weight_sym)
else:
mnn_json = f'{self.mnn_model_path}.json'
self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native)
self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native, weight_sym=weight_sym)
self.mnn2json(self.mnn_model_path, mnn_json)
self.rebuild(mnn_json)
self.json2mnn(mnn_json, self.mnn_model_path)

View File

@ -266,7 +266,6 @@ class Rotary(torch.nn.Module):
def get_theta():
return 1.0 / (self.rope_theta ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim))
# default rope type's theta
self.theta = get_theta()
# other type
@ -278,7 +277,6 @@ class Rotary(torch.nn.Module):
rope_type = config.rope_scaling['type']
elif 'rope_type' in config.rope_scaling:
rope_type = config.rope_scaling['rope_type']
# gen theta for rope_type
if rope_type == 'dynamic': # NTK
if 'alpha' in config.rope_scaling: # NTKAlpha in Hunyuan
@ -286,7 +284,7 @@ class Rotary(torch.nn.Module):
else: # NTKScaling
pass
self.theta = get_theta()
elif rope_type == 'yarn': # YaRN in gpt-oss
elif rope_type == 'yarn':
self.is_scaled = True
self.theta, self.attention_scaling = _compute_yarn_parameters(
rotary_dim=self.rotary_dim,