mirror of https://github.com/alibaba/MNN.git
MNN:Sync: Sync Internal 3.2.3
This commit is contained in:
parent
8f175e2748
commit
318a3de860
|
@ -376,4 +376,6 @@ datasets/*
|
|||
|
||||
# qnn 3rdParty
|
||||
source/backend/qnn/3rdParty/include
|
||||
apps/Android/MnnLlmChat/release_outputs
|
||||
project/android/.cxx
|
||||
pymnn/android/.cxx/
|
||||
pymnn/android/.cxx/abi_configuration_5u53tc49.json
|
||||
|
|
|
@ -11,8 +11,6 @@
|
|||
|
||||
|
||||
## News 🔥
|
||||
- [2025/08/08] Now we support [gpt-oss-20b](./apps/Android/MnnLlmChat/README.md#releases).
|
||||
- [2025/08/05] MNN Chat Android is availabe in [GooglePlay](https://play.google.com/store/apps/details?id=com.alibaba.mnnllm.android.release) !
|
||||
- [2025/06/11] New App MNN TaoAvatar released, you can talk with 3DAvatar offline with LLM, ASR, TTS, A2BS and NNR models all run local on your device!! [MNN TaoAvatar](./apps/Android/Mnn3dAvatar/README.md)
|
||||
<p align="center">
|
||||
<img width="20%" alt="Icon" src="https://meta.alicdn.com/data/mnn/avatar/avatar_demo.gif" style="margin: 0 10px;">
|
||||
|
|
|
@ -183,6 +183,21 @@ struct Info {
|
|||
const Info* getInfo() const;
|
||||
```
|
||||
|
||||
### 获取设备信息
|
||||
调用`getDeviceInfo`函数可获取`Device`信息,可以参考代码:
|
||||
```cpp
|
||||
std::string soc_id, dsp_arch;
|
||||
bool success = MNN::Express::Executor::RuntimeManager::getDeviceInfo("dsp_arch", MNN_FORWARD_NN, dsp_arch);
|
||||
if(success) {
|
||||
MNN_PRINT("Device dsp_arch: %s\n", dsp_arch.c_str());
|
||||
}
|
||||
|
||||
success = MNN::Express::Executor::RuntimeManager::getDeviceInfo("soc_id", MNN_FORWARD_NN, soc_id);
|
||||
if(success) {
|
||||
MNN_PRINT("Device soc_id: %s\n", soc_id.c_str());
|
||||
}
|
||||
```
|
||||
|
||||
### 执行推理
|
||||
调用`onForward`执行推理。
|
||||
|
||||
|
|
|
@ -58,6 +58,10 @@ adb push ${MNN_ROOT}/source/backend/qnn/3rdParty/lib/hexagon-v${HEXAGON_ARCH}/un
|
|||
adb shell "cd /data/local/tmp && LD_LIBRARY_PATH=/data/local/tmp ADSP_LIBRARY_PATH=/data/local/tmp ./MyExe.out"
|
||||
```
|
||||
|
||||
### QNN量化功能说明
|
||||
- 仅权重量化(激活是浮点):只支持Linear权重int8、channel-wise的对称量化。
|
||||
- 激活&权重都量化:支持激活per-tensor对称量化,权重是int8/int4、channel-wise的对称量化。
|
||||
|
||||
## CoreML
|
||||
适用于 Mac / iOS / iPad
|
||||
|
||||
|
|
|
@ -388,3 +388,12 @@ npu model path:./qnn_smolvlm_model.bin
|
|||
|
||||
./ModuleBasic.out qnn_smolvlm_model.mnn dir 0 0 10
|
||||
```
|
||||
### 生成多种QNN设备模型脚本
|
||||
tools/script/genQNNModelsFromMNN.py中提供了8Gen1 ~ 8Elite设备的QNN模型生成脚本
|
||||
```
|
||||
// 使用示例
|
||||
cd mnn_path
|
||||
cd build
|
||||
python3 ../tools/script/genQNNModelsFromMNN.py --config_path ../source/backend/qnn/convertor/config_example/ --graph_name visual_qnn --qnn_sdk_root_path /mnt/2Tpartition/tianbu/QNN/qairt/2.37.0.250724/ --src_model visual.mnn --executable_path ./MNN2QNNModel
|
||||
```
|
||||
后续将在qnn_models文件夹下生成8Gen1 ~ 8Elite设备的QNN模型产物。
|
||||
|
|
|
@ -106,6 +106,10 @@ optional arguments:
|
|||
mnn quant bit, 4 or 8, default is 4.
|
||||
--quant_block QUANT_BLOCK
|
||||
mnn quant block, 0 mean channle-wise, default is 128.
|
||||
--visual_quant_bit VISUAL_QUANT_BIT
|
||||
mnn visual model quant bit, 4 or 8, default is setting in utils/vision.py by different vit model.
|
||||
--visual_quant_block VISUAL_QUANT_BLOCK
|
||||
mnn visual model quant block, 0 mean channle-wise, default is setting in utils/vision.py by different vit model.
|
||||
--lm_quant_bit LM_QUANT_BIT
|
||||
mnn lm_head quant bit, 4 or 8, default is `quant_bit`.
|
||||
--mnnconvert MNNCONVERT
|
||||
|
@ -113,10 +117,12 @@ optional arguments:
|
|||
--ppl Whether or not to get all logits of input tokens.
|
||||
--awq Whether or not to use awq quant.
|
||||
--sym Whether or not to using symmetric quant (without zeropoint), defualt is False.
|
||||
--visual_sym Whether or not to using symmetric quant (without zeropoint) for visual model, defualt is False.
|
||||
--seperate_embed For lm and embed shared model, whether or not to sepearte embed to avoid quant, defualt is False, if True, embed weight will be seperate to embeddingbf16.bin.
|
||||
--lora_split Whether or not export lora split, defualt is False.
|
||||
```
|
||||
|
||||
|
||||
### 权重读取
|
||||
llmexport.py 同时支持 LLM 的验证功能,有较多的依赖。在没有相应环境的情况下,MNN-LLM也提供由 safetensors 或 gguf 文件读取权重的工具,可以降低内存需求,提高转换速度。使用方法如下:
|
||||
|
||||
|
@ -166,6 +172,7 @@ python3 gguf2mnn.py --gguf ~/third/llama.cpp/build/ggml-model-Q4_K.gguf --mnn_di
|
|||
```
|
||||
-DLLM_SUPPORT_VISION=true -DMNN_BUILD_OPENCV=true -DMNN_IMGCODECS=true
|
||||
```
|
||||
|
||||
- 需要开启音频功能时,增加相关编译宏
|
||||
```
|
||||
-DLLM_SUPPORT_AUDIO=true -DMNN_BUILD_AUDIO=true
|
||||
|
@ -195,6 +202,12 @@ cd project/android
|
|||
mkdir build_64
|
||||
../build_64.sh -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_ARM82=true -DMNN_OPENCL=true -DMNN_USE_LOGCAT=true
|
||||
```
|
||||
高通设备部分视觉模型支持NPU功能,可增加`MNN_QNN` 和`MNN_WITH_PLUGIN`的宏启用QNN功能。
|
||||
```
|
||||
cd project/android
|
||||
mkdir build_64
|
||||
../build_64.sh -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_ARM82=true -DMNN_OPENCL=true -DMNN_QNN=true -DMNN_WITH_PLUGIN=true -DMNN_USE_LOGCAT=true
|
||||
```
|
||||
|
||||
#### iOS: 参考 transformers/llm/engine/ios/README.md
|
||||
```
|
||||
|
|
|
@ -281,6 +281,17 @@ bool Executor::RuntimeManager::getInfo(Interpreter::SessionInfoCode code, void*
|
|||
return false;
|
||||
}
|
||||
|
||||
bool Executor::RuntimeManager::getDeviceInfo(const std::string& deviceKey, const MNNForwardType type, std::string& deviceValue) {
|
||||
auto creator = MNNGetExtraRuntimeCreator(type);
|
||||
if (creator != nullptr) {
|
||||
auto res = creator->onGetDeviceInfo(deviceKey, deviceValue);
|
||||
if(res) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Executor::RuntimeManager::RuntimeManager() {
|
||||
mInside = new RuntimeAttr;
|
||||
mInside->mContent.reset(new RuntimeAttr::Immutable);
|
||||
|
|
|
@ -130,6 +130,7 @@ public:
|
|||
void setHint(Interpreter::HintMode mode, int* value, size_t size);
|
||||
void setHintPtr(Interpreter::HintMode mode, void* value);
|
||||
bool getInfo(Interpreter::SessionInfoCode code, void* ptr);
|
||||
static bool getDeviceInfo(const std::string& deviceKey, const MNNForwardType type, std::string& deviceValue);
|
||||
BackendConfig* getBnConfig();
|
||||
const RuntimeAttr* getInside() const {
|
||||
return mInside;
|
||||
|
|
|
@ -25,3 +25,9 @@ set_target_properties( MNNOpenCV
|
|||
PROPERTIES IMPORTED_LOCATION
|
||||
${CMAKE_CURRENT_LIST_DIR}/libs/${ANDROID_ABI}/libMNNOpenCV.so
|
||||
)
|
||||
|
||||
add_library(MNN_LLM SHARED IMPORTED GLOBAL )
|
||||
set_target_properties(MNN_LLM
|
||||
PROPERTIES IMPORTED_LOCATION
|
||||
${CMAKE_CURRENT_LIST_DIR}/libs/${ANDROID_ABI}/libllm.so
|
||||
)
|
|
@ -3,4 +3,4 @@ distributionBase=GRADLE_USER_HOME
|
|||
distributionPath=wrapper/dists
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
distributionUrl=http\://mtl-gradle-mirror.oss-cn-hangzhou.aliyuncs.com/gradle-4.6-all.zip
|
||||
distributionUrl=http://mtl-gradle-mirror.oss-cn-hangzhou.aliyuncs.com/gradle-6.7.1-all.zip
|
|
@ -20,23 +20,48 @@ afterEvaluate {
|
|||
android.libraryVariants.all {variant ->
|
||||
def name = variant.buildType.name
|
||||
def zipTask = tasks.create(name: "zipNative${name.capitalize()}", type: Zip) {
|
||||
from project.zipTree("${variant.packageLibrary.destinationDir.path}/${project.name}-${name}.aar")
|
||||
from project.zipTree("${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}.aar")
|
||||
from "$buildDir/native-packed"
|
||||
archiveName "${project.name}-${name}-tmp.aar"
|
||||
archiveName "${artifactName}-${name}-tmp.aar"
|
||||
destinationDir(variant.packageLibrary.destinationDir)
|
||||
inputs.dir("$buildDir/native-packed")
|
||||
inputs.file("${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}.aar")
|
||||
outputs.file("${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}.aar")
|
||||
}
|
||||
|
||||
zipTask.dependsOn("externalNativeBuild${name.capitalize()}")
|
||||
|
||||
// Ensure QNN dependencies are prepared when BUILD_QNN is enabled
|
||||
if (project.ext.has('BUILD_QNN') && project.ext.BUILD_QNN) {
|
||||
zipTask.dependsOn('prepareQnnDeps')
|
||||
}
|
||||
|
||||
zipTask.doLast {
|
||||
copy {
|
||||
from "${variant.packageLibrary.destinationDir.path}/${project.name}-${name}-tmp.aar"
|
||||
from "${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}-tmp.aar"
|
||||
into "${variant.packageLibrary.destinationDir.path}"
|
||||
rename{
|
||||
String fileName->
|
||||
fileName.replace('-tmp','')
|
||||
}
|
||||
}
|
||||
delete "${variant.packageLibrary.destinationDir.path}/${project.name}-${name}-tmp.aar"
|
||||
delete "${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}-tmp.aar"
|
||||
|
||||
println "Generated final AAR: ${variant.packageLibrary.destinationDir.path}/${artifactName}-${name}.aar"
|
||||
}
|
||||
tasks.findByName("assemble${name.capitalize()}").doLast {
|
||||
zipTask.execute()
|
||||
|
||||
def bundleTask = tasks.findByName("bundle${name.capitalize()}Aar")
|
||||
if (bundleTask != null) {
|
||||
zipTask.dependsOn(bundleTask)
|
||||
}
|
||||
|
||||
def assembleTask = tasks.findByName("assemble${name.capitalize()}")
|
||||
if (assembleTask != null) {
|
||||
assembleTask.dependsOn(zipTask)
|
||||
zipTask.mustRunAfter(bundleTask)
|
||||
}
|
||||
def packageTask = tasks.findByName("package${name.capitalize()}Aar")
|
||||
if (packageTask != null) {
|
||||
packageTask.finalizedBy(zipTask)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,174 @@
|
|||
// QNN Dependencies Preparation
|
||||
// This gradle file handles downloading and preparing QNN dependencies
|
||||
|
||||
// QNN configuration
|
||||
ext {
|
||||
// QNN download settings
|
||||
QNN_LIBS_URL = 'http://meta.alicdn.com/data/mnn/libs/qnn_inc_libs.zip'
|
||||
}
|
||||
|
||||
def qnnZipName = 'qnn_inc_libs.zip'
|
||||
def qnnZipFile = new File(buildDir, qnnZipName)
|
||||
def qnnTmpDir = new File(buildDir, 'qnn_tmp')
|
||||
|
||||
task prepareQnnDeps {
|
||||
group = 'build setup'
|
||||
description = 'Download and extract QNN include/lib into source/backend/qnn/3rdParty when BUILD_QNN is enabled.'
|
||||
onlyIf {
|
||||
project.ext.has('BUILD_QNN') && project.ext.BUILD_QNN
|
||||
}
|
||||
|
||||
// Define inputs and outputs for incremental build
|
||||
inputs.property("qnn_libs_url", { QNN_LIBS_URL })
|
||||
inputs.property("build_qnn", { project.ext.has('BUILD_QNN') && project.ext.BUILD_QNN })
|
||||
outputs.dir(new File(project.rootDir, '../../source/backend/qnn/3rdParty'))
|
||||
outputs.dir(new File(buildDir, 'native-packed/native/libs'))
|
||||
|
||||
doLast {
|
||||
println "BUILD_QNN is ON. Preparing QNN dependencies..."
|
||||
|
||||
def url = QNN_LIBS_URL
|
||||
println "Downloading QNN dependencies from ${url}"
|
||||
|
||||
if (!qnnZipFile.exists()) {
|
||||
println "Downloading ${qnnZipName} ..."
|
||||
qnnZipFile.parentFile.mkdirs()
|
||||
|
||||
try {
|
||||
new java.net.URL(url).withInputStream { inputStream ->
|
||||
qnnZipFile.withOutputStream { outputStream ->
|
||||
outputStream << inputStream
|
||||
}
|
||||
}
|
||||
println "Downloaded to: ${qnnZipFile.absolutePath}"
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to download QNN dependencies from ${url}. Error: ${e.message}")
|
||||
}
|
||||
} else {
|
||||
println "Using cached zip: ${qnnZipFile.absolutePath}"
|
||||
}
|
||||
|
||||
// Clean temp dir and unpack
|
||||
project.delete(qnnTmpDir)
|
||||
qnnTmpDir.mkdirs()
|
||||
copy {
|
||||
from zipTree(qnnZipFile)
|
||||
into qnnTmpDir
|
||||
}
|
||||
|
||||
// Find the extracted QNN directory
|
||||
def extractedQnnDir = findQnnDirectory(qnnTmpDir)
|
||||
if (extractedQnnDir == null) {
|
||||
throw new RuntimeException("Failed to find QNN directory structure in ${qnnZipFile.name}")
|
||||
}
|
||||
|
||||
copyQnnFiles(extractedQnnDir)
|
||||
}
|
||||
}
|
||||
|
||||
def findQnnDirectory(File searchDir) {
|
||||
// Look for directory containing both include and jniLibs (or lib) directories
|
||||
def candidates = []
|
||||
|
||||
searchDir.eachDirRecurse { dir ->
|
||||
def hasInclude = new File(dir, 'include').exists()
|
||||
def hasLibs = new File(dir, 'jniLibs').exists() || new File(dir, 'lib').exists()
|
||||
|
||||
if (hasInclude && hasLibs) {
|
||||
candidates.add(dir)
|
||||
}
|
||||
}
|
||||
|
||||
if (candidates.isEmpty()) {
|
||||
// Fallback: look for directory that contains include
|
||||
searchDir.eachDirRecurse { dir ->
|
||||
if (new File(dir, 'include').exists()) {
|
||||
candidates.add(dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return candidates.isEmpty() ? null : candidates[0]
|
||||
}
|
||||
|
||||
def copyQnnFiles(File sourceDir) {
|
||||
// Resolve destination directories
|
||||
def qnnRoot = new File(project.rootDir, '../../source/backend/qnn/3rdParty')
|
||||
def includeDir = new File(qnnRoot, 'include')
|
||||
def libDir = new File(qnnRoot, 'lib')
|
||||
|
||||
includeDir.mkdirs()
|
||||
libDir.mkdirs()
|
||||
|
||||
// Copy include files
|
||||
def sourceInclude = new File(sourceDir, 'include')
|
||||
if (sourceInclude.exists()) {
|
||||
copy {
|
||||
from sourceInclude
|
||||
into includeDir
|
||||
}
|
||||
println "QNN includes copied to: ${includeDir.absolutePath}"
|
||||
} else {
|
||||
throw new RuntimeException("Include directory not found in ${sourceDir.absolutePath}")
|
||||
}
|
||||
|
||||
// Copy library files - try both jniLibs and lib directories
|
||||
def sourceLibs = new File(sourceDir, 'jniLibs')
|
||||
if (!sourceLibs.exists()) {
|
||||
sourceLibs = new File(sourceDir, 'lib')
|
||||
}
|
||||
|
||||
if (sourceLibs.exists()) {
|
||||
copy {
|
||||
from sourceLibs
|
||||
into libDir
|
||||
}
|
||||
println "QNN libs copied to: ${libDir.absolutePath}"
|
||||
|
||||
// Also copy QNN .so files to native-packed for AAR packaging
|
||||
copyQnnLibsToNativePacked(sourceLibs)
|
||||
} else {
|
||||
println "Warning: No lib/jniLibs directory found in ${sourceDir.absolutePath}"
|
||||
}
|
||||
|
||||
println "QNN dependencies preparation completed successfully."
|
||||
}
|
||||
|
||||
def copyQnnLibsToNativePacked(File sourceLibsDir) {
|
||||
// Create native-packed directory structure for QNN .so files
|
||||
def nativePackedLibsDir = new File(buildDir, 'native-packed/native/libs')
|
||||
nativePackedLibsDir.mkdirs()
|
||||
|
||||
println "Copying QNN .so files to native-packed for AAR packaging..."
|
||||
|
||||
// Copy all .so files from QNN libs to native-packed
|
||||
sourceLibsDir.eachFileRecurse { file ->
|
||||
if (file.name.endsWith('.so')) {
|
||||
// Determine the ABI directory (arm64-v8a, armeabi-v7a, etc.)
|
||||
def relativePath = sourceLibsDir.toPath().relativize(file.toPath())
|
||||
def targetFile = new File(nativePackedLibsDir, relativePath.toString())
|
||||
|
||||
// Create parent directories if they don't exist
|
||||
targetFile.parentFile.mkdirs()
|
||||
|
||||
// Copy the .so file
|
||||
copy {
|
||||
from file
|
||||
into targetFile.parentFile
|
||||
}
|
||||
|
||||
println "Copied QNN lib: ${file.name} -> ${targetFile.absolutePath}"
|
||||
}
|
||||
}
|
||||
|
||||
println "QNN .so files copied to native-packed directory"
|
||||
}
|
||||
|
||||
// Ensure preparation runs before compilation when enabled
|
||||
if (project.ext.has('BUILD_QNN') && project.ext.BUILD_QNN) {
|
||||
afterEvaluate {
|
||||
if (tasks.findByName('preBuild')) {
|
||||
preBuild.dependsOn prepareQnnDeps
|
||||
}
|
||||
}
|
||||
}
|
|
@ -480,7 +480,6 @@
|
|||
92FF02E023AA0B5A00AC97F6 /* MNNMatrixSub.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016023AA0B4E00AC97F6 /* MNNMatrixSub.S */; };
|
||||
92FF02E123AA0B5A00AC97F6 /* MNNPowC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016123AA0B4E00AC97F6 /* MNNPowC8.S */; };
|
||||
92FF02E223AA0B5A00AC97F6 /* MNNMatrixAdd.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */; };
|
||||
92FF02E323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */; };
|
||||
92FF02E523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */; };
|
||||
92FF02E723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */; };
|
||||
92FF02E823AA0B5A00AC97F6 /* MNNSamplerC1BilinearOpt.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */; };
|
||||
|
@ -520,7 +519,6 @@
|
|||
92FF032023AA0B5A00AC97F6 /* MNNMatrixSub.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A123AA0B4E00AC97F6 /* MNNMatrixSub.S */; };
|
||||
92FF032123AA0B5A00AC97F6 /* MNNPowC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A223AA0B4E00AC97F6 /* MNNPowC8.S */; };
|
||||
92FF032223AA0B5A00AC97F6 /* MNNMatrixAdd.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */; };
|
||||
92FF032323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */; };
|
||||
92FF032523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */; };
|
||||
92FF032723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */; };
|
||||
92FF032823AA0B5A00AC97F6 /* MNNSamplerC1BilinearOpt.S in Sources */ = {isa = PBXBuildFile; fileRef = 92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */; };
|
||||
|
@ -1315,7 +1313,6 @@
|
|||
92FF016023AA0B4E00AC97F6 /* MNNMatrixSub.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixSub.S; sourceTree = "<group>"; };
|
||||
92FF016123AA0B4E00AC97F6 /* MNNPowC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPowC8.S; sourceTree = "<group>"; };
|
||||
92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixAdd.S; sourceTree = "<group>"; };
|
||||
92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNExpC8.S; sourceTree = "<group>"; };
|
||||
92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNConvDwF23SourceTransUnit.S; sourceTree = "<group>"; };
|
||||
92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNDeconvRunForUnitDepthWise.S; sourceTree = "<group>"; };
|
||||
92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNSamplerC1BilinearOpt.S; sourceTree = "<group>"; };
|
||||
|
@ -1355,7 +1352,6 @@
|
|||
92FF01A123AA0B4E00AC97F6 /* MNNMatrixSub.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixSub.S; sourceTree = "<group>"; };
|
||||
92FF01A223AA0B4E00AC97F6 /* MNNPowC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPowC8.S; sourceTree = "<group>"; };
|
||||
92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNMatrixAdd.S; sourceTree = "<group>"; };
|
||||
92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNExpC8.S; sourceTree = "<group>"; };
|
||||
92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNConvDwF23SourceTransUnit.S; sourceTree = "<group>"; };
|
||||
92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNDeconvRunForUnitDepthWise.S; sourceTree = "<group>"; };
|
||||
92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNSamplerC1BilinearOpt.S; sourceTree = "<group>"; };
|
||||
|
@ -2591,7 +2587,6 @@
|
|||
92FF016023AA0B4E00AC97F6 /* MNNMatrixSub.S */,
|
||||
92FF016123AA0B4E00AC97F6 /* MNNPowC8.S */,
|
||||
92FF016223AA0B4E00AC97F6 /* MNNMatrixAdd.S */,
|
||||
92FF016323AA0B4E00AC97F6 /* MNNExpC8.S */,
|
||||
92FF016523AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */,
|
||||
92FF016723AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */,
|
||||
92FF016823AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */,
|
||||
|
@ -2677,7 +2672,6 @@
|
|||
92FF01A123AA0B4E00AC97F6 /* MNNMatrixSub.S */,
|
||||
92FF01A223AA0B4E00AC97F6 /* MNNPowC8.S */,
|
||||
92FF01A323AA0B4E00AC97F6 /* MNNMatrixAdd.S */,
|
||||
92FF01A423AA0B4E00AC97F6 /* MNNExpC8.S */,
|
||||
92FF01A623AA0B4E00AC97F6 /* MNNConvDwF23SourceTransUnit.S */,
|
||||
92FF01A823AA0B4E00AC97F6 /* MNNDeconvRunForUnitDepthWise.S */,
|
||||
92FF01A923AA0B4E00AC97F6 /* MNNSamplerC1BilinearOpt.S */,
|
||||
|
@ -3279,7 +3273,6 @@
|
|||
92FF02C223AA0B5A00AC97F6 /* MNNLoadU8AndSum.S in Sources */,
|
||||
4819FB2E24C1396A0050BD09 /* GeometryLSTM.cpp in Sources */,
|
||||
9558334B29B09A7B00488807 /* MNNGeluFP16.S in Sources */,
|
||||
92FF02E323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */,
|
||||
92FF030223AA0B5A00AC97F6 /* MNNQuanToDestUint8.S in Sources */,
|
||||
489D7AA12550FDC900AD896A /* MetalUnary.mm in Sources */,
|
||||
92FF037323AA0B5A00AC97F6 /* CPUEltwiseInt8.cpp in Sources */,
|
||||
|
@ -3438,7 +3431,6 @@
|
|||
92FF030723AA0B5A00AC97F6 /* MNNBlitC1ToFloatRGBA.S in Sources */,
|
||||
92FF03A423AA0B5A00AC97F6 /* OptimizedComputer.cpp in Sources */,
|
||||
92FF032E23AA0B5A00AC97F6 /* MNNReluWithSlopeChannel.S in Sources */,
|
||||
92FF032323AA0B5A00AC97F6 /* MNNExpC8.S in Sources */,
|
||||
95278CEA2B9F09C0009E9B29 /* ShapeDynamicQuant.cpp in Sources */,
|
||||
92FF044C23AA0B7100AC97F6 /* ShapePool3D.cpp in Sources */,
|
||||
92FF029823AA0B5A00AC97F6 /* CPUTFQuantizedConv2D.cpp in Sources */,
|
||||
|
|
|
@ -17,6 +17,7 @@ option(PYMNN_INTERNAL_SERVING "Internal use only." OFF)
|
|||
option(PYMNN_OPENCV_API "MNN OpenCV API be exposed" ON)
|
||||
option(PYMNN_IMGCODECS "MNN IMGCODECS API be exposed" OFF)
|
||||
option(PYMNN_AUDIO_API "MNN Audio API be exposed" OFF)
|
||||
option(PYMNN_LLM_API "MNN LLM API be exposed" OFF)
|
||||
option(PYMNN_OHOS_INTERNAL "compile for harmony internal." OFF)
|
||||
option(PYMNN_LLM_COLLECTION "offline llm data collection." OFF)
|
||||
|
||||
|
@ -97,6 +98,10 @@ if(PYMNN_AUDIO_API)
|
|||
target_compile_definitions(mnnpybridge PRIVATE PYMNN_AUDIO_API)
|
||||
endif()
|
||||
|
||||
if(PYMNN_LLM_API)
|
||||
target_compile_definitions(mnnpybridge PRIVATE PYMNN_LLM_API)
|
||||
endif()
|
||||
|
||||
if(PYMNN_INTERNAL_SERVING)
|
||||
message(STATUS "mnnpybridge define PYMNN_INTERNAL_SERVING")
|
||||
target_compile_definitions(mnnpybridge PRIVATE PYMNN_INTERNAL_SERVING)
|
||||
|
@ -214,6 +219,9 @@ else()
|
|||
if(PYMNN_NUMPY_USABLE)
|
||||
target_link_libraries(mnnpybridge PRIVATE numpy_python)
|
||||
endif()
|
||||
if(PYMNN_LLM_API)
|
||||
target_link_libraries(mnnpybridge PRIVATE MNN_LLM)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
|
104
pymnn/src/llm.h
104
pymnn/src/llm.h
|
@ -1,6 +1,10 @@
|
|||
#include <sstream>
|
||||
#include <iostream>
|
||||
#ifdef BUILD_FOR_IOS
|
||||
#include "MNN/llm/llm.hpp"
|
||||
#else
|
||||
#include "llm/llm.hpp"
|
||||
|
||||
#endif
|
||||
#ifdef PYMNN_LLM_COLLECTION
|
||||
#include "cpp/getLinearInput.hpp"
|
||||
#endif
|
||||
|
@ -85,18 +89,87 @@ static PyObject* PyMNNLLM_getCurrentHistory(LLM *self, PyObject *args) {
|
|||
}
|
||||
static PyObject* PyMNNLLM_response(LLM *self, PyObject *args) {
|
||||
if (self->is_embedding) {
|
||||
MNN_PRINT("[MNNLLM] response: is_embedding\n");
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
const char* query = NULL;
|
||||
|
||||
PyObject* content = nullptr;
|
||||
int stream = 0;
|
||||
if (!PyArg_ParseTuple(args, "s|p", &query, &stream)) {
|
||||
int max_new_tokens = 2048;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "O|ii", &content, &stream, &max_new_tokens)) {
|
||||
MNN_PRINT("[MNNLLM] response: PyArg_ParseTuple failed\n");
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
std::ostringstream null_os;
|
||||
self->llm->response(query, stream ? &std::cout : &null_os);
|
||||
return string2Object(null_os.str());
|
||||
std::ostream* output_stream = stream ? &std::cout : &null_os;
|
||||
|
||||
if (isString(content)) {
|
||||
std::string text = object2String(content);
|
||||
MNN_PRINT("[MNNLLM] response: text=%s, stream=%d, max_new_tokens=%d\n", text.c_str(), stream, max_new_tokens);
|
||||
self->llm->response(text, output_stream, nullptr, max_new_tokens);
|
||||
} else if (isPyDict(content)) {
|
||||
MNN::Transformer::MultimodalPrompt multimodal_input;
|
||||
PyObject* text_obj = PyDict_GetItemString(content, "text");
|
||||
if (text_obj && isString(text_obj)) {
|
||||
multimodal_input.prompt_template = object2String(text_obj);
|
||||
}
|
||||
PyObject* images_obj = PyDict_GetItemString(content, "images");
|
||||
if (images_obj && PyList_Check(images_obj)) {
|
||||
Py_ssize_t img_count = PyList_Size(images_obj);
|
||||
for (Py_ssize_t i = 0; i < img_count; i++) {
|
||||
PyObject* img_dict = PyList_GetItem(images_obj, i);
|
||||
if (isPyDict(img_dict)) {
|
||||
PyObject* data_obj = PyDict_GetItemString(img_dict, "data");
|
||||
PyObject* width_obj = PyDict_GetItemString(img_dict, "width");
|
||||
PyObject* height_obj = PyDict_GetItemString(img_dict, "height");
|
||||
|
||||
if (data_obj && width_obj && height_obj) {
|
||||
MNN::Transformer::PromptImagePart image_part;
|
||||
image_part.image_data = toVar(data_obj);
|
||||
image_part.width = PyLong_AsLong(width_obj);
|
||||
image_part.height = PyLong_AsLong(height_obj);
|
||||
|
||||
std::string key = "image_" + std::to_string(i);
|
||||
multimodal_input.images[key] = image_part;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PyObject* audios_obj = PyDict_GetItemString(content, "audios");
|
||||
if (audios_obj && PyList_Check(audios_obj)) {
|
||||
Py_ssize_t audio_count = PyList_Size(audios_obj);
|
||||
for (Py_ssize_t i = 0; i < audio_count; i++) {
|
||||
PyObject* audio_dict = PyList_GetItem(audios_obj, i);
|
||||
if (isPyDict(audio_dict)) {
|
||||
MNN::Transformer::PromptAudioPart audio_part;
|
||||
|
||||
PyObject* file_path_obj = PyDict_GetItemString(audio_dict, "file_path");
|
||||
if (file_path_obj && isString(file_path_obj)) {
|
||||
audio_part.file_path = object2String(file_path_obj);
|
||||
}
|
||||
|
||||
PyObject* waveform_obj = PyDict_GetItemString(audio_dict, "waveform");
|
||||
if (waveform_obj) {
|
||||
audio_part.waveform = toVar(waveform_obj);
|
||||
}
|
||||
|
||||
std::string key = "audio_" + std::to_string(i);
|
||||
multimodal_input.audios[key] = audio_part;
|
||||
}
|
||||
}
|
||||
}
|
||||
MNN_PRINT("[MNNLLM] response: multimodal, stream=%d, max_new_tokens=%d\n", stream, max_new_tokens);
|
||||
self->llm->response(multimodal_input, output_stream, nullptr, max_new_tokens);
|
||||
} else {
|
||||
PyMNN_ERROR("content must be str or dict");
|
||||
}
|
||||
std::string response_str = null_os.str();
|
||||
MNN_PRINT("[MNNLLM] response: %s\n", response_str.c_str());
|
||||
return string2Object(response_str);
|
||||
}
|
||||
|
||||
static PyObject* PyMNNLLM_tokenizer_encode(LLM *self, PyObject *args) {
|
||||
if (self->is_embedding) {
|
||||
Py_RETURN_NONE;
|
||||
|
@ -149,6 +222,14 @@ static PyObject* PyMNNLLM_reset(LLM *self, PyObject *args) {
|
|||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* PyMNNLLM_get_statistics(LLM *self, PyObject *args) {
|
||||
if (self->is_embedding) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
auto statistics = self->llm->get_statistics();
|
||||
return string2Object(statistics);
|
||||
}
|
||||
|
||||
#ifdef PYMNN_LLM_COLLECTION
|
||||
static PyObject* PyMNNLLM_enable_collection_mode(LLM *self, PyObject *args) {
|
||||
if (self->is_embedding) {
|
||||
|
@ -205,12 +286,11 @@ static PyObject* PyMNNLLM_enable_collection_mode(LLM *self, PyObject *args) {
|
|||
return toPyObj(true);
|
||||
}
|
||||
#endif
|
||||
|
||||
static PyMethodDef PyMNNLLM_methods[] = {
|
||||
{"load", (PyCFunction)PyMNNLLM_load, METH_VARARGS, "load model."},
|
||||
{"forward", (PyCFunction)PyMNNLLM_forward, METH_VARARGS, "forward `logits` by `input_ids`."},
|
||||
{"generate", (PyCFunction)PyMNNLLM_generate, METH_VARARGS, "generate `output_ids` by `input_ids`."},
|
||||
{"response", (PyCFunction)PyMNNLLM_response, METH_VARARGS, "response `query` without hsitory."},
|
||||
{"response", (PyCFunction)PyMNNLLM_response, METH_VARARGS, "response `query` - supports both text and multimodal input."},
|
||||
{"get_current_history", (PyCFunction)PyMNNLLM_getCurrentHistory, METH_VARARGS, "Get Current History."},
|
||||
{"erase_history", (PyCFunction)PyMNNLLM_eraseHistory, METH_VARARGS, "Erase History."},
|
||||
{"tokenizer_encode", (PyCFunction)PyMNNLLM_tokenizer_encode, METH_VARARGS, "tokenizer encode."},
|
||||
|
@ -219,6 +299,7 @@ static PyMethodDef PyMNNLLM_methods[] = {
|
|||
{"create_lora", (PyCFunction)PyMNNLLM_create_lora, METH_VARARGS, "create_lora."},
|
||||
{"set_config", (PyCFunction)PyMNNLLM_set_config, METH_VARARGS, "set_config."},
|
||||
{"reset", (PyCFunction)PyMNNLLM_reset, METH_VARARGS, "reset."},
|
||||
{"get_statistics", (PyCFunction)PyMNNLLM_get_statistics, METH_VARARGS, "get performance statistics."},
|
||||
#ifdef PYMNN_LLM_COLLECTION
|
||||
{"enable_collection_mode", (PyCFunction)PyMNNLLM_enable_collection_mode, METH_VARARGS, "Enable data collection mode."},
|
||||
#endif
|
||||
|
@ -274,7 +355,7 @@ static PyObject* PyMNNLLM_create_lora(LLM *self, PyObject *args) {
|
|||
Py_RETURN_NONE;
|
||||
}
|
||||
auto lora = self->llm->create_lora(path);
|
||||
LLM *llm = (LLM *)PyObject_Call((PyObject*)&PyMNNLLM, PyTuple_New(0), NULL);
|
||||
LLM *llm = (LLM *)PyObject_Call((PyObject*)PyType_FindTLSType(&PyMNNLLM), PyTuple_New(0), NULL);
|
||||
if (!llm) {
|
||||
return NULL;
|
||||
}
|
||||
|
@ -288,10 +369,11 @@ static PyObject* PyMNNLLM_create(PyObject *self, PyObject *args) {
|
|||
}
|
||||
const char* path = NULL;
|
||||
int embedding_model = 0;
|
||||
if (!PyArg_ParseTuple(args, "s|p", &path, &embedding_model)) {
|
||||
if (!PyArg_ParseTuple(args, "s|i", &path, &embedding_model)) {
|
||||
PyMNN_ERROR_LOG("Invalid arguments. Usage: create(path, embedding_model=False)");
|
||||
return NULL;
|
||||
}
|
||||
LLM *llm = (LLM *)PyObject_Call((PyObject*)&PyMNNLLM, PyTuple_New(0), NULL);
|
||||
LLM *llm = (LLM *)PyObject_Call((PyObject*)PyType_FindTLSType(&PyMNNLLM), PyTuple_New(0), NULL);
|
||||
if (!llm) {
|
||||
return NULL;
|
||||
}
|
||||
|
|
|
@ -398,6 +398,9 @@ static inline bool isPySequence(PyObject* obj) {
|
|||
// use isPySequence replace PySequence_Check
|
||||
return PyTuple_Check(obj) || PyList_Check(obj) || PyBytes_Check(obj);
|
||||
}
|
||||
static inline bool isPyDict(PyObject* obj) {
|
||||
return PyDict_Check(obj);
|
||||
}
|
||||
static inline int PySequenceSize(PyObject* obj) {
|
||||
if (PyTuple_Check(obj)) return PyTuple_Size(obj);
|
||||
if (PyList_Check(obj)) return PyList_Size(obj);
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -24,10 +24,12 @@
|
|||
#define FLOAT16_T float
|
||||
#endif
|
||||
|
||||
#define MNN_FLASH_ATTENTION_BLOCK_SIZE 64
|
||||
|
||||
namespace MNN {
|
||||
|
||||
template <typename T>
|
||||
void CPUAttention::pack_query(Tensor* query, char* pack_q, char* sum_q, int seq_len, int h, float q_scale) {
|
||||
void CPUAttention::pack_query(Tensor* query, int8_t* pack_q, int8_t* sum_q, int seq_len, int h, float q_scale) {
|
||||
if (mUseGemmInt8) { // Shape of Query: numhead, [seqlen/eP8, headdim/lP8, eP8, lP8]
|
||||
mMinQ[h] = query->host<T>()[h * mHeadDim];
|
||||
mMaxQ[h] = query->host<T>()[h * mHeadDim];
|
||||
|
@ -75,21 +77,21 @@ void CPUAttention::pack_query(Tensor* query, char* pack_q, char* sum_q, int seq_
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void CPUAttention::unpack_QK(float * unpack_qk_dst, char * pack_qk_src, int seq_len, int kv_seq_len) {
|
||||
void CPUAttention::unpack_QK(float * unpack_qk_dst, int8_t * pack_qk_src, int seq_len, int kv_seq_len) {
|
||||
float * dst = unpack_qk_dst;
|
||||
T * src = (T *)(pack_qk_src);
|
||||
// [kv_seq_len/unit, seq_len, unit] -> [seq_len, kv_seq_len]
|
||||
// [kv_seq_len/mPack, seq_len, mPack] -> [seq_len, kv_seq_len]
|
||||
for (int i = 0; i < seq_len; i++) {
|
||||
for (int j = 0; j < kv_seq_len; j++) {
|
||||
int out_index = j / unit;
|
||||
int in_index = j % unit;
|
||||
dst[i * kv_seq_len + j] = src[out_index * seq_len * unit + i * unit + in_index];
|
||||
int out_index = j / mPack;
|
||||
int in_index = j % mPack;
|
||||
dst[i * kv_seq_len + j] = src[out_index * seq_len * mPack + i * mPack + in_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void pack_QK(char * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_len, int eP, int lP, int bytes) {
|
||||
static void pack_QK(int8_t * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_len, int eP, int lP, int bytes) {
|
||||
T * dst = reinterpret_cast<T*>(pack_qk_dst);
|
||||
float * src = reinterpret_cast<float*>(qk_src);
|
||||
// [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len/lP, eP, lP]
|
||||
|
@ -108,38 +110,50 @@ static void pack_QK(char * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale, float min_val, const Tensor* mask) {
|
||||
if (mask == nullptr) {
|
||||
for (int i = 0; i < kv_seq_len; i++) {
|
||||
static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale, float min_val, const Tensor* maskTensor, int offset, int startIndx, int processedKvLen) {
|
||||
|
||||
int endIndx = startIndx + processedKvLen;
|
||||
if (maskTensor == nullptr) {
|
||||
for (int i = 0; i < processedKvLen; i++) {
|
||||
unpack_qk[i] = unpack_qk[i] * mScale;
|
||||
}
|
||||
} else if (mask->getType() == halide_type_of<float>()) {
|
||||
return;
|
||||
}
|
||||
const int8_t* mask = maskTensor->host<int8_t>();
|
||||
halide_type_t htype = maskTensor->getType();
|
||||
int maskSize = maskTensor->elementSize();
|
||||
|
||||
if (htype == halide_type_of<float>()) {
|
||||
// float mask
|
||||
T* fpmask_ptr = mask->host<T>();
|
||||
if (mask->elementSize() == seq_len * kv_seq_len) {
|
||||
// normal mask for all token
|
||||
for (int i = 0; i < seq_len * kv_seq_len; i++) {
|
||||
unpack_qk[i] = unpack_qk[i] * mScale + fpmask_ptr[i];
|
||||
}
|
||||
} else {
|
||||
// square mask just for new generation token
|
||||
int offset = kv_seq_len - seq_len;
|
||||
T* fpmask_ptr = (T*)mask;
|
||||
if (maskSize == seq_len * kv_seq_len) { // sliding attention, mask shape: [seq_len, kv_seq_len]
|
||||
for (int i = 0; i < seq_len; ++i) {
|
||||
auto unpack_qki = unpack_qk + i * kv_seq_len;
|
||||
auto fpmask_ptri = fpmask_ptr + i * seq_len;
|
||||
for (int j = 0; j < offset; ++j) {
|
||||
unpack_qki[j] = unpack_qki[j] * mScale;
|
||||
auto unpack_qki = unpack_qk + i * processedKvLen;
|
||||
auto fpmask_ptri = fpmask_ptr + i * kv_seq_len;
|
||||
for (int j = startIndx; j < endIndx; ++j) {
|
||||
unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale + fpmask_ptri[j];
|
||||
}
|
||||
for (int j = 0; j < seq_len; ++j) {
|
||||
unpack_qki[offset + j] = unpack_qki[offset + j] * mScale + fpmask_ptri[j];
|
||||
}
|
||||
} else { // mask shape: [seq_len, seq_len]
|
||||
for (int i = 0; i < seq_len; ++i) {
|
||||
auto unpack_qki = unpack_qk + i * processedKvLen;
|
||||
auto fpmask_ptri = fpmask_ptr + i * seq_len;
|
||||
|
||||
auto notMaskIndx = ALIMIN(endIndx, offset);
|
||||
auto stMaskIndx = ALIMAX(startIndx, offset);
|
||||
for (int j = startIndx; j < notMaskIndx; ++j) {
|
||||
unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale;
|
||||
}
|
||||
for (int j = stMaskIndx; j < endIndx; ++j) {
|
||||
unpack_qki[j - startIndx] = unpack_qki[j - startIndx] * mScale + fpmask_ptri[j - offset];
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// int mask
|
||||
int* mask_ptr = mask->host<int>();
|
||||
for (int i = 0; i < seq_len * kv_seq_len; i++) {
|
||||
if (mask_ptr[i]) {
|
||||
int* mask_ptr = (int*)mask;
|
||||
for (int i = 0; i < seq_len * processedKvLen; i++) {
|
||||
if (mask_ptr[i / processedKvLen * kv_seq_len + i % processedKvLen]) {
|
||||
unpack_qk[i] = unpack_qk[i] * mScale;
|
||||
} else {
|
||||
unpack_qk[i] = min_val;
|
||||
|
@ -148,41 +162,47 @@ static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale
|
|||
}
|
||||
}
|
||||
|
||||
static void softmax_QK(float* softmax_qk_addr, float* unpack_qk_addr, int seq_len, int kv_seq_len) {
|
||||
for (int i = 0; i < seq_len; i++) { // softmax each row
|
||||
MNNSoftmax(softmax_qk_addr + i * kv_seq_len, unpack_qk_addr + i * kv_seq_len, kv_seq_len);
|
||||
typedef void(softmaxFunc)(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize);
|
||||
template <typename T>
|
||||
static void softmaxQK(float* softmax_qk_addr, float* unpack_qk_addr, float* runningMax, float* runningSum, float* diffScale, const float* sinkPtr, softmaxFunc* sffunc, int seq_len, int kv_seq_len, int headIdx, bool isLastKvBlock) {
|
||||
|
||||
// not sliding attention
|
||||
if (sinkPtr == nullptr) {
|
||||
sffunc(softmax_qk_addr, unpack_qk_addr, runningMax, runningSum, diffScale, seq_len, kv_seq_len);
|
||||
return;
|
||||
}
|
||||
|
||||
float sink = ((T*)sinkPtr)[headIdx];
|
||||
if (!runningMax && !runningSum) { // Do not use flash attention
|
||||
|
||||
for (int i = 0; i < seq_len; ++i) {
|
||||
float exprOffset[4] = {1, 0, -sink, 1.f};
|
||||
MNNExp(softmax_qk_addr + i * kv_seq_len, unpack_qk_addr + i * kv_seq_len, exprOffset, kv_seq_len);
|
||||
for (int j = 0; j < kv_seq_len; ++j) {
|
||||
softmax_qk_addr[i * kv_seq_len + j] /= exprOffset[3];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
static void sink_softmax_QK(float* softmax_qk_addr, float* unpack_qk_addr, int seq_len, int kv_seq_len, float sink) {
|
||||
// TODO: opt
|
||||
std::vector<float> buffer(2 * (kv_seq_len + 1));
|
||||
float* sinkSrc = buffer.data();
|
||||
float* sinkDst = buffer.data() + kv_seq_len + 1;
|
||||
for (int i = 0; i < seq_len; i++) { // softmax each row
|
||||
::memcpy(sinkSrc, unpack_qk_addr + i * kv_seq_len, kv_seq_len * sizeof(float));
|
||||
sinkSrc[kv_seq_len] = sink;
|
||||
float rowMax = sink;
|
||||
for (int j = 0; j < kv_seq_len; j++) {
|
||||
rowMax = ALIMAX(rowMax, sinkSrc[j]);
|
||||
// Use flash attention
|
||||
if (isLastKvBlock) {
|
||||
for (int i = 0; i < seq_len; ++i) {
|
||||
runningSum[i] += expf(sink - runningMax[i]);
|
||||
}
|
||||
for (int j = 0; j < kv_seq_len + 1; j++) {
|
||||
sinkSrc[j] = sinkSrc[j] - rowMax;
|
||||
}
|
||||
MNNSoftmax(sinkDst, sinkSrc, kv_seq_len + 1);
|
||||
::memcpy(softmax_qk_addr + i * kv_seq_len, sinkDst, kv_seq_len * sizeof(float));
|
||||
}
|
||||
MNNSoftmax(softmax_qk_addr, unpack_qk_addr, runningMax, runningSum, diffScale, seq_len, kv_seq_len);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void unpack_QKV(char* pack_qkv, char* unpack_qkv, int mNumHead, int mHeadDim, int unit, int seq_len) {
|
||||
static void unpack_QKV(int8_t* pack_qkv, int8_t* unpack_qkv, int mNumHead, int mHeadDim, int mPack, int seq_len) {
|
||||
auto src_ptr = reinterpret_cast<T*>(pack_qkv);
|
||||
auto dst_ptr = reinterpret_cast<T*>(unpack_qkv);
|
||||
for (int i = 0; i < seq_len; i++) {
|
||||
for (int j = 0; j < mHeadDim; j++) {
|
||||
int a = j / unit;
|
||||
int b = j % unit;
|
||||
dst_ptr[i * mNumHead * mHeadDim + j] = src_ptr[a * seq_len * unit + i * unit + b];
|
||||
int a = j / mPack;
|
||||
int b = j % mPack;
|
||||
dst_ptr[i * mNumHead * mHeadDim + j] = src_ptr[a * seq_len * mPack + i * mPack + b];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -191,10 +211,10 @@ ErrorCode CPUAttention::onResize(const std::vector<Tensor*>& inputs, const std::
|
|||
auto core = static_cast<CPUBackend *>(backend())->functions();
|
||||
core->MNNGetMatMulPackMode(&eP, &lP, &hP);
|
||||
mThreadNum = ((CPUBackend *)backend())->threadNumber();
|
||||
unit = core->pack;
|
||||
mPack = core->pack;
|
||||
bytes = core->bytes;
|
||||
int qkvQuantOptions = static_cast<CPUBackend *>(backend())->getRuntime()->hint().qkvQuantOption;
|
||||
mUseGemmInt8 = (qkvQuantOptions == 4);
|
||||
mUseGemmInt8 = (qkvQuantOptions % 8 == 4);
|
||||
if (mUseGemmInt8) {
|
||||
static_cast<CPUBackend*>(backend())->int8Functions()->MNNGetGemmUnit(&hP8, &lP8, &eP8);
|
||||
}
|
||||
|
@ -208,7 +228,7 @@ ErrorCode CPUAttention::onResize(const std::vector<Tensor*>& inputs, const std::
|
|||
if (mUseGemmInt8) {
|
||||
mPackQ.reset(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(seq_len, eP8), UP_DIV(mHeadDim, lP8), eP8 * lP8}));
|
||||
mSumQ.reset(Tensor::createDevice<int32_t>({mThreadNum, UP_DIV(seq_len, eP8), eP8}));
|
||||
mPackQKV.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(mHeadDim, unit), seq_len, unit}));
|
||||
mPackQKV.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack}));
|
||||
backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC);
|
||||
backend()->onAcquireBuffer(mSumQ.get(), Backend::DYNAMIC);
|
||||
backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC);
|
||||
|
@ -220,18 +240,40 @@ ErrorCode CPUAttention::onResize(const std::vector<Tensor*>& inputs, const std::
|
|||
mQueryScale.resize(mNumHead);
|
||||
mQueryZeroPoint.resize(mNumHead);
|
||||
} else {
|
||||
mPackQ.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mHeadDim, lP), eP}));
|
||||
mPackQKV.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(mHeadDim, unit), seq_len, unit}));
|
||||
mPackQ.reset(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mHeadDim, lP), eP * bytes}));
|
||||
mPackQKV.reset(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack * bytes}));
|
||||
backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC);
|
||||
backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC);
|
||||
|
||||
// flash attention
|
||||
if (qkvQuantOptions / 8 == 1) {
|
||||
mRunningMax.reset(Tensor::createDevice<int8_t>({mThreadNum, seq_len * 4}));
|
||||
mRunningSum.reset(Tensor::createDevice<int8_t>({mThreadNum, seq_len * 4}));
|
||||
mExpfDiffMax.reset(Tensor::createDevice<int8_t>({mThreadNum, seq_len * 4}));
|
||||
mTempOut.reset(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(mHeadDim, mPack), seq_len, mPack * bytes}));
|
||||
|
||||
backend()->onAcquireBuffer(mRunningMax.get(), Backend::DYNAMIC);
|
||||
backend()->onAcquireBuffer(mRunningSum.get(), Backend::DYNAMIC);
|
||||
backend()->onAcquireBuffer(mExpfDiffMax.get(), Backend::DYNAMIC);
|
||||
backend()->onAcquireBuffer(mTempOut.get(), Backend::DYNAMIC);
|
||||
}
|
||||
|
||||
backend()->onReleaseBuffer(mPackQ.get(), Backend::DYNAMIC);
|
||||
backend()->onReleaseBuffer(mPackQKV.get(), Backend::DYNAMIC);
|
||||
|
||||
if (qkvQuantOptions / 8 == 1) {
|
||||
backend()->onReleaseBuffer(mRunningMax.get(), Backend::DYNAMIC);
|
||||
backend()->onReleaseBuffer(mRunningSum.get(), Backend::DYNAMIC);
|
||||
backend()->onReleaseBuffer(mExpfDiffMax.get(), Backend::DYNAMIC);
|
||||
backend()->onReleaseBuffer(mTempOut.get(), Backend::DYNAMIC);
|
||||
}
|
||||
}
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
||||
auto core = static_cast<CPUBackend *>(backend())->functions();
|
||||
auto qkvQuantOptions = static_cast<CPUBackend *>(backend())->getRuntime()->hint().qkvQuantOption;
|
||||
auto query = inputs[0];
|
||||
auto key = inputs[1];
|
||||
auto value = inputs[2];
|
||||
|
@ -283,146 +325,146 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std:
|
|||
int max_len = mKVCacheManager->maxLength();
|
||||
bool quant_key = mKVCacheManager->config()->mQuantKey;
|
||||
bool quant_value = mKVCacheManager->config()->mQuantValue;
|
||||
|
||||
mBlockKV = (qkvQuantOptions / 8 == 1) ? ALIMIN(MNN_FLASH_ATTENTION_BLOCK_SIZE, kv_seq_len) : kv_seq_len;
|
||||
int32_t units[2] = {eP, lP};
|
||||
|
||||
// Temporary tensors for intermediate results
|
||||
std::shared_ptr<Tensor> packQK(Tensor::createDevice<float>({mThreadNum, UP_DIV(kv_seq_len, unit), seq_len, unit}));
|
||||
std::shared_ptr<Tensor> unpackQK(Tensor::createDevice<int32_t>({mThreadNum, seq_len, kv_seq_len}));
|
||||
std::shared_ptr<Tensor> softmMaxQ(Tensor::createDevice<int32_t>({mThreadNum, seq_len, kv_seq_len}));
|
||||
std::shared_ptr<Tensor> newPackQK(Tensor::createDevice<float>({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(kv_seq_len, lP), eP}));
|
||||
std::shared_ptr<Tensor> dequantV(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), kv_seq_len, hP}));
|
||||
backend()->onAcquireBuffer(packQK.get(), Backend::STATIC);
|
||||
std::shared_ptr<Tensor> unpackQK(Tensor::createDevice<int32_t>({mThreadNum, seq_len, mBlockKV}));
|
||||
std::shared_ptr<Tensor> softmMaxQ(Tensor::createDevice<int32_t>({mThreadNum, seq_len, mBlockKV}));
|
||||
std::shared_ptr<Tensor> newPackQK(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mBlockKV, lP), eP * bytes}));
|
||||
std::shared_ptr<Tensor> dequantV(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mHeadDim, hP), kv_seq_len, hP * bytes}));
|
||||
// mTempQKBlock.reset(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(mBlockKV, mPack), seq_len, mPack * bytes}));
|
||||
std::shared_ptr<Tensor> tempQKBlock(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(mBlockKV, mPack), seq_len, mPack * bytes}));
|
||||
backend()->onAcquireBuffer(unpackQK.get(), Backend::STATIC);
|
||||
backend()->onAcquireBuffer(softmMaxQ.get(), Backend::STATIC);
|
||||
backend()->onAcquireBuffer(newPackQK.get(), Backend::STATIC);
|
||||
backend()->onAcquireBuffer(tempQKBlock.get(), Backend::STATIC);
|
||||
if (quant_value) {
|
||||
backend()->onAcquireBuffer(dequantV.get(), Backend::STATIC);
|
||||
mKVCacheManager->onDequantValue(dequantV.get());
|
||||
}
|
||||
const float* sinksPtr = sinks ? sinks->host<float>() : nullptr;
|
||||
std::function<void(int)> mCompute = [=](int tId) {
|
||||
auto pack_q = mPackQ->host<char>() + tId * UP_DIV(seq_len, eP) * ROUND_UP(mHeadDim, lP) * eP * bytes;
|
||||
auto pack_qk = packQK->host<char>() + tId * UP_DIV(kv_seq_len, unit) * seq_len * unit * bytes;
|
||||
char * sum_q = nullptr;
|
||||
auto unpack_qk = unpackQK->host<float>() + tId * seq_len * kv_seq_len;
|
||||
auto softmax_qk = softmMaxQ->host<float>() + tId * seq_len * kv_seq_len;
|
||||
auto new_pack_qk = newPackQK->host<char>() + tId * UP_DIV(seq_len, eP) * ROUND_UP(kv_seq_len, lP) * eP * bytes;
|
||||
auto pack_qkv = mPackQKV->host<char>() + tId * UP_DIV(mHeadDim, unit) * seq_len * unit * bytes;
|
||||
auto qReordered = mPackQ->host<int8_t>() + tId * mPackQ->stride(0);
|
||||
auto qkPacked = tempQKBlock->host<int8_t>() + tId * tempQKBlock->stride(0);
|
||||
int8_t * sum_q = nullptr;
|
||||
auto qkFlatten = unpackQK->host<float>() + tId * unpackQK->stride(0);
|
||||
auto qkSoftmax = softmMaxQ->host<float>() + tId * softmMaxQ->stride(0);
|
||||
auto qkReordered = newPackQK->host<int8_t>() + tId * newPackQK->stride(0);
|
||||
auto qkvPacked = mPackQKV->host<int8_t>() + tId * mPackQKV->stride(0);
|
||||
auto QxK = quant_key ? core->MNNPackedMatMul_int8 : core->MNNPackedMatMul;
|
||||
auto QxK_remain = quant_key ? core->MNNPackedMatMulRemain_int8 : core->MNNPackedMatMulRemain;
|
||||
|
||||
// Flash Attention
|
||||
auto runningMax = mRunningMax ? (float*)(mRunningMax->host<int8_t>() + tId * mRunningMax->stride(0)) : nullptr;
|
||||
auto runningSum = mRunningSum ? (float*)(mRunningSum->host<int8_t>() + tId * mRunningSum->stride(0)) : nullptr;
|
||||
auto diffScale = mExpfDiffMax ? (float*)(mExpfDiffMax->host<int8_t>() + tId * mExpfDiffMax->stride(0)) : nullptr;
|
||||
auto outputPacked = mTempOut ? mTempOut->host<int8_t>() + tId * mTempOut->stride(0) : qkvPacked;
|
||||
int head_index = tId * tileCount;
|
||||
int kvBlocks = UP_DIV(kv_seq_len, mBlockKV);
|
||||
|
||||
if (mUseGemmInt8) {
|
||||
pack_q = mPackQ->host<char>() + tId * UP_DIV(seq_len, eP8) * UP_DIV(mHeadDim, lP8) * eP8 * lP8;
|
||||
sum_q = mSumQ->host<char>() + tId * UP_DIV(seq_len, eP8) * eP8 * 4;
|
||||
qReordered = mPackQ->host<int8_t>() + tId * UP_DIV(seq_len, eP8) * UP_DIV(mHeadDim, lP8) * eP8 * lP8;
|
||||
sum_q = mSumQ->host<int8_t>() + tId * UP_DIV(seq_len, eP8) * eP8 * 4;
|
||||
}
|
||||
for (int h = head_index; h < head_index + tileCount && h < mNumHead; h++) {
|
||||
int kv_h = h / group_size;
|
||||
char * key_addr = mKVCacheManager->addrOfKey(kv_h);
|
||||
char * scale_addr = mKVCacheManager->addrOfScale(kv_h);
|
||||
char * zero_point_addr = mKVCacheManager->addrOfZeroPoint(kv_h);
|
||||
char * key_sum_addr = mKVCacheManager->addrOfKeySum(kv_h);
|
||||
char * value_addr = quant_value ? (dequantV->host<char>() + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(kv_seq_len, lP) * hP * bytes) : mKVCacheManager->addrOfValue(kv_h);
|
||||
if (bytes == 2) {
|
||||
pack_query<FLOAT16_T>(query, pack_q, sum_q, seq_len, h, q_scale);
|
||||
} else {
|
||||
pack_query<float>(query, pack_q, sum_q, seq_len, h, q_scale);
|
||||
if (runningSum && runningMax) {
|
||||
memset(runningSum, 0, mRunningSum->stride(0));
|
||||
if (sinksPtr == nullptr) {
|
||||
for (int k = 0; k < seq_len; ++k) {
|
||||
runningMax[k] = -std::numeric_limits<float>::infinity();
|
||||
}
|
||||
} else {
|
||||
float sinkVal;
|
||||
if (bytes == 2) {
|
||||
sinkVal = ((FLOAT16_T*)sinksPtr)[h];
|
||||
} else {
|
||||
sinkVal =sinksPtr[h];
|
||||
}
|
||||
for (int k = 0; k < seq_len; ++k) {
|
||||
runningMax[k] = sinkVal;
|
||||
}
|
||||
}
|
||||
}
|
||||
// query @ key
|
||||
int kv_h = h / group_size;
|
||||
int8_t * key_addr = mKVCacheManager->addrOfKey(kv_h);
|
||||
int8_t * scale_addr = mKVCacheManager->addrOfScale(kv_h);
|
||||
int8_t * zero_point_addr = mKVCacheManager->addrOfZeroPoint(kv_h);
|
||||
int8_t * key_sum_addr = mKVCacheManager->addrOfKeySum(kv_h);
|
||||
int8_t * value_addr = quant_value ? (dequantV->host<int8_t>() + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(kv_seq_len, lP) * hP * bytes) : mKVCacheManager->addrOfValue(kv_h);
|
||||
if (mUseGemmInt8) {
|
||||
auto GemmInt8Kernel = static_cast<CPUBackend*>(backend())->int8Functions()->Int8GemmKernel;
|
||||
if (bytes == 2 && unit == 8) {
|
||||
GemmInt8Kernel = static_cast<CPUBackend*>(backend())->int8Functions()->MNNGemmInt8AddBiasScale_Unit_FP16;
|
||||
}
|
||||
std::vector<float> postScale(ROUND_UP(kv_seq_len, hP8), 0.0f);
|
||||
for (int i = 0; i < kv_seq_len; i++) {
|
||||
postScale[i] = ((float*)scale_addr)[i] * mQueryScale[h] * q_scale;
|
||||
}
|
||||
std::vector<float> weightQuantBias(ROUND_UP(kv_seq_len, hP8), 0.0f);
|
||||
for (int i = 0; i < kv_seq_len; i++) {
|
||||
weightQuantBias[i] = -((float*)scale_addr)[i] * ((float*)zero_point_addr)[i] * q_scale;
|
||||
}
|
||||
std::vector<float> biasFloat(ROUND_UP(kv_seq_len, hP8), 0.0f);
|
||||
for (int i = 0; i < kv_seq_len; i++) {
|
||||
biasFloat[i] = -mQueryScale[h] * mQueryZeroPoint[h] * ((float*)key_sum_addr)[i] * q_scale;
|
||||
}
|
||||
QuanPostTreatParameters post;
|
||||
post.bias = nullptr;
|
||||
post.biasFloat = biasFloat.data();
|
||||
post.blockNum = 1;
|
||||
post.inputBias = nullptr;
|
||||
post.inputScale = nullptr;
|
||||
post.fp32minmax = nullptr;
|
||||
post.scale = postScale.data();
|
||||
post.useInt8 = false;
|
||||
post.weightKernelSum = weightQuantBias.data();
|
||||
int N = UP_DIV(seq_len, eP8);
|
||||
for (int i = 0; i < N; i++) {
|
||||
int realcount = ALIMIN(eP8, seq_len - i * eP8);
|
||||
post.srcKernelSum = (float*)((char*)sum_q + i * eP8 * 4);
|
||||
GemmInt8Kernel(
|
||||
(int8_t*)pack_qk + i * eP8 * unit * bytes,
|
||||
(int8_t*)pack_q + i * ROUND_UP(mHeadDim, lP8) * eP8,
|
||||
(int8_t*)key_addr,
|
||||
UP_DIV(mHeadDim, lP8),
|
||||
seq_len * unit * bytes,
|
||||
UP_DIV(kv_seq_len, unit),
|
||||
&post,
|
||||
realcount
|
||||
);
|
||||
if (bytes == 2) {
|
||||
pack_query<FLOAT16_T>(query, qReordered, sum_q, seq_len, h, q_scale);
|
||||
} else {
|
||||
pack_query<float>(query, qReordered, sum_q, seq_len, h, q_scale);
|
||||
}
|
||||
} else {
|
||||
core->MNNAttenPackAndScaleSingleHead((float*)qReordered, (float*)(query->host<int8_t>() + h * mHeadDim * bytes), mHeadDim * mNumHead, &q_scale, units, seq_len, mHeadDim);
|
||||
}
|
||||
else {
|
||||
for (int i = 0; i < kvBlocks; ++i) {
|
||||
int subKvSeqLen = ALIMIN(mBlockKV, kv_seq_len - i * mBlockKV);
|
||||
auto keyPtr = key_addr + i * UP_DIV(mBlockKV, hP) * ROUND_UP(mHeadDim, lP) * hP * bytes;
|
||||
auto valuePtr = value_addr + i * UP_DIV(mBlockKV, lP) * hP * lP * bytes;
|
||||
// query @ key
|
||||
{
|
||||
int loop_e = seq_len / eP;
|
||||
int remain = seq_len % eP;
|
||||
auto qStride0 = ROUND_UP(mHeadDim, lP) * eP * bytes;
|
||||
size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)mHeadDim, lP), (size_t)subKvSeqLen, (size_t)seq_len * mPack * bytes, 0, 0, 0};
|
||||
for (int ei = 0 ; ei < loop_e; ei++) {
|
||||
QxK((float*)(qkPacked + (ei * eP * mPack) * bytes), (float*)(qReordered + ei * qStride0), (float*)keyPtr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
|
||||
}
|
||||
QxK_remain((float*)(qkPacked + (loop_e * eP * mPack) * bytes), (float*)(qReordered + loop_e * qStride0), (float*)keyPtr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
|
||||
}
|
||||
// qk: [kv_seq_len/mPack, seq_len, mPack] -> [seq_len/eP, kv_seq_len, eP]
|
||||
{
|
||||
if(bytes == 2) {
|
||||
if (seq_len == 1) {
|
||||
core->MNNLowpToFp32((int16_t*)qkPacked, qkFlatten, seq_len * subKvSeqLen);
|
||||
} else {
|
||||
core->MNNAttenUnpackAndConvertFp16(qkFlatten, (float*)qkPacked, subKvSeqLen, seq_len, mPack);
|
||||
}
|
||||
mask_QK<FLOAT16_T>(qkFlatten, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask, kv_seq_len - seq_len, i * mBlockKV, subKvSeqLen);
|
||||
softmaxQK<FLOAT16_T>(qkSoftmax, qkFlatten, runningMax, runningSum, diffScale, sinksPtr, core->MNNSoftmax, seq_len, subKvSeqLen, h, i == kvBlocks - 1);
|
||||
core->MNNAttenPackAndConvertFp32((float*)qkReordered, qkSoftmax, units, seq_len, subKvSeqLen);
|
||||
} else {
|
||||
if (seq_len > 1) {
|
||||
int32_t areaOffset[2] = {seq_len, seq_len};
|
||||
core->MNNUnpackCUnitTranspose(qkFlatten, (float*)qkPacked, seq_len, subKvSeqLen, areaOffset);
|
||||
} else {
|
||||
memcpy(qkFlatten, qkPacked, subKvSeqLen * sizeof(float));
|
||||
}
|
||||
mask_QK<float>(qkFlatten, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask, kv_seq_len - seq_len, i * mBlockKV, subKvSeqLen);
|
||||
softmaxQK<float>(qkSoftmax, qkFlatten, runningMax, runningSum, diffScale, sinksPtr, core->MNNSoftmax, seq_len, subKvSeqLen, h, i == kvBlocks - 1);
|
||||
packKvCache((float*)qkReordered, qkSoftmax, seq_len, subKvSeqLen, eP);
|
||||
}
|
||||
}
|
||||
// qk @ v
|
||||
// TODO: update qkvPacked using diffScale
|
||||
size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)subKvSeqLen, lP), (size_t)mHeadDim, (size_t)seq_len * mPack * bytes, 0, 0, 0};
|
||||
size_t bExtraStride = (UP_DIV(max_len, lP) - UP_DIV(subKvSeqLen + i * mBlockKV, lP) + UP_DIV(i * mBlockKV, lP)) * hP * lP * bytes;
|
||||
shapeParameters[5] = quant_value ? 0 : bExtraStride;
|
||||
int loop_e = seq_len / eP;
|
||||
int remain = seq_len % eP;
|
||||
auto qStride0 = ROUND_UP(mHeadDim, lP) * eP * bytes;
|
||||
size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)mHeadDim, lP), (size_t)kv_seq_len, (size_t)seq_len * unit * bytes, 0, 0, 0};
|
||||
for (int i = 0 ; i < loop_e; i++) {
|
||||
QxK((float*)(pack_qk + (i * eP * unit) * bytes), (float*)(pack_q + i * qStride0), (float*)key_addr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
|
||||
auto qkStride0 = ROUND_UP(subKvSeqLen, lP) * eP * bytes;
|
||||
for (int ei = 0 ; ei < loop_e; ei++) {
|
||||
core->MNNPackedMatMul((float*)(qkvPacked + (ei * eP * mPack) * bytes), (float*)(qkReordered + ei * qkStride0), (float*)valuePtr, shapeParameters, nullptr, nullptr, nullptr, nullptr);
|
||||
}
|
||||
QxK_remain((float*)(pack_qk + (loop_e * eP * unit) * bytes), (float*)(pack_q + loop_e * qStride0), (float*)key_addr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
|
||||
}
|
||||
// qk: [kv_seq_len/unit, seq_len, unit] -> [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len, eP]
|
||||
if (sinksPtr != nullptr) {
|
||||
if(bytes == 2) {
|
||||
unpack_QK<FLOAT16_T>(unpack_qk, pack_qk, seq_len, kv_seq_len);
|
||||
mask_QK<FLOAT16_T>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask);
|
||||
sink_softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len, sinksPtr[h]);
|
||||
pack_QK<FLOAT16_T>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes);
|
||||
} else {
|
||||
unpack_QK<float>(unpack_qk, pack_qk, seq_len, kv_seq_len);
|
||||
mask_QK<float>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask);
|
||||
sink_softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len, sinksPtr[h]);
|
||||
pack_QK<float>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes);
|
||||
}
|
||||
} else {
|
||||
if(bytes == 2) {
|
||||
unpack_QK<FLOAT16_T>(unpack_qk, pack_qk, seq_len, kv_seq_len);
|
||||
mask_QK<FLOAT16_T>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask);
|
||||
softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len);
|
||||
pack_QK<FLOAT16_T>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes);
|
||||
} else {
|
||||
unpack_QK<float>(unpack_qk, pack_qk, seq_len, kv_seq_len);
|
||||
mask_QK<float>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask);
|
||||
softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len);
|
||||
pack_QK<float>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes);
|
||||
core->MNNPackedMatMulRemain((float*)(qkvPacked + (loop_e * eP * mPack) * bytes), (float*)(qkReordered + loop_e * qkStride0), (float*)valuePtr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr);
|
||||
|
||||
if (runningMax != nullptr && runningSum != nullptr && diffScale != nullptr) {
|
||||
core->MNNFlashAttentionUpdateBlockOutput((float*)outputPacked, (float*)qkvPacked, diffScale, runningSum, UP_DIV(mHeadDim, mPack), seq_len, mPack, i, kvBlocks, mPackQKV->stride(0) / bytes, bytes);
|
||||
}
|
||||
}
|
||||
// qk @ v
|
||||
size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)kv_seq_len, lP), (size_t)mHeadDim, (size_t)seq_len * unit * bytes, 0, 0, 0};
|
||||
size_t bExtraStride = (UP_DIV(max_len, lP) - UP_DIV(kv_seq_len, lP)) * hP * lP * bytes;
|
||||
shapeParameters[5] = quant_value ? 0 : bExtraStride;
|
||||
int loop_e = seq_len / eP;
|
||||
int remain = seq_len % eP;
|
||||
auto qkStride0 = ROUND_UP(kv_seq_len, lP) * eP * bytes;
|
||||
for (int i = 0 ; i < loop_e; i++) {
|
||||
core->MNNPackedMatMul((float*)(pack_qkv + (i * eP * unit) * bytes), (float*)(new_pack_qk + i * qkStride0), (float*)value_addr, shapeParameters, nullptr, nullptr, nullptr, nullptr);
|
||||
}
|
||||
core->MNNPackedMatMulRemain((float*)(pack_qkv + (loop_e * eP * unit) * bytes), (float*)(new_pack_qk + loop_e * qkStride0), (float*)value_addr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr);
|
||||
// unpack: [head_dim/unit, seq_len, unit] -> [seq_len, num_head, head_dim]
|
||||
auto dst_ptr = outputs[0]->host<char>() + h * mHeadDim * bytes;
|
||||
// unpack: [head_dim/mPack, seq_len, mPack] -> [seq_len, num_head, head_dim]
|
||||
auto dst_ptr = outputs[0]->host<int8_t>() + h * mHeadDim * bytes;
|
||||
if (bytes == 2) {
|
||||
unpack_QKV<int16_t>(pack_qkv, dst_ptr, mNumHead, mHeadDim, unit, seq_len);
|
||||
unpack_QKV<int16_t>((int8_t*)outputPacked, dst_ptr, mNumHead, mHeadDim, mPack, seq_len);
|
||||
} else {
|
||||
unpack_QKV<float>(pack_qkv, dst_ptr, mNumHead, mHeadDim, unit, seq_len);
|
||||
unpack_QKV<float>((int8_t*)outputPacked, dst_ptr, mNumHead, mHeadDim, mPack, seq_len);
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -431,10 +473,10 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std:
|
|||
}
|
||||
MNN_CONCURRENCY_END();
|
||||
|
||||
backend()->onReleaseBuffer(packQK.get(), Backend::STATIC);
|
||||
backend()->onReleaseBuffer(unpackQK.get(), Backend::STATIC);
|
||||
backend()->onReleaseBuffer(softmMaxQ.get(), Backend::STATIC);
|
||||
backend()->onReleaseBuffer(newPackQK.get(), Backend::STATIC);
|
||||
backend()->onReleaseBuffer(tempQKBlock.get(), Backend::STATIC);
|
||||
if (quant_value){
|
||||
backend()->onReleaseBuffer(dequantV.get(), Backend::STATIC);
|
||||
}
|
||||
|
@ -460,9 +502,19 @@ CPUAttention::CPUAttention(Backend *backend, bool kv_cache) : Execution(backend)
|
|||
mPackQKV.reset(Tensor::createDevice<float>({1, 1, 1, 1}));
|
||||
MNN::KVCacheManager::KVCacheConfig kvconfig;
|
||||
int qkvQuantOptions = static_cast<CPUBackend *>(backend)->getRuntime()->hint().qkvQuantOption;
|
||||
kvconfig.mUseInt8Kernel = (qkvQuantOptions == 4);
|
||||
kvconfig.mQuantKey = (qkvQuantOptions == 4) || (qkvQuantOptions & 1);
|
||||
kvconfig.mQuantValue = (qkvQuantOptions == 4) || ((qkvQuantOptions >> 1) & 1);
|
||||
kvconfig.mUseInt8Kernel = (qkvQuantOptions % 8 == 4);
|
||||
|
||||
// qkvQuantOption % 8:
|
||||
// 0: Do not quantize
|
||||
// 1: Only quantize key, use int8 asymmetric quantization
|
||||
// 2: Only quantize value, use fp8 quantization
|
||||
// 3: quantize both key and value
|
||||
// 4: quantize query, key and value, and use gemm int8 kernel to compute K*V
|
||||
|
||||
// qkvQuantOption / 8:
|
||||
// 1: use flash attention
|
||||
kvconfig.mQuantKey = (qkvQuantOptions % 8 == 4) || (qkvQuantOptions % 8 == 1) || (qkvQuantOptions % 8 == 3);
|
||||
kvconfig.mQuantValue = (qkvQuantOptions % 8 == 4) || (qkvQuantOptions % 8 == 2);
|
||||
kvconfig.mKVCacheDir = static_cast<CPUBackend *>(backend)->getRuntime()->hint().kvcacheDirPath;
|
||||
kvconfig.mKVCacheSizeLimit = static_cast<CPUBackend *>(backend)->getRuntime()->hint().kvcacheSizeLimit;
|
||||
kvconfig.mExpandChunk = 64;
|
||||
|
|
|
@ -30,15 +30,16 @@ private:
|
|||
bool mKVCache = true;
|
||||
bool mUseGemmInt8 = false;
|
||||
int bytes = 4;
|
||||
int mThreadNum = 1;;
|
||||
int eP, lP, hP, unit; // float matmul packing
|
||||
int mThreadNum = 1;
|
||||
int mBlockKV = 512;
|
||||
int eP, lP, hP, mPack; // float matmul packing
|
||||
int eP8, lP8, hP8; // GemmInt8 packing
|
||||
int mNumHead, mKvNumHead, mHeadDim;
|
||||
std::shared_ptr<Tensor> mPackQ, mPackQKV, mSumQ;
|
||||
std::shared_ptr<Tensor> mPackQ, mPackQKV, mSumQ, mRunningMax, mRunningSum, mTempQKBlock, mTempOut, mExpfDiffMax;
|
||||
std::shared_ptr<KVCacheManager> mKVCacheManager = nullptr;
|
||||
std::vector<float> mMinQ, mMaxQ, mQueryScale, mQueryZeroPoint;
|
||||
template <typename T> void pack_query(Tensor* query, char* pack_q, char* sum_q, int seq_len, int h, float q_scale);
|
||||
template <typename T> void unpack_QK(float * unpack_qk_dst, char * pack_qk_src, int seq_len, int kv_seq_len);
|
||||
template <typename T> void pack_query(Tensor* query, int8_t* pack_q, int8_t* sum_q, int seq_len, int h, float q_scale);
|
||||
template <typename T> void unpack_QK(float * unpack_qk_dst, int8_t * pack_qk_src, int seq_len, int kv_seq_len);
|
||||
KVMeta* mMeta;
|
||||
};
|
||||
|
||||
|
|
|
@ -509,21 +509,21 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p
|
|||
mDmaInfo.reset(new CPURuntime::DynamicAllocator);
|
||||
mDmaInfo->mDynamicAllocator.reset(mRuntime->createDynamicBufferAlloctor(0));
|
||||
mDmaInfo->mCurrentDynamicAllocator = mDmaInfo->mDynamicAllocator.get();
|
||||
mDmaInfo->mCacheGroup.resize(MNN_CPU_MAX_BUFFER_INDEX);
|
||||
for (int i=0; i<mDmaInfo->mCacheGroup.size(); ++i) {
|
||||
mDmaInfo->mCacheGroup[i].reset(new CPUResizeCache);
|
||||
}
|
||||
} else {
|
||||
mDmaInfo = dynamicAlloc;
|
||||
}
|
||||
mPrecisionMode = precision;
|
||||
mCoreFunctions = MNNGetCoreFunctions();
|
||||
mInt8CoreFunctions = MNNGetInt8CoreFunctions();
|
||||
mCacheGroup.resize(MNN_CPU_MAX_BUFFER_INDEX);
|
||||
for (int i=0; i<mCacheGroup.size(); ++i) {
|
||||
mCacheGroup[i].reset(new CPUResizeCache);
|
||||
}
|
||||
mCache = mCacheGroup[0].get();
|
||||
mCache = mDmaInfo->mCacheGroup[0].get();
|
||||
}
|
||||
|
||||
CPUBackend::~CPUBackend() {
|
||||
mCacheGroup.clear();
|
||||
// Do nothing
|
||||
}
|
||||
void CPUBackend::_resetDynamicMemory() const {
|
||||
mRuntime->pCurrentStatus = mDmaInfo->mDynamicAllocator->apply();
|
||||
|
@ -560,7 +560,7 @@ bool CPUBackend::onSelectDynamicAllocator(int index, int maxIndex) {
|
|||
mRuntime->buffer(0)->release();
|
||||
mDmaInfo->mCurrentDynamicAllocator = mDmaInfo->mDynamicAllocator.get();
|
||||
}
|
||||
mCache = mCacheGroup[index].get();
|
||||
mCache = mDmaInfo->mCacheGroup[index].get();
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -23,12 +23,14 @@
|
|||
|
||||
namespace MNN {
|
||||
class WorkerThread;
|
||||
class CPUResizeCache;
|
||||
class CPURuntime : public Runtime {
|
||||
public:
|
||||
struct DynamicAllocator {
|
||||
std::shared_ptr<BufferAllocator> mDynamicAllocator;
|
||||
std::shared_ptr<BufferAllocator> mDynamicAllocatorBackup;
|
||||
BufferAllocator* mCurrentDynamicAllocator = nullptr;
|
||||
std::vector<std::shared_ptr<CPUResizeCache>> mCacheGroup;
|
||||
};
|
||||
friend class CPUBackend;
|
||||
CPURuntime(const Backend::Info& info);
|
||||
|
@ -82,7 +84,6 @@ struct CoreFunctions;
|
|||
struct CoreInt8Functions;
|
||||
struct MatmulRelatedFunctions;
|
||||
|
||||
class CPUResizeCache;
|
||||
class CPUMemObj : public Backend::MemObj {
|
||||
public:
|
||||
CPUMemObj(BufferAllocator* allocator, MemChunk chunk, int size) : mAllocator(allocator), mChunk(chunk), mSize(size) {}
|
||||
|
@ -199,7 +200,6 @@ private:
|
|||
BackendConfig::MemoryMode mMemory;
|
||||
static std::map<OpType, CPUBackend::Creator*>* gCreator;
|
||||
CPUResizeCache* mCache;
|
||||
std::vector<std::shared_ptr<CPUResizeCache>> mCacheGroup;
|
||||
};
|
||||
/** execution cast wrapper. insert tensor cast dynamic. */
|
||||
class CastWrapExecution : public Execution {
|
||||
|
|
|
@ -170,7 +170,7 @@ cpu_mask_t MNNGetCPUMask(const std::vector<int>& cpuIds) {
|
|||
|
||||
// cpuinfo
|
||||
// Reference from: https://github.com/pytorch/cpuinfo
|
||||
#if defined(ENABLE_ARMV82) && defined(__arm__)
|
||||
#if (defined(ENABLE_ARMV82) && defined(__arm__)) || (defined(__ANDROID__) && defined(__aarch64__))
|
||||
|
||||
/* As per include/sys/system_properties.h in Android NDK */
|
||||
#define CPUINFO_HARDWARE_VALUE_MAX 64
|
||||
|
@ -1180,7 +1180,7 @@ struct cpuinfo_arm_chipset cpuinfo_arm_android_decode_chipset(const struct cpuin
|
|||
// MNN_PRINT("chipset vendor, series, model is: %d, %d, %d\n", chipset.vendor, chipset.series, chipset.model);
|
||||
return chipset;
|
||||
}
|
||||
static void _getInfoARMv7(MNNCPUInfo* cpuinfo_isa) {
|
||||
static void _getInfoArm(MNNCPUInfo* cpuinfo_isa) {
|
||||
// Get White List And Black List
|
||||
struct cpuinfo_arm_linux_processor* arm_linux_processors = NULL;
|
||||
if (0 == cpuinfo_isa->groups.size()) {
|
||||
|
@ -1500,8 +1500,8 @@ static void _fillInfo(MNNCPUInfo* cpuinfo_isa) {
|
|||
#if defined(__aarch64__)
|
||||
_getInfoAux(cpuinfo_isa);
|
||||
#endif
|
||||
#if defined(ENABLE_ARMV82) && defined(__arm__)
|
||||
_getInfoARMv7(cpuinfo_isa);
|
||||
#if (defined(ENABLE_ARMV82) && defined(__arm__)) || (defined(__ANDROID__) && defined(__aarch64__))
|
||||
_getInfoArm(cpuinfo_isa);
|
||||
#endif // #ifdef arm / arm64
|
||||
#endif // #ifdef __linux__
|
||||
|
||||
|
|
|
@ -200,7 +200,7 @@ int CPUSoftmax::_softmaxCommon(const uint8_t *srcData, uint8_t *dstData) {
|
|||
for (int v=0; v<mInside; ++v) {
|
||||
//TODO: Fix x86 compute error and use the same function
|
||||
#ifdef MNN_USE_SSE
|
||||
MNNSoftmax(workDst+v*mChannel, workSrc+v*mChannel, mChannel);
|
||||
MNNSoftmax(workDst+v*mChannel, workSrc+v*mChannel, nullptr, nullptr, nullptr, 1, mChannel);
|
||||
#else
|
||||
___MNNSoftmax(workDst+v*mChannel, workSrc+v*mChannel, mChannel, mulFunction);
|
||||
#endif
|
||||
|
|
|
@ -75,13 +75,13 @@ void KVCacheManager::resetKVCacheFileSize(size_t keySize, size_t valueSize) {
|
|||
void KVCacheManager::mmapKVCache(size_t keySize, size_t valueSize)
|
||||
{
|
||||
if (mMapKeyAddr == nullptr) {
|
||||
mMapKeyAddr = (char *)MNNMmapFile(mKeyCacheFD, keySize);
|
||||
mMapKeyAddr = (int8_t *)MNNMmapFile(mKeyCacheFD, keySize);
|
||||
if (mMapKeyAddr == nullptr) {
|
||||
MNN_PRINT("Failed to memory-map the kvcache!\n");
|
||||
}
|
||||
}
|
||||
if (mMapValueAddr == nullptr) {
|
||||
mMapValueAddr = (char *)MNNMmapFile(mValueCacheFD, valueSize);
|
||||
mMapValueAddr = (int8_t *)MNNMmapFile(mValueCacheFD, valueSize);
|
||||
if (mMapValueAddr == nullptr) {
|
||||
MNN_PRINT("Failed to memory-map the kvcache!\n");
|
||||
}
|
||||
|
@ -111,8 +111,8 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
|
|||
mBackend->onAcquireBuffer(new_key, Backend::STATIC);
|
||||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(
|
||||
new_key->host<char>() + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
|
||||
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
|
||||
new_key->host<int8_t>() + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
|
||||
mPastKey->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
|
||||
UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8
|
||||
);
|
||||
}
|
||||
|
@ -123,8 +123,8 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
|
|||
mBackend->onAcquireBuffer(new_key, Backend::STATIC);
|
||||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(
|
||||
new_key->host<char>() + h * new_key->stride(0),
|
||||
mPastKey->host<char>() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP),
|
||||
new_key->host<int8_t>() + h * new_key->stride(0),
|
||||
mPastKey->host<int8_t>() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP),
|
||||
ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP)
|
||||
);
|
||||
}
|
||||
|
@ -135,12 +135,12 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
|
|||
mBackend->onAcquireBuffer(new_key, Backend::STATIC);
|
||||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(
|
||||
new_key->host<char>() + h * new_key->stride(0) * mBytes,
|
||||
mPastKey->host<char>() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes,
|
||||
new_key->host<int8_t>() + h * new_key->stride(0) * mBytes,
|
||||
mPastKey->host<int8_t>() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes,
|
||||
ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes
|
||||
);
|
||||
if ((new_key->stride(0) - mPastKey->stride(0)) > 0) {
|
||||
memset(new_key->host<char>() + h * new_key->stride(0) * mBytes + mPastKey->stride(0) * mBytes, 0, (new_key->stride(0) - mPastKey->stride(0)) * mBytes);
|
||||
memset(new_key->host<int8_t>() + h * new_key->stride(0) * mBytes + mPastKey->stride(0) * mBytes, 0, (new_key->stride(0) - mPastKey->stride(0)) * mBytes);
|
||||
}
|
||||
}
|
||||
mPastKey.reset(new_key);
|
||||
|
@ -152,8 +152,8 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
|
|||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
||||
memcpy(
|
||||
new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP,
|
||||
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
|
||||
new_value->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP,
|
||||
mPastValue->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
|
||||
ROUND_UP(oldMaxLength, lP) * hP
|
||||
);
|
||||
}
|
||||
|
@ -166,12 +166,12 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
|
|||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
||||
memcpy(
|
||||
new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes,
|
||||
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
|
||||
new_value->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes,
|
||||
mPastValue->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
|
||||
ROUND_UP(oldMaxLength, lP) * hP * mBytes
|
||||
);
|
||||
if ((new_value->stride(1) - mPastValue->stride(1)) > 0) {
|
||||
memset(new_value->host<char>() + (h * new_value->stride(0) + i * new_value->stride(1)) * mBytes + mPastValue->stride(1) * mBytes, 0, (new_value->stride(1) - mPastValue->stride(1)) * mBytes);
|
||||
memset(new_value->host<int8_t>() + (h * new_value->stride(0) + i * new_value->stride(1)) * mBytes + mPastValue->stride(1) * mBytes, 0, (new_value->stride(1) - mPastValue->stride(1)) * mBytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -189,7 +189,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
|
|||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(
|
||||
mMapKeyAddr + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
|
||||
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
|
||||
mPastKey->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
|
||||
UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8
|
||||
);
|
||||
}
|
||||
|
@ -200,7 +200,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
|
|||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(
|
||||
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
|
||||
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
|
||||
mPastKey->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
|
||||
UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP
|
||||
);
|
||||
}
|
||||
|
@ -214,7 +214,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
|
|||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(
|
||||
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
|
||||
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
|
||||
mPastKey->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
|
||||
UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes
|
||||
);
|
||||
}
|
||||
|
@ -227,7 +227,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
|
|||
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
||||
memcpy(
|
||||
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP,
|
||||
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
|
||||
mPastValue->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
|
||||
ROUND_UP(oldMaxLength, lP) * hP
|
||||
);
|
||||
}
|
||||
|
@ -243,7 +243,7 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
|
|||
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
||||
memcpy(
|
||||
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes,
|
||||
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
|
||||
mPastValue->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
|
||||
ROUND_UP(oldMaxLength, lP) * hP * mBytes
|
||||
);
|
||||
}
|
||||
|
@ -282,8 +282,8 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
|
|||
memset(old_value->host<uint8_t>(), 0, old_value->length(0) * old_value->stride(0) * mBytes);
|
||||
}
|
||||
mmapKVCache(oldKeySize, oldValueSize);
|
||||
memcpy(old_key->host<char>(), mMapKeyAddr, oldKeySize);
|
||||
memcpy(old_value->host<char>(), mMapValueAddr, oldValueSize);
|
||||
memcpy(old_key->host<int8_t>(), mMapKeyAddr, oldKeySize);
|
||||
memcpy(old_value->host<int8_t>(), mMapValueAddr, oldValueSize);
|
||||
// Step 2: Resize the kvcache files and remap them
|
||||
unmapKVCache(oldKeySize, oldValueSize);
|
||||
resetKVCacheFileSize(keySize, valueSize);
|
||||
|
@ -293,7 +293,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
|
|||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(
|
||||
mMapKeyAddr + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
|
||||
old_key->host<char>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
|
||||
old_key->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8,
|
||||
UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8
|
||||
);
|
||||
}
|
||||
|
@ -301,7 +301,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
|
|||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(
|
||||
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
|
||||
old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
|
||||
old_key->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
|
||||
UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP
|
||||
);
|
||||
}
|
||||
|
@ -309,7 +309,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
|
|||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(
|
||||
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
|
||||
old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
|
||||
old_key->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
|
||||
UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes
|
||||
);
|
||||
}
|
||||
|
@ -319,7 +319,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
|
|||
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
||||
memcpy(
|
||||
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP,
|
||||
old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
|
||||
old_value->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
|
||||
ROUND_UP(oldMaxLength, lP) * hP
|
||||
);
|
||||
}
|
||||
|
@ -329,7 +329,7 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
|
|||
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
||||
memcpy(
|
||||
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes,
|
||||
old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
|
||||
old_value->host<int8_t>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
|
||||
ROUND_UP(oldMaxLength, lP) * hP * mBytes
|
||||
);
|
||||
}
|
||||
|
@ -460,9 +460,9 @@ void KVCacheManager::onRealloc(const KVMeta* meta) {
|
|||
mBackend->onAcquireBuffer(new_zeroPoint, Backend::STATIC);
|
||||
mBackend->onAcquireBuffer(new_sum, Backend::STATIC);
|
||||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(new_scale->host<char>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyScale->host<char>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4);
|
||||
memcpy(new_zeroPoint->host<char>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyZeroPoint->host<char>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4);
|
||||
memcpy(new_sum->host<char>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeySum->host<char>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4);
|
||||
memcpy(new_scale->host<int8_t>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyScale->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4);
|
||||
memcpy(new_zeroPoint->host<int8_t>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyZeroPoint->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4);
|
||||
memcpy(new_sum->host<int8_t>() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeySum->host<int8_t>() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4);
|
||||
}
|
||||
mKeyScale.reset(new_scale);
|
||||
mKeyZeroPoint.reset(new_zeroPoint);
|
||||
|
@ -473,8 +473,8 @@ void KVCacheManager::onRealloc(const KVMeta* meta) {
|
|||
mBackend->onAcquireBuffer(new_scale, Backend::STATIC);
|
||||
mBackend->onAcquireBuffer(new_zeroPoint, Backend::STATIC);
|
||||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(new_scale->host<char>() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyScale->host<char>() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes);
|
||||
memcpy(new_zeroPoint->host<char>() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyZeroPoint->host<char>() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes);
|
||||
memcpy(new_scale->host<int8_t>() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyScale->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes);
|
||||
memcpy(new_zeroPoint->host<int8_t>() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyZeroPoint->host<int8_t>() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes);
|
||||
}
|
||||
mKeyScale.reset(new_scale);
|
||||
mKeyZeroPoint.reset(new_zeroPoint);
|
||||
|
@ -690,8 +690,8 @@ void KVCacheManager::onDequantValue(Tensor * dequantedValues) {
|
|||
int tileCount = UP_DIV(mKvNumHead, mThreadNum);
|
||||
std::function<void(int)> dequant = [=](int tid) {
|
||||
for (int kv_h = tid * tileCount; kv_h < (tid+1) * tileCount && kv_h < mKvNumHead; kv_h++) {
|
||||
char * dst = dequantedValues->host<char>() + kv_h * UP_DIV(mHeadDim, hP) * mPastLength * hP * mBytes;
|
||||
char * src = addrOfValue(kv_h);
|
||||
int8_t * dst = dequantedValues->host<int8_t>() + kv_h * UP_DIV(mHeadDim, hP) * mPastLength * hP * mBytes;
|
||||
int8_t * src = addrOfValue(kv_h);
|
||||
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
||||
if (mBytes == 2) {
|
||||
core->MNNFp8ToFp16((uint16_t*)dst, (uint8_t*)src, mPastLength * hP);
|
||||
|
|
|
@ -46,8 +46,8 @@ private:
|
|||
std::shared_ptr<Tensor> mKeySum; // numhead, [maxlen/hP8, hP8]
|
||||
file_t mKeyCacheFD = INVALID_FILE; // The file descriptor of keys
|
||||
file_t mValueCacheFD = INVALID_FILE; // The file descriptor of values
|
||||
char * mMapKeyAddr = nullptr; // Memory-mapped address of keys
|
||||
char * mMapValueAddr = nullptr; // Memory-mapped address of values
|
||||
int8_t * mMapKeyAddr = nullptr; // Memory-mapped address of keys
|
||||
int8_t * mMapValueAddr = nullptr; // Memory-mapped address of values
|
||||
bool mKVCacheInDisk = false; // Whether the kvcache is in disk or in memory now
|
||||
int mPastLength = 0; // Length of past kvcache
|
||||
int mMaxLength = 0; // Capacity of current kvcache buffer (how many kv items can be stored at most)
|
||||
|
@ -104,15 +104,15 @@ public:
|
|||
return mMaxLength;
|
||||
}
|
||||
uint8_t* keyAddr() {
|
||||
char * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host<char>();
|
||||
int8_t * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host<int8_t>();
|
||||
return (uint8_t*)baseAddr;
|
||||
}
|
||||
uint8_t* valudAddr() {
|
||||
char * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host<char>();
|
||||
int8_t * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host<int8_t>();
|
||||
return (uint8_t*)baseAddr;
|
||||
}
|
||||
char * addrOfKey(int kv_h) {
|
||||
char * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host<char>();
|
||||
int8_t * addrOfKey(int kv_h) {
|
||||
int8_t * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host<int8_t>();
|
||||
if (mConfig.mUseInt8Kernel) {
|
||||
return baseAddr + kv_h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
|
||||
} else if (mConfig.mQuantKey) {
|
||||
|
@ -121,35 +121,35 @@ public:
|
|||
return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes;
|
||||
}
|
||||
}
|
||||
char * addrOfValue(int kv_h) {
|
||||
char * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host<char>();
|
||||
int8_t * addrOfValue(int kv_h) {
|
||||
int8_t * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host<int8_t>();
|
||||
if (mConfig.mQuantValue) {
|
||||
return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP;
|
||||
} else {
|
||||
return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP * mBytes;
|
||||
}
|
||||
}
|
||||
char * addrOfScale(int kv_h) {
|
||||
int8_t * addrOfScale(int kv_h) {
|
||||
if (mConfig.mUseInt8Kernel) {
|
||||
return mKeyScale->host<char>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4;
|
||||
return mKeyScale->host<int8_t>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4;
|
||||
} else if (mConfig.mQuantKey) {
|
||||
return mKeyScale->host<char>() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes;
|
||||
return mKeyScale->host<int8_t>() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
char * addrOfZeroPoint(int kv_h) {
|
||||
int8_t * addrOfZeroPoint(int kv_h) {
|
||||
if (mConfig.mUseInt8Kernel) {
|
||||
return mKeyZeroPoint->host<char>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4;
|
||||
return mKeyZeroPoint->host<int8_t>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4;
|
||||
} else if (mConfig.mQuantKey) {
|
||||
return mKeyZeroPoint->host<char>() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes;
|
||||
return mKeyZeroPoint->host<int8_t>() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
char * addrOfKeySum(int kv_h) {
|
||||
int8_t * addrOfKeySum(int kv_h) {
|
||||
if (mConfig.mUseInt8Kernel) {
|
||||
return mKeySum->host<char>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4;
|
||||
return mKeySum->host<int8_t>() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4;
|
||||
}else {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -73,6 +73,386 @@ void MNNTranspose16Bit(int16_t* dstO, const int16_t* srcO, int32_t* dim) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
#define EXP_APPROX_MIN_INPUT vdupq_n_f32(-88.0f)
|
||||
#define EXP_APPROX_MAX_INPUT vdupq_n_f32(88.0f)
|
||||
#define EXP_APPROX_LN2 vdupq_n_f32(0.69314718056f) // ln(2)
|
||||
#define EXP_APPROX_LN2_INV vdupq_n_f32(1.44269504089f) // 1/ln(2)
|
||||
// Fourth-order polynomial approximation coefficients of exp(r):
|
||||
// P(x) = c4*x^4 + c3*x^3 + c2*x^2 + c1*x + c0
|
||||
#define EXP_APPROX_C4 vdupq_n_f32(0.0416624f)
|
||||
#define EXP_APPROX_C3 vdupq_n_f32(0.166665f)
|
||||
#define EXP_APPROX_C2 vdupq_n_f32(0.500000f)
|
||||
#define EXP_APPROX_C1 vdupq_n_f32(1.0f)
|
||||
#define EXP_APPROX_C0 vdupq_n_f32(1.0f)
|
||||
|
||||
#ifndef __aarch64__
|
||||
static inline float32x4_t vrndaq_f32_compat(float32x4_t val) {
|
||||
const float32x4_t v_zero = vdupq_n_f32(0.0f);
|
||||
|
||||
float32x4_t v_truncated = vcvtq_f32_s32(vcvtq_s32_f32(val));
|
||||
|
||||
uint32x4_t v_is_positive_frac = vcgtq_f32(val, v_truncated);
|
||||
uint32x4_t v_is_negative_frac = vcltq_f32(val, v_truncated);
|
||||
|
||||
float32x4_t v_offset = vbslq_f32(v_is_positive_frac, vdupq_n_f32(1.0f), v_zero);
|
||||
v_offset = vbslq_f32(v_is_negative_frac, vdupq_n_f32(-1.0f), v_offset);
|
||||
|
||||
return vaddq_f32(v_truncated, v_offset);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
static inline float32x4_t expApprox(float32x4_t x) {
|
||||
x = vminq_f32(vmaxq_f32(x, EXP_APPROX_MIN_INPUT), EXP_APPROX_MAX_INPUT);
|
||||
|
||||
float32x4_t k_float;
|
||||
float32x4_t r;
|
||||
float32x4_t exp_r;
|
||||
|
||||
#if defined(__aarch64__)
|
||||
// 1. x = k * ln(2) + r
|
||||
k_float = vrndaq_f32(vmulq_f32(x, EXP_APPROX_LN2_INV));
|
||||
|
||||
// r = x - k * ln(2)
|
||||
r = vfmsq_f32(x, k_float, EXP_APPROX_LN2);
|
||||
|
||||
// 2. c0 + r*(c1 + r*(c2 + r*(c3 + r*c4))) (Horner's method)
|
||||
exp_r = vfmaq_f32(EXP_APPROX_C3, EXP_APPROX_C4, r); // c3 + c4*r
|
||||
exp_r = vfmaq_f32(EXP_APPROX_C2, exp_r, r); // c2 + r*(...)
|
||||
exp_r = vfmaq_f32(EXP_APPROX_C1, exp_r, r); // c1 + r*(...)
|
||||
exp_r = vfmaq_f32(EXP_APPROX_C0, exp_r, r); // c0 + r*(...)
|
||||
|
||||
#else
|
||||
|
||||
k_float = vrndaq_f32_compat(vmulq_f32(x, EXP_APPROX_LN2_INV));
|
||||
|
||||
|
||||
r = vsubq_f32(x, vmulq_f32(k_float, EXP_APPROX_LN2));
|
||||
|
||||
// 2. c0 + r*(c1 + r*(c2 + r*(c3 + r*c4)))
|
||||
exp_r = vmlaq_f32(EXP_APPROX_C3, EXP_APPROX_C4, r); // c3 + c4*r
|
||||
exp_r = vmlaq_f32(EXP_APPROX_C2, exp_r, r); // c2 + r*(...)
|
||||
exp_r = vmlaq_f32(EXP_APPROX_C1, exp_r, r); // c1 + r*(...)
|
||||
exp_r = vmlaq_f32(EXP_APPROX_C0, exp_r, r); // c0 + r*(...)
|
||||
|
||||
#endif
|
||||
|
||||
int32x4_t k_int = vcvtq_s32_f32(k_float);
|
||||
int32x4_t k_shifted = vshlq_n_s32(k_int, 23);
|
||||
float32x4_t result = vreinterpretq_f32_s32(vaddq_s32(vreinterpretq_s32_f32(exp_r), k_shifted));
|
||||
return result;
|
||||
}
|
||||
|
||||
void MNNExpC8(float* dst, const float* src, float* offset, const float* parameters, size_t countC8) {
|
||||
float32x4_t maxVec = vdupq_n_f32(offset[2]);
|
||||
float32x4_t sumVec0 = vdupq_n_f32(0);
|
||||
float32x4_t sumVec1 = vdupq_n_f32(0);
|
||||
|
||||
float32x4_t c0 = vdupq_n_f32(offset[0]);
|
||||
float32x4_t c1 = vdupq_n_f32(offset[1]);
|
||||
|
||||
|
||||
for (int i = 0; i < countC8; ++i) {
|
||||
float32x4_t srcVec0 = vld1q_f32(src);
|
||||
float32x4_t srcVec1 = vld1q_f32(src + 4);
|
||||
auto subVec0 = vaddq_f32(vmulq_f32(srcVec0, c0), maxVec);
|
||||
auto subVec1 = vaddq_f32(vmulq_f32(srcVec1, c0), maxVec);
|
||||
auto expVec0 = vaddq_f32(expApprox(subVec0), c1);
|
||||
auto expVec1 = vaddq_f32(expApprox(subVec1), c1);
|
||||
vst1q_f32(dst, expVec0);
|
||||
vst1q_f32(dst + 4, expVec1);
|
||||
sumVec0 = vaddq_f32(sumVec0, expVec0);
|
||||
sumVec1 = vaddq_f32(sumVec1, expVec1);
|
||||
|
||||
src += 8;
|
||||
dst += 8;
|
||||
}
|
||||
|
||||
sumVec0 = vaddq_f32(sumVec0, sumVec1);
|
||||
float32x2_t sumP = vpadd_f32(vget_low_f32(sumVec0), vget_high_f32(sumVec0));
|
||||
sumP = vpadd_f32(sumP, sumP);
|
||||
offset[3] += vget_lane_f32(sumP, 0);
|
||||
}
|
||||
|
||||
|
||||
void MNNExp(float* destPtr, const float* srcPtr, float* offset, size_t size) {
|
||||
float32x4_t maxVec = vdupq_n_f32(-offset[2]);
|
||||
float32x4_t sumVec0 = vdupq_n_f32(0);
|
||||
float32x4_t sumVec1 = vdupq_n_f32(0);
|
||||
if (offset[0] == 1.f && offset[1] == 0.f) {
|
||||
while (size >= 8) {
|
||||
float32x4_t srcVec0 = vld1q_f32(srcPtr);
|
||||
float32x4_t srcVec1 = vld1q_f32(srcPtr + 4);
|
||||
auto subVec0 = vsubq_f32(srcVec0, maxVec);
|
||||
auto subVec1 = vsubq_f32(srcVec1, maxVec);
|
||||
auto expVec0 = expApprox(subVec0);
|
||||
auto expVec1 = expApprox(subVec1);
|
||||
vst1q_f32(destPtr, expVec0);
|
||||
vst1q_f32(destPtr + 4, expVec1);
|
||||
sumVec0 = vaddq_f32(sumVec0, expVec0);
|
||||
sumVec1 = vaddq_f32(sumVec1, expVec1);
|
||||
srcPtr += 8;
|
||||
destPtr += 8;
|
||||
size -= 8;
|
||||
|
||||
}
|
||||
while (size >= 4) {
|
||||
float32x4_t srcVec0 = vld1q_f32(srcPtr);
|
||||
auto subVec0 = vsubq_f32(srcVec0, maxVec);
|
||||
auto expVec0 = expApprox(subVec0);
|
||||
sumVec0 = vaddq_f32(sumVec0, expVec0);
|
||||
vst1q_f32(destPtr, expVec0);
|
||||
srcPtr += 4;
|
||||
destPtr += 4;
|
||||
size -= 4;
|
||||
}
|
||||
//merge
|
||||
sumVec0 = vaddq_f32(sumVec0, sumVec1);
|
||||
float32x2_t sumP = vpadd_f32(vget_low_f32(sumVec0), vget_high_f32(sumVec0));
|
||||
sumP = vpadd_f32(sumP, sumP);
|
||||
auto newSum = vget_lane_f32(sumP, 0);
|
||||
if (size > 0) {
|
||||
float tmp[4];
|
||||
memcpy(tmp, srcPtr, size * sizeof(float));
|
||||
float32x4_t srcVec0 = vld1q_f32(tmp);
|
||||
auto subVec0 = vsubq_f32(srcVec0, maxVec);
|
||||
auto expVec0 = expApprox(subVec0);
|
||||
vst1q_f32(tmp, expVec0);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
newSum += tmp[i];
|
||||
destPtr[i] = tmp[i];
|
||||
}
|
||||
}
|
||||
offset[3] += newSum;
|
||||
} else {
|
||||
float32x4_t c0 = vdupq_n_f32(offset[0]);
|
||||
float32x4_t c1 = vdupq_n_f32(offset[1]);
|
||||
while (size >= 8) {
|
||||
float32x4_t srcVec0 = vld1q_f32(srcPtr);
|
||||
float32x4_t srcVec1 = vld1q_f32(srcPtr + 4);
|
||||
auto subVec0 = vsubq_f32(vmulq_f32(srcVec0, c0), maxVec);
|
||||
auto subVec1 = vsubq_f32(vmulq_f32(srcVec1, c0), maxVec);
|
||||
auto expVec0 = vaddq_f32(expApprox(subVec0), c1);
|
||||
auto expVec1 = vaddq_f32(expApprox(subVec1), c1);
|
||||
vst1q_f32(destPtr, expVec0);
|
||||
vst1q_f32(destPtr + 4, expVec1);
|
||||
sumVec0 = vaddq_f32(sumVec0, expVec0);
|
||||
sumVec1 = vaddq_f32(sumVec1, expVec1);
|
||||
srcPtr += 8;
|
||||
destPtr += 8;
|
||||
size -= 8;
|
||||
|
||||
}
|
||||
while (size >= 4) {
|
||||
float32x4_t srcVec0 = vld1q_f32(srcPtr);
|
||||
auto subVec0 = vsubq_f32(vmulq_f32(srcVec0, c0), maxVec);
|
||||
auto expVec0 = vaddq_f32(expApprox(subVec0), c1);
|
||||
sumVec0 = vaddq_f32(sumVec0, expVec0);
|
||||
vst1q_f32(destPtr, expVec0);
|
||||
srcPtr += 4;
|
||||
destPtr += 4;
|
||||
size -= 4;
|
||||
}
|
||||
//merge
|
||||
sumVec0 = vaddq_f32(sumVec0, sumVec1);
|
||||
float32x2_t sumP = vpadd_f32(vget_low_f32(sumVec0), vget_high_f32(sumVec0));
|
||||
sumP = vpadd_f32(sumP, sumP);
|
||||
auto newSum = vget_lane_f32(sumP, 0);
|
||||
if (size > 0) {
|
||||
float tmp[4];
|
||||
memcpy(tmp, srcPtr, size * sizeof(float));
|
||||
float32x4_t srcVec0 = vld1q_f32(tmp);
|
||||
auto subVec0 = vsubq_f32(vmulq_f32(srcVec0, c0), maxVec);
|
||||
auto expVec0 = vaddq_f32(expApprox(subVec0), c1);
|
||||
vst1q_f32(tmp, expVec0);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
newSum += tmp[i];
|
||||
destPtr[i] = tmp[i];
|
||||
}
|
||||
}
|
||||
offset[3] += newSum;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static inline void transposeAndStore4x4(const float* srcRowPtrs[4], float* dstColBase, size_t dstColStride) {
|
||||
float32x4_t row0 = vld1q_f32(srcRowPtrs[0]);
|
||||
float32x4_t row1 = vld1q_f32(srcRowPtrs[1]);
|
||||
float32x4_t row2 = vld1q_f32(srcRowPtrs[2]);
|
||||
float32x4_t row3 = vld1q_f32(srcRowPtrs[3]);
|
||||
|
||||
// Step 1: Transpose 2x2 blocks of 2-element vectors
|
||||
float32x4x2_t t01 = vtrnq_f32(row0, row1);
|
||||
float32x4x2_t t23 = vtrnq_f32(row2, row3);
|
||||
|
||||
// Step 2: Combine the results to get the full transpose
|
||||
float32x4_t col0 = vcombine_f32(vget_low_f32(t01.val[0]), vget_low_f32(t23.val[0]));
|
||||
float32x4_t col1 = vcombine_f32(vget_low_f32(t01.val[1]), vget_low_f32(t23.val[1]));
|
||||
float32x4_t col2 = vcombine_f32(vget_high_f32(t01.val[0]), vget_high_f32(t23.val[0]));
|
||||
float32x4_t col3 = vcombine_f32(vget_high_f32(t01.val[1]), vget_high_f32(t23.val[1]));
|
||||
|
||||
vst1q_f32(dstColBase, col0);
|
||||
vst1q_f32(dstColBase + dstColStride, col1);
|
||||
vst1q_f32(dstColBase + 2 * dstColStride, col2);
|
||||
vst1q_f32(dstColBase + 3 * dstColStride, col3);
|
||||
}
|
||||
|
||||
|
||||
void packKvCache(float* dst, const float* src, size_t seqLen, size_t kvSeqLen, size_t eP) {
|
||||
if (seqLen == 0 || kvSeqLen == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// source [seqLen, kvSeqLen]
|
||||
// dest [seqLen/eP, kvSeqLen, eP]
|
||||
|
||||
const int kTileS = 4; // Tiling size for seqLen dimension
|
||||
const int kTileK = 4; // Tiling size for kvSeqLen dimension
|
||||
const size_t dstSOuterStride = kvSeqLen * eP;
|
||||
|
||||
int s = 0;
|
||||
for (; s + kTileS <= seqLen; s += kTileS) {
|
||||
const int sOuter = s / eP;
|
||||
const int sInner = s % eP;
|
||||
if (sInner + kTileS > eP) {
|
||||
break;
|
||||
}
|
||||
|
||||
float* dstSBase = dst + sOuter * dstSOuterStride + sInner;
|
||||
const float* srcRowPtrs[kTileS];
|
||||
srcRowPtrs[0] = src + (s + 0) * kvSeqLen;
|
||||
srcRowPtrs[1] = src + (s + 1) * kvSeqLen;
|
||||
srcRowPtrs[2] = src + (s + 2) * kvSeqLen;
|
||||
srcRowPtrs[3] = src + (s + 3) * kvSeqLen;
|
||||
|
||||
int k = 0;
|
||||
for (; k + kTileK <= kvSeqLen; k += kTileK) {
|
||||
const float* currentSrcPtrs[kTileS];
|
||||
currentSrcPtrs[0] = srcRowPtrs[0] + k;
|
||||
currentSrcPtrs[1] = srcRowPtrs[1] + k;
|
||||
currentSrcPtrs[2] = srcRowPtrs[2] + k;
|
||||
currentSrcPtrs[3] = srcRowPtrs[3] + k;
|
||||
float* dstKBase = dstSBase + k * eP;
|
||||
transposeAndStore4x4(currentSrcPtrs, dstKBase, eP);
|
||||
}
|
||||
|
||||
for (; k < kvSeqLen; ++k) {
|
||||
float buffer[kTileS] = {
|
||||
srcRowPtrs[0][k],
|
||||
srcRowPtrs[1][k],
|
||||
srcRowPtrs[2][k],
|
||||
srcRowPtrs[3][k]
|
||||
};
|
||||
vst1q_f32(dstSBase + k * eP, vld1q_f32(buffer));
|
||||
}
|
||||
}
|
||||
|
||||
for (; s < seqLen; ++s) {
|
||||
const int sOuter = s / eP;
|
||||
const int sInner = s % eP;
|
||||
const float* srcRow = src + s * kvSeqLen;
|
||||
float* dstSBase = dst + sOuter * dstSOuterStride + sInner;
|
||||
for (int k = 0; k < kvSeqLen; ++k) {
|
||||
dstSBase[k * eP] = srcRow[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) {
|
||||
for (int k = 0; k < outside; ++k) {
|
||||
auto source = input + k * reduceSize;
|
||||
auto dest = softmaxDst + k * reduceSize;
|
||||
|
||||
// new max
|
||||
auto srcPtr = source;
|
||||
auto size = reduceSize;
|
||||
float32x4_t maxVec0 = vdupq_n_f32(source[0]);
|
||||
auto maxVec1 = maxVec0;
|
||||
|
||||
float oldMax = source[0];
|
||||
if (runningMax) {
|
||||
oldMax = runningMax[k];
|
||||
}
|
||||
|
||||
while (size >= 8) {
|
||||
float32x4_t srcVec0 = vld1q_f32(srcPtr);
|
||||
float32x4_t srcVec1 = vld1q_f32(srcPtr + 4);
|
||||
|
||||
maxVec0 = vmaxq_f32(maxVec0, srcVec0);
|
||||
maxVec1 = vmaxq_f32(maxVec1, srcVec1);
|
||||
|
||||
srcPtr += 8;
|
||||
size -= 8;
|
||||
}
|
||||
|
||||
while (size >= 4) {
|
||||
float32x4_t srcVec0 = vld1q_f32(srcPtr);
|
||||
maxVec0 = vmaxq_f32(maxVec0, srcVec0);
|
||||
srcPtr += 4;
|
||||
size -= 4;
|
||||
}
|
||||
|
||||
maxVec0 = vmaxq_f32(maxVec0, maxVec1);
|
||||
float32x2_t maxP = vpmax_f32(vget_low_f32(maxVec0), vget_high_f32(maxVec0));
|
||||
maxP = vpmax_f32(maxP, maxP);
|
||||
auto newMax = vget_lane_f32(maxP, 0);
|
||||
|
||||
while (size > 0) {
|
||||
newMax = ALIMAX(newMax, srcPtr[0]);
|
||||
srcPtr += 1;
|
||||
size -= 1;
|
||||
}
|
||||
|
||||
newMax = ALIMAX(oldMax, newMax);
|
||||
srcPtr = source;
|
||||
auto destPtr = dest;
|
||||
size = reduceSize;
|
||||
|
||||
float exprOffset[4] = {
|
||||
1.0f,
|
||||
0.0f,
|
||||
0.0f,
|
||||
0.0f
|
||||
};
|
||||
exprOffset[2] = -newMax;
|
||||
|
||||
// expf(xi-newmax) & new sum
|
||||
MNNExp(destPtr, srcPtr, exprOffset, size);
|
||||
|
||||
if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) {
|
||||
// update runningSum, runningMax, scale=expf(oldMax-newMax)
|
||||
float newSum = exprOffset[3];
|
||||
runningSum[k] = runningSum[k] * expf(oldMax - newMax) + newSum;
|
||||
runningMax[k] = newMax;
|
||||
updateScale[k] = expf(oldMax - newMax);
|
||||
} else {
|
||||
// Normalize
|
||||
float sum = exprOffset[3];
|
||||
float scale = 1.0f / (sum + 1e-20f);
|
||||
int count = reduceSize;
|
||||
auto pDest = dest;
|
||||
|
||||
float32x4_t scaleVec = vdupq_n_f32(scale);
|
||||
while (count >= 4) {
|
||||
float32x4_t data = vld1q_f32(pDest);
|
||||
data = vmulq_f32(data, scaleVec);
|
||||
vst1q_f32(pDest, data);
|
||||
|
||||
pDest += 4;
|
||||
count -= 4;
|
||||
}
|
||||
|
||||
while (count > 0) {
|
||||
*pDest *= scale;
|
||||
pDest++;
|
||||
count--;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#ifndef MNN_USE_NEON
|
||||
|
||||
void MNNPackedSparseMatMulEpx1(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, unsigned int* NNZMap, int* dataOffsetMap) {
|
||||
|
|
|
@ -1,115 +0,0 @@
|
|||
//
|
||||
// MNNExpC8.S
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2019/01/18.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifdef __arm__
|
||||
#ifndef __aarch64__
|
||||
|
||||
#include "MNNAsmGlobal.h"
|
||||
.text
|
||||
.align 5
|
||||
//void MNNExpC8(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8)
|
||||
asm_function MNNExpC8
|
||||
|
||||
//r0: dest, r1:source, r2: offset, r3:parameters, r4:countC8
|
||||
push {r4, r5, lr}
|
||||
ldr r4, [sp, #12]
|
||||
vpush {q4-q7}
|
||||
vmov.i32 q7, #0
|
||||
ldr r5, [r2, #0]
|
||||
vdup.32 q4, r5 // Alpha
|
||||
ldr r5, [r2, #4]
|
||||
vdup.32 q5, r5 // Beta
|
||||
ldr r5, [r2, #8]
|
||||
vdup.32 q6, r5 // Bias
|
||||
|
||||
|
||||
vld1.32 {q0, q1}, [r3]
|
||||
|
||||
vmov.i32 q2, #87
|
||||
vcvt.f32.s32 q2, q2
|
||||
vneg.f32 q3, q2
|
||||
|
||||
Loop:
|
||||
|
||||
vld1.32 {q8, q9}, [r1]!
|
||||
vmul.f32 q8, q8, q4
|
||||
vmul.f32 q9, q9, q4
|
||||
vadd.f32 q8, q8, q6
|
||||
vadd.f32 q9, q9, q6
|
||||
|
||||
vmin.f32 q8, q8, q2
|
||||
vmin.f32 q9, q9, q2
|
||||
vmax.f32 q10, q8, q3
|
||||
vmax.f32 q11, q9, q3
|
||||
|
||||
vmul.f32 q8, q10, d0[1]
|
||||
vmul.f32 q9, q11, d0[1]
|
||||
vcvt.s32.f32 q8, q8
|
||||
vcvt.s32.f32 q9, q9
|
||||
|
||||
vcvt.f32.s32 q12, q8
|
||||
vcvt.f32.s32 q13, q9
|
||||
|
||||
//q10, q11: t
|
||||
vmls.f32 q10, q12, d0[0]
|
||||
vmls.f32 q11, q13, d0[0]
|
||||
|
||||
vmul.f32 q10, q10, d1[0]
|
||||
vmul.f32 q11, q11, d1[0]
|
||||
|
||||
.macro MLA_TWO z0 z1 z2 z3
|
||||
vdup.32 \z1, \z0
|
||||
vmla.f32 \z1, \z2, \z3
|
||||
.endm
|
||||
|
||||
MLA_TWO d3[0], q12, q10, d3[1]
|
||||
MLA_TWO d3[0], q13, q11, d3[1]
|
||||
MLA_TWO d2[1], q14, q10, q12
|
||||
MLA_TWO d2[1], q15, q11, q13
|
||||
MLA_TWO d2[0], q12, q10, q14
|
||||
MLA_TWO d2[0], q13, q11, q15
|
||||
MLA_TWO d1[1], q14, q10, q12
|
||||
MLA_TWO d1[1], q15, q11, q13
|
||||
MLA_TWO d1[1], q12, q10, q14
|
||||
MLA_TWO d1[1], q13, q11, q15
|
||||
|
||||
//q12, q13 is expRemain
|
||||
|
||||
vmul.f32 q12, q12, q12
|
||||
vmul.f32 q13, q13, q13
|
||||
vmul.f32 q12, q12, q12
|
||||
vmul.f32 q13, q13, q13
|
||||
|
||||
vshl.i32 q8, q8, #23
|
||||
vshl.i32 q9, q9, #23
|
||||
vadd.i32 q12, q12, q8
|
||||
vadd.i32 q13, q13, q9
|
||||
|
||||
vadd.f32 q12, q12, q5
|
||||
vadd.f32 q13, q13, q5
|
||||
vadd.f32 q7, q12, q7
|
||||
|
||||
vst1.32 {q12, q13}, [r0]!
|
||||
vadd.f32 q7, q13, q7
|
||||
|
||||
subs r4, r4, #1
|
||||
bne Loop
|
||||
add r5, r2, #12
|
||||
vld1.32 {d0[0]}, [r5]
|
||||
vadd.f32 d14, d14, d15
|
||||
vtrn.32 d14, d15
|
||||
vadd.f32 d14, d14, d15
|
||||
vadd.f32 d0, d14, d0
|
||||
vst1.32 {d0[0]}, [r5]
|
||||
|
||||
vpop {q4-q7}
|
||||
pop {r4, r5, pc}
|
||||
|
||||
|
||||
#endif
|
||||
#endif
|
|
@ -1,270 +0,0 @@
|
|||
//
|
||||
// MNNSoftmax.S
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2021/07/05.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifdef __arm__
|
||||
#ifndef __aarch64__
|
||||
|
||||
#include "MNNAsmGlobal.h"
|
||||
.text
|
||||
.align 5
|
||||
//void MNNSoftmax(float* dest, const float* source, size_t size)
|
||||
asm_function MNNSoftmax
|
||||
push {r4, r5, r6, r7, r8, lr}
|
||||
bic r3, r2, #3
|
||||
lsrs r4, r2, #2
|
||||
vpush.64 {d8, d9, d10, d11, d12, d13, d14, d15}
|
||||
sub sp, sp, #8
|
||||
beq Loop_2
|
||||
vld1.32 {d16-d17}, [r1]
|
||||
cmp r4, #1
|
||||
beq Loop_3
|
||||
add lr, r1, #16
|
||||
mov ip, #1
|
||||
Loop_4:
|
||||
vld1.32 {d18-d19}, [lr]!
|
||||
add ip, ip, #1
|
||||
cmp r4, ip
|
||||
vmax.f32 q8, q8, q9
|
||||
bne Loop_4
|
||||
Loop_3:
|
||||
vmov.32 ip, d16[0]
|
||||
vmov s15, ip
|
||||
vmov.32 ip, d16[1]
|
||||
vmov s12, ip
|
||||
vmov.32 ip, d17[0]
|
||||
vcmpe.f32 s15, s12
|
||||
vmov s13, ip
|
||||
vmov.32 ip, d17[1]
|
||||
vmrs APSR_nzcv, FPSCR
|
||||
vmov s14, ip
|
||||
vmovle.f32 s15, s12
|
||||
vcmpe.f32 s13, s15
|
||||
vmrs APSR_nzcv, FPSCR
|
||||
vmovpl.f32 s15, s13
|
||||
vcmpe.f32 s14, s15
|
||||
vmrs APSR_nzcv, FPSCR
|
||||
vmovpl.f32 s15, s14
|
||||
cmp r2, r3
|
||||
bls Loop_25
|
||||
Loop_24:
|
||||
add lr, r1, r3, lsl #2
|
||||
mov ip, r3
|
||||
Loop_11:
|
||||
vldmia.32 lr!, {s14}
|
||||
add ip, ip, #1
|
||||
vcmpe.f32 s14, s15
|
||||
vmrs APSR_nzcv, FPSCR
|
||||
vmovpl.f32 s15, s14
|
||||
cmp r2, ip
|
||||
bhi Loop_11
|
||||
cmp r4, #0
|
||||
beq Loop_8
|
||||
Loop_25:
|
||||
vmov.f32 q11, #0.0 @ v4sf
|
||||
mov r5, r1
|
||||
vldr d14, Loop_54
|
||||
vldr d15, Loop_54+8
|
||||
mov lr, r0
|
||||
vldr d12, Loop_54+16
|
||||
vldr d13, Loop_54+24
|
||||
mov ip, #0
|
||||
vmov.f32 q12, #1.0e+0 @ v4sf
|
||||
vstr.32 s15, [sp, #4]
|
||||
vmov.f32 q5, #5.0e-1 @ v4sf
|
||||
vldr d8, Loop_54+32
|
||||
vldr d9, Loop_54+40
|
||||
vldr d0, Loop_54+48
|
||||
vldr d1, Loop_54+56
|
||||
vldr d2, Loop_54+64
|
||||
vldr d3, Loop_54+72
|
||||
vldr d4, Loop_54+80
|
||||
vldr d5, Loop_54+88
|
||||
vldr d30, Loop_54+96
|
||||
vldr d31, Loop_54+104
|
||||
vmov.i32 q14, #8388608 @ v4si
|
||||
vdup.32 q13, d7[1]
|
||||
Loop_12:
|
||||
vld1.32 {d20-d21}, [r5]!
|
||||
add ip, ip, #1
|
||||
vmov.i32 q3, #127 @ v4si
|
||||
cmp r4, ip
|
||||
vsub.f32 q10, q10, q13
|
||||
vmax.f32 q10, q10, q15
|
||||
vmin.f32 q10, q10, q2
|
||||
vmul.f32 q9, q10, q6
|
||||
vcvt.s32.f32 q9, q9
|
||||
vcvt.f32.s32 q8, q9
|
||||
vadd.i32 q9, q9, q3
|
||||
vmul.f32 q8, q8, q7
|
||||
vmul.i32 q9, q9, q14
|
||||
vsub.f32 q10, q10, q8
|
||||
vmul.f32 q8, q1, q10
|
||||
vadd.f32 q8, q8, q0
|
||||
vmul.f32 q8, q8, q10
|
||||
vadd.f32 q8, q8, q4
|
||||
vmul.f32 q8, q8, q10
|
||||
vadd.f32 q8, q8, q5
|
||||
vmul.f32 q8, q8, q10
|
||||
vadd.f32 q8, q8, q12
|
||||
vmul.f32 q8, q8, q10
|
||||
vadd.f32 q8, q8, q12
|
||||
vmul.f32 q9, q9, q8
|
||||
vadd.f32 q11, q9, q11
|
||||
vst1.32 {d18-d19}, [lr]!
|
||||
bgt Loop_12
|
||||
vmov.32 ip, d22[0]
|
||||
vldr.32 s15, [sp, #4]
|
||||
cmp r2, r3
|
||||
vmov s12, ip
|
||||
vmov.32 ip, d22[1]
|
||||
vmov s11, ip
|
||||
vmov.32 ip, d23[0]
|
||||
vadd.f32 s12, s12, s11
|
||||
vmov s13, ip
|
||||
vmov.32 ip, d23[1]
|
||||
vadd.f32 s12, s12, s13
|
||||
vmov s14, ip
|
||||
vadd.f32 s12, s12, s14
|
||||
bls Loop_52
|
||||
Loop_26:
|
||||
lsl ip, r3, #2
|
||||
add r5, r1, r2, lsl #2
|
||||
add lr, r1, ip
|
||||
movw r7, #13877
|
||||
movt r7, 179
|
||||
movw r6, #55317
|
||||
movt r6, 32310
|
||||
vldr.32 s11, Loop_54+112
|
||||
vldr.32 s10, Loop_54+116
|
||||
add r1, r0, ip
|
||||
vldr.32 s9, Loop_54+120
|
||||
vmov.f32 s8, #5.0e-1
|
||||
vldr.32 s5, Loop_54+124
|
||||
vldr.32 s6, Loop_54+128
|
||||
vldr.32 s7, Loop_54+132
|
||||
Loop_17:
|
||||
vldmia.32 lr!, {s14}
|
||||
vmov s13, r6
|
||||
vsub.f32 s14, s14, s15
|
||||
vcmpe.f32 s14, s11
|
||||
vmrs APSR_nzcv, FPSCR
|
||||
vmovle s13, r7
|
||||
ble Loop_14
|
||||
vcmpe.f32 s14, s10
|
||||
vmrs APSR_nzcv, FPSCR
|
||||
bmi Loop_53
|
||||
Loop_14:
|
||||
vadd.f32 s12, s12, s13
|
||||
cmp r5, lr
|
||||
vstmia.32 r1!, {s13}
|
||||
bne Loop_17
|
||||
vmov.f32 s15, #1.0e+0
|
||||
cmp r4, #0
|
||||
vdiv.f32 s14, s15, s12
|
||||
beq Loop_20
|
||||
Loop_23:
|
||||
vdup.32 q9, d7[0]
|
||||
mov ip, r0
|
||||
mov r1, #0
|
||||
Loop_19:
|
||||
vld1.32 {d16-d17}, [ip]
|
||||
add r1, r1, #1
|
||||
cmp r4, r1
|
||||
vmul.f32 q8, q8, q9
|
||||
vst1.32 {d16-d17}, [ip]!
|
||||
bgt Loop_19
|
||||
cmp r2, r3
|
||||
bls Loop_1
|
||||
lsl ip, r3, #2
|
||||
Loop_20:
|
||||
add r0, r0, ip
|
||||
Loop_21:
|
||||
vldr.32 s15, [r0]
|
||||
add r3, r3, #1
|
||||
cmp r2, r3
|
||||
vmul.f32 s15, s15, s14
|
||||
vstmia.32 r0!, {s15}
|
||||
bhi Loop_21
|
||||
Loop_1:
|
||||
add sp, sp, #8
|
||||
vldm sp!, {d8-d15}
|
||||
pop {r4, r5, r6, r7, r8, pc}
|
||||
Loop_2:
|
||||
vldr.32 s15, Loop_54+136
|
||||
cmp r2, r3
|
||||
bhi Loop_24
|
||||
Loop_8:
|
||||
cmp r2, r3
|
||||
vldrhi.32 s12, Loop_54+136
|
||||
bhi Loop_26
|
||||
b Loop_1
|
||||
Loop_52:
|
||||
vmov.f32 s15, #1.0e+0
|
||||
cmp r4, #0
|
||||
vdiv.f32 s14, s15, s12
|
||||
bne Loop_23
|
||||
b Loop_1
|
||||
Loop_53:
|
||||
vdiv.f32 s4, s14, s9
|
||||
vmov.f32 s13, #1.0e+0
|
||||
vcvt.s32.f32 s4, s4
|
||||
vcvt.f32.s32 s3, s4
|
||||
vmov r8, s4 @ int
|
||||
vmov.f32 s4, s7
|
||||
vmls.f32 s14, s3, s9
|
||||
vmov.f32 s3, s6
|
||||
add r8, r8, #127
|
||||
lsl r8, r8, #23
|
||||
vmla.f32 s3, s14, s5
|
||||
vmla.f32 s4, s3, s14
|
||||
vmov.f32 s3, s8
|
||||
vmla.f32 s3, s4, s14
|
||||
vmov.f32 s4, s13
|
||||
vmla.f32 s4, s3, s14
|
||||
vmla.f32 s13, s4, s14
|
||||
vmov s14, r8
|
||||
vmul.f32 s13, s13, s14
|
||||
b Loop_14
|
||||
Loop_54:
|
||||
.word 1060205080
|
||||
.word 1060205080
|
||||
.word 1060205080
|
||||
.word 1060205080
|
||||
.word 1069066811
|
||||
.word 1069066811
|
||||
.word 1069066811
|
||||
.word 1069066811
|
||||
.word 1042983595
|
||||
.word 1042983595
|
||||
.word 1042983595
|
||||
.word 1042983595
|
||||
.word 1026206379
|
||||
.word 1026206379
|
||||
.word 1026206379
|
||||
.word 1026206379
|
||||
.word 1007192201
|
||||
.word 1007192201
|
||||
.word 1007192201
|
||||
.word 1007192201
|
||||
.word 1118699520
|
||||
.word 1118699520
|
||||
.word 1118699520
|
||||
.word 1118699520
|
||||
.word -1028784128
|
||||
.word -1028784128
|
||||
.word -1028784128
|
||||
.word -1028784128
|
||||
.word -1028784128
|
||||
.word 1118699520
|
||||
.word 1060205080
|
||||
.word 1007192201
|
||||
.word 1026206379
|
||||
.word 1042983595
|
||||
.word 0
|
||||
#endif
|
||||
#endif
|
|
@ -1,109 +0,0 @@
|
|||
//
|
||||
// MNNExpC8.S
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2019/01/18.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifdef __aarch64__
|
||||
|
||||
#include "MNNAsmGlobal.h"
|
||||
.text
|
||||
.align 5
|
||||
|
||||
//void MNNExpC8(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8)
|
||||
asm_function MNNExpC8
|
||||
|
||||
//x0: dest, x1:source, x2: offset, x3:parameters, x4:countC8
|
||||
ldr w5, [x2, #0]
|
||||
ldr w6, [x2, #4]
|
||||
ldr w7, [x2, #8]
|
||||
|
||||
ld1 {v0.4s, v1.4s}, [x3]
|
||||
movi v2.4s, #23
|
||||
movi v3.4s, #87
|
||||
scvtf v3.4s, v3.4s
|
||||
fneg v4.4s, v3.4s
|
||||
dup v30.4s, w5
|
||||
dup v31.4s, w6
|
||||
dup v29.4s, w7
|
||||
|
||||
// Summer
|
||||
movi v28.4s, #0
|
||||
|
||||
Loop:
|
||||
|
||||
ld1 {v16.4s, v17.4s}, [x1], #32
|
||||
fmul v16.4s, v16.4s, v30.4s
|
||||
fmul v17.4s, v17.4s, v30.4s
|
||||
fadd v16.4s, v16.4s, v29.4s
|
||||
fadd v17.4s, v17.4s, v29.4s
|
||||
fmin v16.4s, v16.4s, v3.4s
|
||||
fmin v17.4s, v17.4s, v3.4s
|
||||
fmax v18.4s, v16.4s, v4.4s
|
||||
fmax v19.4s, v17.4s, v4.4s
|
||||
|
||||
fmul v16.4s, v18.4s, v0.s[1]
|
||||
fmul v17.4s, v19.4s, v0.s[1]
|
||||
fcvtzs v16.4s, v16.4s
|
||||
fcvtzs v17.4s, v17.4s
|
||||
scvtf v20.4s, v16.4s
|
||||
scvtf v21.4s, v17.4s
|
||||
|
||||
//v18.4s, v19.4s: t
|
||||
fmls v18.4s, v20.4s, v0.s[0]
|
||||
fmls v19.4s, v21.4s, v0.s[0]
|
||||
|
||||
fmul v18.4s, v18.4s, v0.s[2]
|
||||
fmul v19.4s, v19.4s, v0.s[2]
|
||||
|
||||
.macro MLA_TWO z0 z1 z2 z3
|
||||
dup \z1, \z0
|
||||
fmla \z1, \z2, \z3
|
||||
.endm
|
||||
|
||||
MLA_TWO v1.s[2], v20.4s, v18.4s, v1.s[3]
|
||||
MLA_TWO v1.s[2], v21.4s, v19.4s, v1.s[3]
|
||||
MLA_TWO v1.s[1], v22.4s, v18.4s, v20.4s
|
||||
MLA_TWO v1.s[1], v23.4s, v19.4s, v21.4s
|
||||
MLA_TWO v1.s[0], v20.4s, v18.4s, v22.4s
|
||||
MLA_TWO v1.s[0], v21.4s, v19.4s, v23.4s
|
||||
MLA_TWO v0.s[3], v22.4s, v18.4s, v20.4s
|
||||
MLA_TWO v0.s[3], v23.4s, v19.4s, v21.4s
|
||||
MLA_TWO v0.s[3], v20.4s, v18.4s, v22.4s
|
||||
MLA_TWO v0.s[3], v21.4s, v19.4s, v23.4s
|
||||
|
||||
//v20.4s, v21.4s is expRemain
|
||||
fmul v20.4s, v20.4s, v20.4s
|
||||
fmul v21.4s, v21.4s, v21.4s
|
||||
fmul v20.4s, v20.4s, v20.4s
|
||||
fmul v21.4s, v21.4s, v21.4s
|
||||
|
||||
ushl v16.4s, v16.4s, v2.4s
|
||||
ushl v17.4s, v17.4s, v2.4s
|
||||
add v20.4s, v20.4s, v16.4s
|
||||
add v21.4s, v21.4s, v17.4s
|
||||
|
||||
fadd v20.4s, v20.4s, v31.4s
|
||||
fadd v21.4s, v21.4s, v31.4s
|
||||
|
||||
st1 {v20.4s, v21.4s}, [x0], #32
|
||||
fadd v28.4s, v28.4s, v20.4s
|
||||
fadd v28.4s, v28.4s, v21.4s
|
||||
|
||||
subs x4, x4, #1
|
||||
bne Loop
|
||||
|
||||
// Bias
|
||||
add x7, x2, #12
|
||||
ld1 {v27.s}[0], [x7]
|
||||
faddp v28.4s, v28.4s, v28.4s
|
||||
faddp v28.2s, v28.2s, v28.2s
|
||||
fadd v27.2s, v28.2s, v27.2s
|
||||
st1 {v27.s}[0], [x7]
|
||||
|
||||
ret
|
||||
|
||||
#endif
|
||||
|
|
@ -1,204 +0,0 @@
|
|||
//
|
||||
// MNNSoftmax.S
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2021/07/05.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifdef __aarch64__
|
||||
|
||||
#include "MNNAsmGlobal.h"
|
||||
.text
|
||||
.align 5
|
||||
|
||||
//void MNNSoftmax(float* dest, const float* source, size_t countC8)
|
||||
asm_function MNNSoftmax
|
||||
stp x21, x22, [sp, #-32]!
|
||||
stp x19, x20, [sp, #16]
|
||||
sxtw x8, w2
|
||||
lsr w9, w2, #2
|
||||
and x8, x8, #-4
|
||||
cbz w9, Loop_5
|
||||
ldr q0, [x1]
|
||||
cmp w9, #1
|
||||
b.eq Loop_4
|
||||
add x10, x1, #16
|
||||
sub x11, x9, #1
|
||||
Loop_3:
|
||||
ldr q1, [x10], #16
|
||||
subs x11, x11, #1
|
||||
fmax v0.4s, v0.4s, v1.4s
|
||||
b.ne Loop_3
|
||||
Loop_4:
|
||||
mov s1, v0.s[1]
|
||||
fcmp s0, s1
|
||||
mov s2, v0.s[2]
|
||||
fcsel s1, s0, s1, gt
|
||||
fcmp s1, s2
|
||||
fcsel s1, s1, s2, gt
|
||||
mov s0, v0.s[3]
|
||||
fcmp s1, s0
|
||||
fcsel s0, s1, s0, gt
|
||||
cmp w8, w2
|
||||
b.lo Loop_6
|
||||
b Loop_8
|
||||
Loop_5:
|
||||
fmov s0, wzr
|
||||
cmp w8, w2
|
||||
b.hs Loop_8
|
||||
Loop_6:
|
||||
add x10, x1, w8, sxtw #2
|
||||
mov w11, w8
|
||||
Loop_7:
|
||||
ldr s1, [x10], #4
|
||||
add w11, w11, #1
|
||||
fcmp s0, s1
|
||||
fcsel s0, s0, s1, gt
|
||||
cmp w11, w2
|
||||
b.lo Loop_7
|
||||
Loop_8:
|
||||
cbz w9, Loop_12
|
||||
mov w10, #-1028784128
|
||||
mov w12, #43579
|
||||
dup v4.4s, w10
|
||||
mov w10, #29208
|
||||
mov w11, #1118699520
|
||||
movk w12, #16312, lsl #16
|
||||
movk w10, #48945, lsl #16
|
||||
dup v5.4s, w11
|
||||
mov w11, #34953
|
||||
dup v6.4s, w12
|
||||
mov w12, #43691
|
||||
dup v7.4s, w10
|
||||
mov w10, #43691
|
||||
movk w11, #15368, lsl #16
|
||||
movk w12, #15658, lsl #16
|
||||
movk w10, #15914, lsl #16
|
||||
dup v2.4s, v0.s[0]
|
||||
movi v1.16b, #0
|
||||
fmov v3.4s, #1.0
|
||||
dup v16.4s, w11
|
||||
dup v17.4s, w12
|
||||
dup v18.4s, w10
|
||||
movi v19.4s, #63, lsl #24
|
||||
mov x10, x9
|
||||
mov x11, x1
|
||||
mov x12, x0
|
||||
Loop_10:
|
||||
ldr q20, [x11], #16
|
||||
subs x10, x10, #1
|
||||
fsub v20.4s, v20.4s, v2.4s
|
||||
fmax v20.4s, v20.4s, v4.4s
|
||||
fmin v20.4s, v20.4s, v5.4s
|
||||
fmul v21.4s, v20.4s, v6.4s
|
||||
fcvtzs v21.4s, v21.4s
|
||||
scvtf v22.4s, v21.4s
|
||||
fmla v20.4s, v22.4s, v7.4s
|
||||
fmul v22.4s, v20.4s, v16.4s
|
||||
fadd v22.4s, v22.4s, v17.4s
|
||||
fmul v22.4s, v20.4s, v22.4s
|
||||
fadd v22.4s, v22.4s, v18.4s
|
||||
fmul v22.4s, v20.4s, v22.4s
|
||||
fadd v22.4s, v22.4s, v19.4s
|
||||
fmul v22.4s, v20.4s, v22.4s
|
||||
fadd v22.4s, v22.4s, v3.4s
|
||||
shl v21.4s, v21.4s, #23
|
||||
fmul v20.4s, v20.4s, v22.4s
|
||||
add v21.4s, v21.4s, v3.4s
|
||||
fadd v20.4s, v20.4s, v3.4s
|
||||
fmul v20.4s, v20.4s, v21.4s
|
||||
fadd v1.4s, v1.4s, v20.4s
|
||||
str q20, [x12], #16
|
||||
b.ne Loop_10
|
||||
dup v2.4s, v1.s[1]
|
||||
dup v3.4s, v1.s[2]
|
||||
fadd v2.4s, v1.4s, v2.4s
|
||||
fadd v2.4s, v3.4s, v2.4s
|
||||
dup v1.4s, v1.s[3]
|
||||
fadd v1.4s, v1.4s, v2.4s
|
||||
b Loop_13
|
||||
Loop_12:
|
||||
fmov s1, wzr
|
||||
Loop_13:
|
||||
cmp w8, w2
|
||||
fmov s2, #1.0
|
||||
b.hs Loop_16
|
||||
lsl x21, x8, #2
|
||||
mov w12, #29208
|
||||
mov w14, #34953
|
||||
mov w15, #43691
|
||||
mov w19, #43691
|
||||
mov w10, #-1028784128
|
||||
mov w11, #1118699520
|
||||
movk w12, #16177, lsl #16
|
||||
mov w13, #1065353216
|
||||
movk w14, #15368, lsl #16
|
||||
movk w15, #15658, lsl #16
|
||||
movk w19, #15914, lsl #16
|
||||
add x20, x1, x21
|
||||
add x21, x0, x21
|
||||
fmov s3, #0.5
|
||||
mov w1, w8
|
||||
Loop_15:
|
||||
ldr s4, [x20], #4
|
||||
fmov s5, w10
|
||||
fmov s6, w11
|
||||
fmov s7, w12
|
||||
fsub s4, s4, s0
|
||||
fmaxnm s4, s4, s5
|
||||
fminnm s4, s4, s6
|
||||
fdiv s6, s4, s7
|
||||
fcvtzs w3, s6
|
||||
scvtf s6, w3
|
||||
fmul s6, s6, s7
|
||||
fmov s5, w14
|
||||
fsub s4, s4, s6
|
||||
fmov s7, w15
|
||||
fmul s5, s4, s5
|
||||
fadd s5, s5, s7
|
||||
fmov s6, w19
|
||||
fmul s5, s4, s5
|
||||
fadd s5, s5, s6
|
||||
fmul s5, s4, s5
|
||||
fadd s5, s5, s3
|
||||
fmul s5, s4, s5
|
||||
fadd s5, s5, s2
|
||||
add w3, w13, w3, lsl #23
|
||||
fmul s4, s4, s5
|
||||
fmov s7, w3
|
||||
fadd s4, s4, s2
|
||||
add w1, w1, #1
|
||||
fmul s4, s4, s7
|
||||
cmp w1, w2
|
||||
str s4, [x21], #4
|
||||
fadd s1, s1, s4
|
||||
b.lo Loop_15
|
||||
Loop_16:
|
||||
fdiv s0, s2, s1
|
||||
cbz w9, Loop_19
|
||||
dup v1.4s, v0.s[0]
|
||||
mov x10, x0
|
||||
Loop_18:
|
||||
ldr q2, [x10]
|
||||
subs x9, x9, #1
|
||||
fmul v2.4s, v1.4s, v2.4s
|
||||
str q2, [x10], #16
|
||||
b.ne Loop_18
|
||||
Loop_19:
|
||||
cmp w8, w2
|
||||
b.hs Loop_22
|
||||
add x9, x0, w8, sxtw #2
|
||||
Loop_21:
|
||||
ldr s1, [x9]
|
||||
add w8, w8, #1
|
||||
cmp w8, w2
|
||||
fmul s1, s0, s1
|
||||
str s1, [x9], #4
|
||||
b.lo Loop_21
|
||||
Loop_22:
|
||||
ldp x19, x20, [sp, #16]
|
||||
ldp x21, x22, [sp], #32
|
||||
ret
|
||||
#endif
|
||||
|
|
@ -379,10 +379,10 @@ LoopSzEnd_TILE_1:
|
|||
.inst 0x05b02324 // dup z4.q, z25.q[2]
|
||||
.inst 0x05f02325 // dup z5.q, z25.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -404,10 +404,10 @@ LoopSzEnd_TILE_1:
|
|||
.inst 0x05f02325 // dup z5.q, z25.q[3]
|
||||
.inst 0x05702346 // dup z6.q, z26.q[1]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -431,11 +431,11 @@ LoopSzEnd_TILE_1:
|
|||
.inst 0x05702346 // dup z6.q, z26.q[1]
|
||||
.inst 0x05b02347 // dup z7.q, z26.q[2]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -461,11 +461,11 @@ TILE1_STORE48:
|
|||
.inst 0x05b02347 // dup z7.q, z26.q[2]
|
||||
.inst 0x05f02348 // dup z8.q, z26.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -492,12 +492,12 @@ TILE1_STORE52:
|
|||
.inst 0x05b02347 // dup z7.q, z26.q[2]
|
||||
.inst 0x05f02348 // dup z8.q, z26.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -526,12 +526,12 @@ TILE1_STORE56:
|
|||
.inst 0x05f02348 // dup z8.q, z26.q[3]
|
||||
.inst 0x05702369 // dup z9.q, z27.q[1]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -562,13 +562,13 @@ TILE1_STORE60:
|
|||
.inst 0x05702369 // dup z9.q, z27.q[1]
|
||||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -601,13 +601,13 @@ TILE1_STORE64:
|
|||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -640,15 +640,14 @@ TILE1_STORE68:
|
|||
.inst 0x05702369 // dup z9.q, z27.q[1]
|
||||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -667,7 +666,7 @@ TILE1_STORE68:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
b TILE1_Dz_End
|
||||
|
@ -685,14 +684,15 @@ TILE1_STORE72:
|
|||
.inst 0x05702369 // dup z9.q, z27.q[1]
|
||||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -711,7 +711,7 @@ TILE1_STORE72:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
|
@ -731,14 +731,15 @@ TILE1_STORE76:
|
|||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
.inst 0x05b0238d // dup z13.q, z28.q[2]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -757,8 +758,8 @@ TILE1_STORE76:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
|
@ -779,14 +780,16 @@ TILE1_STORE80:
|
|||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
.inst 0x05b0238d // dup z13.q, z28.q[2]
|
||||
.inst 0x05f0238e // dup z14.q, z28.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -805,12 +808,13 @@ TILE1_STORE80:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
|
||||
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
|
||||
b TILE1_Dz_End
|
||||
|
||||
TILE1_STORE84:
|
||||
|
@ -827,14 +831,17 @@ TILE1_STORE84:
|
|||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
.inst 0x05b0238d // dup z13.q, z28.q[2]
|
||||
.inst 0x05f0238e // dup z14.q, z28.q[3]
|
||||
.inst 0x057023af // dup z15.q, z29.q[1]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -853,12 +860,15 @@ TILE1_STORE84:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
|
||||
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
|
||||
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
|
||||
b TILE1_Dz_End
|
||||
|
||||
TILE1_STORE88:
|
||||
|
@ -876,14 +886,16 @@ TILE1_STORE88:
|
|||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
.inst 0x05b0238d // dup z13.q, z28.q[2]
|
||||
.inst 0x05f0238e // dup z14.q, z28.q[3]
|
||||
.inst 0x057023af // dup z15.q, z29.q[1]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -902,12 +914,16 @@ TILE1_STORE88:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
|
||||
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
|
||||
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
|
||||
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
|
||||
b TILE1_Dz_End
|
||||
|
||||
TILE1_STORE92:
|
||||
|
@ -926,14 +942,17 @@ TILE1_STORE92:
|
|||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
.inst 0x05b0238d // dup z13.q, z28.q[2]
|
||||
.inst 0x05f0238e // dup z14.q, z28.q[3]
|
||||
.inst 0x057023af // dup z15.q, z29.q[1]
|
||||
.inst 0x05b023b0 // dup z16.q, z29.q[2]
|
||||
.inst 0x05f023b1 // dup z17.q, z29.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -952,15 +971,18 @@ TILE1_STORE92:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x8, x23, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
add x11, x5, x4, lsl #3 // +22
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
|
||||
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
|
||||
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
|
||||
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
|
||||
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
|
||||
b TILE1_Dz_End
|
||||
|
||||
TILE1_STORE96:
|
||||
|
@ -980,14 +1002,16 @@ TILE1_STORE96:
|
|||
.inst 0x05b0238d // dup z13.q, z28.q[2]
|
||||
.inst 0x05f0238e // dup z14.q, z28.q[3]
|
||||
.inst 0x057023af // dup z15.q, z29.q[1]
|
||||
.inst 0x05b023b0 // dup z16.q, z29.q[2]
|
||||
.inst 0x05f023b1 // dup z17.q, z29.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -1006,9 +1030,10 @@ TILE1_STORE96:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x8, x23, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
add x11, x5, x4, lsl #3 // +22
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
|
@ -1016,6 +1041,8 @@ TILE1_STORE96:
|
|||
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
|
||||
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
|
||||
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
|
||||
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
|
||||
.inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4]
|
||||
b TILE1_Dz_End
|
||||
|
||||
TILE1_STORE100:
|
||||
|
@ -1036,14 +1063,15 @@ TILE1_STORE100:
|
|||
.inst 0x05f0238e // dup z14.q, z28.q[3]
|
||||
.inst 0x057023af // dup z15.q, z29.q[1]
|
||||
.inst 0x05b023b0 // dup z16.q, z29.q[2]
|
||||
.inst 0x05f023b1 // dup z17.q, z29.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -1062,10 +1090,11 @@ TILE1_STORE100:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x8, x23, x4, lsl #3
|
||||
add x11, x5, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
add x11, x5, x4, lsl #3 // +22
|
||||
add x13, x9, x4, lsl #3 // +24
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
|
@ -1074,6 +1103,8 @@ TILE1_STORE100:
|
|||
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
|
||||
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
|
||||
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
|
||||
.inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4]
|
||||
.inst 0xe400f5be // st1b {z30.b}, p5, [x13]
|
||||
b TILE1_Dz_End
|
||||
|
||||
TILE1_STORE104:
|
||||
|
@ -1095,75 +1126,15 @@ TILE1_STORE104:
|
|||
.inst 0x057023af // dup z15.q, z29.q[1]
|
||||
.inst 0x05b023b0 // dup z16.q, z29.q[2]
|
||||
.inst 0x05f023b1 // dup z17.q, z29.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
.inst 0xe400f521 // st1b {z1.b}, p5, [x9]
|
||||
.inst 0xe4045522 // st1b {z2.b}, p5, [x9, x4]
|
||||
.inst 0xe400f5f9 // st1b {z25.b}, p5, [x15]
|
||||
.inst 0xe40455e3 // st1b {z3.b}, p5, [x15, x4]
|
||||
.inst 0xe400f504 // st1b {z4.b}, p5, [x8]
|
||||
.inst 0xe4045505 // st1b {z5.b}, p5, [x8, x4]
|
||||
.inst 0xe400f57a // st1b {z26.b}, p5, [x11]
|
||||
.inst 0xe4045566 // st1b {z6.b}, p5, [x11, x4]
|
||||
.inst 0xe400f5a7 // st1b {z7.b}, p5, [x13]
|
||||
.inst 0xe40455a8 // st1b {z8.b}, p5, [x13, x4]
|
||||
.inst 0xe400f6fb // st1b {z27.b}, p5, [x23]
|
||||
.inst 0xe40456e9 // st1b {z9.b}, p5, [x23, x4]
|
||||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x8, x23, x4, lsl #3
|
||||
add x11, x5, x4, lsl #3
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
|
||||
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
|
||||
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
|
||||
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
|
||||
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
|
||||
.inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4]
|
||||
b TILE1_Dz_End
|
||||
|
||||
TILE1_STORE108:
|
||||
.inst 0x05702300 // dup z0.q, z24.q[1]
|
||||
.inst 0x05b02301 // dup z1.q, z24.q[2]
|
||||
.inst 0x05f02302 // dup z2.q, z24.q[3]
|
||||
.inst 0x05702323 // dup z3.q, z25.q[1]
|
||||
.inst 0x05b02324 // dup z4.q, z25.q[2]
|
||||
.inst 0x05f02325 // dup z5.q, z25.q[3]
|
||||
.inst 0x05702346 // dup z6.q, z26.q[1]
|
||||
.inst 0x05b02347 // dup z7.q, z26.q[2]
|
||||
.inst 0x05f02348 // dup z8.q, z26.q[3]
|
||||
.inst 0x05702369 // dup z9.q, z27.q[1]
|
||||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
.inst 0x05b0238d // dup z13.q, z28.q[2]
|
||||
.inst 0x05f0238e // dup z14.q, z28.q[3]
|
||||
.inst 0x057023af // dup z15.q, z29.q[1]
|
||||
.inst 0x05b023b0 // dup z16.q, z29.q[2]
|
||||
.inst 0x05f023b1 // dup z17.q, z29.q[3]
|
||||
.inst 0x057023d2 // dup z18.q, z30.q[1]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -1182,11 +1153,11 @@ TILE1_STORE108:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x8, x23, x4, lsl #3
|
||||
add x11, x5, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
add x11, x5, x4, lsl #3 // +22
|
||||
add x13, x9, x4, lsl #3 // +24
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
|
@ -1200,7 +1171,8 @@ TILE1_STORE108:
|
|||
.inst 0xe40455b2 // st1b {z18.b}, p5, [x13, x4]
|
||||
b TILE1_Dz_End
|
||||
|
||||
TILE1_STORE112:
|
||||
/* oc=108 */
|
||||
TILE1_STORE108:
|
||||
.inst 0x05702300 // dup z0.q, z24.q[1]
|
||||
.inst 0x05b02301 // dup z1.q, z24.q[2]
|
||||
.inst 0x05f02302 // dup z2.q, z24.q[3]
|
||||
|
@ -1222,13 +1194,13 @@ TILE1_STORE108:
|
|||
.inst 0x057023d2 // dup z18.q, z30.q[1]
|
||||
.inst 0x05b023d3 // dup z19.q, z30.q[2]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -1247,12 +1219,12 @@ TILE1_STORE108:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x8, x23, x4, lsl #3
|
||||
add x11, x5, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
add x11, x5, x4, lsl #3 // +22
|
||||
add x13, x9, x4, lsl #3 // +24
|
||||
add x23, x15, x4, lsl #3 // +26
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
|
@ -1267,6 +1239,77 @@ TILE1_STORE108:
|
|||
.inst 0xe400f6f3 // st1b {z19.b}, p5, [x23]
|
||||
b TILE1_Dz_End
|
||||
|
||||
/* oc=112 */
|
||||
TILE1_STORE112:
|
||||
.inst 0x05702300 // dup z0.q, z24.q[1]
|
||||
.inst 0x05b02301 // dup z1.q, z24.q[2]
|
||||
.inst 0x05f02302 // dup z2.q, z24.q[3]
|
||||
.inst 0x05702323 // dup z3.q, z25.q[1]
|
||||
.inst 0x05b02324 // dup z4.q, z25.q[2]
|
||||
.inst 0x05f02325 // dup z5.q, z25.q[3]
|
||||
.inst 0x05702346 // dup z6.q, z26.q[1]
|
||||
.inst 0x05b02347 // dup z7.q, z26.q[2]
|
||||
.inst 0x05f02348 // dup z8.q, z26.q[3]
|
||||
.inst 0x05702369 // dup z9.q, z27.q[1]
|
||||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
.inst 0x05b0238d // dup z13.q, z28.q[2]
|
||||
.inst 0x05f0238e // dup z14.q, z28.q[3]
|
||||
.inst 0x057023af // dup z15.q, z29.q[1]
|
||||
.inst 0x05b023b0 // dup z16.q, z29.q[2]
|
||||
.inst 0x05f023b1 // dup z17.q, z29.q[3]
|
||||
.inst 0x057023d2 // dup z18.q, z30.q[1]
|
||||
.inst 0x05b023d3 // dup z19.q, z30.q[2]
|
||||
.inst 0x05f023d4 // dup z20.q, z30.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
.inst 0xe400f521 // st1b {z1.b}, p5, [x9]
|
||||
.inst 0xe4045522 // st1b {z2.b}, p5, [x9, x4]
|
||||
.inst 0xe400f5f9 // st1b {z25.b}, p5, [x15]
|
||||
.inst 0xe40455e3 // st1b {z3.b}, p5, [x15, x4]
|
||||
.inst 0xe400f504 // st1b {z4.b}, p5, [x8]
|
||||
.inst 0xe4045505 // st1b {z5.b}, p5, [x8, x4]
|
||||
.inst 0xe400f57a // st1b {z26.b}, p5, [x11]
|
||||
.inst 0xe4045566 // st1b {z6.b}, p5, [x11, x4]
|
||||
.inst 0xe400f5a7 // st1b {z7.b}, p5, [x13]
|
||||
.inst 0xe40455a8 // st1b {z8.b}, p5, [x13, x4]
|
||||
.inst 0xe400f6fb // st1b {z27.b}, p5, [x23]
|
||||
.inst 0xe40456e9 // st1b {z9.b}, p5, [x23, x4]
|
||||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
add x11, x5, x4, lsl #3 // +22
|
||||
add x13, x9, x4, lsl #3 // +24
|
||||
add x23, x15, x4, lsl #3 // +26
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
|
||||
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
|
||||
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
|
||||
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
|
||||
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
|
||||
.inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4]
|
||||
.inst 0xe400f5be // st1b {z30.b}, p5, [x13]
|
||||
.inst 0xe40455b2 // st1b {z18.b}, p5, [x13, x4]
|
||||
.inst 0xe400f6f3 // st1b {z19.b}, p5, [x23]
|
||||
.inst 0xe40456f4 // st1b {z20.b}, p5, [x23, x4]
|
||||
b TILE1_Dz_End
|
||||
|
||||
/* oc=116 */
|
||||
TILE1_STORE116:
|
||||
.inst 0x05702300 // dup z0.q, z24.q[1]
|
||||
.inst 0x05b02301 // dup z1.q, z24.q[2]
|
||||
|
@ -1290,13 +1333,13 @@ TILE1_STORE108:
|
|||
.inst 0x05b023d3 // dup z19.q, z30.q[2]
|
||||
.inst 0x05f023d4 // dup z20.q, z30.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -1315,13 +1358,13 @@ TILE1_STORE108:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x8, x23, x4, lsl #3
|
||||
add x11, x5, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
add x11, x5, x4, lsl #3 // +22
|
||||
add x13, x9, x4, lsl #3 // +24
|
||||
add x23, x15, x4, lsl #3 // +26
|
||||
add x5, x8, x4, lsl #3 // +28
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
|
@ -1338,6 +1381,7 @@ TILE1_STORE108:
|
|||
.inst 0xe400f4bf // st1b {z31.b}, p5, [x5]
|
||||
b TILE1_Dz_End
|
||||
|
||||
/* oc=120 */
|
||||
TILE1_STORE120:
|
||||
.inst 0x05702300 // dup z0.q, z24.q[1]
|
||||
.inst 0x05b02301 // dup z1.q, z24.q[2]
|
||||
|
@ -1362,13 +1406,13 @@ TILE1_STORE108:
|
|||
.inst 0x05f023d4 // dup z20.q, z30.q[3]
|
||||
.inst 0x057023f5 // dup z21.q, z31.q[1]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -1387,13 +1431,13 @@ TILE1_STORE108:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x8, x23, x4, lsl #3
|
||||
add x11, x5, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
add x11, x5, x4, lsl #3 // +22
|
||||
add x13, x9, x4, lsl #3 // +24
|
||||
add x23, x15, x4, lsl #3 // +26
|
||||
add x5, x8, x4, lsl #3 // +28
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
|
@ -1411,6 +1455,7 @@ TILE1_STORE108:
|
|||
.inst 0xe40454b5 // st1b {z21.b}, p5, [x5, x4]
|
||||
b TILE1_Dz_End
|
||||
|
||||
/* oc=124 */
|
||||
TILE1_STORE124:
|
||||
.inst 0x05702300 // dup z0.q, z24.q[1]
|
||||
.inst 0x05b02301 // dup z1.q, z24.q[2]
|
||||
|
@ -1487,6 +1532,7 @@ TILE1_STORE108:
|
|||
.inst 0xe400f596 // st1b {z22.b}, p5, [x12]
|
||||
b TILE1_Dz_End
|
||||
|
||||
/* oc=128 */
|
||||
TILE1_STORE128:
|
||||
.inst 0x05702300 // dup z0.q, z24.q[1]
|
||||
.inst 0x05b02301 // dup z1.q, z24.q[2]
|
||||
|
|
|
@ -372,10 +372,10 @@ LoopSzEnd_TILE_1:
|
|||
.inst 0x05b02324 // dup z4.q, z25.q[2]
|
||||
.inst 0x05f02325 // dup z5.q, z25.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -397,10 +397,10 @@ LoopSzEnd_TILE_1:
|
|||
.inst 0x05f02325 // dup z5.q, z25.q[3]
|
||||
.inst 0x05702346 // dup z6.q, z26.q[1]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -424,11 +424,11 @@ LoopSzEnd_TILE_1:
|
|||
.inst 0x05702346 // dup z6.q, z26.q[1]
|
||||
.inst 0x05b02347 // dup z7.q, z26.q[2]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -454,11 +454,11 @@ TILE1_STORE48:
|
|||
.inst 0x05b02347 // dup z7.q, z26.q[2]
|
||||
.inst 0x05f02348 // dup z8.q, z26.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -485,12 +485,12 @@ TILE1_STORE52:
|
|||
.inst 0x05b02347 // dup z7.q, z26.q[2]
|
||||
.inst 0x05f02348 // dup z8.q, z26.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -519,12 +519,12 @@ TILE1_STORE56:
|
|||
.inst 0x05f02348 // dup z8.q, z26.q[3]
|
||||
.inst 0x05702369 // dup z9.q, z27.q[1]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -555,13 +555,13 @@ TILE1_STORE60:
|
|||
.inst 0x05702369 // dup z9.q, z27.q[1]
|
||||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -594,13 +594,13 @@ TILE1_STORE64:
|
|||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -633,15 +633,14 @@ TILE1_STORE68:
|
|||
.inst 0x05702369 // dup z9.q, z27.q[1]
|
||||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -660,7 +659,7 @@ TILE1_STORE68:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
b TILE1_Dz_End
|
||||
|
@ -678,14 +677,15 @@ TILE1_STORE72:
|
|||
.inst 0x05702369 // dup z9.q, z27.q[1]
|
||||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -704,7 +704,7 @@ TILE1_STORE72:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
|
@ -724,14 +724,15 @@ TILE1_STORE76:
|
|||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
.inst 0x05b0238d // dup z13.q, z28.q[2]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -750,8 +751,8 @@ TILE1_STORE76:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
|
@ -772,14 +773,16 @@ TILE1_STORE80:
|
|||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
.inst 0x05b0238d // dup z13.q, z28.q[2]
|
||||
.inst 0x05f0238e // dup z14.q, z28.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -798,12 +801,13 @@ TILE1_STORE80:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
|
||||
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
|
||||
b TILE1_Dz_End
|
||||
|
||||
TILE1_STORE84:
|
||||
|
@ -820,14 +824,16 @@ TILE1_STORE84:
|
|||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
.inst 0x05b0238d // dup z13.q, z28.q[2]
|
||||
.inst 0x05f0238e // dup z14.q, z28.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -846,12 +852,15 @@ TILE1_STORE84:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
|
||||
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
|
||||
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
|
||||
b TILE1_Dz_End
|
||||
|
||||
TILE1_STORE88:
|
||||
|
@ -869,14 +878,16 @@ TILE1_STORE88:
|
|||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
.inst 0x05b0238d // dup z13.q, z28.q[2]
|
||||
.inst 0x05f0238e // dup z14.q, z28.q[3]
|
||||
.inst 0x057023af // dup z15.q, z29.q[1]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -895,12 +906,16 @@ TILE1_STORE88:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
|
||||
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
|
||||
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
|
||||
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
|
||||
b TILE1_Dz_End
|
||||
|
||||
TILE1_STORE92:
|
||||
|
@ -919,14 +934,17 @@ TILE1_STORE92:
|
|||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
.inst 0x05b0238d // dup z13.q, z28.q[2]
|
||||
.inst 0x05f0238e // dup z14.q, z28.q[3]
|
||||
.inst 0x057023af // dup z15.q, z29.q[1]
|
||||
.inst 0x05b023b0 // dup z16.q, z29.q[2]
|
||||
.inst 0x05f023b1 // dup z17.q, z29.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -945,15 +963,18 @@ TILE1_STORE92:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x8, x23, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
add x11, x5, x4, lsl #3 // +22
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
|
||||
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
|
||||
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
|
||||
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
|
||||
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
|
||||
b TILE1_Dz_End
|
||||
|
||||
TILE1_STORE96:
|
||||
|
@ -973,14 +994,16 @@ TILE1_STORE96:
|
|||
.inst 0x05b0238d // dup z13.q, z28.q[2]
|
||||
.inst 0x05f0238e // dup z14.q, z28.q[3]
|
||||
.inst 0x057023af // dup z15.q, z29.q[1]
|
||||
.inst 0x05b023b0 // dup z16.q, z29.q[2]
|
||||
.inst 0x05f023b1 // dup z17.q, z29.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -999,9 +1022,10 @@ TILE1_STORE96:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x8, x23, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
add x11, x5, x4, lsl #3 // +22
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
|
@ -1009,6 +1033,8 @@ TILE1_STORE96:
|
|||
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
|
||||
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
|
||||
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
|
||||
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
|
||||
.inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4]
|
||||
b TILE1_Dz_End
|
||||
|
||||
TILE1_STORE100:
|
||||
|
@ -1029,14 +1055,15 @@ TILE1_STORE100:
|
|||
.inst 0x05f0238e // dup z14.q, z28.q[3]
|
||||
.inst 0x057023af // dup z15.q, z29.q[1]
|
||||
.inst 0x05b023b0 // dup z16.q, z29.q[2]
|
||||
.inst 0x05f023b1 // dup z17.q, z29.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -1055,10 +1082,11 @@ TILE1_STORE100:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x8, x23, x4, lsl #3
|
||||
add x11, x5, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
add x11, x5, x4, lsl #3 // +22
|
||||
add x13, x9, x4, lsl #3 // +24
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
|
@ -1067,6 +1095,8 @@ TILE1_STORE100:
|
|||
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
|
||||
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
|
||||
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
|
||||
.inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4]
|
||||
.inst 0xe400f5be // st1b {z30.b}, p5, [x13]
|
||||
b TILE1_Dz_End
|
||||
|
||||
TILE1_STORE104:
|
||||
|
@ -1088,75 +1118,15 @@ TILE1_STORE104:
|
|||
.inst 0x057023af // dup z15.q, z29.q[1]
|
||||
.inst 0x05b023b0 // dup z16.q, z29.q[2]
|
||||
.inst 0x05f023b1 // dup z17.q, z29.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
.inst 0xe400f521 // st1b {z1.b}, p5, [x9]
|
||||
.inst 0xe4045522 // st1b {z2.b}, p5, [x9, x4]
|
||||
.inst 0xe400f5f9 // st1b {z25.b}, p5, [x15]
|
||||
.inst 0xe40455e3 // st1b {z3.b}, p5, [x15, x4]
|
||||
.inst 0xe400f504 // st1b {z4.b}, p5, [x8]
|
||||
.inst 0xe4045505 // st1b {z5.b}, p5, [x8, x4]
|
||||
.inst 0xe400f57a // st1b {z26.b}, p5, [x11]
|
||||
.inst 0xe4045566 // st1b {z6.b}, p5, [x11, x4]
|
||||
.inst 0xe400f5a7 // st1b {z7.b}, p5, [x13]
|
||||
.inst 0xe40455a8 // st1b {z8.b}, p5, [x13, x4]
|
||||
.inst 0xe400f6fb // st1b {z27.b}, p5, [x23]
|
||||
.inst 0xe40456e9 // st1b {z9.b}, p5, [x23, x4]
|
||||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x8, x23, x4, lsl #3
|
||||
add x11, x5, x4, lsl #3
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
|
||||
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
|
||||
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
|
||||
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
|
||||
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
|
||||
.inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4]
|
||||
b TILE1_Dz_End
|
||||
|
||||
TILE1_STORE108:
|
||||
.inst 0x05702300 // dup z0.q, z24.q[1]
|
||||
.inst 0x05b02301 // dup z1.q, z24.q[2]
|
||||
.inst 0x05f02302 // dup z2.q, z24.q[3]
|
||||
.inst 0x05702323 // dup z3.q, z25.q[1]
|
||||
.inst 0x05b02324 // dup z4.q, z25.q[2]
|
||||
.inst 0x05f02325 // dup z5.q, z25.q[3]
|
||||
.inst 0x05702346 // dup z6.q, z26.q[1]
|
||||
.inst 0x05b02347 // dup z7.q, z26.q[2]
|
||||
.inst 0x05f02348 // dup z8.q, z26.q[3]
|
||||
.inst 0x05702369 // dup z9.q, z27.q[1]
|
||||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
.inst 0x05b0238d // dup z13.q, z28.q[2]
|
||||
.inst 0x05f0238e // dup z14.q, z28.q[3]
|
||||
.inst 0x057023af // dup z15.q, z29.q[1]
|
||||
.inst 0x05b023b0 // dup z16.q, z29.q[2]
|
||||
.inst 0x05f023b1 // dup z17.q, z29.q[3]
|
||||
.inst 0x057023d2 // dup z18.q, z30.q[1]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -1175,11 +1145,11 @@ TILE1_STORE108:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x8, x23, x4, lsl #3
|
||||
add x11, x5, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
add x11, x5, x4, lsl #3 // +22
|
||||
add x13, x9, x4, lsl #3 // +24
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
|
@ -1193,7 +1163,8 @@ TILE1_STORE108:
|
|||
.inst 0xe40455b2 // st1b {z18.b}, p5, [x13, x4]
|
||||
b TILE1_Dz_End
|
||||
|
||||
TILE1_STORE112:
|
||||
/* oc=108 */
|
||||
TILE1_STORE108:
|
||||
.inst 0x05702300 // dup z0.q, z24.q[1]
|
||||
.inst 0x05b02301 // dup z1.q, z24.q[2]
|
||||
.inst 0x05f02302 // dup z2.q, z24.q[3]
|
||||
|
@ -1215,13 +1186,13 @@ TILE1_STORE108:
|
|||
.inst 0x057023d2 // dup z18.q, z30.q[1]
|
||||
.inst 0x05b023d3 // dup z19.q, z30.q[2]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -1240,12 +1211,12 @@ TILE1_STORE108:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x8, x23, x4, lsl #3
|
||||
add x11, x5, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
add x11, x5, x4, lsl #3 // +22
|
||||
add x13, x9, x4, lsl #3 // +24
|
||||
add x23, x15, x4, lsl #3 // +26
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
|
@ -1260,6 +1231,77 @@ TILE1_STORE108:
|
|||
.inst 0xe400f6f3 // st1b {z19.b}, p5, [x23]
|
||||
b TILE1_Dz_End
|
||||
|
||||
/* oc=112 */
|
||||
TILE1_STORE112:
|
||||
.inst 0x05702300 // dup z0.q, z24.q[1]
|
||||
.inst 0x05b02301 // dup z1.q, z24.q[2]
|
||||
.inst 0x05f02302 // dup z2.q, z24.q[3]
|
||||
.inst 0x05702323 // dup z3.q, z25.q[1]
|
||||
.inst 0x05b02324 // dup z4.q, z25.q[2]
|
||||
.inst 0x05f02325 // dup z5.q, z25.q[3]
|
||||
.inst 0x05702346 // dup z6.q, z26.q[1]
|
||||
.inst 0x05b02347 // dup z7.q, z26.q[2]
|
||||
.inst 0x05f02348 // dup z8.q, z26.q[3]
|
||||
.inst 0x05702369 // dup z9.q, z27.q[1]
|
||||
.inst 0x05b0236a // dup z10.q, z27.q[2]
|
||||
.inst 0x05f0236b // dup z11.q, z27.q[3]
|
||||
.inst 0x0570238c // dup z12.q, z28.q[1]
|
||||
.inst 0x05b0238d // dup z13.q, z28.q[2]
|
||||
.inst 0x05f0238e // dup z14.q, z28.q[3]
|
||||
.inst 0x057023af // dup z15.q, z29.q[1]
|
||||
.inst 0x05b023b0 // dup z16.q, z29.q[2]
|
||||
.inst 0x05f023b1 // dup z17.q, z29.q[3]
|
||||
.inst 0x057023d2 // dup z18.q, z30.q[1]
|
||||
.inst 0x05b023d3 // dup z19.q, z30.q[2]
|
||||
.inst 0x05f023d4 // dup z20.q, z30.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
.inst 0xe400f521 // st1b {z1.b}, p5, [x9]
|
||||
.inst 0xe4045522 // st1b {z2.b}, p5, [x9, x4]
|
||||
.inst 0xe400f5f9 // st1b {z25.b}, p5, [x15]
|
||||
.inst 0xe40455e3 // st1b {z3.b}, p5, [x15, x4]
|
||||
.inst 0xe400f504 // st1b {z4.b}, p5, [x8]
|
||||
.inst 0xe4045505 // st1b {z5.b}, p5, [x8, x4]
|
||||
.inst 0xe400f57a // st1b {z26.b}, p5, [x11]
|
||||
.inst 0xe4045566 // st1b {z6.b}, p5, [x11, x4]
|
||||
.inst 0xe400f5a7 // st1b {z7.b}, p5, [x13]
|
||||
.inst 0xe40455a8 // st1b {z8.b}, p5, [x13, x4]
|
||||
.inst 0xe400f6fb // st1b {z27.b}, p5, [x23]
|
||||
.inst 0xe40456e9 // st1b {z9.b}, p5, [x23, x4]
|
||||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
add x11, x5, x4, lsl #3 // +22
|
||||
add x13, x9, x4, lsl #3 // +24
|
||||
add x23, x15, x4, lsl #3 // +26
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
.inst 0xe400f5ed // st1b {z13.b}, p5, [x15]
|
||||
.inst 0xe40455ee // st1b {z14.b}, p5, [x15, x4]
|
||||
.inst 0xe400f51d // st1b {z29.b}, p5, [x8]
|
||||
.inst 0xe404550f // st1b {z15.b}, p5, [x8, x4]
|
||||
.inst 0xe400f570 // st1b {z16.b}, p5, [x11]
|
||||
.inst 0xe4045571 // st1b {z17.b}, p5, [x11, x4]
|
||||
.inst 0xe400f5be // st1b {z30.b}, p5, [x13]
|
||||
.inst 0xe40455b2 // st1b {z18.b}, p5, [x13, x4]
|
||||
.inst 0xe400f6f3 // st1b {z19.b}, p5, [x23]
|
||||
.inst 0xe40456f4 // st1b {z20.b}, p5, [x23, x4]
|
||||
b TILE1_Dz_End
|
||||
|
||||
/* oc=116 */
|
||||
TILE1_STORE116:
|
||||
.inst 0x05702300 // dup z0.q, z24.q[1]
|
||||
.inst 0x05b02301 // dup z1.q, z24.q[2]
|
||||
|
@ -1283,13 +1325,13 @@ TILE1_STORE108:
|
|||
.inst 0x05b023d3 // dup z19.q, z30.q[2]
|
||||
.inst 0x05f023d4 // dup z20.q, z30.q[3]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -1308,13 +1350,13 @@ TILE1_STORE108:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x8, x23, x4, lsl #3
|
||||
add x11, x5, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
add x11, x5, x4, lsl #3 // +22
|
||||
add x13, x9, x4, lsl #3 // +24
|
||||
add x23, x15, x4, lsl #3 // +26
|
||||
add x5, x8, x4, lsl #3 // +28
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
|
@ -1331,6 +1373,7 @@ TILE1_STORE108:
|
|||
.inst 0xe400f4bf // st1b {z31.b}, p5, [x5]
|
||||
b TILE1_Dz_End
|
||||
|
||||
/* oc=120 */
|
||||
TILE1_STORE120:
|
||||
.inst 0x05702300 // dup z0.q, z24.q[1]
|
||||
.inst 0x05b02301 // dup z1.q, z24.q[2]
|
||||
|
@ -1355,13 +1398,13 @@ TILE1_STORE108:
|
|||
.inst 0x05f023d4 // dup z20.q, z30.q[3]
|
||||
.inst 0x057023f5 // dup z21.q, z31.q[1]
|
||||
|
||||
add x9, x6, x4, lsl #1
|
||||
add x15, x6, x4, lsl #2
|
||||
add x8, x9, x4, lsl #2
|
||||
add x11, x6, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #1 // +2
|
||||
add x15, x6, x4, lsl #2 // +4
|
||||
add x8, x9, x4, lsl #2 // +6
|
||||
add x11, x6, x4, lsl #3 // +8
|
||||
add x13, x9, x4, lsl #3 // +10
|
||||
add x23, x15, x4, lsl #3 // +12
|
||||
add x5, x8, x4, lsl #3 // +14
|
||||
|
||||
.inst 0xe400f4d8 // st1b {z24.b}, p5, [x6]
|
||||
.inst 0xe40454c0 // st1b {z0.b}, p5, [x6, x4]
|
||||
|
@ -1380,13 +1423,13 @@ TILE1_STORE108:
|
|||
.inst 0xe400f4aa // st1b {z10.b}, p5, [x5]
|
||||
.inst 0xe40454ab // st1b {z11.b}, p5, [x5, x4]
|
||||
|
||||
add x9, x6, x4, lsl #4
|
||||
add x15, x13, x4, lsl #3
|
||||
add x8, x23, x4, lsl #3
|
||||
add x11, x5, x4, lsl #3
|
||||
add x13, x9, x4, lsl #3
|
||||
add x23, x15, x4, lsl #3
|
||||
add x5, x8, x4, lsl #3
|
||||
add x9, x6, x4, lsl #4 // +16
|
||||
add x15, x13, x4, lsl #3 // +18
|
||||
add x8, x23, x4, lsl #3 // +20
|
||||
add x11, x5, x4, lsl #3 // +22
|
||||
add x13, x9, x4, lsl #3 // +24
|
||||
add x23, x15, x4, lsl #3 // +26
|
||||
add x5, x8, x4, lsl #3 // +28
|
||||
|
||||
.inst 0xe400f53c // st1b {z28.b}, p5, [x9]
|
||||
.inst 0xe404552c // st1b {z12.b}, p5, [x9, x4]
|
||||
|
@ -1404,6 +1447,7 @@ TILE1_STORE108:
|
|||
.inst 0xe40454b5 // st1b {z21.b}, p5, [x5, x4]
|
||||
b TILE1_Dz_End
|
||||
|
||||
/* oc=124 */
|
||||
TILE1_STORE124:
|
||||
.inst 0x05702300 // dup z0.q, z24.q[1]
|
||||
.inst 0x05b02301 // dup z1.q, z24.q[2]
|
||||
|
@ -1480,6 +1524,7 @@ TILE1_STORE108:
|
|||
.inst 0xe400f596 // st1b {z22.b}, p5, [x12]
|
||||
b TILE1_Dz_End
|
||||
|
||||
/* oc=128 */
|
||||
TILE1_STORE128:
|
||||
.inst 0x05702300 // dup z0.q, z24.q[1]
|
||||
.inst 0x05b02301 // dup z1.q, z24.q[2]
|
||||
|
|
|
@ -1390,6 +1390,89 @@ void MNNAccumulateSequenceNumber (float* dst, const float* src, int size) {
|
|||
*dst = sum;
|
||||
}
|
||||
|
||||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
|
||||
static void MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) {
|
||||
// source shape: [headDim/pack, seqLen, pack]
|
||||
// scale & normalizeScale shape: [seqLen]
|
||||
// dest shape: [headDim/pack, seqLen, pack]
|
||||
auto stride0 = plane * pack;
|
||||
|
||||
if (idx > 0) {
|
||||
for (int j = 0; j < depthQuad; ++j) {
|
||||
for (int i = 0; i < plane; ++i) {
|
||||
auto dataNew = Vec::load(src + j * stride0 + i * pack);
|
||||
auto dataOld = Vec::load(dst + j * stride0 + i * pack);
|
||||
auto s = Vec(scale[i]);
|
||||
dataNew = Vec::fma(dataNew, dataOld, s);
|
||||
Vec::save(dst + j * stride0 + i * pack, dataNew);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
memcpy(dst, src, size * bytes);
|
||||
}
|
||||
if (idx == kvBlocks - 1) { // if last subBlock, exp(xi)/sum(exp(xi))
|
||||
for (int j = 0; j < depthQuad; ++j) {
|
||||
for (int i = 0; i < plane; ++i) {
|
||||
auto dataNew = Vec::load(dst + j * stride0 + i * pack);
|
||||
auto ns = Vec(1.0f / normalizeScale[i]);
|
||||
dataNew = dataNew * ns;
|
||||
Vec::save(dst + j * stride0 + i * pack, dataNew);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void MNNAttenPackAndScaleSingleHead(float* dst, const float* srcHeadBase, size_t srcRowStride, const float* scale, const int32_t* units, size_t seqLen, size_t headDim) {
|
||||
const int32_t eP = units[0];
|
||||
const int32_t lP = units[1];
|
||||
|
||||
if (lP != 1) {
|
||||
MNN_ERROR("This function only supports lP=1 or 2\n");
|
||||
return;
|
||||
}
|
||||
|
||||
const float scaleVal = scale[0];
|
||||
#ifdef MNN_USE_NEON
|
||||
const float32x4_t vScale = vdupq_n_f32(scaleVal);
|
||||
#endif
|
||||
const size_t packedHeadDim = UP_DIV(headDim, lP);
|
||||
const size_t dstStrideDOuter = (size_t)eP * lP;
|
||||
const size_t dstStrideSOuter = packedHeadDim * dstStrideDOuter;
|
||||
|
||||
for (int s = 0; s < seqLen; ++s) {
|
||||
const int sOuter = s / eP;
|
||||
const int sInner = s % eP;
|
||||
const float* srcRowPtr = srcHeadBase + s * srcRowStride;
|
||||
float* dstBasePtr = dst + sOuter * dstStrideSOuter + sInner * lP;
|
||||
|
||||
size_t d = 0;
|
||||
#ifdef MNN_USE_NEON
|
||||
for (; d + 7 < headDim; d += 8) {
|
||||
float32x4_t sVec0 = vld1q_f32(srcRowPtr + d);
|
||||
float32x4_t sVec1 = vld1q_f32(srcRowPtr + d + 4);
|
||||
sVec0 = vmulq_f32(sVec0, vScale);
|
||||
sVec1 = vmulq_f32(sVec1, vScale);
|
||||
|
||||
dstBasePtr[(d + 0) * dstStrideDOuter] = sVec0[0];
|
||||
dstBasePtr[(d + 1) * dstStrideDOuter] = sVec0[1];
|
||||
dstBasePtr[(d + 2) * dstStrideDOuter] = sVec0[2];
|
||||
dstBasePtr[(d + 3) * dstStrideDOuter] = sVec0[3];
|
||||
dstBasePtr[(d + 4) * dstStrideDOuter] = sVec1[0];
|
||||
dstBasePtr[(d + 5) * dstStrideDOuter] = sVec1[1];
|
||||
dstBasePtr[(d + 6) * dstStrideDOuter] = sVec1[2];
|
||||
dstBasePtr[(d + 7) * dstStrideDOuter] = sVec1[3];
|
||||
}
|
||||
#else
|
||||
for (; d < headDim; ++d) {
|
||||
dstBasePtr[d * dstStrideDOuter] = srcRowPtr[d] * scaleVal;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif // MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
|
||||
|
||||
#ifndef MNN_USE_NEON
|
||||
void MNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
|
||||
*eP = 16;
|
||||
|
@ -2619,30 +2702,54 @@ void MNNExpC8(float* dest, const float* source, float* offset, const float* para
|
|||
offset[3] = summer;
|
||||
}
|
||||
|
||||
void MNNSoftmax(float* dest, const float* source, size_t size) {
|
||||
float maxValue = ALIMAX(source[0], source[1]);
|
||||
for (int i = 2; i < size; ++i) {
|
||||
maxValue = ALIMAX(maxValue, source[i]);
|
||||
}
|
||||
float xLimit = 87, param = 0.6931471805599453, sumValue = 0.f;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
auto x = source[i] - maxValue;
|
||||
x = x > -xLimit ? x : -xLimit;
|
||||
x = x < xLimit ? x : xLimit;
|
||||
void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) {
|
||||
for (int k = 0; k < outside; ++k) {
|
||||
auto source = input + k * reduceSize;
|
||||
auto dest = softmaxDst + k * reduceSize;
|
||||
|
||||
int div = (x / param);
|
||||
int div2 = (div + 127) << 23;
|
||||
auto xReamin = x - div * param;
|
||||
float expBasic = *(float*)(&div2);
|
||||
float oldMax = source[0];
|
||||
if (runningMax) {
|
||||
oldMax = runningMax[k];
|
||||
}
|
||||
|
||||
auto t = xReamin;
|
||||
auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f;
|
||||
dest[i] = expBasic * expRemain;
|
||||
sumValue += dest[i];
|
||||
}
|
||||
sumValue = 1.f / sumValue;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
dest[i] *= sumValue;
|
||||
// find max value of current block
|
||||
float blockMax =source[0];
|
||||
for (int i = 1; i < reduceSize; ++i) {
|
||||
blockMax = ALIMAX(blockMax, source[i]);
|
||||
}
|
||||
float newMax = ALIMAX(oldMax, blockMax);
|
||||
|
||||
// caculate block's expr(xi-newmax) and update runningMax
|
||||
float xLimit = 87, param = 0.6931471805599453;
|
||||
float blockSum = 0.f;
|
||||
for (int i = 0; i < reduceSize; ++i) {
|
||||
auto x = source[i] - newMax;
|
||||
x = x > -xLimit ? x : -xLimit;
|
||||
x = x < xLimit ? x : xLimit;
|
||||
|
||||
int div = (x / param);
|
||||
int div2 = (div + 127) << 23;
|
||||
auto xReamin = x - div * param;
|
||||
float expBasic = *(float*)(&div2);
|
||||
|
||||
auto t = xReamin;
|
||||
auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f;
|
||||
dest[i] = expBasic * expRemain;
|
||||
blockSum += dest[i];
|
||||
}
|
||||
|
||||
if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) {
|
||||
// update runningSum, runningMax, scale=expf(oldMax-newMax)
|
||||
runningSum[k] = runningSum[k] * expf(oldMax - newMax) + blockSum;
|
||||
runningMax[k] = newMax;
|
||||
updateScale[k] = expf(oldMax - newMax);
|
||||
} else {
|
||||
// Normalize
|
||||
auto scale = 1.f / blockSum;
|
||||
for (int i = 0; i < reduceSize; ++i) {
|
||||
dest[i] *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2657,6 +2764,68 @@ void MNNReluInt8(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint)
|
|||
}
|
||||
#endif // no MNN_USE_SSE
|
||||
|
||||
void MNNExp(float* dst, const float* src, float* offset, size_t dataSize) {
|
||||
int countC8 = static_cast<int32_t>(dataSize) / 8;
|
||||
int remain = static_cast<int32_t>(dataSize) % 8;
|
||||
static const float parameters[] = {
|
||||
(float)logf(2.0f), 1.0f / (float)logf(2.0f), 0.25f, 1.0f, 0.5f, 1.0f / 6.0f, 1.0f / 24.0f, 1.0f / 120.0f};
|
||||
if (countC8 > 0) {
|
||||
// Align to eight so asm is easier to write
|
||||
MNNExpC8(dst, src, offset, parameters, countC8);
|
||||
}
|
||||
if (remain > 0) {
|
||||
auto param = parameters[0];
|
||||
float xLimit = 87;
|
||||
float summer = offset[3];
|
||||
auto source = src + countC8 * 8;
|
||||
auto dest = dst + countC8 * 8;
|
||||
for (int i = 0; i < remain; ++i) {
|
||||
auto x = source[i] * offset[0] + offset[2];
|
||||
x = ALIMAX(x, -xLimit);
|
||||
x = ALIMIN(x, xLimit);
|
||||
int div = (x * parameters[1]);
|
||||
int div2 = (div + 127) << 23;
|
||||
auto xReamin = x - div * param;
|
||||
float expBasic = *(float*)(&div2);
|
||||
auto t = xReamin * 0.25f;
|
||||
auto expRemain =
|
||||
((((parameters[7] * t + parameters[6]) * t + parameters[5]) * t + parameters[4]) * t + 1.0f) * t +
|
||||
1.0f;
|
||||
expRemain = expRemain * expRemain;
|
||||
expRemain = expRemain * expRemain;
|
||||
dest[i] = expBasic * expRemain + offset[1];
|
||||
summer+= dest[i];
|
||||
}
|
||||
offset[3] = summer;
|
||||
}
|
||||
}
|
||||
|
||||
void packKvCache(float* dst, const float* src, size_t seqLen, size_t kvSeqLen, size_t eP) {
|
||||
if (seqLen == 0 || kvSeqLen == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t dstSOuterStride = kvSeqLen * eP;
|
||||
|
||||
for (size_t sBase = 0; sBase < seqLen; sBase += eP) {
|
||||
const size_t numRowsInBlock = std::min(eP, seqLen - sBase);
|
||||
const size_t sOuter = sBase / eP;
|
||||
|
||||
float* dstSOuterBase = dst + sOuter * dstSOuterStride;
|
||||
|
||||
for (size_t k = 0; k < kvSeqLen; ++k) {
|
||||
float* dstColStart = dstSOuterBase + k * eP;
|
||||
|
||||
for (size_t sInner = 0; sInner < numRowsInBlock; ++sInner) {
|
||||
const size_t s = sBase + sInner;
|
||||
|
||||
const float value = src[s * kvSeqLen + k];
|
||||
dstColStart[sInner] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MNNMaxFloat(float* input, float* maxBuffer, int32_t inputCountUnit) {
|
||||
for (int i = 0; i < inputCountUnit; i++) {
|
||||
for (int j = 0; j < UNIT; j++) {
|
||||
|
@ -3173,42 +3342,6 @@ void MNNPackTranspose(float* dst, const float* src, size_t area, size_t depth, i
|
|||
}
|
||||
}
|
||||
|
||||
void MNNExp(float* dst, const float* src, float* offset, size_t dataSize) {
|
||||
int countC8 = static_cast<int32_t>(dataSize) / 8;
|
||||
int remain = static_cast<int32_t>(dataSize) % 8;
|
||||
static const float parameters[] = {
|
||||
(float)logf(2.0f), 1.0f / (float)logf(2.0f), 0.25f, 1.0f, 0.5f, 1.0f / 6.0f, 1.0f / 24.0f, 1.0f / 120.0f};
|
||||
if (countC8 > 0) {
|
||||
// Align to eight so asm is easier to write
|
||||
MNNExpC8(dst, src, offset, parameters, countC8);
|
||||
}
|
||||
if (remain > 0) {
|
||||
auto param = parameters[0];
|
||||
float xLimit = 87;
|
||||
float summer = offset[3];
|
||||
auto source = src + countC8 * 8;
|
||||
auto dest = dst + countC8 * 8;
|
||||
for (int i = 0; i < remain; ++i) {
|
||||
auto x = source[i] * offset[0] + offset[2];
|
||||
x = ALIMAX(x, -xLimit);
|
||||
x = ALIMIN(x, xLimit);
|
||||
int div = (x * parameters[1]);
|
||||
int div2 = (div + 127) << 23;
|
||||
auto xReamin = x - div * param;
|
||||
float expBasic = *(float*)(&div2);
|
||||
auto t = xReamin * 0.25f;
|
||||
auto expRemain =
|
||||
((((parameters[7] * t + parameters[6]) * t + parameters[5]) * t + parameters[4]) * t + 1.0f) * t +
|
||||
1.0f;
|
||||
expRemain = expRemain * expRemain;
|
||||
expRemain = expRemain * expRemain;
|
||||
dest[i] = expBasic * expRemain + offset[1];
|
||||
summer+= dest[i];
|
||||
}
|
||||
offset[3] = summer;
|
||||
}
|
||||
}
|
||||
|
||||
// Lambert's series with 7 divisions
|
||||
// reference from
|
||||
// https://varietyofsound.wordpress.com/2011/02/14/efficient-tanh-computation-using-lamberts-continued-fraction/
|
||||
|
@ -4011,6 +4144,7 @@ void MNNCoreFunctionInit() {
|
|||
gCoreFunction->chooseWinoDestUnrollTransform = WinogradFunction::chooseWinoDestUnrollTransform;
|
||||
gCoreFunction->MNNDeconvRunForLineDepthwise = MNNDeconvRunForLineDepthwise;
|
||||
gCoreFunction->MNNDeconvRunForUnitDepthWise = MNNDeconvRunForUnitDepthWise;
|
||||
gCoreFunction->MNNSoftmax = MNNSoftmax;
|
||||
#ifdef MNN_USE_NEON
|
||||
gCoreFunction->MNNDepthwiseConvFastKernel = MNNDepthwiseConvFastKernel;
|
||||
#endif
|
||||
|
@ -4019,6 +4153,12 @@ void MNNCoreFunctionInit() {
|
|||
#ifdef MNN_SUPPORT_QUANT_EXTEND
|
||||
gCoreFunction->MNNSelectUnaryFunctionForInt8 = CPUUnary::selectForInt8;
|
||||
#endif
|
||||
|
||||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
gCoreFunction->MNNAttenPackAndScaleSingleHead = MNNAttenPackAndScaleSingleHead;
|
||||
gCoreFunction->MNNFlashAttentionUpdateBlockOutput = MNNFlashAttentionUpdateBlockOutput;
|
||||
#endif
|
||||
|
||||
gCoreFunction->MNNReluWithSlopeChannel = MNNReluWithSlopeChannel;
|
||||
gCoreFunction->MNNPoolingAvg = (decltype(gCoreFunction->MNNPoolingAvg))(poolingAvg<float, Vec4, 4>);
|
||||
// Set min value as 1 << 24
|
||||
|
|
|
@ -32,12 +32,16 @@ void MNNAbsMaxFP32_Pack4(const float* source, float* absmax, size_t src_depth_qu
|
|||
void MNNAbsMaxFP32_Pack8(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack);
|
||||
void MNNQuantScaleFP32(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch);
|
||||
void MNNDynamicUpdateConvBiasScale(float* newbias, float* oldbias, float* weightKernelSum, float* inputZero, size_t ocQuad);
|
||||
#endif
|
||||
#endif // MNN_LOW_MEMORY
|
||||
|
||||
#ifdef MNN_SME2
|
||||
void MNNPackedMatMulRemainFP32_SME2(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
|
||||
#endif
|
||||
|
||||
#endif // __aarch64__
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
#endif
|
||||
void MNNFp32ToFp8(uint8_t* dst, const float* src, size_t size);
|
||||
void MNNFp8ToFp32(float* dst, const uint8_t* src, size_t size);
|
||||
void MNNFp16ToFp8(uint8_t* dst, const uint16_t* src, size_t size);
|
||||
|
@ -117,7 +121,7 @@ void MNNReluWithSlopeCommon(float* dst, const float* src, size_t size, float slo
|
|||
void MNNHardSwishCommon(float* dst, const float* src, size_t size);
|
||||
void MNNGeluCommon(float* dst, const float* src, size_t size);
|
||||
void MNNGeluStandardCommon(float* dst, const float* src, size_t size);
|
||||
void MNNSoftmax(float* dest, const float* source, size_t size);
|
||||
void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize);
|
||||
void MNNNorm(float* dest, const float* source, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm = false);
|
||||
|
||||
// Get Pack for MatMul's e , l , h , the pack number must be 1 or 4 * n
|
||||
|
@ -187,6 +191,8 @@ void MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const floa
|
|||
void MNNCopyC4Int16WithStride(const float* sourceF, float* destF, size_t srcStride, size_t dstStride, size_t count);
|
||||
void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count);
|
||||
|
||||
void packKvCache(float* dst, const float* src, size_t seqLen, size_t kvSeqLen, size_t eP);
|
||||
|
||||
struct SumByAxisParams {
|
||||
ssize_t kernelCountUnitDouble;
|
||||
ssize_t unitColBufferSize;
|
||||
|
@ -401,6 +407,13 @@ struct CoreFunctions {
|
|||
void(*MNNSumWeightInt8)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP);
|
||||
void(*MNNSumWeightInt8SmeHp64)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP);
|
||||
|
||||
// Attention
|
||||
void(*MNNAttenUnpackAndConvertFp16)(float* dst, float* src, size_t depth, size_t planesize, int pack);
|
||||
void(*MNNAttenPackAndConvertFp32)(float* dst, float* src, const int32_t* units, size_t depth, size_t planesize);
|
||||
void(*MNNAttenPackAndScaleSingleHead)(float* dst, const float* srcHeadBase, size_t srcRowStride, const float* scale, const int32_t* units, size_t seqLen, size_t headDim);
|
||||
void(*MNNFlashAttentionUpdateBlockOutput)(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes);
|
||||
void(*MNNSoftmax)(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize);
|
||||
|
||||
MatmulRelatedFunctions int8MatmulRelatedFunctions;
|
||||
MatmulRelatedFunctions sme2Int8MatmulRelatedFuncionsHp32;
|
||||
};
|
||||
|
|
|
@ -805,17 +805,16 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector<Tensor*>& input
|
|||
return OUT_OF_MEMORY;
|
||||
}
|
||||
if (mOnlineReorderWeightSme && planeSize > 1) { // only prefill need
|
||||
int dstHp = 32; // if mOnlineReorderWeightSme==true, UNIT is 64 when model loaded.
|
||||
int weightlenNew = ROUND_UP(outC, dstHp) * mBlockNum * ROUND_UP(ic / mBlockNum, SRC_UNIT) * kernelCount * SRC_UNIT * dstHp;
|
||||
int weightlenNew = ROUND_UP(outC, SME_DECODE_MAXHP) * mBlockNum * ROUND_UP(ic / mBlockNum, SRC_UNIT) * kernelCount;
|
||||
if (mResourceInt8->mActBits == 4) {
|
||||
weightlenNew /= 2;
|
||||
}
|
||||
mWeight4Prefill = bufferAlloc->alloc(weightlenNew + 2 * ROUND_UP(outC, dstHp) * QUANT_INFO_BYTES);
|
||||
mWeight4Prefill = bufferAlloc->alloc(weightlenNew + 2 * mBlockNum * ROUND_UP(outC, SME_DECODE_MAXHP) * QUANT_INFO_BYTES);
|
||||
if (mWeight4Prefill.invalid()) {
|
||||
return OUT_OF_MEMORY;
|
||||
}
|
||||
if (mInputBlockNum > 1) { // only in this case, need to use weight_kernel_sum
|
||||
mWeightKernelSum4Prefill = bufferAlloc->alloc(ROUND_UP(outC, 32) * mBlockNum * sizeof(float));
|
||||
mWeightKernelSum4Prefill = bufferAlloc->alloc(ROUND_UP(outC, SME_DECODE_MAXHP) * mBlockNum * sizeof(float));
|
||||
if (mWeightKernelSum4Prefill.invalid()) {
|
||||
return OUT_OF_MEMORY;
|
||||
}
|
||||
|
|
|
@ -62,7 +62,8 @@ bool AVX2Functions::init(int cpuFlags) {
|
|||
coreFunction->MNNComputeMatMulForH_1 = _AVX_MNNComputeMatMulForH_1;
|
||||
// Dynamic Quant
|
||||
coreFunction->MNNCountMaxMinValue = _AVX_MNNCountMinMaxValue;
|
||||
|
||||
|
||||
coreFunction->MNNSoftmax = _AVX_MNNSoftmax;
|
||||
|
||||
// For Packed Functions
|
||||
coreFunction->pack = 8;
|
||||
|
|
|
@ -24,7 +24,7 @@ struct FunctionGroup {
|
|||
int lP = 1;
|
||||
int hP = 4;
|
||||
void (*MNNExpC8)(float* dest, const float* source, float* offset, const float* parameters, size_t countC8) = _SSE_MNNExpC8;
|
||||
void (*MNNSoftmax)(float* dest, const float* source, size_t size) = _SSE_MNNSoftmax;
|
||||
void (*MNNSoftmax)(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) = _SSE_MNNSoftmax;
|
||||
void (*MNNReluInt8)(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint) = _SSE_MNNReluInt8;
|
||||
void (*MNNHardSwish)(float* dst, const float* src, size_t size) = _SSE_MNNHardSwish;
|
||||
void (*MNNGelu)(float* dst, const float* src, size_t size, float* parameters) = _SSE_MNNGelu;
|
||||
|
@ -65,6 +65,8 @@ void MNNFunctionInit() {
|
|||
coreFunction->MNNPackForMatMul_B = _SSE_MNNPackForMatMul_B;
|
||||
// Dynamic Quant
|
||||
coreFunction->MNNCountMaxMinValue = _SSE_MNNCountMinMaxValue;
|
||||
|
||||
coreFunction->MNNSoftmax = _SSE_MNNSoftmax;
|
||||
}
|
||||
#ifdef MNN_USE_AVX
|
||||
if (cpuFlags & libyuv::kCpuHasAVX2) {
|
||||
|
@ -205,8 +207,8 @@ void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count) {
|
|||
_SSE_MNNInt8ToInt16(dest, source, count);
|
||||
}
|
||||
|
||||
void MNNSoftmax(float* dest, const float* source, size_t size) {
|
||||
gFunc.MNNSoftmax(dest, source, size);
|
||||
void MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) {
|
||||
gFunc.MNNSoftmax(softmaxDst, input, runningMax, runningSum, updateScale, outside, reduceSize);
|
||||
}
|
||||
|
||||
void MNNNorm(float* dest, const float* source, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm) {
|
||||
|
|
|
@ -53,7 +53,7 @@ void _AVX_MNNAsyQuantInfo(float* scale, float* bias, float* qscale, float* qbias
|
|||
void _AVX_MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);
|
||||
|
||||
void _AVX_MNNExpC8(float* dest, const float* source, float* offset, const float* parameters, size_t countC8);
|
||||
void _AVX_MNNSoftmax(float* dest, const float* source, size_t size);
|
||||
void _AVX_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize);
|
||||
void _AVX_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, const float* zeroPoint, ssize_t quanParamVec);
|
||||
void _AVX_MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t sizeQuad, const float* zeroPoint, ssize_t quanParamVec);
|
||||
void _AVX_MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dstO, const int8_t* srcO, const int8_t* weightO, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder);
|
||||
|
|
|
@ -117,103 +117,71 @@ void _AVX_MNNExpC8(float* dest, const float* source, float* offset, const float*
|
|||
}
|
||||
|
||||
|
||||
void _AVX_MNNSoftmax(float* dest, const float* source, size_t size) {
|
||||
float tmpfloat8[8];
|
||||
int count = size / 8;
|
||||
int remain = count * 8;
|
||||
// step 1: get maxValue
|
||||
float maxValue = source[0];
|
||||
if (count > 0) {
|
||||
auto maxVal = _mm256_loadu_ps(source);
|
||||
for (int i = 1; i < count; i++) {
|
||||
maxVal = _mm256_max_ps(maxVal, _mm256_loadu_ps(source + i * 8));
|
||||
}
|
||||
_mm256_storeu_ps(tmpfloat8, maxVal);
|
||||
maxValue = tmpfloat8[0] > tmpfloat8[1] ? tmpfloat8[0] : tmpfloat8[1];
|
||||
for (int i = 2; i < 8; i++) {
|
||||
maxValue = maxValue > tmpfloat8[i] ? maxValue : tmpfloat8[i];
|
||||
}
|
||||
}
|
||||
for (int i = remain; i < size; i++) {
|
||||
maxValue = maxValue > source[i] ? maxValue : source[i];
|
||||
}
|
||||
void _AVX_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) {
|
||||
const float xLimit = 87.0f;
|
||||
const float param = 0.6931471805599453f; // ln(2)
|
||||
const float inv_param = 1.0f / param;
|
||||
const int32_t exp_offset = 127;
|
||||
const float exp_scale = 8388608.0f; // 2^23
|
||||
|
||||
// step 2: get exp(x - maxValue) and sum(exp(x - maxValue))
|
||||
float sumValue = 0.f;
|
||||
if (count > 0) {
|
||||
auto sumVal = _mm256_set1_ps(0.f);
|
||||
auto p0 = _mm256_set1_ps(0.6931471805599453);
|
||||
auto p1 = _mm256_set1_ps(1.4426950408889634);
|
||||
auto p2 = _mm256_set1_ps(1.f);
|
||||
auto p3 = _mm256_set1_ps(1.f);
|
||||
auto p4 = _mm256_set1_ps(0.5);
|
||||
auto p5 = _mm256_set1_ps(0.1666666666666666);
|
||||
auto p6 = _mm256_set1_ps(0.041666666666666664);
|
||||
auto p7 = _mm256_set1_ps(0.008333333333333333);
|
||||
auto xMax = _mm256_set1_ps(87);
|
||||
auto xMin = _mm256_set1_ps(-87);
|
||||
auto basic = _mm256_set1_epi32(1 << 23);
|
||||
auto temp127 = _mm256_set1_epi32(127);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
auto x = _mm256_sub_ps(_mm256_loadu_ps(source + i * 8), _mm256_set1_ps(maxValue));
|
||||
x = _mm256_max_ps(x, xMin);
|
||||
x = _mm256_min_ps(x, xMax);
|
||||
auto div = _mm256_mul_ps(x, p1);
|
||||
auto divInt = _mm256_cvtps_epi32(div);
|
||||
div = _mm256_cvtepi32_ps(divInt);
|
||||
auto div2 = _mm256_add_epi32(divInt, temp127);
|
||||
div2 = _mm256_mullo_epi32(div2, basic);
|
||||
auto expBasic = _mm256_castsi256_ps(div2);
|
||||
auto xReamin = _mm256_sub_ps(x, _mm256_mul_ps(div, p0));
|
||||
auto t = xReamin;
|
||||
auto c0 = _mm256_mul_ps(p7, t);
|
||||
auto c1 = _mm256_add_ps(c0, p6);
|
||||
auto c2 = _mm256_mul_ps(c1, t);
|
||||
auto c3 = _mm256_add_ps(c2, p5);
|
||||
auto c4 = _mm256_mul_ps(c3, t);
|
||||
auto c5 = _mm256_add_ps(c4, p4);
|
||||
auto c6 = _mm256_mul_ps(c5, t);
|
||||
auto c7 = _mm256_add_ps(c6, p3);
|
||||
auto c8 = _mm256_mul_ps(c7, t);
|
||||
auto c9 = _mm256_add_ps(c8, p2);
|
||||
auto expRemain = c9;
|
||||
auto expRes = _mm256_mul_ps(expBasic, expRemain);
|
||||
sumVal = _mm256_add_ps(expRes, sumVal);
|
||||
_mm256_storeu_ps(dest + 8 * i, expRes);
|
||||
}
|
||||
_mm256_storeu_ps(tmpfloat8, sumVal);
|
||||
for (int i = 0; i < 8; i++) {
|
||||
sumValue += tmpfloat8[i];
|
||||
}
|
||||
}
|
||||
auto param = 0.6931471805599453;
|
||||
float xLimit = 87;
|
||||
for (int i = remain; i < size; i++) {
|
||||
auto x = source[i] - maxValue;
|
||||
x = x > -xLimit ? x : -xLimit;
|
||||
x = x < xLimit ? x : xLimit;
|
||||
for (int k = 0; k < outside; ++k) {
|
||||
float* source = input + k * reduceSize;
|
||||
float* dest = softmaxDst + k * reduceSize;
|
||||
|
||||
int div = (x / param);
|
||||
int div2 = (div + 127) << 23;
|
||||
auto xReamin = x - div * param;
|
||||
float expBasic = *(float*)(&div2);
|
||||
float tmpfloat8[8];
|
||||
int count = reduceSize/ 8;
|
||||
int remain = count * 8;
|
||||
// step 1: get maxValue
|
||||
float maxValue = source[0];
|
||||
|
||||
auto t = xReamin;
|
||||
auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f;
|
||||
dest[i] = expBasic * expRemain;
|
||||
sumValue += dest[i];
|
||||
}
|
||||
// step 3: get x / sum and store
|
||||
for (int i = 0; i < count; ++i) {
|
||||
// using 1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu
|
||||
auto x = _mm256_rcp_ps(_mm256_loadu_ps(dest + 8 * i));
|
||||
auto y = _mm256_set1_ps(sumValue);
|
||||
auto z = _mm256_rcp_ps(_mm256_mul_ps(x, y));
|
||||
_mm256_storeu_ps(dest + 8 * i, z);
|
||||
}
|
||||
sumValue = 1.f / sumValue;
|
||||
for (int i = remain; i < size; i++) {
|
||||
dest[i] *= sumValue;
|
||||
float oldMax = maxValue;
|
||||
if (runningMax) {
|
||||
oldMax = runningMax[k];
|
||||
}
|
||||
|
||||
if (count > 0) {
|
||||
auto maxVal = _mm256_loadu_ps(source);
|
||||
for (int i = 1; i < count; i++) {
|
||||
maxVal = _mm256_max_ps(maxVal, _mm256_loadu_ps(source + i * 8));
|
||||
}
|
||||
_mm256_storeu_ps(tmpfloat8, maxVal);
|
||||
maxValue = tmpfloat8[0] > tmpfloat8[1] ? tmpfloat8[0] : tmpfloat8[1];
|
||||
for (int i = 2; i < 8; i++) {
|
||||
maxValue = maxValue > tmpfloat8[i] ? maxValue : tmpfloat8[i];
|
||||
}
|
||||
}
|
||||
for (int i = remain; i < reduceSize; i++) {
|
||||
maxValue = maxValue > source[i] ? maxValue : source[i];
|
||||
}
|
||||
|
||||
float newMax = ALIMAX(oldMax, maxValue);
|
||||
|
||||
// step 2: get exp(x - newMax) and sum(exp(x - newMax))
|
||||
float exprOffset[4] = {1.0f, 0.0f, 0.0f, 0.0f };
|
||||
exprOffset[2] = -newMax;
|
||||
MNNExp(dest, source, exprOffset, reduceSize);
|
||||
float sumValue = exprOffset[3];
|
||||
|
||||
if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) {
|
||||
// === Step 3: Update running variables ===
|
||||
float scale = expf(oldMax - newMax);
|
||||
runningSum[k] = runningSum[k] * scale + sumValue;
|
||||
runningMax[k] = newMax;
|
||||
updateScale[k] = scale;
|
||||
} else {
|
||||
// step 3: get x / sum and store
|
||||
for (int i = 0; i < count; ++i) {
|
||||
// using 1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu
|
||||
auto x = _mm256_rcp_ps(_mm256_loadu_ps(dest + 8 * i));
|
||||
auto y = _mm256_set1_ps(sumValue);
|
||||
auto z = _mm256_rcp_ps(_mm256_mul_ps(x, y));
|
||||
_mm256_storeu_ps(dest + 8 * i, z);
|
||||
}
|
||||
auto scale = 1.f / sumValue;
|
||||
for (int i = remain; i < reduceSize; i++) {
|
||||
dest[i] *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -48,8 +48,45 @@ void _AVX_MNNConvRunForLineDepthwise(float* dst, const float* src, const float*
|
|||
size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height,
|
||||
size_t srcHStep, size_t dstHStep, const float* bias, const float* parameters);
|
||||
void _AVX_MNNAxByClampBroadcastUnit(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters);
|
||||
|
||||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
void _AVX_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
void _AVX_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) {
|
||||
// source shape: [headDim/pack, seqLen, pack]
|
||||
// scale & normalizeScale shape: [seqLen]
|
||||
// dest shape: [headDim/pack, seqLen, pack]
|
||||
auto stride0 = plane * pack;
|
||||
|
||||
if (idx > 0) {
|
||||
for (int j = 0; j < depthQuad; ++j) {
|
||||
for (int i = 0; i < plane; ++i) {
|
||||
auto dataNew = Vec::load(src + j * stride0 + i * pack);
|
||||
auto dataOld = Vec::load(dst + j * stride0 + i * pack);
|
||||
auto s = Vec(scale[i]);
|
||||
dataNew = Vec::fma(dataNew, dataOld, s);
|
||||
Vec::save(dst + j * stride0 + i * pack, dataNew);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
memcpy(dst, src, size * bytes);
|
||||
}
|
||||
if (idx == kvBlocks - 1) { // if last subBlock, exp(xi)/sum(exp(xi))
|
||||
for (int j = 0; j < depthQuad; ++j) {
|
||||
for (int i = 0; i < plane; ++i) {
|
||||
auto dataNew = Vec::load(dst + j * stride0 + i * pack);
|
||||
auto ns = Vec(1.0f / normalizeScale[i]);
|
||||
dataNew = dataNew * ns;
|
||||
Vec::save(dst + j * stride0 + i * pack, dataNew);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
void _AVX_MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
|
||||
for (int i = 0; i < count; ++i) {
|
||||
|
@ -728,4 +765,9 @@ void _AVX_ExtraInit(void* functions) {
|
|||
// sparse conv funcs
|
||||
coreFunction->MNNGetSparseMatMulPackMode = _AVX_MNNGetSparseMatMulPackMode;
|
||||
coreFunction->MNNAdjustOptimalSparseKernel = _AVX_MNNAdjustOptimalSparseKernel;
|
||||
|
||||
// attention
|
||||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
coreFunction->MNNFlashAttentionUpdateBlockOutput = _AVX_MNNFlashAttentionUpdateBlockOutput;
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -1064,6 +1064,39 @@ static void _AVX512_MNNAdjustOptimalSparseKernel(int& sparseBlockOC, MNN::CoreFu
|
|||
}
|
||||
}
|
||||
|
||||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
void _AVX512_MNNFlashAttentionUpdateBlockOutput(float* dst, float* src, float* scale, float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) {
|
||||
// source shape: [headDim/pack, seqLen, pack]
|
||||
// scale & normalizeScale shape: [seqLen]
|
||||
// dest shape: [headDim/pack, seqLen, pack]
|
||||
auto stride0 = plane * pack;
|
||||
|
||||
if (idx > 0) {
|
||||
for (int j = 0; j < depthQuad; ++j) {
|
||||
for (int i = 0; i < plane; ++i) {
|
||||
auto dataNew = Vec::load(src + j * stride0 + i * pack);
|
||||
auto dataOld = Vec::load(dst + j * stride0 + i * pack);
|
||||
auto s = Vec(scale[i]);
|
||||
dataNew = Vec::fma(dataNew, dataOld, s);
|
||||
Vec::save(dst + j * stride0 + i * pack, dataNew);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
memcpy(dst, src, size * bytes);
|
||||
}
|
||||
if (idx == kvBlocks - 1) { // if last subBlock, exp(xi)/sum(exp(xi))
|
||||
for (int j = 0; j < depthQuad; ++j) {
|
||||
for (int i = 0; i < plane; ++i) {
|
||||
auto dataNew = Vec::load(dst + j * stride0 + i * pack);
|
||||
auto ns = Vec(1.0f / normalizeScale[i]);
|
||||
dataNew = dataNew * ns;
|
||||
Vec::save(dst + j * stride0 + i * pack, dataNew);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
void _AVX512_ExtraInit(void* functions) {
|
||||
auto coreFunction = static_cast<MNN::CoreFunctions*>(functions);
|
||||
|
@ -1100,4 +1133,8 @@ void _AVX512_ExtraInit(void* functions) {
|
|||
|
||||
coreFunction->MNNGetSparseMatMulPackMode = _AVX512_MNNGetSparseMatMulPackMode;
|
||||
coreFunction->MNNAdjustOptimalSparseKernel = _AVX512_MNNAdjustOptimalSparseKernel;
|
||||
|
||||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
coreFunction->MNNFlashAttentionUpdateBlockOutput = _AVX512_MNNFlashAttentionUpdateBlockOutput;
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -81,7 +81,7 @@ void _SSE_MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count);
|
|||
|
||||
void _SSE_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
|
||||
void _SSE_MNNReluInt8(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint);
|
||||
void _SSE_MNNSoftmax(float* dest, const float* source, size_t size);
|
||||
void _SSE_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize);
|
||||
void _SSE_ExtraInit(void* functions);
|
||||
void _SSE_MNNNorm(float *dst, const float *src, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm);
|
||||
void _SSE_ImageProcessInit(void* functions, int cpuFlags);
|
||||
|
|
|
@ -69,100 +69,68 @@ void _SSE_MNNExpC8(float* dest, const float* source, float* offset, const float*
|
|||
offset[3] = total;
|
||||
}
|
||||
|
||||
void _SSE_MNNSoftmax(float* dest, const float* source, size_t size) {
|
||||
float tmpfloat4[4];
|
||||
int count = static_cast<int32_t>(size / 4);
|
||||
int remain = count * 4;
|
||||
// step 1: get maxValue
|
||||
float maxValue = source[0];
|
||||
if (count > 0) {
|
||||
auto maxVal = _mm_loadu_ps(source);
|
||||
for (int i = 1; i < count; i++) {
|
||||
maxVal = _mm_max_ps(maxVal, _mm_loadu_ps(source + i * 4));
|
||||
void _SSE_MNNSoftmax(float* softmaxDst, float* input, float* runningMax, float* runningSum, float* updateScale, int outside, int reduceSize) {
|
||||
const float xLimit = 87.0f;
|
||||
const float param = 0.6931471805599453f; // ln(2)
|
||||
const float inv_param = 1.0f / param;
|
||||
const int32_t exp_offset = 127;
|
||||
const float exp_scale = 8388608.0f; // 2^23
|
||||
|
||||
for (int k = 0; k < outside; ++k) {
|
||||
float* source = input + k * reduceSize;
|
||||
float* dest = softmaxDst + k * reduceSize;
|
||||
|
||||
float tmpfloat4[4];
|
||||
int count = static_cast<int32_t>(reduceSize / 4);
|
||||
int remain = count * 4;
|
||||
// step 1: get maxValue
|
||||
float maxValue = source[0];
|
||||
float oldMax = maxValue;
|
||||
if (runningMax) {
|
||||
oldMax = runningMax[k];
|
||||
}
|
||||
_mm_storeu_ps(tmpfloat4, maxVal);
|
||||
maxValue = tmpfloat4[0] > tmpfloat4[1] ? tmpfloat4[0] : tmpfloat4[1];
|
||||
maxValue = maxValue > tmpfloat4[2] ? maxValue : tmpfloat4[2];
|
||||
maxValue = maxValue > tmpfloat4[3] ? maxValue : tmpfloat4[3];
|
||||
}
|
||||
for (int i = remain; i < size; i++) {
|
||||
maxValue = maxValue > source[i] ? maxValue : source[i];
|
||||
}
|
||||
|
||||
// step 2: get exp(x - maxValue) and sum(exp(x - maxValue))
|
||||
float sumValue = 0.f;
|
||||
if (count > 0) {
|
||||
auto sumVal = _mm_set1_ps(0.f);
|
||||
auto p0 = _mm_set1_ps(0.6931471805599453);
|
||||
auto p1 = _mm_set1_ps(1.4426950408889634);
|
||||
auto p2 = _mm_set1_ps(1.f);
|
||||
auto p3 = _mm_set1_ps(1.f);
|
||||
auto p4 = _mm_set1_ps(0.5);
|
||||
auto p5 = _mm_set1_ps(0.1666666666666666);
|
||||
auto p6 = _mm_set1_ps(0.041666666666666664);
|
||||
auto p7 = _mm_set1_ps(0.008333333333333333);
|
||||
auto xMax = _mm_set1_ps(87);
|
||||
auto xMin = _mm_set1_ps(-87);
|
||||
// auto basic = _mm_set1_epi32(1 << 23);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
auto x = _mm_sub_ps(_mm_loadu_ps(source + i * 4), _mm_set1_ps(maxValue));
|
||||
x = _mm_max_ps(x, xMin);
|
||||
x = _mm_min_ps(x, xMax);
|
||||
auto div = _mm_mul_ps(x, p1);
|
||||
auto divInt = _mm_cvtps_epi32(div);
|
||||
div = _mm_cvtepi32_ps(divInt);
|
||||
auto div2 = _mm_add_epi32(divInt, _mm_set1_epi32(127));
|
||||
// div2 = _mm_mullo_epi32(div2, basic);
|
||||
div2 = _mm_slli_epi32(div2, 23);
|
||||
auto expBasic = _mm_castsi128_ps(div2);
|
||||
auto xReamin = _mm_sub_ps(x, _mm_mul_ps(div, p0));
|
||||
auto t = xReamin;
|
||||
auto c0 = _mm_mul_ps(p7, t);
|
||||
auto c1 = _mm_add_ps(c0, p6);
|
||||
auto c2 = _mm_mul_ps(c1, t);
|
||||
auto c3 = _mm_add_ps(c2, p5);
|
||||
auto c4 = _mm_mul_ps(c3, t);
|
||||
auto c5 = _mm_add_ps(c4, p4);
|
||||
auto c6 = _mm_mul_ps(c5, t);
|
||||
auto c7 = _mm_add_ps(c6, p3);
|
||||
auto c8 = _mm_mul_ps(c7, t);
|
||||
auto c9 = _mm_add_ps(c8, p2);
|
||||
auto expRemain = c9;
|
||||
auto expRes = _mm_mul_ps(expBasic, expRemain);
|
||||
sumVal = _mm_add_ps(expRes, sumVal);
|
||||
_mm_storeu_ps(dest + 4 * i, expRes);
|
||||
if (count > 0) {
|
||||
auto maxVal = _mm_loadu_ps(source);
|
||||
for (int i = 1; i < count; i++) {
|
||||
maxVal = _mm_max_ps(maxVal, _mm_loadu_ps(source + i * 4));
|
||||
}
|
||||
_mm_storeu_ps(tmpfloat4, maxVal);
|
||||
maxValue = tmpfloat4[0] > tmpfloat4[1] ? tmpfloat4[0] : tmpfloat4[1];
|
||||
maxValue = maxValue > tmpfloat4[2] ? maxValue : tmpfloat4[2];
|
||||
maxValue = maxValue > tmpfloat4[3] ? maxValue : tmpfloat4[3];
|
||||
}
|
||||
for (int i = remain; i < reduceSize; i++) {
|
||||
maxValue = maxValue > source[i] ? maxValue : source[i];
|
||||
}
|
||||
_mm_storeu_ps(tmpfloat4, sumVal);
|
||||
sumValue = tmpfloat4[0] + tmpfloat4[1] + tmpfloat4[2] + tmpfloat4[3];
|
||||
}
|
||||
auto param = 0.6931471805599453;
|
||||
float xLimit = 87;
|
||||
for (int i = remain; i < size; i++) {
|
||||
auto x = source[i] - maxValue;
|
||||
x = x > -xLimit ? x : -xLimit;
|
||||
x = x < xLimit ? x : xLimit;
|
||||
|
||||
int div = (x / param);
|
||||
int div2 = (div + 127) << 23;
|
||||
auto xReamin = x - div * param;
|
||||
float expBasic = *(float*)(&div2);
|
||||
float newMax = ALIMAX(oldMax, maxValue);
|
||||
|
||||
auto t = xReamin;
|
||||
auto expRemain = ((((1.0f / 120 * t + 1.0f / 24) * t + 1.0f / 6) * t + 0.5f) * t + 1.0f) * t + 1.0f;
|
||||
dest[i] = expBasic * expRemain;
|
||||
sumValue += dest[i];
|
||||
}
|
||||
// step 3: get x / sum and store
|
||||
for (int i = 0; i < count; ++i) {
|
||||
// using 1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu
|
||||
auto x = _mm_rcp_ps(_mm_loadu_ps(dest + 4 * i));
|
||||
auto y = _mm_set1_ps(sumValue);
|
||||
auto z = _mm_rcp_ps(_mm_mul_ps(x, y));
|
||||
_mm_storeu_ps(dest + 4 * i, z);
|
||||
}
|
||||
sumValue = 1.f / sumValue;
|
||||
for (int i = remain; i < size; i++) {
|
||||
dest[i] *= sumValue;
|
||||
// step 2: get exp(x - newMax) and sum(exp(x - newMax))
|
||||
float exprOffset[4] = {1.0f, 0.0f, 0.0f, 0.0f };
|
||||
exprOffset[2] = -newMax;
|
||||
MNNExp(dest, source, exprOffset, reduceSize);
|
||||
float sumValue = exprOffset[3];
|
||||
|
||||
if (runningMax != nullptr && runningSum != nullptr && updateScale != nullptr) {
|
||||
// === Step 3: Update running variables ===
|
||||
float scale = expf(oldMax - newMax);
|
||||
runningSum[k] = runningSum[k] * scale + sumValue;
|
||||
runningMax[k] = newMax;
|
||||
updateScale[k] = scale;
|
||||
} else {
|
||||
// step 3: get x / sum and store
|
||||
for (int i = 0; i < count; ++i) {
|
||||
// using 1 / ((1 / x) * sum) instead x * (1 / sum) or x / sum for some bugs in intel cpu
|
||||
auto x = _mm_rcp_ps(_mm_loadu_ps(dest + 4 * i));
|
||||
auto y = _mm_set1_ps(sumValue);
|
||||
auto z = _mm_rcp_ps(_mm_mul_ps(x, y));
|
||||
_mm_storeu_ps(dest + 4 * i, z);
|
||||
}
|
||||
auto scale = 1.f / sumValue;
|
||||
for (int i = remain; i < reduceSize; i++) {
|
||||
dest[i] *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -324,9 +324,9 @@ bool BufferConvertor::convertToNC4HW4Buffer(const Tensor *buffer, const OpenCLBu
|
|||
auto formattedBufferShape = tensorShapeFormat(buffer);//NHWC
|
||||
std::vector<size_t> imageShape;
|
||||
getImageShape(formattedBufferShape, type, &imageShape);
|
||||
|
||||
|
||||
uint32_t gws[2] = {static_cast<uint32_t>(imageShape[0]), static_cast<uint32_t>(imageShape[1])};
|
||||
|
||||
|
||||
auto runtime = mOpenCLRuntime;
|
||||
std::string kernelName;
|
||||
std::string kernelFile = "buffer_convert_buf";
|
||||
|
@ -360,26 +360,23 @@ bool BufferConvertor::convertToNC4HW4Buffer(const Tensor *buffer, const OpenCLBu
|
|||
default:
|
||||
break;
|
||||
}
|
||||
if (mBufferToImageKernel.get() == nullptr || mBufferToImageKernelName != kernelName) {
|
||||
mBufferToImageKernelName = kernelName;
|
||||
std::set<std::string> buildOptions;
|
||||
if(needTrans) {
|
||||
//buildOptions.emplace("-DBUFFER_FORMAT_INP_TRANS");
|
||||
kernelName += "_floatin";
|
||||
}
|
||||
#ifdef MNN_LOW_MEMORY
|
||||
if (lowMemory) {
|
||||
if (quantBit == 8) {
|
||||
// int8 case
|
||||
buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT8");
|
||||
} else if (quantBit == 4){
|
||||
// int4 case
|
||||
buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT4");
|
||||
} else {/* More types to be supported. */}
|
||||
}
|
||||
#endif
|
||||
mBufferToImageKernel = runtime->buildKernelWithCache(kernelFile, kernelName, buildOptions, precision, buffer, image);
|
||||
std::set<std::string> buildOptions;
|
||||
if(needTrans) {
|
||||
//buildOptions.emplace("-DBUFFER_FORMAT_INP_TRANS");
|
||||
kernelName += "_floatin";
|
||||
}
|
||||
#ifdef MNN_LOW_MEMORY
|
||||
if (lowMemory) {
|
||||
if (quantBit == 8) {
|
||||
// int8 case
|
||||
buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT8");
|
||||
} else if (quantBit == 4){
|
||||
// int4 case
|
||||
buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT4");
|
||||
} else {/* More types to be supported. */}
|
||||
}
|
||||
#endif
|
||||
mBufferToImageKernel = runtime->buildKernelWithCache(kernelFile, kernelName, buildOptions, precision, buffer, image);
|
||||
auto kernel = mBufferToImageKernel->get();
|
||||
|
||||
uint32_t idx = 0;
|
||||
|
|
|
@ -43,7 +43,7 @@ public:
|
|||
explicit BufferConvertor(OpenCLRuntime *opencl_runtime) : mOpenCLRuntime(opencl_runtime) {
|
||||
}
|
||||
bool convertToNC4HW4Buffer(const Tensor *input, const OpenCLBufferFormat type, Tensor *output, int precision,
|
||||
bool needTrans, bool needWait = false, bool lowMemory = false, int quantBit = 0);
|
||||
bool needTrans, bool needWait = true, bool lowMemory = false, int quantBit = 0);
|
||||
|
||||
private:
|
||||
OpenCLRuntime *mOpenCLRuntime;
|
||||
|
|
|
@ -47,8 +47,13 @@ CLRuntime::CLRuntime(const Backend::Info& info){
|
|||
mPrecision = mInfo.user->precision;
|
||||
mMemory = mInfo.user->memory;
|
||||
}
|
||||
|
||||
mOpenCLRuntime.reset(new OpenCLRuntime(platform_size, platform_id, device_id, context_ptr, hint()));
|
||||
|
||||
// protect
|
||||
if(mPrecision > 2 || mPrecision < 0){
|
||||
mPrecision = BackendConfig::Precision_High;
|
||||
}
|
||||
|
||||
mOpenCLRuntime.reset(new OpenCLRuntime(platform_size, platform_id, device_id, context_ptr, hint()));
|
||||
|
||||
//Whether runtimeError
|
||||
mCLRuntimeError = mOpenCLRuntime->isCreateError();
|
||||
|
@ -206,6 +211,10 @@ Backend* CLRuntime::onCreate(const BackendConfig* config, Backend* origin) const
|
|||
precision = config->precision;
|
||||
memory = config->memory;
|
||||
}
|
||||
// protect
|
||||
if(precision > 2 || precision < 0){
|
||||
precision = BackendConfig::Precision_High;
|
||||
}
|
||||
auto backend = new OpenCLBackend(precision, memory, mInfo.gpuMode, mImagePool, mBufferPool, this);
|
||||
backend->setMetaPtr(pMeta);
|
||||
return backend;
|
||||
|
@ -246,6 +255,10 @@ OpenCLBackend::OpenCLBackend(BackendConfig::PrecisionMode precision, BackendConf
|
|||
} else{
|
||||
mPrecision = BackendConfig::Precision_High;
|
||||
}
|
||||
// protect
|
||||
if(mPrecision > 2 || mPrecision < 0){
|
||||
mPrecision = BackendConfig::Precision_High;
|
||||
}
|
||||
mMemory = memory;
|
||||
// set tuneLevel, memtype, record mode
|
||||
setGpuMode(gpuMode);
|
||||
|
|
|
@ -18,6 +18,8 @@ struct QnnContext {
|
|||
Qnn_LogHandle_t logHandle = nullptr;
|
||||
Qnn_BackendHandle_t backendHandle = nullptr;
|
||||
Qnn_DeviceHandle_t deviceHandle = nullptr;
|
||||
int soc_id;
|
||||
int dsp_arch;
|
||||
};
|
||||
|
||||
static QnnContext gContext;
|
||||
|
@ -94,6 +96,7 @@ bool PluginShapeRaw::compute(InferShapeContext* ctx) {
|
|||
auto dst = ctx->output(i);
|
||||
std::string key = prefix + std::to_string(i);
|
||||
auto attr = ctx->getAttr(key.c_str());
|
||||
|
||||
if (nullptr == attr || nullptr == attr->tensor()) {
|
||||
MNN_ERROR("MNN_QNN: Failed to find raw shape %s.\n", key.c_str());
|
||||
return false;
|
||||
|
@ -886,7 +889,12 @@ ErrorCode QnnBackend::onResizeEnd() {
|
|||
#ifdef QNN_VERBOSE
|
||||
MNN_PRINT("start finalize\n");
|
||||
#endif
|
||||
buildOutputDequant();
|
||||
finalizeGraph();
|
||||
for(auto func : mReleaseFunc){
|
||||
func();
|
||||
}
|
||||
mReleaseFunc.clear();
|
||||
#ifdef QNN_VERBOSE
|
||||
MNN_PRINT("end finalize\n");
|
||||
#endif
|
||||
|
@ -915,7 +923,14 @@ Backend::MemObj* QnnBackend::onAcquire(const Tensor* tensor, StorageType storage
|
|||
}
|
||||
|
||||
Qnn_DataType_t tDataType;
|
||||
MNN_ASSERT((tensor->getType().code == halide_type_float) || (tensor->getType().code == halide_type_int && tensor->getType().bits == 32));
|
||||
Qnn_QuantizeParams_t tQuantizeParams{};
|
||||
tQuantizeParams.encodingDefinition = QNN_DEFINITION_UNDEFINED;
|
||||
tQuantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED;
|
||||
Qnn_ScaleOffset_t tScaleOffsetEncoding;
|
||||
tScaleOffsetEncoding.scale = 0.0f;
|
||||
tScaleOffsetEncoding.offset = 0;
|
||||
auto quant = TensorUtils::getDescribe(tensor)->quantAttr.get();
|
||||
//MNN_ASSERT((tensor->getType().code == halide_type_float) || (tensor->getType().code == halide_type_int && tensor->getType().bits == 32));
|
||||
if (mUseFP16 && tensor->getType().code == halide_type_float) {
|
||||
tDataType = QNN_DATATYPE_FLOAT_16;
|
||||
} else if (tensor->getType().code == halide_type_float) {
|
||||
|
@ -926,13 +941,19 @@ Backend::MemObj* QnnBackend::onAcquire(const Tensor* tensor, StorageType storage
|
|||
MNN_PRINT("MNN_QNN: Not supported data type in <QnnBackend::onAcquire>.\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Qnn_QuantizeParams_t tQuantizeParams{};
|
||||
tQuantizeParams.encodingDefinition = QNN_DEFINITION_UNDEFINED;
|
||||
tQuantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED;
|
||||
Qnn_ScaleOffset_t tScaleOffsetEncoding;
|
||||
tScaleOffsetEncoding.scale = 0.0f;
|
||||
tScaleOffsetEncoding.offset = 0;
|
||||
if(quant != nullptr && TensorUtils::getDescribe(tensor)->type == DataType_DT_INT8) {
|
||||
tQuantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;
|
||||
tQuantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;
|
||||
tScaleOffsetEncoding.scale = quant->scale;
|
||||
if(quant->zero != 0){
|
||||
MNN_PRINT("MNN_QNN: Not supported asymmetric quant in <QnnBackend::onAcquire>.\n");
|
||||
return nullptr;
|
||||
}
|
||||
tDataType = QNN_DATATYPE_SFIXED_POINT_8;
|
||||
if (isOutput) {
|
||||
tType = QNN_TENSOR_TYPE_NATIVE;
|
||||
}
|
||||
}
|
||||
tQuantizeParams.scaleOffsetEncoding = tScaleOffsetEncoding;
|
||||
|
||||
std::unique_ptr<Tensor> tempTensor(new Tensor(tensor, gQnnTensorDimType, false));
|
||||
|
@ -947,17 +968,41 @@ Backend::MemObj* QnnBackend::onAcquire(const Tensor* tensor, StorageType storage
|
|||
|
||||
Qnn_Tensor_t * qnnTensor = qnnTensorWrapper->getNativeTensor();
|
||||
CALL_QNN(mRuntime->mQnnInterface.tensorCreateGraphTensor(mQnnGraphHandle, qnnTensor));
|
||||
mQNNTensorWrappers.push_back(qnnTensorWrapper);
|
||||
mTensorMap.insert({TensorUtils::getDescribe(tensor), mTensorCounter});
|
||||
|
||||
if (isInput) {
|
||||
mInputTensorIndexes.push_back(mTensorCounter);
|
||||
qnnTensorWrapper->alloc();
|
||||
}
|
||||
if (isOutput) {
|
||||
mOutputTensorIndexes.push_back(mTensorCounter);
|
||||
qnnTensorWrapper->alloc();
|
||||
if(quant != nullptr && TensorUtils::getDescribe(tensor)->type == DataType_DT_INT8){
|
||||
mTensorCounter += 1;
|
||||
std::shared_ptr<Tensor> stageTensor;
|
||||
stageTensor.reset(Tensor::create<float>(tensor->shape(), nullptr, gQnnTensorDimType));
|
||||
tName = "QnnTensor_" + std::to_string(mTensorCounter);
|
||||
tType = QNN_TENSOR_TYPE_APP_READ;
|
||||
if (mUseFP16 && tensor->getType().code == halide_type_float) {
|
||||
tDataType = QNN_DATATYPE_FLOAT_16;
|
||||
} else if (tensor->getType().code == halide_type_float) {
|
||||
tDataType = QNN_DATATYPE_FLOAT_32;
|
||||
} else {
|
||||
MNN_PRINT("MNN_QNN: Not supported data type in <QnnBackend::onAcquire>.\n");
|
||||
return nullptr;
|
||||
}
|
||||
Qnn_QuantizeParams_t tQuantizeParamstmp = QNN_QUANTIZE_PARAMS_INIT;
|
||||
std::shared_ptr<QNNTensorWrapper> qnnOutputTensorWrapper = QNNTensorWrapper::create(tName, tType, tDataType, tDims, tQuantizeParamstmp);
|
||||
CALL_QNN(mRuntime->mQnnInterface.tensorCreateGraphTensor(mQnnGraphHandle, qnnOutputTensorWrapper->getNativeTensor()));
|
||||
mDeQuantOutputTensorMap.insert({TensorUtils::getDescribe(tensor), {tensor, stageTensor}});
|
||||
mQNNTensorWrappers.push_back(qnnOutputTensorWrapper);
|
||||
mTensorMap.insert({TensorUtils::getDescribe(const_cast<const Tensor*>(stageTensor.get())), mTensorCounter});
|
||||
mOutputTensorIndexes.push_back(mTensorCounter);
|
||||
qnnOutputTensorWrapper->alloc();
|
||||
} else{
|
||||
mOutputTensorIndexes.push_back(mTensorCounter);
|
||||
qnnTensorWrapper->alloc();
|
||||
}
|
||||
}
|
||||
mQNNTensorWrappers.push_back(qnnTensorWrapper);
|
||||
mTensorMap.insert({TensorUtils::getDescribe(tensor), mTensorCounter});
|
||||
|
||||
mTensorCounter += 1;
|
||||
#ifdef QNN_VERBOSE
|
||||
|
@ -1029,7 +1074,13 @@ void QnnBackend::inputIO(const Tensor* srcTensor, const Tensor* dstTensor) const
|
|||
}
|
||||
|
||||
void QnnBackend::outputIO(const Tensor* srcTensor, const Tensor* dstTensor) const {
|
||||
int srcIndex = getTensorIdx(srcTensor);
|
||||
auto iter = mDeQuantOutputTensorMap.find(TensorUtils::getDescribe(srcTensor));
|
||||
int srcIndex = -1;
|
||||
if(iter != mDeQuantOutputTensorMap.end()){
|
||||
srcIndex = getTensorIdx(iter->second.second.get());
|
||||
} else{
|
||||
srcIndex = getTensorIdx(srcTensor);
|
||||
}
|
||||
std::shared_ptr<QNNTensorWrapper> srcQnnTensorWrapper = mQNNTensorWrappers[srcIndex];
|
||||
std::shared_ptr<Tensor> srcDataContainer = srcQnnTensorWrapper->getDataContainer();
|
||||
|
||||
|
@ -1091,7 +1142,7 @@ void QnnBackend::finalizeGraph() {
|
|||
// set QNN_PROFILE_LEVEL_DETAILED
|
||||
QnnProfile_Level_t profileLevel = QNN_PROFILE_LEVEL_DETAILED;
|
||||
MNN_PRINT("[QNN Profile] Creating QNN Profile Handle with DETAILED level.\n");
|
||||
auto profile_err = mRuntime->mQnnInterface.profileCreate(mQnnContextHandle, profileLevel, &mQnnProfileHandle);
|
||||
auto profile_err = mRuntime->mQnnInterface.profileCreate(mRuntime->mQnnContextHandle, profileLevel, &mQnnProfileHandle);
|
||||
if (profile_err != QNN_SUCCESS || mQnnProfileHandle == nullptr) {
|
||||
MNN_ERROR("[QNN Profile] Failed to create QNN Profile Handle, error: %d\n", (int)profile_err);
|
||||
mQnnProfileHandle = nullptr;
|
||||
|
@ -1216,6 +1267,27 @@ void QnnBackend::clean() {
|
|||
mTensorMap.clear();
|
||||
mInputTensorIndexes.clear();
|
||||
mOutputTensorIndexes.clear();
|
||||
mDeQuantOutputTensorMap.clear();
|
||||
}
|
||||
void QnnBackend::buildOutputDequant(){
|
||||
Qnn_OpConfigVersion_t mOpConfigVersion = QNN_OPCONFIG_VERSION_1;
|
||||
std::string mNodeName;
|
||||
std::string mPackageName = "qti.aisw";
|
||||
std::string mNodeType;
|
||||
std::vector<Qnn_Param_t> mParams;
|
||||
std::vector<Qnn_Tensor_t> mInputs;
|
||||
std::vector<Qnn_Tensor_t> mOutputs;
|
||||
for(auto iter : mDeQuantOutputTensorMap){
|
||||
mNodeType.clear();
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Dequantize";
|
||||
std::string name = "Dequantize_I_" + std::to_string(getTensorIdx(iter.second.first)) + "_O_" + std::to_string(getTensorIdx(iter.second.second.get()));;
|
||||
mInputs.push_back(*(getNativeTensor(iter.second.first))); // input
|
||||
mOutputs.push_back(*(getNativeTensor(iter.second.second.get()))); // output
|
||||
addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
}
|
||||
|
||||
QnnRuntime::QnnRuntime(const Backend::Info& info, QNN_INTERFACE_VER_TYPE qnnInterface, Qnn_LogHandle_t qnnLogHandle, Qnn_BackendHandle_t qnnBackendHandle, Qnn_DeviceHandle_t qnnDeviceHandle) {
|
||||
|
@ -1324,12 +1396,23 @@ bool QnnRuntime::registerCustomOpPackage(QNN_INTERFACE_VER_TYPE qnnInterface, Qn
|
|||
|
||||
class QnnRuntimeCreator : public RuntimeCreator {
|
||||
public:
|
||||
virtual Runtime* onCreate(const Backend::Info& info) const {
|
||||
virtual Runtime* onCreate(const Backend::Info& info) const override {
|
||||
return QnnRuntime::create(info);
|
||||
}
|
||||
virtual bool onValid(Backend::Info& info) const {
|
||||
virtual bool onValid(Backend::Info& info) const override {
|
||||
return true;
|
||||
}
|
||||
virtual bool onGetDeviceInfo(const std::string& deviceKey, std::string& deviceValue) const override {
|
||||
if(deviceKey == "soc_id" && gContext.soc_id != 0) {
|
||||
deviceValue = std::to_string(gContext.soc_id);
|
||||
return true;
|
||||
}
|
||||
if(deviceKey == "dsp_arch" && gContext.dsp_arch != 0) {
|
||||
deviceValue = "v" + std::to_string(gContext.dsp_arch);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace QNN
|
||||
|
@ -1401,6 +1484,8 @@ void registerQNNRuntimeCreator() {
|
|||
|
||||
// Create Device.
|
||||
Qnn_DeviceHandle_t deviceHandle = nullptr;
|
||||
QnnHtpDevice_Arch_t dspArch = QNN_HTP_DEVICE_ARCH_NONE;
|
||||
uint32_t socId = 0;
|
||||
{
|
||||
// Check whether the device API is supported.
|
||||
bool supportDevice = QNN::checkCapability(qnnInterface, QNN_PROPERTY_GROUP_DEVICE);
|
||||
|
@ -1422,10 +1507,8 @@ void registerQNNRuntimeCreator() {
|
|||
MNN_PRINT("[Warning]: deviceGetPlatformInfo Failed to query platform info");
|
||||
} else {
|
||||
QnnDevice_HardwareDeviceInfo_t* hwDeviceInfo = backendPlatformInfoPtr->v1.hwDevices;
|
||||
QnnHtpDevice_Arch_t arch = hwDeviceInfo->v1.deviceInfoExtension->onChipDevice.arch;
|
||||
uint32_t socModel = hwDeviceInfo->v1.deviceInfoExtension->onChipDevice.socModel;
|
||||
MNN_PRINT("Qnn Device soc_id: %d\n", socModel);
|
||||
MNN_PRINT("Qnn Device dsp_arch: v%d\n", arch);
|
||||
dspArch = hwDeviceInfo->v1.deviceInfoExtension->onChipDevice.arch;
|
||||
socId = hwDeviceInfo->v1.deviceInfoExtension->onChipDevice.socModel;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -1478,6 +1561,8 @@ void registerQNNRuntimeCreator() {
|
|||
QNN::gContext.backendHandle = backendHandle;
|
||||
QNN::gContext.deviceHandle = deviceHandle;
|
||||
QNN::gContext.logHandle = logHandle;
|
||||
QNN::gContext.soc_id = socId;
|
||||
QNN::gContext.dsp_arch = dspArch;
|
||||
|
||||
QNN::registerQNNOps();
|
||||
MNNInsertExtraRuntimeCreator(MNN_FORWARD_NN, new QNN::QnnRuntimeCreator, false);
|
||||
|
|
|
@ -78,6 +78,10 @@ public:
|
|||
std::shared_ptr<QNNTensorWrapper> getTensorWrapper(const Tensor * tensor);
|
||||
bool useCache() const;
|
||||
bool getUseFP16() const;
|
||||
void buildOutputDequant();
|
||||
void pushReleaseFunc(std::function<void()> func){
|
||||
mReleaseFunc.push_back(func);
|
||||
}
|
||||
|
||||
private:
|
||||
void clean();
|
||||
|
@ -109,8 +113,10 @@ private:
|
|||
mutable int mTensorCounter = 0;
|
||||
mutable std::vector<std::shared_ptr<QNNTensorWrapper>> mQNNTensorWrappers;
|
||||
mutable std::map<const Tensor::InsideDescribe::NativeInsideDescribe *, int> mTensorMap;
|
||||
mutable std::map<const Tensor::InsideDescribe::NativeInsideDescribe *, std::pair<const Tensor*, std::shared_ptr<Tensor>>> mDeQuantOutputTensorMap;
|
||||
std::vector<int> mInputTensorIndexes;
|
||||
std::vector<int> mOutputTensorIndexes;
|
||||
std::vector<std::function<void()>> mReleaseFunc;
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -134,6 +134,8 @@ void registerQNNOps() {
|
|||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
___QNNAttentionCreator__OpType_Attention__();
|
||||
#endif
|
||||
___QNNQuantCreator__OpType_FloatToInt8__();
|
||||
___QNNDeQuantCreator__OpType_Int8ToFloat__();
|
||||
}
|
||||
|
||||
Tensor::DimensionType gQnnTensorDimType = Tensor::TENSORFLOW;
|
||||
|
|
|
@ -104,6 +104,8 @@ extern void ___QNNMatMulCreator__OpType_MatMul__();
|
|||
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
extern void ___QNNAttentionCreator__OpType_Attention__();
|
||||
#endif
|
||||
extern void ___QNNQuantCreator__OpType_FloatToInt8__();
|
||||
extern void ___QNNDeQuantCreator__OpType_Int8ToFloat__();
|
||||
void registerQNNOps();
|
||||
|
||||
extern Tensor::DimensionType gQnnTensorDimType;
|
||||
|
|
|
@ -25,7 +25,7 @@ std::shared_ptr<QNNTensorWrapper> QNNTensorWrapper::create(const std::string & n
|
|||
|
||||
std::shared_ptr<QNNTensorWrapper> QNNTensorWrapper::createStaticTensor(const std::string & name, Qnn_DataType_t dataType, const std::vector<uint32_t> & dimensions, const void * buffer, Qnn_QuantizeParams_t quantizeParam) {
|
||||
MNN_ASSERT(!name.empty() && !dimensions.empty() && buffer);
|
||||
MNN_ASSERT(dataType == QNN_DATATYPE_SFIXED_POINT_8 || dataType == QNN_DATATYPE_INT_32);
|
||||
MNN_ASSERT(dataType == QNN_DATATYPE_SFIXED_POINT_8 || dataType == QNN_DATATYPE_INT_32 || dataType == QNN_DATATYPE_SFIXED_POINT_32 || dataType == QNN_DATATYPE_UFIXED_POINT_8);
|
||||
|
||||
std::shared_ptr<QNNTensorWrapper> tensorWrapper = QNNTensorWrapper::create(name, QNN_TENSOR_TYPE_STATIC, dataType, dimensions, quantizeParam);
|
||||
uint32_t numElement = 1;
|
||||
|
@ -114,7 +114,8 @@ void * QNNTensorWrapper::alloc() {
|
|||
dims[i] = (int)mDimensions[i];
|
||||
}
|
||||
|
||||
MNN_ASSERT(mQnnTensor.v1.dataType == QNN_DATATYPE_FLOAT_32 || mQnnTensor.v1.dataType == QNN_DATATYPE_FLOAT_16 || mQnnTensor.v1.dataType == QNN_DATATYPE_INT_32 || mQnnTensor.v1.dataType == QNN_DATATYPE_SFIXED_POINT_8);
|
||||
MNN_ASSERT(mQnnTensor.v1.dataType == QNN_DATATYPE_FLOAT_32 || mQnnTensor.v1.dataType == QNN_DATATYPE_FLOAT_16 || mQnnTensor.v1.dataType == QNN_DATATYPE_INT_32 || mQnnTensor.v1.dataType == QNN_DATATYPE_SFIXED_POINT_8 || mQnnTensor.v1.dataType == QNN_DATATYPE_SFIXED_POINT_32
|
||||
|| mQnnTensor.v1.dataType == QNN_DATATYPE_UFIXED_POINT_8);
|
||||
halide_type_t halideType;
|
||||
|
||||
halideType.lanes = 1;
|
||||
|
@ -134,6 +135,15 @@ void * QNNTensorWrapper::alloc() {
|
|||
case QNN_DATATYPE_SFIXED_POINT_8:
|
||||
halideType.code = halide_type_int;
|
||||
halideType.bits = 8;
|
||||
break;
|
||||
case QNN_DATATYPE_SFIXED_POINT_32:
|
||||
halideType.code = halide_type_int;
|
||||
halideType.bits = 32;
|
||||
break;
|
||||
case QNN_DATATYPE_UFIXED_POINT_8:
|
||||
halideType.code = halide_type_int;
|
||||
halideType.bits = 8;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -269,6 +269,10 @@ std::vector<std::string> QNNTranslator::TranslateTensor(const QNNCommandTensor&
|
|||
if (isParam) {
|
||||
result.push_back(QNNTranslator::TranslateParamDataArray(dataNameSymbol, cmdT.dataType, cmdT.clientBuf));
|
||||
}
|
||||
if(hasQuant){
|
||||
std::vector<std::string> linesQuantScaleOffset = TranslateQuantizeScaleOffsetDataArray(tensorNameSymbol, cmdT.quantizeParams, cmdT.rank, cmdT.dimensions);
|
||||
APPEND_VECTOR(result, linesQuantScaleOffset);
|
||||
}
|
||||
result.push_back(" Qnn_Tensor_t " + tensorNameSymbol + " = QNN_TENSOR_INIT;");
|
||||
result.push_back(" {");
|
||||
result.push_back(" " + tensorNameSymbol + ".version = QNN_TENSOR_VERSION_1;");
|
||||
|
@ -456,11 +460,122 @@ std::string QNNTranslator::TranslateParamDataArray(const std::string & dataNameS
|
|||
return result;
|
||||
}
|
||||
|
||||
std::vector<std::string> QNNTranslator::TranslateQuantizeScaleOffsetDataArray(const std::string & tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParams, uint32_t rank, const uint32_t * dimensions){
|
||||
std::vector<std::string> result;
|
||||
if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET){
|
||||
result.push_back(" Qnn_ScaleOffset_t " + tensorNameSymbol + "_axis_scale_offset[] = {");
|
||||
int totalnum = (quantizeParams.axisScaleOffsetEncoding.numScaleOffsets + 3) / 4;
|
||||
for(int i = 0; i < totalnum; ++i){
|
||||
std::string line = " ";
|
||||
for(int j = 0; j < 4; ++j){
|
||||
int index = i * 4 + j;
|
||||
if(index >= quantizeParams.axisScaleOffsetEncoding.numScaleOffsets)
|
||||
break;
|
||||
line += "{.scale= " + std::to_string(quantizeParams.axisScaleOffsetEncoding.scaleOffset[index].scale) + ", .offset= " + std::to_string(quantizeParams.axisScaleOffsetEncoding.scaleOffset[index].offset) + "}, ";
|
||||
}
|
||||
result.push_back(line);
|
||||
}
|
||||
result.push_back(" };");
|
||||
}
|
||||
|
||||
if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET){
|
||||
result.push_back(" float " + tensorNameSymbol + "_bwaxis_scale[] = {");
|
||||
int totalnum = (quantizeParams.bwAxisScaleOffsetEncoding.numElements + 3) / 4;
|
||||
for(int i = 0; i < totalnum; ++i){
|
||||
std::string line = " ";
|
||||
for(int j = 0; j < 4; ++j){
|
||||
int index = i * 4 + j;
|
||||
if(index >= quantizeParams.bwAxisScaleOffsetEncoding.numElements)
|
||||
break;
|
||||
line += std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.scales[index]) + ", ";
|
||||
}
|
||||
result.push_back(line);
|
||||
}
|
||||
result.push_back(" };");
|
||||
if(quantizeParams.bwAxisScaleOffsetEncoding.offsets != nullptr){
|
||||
result.push_back(" int32_t " + tensorNameSymbol + "_bwaxis_offset[] = {");
|
||||
for(int i = 0; i < totalnum; ++i){
|
||||
std::string line = " ";
|
||||
for(int j = 0; j < 4; ++j){
|
||||
int index = i * 4 + j;
|
||||
if(index >= quantizeParams.bwAxisScaleOffsetEncoding.numElements)
|
||||
break;
|
||||
line += std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.offsets[index]) + ", ";
|
||||
}
|
||||
result.push_back(line);
|
||||
}
|
||||
result.push_back(" };");
|
||||
}
|
||||
}
|
||||
|
||||
if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION){
|
||||
int axis = quantizeParams.blockwiseExpansion->axis;
|
||||
int oc = dimensions[axis];
|
||||
int blockSize = quantizeParams.blockwiseExpansion->numBlocksPerAxis;
|
||||
result.push_back(" Qnn_BlockwiseExpansion_t " + tensorNameSymbol + "_blockwiseExpansion = QNN_BLOCKWISE_EXPANSION_INIT;");
|
||||
|
||||
result.push_back(" Qnn_ScaleOffset_t " + tensorNameSymbol + "_blockwiseExpansionScaleOffset[] = {");
|
||||
int totalnum = (oc + 3) / 4;
|
||||
for(int i = 0; i < totalnum; ++i){
|
||||
std::string line = " ";
|
||||
for(int j = 0; j < 4; ++j){
|
||||
int index = i * 4 + j;
|
||||
if(index >= oc)
|
||||
break;
|
||||
line += "{.scale= " + std::to_string(quantizeParams.blockwiseExpansion->scaleOffsets[index].scale) + ", .offset= " + std::to_string(quantizeParams.blockwiseExpansion->scaleOffsets[index].offset) + "}, ";
|
||||
}
|
||||
result.push_back(line);
|
||||
}
|
||||
result.push_back(" };");
|
||||
if(quantizeParams.blockwiseExpansion->blockScaleStorageType == QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8){
|
||||
result.push_back(" uint8_t " + tensorNameSymbol + "_blockwiseExpansionBlockScale[] = {");
|
||||
totalnum = (oc * blockSize + 3) / 4;
|
||||
for(int i = 0; i < totalnum; ++i){
|
||||
std::string line = " ";
|
||||
for(int j = 0; j < 4; ++j){
|
||||
int index = i * 4 + j;
|
||||
if(index >= oc * blockSize)
|
||||
break;
|
||||
line += std::to_string(quantizeParams.blockwiseExpansion->blocksScale8[index]) + ", ";
|
||||
}
|
||||
result.push_back(line);
|
||||
}
|
||||
result.push_back(" };");
|
||||
}else{
|
||||
result.push_back(" uint16_t " + tensorNameSymbol + "_blockwiseExpansionBlockScale[] = {");
|
||||
totalnum = (oc * blockSize + 3) / 4;
|
||||
for(int i = 0; i < totalnum; ++i){
|
||||
std::string line = " ";
|
||||
for(int j = 0; j < 4; ++j){
|
||||
int index = i * 4 + j;
|
||||
if(index >= oc * blockSize)
|
||||
break;
|
||||
line += std::to_string(quantizeParams.blockwiseExpansion->blocksScale16[index]) + ", ";
|
||||
}
|
||||
result.push_back(line);
|
||||
}
|
||||
result.push_back(" };");
|
||||
}
|
||||
result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.axis = " + std::to_string(quantizeParams.blockwiseExpansion->axis) + ";");
|
||||
result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.scaleOffsets = " + tensorNameSymbol + "_blockwiseExpansionScaleOffset;");
|
||||
result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.numBlocksPerAxis = " + std::to_string(quantizeParams.blockwiseExpansion->numBlocksPerAxis) + ";");
|
||||
result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.blockScaleBitwidth = " + std::to_string(quantizeParams.blockwiseExpansion->blockScaleBitwidth) + ";");
|
||||
if(quantizeParams.blockwiseExpansion->blockScaleStorageType == QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8){
|
||||
result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.blockScaleStorageType = QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8;");
|
||||
result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.blocksScale8 = " + tensorNameSymbol + "_blockwiseExpansionBlockScale;");
|
||||
}else{
|
||||
result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.blockScaleStorageType = QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_16;");
|
||||
result.push_back(" " + tensorNameSymbol + "_blockwiseExpansion.blocksScale16 = " + tensorNameSymbol + "_blockwiseExpansionBlockScale;");
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Currently, only support QNN_QUANTIZATION_ENCODING_UNDEFINED, QNN_QUANTIZATION_ENCODING_SCALE_OFFSET.
|
||||
std::vector<std::string> QNNTranslator::TranslateTensorQuantizeParams(const std::string tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParmas) {
|
||||
std::vector<std::string> QNNTranslator::TranslateTensorQuantizeParams(const std::string tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParams) {
|
||||
std::vector<std::string> result;
|
||||
|
||||
if (quantizeParmas.encodingDefinition == QNN_DEFINITION_UNDEFINED) {
|
||||
if (quantizeParams.encodingDefinition == QNN_DEFINITION_UNDEFINED) {
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_UNDEFINED;");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED;");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.scale = 0.0f;");
|
||||
|
@ -468,13 +583,42 @@ std::vector<std::string> QNNTranslator::TranslateTensorQuantizeParams(const std:
|
|||
return result;
|
||||
}
|
||||
|
||||
if (quantizeParmas.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParmas.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) {
|
||||
if (quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) {
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.scale = " + std::to_string(quantizeParmas.scaleOffsetEncoding.scale) + ";");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.offset = " + std::to_string(quantizeParmas.scaleOffsetEncoding.offset) + ";");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.scale = " + std::to_string(quantizeParams.scaleOffsetEncoding.scale) + ";");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.scaleOffsetEncoding.offset = " + std::to_string(quantizeParams.scaleOffsetEncoding.offset) + ";");
|
||||
return result;
|
||||
}
|
||||
|
||||
if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET){
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET;");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.axisScaleOffsetEncoding.axis = " + std::to_string(quantizeParams.axisScaleOffsetEncoding.axis) + ";");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets = " + std::to_string(quantizeParams.axisScaleOffsetEncoding.numScaleOffsets) + ";");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.axisScaleOffsetEncoding.scaleOffset = " + tensorNameSymbol + "_axis_scale_offset;");
|
||||
return result;
|
||||
}
|
||||
|
||||
if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET){
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET;");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.axis = " + std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.axis) + ";");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.bitwidth = " + std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.bitwidth) + ";");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.numElements = " + std::to_string(quantizeParams.bwAxisScaleOffsetEncoding.numElements) + ";");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.scales = " + tensorNameSymbol + "_bwaxis_scale;");
|
||||
if(quantizeParams.bwAxisScaleOffsetEncoding.offsets != nullptr)
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.bwAxisScaleOffsetEncoding.offset = " + tensorNameSymbol + "_bwaxis_offset;");
|
||||
return result;
|
||||
}
|
||||
|
||||
if(quantizeParams.encodingDefinition == QNN_DEFINITION_DEFINED && quantizeParams.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION){
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.encodingDefinition = QNN_DEFINITION_DEFINED;");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION;");
|
||||
result.push_back(" " + tensorNameSymbol + ".v1.quantizeParams.blockwiseExpansion = &" + tensorNameSymbol + "_blockwiseExpansion;");
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
MNN_ERROR("MNN_QNN: Unknown QuantizeParams.\n");
|
||||
|
||||
|
|
|
@ -69,7 +69,8 @@ private:
|
|||
static std::string MapDataType(Qnn_DataType_t dataType);
|
||||
static std::string TranslateDimensionsArray(const std::string & dimensionsNameSymbol, uint32_t rank, const uint32_t * dimensions);
|
||||
static std::string TranslateParamDataArray(const std::string & dataNameSymbol, Qnn_DataType_t dataType, const Qnn_ClientBuffer_t & clientBuf);
|
||||
static std::vector<std::string> TranslateTensorQuantizeParams(const std::string tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParmas);
|
||||
static std::vector<std::string> TranslateQuantizeScaleOffsetDataArray(const std::string & tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParams, uint32_t rank, const uint32_t * dimensions);
|
||||
static std::vector<std::string> TranslateTensorQuantizeParams(const std::string tensorNameSymbol, const Qnn_QuantizeParams_t & quantizeParams);
|
||||
static std::vector<std::string> TranslateTensorClientBuf(const std::string & tensorNameSymbol, const std::string & dataNameSymbol, const std::string & sname, const Qnn_ClientBuffer_t & clientBuf, bool hasClientBuf, bool isParam);
|
||||
// Utility functions used by TranslateNode.
|
||||
static std::vector<std::string> TranslateNodeParamArray(const std::string & nodeName,const std::string & paramArraySymbol, uint32_t numOfParams, const Qnn_Param_t * params);
|
||||
|
|
|
@ -11,6 +11,157 @@
|
|||
namespace MNN {
|
||||
namespace QNN {
|
||||
|
||||
void QNNConvDepthwise::isWeightQuantSupported(const Tensor *input, const int oc){
|
||||
Qnn_DataType_t dataType = mBackend->getNativeTensor(input)->v1.dataType;
|
||||
if(mOp->main_as_Convolution2D()->quanParameter() == nullptr){
|
||||
mWeightQuant = false;
|
||||
return;
|
||||
}else{
|
||||
bool hasBais = false;
|
||||
auto bias = mOp->main_as_Convolution2D()->bias();
|
||||
auto biasPtr = (float*)bias->data();
|
||||
for(int i = 0; i < oc; ++i){
|
||||
if(biasPtr[i] != 0.0f){
|
||||
hasBais = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true);
|
||||
if(quanCommon->asymmetric || dataType == QNN_DATATYPE_FLOAT_16 || dataType == QNN_DATATYPE_FLOAT_32){
|
||||
// not support asymmetric and mBlockSize > 1 results incorrect now
|
||||
mWeightQuant = false;
|
||||
return;
|
||||
}
|
||||
|
||||
float inputScale = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.scale;
|
||||
int inputOffset = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.offset;
|
||||
if(inputOffset == 0){
|
||||
mWeightQuant = true;
|
||||
}else{
|
||||
if(hasBais){
|
||||
mWeightQuant = false;
|
||||
}else{
|
||||
mWeightQuant = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ErrorCode QNNConvDepthwise::onEncodeQuantDequantDepthConv(Tensor *input, Tensor *output, const int n, const int ic, const int oc) {
|
||||
auto conv2D = mOp->main_as_Convolution2D();
|
||||
auto common = conv2D->common();
|
||||
Qnn_DataType_t dataType = QNN_DATATYPE_FLOAT_32;
|
||||
if(mBackend->getUseFP16()){
|
||||
dataType = QNN_DATATYPE_FLOAT_16;
|
||||
}
|
||||
|
||||
// create dequant input stage tensor
|
||||
this->createStageTensor("DequantInput", dataType, getNHWCShape(input)); // mTempTensorWrappers[2]
|
||||
this->createStageTensor("QuantOutput", dataType, getNHWCShape(output)); // mTempTensorWrappers[3]
|
||||
|
||||
// add nodes
|
||||
{
|
||||
// dequant input
|
||||
{
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Dequantize";
|
||||
std::string name = mNodeName + "_dequant_input";
|
||||
|
||||
mInputs.push_back(*(mBackend->getNativeTensor(input))); // input
|
||||
mOutputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
|
||||
if (common->relu() || common->relu6()) {
|
||||
this->createStageTensor("ReluTensor", dataType, getNHWCShape(output)); // mTempTensorWrappers[4]
|
||||
// Stage one
|
||||
{
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "DepthWiseConv2d";
|
||||
std::string name = mNodeName + "_convDepthwise";
|
||||
mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride
|
||||
mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount
|
||||
mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation
|
||||
|
||||
mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
|
||||
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight
|
||||
mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias
|
||||
|
||||
mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
|
||||
// Stage two
|
||||
{
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = common->relu6() ? "ReluMinMax" : "Relu";
|
||||
std::string name = mNodeName + "_relu";
|
||||
if (common->relu6()) {
|
||||
mParams.push_back(*(mParamScalarWrappers[1]->getNativeParam())); // min_value
|
||||
mParams.push_back(*(mParamScalarWrappers[2]->getNativeParam())); // max_value
|
||||
}
|
||||
mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor
|
||||
mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
|
||||
} else {
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "DepthWiseConv2d";
|
||||
mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride
|
||||
mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount
|
||||
mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation
|
||||
|
||||
mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
|
||||
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight
|
||||
mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias
|
||||
|
||||
mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, mNodeName.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
|
||||
// Quant output
|
||||
{
|
||||
auto QuantOutputTensor = mTempTensorWrappers[3]->getNativeTensor();
|
||||
if(mBackend->getUseFP16()){
|
||||
this->createStageTensor("CastOutput", QNN_DATATYPE_FLOAT_32, getNHWCShape(output));
|
||||
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Cast";
|
||||
std::string name = mNodeName + "_Cast_Output";
|
||||
|
||||
mInputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
|
||||
mOutputs.push_back(*(mTempTensorWrappers.back()->getNativeTensor())); // CastOutput
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
QuantOutputTensor = mTempTensorWrappers.back()->getNativeTensor();
|
||||
}
|
||||
{
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Quantize";
|
||||
std::string name = mNodeName + "_Quant_Output";
|
||||
|
||||
mInputs.push_back(*(QuantOutputTensor)); // stage tensor
|
||||
mOutputs.push_back(*(mBackend->getNativeTensor(output))); // output
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
}
|
||||
}
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
ErrorCode QNNConvDepthwise::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
auto conv2D = mOp->main_as_Convolution2D();
|
||||
auto common = conv2D->common();
|
||||
|
@ -35,7 +186,8 @@ ErrorCode QNNConvDepthwise::onEncode(const std::vector<Tensor *> &inputs, const
|
|||
padTop = std::get<1>(pads); padBottom = std::get<3>(pads); padLeft = std::get<0>(pads); padRight = std::get<2>(pads);
|
||||
dilationH = common->dilateY(); dilationW = common->dilateX();
|
||||
}
|
||||
|
||||
|
||||
isWeightQuantSupported(inputs[0], oc);
|
||||
// create all tensors and params
|
||||
{
|
||||
std::vector<uint32_t> strideData = {(uint32_t)strideH, (uint32_t)strideW};
|
||||
|
@ -49,10 +201,24 @@ ErrorCode QNNConvDepthwise::onEncode(const std::vector<Tensor *> &inputs, const
|
|||
this->createParamScalar("max_value", 6.0f);
|
||||
}
|
||||
|
||||
this->createWeight(dataType, oc, kernelH, kernelW);
|
||||
this->createBias(dataType, oc);
|
||||
this->createWeightAndBias(dataType, inputs[0], oc, kernelH, kernelW);
|
||||
// dequant input and quant output
|
||||
if(mWeightQuant == false && dataType != QNN_DATATYPE_FLOAT_16 && dataType != QNN_DATATYPE_FLOAT_32){
|
||||
return this->onEncodeQuantDequantDepthConv(inputs[0], outputs[0], n, ic, oc);
|
||||
}
|
||||
|
||||
if (common->relu() || common->relu6()) {
|
||||
this->createStageTensor("ReluTensor", dataType, getNHWCShape(outputs[0]));
|
||||
Qnn_QuantizeParams_t quantize = DEFAULT_QUANTIZE_PARAMS;
|
||||
Qnn_ScaleOffset_t tScaleOffsetEncoding;
|
||||
auto quant = TensorUtils::getDescribe(outputs[0])->quantAttr.get();
|
||||
if(quant != nullptr && TensorUtils::getDescribe(outputs[0])->type == DataType_DT_INT8){
|
||||
quantize.encodingDefinition = QNN_DEFINITION_DEFINED;
|
||||
quantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;
|
||||
tScaleOffsetEncoding.scale = quant->scale;
|
||||
tScaleOffsetEncoding.offset = quant->zero;
|
||||
quantize.scaleOffsetEncoding = tScaleOffsetEncoding;
|
||||
}
|
||||
this->createStageTensor("ReluTensor", dataType, getNHWCShape(outputs[0]), quantize);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -112,43 +278,140 @@ ErrorCode QNNConvDepthwise::onEncode(const std::vector<Tensor *> &inputs, const
|
|||
|
||||
|
||||
|
||||
void QNNConvDepthwise::createWeight(Qnn_DataType_t dataType, int oc, int kernelH, int kernelW) {
|
||||
std::vector<float> weightData;
|
||||
const float* source = nullptr;
|
||||
int weightElementNum = 0;
|
||||
std::shared_ptr<ConvolutionCommon::Int8Common> quanWeight;
|
||||
ConvolutionCommon::getConvParameters(&quanWeight, mBackend, mOp, &source, &weightElementNum);
|
||||
// oc ic h w ---> h w ic oc
|
||||
weightData.resize(weightElementNum);
|
||||
convertWeight(source, (float *) weightData.data(), oc, kernelH, kernelW);
|
||||
this->createStaticFloatTensor("weight", dataType, {(uint32_t)kernelH, (uint32_t)kernelW, 1, (uint32_t)oc}, weightData.data());
|
||||
}
|
||||
void QNNConvDepthwise::createWeightAndBias(Qnn_DataType_t dataType, const Tensor *input, int oc, int kernelH, int kernelW) {
|
||||
if(mWeightQuant){
|
||||
Qnn_QuantizeParams_t weightQuantize{};
|
||||
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true);
|
||||
|
||||
// [TODO] Support asymmetric and other quantBits.
|
||||
MNN_ASSERT(!quanCommon->asymmetric);
|
||||
|
||||
void QNNConvDepthwise::createBias(Qnn_DataType_t dataType, int oc) {
|
||||
int biasElementNum = oc;
|
||||
std::vector<float> biasData;
|
||||
biasData.resize(biasElementNum, 0);
|
||||
auto bias = mOp->main_as_Convolution2D()->bias();
|
||||
if (nullptr != bias) {
|
||||
::memcpy((void *)biasData.data(), (void *)bias->data(), biasElementNum * sizeof(float));
|
||||
}
|
||||
this->createStaticFloatTensor("bias", dataType, {(uint32_t)oc}, biasData.data());
|
||||
}
|
||||
|
||||
|
||||
// oc, h, w ---> h, w, oc
|
||||
void QNNConvDepthwise::convertWeight(const float * src, float * dst, int oc, int kernelH, int kernelW) {
|
||||
for (int c = 0; c < oc; c++) {
|
||||
for (int h = 0; h < kernelH; h++) {
|
||||
for (int w = 0; w < kernelW; w++) {
|
||||
int srcOffset = w + kernelW * (h + kernelH * c);
|
||||
int dstOffset = c + oc * (w + kernelW * h);
|
||||
dst[dstOffset] = src[srcOffset];
|
||||
// create weight
|
||||
const int8_t * source = quanCommon->weight.get();
|
||||
std::vector<int8_t> quantWeightData(oc * kernelH * kernelW, 0);
|
||||
if(quanCommon->canUseInt4){
|
||||
for (int c = 0; c < oc; c++) {
|
||||
for (int h = 0; h < kernelH; h++) {
|
||||
for (int w = 0; w < kernelW; w++) {
|
||||
int srcOffset = w + kernelW * (h + kernelH * c);
|
||||
int dstOffset = c + oc * (w + kernelW * h);
|
||||
if(srcOffset % 2 == 0){
|
||||
quantWeightData[dstOffset] = ((source[srcOffset / 2] >> 4) & 0x0f) - 8;
|
||||
}else{
|
||||
quantWeightData[dstOffset] = (source[srcOffset / 2] & 0x0f) - 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}else{
|
||||
convertWeight(source, (int8_t *) quantWeightData.data(), oc, kernelH, kernelW);
|
||||
}
|
||||
|
||||
mDequantAlpha = quanCommon->alpha.get();
|
||||
if(quanCommon->canUseInt4){
|
||||
weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
|
||||
weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET;
|
||||
Qnn_BwAxisScaleOffset_t weightBWAxisScaleOffsetEncoding{};
|
||||
weightBWAxisScaleOffsetEncoding.bitwidth = 4;
|
||||
weightBWAxisScaleOffsetEncoding.axis = 3;
|
||||
weightBWAxisScaleOffsetEncoding.numElements = oc;
|
||||
mScale.resize(oc);
|
||||
std::vector<int32_t> OffsetData(oc);
|
||||
for (int i = 0; i < oc; i++) {
|
||||
mScale[i] = mDequantAlpha[i];
|
||||
}
|
||||
weightBWAxisScaleOffsetEncoding.scales = mScale.data();
|
||||
weightQuantize.bwAxisScaleOffsetEncoding = weightBWAxisScaleOffsetEncoding;
|
||||
|
||||
this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, 1, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize);
|
||||
std::function<void()> mReleaseWeightScaleOffset = [&](){
|
||||
std::vector<float>().swap(mScale);
|
||||
};
|
||||
mBackend->pushReleaseFunc(mReleaseWeightScaleOffset);
|
||||
}else{
|
||||
weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
|
||||
weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET;
|
||||
Qnn_AxisScaleOffset_t weightAxisScaleOffsetEncoding{};
|
||||
weightAxisScaleOffsetEncoding.axis = 3;
|
||||
weightAxisScaleOffsetEncoding.numScaleOffsets = oc;
|
||||
mScaleOffsetData.resize(oc);
|
||||
for (int i = 0; i < oc; i++) {
|
||||
mScaleOffsetData[i].scale = mDequantAlpha[i];
|
||||
mScaleOffsetData[i].offset = 0;
|
||||
}
|
||||
weightAxisScaleOffsetEncoding.scaleOffset = mScaleOffsetData.data();
|
||||
weightQuantize.axisScaleOffsetEncoding = weightAxisScaleOffsetEncoding;
|
||||
|
||||
this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, 1, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize);
|
||||
|
||||
std::function<void()> mReleaseWeightScaleOffset = [&](){
|
||||
std::vector<Qnn_ScaleOffset_t>().swap(mScaleOffsetData);
|
||||
};
|
||||
mBackend->pushReleaseFunc(mReleaseWeightScaleOffset);
|
||||
}
|
||||
// create bias
|
||||
{
|
||||
float inputScale = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.scale;
|
||||
int inputOffset = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.offset;
|
||||
std::vector<int> biasData;
|
||||
biasData.resize(oc, 0);
|
||||
|
||||
Qnn_QuantizeParams_t biasQuantize{};
|
||||
biasQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
|
||||
biasQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET;
|
||||
Qnn_AxisScaleOffset_t biasAxisScaleOffsetEncoding{};
|
||||
biasAxisScaleOffsetEncoding.axis = 0;
|
||||
biasAxisScaleOffsetEncoding.numScaleOffsets = oc;
|
||||
mBiasScaleOffsetData.resize(oc);
|
||||
|
||||
auto bias = mOp->main_as_Convolution2D()->bias();
|
||||
auto biasPtr = (float*)bias->data();
|
||||
if (nullptr != bias) {
|
||||
for(int i = 0; i < oc; ++i){
|
||||
float biasScale = inputScale * mDequantAlpha[i];
|
||||
mBiasScaleOffsetData[i].scale = 0.f;
|
||||
mBiasScaleOffsetData[i].offset = 0;
|
||||
if(fabs(biasPtr[i]) < 0.000001 || fabs(biasScale) < 0.000001){
|
||||
biasData[i] = 0;
|
||||
} else{
|
||||
biasData[i] = (int)(biasPtr[i] / (biasScale));
|
||||
}
|
||||
}
|
||||
}
|
||||
biasAxisScaleOffsetEncoding.scaleOffset = mBiasScaleOffsetData.data();
|
||||
biasQuantize.axisScaleOffsetEncoding = biasAxisScaleOffsetEncoding;
|
||||
|
||||
this->createStaticTensor("bias", QNN_DATATYPE_SFIXED_POINT_32, {(uint32_t)oc}, biasData.data(), biasQuantize);
|
||||
std::function<void()> mReleaseBiasScaleOffset = [&](){
|
||||
std::vector<Qnn_ScaleOffset_t>().swap(mBiasScaleOffsetData);
|
||||
};
|
||||
mBackend->pushReleaseFunc(mReleaseBiasScaleOffset);
|
||||
}
|
||||
}else{
|
||||
Qnn_DataType_t floatDatatype = QNN_DATATYPE_FLOAT_32;
|
||||
if(mBackend->getUseFP16()){
|
||||
floatDatatype = QNN_DATATYPE_FLOAT_16;
|
||||
}
|
||||
std::vector<float> weightData;
|
||||
const float* source = nullptr;
|
||||
int weightElementNum = 0;
|
||||
std::shared_ptr<ConvolutionCommon::Int8Common> quanWeight;
|
||||
ConvolutionCommon::getConvParameters(&quanWeight, mBackend, mOp, &source, &weightElementNum);
|
||||
// oc ic h w ---> h w ic oc
|
||||
weightData.resize(weightElementNum);
|
||||
convertWeight(source, (float *) weightData.data(), oc, kernelH, kernelW);
|
||||
this->createStaticFloatTensor("weight", floatDatatype, {(uint32_t)kernelH, (uint32_t)kernelW, 1, (uint32_t)oc}, weightData.data());
|
||||
|
||||
// create bias
|
||||
std::vector<float> biasData;
|
||||
biasData.resize(oc, 0);
|
||||
auto bias = mOp->main_as_Convolution2D()->bias();
|
||||
if (nullptr != bias) {
|
||||
::memcpy((void *)biasData.data(), (void *)bias->data(), oc * sizeof(float));
|
||||
}
|
||||
this->createStaticFloatTensor("bias", floatDatatype, {(uint32_t)oc}, biasData.data());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class QNNConvDepthwiseCreator : public QnnBackend::Creator {
|
||||
public:
|
||||
|
|
|
@ -19,10 +19,28 @@ class QNNConvDepthwise : public QNNCommonExecution {
|
|||
public:
|
||||
QNNConvDepthwise(Backend *backend, const Op *op) : QNNCommonExecution(backend, op) {}
|
||||
virtual ErrorCode onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
ErrorCode onEncodeQuantDequantDepthConv(Tensor *input, Tensor *output, const int n, const int ic, const int oc);
|
||||
private:
|
||||
void createWeight(Qnn_DataType_t dataType, int oc, int kernelH, int kernelW);
|
||||
void createBias(Qnn_DataType_t dataType, int oc);
|
||||
void convertWeight(const float * src, float * dst, int oc, int kernelH, int kernelW);
|
||||
template <typename T>
|
||||
void convertWeight(const T * src, T * dst, int oc, int kernelH, int kernelW) {
|
||||
for (int c = 0; c < oc; c++) {
|
||||
for (int h = 0; h < kernelH; h++) {
|
||||
for (int w = 0; w < kernelW; w++) {
|
||||
int srcOffset = w + kernelW * (h + kernelH * c);
|
||||
int dstOffset = c + oc * (w + kernelW * h);
|
||||
dst[dstOffset] = src[srcOffset];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
void isWeightQuantSupported(const Tensor *input, const int oc);
|
||||
void createWeightAndBias(Qnn_DataType_t dataType, const Tensor *input, int oc, int kernelH, int kernelW);
|
||||
std::vector<float> mScale;
|
||||
std::vector<Qnn_ScaleOffset_t> mScaleOffsetData;
|
||||
std::vector<Qnn_ScaleOffset_t> mBiasScaleOffsetData;
|
||||
std::vector<uint8_t> mBlockScale;
|
||||
float *mDequantAlpha = nullptr;
|
||||
bool mWeightQuant = false;
|
||||
};
|
||||
|
||||
} // end namespace QNN
|
||||
|
|
|
@ -21,6 +21,63 @@ static std::pair<int, int> closest_factors(int n) {
|
|||
}
|
||||
return {1, n};
|
||||
}
|
||||
|
||||
void QNNConvolution::isWeightQuantSupported(const Tensor *input, const int ic, const int oc){
|
||||
Qnn_DataType_t dataType = mBackend->getNativeTensor(input)->v1.dataType;
|
||||
if(mOp->main_as_Convolution2D()->quanParameter() == nullptr){
|
||||
mWeightQuant = false;
|
||||
return;
|
||||
}else{
|
||||
bool hasBias = false;
|
||||
auto bias = mOp->main_as_Convolution2D()->bias();
|
||||
auto biasPtr = (float*)bias->data();
|
||||
for(int i = 0; i < oc; ++i){
|
||||
if(biasPtr[i] != 0.0f){
|
||||
hasBias = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true);
|
||||
int totalCount = quanCommon->alpha.size();
|
||||
mBlockSize = totalCount / oc;
|
||||
if(quanCommon->asymmetric){
|
||||
// not support asymmetric and mBlockSize > 1 results incorrect now
|
||||
mWeightQuant = false;
|
||||
return;
|
||||
}
|
||||
|
||||
if(dataType == QNN_DATATYPE_FLOAT_16 || dataType == QNN_DATATYPE_FLOAT_32){
|
||||
if(mIsMatMul && mBlockSize == 1){
|
||||
mWeightQuant = true;
|
||||
}else{
|
||||
mWeightQuant = false;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
float inputScale = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.scale;
|
||||
int inputOffset = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.offset;
|
||||
if(inputOffset == 0){
|
||||
mWeightQuant = true;
|
||||
}else{
|
||||
if(hasBias){
|
||||
mWeightQuant = false;
|
||||
}else{
|
||||
mWeightQuant = true;
|
||||
}
|
||||
}
|
||||
|
||||
if(mBlockSize > 1 && mWeightQuant){
|
||||
if(mIs1x1Conv && hasBias == false && (ic / mBlockSize) >= 16){
|
||||
mWeightQuant = true;
|
||||
}else{
|
||||
mWeightQuant = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ErrorCode QNNConvolution::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
auto conv2D = mOp->main_as_Convolution2D();
|
||||
auto common = conv2D->common();
|
||||
|
@ -46,32 +103,16 @@ ErrorCode QNNConvolution::onEncode(const std::vector<Tensor *> &inputs, const st
|
|||
dilationH = common->dilateY(); dilationW = common->dilateX();
|
||||
group = common->group();
|
||||
}
|
||||
|
||||
const float * weightSource = nullptr;
|
||||
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
|
||||
if (mOp->main_as_Convolution2D()->quanParameter()) {
|
||||
bool forceFloat = (common->kernelX() == 1 && common->kernelY() == 1) ? false : true;
|
||||
|
||||
quanCommon = ConvolutionCommon::load(mOp, this->backend(), forceFloat);
|
||||
if (quanCommon->weightFloat.get() == nullptr) {
|
||||
// [TODO] Support asymmetric and other quantBits.
|
||||
// isQuantWeight && symmetric quantization && int8 quantization && 1x1 conv
|
||||
if (quanCommon->asymmetric || quanCommon->canUseInt4) {
|
||||
return NOT_SUPPORT;
|
||||
}
|
||||
return this->onEncodeQuant(inputs[0], outputs[0], n, ih, iw, ic, oc, quanCommon);
|
||||
} else {
|
||||
weightSource = quanCommon->weightFloat.get();
|
||||
}
|
||||
} else {
|
||||
int weightElementNum;
|
||||
ConvolutionCommon::getConvParameters(&quanCommon, mBackend, mOp, &weightSource, &weightElementNum);
|
||||
mIs1x1Conv = kernelW==1 && strideH==1 && \
|
||||
strideW==1 && dilationH==1 && dilationW==1 && group==1 && \
|
||||
padTop==0 && padBottom==0 && padLeft==0 && padRight==0;
|
||||
mIsMatMul = ih==1 && iw==1 && oh==1 && ow==1 && mIs1x1Conv;
|
||||
isWeightQuantSupported(inputs[0], ic, oc);
|
||||
|
||||
if(mIsMatMul && mWeightQuant && (dataType == QNN_DATATYPE_FLOAT_16 || dataType == QNN_DATATYPE_FLOAT_32)){
|
||||
return onEncodeFpAIntBMatMul(inputs[0], outputs[0], n, ih, iw, ic, oc);
|
||||
}
|
||||
|
||||
#ifdef QNN_VERBOSE
|
||||
MNN_PRINT("n:%d, ih:%d, iw:%d, ic:%d, oh:%d, ow:%d, oc:%d, kernelH:%d, kernelW:%d, dilationH:%d, dilationW:%d, strideH:%d, strideW:%d, group:%d, pad:%d %d %d %d\n", n, ih, iw, ic, oh, ow, oc, kernelH, kernelW, dilationH, \
|
||||
dilationW, strideH, strideW, group, padTop, padBottom, padLeft, padRight);
|
||||
#endif
|
||||
|
||||
// create all tensors and params
|
||||
{
|
||||
std::vector<uint32_t> strideData = {(uint32_t)strideH, (uint32_t)strideW};
|
||||
|
@ -85,12 +126,26 @@ ErrorCode QNNConvolution::onEncode(const std::vector<Tensor *> &inputs, const st
|
|||
this->createParamScalar("min_value", 0.0f);
|
||||
this->createParamScalar("max_value", 6.0f);
|
||||
}
|
||||
}
|
||||
|
||||
this->createWeight(dataType, oc, ic, kernelH, kernelW, group, weightSource);
|
||||
this->createBias(dataType, oc);
|
||||
if (common->relu() || common->relu6()) {
|
||||
this->createStageTensor("ReluTensor", dataType, getNHWCShape(outputs[0]));
|
||||
this->createWeightAndBias(dataType, inputs[0], oc, ic, kernelH, kernelW, group);
|
||||
// dequant input and quant output
|
||||
if(mWeightQuant == false && dataType != QNN_DATATYPE_FLOAT_16 && dataType != QNN_DATATYPE_FLOAT_32){
|
||||
return this->onEncodeQuantDequantConv(inputs[0], outputs[0], n, ic, oc);
|
||||
}
|
||||
|
||||
if (common->relu() || common->relu6()) {
|
||||
Qnn_QuantizeParams_t quantize = DEFAULT_QUANTIZE_PARAMS;
|
||||
Qnn_ScaleOffset_t tScaleOffsetEncoding;
|
||||
auto quant = TensorUtils::getDescribe(outputs[0])->quantAttr.get();
|
||||
if(quant != nullptr && TensorUtils::getDescribe(outputs[0])->type == DataType_DT_INT8){
|
||||
quantize.encodingDefinition = QNN_DEFINITION_DEFINED;
|
||||
quantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;
|
||||
tScaleOffsetEncoding.scale = quant->scale;
|
||||
tScaleOffsetEncoding.offset = quant->zero;
|
||||
quantize.scaleOffsetEncoding = tScaleOffsetEncoding;
|
||||
}
|
||||
this->createStageTensor("ReluTensor", dataType, getNHWCShape(outputs[0]), quantize);
|
||||
}
|
||||
|
||||
// add nodes
|
||||
|
@ -131,13 +186,34 @@ ErrorCode QNNConvolution::onEncode(const std::vector<Tensor *> &inputs, const st
|
|||
}
|
||||
|
||||
} else {
|
||||
bool isMatmul = ih==1 && iw==1 && oh==1 && ow==1 && kernelH==1 && kernelW==1 && strideH==1 && \
|
||||
strideW==1 && dilationH==1 && dilationW==1 && group==1 && \
|
||||
padTop==0 && padBottom==0 && padLeft==0 && padRight==0;
|
||||
if(isMatmul && n > 1) {
|
||||
if(mIsMatMul && n > 1) {
|
||||
auto num = closest_factors(n);
|
||||
this->createStageTensor("InputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, ic}));
|
||||
this->createStageTensor("OutputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, oc}));
|
||||
{
|
||||
Qnn_QuantizeParams_t quantize = DEFAULT_QUANTIZE_PARAMS;
|
||||
Qnn_ScaleOffset_t tScaleOffsetEncoding;
|
||||
auto quant = TensorUtils::getDescribe(inputs[0])->quantAttr.get();
|
||||
if(quant != nullptr && TensorUtils::getDescribe(inputs[0])->type == DataType_DT_INT8){
|
||||
quantize.encodingDefinition = QNN_DEFINITION_DEFINED;
|
||||
quantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;
|
||||
tScaleOffsetEncoding.scale = quant->scale;
|
||||
tScaleOffsetEncoding.offset = quant->zero;
|
||||
quantize.scaleOffsetEncoding = tScaleOffsetEncoding;
|
||||
}
|
||||
this->createStageTensor("InputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, ic}), quantize);
|
||||
}
|
||||
{
|
||||
Qnn_QuantizeParams_t quantize = DEFAULT_QUANTIZE_PARAMS;
|
||||
Qnn_ScaleOffset_t tScaleOffsetEncoding;
|
||||
auto quant = TensorUtils::getDescribe(outputs[0])->quantAttr.get();
|
||||
if(quant != nullptr && TensorUtils::getDescribe(outputs[0])->type == DataType_DT_INT8){
|
||||
quantize.encodingDefinition = QNN_DEFINITION_DEFINED;
|
||||
quantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_SCALE_OFFSET;
|
||||
tScaleOffsetEncoding.scale = quant->scale;
|
||||
tScaleOffsetEncoding.offset = quant->zero;
|
||||
quantize.scaleOffsetEncoding = tScaleOffsetEncoding;
|
||||
}
|
||||
this->createStageTensor("OutputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, oc}), quantize);
|
||||
}
|
||||
#ifdef QNN_VERBOSE
|
||||
MNN_PRINT("Matmul2Conv, start reshape batch:%d -> %dx%d\n", n, num.first, num.second);
|
||||
#endif
|
||||
|
@ -208,52 +284,273 @@ ErrorCode QNNConvolution::onEncode(const std::vector<Tensor *> &inputs, const st
|
|||
return NO_ERROR;
|
||||
}
|
||||
|
||||
ErrorCode QNNConvolution::onEncodeQuantDequantConv(Tensor *input, Tensor *output, const int n, const int ic, const int oc) {
|
||||
auto conv2D = mOp->main_as_Convolution2D();
|
||||
auto common = conv2D->common();
|
||||
Qnn_DataType_t dataType = QNN_DATATYPE_FLOAT_32;
|
||||
if(mBackend->getUseFP16()){
|
||||
dataType = QNN_DATATYPE_FLOAT_16;
|
||||
}
|
||||
|
||||
// create dequant input stage tensor
|
||||
this->createStageTensor("DequantInput", dataType, getNHWCShape(input)); // mTempTensorWrappers[2]
|
||||
this->createStageTensor("QuantOutput", dataType, getNHWCShape(output)); // mTempTensorWrappers[3]
|
||||
|
||||
// add nodes
|
||||
{
|
||||
// dequant input
|
||||
{
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Dequantize";
|
||||
std::string name = mNodeName + "_dequant_input";
|
||||
|
||||
mInputs.push_back(*(mBackend->getNativeTensor(input))); // input
|
||||
mOutputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
|
||||
if (common->relu() || common->relu6()) {
|
||||
this->createStageTensor("ReluTensor", dataType, getNHWCShape(output)); // mTempTensorWrappers[4]
|
||||
// Stage one
|
||||
{
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Conv2d";
|
||||
std::string name = mNodeName + "_conv";
|
||||
mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride
|
||||
mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount
|
||||
mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation
|
||||
mParams.push_back(*(mParamScalarWrappers[0]->getNativeParam())); // group
|
||||
|
||||
mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
|
||||
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight
|
||||
mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias
|
||||
|
||||
mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
|
||||
ErrorCode QNNConvolution::onEncodeQuant(Tensor * input, Tensor * output, int n, int h, int w, int ic, int oc, std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon) {
|
||||
// Stage two
|
||||
{
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = common->relu6() ? "ReluMinMax" : "Relu";
|
||||
std::string name = mNodeName + "_relu";
|
||||
if (common->relu6()) {
|
||||
mParams.push_back(*(mParamScalarWrappers[1]->getNativeParam())); // min_value
|
||||
mParams.push_back(*(mParamScalarWrappers[2]->getNativeParam())); // max_value
|
||||
}
|
||||
mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor
|
||||
mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
|
||||
} else {
|
||||
if(mIsMatMul && n > 1) {
|
||||
auto num = closest_factors(n);
|
||||
this->createStageTensor("InputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, ic})); // mTempTensorWrappers[4]
|
||||
this->createStageTensor("OutputReshapeTensor", dataType, std::vector<int>({1, num.first, num.second, oc})); // mTempTensorWrappers[5]
|
||||
#ifdef QNN_VERBOSE
|
||||
MNN_PRINT("Matmul2Conv, start reshape batch:%d -> %dx%d\n", n, num.first, num.second);
|
||||
#endif
|
||||
// reshape input
|
||||
{
|
||||
std::string name = mNodeName + "_input_reshape";
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Reshape";
|
||||
|
||||
mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
|
||||
mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // InputReshapeTensor
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
// conv2d
|
||||
{
|
||||
std::string name = mNodeName;
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Conv2d";
|
||||
|
||||
mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride
|
||||
mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount
|
||||
mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation
|
||||
mParams.push_back(*(mParamScalarWrappers[0]->getNativeParam())); // group
|
||||
|
||||
mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // InputReshapeTensor
|
||||
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight
|
||||
mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias
|
||||
|
||||
mOutputs.push_back(*(mTempTensorWrappers[5]->getNativeTensor())); // OutputReshapeTensor
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
|
||||
// reshape output
|
||||
{
|
||||
std::string name = mNodeName + "_output_reshape";
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Reshape";
|
||||
|
||||
mInputs.push_back(*(mTempTensorWrappers[5]->getNativeTensor())); // OutputReshapeTensor
|
||||
mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
} else{
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Conv2d";
|
||||
mParams.push_back(*(mParamTensorWrappers[0]->getNativeParam())); // stride
|
||||
mParams.push_back(*(mParamTensorWrappers[1]->getNativeParam())); // pad_amount
|
||||
mParams.push_back(*(mParamTensorWrappers[2]->getNativeParam())); // dilation
|
||||
mParams.push_back(*(mParamScalarWrappers[0]->getNativeParam())); // group
|
||||
|
||||
mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // DequantInput
|
||||
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // weight
|
||||
mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // bias
|
||||
|
||||
mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, mNodeName.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
}
|
||||
|
||||
// Quant output
|
||||
{
|
||||
auto QuantOutputTensor = mTempTensorWrappers[3]->getNativeTensor();
|
||||
if(mBackend->getUseFP16()){
|
||||
this->createStageTensor("CastOutput", QNN_DATATYPE_FLOAT_32, getNHWCShape(output));
|
||||
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Cast";
|
||||
std::string name = mNodeName + "_Cast_Output";
|
||||
|
||||
mInputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // QuantOutput
|
||||
mOutputs.push_back(*(mTempTensorWrappers.back()->getNativeTensor())); // CastOutput
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
QuantOutputTensor = mTempTensorWrappers.back()->getNativeTensor();
|
||||
}
|
||||
{
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Quantize";
|
||||
std::string name = mNodeName + "_Quant_Output";
|
||||
|
||||
mInputs.push_back(*(QuantOutputTensor)); // stage tensor
|
||||
mOutputs.push_back(*(mBackend->getNativeTensor(output))); // output
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
}
|
||||
}
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
ErrorCode QNNConvolution::onEncodeFpAIntBMatMul(Tensor * input, Tensor * output, int n, int h, int w, int ic, int oc) {
|
||||
// create parameters and stage tensors
|
||||
auto conv2D = mOp->main_as_Convolution2D();
|
||||
auto common = conv2D->common();
|
||||
Qnn_DataType_t dataType = mBackend->getNativeTensor(input)->v1.dataType;
|
||||
{
|
||||
bool transposeWeightFlag = true;
|
||||
this->createParamScalar("transpose_in1", transposeWeightFlag);
|
||||
|
||||
|
||||
std::vector<uint32_t> tempInputShape = {(uint32_t) n * h * w , (uint32_t) ic};
|
||||
std::vector<uint32_t> tempOutputShape = {(uint32_t) n * h * w , (uint32_t) oc};
|
||||
this->createStageTensor("tempInput", QNN_DATATYPE_FLOAT_16, tempInputShape);
|
||||
this->createStageTensor("tempOutput", QNN_DATATYPE_FLOAT_16, tempOutputShape);
|
||||
this->createStageTensor("tempInput", dataType, tempInputShape);
|
||||
this->createStageTensor("tempOutput", dataType, tempOutputShape);
|
||||
|
||||
// create weight
|
||||
const int8_t * source = quanCommon->weight.get();
|
||||
std::vector<int8_t> quantWeightData(oc * ic, 0);
|
||||
::memcpy(quantWeightData.data(), source, oc * ic * sizeof(int8_t));
|
||||
|
||||
float * dequantAlpha = quanCommon->alpha.get();
|
||||
|
||||
Qnn_QuantizeParams_t weightQuantize{};
|
||||
weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
|
||||
weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET;
|
||||
Qnn_AxisScaleOffset_t weightAxisScaleOffsetEncoding{};
|
||||
weightAxisScaleOffsetEncoding.axis = 0;
|
||||
weightAxisScaleOffsetEncoding.numScaleOffsets = oc;
|
||||
std::vector<Qnn_ScaleOffset_t> scaleOffsetData(oc);
|
||||
for (int i = 0; i < oc; i++) {
|
||||
if (quanCommon->asymmetric) {
|
||||
scaleOffsetData[i].scale = dequantAlpha[2 * i + 1];
|
||||
// scaleOffsetData[i].offset = (int) dequantAlpha[2 * i + 0];
|
||||
scaleOffsetData[i].offset = 0;
|
||||
} else {
|
||||
scaleOffsetData[i].scale = dequantAlpha[i];
|
||||
scaleOffsetData[i].offset = 0;
|
||||
// create weight and bias
|
||||
{
|
||||
Qnn_QuantizeParams_t weightQuantize{};
|
||||
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true);
|
||||
MNN_ASSERT(!quanCommon->asymmetric);
|
||||
const int8_t * source = quanCommon->weight.get();
|
||||
std::vector<int8_t> quantWeightData(oc * ic, 0);
|
||||
if(quanCommon->canUseInt4){
|
||||
for (int o = 0; o < oc; o++) {
|
||||
for (int i = 0; i < ic; i++) {
|
||||
uint32_t srcOffset = o * ic + i;
|
||||
uint32_t dstOffset = srcOffset;
|
||||
if(srcOffset % 2 == 0){
|
||||
quantWeightData[dstOffset] = ((source[srcOffset / 2] >> 4) & 0x0f) - 8;
|
||||
}else{
|
||||
quantWeightData[dstOffset] = (source[srcOffset / 2] & 0x0f) - 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
}else{
|
||||
::memcpy(quantWeightData.data(), source, oc * ic * sizeof(int8_t));
|
||||
}
|
||||
mDequantAlpha = quanCommon->alpha.get();
|
||||
int totalCount = quanCommon->alpha.size();
|
||||
mBlockSize = totalCount / oc;
|
||||
int blockNum = ic / mBlockSize;
|
||||
if(quanCommon->canUseInt4){
|
||||
weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
|
||||
weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET;
|
||||
Qnn_BwAxisScaleOffset_t weightBWAxisScaleOffsetEncoding{};
|
||||
weightBWAxisScaleOffsetEncoding.bitwidth = 4;
|
||||
weightBWAxisScaleOffsetEncoding.axis = 0;
|
||||
weightBWAxisScaleOffsetEncoding.numElements = oc;
|
||||
mScale.resize(oc);
|
||||
std::vector<int32_t> OffsetData(oc);
|
||||
for (int i = 0; i < oc; i++) {
|
||||
mScale[i] = mDequantAlpha[i];
|
||||
}
|
||||
weightBWAxisScaleOffsetEncoding.scales = mScale.data();
|
||||
weightQuantize.bwAxisScaleOffsetEncoding = weightBWAxisScaleOffsetEncoding;
|
||||
|
||||
this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)oc, (uint32_t)ic}, (void *) quantWeightData.data(), weightQuantize);
|
||||
std::function<void()> mReleaseWeightScaleOffset = [&](){
|
||||
std::vector<float>().swap(mScale);
|
||||
};
|
||||
mBackend->pushReleaseFunc(mReleaseWeightScaleOffset);
|
||||
}else{
|
||||
weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
|
||||
weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET;
|
||||
Qnn_AxisScaleOffset_t weightAxisScaleOffsetEncoding{};
|
||||
weightAxisScaleOffsetEncoding.axis = 0;
|
||||
weightAxisScaleOffsetEncoding.numScaleOffsets = oc;
|
||||
mScaleOffsetData.resize(oc);
|
||||
for (int i = 0; i < oc; i++) {
|
||||
mScaleOffsetData[i].scale = mDequantAlpha[i];
|
||||
mScaleOffsetData[i].offset = 0;
|
||||
}
|
||||
weightAxisScaleOffsetEncoding.scaleOffset = mScaleOffsetData.data();
|
||||
weightQuantize.axisScaleOffsetEncoding = weightAxisScaleOffsetEncoding;
|
||||
|
||||
this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)oc, (uint32_t)ic}, (void *) quantWeightData.data(), weightQuantize);
|
||||
std::function<void()> mReleaseWeightScaleOffset = [&](){
|
||||
std::vector<Qnn_ScaleOffset_t>().swap(mScaleOffsetData);
|
||||
};
|
||||
mBackend->pushReleaseFunc(mReleaseWeightScaleOffset);
|
||||
}
|
||||
//create bias
|
||||
this->createBias(dataType, oc, input, quanCommon);
|
||||
}
|
||||
|
||||
if (common->relu6()) {
|
||||
this->createParamScalar("min_value", 0.0f);
|
||||
this->createParamScalar("max_value", 6.0f);
|
||||
}
|
||||
if (common->relu() || common->relu6()) {
|
||||
this->createStageTensor("ReluTensor", dataType, getNHWCShape(output));
|
||||
}
|
||||
weightAxisScaleOffsetEncoding.scaleOffset = scaleOffsetData.data();
|
||||
weightQuantize.axisScaleOffsetEncoding = weightAxisScaleOffsetEncoding;
|
||||
|
||||
this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t) oc, (uint32_t) ic}, (void *) quantWeightData.data(), weightQuantize);
|
||||
}
|
||||
|
||||
// Stage One: reshape input
|
||||
{
|
||||
mNodeType = "Reshape";
|
||||
std::string name = mNodeName + "_reshapeOutput";
|
||||
std::string name = mNodeName + "_reshapeInput";
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
|
@ -273,6 +570,7 @@ ErrorCode QNNConvolution::onEncodeQuant(Tensor * input, Tensor * output, int n,
|
|||
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // tempInput
|
||||
// mInputs.push_back(*(mBackend->getNativeTensor(input)));
|
||||
mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor())); // weight
|
||||
mInputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); // bias
|
||||
mOutputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor())); // tempOutput
|
||||
// mOutputs.push_back(*(mBackend->getNativeTensor(output)));
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
|
@ -286,48 +584,217 @@ ErrorCode QNNConvolution::onEncodeQuant(Tensor * input, Tensor * output, int n,
|
|||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor()));
|
||||
mOutputs.push_back(*(mBackend->getNativeTensor(output)));
|
||||
if (common->relu() || common->relu6()){
|
||||
mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); //ReluTensor
|
||||
}else{
|
||||
mOutputs.push_back(*(mBackend->getNativeTensor(output)));
|
||||
}
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
|
||||
// Stage Four: relu or relu6
|
||||
if (common->relu() || common->relu6()){
|
||||
mNodeType.clear();
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = common->relu6() ? "ReluMinMax" : "Relu";
|
||||
std::string name = mNodeName + "_relu";
|
||||
if (common->relu6()) {
|
||||
mParams.push_back(*(mParamScalarWrappers[1]->getNativeParam())); // min_value
|
||||
mParams.push_back(*(mParamScalarWrappers[2]->getNativeParam())); // max_value
|
||||
}
|
||||
mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // ReluTensor
|
||||
mOutputs.push_back(*(mBackend->getNativeTensor(output))); // output
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
bool QNNConvolution::createWeightAndBias(Qnn_DataType_t dataType, const Tensor *input, int oc, int ic, int kernelH, int kernelW, int group) {
|
||||
if(mWeightQuant){
|
||||
Qnn_QuantizeParams_t weightQuantize{};
|
||||
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon = ConvolutionCommon::load(mOp, this->backend(), false, true);
|
||||
if(quanCommon->asymmetric) {
|
||||
MNN_ERROR("[Error]: Qnn weight quant only support symmetric currently\n");
|
||||
return false;
|
||||
}
|
||||
const int8_t * source = quanCommon->weight.get();
|
||||
std::vector<int8_t> quantWeightData(oc * (ic / group) * kernelH * kernelW, 0);
|
||||
if(quanCommon->canUseInt4){
|
||||
for (int o = 0; o < oc; o++) {
|
||||
for (int i = 0; i < ic/group; i++) {
|
||||
for (int h = 0; h < kernelH; h++) {
|
||||
for (int w = 0; w < kernelW; w++) {
|
||||
uint32_t srcOffset = w + kernelW * (h + kernelH * (i + ic/group * o));
|
||||
uint32_t dstOffset = o + oc * (i + ic/group * (w + kernelW * h));
|
||||
if(srcOffset % 2 == 0){
|
||||
quantWeightData[dstOffset] = ((source[srcOffset / 2] >> 4) & 0x0f) - 8;
|
||||
}else{
|
||||
quantWeightData[dstOffset] = (source[srcOffset / 2] & 0x0f) - 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}else{
|
||||
convertWeight(source, (int8_t *) quantWeightData.data(), oc, ic/group, kernelH, kernelW);
|
||||
}
|
||||
mDequantAlpha = quanCommon->alpha.get();
|
||||
int totalCount = quanCommon->alpha.size();
|
||||
mBlockSize = totalCount / oc;
|
||||
// Todo: result is wrong, need to verify
|
||||
if(mBlockSize > 1){
|
||||
Qnn_QuantizeParams_t weightQuantize{};
|
||||
weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
|
||||
weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION;
|
||||
|
||||
weightBlockwiseExpansionEncoding.axis = 3;
|
||||
weightBlockwiseExpansionEncoding.numBlocksPerAxis = mBlockSize;
|
||||
weightBlockwiseExpansionEncoding.blockScaleBitwidth = 4;
|
||||
weightBlockwiseExpansionEncoding.blockScaleStorageType = QNN_BLOCKWISE_EXPANSION_BITWIDTH_SCALE_STORAGE_8;
|
||||
mBlockScale.resize(oc * mBlockSize);
|
||||
mScaleOffsetData.resize(oc);
|
||||
for (int i = 0; i < oc; i++) {
|
||||
float maxscale = -MAXFLOAT;
|
||||
for(int j = 0; j < mBlockSize; ++j){
|
||||
if(mDequantAlpha[i * mBlockSize + j] > maxscale){
|
||||
maxscale = mDequantAlpha[i * mBlockSize + j];
|
||||
}
|
||||
}
|
||||
float blockScale = maxscale / 16.0f;
|
||||
for(int j = 0; j < mBlockSize; ++j){
|
||||
int quantBlock = round(mDequantAlpha[i * mBlockSize + j] / blockScale);
|
||||
mBlockScale[i * mBlockSize + j] = (uint8_t)std::min(std::max(quantBlock, 1), 16);
|
||||
}
|
||||
mScaleOffsetData[i].scale = blockScale;
|
||||
mScaleOffsetData[i].offset = 0;
|
||||
}
|
||||
weightBlockwiseExpansionEncoding.scaleOffsets = mScaleOffsetData.data();
|
||||
weightBlockwiseExpansionEncoding.blocksScale8 = mBlockScale.data();
|
||||
weightQuantize.blockwiseExpansion = &weightBlockwiseExpansionEncoding;
|
||||
this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize);
|
||||
std::function<void()> mReleaseWeightScaleOffset = [&](){
|
||||
std::vector<Qnn_ScaleOffset_t>().swap(mScaleOffsetData);
|
||||
};
|
||||
mBackend->pushReleaseFunc(mReleaseWeightScaleOffset);
|
||||
std::function<void()> mReleaseBlockScale = [&](){
|
||||
std::vector<uint8_t>().swap(mBlockScale);
|
||||
};
|
||||
mBackend->pushReleaseFunc(mReleaseBlockScale);
|
||||
}else if(quanCommon->canUseInt4){
|
||||
weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
|
||||
weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET;
|
||||
Qnn_BwAxisScaleOffset_t weightBWAxisScaleOffsetEncoding{};
|
||||
weightBWAxisScaleOffsetEncoding.bitwidth = 4;
|
||||
weightBWAxisScaleOffsetEncoding.axis = 3;
|
||||
weightBWAxisScaleOffsetEncoding.numElements = oc;
|
||||
mScale.resize(oc);
|
||||
std::vector<int32_t> OffsetData(oc);
|
||||
for (int i = 0; i < oc; i++) {
|
||||
mScale[i] = mDequantAlpha[i];
|
||||
}
|
||||
weightBWAxisScaleOffsetEncoding.scales = mScale.data();
|
||||
weightQuantize.bwAxisScaleOffsetEncoding = weightBWAxisScaleOffsetEncoding;
|
||||
|
||||
this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize);
|
||||
std::function<void()> mReleaseWeightScaleOffset = [&](){
|
||||
std::vector<float>().swap(mScale);
|
||||
};
|
||||
mBackend->pushReleaseFunc(mReleaseWeightScaleOffset);
|
||||
}else{
|
||||
weightQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
|
||||
weightQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET;
|
||||
Qnn_AxisScaleOffset_t weightAxisScaleOffsetEncoding{};
|
||||
weightAxisScaleOffsetEncoding.axis = 3;
|
||||
weightAxisScaleOffsetEncoding.numScaleOffsets = oc;
|
||||
mScaleOffsetData.resize(oc);
|
||||
for (int i = 0; i < oc; i++) {
|
||||
mScaleOffsetData[i].scale = mDequantAlpha[i];
|
||||
mScaleOffsetData[i].offset = 0;
|
||||
}
|
||||
weightAxisScaleOffsetEncoding.scaleOffset = mScaleOffsetData.data();
|
||||
weightQuantize.axisScaleOffsetEncoding = weightAxisScaleOffsetEncoding;
|
||||
|
||||
this->createStaticTensor("quantWeight", QNN_DATATYPE_SFIXED_POINT_8, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, (void *) quantWeightData.data(), weightQuantize);
|
||||
std::function<void()> mReleaseWeightScaleOffset = [&](){
|
||||
std::vector<Qnn_ScaleOffset_t>().swap(mScaleOffsetData);
|
||||
};
|
||||
mBackend->pushReleaseFunc(mReleaseWeightScaleOffset);
|
||||
}
|
||||
this->createBias(dataType, oc, input, quanCommon);
|
||||
} else {
|
||||
std::vector<float> weightData;
|
||||
const float* source = nullptr;
|
||||
int weightElementNum = 0;
|
||||
std::shared_ptr<ConvolutionCommon::Int8Common> quanWeight;
|
||||
ConvolutionCommon::getConvParameters(&quanWeight, mBackend, mOp, &source, &weightElementNum);
|
||||
// oc ic h w ---> h w ic oc
|
||||
weightData.resize(weightElementNum);
|
||||
convertWeight(source, (float *) weightData.data(), oc, ic/group, kernelH, kernelW);
|
||||
Qnn_DataType_t floatDatatype = QNN_DATATYPE_FLOAT_32;
|
||||
if(mBackend->getUseFP16()){
|
||||
floatDatatype = QNN_DATATYPE_FLOAT_16;
|
||||
}
|
||||
this->createStaticFloatTensor("weight", floatDatatype, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, weightData.data());
|
||||
this->createBias(dataType, oc, input, nullptr);
|
||||
}
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
|
||||
void QNNConvolution::createWeight(Qnn_DataType_t dataType, int oc, int ic, int kernelH, int kernelW, int group, const float * source) {
|
||||
std::vector<float> weightData;
|
||||
int weightElementNum = oc * ic / group * kernelH * kernelW;
|
||||
// oc ic/group h w ---> h w ic/group oc
|
||||
weightData.resize(weightElementNum);
|
||||
|
||||
convertWeight(source, (float *) weightData.data(), oc, ic/group, kernelH, kernelW);
|
||||
this->createStaticFloatTensor("weight", dataType, {(uint32_t)kernelH, (uint32_t)kernelW, (uint32_t)ic / (uint32_t)group, (uint32_t)oc}, weightData.data());
|
||||
}
|
||||
|
||||
|
||||
void QNNConvolution::createBias(Qnn_DataType_t dataType, int oc) {
|
||||
void QNNConvolution::createBias(Qnn_DataType_t dataType, int oc, const Tensor *input, std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon) {
|
||||
int biasElementNum = oc;
|
||||
std::vector<float> biasData;
|
||||
biasData.resize(biasElementNum, 0);
|
||||
auto bias = mOp->main_as_Convolution2D()->bias();
|
||||
if (nullptr != bias) {
|
||||
::memcpy((void *)biasData.data(), (void *)bias->data(), biasElementNum * sizeof(float));
|
||||
}
|
||||
this->createStaticFloatTensor("bias", dataType, {(uint32_t)oc}, biasData.data());
|
||||
}
|
||||
if(dataType != QNN_DATATYPE_FLOAT_16 && dataType != QNN_DATATYPE_FLOAT_32 && mWeightQuant){
|
||||
mDequantAlpha = quanCommon->alpha.get();
|
||||
float inputScale = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.scale;
|
||||
int inputOffset = mBackend->getNativeTensor(input)->v1.quantizeParams.scaleOffsetEncoding.offset;
|
||||
std::vector<int> biasData;
|
||||
biasData.resize(biasElementNum, 0);
|
||||
|
||||
// oc ic h w ---> h w ic oc
|
||||
void QNNConvolution::convertWeight(const float * src, float * dst, int oc, int ic, int kernelH, int kernelW) {
|
||||
for (int o = 0; o < oc; o++) {
|
||||
for (int i = 0; i < ic; i++) {
|
||||
for (int h = 0; h < kernelH; h++) {
|
||||
for (int w = 0; w < kernelW; w++) {
|
||||
uint32_t srcOffset = w + kernelW * (h + kernelH * (i + ic * o));
|
||||
uint32_t dstOffset = o + oc * (i + ic * (w + kernelW * h));
|
||||
dst[dstOffset] = src[srcOffset];
|
||||
Qnn_QuantizeParams_t biasQuantize{};
|
||||
biasQuantize.encodingDefinition = QNN_DEFINITION_DEFINED;
|
||||
biasQuantize.quantizationEncoding = QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET;
|
||||
Qnn_AxisScaleOffset_t biasAxisScaleOffsetEncoding{};
|
||||
biasAxisScaleOffsetEncoding.axis = 0;
|
||||
biasAxisScaleOffsetEncoding.numScaleOffsets = biasElementNum;
|
||||
mBiasScaleOffsetData.resize(biasElementNum);
|
||||
|
||||
auto bias = mOp->main_as_Convolution2D()->bias();
|
||||
auto biasPtr = (float*)bias->data();
|
||||
if (nullptr != bias) {
|
||||
for(int i = 0; i < biasElementNum; ++i){
|
||||
float biasScale = inputScale * mDequantAlpha[i];
|
||||
mBiasScaleOffsetData[i].scale = biasScale;
|
||||
mBiasScaleOffsetData[i].offset = 0;
|
||||
if(fabs(biasPtr[i]) < 0.000001 || fabs(biasScale) < 0.000001){
|
||||
biasData[i] = 0;
|
||||
} else{
|
||||
biasData[i] = (int)(biasPtr[i] / biasScale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
biasAxisScaleOffsetEncoding.scaleOffset = mBiasScaleOffsetData.data();
|
||||
biasQuantize.axisScaleOffsetEncoding = biasAxisScaleOffsetEncoding;
|
||||
|
||||
this->createStaticTensor("bias", QNN_DATATYPE_SFIXED_POINT_32, {(uint32_t)biasElementNum}, biasData.data(), biasQuantize);
|
||||
std::function<void()> mReleaseBiasScaleOffset = [&](){
|
||||
std::vector<Qnn_ScaleOffset_t>().swap(mBiasScaleOffsetData);
|
||||
};
|
||||
mBackend->pushReleaseFunc(mReleaseBiasScaleOffset);
|
||||
}else{
|
||||
std::vector<float> biasData;
|
||||
biasData.resize(biasElementNum, 0);
|
||||
auto bias = mOp->main_as_Convolution2D()->bias();
|
||||
if (nullptr != bias) {
|
||||
::memcpy((void *)biasData.data(), (void *)bias->data(), biasElementNum * sizeof(float));
|
||||
}
|
||||
Qnn_DataType_t floatDatatype = QNN_DATATYPE_FLOAT_32;
|
||||
if(mBackend->getUseFP16()){
|
||||
floatDatatype = QNN_DATATYPE_FLOAT_16;
|
||||
}
|
||||
this->createStaticFloatTensor("bias", floatDatatype, {(uint32_t)oc}, biasData.data());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,12 +19,37 @@ class QNNConvolution : public QNNCommonExecution {
|
|||
public:
|
||||
QNNConvolution(Backend *backend, const Op *op) : QNNCommonExecution(backend, op) {}
|
||||
virtual ErrorCode onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
ErrorCode onEncodeQuant(Tensor * input, Tensor * output, int n, int h, int w, int ic, int oc, std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon);
|
||||
ErrorCode onEncodeFpAIntBMatMul(Tensor * input, Tensor * output, int n, int h, int w, int ic, int oc);
|
||||
ErrorCode onEncodeQuantDequantConv(Tensor *input, Tensor *output, const int n, const int ic, const int oc);
|
||||
|
||||
private:
|
||||
void createWeight(Qnn_DataType_t dataType, int oc, int ic, int kernelH, int kernelW, int group, const float * source);
|
||||
void createBias(Qnn_DataType_t dataType, int oc);
|
||||
void convertWeight(const float * src, float * dst, int oc, int ic, int kernelH, int kernelW);
|
||||
template <typename T>
|
||||
void convertWeight(const T * src, T * dst, int oc, int ic, int kernelH, int kernelW) {
|
||||
for (int o = 0; o < oc; o++) {
|
||||
for (int i = 0; i < ic; i++) {
|
||||
for (int h = 0; h < kernelH; h++) {
|
||||
for (int w = 0; w < kernelW; w++) {
|
||||
uint32_t srcOffset = w + kernelW * (h + kernelH * (i + ic * o));
|
||||
uint32_t dstOffset = o + oc * (i + ic * (w + kernelW * h));
|
||||
dst[dstOffset] = src[srcOffset];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
void isWeightQuantSupported(const Tensor *input, const int ic, const int oc);
|
||||
bool createWeightAndBias(Qnn_DataType_t dataType, const Tensor *input, int oc, int ic, int kernelH, int kernelW, int group);
|
||||
void createBias(Qnn_DataType_t dataType, int oc, const Tensor *input, std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon);
|
||||
std::vector<float> mScale;
|
||||
std::vector<Qnn_ScaleOffset_t> mScaleOffsetData;
|
||||
std::vector<Qnn_ScaleOffset_t> mBiasScaleOffsetData;
|
||||
std::vector<uint8_t> mBlockScale;
|
||||
Qnn_BlockwiseExpansion_t weightBlockwiseExpansionEncoding = QNN_BLOCKWISE_EXPANSION_INIT;
|
||||
float *mDequantAlpha = nullptr;
|
||||
int mBlockSize = 1;
|
||||
bool mWeightQuant = false;
|
||||
bool mIsMatMul = false;
|
||||
bool mIs1x1Conv = false;
|
||||
};
|
||||
|
||||
} // end namespace QNN
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
//
|
||||
// QNNQuant.cpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on b'2025/05/29'.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#include "QNNQuant.hpp"
|
||||
|
||||
namespace MNN {
|
||||
namespace QNN {
|
||||
|
||||
ErrorCode QNNQuant::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
this->createStageTensor("Cast", QNN_DATATYPE_FLOAT_32, getNHWCShape(outputs[0]));
|
||||
// Stage one fp16 -> fp32
|
||||
{
|
||||
mNodeType = "Cast";
|
||||
std::string name = mNodeName + "_Cast";
|
||||
|
||||
mInputs.push_back(*(mBackend->getNativeTensor(inputs[0]))); // input
|
||||
mOutputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // stage tensor
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
|
||||
// Stage two fp32 -> int8
|
||||
{
|
||||
mNodeType.clear();
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Quantize";
|
||||
std::string name = mNodeName;
|
||||
|
||||
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor())); // stage tensor
|
||||
mOutputs.push_back(*(mBackend->getNativeTensor(outputs[0]))); // output
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
|
||||
class QNNQuantCreator : public QnnBackend::Creator {
|
||||
public:
|
||||
virtual QNNCommonExecution * onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const MNN::Op* op,
|
||||
Backend* backend) const override {
|
||||
return new QNNQuant(backend, op);
|
||||
}
|
||||
};
|
||||
|
||||
ErrorCode QNNDeQuant::onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
// Stage one int8 -> fp16
|
||||
{
|
||||
mNodeType.clear();
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Dequantize";
|
||||
std::string name = mNodeName;
|
||||
|
||||
mInputs.push_back(*(mBackend->getNativeTensor(inputs[0]))); // input
|
||||
mOutputs.push_back(*(mBackend->getNativeTensor(outputs[0]))); // output
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
return NO_ERROR;
|
||||
}
|
||||
|
||||
|
||||
class QNNDeQuantCreator : public QnnBackend::Creator {
|
||||
public:
|
||||
virtual QNNCommonExecution * onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const MNN::Op* op,
|
||||
Backend* backend) const override {
|
||||
return new QNNDeQuant(backend, op);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_QNN_OP_CREATOR(QNNQuantCreator, OpType_FloatToInt8)
|
||||
REGISTER_QNN_OP_CREATOR(QNNDeQuantCreator, OpType_Int8ToFloat)
|
||||
|
||||
} // end namespace QNN
|
||||
} // end namespace MNN
|
|
@ -0,0 +1,32 @@
|
|||
//
|
||||
// QNNQuant.hpp
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on b'2025/05/29'.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifndef MNN_QNNQUANT_HPP
|
||||
#define MNN_QNNQUANT_HPP
|
||||
|
||||
#include "QNNCommonExecution.hpp"
|
||||
|
||||
namespace MNN {
|
||||
namespace QNN {
|
||||
|
||||
class QNNQuant : public QNNCommonExecution {
|
||||
public:
|
||||
QNNQuant(Backend *backend, const Op *op) : QNNCommonExecution(backend, op) {};
|
||||
virtual ErrorCode onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
};
|
||||
|
||||
class QNNDeQuant : public QNNCommonExecution {
|
||||
public:
|
||||
QNNDeQuant(Backend *backend, const Op *op) : QNNCommonExecution(backend, op) {};
|
||||
virtual ErrorCode onEncode(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
};
|
||||
|
||||
} // end namespace QNN
|
||||
} // end namespace MNN
|
||||
|
||||
#endif // end MNN_QNNQUANT_HPP
|
|
@ -28,10 +28,25 @@ ErrorCode QNNScale::onEncode(const std::vector<Tensor *> &inputs, const std::vec
|
|||
MNN_ASSERT(channel == mWeightData.size());
|
||||
|
||||
Qnn_DataType_t dataType = mBackend->getNativeTensor(inputs[0])->v1.dataType;
|
||||
|
||||
this->createStaticFloatTensor("weight", dataType, {(uint32_t)channel}, mWeightData.data());
|
||||
this->createStaticFloatTensor("bias", dataType, {(uint32_t)channel}, mBiasData.data());
|
||||
this->createStageTensor("Stage", dataType, getNHWCShape(inputs[0]));
|
||||
mNeedQuantDequant = dataType != QNN_DATATYPE_FLOAT_16 && dataType != QNN_DATATYPE_FLOAT_32;
|
||||
if(mNeedQuantDequant){
|
||||
Qnn_DataType_t tempDataType = QNN_DATATYPE_FLOAT_32;
|
||||
if(mBackend->getUseFP16()){
|
||||
tempDataType = QNN_DATATYPE_FLOAT_16;
|
||||
}
|
||||
this->createStaticFloatTensor("weight", tempDataType, {(uint32_t)channel}, mWeightData.data());
|
||||
this->createStaticFloatTensor("bias", tempDataType, {(uint32_t)channel}, mBiasData.data());
|
||||
this->createStageTensor("Stage", tempDataType, getNHWCShape(inputs[0]));
|
||||
this->createStageTensor("Stage_dequantize_input", tempDataType, getNHWCShape(inputs[0]));
|
||||
this->createStageTensor("Stage_add_output", tempDataType, getNHWCShape(outputs[0]));
|
||||
if(mBackend->getUseFP16()){
|
||||
this->createStageTensor("Stage_cast_output", QNN_DATATYPE_FLOAT_32, getNHWCShape(outputs[0]));
|
||||
}
|
||||
}else{
|
||||
this->createStaticFloatTensor("weight", dataType, {(uint32_t)channel}, mWeightData.data());
|
||||
this->createStaticFloatTensor("bias", dataType, {(uint32_t)channel}, mBiasData.data());
|
||||
this->createStageTensor("Stage", dataType, getNHWCShape(inputs[0]));
|
||||
}
|
||||
}
|
||||
|
||||
// add nodes
|
||||
|
@ -42,34 +57,98 @@ ErrorCode QNNScale::onEncode(const std::vector<Tensor *> &inputs, const std::vec
|
|||
}
|
||||
|
||||
void QNNScale::mulWeight(Tensor * input) {
|
||||
mNodeType = "ElementWiseMultiply";
|
||||
std::string name = mNodeName + "_mul";
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
|
||||
mInputs.push_back(*(mBackend->getNativeTensor(input)));
|
||||
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor()));
|
||||
|
||||
mOutputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor()));
|
||||
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
Qnn_DataType_t dataType = mBackend->getNativeTensor(input)->v1.dataType;
|
||||
// need dequantize to float16
|
||||
if(mNeedQuantDequant){
|
||||
mNodeType.clear();
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Dequantize";
|
||||
std::string name = mNodeName + "_Dequantize";
|
||||
|
||||
mInputs.push_back(*(mBackend->getNativeTensor(input))); // input
|
||||
mOutputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); //Stage_dequantize_input
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
{
|
||||
mNodeType.clear();
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "ElementWiseMultiply";
|
||||
std::string name = mNodeName + "_mul";
|
||||
|
||||
if(mNeedQuantDequant){
|
||||
mInputs.push_back(*(mTempTensorWrappers[3]->getNativeTensor())); //Stage_dequantize_input
|
||||
}else{
|
||||
mInputs.push_back(*(mBackend->getNativeTensor(input)));
|
||||
}
|
||||
mInputs.push_back(*(mTempTensorWrappers[0]->getNativeTensor()));
|
||||
|
||||
mOutputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor()));
|
||||
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void QNNScale::addBias(Tensor * output) {
|
||||
mNodeType = "ElementWiseAdd";
|
||||
std::string name = mNodeName + "_add";
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
Qnn_DataType_t dataType = mBackend->getNativeTensor(output)->v1.dataType;
|
||||
{
|
||||
mNodeType.clear();
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "ElementWiseAdd";
|
||||
std::string name = mNodeName + "_add";
|
||||
|
||||
mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor()));
|
||||
mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor()));
|
||||
|
||||
if(mNeedQuantDequant){
|
||||
mOutputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // Stage_add_output
|
||||
}else{
|
||||
mOutputs.push_back(*(mBackend->getNativeTensor(output)));
|
||||
}
|
||||
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
|
||||
// need quantize output
|
||||
if(mNeedQuantDequant){
|
||||
// Stage one fp16 -> fp32
|
||||
if(mBackend->getUseFP16()){
|
||||
mNodeType.clear();
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Cast";
|
||||
std::string name = mNodeName + "_Cast";
|
||||
|
||||
mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // Stage_add_output
|
||||
mOutputs.push_back(*(mTempTensorWrappers[5]->getNativeTensor())); // Stage_cast_output
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
|
||||
mInputs.push_back(*(mTempTensorWrappers[2]->getNativeTensor()));
|
||||
mInputs.push_back(*(mTempTensorWrappers[1]->getNativeTensor()));
|
||||
|
||||
mOutputs.push_back(*(mBackend->getNativeTensor(output)));
|
||||
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
// Stage two fp32 -> int8
|
||||
{
|
||||
mNodeType.clear();
|
||||
mParams.clear();
|
||||
mInputs.clear();
|
||||
mOutputs.clear();
|
||||
mNodeType = "Quantize";
|
||||
std::string name = mNodeName + "_Quantize";
|
||||
|
||||
if(mBackend->getUseFP16()){
|
||||
mInputs.push_back(*(mTempTensorWrappers[5]->getNativeTensor())); // Stage_cast_output
|
||||
}else{
|
||||
mInputs.push_back(*(mTempTensorWrappers[4]->getNativeTensor())); // Stage_add_output
|
||||
}
|
||||
mOutputs.push_back(*(mBackend->getNativeTensor(output))); // output
|
||||
mBackend->addNodeToGraph(mOpConfigVersion, name.c_str(), mPackageName.c_str(), mNodeType.c_str(), mParams, mInputs, mOutputs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ErrorCode QNNScale::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
|
|
|
@ -25,6 +25,7 @@ private:
|
|||
private:
|
||||
std::vector<float> mWeightData;
|
||||
std::vector<float> mBiasData;
|
||||
bool mNeedQuantDequant = false;
|
||||
};
|
||||
|
||||
} // end namespace QNN
|
||||
|
|
|
@ -34,12 +34,17 @@ struct RuntimeHint {
|
|||
int cpuDecreaseRate = 50;
|
||||
int dynamicQuantOption = 0;
|
||||
|
||||
// qkvQuantOption % 8:
|
||||
// 0: Do not quantize
|
||||
// 1: Only quantize key, use int8 asymmetric quantization
|
||||
// 2: Only quantize value, use fp8 quantization
|
||||
// 3: quantize both key and value
|
||||
// 4: quantize query, key and value, and use gemm int8 kernel to compute K*V
|
||||
int qkvQuantOption = 0;
|
||||
|
||||
// qkvQuantOption / 8:
|
||||
// 1: use flash attention
|
||||
|
||||
int qkvQuantOption = 8;
|
||||
|
||||
// the kvcache size limit of each layer
|
||||
// if the size of kvcache in memory exceeds the limit
|
||||
|
@ -403,6 +408,9 @@ public:
|
|||
info.mode = Backend::Info::DIRECT;
|
||||
return true;
|
||||
}
|
||||
virtual bool onGetDeviceInfo(const std::string& deviceKey, std::string& deviceValue) const {
|
||||
return false;
|
||||
}
|
||||
protected:
|
||||
/**
|
||||
@brief deinitializer.
|
||||
|
|
|
@ -129,6 +129,13 @@ static int dumpModelInfo(const char* modelName) {
|
|||
} else {
|
||||
MNN_PRINT("Model Version: %s \n", info->version.c_str());
|
||||
}
|
||||
if (!info->metaData.empty()) {
|
||||
MNN_PRINT("MetaData: Begin \n");
|
||||
for (auto& iter : info->metaData) {
|
||||
MNN_PRINT("[Meta] %s : %s\n", iter.first.c_str(), iter.second.c_str());
|
||||
}
|
||||
MNN_PRINT("MetaData: End \n");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -130,7 +130,17 @@ bool CompleteSubGraph(const std::unordered_map<std::string, VARP>& inputs, const
|
|||
return true;
|
||||
}
|
||||
|
||||
|
||||
static bool _hasDupName(std::unique_ptr<MNN::NetT>& originNet) {
|
||||
std::set<std::string> names;
|
||||
for (auto& tensorName : originNet->tensorName) {
|
||||
if (names.find(tensorName) != names.end()) {
|
||||
MNN_ERROR("Repeat name %s\n", tensorName.c_str());
|
||||
return true;
|
||||
}
|
||||
names.insert(tensorName);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
void RunNetPass(const std::vector<std::string>& passes, std::unique_ptr<MNN::NetT>& originNet) {
|
||||
for (auto pass : passes) {
|
||||
auto convert = PostConverter::get(pass);
|
||||
|
@ -138,7 +148,14 @@ void RunNetPass(const std::vector<std::string>& passes, std::unique_ptr<MNN::Net
|
|||
LOG(INFO) << "Can't find pass of " << pass << "\n";
|
||||
continue;
|
||||
}
|
||||
auto originSize = originNet->oplists.size();
|
||||
bool valid = convert->onExecute(originNet);
|
||||
#ifdef DEBUG
|
||||
auto hasDup = _hasDupName(originNet);
|
||||
if (originSize != originNet->oplists.size() || hasDup) {
|
||||
MNN_PRINT("%s: %d -> %d, dup: %d\n", pass.c_str(), originSize, originNet->oplists.size(), hasDup);
|
||||
}
|
||||
#endif
|
||||
if (!valid) {
|
||||
LOG(INFO) << "Run " << pass << "Error\n";
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ static EXPRP clipConvert(EXPRP expr, bool supportRelu6) {
|
|||
auto extraParam = op->main_as_Extra();
|
||||
// auto dataType = expr->outputInfo(0)->type.code;
|
||||
auto maxValue = std::numeric_limits<T>().max();
|
||||
auto minValue = std::numeric_limits<T>().min();
|
||||
auto minValue = std::numeric_limits<T>().lowest();
|
||||
if (nullptr != extraParam->attr()) {
|
||||
const int attrSize = extraParam->attr()->size();
|
||||
for (int i = 0; i < attrSize; ++i) {
|
||||
|
|
|
@ -12,20 +12,70 @@
|
|||
class RemoveCopy : public PostConverter {
|
||||
public:
|
||||
virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override {
|
||||
auto config = Global<modelConfig>::Get();
|
||||
if (config->optimizeLevel < 1 || config->inSubGraph) {
|
||||
return true;
|
||||
std::set<std::string> netOutputNames;
|
||||
for (auto& t : net->outputName) {
|
||||
netOutputNames.insert(t);
|
||||
}
|
||||
for (auto iter = net->oplists.begin(); iter != net->oplists.end(); iter++) {
|
||||
auto& op = *iter;
|
||||
if (op->type == MNN::OpType_Input) {
|
||||
for (auto o : op->outputIndexes) {
|
||||
netOutputNames.insert(net->tensorName[o]);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto config = Global<modelConfig>::Get();
|
||||
for (auto iter = net->oplists.begin(); iter != net->oplists.end();) {
|
||||
auto& op = *iter;
|
||||
if (op->type != MNN::OpType_Identity) {
|
||||
if (op->type != MNN::OpType_Identity || op->inputIndexes.size() != op->outputIndexes.size()) {
|
||||
iter++;
|
||||
continue;
|
||||
}
|
||||
|
||||
bool hasOutputName = false;
|
||||
for (auto o : op->outputIndexes) {
|
||||
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
|
||||
hasOutputName = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
bool hasOutputFromInput = false;
|
||||
for (auto o : op->inputIndexes) {
|
||||
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
|
||||
hasOutputFromInput = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (hasOutputFromInput && hasOutputName) {
|
||||
iter++;
|
||||
continue;
|
||||
}
|
||||
auto originInput = op->inputIndexes;
|
||||
auto originOutputs = op->outputIndexes;
|
||||
MNN_ASSERT(originInput.size() == originOutputs.size());
|
||||
if (hasOutputName) {
|
||||
bool valid = true;
|
||||
for (int i=0; i<op->inputIndexes.size(); ++i) {
|
||||
auto o = op->outputIndexes[i];
|
||||
auto originInput = op->inputIndexes[i];
|
||||
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
|
||||
if (netOutputNames.find(net->tensorName[originInput]) != netOutputNames.end()) {
|
||||
valid = false;
|
||||
break;
|
||||
}
|
||||
auto originName = net->tensorName[originInput];
|
||||
net->tensorName[originInput] = net->tensorName[o];
|
||||
net->tensorName[o] = originName;
|
||||
}
|
||||
}
|
||||
if (!valid) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
std::map<int, int> replaceIndexes;
|
||||
for (int i=0; i<op->inputIndexes.size();++i) {
|
||||
replaceIndexes.insert(std::make_pair(op->outputIndexes[i], op->inputIndexes[i]));
|
||||
net->tensorName[op->inputIndexes[i]] = net->tensorName[op->outputIndexes[i]];
|
||||
}
|
||||
for (auto subIter = net->oplists.begin(); subIter != net->oplists.end(); subIter++) {
|
||||
auto& subOp = *subIter;
|
||||
|
@ -35,14 +85,6 @@ public:
|
|||
}
|
||||
}
|
||||
}
|
||||
for (int v=0; v<op->inputIndexes.size(); ++v) {
|
||||
for (auto& o : net->outputName) {
|
||||
if (o == net->tensorName[op->inputIndexes[v]]) {
|
||||
o = net->tensorName[op->outputIndexes[v]];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
iter = net->oplists.erase(iter);
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
#include "RemoveTestNoUseOps.hpp"
|
||||
bool RemoveTestNoUseOps::onExecute(std::unique_ptr<MNN::NetT>& net) const {
|
||||
const MNN::NetT* const netPtr = net.get();
|
||||
std::set<std::string> netOutputNames;
|
||||
for (auto& t : net->outputName) {
|
||||
netOutputNames.insert(t);
|
||||
}
|
||||
for (auto iter = net->oplists.begin(); iter != net->oplists.end(); iter++) {
|
||||
auto& op = *iter;
|
||||
if (op->type == OpType_Input) {
|
||||
for (auto o : op->outputIndexes) {
|
||||
netOutputNames.insert(net->tensorName[o]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_set<int> removedInputs;
|
||||
for (auto iter = net->oplists.begin(); iter != net->oplists.end();) {
|
||||
auto& op = *iter;
|
||||
bool shouldDelete = shouldDeleteJudge(op.get(), netPtr);
|
||||
if (!shouldDelete) {
|
||||
iter++;
|
||||
continue;
|
||||
}
|
||||
bool hasOutputName = false;
|
||||
for (auto o : op->outputIndexes) {
|
||||
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
|
||||
hasOutputName = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
bool hasOutputFromInput = false;
|
||||
for (auto o : op->inputIndexes) {
|
||||
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
|
||||
hasOutputFromInput = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (hasOutputFromInput && hasOutputName) {
|
||||
iter++;
|
||||
continue;
|
||||
}
|
||||
bool deleteOutput = shouldDeleteOutput(op.get());
|
||||
// Find the next op
|
||||
if (op->outputIndexes.empty() || op->inputIndexes.empty()) {
|
||||
iter = net->oplists.erase(iter);
|
||||
continue;
|
||||
}
|
||||
auto originInput = op->inputIndexes[0];
|
||||
auto originOutputs = op->outputIndexes;
|
||||
if ((!deleteOutput) && hasOutputName) {
|
||||
bool valid = true;
|
||||
for (auto o : originOutputs) {
|
||||
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
|
||||
if (netOutputNames.find(net->tensorName[originInput]) != netOutputNames.end()) {
|
||||
valid = false;
|
||||
break;
|
||||
}
|
||||
net->tensorName[originInput] = net->tensorName[o];
|
||||
}
|
||||
}
|
||||
if (!valid) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
for (auto subIter = net->oplists.begin(); subIter != net->oplists.end(); subIter++) {
|
||||
auto& subOp = *subIter;
|
||||
if (deleteOutput) {
|
||||
for (auto iter=subOp->inputIndexes.begin(); iter != subOp->inputIndexes.end();) {
|
||||
if (std::find(originOutputs.begin(), originOutputs.end(), *iter) != originOutputs.end()) {
|
||||
iter = subOp->inputIndexes.erase(iter);
|
||||
continue;
|
||||
}
|
||||
iter++;
|
||||
}
|
||||
} else {
|
||||
for (int v = 0; v < subOp->inputIndexes.size(); ++v) {
|
||||
if (std::find(originOutputs.begin(), originOutputs.end(), subOp->inputIndexes[v]) != originOutputs.end()) {
|
||||
subOp->inputIndexes[v] = originInput;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
bool removeUselessInput = shouldRemoveUnusefulInputs(op.get());
|
||||
if (removeUselessInput) {
|
||||
for (int input : op->inputIndexes) {
|
||||
removedInputs.emplace(input);
|
||||
}
|
||||
}
|
||||
iter = net->oplists.erase(iter);
|
||||
}
|
||||
|
||||
// Remove the op only if the reference counts of it's all outputs
|
||||
// are reduced to be zero.
|
||||
std::unordered_map<int, int/*reference count*/> uselessIndex;
|
||||
for (const auto& op : net->oplists) {
|
||||
for (int input : op->inputIndexes) {
|
||||
auto it = uselessIndex.find(input);
|
||||
if (it == uselessIndex.end()) {
|
||||
uselessIndex.emplace(input, 1);
|
||||
} else {
|
||||
++it->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Set reference count 1 for all net outputs.
|
||||
for (const auto& op : net->oplists) {
|
||||
for (int output : op->outputIndexes) {
|
||||
auto it = uselessIndex.find(output);
|
||||
if (it == uselessIndex.end()) {
|
||||
if (removedInputs.count(output)) {
|
||||
uselessIndex.emplace(output, 0);
|
||||
} else {
|
||||
uselessIndex.emplace(output, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool needIteration = false;
|
||||
do {
|
||||
needIteration = false;
|
||||
for (auto iter = net->oplists.begin(); iter != net->oplists.end();) {
|
||||
auto& op = *iter;
|
||||
bool useless = true;
|
||||
for (auto index : op->outputIndexes) {
|
||||
if (uselessIndex.at(index) > 0) {
|
||||
useless = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!useless) {
|
||||
iter++;
|
||||
continue;
|
||||
}
|
||||
if (!op->inputIndexes.empty()) {
|
||||
for (auto index : op->inputIndexes) {
|
||||
auto it = uselessIndex.find(index);
|
||||
MNN_ASSERT(it != uselessIndex.end());
|
||||
--it->second;
|
||||
}
|
||||
needIteration = true;
|
||||
}
|
||||
iter = net->oplists.erase(iter);
|
||||
}
|
||||
} while (needIteration);
|
||||
|
||||
return true;
|
||||
}
|
|
@ -25,135 +25,5 @@ public:
|
|||
|
||||
virtual bool shouldDeleteOutput(const MNN::OpT* op) const = 0;
|
||||
|
||||
virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override {
|
||||
const MNN::NetT* const netPtr = net.get();
|
||||
std::set<std::string> netOutputNames;
|
||||
for (auto& t : net->outputName) {
|
||||
netOutputNames.insert(t);
|
||||
}
|
||||
std::unordered_set<int> removedInputs;
|
||||
for (auto iter = net->oplists.begin(); iter != net->oplists.end();) {
|
||||
auto& op = *iter;
|
||||
bool shouldDelete = shouldDeleteJudge(op.get(), netPtr);
|
||||
if (!shouldDelete) {
|
||||
iter++;
|
||||
continue;
|
||||
}
|
||||
bool hasOutputName = false;
|
||||
for (auto o : op->outputIndexes) {
|
||||
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
|
||||
hasOutputName = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
bool hasOutputFromInput = false;
|
||||
for (auto o : op->inputIndexes) {
|
||||
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
|
||||
hasOutputFromInput = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (hasOutputFromInput && hasOutputName) {
|
||||
iter++;
|
||||
continue;
|
||||
}
|
||||
bool deleteOutput = shouldDeleteOutput(op.get());
|
||||
// Find the next op
|
||||
if (op->outputIndexes.empty() || op->inputIndexes.empty()) {
|
||||
iter = net->oplists.erase(iter);
|
||||
continue;
|
||||
}
|
||||
auto originInput = op->inputIndexes[0];
|
||||
auto originOutputs = op->outputIndexes;
|
||||
if ((!deleteOutput) && hasOutputName) {
|
||||
for (auto o : originOutputs) {
|
||||
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
|
||||
net->tensorName[originInput] = net->tensorName[o];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto subIter = net->oplists.begin(); subIter != net->oplists.end(); subIter++) {
|
||||
auto& subOp = *subIter;
|
||||
if (deleteOutput) {
|
||||
for (auto iter=subOp->inputIndexes.begin(); iter != subOp->inputIndexes.end();) {
|
||||
if (std::find(originOutputs.begin(), originOutputs.end(), *iter) != originOutputs.end()) {
|
||||
iter = subOp->inputIndexes.erase(iter);
|
||||
continue;
|
||||
}
|
||||
iter++;
|
||||
}
|
||||
} else {
|
||||
for (int v = 0; v < subOp->inputIndexes.size(); ++v) {
|
||||
if (std::find(originOutputs.begin(), originOutputs.end(), subOp->inputIndexes[v]) != originOutputs.end()) {
|
||||
subOp->inputIndexes[v] = originInput;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
bool removeUselessInput = shouldRemoveUnusefulInputs(op.get());
|
||||
if (removeUselessInput) {
|
||||
for (int input : op->inputIndexes) {
|
||||
removedInputs.emplace(input);
|
||||
}
|
||||
}
|
||||
iter = net->oplists.erase(iter);
|
||||
}
|
||||
|
||||
// Remove the op only if the reference counts of it's all outputs
|
||||
// are reduced to be zero.
|
||||
std::unordered_map<int, int/*reference count*/> uselessIndex;
|
||||
for (const auto& op : net->oplists) {
|
||||
for (int input : op->inputIndexes) {
|
||||
auto it = uselessIndex.find(input);
|
||||
if (it == uselessIndex.end()) {
|
||||
uselessIndex.emplace(input, 1);
|
||||
} else {
|
||||
++it->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Set reference count 1 for all net outputs.
|
||||
for (const auto& op : net->oplists) {
|
||||
for (int output : op->outputIndexes) {
|
||||
auto it = uselessIndex.find(output);
|
||||
if (it == uselessIndex.end()) {
|
||||
if (removedInputs.count(output)) {
|
||||
uselessIndex.emplace(output, 0);
|
||||
} else {
|
||||
uselessIndex.emplace(output, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool needIteration = false;
|
||||
do {
|
||||
needIteration = false;
|
||||
for (auto iter = net->oplists.begin(); iter != net->oplists.end();) {
|
||||
auto& op = *iter;
|
||||
bool useless = true;
|
||||
for (auto index : op->outputIndexes) {
|
||||
if (uselessIndex.at(index) > 0) {
|
||||
useless = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!useless) {
|
||||
iter++;
|
||||
continue;
|
||||
}
|
||||
if (!op->inputIndexes.empty()) {
|
||||
for (auto index : op->inputIndexes) {
|
||||
auto it = uselessIndex.find(index);
|
||||
MNN_ASSERT(it != uselessIndex.end());
|
||||
--it->second;
|
||||
}
|
||||
needIteration = true;
|
||||
}
|
||||
iter = net->oplists.erase(iter);
|
||||
}
|
||||
} while (needIteration);
|
||||
|
||||
return true;
|
||||
}
|
||||
virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override;
|
||||
};
|
||||
|
|
|
@ -33,6 +33,11 @@ public:
|
|||
return true;
|
||||
}
|
||||
}
|
||||
if (op->type == OpType_Identity) {
|
||||
// Support 1->N
|
||||
return op->inputIndexes.size() == 1 && op->outputIndexes.size() > 1;
|
||||
}
|
||||
|
||||
if (op->type == OpType_Cast) {
|
||||
if (op->main.AsCastParam()->dstT == op->main.AsCastParam()->srcT) {
|
||||
return true;
|
||||
|
|
|
@ -37,11 +37,21 @@ int main(int argc, const char* argv[]) {
|
|||
/**
|
||||
generate qnn .cpp and .bin
|
||||
*/
|
||||
|
||||
std::string dstModelName = dstMNN;
|
||||
size_t pos = dstModelName.find_last_of("/\\");
|
||||
std::string dstModelPath;
|
||||
if (pos == std::string::npos) {
|
||||
// current path
|
||||
dstModelPath = "./";
|
||||
} else {
|
||||
dstModelPath = dstModelName.substr(0, pos);
|
||||
}
|
||||
std::string qnnModelPath = dstModelPath + "/" + qnnModelName;
|
||||
MNN_PRINT("[Temp Product]: Qnn temp product generate at %s\n", qnnModelPath.c_str());
|
||||
MNN::ScheduleConfig config;
|
||||
config.type = MNN_FORWARD_NN;
|
||||
std::shared_ptr<Executor::RuntimeManager> rtmgr(Executor::RuntimeManager::createRuntimeManager(config));
|
||||
rtmgr->setCache(qnnModelName.c_str());
|
||||
rtmgr->setCache(qnnModelPath.c_str());
|
||||
MNN::Express::Module::Config mConfig;
|
||||
mConfig.shapeMutable = false;
|
||||
std::shared_ptr<MNN::Express::Module> m(MNN::Express::Module::load(inputNames, outputNames, srcMNN, rtmgr, &mConfig), MNN::Express::Module::destroy);
|
||||
|
@ -64,7 +74,7 @@ int main(int argc, const char* argv[]) {
|
|||
}
|
||||
|
||||
int ret = 0;
|
||||
std::string tarBinCmd = "cd " + qnnModelName + \
|
||||
std::string tarBinCmd = "cd " + qnnModelPath + \
|
||||
" && " + \
|
||||
"tar -cf " + qnnModelName + ".bin *.raw";
|
||||
ret = system(tarBinCmd.c_str());
|
||||
|
@ -74,10 +84,10 @@ int main(int argc, const char* argv[]) {
|
|||
}
|
||||
|
||||
std::string modelLibCmd = qnnSdkPath + "/bin/x86_64-linux-clang/qnn-model-lib-generator " + \
|
||||
"-c " + qnnModelName + "/" + qnnModelName + ".cpp " + \
|
||||
"-b " + qnnModelName + "/" + qnnModelName + ".bin " + \
|
||||
"-c " + qnnModelPath + "/" + qnnModelName + ".cpp " + \
|
||||
"-b " + qnnModelPath + "/" + qnnModelName + ".bin " + \
|
||||
"-t x86_64-linux-clang " + \
|
||||
"-o " + qnnModelName + "/lib ";
|
||||
"-o " + qnnModelPath + "/lib ";
|
||||
ret = system(modelLibCmd.c_str());
|
||||
if(ret) {
|
||||
MNN_ERROR("[Error]: qnn-model-lib-generator error!\n");
|
||||
|
@ -86,12 +96,13 @@ int main(int argc, const char* argv[]) {
|
|||
MNN_PRINT("[Pass]: qnn-model-lib-generator success!\n");
|
||||
}
|
||||
|
||||
std::string qnnBin = dstModelPath + "/" + qnnModelName + ".bin";
|
||||
std::string binaryGenCmd = qnnSdkPath + "/bin/x86_64-linux-clang/qnn-context-binary-generator " + \
|
||||
"--model " + qnnModelName + "/lib/x86_64-linux-clang/lib" + qnnModelName + ".so " + \
|
||||
"--model " + qnnModelPath + "/lib/x86_64-linux-clang/lib" + qnnModelName + ".so " + \
|
||||
"--backend " + qnnSdkPath + "/lib/x86_64-linux-clang/libQnnHtp.so " + \
|
||||
"--binary_file " + qnnModelName + " " + \
|
||||
"--config_file " + qnnContextConfig + " " + \
|
||||
"--output_dir " + qnnModelName + "/binary";
|
||||
"--output_dir " + dstModelPath;
|
||||
ret = system(binaryGenCmd.c_str());
|
||||
if(ret) {
|
||||
MNN_ERROR("[Error]: qnn-context-binary-generator error!\n");
|
||||
|
@ -100,7 +111,7 @@ int main(int argc, const char* argv[]) {
|
|||
MNN_PRINT("[Pass]: qnn-context-binary-generator success!\n");
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::vector<MNN::Express::Variable::Info> inputInfos(inputs.size());
|
||||
for (int i=0; i<inputInfos.size(); ++i) {
|
||||
inputInfos[i] = *inputs[i]->getInfo();
|
||||
|
@ -122,7 +133,8 @@ int main(int argc, const char* argv[]) {
|
|||
dstNet->oplists.emplace_back(std::move(input));
|
||||
}
|
||||
|
||||
std::string npuPath = std::string("/") + qnnModelName + std::string(".bin");
|
||||
std::string npuPath = std::string("/") + qnnModelName + std::string(".bin");
|
||||
|
||||
MNN_PRINT("npu model path:%s\n", npuPath.c_str());
|
||||
/** Fuse to Op*/
|
||||
std::unique_ptr<MNN::OpT> op(new OpT);
|
||||
|
@ -204,7 +216,7 @@ int main(int argc, const char* argv[]) {
|
|||
}
|
||||
for (int i=0; i<outputInfos.size(); ++i) {
|
||||
attr.reset(new MNN::AttributeT);
|
||||
attr->key = "o_" + std::to_string(i) + "_0";
|
||||
attr->key = "o_0_" + std::to_string(i);
|
||||
attr->tensor.reset(new BlobT);
|
||||
attr->tensor->dataType = OpCommonUtils::convertDataType(outputInfos[i].type);
|
||||
attr->tensor->dims = outputInfos[i].dim;
|
||||
|
@ -240,7 +252,6 @@ int main(int argc, const char* argv[]) {
|
|||
outputOs.close();
|
||||
|
||||
MNN_PRINT("[All Pass]: npu model generator success!\n");
|
||||
std::string qnnBin = qnnModelName + "/binary/" + qnnModelName + ".bin";
|
||||
MNN_PRINT("[Output Product]:\nNew mnn model path: %s\nNpu model path: %s\n", dstMNN, qnnBin.c_str());
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -43,6 +43,6 @@ for w in gWrong:
|
|||
print(w)
|
||||
print('TEST_NAME_MODULE: 模型测试\nTEST_CASE_AMOUNT_MODULE: {\"blocked\":0,\"failed\":%d,\"passed\":%d,\"skipped\":0}\n'%(len(gWrong), total_num - len(gWrong)))
|
||||
print('TEST_CASE={\"name\":\"Onnx转换测试\",\"failed\":%d,\"passed\":%d}\n'%(len(gWrong), total_num - len(gWrong)))
|
||||
print('Total Size: ', mnnsize,' MB, convert ', correct_num," model")
|
||||
print('Total Size: ', total_size,' MB, convert ', correct_num," model")
|
||||
if len(gWrong) > 0:
|
||||
exit(1)
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
import json
|
||||
import copy
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import shutil # 导入 shutil 模块用于删除目录
|
||||
|
||||
def generate_all_configs(config_path, graph_name, qnn_sdk_root_path, src_model, executable_path, output_dir):
|
||||
"""
|
||||
为每个组合创建子目录,生成配置文件,并调用C++可执行文件进行模型转换。
|
||||
"""
|
||||
# --- 0. 准备工作 ---
|
||||
# 创建主输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
print(f"所有生成的文件将被保存在主目录: '{output_dir}'")
|
||||
|
||||
# 定义组合
|
||||
combinations = [
|
||||
[36, 'v69'],
|
||||
[42, 'v69'],
|
||||
[43, 'v73'],
|
||||
[57, 'v75'],
|
||||
[69, 'v79']
|
||||
]
|
||||
|
||||
# --- 1. 读取模板文件 ---
|
||||
htp_template_file = os.path.join(config_path, "htp_backend_extensions.json")
|
||||
context_template_file = os.path.join(config_path, "context_config.json")
|
||||
|
||||
try:
|
||||
with open(htp_template_file, 'r', encoding='utf-8') as f:
|
||||
base_htp_data = json.load(f)
|
||||
print(f"成功读取模板文件 '{htp_template_file}'。")
|
||||
|
||||
with open(context_template_file, 'r', encoding='utf-8') as f:
|
||||
base_context_data = json.load(f)
|
||||
print(f"成功读取模板文件 '{context_template_file}'。")
|
||||
except FileNotFoundError as e:
|
||||
print(f"错误:模板文件未找到。请确保 '{e.filename}' 存在于指定的路径中。")
|
||||
return
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"错误:文件格式无效。请检查 {e.doc} 是否为有效的JSON。")
|
||||
return
|
||||
|
||||
# --- 2. 遍历组合,生成文件并执行命令 ---
|
||||
for soc_id, dsp_arch in combinations:
|
||||
print(f"\n{'='*15} 处理组合: soc_id={soc_id}, dsp_arch={dsp_arch} {'='*15}")
|
||||
|
||||
# --- 新增步骤: 为当前组合创建专用的子目录 ---
|
||||
new_graph_name = f"{graph_name}_{soc_id}_{dsp_arch}"
|
||||
graph_specific_dir = output_dir
|
||||
|
||||
# --- Part A: 生成 htp_backend_extensions 文件 (路径更新) ---
|
||||
htp_config_data = copy.deepcopy(base_htp_data)
|
||||
try:
|
||||
htp_config_data["graphs"][0]["graph_names"] = [new_graph_name]
|
||||
htp_config_data["devices"][0]["soc_id"] = soc_id
|
||||
htp_config_data["devices"][0]["dsp_arch"] = dsp_arch
|
||||
except (IndexError, KeyError) as e:
|
||||
print(f"处理组合时出错: '{htp_template_file}' 结构不正确。错误: {e}")
|
||||
continue
|
||||
|
||||
htp_output_filename = f"htp_backend_extensions_{soc_id}_{dsp_arch}.json"
|
||||
# 更新路径,使其指向新的子目录
|
||||
htp_output_filepath = os.path.join(graph_specific_dir, htp_output_filename)
|
||||
with open(htp_output_filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(htp_config_data, f, indent=4, ensure_ascii=False)
|
||||
print(f"-> 已生成配置文件: '{htp_output_filepath}'")
|
||||
|
||||
# --- Part B: 生成 context_config 文件 (路径更新) ---
|
||||
context_config_data = copy.deepcopy(base_context_data)
|
||||
try:
|
||||
# 这里的 htp_output_filename 是相对路径,这是正确的,
|
||||
# 因为 context_config 和 htp_backend_extensions 在同一个目录中。
|
||||
context_config_data["backend_extensions"]["config_file_path"] = htp_output_filepath
|
||||
path_template = context_config_data["backend_extensions"]["shared_library_path"]
|
||||
new_lib_path = path_template.replace("{QNN_SDK_ROOT}", qnn_sdk_root_path)
|
||||
context_config_data["backend_extensions"]["shared_library_path"] = new_lib_path
|
||||
except KeyError as e:
|
||||
print(f"处理组合时出错: '{context_template_file}' 结构不正确,缺少键: {e}")
|
||||
continue
|
||||
|
||||
context_output_filename = f"context_config_{soc_id}_{dsp_arch}.json"
|
||||
# 更新路径,使其指向新的子目录
|
||||
context_output_filepath = os.path.join(graph_specific_dir, context_output_filename)
|
||||
with open(context_output_filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(context_config_data, f, indent=4, ensure_ascii=False)
|
||||
print(f"-> 已生成关联文件: '{context_output_filepath}'")
|
||||
|
||||
# --- Part C: 调用C++可执行命令 (路径更新) ---
|
||||
dst_model_filename = f"{graph_name}_{soc_id}_{dsp_arch}.mnn"
|
||||
# 更新路径,使其指向新的子目录
|
||||
dst_model_filepath = os.path.join(graph_specific_dir, dst_model_filename)
|
||||
|
||||
graph_product_dir = os.path.join(graph_specific_dir, new_graph_name)
|
||||
os.makedirs(graph_product_dir, exist_ok=True)
|
||||
print(f"-> 已创建/确认子目录: '{graph_product_dir}'")
|
||||
|
||||
command = [
|
||||
executable_path,
|
||||
src_model,
|
||||
dst_model_filepath,
|
||||
qnn_sdk_root_path,
|
||||
new_graph_name,
|
||||
context_output_filepath
|
||||
]
|
||||
|
||||
print("--> 准备执行命令...")
|
||||
print(f" $ {' '.join(command)}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(command, check=True, capture_output=True, text=True)
|
||||
print("--> 命令执行成功!")
|
||||
# 即使成功,也打印 C++ 程序的输出,这对于查看警告等信息很有用
|
||||
if result.stdout:
|
||||
print(" --- C++程序输出 (stdout) ---")
|
||||
print(result.stdout.strip())
|
||||
print(" ------------------------------")
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"!!! 命令执行失败: 可执行文件未找到 '{executable_path}'。请检查路径。")
|
||||
break # 如果可执行文件找不到,直接退出循环
|
||||
except subprocess.CalledProcessError as e:
|
||||
# 这是关键的修改部分
|
||||
print(f"!!! 命令执行失败 (返回码: {e.returncode})")
|
||||
|
||||
# 检查并打印 C++ 程序在失败前产生的标准输出
|
||||
if e.stdout:
|
||||
print(" --- C++程序输出 (stdout) ---")
|
||||
print(e.stdout.strip())
|
||||
print(" ------------------------------")
|
||||
|
||||
# 检查并打印 C++ 程序在失败前产生的标准错误(错误日志通常在这里)
|
||||
if e.stderr:
|
||||
print(" --- C++程序错误 (stderr) ---")
|
||||
print(e.stderr.strip())
|
||||
print(" ------------------------------")
|
||||
except Exception as e:
|
||||
print(f"!!! 执行期间发生未知错误: {e}")
|
||||
|
||||
finally:
|
||||
# --- 步骤 3: 清理 ---
|
||||
# 检查目录是否存在,然后删除
|
||||
if os.path.exists(graph_product_dir):
|
||||
print(f"--> 清理临时文件和目录: '{graph_product_dir}'")
|
||||
shutil.rmtree(graph_product_dir)
|
||||
else:
|
||||
print("--> 无需清理,临时目录未创建。")
|
||||
|
||||
# --- 脚本执行入口 ---
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="为多个组合创建子目录,生成QNN配置文件并调用模型转换工具。",
|
||||
formatter_class=argparse.RawTextHelpFormatter
|
||||
)
|
||||
# ... (argparse部分保持完全不变) ...
|
||||
gen_group = parser.add_argument_group('文件生成参数')
|
||||
gen_group.add_argument("--config_path", required=True, help="[必需] 包含模板文件的目录路径。")
|
||||
gen_group.add_argument("--graph_name", required=True, help="[必需] 模型图的名称 (不含soc_id等后缀)。")
|
||||
gen_group.add_argument("--qnn_sdk_root_path", required=True, help="[必需] QNN SDK 的根路径。")
|
||||
|
||||
exec_group = parser.add_argument_group('模型转换参数')
|
||||
exec_group.add_argument("--src_model", required=True, help="[必需] 源模型文件路径 (例如: my_model.mnn)。")
|
||||
exec_group.add_argument("--executable_path", required=True, help="[必需] C++模型转换可执行文件的路径。")
|
||||
exec_group.add_argument("--output_dir", default="./qnn_models", help="存放所有生成文件的输出目录 (默认: ./qnn_models)。")
|
||||
|
||||
args = parser.parse_args()
|
||||
generate_all_configs(args.config_path, args.graph_name, args.qnn_sdk_root_path, args.src_model, args.executable_path, args.output_dir)
|
|
@ -51,6 +51,21 @@ if (LLM_SUPPORT_AUDIO AND MNN_BUILD_AUDIO)
|
|||
add_definitions(-DLLM_SUPPORT_AUDIO)
|
||||
endif()
|
||||
|
||||
IF(CMAKE_SYSTEM_NAME MATCHES "^Android" AND NOT MNN_BUILD_FOR_ANDROID_COMMAND)
|
||||
IF(NOT NATIVE_INCLUDE_OUTPUT)
|
||||
set(NATIVE_INCLUDE_OUTPUT ".")
|
||||
ENDIF()
|
||||
add_custom_command(
|
||||
TARGET llm
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND}
|
||||
ARGS -E copy_directory ${CMAKE_CURRENT_LIST_DIR}/include ${NATIVE_INCLUDE_OUTPUT}
|
||||
)
|
||||
ELSE()
|
||||
INSTALL(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/include/ DESTINATION include FILES_MATCHING PATTERN *.hpp)
|
||||
ENDIF()
|
||||
|
||||
|
||||
add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/demo/llm_demo.cpp)
|
||||
add_executable(embedding_demo ${CMAKE_CURRENT_LIST_DIR}/demo/embedding_demo.cpp)
|
||||
add_executable(reranker_demo ${CMAKE_CURRENT_LIST_DIR}/demo/reranker_demo.cpp)
|
||||
|
|
|
@ -37,6 +37,23 @@ struct TimePerformance;
|
|||
using ChatMessage = std::pair<std::string, std::string>; // <role, content>
|
||||
using ChatMessages = std::vector<ChatMessage>;
|
||||
|
||||
struct MNN_PUBLIC PromptImagePart {
|
||||
MNN::Express::VARP image_data;
|
||||
int width;
|
||||
int height;
|
||||
};
|
||||
|
||||
struct MNN_PUBLIC PromptAudioPart {
|
||||
std::string file_path;
|
||||
MNN::Express::VARP waveform;
|
||||
};
|
||||
|
||||
struct MNN_PUBLIC MultimodalPrompt {
|
||||
std::string prompt_template;
|
||||
std::map<std::string, PromptImagePart> images;
|
||||
std::map<std::string, PromptAudioPart> audios;
|
||||
};
|
||||
|
||||
enum TuneType {
|
||||
// op encoder number for commit
|
||||
OP_ENCODER_NUMBER = 0,
|
||||
|
@ -108,11 +125,17 @@ public:
|
|||
std::string dump_config();
|
||||
bool set_config(const std::string& content);
|
||||
Llm* create_lora(const std::string& lora_path);
|
||||
std::string get_statistics();
|
||||
// tokenier function
|
||||
bool is_stop(int token);
|
||||
std::string tokenizer_decode(int token);
|
||||
virtual std::vector<int> tokenizer_encode(const std::string& query);
|
||||
friend class Pipeline;
|
||||
virtual std::vector<int> tokenizer_encode(const MultimodalPrompt& multimodal_input);
|
||||
void response(const MultimodalPrompt& multimodal_input,
|
||||
std::ostream* os = &std::cout,
|
||||
const char* end_with = nullptr,
|
||||
int max_new_tokens = -1);
|
||||
const LlmContext* getContext() const {
|
||||
return mContext.get();
|
||||
}
|
||||
|
|
|
@ -48,6 +48,7 @@ DiskEmbedding::DiskEmbedding(const std::shared_ptr<LlmConfig>& config, std::stri
|
|||
if (mQuantBit != 16) {
|
||||
if (mQuantBlock == 0) {
|
||||
mBlockNum = 1;
|
||||
mQuantBlock = mHiddenSize; // be used for mDequantFunc.
|
||||
} else {
|
||||
mBlockNum = mHiddenSize / mQuantBlock;
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <MNN/AutoTime.hpp>
|
||||
|
@ -97,11 +98,44 @@ bool Llm::set_config(const std::string& content) {
|
|||
return res;
|
||||
}
|
||||
|
||||
std::string Llm::get_statistics() {
|
||||
auto context = getContext();
|
||||
int prompt_len = context->prompt_len;
|
||||
int decode_len = context->gen_seq_len;
|
||||
int64_t vision_time = context->vision_us;
|
||||
int64_t audio_time = context->audio_us;
|
||||
int64_t prefill_time = context->prefill_us;
|
||||
int64_t decode_time = context->decode_us;
|
||||
int64_t sample_time = context->sample_us;
|
||||
float vision_s = vision_time / 1e6;
|
||||
float audio_s = audio_time / 1e6;
|
||||
float prefill_s = prefill_time / 1e6;
|
||||
float decode_s = decode_time / 1e6;
|
||||
float sample_s = sample_time / 1e6;
|
||||
float prefill_speed = (prefill_s > 0.0f) ? (prompt_len / prefill_s) : 0.0f;
|
||||
float decode_speed = (decode_s > 0.0f) ? (decode_len / decode_s) : 0.0f;
|
||||
|
||||
std::ostringstream json_stream;
|
||||
json_stream << "{"
|
||||
<< "\"prompt_tokens\":" << prompt_len << ","
|
||||
<< "\"decode_tokens\":" << decode_len << ","
|
||||
<< "\"vision_time\":" << std::fixed << std::setprecision(2) << vision_s << ","
|
||||
<< "\"audio_time\":" << std::fixed << std::setprecision(2) << audio_s << ","
|
||||
<< "\"prefill_time\":" << std::fixed << std::setprecision(2) << prefill_s << ","
|
||||
<< "\"decode_time\":" << std::fixed << std::setprecision(2) << decode_s << ","
|
||||
<< "\"sample_time\":" << std::fixed << std::setprecision(2) << sample_s << ","
|
||||
<< "\"prefill_speed\":" << std::fixed << std::setprecision(2) << prefill_speed << ","
|
||||
<< "\"decode_speed\":" << std::fixed << std::setprecision(2) << decode_speed
|
||||
<< "}";
|
||||
|
||||
return json_stream.str();
|
||||
}
|
||||
|
||||
void Llm::setRuntimeHint(std::shared_ptr<Express::Executor::RuntimeManager> &rtg) {
|
||||
rtg->setHint(MNN::Interpreter::INIT_THREAD_NUMBER, 4);
|
||||
|
||||
rtg->setHint(MNN::Interpreter::MEM_ALLOCATOR_TYPE, 0);
|
||||
rtg->setHint(MNN::Interpreter::QKV_QUANT_OPTIONS, mConfig->quant_qkv());
|
||||
rtg->setHint(MNN::Interpreter::QKV_QUANT_OPTIONS, mConfig->config_.value("quant_qkv", 8));
|
||||
rtg->setHint(MNN::Interpreter::KVCACHE_SIZE_LIMIT, mConfig->kvcache_limit());
|
||||
if (mConfig->use_cached_mmap()) {
|
||||
rtg->setHint(MNN::Interpreter::USE_CACHED_MMAP, 1);
|
||||
|
@ -113,6 +147,8 @@ void Llm::setRuntimeHint(std::shared_ptr<Express::Executor::RuntimeManager> &rtg
|
|||
if (mConfig->use_mmap()) {
|
||||
rtg->setExternalPath(tmpPath, MNN::Interpreter::EXTERNAL_WEIGHT_DIR);
|
||||
}
|
||||
// set npu model dir
|
||||
rtg->setExternalPath(mConfig->npu_model_dir(), 3);
|
||||
auto dynamicOption = mConfig->dynamic_option();
|
||||
if (mConfig->dynamic_option()) {
|
||||
rtg->setHint(MNN::Interpreter::DYNAMIC_QUANT_OPTIONS, mConfig->dynamic_option());
|
||||
|
@ -246,8 +282,7 @@ void Llm::load() {
|
|||
if (needHiddenState) {
|
||||
outputNames.emplace_back("hidden_states");
|
||||
}
|
||||
// set npu model dir
|
||||
mRuntimeManager->setExternalPath(mConfig->npu_model_dir(), 3);
|
||||
|
||||
mRuntimeManager->setExternalFile(mConfig->llm_weight());
|
||||
mModules[0].reset(Module::load(inputNames, outputNames, model_path.c_str(), mRuntimeManager, &module_config));
|
||||
mRuntimeManager->setExternalFile("");
|
||||
|
@ -594,6 +629,29 @@ std::vector<int> Llm::tokenizer_encode(const std::string& user_content) {
|
|||
return mTokenizer->encode(user_content);
|
||||
}
|
||||
|
||||
std::vector<int> Llm::tokenizer_encode(const MultimodalPrompt& multimodal_input) {
|
||||
return mTokenizer->encode(multimodal_input.prompt_template);
|
||||
}
|
||||
|
||||
void Llm::response(const MultimodalPrompt& multimodal_input,
|
||||
std::ostream* os, const char* end_with, int max_new_tokens) {
|
||||
auto prompt = multimodal_input.prompt_template;
|
||||
if (mConfig->use_template()) {
|
||||
prompt = mPrompt->applyTemplate(prompt, true);
|
||||
}
|
||||
|
||||
int prompt_len = 0;
|
||||
int decode_len = 0;
|
||||
int64_t vision_time = 0;
|
||||
int64_t audio_time = 0;
|
||||
int64_t prefill_time = 0;
|
||||
int64_t decode_time = 0;
|
||||
int64_t sample_time = 0;
|
||||
|
||||
std::vector<int> input_ids = tokenizer_encode(multimodal_input);
|
||||
response(input_ids, os, end_with, max_new_tokens);
|
||||
}
|
||||
|
||||
std::vector<int> Llm::generate(MNN::Express::VARP input_embeds, int max_tokens) {
|
||||
if (max_tokens < 0) {
|
||||
max_tokens = mConfig->max_new_tokens();
|
||||
|
@ -621,8 +679,8 @@ std::vector<int> Llm::generate(MNN::Express::VARP input_embeds, int max_tokens)
|
|||
}
|
||||
{
|
||||
std::ofstream outFile("logits.txt");
|
||||
auto temp = outputs[0]->readMap<float>();
|
||||
for (size_t i = 0; i < outputs[0]->getInfo()->size; ++i) {
|
||||
auto temp = mGenerateParam->outputs[0]->readMap<float>();
|
||||
for (size_t i = 0; i < mGenerateParam->outputs[0]->getInfo()->size; ++i) {
|
||||
outFile << temp[i] << " "; // 每个数字后加空格
|
||||
}
|
||||
outFile.close();
|
||||
|
|
|
@ -359,10 +359,6 @@ public:
|
|||
return config_.value("memory", "low");
|
||||
}
|
||||
|
||||
int quant_qkv() const {
|
||||
return config_.value("quant_qkv", 0);
|
||||
}
|
||||
|
||||
int kvcache_limit() const {
|
||||
return config_.value("kvcache_limit", -1);
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#define _USE_MATH_DEFINES
|
||||
#endif
|
||||
#include <regex>
|
||||
#include <algorithm>
|
||||
#include <MNN/AutoTime.hpp>
|
||||
#include <MNN/expr/ExecutorScope.hpp>
|
||||
#include "omni.hpp"
|
||||
|
@ -555,8 +556,16 @@ std::vector<int> Omni::minicpmVisionProcess(VARP image) {
|
|||
std::vector<int> Omni::visionProcess(const std::string& file) {
|
||||
#ifdef LLM_SUPPORT_VISION
|
||||
VARP image = MNN::CV::imread(file);
|
||||
return visionProcess(image);
|
||||
#else
|
||||
return std::vector<int>(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<int> Omni::visionProcess(VARP image) {
|
||||
#ifdef LLM_SUPPORT_VISION
|
||||
if (image == nullptr) {
|
||||
MNN_PRINT("Omni Can't open image: %s\n", file.c_str());
|
||||
MNN_PRINT("Omni Can't open image\n");
|
||||
return std::vector<int>(0);
|
||||
}
|
||||
Timer _t;
|
||||
|
@ -573,7 +582,7 @@ std::vector<int> Omni::visionProcess(const std::string& file) {
|
|||
} else {
|
||||
imgIds = defaultVisionProcess(image);
|
||||
}
|
||||
mContext->vision_us = _t.durationInUs();
|
||||
mContext->vision_us += _t.durationInUs();
|
||||
// set vision number for image idx
|
||||
mVisionNum += 1;
|
||||
return imgIds;
|
||||
|
@ -591,9 +600,19 @@ std::vector<int> Omni::audioProcess(const std::string& file) {
|
|||
MNN_PRINT("Omni Can't open audio: %s\n", file.c_str());
|
||||
return std::vector<int>(0);
|
||||
}
|
||||
// int sample_rate = load_res.second;
|
||||
int wav_len = waveform->getInfo()->dim[0];
|
||||
int hop_length = 160;
|
||||
return audioProcess(waveform);
|
||||
#else
|
||||
return std::vector<int>(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<int> Omni::audioProcess(MNN::Express::VARP waveform) {
|
||||
#ifdef LLM_SUPPORT_AUDIO
|
||||
if (waveform == nullptr) {
|
||||
MNN_PRINT("Omni Can't process audio: waveform is null\n");
|
||||
return std::vector<int>(0);
|
||||
}
|
||||
|
||||
Timer _t;
|
||||
auto input_features = MNN::AUDIO::whisper_fbank(waveform);
|
||||
VARP audio_embedding;
|
||||
|
@ -727,21 +746,30 @@ void Omni::addPositionIds(int t, int h, int w) {
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<int> Omni::tokenizer_encode(const std::string& prompt) {
|
||||
// split query
|
||||
std::vector<int> Omni::tokenizer_encode(const MultimodalPrompt& multimodal_input) {
|
||||
std::string prompt = multimodal_input.prompt_template;
|
||||
// MNN_PRINT("tokenizer_encode(MultimodalPrompt) prompt: %s", prompt.c_str());
|
||||
std::regex multimode_regex("<(img|audio)>(.*?)</\\1>");
|
||||
std::string::const_iterator searchStart(prompt.cbegin());
|
||||
std::smatch match;
|
||||
std::vector<std::string> img_infos;
|
||||
std::vector<int> ids{};
|
||||
|
||||
mPositionIds.clear();
|
||||
|
||||
while (std::regex_search(searchStart, prompt.cend(), match, multimode_regex)) {
|
||||
// std::cout << "img match: " << match[1].str() << std::endl;
|
||||
auto txt_ids = mTokenizer->encode(match.prefix().str());
|
||||
addPositionIds(txt_ids.size());
|
||||
ids.insert(ids.end(), txt_ids.begin(), txt_ids.end());
|
||||
auto mul_ids = multimodeProcess(match[1].str(), match[2].str());
|
||||
std::string mode = match[1].str();
|
||||
std::string content = match[2].str();
|
||||
std::vector<int> mul_ids;
|
||||
if (mode == "img") {
|
||||
mul_ids = processImageContent(content, multimodal_input.images);
|
||||
// MNN_PRINT("tokenizer_encode(MultimodalPrompt) image mul_ids size: %lu", mul_ids.size());
|
||||
} else if (mode == "audio") {
|
||||
mul_ids = processAudioContent(content, multimodal_input.audios);
|
||||
// MNN_PRINT("tokenizer_encode(MultimodalPrompt) audio mul_ids size: %lu", mul_ids.size());
|
||||
}
|
||||
|
||||
ids.insert(ids.end(), mul_ids.begin(), mul_ids.end());
|
||||
searchStart = match.suffix().first;
|
||||
}
|
||||
|
@ -753,6 +781,43 @@ std::vector<int> Omni::tokenizer_encode(const std::string& prompt) {
|
|||
return ids;
|
||||
}
|
||||
|
||||
std::vector<int> Omni::tokenizer_encode(const std::string& prompt) {
|
||||
MultimodalPrompt multimodal_input;
|
||||
multimodal_input.prompt_template = prompt;
|
||||
return tokenizer_encode(multimodal_input);
|
||||
}
|
||||
|
||||
std::vector<int> Omni::processImageContent(const std::string& content, const std::map<std::string, PromptImagePart>& images) {
|
||||
auto it = images.find(content);
|
||||
if (it != images.end()) {
|
||||
if (it->second.height > 0 && it->second.width > 0) {
|
||||
mVisionHeight = it->second.height;
|
||||
mVisionWidth = it->second.width;
|
||||
}
|
||||
// MNN_PRINT("processImageContent: using placeholder '%s' with size %dx%d", content.c_str(), mVisionWidth, mVisionHeight);
|
||||
return visionProcess(it->second.image_data);
|
||||
}
|
||||
// MNN_PRINT("processImageContent: treating '%s' as file path or URL", content.c_str());
|
||||
return multimodeProcess("img", content);
|
||||
}
|
||||
|
||||
std::vector<int> Omni::processAudioContent(const std::string& content, const std::map<std::string, PromptAudioPart>& audios) {
|
||||
auto it = audios.find(content);
|
||||
if (it != audios.end()) {
|
||||
// MNN_PRINT("processAudioContent: using placeholder '%s'", content.c_str());
|
||||
if (it->second.waveform.get() != nullptr) {
|
||||
return audioProcess(it->second.waveform);
|
||||
} else if (!it->second.file_path.empty()) {
|
||||
return audioProcess(it->second.file_path);
|
||||
} else {
|
||||
MNN_PRINT("processAudioContent: audio_part has no valid input\n");
|
||||
return std::vector<int>(0);
|
||||
}
|
||||
}
|
||||
// MNN_PRINT("processAudioContent: treating '%s' as file path", content.c_str());
|
||||
return multimodeProcess("audio", content);
|
||||
}
|
||||
|
||||
VARP Omni::embedding(const std::vector<int>& input_ids) {
|
||||
if (input_ids.size() == 1) {
|
||||
return Llm::embedding(input_ids);
|
||||
|
@ -845,21 +910,28 @@ VARP Omni::gen_position_ids(int seq_len) {
|
|||
auto ptr = positionIds->writeMap<int>();
|
||||
if (mContext->gen_seq_len > 0) {
|
||||
for (int i=0; i<seq_len; ++i) {
|
||||
auto pos = mContext->gen_seq_len + mPositionIds.back() + i;
|
||||
// auto pos = mContext->gen_seq_len + mPositionIds.back() + i;
|
||||
auto pos = mContext->all_seq_len + i;
|
||||
ptr[i + 0] = pos;
|
||||
ptr[i + seq_len] = pos;
|
||||
ptr[i + seq_len * 2] = pos;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < seq_len; i++) {
|
||||
ptr[i] = mPositionIds.mT[i];
|
||||
ptr[i + seq_len] = mPositionIds.mH[i];
|
||||
ptr[i + seq_len * 2] = mPositionIds.mW[i];
|
||||
ptr[i] = mPositionIds.mT[i] + mContext->all_seq_len;
|
||||
ptr[i + seq_len] = mPositionIds.mH[i] + mContext->all_seq_len;
|
||||
ptr[i + seq_len * 2] = mPositionIds.mW[i] + mContext->all_seq_len;
|
||||
}
|
||||
if (mTalker) {
|
||||
mTalker->setPostionIds(mPositionIds);
|
||||
}
|
||||
}
|
||||
// // dump position ids
|
||||
// printf("position_ids = [");
|
||||
// for (int i = 0; i < seq_len; i++) {
|
||||
// printf("%d ", ptr[i]);
|
||||
// }
|
||||
// printf("]\n");
|
||||
return positionIds;
|
||||
}
|
||||
|
||||
|
|
|
@ -105,12 +105,14 @@ public:
|
|||
virtual void load() override;
|
||||
virtual std::vector<Express::VARP> forwardRaw(Express::VARP hiddenState, Express::VARP mask, Express::VARP inputPos) override;
|
||||
virtual std::vector<int> tokenizer_encode(const std::string& query) override;
|
||||
virtual std::vector<int> tokenizer_encode(const MultimodalPrompt& multimodal_input) override;
|
||||
virtual Express::VARP embedding(const std::vector<int>& input_ids) override;
|
||||
virtual Express::VARP gen_position_ids(int seq_len) override;
|
||||
virtual void response(const std::vector<int>& input_ids, std::ostream* os = &std::cout, const char* end_with = nullptr, int max_new_tokens = -1) override;
|
||||
virtual void setWavformCallback(std::function<bool(const float*, size_t, bool)> callback) override;
|
||||
virtual void generateWavform() override;
|
||||
// some models preprocess function
|
||||
std::vector<int> visionProcess(VARP image);
|
||||
std::vector<int> defaultVisionProcess(VARP image);
|
||||
std::vector<int> qwen2VisionProcess(VARP image);
|
||||
std::vector<int> smolvlmVisionProcess(VARP image);
|
||||
|
@ -126,6 +128,9 @@ private:
|
|||
std::vector<int> multimodeProcess(const std::string& mode, std::string info);
|
||||
std::vector<int> visionProcess(const std::string& file);
|
||||
std::vector<int> audioProcess(const std::string& file);
|
||||
std::vector<int> audioProcess(MNN::Express::VARP waveform);
|
||||
std::vector<int> processImageContent(const std::string& content, const std::map<std::string, PromptImagePart>& images);
|
||||
std::vector<int> processAudioContent(const std::string& content, const std::map<std::string, PromptAudioPart>& audios);
|
||||
std::shared_ptr<Module> mVisionModule, mAudioModule;
|
||||
std::vector<VARP> mVisionEmbeddings, mAudioEmbeddings;
|
||||
std::shared_ptr<Talker> mTalker;
|
||||
|
|
|
@ -225,6 +225,7 @@ class LlmExporter(torch.nn.Module):
|
|||
self.llm_config['jinja'] = prompt_template['jinja']
|
||||
# load modules
|
||||
ModelMapper.do_map(self, self.model, self.model_map['model'])
|
||||
|
||||
# rebuild modules
|
||||
if self.lm_ is None:
|
||||
out_features, in_features = self.embed_.weight.shape
|
||||
|
@ -244,6 +245,7 @@ class LlmExporter(torch.nn.Module):
|
|||
self.embed = Embedding(embed_copy, self)
|
||||
else:
|
||||
self.embed = Embedding(self.embed_, self)
|
||||
|
||||
# tie word embeddings
|
||||
self.tie_word_embeddings = not self.args.seperate_embed and self.lm_.weight.equal(self.embed_.weight)
|
||||
if self.tie_word_embeddings:
|
||||
|
@ -802,14 +804,21 @@ class LlmExporter(torch.nn.Module):
|
|||
if self.mnn_converter:
|
||||
fuse_transformer = self.visual.transformer_fuse
|
||||
native_group_conv = self.visual.group_conv_native
|
||||
quant_bit_visual = self.visual.quant_bit
|
||||
quant_block_visual = self.visual.quant_block
|
||||
if self.args.transformer_fuse:
|
||||
fuse_transformer = True
|
||||
if self.args.group_conv_native:
|
||||
native_group_conv = True
|
||||
self.mnn_converter.export(vision_onnx, self.visual.quant_bit,
|
||||
self.visual.quant_block,
|
||||
if self.args.visual_quant_bit is not None:
|
||||
quant_bit_visual = self.args.visual_quant_bit
|
||||
if self.args.visual_quant_block is not None:
|
||||
quant_block_visual = self.args.visual_quant_block
|
||||
self.mnn_converter.export(vision_onnx, quant_bit_visual,
|
||||
quant_block_visual,
|
||||
transformer_fuse=fuse_transformer,
|
||||
group_conv_native=native_group_conv)
|
||||
group_conv_native=native_group_conv,
|
||||
weight_sym=self.args.visual_sym)
|
||||
|
||||
def export_audio(self):
|
||||
if self.audio is None:
|
||||
|
@ -1237,6 +1246,9 @@ def export(path,
|
|||
'onnx_slim': onnx_slim,
|
||||
'quant_bit': quant_bit,
|
||||
'quant_block': quant_block,
|
||||
'visual_quant_bit': visual_quant_bit,
|
||||
'visual_quant_block': visual_quant_block,
|
||||
'visual_sym': visual_sym,
|
||||
'lm_quant_bit': lm_quant_bit,
|
||||
'mnnconvert': mnnconvert,
|
||||
'ppl': ppl,
|
||||
|
@ -1276,6 +1288,8 @@ def main():
|
|||
parser.add_argument('--onnx_slim', action='store_true', help='Whether or not to use onnx-slim.')
|
||||
parser.add_argument('--quant_bit', type=int, default=4, help='mnn quant bit, 4 or 8, default is 4.')
|
||||
parser.add_argument('--quant_block', type=int, default=64, help='mnn quant block, 0 mean channle-wise, default is 64.')
|
||||
parser.add_argument('--visual_quant_bit', type=int, default=None, help='mnn viusal quant bit, 4 or 8, default is setting in utils/vision.py by different vit model.')
|
||||
parser.add_argument('--visual_quant_block', type=int, default=None, help='mnn quant block, default is setting in utils/vision.py by different vit model.')
|
||||
parser.add_argument('--lm_quant_bit', type=int, default=None, help='mnn lm_head quant bit, 4 or 8, default is `quant_bit`.')
|
||||
parser.add_argument('--mnnconvert', type=str, default='../../../build/MNNConvert', help='local mnnconvert path, if invalid, using pymnn.')
|
||||
parser.add_argument('--ppl', action='store_true', help='Whether or not to get all logits of input tokens.')
|
||||
|
@ -1284,9 +1298,10 @@ def main():
|
|||
parser.add_argument('--transformer_fuse', action='store_true', help='Whether or not to fuse vision transformer op.')
|
||||
parser.add_argument('--group_conv_native', action='store_true', help='Whether or not to keep native group_conv.')
|
||||
parser.add_argument('--smooth', action='store_true', help='Whether or not to use smooth quant.')
|
||||
parser.add_argument('--sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint), defualt is False.')
|
||||
parser.add_argument('--seperate_embed', action='store_true', help='For lm and embed shared model, whether or not to sepearte embed to avoid quant, defualt is False, if True, embed weight will be seperate to embeddingbf16.bin.')
|
||||
parser.add_argument('--lora_split', action='store_true', help='Whether or not export lora split, defualt is False.')
|
||||
parser.add_argument('--sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint), default is False.')
|
||||
parser.add_argument('--visual_sym', action='store_true', help='Whether or not to using symmetric quant (without zeropoint) for visual model, default is False.')
|
||||
parser.add_argument('--seperate_embed', action='store_true', help='For lm and embed shared model, whether or not to sepearte embed to avoid quant, default is False, if True, embed weight will be seperate to embeddingbf16.bin.')
|
||||
parser.add_argument('--lora_split', action='store_true', help='Whether or not export lora split, default is False.')
|
||||
parser.add_argument('--calib_data', type=str, default=None, help='calibration data path, defaut is `None` mean not use calib data.')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ class MNNConveter:
|
|||
os.close(log_fd)
|
||||
|
||||
@spinner_run(f'convert onnx model to ')
|
||||
def onnx2mnn(self, onnx_path, mnn_path, args = [], transformer_fuse = True, group_conv_native = False, save_external_data = True):
|
||||
def onnx2mnn(self, onnx_path, mnn_path, args = [], transformer_fuse = True, group_conv_native = False, weight_sym = False, save_external_data = True):
|
||||
convert_args = [
|
||||
'',
|
||||
'-f',
|
||||
|
@ -66,6 +66,8 @@ class MNNConveter:
|
|||
convert_args += ['--transformerFuse']
|
||||
if group_conv_native:
|
||||
convert_args += ['--groupConvNative']
|
||||
if weight_sym:
|
||||
convert_args += ['--weightQuantAsymmetric=0']
|
||||
if save_external_data:
|
||||
convert_args += ['--saveExternalData']
|
||||
convert_args += args
|
||||
|
@ -112,7 +114,7 @@ class MNNConveter:
|
|||
self.convert(convert_args)
|
||||
return mnn_path
|
||||
|
||||
def export(self, onnx_path, quant_bit = None, quant_block = None, transformer_fuse = True, group_conv_native = False):
|
||||
def export(self, onnx_path, quant_bit = None, quant_block = None, transformer_fuse = True, group_conv_native = False, weight_sym = None):
|
||||
self.onnx_model_path = onnx_path
|
||||
self.mnn_name = os.path.basename(onnx_path).replace('.onnx', '.mnn')
|
||||
self.mnn_model_path = os.path.join(self.config.args.dst_path, self.mnn_name)
|
||||
|
@ -133,10 +135,10 @@ class MNNConveter:
|
|||
]
|
||||
if quant_bit == 32:
|
||||
quant_args = []
|
||||
self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, quant_args, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native)
|
||||
self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, quant_args, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native, weight_sym=weight_sym)
|
||||
else:
|
||||
mnn_json = f'{self.mnn_model_path}.json'
|
||||
self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native)
|
||||
self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, transformer_fuse=transformer_fuse, group_conv_native=group_conv_native, weight_sym=weight_sym)
|
||||
self.mnn2json(self.mnn_model_path, mnn_json)
|
||||
self.rebuild(mnn_json)
|
||||
self.json2mnn(mnn_json, self.mnn_model_path)
|
||||
|
@ -511,4 +513,4 @@ class MNNConveter:
|
|||
}
|
||||
if name.startswith('/expert/'):
|
||||
post_reshape['main']['dims'] = [-1, oc]
|
||||
return [pre_reshape, pre_convert, conv_op, post_convert, post_reshape]
|
||||
return [pre_reshape, pre_convert, conv_op, post_convert, post_reshape]
|
||||
|
|
|
@ -266,7 +266,6 @@ class Rotary(torch.nn.Module):
|
|||
|
||||
def get_theta():
|
||||
return 1.0 / (self.rope_theta ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim))
|
||||
|
||||
# default rope type's theta
|
||||
self.theta = get_theta()
|
||||
# other type
|
||||
|
@ -278,7 +277,6 @@ class Rotary(torch.nn.Module):
|
|||
rope_type = config.rope_scaling['type']
|
||||
elif 'rope_type' in config.rope_scaling:
|
||||
rope_type = config.rope_scaling['rope_type']
|
||||
|
||||
# gen theta for rope_type
|
||||
if rope_type == 'dynamic': # NTK
|
||||
if 'alpha' in config.rope_scaling: # NTKAlpha in Hunyuan
|
||||
|
@ -286,7 +284,7 @@ class Rotary(torch.nn.Module):
|
|||
else: # NTKScaling
|
||||
pass
|
||||
self.theta = get_theta()
|
||||
elif rope_type == 'yarn': # YaRN in gpt-oss
|
||||
elif rope_type == 'yarn':
|
||||
self.is_scaled = True
|
||||
self.theta, self.attention_scaling = _compute_yarn_parameters(
|
||||
rotary_dim=self.rotary_dim,
|
||||
|
|
Loading…
Reference in New Issue