Merge pull request #3747 from alibaba/feature/sync

MNN:Sync: Sync Internal 3.2.2
This commit is contained in:
jxt1234 2025-07-23 15:45:53 +08:00 committed by GitHub
commit a739ea5870
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
231 changed files with 21706 additions and 1989 deletions

View File

@ -78,6 +78,7 @@ option(MNN_SUPPORT_BF16 "Enable MNN's bf16 op" OFF)
option(MNN_LOW_MEMORY "Build MNN support low memory for weight quant model." OFF)
option(MNN_CPU_WEIGHT_DEQUANT_GEMM "Build MNN CPU weight dequant related gemm kernels." OFF)
option(MNN_BUILD_AUDIO "Build audio api in MNN." OFF)
option(MNN_SME2 "Use Arm sme2 instructions" ON)
if (MNN_BUILD_MINI)
set(MNN_SKIPBUILD_GEOMETRY ON)

View File

@ -100,7 +100,7 @@ ErrorCode CPULSTM::onResize(const std::vector<Tensor *> &inputs, const std::vect
auto temp = tempBuffer->host<float>();
auto dest = dst + n * UP_DIV(timeSteps, hP) * numFeatures * hP;
MNNUnpackC4(temp, source, numFeatures, timeSteps);
MNNPackForMatMul_B(dest, temp, timeSteps, numFeatures, true);
MNNPackForMatMul_B(dest, temp, timeSteps, 1, numFeatures, true);
}
};

View File

@ -159,7 +159,8 @@ void MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup, con
return;
}
void MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose) {
void MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
auto l = kernelsize * ic;
auto hP = h / 4;
auto hR = hP * 4;
if (hR != h) {

View File

@ -71,7 +71,7 @@ static void _winograd(const DeconvolutionWithStride::ComputeUnit& unit, int thre
el[3] = 0;
size_t parameters[6];
parameters[0] = eP * sizeof(float);
parameters[1] = ic;
parameters[1] = ROUND_UP(ic, lP);
parameters[2] = oc;
parameters[3] = eP * 4 * sizeof(float);
parameters[4] = 0;
@ -130,7 +130,7 @@ static void _gemmAndIm2col(const DeconvolutionWithStride::ComputeUnit& unit, int
el[3] = 0;
size_t parameters[6];
parameters[0] = eP * sizeof(float);
parameters[1] = ic;
parameters[1] = ROUND_UP(ic, lP);
parameters[2] = oc;
parameters[3] = eP * 4 * sizeof(float);
parameters[4] = 0;

View File

@ -62,6 +62,84 @@ static void dumpCmd(const Command* cmd) {
MNN_PRINT("}\n");
}
void mergeConvolutionAndPrelu(Node* root, MNNForwardType forwardType){
if (root->cmd->op != nullptr && root->cmd->op->type() == OpType_Convolution && root->succ.size() == 1) {
auto child = root->succ[0];
if(child->cmd->op->type() == OpType_PReLU){
if(root->cmd->op->externalPath() != nullptr){
return;
}
std::shared_ptr<Command> cmdPlugin;
auto inputs = root->cmd->inputs;
auto outputs = root->cmd->outputs;
auto convOp = root->cmd->op->main_as_Convolution2D();
if(convOp->quanParameter() != nullptr || convOp->symmetricQuan() != nullptr || convOp->sparseParameter() != nullptr || convOp->external() != nullptr || convOp->common()->outputCount() != child->cmd->op->main_as_PRelu()->slopeCount()){
return;
}
std::unique_ptr<OpT> fuseOp(new OpT);
fuseOp->type = OpType_Extra;
fuseOp->name = root->cmd->op->name()->str();
ExtraT* extra_param = new ExtraT;
extra_param->type = "ExtraConvolution2DPrelu";
extra_param->attr.resize(2);
// copy convolution2D param
AttributeT* convAtr = new AttributeT;
BlobT* convParamBlob = new BlobT;
{
std::unique_ptr<Convolution2DT> convolutionParam(convOp->UnPack());
flatbuffers::FlatBufferBuilder builder;
auto lastOffset = Convolution2D::Pack(builder, convolutionParam.get());
builder.Finish(lastOffset);
const uint8_t* buffer_ptr = builder.GetBufferPointer();
const size_t size = builder.GetSize();
convParamBlob->uint8s.resize(size);
::memcpy(convParamBlob->uint8s.data(), buffer_ptr, size);
}
convAtr->tensor.reset(convParamBlob);
extra_param->attr[0].reset(convAtr);
// copy prelu param
AttributeT* preluAtr = new AttributeT;
BlobT* preluParamBlob = new BlobT;
{
std::unique_ptr<PReluT> preluParam(child->cmd->op->main_as_PRelu()->UnPack());
flatbuffers::FlatBufferBuilder builder;
auto lastOffset = PRelu::Pack(builder, preluParam.get());
builder.Finish(lastOffset);
const uint8_t* buffer_ptr = builder.GetBufferPointer();
const size_t size = builder.GetSize();
preluParamBlob->uint8s.resize(size);
::memcpy(preluParamBlob->uint8s.data(), buffer_ptr, size);
}
preluAtr->tensor.reset(preluParamBlob);
extra_param->attr[1].reset(preluAtr);
fuseOp->main.type = OpParameter_Extra;
fuseOp->main.value = extra_param;
flatbuffers::FlatBufferBuilder builder;
auto lastOffset = Op::Pack(builder, fuseOp.get());
builder.Finish(lastOffset);
cmdPlugin = GeometryComputerUtils::makeCommand(builder, inputs, outputs);
root->cmd->op = cmdPlugin->op;
root->cmd->inputs = cmdPlugin->inputs;
root->cmd->outputs = cmdPlugin->outputs;
root->cmd->buffer = cmdPlugin->buffer;
child->cmd->op = nullptr;
child->cmd->buffer.reset();
for(auto &childNode : child->succ){
for(auto &input : childNode->cmd->inputs){
if(input == child->cmd->outputs[0]){
input = root->cmd->outputs[0];
}
}
}
root->succ = child->succ;
}
}
}
// is legal fused type
bool isLegal(Command* cmd, MNNForwardType forwardType) {
auto type = cmd->op->type();
@ -369,6 +447,20 @@ bool opFuse(std::vector<Schedule::OpCacheInfo>& infos, MNNForwardType type, Back
graph.push_back(std::move(node));
}
}
if(type == MNN_FORWARD_OPENCL){
for(int i = 0; i < graph.size(); ++i){
mergeConvolutionAndPrelu(graph[i].get(), type);
}
for(auto iter = graph.begin(); iter != graph.end();){
if(iter->get()->cmd->op == nullptr){
iter = graph.erase(iter);
}else{
++iter;
}
}
}
std::queue<Node*> postDominateNodeQueue;
// build dominate tree
for (int i = static_cast<int>(graph.size()) - 1; i >= 0; i--) {

View File

@ -19,7 +19,8 @@ MNN使用CMake构建项目CMake中的宏定义列表如下
| MNN_SUPPORT_QUNAT_EXTEND | 是否编译非核心算子的量化版本,默认为`ON` |
| MNN_SUPPORT_DEPRECATED_OP | 是否支持Tflite的量化算子等已经废弃的算子用于兼容历史模型(1.1.0版本之前),默认为`OFF` |
| MNN_SUPPORT_DEPRECATED_OPV2 | 是否编译MNN更新到3.0之后已经废弃的算子,用于兼容历史模型(3.0.0版本之前),比如 Convolution3D 和 ConvTranspose3D在3.0.0 版本之后改由模型转换器转化为对应2D算子不再需要运行时支持默认为`ON` |
| MNN_REDUCE_SIZE | 是否裁剪MNN库大小去除求导相关算子减少优化策略默认为`OFF` 开启时MNN_SUPPORT_QUNAT_EXTEND / MNN_SUPPORT_DEPRECATED_OP / MNN_SUPPORT_DEPRECATED_OPV2 都会设成 OFF|
| MNN_REDUCE_SIZE | 是否裁剪MNN库大小去除求导相关算子减少优化策略默认为`OFF` 开启时MNN_SUPPORT_QUANT_EXTEND / MNN_SUPPORT_DEPRECATED_OP / MNN_SUPPORT_DEPRECATED_OPV2 都会设成 OFF|
| MNN_SUPPORT_QUANT_EXTEND | 是否开启Binary/Unary等算子的量化计算支持默认为`ON` |
| MNN_DEBUG_MEMORY | 是否开启MNN内存调试默认为`OFF` |
| MNN_DEBUG_TENSOR_SIZE | 是否开启MNN tensor size调试默认为`OFF` |
| MNN_GPU_TRACE | 是否开启MNN GPU调试默认为`OFF` |
@ -43,6 +44,7 @@ MNN使用CMake构建项目CMake中的宏定义列表如下
| MNN_OPENGL | 是否构建`OpenGL`后端,默认为`OFF` |
| MNN_VULKAN | 是否构建`Vulkan`后端,默认为`OFF` |
| MNN_ARM82 | 编译ARM架构时是否构建`Armv8.2`后端以支持FP16计算默认为`ON` |
| MNN_SME2 | 编译ARM架构时是否构建`ArmSme2`后端以支持使用sme2指令集计算默认为`ON` |
| MNN_SUPPORT_FP16_ARMV7 | 编译armeabi-v7a架构时是否构建`Armv8.2`后端以支持FP16计算默认为`OFF` |
| MNN_ONEDNN | 是否使用`oneDNN`,默认为`OFF` |
| MNN_AVX2 | 在`MNN_USE_SSE`开启的基础上是否增加AVX2指令的支持默认为`ON` |
@ -55,6 +57,8 @@ MNN使用CMake构建项目CMake中的宏定义列表如下
| MNN_TENSORRT | 是否构建`TensorRT`后端,默认为`OFF` |
| MNN_COREML | 是否构建`CoreML`后端,默认为`OFF` |
| MNN_NNAPI | 是否构建`NNAPI`后端,默认为`OFF` |
| MNN_QNN | 是否构建`QNN`后端,默认为`OFF` |
| MNN_QNN_CONVERT_MODE | 在`MNN_QNN`开启的基础上,是否构建Convert模式的QNN后端默认为`OFF` |
| MNN_BUILD_BENCHMARK | 是否构建MNN的性能测试默认为`OFF` |
| MNN_BUILD_TEST | 是否构建MNN的单元测试默认为`OFF` |
| MNN_BUILD_FOR_ANDROID_COMMAND | 是否使用命令行构建`Android`,默认为`OFF` |

View File

@ -171,6 +171,7 @@
- `rasterDemo.out` Raster示例
- `nluDemo.out` nlu模型示例
- `mergeInplaceForCPU` 将模型中可以Inplace计算的算子改成Inplace计算可以减少内存占用但限定CPU后端运行
- `OpenCLProgramBuildTest.out` 测试OpenCL后端的Program在设备上是否能编译成功
## 单元测试
- 相关编译选项
- `MNN_BUILD_TEST` 是否编译MNN单元测试

View File

@ -3132,6 +3132,20 @@ roialign
```python
TODO
```
---
### `jsonop(inputs, describe, output_number)`
jsonop
对于MNN模型支持的算子但没有相应表达式透出的情况可以使用jsonop接口以json描述算子
参数:
- `inputs` : List[Var] 输入变量数组,任意类型
- `describe` : str 算子的json描述
- `output_number` : int, 算子输出数
---
**以下函数为框架开发者使用函数,普通用户不建议使用!**

View File

@ -85,6 +85,8 @@ Usage:
--useGeluApproximation 在进行Gelu算子合并时使用Gelu的近似算法默认为1 ,也就是`true`
--useOriginRNNImpl LSTM和GRU算子是否使用原始算子实现默认关闭。若开启性能可能提升但无法进行LSTM/GRU的量化
```
**说明1: 选项weightQuantBits使用方式为 --weightQuantBits numBitsnumBits可选2~8此功能仅对conv/matmul/LSTM的float32权值进行量化仅优化模型大小加载模型后会解码为float32量化位宽可选2~8运行速度和float32模型一致。经内部测试8bit时精度基本无损模型大小减小4倍。default: 0即不进行权值量化。**

View File

@ -336,6 +336,7 @@ void Executor::RuntimeManager::setCache(std::string cacheName) {
mInside->mCache.reset(new Cache);
mInside->mCache->cacheFile = cacheName;
mInside->mInfo->onSetCachePath(cacheName.c_str(), 0);
if (nullptr == mInside->mCache->cacheFile.c_str()) {
MNN_ERROR("Empty cacheFile\n");
return;

View File

@ -1339,6 +1339,18 @@ std::vector<EXPRP> Variable::getExecuteOrder(const std::vector<VARP>& outputs) {
if (nullptr == output) {
continue;
}
if (nullptr == output->expr().first) {
continue;
}
auto op = output->expr().first->get();
bool isConst = ((op && op->type() == OpType_Const) || (!op && output->expr().first->inputType() == VARP::CONSTANT));
if (isConst) {
if (!output->expr().first->visited()){
output->expr().first->setVisited(true);
sequence.emplace_back(output->expr().first);
continue;
}
}
workStack.push(output->expr().first);
}
while (!workStack.empty()) {

View File

@ -128,6 +128,10 @@ VARP _Conv(VARP weight, VARP bias, VARP x, PaddingMode pad, INTS stride, INTS di
std::unique_ptr<OpT> convOp(new OpT);
convOp->type = OpType_Convolution;
auto shape = weight->getInfo();
if (shape == nullptr) {
MNN_ERROR("Weight for convolution should have shape information.\n");
return nullptr;
}
if (NHWC == shape->order) {
weight = _Transpose(weight, {0, 3, 1, 2});
shape = weight->getInfo();
@ -268,6 +272,10 @@ VARP _Deconv(VARP weight, VARP bias, VARP x, PaddingMode pad, INTS stride, INTS
std::unique_ptr<OpT> convOp(new OpT);
convOp->type = OpType_Deconvolution;
auto shape = weight->getInfo();
if (shape == nullptr) {
MNN_ERROR("weight's info is null\n");
return nullptr;
}
auto channel = std::vector<int>{shape->dim[1], shape->dim[0]};
auto kernelSize = std::vector<int>{shape->dim[3], shape->dim[2]};
if (channel[1] * channel[0] == group) {
@ -1048,6 +1056,10 @@ VARP _BatchToSpaceND(VARP input, VARP block_shape, VARP crops) {
auto info_block_shape = block_shape->getInfo();
auto info_crops = crops->getInfo();
if (info_block_shape == nullptr || info_crops == nullptr) {
MNN_ERROR("BatchToSpaceND: block_shape or crops info is null.\n");
return nullptr;
}
MNN_ASSERT(info_block_shape != nullptr);
MNN_ASSERT(info_crops != nullptr);
MNN_ASSERT(halide_type_int == info_block_shape->type.code);
@ -1057,6 +1069,10 @@ VARP _BatchToSpaceND(VARP input, VARP block_shape, VARP crops) {
blob_blockShape->dataFormat = (MNN_DATA_FORMAT)Utils::convertFormat(info_block_shape->order);
blob_blockShape->dataType = (MNN::DataType)Utils::convertDataType(info_block_shape->type);
auto data_block_shape = block_shape->readMap<int>();
if (data_block_shape == nullptr) {
MNN_ERROR("BatchToSpaceND: block_shape data is null.\n");
return nullptr;
}
for (int i=0; i<info_block_shape->size; i++)
{
blob_blockShape->int32s.emplace_back(data_block_shape[i]);
@ -1065,6 +1081,10 @@ VARP _BatchToSpaceND(VARP input, VARP block_shape, VARP crops) {
blob_paddings->dataFormat = (MNN_DATA_FORMAT)Utils::convertFormat(info_crops->order);
blob_paddings->dataType = (MNN::DataType)Utils::convertDataType(info_crops->type);
auto data_crop = crops->readMap<int>();
if (data_crop == nullptr) {
MNN_ERROR("BatchToSpaceND: crops data is null.\n");
return nullptr;
}
for (int i=0; i<info_crops->size; i++)
{
blob_paddings->int32s.emplace_back(data_crop[i]);
@ -1178,6 +1198,10 @@ VARP _SpaceToBatchND(VARP input, VARP block_shape, VARP paddings) {
auto param = new SpaceBatchT;
auto info_block_shape = block_shape->getInfo();
auto info_paddings = paddings->getInfo();
if (info_block_shape == nullptr || info_paddings == nullptr) {
MNN_ERROR("SpaceToBatchND: block_shape or paddings info is null.\n");
return nullptr;
}
MNN_ASSERT(info_block_shape != nullptr);
MNN_ASSERT(info_paddings != nullptr);
MNN_ASSERT(halide_type_int == info_block_shape->type.code);
@ -1187,6 +1211,10 @@ VARP _SpaceToBatchND(VARP input, VARP block_shape, VARP paddings) {
blob_blockShape->dataFormat = (MNN::MNN_DATA_FORMAT)Utils::convertFormat(info_block_shape->order);
blob_blockShape->dataType = (MNN::DataType)Utils::convertDataType(info_block_shape->type);
auto data_block_shape = block_shape->readMap<int>();
if (data_block_shape == nullptr) {
MNN_ERROR("SpaceToBatchND: block_shape data is null.\n");
return nullptr;
}
for (int i=0; i<info_block_shape->size; i++)
{
blob_blockShape->int32s.emplace_back(data_block_shape[i]);
@ -1195,6 +1223,10 @@ VARP _SpaceToBatchND(VARP input, VARP block_shape, VARP paddings) {
blob_paddings->dataFormat = (MNN::MNN_DATA_FORMAT)Utils::convertFormat(info_paddings->order);
blob_paddings->dataType = (MNN::DataType)Utils::convertDataType(info_paddings->type);
auto data_paddings = paddings->readMap<int>();
if (data_paddings == nullptr) {
MNN_ERROR("SpaceToBatchND: paddings data is null.\n");
return nullptr;
}
for (int i=0; i<info_paddings->size; i++)
{
blob_paddings->int32s.emplace_back(data_paddings[i]);
@ -1235,7 +1267,10 @@ std::vector <VARP> _Unstack(VARP value, int axis) {
std::unique_ptr<OpT> op(new OpT);
op->type = OpType_Unpack;
auto info_value = value->getInfo();
MNN_ASSERT(info_value != nullptr);
if (info_value == nullptr) {
MNN_ERROR("Unstack: value info is null.\n");
return {};
}
auto dims = info_value->dim;
auto dimsize = dims.size();
MNN_ASSERT(dimsize >= 1);
@ -1703,9 +1738,12 @@ VARP _FloatToInt8(VARP x, VARP scale, int8_t minValue, int8_t maxValue, int8_t z
}
VARP _Int8ToFloat(VARP x, VARP scale) {
auto xInfo = x->getInfo();
auto scaleInfo = scale->getInfo();
auto scalePtr = scale->readMap<float>();
if (nullptr == scaleInfo) {
MNN_ERROR("Error for Int8ToFloat because var not ready\n");
return nullptr;
}
std::unique_ptr<OpT> op(new OpT);
op->type = OpType_Int8ToFloat;
op->main.type = OpParameter_QuantizedFloatParam;
@ -1716,9 +1754,21 @@ VARP _Int8ToFloat(VARP x, VARP scale) {
}
VARP _Int8ToFloat(VARP x, VARP scale, int8_t zeroPoint) {
auto xInfo = x->getInfo();
auto scaleInfo = scale->getInfo();
if (nullptr == scaleInfo) {
MNN_ERROR("Error for Int8ToFloat because var not ready\n");
return nullptr;
}
if (scaleInfo->size <= 0) {
MNN_ERROR("Error for Int8ToFloat because scale size is zero\n");
return nullptr;
}
auto scalePtr = scale->readMap<float>();
if (nullptr == scalePtr) {
MNN_ERROR("Error for Int8ToFloat because scale data is null\n");
return nullptr;
}
std::unique_ptr<OpT> op(new OpT);
op->type = OpType_Int8ToFloat;
op->main.type = OpParameter_QuantizedFloatParam;
@ -1786,13 +1836,10 @@ VARP _Sort(VARP x, int axis, bool arg, bool descend) {
auto topk = new TopKV2T;
topk->largest = descend;
op->main.value = topk;
auto shape = x->getInfo()->dim;
axis = axis < 0 ? shape.size() + axis : axis;
int k = x->getInfo()->dim[axis];
std::vector<VARP> inputs {x, _Scalar(k)};
if (axis + 1 != shape.size()) {
inputs.push_back(_Scalar(axis));
}
auto posAxis = _Mod(_Scalar(axis), _Rank(x));
auto K = _Slice(_Shape(x), _Unsqueeze(posAxis, {0}), _Unsqueeze(_Scalar<int32_t>(1), {0}));
std::vector<VARP> inputs {x, K, _Scalar(axis)};
auto expr = Expr::create(op.get(), inputs, 2);
return Variable::create(expr, arg);
}

View File

@ -202,25 +202,52 @@ Executor::ComputeCache::~ComputeCache() {
FUNC_PRINT(gInstanceCount);
#endif
}
Executor::RuntimeExecuteWrap::RuntimeExecuteWrap(const RuntimeInfo& info) : mRt(info) {
for (auto& iter : mRt.first) {
iter.second->onConcurrencyBegin();
}
}
Executor::RuntimeExecuteWrap::~RuntimeExecuteWrap() {
for (auto& iter : mRt.first) {
iter.second->onConcurrencyEnd();
}
}
ErrorCode Executor::ComputeCache::compute() {
std::stack<ComputeCache*> dfsStack;
std::set<ComputeCache*> visited;
dfsStack.push(this);
ErrorCode code = NO_ERROR;
auto globalExecutor = ExecutorScope::Current();
auto& rt = globalExecutor->mRuntimeInfo;
for (auto& iter : rt.first) {
iter.second->onConcurrencyBegin();
}
auto debug = globalExecutor->getDebugTools();
auto hasUnvisitInput = [&] (ComputeCache* cache) {
for (auto c : cache->mInputs) {
if (visited.find(c.get()) == visited.end()) {
return true;
}
}
return false;
};
// Check need compute or not
while (!dfsStack.empty()) {
//printf("stcak = %d\n", dfsStack.size());
auto cache = dfsStack.top();
dfsStack.pop();
for (auto& c : cache->mInputInside) {
if (c->mContentDirty) {
return CALL_BACK_STOP;
}
}
if (hasUnvisitInput(cache)) {
for (auto c : cache->mInputs) {
dfsStack.push(c.get());
}
}
}
// Compute
visited.clear();
dfsStack.push(this);
ErrorCode code = NO_ERROR;
auto glo = ExecutorScope::Current();
RuntimeExecuteWrap wrap(glo->mRuntimeInfo);
auto debug = glo->getDebugTools();
while (!dfsStack.empty()) {
auto cache = dfsStack.top();
if (cache->mShapeDirty) {
auto code = cache->resize();
if (NO_ERROR != code) {
@ -233,15 +260,7 @@ ErrorCode Executor::ComputeCache::compute() {
dfsStack.pop();
continue;
}
auto hasUnvisitInput = [&] () {
for (auto c : cache->mInputs) {
if (visited.find(c.get()) == visited.end()) {
return true;
}
}
return false;
};
if (hasUnvisitInput()) {
if (hasUnvisitInput(cache)) {
for (auto c : cache->mInputs) {
dfsStack.push(c.get());
}
@ -259,9 +278,6 @@ ErrorCode Executor::ComputeCache::compute() {
cache->mContentDirty = false;
}
}
for (auto& iter : rt.first) {
iter.second->onConcurrencyEnd();
}
return NO_ERROR;
}
ErrorCode Executor::ComputeCache::resizeImpl() {

View File

@ -82,7 +82,13 @@ public:
static Tensor* getTensor(VARP var);
static EXPRP makeRaster(const std::vector<VARP>& vars, const std::vector<int>& regions, const std::vector<int>& shape, halide_type_t dataType, MNN_DATA_FORMAT format);
};
class Executor::RuntimeExecuteWrap {
public:
RuntimeExecuteWrap(const RuntimeInfo& info);
~ RuntimeExecuteWrap();
private:
const RuntimeInfo& mRt;
};
} // namespace Express
} // namespace MNN
#endif

View File

@ -216,12 +216,10 @@ public:
auto glo = ExecutorScope::Current();
glo->getDebugTools()->flops = 0.0f;
#endif
for (auto& iter : mInfo->runTimeManager->getInside()->mRuntime.first) {
iter.second->onConcurrencyBegin();
}
auto outputs = mModule->onForward(inputs);
for (auto& iter : mInfo->runTimeManager->getInside()->mRuntime.first) {
iter.second->onConcurrencyEnd();
std::vector<VARP> outputs;
{
Executor::RuntimeExecuteWrap wrap(mInfo->runTimeManager->getInside()->mRuntime);
outputs = mModule->onForward(inputs);
}
#ifdef MNN_INTERNAL_ENABLED
do {

View File

@ -775,6 +775,10 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
std::set<int> outputIndexes;
std::map<int, int> stackMap;
std::map<std::string, int> outputIndexesMap;
std::set<std::string> outputNameSet;
for (auto name : outputs) {
outputIndexesMap.insert(std::make_pair(name, -1));
}
for (int i=0; i<net->tensorName()->size(); ++i) {
auto tname = net->tensorName()->GetAsString(i)->str();
for (int j=0; j<inputs.size(); ++j) {
@ -784,20 +788,20 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
break;
}
}
for (int j=0; j<outputs.size(); ++j) {
if (tname == outputs[j]) {
outputIndexes.emplace(i);
outputIndexesMap.insert(std::make_pair(tname, i));
break;
}
auto outputIter = outputIndexesMap.find(tname);
if (outputIter != outputIndexesMap.end()) {
outputIndexes.emplace(i);
outputIter->second = i;
}
}
if (outputIndexesMap.size() != outputs.size()) {
MNN_ERROR("PipelineModule:: Can't find enough output from the model, finded is:\n");
for (auto& iter : outputIndexesMap) {
MNN_ERROR("[ %s ] ", iter.first.c_str());
bool valid = true;
for (auto& iter : outputIndexesMap) {
if (iter.second == -1) {
MNN_ERROR("PipelineModule:: Can't find output %s from the model:\n", iter.first.c_str());
valid = false;
}
MNN_ERROR("\n");
}
if (!valid) {
return nullptr;
}
bool divideSuccess = true;
@ -810,6 +814,7 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
for (int i=0; i<subModulesInfo.size(); ++i) {
subModules[i].reset(_createSubModule(bufferStorage, subModulesInfo[i], subGraphMap, sharedConst, *config, modRuntime));
}
bool needReplaceBackend = false;
if (preReplaceConstTensor) {
// Prereplace const tensor
auto curBackend = sharedConst->constReplaceBackend.get();
@ -838,6 +843,7 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
if (!tempRes) {
continue;
}
needReplaceBackend = true;
outDes->stageMask |= Tensor::InsideDescribe::CONVERTED_STAGE;
WrapExecution::copyReplaceTensor(wrapTensor.get(), t.get());
}
@ -848,6 +854,11 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
for (auto index : noneedComputeIndexes) {
auto tensor = Tensor::clone(sharedConst->allTensors[index].get());
auto constVar = Variable::create(Expr::create(tensor, true));
auto back = TensorUtils::getDescribeOrigin(tensor)->getBackend();
auto x =sharedConst->constReplaceBackend.get();
if (needReplaceBackend && TensorUtils::getDescribeOrigin(tensor)->getBackend() == sharedConst->constReplaceBackend.get()) {
constVar->expr().first->inside()->mHoldBackend = sharedConst->constReplaceBackend;
}
initVars.insert(std::make_pair(index, constVar));
}
auto result = new PipelineModule;

View File

@ -216,8 +216,9 @@ public:
// Geometry Compute option, default is 0xFFFF
GEOMETRY_COMPUTE_MASK = 4,
// 0: Close dynamic quant;
// default 0
// 1: For general convolution, use one scale&zeropoint to quant.
// 2: use block-quant for input data.
DYNAMIC_QUANT_OPTIONS = 5,
// For Mobile CPU with big-litter core, set decrease rate to let MNN divide task differential by CPU's performance
@ -247,8 +248,11 @@ public:
// Multi-Thread Load module, default is 0 (don't use other Thread)
INIT_THREAD_NUMBER = 13,
// CPU core ids
// Used CPU ids
CPU_CORE_IDS = 14,
// set CPU threads to use when supports Arm sme2
CPU_SME2_INSTRUCTIONS = 15
};
enum ExternalPathType {

View File

@ -76,6 +76,6 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
#define STR(x) STR_IMP(x)
#define MNN_VERSION_MAJOR 3
#define MNN_VERSION_MINOR 2
#define MNN_VERSION_PATCH 1
#define MNN_VERSION_PATCH 2
#define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH)
#endif /* MNNDefine_h */

View File

@ -27,6 +27,7 @@ struct ExecutorAttr;
class MNN_PUBLIC Executor {
public:
class ComputeCache;
class RuntimeExecuteWrap;
struct DebugTools;
/**Internal Usage Begin*/
struct Requirement {

View File

@ -21,7 +21,7 @@ cd ../
rm -rf ios_32
mkdir ios_32
cd ios_32
cmake ../../../ -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../../../cmake/ios.toolchain.cmake -DMNN_METAL=ON -DARCHS="armv7;armv7s" -DENABLE_BITCODE=0 -DMNN_AAPL_FMWK=1 -DMNN_SEP_BUILD=0 -DMNN_BUILD_SHARED_LIBS=false -DMNN_USE_THREAD_POOL=OFF $1
cmake ../../../ -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../../../cmake/ios.toolchain.cmake -DMNN_METAL=ON -DARCHS="armv7;armv7s" -DENABLE_BITCODE=0 -DMNN_AAPL_FMWK=1 -DMNN_SEP_BUILD=0 -DMNN_BUILD_SHARED_LIBS=false -DMNN_USE_THREAD_POOL=OFF $*
echo "Building AArch32"
make MNN -j16
echo "End Building AArch32"

View File

@ -0,0 +1,41 @@
#!/bin/sh
echo "Change directory to MNN_SOURCE_ROOT/project/ios before running this script"
echo "Current PWD: ${PWD}"
rm -rf MNN-iOS-CPU-GPU
mkdir MNN-iOS-CPU-GPU
cd MNN-iOS-CPU-GPU
# Static Begin
mkdir Static
cd Static
COMMON="-DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../../../cmake/ios.toolchain.cmake -DENABLE_BITCODE=0 -DMNN_AAPL_FMWK=1 -DMNN_SEP_BUILD=0 -DMNN_BUILD_SHARED_LIBS=false -DMNN_USE_THREAD_POOL=OFF"
rm -rf ios_64
mkdir ios_64
cd ios_64
cmake ../../../ ${COMMON} -DMNN_METAL=ON -DARCHS="arm64" $*
echo "Building AArch64"
make MNN -j16
echo "End Building AArch64"
cd ../
rm -rf ios_32
mkdir ios_32
cd ios_32
cmake ../../../ ${COMMON} -DMNN_METAL=OFF -DPLATFORM=SIMULATOR64 -DARCHS="x86_64" $*
echo "Building Simulator64"
make MNN -j16
echo "End Building Simulator64"
cd ../
mv ios_32/MNN.framework/MNN ios_32/MNN.framework/MNN_32
echo "Creating Fat Binary"
lipo -create ios_32/MNN.framework/MNN_32 ios_64/MNN.framework/MNN -output ios_32/MNN.framework/MNN
rm ios_32/MNN.framework/MNN_32
echo "Patching Framework Headers"
rm -rf ./MNN.framework
cp -R ios_32/MNN.framework ./MNN.framework
rm -rf ios_32
rm -rf ios_64

View File

@ -706,12 +706,15 @@
9558333D29B0947300488807 /* MNNGelu.S in Sources */ = {isa = PBXBuildFile; fileRef = 9558333C29B0947300488807 /* MNNGelu.S */; };
9558334729B09A2300488807 /* MNNGelu.S in Sources */ = {isa = PBXBuildFile; fileRef = 9558334629B09A2300488807 /* MNNGelu.S */; };
9558334B29B09A7B00488807 /* MNNGeluFP16.S in Sources */ = {isa = PBXBuildFile; fileRef = 9558334A29B09A7B00488807 /* MNNGeluFP16.S */; settings = {COMPILER_FLAGS = "-march=armv8.2-a+fp16"; }; };
955AD7522E1FB44E0099F26C /* MoEModule.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 955AD7512E1FB44E0099F26C /* MoEModule.cpp */; };
955AD7532E1FB44E0099F26C /* MoEModule.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 955AD7502E1FB44E0099F26C /* MoEModule.hpp */; };
9560EAD62BDE426A00C8D0B6 /* GeometryLayernorm.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 9560EAD52BDE426A00C8D0B6 /* GeometryLayernorm.cpp */; };
956F52E12AB2D692004B13D9 /* ImageProcessUtils.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 956F52E02AB2D692004B13D9 /* ImageProcessUtils.cpp */; };
956F52E32AB2D6A1004B13D9 /* ImageProcessUtils.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 956F52E22AB2D6A1004B13D9 /* ImageProcessUtils.hpp */; };
95772DCF2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM82.S in Sources */ = {isa = PBXBuildFile; fileRef = 95772DCD2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM82.S */; };
95772DD02C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM86.S in Sources */ = {isa = PBXBuildFile; fileRef = 95772DCE2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM86.S */; };
958375352A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S in Sources */ = {isa = PBXBuildFile; fileRef = 958375342A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S */; };
959F15A92E2782F800C67803 /* CountMinMaxValue_FP16.S in Sources */ = {isa = PBXBuildFile; fileRef = 959F15A82E2782F800C67803 /* CountMinMaxValue_FP16.S */; settings = {COMPILER_FLAGS = "-march=armv8.2-a+fp16"; }; };
C43C81FA251894A600A0FF84 /* CommonOptFunctionNeon.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C43C81F8251894A500A0FF84 /* CommonOptFunctionNeon.cpp */; };
C43C81FE251894BD00A0FF84 /* CPUPlugin.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C43C81FB251894BD00A0FF84 /* CPUPlugin.cpp */; };
C43C81FF251894BD00A0FF84 /* ThreadPool.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C43C81FC251894BD00A0FF84 /* ThreadPool.cpp */; };
@ -754,10 +757,6 @@
CE072A282C91AF0700F190FD /* MNNC3ToXYZFast.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A252C91AF0700F190FD /* MNNC3ToXYZFast.S */; };
CE072A2A2CAA50DE00F190FD /* MNNDepthwiseConvFastKernel.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A292CAA50DE00F190FD /* MNNDepthwiseConvFastKernel.S */; };
CE072A2C2CAA510F00F190FD /* MNNDepthwiseConvFastKernelFP16.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A2B2CAA510F00F190FD /* MNNDepthwiseConvFastKernelFP16.S */; settings = {COMPILER_FLAGS = "-march=armv8.2-a+fp16"; }; };
CE0AD4E42E1FB106002013A8 /* CountMinMaxValue_FP16.S in Sources */ = {isa = PBXBuildFile; fileRef = CE0AD4E32E1FB106002013A8 /* CountMinMaxValue_FP16.S */; settings = {COMPILER_FLAGS = "-march=armv8.2-a+fp16"; }; };
CE0AD4E82E1FB152002013A8 /* MoEModule.hpp in Headers */ = {isa = PBXBuildFile; fileRef = CE0AD4E62E1FB152002013A8 /* MoEModule.hpp */; };
CE0AD4E92E1FB152002013A8 /* ModuleInside.hpp in Headers */ = {isa = PBXBuildFile; fileRef = CE0AD4E52E1FB152002013A8 /* ModuleInside.hpp */; };
CE0AD4EA2E1FB152002013A8 /* MoEModule.cpp in Sources */ = {isa = PBXBuildFile; fileRef = CE0AD4E72E1FB152002013A8 /* MoEModule.cpp */; };
CE125CC82A52BF6B003698C9 /* MNNBilinearSampleC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CE125CC62A52BF6B003698C9 /* MNNBilinearSampleC8.S */; };
CE125CC92A52BF6B003698C9 /* MNNBilinearLineC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CE125CC72A52BF6B003698C9 /* MNNBilinearLineC8.S */; };
CE31C7C12D783CBB00741F49 /* WorkerThread.cpp in Sources */ = {isa = PBXBuildFile; fileRef = CE31C7C02D783CBB00741F49 /* WorkerThread.cpp */; };
@ -1543,12 +1542,15 @@
9558333C29B0947300488807 /* MNNGelu.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNGelu.S; sourceTree = "<group>"; };
9558334629B09A2300488807 /* MNNGelu.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNGelu.S; sourceTree = "<group>"; };
9558334A29B09A7B00488807 /* MNNGeluFP16.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNGeluFP16.S; path = ../../../arm82/asm/arm64/MNNGeluFP16.S; sourceTree = "<group>"; };
955AD7502E1FB44E0099F26C /* MoEModule.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = MoEModule.hpp; sourceTree = "<group>"; };
955AD7512E1FB44E0099F26C /* MoEModule.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = MoEModule.cpp; sourceTree = "<group>"; };
9560EAD52BDE426A00C8D0B6 /* GeometryLayernorm.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = GeometryLayernorm.cpp; sourceTree = "<group>"; };
956F52E02AB2D692004B13D9 /* ImageProcessUtils.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ImageProcessUtils.cpp; sourceTree = "<group>"; };
956F52E22AB2D6A1004B13D9 /* ImageProcessUtils.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = ImageProcessUtils.hpp; sourceTree = "<group>"; };
95772DCD2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM82.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackC4Int8ForMatMulA_ARM82.S; sourceTree = "<group>"; };
95772DCE2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM86.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackC4Int8ForMatMulA_ARM86.S; sourceTree = "<group>"; };
958375342A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S; path = arm/arm64/MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S; sourceTree = "<group>"; };
959F15A82E2782F800C67803 /* CountMinMaxValue_FP16.S */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.asm; name = CountMinMaxValue_FP16.S; path = ../../../arm82/asm/arm64/CountMinMaxValue_FP16.S; sourceTree = "<group>"; };
C43C81F8251894A500A0FF84 /* CommonOptFunctionNeon.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CommonOptFunctionNeon.cpp; sourceTree = "<group>"; };
C43C81FB251894BD00A0FF84 /* CPUPlugin.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CPUPlugin.cpp; sourceTree = "<group>"; };
C43C81FC251894BD00A0FF84 /* ThreadPool.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ThreadPool.cpp; sourceTree = "<group>"; };
@ -1591,10 +1593,6 @@
CE072A252C91AF0700F190FD /* MNNC3ToXYZFast.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNC3ToXYZFast.S; path = arm/arm64/MNNC3ToXYZFast.S; sourceTree = "<group>"; };
CE072A292CAA50DE00F190FD /* MNNDepthwiseConvFastKernel.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNDepthwiseConvFastKernel.S; sourceTree = "<group>"; };
CE072A2B2CAA510F00F190FD /* MNNDepthwiseConvFastKernelFP16.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNDepthwiseConvFastKernelFP16.S; path = ../../../arm82/asm/arm64/MNNDepthwiseConvFastKernelFP16.S; sourceTree = "<group>"; };
CE0AD4E32E1FB106002013A8 /* CountMinMaxValue_FP16.S */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.asm; name = CountMinMaxValue_FP16.S; path = ../arm82/asm/arm64/CountMinMaxValue_FP16.S; sourceTree = "<group>"; };
CE0AD4E52E1FB152002013A8 /* ModuleInside.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = ModuleInside.hpp; sourceTree = "<group>"; };
CE0AD4E62E1FB152002013A8 /* MoEModule.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = MoEModule.hpp; sourceTree = "<group>"; };
CE0AD4E72E1FB152002013A8 /* MoEModule.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = MoEModule.cpp; sourceTree = "<group>"; };
CE125CC62A52BF6B003698C9 /* MNNBilinearSampleC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearSampleC8.S; sourceTree = "<group>"; };
CE125CC72A52BF6B003698C9 /* MNNBilinearLineC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearLineC8.S; sourceTree = "<group>"; };
CE31C7BF2D783CBB00741F49 /* WorkerThread.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = WorkerThread.hpp; sourceTree = "<group>"; };
@ -1859,6 +1857,7 @@
488873A8215B639D0079B12E /* source */ = {
isa = PBXGroup;
children = (
CE482EF5288536DA007CD935 /* internal */,
4DF87C482887D3560003E2D4 /* calib3d */,
4D4CF4612760946500A36D9F /* imgproc */,
4D9A931B26255BDA00F9B43C /* coreml */,
@ -1926,7 +1925,6 @@
48887410215B639D0079B12E /* cpu */ = {
isa = PBXGroup;
children = (
CE0AD4E32E1FB106002013A8 /* CountMinMaxValue_FP16.S */,
CEA3C8892D6D71E1003EFAD2 /* CPUStft.hpp */,
CEA3C88A2D6D71E1003EFAD2 /* CPUStft.cpp */,
CE072A242C91AF0700F190FD /* MNNC3ToC4Fast.S */,
@ -2211,9 +2209,8 @@
48C84B6F250F711600EE7666 /* module */ = {
isa = PBXGroup;
children = (
CE0AD4E52E1FB152002013A8 /* ModuleInside.hpp */,
CE0AD4E62E1FB152002013A8 /* MoEModule.hpp */,
CE0AD4E72E1FB152002013A8 /* MoEModule.cpp */,
955AD7502E1FB44E0099F26C /* MoEModule.hpp */,
955AD7512E1FB44E0099F26C /* MoEModule.cpp */,
48C84B71250F711600EE7666 /* PipelineModule.cpp */,
48C84B72250F711600EE7666 /* Module.cpp */,
48C84B73250F711600EE7666 /* WhileModule.hpp */,
@ -2643,6 +2640,7 @@
4896D37025FE2A6A00717702 /* MNNExpFP16.S */,
4896D37125FE2A6A00717702 /* MNNPackedMatMulFP16.S */,
4896D37225FE2A6A00717702 /* MNNPackedMatMulRemainFP16.S */,
959F15A82E2782F800C67803 /* CountMinMaxValue_FP16.S */,
11A01A0A258785FB00745FA7 /* MNNVectorTop1Float.S */,
11A01A0B258785FB00745FA7 /* MNNVectorTop1Int32.S */,
48034566254157DF004738E3 /* MNNNV21ToBGRAUnit.S */,
@ -2892,6 +2890,7 @@
CEA82BDC2A15F8AD002CBC95 /* IdstConvolutionInt8.hpp in Headers */,
4DE4E82C275E307B0016A916 /* cv in Headers */,
1F501F842397BA5B004E8721 /* ImageProcess.hpp in Headers */,
CECF8C5D299CACFD00D3875B /* Log.hpp in Headers */,
1F501F822397BA5B004E8721 /* Interpreter.hpp in Headers */,
C4F906B327688C3A0026B847 /* NMSModule.hpp in Headers */,
1F501F882397BA5B004E8721 /* Tensor.hpp in Headers */,
@ -2902,6 +2901,7 @@
48C84B98250F71E900EE7666 /* CPUSoftmax.hpp in Headers */,
4882C8B8241A22B800DAC168 /* OpCommonUtils.hpp in Headers */,
48608B54250632EC00CB1D71 /* GeometryComputer.hpp in Headers */,
CECF8C7A299CAD9400D3875B /* sha1.h in Headers */,
4894C6EC27016F7200D8BE79 /* CPUResizeCache.hpp in Headers */,
92FF04A623AA0BFB00AC97F6 /* FileLoader.hpp in Headers */,
48F34733273A7C8400C45394 /* ImageProcessFunction.hpp in Headers */,
@ -2915,6 +2915,7 @@
48925F352744AC0700919B37 /* CPUROIAlign.hpp in Headers */,
92FF029623AA0B5A00AC97F6 /* CPUCast.hpp in Headers */,
4D9A937826255BDA00F9B43C /* CoreMLBinary.hpp in Headers */,
CECF8C85299CAD9400D3875B /* log_util.h in Headers */,
4D6D7FD52656896600F80814 /* DenseConvolutionTiledExecutor.hpp in Headers */,
4D9A936626255BDA00F9B43C /* CoreMLExecutor.h in Headers */,
92FF027A23AA0B5A00AC97F6 /* CPUPool.hpp in Headers */,
@ -2923,6 +2924,7 @@
1F501F802397BA5B004E8721 /* MNNDefine.h in Headers */,
19D0FE76285C66F200B74B1A /* MetalLayerNorm.hpp in Headers */,
489D7A682550FDC800AD896A /* MetalReduction.hpp in Headers */,
CECF8C86299CAD9400D3875B /* sds.h in Headers */,
1F501F7F2397BA5B004E8721 /* HalideRuntime.h in Headers */,
92FF029E23AA0B5A00AC97F6 /* CPUDeconvolutionDepthwise.hpp in Headers */,
4D9A935B26255BDA00F9B43C /* NeuralNetwork.pb-c.h in Headers */,
@ -2944,8 +2946,10 @@
481C2DEE25FE2CD6001ED6DF /* Arm82Functions.hpp in Headers */,
4894C6EA27016F7200D8BE79 /* UnaryUtils.hpp in Headers */,
EBD4842A2485FF650083CE95 /* Arm82Interp.hpp in Headers */,
CECF8C81299CAD9400D3875B /* log_util_imp.h in Headers */,
92FF037623AA0B5A00AC97F6 /* CPUBinary.hpp in Headers */,
4D9A935826255BDA00F9B43C /* FeatureTypes.pb-c.h in Headers */,
CECF8C7C299CAD9400D3875B /* hmac-sha.h in Headers */,
48608B53250632EC00CB1D71 /* GeometryComputerUtils.hpp in Headers */,
489D7A732550FDC800AD896A /* MetalBackend.hpp in Headers */,
92FF03AC23AA0B5A00AC97F6 /* ResizeFunction.h in Headers */,
@ -2968,10 +2972,12 @@
4D9A937526255BDA00F9B43C /* CoreMLCommonExecution.hpp in Headers */,
4DF87C522887D3F20003E2D4 /* CPUSvd.hpp in Headers */,
48747D4B245D9D24000B9709 /* RuntimeFactory.hpp in Headers */,
CECF8C77299CAD9400D3875B /* log_builder.h in Headers */,
4D9A937226255BDA00F9B43C /* CoreMLConvolution.hpp in Headers */,
92FF038B23AA0B5A00AC97F6 /* CPUUnravelIndex.hpp in Headers */,
4AF4FB26269ED235005BA97B /* SparseConvInt8TiledExecutor.hpp in Headers */,
CEA49AA92AFD010900971CB7 /* MetalExecution.hpp in Headers */,
955AD7532E1FB44E0099F26C /* MoEModule.hpp in Headers */,
92FF03BC23AA0B5A00AC97F6 /* OptimizedComputer.hpp in Headers */,
95278CE72B9F0999009E9B29 /* CPUDynamicQuant.hpp in Headers */,
48C84BA0250F725600EE7666 /* InitNet.hpp in Headers */,
@ -3005,6 +3011,7 @@
92FF03CA23AA0B5A00AC97F6 /* CPUConvolutionDepthwise.hpp in Headers */,
92FF04A923AA0BFB00AC97F6 /* Schedule.hpp in Headers */,
489D7A9F2550FDC900AD896A /* MetalConvolutionCommon.hpp in Headers */,
CECF8C80299CAD9400D3875B /* lz4.h in Headers */,
92FF028623AA0B5A00AC97F6 /* CPUDeconvolution.hpp in Headers */,
489D7A722550FDC800AD896A /* MetalReLU6.hpp in Headers */,
92FF04B523AA0BFB00AC97F6 /* TensorUtils.hpp in Headers */,
@ -3045,8 +3052,6 @@
92FF026023AA0B5A00AC97F6 /* CPURNNSequenceGRU.hpp in Headers */,
48747D4F245D9E13000B9709 /* CPURaster.hpp in Headers */,
489D7A822550FDC900AD896A /* MetalPReLU.hpp in Headers */,
CE0AD4E82E1FB152002013A8 /* MoEModule.hpp in Headers */,
CE0AD4E92E1FB152002013A8 /* ModuleInside.hpp in Headers */,
48C84B84250F711700EE7666 /* WhileModule.hpp in Headers */,
92FF02A923AA0B5A00AC97F6 /* CPUCropAndResize.hpp in Headers */,
4D6D7FD92656897200F80814 /* SparseConvolutionTiledExecutor.hpp in Headers */,
@ -3058,20 +3063,24 @@
92FF03A623AA0B5A00AC97F6 /* ConvolutionTiledExecutor.hpp in Headers */,
92FF036523AA0B5A00AC97F6 /* CPUResize.hpp in Headers */,
92FF04B423AA0BFB00AC97F6 /* MNNMemoryUtils.h in Headers */,
CECF8C88299CAD9400D3875B /* log_api.h in Headers */,
4A224A0D27D0C2D9000A9260 /* ConvolutionPackWinograd.hpp in Headers */,
4A224A0E27D0C2D9000A9260 /* ConvolutionPackFreeWinograd.hpp in Headers */,
4D9A937426255BDA00F9B43C /* CoreMLReduction.hpp in Headers */,
48C84B8B250F711700EE7666 /* PipelineModule.hpp in Headers */,
F41497D7278D8A21004A363A /* RuntimeAttr.hpp in Headers */,
CECF8C5B299CACFD00D3875B /* LogHelper.hpp in Headers */,
92FF04C123AA0BFB00AC97F6 /* Backend.hpp in Headers */,
482BFBCD28351BA1009210E4 /* ShaderMap.hpp in Headers */,
489D7A812550FDC900AD896A /* MetalPooling.hpp in Headers */,
CECF8C7F299CAD9400D3875B /* md5.h in Headers */,
92FF02A623AA0B5A00AC97F6 /* CPUQuantizedMaxPool.hpp in Headers */,
92FF028023AA0B5A00AC97F6 /* CPUFloatToInt8.hpp in Headers */,
92FF028723AA0B5A00AC97F6 /* CPUFixedPoint.hpp in Headers */,
C43C8227251894F400A0FF84 /* Vec.hpp in Headers */,
4819FB1D24C138DF0050BD09 /* GeometryConvUtils.hpp in Headers */,
489D7A952550FDC900AD896A /* MetalMatMul.hpp in Headers */,
CECF8C83299CAD9400D3875B /* log_define.h in Headers */,
C48CAE2628900C4A00271A6D /* ConvInt8Winograd.hpp in Headers */,
48F34730273A7C7300C45394 /* CPUImageProcess.hpp in Headers */,
489D7A702550FDC800AD896A /* MetalRaster.hpp in Headers */,
@ -3392,6 +3401,7 @@
48F34734273A7C8400C45394 /* ImageProcessFunction.cpp in Sources */,
6A131E4025823349002EC3D6 /* PluginKernel.cpp in Sources */,
48958781268EBA6F00EA01A7 /* CPUSegmentMean.cpp in Sources */,
CECF8C7B299CAD9400D3875B /* sha1.c in Sources */,
4D9A937026255BDA00F9B43C /* CoreMLUnary.cpp in Sources */,
92FF04A823AA0BFB00AC97F6 /* AutoTime.cpp in Sources */,
92FF04AE23AA0BFB00AC97F6 /* Backend.cpp in Sources */,
@ -3418,7 +3428,6 @@
48925F342744AC0700919B37 /* CPUROIAlign.cpp in Sources */,
4896D36925FE2A3D00717702 /* Arm82Unary.cpp in Sources */,
4DCF53902892B17100B5B393 /* ShapeHistogram.cpp in Sources */,
CE0AD4EA2E1FB152002013A8 /* MoEModule.cpp in Sources */,
92FF043423AA0B7100AC97F6 /* ShapeStridedSlice.cpp in Sources */,
4896D37825FE2A6B00717702 /* MNNExpFP16.S in Sources */,
4D4CF46B2760946500A36D9F /* draw.cpp in Sources */,
@ -3447,6 +3456,7 @@
92FF03CE23AA0B5A00AC97F6 /* CPUOPRegister.cpp in Sources */,
92FF02B323AA0B5A00AC97F6 /* CPUInstanceNorm.cpp in Sources */,
4819FB2C24C1396A0050BD09 /* GeometryPoolGrad.cpp in Sources */,
CECF8C7E299CAD9400D3875B /* log_builder.cpp in Sources */,
92FF042223AA0B7100AC97F6 /* ShapeConcat.cpp in Sources */,
4D6D7FD12656891400F80814 /* MNNPackedSparseMatMulEpx4.S in Sources */,
4D5662CC299B76ED0031C1A1 /* MNNMaxPoolInt8.S in Sources */,
@ -3520,15 +3530,17 @@
92FF041A23AA0B7100AC97F6 /* ShapeFill.cpp in Sources */,
EB45C776244D7C6600E28F44 /* MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S in Sources */,
4AF4FB29269ED244005BA97B /* MNNPackedSparseQuantMatMulEpx1.S in Sources */,
CE0AD4E42E1FB106002013A8 /* CountMinMaxValue_FP16.S in Sources */,
4D759B2C25FF89EE0037B0B6 /* GeometryShape.cpp in Sources */,
11A01A07258785EA00745FA7 /* MNNVectorTop1Float.S in Sources */,
48747D6E245D9E33000B9709 /* GeometrySlice.cpp in Sources */,
CE072A272C91AF0700F190FD /* MNNC3ToC4Fast.S in Sources */,
CECF8C7D299CAD9400D3875B /* md5.c in Sources */,
92FF041923AA0B7100AC97F6 /* ShapeQuantizedMaxPool.cpp in Sources */,
92FF038A23AA0B5A00AC97F6 /* CPURange.cpp in Sources */,
CE072A182C91AEE700F190FD /* MNNGRAYToC4Fast.S in Sources */,
CE125CC92A52BF6B003698C9 /* MNNBilinearLineC8.S in Sources */,
955AD7522E1FB44E0099F26C /* MoEModule.cpp in Sources */,
959F15A92E2782F800C67803 /* CountMinMaxValue_FP16.S in Sources */,
92FF03A123AA0B5A00AC97F6 /* Int8FunctionsOpt.cpp in Sources */,
CE072A222C91AEE700F190FD /* MNNPackC2.S in Sources */,
92FF026523AA0B5A00AC97F6 /* CPUQuantizedAvgPool.cpp in Sources */,
@ -3588,8 +3600,10 @@
92FF042E23AA0B7100AC97F6 /* ShapeProposal.cpp in Sources */,
92FF025923AA0B5A00AC97F6 /* CPUPoolInt8.cpp in Sources */,
92FF045B23AA0B7100AC97F6 /* ShapeShape.cpp in Sources */,
CECF8C87299CAD9400D3875B /* sds.c in Sources */,
9560EAD62BDE426A00C8D0B6 /* GeometryLayernorm.cpp in Sources */,
4D6D7FD72656896D00F80814 /* SparseConvolutionTiledExecutor.cpp in Sources */,
CECF8C82299CAD9400D3875B /* log_api.cpp in Sources */,
92FF03A823AA0B5A00AC97F6 /* WinogradOptFunction.cpp in Sources */,
4A224A0C27D0C2D9000A9260 /* ConvolutionPackWinograd.cpp in Sources */,
92FF044123AA0B7100AC97F6 /* ShapeMoments.cpp in Sources */,
@ -3597,6 +3611,7 @@
4D9A936026255BDA00F9B43C /* Model.pb-c.c in Sources */,
CE9AFED628E54E3300566949 /* CPUInterp3D.cpp in Sources */,
C4F906B427688C3A0026B847 /* NMSModule.cpp in Sources */,
CECF8C64299CAD8400D3875B /* LogHelper.mm in Sources */,
48FA474523AA127B00172C3B /* Executor.cpp in Sources */,
92FF02EA23AA0B5A00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S in Sources */,
CE072A162C91AEE700F190FD /* MNNBGRAToBGR.S in Sources */,
@ -3624,6 +3639,7 @@
CE072A2C2CAA510F00F190FD /* MNNDepthwiseConvFastKernelFP16.S in Sources */,
EBECA3A724643D5D0062C7A3 /* MNNQuantizeFP16_UNIT4.S in Sources */,
92FF04A423AA0BFB00AC97F6 /* Interpreter.cpp in Sources */,
CECF8C5C299CACFD00D3875B /* Log.cpp in Sources */,
92FF045623AA0B7100AC97F6 /* ShapeReshape.cpp in Sources */,
92FF032523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */,
92FF044423AA0B7100AC97F6 /* ShapeLSTM.cpp in Sources */,
@ -3659,6 +3675,7 @@
92FF02B623AA0B5A00AC97F6 /* CPUUnary.cpp in Sources */,
92FF032723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */,
CE7DC00028E2DE6B00797689 /* ShapeConvTranspose3D.cpp in Sources */,
CECF8C78299CAD9400D3875B /* log_util_imp.cpp in Sources */,
CE072A152C91AEE700F190FD /* MNNRGBAToGRAYFast.S in Sources */,
92FF02CA23AA0B5A00AC97F6 /* MNNUnPackC4.S in Sources */,
952298B22B4D39050043978B /* MetalLoop.mm in Sources */,
@ -3683,11 +3700,13 @@
92FF02FF23AA0B5A00AC97F6 /* MNNFloat2Int8.S in Sources */,
4D9A937926255BDA00F9B43C /* CoreMLRaster.cpp in Sources */,
48417FF224D13BF50056D9A7 /* GeometrySelect.cpp in Sources */,
CECF8C84299CAD9400D3875B /* lz4.c in Sources */,
489D7A7E2550FDC900AD896A /* MNNMetalContext.mm in Sources */,
92FF033423AA0B5A00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S in Sources */,
92FF036B23AA0B5A00AC97F6 /* CPUResize.cpp in Sources */,
92FF02C723AA0B5A00AC97F6 /* MNNCopyC4WithStride.S in Sources */,
92FF030923AA0B5A00AC97F6 /* MNNNV21ToBGRUnit.S in Sources */,
CECF8C79299CAD9400D3875B /* hmac-sha.cpp in Sources */,
92FF04C023AA0BFB00AC97F6 /* Tensor.cpp in Sources */,
CEE9B95B2A3AA4D4006438F2 /* MNNBilinearLineC8.S in Sources */,
92FF045D23AA0B7100AC97F6 /* ShapeCast.cpp in Sources */,
@ -4016,7 +4035,7 @@
CODE_SIGN_STYLE = Automatic;
DEAD_CODE_STRIPPING = YES;
DEFINES_MODULE = YES;
DEVELOPMENT_TEAM = 6G7464HHUS;
DEVELOPMENT_TEAM = Q48UX93J22;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
@ -4081,7 +4100,7 @@
CODE_SIGN_STYLE = Automatic;
DEAD_CODE_STRIPPING = YES;
DEFINES_MODULE = YES;
DEVELOPMENT_TEAM = 6G7464HHUS;
DEVELOPMENT_TEAM = Q48UX93J22;
DYLIB_COMPATIBILITY_VERSION = 1;
DYLIB_CURRENT_VERSION = 1;
DYLIB_INSTALL_NAME_BASE = "@rpath";
@ -4144,7 +4163,7 @@
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage;
CODE_SIGN_STYLE = Automatic;
DEVELOPMENT_TEAM = 6G7464HHUS;
DEVELOPMENT_TEAM = Q48UX93J22;
GCC_ENABLE_CPP_EXCEPTIONS = NO;
GCC_ENABLE_CPP_RTTI = NO;
GCC_PREPROCESSOR_DEFINITIONS = "MNN_REDUCE_SIZE=1";
@ -4160,8 +4179,7 @@
IPHONEOS_DEPLOYMENT_TARGET = 9.0;
LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcdeve;
"PRODUCT_BUNDLE_IDENTIFIER[sdk=iphoneos*]" = com.taobao.mnn.abcdes;
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcdeve00;
PRODUCT_NAME = "$(TARGET_NAME)";
TARGETED_DEVICE_FAMILY = "1,2";
};
@ -4173,7 +4191,7 @@
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage;
CODE_SIGN_STYLE = Automatic;
DEVELOPMENT_TEAM = 6G7464HHUS;
DEVELOPMENT_TEAM = Q48UX93J22;
GCC_ENABLE_CPP_EXCEPTIONS = NO;
GCC_ENABLE_CPP_RTTI = NO;
GCC_PREPROCESSOR_DEFINITIONS = "MNN_REDUCE_SIZE=1";
@ -4189,8 +4207,7 @@
IPHONEOS_DEPLOYMENT_TARGET = 9.0;
LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcdeve;
"PRODUCT_BUNDLE_IDENTIFIER[sdk=iphoneos*]" = com.taobao.mnn.abcdes;
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcdeve00;
PRODUCT_NAME = "$(TARGET_NAME)";
TARGETED_DEVICE_FAMILY = "1,2";
};
@ -4304,3 +4321,4 @@
};
rootObject = 0F1465AE1FA18D1000F9860A /* Project object */;
}

View File

@ -16,7 +16,7 @@ option(PYMNN_TRAIN_API "MNN train API be exposed" OFF)
option(PYMNN_INTERNAL_SERVING "Internal use only." OFF)
option(PYMNN_OPENCV_API "MNN OpenCV API be exposed" ON)
option(PYMNN_IMGCODECS "MNN IMGCODECS API be exposed" OFF)
option(PYMNN_AUDIO_API "MNN Audio API be exposed" ON)
option(PYMNN_AUDIO_API "MNN Audio API be exposed" OFF)
option(PYMNN_OHOS_INTERNAL "compile for harmony internal." OFF)
if (PYMNN_OHOS_INTERNAL)
@ -202,7 +202,11 @@ else()
endif()
export_headers(DIR ${CMAKE_SOURCE_DIR}/pip_package/MNN)
else()
target_link_libraries(mnnpybridge PRIVATE log MNN MNN_Express MNNOpenCV MNNAudio)
if (PYMNN_AUDIO_API)
target_link_libraries(mnnpybridge PRIVATE log MNN MNN_Express MNNOpenCV MNNAudio)
else()
target_link_libraries(mnnpybridge PRIVATE log MNN MNN_Express MNNOpenCV)
endif()
if(PYMNN_USE_ALINNPYTHON)
target_link_libraries(mnnpybridge PRIVATE AliNNPython)
endif()

View File

@ -88,7 +88,7 @@ def build_deps():
if USE_OPENCL:
extra_opts += ' -DMNN_OPENCL=ON'
if USE_LLM:
extra_opts += ' -DMNN_BUILD_LLM=ON -DMNN_LOW_MEMORY=ON -DMNN_SUPPORT_TRANSFORMER_FUSE=ON'
extra_opts += ' -DMNN_BUILD_LLM=ON -DMNN_LOW_MEMORY=ON -DMNN_SUPPORT_TRANSFORMER_FUSE=ON -DLLM_SUPPORT_VISION=ON -DLLM_SUPPORT_AUDIO=ON'
if USE_ARM82:
extra_opts += ' -DMNN_ARM82=ON'
extra_opts += ' -DMNN_USE_THREAD_POOL=OFF -DMNN_OPENMP=ON' if USE_OPENMP else ' -DMNN_USE_THREAD_POOL=ON -DMNN_OPENMP=OFF'

View File

@ -1,5 +1,6 @@
#include <sstream>
#include "llm/llm.hpp"
#include "cpp/getLinearInput.hpp"
typedef struct {
PyObject_HEAD
@ -146,6 +147,61 @@ static PyObject* PyMNNLLM_reset(LLM *self, PyObject *args) {
Py_RETURN_NONE;
}
static PyObject* PyMNNLLM_enable_collection_mode(LLM *self, PyObject *args) {
if (self->is_embedding) {
Py_RETURN_NONE;
}
int mode = 0;
const char* output_file = NULL;
float target_sparsity = 0.5;
if (!PyArg_ParseTuple(args, "i|sf", &mode, &output_file, &target_sparsity)) {
PyErr_SetString(PyExc_ValueError, "Invalid arguments. Usage: enable_collection_mode(mode, output_file=None, target_sparsity=0.5)");
Py_RETURN_NONE;
}
std::string filename;
switch (mode) {
case 1: {
// Threshold mode
if (output_file == NULL) {
filename = "thresholds.json";
} else {
filename = std::string(output_file);
}
MNN::LinearInput::initGetThreshold(filename, target_sparsity);
MNN_PRINT("Enabled threshold collection mode. Output: %s, Sparsity: %.2f\n",
filename.c_str(), target_sparsity);
break;
}
case 2: {
// MaxValue mode
if (output_file == NULL) {
filename = "max_values.json";
} else {
filename = std::string(output_file);
}
MNN::LinearInput::initGetMaxValue(filename);
MNN_PRINT("Enabled max value collection mode. Output: %s\n", filename.c_str());
break;
}
default: {
PyErr_SetString(PyExc_ValueError, "Invalid mode. Use 1 for threshold collection, 2 for max value collection");
Py_RETURN_NONE;
}
}
return toPyObj(true);
}
static PyMethodDef PyMNNLLM_methods[] = {
{"load", (PyCFunction)PyMNNLLM_load, METH_VARARGS, "load model."},
{"forward", (PyCFunction)PyMNNLLM_forward, METH_VARARGS, "forward `logits` by `input_ids`."},
@ -159,6 +215,7 @@ static PyMethodDef PyMNNLLM_methods[] = {
{"create_lora", (PyCFunction)PyMNNLLM_create_lora, METH_VARARGS, "create_lora."},
{"set_config", (PyCFunction)PyMNNLLM_set_config, METH_VARARGS, "set_config."},
{"reset", (PyCFunction)PyMNNLLM_reset, METH_VARARGS, "reset."},
{"enable_collection_mode", (PyCFunction)PyMNNLLM_enable_collection_mode, METH_VARARGS, "Enable data collection mode."},
{NULL} /* Sentinel */
};

View File

@ -30,7 +30,7 @@ void Arm82MNNPackForMatMul_A(float* destOrigin, float const** sourceGroup, const
// C(UP_DIV(h,8), e, h8) = B(UP_DIV(h,hP), l, hP) * A(l, eP), hP = 24
// parameter: [aStride, l, h, cStride, bExtraStride]
// aStride in parameter is deprecated (useless), but for code clean, just retain it
void MNNPackedMatMulFP16(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias);
void MNNPackedMatMulFP16(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
// C(UP_DIV(h,8), e, h8) = B(UP_DIV(h,hP), l, hP) * A(l, e), hP = 24, e >= 1
// parameter: [aStride, l, h, cStride, bExtraStride]
@ -52,10 +52,16 @@ void MNNDynamicQuantFP16_Pack8(const float* src, int8_t* dst, const float* scale
void MNNDynamicQuantFP16_Pack4(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, const float* bias, size_t pack);
void MNNGeneralIm2col_Arm82(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
void MNNGeneralIm2col_Arm86(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
#ifdef MNN_SME2
void MNNGeneralIm2col_Fp16Sme2(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
#endif
void MNNLocalMinMaxFP16_Pack4(float* dstMin, float* dstMax, const float* source, size_t blockNum, size_t blockLU, size_t EP, size_t LP, size_t loadDstBuffer);
void MNNLocalMinMaxFP16_Pack8(float* dstMin, float* dstMax, const float* source, size_t blockNum, size_t blockLU, size_t EP, size_t LP, size_t loadDstBuffer);
#endif // MNN_LOW_MEMORY
void CountMinMaxValue_FP16(float* source, float* minVal, float* maxVal, size_t sizeQuad);
#ifdef MNN_SME2
void MNNPackedMatMulRemainFP16_SME2(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
#endif
#endif
#if defined(__aarch64__)
@ -72,6 +78,12 @@ void MNNConvRunForLineDepthwiseFP16(float* dst, const float* src, const float* w
namespace MNN {
static void Sme2MNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
*hP = 64;
*eP = 16;
*lP = 2;
}
static void MNNMatrixAddFP16(FLOAT16* C, const FLOAT16* A, const FLOAT16* B, size_t widthC8, size_t cStride, size_t aStride, size_t bStride, size_t height) {
for (int y = 0; y < height; ++y) {
auto a = A + aStride * y, b = B + bStride * y;
@ -103,18 +115,23 @@ static void ARM82CountMinMaxValue(float* source, float* minVal, float* maxVal, s
max_ = ((__fp16*)maxVal)[0];
min_ = ((__fp16*)minVal)[0];
}
if (remain > 0) {
int16_t tmp[8] = {0};
auto srcRemain = reinterpret_cast<uint8_t*>(source) + 8 * (size / 8) * 2;
::memcpy(tmp, srcRemain, remain * 2);
CountMinMaxValue_FP16((float*)tmp, (float*)tmp, (float*)((uint8_t*)tmp + 2), 1);
max_ = ALIMAX(((__fp16*)tmp)[1], max_);
min_ = ALIMIN(((__fp16*)tmp)[0], min_);
auto srcPtr = reinterpret_cast<__fp16*>(source);
while (remain) {
max_ = ALIMAX(srcPtr[0], max_);
min_ = ALIMIN(srcPtr[0], min_);
srcPtr += 1;
remain--;
}
reinterpret_cast<__fp16*>(minVal)[0] = min_;
reinterpret_cast<__fp16*>(maxVal)[0] = max_;
}
}
#ifdef MNN_SME2
//(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias)
static void MNNPackedMatMulFP16_SME2(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) {
MNNPackedMatMulRemainFP16_SME2(C, A, B, 16, parameter, postParameters, bias, k, b);
}
#endif
#else
static void ARM82CountMinMaxValue(float* source, float* minVal, float* maxVal, size_t size) {
auto srcPtr = (FLOAT16*)source;
@ -134,7 +151,8 @@ static void ARM82CountMinMaxValue(float* source, float* minVal, float* maxVal, s
}
#endif
static void Arm82MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h, size_t l, bool transpose) {
static void Arm82MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h, size_t kernelsize, size_t ic, bool transpose) {
auto l = kernelsize * ic;
auto dest = (int16_t*)destC;
auto source = (int16_t*)sourceC;
int ePack, lPack, hPack;
@ -169,6 +187,36 @@ static void Arm82MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h
}
}
static void Sme2MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h, size_t kernelsize, size_t ic ,bool transpose) {
auto dest = (int16_t*)destC;
auto source = (int16_t*)sourceC;
int LP = 2;
int HP = 64;
auto l = kernelsize * ic;
memset(dest, 0, ROUND_UP(h, HP) * ROUND_UP(ic, LP) * kernelsize * sizeof(FLOAT16));
auto stride0 = ROUND_UP(ic, LP) * kernelsize * HP;
auto stride1 = HP * ROUND_UP(ic, LP);
auto stride2 = HP * LP;
size_t srcStride0 = l; // [h,k2,ic]->[hu,k2,ic/lp,hp,lp]
size_t srcStride1 = 1;
if (!transpose) { // [k2,ic,h]->[hu,k2,ic/lp,hp,lp]
srcStride0 = 1;
srcStride1 = h;
}
for (int y = 0; y < h; ++y) {
auto yHu = y / HP;
auto yHp = y % HP;
for (int k = 0; k < kernelsize; ++k) {
for (int x = 0; x < ic; ++x) {
auto xLu = x / LP;
auto xLp = x % LP;
dest[yHu * stride0 + k * stride1 + xLu * stride2 + yHp * LP + xLp] = source[y * srcStride0 + (x + k * ic) * srcStride1];
}
}
}
}
static void MNNScaleAndAddBiasFP16(FLOAT16* dst, const FLOAT16* src, const FLOAT16* bias, const FLOAT16* alpha, size_t planeNumber,
size_t biasNumber) {
for (int z = 0; z < biasNumber; ++z) {
@ -221,7 +269,7 @@ static void MNNGridSampleComputeCordFP16(FLOAT16* dst, const FLOAT16* src, size_
::memcpy(dst, tempDst, areaRemain * 2 * sizeof(int16_t));
}
static void MNNGridSampleComputeCord3DFp16(FLOAT* dst, const FLOAT* src, size_t inD, size_t inH, size_t inW, size_t outD, size_t outH, size_t outW, size_t strideD, size_t strideH, bool alignCorners) {
static void MNNGridSampleComputeCord3DFp16(FLOAT* dst, const FLOAT* src, size_t inD, size_t inH, size_t inW, size_t outD, size_t outH, size_t outW, bool alignCorners) {
float16x8_t zero = vdupq_n_f16(0);
float16x8_t one = vdupq_n_f16(1);
float16x8_t half = vdupq_n_f16(0.5f);
@ -882,6 +930,499 @@ static void _ArmBasicMNNPackC4ForMatMul_A_L8(int8_t* destOrigin, int8_t const**
}
}
static void Sme2MNNPackForMatMul_A_FP16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) {
int LP = 2;
int pack = 8;
int eDest = 16;
// LP >= pack
int number = info[0];
int eReal = info[1];
int offset = info[3];
for (int n=0; n<number; ++n) {
int eWork = el[4 * n + 0];
int lWork = el[4 * n + 1];
int eOffset = el[4 * n + 2];
int lOffset = el[4 * n + 3];
auto sourceN = (FLOAT16*)(sourceGroup[n]);
auto destN = (FLOAT16*)destOrigin + lOffset * eDest + eOffset * LP;
auto srcStride0 = pack * offset;
auto dstStride0 = eDest * LP;
auto l = lWork;
while (l > 7) {
auto source = sourceN;
auto dest = destN;
l -= 8;
auto e = eWork;
if (e == eDest) {
auto s0 = vld1q_f32((float*)(source)); // 00112233
auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233
auto s2 = vld1q_f32((float*)(source + 2 * srcStride0));
auto s3 = vld1q_f32((float*)(source + 3 * srcStride0));
auto s4 = vld1q_f32((float*)(source + 4 * srcStride0));
auto s5 = vld1q_f32((float*)(source + 5 * srcStride0));
auto s6 = vld1q_f32((float*)(source + 6 * srcStride0));
auto s7 = vld1q_f32((float*)(source + 7 * srcStride0));
auto s8 = vld1q_f32((float*)(source + 8 * srcStride0));
auto s9 = vld1q_f32((float*)(source + 9 * srcStride0));
auto s10 = vld1q_f32((float*)(source + 10 * srcStride0));
auto s11 = vld1q_f32((float*)(source + 11 * srcStride0));
auto s12 = vld1q_f32((float*)(source + 12 * srcStride0));
auto s13 = vld1q_f32((float*)(source + 13 * srcStride0));
auto s14 = vld1q_f32((float*)(source + 14 * srcStride0));
auto s15 = vld1q_f32((float*)(source + 15 * srcStride0));
auto zip1s01 = vzip1q_f32(s0, s1); // 00001111
auto zip1s23 = vzip1q_f32(s2, s3); // 00001111
auto zip1s45 = vzip1q_f32(s4, s5); // 00001111
auto zip1s67 = vzip1q_f32(s6, s7); // 00001111
auto zip1s89 = vzip1q_f32(s8, s9); // 00001111
auto zip1s1011 = vzip1q_f32(s10, s11); // 00001111
auto zip1s1213 = vzip1q_f32(s12, s13); // 00001111
auto zip1s1415 = vzip1q_f32(s14, s15); // 00001111
auto zip2s01 = vzip2q_f32(s0, s1); // 22223333
auto zip2s23 = vzip2q_f32(s2, s3); // 22223333
auto zip2s45 = vzip2q_f32(s4, s5); // 22223333
auto zip2s67 = vzip2q_f32(s6, s7); // 22223333
auto zip2s89 = vzip2q_f32(s8, s9); // 22223333
auto zip2s1011 = vzip2q_f32(s10, s11); // 22223333
auto zip2s1213 = vzip2q_f32(s12, s13); // 22223333
auto zip2s1415 = vzip2q_f32(s14, s15); // 22223333
auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000
auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
auto zip1s891011_01 = vzip1q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011);
auto zip1s12131415_01 = vzip1q_f64((float64x2_t)zip1s1213, (float64x2_t)zip1s1415);
auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111
auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
auto zip2s891011_01 = vzip2q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011);
auto zip2s12131415_01 = vzip2q_f64((float64x2_t)zip1s1213, (float64x2_t)zip1s1415);
auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222
auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
auto zip1s891011_23 = vzip1q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011);
auto zip1s12131415_23 = vzip1q_f64((float64x2_t)zip2s1213, (float64x2_t)zip2s1415);
auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333
auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
auto zip2s891011_23 = vzip2q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011);
auto zip2s12131415_23 = vzip2q_f64((float64x2_t)zip2s1213, (float64x2_t)zip2s1415);
vst1q_f64((float64_t*)dest, zip1s0123_01);
vst1q_f64((float64_t*)(dest + 8), zip1s4567_01);
vst1q_f64((float64_t*)(dest + 16), zip1s891011_01);
vst1q_f64((float64_t*)(dest + 24), zip1s12131415_01);
vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01);
vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01);
vst1q_f64((float64_t*)(dest + dstStride0 + 16), zip2s891011_01);
vst1q_f64((float64_t*)(dest + dstStride0 + 24), zip2s12131415_01);
vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23);
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23);
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 16), zip1s891011_23);
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 24), zip1s12131415_23);
vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23);
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23);
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 16), zip2s891011_23);
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 24), zip2s12131415_23);
// dest += (4 * dstStride0);
// e -= eDest;
sourceN += (eReal * pack);
destN += (4 * dstStride0);
continue;
}
if (e > 11) {
auto s0 = vld1q_f32((float*)(source)); // 00112233
auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233
auto s2 = vld1q_f32((float*)(source + 2 * srcStride0));
auto s3 = vld1q_f32((float*)(source + 3 * srcStride0));
auto s4 = vld1q_f32((float*)(source + 4 * srcStride0));
auto s5 = vld1q_f32((float*)(source + 5 * srcStride0));
auto s6 = vld1q_f32((float*)(source + 6 * srcStride0));
auto s7 = vld1q_f32((float*)(source + 7 * srcStride0));
auto s8 = vld1q_f32((float*)(source + 8 * srcStride0));
auto s9 = vld1q_f32((float*)(source + 9 * srcStride0));
auto s10 = vld1q_f32((float*)(source + 10 * srcStride0));
auto s11 = vld1q_f32((float*)(source + 11 * srcStride0));
auto zip1s01 = vzip1q_f32(s0, s1); // 00001111
auto zip1s23 = vzip1q_f32(s2, s3); // 00001111
auto zip1s45 = vzip1q_f32(s4, s5); // 00001111
auto zip1s67 = vzip1q_f32(s6, s7); // 00001111
auto zip1s89 = vzip1q_f32(s8, s9); // 00001111
auto zip1s1011 = vzip1q_f32(s10, s11); // 00001111
auto zip2s01 = vzip2q_f32(s0, s1); // 22223333
auto zip2s23 = vzip2q_f32(s2, s3); // 22223333
auto zip2s45 = vzip2q_f32(s4, s5); // 22223333
auto zip2s67 = vzip2q_f32(s6, s7); // 22223333
auto zip2s89 = vzip2q_f32(s8, s9); // 22223333
auto zip2s1011 = vzip2q_f32(s10, s11); // 22223333
auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000
auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
auto zip1s891011_01 = vzip1q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011);
auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111
auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
auto zip2s891011_01 = vzip2q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011);
auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222
auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
auto zip1s891011_23 = vzip1q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011);
auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333
auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
auto zip2s891011_23 = vzip2q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011);
vst1q_f64((float64_t*)dest, zip1s0123_01);
vst1q_f64((float64_t*)(dest + 8), zip1s4567_01);
vst1q_f64((float64_t*)(dest + 16), zip1s891011_01);
vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01);
vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01);
vst1q_f64((float64_t*)(dest + dstStride0 + 16), zip2s891011_01);
vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23);
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23);
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 16), zip1s891011_23);
vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23);
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23);
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 16), zip2s891011_23);
dest += 24;
e -= 12;
source += (12 * srcStride0);
}
if (e > 7) {
auto s0 = vld1q_f32((float*)(source)); // 00112233
auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233
auto s2 = vld1q_f32((float*)(source + 2 * srcStride0));
auto s3 = vld1q_f32((float*)(source + 3 * srcStride0));
auto s4 = vld1q_f32((float*)(source + 4 * srcStride0));
auto s5 = vld1q_f32((float*)(source + 5 * srcStride0));
auto s6 = vld1q_f32((float*)(source + 6 * srcStride0));
auto s7 = vld1q_f32((float*)(source + 7 * srcStride0));
auto zip1s01 = vzip1q_f32(s0, s1); // 00001111
auto zip1s23 = vzip1q_f32(s2, s3); // 00001111
auto zip1s45 = vzip1q_f32(s4, s5); // 00001111
auto zip1s67 = vzip1q_f32(s6, s7); // 00001111
auto zip2s01 = vzip2q_f32(s0, s1); // 22223333
auto zip2s23 = vzip2q_f32(s2, s3); // 22223333
auto zip2s45 = vzip2q_f32(s4, s5); // 22223333
auto zip2s67 = vzip2q_f32(s6, s7); // 22223333
auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000
auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111
auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222
auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333
auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
vst1q_f64((float64_t*)dest, zip1s0123_01);
vst1q_f64((float64_t*)(dest + 8), zip1s4567_01);
vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01);
vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01);
vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23);
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23);
vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23);
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23);
dest += 16;
e -= 8;
source += (8 * srcStride0);
}
if (e > 3) {
auto s0 = vld1q_f32((float*)(source)); // 00112233
auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233
auto s2 = vld1q_f32((float*)(source + 2 * srcStride0));
auto s3 = vld1q_f32((float*)(source + 3 * srcStride0));
auto zip1s01 = vzip1q_f32(s0, s1); // 00001111
auto zip1s23 = vzip1q_f32(s2, s3); // 00001111
auto zip2s01 = vzip2q_f32(s0, s1); // 22223333
auto zip2s23 = vzip2q_f32(s2, s3); // 22223333
auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000
auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111
auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222
auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333
vst1q_f64((float64_t*)dest, zip1s0123_01);
vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01);
vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23);
vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23);
dest += 8;
e -= 4;
source += (4 * srcStride0);
}
while (e > 0) {
auto s0 = vld1q_f32((float*)(source)); // 00112233
((float*)dest)[0] = s0[0];
((float*)(dest + dstStride0))[0] = s0[1];
((float*)(dest + 2 * dstStride0))[0] = s0[2];
((float*)(dest + 3 * dstStride0))[0] = s0[3];
dest += 2;
e -= 1;
source += srcStride0;
}
sourceN += (eReal * pack);
destN += (4 * dstStride0);
} // l>7
if (l > 3) {
auto source = sourceN;
auto dest = destN;
l -= 4;
auto e = eWork;
if (e == eDest) {
auto s0 = vld1_f32((float*)(source)); // 0011
auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
auto s2 = vld1_f32((float*)(source + 2 * srcStride0));
auto s3 = vld1_f32((float*)(source + 3 * srcStride0));
auto s4 = vld1_f32((float*)(source + 4 * srcStride0));
auto s5 = vld1_f32((float*)(source + 5 * srcStride0));
auto s6 = vld1_f32((float*)(source + 6 * srcStride0));
auto s7 = vld1_f32((float*)(source + 7 * srcStride0));
auto s8 = vld1_f32((float*)(source + 8 * srcStride0));
auto s9 = vld1_f32((float*)(source + 9 * srcStride0));
auto s10 = vld1_f32((float*)(source + 10 * srcStride0));
auto s11 = vld1_f32((float*)(source + 11 * srcStride0));
auto s12 = vld1_f32((float*)(source + 12 * srcStride0));
auto s13 = vld1_f32((float*)(source + 13 * srcStride0));
auto s14 = vld1_f32((float*)(source + 14 * srcStride0));
auto s15 = vld1_f32((float*)(source + 15 * srcStride0));
auto zip1s01 = vzip1_f32(s0, s1); // 0000
auto zip1s23 = vzip1_f32(s2, s3); // 0000
auto zip1s45 = vzip1_f32(s4, s5); // 0000
auto zip1s67 = vzip1_f32(s6, s7); // 0000
auto zip1s89 = vzip1_f32(s8, s9); // 0000
auto zip1s1011 = vzip1_f32(s10, s11); // 0000
auto zip1s1213 = vzip1_f32(s12, s13); // 0000
auto zip1s1415 = vzip1_f32(s14, s15); // 0000
auto zip2s01 = vzip2_f32(s0, s1); // 1111
auto zip2s23 = vzip2_f32(s2, s3); // 1111
auto zip2s45 = vzip2_f32(s4, s5); // 1111
auto zip2s67 = vzip2_f32(s6, s7); // 1111
auto zip2s89 = vzip2_f32(s8, s9); // 1111
auto zip2s1011 = vzip2_f32(s10, s11); // 1111
auto zip2s1213 = vzip2_f32(s12, s13); // 1111
auto zip2s1415 = vzip2_f32(s14, s15); // 1111
vst1_f32((float32_t*)dest, zip1s01);
vst1_f32((float32_t*)(dest + 4), zip1s23);
vst1_f32((float32_t*)(dest + 8), zip1s45);
vst1_f32((float32_t*)(dest + 12), zip1s67);
vst1_f32((float32_t*)(dest + 16), zip1s89);
vst1_f32((float32_t*)(dest + 20), zip1s1011);
vst1_f32((float32_t*)(dest + 24), zip1s1213);
vst1_f32((float32_t*)(dest + 28), zip1s1415);
vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23);
vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45);
vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67);
vst1_f32((float32_t*)(dest + dstStride0 + 16), zip2s89);
vst1_f32((float32_t*)(dest + dstStride0 + 20), zip2s1011);
vst1_f32((float32_t*)(dest + dstStride0 + 24), zip2s1213);
vst1_f32((float32_t*)(dest + dstStride0 + 28), zip2s1415);
dest += 32;
e -= eDest;
}
if (e > 11) {
auto s0 = vld1_f32((float*)(source)); // 0011
auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
auto s2 = vld1_f32((float*)(source + 2 * srcStride0));
auto s3 = vld1_f32((float*)(source + 3 * srcStride0));
auto s4 = vld1_f32((float*)(source + 4 * srcStride0));
auto s5 = vld1_f32((float*)(source + 5 * srcStride0));
auto s6 = vld1_f32((float*)(source + 6 * srcStride0));
auto s7 = vld1_f32((float*)(source + 7 * srcStride0));
auto s8 = vld1_f32((float*)(source + 8 * srcStride0));
auto s9 = vld1_f32((float*)(source + 9 * srcStride0));
auto s10 = vld1_f32((float*)(source + 10 * srcStride0));
auto s11 = vld1_f32((float*)(source + 11 * srcStride0));
auto zip1s01 = vzip1_f32(s0, s1); // 0000
auto zip1s23 = vzip1_f32(s2, s3); // 0000
auto zip1s45 = vzip1_f32(s4, s5); // 0000
auto zip1s67 = vzip1_f32(s6, s7); // 0000
auto zip1s89 = vzip1_f32(s8, s9); // 0000
auto zip1s1011 = vzip1_f32(s10, s11); // 0000
auto zip2s01 = vzip2_f32(s0, s1); // 1111
auto zip2s23 = vzip2_f32(s2, s3); // 1111
auto zip2s45 = vzip2_f32(s4, s5); // 1111
auto zip2s67 = vzip2_f32(s6, s7); // 1111
auto zip2s89 = vzip2_f32(s8, s9); // 1111
auto zip2s1011 = vzip2_f32(s10, s11); // 1111
vst1_f32((float32_t*)dest, zip1s01);
vst1_f32((float32_t*)(dest + 4), zip1s23);
vst1_f32((float32_t*)(dest + 8), zip1s45);
vst1_f32((float32_t*)(dest + 12), zip1s67);
vst1_f32((float32_t*)(dest + 16), zip1s89);
vst1_f32((float32_t*)(dest + 20), zip1s1011);
vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23);
vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45);
vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67);
vst1_f32((float32_t*)(dest + dstStride0 + 16), zip2s89);
vst1_f32((float32_t*)(dest + dstStride0 + 20), zip2s1011);
dest += 24;
e -= 12;
source += (12 * srcStride0);
}
if (e > 7) {
auto s0 = vld1_f32((float*)(source)); // 0011
auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
auto s2 = vld1_f32((float*)(source + 2 * srcStride0));
auto s3 = vld1_f32((float*)(source + 3 * srcStride0));
auto s4 = vld1_f32((float*)(source + 4 * srcStride0));
auto s5 = vld1_f32((float*)(source + 5 * srcStride0));
auto s6 = vld1_f32((float*)(source + 6 * srcStride0));
auto s7 = vld1_f32((float*)(source + 7 * srcStride0));
auto zip1s01 = vzip1_f32(s0, s1); // 0000
auto zip1s23 = vzip1_f32(s2, s3); // 0000
auto zip1s45 = vzip1_f32(s4, s5); // 0000
auto zip1s67 = vzip1_f32(s6, s7); // 0000
auto zip2s01 = vzip2_f32(s0, s1); // 1111
auto zip2s23 = vzip2_f32(s2, s3); // 1111
auto zip2s45 = vzip2_f32(s4, s5); // 1111
auto zip2s67 = vzip2_f32(s6, s7); // 1111
vst1_f32((float32_t*)dest, zip1s01);
vst1_f32((float32_t*)(dest + 4), zip1s23);
vst1_f32((float32_t*)(dest + 8), zip1s45);
vst1_f32((float32_t*)(dest + 12), zip1s67);
vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23);
vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45);
vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67);
dest += 16;
e -= 8;
source += (8 * srcStride0);
}
if (e > 3) {
auto s0 = vld1_f32((float*)(source)); // 0011
auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
auto s2 = vld1_f32((float*)(source + 2 * srcStride0));
auto s3 = vld1_f32((float*)(source + 3 * srcStride0));
auto zip1s01 = vzip1_f32(s0, s1); // 0000
auto zip1s23 = vzip1_f32(s2, s3); // 0000
auto zip2s01 = vzip2_f32(s0, s1); // 1111
auto zip2s23 = vzip2_f32(s2, s3); // 1111
vst1_f32((float32_t*)dest, zip1s01);
vst1_f32((float32_t*)(dest + 4), zip1s23);
vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23);
dest += 8;
e -= 4;
source += (4 * srcStride0);
}
if (e > 1) {
auto s0 = vld1_f32((float*)(source)); // 0011
auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
auto zip1s01 = vzip1_f32(s0, s1); // 0000
auto zip2s01 = vzip2_f32(s0, s1); // 1111
vst1_f32((float32_t*)dest, zip1s01);
vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
dest += 4;
e -= 2;
source += (2 * srcStride0);
}
if (e > 0) {
auto s0 = vld1_f32((float*)(source)); // 0011
((float*)dest)[0] = s0[0];
((float*)(dest + dstStride0))[0] = s0[1];
}
sourceN += 4;
destN += (2 * dstStride0);
}
auto source = (FLOAT16*)(sourceGroup[n]);
auto dest = (FLOAT16*)destOrigin + lOffset * eDest + eOffset * LP;
if (l > 0) {
auto e = eWork;
auto lRemain = lWork - l;
// if e < eDest, packed A -> [LU, eDest, LP] eDest=eP
for (int y=0; y<e; ++y) {
auto yR = y % eDest;
for (int x=lRemain; x<lWork; ++x) {
auto xR = x % pack;
auto xC = x / pack;
auto xOut = x / LP;
auto xIn = x % LP;
dest[xOut * eDest * LP + yR * LP + xIn] = source[xC * eReal * pack + y * pack * offset + xR];
}
}
l--;
}
}
}
#ifdef MNN_LOW_MEMORY
void MNNAbsMaxFP16(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) {
if (pack == 4) {
@ -1024,7 +1565,7 @@ static void MNNAsyQuantInfo_FP16(float* scale, float* bias, float* qscale, float
}
#ifdef __aarch64__
if (DST_XUNIT == 12) { // Arm82, fp16: core->pack=8, SRC_UNIT=4
if (DST_XUNIT == 12 || DST_XUNIT == 16) { // Arm82/SME2, fp16: core->pack=8, SRC_UNIT=4
// max,min shape: [blockNum, EP]
if (innerSide == 4) {
for (int i = 0; i < kernelsize; ++i) {
@ -1037,12 +1578,22 @@ static void MNNAsyQuantInfo_FP16(float* scale, float* bias, float* qscale, float
}
}
// scale, bias
auto success = MNNAsyLocalQuantInfo_EP12_FP16(scale, bias, qscale, qbias, dstMin, dstMax, info);
if (!success) {
MNN_ERROR("Call error: MNNAsyLocalQuantInfo_EP12_FP16\n");
if (DST_XUNIT == 12) {
auto success = MNNAsyLocalQuantInfo_EP12_FP16(scale, bias, qscale, qbias, dstMin, dstMax, info);
if (!success) {
MNN_ERROR("Call error: MNNAsyLocalQuantInfo_EP12_FP16\n");
return;
}
return;
}
if (DST_XUNIT == 16) {
auto success = MNNAsyLocalQuantInfo_EP16_FP16(scale, bias, qscale, qbias, dstMin, dstMax, info);
if (!success) {
MNN_ERROR("Call error: MNNAsyLocalQuantInfo_EP16_FP16\n");
return;
}
return;
}
return;
}
if (DST_XUNIT == 10 && innerSide == 8) { // Arm86, fp16: core->pack=8, SRC_UNIT=8
// max,min shape: [blockNum, plane]
@ -1120,12 +1671,16 @@ bool Arm82Functions::init() {
gInstance = new CoreFunctions;
gArm82CoreInt8Functions = new CoreInt8Functions;
*gArm82CoreInt8Functions = *MNNGetInt8CoreFunctions();
gInstance->backendMatmulRelatedFunctions = origin->backendMatmulRelatedFunctions;
gInstance->sme2MatmulRelatedFuncions = origin->sme2MatmulRelatedFuncions;
{
if (origin->supportSDot) {
gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<12, 4>;
gInstance->supportSDot = true;
}
if (origin->supportI8mm) {
gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A_L8<10, 8>;
gInstance->supportI8mm = true;
}
}
@ -1165,12 +1720,16 @@ bool Arm82Functions::init() {
FUNC_PTR_ASSIGN(gInstance->MNNAddC4WithStride, MNNAddC8WithStrideFP16);
// MatMul
FUNC_PTR_ASSIGN(gInstance->MNNGetMatMulPackMode, Arm82MNNGetMatMulPackMode);
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul, MNNPackedMatMulFP16);
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain, MNNPackedMatMulRemainFP16);
FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Arm82MNNPackForMatMul_A);
FUNC_PTR_ASSIGN(gInstance->MNNPackForMatMul_B, Arm82MNNPackForMatMul_B);
#if defined(__aarch64__)
gInstance->supportFp16arith = origin->supportFp16arith;
gInstance->supportSDot = origin->supportSDot;
gInstance->supportI8mm = origin->supportI8mm;
gInstance->supportSME2 = origin->supportSME2;
#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM
// Weight Dequant Gemm Kernels
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul_int8, MNNPackedMatMulFP16_int8);
@ -1184,22 +1743,18 @@ bool Arm82Functions::init() {
FUNC_PTR_ASSIGN(gInstance->MNNAsyQuantFunc, MNNAsyQuantFunc_Arm82);
FUNC_PTR_ASSIGN(gInstance->MNNAsyQuantInfo, MNNAsyQuantInfo_FP16); // return 'plane' min&max
FUNC_PTR_ASSIGN(gInstance->MNNDynamicUpdateConvBiasScale, origin->MNNDynamicUpdateConvBiasScale);
#ifdef __aarch64__
if (origin->supportSDot) {
FUNC_PTR_ASSIGN(gInstance->MNNGeneralIm2Col, MNNGeneralIm2col_Arm82);
}
if (origin->supportI8mm) {
FUNC_PTR_ASSIGN(gInstance->MNNGeneralIm2Col, MNNGeneralIm2col_Arm86);
}
#endif
#endif
#endif // MNN_LOW_MEMORY
FUNC_PTR_ASSIGN(gInstance->MNNCountMaxMinValue, ARM82CountMinMaxValue); // return one min&max
FUNC_PTR_ASSIGN(gInstance->MNNSumByAxisLForMatmul_A, origin->MNNSumByAxisLForMatmul_A);
FUNC_PTR_ASSIGN(gInstance->MNNDepthwiseConvFastKernel, MNNDepthwiseConvFastKernelFP16);
#endif
FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Arm82MNNPackForMatMul_A);
FUNC_PTR_ASSIGN(gInstance->MNNGetMatMulPackMode, Arm82MNNGetMatMulPackMode);
FUNC_PTR_ASSIGN(gInstance->MNNPackForMatMul_B, Arm82MNNPackForMatMul_B);
gInstance->MNNComputeMatMulForH_1 = _MNNComputeMatMulForH_1_FP16;
gInstance->MNNComputeMatMulForE_1 = _MNNComputeMatMulForE_1_FP16;
@ -1219,6 +1774,43 @@ bool Arm82Functions::init() {
gInstance->MNNPoolingMax = (decltype(gInstance->MNNPoolingMax))(poolingMax<float16_t, Vec, 8, -65535>);
gInstance->MNNPoolingAvg = (decltype(gInstance->MNNPoolingAvg))(poolingAvg<float16_t, Vec, 8>);
{
gInstance->backendMatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A;
gInstance->backendMatmulRelatedFunctions.MNNGeneralIm2Col = gInstance->MNNGeneralIm2Col;
gInstance->backendMatmulRelatedFunctions.MNNGetMatMulPackMode = gInstance->MNNGetMatMulPackMode;
gInstance->backendMatmulRelatedFunctions.MNNPackedMatMul = gInstance->MNNPackedMatMul;
gInstance->backendMatmulRelatedFunctions.MNNPackedMatMulRemain = gInstance->MNNPackedMatMulRemain;
gInstance->backendMatmulRelatedFunctions.MNNPackC4ForMatMul_A = gInstance->MNNPackC4ForMatMul_A;
gInstance->backendMatmulRelatedFunctions.MNNPackForMatMul_B = gInstance->MNNPackForMatMul_B;
}
#ifdef __aarch64__
#ifdef MNN_SME2
if (origin->supportSME2) {
gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<16, 4>;
gInstance->sme2MatmulRelatedFuncions.MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<16, 4>;
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul, MNNPackedMatMulFP16_SME2);
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain, MNNPackedMatMulRemainFP16_SME2);
FUNC_PTR_ASSIGN(gInstance->MNNGetMatMulPackMode, Sme2MNNGetMatMulPackMode);
FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Sme2MNNPackForMatMul_A_FP16);
FUNC_PTR_ASSIGN(gInstance->MNNPackForMatMul_B, Sme2MNNPackForMatMul_B);
gInstance->sme2MatmulRelatedFuncions.MNNPackedMatMul = MNNPackedMatMulFP16_SME2;
gInstance->sme2MatmulRelatedFuncions.MNNPackedMatMulRemain = MNNPackedMatMulRemainFP16_SME2;
gInstance->sme2MatmulRelatedFuncions.MNNGetMatMulPackMode = Sme2MNNGetMatMulPackMode;
gInstance->sme2MatmulRelatedFuncions.MNNPackC4ForMatMul_A = Sme2MNNPackForMatMul_A_FP16;
gInstance->sme2MatmulRelatedFuncions.MNNPackForMatMul_B = Sme2MNNPackForMatMul_B;
#ifdef MNN_LOW_MEMORY
FUNC_PTR_ASSIGN(gInstance->MNNGeneralIm2Col, MNNGeneralIm2col_Fp16Sme2);
gInstance->sme2MatmulRelatedFuncions.MNNGeneralIm2Col = MNNGeneralIm2col_Fp16Sme2;
#endif
}
#endif // MNN_SME2
#endif // __aarch64__
return true;
}

View File

@ -249,6 +249,268 @@ void MNNNC8HW8TONHWC(float* dest, const FLOAT16* src, size_t plane, size_t chann
#ifdef __aarch64__
#ifdef MNN_LOW_MEMORY
bool MNNAsyLocalQuantInfo_EP16_FP16(float* scale, float* bias, float* qscale, float* qbias, const float* srcMin, const float* srcMax, const size_t* info) {
// dequant scale/bias : [EU, blockNum, step]
// quant scale/bias: [blockNum, EP]
auto blockNum = info[0];
auto EP = info[1];
auto DST_XUNIT = info[3];
if (DST_XUNIT != 16) {
MNN_ERROR("Call error: MNNAsyLocalQuantInfo_EP16_FP16\n");
return false;
}
auto stride = EP * blockNum;
auto minfloat = vdupq_n_f32(1e-6);
auto _255f = vdupq_n_f32(255.f);
auto _128f = vdupq_n_f32(128.f);
auto _0f = vdupq_n_f32(0.f);
auto minPtr = (FLOAT16*)srcMin;
auto maxPtr = (FLOAT16*)srcMax;
for (int k = 0; k < blockNum; ++k) {
auto qind = k * EP;
auto realDstCount = EP;
auto scalePtr = scale + k * ALIMIN(EP, DST_XUNIT);
auto biasPtr = bias + k * ALIMIN(EP, DST_XUNIT);
while (realDstCount > DST_XUNIT - 1) {
auto max0_fp16 = vld1_f16(maxPtr + qind);
auto max1_fp16 = vld1_f16(maxPtr + qind + 4);
auto max2_fp16 = vld1_f16(maxPtr + qind + 8);
auto max3_fp16 = vld1_f16(maxPtr + qind + 12);
auto min0_fp16 = vld1_f16(minPtr + qind);
auto min1_fp16 = vld1_f16(minPtr + qind + 4);
auto min2_fp16 = vld1_f16(minPtr + qind + 8);
auto min3_fp16 = vld1_f16(minPtr + qind + 12);
// float16 -> float32
auto max0 = vcvt_f32_f16(max0_fp16);
auto max1 = vcvt_f32_f16(max1_fp16);
auto max2 = vcvt_f32_f16(max2_fp16);
auto max3 = vcvt_f32_f16(max3_fp16);
auto min0 = vcvt_f32_f16(min0_fp16);
auto min1 = vcvt_f32_f16(min1_fp16);
auto min2 = vcvt_f32_f16(min2_fp16);
auto min3 = vcvt_f32_f16(min3_fp16);
// diff
auto diff0 = vsubq_f32(max0, min0);
auto diff1 = vsubq_f32(max1, min1);
auto diff2 = vsubq_f32(max2, min2);
auto diff3 = vsubq_f32(max3, min3);
auto qscaleV0 = vdivq_f32(_255f, diff0);
auto qscaleV1 = vdivq_f32(_255f, diff1);
auto qscaleV2 = vdivq_f32(_255f, diff2);
auto qscaleV3 = vdivq_f32(_255f, diff3);
auto scaleV0 = vdivq_f32(diff0, _255f);
auto scaleV1 = vdivq_f32(diff1, _255f);
auto scaleV2 = vdivq_f32(diff2, _255f);
auto scaleV3 = vdivq_f32(diff3, _255f);
auto qbiasV0 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min0), diff0), _128f));
auto qbiasV1 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min1), diff1), _128f));
auto qbiasV2 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min2), diff2), _128f));
auto qbiasV3 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min3), diff3), _128f));
auto biasV0 = vaddq_f32(vdivq_f32(vmulq_f32(diff0, _128f), _255f), min0);
auto biasV1 = vaddq_f32(vdivq_f32(vmulq_f32(diff1, _128f), _255f), min1);
auto biasV2 = vaddq_f32(vdivq_f32(vmulq_f32(diff2, _128f), _255f), min2);
auto biasV3 = vaddq_f32(vdivq_f32(vmulq_f32(diff3, _128f), _255f), min3);
auto _0bic = vclezq_f32(diff0);
auto _1bic = vclezq_f32(diff1);
auto _2bic = vclezq_f32(diff2);
auto _3bic = vclezq_f32(diff3);
qscaleV0 = vbslq_f32(_0bic, _0f, qscaleV0);
qscaleV1 = vbslq_f32(_1bic, _0f, qscaleV1);
qscaleV2 = vbslq_f32(_2bic, _0f, qscaleV2);
qscaleV3 = vbslq_f32(_3bic, _0f, qscaleV3);
qbiasV0 = vrndaq_f32(vbslq_f32(_0bic, _0f, qbiasV0));
qbiasV1 = vrndaq_f32(vbslq_f32(_1bic, _0f, qbiasV1));
qbiasV2 = vrndaq_f32(vbslq_f32(_2bic, _0f, qbiasV2));
qbiasV3 = vrndaq_f32(vbslq_f32(_3bic, _0f, qbiasV3));
scaleV0 = vbslq_f32(_0bic, _0f, scaleV0);
scaleV1 = vbslq_f32(_1bic, _0f, scaleV1);
scaleV2 = vbslq_f32(_2bic, _0f, scaleV2);
scaleV3 = vbslq_f32(_3bic, _0f, scaleV3);
biasV0 = vbslq_f32(_0bic, max0, biasV0);
biasV1 = vbslq_f32(_1bic, max1, biasV1);
biasV2 = vbslq_f32(_2bic, max2, biasV2);
biasV3 = vbslq_f32(_3bic, max3, biasV3);
vst1q_f32(qscale + qind, qscaleV0);
vst1q_f32(qscale + qind + 4, qscaleV1);
vst1q_f32(qscale + qind + 8, qscaleV2);
vst1q_f32(qscale + qind + 12, qscaleV3);
vst1q_f32(qbias + qind, qbiasV0);
vst1q_f32(qbias + qind + 4, qbiasV1);
vst1q_f32(qbias + qind + 8, qbiasV2);
vst1q_f32(qbias + qind + 12, qbiasV3);
vst1q_f32(scalePtr, scaleV0);
vst1q_f32(scalePtr + 4, scaleV1);
vst1q_f32(scalePtr + 8, scaleV2);
vst1q_f32(scalePtr + 12, scaleV3);
vst1q_f32(biasPtr, biasV0);
vst1q_f32(biasPtr + 4, biasV1);
vst1q_f32(biasPtr + 8, biasV2);
vst1q_f32(biasPtr + 12, biasV3);
realDstCount -= DST_XUNIT;
qind += DST_XUNIT;
scalePtr += (blockNum * DST_XUNIT);
biasPtr += (blockNum * DST_XUNIT);
}
if (realDstCount == 0) {
continue;
}
auto remainE = realDstCount;
auto stride0 = remainE * blockNum;
scalePtr = scale + (EP / DST_XUNIT) * blockNum * DST_XUNIT + k * remainE;
biasPtr = bias + (EP / DST_XUNIT) * blockNum * DST_XUNIT + k * remainE;
if (realDstCount > 7) {
auto max0_fp16 = vld1_f16(maxPtr + qind);
auto max1_fp16 = vld1_f16(maxPtr + qind + 4);
auto min0_fp16 = vld1_f16(minPtr + qind);
auto min1_fp16 = vld1_f16(minPtr + qind + 4);
// float16 -> float32
auto max0 = vcvt_f32_f16(max0_fp16);
auto max1 = vcvt_f32_f16(max1_fp16);
auto min0 = vcvt_f32_f16(min0_fp16);
auto min1 = vcvt_f32_f16(min1_fp16);
// diff
auto diff0 = vsubq_f32(max0, min0);
auto diff1 = vsubq_f32(max1, min1);
auto qscaleV0 = vdivq_f32(_255f, diff0);
auto qscaleV1 = vdivq_f32(_255f, diff1);
auto scaleV0 = vdivq_f32(diff0, _255f);
auto scaleV1 = vdivq_f32(diff1, _255f);
auto qbiasV0 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min0), diff0), _128f));
auto qbiasV1 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min1), diff1), _128f));
auto biasV0 = vaddq_f32(vdivq_f32(vmulq_f32(diff0, _128f), _255f), min0);
auto biasV1 = vaddq_f32(vdivq_f32(vmulq_f32(diff1, _128f), _255f), min1);
auto _0bic = vclezq_f32(diff0);
auto _1bic = vclezq_f32(diff1);
qscaleV0 = vbslq_f32(_0bic, _0f, qscaleV0);
qscaleV1 = vbslq_f32(_1bic, _0f, qscaleV1);
qbiasV0 = vrndaq_f32(vbslq_f32(_0bic, _0f, qbiasV0));
qbiasV1 = vrndaq_f32(vbslq_f32(_1bic, _0f, qbiasV1));
scaleV0 = vbslq_f32(_0bic, _0f, scaleV0);
scaleV1 = vbslq_f32(_1bic, _0f, scaleV1);
biasV0 = vbslq_f32(_0bic, max0, biasV0);
biasV1 = vbslq_f32(_1bic, max1, biasV1);
vst1q_f32(qscale + qind, qscaleV0);
vst1q_f32(qscale + qind + 4, qscaleV1);
vst1q_f32(qbias + qind, qbiasV0);
vst1q_f32(qbias + qind + 4, qbiasV1);
vst1q_f32(scalePtr, scaleV0);
vst1q_f32(scalePtr + 4, scaleV1);
vst1q_f32(biasPtr, biasV0);
vst1q_f32(biasPtr + 4, biasV1);
realDstCount -= 8;
qind += 8;
scalePtr += 8;
biasPtr += 8;
}
if (realDstCount > 3) {
auto max0_fp16 = vld1_f16(maxPtr + qind);
auto min0_fp16 = vld1_f16(minPtr + qind);
// float16 -> float32
auto max0 = vcvt_f32_f16(max0_fp16);
auto min0 = vcvt_f32_f16(min0_fp16);
// diff
auto diff0 = vsubq_f32(max0, min0);
auto qscaleV0 = vdivq_f32(_255f, diff0);
auto scaleV0 = vdivq_f32(diff0, _255f);
auto qbiasV0 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min0), diff0), _128f));
auto biasV0 = vaddq_f32(vdivq_f32(vmulq_f32(diff0, _128f), _255f), min0);
auto _0bic = vclezq_f32(diff0);
qscaleV0 = vbslq_f32(_0bic, _0f, qscaleV0);
qbiasV0 = vrndaq_f32(vbslq_f32(_0bic, _0f, qbiasV0));
scaleV0 = vbslq_f32(_0bic, _0f, scaleV0);
biasV0 = vbslq_f32(_0bic, max0, biasV0);
vst1q_f32(qscale + qind, qscaleV0);
vst1q_f32(qbias + qind, qbiasV0);
vst1q_f32(scalePtr, scaleV0);
vst1q_f32(biasPtr, biasV0);
realDstCount -= 4;
qind += 4;
scalePtr += 4;
biasPtr += 4;
}
while (realDstCount > 0) {
auto max0_fp16 = vld1_dup_f16(maxPtr + qind);
auto min0_fp16 = vld1_dup_f16(minPtr + qind);
// float16->float32
auto max0 = vcvt_f32_f16(max0_fp16);
auto min0 = vcvt_f32_f16(min0_fp16);
auto diff0 = vsubq_f32(max0, min0);
auto qscaleV0 = vdivq_f32(_255f, diff0);
auto scaleV0 = vdivq_f32(diff0, _255f);
auto qbiasV0 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min0), diff0), _128f));
auto biasV0 = vaddq_f32(vdivq_f32(vmulq_f32(diff0, _128f), _255f), min0);
auto _0bic = vclezq_f32(diff0);
qscaleV0 = vbslq_f32(_0bic, _0f, qscaleV0);
qbiasV0 = vrndaq_f32(vbslq_f32(_0bic, _0f, qbiasV0));
scaleV0 = vbslq_f32(_0bic, _0f, scaleV0);
biasV0 = vbslq_f32(_0bic, max0, biasV0);
vst1q_lane_f32(qscale + qind, qscaleV0, 0);
vst1q_lane_f32(qbias + qind, qbiasV0, 0);
vst1q_lane_f32(scalePtr, scaleV0, 0);
vst1q_lane_f32(biasPtr, biasV0, 0);
realDstCount -= 1;
qind += 1;
scalePtr += 1;
biasPtr += 1;
}
}
return true;
}
bool MNNAsyLocalQuantInfo_EP12_FP16(float* scale, float* bias, float* qscale, float* qbias, const float* srcMin, const float* srcMax, const size_t* info) {
// dequant scale/bias : [EU, blockNum, step]
// quant scale/bias: [blockNum, EP]

View File

@ -36,6 +36,7 @@ void MNNSlowCopy(T* dst, const U* src, size_t size) {
#ifdef __aarch64__
bool MNNAsyLocalQuantInfo_EP12_FP16(float* scale, float* bias, float* qscale, float* qbias, const float* srcMin, const float* srcMax, const size_t* info);
bool MNNAsyLocalQuantInfo_EP10_FP16(float* scale, float* bias, float* qscale, float* qbias, const float* srcMin, const float* srcMax, const size_t* info);
bool MNNAsyLocalQuantInfo_EP16_FP16(float* scale, float* bias, float* qscale, float* qbias, const float* srcMin, const float* srcMax, const size_t* info);
#endif
#endif

View File

@ -13,7 +13,15 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR AR
if (MNN_CPU_WEIGHT_DEQUANT_GEMM)
file(GLOB MNN_ARM82_SRCS_ASM ${MNN_ARM82_SRCS_ASM} ${CMAKE_CURRENT_LIST_DIR}/asm/arm64/normal_memory/*)
endif()
add_library(MNN_Arm82 OBJECT ${MNN_ARM82_SRCS} ${MNN_ARM82_SRCS_ASM})
if (MNN_SME2)
file(GLOB MNN_SME2_SRCS_ASM_FP16 ${MNN_SME2_SRCS_ASM_FP16} ${CMAKE_CURRENT_LIST_DIR}/asm/arm64/sme2_asm/*)
#set_source_files_properties(${MNN_SME2_SRCS_ASM_FP16} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.6-a+sve+sve2+sme+sme2+fp16")
set_source_files_properties(${MNN_SME2_SRCS_ASM_FP16} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+fp16")
endif()
set_source_files_properties(${MNN_ARM82_SRCS} PROPERTIES COMPILE_OPTIONS "-march=armv8.2-a+fp16")
set_source_files_properties(${MNN_ARM82_SRCS_ASM} PROPERTIES COMPILE_OPTIONS "-march=armv8.2-a+fp16")
add_library(MNN_Arm82 OBJECT ${MNN_ARM82_SRCS} ${MNN_ARM82_SRCS_ASM} ${MNN_SME2_SRCS_ASM_FP16})
if (MNN_LOW_MEMORY)
target_compile_options(MNN_Arm82 PRIVATE -DMNN_LOW_MEMORY)
endif()

File diff suppressed because it is too large Load Diff

View File

@ -56,13 +56,19 @@ void CPUAttention::pack_query(Tensor* query, char* pack_q, char* sum_q, int seq_
}
}
else {
// target: [seq_len/eP, mHeadDim/lP, eP, lP]
T * query_src = query->host<T>();
T * query_dst = reinterpret_cast<T*>(pack_q);
auto stride0 = ROUND_UP(mHeadDim, lP) * eP;
auto stride1 = eP * lP;
if (mHeadDim % lP) {
memset(query_dst, 0, ROUND_UP(mHeadDim, lP) * bytes * ROUND_UP(seq_len, eP));
}
for (int i = 0; i < seq_len; i++) {
int out_index = i / eP;
int in_index = i % eP;
for (int j = 0; j < mHeadDim; j++) {
query_dst[out_index * mHeadDim * eP + j * eP + in_index] = query_src[i * mNumHead * mHeadDim + h * mHeadDim + j] * q_scale;
query_dst[out_index * stride0 + (j / lP) * stride1 + in_index * lP + (j % lP)] = query_src[i * mNumHead * mHeadDim + h * mHeadDim + j] * q_scale;
}
}
}
@ -83,15 +89,20 @@ void CPUAttention::unpack_QK(float * unpack_qk_dst, char * pack_qk_src, int seq_
}
template <typename T>
static void pack_QK(char * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_len, int eP) {
static void pack_QK(char * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_len, int eP, int lP, int bytes) {
T * dst = reinterpret_cast<T*>(pack_qk_dst);
float * src = reinterpret_cast<float*>(qk_src);
// [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len, eP]
// [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len/lP, eP, lP]
auto stride0 = ROUND_UP(kv_seq_len, lP) * eP;
auto stride1 = eP * lP;
if (kv_seq_len % lP) {
memset(dst, 0, ROUND_UP(kv_seq_len, lP) * ROUND_UP(seq_len, eP) * bytes);
}
for (int i = 0; i < seq_len; i++) {
int out_index = i / eP;
int in_index = i % eP;
for (int j = 0; j < kv_seq_len; j++) {
dst[out_index * kv_seq_len * eP + j * eP + in_index] = src[i * kv_seq_len + j];
dst[out_index * stride0 + (j / lP) * stride1 + in_index * lP + (j % lP)] = src[i * kv_seq_len + j];
}
}
}
@ -181,7 +192,7 @@ ErrorCode CPUAttention::onResize(const std::vector<Tensor*>& inputs, const std::
mQueryScale.resize(mNumHead);
mQueryZeroPoint.resize(mNumHead);
} else {
mPackQ.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(seq_len, eP), mHeadDim, eP}));
mPackQ.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mHeadDim, lP), eP}));
mPackQKV.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(mHeadDim, unit), seq_len, unit}));
backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC);
backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC);
@ -242,7 +253,7 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std:
std::shared_ptr<Tensor> packQK(Tensor::createDevice<float>({mThreadNum, UP_DIV(kv_seq_len, unit), seq_len, unit}));
std::shared_ptr<Tensor> unpackQK(Tensor::createDevice<int32_t>({mThreadNum, seq_len, kv_seq_len}));
std::shared_ptr<Tensor> softmMaxQ(Tensor::createDevice<int32_t>({mThreadNum, seq_len, kv_seq_len}));
std::shared_ptr<Tensor> newPackQK(Tensor::createDevice<float>({mThreadNum, UP_DIV(seq_len, eP), kv_seq_len, eP}));
std::shared_ptr<Tensor> newPackQK(Tensor::createDevice<float>({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(kv_seq_len, lP), eP}));
std::shared_ptr<Tensor> dequantV(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), kv_seq_len, hP}));
backend()->onAcquireBuffer(packQK.get(), Backend::STATIC);
backend()->onAcquireBuffer(unpackQK.get(), Backend::STATIC);
@ -254,12 +265,12 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std:
}
std::function<void(int)> mCompute = [=](int tId) {
auto pack_q = mPackQ->host<char>() + tId * UP_DIV(seq_len, eP) * mHeadDim * eP * bytes;
auto pack_q = mPackQ->host<char>() + tId * UP_DIV(seq_len, eP) * ROUND_UP(mHeadDim, lP) * eP * bytes;
auto pack_qk = packQK->host<char>() + tId * UP_DIV(kv_seq_len, unit) * seq_len * unit * bytes;
char * sum_q = nullptr;
auto unpack_qk = unpackQK->host<float>() + tId * seq_len * kv_seq_len;
auto softmax_qk = softmMaxQ->host<float>() + tId * seq_len * kv_seq_len;
auto new_pack_qk = newPackQK->host<char>() + tId * UP_DIV(seq_len, eP) * kv_seq_len * eP * bytes;
auto new_pack_qk = newPackQK->host<char>() + tId * UP_DIV(seq_len, eP) * ROUND_UP(kv_seq_len, lP) * eP * bytes;
auto pack_qkv = mPackQKV->host<char>() + tId * UP_DIV(mHeadDim, unit) * seq_len * unit * bytes;
auto QxK = quant_key ? core->MNNPackedMatMul_int8 : core->MNNPackedMatMul;
auto QxK_remain = quant_key ? core->MNNPackedMatMulRemain_int8 : core->MNNPackedMatMulRemain;
@ -274,7 +285,7 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std:
char * scale_addr = mKVCacheManager->addrOfScale(kv_h);
char * zero_point_addr = mKVCacheManager->addrOfZeroPoint(kv_h);
char * key_sum_addr = mKVCacheManager->addrOfKeySum(kv_h);
char * value_addr = quant_value ? (dequantV->host<char>() + kv_h * UP_DIV(mHeadDim, hP) * kv_seq_len * hP * bytes) : mKVCacheManager->addrOfValue(kv_h);
char * value_addr = quant_value ? (dequantV->host<char>() + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(kv_seq_len, lP) * hP * bytes) : mKVCacheManager->addrOfValue(kv_h);
if (bytes == 2) {
pack_query<FLOAT16_T>(query, pack_q, sum_q, seq_len, h, q_scale);
} else {
@ -327,33 +338,36 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std:
else {
int loop_e = seq_len / eP;
int remain = seq_len % eP;
size_t shapeParameters[7] = {(size_t)eP * bytes, (size_t)mHeadDim, (size_t)kv_seq_len, (size_t)seq_len * unit * bytes, 0, 0, 0};
auto qStride0 = ROUND_UP(mHeadDim, lP) * eP * bytes;
size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)mHeadDim, lP), (size_t)kv_seq_len, (size_t)seq_len * unit * bytes, 0, 0, 0};
for (int i = 0 ; i < loop_e; i++) {
QxK((float*)(pack_qk + (i * eP * unit) * bytes), (float*)(pack_q + (i * mHeadDim * eP) * bytes), (float*)key_addr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
QxK((float*)(pack_qk + (i * eP * unit) * bytes), (float*)(pack_q + i * qStride0), (float*)key_addr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
}
QxK_remain((float*)(pack_qk + (loop_e * eP * unit) * bytes), (float*)(pack_q + (loop_e * mHeadDim * eP) * bytes), (float*)key_addr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
QxK_remain((float*)(pack_qk + (loop_e * eP * unit) * bytes), (float*)(pack_q + loop_e * qStride0), (float*)key_addr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
}
// qk: [kv_seq_len/unit, seq_len, unit] -> [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len, eP]
if(bytes == 2) {
unpack_QK<FLOAT16_T>(unpack_qk, pack_qk, seq_len, kv_seq_len);
mask_QK<FLOAT16_T>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask);
softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len);
pack_QK<FLOAT16_T>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP);
pack_QK<FLOAT16_T>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes);
} else {
unpack_QK<float>(unpack_qk, pack_qk, seq_len, kv_seq_len);
mask_QK<float>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask);
softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len);
pack_QK<float>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP);
pack_QK<float>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes);
}
// qk @ v
size_t shapeParameters[7] = {(size_t)eP * bytes, (size_t)kv_seq_len, (size_t)mHeadDim, (size_t)seq_len * unit * bytes, 0, 0, 0};
shapeParameters[5] = quant_value ? 0 : (max_len - kv_seq_len) * hP * bytes;
size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)kv_seq_len, lP), (size_t)mHeadDim, (size_t)seq_len * unit * bytes, 0, 0, 0};
size_t bExtraStride = (UP_DIV(max_len, lP) - UP_DIV(kv_seq_len, lP)) * hP * lP * bytes;
shapeParameters[5] = quant_value ? 0 : bExtraStride;
int loop_e = seq_len / eP;
int remain = seq_len % eP;
auto qkStride0 = ROUND_UP(kv_seq_len, lP) * eP * bytes;
for (int i = 0 ; i < loop_e; i++) {
core->MNNPackedMatMul((float*)(pack_qkv + (i * eP * unit) * bytes), (float*)(new_pack_qk + (i * kv_seq_len * eP) * bytes), (float*)value_addr, shapeParameters, nullptr, nullptr, nullptr, nullptr);
core->MNNPackedMatMul((float*)(pack_qkv + (i * eP * unit) * bytes), (float*)(new_pack_qk + i * qkStride0), (float*)value_addr, shapeParameters, nullptr, nullptr, nullptr, nullptr);
}
core->MNNPackedMatMulRemain((float*)(pack_qkv + (loop_e * eP * unit) * bytes), (float*)(new_pack_qk + (loop_e * kv_seq_len * eP) * bytes), (float*)value_addr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr);
core->MNNPackedMatMulRemain((float*)(pack_qkv + (loop_e * eP * unit) * bytes), (float*)(new_pack_qk + loop_e * qkStride0), (float*)value_addr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr);
// unpack: [head_dim/unit, seq_len, unit] -> [seq_len, num_head, head_dim]
auto dst_ptr = outputs[0]->host<char>() + h * mHeadDim * bytes;
if (bytes == 2) {
@ -393,7 +407,7 @@ bool CPUAttention::onClone(Backend* bn, const Op* op, Execution** dst) {
}
CPUAttention::CPUAttention(Backend *backend, bool kv_cache) : Execution(backend), mKVCache(kv_cache) {
mMeta = (KVMeta*)(backend->getRuntime()->pMeta);
mMeta = (KVMeta*)(backend->getMetaPtr());
mPackQ.reset(Tensor::createDevice<float>({1, 1, 1, 1}));
mPackQKV.reset(Tensor::createDevice<float>({1, 1, 1, 1}));
MNN::KVCacheManager::KVCacheConfig kvconfig;

View File

@ -102,7 +102,7 @@ void CPURuntime::_bindCPUCore() const {
}
#endif
#ifdef MNN_USE_THREAD_POOL
if(mThreadPool) {
if (nullptr != mThreadPool) {
mThreadPool->active();
mThreadPool->enqueue(std::make_pair([&](int i) {
MNNSetSchedAffinity(lockCPUIndexes[i].first, lockCPUIndexes[i].second);
@ -113,24 +113,12 @@ void CPURuntime::_bindCPUCore() const {
#endif
}
void CPURuntime::_resetThreadPool() const{
void CPURuntime::_resetThreadPool() const {
mThreadNumber = std::max(1, mThreadNumber);
mThreadNumber = std::min(mThreadNumber, MAX_THREAD_NUMBER);
#ifdef MNN_USE_THREAD_POOL
if (mThreadPool) {
mThreadPool->releaseWorkIndex(mTaskIndex);
}
if (mThreadNumber > 1) {
mThreadNumber = ALIMIN(ThreadPool::init(mThreadNumber, mCpuMask, mThreadPool), mThreadNumber);
if (mThreadPool) {
mTaskIndex = mThreadPool->acquireWorkIndex();
}
if (-1 == mTaskIndex) {
MNN_ERROR("The ThreadPool has been used to MNN_THREAD_POOL_MAX_TASKS, can't use thread pool\n");
mThreadNumber = 1;
}
} else {
mTaskIndex = -1;
}
#endif
// Reset tid to rebind cpu if necessary
@ -256,11 +244,7 @@ CPURuntime::CPURuntime(const Backend::Info& info) {
}
CPURuntime:: ~ CPURuntime() {
#ifdef MNN_USE_THREAD_POOL
if(mThreadPool) {
mThreadPool->releaseWorkIndex(mTaskIndex);
}
#endif
// Do nothing
}
float CPURuntime::onGetMemoryInMB() {
auto staticMemoryInMB = mStaticAllocator->totalSize() / 1024.0f / 1024.0f;
@ -336,11 +320,17 @@ Backend* CPURuntime::onCreate(const BackendConfig* config, Backend* origin) cons
MNN_PRINT("cpu backend was created by runtime:%p\n", this);
#endif
CPUBackend* res = nullptr;
auto initThreadNumber = hint().initThreadNumber;
do {
#ifdef MNN_USE_ARMV82
auto core = MNNGetCoreFunctions();
if (core->supportFp16arith && precision == BackendConfig::Precision_Low) {
res = new Arm82Backend(this, memory);
if (hint().useArmSme2Cores && res->threadNumber() <= 2 && core->supportSME2 && res->functions()->sme2MatmulRelatedFuncions.Int8GemmKernel) {
res->mRelatedFunctions = &(res->functions()->sme2MatmulRelatedFuncions);
} else {
res->mRelatedFunctions = &(res->functions()->backendMatmulRelatedFunctions);
}
break;
}
#endif
@ -353,7 +343,7 @@ Backend* CPURuntime::onCreate(const BackendConfig* config, Backend* origin) cons
#endif
if (flags == MNN_CPU_USE_DEFAULT_BACKEND) {
// Default don't use multi-thread init
res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU, 0);
res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU);
break;
}
#ifdef MNN_USE_SSE
@ -365,6 +355,7 @@ Backend* CPURuntime::onCreate(const BackendConfig* config, Backend* origin) cons
res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU, flags);
} while (false);
mSharedDmaInfo = nullptr;
res->setMetaPtr(pMeta);
return res;
}
@ -402,9 +393,13 @@ void CPURuntime::onGabageCollect(int level) {
void CPURuntime::onConcurrencyBegin() const {
#ifdef MNN_USE_THREAD_POOL
if (mTaskIndex < 0 && nullptr != mThreadPool) {
mTaskIndex = mThreadPool->acquireWorkIndex();
}
if (mTaskIndex >= 0) {
if (mThreadOpen == 0 && mThreadPool) {
// mThreadOpen 0 -> 1, open ThreadPool
// mThreadOpen 0 -> 1, active ThreadPool
// For next onConcurrencyBegin, will only add mThreadOpen
if (0 == mThreadOpen) {
mThreadPool->active();
}
mThreadOpen++;
@ -421,11 +416,12 @@ void CPURuntime::onConcurrencyBegin() const {
void CPURuntime::onConcurrencyEnd() const {
#ifdef MNN_USE_THREAD_POOL
if (mTaskIndex >= 0) {
MNN_ASSERT(mThreadOpen > 0);
mThreadOpen--;
mThreadOpen = mThreadOpen < 0 ? 0 : mThreadOpen;
if (0 == mThreadOpen && mThreadPool) {
if (0 == mThreadOpen) {
mThreadPool->releaseWorkIndex(mTaskIndex);
mThreadPool->deactive();
mTaskIndex = -1;
}
}
#endif
@ -460,10 +456,14 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p
#endif
mMemory = memory;
mRuntime = const_cast<CPURuntime*>(runtime);
auto core = MNNGetCoreFunctions();
mThreadNumber = mRuntime->mThreadNumber;
#ifdef MNN_USE_THREAD_POOL
mThreadPool = mRuntime->mThreadPool;
#endif
if (mRuntime->hint().useArmSme2Cores && core->supportSME2 && core->sme2MatmulRelatedFuncions.Int8GemmKernel) {
mThreadNumber = 2;
mRelatedFunctions = &core->sme2MatmulRelatedFuncions;
} else {
mRelatedFunctions = &core->backendMatmulRelatedFunctions;
}
// Compute Group Rate
do {
if (mThreadNumber <= 1 || mRuntime->mPower == BackendConfig::Power_Low) {

View File

@ -60,10 +60,12 @@ private:
mutable int mThreadNumber;
mutable std::vector<int> mCpuIds;
mutable unsigned long mCpuMask;
#ifdef MNN_USE_THREAD_POOL
mutable ThreadPool* mThreadPool = nullptr;
#endif
#ifdef MNN_USE_THREAD_POOL
mutable int mTaskIndex = -1;
mutable int mThreadOpen = 0;
mutable ThreadPool* mThreadPool = nullptr;
#endif
BackendConfig::MemoryMode mMemory;
BackendConfig::PowerMode mPower;
@ -82,6 +84,7 @@ private:
};
struct CoreFunctions;
struct CoreInt8Functions;
struct MatmulRelatedFunctions;
class CPUResizeCache;
class CPUMemObj : public Backend::MemObj {
@ -134,6 +137,10 @@ public:
const CoreFunctions* functions() const {
return mCoreFunctions;
}
const MatmulRelatedFunctions* int8GemmFunctions() const {
return mRelatedFunctions;
}
// Return element size for Tensor, conside pack
size_t getTensorSize(const Tensor* tensor, bool multiBytes = false) const;
const CoreInt8Functions* int8Functions() const {
@ -183,12 +190,10 @@ protected:
MemObj* allocBuffer(size_t size, Tensor* dest, StorageType storageType);
CoreFunctions* mCoreFunctions;
CoreInt8Functions* mInt8CoreFunctions;
const MatmulRelatedFunctions* mRelatedFunctions;
private:
mutable std::shared_ptr<WorkerThread> mInitWorkQueue;
int mThreadNumber;
#ifdef MNN_USE_THREAD_POOL
ThreadPool* mThreadPool = nullptr;
#endif
mutable int mThreadNumber = 1;
std::vector<std::pair<float, int>> mGroupWithComputeRate;
float mComputeI = 0.f;

View File

@ -53,7 +53,7 @@ static void _transformWeight(const uint8_t* tempWeight, uint8_t* dest, int outpu
core->MNNPackCUnit((float*)dst, (const float*)src, fw*fh, outputCount, offset);
}
//printf("%d - %d - %d - %d\n", outputCount, srcCount, fh, fw);
core->MNNPackForMatMul_B((float*)dest, (const float*)cache, outputC4 * fw * fh * core->pack, srcCount, false);
core->MNNPackForMatMul_B((float*)dest, (const float*)cache, outputC4 * fw * fh * core->pack, 1, srcCount, false);
}
std::shared_ptr<DeconvolutionResource> CPUDeconvolution::makeResource(int srcCount, const Op *convOp, Backend* backend, bool dynamic) {
auto core = static_cast<CPUBackend*>(backend)->functions();
@ -253,7 +253,7 @@ ErrorCode CPUDeconvolutionOrigin::onResize(const std::vector<Tensor*>& inputs, c
}
::memset(tempOutPtr, 0, outputSize);
int l = mSrcCount;
int l = ROUND_UP(mSrcCount, lP);
int h = kernelCount * core->pack;
auto weightPtr = weightTensor->host<uint8_t>();
for (int index=tId; index < tileCount; index+=threadNumber) {

View File

@ -101,7 +101,7 @@ ErrorCode CPUMatMul::onResize(const std::vector<Tensor*>& inputs, const std::vec
}
mPreFunctions.emplace_back(std::make_pair([BTPtrAlloc, l, h, this, core] (int tId, const float* APtr, const float* BPtr, const float* Bias, float* C) {
core->MNNPackForMatMul_B((float*)BTPtrAlloc.ptr(), BPtr, h, l, mTransposeB);
core->MNNPackForMatMul_B((float*)BTPtrAlloc.ptr(), BPtr, h, 1, l, mTransposeB);
} , 1));
bool useBias = false;
MemChunk bdestAlloc;
@ -194,7 +194,7 @@ void CPUMatMul::execute(const float* APtr, const float* BPtr, float* CPtr, const
auto TC = mTempC.ptr() + tId * eP * hC4 * core->pack * core->bytes;
size_t parameters[6];
parameters[0] = eP * core->bytes;
parameters[1] = mL;
parameters[1] = lAlign;
parameters[2] = mH;
parameters[3] = eP * core->pack * core->bytes;
parameters[4] = 0;
@ -251,7 +251,7 @@ void CPUMatMul::execute(const float* APtr, const float* BPtr, float* CPtr, const
int yy = lC;
for (int x=0; x<xC; ++x) {
::memset(TA + (yy * eP * lP + x * lP) * core->bytes, 0, lP * core->bytes);
::memcpy(TA + (yy * eP * lP + x * lP) * core->bytes, (uint8_t*)APtr + ((x+xStart)*mL+yy*lP)*core->bytes, xC * core->bytes);
::memcpy(TA + (yy * eP * lP + x * lP) * core->bytes, (uint8_t*)APtr + ((x+xStart)*mL+yy*lP)*core->bytes, lR * core->bytes);
}
}
} else {

View File

@ -68,7 +68,7 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength,
mulFunction(resetGatePtr, rtPtr, hiddenStatePtr, numUnits, -1);
// deal with recurrent bias and linear_before_reset parameter
auto recurrentBiasAddedPtr = inputAndStatePtr + (inputLength + numUnits) * bytes;
auto recurrentHiddenBiasPtr = recurrentBias->host<float>() + 2 * numUnits * bytes;
auto recurrentHiddenBiasPtr = (float*)(recurrentBias->host<uint8_t>() + 2 * numUnits * bytes);
addFunction(recurrentBiasAddedPtr, recurrentHiddenBiasPtr, candidateBias->host<float>(), numUnits, -1);
mMatMulI2U->execute(inputAndState->host<float>(), candidateWeight->host<float>(), resetHt->host<float>(), nullptr);
// reuse r_t memory as h_t'

View File

@ -20,6 +20,8 @@
using Vec4 = MNN::Math::Vec<float, 4>;
namespace MNN {
typedef void (*FP16ToFP32)(const int16_t* src, float* dst, size_t size);
typedef void (*FP32ToFP16)(const float* src, int16_t* dst, size_t size);
struct ReduceInfo {
int reduceMask[3] = {0, 0, 0};
int reduceNum = 0;
@ -434,7 +436,7 @@ static void _zero(const Tensor::InsideDescribe::Region& slice, int bytes, uint8_
}
}
}
static bool _reduceblit(const Tensor::InsideDescribe::Region& slice, int bytes, const uint8_t* srcPtr, uint8_t* dstPtr) {
static bool _reduceblit(const Tensor::InsideDescribe::Region& slice, int bytes, const uint8_t* srcPtr, uint8_t* dstPtr, FP16ToFP32 funcFp16ToFp32 = nullptr, FP32ToFP16 funcFp32ToFp16 = nullptr) {
ReduceInfo reduceInfo;
reduceInfo.compute(slice);
auto normalIndex = reduceInfo.normalIndex;
@ -443,21 +445,31 @@ static bool _reduceblit(const Tensor::InsideDescribe::Region& slice, int bytes,
case 3:
{
float summer = 0.0f;
float fp32Buffer[1];
for (int z=0; z<slice.size[0]; ++z) {
auto srcZ = srcPtr + z * slice.src.stride[0] * bytes;
for (int y=0; y<slice.size[1]; ++y) {
auto srcY = srcZ + y * slice.src.stride[1] * bytes;
auto S = (float*)srcY;
for (int x=0; x<slice.size[2]; ++x) {
auto S = (float*)srcY;
if (bytes == 2) {
funcFp16ToFp32((int16_t*)srcY, fp32Buffer, 1);
S = fp32Buffer;
}
summer += S[slice.src.stride[2] * x];
}
}
}
((float*)dstPtr)[0] = summer;
if (bytes == 4) {
((float*)dstPtr)[0] = summer;
} else {
funcFp32ToFp16(&summer, (int16_t*)dstPtr, 1);
}
return true;
}
case 2:
{
float fp32Buffer[1];
int sizeZ = slice.size[normalIndex[0]];
int srcStrideZ = slice.src.stride[normalIndex[0]];
int dstStrideZ = slice.dst.stride[normalIndex[0]];
@ -473,12 +485,20 @@ static bool _reduceblit(const Tensor::InsideDescribe::Region& slice, int bytes,
auto dstZ = dstPtr + z * dstStrideZ * bytes;
for (int y=0; y<sizeY; ++y) {
auto srcY = srcZ + y * srcStrideY * bytes;
auto S = (float*)srcY;
for (int x=0; x<sizeX; ++x) {
auto S = (float*)srcY;
if (bytes == 2) {
funcFp16ToFp32((int16_t*)srcY, fp32Buffer, 1);
S = fp32Buffer;
}
summer += S[srcStrideX * x];
}
}
((float*)dstZ)[0] = summer;
if (bytes == 4) {
((float*)dstZ)[0] = summer;
} else {
funcFp32ToFp16(&summer, (int16_t*)dstZ, 1);
}
}
return true;
}
@ -493,6 +513,7 @@ static bool _reduceblit(const Tensor::InsideDescribe::Region& slice, int bytes,
int sizeX = slice.size[reduceIndex[0]];
int srcStrideX = slice.src.stride[reduceIndex[0]];
int dstStrideX = slice.dst.stride[reduceIndex[0]];
std::vector<float> fp32Buffer(sizeX);
for (int z=0; z<sizeZ; ++z) {
auto srcZ = srcPtr + z * srcStrideZ * bytes;
auto dstZ = dstPtr + z * dstStrideZ * bytes;
@ -500,11 +521,20 @@ static bool _reduceblit(const Tensor::InsideDescribe::Region& slice, int bytes,
float summer = 0.0f;
auto srcY = srcZ + y * srcStrideY * bytes;
auto dstY = dstZ + y * dstStrideY * bytes;
auto S = (float*)srcY;
float* S = (float*)srcY;
float* D = (float*)dstY;
if (bytes == 2) {
funcFp16ToFp32((int16_t*)srcY, fp32Buffer.data(), sizeX);
S = fp32Buffer.data();
}
for (int x=0; x<sizeX; ++x) {
summer += S[srcStrideX * x];
}
((float*)dstY)[0] = summer;
if (bytes == 2) {
funcFp32ToFp16(&summer, (int16_t*)dstY, 1);
} else {
D[0] = summer;
}
}
}
return true;
@ -515,10 +545,10 @@ static bool _reduceblit(const Tensor::InsideDescribe::Region& slice, int bytes,
return false;
}
static void _blit(const Tensor::InsideDescribe::Region& slice, int bytes, const uint8_t* srcPtr, uint8_t* dstPtr, bool hasReduce) {
static void _blit(const Tensor::InsideDescribe::Region& slice, int bytes, const uint8_t* srcPtr, uint8_t* dstPtr, bool hasReduce, FP16ToFP32 funcFp16ToFp32 = nullptr, FP32ToFP16 funcFp32ToFp16 = nullptr) {
auto proc = _selectUnitProc(bytes, slice.src.stride[2], slice.dst.stride[2]);
if (hasReduce) {
if (_reduceblit(slice, bytes, srcPtr, dstPtr)) {
if (_reduceblit(slice, bytes, srcPtr, dstPtr, funcFp16ToFp32, funcFp32ToFp16)) {
return;
}
}
@ -607,10 +637,10 @@ ErrorCode CPURaster::onExecute(const std::vector<Tensor *> &____inputs, const st
auto sourceFormat = TensorUtils::getDescribe(realInput)->dimensionFormat;
auto destFormat = TensorUtils::getDescribe(output)->dimensionFormat;
auto channelC4 = UP_DIV(srcChannel, core->pack);
int batchStrideC4 = channelC4 * core->pack * srcArea * bytes;
int batchStride = srcChannel * srcArea * bytes;
int inputBatchStride = batchStride;
int outputBatchStride = batchStride;
auto batchStrideC4 = channelC4 * core->pack * srcArea * bytes;
auto batchStride = srcChannel * srcArea * bytes;
auto inputBatchStride = batchStride;
auto outputBatchStride = batchStride;
if (MNN_DATA_FORMAT_NC4HW4 == sourceFormat) {
if (realInput->dimensions() <= 1) {
::memcpy(output->host<uint8_t>(), realInput->host<uint8_t>(), realInput->elementSize() * bytes);
@ -670,7 +700,7 @@ ErrorCode CPURaster::onExecute(const std::vector<Tensor *> &____inputs, const st
}
auto srcPtr = iter.first->host<uint8_t>() + slice.src.offset * bytes;
auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes;
_blit(slice, bytes, srcPtr, dstPtr, mHasReduce);
_blit(slice, bytes, srcPtr, dstPtr, mHasReduce, core->MNNLowpToFp32, core->MNNFp32ToLowp);
}
}
MNN_CONCURRENCY_END();
@ -1088,11 +1118,11 @@ public:
break;
case BinaryOpOperation_MUL:
for (int z=0; z<sizeZ; ++z) {
auto srcZ = srcF + z * dstStride[0];
auto dstZ = dstF + z * outputStride[0];
auto srcZ = srcF + z * outputStride[0];
auto dstZ = dstF + z * dstStride[0];
for (int y=0; y<sizeY; ++y) {
auto srcY = srcZ + z * dstStride[1];
auto dstY = dstZ + z * outputStride[1];
auto srcY = srcZ + y * outputStride[1];
auto dstY = dstZ + y * dstStride[1];
for (int x=0; x<sizeX; ++x) {
auto dstOffset = x * dstStride[2];
dstY[dstOffset] = dstY[dstOffset] * srcY[x];
@ -1102,16 +1132,14 @@ public:
break;
case BinaryOpOperation_SUB:
for (int z=0; z<sizeZ; ++z) {
auto srcZ = srcF + z * dstStride[0];
auto dstZ = dstF + z * outputStride[0];
auto srcZ = srcF + z * outputStride[0];
auto dstZ = dstF + z * dstStride[0];
for (int y=0; y<sizeY; ++y) {
auto srcY = srcZ + z * dstStride[1];
auto dstY = dstZ + z * outputStride[1];
auto srcY = srcZ + y * outputStride[1];
auto dstY = dstZ + y * dstStride[1];
for (int x=0; x<sizeX; ++x) {
auto dstOffset = x * dstStride[2];
auto D = dstY[dstOffset];
auto S = srcY[x];
dstY[dstOffset] = D - S;
dstY[dstOffset] = dstY[dstOffset] - srcY[x];
}
}
}

View File

@ -119,54 +119,60 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
mPastKey.reset(new_key);
}
else if (mConfig.mQuantKey) {
auto new_key = Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP});
auto new_key = Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP});
mBackend->onAcquireBuffer(new_key, Backend::STATIC);
for (int h = 0; h < mKvNumHead; h++) {
memcpy(
new_key->host<char>() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP,
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP,
UP_DIV(oldMaxLength, hP) * mHeadDim * hP
new_key->host<char>() + h * new_key->stride(0),
mPastKey->host<char>() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP),
ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP)
);
}
mPastKey.reset(new_key);
}
else {
auto new_key = Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP});
auto new_key = Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP});
mBackend->onAcquireBuffer(new_key, Backend::STATIC);
for (int h = 0; h < mKvNumHead; h++) {
memcpy(
new_key->host<char>() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes,
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes,
UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes
new_key->host<char>() + h * new_key->stride(0) * mBytes,
mPastKey->host<char>() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes,
ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes
);
if ((new_key->stride(0) - mPastKey->stride(0)) > 0) {
memset(new_key->host<char>() + h * new_key->stride(0) * mBytes + mPastKey->stride(0) * mBytes, 0, (new_key->stride(0) - mPastKey->stride(0)) * mBytes);
}
}
mPastKey.reset(new_key);
}
/*=================================== Value ===================================*/
if (mConfig.mQuantValue) {
auto new_value = Tensor::createDevice<fp8_t>({mKvNumHead, UP_DIV(mHeadDim, hP), mMaxLength, hP});
auto new_value = Tensor::createDevice<fp8_t>({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(mMaxLength, lP), hP, lP});
mBackend->onAcquireBuffer(new_value, Backend::STATIC);
for (int h = 0; h < mKvNumHead; h++) {
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
memcpy(
new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP,
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP,
oldMaxLength * hP
new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP,
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
ROUND_UP(oldMaxLength, lP) * hP
);
}
}
mPastValue.reset(new_value);
}
else {
auto new_value = Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), mMaxLength, hP});
auto new_value = Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(mMaxLength, lP), hP, lP});
mBackend->onAcquireBuffer(new_value, Backend::STATIC);
for (int h = 0; h < mKvNumHead; h++) {
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
memcpy(
new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes,
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes,
oldMaxLength * hP * mBytes
new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes,
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
ROUND_UP(oldMaxLength, lP) * hP * mBytes
);
if ((new_value->stride(1) - mPastValue->stride(1)) > 0) {
memset(new_value->host<char>() + (h * new_value->stride(0) + i * new_value->stride(1)) * mBytes + mPastValue->stride(1) * mBytes, 0, (new_value->stride(1) - mPastValue->stride(1)) * mBytes);
}
}
}
mPastValue.reset(new_value);
@ -193,20 +199,23 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
if (mConfig.mQuantKey) {
for (int h = 0; h < mKvNumHead; h++) {
memcpy(
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP,
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP,
UP_DIV(oldMaxLength, hP) * mHeadDim * hP
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP
);
}
mBackend->onReleaseBuffer(mPastKey.get(), Backend::STATIC);
mPastKey.reset();
}
else {
if (mHeadDim % lP) {
memset(mMapKeyAddr, 0, mKvNumHead * ROUND_UP(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes );
}
for (int h = 0; h < mKvNumHead; h++) {
memcpy(
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes,
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes,
UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes
);
}
mBackend->onReleaseBuffer(mPastKey.get(), Backend::STATIC);
@ -217,9 +226,9 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
for (int h = 0; h < mKvNumHead; h++) {
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
memcpy(
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP,
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP,
oldMaxLength * hP
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP,
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
ROUND_UP(oldMaxLength, lP) * hP
);
}
}
@ -227,12 +236,15 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
mPastValue.reset();
}
else {
if (lP > 1) {
memset(mMapValueAddr, 0, mKvNumHead * ROUND_UP(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * mBytes);
}
for (int h = 0; h < mKvNumHead; h++) {
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
memcpy(
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes,
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes,
oldMaxLength * hP * mBytes
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes,
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
ROUND_UP(oldMaxLength, lP) * hP * mBytes
);
}
}
@ -250,17 +262,25 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
if (mConfig.mUseInt8Kernel) {
old_key.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(oldMaxLength, hP8), UP_DIV(mHeadDim, lP8), hP8 * lP8}));
} else if (mConfig.mQuantKey) {
old_key.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(oldMaxLength, hP), mHeadDim, hP}));
old_key.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(oldMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP}));
} else {
old_key.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(oldMaxLength, hP), mHeadDim, hP}));
old_key.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(oldMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP}));
}
if (mConfig.mQuantValue) {
old_value.reset(Tensor::createDevice<fp8_t>({mKvNumHead, UP_DIV(mHeadDim, hP), oldMaxLength, hP}));
old_value.reset(Tensor::createDevice<fp8_t>({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(oldMaxLength, lP), hP, lP}));
} else {
old_value.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), oldMaxLength, hP}));
old_value.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(oldMaxLength, lP), hP, lP}));
}
mBackend->onAcquireBuffer(old_key.get(), Backend::STATIC);
mBackend->onAcquireBuffer(old_value.get(), Backend::STATIC);
if (mHeadDim % lP) {
memset(old_key->host<uint8_t>(), 0, old_key->length(0) * old_key->stride(0) * mBytes);
}
if (lP > 1) {
// can't be mMaxLenth % lP, since mMaxLength may be larger than seq_len for prefilling, we should ensure the (mMaxLength - seq_len)'s buffer is 0.
// computing L is seq_len
memset(old_value->host<uint8_t>(), 0, old_value->length(0) * old_value->stride(0) * mBytes);
}
mmapKVCache(oldKeySize, oldValueSize);
memcpy(old_key->host<char>(), mMapKeyAddr, oldKeySize);
memcpy(old_value->host<char>(), mMapValueAddr, oldValueSize);
@ -280,17 +300,17 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
} else if (mConfig.mQuantKey) {
for (int h = 0; h < mKvNumHead; h++) {
memcpy(
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP,
old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP,
UP_DIV(oldMaxLength, hP) * mHeadDim * hP
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP
);
}
} else {
for (int h = 0; h < mKvNumHead; h++) {
memcpy(
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes,
old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes,
UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes
);
}
}
@ -298,9 +318,9 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
for (int h = 0; h < mKvNumHead; h++) {
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
memcpy(
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP,
old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP,
oldMaxLength * hP
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP,
old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
ROUND_UP(oldMaxLength, lP) * hP
);
}
}
@ -308,9 +328,9 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
for (int h = 0; h < mKvNumHead; h++) {
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
memcpy(
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes,
old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes,
oldMaxLength * hP * mBytes
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes,
old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
ROUND_UP(oldMaxLength, lP) * hP * mBytes
);
}
}
@ -341,11 +361,11 @@ void KVCacheManager::onAlloc(int kv_seq_len) {
if (mConfig.mUseInt8Kernel) {
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
} else if (mConfig.mQuantKey) {
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP;
keySize = (size_t)mKvNumHead * ROUND_UP(mMaxLength, hP) * ROUND_UP(mHeadDim, lP);
} else {
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes;
keySize = (size_t)mKvNumHead * ROUND_UP(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes;
}
valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
valueSize = (size_t)mKvNumHead * ROUND_UP(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * (mConfig.mQuantValue ? 1 : mBytes);
/*============== Put the kvcache in disk ===========*/
if (mConfig.mKVCacheSizeLimit != -1 && keySize + valueSize > mConfig.mKVCacheSizeLimit) {
createKVCacheFile();
@ -358,17 +378,23 @@ void KVCacheManager::onAlloc(int kv_seq_len) {
if (mConfig.mUseInt8Kernel) {
mPastKey.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mMaxLength, hP8), UP_DIV(mHeadDim, lP8), hP8 * lP8}));
} else if (mConfig.mQuantKey) {
mPastKey.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP}));
mPastKey.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP}));
} else {
mPastKey.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP}));
mPastKey.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP}));
}
if (mConfig.mQuantValue) {
mPastValue.reset(Tensor::createDevice<fp8_t>({mKvNumHead, UP_DIV(mHeadDim, hP), mMaxLength, hP}));
mPastValue.reset(Tensor::createDevice<fp8_t>({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(mMaxLength, lP), hP, lP}));
} else {
mPastValue.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), mMaxLength, hP}));
mPastValue.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(mMaxLength, lP), hP, lP}));
}
mBackend->onAcquireBuffer(mPastKey.get(), Backend::STATIC);
mBackend->onAcquireBuffer(mPastValue.get(), Backend::STATIC);
if (mHeadDim % lP) {
memset(mPastKey->host<int8_t>(), 0, mPastKey->length(0) * mPastKey->stride(0) * mBytes);
}
if (lP > 1) { // can't be mMaxLenth % lP, since mMaxLength may be larger than seq_len for prefilling, we should ensure the (mMaxLength - seq_len)'s buffer is 0.
memset(mPastValue->host<int8_t>(), 0, mPastValue->length(0) * mPastValue->stride(0) * mBytes);
}
}
// scale, zero point and sum of key for quantization
if (mConfig.mUseInt8Kernel) {
@ -397,14 +423,14 @@ void KVCacheManager::onRealloc(const KVMeta* meta) {
oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
} else if (mConfig.mQuantKey) {
oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * mHeadDim * hP;
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP;
oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP;
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP;
} else {
oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes;
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes;
oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes;
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes;
}
oldValueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * oldMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
oldValueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * ROUND_UP(oldMaxLength, lP) * hP * (mConfig.mQuantValue ? 1 : mBytes);
valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP * (mConfig.mQuantValue ? 1 : mBytes);
/*==== No limit for kvcache ====*/
if (mConfig.mKVCacheSizeLimit == -1) {
expandKVCacheInMem(oldMaxLength);
@ -485,14 +511,14 @@ void KVCacheManager::onRealloc(const KVMeta* meta) {
// mPastKey.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP}));
// Move K
auto keyStride = UP_DIV(mMaxLength, align) * align * mHeadDim;
auto dstKAddr = keyAddr() + dstStartAlign * mHeadDim * mBytes;
auto srcKAddr = keyAddr() + startAlign * mHeadDim * mBytes;
auto keyStride = UP_DIV(mMaxLength, align) * align * ROUND_UP(mHeadDim, lP);
auto dstKAddr = keyAddr() + dstStartAlign * ROUND_UP(mHeadDim, lP) * mBytes;
auto srcKAddr = keyAddr() + startAlign * ROUND_UP(mHeadDim, lP) * mBytes;
for (int i=0; i<mKvNumHead; ++i) {
auto dst = dstKAddr + i * keyStride * mBytes;
auto src = srcKAddr + i * keyStride * mBytes;
for (int j=0; j<sizeUnit; ++j) {
::memcpy(dst + j * align * mHeadDim * mBytes, src + j * align * mHeadDim * mBytes, align * mHeadDim * mBytes);
::memcpy(dst + j * align * ROUND_UP(mHeadDim, lP) * mBytes, src + j * align * ROUND_UP(mHeadDim, lP) * mBytes, align * ROUND_UP(mHeadDim, lP) * mBytes);
}
}
@ -503,8 +529,8 @@ void KVCacheManager::onRealloc(const KVMeta* meta) {
auto srcVAddr = valudAddr() + startAlign * align * mBytes;
auto number = mKvNumHead * UP_DIV(mHeadDim, align);
for (int i=0; i<number; ++i) {
auto dst = dstVAddr + i * mMaxLength * align * mBytes;
auto src = srcVAddr + i * mMaxLength * align * mBytes;
auto dst = dstVAddr + i * ROUND_UP(mMaxLength, lP) * align * mBytes;
auto src = srcVAddr + i * ROUND_UP(mMaxLength, lP) * align * mBytes;
for (int j=0; j<sizeUnit; ++j) {
::memcpy(dst + j * align * align * mBytes, src + j * align * align * mBytes, align * align * mBytes);
}
@ -521,11 +547,11 @@ void KVCacheManager::onClear() {
if (mConfig.mUseInt8Kernel) {
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
} else if (mConfig.mQuantKey) {
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP;
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP;
} else {
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes;
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes;
}
valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP * (mConfig.mQuantValue ? 1 : mBytes);
unmapKVCache(keySize, valueSize);
removeKVCacheFile();
mKVCacheInDisk = false;
@ -584,14 +610,16 @@ void KVCacheManager::pack_key(const Tensor* key, int seq_len, int kv_h) {
}
}
}
else { // [maxlen/hP, headdim, hP]
else { // target: [maxlen/hP, headdim/lP, hP, lP]
T * key_dst = reinterpret_cast<T*>(addrOfKey(kv_h));
auto stride0 = ROUND_UP(mHeadDim, lP) * hP;
auto stride1 = hP * lP;
for (int i = 0; i < seq_len; i++) {
T * key_src = key->host<T>() + i * mKvNumHead * mHeadDim + kv_h * mHeadDim;
int out_index = (mPastLength + i) / hP;
int in_index = (mPastLength + i) % hP;
for (int j = 0; j < mHeadDim; j++) {
key_dst[out_index * mHeadDim * hP + j * hP + in_index] = key_src[j];
key_dst[out_index * stride0 + (j / lP) * stride1 + in_index * lP + (j % lP)] = key_src[j];
}
}
}
@ -618,13 +646,18 @@ void KVCacheManager::pack_value(const Tensor* value, int seq_len, int kv_h) { //
MNNMemoryFreeAlign(buf);
}
else {
// [mHeadDim/hP, mMaxLength/lP, hP, lP]
auto stride0 = ROUND_UP(mMaxLength, lP) * hP;
auto stride1 = hP * lP;
T * value_dst = reinterpret_cast<T*>(addrOfValue(kv_h));
for (int i = 0; i < seq_len; i++) {
T * value_src = value->host<T>() + i * mKvNumHead * mHeadDim + kv_h * mHeadDim;
int seqLenOut = (mPastLength + i) / lP;
int seqLenIn = (mPastLength + i) % lP;
for (int j = 0; j < mHeadDim; j++) {
int out_index = j / hP;
int in_index = j % hP;
value_dst[out_index * mMaxLength * hP + (mPastLength + i) * hP + in_index] = value_src[j];
value_dst[out_index * stride0 + seqLenOut * stride1 + in_index * lP + seqLenIn] = value_src[j];
}
}
}

View File

@ -116,17 +116,17 @@ public:
if (mConfig.mUseInt8Kernel) {
return baseAddr + kv_h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
} else if (mConfig.mQuantKey) {
return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * mHeadDim * hP;
return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP;
} else {
return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes;
return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes;
}
}
char * addrOfValue(int kv_h) {
char * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host<char>();
if (mConfig.mQuantValue) {
return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * mMaxLength * hP;
return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP;
} else {
return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * mMaxLength * hP * mBytes;
return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP * mBytes;
}
}
char * addrOfScale(int kv_h) {

View File

@ -103,7 +103,7 @@ if (MNN_KLEIDIAI)
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c
)
set_source_files_properties(${MNN_SOURCES_KLEIDIAI} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+i8mm+dotprod+sve+sve2+fp16")
set_source_files_properties(${MNN_SOURCES_KLEIDIAI} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm+dotprod+sve+sve2+fp16)
set_source_files_properties(${KLEIDIAI_FILES_SME2} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+sve+sve2")
endif()
@ -121,7 +121,13 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv7" OR ARCHS MATCHES "^armv7(;armv7s)?")
endif()
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR ARCHS STREQUAL "ARM64")
message(STATUS "Enabling AArch64 Assemblies")
add_library(MNNARM64 OBJECT ${MNN_AArch64_SRC} ${MNN_NEON_SRC} ${MNN_SOURCES_KLEIDIAI} ${KLEIDIAI_FILES_SME2})
if (MNN_SME2)
add_definitions(-DMNN_SME2)
FILE(GLOB MNN_SME2_AArch64_SRC ${MNN_SME2_AArch64_SRC} ${CMAKE_CURRENT_LIST_DIR}/arm64/sme2_asm/*.[sS])
#set_source_files_properties(${MNN_SME2_AArch64_SRC} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.6-a+sve+sve2+sme+sme2+fp16")
set_source_files_properties(${MNN_SME2_SRCS_ASM_FP16} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+fp16")
endif()
add_library(MNNARM64 OBJECT ${MNN_AArch64_SRC} ${MNN_NEON_SRC} ${MNN_SOURCES_KLEIDIAI} ${KLEIDIAI_FILES_SME2} ${MNN_SME2_AArch64_SRC})
target_include_directories(MNNARM64 PRIVATE ${CMAKE_CURRENT_LIST_DIR}/)
list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNNARM64>)
list(APPEND MNN_TARGETS MNNARM64)

View File

@ -907,9 +907,10 @@ void MNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
// input shape is (l, h) when transpose=false, else input shape is (h, l)
// output shape is (UP_DIV(h, 8), l, 8)
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose) {
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
auto hP = (int)h / 8;
auto hR = (int)hP * 8;
auto l = kernelsize * ic;
if (hR != h) {
::memset(dest, 0, UP_DIV(h, 8)*8*l*sizeof(float));
}
@ -952,7 +953,8 @@ void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bo
}
}
#else
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose) {
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
auto l = kernelsize * ic;
if (!transpose) {
auto hP = h / 4;
auto hR = hP * 4;

View File

@ -28,7 +28,7 @@ void NEON_MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup
const int32_t* el);
void NEON_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose);
void NEON_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
void NEON_MNNPackedMatMul_BF16(float* C, const float* A, const float* B, const size_t* parameter,
const float* postParameters, const float* bias, const float* k, const float* b);
@ -44,7 +44,7 @@ void MNNPackC4_BF16(float* dest, const float* source, size_t area, size_t depth,
#ifdef __aarch64__
void MNNPackC8_BF16(float* dest, const float* source, size_t l, size_t h);
void ARMV86_MNNGetMatMulPackMode_BF16(int* eP, int* lP, int* hP);
void ARMV86_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose);
void ARMV86_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
void ARMV86_MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);
void ARMV86_MNNPackedMatMul_BF16(float* C, const float* A, const float* B, const size_t* parameter,
const float* postParameters, const float* bias, const float* k, const float* b);

View File

@ -90,8 +90,12 @@ sub \rg0, \rg0, \rg1
.endm
.macro REVERT_WEIGHT_KERNEL_SUM rg0, rg1, rg2, rg3
mul \rg1, \rg2, \rg3 // blocknum * ocUp4 * sizeof(float)
sub \rg0, \rg0, \rg1, LSL #4 // revert weight kernel sum
// y=UP_DIV(ocDiv4,(hp/pack))
add \rg1, \rg3, #1
lsr \rg1, \rg1, #1
// blockNum * y * (hp * sizeof(float))
mul \rg1, \rg2, \rg1
sub \rg0, \rg0, \rg1, LSL #5 // revert weight kernel sum
.endm
asm_function MNNGemmInt8AddBiasScale_ARMV82_Unit
/*
@ -435,9 +439,11 @@ L4_TILE12_BLOCKNUM:
L4_Tile12Quan:
ld1 {v0.4s}, [x2], #16 // scale
ld1 {v0.4s}, [x2] // scale
add x2, x2, #32
ld1 {v2.4s, v3.4s, v4.4s}, [x8], #48 // x kernel sum
ld1 {v5.4s}, [x2], #16 // weight quan zeropoint
ld1 {v5.4s}, [x2] // weight quan zeropoint
add x2, x2, #32
Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15
Int32ToFloat v16, v17, v18, v19
@ -470,7 +476,7 @@ L4_TILE12_BLOCKNUM:
cbz x27, L4_TILE12_ADD_DSTV
ld1 {v0.4s, v1.4s, v2.4s}, [x27], #48 // input dequant bias
ld1 {v3.4s}, [x28], #16 // weight kernel sum
ld1 {v3.4s}, [x28] // weight kernel sum
MLA_WEIGHTZERO v8, v0, v3, 0
MLA_WEIGHTZERO v9, v0, v3, 1
MLA_WEIGHTZERO v10, v0, v3, 2
@ -483,6 +489,7 @@ L4_TILE12_BLOCKNUM:
MLA_WEIGHTZERO v17, v2, v3, 1
MLA_WEIGHTZERO v18, v2, v3, 2
MLA_WEIGHTZERO v19, v2, v3, 3
add x28, x28, #32
L4_TILE12_ADD_DSTV:
cbz x19, L4_TILE12_ACCUM_BUFFER // x19=0: first block, do not add previous block result
@ -782,9 +789,11 @@ L4_TILE8_BLOCKNUM:
bne L4LoopSz_TILE_8
L4Tile8Quan:
ld1 {v0.4s}, [x12], #16 // scale
ld1 {v0.4s}, [x12] // scale
add x12, x12, #32
ld1 {v2.4s, v3.4s}, [x8], x22 // x kernel sum
ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
ld1 {v24.4s}, [x12] // weight quan zeropoint
add x12, x12, #32
Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15
MUL_SCALE v0, v8, v9, v10, v11
@ -808,7 +817,7 @@ L4_TILE8_BLOCKNUM:
cbz x27, L4_TILE8_ADD_DSTV
ld1 {v2.4s, v3.4s}, [x27], x25
ld1 {v24.4s}, [x28], #16
ld1 {v24.4s}, [x28]
MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3
MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3
MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3
@ -817,6 +826,7 @@ L4_TILE8_BLOCKNUM:
MLA_WEIGHTZERO v13, v3, v24, 1 // tile:5, oc:0-3
MLA_WEIGHTZERO v14, v3, v24, 2 // tile:6, oc:0-3
MLA_WEIGHTZERO v15, v3, v24, 3 // tile:7, oc:0-3
add x28, x28, #32
L4_TILE8_ADD_DSTV:
cbz x19, TILE8_L4_ACCUM_BUFFER
@ -1055,9 +1065,11 @@ L4_TILE4_BLOCKNUM:
bne L4LoopSz_TILE_4
L4Tile4Quan:
ld1 {v0.4s}, [x12], #16 // scale
ld1 {v0.4s}, [x12] // scale
add x12, x12, #32
ld1 {v2.4s}, [x8], x22 // x kernel sum
ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
ld1 {v24.4s}, [x12] // weight quan zeropoint
add x12, x12, #32
Int32ToFloat v8, v9, v10, v11
MUL_SCALE v0, v8, v9, v10, v11
@ -1074,11 +1086,12 @@ L4_TILE4_BLOCKNUM:
cbz x27, L4_TILE4_ADD_DSTV
ld1 {v2.4s}, [x27], x25
ld1 {v24.4s}, [x28], #16
ld1 {v24.4s}, [x28]
MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3
MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3
MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3
MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3
add x28, x28, #32
L4_TILE4_ADD_DSTV:
cbz x19, TILE4_L4_ACCUM_BUFFER
@ -1281,9 +1294,11 @@ L4_TILE1_BLOCKNUM:
bne L4LoopSz_TILE_1
L4Tile1Quan:
ld1 {v0.4s}, [x12], #16 // scale
ld1 {v0.4s}, [x12] // scale
add x12, x12, #32
ld1 {v2.s}[0], [x8], x22 // x kernel sum
ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
ld1 {v24.4s}, [x12] // weight quan zeropoint
add x12, x12, #32
scvtf v8.4s, v8.4s
fmul v8.4s, v8.4s, v0.4s
@ -1297,8 +1312,9 @@ L4_TILE1_BLOCKNUM:
cbz x27, L4_TILE1_ADD_DSTV
ld1 {v2.s}[0], [x27], x25
ld1 {v24.4s}, [x28], #16
ld1 {v24.4s}, [x28]
MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3
add x28, x28, #32
L4_TILE1_ADD_DSTV:
cbz x19, L4_TILE1_L8_ACCUM_BUFFER

View File

@ -101,8 +101,12 @@ sub \rg0, \rg0, \rg1
.endm
.macro REVERT_WEIGHT_KERNEL_SUM rg0, rg1, rg2, rg3
mul \rg1, \rg2, \rg3 // blocknum * ocUp4 * sizeof(float)
sub \rg0, \rg0, \rg1, LSL #4 // revert weight kernel sum
// y=UP_DIV(ocDiv4,(hp/pack))
add \rg1, \rg3, #1
lsr \rg1, \rg1, #1
// blockNum * y * (hp * sizeof(float))
mul \rg1, \rg2, \rg1
sub \rg0, \rg0, \rg1, LSL #5 // revert weight kernel sum
.endm
asm_function MNNGemmInt8AddBiasScale_ARMV86_Unit
/*
@ -500,10 +504,12 @@ L4_LoopSz_TILE_10:
scvtf v9.4s, v9.4s
L4_Tile10Quan:
ld1 {v20.4s}, [x2], #16 // weight scale
ld1 {v20.4s}, [x2] // weight scale
add x2, x2, #32
ld1 {v22.4s, v23.4s}, [x8], #32 // x kernel sum
ld1 {v24.d}[0], [x8], #8
ld1 {v25.4s}, [x2], #16 // weight quan zeropoint
ld1 {v25.4s}, [x2] // weight quan zeropoint
add x2, x2, #32
MUL_SCALE v20, v0, v1, v2, v3
MUL_SCALE v20, v4, v5, v6, v7
fmul v8.4s, v8.4s, v20.4s
@ -534,7 +540,7 @@ L4_Tile10Quan:
cbz x27, L4_TILE10_ADD_DSTV
ld1 {v22.4s, v23.4s}, [x27], #32 // input dequant bias
ld1 {v24.2s}, [x27], #8
ld1 {v25.4s}, [x28], #16 // weight kernel sum
ld1 {v25.4s}, [x28] // weight kernel sum
MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3
MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3
MLA_WEIGHTZERO v2, v22, v25, 2 // tile:2, oc:0-3
@ -545,6 +551,7 @@ L4_Tile10Quan:
MLA_WEIGHTZERO v7, v23, v25, 3 // tile:7, oc:0-3
MLA_WEIGHTZERO v8, v24, v25, 0 // tile:8, oc:0-3
MLA_WEIGHTZERO v9, v24, v25, 1 // tile:9, oc:0-3
add x28, x28, #32
L4_TILE10_ADD_DSTV:
cbz x19, L4_TILE10_TEMP_BUFFER
@ -902,9 +909,11 @@ LoopSz4End_TILE_8:
Int32ToFloat v4, v5, v6, v7
L4_Tile8Quan:
ld1 {v20.4s}, [x12], #16 // scale
ld1 {v20.4s}, [x12] // scale
add x12, x12, #32
ld1 {v22.4s, v23.4s}, [x8] // x kernel sum
ld1 {v25.4s}, [x12], #16 // weight quan zeropoint
ld1 {v25.4s}, [x12] // weight quan zeropoint
add x12, x12, #32
add x8, x8, x22, LSR #1
MUL_SCALE v20, v0, v1, v2, v3
MUL_SCALE v20, v4, v5, v6, v7
@ -928,7 +937,7 @@ L4_Tile8Quan:
cbz x27, L4_TILE8_ADD_DSTV
ld1 {v22.4s, v23.4s}, [x27], x25 // input dequant bias
ld1 {v25.4s}, [x28], #16 // weight kernel sum
ld1 {v25.4s}, [x28] // weight kernel sum
MLA_WEIGHTZERO v0, v22, v25, 0
MLA_WEIGHTZERO v1, v22, v25, 1
MLA_WEIGHTZERO v2, v22, v25, 2
@ -937,6 +946,7 @@ L4_Tile8Quan:
MLA_WEIGHTZERO v5, v23, v25, 1
MLA_WEIGHTZERO v6, v23, v25, 2
MLA_WEIGHTZERO v7, v23, v25, 3
add x28, x28, #32
L4_TILE8_ADD_DSTV:
cbz x19, L4_TILE8_TEMP_BUFFER
@ -1182,9 +1192,11 @@ L4_LoopSz_TILE_4:
Int32ToFloat v0, v1, v2, v3
L4_Tile4Quan:
ld1 {v20.4s}, [x12], #16 // scale
ld1 {v20.4s}, [x12] // scale
add x12, x12, #32
ld1 {v22.4s}, [x8] // x kernel sum
ld1 {v25.4s}, [x12], #16 // weight quan zeropoint
ld1 {v25.4s}, [x12] // weight quan zeropoint
add x12, x12, #32
MUL_SCALE v20, v0, v1, v2, v3
add x8, x8, x22, LSR #1
@ -1202,11 +1214,12 @@ L4_Tile4Quan:
cbz x27, L4_TILE4_ADD_DSTV
ld1 {v22.4s}, [x27], x25 // input dequant bias
ld1 {v25.4s}, [x28], #16 // weight kernel sum
ld1 {v25.4s}, [x28] // weight kernel sum
MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3
MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3
MLA_WEIGHTZERO v2, v22, v25, 2 // tile:2, oc:0-3
MLA_WEIGHTZERO v3, v22, v25, 3 // tile:3, oc:0-3
add x28, x28, #32
L4_TILE4_ADD_DSTV:
cbz x19, L4_TILE4_ACCUM_BUFFER
@ -1414,9 +1427,11 @@ L4_LoopSz_TILE_2:
scvtf v1.4s, v1.4s
L4_Tile2Quan:
ld1 {v20.4s}, [x12], #16 // scale
ld1 {v20.4s}, [x12] // scale
add x12, x12, #32
ld1 {v22.d}[0], [x8] // x kernel sum
ld1 {v25.4s}, [x12], #16 // weight quan zeropoint
ld1 {v25.4s}, [x12] // weight quan zeropoint
add x12, x12, #32
fmul v0.4s, v0.4s, v20.4s
fmul v1.4s, v1.4s, v20.4s
add x8, x8, x22, LSR #1
@ -1434,9 +1449,10 @@ L4_Tile2Quan:
cbz x27, L4_TILE2_ADD_DSTV
ld1 {v22.2s}, [x27], x25 // input dequant bias
ld1 {v25.4s}, [x28], #16 // weight kernel sum
ld1 {v25.4s}, [x28] // weight kernel sum
MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3
MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3
MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3]
add x28, x28, #32
L4_TILE2_ADD_DSTV:
cbz x19, L4_TILE2_ACCUM_BUFFER
@ -1636,9 +1652,11 @@ L4_LoopSzEnd_TILE_1:
scvtf v25.4s, v25.4s
L4_Tile1Quan:
ld1 {v0.4s}, [x12], #16 // scale
ld1 {v0.4s}, [x12] // scale
add x12, x12, #32
ld1 {v6.s}[0], [x8] // x kernel sum
ld1 {v8.4s}, [x12], #16 // weight quan zeropoint
ld1 {v8.4s}, [x12] // weight quan zeropoint
add x12, x12, #32
fmul v25.4s, v25.4s, v0.4s
add x8, x8, x22, LSR #1
cbz x21, L4_TILE1_MLA
@ -1652,8 +1670,9 @@ L4_Tile1Quan:
cbz x27, L4_TILE1_ADD_DSTV
ld1 {v6.s}[0], [x27], x25 // input dequant bias
ld1 {v8.4s}, [x28], #16 // weight kernel sum
ld1 {v8.4s}, [x28] // weight kernel sum
MLA_WEIGHTZERO v25, v6, v8, 0 // tile:0, oc:0-3
add x28, x28, #32
L4_TILE1_ADD_DSTV:
cbz x19, L4_TILE1_ACCUM_BUFFER

View File

@ -0,0 +1,337 @@
#ifdef __aarch64__
#include "MNNAsmGlobal.h"
.text
.align 5
.macro SET_0 s0, s1, s2, s3
movi \s0\().4s, #0
movi \s1\().4s, #0
movi \s2\().4s, #0
movi \s3\().4s, #0
.endm
/*
struct SumByAxisParams {
ssize_t kernelCountUnitDouble;
ssize_t col_buffer_unit_size;
ssize_t DST_XUNIT;
ssize_t SRC_UNIT;
ssize_t blockNum;
ssize_t oneScale;
};
*/
asm_function MNNSumByAxisLForMatmul_A_SME2
// MNNSumByAxisLForMatmul_A_SME2(float_t* dest, int8_t* source, float* dequantScale, ssize_t realDstCount,
// ssize_t kernelCountUnitDouble, ssize_t col_buffer_unit_size, ssize_t EP, ssize_t LP, ssize_t blockNum, ssize_t oneScale);
// x0: dest, x1: source, x2: dequantScale, x3: realDstCount, x4: sumParams
// x5: oneScale
// Load from sp: x8: blockNum
// EP=16, LP=4, HP=16
ldr x12, [x4, #48] // Valid
ldr x8, [x4, #32] // blockNum
ldr x5, [x4, #40] // oneScale
ldr x14, [x4, #56] // kx*ky
ldr x15, [x4, #72] // input block quant, 0:no, 1:yes
ldr x4, [x4, #64] // LU
stp d14, d15, [sp, #(-16 * 5)]!
stp d12, d13, [sp, #(16 * 1)]
stp d10, d11, [sp, #(16 * 2)]
stp d8, d9, [sp, #(16 * 3)]
stp x20, x21, [sp, #(16 * 4)]
movi v31.16b, #1
mov v29.16b, v31.16b
ld1r {v30.4s}, [x2] // Dequant scale
sdiv x4, x4, x8 // src_depth_quad per block
cbz x12, Start
mov x13, #0xFFFFFFFF
lsl x12, x12, #3
lsl x13, x13, x12
dup v28.4s, w13
bic v29.16b, v31.16b, v28.16b
Start:
mov x13, x15 // input block quant, 0:no, 1:yes
TILE_16:
cmp x3, #16
blt Remain
mov x9, x8 // blockNum
cbnz x13, TILE16_BLOCK_NUM
ld1 {v13.4s, v14.4s, v15.4s, v16.4s}, [x2], #64 // batch quant scale
TILE16_BLOCK_NUM:
mov x15, x14 // kx*ky
movi v9.4s, #0
movi v10.4s, #0
movi v11.4s, #0
movi v12.4s, #0
/* for range(kx*ky)...for range(ic/pack) */
TILE16_BLOCK_INNER:
sub x12, x4, #1 // icDiv4
cbz x12, TILE16_LAST_QUAD
TILE16_PRE_QUAD:
ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64 // E: 0,1,2,3,...,15
.inst 0x4e8097e9 // sdot v9.4s, v31.16b, v0.16b
.inst 0x4e8197ea // sdot v10.4s, v31.16b, v1.16b
.inst 0x4e8297eb // sdot v11.4s, v31.16b, v2.16b
.inst 0x4e8397ec // sdot v12.4s, v31.16b, v3.16b
subs x12, x12, #1 // icDiv4--
bne TILE16_PRE_QUAD
TILE16_LAST_QUAD:
ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64 // E: 0,1,2,3,...,11
.inst 0x4e8097a9 // sdot v9.4s, v29.16b, v0.16b
.inst 0x4e8197aa // sdot v10.4s, v29.16b, v1.16b
.inst 0x4e8297ab // sdot v11.4s, v29.16b, v2.16b
.inst 0x4e8397ac // sdot v12.4s, v29.16b, v3.16b
subs x15, x15, #1
bne TILE16_BLOCK_INNER
TILE16_BLOCK_INNER_END:
subs x9, x9, #1 // blockNum--
scvtf v9.4s, v9.4s
scvtf v10.4s, v10.4s
scvtf v11.4s, v11.4s
scvtf v12.4s, v12.4s
cbnz x5, TILE16_MUL_ONE_SCALE
cbz x13, TILE16_MUL_BLOCK_SCALE
ld1 {v13.4s, v14.4s, v15.4s, v16.4s}, [x2], #64 // batch quant scale, input block quant
TILE16_MUL_BLOCK_SCALE:
fmul v9.4s, v9.4s, v13.4s
fmul v10.4s, v10.4s, v14.4s
fmul v11.4s, v11.4s, v15.4s
fmul v12.4s, v12.4s, v16.4s
b TILE16_STORE
TILE16_MUL_ONE_SCALE:
fmul v9.4s, v9.4s, v30.4s
fmul v10.4s, v10.4s, v30.4s
fmul v11.4s, v11.4s, v30.4s
fmul v12.4s, v12.4s, v30.4s
TILE16_STORE:
st1 {v9.4s, v10.4s, v11.4s, v12.4s}, [x0], #64
bne TILE16_BLOCK_NUM
TILE16_END:
subs x3, x3, #16 // realDstCount-=16
bne TILE_16
Remain: // remain realDstCount < EP
cbz x3, End
/* x11: Remain dstCount step for each block */
lsl x11, x3, #2
lsl x6, x3, #2 // x6=eDest * LP
mov x20, x2
TILE_12:
cmp x3, #12
blt TILE_2
mov x7, x1
mov x9, x8 // blockNum
mov x10, x0 // tag dst address
mov x9, x8 // blockNum
cbnz x13, TILE12_BLOCK_NUM
ld1 {v13.4s, v14.4s, v15.4s}, [x2], #48 // batch quant scale
TILE12_BLOCK_NUM:
mov x15, x14 // kx*ky
movi v10.4s, #0
movi v11.4s, #0
movi v12.4s, #0
/* for range(kx*ky)...for range(ic/pack) */
TILE12_BLOCK_INNER:
sub x12, x4, #1 // icDiv4
cbz x12, TILE12_LAST_QUAD
TILE12_PRE_QUAD:
ld1 {v0.16b, v1.16b, v2.16b}, [x7], x6 // E: 0,1,2,3,...,11
.inst 0x4e8097ea // sdot v10.4s, v31.16b, v0.16b // sum LP axis for E0, E1, E2, E3
.inst 0x4e8197eb // sdot v11.4s, v31.16b, v1.16b
.inst 0x4e8297ec // sdot v12.4s, v31.16b, v2.16b
subs x12, x12, #1 // icDiv4--
bne TILE12_PRE_QUAD
TILE12_LAST_QUAD:
ld1 {v0.16b, v1.16b, v2.16b}, [x7], x6 // E: 0,1,2,3,...,11
.inst 0x4e8097aa // sdot v10.4s, v29.16b, v0.16b // sum LP axis for E0, E1, E2, E3
.inst 0x4e8197ab // sdot v11.4s, v29.16b, v1.16b
.inst 0x4e8297ac // sdot v12.4s, v29.16b, v2.16b
subs x15, x15, #1
bne TILE12_BLOCK_INNER
TILE12_BLOCK_INNER_END:
subs x9, x9, #1 // blockNum--
scvtf v10.4s, v10.4s
scvtf v11.4s, v11.4s
scvtf v12.4s, v12.4s
cbnz x5, TILE12_MUL_ONE_SCALE
cbz x13, TILE12_MUL_BLOCK_SCALE
ld1 {v13.4s, v14.4s, v15.4s}, [x2], x6 // batch quant scale, input block quant
TILE12_MUL_BLOCK_SCALE:
fmul v10.4s, v10.4s, v13.4s
fmul v11.4s, v11.4s, v14.4s
fmul v12.4s, v12.4s, v15.4s
b TILE12_STORE
TILE12_MUL_ONE_SCALE:
fmul v10.4s, v10.4s, v30.4s
fmul v11.4s, v11.4s, v30.4s
fmul v12.4s, v12.4s, v30.4s
TILE12_STORE:
st1 {v10.4s, v11.4s, v12.4s}, [x10], x11
bne TILE12_BLOCK_NUM
TILE12_END:
subs x3, x3, #12 // realDstCount-=12
add x1, x1, #48 // LP * 12 * sizeof(int8_t)
add x0, x0, #48 // finish 12*sizeof(float)
add x2, x20, #48 // x20 + 12 * sizeof(float)
mov x20, x2
TILE_2: // realDstCount >= 1
cmp x3, #2
blt TILE_1
mov x7, x1
mov x9, x8 // blockNum
mov x10, x0 // tag dst address
cbnz x13, TILE2_BLOCK_NUM
ld1 {v13.d}[0], [x2], #8 // batch quant scale
TILE2_BLOCK_NUM:
mov x15, x14 // kx*ky
movi v10.4s, #0
TILE2_BLOCK_INNER: // range(kxky)
sub x12, x4, #1 // icDiv4
cbz x12, TILE2_LAST_QUAD
TILE2_PRE_QUAD: // range(icDiv4)
ld1 {v0.d}[0], [x7], x6 // E: 0,1
subs x12, x12, #1
.inst 0x4e8097ea // sdot v10.4s, v31.16b, v0.16b // sum LP axis for E0
bne TILE2_PRE_QUAD
TILE2_LAST_QUAD:
ld1 {v0.d}[0], [x7], x6 // E: 0,1
.inst 0x4e8097aa // sdot v10.4s, v29.16b, v0.16b
subs x15, x15, #1 // kxky--
bne TILE2_BLOCK_INNER
TILE2_BLOCK_INNER_END:
scvtf v10.4s, v10.4s
cbnz x5, TILE2_MUL_ONE_SCALE
cbz x13, TILE2_MUL_BLOCK_SCALE
ld1 {v13.d}[0], [x2], x6 // batch quant scale
TILE2_MUL_BLOCK_SCALE:
fmul v10.4s, v10.4s, v13.4s
b TILE2_STORE
TILE2_MUL_ONE_SCALE:
fmul v10.4s, v10.4s, v30.4s
TILE2_STORE:
subs x9, x9, #1 // blockNum--
st1 {v10.d}[0], [x10], x11
bne TILE2_BLOCK_NUM
TILE2_END:
sub x3, x3, #2 // realDstCount-=2
add x1, x1, #8 // LP * 2
add x0, x0, #8 // finish remain 2
add x2, x20, #8 // x20 + 2 * sizeof(float)
mov x20, x2
b TILE_2
TILE_1: // realDstCount >= 1
cmp x3, #1
blt End
mov x7, x1
mov x9, x8 // blockNum
mov x10, x0
cbnz x13, TILE1_BLOCK_NUM
ld1 {v13.s}[0], [x2], #4 // batch quant scale
TILE1_BLOCK_NUM:
mov x15, x14 // kx*ky
movi v10.4s, #0
TILE1_BLOCK_INNER:
sub x12, x4, #1
cbz x12, TILE1_LAST_QUAD
TILE1_PRE_QUAD:
ld1 {v0.s}[0], [x7] // E: 0
add x7, x7, x6
.inst 0x4e8097ea // sdot v10.4s, v31.16b, v0.16b // sum LP axis for E0
subs x12, x12, #1 // icDiv4--
bne TILE1_PRE_QUAD
TILE1_LAST_QUAD:
ld1 {v0.s}[0], [x7], x6 // E: 0
.inst 0x4e8097aa // sdot v10.4s, v29.16b, v0.16b
subs x15, x15, #1 // kxky--
bne TILE1_BLOCK_INNER
TILE1_BLOCK_INNER_END:
scvtf v10.4s, v10.4s
cbnz x5, TILE1_MUL_ONE_SCALE
cbz x13, TILE1_MUL_BLOCK_SCALE
ld1 {v13.s}[0], [x2], x6 // batch quant scale
TILE1_MUL_BLOCK_SCALE:
fmul v10.4s, v10.4s, v13.4s
b TILE1_STORE
TILE1_MUL_ONE_SCALE:
fmul v10.4s, v10.4s, v30.4s
TILE1_STORE:
subs x9, x9, #1 // blockNum--
st1 {v10.s}[0], [x10], x11
bne TILE1_BLOCK_NUM
TILE1_END:
sub x3, x3, #1 // realDstCount-=1
add x1, x1, #4 // LP * 1
add x0, x0, #4 // finish remain 1
add x2, x20, #4 // x20 + 1 * sizeof(float)
mov x20, x2
b TILE_1
End:
ldp x20, x21, [sp, #(16 * 4)]
ldp d8, d9, [sp, #(16 * 3)]
ldp d10, d11, [sp, #(16 * 2)]
ldp d12, d13, [sp, #(16 * 1)]
ldp d14, d15, [sp], #(16 * 5)
ret
#endif

View File

@ -0,0 +1,151 @@
#ifdef __aarch64__
#include "MNNAsmGlobal.h"
.text
.align 5
.macro SET_0 s0, s1, s2, s3
movi \s0\().4s, #0
movi \s1\().4s, #0
movi \s2\().4s, #0
movi \s3\().4s, #0
.endm
.macro Int32_To_Float32 s0, s1, s2, s3
scvtf \s0\().4s, \s0\().4s
scvtf \s1\().4s, \s1\().4s
scvtf \s2\().4s, \s2\().4s
scvtf \s3\().4s, \s3\().4s
.endm
asm_function MNNPermuteSumWeightInt4Sme2
// void MNNPermuteSumWeightInt4Sme2(uint8_t* dest, uint8_t* source, size_t outside, size_t inside, float* kernelSum);
// auto load: x0: dest, x1: source, x2: outside, x3: inside, x4: kernelSum
// inside = lu
// outside = blocknum*hu
// kernelSum shape: [hu, blockNum, hp]
stp d14, d15, [sp, #-64]!
stp d12, d13, [sp, #16]
stp d10, d11, [sp, #32]
stp d8, d9, [sp, #48]
movi v31.16b, #15
movi v30.16b, #4
movi v29.16b, #1
Loop: // blocknum*hu
mov x6, x3 // lu
SET_0 v4, v5, v6, v7
SET_0 v20, v21, v22, v23
cmp x6, #2
blt LoopLU
LoopLU2:
ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64
ushr v8.16b, v0.16b, #4 // v8: 0 2 ... 30
and v9.16b, v0.16b, v31.16b // v9: 1 3 ... 31
ushr v10.16b, v1.16b, #4
and v11.16b, v1.16b, v31.16b
ushr v27.16b, v2.16b, #4
and v28.16b, v2.16b, v31.16b
ushr v24.16b, v3.16b, #4
and v25.16b, v3.16b, v31.16b
zip1 v12.16b, v8.16b, v9.16b // v12: 0 1 2 3 ... 14 15
zip2 v13.16b, v8.16b, v9.16b // v13: 16 17 18 19 ... 30 31
zip1 v14.16b, v10.16b, v11.16b // v14: 32 33 34 35 ... 46 47
zip2 v15.16b, v10.16b, v11.16b // v15: 48 49 50 51 ... 62 63
zip1 v16.16b, v27.16b, v28.16b
zip2 v17.16b, v27.16b, v28.16b
zip1 v18.16b, v24.16b, v25.16b
zip2 v19.16b, v24.16b, v25.16b
// weight kernel sum
.inst 0x6e8c97a4 // udot v4.4s, v29.16b, v12.16b
.inst 0x6e8d97a5 // udot v5.4s, v29.16b, v13.16b
.inst 0x6e8e97a6 // udot v6.4s, v29.16b, v14.16b
.inst 0x6e8f97a7 // udot v7.4s, v29.16b, v15.16b
.inst 0x6e9097b4 // udot v20.4s, v29.16b, v16.16b
.inst 0x6e9197b5 // udot v21.4s, v29.16b, v17.16b
.inst 0x6e9297b6 // udot v22.4s, v29.16b, v18.16b
.inst 0x6e9397b7 // udot v23.4s, v29.16b, v19.16b
sub x6, x6, #2
// transpose
ushl v9.16b, v9.16b, v30.16b
ushl v11.16b, v11.16b, v30.16b
ushl v28.16b, v28.16b, v30.16b
ushl v25.16b, v25.16b, v30.16b
orr v0.16b, v8.16b, v9.16b
orr v1.16b, v10.16b, v11.16b
orr v2.16b, v27.16b, v28.16b
orr v3.16b, v24.16b, v25.16b
st1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64
cmp x6, #2
bge LoopLU2
cbz x6, LUEnd
LoopLU:
cbz x6, LUEnd
ld1 {v0.16b, v1.16b}, [x1], #32
ushr v8.16b, v0.16b, #4 // v8: 0 2 ... 30
and v9.16b, v0.16b, v31.16b // v9: 1 3 ... 31
ushr v10.16b, v1.16b, #4
and v11.16b, v1.16b, v31.16b
zip1 v12.16b, v8.16b, v9.16b // v12: 0 1 2 3 ... 14 15
zip2 v13.16b, v8.16b, v9.16b // v13: 16 17 18 19 ... 30 31
zip1 v14.16b, v10.16b, v11.16b // v14: 32 33 34 35 ... 46 47
zip2 v15.16b, v10.16b, v11.16b // v15: 48 49 50 51 ... 62 63
// weight kernel sum
.inst 0x6e8c97a4 // udot v4.4s, v29.16b, v12.16b
.inst 0x6e8d97a5 // udot v5.4s, v29.16b, v13.16b
.inst 0x6e8e97a6 // udot v6.4s, v29.16b, v14.16b
.inst 0x6e8f97a7 // udot v7.4s, v29.16b, v15.16b
// <<4
ushl v9.16b, v9.16b, v30.16b
ushl v11.16b, v11.16b, v30.16b
orr v0.16b, v8.16b, v9.16b
orr v1.16b, v10.16b, v11.16b
st1 {v0.16b, v1.16b}, [x0], #32
LUEnd:
add v4.4s, v4.4s, v20.4s
add v5.4s, v5.4s, v21.4s
add v6.4s, v6.4s, v22.4s
add v7.4s, v7.4s, v23.4s
scvtf v4.4s, v4.4s
scvtf v5.4s, v5.4s
scvtf v6.4s, v6.4s
scvtf v7.4s, v7.4s
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x4], #64
subs x2, x2, #1 // outside--
bne Loop
End:
ldp d8, d9, [sp, #48]
ldp d10, d11, [sp, #32]
ldp d12, d13, [sp, #16]
ldp d14, d15, [sp], #64
ret
#endif

View File

@ -0,0 +1,127 @@
#ifdef __aarch64__
#include "MNNAsmGlobal.h"
.text
.align 5
.macro SET_0 s0, s1, s2, s3
movi \s0\().4s, #0
movi \s1\().4s, #0
movi \s2\().4s, #0
movi \s3\().4s, #0
.endm
.macro Int32_To_Float32 s0, s1, s2, s3
scvtf \s0\().4s, \s0\().4s
scvtf \s1\().4s, \s1\().4s
scvtf \s2\().4s, \s2\().4s
scvtf \s3\().4s, \s3\().4s
.endm
asm_function MNNSumWeightInt8Sme2
// void MNNSumWeightInt8Sme2(float* kernlesum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP)
// auto load: x0: dest, x1: source, x2: outside, x3: reduceAxis, x4: hP, x5: lP
// weight shape: [outside, reduceAxis, hP, lP]
// outside = blocknum * hU
// reduceAxis = kernelCount * lU
stp d14, d15, [sp, #-64]!
stp d12, d13, [sp, #16]
stp d10, d11, [sp, #32]
stp d8, d9, [sp, #48]
movi v31.16b, #1
Loop:
mov x5, x3
SET_0 v16, v17, v18, v19
SET_0 v20, v21, v22, v23
SET_0 v24, v25, v26, v27
LU3:
cmp x5, #3
blt LU2
ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64
ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64
ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x1], #64
// kernel sum
.inst 0x4e8097f0 // sdot v16.4s, v31.16b, v0.16b
.inst 0x4e8197f1 // sdot v17.4s, v31.16b, v1.16b
.inst 0x4e8297f2 // sdot v18.4s, v31.16b, v2.16b
.inst 0x4e8397f3 // sdot v19.4s, v31.16b, v3.16b
.inst 0x4e8497f4 // sdot v20.4s, v31.16b, v4.16b
.inst 0x4e8597f5 // sdot v21.4s, v31.16b, v5.16b
.inst 0x4e8697f6 // sdot v22.4s, v31.16b, v6.16b
.inst 0x4e8797f7 // sdot v23.4s, v31.16b, v7.16b
.inst 0x4e8897f8 // sdot v24.4s, v31.16b, v8.16b
.inst 0x4e8997f9 // sdot v25.4s, v31.16b, v9.16b
.inst 0x4e8a97fa // sdot v26.4s, v31.16b, v10.16b
.inst 0x4e8b97fb // sdot v27.4s, v31.16b, v11.16b
subs x5, x5, #3
beq LUEnd
b LU3
LU2:
cmp x5, #2
blt LU1
ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64
ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64
// kernel sum
.inst 0x4e8097f0 // sdot v16.4s, v31.16b, v0.16b
.inst 0x4e8197f1 // sdot v17.4s, v31.16b, v1.16b
.inst 0x4e8297f2 // sdot v18.4s, v31.16b, v2.16b
.inst 0x4e8397f3 // sdot v19.4s, v31.16b, v3.16b
.inst 0x4e8497f4 // sdot v20.4s, v31.16b, v4.16b
.inst 0x4e8597f5 // sdot v21.4s, v31.16b, v5.16b
.inst 0x4e8697f6 // sdot v22.4s, v31.16b, v6.16b
.inst 0x4e8797f7 // sdot v23.4s, v31.16b, v7.16b
subs x5, x5, #2
beq LUEnd
b LU2
LU1: // outside
cbz x5, LUEnd
ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64
// kernel sum
.inst 0x4e8097f0 // sdot v16.4s, v31.16b, v0.16b
.inst 0x4e8197f1 // sdot v17.4s, v31.16b, v1.16b
.inst 0x4e8297f2 // sdot v18.4s, v31.16b, v2.16b
.inst 0x4e8397f3 // sdot v19.4s, v31.16b, v3.16b
LUEnd:
add v16.4s, v16.4s, v20.4s
add v17.4s, v17.4s, v21.4s
add v18.4s, v18.4s, v22.4s
add v19.4s, v19.4s, v23.4s
add v16.4s, v16.4s, v24.4s
add v17.4s, v17.4s, v25.4s
add v18.4s, v18.4s, v26.4s
add v19.4s, v19.4s, v27.4s
scvtf v16.4s, v16.4s
scvtf v17.4s, v17.4s
scvtf v18.4s, v18.4s
scvtf v19.4s, v19.4s
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64
subs x2, x2, #1 // outside--
bne Loop
End:
ldp d8, d9, [sp, #48]
ldp d10, d11, [sp, #32]
ldp d12, d13, [sp, #16]
ldp d14, d15, [sp], #64
ret
#endif

View File

@ -448,6 +448,239 @@ static bool MNNAsyLocalQuantInfo_EP12_FP32(float* scale, float* bias, float* qsc
}
return true;
}
static bool MNNAsyLocalQuantInfo_EP16_FP32(float* scale, float* bias, float* qscale, float* qbias, const float* srcMin, const float* srcMax, const size_t* info) {
auto blockNum = info[0];
auto EP = info[1];
auto DST_XUNIT = info[3];
if (DST_XUNIT != 16) {
MNN_ERROR("Function called error\n");
return false;
}
auto stride = EP * blockNum;
// dequant scale/bias : [EU, blockNum, step]
// quant scale/bias: [blockNum, EP]
auto minfloat = vdupq_n_f32(1e-6);
auto _255f = vdupq_n_f32(255.f);
auto _128f = vdupq_n_f32(128.f);
auto _0f = vdupq_n_f32(0.f);
for (int k = 0; k < blockNum; ++k) {
auto qind = k * EP;
auto realDstCount = EP;
auto scalePtr = scale + k * ALIMIN(EP, DST_XUNIT);
auto biasPtr = bias + k * ALIMIN(EP, DST_XUNIT);
while (realDstCount > DST_XUNIT - 1) {
auto max0 = vld1q_f32(srcMax + qind);
auto max1 = vld1q_f32(srcMax + qind + 4);
auto max2 = vld1q_f32(srcMax + qind + 8);
auto max3 = vld1q_f32(srcMax + qind + 12);
auto min0 = vld1q_f32(srcMin + qind);
auto min1 = vld1q_f32(srcMin + qind + 4);
auto min2 = vld1q_f32(srcMin + qind + 8);
auto min3 = vld1q_f32(srcMin + qind + 12);
auto diff0 = vsubq_f32(max0, min0);
auto diff1 = vsubq_f32(max1, min1);
auto diff2 = vsubq_f32(max2, min2);
auto diff3 = vsubq_f32(max3, min3);
auto qscaleV0 = vdivq_f32(_255f, diff0);
auto qscaleV1 = vdivq_f32(_255f, diff1);
auto qscaleV2 = vdivq_f32(_255f, diff2);
auto qscaleV3 = vdivq_f32(_255f, diff3);
auto scaleV0 = vdivq_f32(diff0, _255f);
auto scaleV1 = vdivq_f32(diff1, _255f);
auto scaleV2 = vdivq_f32(diff2, _255f);
auto scaleV3 = vdivq_f32(diff3, _255f);
auto qbiasV0 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min0), diff0), _128f));
auto qbiasV1 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min1), diff1), _128f));
auto qbiasV2 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min2), diff2), _128f));
auto qbiasV3 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min3), diff3), _128f));
auto biasV0 = vaddq_f32(vdivq_f32(vmulq_f32(diff0, _128f), _255f), min0);
auto biasV1 = vaddq_f32(vdivq_f32(vmulq_f32(diff1, _128f), _255f), min1);
auto biasV2 = vaddq_f32(vdivq_f32(vmulq_f32(diff2, _128f), _255f), min2);
auto biasV3 = vaddq_f32(vdivq_f32(vmulq_f32(diff3, _128f), _255f), min3);
auto _0bic = vclezq_f32(diff0);
auto _1bic = vclezq_f32(diff1);
auto _2bic = vclezq_f32(diff2);
auto _3bic = vclezq_f32(diff3);
qscaleV0 = vbslq_f32(_0bic, _0f, qscaleV0);
qscaleV1 = vbslq_f32(_1bic, _0f, qscaleV1);
qscaleV2 = vbslq_f32(_2bic, _0f, qscaleV2);
qscaleV3 = vbslq_f32(_3bic, _0f, qscaleV3);
qbiasV0 = vrndaq_f32(vbslq_f32(_0bic, _0f, qbiasV0));
qbiasV1 = vrndaq_f32(vbslq_f32(_1bic, _0f, qbiasV1));
qbiasV2 = vrndaq_f32(vbslq_f32(_2bic, _0f, qbiasV2));
qbiasV3 = vrndaq_f32(vbslq_f32(_3bic, _0f, qbiasV3));
scaleV0 = vbslq_f32(_0bic, _0f, scaleV0);
scaleV1 = vbslq_f32(_1bic, _0f, scaleV1);
scaleV2 = vbslq_f32(_2bic, _0f, scaleV2);
scaleV3 = vbslq_f32(_3bic, _0f, scaleV3);
biasV0 = vbslq_f32(_0bic, max0, biasV0);
biasV1 = vbslq_f32(_1bic, max1, biasV1);
biasV2 = vbslq_f32(_2bic, max2, biasV2);
biasV3 = vbslq_f32(_3bic, max3, biasV3);
vst1q_f32(qscale + qind, qscaleV0);
vst1q_f32(qscale + qind + 4, qscaleV1);
vst1q_f32(qscale + qind + 8, qscaleV2);
vst1q_f32(qscale + qind + 12, qscaleV3);
vst1q_f32(qbias + qind, qbiasV0);
vst1q_f32(qbias + qind + 4, qbiasV1);
vst1q_f32(qbias + qind + 8, qbiasV2);
vst1q_f32(qbias + qind + 12, qbiasV3);
vst1q_f32(scalePtr, scaleV0);
vst1q_f32(scalePtr + 4, scaleV1);
vst1q_f32(scalePtr + 8, scaleV2);
vst1q_f32(scalePtr + 12, scaleV3);
vst1q_f32(biasPtr, biasV0);
vst1q_f32(biasPtr + 4, biasV1);
vst1q_f32(biasPtr + 8, biasV2);
vst1q_f32(biasPtr + 12, biasV3);
realDstCount -= DST_XUNIT;
qind += DST_XUNIT;
scalePtr += (blockNum * DST_XUNIT);
biasPtr += (blockNum * DST_XUNIT);
}
if (realDstCount == 0) {
continue;
}
auto remainE = realDstCount;
auto stride0 = remainE * blockNum;
scalePtr = scale + (EP / DST_XUNIT) * blockNum * DST_XUNIT + k * remainE;
biasPtr = bias + (EP / DST_XUNIT) * blockNum * DST_XUNIT + k * remainE;
if (realDstCount > 7) {
auto max0 = vld1q_f32(srcMax + qind);
auto max1 = vld1q_f32(srcMax + qind + 4);
auto min0 = vld1q_f32(srcMin + qind);
auto min1 = vld1q_f32(srcMin + qind + 4);
auto diff0 = vsubq_f32(max0, min0);
auto diff1 = vsubq_f32(max1, min1);
auto qscaleV0 = vdivq_f32(_255f, diff0);
auto qscaleV1 = vdivq_f32(_255f, diff1);
auto scaleV0 = vdivq_f32(diff0, _255f);
auto scaleV1 = vdivq_f32(diff1, _255f);
auto qbiasV0 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min0), diff0), _128f));
auto qbiasV1 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min1), diff1), _128f));
auto biasV0 = vaddq_f32(vdivq_f32(vmulq_f32(diff0, _128f), _255f), min0);
auto biasV1 = vaddq_f32(vdivq_f32(vmulq_f32(diff1, _128f), _255f), min1);
auto _0bic = vclezq_f32(diff0);
auto _1bic = vclezq_f32(diff1);
qscaleV0 = vbslq_f32(_0bic, _0f, qscaleV0);
qscaleV1 = vbslq_f32(_1bic, _0f, qscaleV1);
qbiasV0 = vrndaq_f32(vbslq_f32(_0bic, _0f, qbiasV0));
qbiasV1 = vrndaq_f32(vbslq_f32(_1bic, _0f, qbiasV1));
scaleV0 = vbslq_f32(_0bic, _0f, scaleV0);
scaleV1 = vbslq_f32(_1bic, _0f, scaleV1);
biasV0 = vbslq_f32(_0bic, max0, biasV0);
biasV1 = vbslq_f32(_1bic, max1, biasV1);
vst1q_f32(qscale + qind, qscaleV0);
vst1q_f32(qscale + qind + 4, qscaleV1);
vst1q_f32(qbias + qind, qbiasV0);
vst1q_f32(qbias + qind + 4, qbiasV1);
vst1q_f32(scalePtr, scaleV0);
vst1q_f32(scalePtr + 4, scaleV1);
vst1q_f32(biasPtr, biasV0);
vst1q_f32(biasPtr + 4, biasV1);
realDstCount -= 8;
qind += 8;
scalePtr += 8;
biasPtr += 8;
}
if (realDstCount > 3) {
auto max0 = vld1q_f32(srcMax + qind);
auto min0 = vld1q_f32(srcMin + qind);
auto diff0 = vsubq_f32(max0, min0);
auto qscaleV0 = vdivq_f32(_255f, diff0);
auto scaleV0 = vdivq_f32(diff0, _255f);
auto qbiasV0 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min0), diff0), _128f));
auto biasV0 = vaddq_f32(vdivq_f32(vmulq_f32(diff0, _128f), _255f), min0);
auto _0bic = vclezq_f32(diff0);
qscaleV0 = vbslq_f32(_0bic, _0f, qscaleV0);
qbiasV0 = vrndaq_f32(vbslq_f32(_0bic, _0f, qbiasV0));
scaleV0 = vbslq_f32(_0bic, _0f, scaleV0);
biasV0 = vbslq_f32(_0bic, max0, biasV0);
vst1q_f32(qscale + qind, qscaleV0);
vst1q_f32(qbias + qind, qbiasV0);
vst1q_f32(scalePtr, scaleV0);
vst1q_f32(biasPtr, biasV0);
realDstCount -= 4;
qind += 4;
scalePtr += 4;
biasPtr += 4;
}
while (realDstCount > 0) {
auto max0 = vld1q_dup_f32(srcMax + qind);
auto min0 = vld1q_dup_f32(srcMin + qind);
auto diff0 = vsubq_f32(max0, min0);
auto qscaleV0 = vdivq_f32(_255f, diff0);
auto scaleV0 = vdivq_f32(diff0, _255f);
auto qbiasV0 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min0), diff0), _128f));
auto biasV0 = vaddq_f32(vdivq_f32(vmulq_f32(diff0, _128f), _255f), min0);
auto _0bic = vclezq_f32(diff0);
qscaleV0 = vbslq_f32(_0bic, _0f, qscaleV0);
qbiasV0 = vrndaq_f32(vbslq_f32(_0bic, _0f, qbiasV0));
scaleV0 = vbslq_f32(_0bic, _0f, scaleV0);
biasV0 = vbslq_f32(_0bic, max0, biasV0);
vst1q_lane_f32(qscale + qind, qscaleV0, 0);
vst1q_lane_f32(qbias + qind, qbiasV0, 0);
vst1q_lane_f32(scalePtr, scaleV0, 0);
vst1q_lane_f32(biasPtr, biasV0, 0);
realDstCount -= 1;
qind += 1;
scalePtr += 1;
biasPtr += 1;
}
}
return true;
}
#endif

View File

@ -90,8 +90,12 @@ sub \rg0, \rg0, \rg1
.endm
.macro REVERT_WEIGHT_KERNEL_SUM rg0, rg1, rg2, rg3
mul \rg1, \rg2, \rg3 // blocknum * ocUp4 * sizeof(float)
sub \rg0, \rg0, \rg1, LSL #4 // revert weight kernel sum
// y=UP_DIV(ocDiv4,(hp/pack))
add \rg1, \rg3, #1
lsr \rg1, \rg1, #1
// blockNum * y * (hp * sizeof(float))
mul \rg1, \rg2, \rg1
sub \rg0, \rg0, \rg1, LSL #5 // revert weight kernel sum
.endm
asm_function MNNGemmInt8AddBiasScale_ARMV82_w4_Unit
@ -398,9 +402,11 @@ L4_TILE12_BLOCKNUM:
L4_Tile12Quan:
ld1 {v0.4s}, [x2], #16 // scale
ld1 {v0.4s}, [x2] // scale
add x2, x2, #32
ld1 {v2.4s, v3.4s, v4.4s}, [x8], #48 // x kernel sum
ld1 {v5.4s}, [x2], #16 // weight quan zeropoint
ld1 {v5.4s}, [x2] // weight quan zeropoint
add x2, x2, #32
Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15
Int32ToFloat v16, v17, v18, v19
@ -430,7 +436,7 @@ L4_TILE12_BLOCKNUM:
cbz x27, L4_TILE12_ADD_DSTV
ld1 {v0.4s, v1.4s, v2.4s}, [x27], #48 // input dequant bias
ld1 {v3.4s}, [x28], #16 // weight kernel sum
ld1 {v3.4s}, [x28] // weight kernel sum
MLA_WEIGHTZERO v8, v0, v3, 0
MLA_WEIGHTZERO v9, v0, v3, 1
MLA_WEIGHTZERO v10, v0, v3, 2
@ -443,6 +449,7 @@ L4_TILE12_BLOCKNUM:
MLA_WEIGHTZERO v17, v2, v3, 1
MLA_WEIGHTZERO v18, v2, v3, 2
MLA_WEIGHTZERO v19, v2, v3, 3
add x28, x28, #32
L4_TILE12_ADD_DSTV:
cbz x19, L4_TILE12_ACCUM_BUFFER // x19=0: first block, do not add previous block result
@ -685,9 +692,11 @@ L4_TILE8_BLOCKNUM:
bne L4LoopSz_TILE_8
L4Tile8Quan:
ld1 {v0.4s}, [x12], #16 // scale
ld1 {v0.4s}, [x12] // scale
add x12, x12, #32
ld1 {v2.4s, v3.4s}, [x8], x22 // x kernel sum
ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
ld1 {v24.4s}, [x12] // weight quan zeropoint
add x12, x12, #32
Int32ToFloat v8, v9, v10, v11
Int32ToFloat v12, v13, v14, v15
MUL_SCALE v0, v8, v9, v10, v11
@ -709,7 +718,7 @@ L4_TILE8_BLOCKNUM:
cbz x27, L4_TILE8_ADD_DSTV
ld1 {v2.4s, v3.4s}, [x27], x25
ld1 {v24.4s}, [x28], #16
ld1 {v24.4s}, [x28]
MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3
MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3
MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3
@ -718,6 +727,7 @@ L4_TILE8_BLOCKNUM:
MLA_WEIGHTZERO v13, v3, v24, 1 // tile:5, oc:0-3
MLA_WEIGHTZERO v14, v3, v24, 2 // tile:6, oc:0-3
MLA_WEIGHTZERO v15, v3, v24, 3 // tile:7, oc:0-3
add x28, x28, #32
L4_TILE8_ADD_DSTV:
cbz x19, TILE8_L4_ACCUM_BUFFER
@ -905,9 +915,11 @@ L4_TILE4_BLOCKNUM:
bne L4LoopSz_TILE_4
L4Tile4Quan:
ld1 {v0.4s}, [x12], #16 // scale
ld1 {v0.4s}, [x12] // scale
add x12, x12, #32
ld1 {v2.4s}, [x8], x22 // x kernel sum
ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
ld1 {v24.4s}, [x12] // weight quan zeropoint
add x12, x12, #32
Int32ToFloat v8, v9, v10, v11
MUL_SCALE v0, v8, v9, v10, v11
@ -922,11 +934,12 @@ L4_TILE4_BLOCKNUM:
cbz x27, L4_TILE4_ADD_DSTV
ld1 {v2.4s}, [x27], x25
ld1 {v24.4s}, [x28], #16
ld1 {v24.4s}, [x28]
MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3
MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3
MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3
MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3
add x28, x28, #32
L4_TILE4_ADD_DSTV:
cbz x19, TILE4_L4_ACCUM_BUFFER
@ -1115,9 +1128,11 @@ L4_TILE1_BLOCKNUM:
bne L4LoopSz_TILE_1_lu1
L4Tile1Quan:
ld1 {v0.4s}, [x12], #16 // scale
ld1 {v0.4s}, [x12] // scale
add x12, x12, #32
ld1 {v2.s}[0], [x8], x22 // x kernel sum
ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
ld1 {v24.4s}, [x12] // weight quan zeropoint
add x12, x12, #32
scvtf v8.4s, v8.4s
fmul v8.4s, v8.4s, v0.4s
@ -1129,8 +1144,9 @@ L4_TILE1_BLOCKNUM:
cbz x27, L4_TILE1_ADD_DSTV
ld1 {v2.s}[0], [x27], x25
ld1 {v24.4s}, [x28], #16
ld1 {v24.4s}, [x28]
MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3
add x28, x28, #32
L4_TILE1_ADD_DSTV:
cbz x19, L4_TILE1_L8_ACCUM_BUFFER

View File

@ -81,8 +81,12 @@ sub \rg0, \rg0, \rg1
.endm
.macro REVERT_WEIGHT_KERNEL_SUM rg0, rg1, rg2, rg3
mul \rg1, \rg2, \rg3 // blocknum * ocUp4 * sizeof(float)
sub \rg0, \rg0, \rg1, LSL #4 // revert weight kernel sum
// y=UP_DIV(ocDiv4,(hp/pack))
add \rg1, \rg3, #1
lsr \rg1, \rg1, #1
// blockNum * y * (hp * sizeof(float))
mul \rg1, \rg2, \rg1
sub \rg0, \rg0, \rg1, LSL #5 // revert weight kernel sum
.endm
asm_function MNNGemmInt8AddBiasScale_ARMV86_w4_Unit
@ -424,14 +428,16 @@ L4_LoopSz_TILE_10:
scvtf v9.4s, v9.4s
L4_Tile10Quan:
ld1 {v20.4s}, [x2], #16 // weight scale
ld1 {v20.4s}, [x2] // weight scale
add x2, x2, #32
ld1 {v22.4s, v23.4s}, [x8], #32 // x kernel sum
ld1 {v24.d}[0], [x8], #8
ld1 {v25.4s}, [x2], #16 // weight quan zeropoint
ld1 {v25.4s}, [x2] // weight quan zeropoint
MUL_SCALE v20, v0, v1, v2, v3
MUL_SCALE v20, v4, v5, v6, v7
fmul v8.4s, v8.4s, v20.4s
fmul v9.4s, v9.4s, v20.4s
add x2, x2, #32
ld1 {v27.4s, v28.4s}, [x23], #32 // input dequant scale
ld1 {v29.d}[0], [x23], x24
@ -455,7 +461,7 @@ L4_Tile10Quan:
cbz x27, L4_TILE10_ADD_DSTV
ld1 {v22.4s, v23.4s}, [x27], #32 // input dequant bias
ld1 {v24.2s}, [x27], #8
ld1 {v25.4s}, [x28], #16 // weight kernel sum
ld1 {v25.4s}, [x28] // weight kernel sum
MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3
MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3
MLA_WEIGHTZERO v2, v22, v25, 2 // tile:2, oc:0-3
@ -466,6 +472,7 @@ L4_Tile10Quan:
MLA_WEIGHTZERO v7, v23, v25, 3 // tile:7, oc:0-3
MLA_WEIGHTZERO v8, v24, v25, 0 // tile:8, oc:0-3
MLA_WEIGHTZERO v9, v24, v25, 1 // tile:9, oc:0-3
add x28, x28, #32
L4_TILE10_ADD_DSTV:
cbz x19, L4_TILE10_TEMP_BUFFER
@ -756,9 +763,11 @@ L4_LoopSz_TILE_8:
Int32ToFloat v4, v5, v6, v7
L4_Tile8Quan:
ld1 {v20.4s}, [x12], #16 // scale
ld1 {v20.4s}, [x12] // scale
add x12, x12, #32
ld1 {v22.4s, v23.4s}, [x8] // x kernel sum
ld1 {v25.4s}, [x12], #16 // weight quan zeropoint
ld1 {v25.4s}, [x12] // weight quan zeropoint
add x12, x12, #32
add x8, x8, x22, LSR #1
MUL_SCALE v20, v0, v1, v2, v3
MUL_SCALE v20, v4, v5, v6, v7
@ -779,7 +788,7 @@ L4_Tile8Quan:
cbz x27, L4_TILE8_ADD_DSTV
ld1 {v22.4s, v23.4s}, [x27], x25 // input dequant bias
ld1 {v25.4s}, [x28], #16 // weight kernel sum
ld1 {v25.4s}, [x28] // weight kernel sum
MLA_WEIGHTZERO v0, v22, v25, 0
MLA_WEIGHTZERO v1, v22, v25, 1
MLA_WEIGHTZERO v2, v22, v25, 2
@ -788,6 +797,7 @@ L4_Tile8Quan:
MLA_WEIGHTZERO v5, v23, v25, 1
MLA_WEIGHTZERO v6, v23, v25, 2
MLA_WEIGHTZERO v7, v23, v25, 3
add x28, x28, #32
L4_TILE8_ADD_DSTV:
cbz x19, L4_TILE8_TEMP_BUFFER
@ -991,9 +1001,11 @@ L4_LoopSz_TILE_4:
Int32ToFloat v0, v1, v2, v3
L4_Tile4Quan:
ld1 {v20.4s}, [x12], #16 // scale
ld1 {v20.4s}, [x12] // scale
add x12, x12, #32
ld1 {v22.4s}, [x8] // x kernel sum
ld1 {v25.4s}, [x12], #16 // weight quan zeropoint
ld1 {v25.4s}, [x12] // weight quan zeropoint
add x12, x12, #32
MUL_SCALE v20, v0, v1, v2, v3
add x8, x8, x22, LSR #1
@ -1008,11 +1020,12 @@ L4_Tile4Quan:
cbz x27, L4_TILE4_ADD_DSTV
ld1 {v22.4s}, [x27], x25 // input dequant bias
ld1 {v25.4s}, [x28], #16 // weight kernel sum
ld1 {v25.4s}, [x28] // weight kernel sum
MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3
MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3
MLA_WEIGHTZERO v2, v22, v25, 2 // tile:2, oc:0-3
MLA_WEIGHTZERO v3, v22, v25, 3 // tile:3, oc:0-3
add x28, x28, #32
L4_TILE4_ADD_DSTV:
cbz x19, L4_TILE4_ACCUM_BUFFER
@ -1180,15 +1193,17 @@ L4_LoopSz_TILE_2:
scvtf v1.4s, v1.4s
L4_Tile2Quan:
ld1 {v20.4s}, [x12], #16 // scale
ld1 {v20.4s}, [x12] // scale
ld1 {v22.d}[0], [x8] // x kernel sum
ld1 {v25.4s}, [x12], #16 // weight quan zeropoint
add x12, x12, #32
ld1 {v25.4s}, [x12] // weight quan zeropoint
fmul v0.4s, v0.4s, v20.4s
fmul v1.4s, v1.4s, v20.4s
add x8, x8, x22, LSR #1
ld1 {v27.d}[0], [x23], x25
fmul v0.4s, v0.4s, v27.s[0]
fmul v1.4s, v1.4s, v27.s[1]
add x12, x12, #32
L4_TILE2_MLA:
MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3
@ -1196,9 +1211,10 @@ L4_Tile2Quan:
cbz x27, L4_TILE2_ADD_DSTV
ld1 {v22.2s}, [x27], x25 // input dequant bias
ld1 {v25.4s}, [x28], #16 // weight kernel sum
ld1 {v25.4s}, [x28] // weight kernel sum
MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3
MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3
add x28, x28, #32
L4_TILE2_ADD_DSTV:
cbz x19, L4_TILE2_ACCUM_BUFFER
@ -1460,11 +1476,13 @@ L4_LoopSzEnd_TILE_1:
scvtf v25.4s, v25.4s
L4_Tile1Quan:
ld1 {v0.4s}, [x12], #16 // scale
ld1 {v0.4s}, [x12] // scale
add x12, x12, #32 // hp*sizeof(float)
ld1 {v6.s}[0], [x8] // x kernel sum
ld1 {v8.4s}, [x12], #16 // weight quan zeropoint
ld1 {v8.4s}, [x12] // weight quan zeropoint
fmul v25.4s, v25.4s, v0.4s
add x8, x8, x22, LSR #1
add x12, x12, #32
ld1 {v10.s}[0], [x23], x25
fmul v25.4s, v25.4s, v10.s[0]
@ -1474,8 +1492,9 @@ L4_Tile1Quan:
cbz x27, L4_TILE1_ADD_DSTV
ld1 {v6.s}[0], [x27], x25 // input dequant bias
ld1 {v8.4s}, [x28], #16 // weight kernel sum
ld1 {v8.4s}, [x28] // weight kernel sum
MLA_WEIGHTZERO v25, v6, v8, 0 // tile:0, oc:0-3
add x28, x28, #32
L4_TILE1_ADD_DSTV:
cbz x19, L4_TILE1_ACCUM_BUFFER

View File

@ -0,0 +1,339 @@
//
// MNNGeneralIm2col_Fp16Sme2.S
// MNN
//
// Created by MNN on 2024/12/25.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifdef __aarch64__
#include "MNNAsmGlobal.h"
.text
.align 5
//void MNNGeneralIm2col_Fp16Sme2(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack)
asm_function MNNGeneralIm2col_Fp16Sme2
// x0:destOrigin, x1:sourceGroup, x2:info, x3:el, x4:LP, x5:pack
stp d14, d15, [sp, #(-16 * 5)]!
stp d12, d13, [sp, #(16 * 1)]
stp d10, d11, [sp, #(16 * 2)]
stp d8, d9, [sp, #(16 * 3)]
stp x19, x20, [sp, #(16 * 4)]
// load el info
ldr w6, [x2, #0] // number
ldr w7, [x2, #4] // eReal
ldr w15, [x2, #8] // eDest (< EP)
ldr w9, [x2, #12] // offset (stride)
ldr x14, [x1, #0] // src start
lsl x9, x9, #4 // pack*offset*sizeof(float16_t)
// stride
lsl x19, x15, #3 // eDest*LP*sizeof(float16_t)
lsl x7, x7, #4 // eReal*pack*sizeof(float16_t)
mov x20, #3 // Sme2,LP=4
LoopNum:
ldr w10, [x3], #4 // e
ldr w11, [x3], #4 // l
ldr w12, [x3], #4 // eOffset
ldr w13, [x3], #4 // lOffset
// dst address: x2
and x2, x13, x20 // lR
sub x13, x13, x2 // lOffset-lR
mul x13, x13, x15 // (lOffset-lR)*(eDest)
add x13, x13, x2 // (lOffset-lR)*(eDest)+lR
add x13, x13, x12, LSL #2 // + eoffset*lp
add x2, x0, x13, LSL #1 // *sizeof(float16_t)
lsl x8, x19, #1 // 2*(eDest*LP*sizeof(float16_t))
cmp x11, #8
blt LoopL4
LoopL8:
mov x5, x2
mov x4, x14
mov x13, x10
add x12, x2, x19 // eDest*LP*sizeof(float16_t)
cmp x13, #16
blt LoopL8E12
sub x8, x8, #64
LoopL8E16:
sub x13, x13, #16
ld1 {v0.8h}, [x14], x9
ld1 {v1.8h}, [x14], x9
ld1 {v2.8h}, [x14], x9
ld1 {v3.8h}, [x14], x9
ld1 {v4.8h}, [x14], x9
ld1 {v5.8h}, [x14], x9
ld1 {v6.8h}, [x14], x9
ld1 {v7.8h}, [x14], x9
ld1 {v8.8h}, [x14], x9
ld1 {v9.8h}, [x14], x9
ld1 {v10.8h}, [x14], x9
ld1 {v11.8h}, [x14], x9
ld1 {v12.8h}, [x14], x9
ld1 {v13.8h}, [x14], x9
ld1 {v14.8h}, [x14], x9
ld1 {v15.8h}, [x14], x9
zip1 v16.2d, v0.2d, v1.2d
zip1 v17.2d, v2.2d, v3.2d
zip1 v18.2d, v4.2d, v5.2d
zip1 v19.2d, v6.2d, v7.2d
zip1 v20.2d, v8.2d, v9.2d
zip1 v21.2d, v10.2d, v11.2d
zip1 v22.2d, v12.2d, v13.2d
zip1 v23.2d, v14.2d, v15.2d
zip2 v24.2d, v0.2d, v1.2d
zip2 v25.2d, v2.2d, v3.2d
zip2 v26.2d, v4.2d, v5.2d
zip2 v27.2d, v6.2d, v7.2d
zip2 v28.2d, v8.2d, v9.2d
zip2 v29.2d, v10.2d, v11.2d
zip2 v30.2d, v12.2d, v13.2d
zip2 v31.2d, v14.2d, v15.2d
st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x2], #64
st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x2], x8
st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x12], #64
st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x12], x8
cmp x13, #16
bge LoopL8E16
add x8, x8, #64
LoopL8E12:
cmp x13, #12
blt LoopL8E8
ld1 {v0.8h}, [x14], x9
ld1 {v1.8h}, [x14], x9
ld1 {v2.8h}, [x14], x9
ld1 {v3.8h}, [x14], x9
ld1 {v4.8h}, [x14], x9
ld1 {v5.8h}, [x14], x9
ld1 {v6.8h}, [x14], x9
ld1 {v7.8h}, [x14], x9
ld1 {v8.8h}, [x14], x9
ld1 {v9.8h}, [x14], x9
ld1 {v10.8h}, [x14], x9
ld1 {v11.8h}, [x14], x9
zip1 v12.2d, v0.2d, v1.2d
zip1 v13.2d, v2.2d, v3.2d
zip1 v14.2d, v4.2d, v5.2d
zip1 v15.2d, v6.2d, v7.2d
zip1 v16.2d, v8.2d, v9.2d
zip1 v17.2d, v10.2d, v11.2d
zip2 v18.2d, v0.2d, v1.2d
zip2 v19.2d, v2.2d, v3.2d
zip2 v20.2d, v4.2d, v5.2d
zip2 v21.2d, v6.2d, v7.2d
zip2 v22.2d, v8.2d, v9.2d
zip2 v23.2d, v10.2d, v11.2d
st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x2], #64
st1 {v16.8h, v17.8h}, [x2], #32
st1 {v18.8h, v19.8h, v20.8h, v21.8h}, [x12], #64
st1 {v22.8h, v23.8h}, [x12], #32
sub x13, x13, #12
LoopL8E8:
cmp x13, #8
blt LoopL8E4
ld1 {v0.8h}, [x14], x9
ld1 {v1.8h}, [x14], x9
ld1 {v2.8h}, [x14], x9
ld1 {v3.8h}, [x14], x9
ld1 {v4.8h}, [x14], x9
ld1 {v5.8h}, [x14], x9
ld1 {v6.8h}, [x14], x9
ld1 {v7.8h}, [x14], x9
zip1 v12.2d, v0.2d, v1.2d
zip1 v13.2d, v2.2d, v3.2d
zip1 v14.2d, v4.2d, v5.2d
zip1 v15.2d, v6.2d, v7.2d
zip2 v18.2d, v0.2d, v1.2d
zip2 v19.2d, v2.2d, v3.2d
zip2 v20.2d, v4.2d, v5.2d
zip2 v21.2d, v6.2d, v7.2d
st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x2], #64
st1 {v18.8h, v19.8h, v20.8h, v21.8h}, [x12], #64
sub x13, x13, #8
LoopL8E4:
cmp x13, #4
blt LoopL8E2
ld1 {v0.8h}, [x14], x9
ld1 {v1.8h}, [x14], x9
ld1 {v2.8h}, [x14], x9
ld1 {v3.8h}, [x14], x9
zip1 v12.2d, v0.2d, v1.2d
zip1 v13.2d, v2.2d, v3.2d
zip2 v18.2d, v0.2d, v1.2d
zip2 v19.2d, v2.2d, v3.2d
st1 {v12.8h, v13.8h}, [x2], #32
st1 {v18.8h, v19.8h}, [x12], #32
sub x13, x13, #4
LoopL8E2:
cmp x13, #2
blt LoopL8E1
ld1 {v0.8h}, [x14], x9
ld1 {v1.8h}, [x14], x9
zip1 v12.2d, v0.2d, v1.2d
zip2 v18.2d, v0.2d, v1.2d
st1 {v12.8h}, [x2], #16
st1 {v18.8h}, [x12], #16
sub x13, x13, #2
LoopL8E1:
cmp x13, #1
blt EndL8LoopE
ld1 {v0.8h}, [x14], x9
st1 {v0.d}[0], [x2], #8
st1 {v0.d}[1], [x12], #8
EndL8LoopE:
sub x11, x11, #8
cmp x11, #8
add x2, x5, x8 // eDest*LP*2*sizeof(float16_t)
add x14, x4, x7
bge LoopL8
cbz x11, EndLoopL
LoopL4:
mov x5, x2
mov x4, x14
mov x13, x10
sub x8, x19, #64
cmp x13, #16
blt LoopL4E12
LoopL4E16:
sub x13, x13, #16
ld1 {v0.8h}, [x14], x9
ld1 {v1.8h}, [x14], x9
ld1 {v2.8h}, [x14], x9
ld1 {v3.8h}, [x14], x9
ld1 {v4.8h}, [x14], x9
ld1 {v5.8h}, [x14], x9
ld1 {v6.8h}, [x14], x9
ld1 {v7.8h}, [x14], x9
ld1 {v8.8h}, [x14], x9
ld1 {v9.8h}, [x14], x9
ld1 {v10.8h}, [x14], x9
ld1 {v11.8h}, [x14], x9
ld1 {v12.8h}, [x14], x9
ld1 {v13.8h}, [x14], x9
ld1 {v14.8h}, [x14], x9
ld1 {v15.8h}, [x14], x9
zip1 v16.2d, v0.2d, v1.2d
zip1 v17.2d, v2.2d, v3.2d
zip1 v18.2d, v4.2d, v5.2d
zip1 v19.2d, v6.2d, v7.2d
zip1 v20.2d, v8.2d, v9.2d
zip1 v21.2d, v10.2d, v11.2d
zip1 v22.2d, v12.2d, v13.2d
zip1 v23.2d, v14.2d, v15.2d
st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x2], #64
st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x2], x8
cmp x13, #16
bge LoopL4E16
LoopL4E12:
cmp x13, #12
blt LoopL4E8
sub x13, x13, #12
ld1 {v0.8h}, [x14], x9
ld1 {v1.8h}, [x14], x9
ld1 {v2.8h}, [x14], x9
ld1 {v3.8h}, [x14], x9
ld1 {v4.8h}, [x14], x9
ld1 {v5.8h}, [x14], x9
ld1 {v6.8h}, [x14], x9
ld1 {v7.8h}, [x14], x9
ld1 {v8.8h}, [x14], x9
ld1 {v9.8h}, [x14], x9
ld1 {v10.8h}, [x14], x9
ld1 {v11.8h}, [x14], x9
zip1 v12.2d, v0.2d, v1.2d
zip1 v13.2d, v2.2d, v3.2d
zip1 v14.2d, v4.2d, v5.2d
zip1 v15.2d, v6.2d, v7.2d
zip1 v16.2d, v8.2d, v9.2d
zip1 v17.2d, v10.2d, v11.2d
st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x2], #64
st1 {v16.8h, v17.8h}, [x2], #32
LoopL4E8:
cmp x13, #8
blt LoopL4E4
ld1 {v0.8h}, [x14], x9
ld1 {v1.8h}, [x14], x9
ld1 {v2.8h}, [x14], x9
ld1 {v3.8h}, [x14], x9
ld1 {v4.8h}, [x14], x9
ld1 {v5.8h}, [x14], x9
ld1 {v6.8h}, [x14], x9
ld1 {v7.8h}, [x14], x9
zip1 v12.2d, v0.2d, v1.2d
zip1 v13.2d, v2.2d, v3.2d
zip1 v14.2d, v4.2d, v5.2d
zip1 v15.2d, v6.2d, v7.2d
st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x2], #64
sub x13, x13, #8
LoopL4E4:
cmp x13, #4
blt LoopL4E2
ld1 {v0.8h}, [x14], x9
ld1 {v1.8h}, [x14], x9
ld1 {v2.8h}, [x14], x9
ld1 {v3.8h}, [x14], x9
zip1 v12.2d, v0.2d, v1.2d
zip1 v13.2d, v2.2d, v3.2d
st1 {v12.8h, v13.8h}, [x2], #32
sub x13, x13, #4
LoopL4E2:
cmp x13, #2
blt LoopL4E1
ld1 {v0.8h}, [x14], x9
ld1 {v1.8h}, [x14], x9
zip1 v12.2d, v0.2d, v1.2d
st1 {v12.8h}, [x2], #16
sub x13, x13, #2
LoopL4E1:
cmp x13, #1
blt EndLoopL
ld1 {v0.8h}, [x14], x9
st1 {v0.d}[0], [x2], #8
EndLoopL:
subs x6, x6, #1
add x1, x1, #8
ldr x14, [x1, #0]
bne LoopNum
End:
ldp x19, x20, [sp, #(16 * 4)]
ldp d8, d9, [sp, #(16 * 3)]
ldp d10, d11, [sp, #(16 * 2)]
ldp d12, d13, [sp, #(16 * 1)]
ldp d14, d15, [sp], #(16 * 5)
ret
#endif

View File

@ -0,0 +1,165 @@
//
// MNNGeneralIm2col_Fp32Sme2.S
// MNN
//
// Created by MNN on 2025/01/13.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifdef __aarch64__
#include "MNNAsmGlobal.h"
.text
.align 5
//void MNNGeneralIm2col_Fp32Sme2(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack)
asm_function MNNGeneralIm2col_Fp32Sme2
// x0:destOrigin, x1:sourceGroup, x2:info, x3:el, x4:LP, x5:pack
stp d14, d15, [sp, #(-16 * 5)]!
stp d12, d13, [sp, #(16 * 1)]
stp d10, d11, [sp, #(16 * 2)]
stp d8, d9, [sp, #(16 * 3)]
stp x19, x20, [sp, #(16 * 4)]
// load el info
ldr w6, [x2, #0] // number
ldr w7, [x2, #4] // eReal
ldr w15, [x2, #8] // eDest (< EP)
ldr w9, [x2, #12] // offset (stride)
ldr x14, [x1, #0] // src start
lsl x9, x9, #4 // pack*offset*sizeof(float32_t)
// stride
lsl x19, x15, #4 // eDest*LP*sizeof(float32_t)
lsl x7, x7, #4 // eReal*pack*sizeof(float32_t)
mov x20, #3 // Arm82,LP=4
LoopNum:
ldr w10, [x3], #4 // e
ldr w11, [x3], #4 // l
ldr w12, [x3], #4 // eOffset
ldr w13, [x3], #4 // lOffset
// dst address: x2
and x2, x13, x20 // lR
sub x13, x13, x2 // lOffset-lR
mul x13, x13, x15 // (lOffset-lR)*(eDest)
add x13, x13, x2 // (lOffset-lR)*(eDest)+lR
add x13, x13, x12, LSL #2 // + eoffset*lp
add x2, x0, x13, LSL #2 // *sizeof(float32_t)
LoopL4:
mov x5, x2
mov x4, x14
mov x13, x10
cmp x13, #16
blt LoopL4E12
LoopL4E16:
sub x13, x13, #16
ld1 {v0.4s}, [x14], x9
ld1 {v1.4s}, [x14], x9
ld1 {v2.4s}, [x14], x9
ld1 {v3.4s}, [x14], x9
ld1 {v4.4s}, [x14], x9
ld1 {v5.4s}, [x14], x9
ld1 {v6.4s}, [x14], x9
ld1 {v7.4s}, [x14], x9
ld1 {v8.4s}, [x14], x9
ld1 {v9.4s}, [x14], x9
ld1 {v10.4s}, [x14], x9
ld1 {v11.4s}, [x14], x9
ld1 {v12.4s}, [x14], x9
ld1 {v13.4s}, [x14], x9
ld1 {v14.4s}, [x14], x9
ld1 {v15.4s}, [x14], x9
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x2], #64
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64
cmp x13, #16
bge LoopL4E16
LoopL4E12:
cmp x13, #12
blt LoopL4E8
ld1 {v0.4s}, [x14], x9
ld1 {v1.4s}, [x14], x9
ld1 {v2.4s}, [x14], x9
ld1 {v3.4s}, [x14], x9
ld1 {v4.4s}, [x14], x9
ld1 {v5.4s}, [x14], x9
ld1 {v6.4s}, [x14], x9
ld1 {v7.4s}, [x14], x9
ld1 {v8.4s}, [x14], x9
ld1 {v9.4s}, [x14], x9
ld1 {v10.4s}, [x14], x9
ld1 {v11.4s}, [x14], x9
sub x13, x13, #12
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x2], #64
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64
LoopL4E8:
cmp x13, #8
blt LoopL4E4
ld1 {v0.4s}, [x14], x9
ld1 {v1.4s}, [x14], x9
ld1 {v2.4s}, [x14], x9
ld1 {v3.4s}, [x14], x9
ld1 {v4.4s}, [x14], x9
ld1 {v5.4s}, [x14], x9
ld1 {v6.4s}, [x14], x9
ld1 {v7.4s}, [x14], x9
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x2], #64
sub x13, x13, #8
LoopL4E4:
cmp x13, #4
blt LoopL4E2
ld1 {v0.4s}, [x14], x9
ld1 {v1.4s}, [x14], x9
ld1 {v2.4s}, [x14], x9
ld1 {v3.4s}, [x14], x9
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64
sub x13, x13, #4
LoopL4E2:
cmp x13, #2
blt LoopL4E1
ld1 {v0.4s}, [x14], x9
ld1 {v1.4s}, [x14], x9
st1 {v0.4s, v1.4s}, [x2], #32
sub x13, x13, #2
LoopL4E1:
cmp x13, #1
blt EndL4LoopE
ld1 {v0.4s}, [x14], x9
st1 {v0.4s}, [x2], #16
EndL4LoopE:
add x2, x5, x19 // eDest*LP*sizeof(float32_t)
add x14, x4, x7
subs x11, x11, #4
bne LoopL4
EndLoopL:
subs x6, x6, #1
add x1, x1, #8
ldr x14, [x1, #0]
bne LoopNum
End:
ldp x19, x20, [sp, #(16 * 4)]
ldp d8, d9, [sp, #(16 * 3)]
ldp d10, d11, [sp, #(16 * 2)]
ldp d12, d13, [sp, #(16 * 1)]
ldp d14, d15, [sp], #(16 * 5)
ret
#endif

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -94,10 +94,11 @@ void ARMV86_MNNGetMatMulPackMode_BF16(int* eP, int* lP, int* hP) {
*hP = HP;
*lP = LP;
}
void ARMV86_MNNPackForMatMul_B_BF16(float* destF, const float* sourceF, size_t h, size_t l, bool transpose) {
void ARMV86_MNNPackForMatMul_B_BF16(float* destF, const float* sourceF, size_t h, size_t kernelsize, size_t ic, bool transpose) {
// [l, h] -> [h/hp, l/lp, hp, lp]
auto dest = (int16_t*)destF;
auto source = (const int32_t*)sourceF;
auto l = kernelsize * ic;
auto lCP = UP_DIV(l, LP);
auto hCP = UP_DIV(h, HP);
int sYstride = 1;
@ -173,10 +174,11 @@ void ARMV86_MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGro
#undef EP
#undef HP
#undef LP
void NEON_MNNPackForMatMul_B_BF16(float* destFloat, const float* sourceFloat, size_t h, size_t l, bool transpose) {
void NEON_MNNPackForMatMul_B_BF16(float* destFloat, const float* sourceFloat, size_t h, size_t kernelsize, size_t ic, bool transpose) {
auto hP = (int)h / 8;
auto hR = (int)hP * 8;
int16_t* dest = (int16_t*)destFloat;
auto l = kernelsize * ic;
const float* source = sourceFloat;
if (!transpose) {
for (int y = 0; y < hP; ++y) {
@ -246,9 +248,10 @@ void NEON_MNNPackForMatMul_B_BF16(float* destFloat, const float* sourceFloat, si
}
#else
void NEON_MNNPackForMatMul_B_BF16(float* destFloat, const float* sourceFloat, size_t h, size_t l, bool transpose) {
void NEON_MNNPackForMatMul_B_BF16(float* destFloat, const float* sourceFloat, size_t h, size_t kernelsize, size_t ic, bool transpose) {
int16_t* dest = (int16_t*)destFloat;
const float* source = sourceFloat;
auto l = kernelsize * ic;
if (!transpose) {
auto hP = h / 4;
auto hR = hP * 4;

View File

@ -228,6 +228,7 @@ void MNNDynamicUpdateConvBiasScale(float* newbias, float* oldbias, float* weight
#endif // not __aarch64__
static void MNNCountMaxMinValue(const float* source, float* minVal, float* maxVal, size_t size) {
#ifndef MNN_USE_NEON
int pack = 4;
float max_ = source[0], min_ = source[0];
for (int i = 1; i < size; ++i) {
@ -240,6 +241,162 @@ static void MNNCountMaxMinValue(const float* source, float* minVal, float* maxVa
}
*minVal = min_;
*maxVal = max_;
#else
auto sizeDiv4 = size / 4;
auto remain = size - 4 * sizeDiv4;
auto srcPtr = source;
auto max0 = vdupq_n_f32(srcPtr[0]);
auto min0 = vdupq_n_f32(srcPtr[0]);
while (sizeDiv4 > 15) {
sizeDiv4 -= 16;
auto data0 = vld1q_f32(srcPtr);
auto data1 = vld1q_f32(srcPtr + 4);
auto data2 = vld1q_f32(srcPtr + 8);
auto data3 = vld1q_f32(srcPtr + 12);
auto data4 = vld1q_f32(srcPtr + 16);
auto data5 = vld1q_f32(srcPtr + 20);
auto data6 = vld1q_f32(srcPtr + 24);
auto data7 = vld1q_f32(srcPtr + 28);
auto data8 = vld1q_f32(srcPtr + 32);
auto data9 = vld1q_f32(srcPtr + 36);
auto data10 = vld1q_f32(srcPtr + 40);
auto data11 = vld1q_f32(srcPtr + 44);
auto data12 = vld1q_f32(srcPtr + 48);
auto data13 = vld1q_f32(srcPtr + 52);
auto data14 = vld1q_f32(srcPtr + 56);
auto data15 = vld1q_f32(srcPtr + 60);
auto lmin0 = vminq_f32(data0, data1);
auto lmin2 = vminq_f32(data2, data3);
auto lmin4 = vminq_f32(data4, data5);
auto lmin6 = vminq_f32(data6, data7);
auto lmin8 = vminq_f32(data8, data9);
auto lmin10 = vminq_f32(data10, data11);
auto lmin12 = vminq_f32(data12, data13);
auto lmin14 = vminq_f32(data14, data15);
auto lmax0 = vmaxq_f32(data0, data1);
auto lmax2 = vmaxq_f32(data2, data3);
auto lmax4 = vmaxq_f32(data4, data5);
auto lmax6 = vmaxq_f32(data6, data7);
auto lmax8 = vmaxq_f32(data8, data9);
auto lmax10 = vmaxq_f32(data10, data11);
auto lmax12 = vmaxq_f32(data12, data13);
auto lmax14 = vmaxq_f32(data14, data15);
lmin0 = vminq_f32(lmin0, lmin2);
lmin4 = vminq_f32(lmin4, lmin6);
lmin8 = vminq_f32(lmin8, lmin10);
lmin12 = vminq_f32(lmin12, lmin14);
lmax0 = vmaxq_f32(lmax0, lmax2);
lmax4 = vmaxq_f32(lmax4, lmax6);
lmax8 = vmaxq_f32(lmax8, lmax10);
lmax12 = vmaxq_f32(lmax12, lmax14);
lmin0 = vminq_f32(lmin0, lmin8);
lmin4 = vminq_f32(lmin4, lmin12);
lmax0 = vmaxq_f32(lmax0, lmax8);
lmax4 = vmaxq_f32(lmax4, lmax12);
lmin0 = vminq_f32(lmin0, lmin4);
lmax0 = vmaxq_f32(lmax0, lmax4);
max0 = vmaxq_f32(max0, lmax0);
min0 = vminq_f32(min0, lmin0);
srcPtr += 64;
}
if (sizeDiv4 > 7) {
sizeDiv4 -= 8;
auto data0 = vld1q_f32(srcPtr);
auto data1 = vld1q_f32(srcPtr + 4);
auto data2 = vld1q_f32(srcPtr + 8);
auto data3 = vld1q_f32(srcPtr + 12);
auto data4 = vld1q_f32(srcPtr + 16);
auto data5 = vld1q_f32(srcPtr + 20);
auto data6 = vld1q_f32(srcPtr + 24);
auto data7 = vld1q_f32(srcPtr + 28);
auto lmin0 = vminq_f32(data0, data1);
auto lmin2 = vminq_f32(data2, data3);
auto lmin4 = vminq_f32(data4, data5);
auto lmin6 = vminq_f32(data6, data7);
auto lmax0 = vmaxq_f32(data0, data1);
auto lmax2 = vmaxq_f32(data2, data3);
auto lmax4 = vmaxq_f32(data4, data5);
auto lmax6 = vmaxq_f32(data6, data7);
lmin0 = vminq_f32(lmin0, lmin2);
lmin4 = vminq_f32(lmin4, lmin6);
lmax0 = vmaxq_f32(lmax0, lmax2);
lmax4 = vmaxq_f32(lmax4, lmax6);
lmin0 = vminq_f32(lmin0, lmin4);
lmax0 = vmaxq_f32(lmax0, lmax4);
max0 = vmaxq_f32(max0, lmax0);
min0 = vminq_f32(min0, lmin0);
srcPtr += 32;
}
if (sizeDiv4 > 3) {
sizeDiv4 -= 4;
auto data0 = vld1q_f32(srcPtr);
auto data1 = vld1q_f32(srcPtr + 4);
auto data2 = vld1q_f32(srcPtr + 8);
auto data3 = vld1q_f32(srcPtr + 12);
auto lmin0 = vminq_f32(data0, data1);
auto lmin2 = vminq_f32(data2, data3);
auto lmax0 = vmaxq_f32(data0, data1);
auto lmax2 = vmaxq_f32(data2, data3);
lmin0 = vminq_f32(lmin0, lmin2);
lmax0 = vmaxq_f32(lmax0, lmax2);
max0 = vmaxq_f32(max0, lmax0);
min0 = vminq_f32(min0, lmin0);
srcPtr += 16;
}
if (sizeDiv4 > 1) {
sizeDiv4 -= 2;
auto data0 = vld1q_f32(srcPtr);
auto data1 = vld1q_f32(srcPtr + 4);
auto lmin0 = vminq_f32(data0, data1);
auto lmax0 = vmaxq_f32(data0, data1);
max0 = vmaxq_f32(max0, lmax0);
min0 = vminq_f32(min0, lmin0);
srcPtr += 8;
}
if (sizeDiv4 > 0) {
sizeDiv4--;
auto data0 = vld1q_f32(srcPtr);
max0 = vmaxq_f32(max0, data0);
min0 = vminq_f32(min0, data0);
srcPtr += 4;
}
float temp0[4];
float temp1[4];
vst1q_f32(temp0, max0);
vst1q_f32(temp1, min0);
auto maxval = temp0[0];
auto minval = temp1[0];
for (int i = 1; i < 4; ++i) {
maxval = ALIMAX(maxval, temp0[i]);
minval = ALIMIN(minval, temp1[i]);
}
while (remain > 0) {
maxval = ALIMAX(maxval, srcPtr[0]);
minval = ALIMIN(minval, srcPtr[0]);
remain--;
srcPtr += 1;
}
minVal[0] = minval;
maxVal[0] = maxval;
#endif
}
#ifdef MNN_LOW_MEMORY
@ -389,18 +546,28 @@ static void MNNAsyQuantInfo_FP32(float* scale, float* bias, float* qscale, float
// dequant scale/bias : [EU, blockNum, step], step=ALIMIN(step, EP), EU=UP_DIV(plane, EP)
// quant scale/bias : [blockNum, plane]
#ifdef __aarch64__
if (DST_XUNIT == 12 && innerSide == 4) { // Arm82,fp32: SRC_UNIT=4, core->pack=4
if ((DST_XUNIT == 12 || DST_XUNIT == 16) && innerSide == 4) { // Arm82,fp32: SRC_UNIT=4, core->pack=4
// max,min shape: [blockNum, EP]
for (int i = 0; i < kernelsize; ++i) {
MNNLocalMinMaxFP32_Pack4(dstMin, dstMax, src + i * stride0, blockNum, blockLU, plane, innerSide, i);
}
// scale, bias
bool success = MNNAsyLocalQuantInfo_EP12_FP32(scale, bias, qscale, qbias, dstMin, dstMax, info);
if (!success) {
MNN_ERROR("Call error for:MNNAsyLocalQuantInfo_EP12\n");
if (DST_XUNIT == 12) {
bool success = MNNAsyLocalQuantInfo_EP12_FP32(scale, bias, qscale, qbias, dstMin, dstMax, info);
if (!success) {
MNN_ERROR("Call error for:MNNAsyLocalQuantInfo_EP12\n");
return;
}
return;
}
if (DST_XUNIT == 16) {
bool success = MNNAsyLocalQuantInfo_EP16_FP32(scale, bias, qscale, qbias, dstMin, dstMax, info);
if (!success) {
MNN_ERROR("Call error for:MNNAsyLocalQuantInfo_EP16_FP32\n");
return;
}
return;
}
return;
}
if (DST_XUNIT == 10) { // Arm86,fp32: SRC_UNIT=8,core->pack=4
// max,min shape: [blockNum, EP]
@ -653,6 +820,73 @@ static void MNNReorderWeightInt4Arm82(uint8_t* dest, const uint8_t* source, int3
}
MNNPermuteSumWeightInt4Arm82(dest, dest, blocknum * hu, lu, kernelsum);
}
static void MNNReorderWeightInt4Sme2(uint8_t* dest, const uint8_t* source, int32_t* shape, size_t size, float* kernelsum) {
MNN_ASSERT(size > 4);
// dst shape: [hu, blocknum, kernelCount, lu, hp, lp], kernelCount=1 in this case
auto blocknum = shape[0];
auto hu = shape[1];
auto lu = shape[2];
auto hp = shape[3];
auto lp = shape[4];
auto ic = blocknum *lu * lp;
auto stride0 = blocknum * hp * lu * lp;
auto stride1 = lu * hp * lp;
auto stride2 = hp * lp;
auto dstPtr = (int16_t*)dest;
auto srcPtr = (int16_t*)source;
int unitpacksize = sizeof(int16_t) / sizeof(uint8_t);
for (int i = 0; i < hu; ++i) {
for (int k = 0; k < hp; ++k) {
for (int bl = 0; bl < blocknum; ++bl) {
int j = 0;
while (j + 7 < lu) {
auto srcindex = ((i * hp + k) * ic + bl * (lu * lp) + j * lp) / unitpacksize;
auto dstindex0 = (bl * stride1 + i * stride0 + j * stride2 + k * lp) / unitpacksize;
auto dstindex1 = (bl * stride1 + i * stride0 + (j + 1) * stride2 + k * lp) / unitpacksize;
auto dstindex2 = (bl * stride1 + i * stride0 + (j + 2) * stride2 + k * lp) / unitpacksize;
auto dstindex3 = (bl * stride1 + i * stride0 + (j + 3) * stride2 + k * lp) / unitpacksize;
auto dstindex4 = (bl * stride1 + i * stride0 + (j + 4) * stride2 + k * lp) / unitpacksize;
auto dstindex5 = (bl * stride1 + i * stride0 + (j + 5) * stride2 + k * lp) / unitpacksize;
auto dstindex6 = (bl * stride1 + i * stride0 + (j + 6) * stride2 + k * lp) / unitpacksize;
auto dstindex7 = (bl * stride1 + i * stride0 + (j + 7) * stride2 + k * lp) / unitpacksize;
j += 8;
auto srcdata = vld1q_s16(srcPtr + srcindex);
vst1q_lane_s16(dstPtr + dstindex0, srcdata, 0);
vst1q_lane_s16(dstPtr + dstindex1, srcdata, 1);
vst1q_lane_s16(dstPtr + dstindex2, srcdata, 2);
vst1q_lane_s16(dstPtr + dstindex3, srcdata, 3);
vst1q_lane_s16(dstPtr + dstindex4, srcdata, 4);
vst1q_lane_s16(dstPtr + dstindex5, srcdata, 5);
vst1q_lane_s16(dstPtr + dstindex6, srcdata, 6);
vst1q_lane_s16(dstPtr + dstindex7, srcdata, 7);
}
while (j + 3 < lu) {
auto srcindex = ((i * hp + k) * ic + bl * (lu * lp) + j * lp) / unitpacksize;
auto dstindex0 = (bl * stride1 + i * stride0 + j * stride2 + k * lp) / unitpacksize;
auto dstindex1 = (bl * stride1 + i * stride0 + (j + 1) * stride2 + k * lp) / unitpacksize;
auto dstindex2 = (bl * stride1 + i * stride0 + (j + 2) * stride2 + k * lp) / unitpacksize;
auto dstindex3 = (bl * stride1 + i * stride0 + (j + 3) * stride2 + k * lp) / unitpacksize;
j += 4;
auto srcdata = vld1_s16(srcPtr + srcindex);
vst1_lane_s16(dstPtr + dstindex0, srcdata, 0);
vst1_lane_s16(dstPtr + dstindex1, srcdata, 1);
vst1_lane_s16(dstPtr + dstindex2, srcdata, 2);
vst1_lane_s16(dstPtr + dstindex3, srcdata, 3);
}
while (j < lu)
{
auto srcindex = ((i * hp + k) * ic + bl * (lu * lp) + j * lp) / 2;
auto dstindex = (bl * stride1 + i * stride0 + j * stride2 + k * lp) / 2;
dstPtr[dstindex] = srcPtr[srcindex];
j++;
}
}
}
}
MNNPermuteSumWeightInt4Sme2(dest, dest, blocknum * hu, lu, kernelsum);
}
#endif // __aarch64__
static void MNNSumWeightInt8(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP) {
@ -1147,9 +1381,11 @@ void MNNGetSparseMatMulPackMode(int* eP, int *lP, int* hP) {
return;
}
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose) {
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
// src: [h, kernelsize, ic]
auto hP = h / 4;
auto hR = hP * 4;
auto l = kernelsize * ic;
if (hR != h) {
::memset(dest, 0, UP_DIV(h, 4)*4*l*sizeof(float));
}
@ -3632,6 +3868,54 @@ static void generalIm2col(float* destOrigin, float const** sourceGroup, const in
}
#endif // MNN_LOW_MEMORY
#ifdef MNN_SME2
static void SME2MNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
*eP = 16;
*lP = 1;
*hP = 64;
}
static void MNNPackedMatMulFP32_SME2(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) {
MNNPackedMatMulRemainFP32_SME2(C, A, B, 16, parameter, postParameters, bias, k, b);
return;
}
static void Sme2MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h, size_t kernelsize, size_t ic, bool transpose) {
// src: [h, kernelsize, ic]
// dst: [h/hp, kernelsize, ic/lp, hp, lp]
auto dest = (int32_t*)destC;
auto source = (int32_t*)sourceC;
int LP = 1;
int HP = 64;
auto l = kernelsize * ic;
memset(dest, 0, ROUND_UP(h, HP) * ROUND_UP(ic, LP) * kernelsize * 4);
auto stride0 = kernelsize * ROUND_UP(ic, LP) * HP;
auto stride1 = ROUND_UP(ic, LP) * HP;
auto stride2 = HP * LP;
auto srcStride0 = l; // [h,l]->[hu,lu,hp,lp]
auto srcStride1 = 1;
if (!transpose) { // [l,h]->[hu,lu,hp,lp]
srcStride0 = 1;
srcStride1 = h;
}
for (int y = 0; y < h; ++y) {
auto yHu = y / HP;
auto yHp = y % HP;
for (int k = 0; k < kernelsize; ++k) {
for (int x = 0; x < ic; ++x) {
auto xLu = x / LP;
auto xLp = x % LP;
dest[yHu * stride0 + k * stride1 + xLu * stride2 + yHp * LP + xLp] = source[y * srcStride0 + (x + k * ic) * srcStride1];
}
}
}
}
static void Sme2MNNPackForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) {
const int32_t infosme2[4] = {info[0], info[1], 16, info[3]};
MNNPackC4ForMatMul_A(destOrigin, sourceGroup, infosme2, el);
return;
}
#endif
namespace MNN {
static CoreFunctions* gCoreFunction = nullptr;
@ -3742,19 +4026,19 @@ void MNNCoreFunctionInit() {
gCoreFunction->supportFp16arith = gCPUInfo.fp16arith;
gCoreFunction->supportSDot = gCPUInfo.dot;
gCoreFunction->supportI8mm = gCPUInfo.i8mm;
gCoreFunction->supportSME2 = gCPUInfo.sme2;
gCoreFunction->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A;
gCoreFunction->MNNReorderWeightInt4 = MNNReorderWeightInt4;
gCoreFunction->MNNSumWeightInt8 = MNNSumWeightInt8;
#ifdef __aarch64__
if (gCoreFunction->supportSDot) {
gCoreFunction->MNNReorderWeightInt4 = MNNReorderWeightInt4Arm82;
gCoreFunction->MNNSumWeightInt8 = MNNSumWeightInt8Arm82;
}
if (gCoreFunction->supportI8mm) {
gCoreFunction->MNNReorderWeightInt4 = MNNReorderWeightInt4Arm86;
gCoreFunction->MNNSumWeightInt8 = MNNSumWeightInt8Arm86;
}
if (gCoreFunction->supportSDot) {
gCoreFunction->MNNReorderWeightInt4 = MNNReorderWeightInt4Arm82;
gCoreFunction->MNNSumWeightInt8 = MNNSumWeightInt8Arm82;
}
if (gCoreFunction->supportI8mm) {
gCoreFunction->MNNReorderWeightInt4 = MNNReorderWeightInt4Arm86;
gCoreFunction->MNNSumWeightInt8 = MNNSumWeightInt8Arm86;
}
#endif
#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM
// Weight Dequant Gemm Kernels
@ -3778,6 +4062,42 @@ void MNNCoreFunctionInit() {
}
#endif
#endif
{ // backendMatmulRelatedFunctions
gCoreFunction->backendMatmulRelatedFunctions.MNNReorderWeightInt4 = gCoreFunction->MNNReorderWeightInt4;
gCoreFunction->backendMatmulRelatedFunctions.MNNSumWeightInt8 = gCoreFunction->MNNSumWeightInt8;
gCoreFunction->backendMatmulRelatedFunctions.MNNGeneralIm2Col = gCoreFunction->MNNGeneralIm2Col;
gCoreFunction->backendMatmulRelatedFunctions.MNNPackedMatMul = gCoreFunction->MNNPackedMatMul;
gCoreFunction->backendMatmulRelatedFunctions.MNNPackedMatMulRemain = gCoreFunction->MNNPackedMatMulRemain;
gCoreFunction->backendMatmulRelatedFunctions.MNNGetMatMulPackMode = gCoreFunction->MNNGetMatMulPackMode;
gCoreFunction->backendMatmulRelatedFunctions.MNNPackC4ForMatMul_A = gCoreFunction->MNNPackC4ForMatMul_A;
gCoreFunction->backendMatmulRelatedFunctions.MNNPackForMatMul_B = gCoreFunction->MNNPackForMatMul_B;
}
#ifdef __aarch64__
#ifdef MNN_SME2
if (gCoreFunction->supportSME2) {
gCoreFunction->MNNSumWeightInt8 = MNNSumWeightInt8Sme2;
gCoreFunction->MNNReorderWeightInt4 = MNNReorderWeightInt4Sme2;
gCoreFunction->MNNPackedMatMul = MNNPackedMatMulFP32_SME2;
gCoreFunction->MNNPackedMatMulRemain = MNNPackedMatMulRemainFP32_SME2;
gCoreFunction->MNNGetMatMulPackMode = SME2MNNGetMatMulPackMode;
gCoreFunction->MNNPackC4ForMatMul_A = Sme2MNNPackForMatMul_A;
gCoreFunction->MNNPackForMatMul_B = Sme2MNNPackForMatMul_B;
gCoreFunction->sme2MatmulRelatedFuncions.MNNSumWeightInt8 = MNNSumWeightInt8Sme2;
gCoreFunction->sme2MatmulRelatedFuncions.MNNReorderWeightInt4 = MNNReorderWeightInt4Sme2;
gCoreFunction->sme2MatmulRelatedFuncions.MNNPackedMatMul = MNNPackedMatMulFP32_SME2;
gCoreFunction->sme2MatmulRelatedFuncions.MNNPackedMatMulRemain = MNNPackedMatMulRemainFP32_SME2;
gCoreFunction->sme2MatmulRelatedFuncions.MNNGetMatMulPackMode = SME2MNNGetMatMulPackMode;
gCoreFunction->sme2MatmulRelatedFuncions.MNNPackC4ForMatMul_A = Sme2MNNPackForMatMul_A;
gCoreFunction->sme2MatmulRelatedFuncions.MNNPackForMatMul_B = Sme2MNNPackForMatMul_B;
#ifdef MNN_LOW_MEMORY
gCoreFunction->MNNGeneralIm2Col = MNNGeneralIm2col_Fp32Sme2;
gCoreFunction->sme2MatmulRelatedFuncions.MNNGeneralIm2Col = MNNGeneralIm2col_Fp32Sme2;
#endif // MNN_LOW_MEMORY
}
#endif // MNN_SME2
#endif // __aarch64__
MNNCoreInt8FunctionInit();
MNNFunctionInit();
}

View File

@ -19,10 +19,11 @@
#include "backend/cpu/compute/Int8FunctionsOpt.h"
extern "C" {
#ifdef MNN_LOW_MEMORY
#ifdef __aarch64__
#ifdef MNN_LOW_MEMORY
void MNNGeneralIm2col_Fp32Arm82(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
void MNNGeneralIm2col_Fp32Arm86(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
void MNNGeneralIm2col_Fp32Sme2(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
void MNNLocalMinMaxFP32_Pack4(float* dstMin, float* dstMax, const float* source, size_t blockNum, size_t blockLU, size_t EP, size_t LP, size_t loadDstBuffer);
void MNNLocalMinMaxFP32_Pack8(float* dstMin, float* dstMax, const float* source, size_t blockNum, size_t blockLU, size_t EP, size_t LP, size_t loadDstBuffer);
void MNNDynamicQuantFP32_Pack4(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, const float* bias, size_t pack);
@ -31,6 +32,10 @@ void MNNAbsMaxFP32_Pack4(const float* source, float* absmax, size_t src_depth_qu
void MNNAbsMaxFP32_Pack8(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack);
void MNNQuantScaleFP32(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch);
void MNNDynamicUpdateConvBiasScale(float* newbias, float* oldbias, float* weightKernelSum, float* inputZero, size_t ocQuad);
#endif
#ifdef MNN_SME2
void MNNPackedMatMulRemainFP32_SME2(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
#endif
#endif
void MNNFp32ToFp8(uint8_t* dst, const float* src, size_t size);
@ -134,7 +139,7 @@ el: number * 4
*/
void MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose);
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
// parameters: e, l, h, CStride, AStride, BStride
void MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
@ -197,8 +202,10 @@ struct SumByAxisParams {
#ifdef __aarch64__
void MNNPermuteSumWeightInt4Arm86(uint8_t* dest, uint8_t* source, size_t outside, size_t inside, float* kernlesum);
void MNNPermuteSumWeightInt4Arm82(uint8_t* dest, uint8_t* source, size_t outside, size_t inside, float* kernlesum);
void MNNPermuteSumWeightInt4Sme2(uint8_t* dest, uint8_t* source, size_t outside, size_t inside, float* kernlesum);
void MNNSumWeightInt8Arm86(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP);
void MNNSumWeightInt8Arm82(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP);
void MNNSumWeightInt8Sme2(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP);
#endif
}
@ -211,6 +218,28 @@ typedef void(*MNNBinaryExecInt8)(int8_t* outputRaw, const int8_t* inputRaw0, con
constexpr int InputTileMax = 14; // same value from DynamicGemm.h, cannot include from different backend code.
namespace MNN {
struct MatmulRelatedFunctions {
// coreFunctions
void (*MNNSumWeightInt8)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP) = nullptr;
void (*MNNReorderWeightInt4)(uint8_t* dest, const uint8_t* source, int32_t* shape, size_t size, float* kernelsum) = nullptr;
void (*MNNPackedMatMul)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) = nullptr;
void (*MNNPackedMatMulRemain)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) = nullptr;
void (*MNNGetMatMulPackMode)(int* eP, int* lP, int* hP) = nullptr;
void (*MNNPackC4ForMatMul_A)(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) = nullptr;
void (*MNNPackForMatMul_B)(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) = nullptr;
void(*MNNGeneralIm2Col)(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack) = nullptr;
// int8CoreFunctions
void(*Int8GemmKernel)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) = nullptr;
void(*Int8GemmKernelFast)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) = nullptr;
void(*MNNGetGemmUnit)(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) = nullptr;
void(*MNNPackC4Int8ForMatMul_A)(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el) = nullptr;
void(*MNNGemmInt8AddBiasScale_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr;
void(*MNNGemmInt8AddBiasScale_w4_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr;
void(*Int8GemmKernel_W4)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr;
void(*MNNSumByAxisLForMatmul_A)(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams) = nullptr;
};
struct CoreFunctions {
// fp8
void (*MNNFp32ToFp8)(uint8_t* dst, const float* src, size_t size);
@ -222,11 +251,12 @@ struct CoreFunctions {
bool supportFp16arith = false;
bool supportSDot = false;
bool supportI8mm = false;
bool supportSME2 = false;
/**MatMul Pack and Functions*/
void(*MNNGetMatMulPackMode)(int* eP, int *lP, int* hP);
void(*MNNGetSparseMatMulPackMode)(int* eP, int *lP, int* hP);
void(*MNNPackC4ForMatMul_A)(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);
void(*MNNPackForMatMul_B)(float* dest, const float* source, size_t h, size_t l, bool transpose);
void(*MNNPackForMatMul_B)(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
void(*MNNGeneralIm2Col)(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
// parameters: e, l, h, CStride, AStride, BStride
void(*MNNPackedMatMul)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
@ -365,6 +395,9 @@ struct CoreFunctions {
void(*MNNSumByAxisLForMatmul_A)(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams);
void(*MNNReorderWeightInt4)(uint8_t* dest, const uint8_t* source, int32_t* shape, size_t size, float* kernelsum);
void(*MNNSumWeightInt8)(float* kernlesum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP);
MatmulRelatedFunctions backendMatmulRelatedFunctions;
MatmulRelatedFunctions sme2MatmulRelatedFuncions;
};
void MNNCoreFunctionInit();
CoreFunctions* MNNGetCoreFunctions();

View File

@ -15,6 +15,7 @@
#include "core/Concurrency.h"
#include "core/TensorUtils.hpp"
#define QUANT_INFO_BYTES 4
namespace MNN {
@ -175,10 +176,10 @@ void ConvInt8TiledExecutor::packWeightAndQuantInfo(int8_t* dstbuffer, const int8
auto huPtr = dstbuffer + hU * blockNum * (stride1 + 2 * UNIT * infoBytes);
int scaleCount = ALIMIN(ocUp4 - hU * UNIT, UNIT);
for (int bl = 0; bl < blockNum; ++bl) {
auto blockPtr = huPtr + bl * (stride1 + 2 * scaleCount * infoBytes);
auto blockPtr = huPtr + bl * (stride1 + 2 * UNIT * infoBytes);
memcpy(blockPtr, src0 + bl * stride1 + hU * stride0, stride1);
memcpy(blockPtr + stride1, src1 + (bl * ocUp4 + hU * UNIT) * infoBytes, scaleCount * infoBytes);
memcpy(blockPtr + stride1 + scaleCount * infoBytes, src2 + (bl * ocUp4 + hU * UNIT) * infoBytes, scaleCount * infoBytes);
memcpy(blockPtr + stride1 + UNIT * infoBytes, src2 + (bl * ocUp4 + hU * UNIT) * infoBytes, scaleCount * infoBytes);
}
}
}
@ -209,15 +210,14 @@ static void _computeReorderQuantInfo(std::shared_ptr<CPUConvolution::ResourceInt
auto weightKernelSum = resource->mWeightKernelSum->host<float>();
::memset(weightKernelSum, 0, resource->mWeightKernelSum->size());
bool blockQuantInput = (resource->mWeightKernelSum->length(0) / QUANT_INFO_BYTES == ocUp4) ? false : true;
bool blockQuantInput = (resource->mWeightKernelSum->length(0) / QUANT_INFO_BYTES == ocUpHp) ? false : true;
int ocDiv4 = UP_DIV(outputCount, pack);
// dstkernelsum: [hU,blocknum,min(hP, pack)]
// resource->mWeightKernelSum: [hU,blocknum,hP]
if (quantCommon->asymmetric) {
for (int i = 0; i < outputCount; ++i) {
float accum = 0.f;
auto ocOutside = i / HP;
auto ocInside = i % HP;
int remain = ALIMIN(HP, ocUp4 - ocOutside * HP);
for (int j = 0; j < blockNum; ++j) {
int index = i * blockNum + j;
int srcSumIndex = ocOutside * blockNum * HP + j * HP + ocInside; // ikernelsum: [hU,blocknum,hP]
@ -229,7 +229,7 @@ static void _computeReorderQuantInfo(std::shared_ptr<CPUConvolution::ResourceInt
accum += ((ikernelSum[srcSumIndex] - blockSize * 8)* quanInfoPtr[2 * index + 1] + blockSize * quanInfoPtr[2 * index]);
}
if (blockQuantInput) {
int dstSumIndex = ocOutside * blockNum * HP + j * remain + ocInside;
int dstSumIndex = ocOutside * blockNum * HP + j * HP + ocInside;
weightKernelSum[dstSumIndex] = accum;
accum = 0;
}
@ -243,7 +243,6 @@ static void _computeReorderQuantInfo(std::shared_ptr<CPUConvolution::ResourceInt
float accum = 0.f;
auto ocOutside = i / HP;
auto ocInside = i % HP;
int remain = ALIMIN(HP, ocUp4 - ocOutside * HP);
for (int j = 0; j < blockNum; ++j) {
int index = i * blockNum + j;
int srcSumIndex = ocOutside * blockNum * HP + j * HP + ocInside; // ikernelsum: [hU,blocknum,hP]
@ -255,7 +254,7 @@ static void _computeReorderQuantInfo(std::shared_ptr<CPUConvolution::ResourceInt
accum += ((ikernelSum[srcSumIndex] - blockSize * 8) * quanInfoPtr[index]);
}
if (blockQuantInput) {
int dstSumIndex = ocOutside * blockNum * HP + j * remain + ocInside;
int dstSumIndex = ocOutside * blockNum * HP + j * HP + ocInside;
weightKernelSum[dstSumIndex] = accum;
accum = 0;
}
@ -286,12 +285,17 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
// backend info
auto core = static_cast<CPUBackend*>(backend)->int8Functions();
auto gcore = static_cast<CPUBackend*>(backend)->functions();
const int threads = static_cast<CPUBackend*>(backend)->threadNumber();
mRelatedFunctions = *(static_cast<CPUBackend*>(backend)->int8GemmFunctions());
int UNIT, SRC_UNIT, DST_XUNIT;
core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
mRelatedFunctions.MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
int pack = gcore->pack;
// compute info
int ocUp4 = ROUND_UP(oc, pack);
int ocUpHp = ROUND_UP(oc, ALIMAX(UNIT, pack));
int lU = UP_DIV(ic / blockNum, SRC_UNIT) * kernelCount;
int scaleSize = ocUp4 * blockNum;
std::vector<int> shape = {blockNum, UP_DIV(oc, UNIT), lU, UNIT, SRC_UNIT};
@ -316,15 +320,15 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
}
}
// buffer allocate
auto quantlen = 2 * blockNum * ROUND_UP(oc, pack) * QUANT_INFO_BYTES;
auto quantlen = 2 * blockNum * ROUND_UP(oc, UNIT) * QUANT_INFO_BYTES;
auto weightlen = shape[0] * shape[1] * shape[2] * shape[3] * shape[4];
mResourceInt8->mWeightInt8.reset(Tensor::createDevice<uint8_t>({weightlen + quantlen}));
mResourceInt8->mOriginBias.reset(Tensor::createDevice<int32_t>({ocUp4})); // float
mResourceInt8->mOriginBias.reset(Tensor::createDevice<int32_t>({ocUpHp})); // float
auto dynamicOption = static_cast<CPUBackend*>(backend)->getRuntime()->hint().dynamicQuantOption; // input ic block quant.
if (dynamicOption != 2) {
mResourceInt8->mWeightKernelSum.reset(Tensor::createDevice<uint8_t>({QUANT_INFO_BYTES * ocUp4}));
mResourceInt8->mWeightKernelSum.reset(Tensor::createDevice<uint8_t>({QUANT_INFO_BYTES * ocUpHp}));
} else {
mResourceInt8->mWeightKernelSum.reset(Tensor::createDevice<uint8_t>({blockNum * QUANT_INFO_BYTES * ocUp4}));
mResourceInt8->mWeightKernelSum.reset(Tensor::createDevice<uint8_t>({blockNum * QUANT_INFO_BYTES * ocUpHp}));
}
auto res = backend->onAcquireBuffer(mResourceInt8->mOriginBias.get(), Backend::STATIC);
@ -341,7 +345,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
}
// read weight, weight's scale&bias, convolution bias
::memset(mResourceInt8->mOriginBias->host<float>(), 0, ocUp4 * sizeof(float));
::memset(mResourceInt8->mOriginBias->host<float>(), 0, ocUpHp * sizeof(float));
if (!isDynamicQuant) {
mResourceInt8->mDynamicQuant = false;
@ -375,7 +379,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
// reorder weight&scale&bias
int bytes = 4;
auto resourceInt8 = mResourceInt8;
auto reorderFunction = [weightNotPacked, weightlen, oc, ic, kernelCount, UNIT, SRC_UNIT, shape, resourceInt8, bytes, L, ocUp4, quantlen, scaleAndBias]()
auto reorderFunction = [weightNotPacked, weightlen, oc, ic, kernelCount, UNIT, SRC_UNIT, shape, resourceInt8, bytes, L, ocUp4, scaleAndBias, pack]()
{
AutoStorage<uint8_t> weightReordered(weightlen);
if (!weightReordered.get()) {
@ -384,22 +388,22 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
}
int32_t info[6] = {1, oc, ic, kernelCount, UNIT, SRC_UNIT};
ConvInt8TiledExecutor::reorderWeight((uint8_t*)weightReordered.get(), (uint8_t*)weightNotPacked, info, 0);
int32_t params[6] = {shape[0], shape[1], shape[2], shape[3], shape[4], quantlen/ (2 * QUANT_INFO_BYTES * 1)};
int32_t params[6] = {shape[0], shape[1], shape[2], shape[3], shape[4], ROUND_UP(oc, pack)};
ConvInt8TiledExecutor::packWeightAndQuantInfo(resourceInt8->mWeightInt8->host<int8_t>(), (int8_t*)weightReordered.get(), (int8_t*)scaleAndBias.get(), params, QUANT_INFO_BYTES);
return 0;
};
static_cast<CPUBackend*>(backend)->enqueueTask(reorderFunction);
// gemmInt8 kernel
mGemmKernel = core->Int8GemmKernel;
mGemmKernel = mRelatedFunctions.Int8GemmKernel;
#ifdef MNN_USE_SSE
int actBits = convOp->symmetricQuan()->nbits();
if (actBits <= 7) {
mGemmKernel = core->Int8GemmKernelFast;
mGemmKernel = mRelatedFunctions.Int8GemmKernelFast;
}
#else
if(convOp->symmetricQuan()->method() == QuantizeAlgo_OVERFLOW_AWARE){
mGemmKernel = core->Int8GemmKernelFast;
mGemmKernel = mRelatedFunctions.Int8GemmKernelFast;
}
#endif
@ -409,11 +413,12 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
// dynamic quant
bool directReadInt4weight = (kernelCount == 1 && ROUND_UP(oc, UNIT) == oc && ROUND_UP(ic, SRC_UNIT) == ic);
auto target = mResourceInt8;
auto funcs = mRelatedFunctions;
// Save bias
if (convOp->bias()) {
::memcpy(mResourceInt8->mOriginBias->host<float>(), convOp->bias()->data(), oc * sizeof(float));
}
auto function = [shape, UNIT, SRC_UNIT, quanCommon, weightlen, scaleSize, directReadInt4weight, blockNum, ic, oc, quantlen, kernelCount, pack, convOp, gcore, target]() -> int {
auto function = [funcs, shape, UNIT, SRC_UNIT, DST_XUNIT, quanCommon, weightlen, scaleSize, directReadInt4weight, blockNum, ic, oc, kernelCount, pack, convOp, gcore, target]() -> int {
auto sh = shape;
AutoStorage<int8_t> weightReordered(weightlen);
AutoStorage<int8_t> reorderedQuantInfo(2 * scaleSize * QUANT_INFO_BYTES);
@ -427,7 +432,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
auto srcPtr = (uint8_t*)quanCommon->weight.get();
auto dstPtr = (uint8_t*)weightReordered.get();
::memset(dstPtr, 0, weightlen);
gcore->MNNReorderWeightInt4(dstPtr, srcPtr, sh.data(), sh.size(), (float*)kernelsum.get());
funcs.MNNReorderWeightInt4(dstPtr, srcPtr, sh.data(), sh.size(), (float*)kernelsum.get());
} else { // int4 weight but oc/ic not packed
auto weightLength = quanCommon->weight.size();
int blocksize = ic * kernelCount / blockNum;
@ -452,7 +457,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
MNN_ERROR("Weight reorder memory not enough!\n");
return -1;
}
reorderWeight(packedInt8weight.get(), (uint8_t*)tmpWeight.data(), info, 0, (float*)kernelsum.get(), gcore->MNNSumWeightInt8);
reorderWeight(packedInt8weight.get(), (uint8_t*)tmpWeight.data(), info, 0, (float*)kernelsum.get(), funcs.MNNSumWeightInt8);
// pack two int4 to int8
int leng = weightlen * 2;
auto srcint4Ptr = (uint8_t*)packedInt8weight.get();
@ -463,14 +468,21 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
auto src0 = srcint4Ptr + i * permuteUnit;
auto dst0 = dstint4Ptr + i * halfPermuteStride;
for (int j = 0; j < halfPermuteStride; ++j) {
int s0 = src0[j];
int s1 = src0[j + halfPermuteStride];
int d = (s0) * 16 + (s1);
int s0, s1, d;
if (UNIT == 16 && SRC_UNIT == 4 && DST_XUNIT == 16) { // SME2
s0 = src0[2 * j + 0];
s1 = src0[2 * j + 1];
d = s0 + (s1) * 16;
} else {
s0 = src0[j];
s1 = src0[j + halfPermuteStride];
d = (s0) * 16 + (s1);
}
dst0[j] = d;
}
}
} else { // int8 weight
reorderWeight((uint8_t*)weightReordered.get(), (uint8_t*)quanCommon->weight.get(), info, 0, (float*)kernelsum.get(), gcore->MNNSumWeightInt8);
reorderWeight((uint8_t*)weightReordered.get(), (uint8_t*)quanCommon->weight.get(), info, 0, (float*)kernelsum.get(), funcs.MNNSumWeightInt8);
}
}
/* 2. compute and order dequant scale&bias */
@ -480,7 +492,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
}
_computeReorderQuantInfo(target, quanCommon, oc, kernelCount * ic, pack, reorderedQuantInfo, (float*)kernelsum.get(), UNIT, notConvertInt4ToInt8);
/* 3. put weight and quantInfo together */
int32_t params[6] = {shape[0], shape[1], shape[2], shape[3], shape[4], quantlen / (2 * QUANT_INFO_BYTES * blockNum)};
int32_t params[6] = {shape[0], shape[1], shape[2], shape[3], shape[4], ROUND_UP(oc, pack)};
ConvInt8TiledExecutor::packWeightAndQuantInfo(target->mWeightInt8->host<int8_t>(), (int8_t*)weightReordered.get(), reorderedQuantInfo.get(), params, QUANT_INFO_BYTES);
return 0;
@ -508,16 +520,11 @@ bool DenseConvInt8TiledExecutor::onClone(Backend* bn, const Op* op, Execution**
return true;
}
void DenseConvInt8TiledExecutor::getPackParameter(int* Unit, int* srcUnit, int* DestUnit, const CoreInt8Functions* core) {
core->MNNGetGemmUnit(Unit, srcUnit, DestUnit);
}
ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
// Initialize.
mUseBatchQuan = false;
mIm2ColBasedInt8 = true;
auto option = static_cast<CPUBackend*>(backend())->getRuntime()->hint().dynamicQuantOption;
int batch = inputs[0]->batch();
int inC = inputs[0]->channel();
@ -526,8 +533,12 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector<Tensor*>& input
auto planeSize = output->width() * output->height() * output->batch();
auto core = static_cast<CPUBackend*>(backend())->int8Functions();
auto gcore =static_cast<CPUBackend*>(backend())->functions();
const int threads = static_cast<CPUBackend*>(backend())->threadNumber();
mRelatedFunctions = *(static_cast<CPUBackend*>(backend())->int8GemmFunctions());
int UNIT, SRC_UNIT, DST_XUNIT;
core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
mRelatedFunctions.MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
int kernelCount = mCommon->kernelY() * mCommon->kernelX();
bool fastway = (kernelCount == 1) && (output->width() == inputs[0]->width()) && (output->height() == inputs[0]->height()) && (mCommon->strideX() * mCommon->strideY()) == 1;
if (inputPlane > 1) {
@ -563,9 +574,9 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector<Tensor*>& input
mIm2ColBasedInt8 = true;
mUseBatchQuan = false;
}
ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParamter, mCommon, inputs[0], outputs[0], mPadX, mPadY, gcore, core);
int matmulUnits[3] = {UNIT, SRC_UNIT, DST_XUNIT};
ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParamter, mCommon, inputs[0], outputs[0], mPadX, mPadY, gcore, core, gcore->pack, matmulUnits);
// input scale buffer
const int threads = static_cast<CPUBackend*>(backend())->threadNumber();
// Im2col info
int im2colBytes = 1;
@ -668,22 +679,22 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector<Tensor*>& input
#ifdef MNN_LOW_MEMORY
{ // Dynamic Quant kernels
mGemmKernel = core->Int8GemmKernel;
mGemmKernel = mRelatedFunctions.Int8GemmKernel;
if (mResourceInt8->mActBits == 4) {
mGemmKernel = core->Int8GemmKernel_W4;
mGemmKernel = mRelatedFunctions.Int8GemmKernel_W4;
}
mQuantFunc = core->MNNFloat2Int8;
if (gcore->bytes == 2 && gcore->pack == 8) {
mGemmKernel = core->MNNGemmInt8AddBiasScale_Unit_FP16;
mGemmKernel = mRelatedFunctions.MNNGemmInt8AddBiasScale_Unit_FP16;
if (mResourceInt8->mActBits == 4) {
mGemmKernel = core->MNNGemmInt8AddBiasScale_w4_Unit_FP16;
mGemmKernel = mRelatedFunctions.MNNGemmInt8AddBiasScale_w4_Unit_FP16;
}
mQuantFunc = core->DynamicQuanInput_ARM82;
mQuantAndReorderFunc = core->DynamicQuanInputAndReorder_ARM82;
}
// A axisSum kernel
mSumByAxisLFunc = gcore->MNNSumByAxisLForMatmul_A;
mSumByAxisLFunc = mRelatedFunctions.MNNSumByAxisLForMatmul_A;
}
mInputBlockNum = (option == 2) ? mBlockNum : 1;
@ -804,8 +815,8 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
auto dynamicOption = static_cast<CPUBackend*>(backend())->getRuntime()->hint().dynamicQuantOption;
int UNIT, SRC_UNIT, DST_XUNIT;
core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
auto blitProc = core->MNNPackC4Int8ForMatMul_A;
mRelatedFunctions.MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
auto blitProc = mRelatedFunctions.MNNPackC4Int8ForMatMul_A;
const int plane = output->batch() * mIm2ColParamter.oh * mIm2ColParamter.ow;
const int batch = input->batch();
const int PackUnit = gcore->pack;
@ -960,6 +971,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
sumParams.kernelxy = kxky;
sumParams.LU = UP_DIV(inputchannel, SRC_UNIT);
sumParams.inputBlock = (mInputBlockNum > 1) ? 1 : 0;
std::vector<float> fakeInputScales(DST_XUNIT, 1.f);
auto tileSplitFunction = [&](int tId, int eStartIndex, int eEndIndex, int estep) {
auto ocDivThread = ocDiv4;
@ -979,20 +991,30 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
inputBias = inputScale + mBatchQuantInfo->stride(0) / 2;
ptrInputBias = inputBias;
}
} else {
inputScale = (uint8_t*)fakeInputScales.data();
ptrInputScale = inputScale;
}
if (mBlockNum > 1) {
accumbuff = reinterpret_cast<float*>(mAccumBuffer->host<int8_t>() + tId * mAccumBuffer->stride(0) * sizeof(int32_t));
}
float* ptrY = nullptr;
if ((dstBytes != 1)) {
if (dstBytes != 1) {
ptrY = mResourceInt8->mWeightKernelSum->host<float>();
}
QuanPostTreatParameters quanParam;
quanParam.blockNum = mBlockNum;
quanParam.weightKernelSum = ptrY;
int32_t indices[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
quanParam.indices = indices;
if (dstBytes != 1) {
quanParam.useInt8 = 0;
quanParam.fp32minmax = reluPtr;
#ifdef MNN_USE_SSE
if (!mBatchQuantInfo.get()) {
quanParam.weightKernelSum = nullptr;
}
#endif
} else {
quanParam.maxValue = mMutableResource->mClampMax;
if (mResourceInt8->mRelu) {
@ -1045,7 +1067,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
memset(im2colDst, 0, mTempIm2ColBuffer->stride(0));
}
info[2] = realDstCount;
gcore->MNNGeneralIm2Col((float*)im2colDst, (float const**)srcPtr, info, el, SRC_UNIT, PackUnit); // im2colDst: [lu, realDstCount, lp]
mRelatedFunctions.MNNGeneralIm2Col((float*)im2colDst, (float const**)srcPtr, info, el, SRC_UNIT, PackUnit); // im2colDst: [lu, realDstCount, lp]
}
ptrInputScale = mBatchQuantInfo->host<uint8_t>() + tId * mBatchQuantInfo->stride(0);
if (dynamicOption == 2) {
@ -1079,7 +1101,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
#endif
if (mResourceInt8->mWeightAsymmetricQuant) {
MNN_ASSERT(mBatchQuantInfo.get() && mBatchQuantInfo->host<float>());
gcore->MNNSumByAxisLForMatmul_A(xKernelSumPtrTid, im2colDst, (float*)ptrInputScale, realDstCount, sumParams);
mRelatedFunctions.MNNSumByAxisLForMatmul_A(xKernelSumPtrTid, im2colDst, (float*)ptrInputScale, realDstCount, sumParams);
} else {
memset(xKernelSumPtrTid, 0, mBlockNum * DST_XUNIT * mIm2ColCount * QUANT_INFO_BYTES);
}
@ -1094,7 +1116,6 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
}
quanParam.srcKernelSum = ptrX;
mGemmKernel(outputInTilePtr, im2colDst, weightPtrTid, blockL, dstZStep * dstBytes, ocDivThread, &quanParam, step);
ptrX += (step * mBlockNum);
realDstCount-=step;
outputInTilePtr += DST_XUNIT * PackUnit * dstBytes;
@ -1155,7 +1176,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
info[0] = number;
info[2] = work;
if (number > 0) { // im2col
gcore->MNNGeneralIm2Col((float*)im2colDstTmp, (float const**)srcPtr, info, el, SRC_UNIT, PackUnit); // im2colDst: [lu, realDstCount, lp]
mRelatedFunctions.MNNGeneralIm2Col((float*)im2colDstTmp, (float const**)srcPtr, info, el, SRC_UNIT, PackUnit); // im2colDst: [lu, realDstCount, lp]
}
if (mUseBatchQuan || dynamicOption == 2) {
if (dynamicOption == 2) {
@ -1192,7 +1213,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
#endif
if (mResourceInt8->mWeightAsymmetricQuant) {
MNN_ASSERT(mBatchQuantInfo.get() && mBatchQuantInfo->host<float>());
gcore->MNNSumByAxisLForMatmul_A(xKernelSumPtr, im2colDst, mBatchQuantInfo->host<float>(), plane, sumParams);
mRelatedFunctions.MNNSumByAxisLForMatmul_A(xKernelSumPtr, im2colDst, mBatchQuantInfo->host<float>(), plane, sumParams);
} else {
memset(xKernelSumPtr, 0, mTileCount * mBlockNum * DST_XUNIT * mIm2ColCount * QUANT_INFO_BYTES);
}
@ -1211,9 +1232,16 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
quanParam.blockNum = mBlockNum;
quanParam.weightKernelSum = ptrY;
quanParam.biasFloat = reinterpret_cast<float*>(biasPtr + ocIndex * 4);
int32_t indices[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
quanParam.indices = indices;
if (dstBytes != 1) {
quanParam.useInt8 = 0;
quanParam.fp32minmax = reluPtr;
#ifdef MNN_USE_SSE
if (!mBatchQuantInfo.get()) {
quanParam.weightKernelSum = nullptr;
}
#endif
} else {
quanParam.maxValue = mMutableResource->mClampMax;
if (mResourceInt8->mRelu) {
@ -1230,6 +1258,8 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
if (dynamicOption == 2) {
inputBias = inputScale + mInputBlockNum * plane * QUANT_INFO_BYTES;
}
} else {
inputScale = (uint8_t*)fakeInputScales.data();
}
if (mBlockNum > 1) {
accumbuff = reinterpret_cast<float*>(mAccumBuffer->host<int8_t>() + tId * mAccumBuffer->stride(0) * sizeof(int32_t));

View File

@ -23,7 +23,6 @@ public:
virtual ~ConvInt8TiledExecutor();
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
virtual void getPackParameter(int* Unit, int* SrcUnit, int* DestUnit, const CoreInt8Functions* core) = 0;
static void packWeightAndQuantInfo(int8_t* dstbuffer, const int8_t* weight, const int8_t* quantInfo, int32_t* info, int infoBytes = 4);
static void reorderWeight(uint8_t* dst, const uint8_t* src, int32_t* info, int32_t initval = 0, float* kernelsum = nullptr, weightSummerFuncion summerFunc = nullptr);
static void initializeConvInt8QuantInfo(std::shared_ptr<CPUConvolution::ResourceInt8>& resourceInt8, const Convolution2D* conv2D);
@ -56,7 +55,6 @@ public:
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
void getPackParameter(int* Unit, int* SrcUnit, int* DestUnit, const CoreInt8Functions* core) override;
private:
DenseConvInt8TiledExecutor(Backend* backend, const Op* op, const DenseConvInt8TiledExecutor& exe);
@ -84,6 +82,7 @@ private:
bool mIm2ColBasedInt8;
int mSizeInputBlockQuant;
bool mToFuseInputbias2Bias;
MatmulRelatedFunctions mRelatedFunctions;
};
} // namespace MNN

View File

@ -519,6 +519,7 @@ ErrorCode ConvInt8Winograd::WinoExecution::onExecute(const std::vector<Tensor *>
auto weight = mWinoResource->weight->host<int8_t>();
std::vector<float> xkernelSum(DST_XUNIT, 0);
std::vector<float> wKernelSum(dc_4 * pack, 0);
std::vector<float> fakeInputScale(DST_XUNIT, 1.f);
std::vector<float> reluThred = {-std::numeric_limits<float>().max(), std::numeric_limits<float>().max()};
auto tFunction = [&](int tId) {
@ -560,7 +561,7 @@ ErrorCode ConvInt8Winograd::WinoExecution::onExecute(const std::vector<Tensor *>
quanParam.biasFloat = (mWinoResource->offsets->host<float>() + i * mWinoResource->offsets->stride(0));
quanParam.scale = mWinoResource->scales->host<float>() + i * dc_4 * pack;
quanParam.inputScale = nullptr;
quanParam.inputScale = fakeInputScale.data();
quanParam.bias = nullptr;
quanParam.blockNum = 1;
gemmFunc((int8_t*)_dstFloatPtr, _srcInt8Ptr, _weightInt8Ptr, mTempInputBuffer->length(2), xC * pack * sizeof(float), dc_4, &quanParam, DST_XUNIT);

View File

@ -52,10 +52,10 @@ Convolution1x1Strassen::Convolution1x1Strassen(const Convolution2DCommon *common
return;
}
core->MNNFp32ToLowp(originWeight, tempTensor->host<int16_t>(), outputCount * mSrcCount);
core->MNNPackForMatMul_B(mResource->mWeight->host<float>(), tempTensor->host<float>(), outputCount, mSrcCount, true);
core->MNNPackForMatMul_B(mResource->mWeight->host<float>(), tempTensor->host<float>(), outputCount, 1, mSrcCount, true);
b->onReleaseBuffer(tempTensor.get(), Backend::STATIC);
} else {
core->MNNPackForMatMul_B(mResource->mWeight->host<float>(), originWeight, outputCount, mSrcCount, true);
core->MNNPackForMatMul_B(mResource->mWeight->host<float>(), originWeight, outputCount, 1, mSrcCount, true);
}
}
Convolution1x1Strassen::Convolution1x1Strassen(std::shared_ptr<CPUConvolution::Resource> resource, const Convolution2DCommon *common, Backend* b) : CPUConvolution(common, b) {

View File

@ -114,7 +114,7 @@ ErrorCode ConvolutionPackFreeWinograd::onExecute(const std::vector<Tensor *> &in
std::vector<size_t> parameters(7);
parameters[0] = eRemain * bytes;
parameters[1] = input->channel();
parameters[1] = ROUND_UP(input->channel(), lPack);
parameters[2] = output->channel();
parameters[3] = ePack * pack * bytes;
parameters[4] = 0;

View File

@ -281,7 +281,7 @@ ErrorCode ConvolutionPackWinograd::onResize(const std::vector<Tensor *> &inputs,
}
int eRemain = (tFin-tSta) % ePack;
std::vector<size_t> parameters(6);
parameters[1] = input->channel();
parameters[1] = ROUND_UP(input->channel(), lPack);
parameters[2] = output->channel();
parameters[4] = 0;
parameters[5] = 0;

View File

@ -96,7 +96,8 @@ void ConvolutionTiledExecutor:: setIm2ColParameter(ConvolutionCommon::Im2ColPara
if (pack == 0) {
pack = floatCore->pack;
}
int EP, LP, HP;
floatCore->MNNGetMatMulPackMode(&EP, &LP, &HP);
const auto kernelCount = convCommon->kernelX() * convCommon->kernelY();
dstIm2ColParamter.dilateX = convCommon->dilateX();
@ -116,8 +117,8 @@ void ConvolutionTiledExecutor:: setIm2ColParameter(ConvolutionCommon::Im2ColPara
dstIm2ColParamter.srcZStep = input->stride(1) * pack * input->batch();
dstIm2ColParamter.srcYStep = input->stride(2) * pack;
dstIm2ColParamter.packCUnit = pack;
dstIm2ColParamter.ic = input->channel();
dstIm2ColParamter.icup4 = input->channel(); // for float im2col.
dstIm2ColParamter.ic = ROUND_UP(input->channel(), LP);
dstIm2ColParamter.icup4 = ROUND_UP(input->channel(), ALIMIN(LP, pack)); // for float im2col.
if (nullptr != int8Core) {
// Compute Int8 Info and align ic
int UNIT, SRC_UNIT, DynamicDestUnit;

View File

@ -24,7 +24,7 @@ namespace MNN {
void DenseConvolutionTiledExecutor::initWeight(float *dest, const float *source, float* cache, int depth, int outputCount, int kernelSize, const CoreFunctions* function) {
ConvolutionTiledExecutor::initWeight(source, cache, depth, outputCount, kernelSize, function);
function->MNNPackForMatMul_B(dest, cache, outputCount, kernelSize * depth, true);
function->MNNPackForMatMul_B(dest, cache, outputCount, kernelSize, depth, true);
}
bool DenseConvolutionTiledExecutor::initQuantizeResource(std::shared_ptr<ConvolutionCommon::Int8Common> int8Info, std::shared_ptr<CPUConvolution::Resource> resource, int hU, int hP, int lU, int lP, int outputCount, int srcChannel, int kernelSize, int bytes) {
@ -177,7 +177,7 @@ DenseConvolutionTiledExecutor::DenseConvolutionTiledExecutor(const Convolution2D
auto srcCount = (int)originWeightSize / outputCount / common->kernelX() / common->kernelY();
auto lSize = srcCount * common->kernelX() * common->kernelY();
auto hU = UP_DIV(outputCount, hP);
auto lU = UP_DIV(lSize, lP);
auto lU = UP_DIV(srcCount, lP) * common->kernelX() * common->kernelY();
if (useInt8Weight) {
// Quantize weight to int8
auto allocSuccess = DenseConvolutionTiledExecutor::initQuantizeResource(int8Info, mResource, hU, hP, lU, lP, outputCount, srcCount, common->kernelX() * common->kernelY(), bytes);
@ -280,7 +280,7 @@ ErrorCode ConvolutionTiledExecutorMultiInput::onExecute(const std::vector<Tensor
MNNTranspose32Bit((int32_t*)dO, (const int32_t*)sO, &dims[0]);
}
}
function->MNNPackForMatMul_B(mTempWeight->host<float>(), mTempWeightCache->host<float>(), outputCount, kernelSize * depth, true);
function->MNNPackForMatMul_B(mTempWeight->host<float>(), mTempWeightCache->host<float>(), outputCount, kernelSize, depth, true);
return mProxy->onExecute(mInputs, outputs);
}
ErrorCode ConvolutionTiledExecutorMultiInput::onResize(const std::vector<Tensor*>& inputs,
@ -292,7 +292,7 @@ ErrorCode ConvolutionTiledExecutorMultiInput::onResize(const std::vector<Tensor*
function->MNNGetMatMulPackMode(&eP, &lP, &hP);
auto kernelSize = depth * inputs[1]->stride(1);
mTempWeight.reset(Tensor::createDevice<float>(
{UP_DIV(outputCount, hP), UP_DIV(kernelSize, lP), lP * hP}));
{UP_DIV(outputCount, hP), UP_DIV(depth, lP) * inputs[1]->stride(1), lP * hP}));
if (function->bytes < 4) {
mTempWeightCache.reset(Tensor::createDevice<int32_t>({2, outputCount * kernelSize}));
} else {
@ -304,10 +304,11 @@ ErrorCode ConvolutionTiledExecutorMultiInput::onResize(const std::vector<Tensor*
if (!res) {
return OUT_OF_MEMORY;
}
if (inputs.size() > 2 && inputs[2]->elementSize() % function->pack == 0) {
if (inputs.size() > 2 && inputs[2]->elementSize() % hP == 0) {
mInputs = {inputs[0], mTempWeight.get(), inputs[2]};
} else {
mTempBias.reset(Tensor::createDevice<float>({UP_DIV(outputCount, function->pack) * function->pack}));
auto hPackedSize = ALIMAX(hP, function->pack);
mTempBias.reset(Tensor::createDevice<float>({UP_DIV(outputCount, hPackedSize) * hPackedSize}));
backend()->onAcquireBuffer(mTempBias.get(), Backend::DYNAMIC);
mInputs = {inputs[0], mTempWeight.get(), mTempBias.get()};
}
@ -445,7 +446,7 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector<Tensor*>& inputs
const uint8_t* dequantBias = nullptr;
auto ic = input->channel();
auto icC4 = UP_DIV(ic, unit);
auto L = ic * mCommon->kernelY() * mCommon->kernelX();
auto L = ROUND_UP(ic, lP) * mCommon->kernelY() * mCommon->kernelX();
auto tileC = std::max(unit, hP);
int blockSize = L;
int blockNum = 1;
@ -677,7 +678,14 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector<Tensor*>& inputs
if (number > 0) {
packA((float *)gemmBuffer, srcPtr, info, el);
}
/*
for (int kk=0; kk < mIm2ColParameters.kernelX * mIm2ColParameters.kernelY; ++kk) {
for (int xx=0; xx < ROUND_UP(input->channel(), lP) * eP; ++xx) {
printf("%f ", ((__fp16*)gemmBuffer)[kk * ROUND_UP(input->channel(), lP) * eP + xx]);
if (xx % (eP * lP) == (eP * lP -1)) printf("\n");
}
}
*/
int finishedL = 0;
int wquantStride = 0;
int8_t* _weightPtr = reinterpret_cast<int8_t*>(weightPtr);

View File

@ -40,6 +40,7 @@ IdstConvolutionInt8::IdstConvolutionInt8(const Convolution2DCommon* convOp, Back
core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
int PackUnit = static_cast<CPUBackend*>(b)->functions()->pack;
int ocUp4 = ROUND_UP(biasSize, PackUnit);
int ocUpHp = ROUND_UP(biasSize, UNIT);
mBias.reset(ocUp4);
mBias.clear();
auto biasDest = mBias.get();
@ -66,9 +67,9 @@ IdstConvolutionInt8::IdstConvolutionInt8(const Convolution2DCommon* convOp, Back
auto srcCount = mSrcCount;
std::vector<int> shape;
shape = {1, UP_DIV(outputCount, UNIT), UP_DIV(srcCount, SRC_UNIT) * kernelCount, UNIT, SRC_UNIT};
mFakeBias.reset(Tensor::createDevice<float>({ocUp4}));
mFakeBias.reset(Tensor::createDevice<float>({ocUpHp}));
int weightlen = shape[0] * shape[1] * shape[2] * shape[3] * shape[4];
int quantlen = 2 * ocUp4 * QUANT_INFO_BYTES;
int quantlen = 2 * ocUpHp * QUANT_INFO_BYTES;
mWeight.reset(Tensor::createDevice<int8_t>({weightlen + quantlen}));
mValid = b->onAcquireBuffer(mWeight.get(), Backend::STATIC);
mValid &= b->onAcquireBuffer(mFakeBias.get(), Backend::STATIC);
@ -199,7 +200,7 @@ ErrorCode IdstConvolutionInt8::onExecute(const std::vector<Tensor*>& inputs, con
quanParam.srcKernelSum = fakeSrcKernleSum.data();
std::vector<float> fakeInputScale(DST_XUNIT, 1.f);
quanParam.inputScale = fakeInputScale.data();
std::vector<float> fakeWeightKernelsSum(ocC4 * PackUnit, 0.f);
std::vector<float> fakeWeightKernelsSum(ROUND_UP(output->channel(), UNIT__), 0.f);
quanParam.weightKernelSum = fakeWeightKernelsSum.data();
quanParam.inputBias = nullptr;
quanParam.blockNum = 1;

View File

@ -39,6 +39,7 @@ void MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3(int8_t* dst, const int8_t*
size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder=nullptr);
void MNNSumByAxisLForMatmul_A_ARM86(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams);
void MNNSumByAxisLForMatmul_A_ARM82(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams);
void MNNSumByAxisLForMatmul_A_SME2(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams);
#if defined(MNN_LOW_MEMORY)
// int4 weight gemmInt8 kernel
void MNNGemmInt8AddBiasScale_ARMV82_w4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad,
@ -62,8 +63,16 @@ void MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16(int8_t* dst, const int8_t* src,
const QuanPostTreatParameters* post, size_t realDstCount);
void DynamicQuanInputAndReorder_ARM82(const float* src, int8_t* dst, size_t planeSize, const float* scale, ssize_t aMin,
ssize_t aMax, const float* zeroPoint, size_t ocQuad, size_t offset);
#endif
#endif
#ifdef MNN_SME2
void MNNGemmInt8AddBiasScale_SME2_w4_Fp32(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount);
void MNNGemmInt8AddBiasScale_SME2_w8_Fp32(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount);
void MNNGemmInt8AddBiasScale_SME2_w4_Fp16(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount);
void MNNGemmInt8AddBiasScale_SME2_w8_Fp16(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount);
#endif
#endif // __aarch64__
}
#endif // MNN_USE_NEON
@ -2224,6 +2233,12 @@ static void MNNGetGemmUnitI8mm(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) {
*DST_XUNIT = 10;
}
static void MNNGetGemmUnitSme2(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) {
*UNIT = 16;
*SRC_UNIT = 4;
*DST_XUNIT = 16;
}
template<int EP, int HP>
static void _ArmBasicMNNPackC4ForMatMul_A_L4(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el) {
int number = info[0];
@ -2341,6 +2356,7 @@ static CoreInt8Functions* gCoreFunc = nullptr;
void MNNCoreInt8FunctionInit() {
/* CoreInt8Functions without sdot */
gCoreFunc = new CoreInt8Functions;
auto core = MNNGetCoreFunctions();
// MatMul
gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_16x4_Unit;
@ -2374,7 +2390,6 @@ void MNNCoreInt8FunctionInit() {
#endif
#if defined(__aarch64__)
auto core = MNNGetCoreFunctions();
if (core->supportSDot) {
// MatMul
gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_ARMV82_Unit;
@ -2401,8 +2416,10 @@ void MNNCoreInt8FunctionInit() {
gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_ARMV86_Unit;
gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnitI8mm;
core->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A_ARM86;
#if defined(MNN_LOW_MEMORY)
gCoreFunc->Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale_ARMV86_w4_Unit;
#ifdef MNN_USE_ARMV82
gCoreFunc->MNNGemmInt8AddBiasScale_Unit_FP16 = MNNGemmInt8AddBiasScale_ARMV86_Unit_FP16;
gCoreFunc->MNNGemmInt8AddBiasScale_w4_Unit_FP16 = MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16;
@ -2411,6 +2428,42 @@ void MNNCoreInt8FunctionInit() {
// Im2Col
gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A<10, 8, 8>;
}
#endif // __aarch64__
{
core->backendMatmulRelatedFunctions.Int8GemmKernel = gCoreFunc->Int8GemmKernel;
core->backendMatmulRelatedFunctions.Int8GemmKernelFast = gCoreFunc->Int8GemmKernelFast;
core->backendMatmulRelatedFunctions.Int8GemmKernel_W4 = gCoreFunc->Int8GemmKernel_W4;
core->backendMatmulRelatedFunctions.MNNGemmInt8AddBiasScale_Unit_FP16 = gCoreFunc->MNNGemmInt8AddBiasScale_Unit_FP16;
core->backendMatmulRelatedFunctions.MNNGemmInt8AddBiasScale_w4_Unit_FP16 = gCoreFunc->MNNGemmInt8AddBiasScale_w4_Unit_FP16;
core->backendMatmulRelatedFunctions.MNNGetGemmUnit = gCoreFunc->MNNGetGemmUnit;
core->backendMatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = gCoreFunc->MNNPackC4Int8ForMatMul_A;
core->backendMatmulRelatedFunctions.MNNSumByAxisLForMatmul_A = core->MNNSumByAxisLForMatmul_A;
}
#ifdef __aarch64__
#ifdef MNN_SME2
if (core->supportSME2) {
gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnitSme2;
gCoreFunc->Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale_SME2_w4_Fp32;
gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_SME2_w8_Fp32;
gCoreFunc->MNNGemmInt8AddBiasScale_w4_Unit_FP16 = MNNGemmInt8AddBiasScale_SME2_w4_Fp16;
gCoreFunc->MNNGemmInt8AddBiasScale_Unit_FP16 = MNNGemmInt8AddBiasScale_SME2_w8_Fp16;
core->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A_SME2;
gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A<16, 4, 16>;
gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_SME2_w8_Fp32;
core->sme2MatmulRelatedFuncions.MNNGetGemmUnit = MNNGetGemmUnitSme2;
core->sme2MatmulRelatedFuncions.Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale_SME2_w4_Fp32;
core->sme2MatmulRelatedFuncions.Int8GemmKernel = MNNGemmInt8AddBiasScale_SME2_w8_Fp32;
core->sme2MatmulRelatedFuncions.MNNGemmInt8AddBiasScale_w4_Unit_FP16 = MNNGemmInt8AddBiasScale_SME2_w4_Fp16;
core->sme2MatmulRelatedFuncions.MNNGemmInt8AddBiasScale_Unit_FP16 = MNNGemmInt8AddBiasScale_SME2_w8_Fp16;
core->sme2MatmulRelatedFuncions.MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A_SME2;
core->sme2MatmulRelatedFuncions.MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A<16, 4, 16>;
core->sme2MatmulRelatedFuncions.Int8GemmKernelFast = MNNGemmInt8AddBiasScale_SME2_w8_Fp32;
}
#endif
#endif
MNNInt8FunctionInit();
}

View File

@ -52,6 +52,7 @@ struct QuanPostTreatParameters {
const float* inputScale = nullptr;
const float* inputBias = nullptr;
float* accumBuffer = nullptr;
int32_t* indices = nullptr;
};
struct QuanPrePostParameters{
float* inputScale;

View File

@ -37,7 +37,7 @@ public:
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
void getPackParameter(int* Unit, int* SrcUnit, int* DestUnit, const CoreInt8Functions* core) override;
void getPackParameter(int* Unit, int* SrcUnit, int* DestUnit, const CoreInt8Functions* core);
bool reorderWeight(Backend* b, const Convolution2DCommon* common, const std::shared_ptr<Tensor>& weightOrigin,
std::shared_ptr<Tensor>& weight, const SparseCommon* sparseCommon);

View File

@ -78,11 +78,11 @@ ErrorCode StrassenMatrixComputor::_generateTrivalMatMul(int e, int l, int h, con
auto matmulUnit = core->MNNPackedMatMul;
auto matmulRemain = core->MNNPackedMatMulRemain;
mFunctions.emplace_back(
std::make_pair([cStride, l, h, xCount, AT, BT, CT, COT, tileBufferBasic, unitNumber, bExtraStride, numberThread, eReal, eP, active, matmulUnit, matmulRemain, this](int tId) {
std::make_pair([cStride, l, h, xCount, AT, BT, CT, COT, tileBufferBasic, unitNumber, bExtraStride, numberThread, eReal, eP, lP, active, matmulUnit, matmulRemain, this](int tId) {
auto core = static_cast<CPUBackend*>(backend())->functions();
size_t parameters[7];
parameters[0] = xCount * core->bytes;
parameters[1] = l;
parameters[1] = ROUND_UP(l, lP);
parameters[2] = h;
parameters[3] = cStride;
parameters[4] = 0;

View File

@ -31,6 +31,7 @@ bool AVX2Backend::isValid() {
AVX2Backend::AVX2Backend(const CPURuntime* runtime, BackendConfig::MemoryMode memory, size_t flags) : CPUBackend(runtime, BackendConfig::Precision_Low, memory, MNN_FORWARD_CPU_EXTENSION, flags) {
mCoreFunctions = AVX2Functions::get();
mInt8CoreFunctions = AVX2Functions::getInt8();
mRelatedFunctions = &(mCoreFunctions->backendMatmulRelatedFunctions);
}
AVX2Backend::~AVX2Backend() {

View File

@ -105,6 +105,18 @@ bool AVX2Functions::init(int cpuFlags) {
sizeof(MNN::CoreFunctions::MNNPackedMatMulKernel) * AVX512_INPUT_TILE_MAX);
}
#endif
{
coreFunction->backendMatmulRelatedFunctions.MNNGetMatMulPackMode = coreFunction->MNNGetMatMulPackMode;
coreFunction->backendMatmulRelatedFunctions.MNNPackC4ForMatMul_A = coreFunction->MNNPackC4ForMatMul_A;
coreFunction->backendMatmulRelatedFunctions.MNNPackForMatMul_B = coreFunction->MNNPackForMatMul_B;
coreFunction->backendMatmulRelatedFunctions.MNNPackedMatMul = coreFunction->MNNPackedMatMul;
coreFunction->backendMatmulRelatedFunctions.MNNPackedMatMulRemain = coreFunction->MNNPackedMatMulRemain;
coreFunction->backendMatmulRelatedFunctions.Int8GemmKernel = gAVX2CoreInt8Functions->Int8GemmKernel;
coreFunction->backendMatmulRelatedFunctions.Int8GemmKernelFast = gAVX2CoreInt8Functions->Int8GemmKernelFast;
coreFunction->backendMatmulRelatedFunctions.Int8GemmKernel_W4 = gAVX2CoreInt8Functions->Int8GemmKernel_W4;
coreFunction->backendMatmulRelatedFunctions.MNNGetGemmUnit = gAVX2CoreInt8Functions->MNNGetGemmUnit;
coreFunction->backendMatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = gAVX2CoreInt8Functions->MNNPackC4Int8ForMatMul_A;
}
return true;
}
#endif

View File

@ -80,6 +80,13 @@ void MNNFunctionInit() {
}
#endif
_SSE_ImageProcessInit(coreFunction, cpuFlags);
{
coreFunction->backendMatmulRelatedFunctions.MNNGetMatMulPackMode = coreFunction->MNNGetMatMulPackMode;
coreFunction->backendMatmulRelatedFunctions.MNNPackC4ForMatMul_A = coreFunction->MNNPackC4ForMatMul_A;
coreFunction->backendMatmulRelatedFunctions.MNNPackForMatMul_B = coreFunction->MNNPackForMatMul_B;
coreFunction->backendMatmulRelatedFunctions.MNNPackedMatMul = coreFunction->MNNPackedMatMul;
coreFunction->backendMatmulRelatedFunctions.MNNPackedMatMulRemain = coreFunction->MNNPackedMatMulRemain;
}
}
void MNNAvgPoolUint8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx, ssize_t paddingx, ssize_t factor) {
@ -131,6 +138,7 @@ void MNNMaxPoolInt8_(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputW
void MNNInt8FunctionInit() {
auto cpuFlags = libyuv::InitCpuFlags();
auto core = MNN::MNNGetInt8CoreFunctions();
auto gcore = MNN::MNNGetCoreFunctions();
core->MNNAvgPoolInt8 = MNNAvgPoolUint8;
core->MNNMaxPoolInt8 = MNNMaxPoolInt8_;
if (cpuFlags & libyuv::kCpuHasSSE41) {
@ -143,6 +151,13 @@ void MNNInt8FunctionInit() {
core->Int8GemmKernel_W4 = _SSE_MNNGemmInt8AddBiasScale_16x4_w4;
#endif
}
{
gcore->backendMatmulRelatedFunctions.Int8GemmKernel = core->Int8GemmKernel;
gcore->backendMatmulRelatedFunctions.Int8GemmKernelFast = core->Int8GemmKernelFast;
gcore->backendMatmulRelatedFunctions.Int8GemmKernel_W4 = core->Int8GemmKernel_W4;
gcore->backendMatmulRelatedFunctions.MNNGetGemmUnit = core->MNNGetGemmUnit;
gcore->backendMatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = core->MNNPackC4Int8ForMatMul_A;
}
}

View File

@ -62,7 +62,7 @@ void _AVX_MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup
void _AVX_MNNCountMinMaxValue(const float* source, float* minVal, float* maxVal, size_t size);
void _AVX_MNNGetMatMulPackMode_BF16(int* eP, int *lP, int* hP);
void _AVX_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose);
void _AVX_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
void _AVX_MNNPackedSparseMatMul(float* C, const float* A, const float* B, unsigned int* NNZMap, int* dataOffsetMap, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias);
void _AVX_MNNComputeMatMulForH_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId);
@ -72,7 +72,7 @@ void _AVX_MNNPackCUnit(float* dst, const float* src, size_t area, size_t depth,
void _AVX_MNNUnpackCUnit(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
void _AVX_MNNPackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
void _AVX_MNNUnpackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
void _AVX_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose);
void _AVX_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
void _AVX_ExtraInit(void* functions);
void _AVX_WinogradInit(void* functions);

View File

@ -336,7 +336,8 @@ _mm256_storeu_ps(temp2 + 7 * 8, t7);\
}
}
void _AVX_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose) {
void _AVX_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
auto l = kernelsize * ic;
int offset[2] = {
(int)l,
(int)l
@ -348,10 +349,11 @@ void _AVX_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t
MNNPackC4(dest, source, l, h, offset);
}
void _AVX_MNNPackForMatMul_B_EShort(float* dest, const float* source, size_t h, size_t l, bool transpose) {
void _AVX_MNNPackForMatMul_B_EShort(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
const int unit = 16;
auto hP = h / unit;
auto hR = hP * unit;
auto l = kernelsize * ic;
if (hR != h) {
::memset(dest, 0, UP_DIV(h, unit)*unit*l*sizeof(float));
}

View File

@ -10,7 +10,8 @@
#include "FunctionSummary.hpp"
#include "core/Macro.h"
void _AVX_MNNPackForMatMul_B_BF16(float* destF, const float* sourceF, size_t h, size_t l, bool transpose) {
void _AVX_MNNPackForMatMul_B_BF16(float* destF, const float* sourceF, size_t h, size_t kernelsize, size_t ic, bool transpose) {
auto l = kernelsize * ic;
auto dest = (int16_t*)destF;
auto source = (const int16_t*)sourceF;
auto lC8 = UP_DIV(l, 8);

View File

@ -30,7 +30,7 @@ do { \
// ========= CommonOptFunction.cpp ===========
extern "C" {
void _AVX512_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose);
void _AVX512_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
void _AVX512_MNNPackC8ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);
void _AVX512_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);

View File

@ -327,7 +327,8 @@ void _AVX_MNNPackCUnitTranspose(float* dst, const float* src, size_t area, size_
void _AVX_MNNUnpackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
}
void _AVX512_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose) {
void _AVX512_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
auto l = kernelsize * ic;
int offset[2] = {
(int)l,
(int)l

View File

@ -286,6 +286,7 @@ void _AVX512_MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* sca
auto sizeC4 = sizeQuad / 2;
auto sizeRemain = sizeQuad % 2;
auto zero = _mm256_set1_epi32(0);
auto offset = _mm256_set1_ps(128.f);
auto scaleValue0 = _mm256_set1_ps(scale[0]);
auto scaleValue1 = scaleValue0;
@ -293,11 +294,11 @@ void _AVX512_MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* sca
scaleValue0 = _mm256_loadu_ps(scale);
scaleValue1 = _mm256_loadu_ps(scale + 8);
}
auto zeroPointValue0 = _mm256_set1_ps(zeroPoint[0]) + _mm256_set1_ps(128.f);
auto zeroPointValue0 = _mm256_add_ps(_mm256_set1_ps(zeroPoint[0]), offset);
auto zeroPointValue1 = zeroPointValue0;
if (quanParamVec >> 1) {
zeroPointValue0 = _mm256_loadu_ps(zeroPoint) + _mm256_set1_ps(128.f);
zeroPointValue1 = _mm256_loadu_ps(zeroPoint + 8) + _mm256_set1_ps(128.f);
zeroPointValue0 = _mm256_add_ps(_mm256_loadu_ps(zeroPoint), offset);
zeroPointValue1 = _mm256_add_ps(_mm256_loadu_ps(zeroPoint + 8), offset);
}
for (int i = 0; i < sizeC4; ++i) {

View File

@ -153,7 +153,6 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
int weight_step_Z = static_cast<int32_t>(src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H)) + (2 * sizeof(float) * GEMMINT8_AVX512_H);
int weight_step_Y = static_cast<int32_t>(GEMMINT8_AVX512_L * GEMMINT8_AVX512_H);
int weightPackStride = GEMMINT8_AVX512_L * PACK_UNIT;
int weight_step_Z_remain = static_cast<int32_t>(src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H)) + (2 * sizeof(float) * dzR * PACK_UNIT);
int source_step = realDst * PACK_UNIT;
if (realDst == GEMMINT8_AVX512_E) {
@ -518,9 +517,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
__m512i D1 = _mm512_set1_epi32(0);
__m512i D2 = _mm512_set1_epi32(0);
__m512i D3 = _mm512_set1_epi32(0);
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
auto weightDzSub = weight_dz + bk * weight_step_Z + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
@ -569,7 +568,7 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
@ -941,9 +940,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
__m512i D0 = _mm512_set1_epi32(0);
__m512i D1 = _mm512_set1_epi32(0);
__m512i D2 = _mm512_set1_epi32(0);
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
auto weightDzSub = weight_dz + bk * weight_step_Z + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
@ -984,7 +983,7 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
f2 = _mm512_mul_ps(f2, inputscale2);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
@ -1284,9 +1283,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
for (int bk = 0; bk < blockNum; ++bk) {
__m512i D0 = _mm512_set1_epi32(0);
__m512i D1 = _mm512_set1_epi32(0);
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
auto weightDzSub = weight_dz + bk * weight_step_Z + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
@ -1320,7 +1319,7 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
f1 = _mm512_mul_ps(f1, inputscale1);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
@ -1537,9 +1536,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
for (int i=0; i<dzR; ++i) {
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
auto weightDzSub = weight_dz + bk * weight_step_Z + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
__m512i D0 = _mm512_set1_epi32(0);
@ -1566,7 +1565,7 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
f0 = _mm512_mul_ps(f0, inputscale0);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
} else if (bk == 0) { // if input not block quant, only accum once!
@ -1655,7 +1654,6 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
int weight_step_Y = GEMMINT8_AVX512_L * GEMMINT8_AVX512_H / 2;
int weight_step_Z = src_depth_quad * weight_step_Y + (2 * 4 * GEMMINT8_AVX512_H);
int weightPackStride = GEMMINT8_AVX512_L / 2 * PACK_UNIT;
int weight_step_Z_remain = src_depth_quad * weight_step_Y + (2 * 4 * dzR * PACK_UNIT);
int source_step = realDst * PACK_UNIT;
if (realDst == GEMMINT8_AVX512_E) {
for (int dz = 0; dz < dzU; ++dz) {
@ -1975,9 +1973,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
__m512i D1 = _mm512_set1_epi32(0);
__m512i D2 = _mm512_set1_epi32(0);
__m512i D3 = _mm512_set1_epi32(0);
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + weightPackStride * suborder[i];
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
auto weightDzSub = weight_dz + bk * weight_step_Z + weightPackStride * suborder[i];
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
@ -2027,7 +2025,7 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
f3 = _mm512_mul_ps(f3, inputscale3);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
@ -2349,9 +2347,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
__m512i D0 = _mm512_set1_epi32(0);
__m512i D1 = _mm512_set1_epi32(0);
__m512i D2 = _mm512_set1_epi32(0);
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + weightPackStride * suborder[i];
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
auto weightDzSub = weight_dz + bk * weight_step_Z + weightPackStride * suborder[i];
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
@ -2394,7 +2392,7 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
f2 = _mm512_mul_ps(f2, inputscale2);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
@ -2655,9 +2653,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
for (int bk = 0; bk < blockNum; ++bk) {
__m512i D0 = _mm512_set1_epi32(0);
__m512i D1 = _mm512_set1_epi32(0);
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + weightPackStride * suborder[i];
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
auto weightDzSub = weight_dz + bk * weight_step_Z + weightPackStride * suborder[i];
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
@ -2693,7 +2691,7 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
f1 = _mm512_mul_ps(f1, inputscale1);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
@ -2878,9 +2876,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
for (int i=0; i<dzR; ++i) {
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + weightPackStride * suborder[i];
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
auto weightDzSub = weight_dz + bk * weight_step_Z + weightPackStride * suborder[i];
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
__m512i D0 = _mm512_set1_epi32(0);
@ -2908,7 +2906,7 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
f0 = _mm512_mul_ps(f0, inputscale0);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
} else if (bk == 0) { // if input not block quant, only accum once!

View File

@ -139,7 +139,6 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
int weight_step_Z = static_cast<int32_t>(src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H)) + (2 * sizeof(float) * GEMMINT8_AVX512_H);
int weight_step_Y = static_cast<int32_t>(GEMMINT8_AVX512_L * GEMMINT8_AVX512_H);
int weightPackStride = GEMMINT8_AVX512_L * PACK_UNIT;
int weight_step_Z_remain = static_cast<int32_t>(src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H)) + (2 * sizeof(float) * dzR * PACK_UNIT);
int source_step = realDst * PACK_UNIT;
if (realDst == GEMMINT8_AVX512_E) {
@ -504,9 +503,9 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
__m512i D1 = _mm512_set1_epi32(0);
__m512i D2 = _mm512_set1_epi32(0);
__m512i D3 = _mm512_set1_epi32(0);
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
auto weightDzSub = weight_dz + bk * weight_step_Z + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
@ -555,7 +554,7 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
@ -927,9 +926,9 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
__m512i D0 = _mm512_set1_epi32(0);
__m512i D1 = _mm512_set1_epi32(0);
__m512i D2 = _mm512_set1_epi32(0);
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
auto weightDzSub = weight_dz + bk * weight_step_Z + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
@ -970,7 +969,7 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
f2 = _mm512_mul_ps(f2, inputscale2);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
@ -1270,9 +1269,9 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
for (int bk = 0; bk < blockNum; ++bk) {
__m512i D0 = _mm512_set1_epi32(0);
__m512i D1 = _mm512_set1_epi32(0);
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
auto weightDzSub = weight_dz + bk * weight_step_Z + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
@ -1306,7 +1305,7 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
f1 = _mm512_mul_ps(f1, inputscale1);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
@ -1523,9 +1522,9 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
for (int i=0; i<dzR; ++i) {
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
auto weightDzSub = weight_dz + bk * weight_step_Z + i * weightPackStride;
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
__m512i D0 = _mm512_set1_epi32(0);
@ -1552,7 +1551,7 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
f0 = _mm512_mul_ps(f0, inputscale0);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
} else if (bk == 0) { // if input not block quant, only accum once!
@ -1640,7 +1639,6 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
int weight_step_Y = GEMMINT8_AVX512_L * GEMMINT8_AVX512_H / 2;
int weight_step_Z = src_depth_quad * weight_step_Y + (2 * 4 * GEMMINT8_AVX512_H);
int weightPackStride = GEMMINT8_AVX512_L / 2 * PACK_UNIT;
int weight_step_Z_remain = src_depth_quad * weight_step_Y + (2 * 4 * dzR * PACK_UNIT);
int source_step = realDst * PACK_UNIT;
if (realDst == GEMMINT8_AVX512_E) {
for (int dz = 0; dz < dzU; ++dz) {
@ -1960,9 +1958,9 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
__m512i D1 = _mm512_set1_epi32(0);
__m512i D2 = _mm512_set1_epi32(0);
__m512i D3 = _mm512_set1_epi32(0);
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + weightPackStride * suborder[i];
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
auto weightDzSub = weight_dz + bk * weight_step_Z + weightPackStride * suborder[i];
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
@ -2012,7 +2010,7 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
f3 = _mm512_mul_ps(f3, inputscale3);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
@ -2334,9 +2332,9 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
__m512i D0 = _mm512_set1_epi32(0);
__m512i D1 = _mm512_set1_epi32(0);
__m512i D2 = _mm512_set1_epi32(0);
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + weightPackStride * suborder[i];
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
auto weightDzSub = weight_dz + bk * weight_step_Z + weightPackStride * suborder[i];
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
@ -2379,7 +2377,7 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
f2 = _mm512_mul_ps(f2, inputscale2);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
@ -2640,9 +2638,9 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
for (int bk = 0; bk < blockNum; ++bk) {
__m512i D0 = _mm512_set1_epi32(0);
__m512i D1 = _mm512_set1_epi32(0);
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + weightPackStride * suborder[i];
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
auto weightDzSub = weight_dz + bk * weight_step_Z + weightPackStride * suborder[i];
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
for (int sz = 0; sz < src_depth_quad; ++sz) {
@ -2678,7 +2676,7 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
f1 = _mm512_mul_ps(f1, inputscale1);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
bias01 = _mm512_mul_ps(inputbias1, wsum0);
@ -2863,9 +2861,9 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
for (int i=0; i<dzR; ++i) {
auto accum_x = accumbuff;
for (int bk = 0; bk < blockNum; ++bk) {
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + weightPackStride * suborder[i];
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + dzR * PACK_UNIT;
auto weightDzSub = weight_dz + bk * weight_step_Z + weightPackStride * suborder[i];
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
__m512i D0 = _mm512_set1_epi32(0);
@ -2893,7 +2891,7 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
f0 = _mm512_mul_ps(f0, inputscale0);
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
if (post->inputBias) {
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
bias00 = _mm512_mul_ps(inputbias0, wsum0);
} else if (bk == 0) { // if input not block quant, only accum once!

View File

@ -72,14 +72,14 @@ void _SSE_MNNConvRunForLineDepthwise(float* dst, const float* src, const float*
void _SSE_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step,
size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst);
void _SSE_MNNExpC8(float* dest, const float* source, float* offset, const float* parameters, size_t countC8);
void _SSE_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose);
void _SSE_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
void _SSE_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, ssize_t maxValue, const float* zeroPoint, ssize_t quanParamVec);
void _SSE_MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t size, const float* zeroPoint, ssize_t quanParamVec);
void _SSE_MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder=nullptr);
void _SSE_MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count);
void _SSE_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose);
void _SSE_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
void _SSE_MNNReluInt8(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint);
void _SSE_MNNSoftmax(float* dest, const float* source, size_t size);
void _SSE_ExtraInit(void* functions);

View File

@ -153,7 +153,8 @@ void _SSE_GemmPostTreat(float* C, size_t eSize, const size_t* parameter, const f
}
}
void _SSE_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose) {
void _SSE_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
auto l = kernelsize * ic;
int offset[2] = {
(int)l,
(int)l
@ -165,7 +166,8 @@ void _SSE_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t
MNNPackC4(dest, source, l, h, offset);
}
void _SSE_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose) {
void _SSE_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
auto l = kernelsize * ic;
int offset[] = {
(int)l,
(int)l

View File

@ -117,7 +117,7 @@ ENDIF()
include(FetchContent)
if(MNN_SUPPORT_TRANSFORMER_FUSE)
set(CUTLASS_COMMIT_HASH "5c6bca04414e06ce74458ab0a2018e2b8272701c")
set(CUTLASS_COMMIT_HASH "b995f933179c22d3fe0d871c3a53d11e4681950f")
set(CUTLASS_VERSION_NAME "v4.0.0")
else()
set(CUTLASS_COMMIT_HASH "319a389f42b776fae5701afcb943fc03be5b5c25")

View File

@ -100,7 +100,9 @@ Backend* CUDARuntimeWrapper::onCreate(const BackendConfig* config, Backend* orig
precision = 1;
}
return new CUDABackend(this, mBufferPool, mCUDARuntime, precision, memory_mode);
auto backend = new CUDABackend(this, mBufferPool, mCUDARuntime, precision, memory_mode);
backend->setMetaPtr(pMeta);
return backend;
}
void CUDARuntimeWrapper::onGabageCollect(int level) {

View File

@ -1,5 +1,6 @@
#include "AttentionExecution.hpp"
#include "core/TensorUtils.hpp"
#include "SoftmaxExecution.hpp"
namespace MNN {
namespace CUDA {
@ -43,8 +44,8 @@ __global__ void compact_kv_cache_kernel(
int copy_len = reserve_info[reserve_pair_idx * 2 + 1];
int copy_dst_begin_offset = reserve_offsets[reserve_pair_idx]; // 在 reserve 区域内的偏移
long long src_offset = past_kv_len_after_remove + copy_src_begin;
long long dst_offset = past_kv_len_after_remove + copy_dst_begin_offset;
int src_offset = past_kv_len_after_remove + copy_src_begin;
int dst_offset = past_kv_len_after_remove + copy_dst_begin_offset;
uint8_t* src_k_ptr = (uint8_t*)src_key_cache;
uint8_t* dst_k_ptr = (uint8_t*)dst_key_cache;
@ -54,13 +55,13 @@ __global__ void compact_kv_cache_kernel(
// Key Cache: [L, B, H_kv, D] -> 拷贝整个 (B, H_kv, D) 下的 L 片段
for (int l = 0; l < copy_len; ++l) {
// Key: [L, B, H, D]
long long k_src_idx = ((src_offset + l) * b + b_idx) * h_kv * d + h_kv_idx * d + d_idx;
long long k_dst_idx = ((dst_offset + l) * b + b_idx) * h_kv * d + h_kv_idx * d + d_idx;
int k_src_idx = ((src_offset + l) * b + b_idx) * h_kv * d + h_kv_idx * d + d_idx;
int k_dst_idx = ((dst_offset + l) * b + b_idx) * h_kv * d + h_kv_idx * d + d_idx;
memcpy(dst_k_ptr + k_dst_idx * element_size, src_k_ptr + k_src_idx * element_size, element_size);
// Value: [B, H, D, L]
long long v_src_idx = (((long long)b_idx * h_kv + h_kv_idx) * d + d_idx) * kv_cache_max_len + (src_offset + l);
long long v_dst_idx = (((long long)b_idx * h_kv + h_kv_idx) * d + d_idx) * kv_cache_max_len + (dst_offset + l);
int v_src_idx = ((b_idx * h_kv + h_kv_idx) * d + d_idx) * kv_cache_max_len + (src_offset + l);
int v_dst_idx = ((b_idx * h_kv + h_kv_idx) * d + d_idx) * kv_cache_max_len + (dst_offset + l);
memcpy(dst_v_ptr + v_dst_idx * element_size, src_v_ptr + v_src_idx * element_size, element_size);
}
}
@ -92,9 +93,9 @@ __global__ void copy_kv_to_cache_kernel(
int h_kv_idx = bh_kv_idx % kv_num_head;
// 输入 K 和 V 的源索引 (假设输入布局为 [B, L_k_new, H_kv, D])
long long input_offset = (long long)b_idx * new_kv_seq_len * kv_num_head * head_dim +
(long long)l_idx_new * kv_num_head * head_dim +
(long long)h_kv_idx * head_dim;
int input_offset = b_idx * new_kv_seq_len * kv_num_head * head_dim +
l_idx_new * kv_num_head * head_dim +
h_kv_idx * head_dim;
T val_to_copy_k = key_input[input_offset + d_idx];
T val_to_copy_v = value_input[input_offset + d_idx]; // 从相同偏移读取因为K和V输入形状相同
@ -104,16 +105,16 @@ __global__ void copy_kv_to_cache_kernel(
if (dest_seq_idx_cache >= allocated_kv_len) return; // 边界检查
// Key Cache 输出: [L_kv_alloc, B, H_kv, D]
long long key_cache_idx = (long long)dest_seq_idx_cache * batch_size * kv_num_head * head_dim +
(long long)b_idx * kv_num_head * head_dim +
(long long)h_kv_idx * head_dim +
int key_cache_idx = dest_seq_idx_cache * batch_size * kv_num_head * head_dim +
b_idx * kv_num_head * head_dim +
h_kv_idx * head_dim +
d_idx;
key_cache_output[key_cache_idx] = val_to_copy_k;
// Value Cache 输出: [B, H_kv, D, L_kv_alloc]
long long value_cache_idx = (long long)b_idx * kv_num_head * head_dim * allocated_kv_len +
(long long)h_kv_idx * head_dim * allocated_kv_len +
(long long)d_idx * allocated_kv_len +
int value_cache_idx = b_idx * kv_num_head * head_dim * allocated_kv_len +
h_kv_idx * head_dim * allocated_kv_len +
d_idx * allocated_kv_len +
dest_seq_idx_cache;
value_cache_output[value_cache_idx] = val_to_copy_v;
}
@ -148,15 +149,15 @@ __global__ void qk_kernel(
int h_kv_idx = h_q_idx / param->group; // 对应的 KV 头索引
// Query 元素指针基址 Q[b_idx, current_full_q_idx, h_q_idx, :]
const T* q_ptr = query_input + (long long)b_idx * param->query_seq_len * param->head_num * param->head_dim +
(long long)current_full_q_idx * param->head_num * param->head_dim +
(long long)h_q_idx * param->head_dim;
const T* q_ptr = query_input + b_idx * param->query_seq_len * param->head_num * param->head_dim +
current_full_q_idx * param->head_num * param->head_dim +
h_q_idx * param->head_dim;
// Key 元素指针基址 K_cache[k_idx, b_idx, h_kv_idx, :]
// Key Cache: [L_kv_alloc, B, H_kv, D]
const T* k_ptr = key_cache + (long long)k_idx * param->batch * param->kv_head_num * param->head_dim +
(long long)b_idx * param->kv_head_num * param->head_dim +
(long long)h_kv_idx * param->head_dim;
const T* k_ptr = key_cache + k_idx * param->batch * param->kv_head_num * param->head_dim +
b_idx * param->kv_head_num * param->head_dim +
h_kv_idx * param->head_dim;
for (int d = 0; d < param->head_dim; ++d) {
score_sum += static_cast<AccT>(q_ptr[d]) * static_cast<AccT>(k_ptr[d]);
@ -168,26 +169,38 @@ __global__ void qk_kernel(
// current_full_q_idx 是在完整查询序列中的索引 (行索引)
// k_idx 是在当前有效Key序列中的索引 (列索引)
// 浮点 Mask 布局为 L_q * L_q (param->query_seq_len * param->query_seq_len),整数 Mask 布局为 L_q * L_k
long long mask_idx = (long long)current_full_q_idx * (is_add_mask_flag ? param->query_seq_len : param->key_seq_len) + k_idx - param->key_seq_len + param->query_seq_len;
if (is_add_mask_flag) {
// 加性Mask通常是float类型
if (k_idx >= param->key_seq_len - param->query_seq_len)
score_sum += k_idx >= param->key_seq_len - param->query_seq_len ?
static_cast<const AccT*>(mask_tensor_data)[mask_idx]
: 0; // 前 L_k - L_q 个 Mask 均视为 0
} else {
// 设置Mask通常是int类型, 0表示mask掉
if (is_add_mask_flag) { // 加性Mask是float类型
int mask_idx = current_full_q_idx * param->query_seq_len + k_idx - param->key_seq_len + param->query_seq_len;
if (k_idx >= param->key_seq_len - param->query_seq_len) { // 前 L_k - L_q 个 Mask 均视为 0
if (sizeof(T) == sizeof(__half)) {
score_sum += __half2float(((const __half*)mask_tensor_data)[mask_idx]);
} else {
score_sum += static_cast<const AccT*>(mask_tensor_data)[mask_idx];
}
}
} else { // 设置Mask是int类型, 0表示mask掉
int mask_idx = current_full_q_idx * param->key_seq_len + k_idx;
if (static_cast<const int*>(mask_tensor_data)[mask_idx] == 0) {
score_sum = AccT(-1e9f);
if (sizeof(T) == sizeof(__half)) {
score_sum = AccT(-65504.0f);
} else {
score_sum = AccT(-1e9f);
}
}
}
}
if (sizeof(T) == sizeof(__half)) {
const AccT max_half_val = AccT(65504.0f);
if (score_sum > max_half_val) score_sum = max_half_val;
if (score_sum < -max_half_val) score_sum = -max_half_val;
}
// 输出: qk_scores_output[b_idx, h_q_idx, q_idx_in_piece, k_idx]
long long out_idx = (long long)b_idx * param->head_num * param->q_seq_piece_len * param->key_seq_len +
(long long)h_q_idx * param->q_seq_piece_len * param->key_seq_len +
(long long)q_idx_in_piece * param->key_seq_len +
int out_idx = b_idx * param->head_num * param->q_seq_piece_len * param->key_seq_len +
h_q_idx * param->q_seq_piece_len * param->key_seq_len +
q_idx_in_piece * param->key_seq_len +
k_idx;
qk_scores_output[out_idx] = static_cast<T>(score_sum);
}
@ -209,8 +222,8 @@ __global__ void softmax_kernel(
return;
}
const T* current_row_scores = qk_scores + (long long)row_global_idx * param->key_seq_len;
T* current_row_result = softmax_result + (long long)row_global_idx * param->key_seq_len;
const T* current_row_scores = qk_scores + row_global_idx * param->key_seq_len;
T* current_row_result = softmax_result + row_global_idx * param->key_seq_len;
// 1. 找到行中的最大值
AccT max_val = AccT(-1e9f); // 或者 -FLT_MAX
@ -229,7 +242,7 @@ __global__ void softmax_kernel(
sum_exp += expf(static_cast<AccT>(current_row_scores[i]) - max_val);
}
// 避免除以零
AccT inv_sum_exp = (sum_exp == 0.0f) ? AccT(1e-10f) : (AccT(1.0f) / sum_exp);
AccT inv_sum_exp = (sum_exp == 0.0f) ? AccT(1e-6f) : (AccT(1.0f) / sum_exp);
// 3. 计算 Softmax
for (int i = 0; i < param->key_seq_len; ++i) {
@ -264,24 +277,24 @@ __global__ void qkv_kernel(
int h_kv_idx = h_q_idx / param->group; // 对应的 KV 头索引
// Softmax 概率指针 S[b_idx, h_q_idx, q_idx_in_piece, :]
const T* prob_ptr = softmax_probs + (long long)b_idx * param->head_num * param->q_seq_piece_len * param->key_seq_len +
(long long)h_q_idx * param->q_seq_piece_len * param->key_seq_len +
(long long)q_idx_in_piece * param->key_seq_len;
const T* prob_ptr = softmax_probs + b_idx * param->head_num * param->q_seq_piece_len * param->key_seq_len +
h_q_idx * param->q_seq_piece_len * param->key_seq_len +
q_idx_in_piece * param->key_seq_len;
// Value Cache 指针基址 V[b_idx, h_kv_idx, d_idx, :]
// Value Cache 布局: [B, H_kv, D, L_kv_alloc_max]
const T* val_ptr_base = value_cache + (long long)b_idx * param->kv_head_num * param->head_dim * param->max_kv_len +
(long long)h_kv_idx * param->head_dim * param->max_kv_len +
(long long)d_idx * param->max_kv_len;
const T* val_ptr_base = value_cache + b_idx * param->kv_head_num * param->head_dim * param->max_kv_len +
h_kv_idx * param->head_dim * param->max_kv_len +
d_idx * param->max_kv_len;
for (int k_s = 0; k_s < param->key_seq_len; ++k_s) { // 沿 L_k_total (param->key_seq_len) 求和
weighted_sum += static_cast<AccT>(prob_ptr[k_s]) * static_cast<AccT>(val_ptr_base[k_s]);
}
// 最终输出: O[b_idx, current_full_q_idx, h_q_idx, d_idx]
long long out_idx = (long long)b_idx * param->query_seq_len * param->head_num * param->head_dim +
(long long)current_full_q_idx * param->head_num * param->head_dim +
(long long)h_q_idx * param->head_dim +
int out_idx = b_idx * param->query_seq_len * param->head_num * param->head_dim +
current_full_q_idx * param->head_num * param->head_dim +
h_q_idx * param->head_dim +
d_idx;
attention_output[out_idx] = static_cast<T>(weighted_sum);
}
@ -293,11 +306,10 @@ AttentionExecution::AttentionExecution(Backend* backend, bool kv_cache_op_param)
: Execution(backend), mIsKVCacheEnabled(kv_cache_op_param), mCudaBackend(static_cast<CUDABackend*>(backend)),
mBatch(0), mQuerySeqLen(0), mNumHead(0), mHeadDim(0), mKvNumHead(0), mNewKvSeqLen(0),
mQseqSplitNum(1), mHasMask(false), mIsAddMask(false), mParam_gpu(nullptr), mScale(1.0f) {
mPrecision = halide_type_of<float>(); // 默认精度,可在 onResize 中更改
mPrecision = 4; // 默认精度,可在 onResize 中更改
if (mIsKVCacheEnabled) {
mCache.reset(new SharedCache());
mMeta = (KVMeta*)(mCudaBackend->getRuntime()->pMeta);
mMeta = (KVMeta*)(mCudaBackend->getMetaPtr());
// printf("[Attention Constructor] Reading runtime. Address is: %p\n", (void*)(mCudaBackend->getRuntime()));
// printf("[Attention Constructor] Reading pMeta as address. %p\n", (void*)(mCudaBackend->getRuntime()->pMeta));
}
@ -321,8 +333,12 @@ ErrorCode AttentionExecution::init_cache_tensors() {
mCache->mMaxLength = 0;
// 创建占位Tensor实际分配在 reallocKVCache_gpu 中进行
mCache->mPastKey.reset(Tensor::createDevice({1, 1, 1, 1}, mPrecision, Tensor::CAFFE));
mCache->mPastValue.reset(Tensor::createDevice({1, 1, 1, 1}, mPrecision, Tensor::CAFFE));
mCache->mPastKey.reset(mPrecision == 4
? Tensor::createDevice<float>({1, 1, 1, 1})
: Tensor::createDevice<uint16_t>({1, 1, 1, 1}));
mCache->mPastValue.reset(mPrecision == 4
? Tensor::createDevice<float>({1, 1, 1, 1})
: Tensor::createDevice<uint16_t>({1, 1, 1, 1}));
if (!mCache->mPastKey || !mCache->mPastValue) return MNN::OUT_OF_MEMORY;
bool res = mCudaBackend->onAcquireBuffer(mCache->mPastKey.get(), Backend::STATIC);
@ -349,9 +365,13 @@ ErrorCode AttentionExecution::reallocKVCache_gpu(int required_total_kv_len, int
}
// Key Cache: [L_kv_max, B, H_kv, D]
std::shared_ptr<Tensor> new_past_key_tensor(Tensor::createDevice({new_allocated_max_len, batch_size, kv_num_head, head_dim}, mPrecision, Tensor::CAFFE));
std::shared_ptr<Tensor> new_past_key_tensor(mPrecision == 4
? Tensor::createDevice<float>({new_allocated_max_len, batch_size, kv_num_head, head_dim})
: Tensor::createDevice<uint16_t>({new_allocated_max_len, batch_size, kv_num_head, head_dim}));
// Value Cache: [B, H_kv, D, L_kv_max]
std::shared_ptr<Tensor> new_past_value_tensor(Tensor::createDevice({batch_size, kv_num_head, head_dim, new_allocated_max_len}, mPrecision, Tensor::CAFFE));
std::shared_ptr<Tensor> new_past_value_tensor(mPrecision == 4
? Tensor::createDevice<float>({batch_size, kv_num_head, head_dim, new_allocated_max_len})
: Tensor::createDevice<uint16_t>({batch_size, kv_num_head, head_dim, new_allocated_max_len}));
if (!new_past_key_tensor || !new_past_value_tensor) return MNN::OUT_OF_MEMORY;
@ -360,7 +380,7 @@ ErrorCode AttentionExecution::reallocKVCache_gpu(int required_total_kv_len, int
if(!resK || !resV) return MNN::OUT_OF_MEMORY;
if (needs_copy) {
size_t element_size_bytes = mPrecision.bytes();
size_t element_size_bytes = mPrecision;
// 拷贝 Key Cache: 从旧的拷贝 [old_past_len, B, H_kv, D] 片段到新的
size_t key_bytes_to_copy = (size_t)old_past_len * batch_size * kv_num_head * head_dim * element_size_bytes;
if (key_bytes_to_copy > 0) {
@ -376,9 +396,9 @@ ErrorCode AttentionExecution::reallocKVCache_gpu(int required_total_kv_len, int
for (int h_kv = 0; h_kv < kv_num_head; ++h_kv) {
for (int d_s = 0; d_s < head_dim; ++d_s) {
uint8_t* dst_ptr = getTensorDevicePtr<uint8_t>(new_past_value_tensor.get()) +
((((long long)b * kv_num_head + h_kv) * head_dim + d_s) * new_allocated_max_len) * element_size_bytes;
(((b * kv_num_head + h_kv) * head_dim + d_s) * new_allocated_max_len) * element_size_bytes;
uint8_t* src_ptr = getTensorDevicePtr<uint8_t>(mCache->mPastValue.get()) +
((((long long)b * kv_num_head + h_kv) * head_dim + d_s) * old_max_len) * element_size_bytes;
(((b * kv_num_head + h_kv) * head_dim + d_s) * old_max_len) * element_size_bytes;
if ((size_t)old_past_len * element_size_bytes > 0) {
cudaMemcpyAsync(dst_ptr, src_ptr, (size_t)old_past_len * element_size_bytes, cudaMemcpyDeviceToDevice, stream);
checkKernelErrors;
@ -409,7 +429,7 @@ ErrorCode AttentionExecution::reallocKVCache_gpu(int required_total_kv_len, cons
return MNN::INVALID_VALUE;
}
size_t element_size = mPrecision.bytes();
size_t element_size = mPrecision;
bool needs_realloc = required_total_kv_len > mCache->mMaxLength;
int past_len_after_remove = mCache->mPastLength - meta->remove;
@ -437,14 +457,12 @@ ErrorCode AttentionExecution::reallocKVCache_gpu(int required_total_kv_len, cons
if (needs_realloc) {
// 优化路径如果需要扩容则直接将旧Cache的有效数据紧凑化拷贝到新Cache中
int new_allocated_max_len = ROUND_UP(required_total_kv_len, mExpandChunk);
std::shared_ptr<Tensor> new_past_key_tensor(
Tensor::createDevice({new_allocated_max_len, mBatch, mKvNumHead, mHeadDim},
mPrecision, Tensor::CAFFE
));
std::shared_ptr<Tensor> new_past_value_tensor(
Tensor::createDevice({mBatch, mKvNumHead, mHeadDim, new_allocated_max_len},
mPrecision, Tensor::CAFFE
));
std::shared_ptr<Tensor> new_past_key_tensor(mPrecision == 4
? Tensor::createDevice<float>({new_allocated_max_len, mBatch, mKvNumHead, mHeadDim})
: Tensor::createDevice<uint16_t>({new_allocated_max_len, mBatch, mKvNumHead, mHeadDim}));
std::shared_ptr<Tensor> new_past_value_tensor(mPrecision == 4
? Tensor::createDevice<float>({mBatch, mKvNumHead, mHeadDim, new_allocated_max_len})
: Tensor::createDevice<uint16_t>({mBatch, mKvNumHead, mHeadDim, new_allocated_max_len}));
if(!mCudaBackend->onAcquireBuffer(new_past_key_tensor.get(), Backend::STATIC)
|| !mCudaBackend->onAcquireBuffer(new_past_value_tensor.get(), Backend::STATIC)) {
return MNN::OUT_OF_MEMORY;
@ -466,8 +484,12 @@ ErrorCode AttentionExecution::reallocKVCache_gpu(int required_total_kv_len, cons
mCache->mMaxLength = new_allocated_max_len;
} else {
// 如果不需要扩容,则需要一个临时缓冲区来执行原地紧凑化
std::shared_ptr<Tensor> temp_key(Tensor::createDevice(mCache->mPastKey->shape(), mPrecision, Tensor::CAFFE));
std::shared_ptr<Tensor> temp_value(Tensor::createDevice(mCache->mPastValue->shape(), mPrecision, Tensor::CAFFE));
std::shared_ptr<Tensor> temp_key(mPrecision == 4
? Tensor::createDevice<float>(mCache->mPastKey->shape())
: Tensor::createDevice<uint16_t>(mCache->mPastKey->shape()));
std::shared_ptr<Tensor> temp_value(mPrecision == 4
? Tensor::createDevice<float>(mCache->mPastValue->shape())
: Tensor::createDevice<uint16_t>(mCache->mPastValue->shape()));
if(!mCudaBackend->onAcquireBuffer(temp_key.get(), Backend::STATIC)
|| !mCudaBackend->onAcquireBuffer(temp_value.get(), Backend::STATIC)) {
return MNN::OUT_OF_MEMORY;
@ -508,14 +530,12 @@ ErrorCode AttentionExecution::reallocKVCache_gpu(int required_total_kv_len, cons
int old_past_len_to_copy = mCache->mPastLength;
int new_allocated_max_len = ROUND_UP(required_total_kv_len, mExpandChunk);
std::shared_ptr<Tensor> new_past_key_tensor(
Tensor::createDevice({new_allocated_max_len, mBatch, mKvNumHead, mHeadDim},
mPrecision, Tensor::CAFFE
));
std::shared_ptr<Tensor> new_past_value_tensor(
Tensor::createDevice({mBatch, mKvNumHead, mHeadDim, new_allocated_max_len},
mPrecision, Tensor::CAFFE
));
std::shared_ptr<Tensor> new_past_key_tensor(mPrecision == 4
? Tensor::createDevice<float>({new_allocated_max_len, mBatch, mKvNumHead, mHeadDim})
: Tensor::createDevice<uint16_t>({new_allocated_max_len, mBatch, mKvNumHead, mHeadDim}));
std::shared_ptr<Tensor> new_past_value_tensor(mPrecision == 4
? Tensor::createDevice<float>({mBatch, mKvNumHead, mHeadDim, new_allocated_max_len})
: Tensor::createDevice<uint16_t>({mBatch, mKvNumHead, mHeadDim, new_allocated_max_len}));
if(!mCudaBackend->onAcquireBuffer(new_past_key_tensor.get(), Backend::STATIC)
|| !mCudaBackend->onAcquireBuffer(new_past_value_tensor.get(), Backend::STATIC)) {
return MNN::OUT_OF_MEMORY;
@ -554,7 +574,9 @@ ErrorCode AttentionExecution::ensureTempBuffers_gpu(int batch, int num_head, int
bool qk_realloc = !mTempQK || mTempQK->shape() != qk_shape;
if (qk_realloc) {
if(mTempQK && mTempQK->deviceId() != 0) mCudaBackend->onReleaseBuffer(mTempQK.get(), Backend::STATIC);
mTempQK.reset(Tensor::createDevice(qk_shape, mPrecision, Tensor::CAFFE));
mTempQK.reset(mPrecision == 4
? Tensor::createDevice<float>(qk_shape)
: Tensor::createDevice<uint16_t>(qk_shape));
if(!mTempQK || !mCudaBackend->onAcquireBuffer(mTempQK.get(), Backend::STATIC)) return MNN::OUT_OF_MEMORY;
}
@ -562,7 +584,9 @@ ErrorCode AttentionExecution::ensureTempBuffers_gpu(int batch, int num_head, int
bool softmax_realloc = !mTempSoftmax || mTempSoftmax->shape() != qk_shape;
if (softmax_realloc) {
if(mTempSoftmax && mTempSoftmax->deviceId() != 0) mCudaBackend->onReleaseBuffer(mTempSoftmax.get(), Backend::STATIC);
mTempSoftmax.reset(Tensor::createDevice(qk_shape, mPrecision, Tensor::CAFFE));
mTempSoftmax.reset(mPrecision == 4
? Tensor::createDevice<float>(qk_shape)
: Tensor::createDevice<uint16_t>(qk_shape));
if(!mTempSoftmax || !mCudaBackend->onAcquireBuffer(mTempSoftmax.get(), Backend::STATIC)) return MNN::OUT_OF_MEMORY;
}
@ -573,7 +597,9 @@ ErrorCode AttentionExecution::ensureTempBuffers_gpu(int batch, int num_head, int
bool temp_k_realloc = !mTempK_current_step || mTempK_current_step->shape() != temp_k_shape;
if(temp_k_realloc){
if(mTempK_current_step && mTempK_current_step->deviceId() !=0) mCudaBackend->onReleaseBuffer(mTempK_current_step.get(), Backend::STATIC);
mTempK_current_step.reset(Tensor::createDevice(temp_k_shape, mPrecision, Tensor::CAFFE));
mTempK_current_step.reset(mPrecision == 4
? Tensor::createDevice<float>(temp_k_shape)
: Tensor::createDevice<uint16_t>(temp_k_shape) );
if(!mTempK_current_step || !mCudaBackend->onAcquireBuffer(mTempK_current_step.get(), Backend::STATIC)) return MNN::OUT_OF_MEMORY;
}
@ -582,7 +608,9 @@ ErrorCode AttentionExecution::ensureTempBuffers_gpu(int batch, int num_head, int
bool temp_v_realloc = !mTempV_current_step || mTempV_current_step->shape() != temp_v_shape;
if(temp_v_realloc){
if(mTempV_current_step && mTempV_current_step->deviceId() !=0) mCudaBackend->onReleaseBuffer(mTempV_current_step.get(), Backend::STATIC);
mTempV_current_step.reset(Tensor::createDevice(temp_v_shape, mPrecision, Tensor::CAFFE));
mTempV_current_step.reset(mPrecision == 4
? Tensor::createDevice<float>(temp_v_shape)
: Tensor::createDevice<uint16_t>(temp_v_shape));
if(!mTempV_current_step || !mCudaBackend->onAcquireBuffer(mTempV_current_step.get(), Backend::STATIC)) return MNN::OUT_OF_MEMORY;
}
}
@ -602,11 +630,6 @@ bool AttentionExecution::onClone(Backend* bn, const Op* op, Execution** dst) {
// #define DEBUG_ATTENTION_VERBOSE
static inline float half_to_float_debug(const __half& h_val) {
return __half2float(h_val);
}
// 打印GPU张量指定切片的辅助函数
void print_gpu_tensor_debug(
const MNN::Tensor* target_tensor, // 要打印的张量 (可以是GPU或CPU上的)
const char* name // 张量的名称,用于日志
@ -620,12 +643,77 @@ void print_gpu_tensor_debug(
target_tensor->print();
}
// 打印GPU张量指定切片的辅助函数
void print_gpu_tensor_debug(
int mPrecision,
const MNN::Tensor* target_tensor,
const char* name
) {
if (!target_tensor) {
printf("\n--- Tensor [%s] is null. ---\n", name);
return;
}
printf("\n--- Tensor [%s] ---\n", name);
const MNN::Tensor* tensor_to_print = nullptr;
std::unique_ptr<MNN::Tensor, decltype(&MNN::Tensor::destroy)> host_tensor_holder(nullptr, &MNN::Tensor::destroy);
if (target_tensor->deviceId() != 0) {
host_tensor_holder.reset(MNN::Tensor::createHostTensorFromDevice(target_tensor, true));
if (!host_tensor_holder) {
printf("Error: Failed to copy device tensor to host for printing.\n");
return;
}
tensor_to_print = host_tensor_holder.get();
} else {
tensor_to_print = target_tensor;
}
printf("Shape: [ ");
const auto& shape = tensor_to_print->shape();
for (int val : shape) printf("%d ", val);
printf("], ");
const int total_elements = tensor_to_print->elementSize() / mPrecision;
if (total_elements == 0) {
printf("Data: (empty)\n");
return;
}
printf("Data (all %d elements):\n", total_elements);
int elements_per_line = total_elements; // 默认情况下,对于一维或零维张量,不换行
const int dims = tensor_to_print->dimensions();
if (dims > 1) elements_per_line = tensor_to_print->length(dims - 1);
if (mPrecision == 2) {
const __half* data_ptr = tensor_to_print->host<__half>();
for (int i = 0; i < total_elements; ++i) {
printf("%8.4f ", __half2float(data_ptr[i]));
if ((i + 1) % elements_per_line == 0 && (i + 1) < total_elements) printf("\n");
}
} else if (mPrecision == 4) {
const float* data_ptr = tensor_to_print->host<float>();
for (int i = 0; i < total_elements; ++i) {
printf("%8.4f ", data_ptr[i]);
if ((i + 1) % elements_per_line == 0 && (i + 1) < total_elements) printf("\n");
}
} else {
target_tensor->print();
}
printf("\n--- End of Tensor [%s] ---\n", name);
}
// 初始化所有参数
ErrorCode AttentionExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
const auto* query_tensor = inputs[0]; // 形状: [B, L_q, H_q, D]
const auto* key_tensor = inputs[1]; // 形状: [B, L_k_new, H_kv, D]
mPrecision = query_tensor->getType(); // 获取Tensor的halide_type_t
if (mCudaBackend->useFp16()) {
mPrecision = 2;
} else {
mPrecision = 4;
}
mBatch = query_tensor->length(0);
mQuerySeqLen = query_tensor->length(1);
@ -671,7 +759,7 @@ ErrorCode AttentionExecution::onExecute(const std::vector<Tensor *> &inputs, con
const Tensor* mask_input_tensor = mHasMask ? inputs[3] : nullptr;
auto final_output_tensor = outputs[0];
// qk_kernel 默认 Mask 为 [1, 1, L_q, L_k],但模型会出现 [1, 1, 1, 1] 值为 -0.00 的 Mask 代表无 Mask
// qk_kernel 默认 Mask 为 [L_q, L_k],但模型会出现 [1, 1, 1, 1] 值为 -0.00 的 Mask 代表无 Mask
if (mIsKVCacheEnabled && mHasMask && mask_input_tensor && mask_input_tensor->elementSize() == 1) {
mHasMask = false;
}
@ -758,15 +846,15 @@ ErrorCode AttentionExecution::onExecute(const std::vector<Tensor *> &inputs, con
// 拷贝新的 K, V 到 Cache
dim3 copy_blockDim(32, 8, 1); // 暂定 Block大小: 256 线程
dim3 copy_gridDim(UP_DIV(mHeadDim, (int)copy_blockDim.x),
UP_DIV(mNewKvSeqLen, (int)copy_blockDim.y),
UP_DIV(mBatch * mKvNumHead, (int)copy_blockDim.z));
if (mPrecision.bytes() == 4) { // float32
dim3 copy_gridDim(UP_DIV(mHeadDim, copy_blockDim.x),
UP_DIV(mNewKvSeqLen, copy_blockDim.y),
UP_DIV(mBatch * mKvNumHead, copy_blockDim.z));
if (mPrecision == 4) { // float32
copy_kv_to_cache_kernel<float><<<copy_gridDim, copy_blockDim, 0, stream>>>(
getTensorDevicePtr<float>(key_input_tensor), getTensorDevicePtr<float>(value_input_tensor),
getTensorDevicePtr<float>(mCache->mPastKey.get()), getTensorDevicePtr<float>(mCache->mPastValue.get()),
mBatch, mNewKvSeqLen, mKvNumHead, mHeadDim, param_cpu.past_kv_len, mCache->mMaxLength);
} else if (mPrecision.bytes() == 2) { // float16
} else if (mPrecision == 2) { // float16
copy_kv_to_cache_kernel<__half><<<copy_gridDim, copy_blockDim, 0, stream>>>(
getTensorDevicePtr<__half>(key_input_tensor), getTensorDevicePtr<__half>(value_input_tensor),
getTensorDevicePtr<__half>(mCache->mPastKey.get()), getTensorDevicePtr<__half>(mCache->mPastValue.get()),
@ -778,10 +866,10 @@ ErrorCode AttentionExecution::onExecute(const std::vector<Tensor *> &inputs, con
effective_value_cache_ptr = getTensorDevicePtr(mCache->mPastValue.get());
#ifdef DEBUG_ATTENTION_VERBOSE
if (current_total_kv_len_for_qk >= 0) {
if (current_total_kv_len_for_qk >= 20) {
// cudaStreamSynchronize(stream); // 确保 copy_kv_to_cache_kernel 完成
// if (mCache && mCache->mPastKey) print_gpu_tensor_debug(mCache->mPastKey.get(), "Key Cache (After copy_kv_to_cache)"); // L_kv_alloc, B, H_kv, D
// if (mCache && mCache->mPastValue) print_gpu_tensor_debug(mCache->mPastValue.get(), "Value Cache (After copy_kv_to_cache)"); // B, H_kv, D, L_kv_alloc
// if (mCache && mCache->mPastKey) print_gpu_tensor_debug(mPrecision, mCache->mPastKey.get(), "Key Cache (After copy_kv_to_cache)"); // L_kv_alloc, B, H_kv, D
// if (mCache && mCache->mPastValue) print_gpu_tensor_debug(mPrecision, mCache->mPastValue.get(), "Value Cache (After copy_kv_to_cache)"); // B, H_kv, D, L_kv_alloc
}
#endif
} else { // 没有 KV Cache, 使用当前步骤的临时Tensor存储 K/V
@ -796,15 +884,15 @@ ErrorCode AttentionExecution::onExecute(const std::vector<Tensor *> &inputs, con
if(temp_err != MNN::NO_ERROR) return temp_err;
dim3 copy_blockDim(32, 8, 1);
dim3 copy_gridDim(UP_DIV(mHeadDim, (int)copy_blockDim.x),
UP_DIV(mNewKvSeqLen, (int)copy_blockDim.y),
UP_DIV(mBatch * mKvNumHead, (int)copy_blockDim.z));
if (mPrecision.bytes() == 4) {
dim3 copy_gridDim(UP_DIV(mHeadDim, copy_blockDim.x),
UP_DIV(mNewKvSeqLen, copy_blockDim.y),
UP_DIV(mBatch * mKvNumHead, copy_blockDim.z));
if (mPrecision == 4) {
copy_kv_to_cache_kernel<float><<<copy_gridDim, copy_blockDim, 0, stream>>>(
getTensorDevicePtr<float>(key_input_tensor), getTensorDevicePtr<float>(value_input_tensor),
getTensorDevicePtr<float>(mTempK_current_step.get()), getTensorDevicePtr<float>(mTempV_current_step.get()),
mBatch, mNewKvSeqLen, mKvNumHead, mHeadDim, 0, allocated_kv_len_for_value_stride);
} else if (mPrecision.bytes() == 2) {
} else if (mPrecision == 2) {
copy_kv_to_cache_kernel<__half><<<copy_gridDim, copy_blockDim, 0, stream>>>(
getTensorDevicePtr<__half>(key_input_tensor), getTensorDevicePtr<__half>(value_input_tensor),
getTensorDevicePtr<__half>(mTempK_current_step.get()), getTensorDevicePtr<__half>(mTempV_current_step.get()),
@ -817,8 +905,8 @@ ErrorCode AttentionExecution::onExecute(const std::vector<Tensor *> &inputs, con
#ifdef DEBUG_ATTENTION_VERBOSE
// cudaStreamSynchronize(stream); // 确保 copy_kv_to_cache_kernel 完成
// if (mTempK_current_step) print_gpu_tensor_debug(mTempK_current_step.get(), "Temp K Current Step (Copied from Input K)");
// if (mTempV_current_step) print_gpu_tensor_debug(mTempV_current_step.get(), "Temp V Current Step (Copied from Input V)");
// if (mTempK_current_step) print_gpu_tensor_debug(mPrecision, mTempK_current_step.get(), "Temp K Current Step (Copied from Input K)");
// if (mTempV_current_step) print_gpu_tensor_debug(mPrecision, mTempV_current_step.get(), "Temp V Current Step (Copied from Input V)");
#endif
}
@ -844,16 +932,16 @@ ErrorCode AttentionExecution::onExecute(const std::vector<Tensor *> &inputs, con
// QK Kernel
dim3 qk_blockDim(16, 16, 1); // 暂用 256 线程
dim3 qk_gridDim(UP_DIV(current_total_kv_len_for_qk, (int)qk_blockDim.x),
UP_DIV(current_piece_actual_len, (int)qk_blockDim.y),
UP_DIV(mBatch * mNumHead, (int)qk_blockDim.z));
dim3 qk_gridDim(UP_DIV(current_total_kv_len_for_qk, qk_blockDim.x),
UP_DIV(current_piece_actual_len, qk_blockDim.y),
UP_DIV(mBatch * mNumHead, qk_blockDim.z));
if (mPrecision.bytes() == 4) {
if (mPrecision == 4) {
qk_kernel<float><<<qk_gridDim, qk_blockDim, 0, stream>>>(
getTensorDevicePtr<float>(query_input_tensor), static_cast<const float*>(effective_key_cache_ptr),
getTensorDevicePtr<float>(mTempQK.get()), mask_ptr_device,
mParam_gpu, q_seq_offset, mHasMask, mIsAddMask);
} else if (mPrecision.bytes() == 2) {
} else if (mPrecision == 2) {
qk_kernel<__half><<<qk_gridDim, qk_blockDim, 0, stream>>>(
getTensorDevicePtr<__half>(query_input_tensor), static_cast<const __half*>(effective_key_cache_ptr),
getTensorDevicePtr<__half>(mTempQK.get()), mask_ptr_device,
@ -862,28 +950,65 @@ ErrorCode AttentionExecution::onExecute(const std::vector<Tensor *> &inputs, con
checkKernelErrors;
// Softmax Kernel
int softmax_total_rows_for_piece = mBatch * mNumHead * current_piece_actual_len;
dim3 softmax_blockDim(256, 1, 1); // 每个线程处理一行
dim3 softmax_gridDim(UP_DIV(softmax_total_rows_for_piece, (int)softmax_blockDim.x), 1, 1);
if (mPrecision.bytes() == 4) {
softmax_kernel<float><<<softmax_gridDim, softmax_blockDim, 0, stream>>>(
getTensorDevicePtr<float>(mTempQK.get()), getTensorDevicePtr<float>(mTempSoftmax.get()), mParam_gpu, current_piece_actual_len);
} else if (mPrecision.bytes() == 2) {
softmax_kernel<__half><<<softmax_gridDim, softmax_blockDim, 0, stream>>>(
getTensorDevicePtr<__half>(mTempQK.get()), getTensorDevicePtr<__half>(mTempSoftmax.get()), mParam_gpu, current_piece_actual_len);
} else { return MNN::NOT_SUPPORT; }
// int softmax_total_rows_for_piece = mBatch * mNumHead * current_piece_actual_len;
// dim3 softmax_blockDim(256, 1, 1); // 每个线程处理一行
// dim3 softmax_gridDim(UP_DIV(softmax_total_rows_for_piece, (int)softmax_blockDim.x), 1, 1);
// if (mPrecision == 4) {
// softmax_kernel<float><<<softmax_gridDim, softmax_blockDim, 0, stream>>>(
// getTensorDevicePtr<float>(mTempQK.get()), getTensorDevicePtr<float>(mTempSoftmax.get()), mParam_gpu, current_piece_actual_len);
// } else if (mPrecision == 2) {
// softmax_kernel<__half><<<softmax_gridDim, softmax_blockDim, 0, stream>>>(
// getTensorDevicePtr<__half>(mTempQK.get()), getTensorDevicePtr<__half>(mTempSoftmax.get()), mParam_gpu, current_piece_actual_len);
// } else { return MNN::NOT_SUPPORT; }
// checkKernelErrors;
const int axis = current_total_kv_len_for_qk;
const int inside = 1;
const int outside = mBatch * mNumHead * current_piece_actual_len;
const int count = outside * inside;
const void* qk_scores_ptr = getTensorDevicePtr(mTempQK.get());
void* softmax_result_ptr = getTensorDevicePtr(mTempSoftmax.get());
// 根据 axis 长度和精度选择最优核函数并启动
// 启动配置:每个线程块(Block)处理一行(row)Softmax
// Grid维度 = 总行数(count)Block维度 = 用于并行计算一行的线程数
if (mPrecision == 4) { // FP32
const auto* input_ptr = static_cast<const float*>(qk_scores_ptr);
auto* output_ptr = static_cast<float*>(softmax_result_ptr);
if (axis <= 32) {
// 对于短序列使用Warp级别原语进行极致优化
SOFTMAX_WARP_32<float><<<count, 32, 0, stream>>>(input_ptr, output_ptr, inside, axis, outside, count);
} else {
// 对于长序列,使用基于共享内存的块内并行归约
constexpr int threads_per_block = 256;
const int calc_multi_num = UP_DIV(axis, threads_per_block);
SOFTMAX_AXIS_REDUCE<float><<<count, threads_per_block, 0, stream>>>(input_ptr, output_ptr, inside, axis, threads_per_block, calc_multi_num, outside, count);
}
} else { // FP16
const auto* input_ptr = static_cast<const __half*>(qk_scores_ptr);
auto* output_ptr = static_cast<__half*>(softmax_result_ptr);
if (axis <= 32) {
SOFTMAX_WARP_32<__half><<<count, 32, 0, stream>>>(input_ptr, output_ptr, inside, axis, outside, count);
} else {
constexpr int threads_per_block = 256;
const int calc_multi_num = UP_DIV(axis, threads_per_block);
SOFTMAX_AXIS_REDUCE<__half><<<count, threads_per_block, 0, stream>>>(input_ptr, output_ptr, inside, axis, threads_per_block, calc_multi_num, outside, count);
}
}
checkKernelErrors;
// QKV Kernel
dim3 qkv_blockDim(32, 8, 1); // 暂用 256 线程
dim3 qkv_gridDim(UP_DIV(mHeadDim, (int)qkv_blockDim.x),
UP_DIV(current_piece_actual_len, (int)qkv_blockDim.y),
UP_DIV(mBatch * mNumHead, (int)qkv_blockDim.z));
if (mPrecision.bytes() == 4) {
dim3 qkv_gridDim(UP_DIV(mHeadDim, qkv_blockDim.x),
UP_DIV(current_piece_actual_len, qkv_blockDim.y),
UP_DIV(mBatch * mNumHead, qkv_blockDim.z));
if (mPrecision == 4) {
qkv_kernel<float><<<qkv_gridDim, qkv_blockDim, 0, stream>>>(
getTensorDevicePtr<float>(mTempSoftmax.get()), static_cast<const float*>(effective_value_cache_ptr),
getTensorDevicePtr<float>(final_output_tensor), mParam_gpu, q_seq_offset);
} else if (mPrecision.bytes() == 2) {
} else if (mPrecision == 2) {
qkv_kernel<__half><<<qkv_gridDim, qkv_blockDim, 0, stream>>>(
getTensorDevicePtr<__half>(mTempSoftmax.get()), static_cast<const __half*>(effective_value_cache_ptr),
getTensorDevicePtr<__half>(final_output_tensor), mParam_gpu, q_seq_offset);
@ -916,23 +1041,22 @@ ErrorCode AttentionExecution::onExecute(const std::vector<Tensor *> &inputs, con
printf("--------------------------------------------------------------------\n");
if (current_total_kv_len_for_qk >= 20) {
// 打印 Mask (如果存在)
// Mask 通常是 [1, 1, Lq_full, Lk_total]
// 打印 Mask
if (mHasMask && mask_input_tensor) {
// print_gpu_tensor_debug(mask_input_tensor, "Mask Input Tensor (Relevant Slice view)");
// print_gpu_tensor_debug(mask_input_tensor, "Mask Input Tensor");
}
// 打印 QK^T Scores, mTempQK shape: [B, H_q, L_q_piece, L_k_total]
// print_gpu_tensor_debug(mTempQK.get(), "QK^T Scores (mTempQK for current piece)");
// print_gpu_tensor_debug(mPrecision, mTempQK.get(), "QK^T Scores (mTempQK for current piece)");
// 打印 Softmax Probabilities, mTempSoftmax shape: [B, H_q, L_q_piece, L_k_total]
// print_gpu_tensor_debug(mTempSoftmax.get(), "Softmax Probs (mTempSoftmax for current piece)");
// print_gpu_tensor_debug(mPrecision, mTempSoftmax.get(), "Softmax Probs (mTempSoftmax for current piece)");
// 打印 Attention Output, final_output_tensor shape: [B, L_q_full, H_q, D] or [1, 1, N]
if (final_output_tensor->dimensions()==3) { // 三维如 [1, 1, N]
// print_gpu_tensor_debug(final_output_tensor, "Attention Output Slice (Final)");
// print_gpu_tensor_debug(mPrecision, final_output_tensor, "Attention Output Slice (Final)");
} else { // 假设是 [B, Lq_full, Hq, D]
// print_gpu_tensor_debug(final_output_tensor, "Attention Output Slice (Final)");
// print_gpu_tensor_debug(mPrecision, final_output_tensor, "Attention Output Slice (Final)");
}
printf("========================= END PIECE DEBUG DUMP =======================\n");

View File

@ -80,7 +80,7 @@ private:
KVMeta* mMeta = nullptr;
const int mExpandChunk = 64; // KV Cache重分配时的扩展块大小
halide_type_t mPrecision; // 精度 (float或half)
int mPrecision; // 精度 (float或half)
};
#endif // MNN_SUPPORT_TRANSFORMER_FUSE

View File

@ -39,7 +39,7 @@ public:
#ifdef MNN_LOW_MEMORY
auto conv2dParams = op->main_as_Convolution2D();
bool isMemoryLowWeightOnlyQuant = (conv2dParams->quanParameter() != nullptr && conv2dParams->quanParameter()->buffer() != nullptr);
bool isMemoryLowWeightOnlyQuant = conv2dParams->quanParameter() && (conv2dParams->external() || conv2dParams->quanParameter()->buffer());
isMemoryLowWeightOnlyQuant = isMemoryLowWeightOnlyQuant && (static_cast<CUDABackend*>(backend)->getMemoryMode() == BackendConfig::Memory_Low);
isMemoryLowWeightOnlyQuant = isMemoryLowWeightOnlyQuant && ConvFpAIntBExecution::isValid(op->main_as_Convolution2D(), backend);
if (isMemoryLowWeightOnlyQuant) {

View File

@ -124,20 +124,19 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con
mFast = false;
break;
}
if(!_directBlitC4(slice0, slice, output)) {
mFast = false;
if(!_directBlitC4(slice0, slice, output)) {
mFast = false;
break;
}
if (!OpCommonUtils::canBlitFast(slice, output, pack, false, true)) {
mFast = false;
mFast = false;
break;
}
}
//MNN_PRINT("raster fast:%d regionNum:%d\n\n\n", mFast, des->regions.size());
if (mFast) {
int dstStep = 1;
for (int i=0; i< des->regions.size(); ++i) {
int srcStep = 1;
int dstStep = 1;
auto& slice = des->regions[i];
if(slice.dst.offset / (slice.size[2] * slice.size[1]) >= 1) {
int batchChannel = slice.dst.offset / (slice.size[1] * slice.size[2]) + 1;
@ -147,6 +146,11 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con
int tmp = slice.dst.stride[0] / slice.src.stride[0];
dstStep = dstStep > tmp ? dstStep : tmp;
}
}
for (int i=0; i< des->regions.size(); ++i) {
int srcStep = 1;
auto& slice = des->regions[i];
if(slice.src.offset / (slice.size[2] * slice.size[1]) >= 1) {
int batchChannel = slice.src.offset / (slice.size[1] * slice.size[2]) + 1;
srcStep = srcStep > batchChannel ? srcStep : batchChannel;

View File

@ -18,6 +18,13 @@
namespace MNN {
namespace CUDA {
template <typename T>
__global__ void SOFTMAX(const T *input, T *output, const int inside, const int axis, const int outside, const int count);
template <typename T>
__global__ void SOFTMAX_WARP_32(const T *input, T *output, const int inside, const int axis, const int outside, const int count);
template <typename T>
__global__ void SOFTMAX_AXIS_REDUCE(const T *input, T *output, const int inside, const int axis, const int per_block_size, const int calc_multi_num, const int outside, const int count);
class SoftmaxExecution : public Execution {
public:
SoftmaxExecution(int axis, Backend *backend);

View File

@ -10,7 +10,6 @@
#include "backend/cuda/core/CUDABackend.hpp"
#include "core/Execution.hpp"
#include "../CutlassGemmParam.hpp"
#include "../MNNCUDADefine.hpp"
#include "../MNNCUDAFunction.cuh"
#include "../cutlass_common/CutlassConvCommonExecution.hpp"
@ -27,12 +26,16 @@ public:
void* mScale;
void* mOffset;
void* mBias;
int mQuanC;
std::shared_ptr<Tensor> weightTensor;
std::shared_ptr<Tensor> scaleTensor;
std::shared_ptr<Tensor> offsetTensor;
std::shared_ptr<Tensor> biasTensor;
Backend* mBackend = nullptr;
bool mIsWeightInt4 = false;
std::shared_ptr<Tensor> mSumBQTensor;
void* mSumBQ = nullptr;
};
static bool isValid(const Convolution2D* conv, Backend* backend);
ConvFpAIntBExecution(Backend* backend, const MNN::Op* op, std::shared_ptr<Resource> res);
@ -43,6 +46,9 @@ public:
private:
std::shared_ptr<Resource> mResource;
std::shared_ptr<Tensor> mDequantFilterTensor;
void* mDequantFilter = nullptr;
};
} // namespace CUDA

View File

@ -100,7 +100,7 @@ void AttentionBufExecution::_init() {
mCache.reset(new SharedCache);
auto mtbn = static_cast<MetalBackend *>(backend());
auto context = (__bridge MNNMetalContext *)mtbn->context();
mMeta = (KVMeta*)(mtbn->getRuntime()->pMeta);
mMeta = (KVMeta*)(mtbn->getMetaPtr());
mParamQKV = [context newDeviceBuffer:sizeof(Param) access:CPUWriteOnly];
mParamSoftmax = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly];

View File

@ -269,6 +269,7 @@ private:
bool mUseFloatAsFp16;
bool mIsIphone = false;
BufferAllocator* mCurrentAllocator = nullptr;
};

View File

@ -1219,7 +1219,9 @@ Backend* MetalRuntime::onCreate(const BackendConfig* config, Backend* origin) co
memory = config->memory;
}
bool useFp16AsFp32 = precision != BackendConfig::Precision_High;
return new MetalBackend(mStatic, this, useFp16AsFp32, memory);
auto backend = new MetalBackend(mStatic, this, useFp16AsFp32, memory);
backend->setMetaPtr(pMeta);
return backend;
}
void MetalRuntime::onGabageCollect(int level) {

View File

@ -574,6 +574,7 @@ bool convertBufferToBuffer(Tensor *input, Tensor *output, OpenCLRuntime *runtime
return true;
}
#ifdef __ANDROID__
bool convertBetweenAHDandCLmem(const Tensor *input, const Tensor *output, OpenCLRuntime *runtime, int precision, int memType, bool toDevice, bool toHost) {
std::set<std::string> buildOptions;
auto srcDimensionFormat = TensorUtils::getDescribe(input)->dimensionFormat;
@ -584,25 +585,46 @@ bool convertBetweenAHDandCLmem(const Tensor *input, const Tensor *output, OpenCL
buildOptions.emplace("-DINPUT_FORMAT=" + std::to_string(srcDimensionFormat));
buildOptions.emplace("-DOUTPUT_FORMAT=" + std::to_string(dstDimensionFormat));
std::vector<int> outputShape;
std::shared_ptr<KernelWrap> kernelW;
if(toDevice){
buildOptions.emplace("-DSHARED_TO_CL");
kernelW = runtime->buildKernelWithCache("glmem_convert", "gl_to_cl", buildOptions, precision, nullptr, output);
outputShape = tensorShapeFormat(output);
} else if(toHost){
buildOptions.emplace("-DCL_TO_SHARED");
kernelW = runtime->buildKernelWithCache("glmem_convert", "cl_to_gl", buildOptions, precision, input, nullptr);
outputShape = tensorShapeFormat(input);
}else{
MNN_PRINT("convertGLMemBetweenCLmem only support toDevice or toHost!\n");
return false;
}
std::vector<int> outputShape = toDevice ? tensorShapeFormat(output): tensorShapeFormat(input);
int shape[4] = {outputShape[0], outputShape[3], outputShape[1], outputShape[2]};//N C H W
uint32_t gws[3] = {static_cast<uint32_t>(UP_DIV(shape[3], 4)),
static_cast<uint32_t>(UP_DIV(shape[1], 4)),
static_cast<uint32_t>(shape[0] * shape[2])};
std::shared_ptr<KernelWrap> kernelW;
int format = AHARDWAREBUFFER_FORMAT_R8G8B8A8_UNORM;
int stride = shape[3];
AHardwareBuffer_Desc Desc = {};
if(OpenCLSymbolsOperator::getOpenclSymbolsPtr()->isSupportAhardwareBufferFunc()){
if(toDevice){
MNNAHardwareBuffer_describe((AHardwareBuffer*)(((CLSharedMemReleaseBuffer*)TensorUtils::getSharedMem(input))->getSharedId()), &Desc);
}else{
MNNAHardwareBuffer_describe((AHardwareBuffer*)(((CLSharedMemReleaseBuffer*)TensorUtils::getSharedMem(output))->getSharedId()), &Desc);
}
format = Desc.format;
stride = Desc.stride;
}
if(format == AHARDWAREBUFFER_FORMAT_R8G8B8A8_UNORM){
if(toDevice){
buildOptions.emplace("-DSHARED_TO_CL");
kernelW = runtime->buildKernelWithCache("glmem_convert", "gl_to_cl", buildOptions, precision, nullptr, output);
} else if(toHost){
buildOptions.emplace("-DCL_TO_SHARED");
kernelW = runtime->buildKernelWithCache("glmem_convert", "cl_to_gl", buildOptions, precision, input, nullptr);
}
}else if(format == AHARDWAREBUFFER_FORMAT_Y8Cb8Cr8_420){
if(toDevice){
buildOptions.emplace("-DSHARED_TO_CL");
kernelW = runtime->buildKernelWithCache("glmem_convert", "yuv_to_cl", buildOptions, precision, nullptr, output);
} else if(toHost){
buildOptions.emplace("-DCL_TO_SHARED");
kernelW = runtime->buildKernelWithCache("glmem_convert", "cl_to_yuv", buildOptions, precision, input, nullptr);
}
}else{
MNN_PRINT("convertGLMemBetweenCLmem only support AHARDWAREBUFFER_FORMAT_R8G8B8A8_UNORM or AHARDWAREBUFFER_FORMAT_Y8Cb8Cr8_420!\n");
return false;
}
auto Kernel = kernelW->get();
uint32_t idx = 0;
cl_int ret = CL_SUCCESS;
@ -629,6 +651,7 @@ bool convertBetweenAHDandCLmem(const Tensor *input, const Tensor *output, OpenCL
}
}
ret |= Kernel.setArg(idx++, sizeof(shape), shape);
ret |= Kernel.setArg(idx++, stride);
MNN_CHECK_CL_SUCCESS(ret, "setArg glmem_convert");
const uint32_t maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(kernelW));
@ -647,6 +670,7 @@ bool convertBetweenAHDandCLmem(const Tensor *input, const Tensor *output, OpenCL
MNN_CHECK_CL_SUCCESS(res, "glmem_convert");
return true;
}
#endif
} // namespace OpenCL
} // namespace MNN

View File

@ -34,7 +34,9 @@ bool convertNC4HW4BufferBetweenNC16HW16Buffer(const Tensor *input, Tensor *outpu
#endif
bool convertBufferToBuffer(Tensor *input, Tensor *output, OpenCLRuntime *runtime, int input_precision, int output_precision, int backend_precison, bool toDevice, bool toHost, bool needWait = false, bool svmFlag = false);
#ifdef __ANDROID__
bool convertBetweenAHDandCLmem(const Tensor *input, const Tensor *output, OpenCLRuntime *runtime, int precision, int memType, bool toDevice, bool toHost);
#endif
class BufferConvertor {
public:

Some files were not shown because too many files have changed in this diff Show More