mirror of https://github.com/alibaba/MNN.git
Merge pull request #3747 from alibaba/feature/sync
MNN:Sync: Sync Internal 3.2.2
This commit is contained in:
commit
a739ea5870
|
@ -78,6 +78,7 @@ option(MNN_SUPPORT_BF16 "Enable MNN's bf16 op" OFF)
|
|||
option(MNN_LOW_MEMORY "Build MNN support low memory for weight quant model." OFF)
|
||||
option(MNN_CPU_WEIGHT_DEQUANT_GEMM "Build MNN CPU weight dequant related gemm kernels." OFF)
|
||||
option(MNN_BUILD_AUDIO "Build audio api in MNN." OFF)
|
||||
option(MNN_SME2 "Use Arm sme2 instructions" ON)
|
||||
|
||||
if (MNN_BUILD_MINI)
|
||||
set(MNN_SKIPBUILD_GEOMETRY ON)
|
||||
|
|
|
@ -100,7 +100,7 @@ ErrorCode CPULSTM::onResize(const std::vector<Tensor *> &inputs, const std::vect
|
|||
auto temp = tempBuffer->host<float>();
|
||||
auto dest = dst + n * UP_DIV(timeSteps, hP) * numFeatures * hP;
|
||||
MNNUnpackC4(temp, source, numFeatures, timeSteps);
|
||||
MNNPackForMatMul_B(dest, temp, timeSteps, numFeatures, true);
|
||||
MNNPackForMatMul_B(dest, temp, timeSteps, 1, numFeatures, true);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -159,7 +159,8 @@ void MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup, con
|
|||
return;
|
||||
}
|
||||
|
||||
void MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose) {
|
||||
void MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
||||
auto l = kernelsize * ic;
|
||||
auto hP = h / 4;
|
||||
auto hR = hP * 4;
|
||||
if (hR != h) {
|
||||
|
|
|
@ -71,7 +71,7 @@ static void _winograd(const DeconvolutionWithStride::ComputeUnit& unit, int thre
|
|||
el[3] = 0;
|
||||
size_t parameters[6];
|
||||
parameters[0] = eP * sizeof(float);
|
||||
parameters[1] = ic;
|
||||
parameters[1] = ROUND_UP(ic, lP);
|
||||
parameters[2] = oc;
|
||||
parameters[3] = eP * 4 * sizeof(float);
|
||||
parameters[4] = 0;
|
||||
|
@ -130,7 +130,7 @@ static void _gemmAndIm2col(const DeconvolutionWithStride::ComputeUnit& unit, int
|
|||
el[3] = 0;
|
||||
size_t parameters[6];
|
||||
parameters[0] = eP * sizeof(float);
|
||||
parameters[1] = ic;
|
||||
parameters[1] = ROUND_UP(ic, lP);
|
||||
parameters[2] = oc;
|
||||
parameters[3] = eP * 4 * sizeof(float);
|
||||
parameters[4] = 0;
|
||||
|
|
|
@ -62,6 +62,84 @@ static void dumpCmd(const Command* cmd) {
|
|||
MNN_PRINT("}\n");
|
||||
}
|
||||
|
||||
void mergeConvolutionAndPrelu(Node* root, MNNForwardType forwardType){
|
||||
if (root->cmd->op != nullptr && root->cmd->op->type() == OpType_Convolution && root->succ.size() == 1) {
|
||||
auto child = root->succ[0];
|
||||
if(child->cmd->op->type() == OpType_PReLU){
|
||||
if(root->cmd->op->externalPath() != nullptr){
|
||||
return;
|
||||
}
|
||||
std::shared_ptr<Command> cmdPlugin;
|
||||
auto inputs = root->cmd->inputs;
|
||||
auto outputs = root->cmd->outputs;
|
||||
auto convOp = root->cmd->op->main_as_Convolution2D();
|
||||
if(convOp->quanParameter() != nullptr || convOp->symmetricQuan() != nullptr || convOp->sparseParameter() != nullptr || convOp->external() != nullptr || convOp->common()->outputCount() != child->cmd->op->main_as_PRelu()->slopeCount()){
|
||||
return;
|
||||
}
|
||||
std::unique_ptr<OpT> fuseOp(new OpT);
|
||||
fuseOp->type = OpType_Extra;
|
||||
fuseOp->name = root->cmd->op->name()->str();
|
||||
ExtraT* extra_param = new ExtraT;
|
||||
extra_param->type = "ExtraConvolution2DPrelu";
|
||||
extra_param->attr.resize(2);
|
||||
// copy convolution2D param
|
||||
AttributeT* convAtr = new AttributeT;
|
||||
BlobT* convParamBlob = new BlobT;
|
||||
{
|
||||
std::unique_ptr<Convolution2DT> convolutionParam(convOp->UnPack());
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto lastOffset = Convolution2D::Pack(builder, convolutionParam.get());
|
||||
builder.Finish(lastOffset);
|
||||
|
||||
const uint8_t* buffer_ptr = builder.GetBufferPointer();
|
||||
const size_t size = builder.GetSize();
|
||||
convParamBlob->uint8s.resize(size);
|
||||
::memcpy(convParamBlob->uint8s.data(), buffer_ptr, size);
|
||||
}
|
||||
convAtr->tensor.reset(convParamBlob);
|
||||
extra_param->attr[0].reset(convAtr);
|
||||
|
||||
// copy prelu param
|
||||
AttributeT* preluAtr = new AttributeT;
|
||||
BlobT* preluParamBlob = new BlobT;
|
||||
{
|
||||
std::unique_ptr<PReluT> preluParam(child->cmd->op->main_as_PRelu()->UnPack());
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto lastOffset = PRelu::Pack(builder, preluParam.get());
|
||||
builder.Finish(lastOffset);
|
||||
const uint8_t* buffer_ptr = builder.GetBufferPointer();
|
||||
const size_t size = builder.GetSize();
|
||||
preluParamBlob->uint8s.resize(size);
|
||||
::memcpy(preluParamBlob->uint8s.data(), buffer_ptr, size);
|
||||
}
|
||||
preluAtr->tensor.reset(preluParamBlob);
|
||||
extra_param->attr[1].reset(preluAtr);
|
||||
|
||||
fuseOp->main.type = OpParameter_Extra;
|
||||
fuseOp->main.value = extra_param;
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
auto lastOffset = Op::Pack(builder, fuseOp.get());
|
||||
builder.Finish(lastOffset);
|
||||
cmdPlugin = GeometryComputerUtils::makeCommand(builder, inputs, outputs);
|
||||
|
||||
root->cmd->op = cmdPlugin->op;
|
||||
root->cmd->inputs = cmdPlugin->inputs;
|
||||
root->cmd->outputs = cmdPlugin->outputs;
|
||||
root->cmd->buffer = cmdPlugin->buffer;
|
||||
child->cmd->op = nullptr;
|
||||
child->cmd->buffer.reset();
|
||||
for(auto &childNode : child->succ){
|
||||
for(auto &input : childNode->cmd->inputs){
|
||||
if(input == child->cmd->outputs[0]){
|
||||
input = root->cmd->outputs[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
root->succ = child->succ;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// is legal fused type
|
||||
bool isLegal(Command* cmd, MNNForwardType forwardType) {
|
||||
auto type = cmd->op->type();
|
||||
|
@ -369,6 +447,20 @@ bool opFuse(std::vector<Schedule::OpCacheInfo>& infos, MNNForwardType type, Back
|
|||
graph.push_back(std::move(node));
|
||||
}
|
||||
}
|
||||
|
||||
if(type == MNN_FORWARD_OPENCL){
|
||||
for(int i = 0; i < graph.size(); ++i){
|
||||
mergeConvolutionAndPrelu(graph[i].get(), type);
|
||||
}
|
||||
for(auto iter = graph.begin(); iter != graph.end();){
|
||||
if(iter->get()->cmd->op == nullptr){
|
||||
iter = graph.erase(iter);
|
||||
}else{
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::queue<Node*> postDominateNodeQueue;
|
||||
// build dominate tree
|
||||
for (int i = static_cast<int>(graph.size()) - 1; i >= 0; i--) {
|
||||
|
|
|
@ -19,7 +19,8 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
|
|||
| MNN_SUPPORT_QUNAT_EXTEND | 是否编译非核心算子的量化版本,默认为`ON` |
|
||||
| MNN_SUPPORT_DEPRECATED_OP | 是否支持Tflite的量化算子等已经废弃的算子,用于兼容历史模型(1.1.0版本之前),默认为`OFF` |
|
||||
| MNN_SUPPORT_DEPRECATED_OPV2 | 是否编译MNN更新到3.0之后已经废弃的算子,用于兼容历史模型(3.0.0版本之前),比如 Convolution3D 和 ConvTranspose3D在3.0.0 版本之后改由模型转换器转化为对应2D算子,不再需要运行时支持,默认为`ON` |
|
||||
| MNN_REDUCE_SIZE | 是否裁剪MNN库大小,去除求导相关算子,减少优化策略,默认为`OFF` ,开启时,MNN_SUPPORT_QUNAT_EXTEND / MNN_SUPPORT_DEPRECATED_OP / MNN_SUPPORT_DEPRECATED_OPV2 都会设成 OFF|
|
||||
| MNN_REDUCE_SIZE | 是否裁剪MNN库大小,去除求导相关算子,减少优化策略,默认为`OFF` ,开启时,MNN_SUPPORT_QUANT_EXTEND / MNN_SUPPORT_DEPRECATED_OP / MNN_SUPPORT_DEPRECATED_OPV2 都会设成 OFF|
|
||||
| MNN_SUPPORT_QUANT_EXTEND | 是否开启Binary/Unary等算子的量化计算支持,默认为`ON` |
|
||||
| MNN_DEBUG_MEMORY | 是否开启MNN内存调试,默认为`OFF` |
|
||||
| MNN_DEBUG_TENSOR_SIZE | 是否开启MNN tensor size调试,默认为`OFF` |
|
||||
| MNN_GPU_TRACE | 是否开启MNN GPU调试,默认为`OFF` |
|
||||
|
@ -43,6 +44,7 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
|
|||
| MNN_OPENGL | 是否构建`OpenGL`后端,默认为`OFF` |
|
||||
| MNN_VULKAN | 是否构建`Vulkan`后端,默认为`OFF` |
|
||||
| MNN_ARM82 | 编译ARM架构时,是否构建`Armv8.2`后端,以支持FP16计算,默认为`ON` |
|
||||
| MNN_SME2 | 编译ARM架构时,是否构建`ArmSme2`后端,以支持使用sme2指令集计算,默认为`ON` |
|
||||
| MNN_SUPPORT_FP16_ARMV7 | 编译armeabi-v7a架构时,是否构建`Armv8.2`后端,以支持FP16计算,默认为`OFF` |
|
||||
| MNN_ONEDNN | 是否使用`oneDNN`,默认为`OFF` |
|
||||
| MNN_AVX2 | 在`MNN_USE_SSE`开启的基础上,是否增加AVX2指令的支持,默认为`ON` |
|
||||
|
@ -55,6 +57,8 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
|
|||
| MNN_TENSORRT | 是否构建`TensorRT`后端,默认为`OFF` |
|
||||
| MNN_COREML | 是否构建`CoreML`后端,默认为`OFF` |
|
||||
| MNN_NNAPI | 是否构建`NNAPI`后端,默认为`OFF` |
|
||||
| MNN_QNN | 是否构建`QNN`后端,默认为`OFF` |
|
||||
| MNN_QNN_CONVERT_MODE | 在`MNN_QNN`开启的基础上,是否构建Convert模式的QNN后端,默认为`OFF` |
|
||||
| MNN_BUILD_BENCHMARK | 是否构建MNN的性能测试,默认为`OFF` |
|
||||
| MNN_BUILD_TEST | 是否构建MNN的单元测试,默认为`OFF` |
|
||||
| MNN_BUILD_FOR_ANDROID_COMMAND | 是否使用命令行构建`Android`,默认为`OFF` |
|
||||
|
|
|
@ -171,6 +171,7 @@
|
|||
- `rasterDemo.out` Raster示例
|
||||
- `nluDemo.out` nlu模型示例
|
||||
- `mergeInplaceForCPU` 将模型中可以Inplace计算的算子改成Inplace计算,可以减少内存占用,但限定CPU后端运行
|
||||
- `OpenCLProgramBuildTest.out` 测试OpenCL后端的Program在设备上是否能编译成功
|
||||
## 单元测试
|
||||
- 相关编译选项
|
||||
- `MNN_BUILD_TEST` 是否编译MNN单元测试
|
||||
|
|
|
@ -3132,6 +3132,20 @@ roialign
|
|||
```python
|
||||
TODO
|
||||
```
|
||||
|
||||
---
|
||||
### `jsonop(inputs, describe, output_number)`
|
||||
|
||||
jsonop
|
||||
|
||||
对于MNN模型支持的算子,但没有相应表达式透出的情况,可以使用jsonop接口,以json描述算子
|
||||
|
||||
参数:
|
||||
- `inputs` : List[Var] 输入变量数组,任意类型
|
||||
- `describe` : str ,算子的json描述
|
||||
- `output_number` : int, 算子输出数
|
||||
|
||||
|
||||
---
|
||||
**以下函数为框架开发者使用函数,普通用户不建议使用!**
|
||||
|
||||
|
|
|
@ -85,6 +85,8 @@ Usage:
|
|||
|
||||
--useGeluApproximation 在进行Gelu算子合并时,使用Gelu的近似算法,默认为1 ,也就是`true`
|
||||
|
||||
--useOriginRNNImpl LSTM和GRU算子是否使用原始算子实现,默认关闭。若开启,性能可能提升,但无法进行LSTM/GRU的量化
|
||||
|
||||
```
|
||||
|
||||
**说明1: 选项weightQuantBits,使用方式为 --weightQuantBits numBits,numBits可选2~8,此功能仅对conv/matmul/LSTM的float32权值进行量化,仅优化模型大小,加载模型后会解码为float32,量化位宽可选2~8,运行速度和float32模型一致。经内部测试8bit时精度基本无损,模型大小减小4倍。default: 0,即不进行权值量化。**
|
||||
|
|
|
@ -336,6 +336,7 @@ void Executor::RuntimeManager::setCache(std::string cacheName) {
|
|||
|
||||
mInside->mCache.reset(new Cache);
|
||||
mInside->mCache->cacheFile = cacheName;
|
||||
mInside->mInfo->onSetCachePath(cacheName.c_str(), 0);
|
||||
if (nullptr == mInside->mCache->cacheFile.c_str()) {
|
||||
MNN_ERROR("Empty cacheFile\n");
|
||||
return;
|
||||
|
|
|
@ -1339,6 +1339,18 @@ std::vector<EXPRP> Variable::getExecuteOrder(const std::vector<VARP>& outputs) {
|
|||
if (nullptr == output) {
|
||||
continue;
|
||||
}
|
||||
if (nullptr == output->expr().first) {
|
||||
continue;
|
||||
}
|
||||
auto op = output->expr().first->get();
|
||||
bool isConst = ((op && op->type() == OpType_Const) || (!op && output->expr().first->inputType() == VARP::CONSTANT));
|
||||
if (isConst) {
|
||||
if (!output->expr().first->visited()){
|
||||
output->expr().first->setVisited(true);
|
||||
sequence.emplace_back(output->expr().first);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
workStack.push(output->expr().first);
|
||||
}
|
||||
while (!workStack.empty()) {
|
||||
|
|
|
@ -128,6 +128,10 @@ VARP _Conv(VARP weight, VARP bias, VARP x, PaddingMode pad, INTS stride, INTS di
|
|||
std::unique_ptr<OpT> convOp(new OpT);
|
||||
convOp->type = OpType_Convolution;
|
||||
auto shape = weight->getInfo();
|
||||
if (shape == nullptr) {
|
||||
MNN_ERROR("Weight for convolution should have shape information.\n");
|
||||
return nullptr;
|
||||
}
|
||||
if (NHWC == shape->order) {
|
||||
weight = _Transpose(weight, {0, 3, 1, 2});
|
||||
shape = weight->getInfo();
|
||||
|
@ -268,6 +272,10 @@ VARP _Deconv(VARP weight, VARP bias, VARP x, PaddingMode pad, INTS stride, INTS
|
|||
std::unique_ptr<OpT> convOp(new OpT);
|
||||
convOp->type = OpType_Deconvolution;
|
||||
auto shape = weight->getInfo();
|
||||
if (shape == nullptr) {
|
||||
MNN_ERROR("weight's info is null\n");
|
||||
return nullptr;
|
||||
}
|
||||
auto channel = std::vector<int>{shape->dim[1], shape->dim[0]};
|
||||
auto kernelSize = std::vector<int>{shape->dim[3], shape->dim[2]};
|
||||
if (channel[1] * channel[0] == group) {
|
||||
|
@ -1048,6 +1056,10 @@ VARP _BatchToSpaceND(VARP input, VARP block_shape, VARP crops) {
|
|||
|
||||
auto info_block_shape = block_shape->getInfo();
|
||||
auto info_crops = crops->getInfo();
|
||||
if (info_block_shape == nullptr || info_crops == nullptr) {
|
||||
MNN_ERROR("BatchToSpaceND: block_shape or crops info is null.\n");
|
||||
return nullptr;
|
||||
}
|
||||
MNN_ASSERT(info_block_shape != nullptr);
|
||||
MNN_ASSERT(info_crops != nullptr);
|
||||
MNN_ASSERT(halide_type_int == info_block_shape->type.code);
|
||||
|
@ -1057,6 +1069,10 @@ VARP _BatchToSpaceND(VARP input, VARP block_shape, VARP crops) {
|
|||
blob_blockShape->dataFormat = (MNN_DATA_FORMAT)Utils::convertFormat(info_block_shape->order);
|
||||
blob_blockShape->dataType = (MNN::DataType)Utils::convertDataType(info_block_shape->type);
|
||||
auto data_block_shape = block_shape->readMap<int>();
|
||||
if (data_block_shape == nullptr) {
|
||||
MNN_ERROR("BatchToSpaceND: block_shape data is null.\n");
|
||||
return nullptr;
|
||||
}
|
||||
for (int i=0; i<info_block_shape->size; i++)
|
||||
{
|
||||
blob_blockShape->int32s.emplace_back(data_block_shape[i]);
|
||||
|
@ -1065,6 +1081,10 @@ VARP _BatchToSpaceND(VARP input, VARP block_shape, VARP crops) {
|
|||
blob_paddings->dataFormat = (MNN_DATA_FORMAT)Utils::convertFormat(info_crops->order);
|
||||
blob_paddings->dataType = (MNN::DataType)Utils::convertDataType(info_crops->type);
|
||||
auto data_crop = crops->readMap<int>();
|
||||
if (data_crop == nullptr) {
|
||||
MNN_ERROR("BatchToSpaceND: crops data is null.\n");
|
||||
return nullptr;
|
||||
}
|
||||
for (int i=0; i<info_crops->size; i++)
|
||||
{
|
||||
blob_paddings->int32s.emplace_back(data_crop[i]);
|
||||
|
@ -1178,6 +1198,10 @@ VARP _SpaceToBatchND(VARP input, VARP block_shape, VARP paddings) {
|
|||
auto param = new SpaceBatchT;
|
||||
auto info_block_shape = block_shape->getInfo();
|
||||
auto info_paddings = paddings->getInfo();
|
||||
if (info_block_shape == nullptr || info_paddings == nullptr) {
|
||||
MNN_ERROR("SpaceToBatchND: block_shape or paddings info is null.\n");
|
||||
return nullptr;
|
||||
}
|
||||
MNN_ASSERT(info_block_shape != nullptr);
|
||||
MNN_ASSERT(info_paddings != nullptr);
|
||||
MNN_ASSERT(halide_type_int == info_block_shape->type.code);
|
||||
|
@ -1187,6 +1211,10 @@ VARP _SpaceToBatchND(VARP input, VARP block_shape, VARP paddings) {
|
|||
blob_blockShape->dataFormat = (MNN::MNN_DATA_FORMAT)Utils::convertFormat(info_block_shape->order);
|
||||
blob_blockShape->dataType = (MNN::DataType)Utils::convertDataType(info_block_shape->type);
|
||||
auto data_block_shape = block_shape->readMap<int>();
|
||||
if (data_block_shape == nullptr) {
|
||||
MNN_ERROR("SpaceToBatchND: block_shape data is null.\n");
|
||||
return nullptr;
|
||||
}
|
||||
for (int i=0; i<info_block_shape->size; i++)
|
||||
{
|
||||
blob_blockShape->int32s.emplace_back(data_block_shape[i]);
|
||||
|
@ -1195,6 +1223,10 @@ VARP _SpaceToBatchND(VARP input, VARP block_shape, VARP paddings) {
|
|||
blob_paddings->dataFormat = (MNN::MNN_DATA_FORMAT)Utils::convertFormat(info_paddings->order);
|
||||
blob_paddings->dataType = (MNN::DataType)Utils::convertDataType(info_paddings->type);
|
||||
auto data_paddings = paddings->readMap<int>();
|
||||
if (data_paddings == nullptr) {
|
||||
MNN_ERROR("SpaceToBatchND: paddings data is null.\n");
|
||||
return nullptr;
|
||||
}
|
||||
for (int i=0; i<info_paddings->size; i++)
|
||||
{
|
||||
blob_paddings->int32s.emplace_back(data_paddings[i]);
|
||||
|
@ -1235,7 +1267,10 @@ std::vector <VARP> _Unstack(VARP value, int axis) {
|
|||
std::unique_ptr<OpT> op(new OpT);
|
||||
op->type = OpType_Unpack;
|
||||
auto info_value = value->getInfo();
|
||||
MNN_ASSERT(info_value != nullptr);
|
||||
if (info_value == nullptr) {
|
||||
MNN_ERROR("Unstack: value info is null.\n");
|
||||
return {};
|
||||
}
|
||||
auto dims = info_value->dim;
|
||||
auto dimsize = dims.size();
|
||||
MNN_ASSERT(dimsize >= 1);
|
||||
|
@ -1703,9 +1738,12 @@ VARP _FloatToInt8(VARP x, VARP scale, int8_t minValue, int8_t maxValue, int8_t z
|
|||
}
|
||||
|
||||
VARP _Int8ToFloat(VARP x, VARP scale) {
|
||||
auto xInfo = x->getInfo();
|
||||
auto scaleInfo = scale->getInfo();
|
||||
auto scalePtr = scale->readMap<float>();
|
||||
if (nullptr == scaleInfo) {
|
||||
MNN_ERROR("Error for Int8ToFloat because var not ready\n");
|
||||
return nullptr;
|
||||
}
|
||||
std::unique_ptr<OpT> op(new OpT);
|
||||
op->type = OpType_Int8ToFloat;
|
||||
op->main.type = OpParameter_QuantizedFloatParam;
|
||||
|
@ -1716,9 +1754,21 @@ VARP _Int8ToFloat(VARP x, VARP scale) {
|
|||
}
|
||||
|
||||
VARP _Int8ToFloat(VARP x, VARP scale, int8_t zeroPoint) {
|
||||
auto xInfo = x->getInfo();
|
||||
auto scaleInfo = scale->getInfo();
|
||||
if (nullptr == scaleInfo) {
|
||||
MNN_ERROR("Error for Int8ToFloat because var not ready\n");
|
||||
return nullptr;
|
||||
}
|
||||
if (scaleInfo->size <= 0) {
|
||||
MNN_ERROR("Error for Int8ToFloat because scale size is zero\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto scalePtr = scale->readMap<float>();
|
||||
if (nullptr == scalePtr) {
|
||||
MNN_ERROR("Error for Int8ToFloat because scale data is null\n");
|
||||
return nullptr;
|
||||
}
|
||||
std::unique_ptr<OpT> op(new OpT);
|
||||
op->type = OpType_Int8ToFloat;
|
||||
op->main.type = OpParameter_QuantizedFloatParam;
|
||||
|
@ -1786,13 +1836,10 @@ VARP _Sort(VARP x, int axis, bool arg, bool descend) {
|
|||
auto topk = new TopKV2T;
|
||||
topk->largest = descend;
|
||||
op->main.value = topk;
|
||||
auto shape = x->getInfo()->dim;
|
||||
axis = axis < 0 ? shape.size() + axis : axis;
|
||||
int k = x->getInfo()->dim[axis];
|
||||
std::vector<VARP> inputs {x, _Scalar(k)};
|
||||
if (axis + 1 != shape.size()) {
|
||||
inputs.push_back(_Scalar(axis));
|
||||
}
|
||||
auto posAxis = _Mod(_Scalar(axis), _Rank(x));
|
||||
auto K = _Slice(_Shape(x), _Unsqueeze(posAxis, {0}), _Unsqueeze(_Scalar<int32_t>(1), {0}));
|
||||
|
||||
std::vector<VARP> inputs {x, K, _Scalar(axis)};
|
||||
auto expr = Expr::create(op.get(), inputs, 2);
|
||||
return Variable::create(expr, arg);
|
||||
}
|
||||
|
|
|
@ -202,25 +202,52 @@ Executor::ComputeCache::~ComputeCache() {
|
|||
FUNC_PRINT(gInstanceCount);
|
||||
#endif
|
||||
}
|
||||
Executor::RuntimeExecuteWrap::RuntimeExecuteWrap(const RuntimeInfo& info) : mRt(info) {
|
||||
for (auto& iter : mRt.first) {
|
||||
iter.second->onConcurrencyBegin();
|
||||
}
|
||||
}
|
||||
Executor::RuntimeExecuteWrap::~RuntimeExecuteWrap() {
|
||||
for (auto& iter : mRt.first) {
|
||||
iter.second->onConcurrencyEnd();
|
||||
}
|
||||
}
|
||||
ErrorCode Executor::ComputeCache::compute() {
|
||||
std::stack<ComputeCache*> dfsStack;
|
||||
std::set<ComputeCache*> visited;
|
||||
dfsStack.push(this);
|
||||
ErrorCode code = NO_ERROR;
|
||||
auto globalExecutor = ExecutorScope::Current();
|
||||
auto& rt = globalExecutor->mRuntimeInfo;
|
||||
for (auto& iter : rt.first) {
|
||||
iter.second->onConcurrencyBegin();
|
||||
}
|
||||
auto debug = globalExecutor->getDebugTools();
|
||||
auto hasUnvisitInput = [&] (ComputeCache* cache) {
|
||||
for (auto c : cache->mInputs) {
|
||||
if (visited.find(c.get()) == visited.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
// Check need compute or not
|
||||
while (!dfsStack.empty()) {
|
||||
//printf("stcak = %d\n", dfsStack.size());
|
||||
auto cache = dfsStack.top();
|
||||
dfsStack.pop();
|
||||
for (auto& c : cache->mInputInside) {
|
||||
if (c->mContentDirty) {
|
||||
return CALL_BACK_STOP;
|
||||
}
|
||||
}
|
||||
if (hasUnvisitInput(cache)) {
|
||||
for (auto c : cache->mInputs) {
|
||||
dfsStack.push(c.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Compute
|
||||
visited.clear();
|
||||
dfsStack.push(this);
|
||||
ErrorCode code = NO_ERROR;
|
||||
auto glo = ExecutorScope::Current();
|
||||
RuntimeExecuteWrap wrap(glo->mRuntimeInfo);
|
||||
auto debug = glo->getDebugTools();
|
||||
while (!dfsStack.empty()) {
|
||||
auto cache = dfsStack.top();
|
||||
if (cache->mShapeDirty) {
|
||||
auto code = cache->resize();
|
||||
if (NO_ERROR != code) {
|
||||
|
@ -233,15 +260,7 @@ ErrorCode Executor::ComputeCache::compute() {
|
|||
dfsStack.pop();
|
||||
continue;
|
||||
}
|
||||
auto hasUnvisitInput = [&] () {
|
||||
for (auto c : cache->mInputs) {
|
||||
if (visited.find(c.get()) == visited.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
if (hasUnvisitInput()) {
|
||||
if (hasUnvisitInput(cache)) {
|
||||
for (auto c : cache->mInputs) {
|
||||
dfsStack.push(c.get());
|
||||
}
|
||||
|
@ -259,9 +278,6 @@ ErrorCode Executor::ComputeCache::compute() {
|
|||
cache->mContentDirty = false;
|
||||
}
|
||||
}
|
||||
for (auto& iter : rt.first) {
|
||||
iter.second->onConcurrencyEnd();
|
||||
}
|
||||
return NO_ERROR;
|
||||
}
|
||||
ErrorCode Executor::ComputeCache::resizeImpl() {
|
||||
|
|
|
@ -82,7 +82,13 @@ public:
|
|||
static Tensor* getTensor(VARP var);
|
||||
static EXPRP makeRaster(const std::vector<VARP>& vars, const std::vector<int>& regions, const std::vector<int>& shape, halide_type_t dataType, MNN_DATA_FORMAT format);
|
||||
};
|
||||
|
||||
class Executor::RuntimeExecuteWrap {
|
||||
public:
|
||||
RuntimeExecuteWrap(const RuntimeInfo& info);
|
||||
~ RuntimeExecuteWrap();
|
||||
private:
|
||||
const RuntimeInfo& mRt;
|
||||
};
|
||||
} // namespace Express
|
||||
} // namespace MNN
|
||||
#endif
|
||||
|
|
|
@ -216,12 +216,10 @@ public:
|
|||
auto glo = ExecutorScope::Current();
|
||||
glo->getDebugTools()->flops = 0.0f;
|
||||
#endif
|
||||
for (auto& iter : mInfo->runTimeManager->getInside()->mRuntime.first) {
|
||||
iter.second->onConcurrencyBegin();
|
||||
}
|
||||
auto outputs = mModule->onForward(inputs);
|
||||
for (auto& iter : mInfo->runTimeManager->getInside()->mRuntime.first) {
|
||||
iter.second->onConcurrencyEnd();
|
||||
std::vector<VARP> outputs;
|
||||
{
|
||||
Executor::RuntimeExecuteWrap wrap(mInfo->runTimeManager->getInside()->mRuntime);
|
||||
outputs = mModule->onForward(inputs);
|
||||
}
|
||||
#ifdef MNN_INTERNAL_ENABLED
|
||||
do {
|
||||
|
|
|
@ -775,6 +775,10 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
|
|||
std::set<int> outputIndexes;
|
||||
std::map<int, int> stackMap;
|
||||
std::map<std::string, int> outputIndexesMap;
|
||||
std::set<std::string> outputNameSet;
|
||||
for (auto name : outputs) {
|
||||
outputIndexesMap.insert(std::make_pair(name, -1));
|
||||
}
|
||||
for (int i=0; i<net->tensorName()->size(); ++i) {
|
||||
auto tname = net->tensorName()->GetAsString(i)->str();
|
||||
for (int j=0; j<inputs.size(); ++j) {
|
||||
|
@ -784,20 +788,20 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
|
|||
break;
|
||||
}
|
||||
}
|
||||
for (int j=0; j<outputs.size(); ++j) {
|
||||
if (tname == outputs[j]) {
|
||||
outputIndexes.emplace(i);
|
||||
outputIndexesMap.insert(std::make_pair(tname, i));
|
||||
break;
|
||||
}
|
||||
auto outputIter = outputIndexesMap.find(tname);
|
||||
if (outputIter != outputIndexesMap.end()) {
|
||||
outputIndexes.emplace(i);
|
||||
outputIter->second = i;
|
||||
}
|
||||
}
|
||||
if (outputIndexesMap.size() != outputs.size()) {
|
||||
MNN_ERROR("PipelineModule:: Can't find enough output from the model, finded is:\n");
|
||||
for (auto& iter : outputIndexesMap) {
|
||||
MNN_ERROR("[ %s ] ", iter.first.c_str());
|
||||
bool valid = true;
|
||||
for (auto& iter : outputIndexesMap) {
|
||||
if (iter.second == -1) {
|
||||
MNN_ERROR("PipelineModule:: Can't find output %s from the model:\n", iter.first.c_str());
|
||||
valid = false;
|
||||
}
|
||||
MNN_ERROR("\n");
|
||||
}
|
||||
if (!valid) {
|
||||
return nullptr;
|
||||
}
|
||||
bool divideSuccess = true;
|
||||
|
@ -810,6 +814,7 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
|
|||
for (int i=0; i<subModulesInfo.size(); ++i) {
|
||||
subModules[i].reset(_createSubModule(bufferStorage, subModulesInfo[i], subGraphMap, sharedConst, *config, modRuntime));
|
||||
}
|
||||
bool needReplaceBackend = false;
|
||||
if (preReplaceConstTensor) {
|
||||
// Prereplace const tensor
|
||||
auto curBackend = sharedConst->constReplaceBackend.get();
|
||||
|
@ -838,6 +843,7 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
|
|||
if (!tempRes) {
|
||||
continue;
|
||||
}
|
||||
needReplaceBackend = true;
|
||||
outDes->stageMask |= Tensor::InsideDescribe::CONVERTED_STAGE;
|
||||
WrapExecution::copyReplaceTensor(wrapTensor.get(), t.get());
|
||||
}
|
||||
|
@ -848,6 +854,11 @@ Module* PipelineModule::load(const std::vector<std::string>& inputs, const std::
|
|||
for (auto index : noneedComputeIndexes) {
|
||||
auto tensor = Tensor::clone(sharedConst->allTensors[index].get());
|
||||
auto constVar = Variable::create(Expr::create(tensor, true));
|
||||
auto back = TensorUtils::getDescribeOrigin(tensor)->getBackend();
|
||||
auto x =sharedConst->constReplaceBackend.get();
|
||||
if (needReplaceBackend && TensorUtils::getDescribeOrigin(tensor)->getBackend() == sharedConst->constReplaceBackend.get()) {
|
||||
constVar->expr().first->inside()->mHoldBackend = sharedConst->constReplaceBackend;
|
||||
}
|
||||
initVars.insert(std::make_pair(index, constVar));
|
||||
}
|
||||
auto result = new PipelineModule;
|
||||
|
|
|
@ -216,8 +216,9 @@ public:
|
|||
// Geometry Compute option, default is 0xFFFF
|
||||
GEOMETRY_COMPUTE_MASK = 4,
|
||||
|
||||
// 0: Close dynamic quant;
|
||||
// default 0
|
||||
// 1: For general convolution, use one scale&zeropoint to quant.
|
||||
// 2: use block-quant for input data.
|
||||
DYNAMIC_QUANT_OPTIONS = 5,
|
||||
|
||||
// For Mobile CPU with big-litter core, set decrease rate to let MNN divide task differential by CPU's performance
|
||||
|
@ -247,8 +248,11 @@ public:
|
|||
// Multi-Thread Load module, default is 0 (don't use other Thread)
|
||||
INIT_THREAD_NUMBER = 13,
|
||||
|
||||
// CPU core ids
|
||||
// Used CPU ids
|
||||
CPU_CORE_IDS = 14,
|
||||
|
||||
// set CPU threads to use when supports Arm sme2
|
||||
CPU_SME2_INSTRUCTIONS = 15
|
||||
};
|
||||
|
||||
enum ExternalPathType {
|
||||
|
|
|
@ -76,6 +76,6 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
|
|||
#define STR(x) STR_IMP(x)
|
||||
#define MNN_VERSION_MAJOR 3
|
||||
#define MNN_VERSION_MINOR 2
|
||||
#define MNN_VERSION_PATCH 1
|
||||
#define MNN_VERSION_PATCH 2
|
||||
#define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH)
|
||||
#endif /* MNNDefine_h */
|
||||
|
|
|
@ -27,6 +27,7 @@ struct ExecutorAttr;
|
|||
class MNN_PUBLIC Executor {
|
||||
public:
|
||||
class ComputeCache;
|
||||
class RuntimeExecuteWrap;
|
||||
struct DebugTools;
|
||||
/**Internal Usage Begin*/
|
||||
struct Requirement {
|
||||
|
|
|
@ -21,7 +21,7 @@ cd ../
|
|||
rm -rf ios_32
|
||||
mkdir ios_32
|
||||
cd ios_32
|
||||
cmake ../../../ -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../../../cmake/ios.toolchain.cmake -DMNN_METAL=ON -DARCHS="armv7;armv7s" -DENABLE_BITCODE=0 -DMNN_AAPL_FMWK=1 -DMNN_SEP_BUILD=0 -DMNN_BUILD_SHARED_LIBS=false -DMNN_USE_THREAD_POOL=OFF $1
|
||||
cmake ../../../ -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../../../cmake/ios.toolchain.cmake -DMNN_METAL=ON -DARCHS="armv7;armv7s" -DENABLE_BITCODE=0 -DMNN_AAPL_FMWK=1 -DMNN_SEP_BUILD=0 -DMNN_BUILD_SHARED_LIBS=false -DMNN_USE_THREAD_POOL=OFF $*
|
||||
echo "Building AArch32"
|
||||
make MNN -j16
|
||||
echo "End Building AArch32"
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
#!/bin/sh
|
||||
echo "Change directory to MNN_SOURCE_ROOT/project/ios before running this script"
|
||||
echo "Current PWD: ${PWD}"
|
||||
|
||||
rm -rf MNN-iOS-CPU-GPU
|
||||
mkdir MNN-iOS-CPU-GPU
|
||||
cd MNN-iOS-CPU-GPU
|
||||
# Static Begin
|
||||
mkdir Static
|
||||
cd Static
|
||||
|
||||
COMMON="-DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../../../cmake/ios.toolchain.cmake -DENABLE_BITCODE=0 -DMNN_AAPL_FMWK=1 -DMNN_SEP_BUILD=0 -DMNN_BUILD_SHARED_LIBS=false -DMNN_USE_THREAD_POOL=OFF"
|
||||
|
||||
rm -rf ios_64
|
||||
mkdir ios_64
|
||||
cd ios_64
|
||||
cmake ../../../ ${COMMON} -DMNN_METAL=ON -DARCHS="arm64" $*
|
||||
echo "Building AArch64"
|
||||
make MNN -j16
|
||||
echo "End Building AArch64"
|
||||
cd ../
|
||||
|
||||
rm -rf ios_32
|
||||
mkdir ios_32
|
||||
cd ios_32
|
||||
cmake ../../../ ${COMMON} -DMNN_METAL=OFF -DPLATFORM=SIMULATOR64 -DARCHS="x86_64" $*
|
||||
echo "Building Simulator64"
|
||||
make MNN -j16
|
||||
echo "End Building Simulator64"
|
||||
cd ../
|
||||
|
||||
mv ios_32/MNN.framework/MNN ios_32/MNN.framework/MNN_32
|
||||
|
||||
echo "Creating Fat Binary"
|
||||
lipo -create ios_32/MNN.framework/MNN_32 ios_64/MNN.framework/MNN -output ios_32/MNN.framework/MNN
|
||||
rm ios_32/MNN.framework/MNN_32
|
||||
echo "Patching Framework Headers"
|
||||
rm -rf ./MNN.framework
|
||||
cp -R ios_32/MNN.framework ./MNN.framework
|
||||
rm -rf ios_32
|
||||
rm -rf ios_64
|
|
@ -706,12 +706,15 @@
|
|||
9558333D29B0947300488807 /* MNNGelu.S in Sources */ = {isa = PBXBuildFile; fileRef = 9558333C29B0947300488807 /* MNNGelu.S */; };
|
||||
9558334729B09A2300488807 /* MNNGelu.S in Sources */ = {isa = PBXBuildFile; fileRef = 9558334629B09A2300488807 /* MNNGelu.S */; };
|
||||
9558334B29B09A7B00488807 /* MNNGeluFP16.S in Sources */ = {isa = PBXBuildFile; fileRef = 9558334A29B09A7B00488807 /* MNNGeluFP16.S */; settings = {COMPILER_FLAGS = "-march=armv8.2-a+fp16"; }; };
|
||||
955AD7522E1FB44E0099F26C /* MoEModule.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 955AD7512E1FB44E0099F26C /* MoEModule.cpp */; };
|
||||
955AD7532E1FB44E0099F26C /* MoEModule.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 955AD7502E1FB44E0099F26C /* MoEModule.hpp */; };
|
||||
9560EAD62BDE426A00C8D0B6 /* GeometryLayernorm.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 9560EAD52BDE426A00C8D0B6 /* GeometryLayernorm.cpp */; };
|
||||
956F52E12AB2D692004B13D9 /* ImageProcessUtils.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 956F52E02AB2D692004B13D9 /* ImageProcessUtils.cpp */; };
|
||||
956F52E32AB2D6A1004B13D9 /* ImageProcessUtils.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 956F52E22AB2D6A1004B13D9 /* ImageProcessUtils.hpp */; };
|
||||
95772DCF2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM82.S in Sources */ = {isa = PBXBuildFile; fileRef = 95772DCD2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM82.S */; };
|
||||
95772DD02C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM86.S in Sources */ = {isa = PBXBuildFile; fileRef = 95772DCE2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM86.S */; };
|
||||
958375352A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S in Sources */ = {isa = PBXBuildFile; fileRef = 958375342A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S */; };
|
||||
959F15A92E2782F800C67803 /* CountMinMaxValue_FP16.S in Sources */ = {isa = PBXBuildFile; fileRef = 959F15A82E2782F800C67803 /* CountMinMaxValue_FP16.S */; settings = {COMPILER_FLAGS = "-march=armv8.2-a+fp16"; }; };
|
||||
C43C81FA251894A600A0FF84 /* CommonOptFunctionNeon.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C43C81F8251894A500A0FF84 /* CommonOptFunctionNeon.cpp */; };
|
||||
C43C81FE251894BD00A0FF84 /* CPUPlugin.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C43C81FB251894BD00A0FF84 /* CPUPlugin.cpp */; };
|
||||
C43C81FF251894BD00A0FF84 /* ThreadPool.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C43C81FC251894BD00A0FF84 /* ThreadPool.cpp */; };
|
||||
|
@ -754,10 +757,6 @@
|
|||
CE072A282C91AF0700F190FD /* MNNC3ToXYZFast.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A252C91AF0700F190FD /* MNNC3ToXYZFast.S */; };
|
||||
CE072A2A2CAA50DE00F190FD /* MNNDepthwiseConvFastKernel.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A292CAA50DE00F190FD /* MNNDepthwiseConvFastKernel.S */; };
|
||||
CE072A2C2CAA510F00F190FD /* MNNDepthwiseConvFastKernelFP16.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A2B2CAA510F00F190FD /* MNNDepthwiseConvFastKernelFP16.S */; settings = {COMPILER_FLAGS = "-march=armv8.2-a+fp16"; }; };
|
||||
CE0AD4E42E1FB106002013A8 /* CountMinMaxValue_FP16.S in Sources */ = {isa = PBXBuildFile; fileRef = CE0AD4E32E1FB106002013A8 /* CountMinMaxValue_FP16.S */; settings = {COMPILER_FLAGS = "-march=armv8.2-a+fp16"; }; };
|
||||
CE0AD4E82E1FB152002013A8 /* MoEModule.hpp in Headers */ = {isa = PBXBuildFile; fileRef = CE0AD4E62E1FB152002013A8 /* MoEModule.hpp */; };
|
||||
CE0AD4E92E1FB152002013A8 /* ModuleInside.hpp in Headers */ = {isa = PBXBuildFile; fileRef = CE0AD4E52E1FB152002013A8 /* ModuleInside.hpp */; };
|
||||
CE0AD4EA2E1FB152002013A8 /* MoEModule.cpp in Sources */ = {isa = PBXBuildFile; fileRef = CE0AD4E72E1FB152002013A8 /* MoEModule.cpp */; };
|
||||
CE125CC82A52BF6B003698C9 /* MNNBilinearSampleC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CE125CC62A52BF6B003698C9 /* MNNBilinearSampleC8.S */; };
|
||||
CE125CC92A52BF6B003698C9 /* MNNBilinearLineC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CE125CC72A52BF6B003698C9 /* MNNBilinearLineC8.S */; };
|
||||
CE31C7C12D783CBB00741F49 /* WorkerThread.cpp in Sources */ = {isa = PBXBuildFile; fileRef = CE31C7C02D783CBB00741F49 /* WorkerThread.cpp */; };
|
||||
|
@ -1543,12 +1542,15 @@
|
|||
9558333C29B0947300488807 /* MNNGelu.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNGelu.S; sourceTree = "<group>"; };
|
||||
9558334629B09A2300488807 /* MNNGelu.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNGelu.S; sourceTree = "<group>"; };
|
||||
9558334A29B09A7B00488807 /* MNNGeluFP16.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNGeluFP16.S; path = ../../../arm82/asm/arm64/MNNGeluFP16.S; sourceTree = "<group>"; };
|
||||
955AD7502E1FB44E0099F26C /* MoEModule.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = MoEModule.hpp; sourceTree = "<group>"; };
|
||||
955AD7512E1FB44E0099F26C /* MoEModule.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = MoEModule.cpp; sourceTree = "<group>"; };
|
||||
9560EAD52BDE426A00C8D0B6 /* GeometryLayernorm.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = GeometryLayernorm.cpp; sourceTree = "<group>"; };
|
||||
956F52E02AB2D692004B13D9 /* ImageProcessUtils.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ImageProcessUtils.cpp; sourceTree = "<group>"; };
|
||||
956F52E22AB2D6A1004B13D9 /* ImageProcessUtils.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = ImageProcessUtils.hpp; sourceTree = "<group>"; };
|
||||
95772DCD2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM82.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackC4Int8ForMatMulA_ARM82.S; sourceTree = "<group>"; };
|
||||
95772DCE2C50F12A000FC1C3 /* MNNPackC4Int8ForMatMulA_ARM86.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNPackC4Int8ForMatMulA_ARM86.S; sourceTree = "<group>"; };
|
||||
958375342A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S; path = arm/arm64/MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S; sourceTree = "<group>"; };
|
||||
959F15A82E2782F800C67803 /* CountMinMaxValue_FP16.S */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.asm; name = CountMinMaxValue_FP16.S; path = ../../../arm82/asm/arm64/CountMinMaxValue_FP16.S; sourceTree = "<group>"; };
|
||||
C43C81F8251894A500A0FF84 /* CommonOptFunctionNeon.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CommonOptFunctionNeon.cpp; sourceTree = "<group>"; };
|
||||
C43C81FB251894BD00A0FF84 /* CPUPlugin.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CPUPlugin.cpp; sourceTree = "<group>"; };
|
||||
C43C81FC251894BD00A0FF84 /* ThreadPool.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ThreadPool.cpp; sourceTree = "<group>"; };
|
||||
|
@ -1591,10 +1593,6 @@
|
|||
CE072A252C91AF0700F190FD /* MNNC3ToXYZFast.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNC3ToXYZFast.S; path = arm/arm64/MNNC3ToXYZFast.S; sourceTree = "<group>"; };
|
||||
CE072A292CAA50DE00F190FD /* MNNDepthwiseConvFastKernel.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNDepthwiseConvFastKernel.S; sourceTree = "<group>"; };
|
||||
CE072A2B2CAA510F00F190FD /* MNNDepthwiseConvFastKernelFP16.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNDepthwiseConvFastKernelFP16.S; path = ../../../arm82/asm/arm64/MNNDepthwiseConvFastKernelFP16.S; sourceTree = "<group>"; };
|
||||
CE0AD4E32E1FB106002013A8 /* CountMinMaxValue_FP16.S */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.asm; name = CountMinMaxValue_FP16.S; path = ../arm82/asm/arm64/CountMinMaxValue_FP16.S; sourceTree = "<group>"; };
|
||||
CE0AD4E52E1FB152002013A8 /* ModuleInside.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = ModuleInside.hpp; sourceTree = "<group>"; };
|
||||
CE0AD4E62E1FB152002013A8 /* MoEModule.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = MoEModule.hpp; sourceTree = "<group>"; };
|
||||
CE0AD4E72E1FB152002013A8 /* MoEModule.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = MoEModule.cpp; sourceTree = "<group>"; };
|
||||
CE125CC62A52BF6B003698C9 /* MNNBilinearSampleC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearSampleC8.S; sourceTree = "<group>"; };
|
||||
CE125CC72A52BF6B003698C9 /* MNNBilinearLineC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearLineC8.S; sourceTree = "<group>"; };
|
||||
CE31C7BF2D783CBB00741F49 /* WorkerThread.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = WorkerThread.hpp; sourceTree = "<group>"; };
|
||||
|
@ -1859,6 +1857,7 @@
|
|||
488873A8215B639D0079B12E /* source */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
CE482EF5288536DA007CD935 /* internal */,
|
||||
4DF87C482887D3560003E2D4 /* calib3d */,
|
||||
4D4CF4612760946500A36D9F /* imgproc */,
|
||||
4D9A931B26255BDA00F9B43C /* coreml */,
|
||||
|
@ -1926,7 +1925,6 @@
|
|||
48887410215B639D0079B12E /* cpu */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
CE0AD4E32E1FB106002013A8 /* CountMinMaxValue_FP16.S */,
|
||||
CEA3C8892D6D71E1003EFAD2 /* CPUStft.hpp */,
|
||||
CEA3C88A2D6D71E1003EFAD2 /* CPUStft.cpp */,
|
||||
CE072A242C91AF0700F190FD /* MNNC3ToC4Fast.S */,
|
||||
|
@ -2211,9 +2209,8 @@
|
|||
48C84B6F250F711600EE7666 /* module */ = {
|
||||
isa = PBXGroup;
|
||||
children = (
|
||||
CE0AD4E52E1FB152002013A8 /* ModuleInside.hpp */,
|
||||
CE0AD4E62E1FB152002013A8 /* MoEModule.hpp */,
|
||||
CE0AD4E72E1FB152002013A8 /* MoEModule.cpp */,
|
||||
955AD7502E1FB44E0099F26C /* MoEModule.hpp */,
|
||||
955AD7512E1FB44E0099F26C /* MoEModule.cpp */,
|
||||
48C84B71250F711600EE7666 /* PipelineModule.cpp */,
|
||||
48C84B72250F711600EE7666 /* Module.cpp */,
|
||||
48C84B73250F711600EE7666 /* WhileModule.hpp */,
|
||||
|
@ -2643,6 +2640,7 @@
|
|||
4896D37025FE2A6A00717702 /* MNNExpFP16.S */,
|
||||
4896D37125FE2A6A00717702 /* MNNPackedMatMulFP16.S */,
|
||||
4896D37225FE2A6A00717702 /* MNNPackedMatMulRemainFP16.S */,
|
||||
959F15A82E2782F800C67803 /* CountMinMaxValue_FP16.S */,
|
||||
11A01A0A258785FB00745FA7 /* MNNVectorTop1Float.S */,
|
||||
11A01A0B258785FB00745FA7 /* MNNVectorTop1Int32.S */,
|
||||
48034566254157DF004738E3 /* MNNNV21ToBGRAUnit.S */,
|
||||
|
@ -2892,6 +2890,7 @@
|
|||
CEA82BDC2A15F8AD002CBC95 /* IdstConvolutionInt8.hpp in Headers */,
|
||||
4DE4E82C275E307B0016A916 /* cv in Headers */,
|
||||
1F501F842397BA5B004E8721 /* ImageProcess.hpp in Headers */,
|
||||
CECF8C5D299CACFD00D3875B /* Log.hpp in Headers */,
|
||||
1F501F822397BA5B004E8721 /* Interpreter.hpp in Headers */,
|
||||
C4F906B327688C3A0026B847 /* NMSModule.hpp in Headers */,
|
||||
1F501F882397BA5B004E8721 /* Tensor.hpp in Headers */,
|
||||
|
@ -2902,6 +2901,7 @@
|
|||
48C84B98250F71E900EE7666 /* CPUSoftmax.hpp in Headers */,
|
||||
4882C8B8241A22B800DAC168 /* OpCommonUtils.hpp in Headers */,
|
||||
48608B54250632EC00CB1D71 /* GeometryComputer.hpp in Headers */,
|
||||
CECF8C7A299CAD9400D3875B /* sha1.h in Headers */,
|
||||
4894C6EC27016F7200D8BE79 /* CPUResizeCache.hpp in Headers */,
|
||||
92FF04A623AA0BFB00AC97F6 /* FileLoader.hpp in Headers */,
|
||||
48F34733273A7C8400C45394 /* ImageProcessFunction.hpp in Headers */,
|
||||
|
@ -2915,6 +2915,7 @@
|
|||
48925F352744AC0700919B37 /* CPUROIAlign.hpp in Headers */,
|
||||
92FF029623AA0B5A00AC97F6 /* CPUCast.hpp in Headers */,
|
||||
4D9A937826255BDA00F9B43C /* CoreMLBinary.hpp in Headers */,
|
||||
CECF8C85299CAD9400D3875B /* log_util.h in Headers */,
|
||||
4D6D7FD52656896600F80814 /* DenseConvolutionTiledExecutor.hpp in Headers */,
|
||||
4D9A936626255BDA00F9B43C /* CoreMLExecutor.h in Headers */,
|
||||
92FF027A23AA0B5A00AC97F6 /* CPUPool.hpp in Headers */,
|
||||
|
@ -2923,6 +2924,7 @@
|
|||
1F501F802397BA5B004E8721 /* MNNDefine.h in Headers */,
|
||||
19D0FE76285C66F200B74B1A /* MetalLayerNorm.hpp in Headers */,
|
||||
489D7A682550FDC800AD896A /* MetalReduction.hpp in Headers */,
|
||||
CECF8C86299CAD9400D3875B /* sds.h in Headers */,
|
||||
1F501F7F2397BA5B004E8721 /* HalideRuntime.h in Headers */,
|
||||
92FF029E23AA0B5A00AC97F6 /* CPUDeconvolutionDepthwise.hpp in Headers */,
|
||||
4D9A935B26255BDA00F9B43C /* NeuralNetwork.pb-c.h in Headers */,
|
||||
|
@ -2944,8 +2946,10 @@
|
|||
481C2DEE25FE2CD6001ED6DF /* Arm82Functions.hpp in Headers */,
|
||||
4894C6EA27016F7200D8BE79 /* UnaryUtils.hpp in Headers */,
|
||||
EBD4842A2485FF650083CE95 /* Arm82Interp.hpp in Headers */,
|
||||
CECF8C81299CAD9400D3875B /* log_util_imp.h in Headers */,
|
||||
92FF037623AA0B5A00AC97F6 /* CPUBinary.hpp in Headers */,
|
||||
4D9A935826255BDA00F9B43C /* FeatureTypes.pb-c.h in Headers */,
|
||||
CECF8C7C299CAD9400D3875B /* hmac-sha.h in Headers */,
|
||||
48608B53250632EC00CB1D71 /* GeometryComputerUtils.hpp in Headers */,
|
||||
489D7A732550FDC800AD896A /* MetalBackend.hpp in Headers */,
|
||||
92FF03AC23AA0B5A00AC97F6 /* ResizeFunction.h in Headers */,
|
||||
|
@ -2968,10 +2972,12 @@
|
|||
4D9A937526255BDA00F9B43C /* CoreMLCommonExecution.hpp in Headers */,
|
||||
4DF87C522887D3F20003E2D4 /* CPUSvd.hpp in Headers */,
|
||||
48747D4B245D9D24000B9709 /* RuntimeFactory.hpp in Headers */,
|
||||
CECF8C77299CAD9400D3875B /* log_builder.h in Headers */,
|
||||
4D9A937226255BDA00F9B43C /* CoreMLConvolution.hpp in Headers */,
|
||||
92FF038B23AA0B5A00AC97F6 /* CPUUnravelIndex.hpp in Headers */,
|
||||
4AF4FB26269ED235005BA97B /* SparseConvInt8TiledExecutor.hpp in Headers */,
|
||||
CEA49AA92AFD010900971CB7 /* MetalExecution.hpp in Headers */,
|
||||
955AD7532E1FB44E0099F26C /* MoEModule.hpp in Headers */,
|
||||
92FF03BC23AA0B5A00AC97F6 /* OptimizedComputer.hpp in Headers */,
|
||||
95278CE72B9F0999009E9B29 /* CPUDynamicQuant.hpp in Headers */,
|
||||
48C84BA0250F725600EE7666 /* InitNet.hpp in Headers */,
|
||||
|
@ -3005,6 +3011,7 @@
|
|||
92FF03CA23AA0B5A00AC97F6 /* CPUConvolutionDepthwise.hpp in Headers */,
|
||||
92FF04A923AA0BFB00AC97F6 /* Schedule.hpp in Headers */,
|
||||
489D7A9F2550FDC900AD896A /* MetalConvolutionCommon.hpp in Headers */,
|
||||
CECF8C80299CAD9400D3875B /* lz4.h in Headers */,
|
||||
92FF028623AA0B5A00AC97F6 /* CPUDeconvolution.hpp in Headers */,
|
||||
489D7A722550FDC800AD896A /* MetalReLU6.hpp in Headers */,
|
||||
92FF04B523AA0BFB00AC97F6 /* TensorUtils.hpp in Headers */,
|
||||
|
@ -3045,8 +3052,6 @@
|
|||
92FF026023AA0B5A00AC97F6 /* CPURNNSequenceGRU.hpp in Headers */,
|
||||
48747D4F245D9E13000B9709 /* CPURaster.hpp in Headers */,
|
||||
489D7A822550FDC900AD896A /* MetalPReLU.hpp in Headers */,
|
||||
CE0AD4E82E1FB152002013A8 /* MoEModule.hpp in Headers */,
|
||||
CE0AD4E92E1FB152002013A8 /* ModuleInside.hpp in Headers */,
|
||||
48C84B84250F711700EE7666 /* WhileModule.hpp in Headers */,
|
||||
92FF02A923AA0B5A00AC97F6 /* CPUCropAndResize.hpp in Headers */,
|
||||
4D6D7FD92656897200F80814 /* SparseConvolutionTiledExecutor.hpp in Headers */,
|
||||
|
@ -3058,20 +3063,24 @@
|
|||
92FF03A623AA0B5A00AC97F6 /* ConvolutionTiledExecutor.hpp in Headers */,
|
||||
92FF036523AA0B5A00AC97F6 /* CPUResize.hpp in Headers */,
|
||||
92FF04B423AA0BFB00AC97F6 /* MNNMemoryUtils.h in Headers */,
|
||||
CECF8C88299CAD9400D3875B /* log_api.h in Headers */,
|
||||
4A224A0D27D0C2D9000A9260 /* ConvolutionPackWinograd.hpp in Headers */,
|
||||
4A224A0E27D0C2D9000A9260 /* ConvolutionPackFreeWinograd.hpp in Headers */,
|
||||
4D9A937426255BDA00F9B43C /* CoreMLReduction.hpp in Headers */,
|
||||
48C84B8B250F711700EE7666 /* PipelineModule.hpp in Headers */,
|
||||
F41497D7278D8A21004A363A /* RuntimeAttr.hpp in Headers */,
|
||||
CECF8C5B299CACFD00D3875B /* LogHelper.hpp in Headers */,
|
||||
92FF04C123AA0BFB00AC97F6 /* Backend.hpp in Headers */,
|
||||
482BFBCD28351BA1009210E4 /* ShaderMap.hpp in Headers */,
|
||||
489D7A812550FDC900AD896A /* MetalPooling.hpp in Headers */,
|
||||
CECF8C7F299CAD9400D3875B /* md5.h in Headers */,
|
||||
92FF02A623AA0B5A00AC97F6 /* CPUQuantizedMaxPool.hpp in Headers */,
|
||||
92FF028023AA0B5A00AC97F6 /* CPUFloatToInt8.hpp in Headers */,
|
||||
92FF028723AA0B5A00AC97F6 /* CPUFixedPoint.hpp in Headers */,
|
||||
C43C8227251894F400A0FF84 /* Vec.hpp in Headers */,
|
||||
4819FB1D24C138DF0050BD09 /* GeometryConvUtils.hpp in Headers */,
|
||||
489D7A952550FDC900AD896A /* MetalMatMul.hpp in Headers */,
|
||||
CECF8C83299CAD9400D3875B /* log_define.h in Headers */,
|
||||
C48CAE2628900C4A00271A6D /* ConvInt8Winograd.hpp in Headers */,
|
||||
48F34730273A7C7300C45394 /* CPUImageProcess.hpp in Headers */,
|
||||
489D7A702550FDC800AD896A /* MetalRaster.hpp in Headers */,
|
||||
|
@ -3392,6 +3401,7 @@
|
|||
48F34734273A7C8400C45394 /* ImageProcessFunction.cpp in Sources */,
|
||||
6A131E4025823349002EC3D6 /* PluginKernel.cpp in Sources */,
|
||||
48958781268EBA6F00EA01A7 /* CPUSegmentMean.cpp in Sources */,
|
||||
CECF8C7B299CAD9400D3875B /* sha1.c in Sources */,
|
||||
4D9A937026255BDA00F9B43C /* CoreMLUnary.cpp in Sources */,
|
||||
92FF04A823AA0BFB00AC97F6 /* AutoTime.cpp in Sources */,
|
||||
92FF04AE23AA0BFB00AC97F6 /* Backend.cpp in Sources */,
|
||||
|
@ -3418,7 +3428,6 @@
|
|||
48925F342744AC0700919B37 /* CPUROIAlign.cpp in Sources */,
|
||||
4896D36925FE2A3D00717702 /* Arm82Unary.cpp in Sources */,
|
||||
4DCF53902892B17100B5B393 /* ShapeHistogram.cpp in Sources */,
|
||||
CE0AD4EA2E1FB152002013A8 /* MoEModule.cpp in Sources */,
|
||||
92FF043423AA0B7100AC97F6 /* ShapeStridedSlice.cpp in Sources */,
|
||||
4896D37825FE2A6B00717702 /* MNNExpFP16.S in Sources */,
|
||||
4D4CF46B2760946500A36D9F /* draw.cpp in Sources */,
|
||||
|
@ -3447,6 +3456,7 @@
|
|||
92FF03CE23AA0B5A00AC97F6 /* CPUOPRegister.cpp in Sources */,
|
||||
92FF02B323AA0B5A00AC97F6 /* CPUInstanceNorm.cpp in Sources */,
|
||||
4819FB2C24C1396A0050BD09 /* GeometryPoolGrad.cpp in Sources */,
|
||||
CECF8C7E299CAD9400D3875B /* log_builder.cpp in Sources */,
|
||||
92FF042223AA0B7100AC97F6 /* ShapeConcat.cpp in Sources */,
|
||||
4D6D7FD12656891400F80814 /* MNNPackedSparseMatMulEpx4.S in Sources */,
|
||||
4D5662CC299B76ED0031C1A1 /* MNNMaxPoolInt8.S in Sources */,
|
||||
|
@ -3520,15 +3530,17 @@
|
|||
92FF041A23AA0B7100AC97F6 /* ShapeFill.cpp in Sources */,
|
||||
EB45C776244D7C6600E28F44 /* MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S in Sources */,
|
||||
4AF4FB29269ED244005BA97B /* MNNPackedSparseQuantMatMulEpx1.S in Sources */,
|
||||
CE0AD4E42E1FB106002013A8 /* CountMinMaxValue_FP16.S in Sources */,
|
||||
4D759B2C25FF89EE0037B0B6 /* GeometryShape.cpp in Sources */,
|
||||
11A01A07258785EA00745FA7 /* MNNVectorTop1Float.S in Sources */,
|
||||
48747D6E245D9E33000B9709 /* GeometrySlice.cpp in Sources */,
|
||||
CE072A272C91AF0700F190FD /* MNNC3ToC4Fast.S in Sources */,
|
||||
CECF8C7D299CAD9400D3875B /* md5.c in Sources */,
|
||||
92FF041923AA0B7100AC97F6 /* ShapeQuantizedMaxPool.cpp in Sources */,
|
||||
92FF038A23AA0B5A00AC97F6 /* CPURange.cpp in Sources */,
|
||||
CE072A182C91AEE700F190FD /* MNNGRAYToC4Fast.S in Sources */,
|
||||
CE125CC92A52BF6B003698C9 /* MNNBilinearLineC8.S in Sources */,
|
||||
955AD7522E1FB44E0099F26C /* MoEModule.cpp in Sources */,
|
||||
959F15A92E2782F800C67803 /* CountMinMaxValue_FP16.S in Sources */,
|
||||
92FF03A123AA0B5A00AC97F6 /* Int8FunctionsOpt.cpp in Sources */,
|
||||
CE072A222C91AEE700F190FD /* MNNPackC2.S in Sources */,
|
||||
92FF026523AA0B5A00AC97F6 /* CPUQuantizedAvgPool.cpp in Sources */,
|
||||
|
@ -3588,8 +3600,10 @@
|
|||
92FF042E23AA0B7100AC97F6 /* ShapeProposal.cpp in Sources */,
|
||||
92FF025923AA0B5A00AC97F6 /* CPUPoolInt8.cpp in Sources */,
|
||||
92FF045B23AA0B7100AC97F6 /* ShapeShape.cpp in Sources */,
|
||||
CECF8C87299CAD9400D3875B /* sds.c in Sources */,
|
||||
9560EAD62BDE426A00C8D0B6 /* GeometryLayernorm.cpp in Sources */,
|
||||
4D6D7FD72656896D00F80814 /* SparseConvolutionTiledExecutor.cpp in Sources */,
|
||||
CECF8C82299CAD9400D3875B /* log_api.cpp in Sources */,
|
||||
92FF03A823AA0B5A00AC97F6 /* WinogradOptFunction.cpp in Sources */,
|
||||
4A224A0C27D0C2D9000A9260 /* ConvolutionPackWinograd.cpp in Sources */,
|
||||
92FF044123AA0B7100AC97F6 /* ShapeMoments.cpp in Sources */,
|
||||
|
@ -3597,6 +3611,7 @@
|
|||
4D9A936026255BDA00F9B43C /* Model.pb-c.c in Sources */,
|
||||
CE9AFED628E54E3300566949 /* CPUInterp3D.cpp in Sources */,
|
||||
C4F906B427688C3A0026B847 /* NMSModule.cpp in Sources */,
|
||||
CECF8C64299CAD8400D3875B /* LogHelper.mm in Sources */,
|
||||
48FA474523AA127B00172C3B /* Executor.cpp in Sources */,
|
||||
92FF02EA23AA0B5A00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S in Sources */,
|
||||
CE072A162C91AEE700F190FD /* MNNBGRAToBGR.S in Sources */,
|
||||
|
@ -3624,6 +3639,7 @@
|
|||
CE072A2C2CAA510F00F190FD /* MNNDepthwiseConvFastKernelFP16.S in Sources */,
|
||||
EBECA3A724643D5D0062C7A3 /* MNNQuantizeFP16_UNIT4.S in Sources */,
|
||||
92FF04A423AA0BFB00AC97F6 /* Interpreter.cpp in Sources */,
|
||||
CECF8C5C299CACFD00D3875B /* Log.cpp in Sources */,
|
||||
92FF045623AA0B7100AC97F6 /* ShapeReshape.cpp in Sources */,
|
||||
92FF032523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */,
|
||||
92FF044423AA0B7100AC97F6 /* ShapeLSTM.cpp in Sources */,
|
||||
|
@ -3659,6 +3675,7 @@
|
|||
92FF02B623AA0B5A00AC97F6 /* CPUUnary.cpp in Sources */,
|
||||
92FF032723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */,
|
||||
CE7DC00028E2DE6B00797689 /* ShapeConvTranspose3D.cpp in Sources */,
|
||||
CECF8C78299CAD9400D3875B /* log_util_imp.cpp in Sources */,
|
||||
CE072A152C91AEE700F190FD /* MNNRGBAToGRAYFast.S in Sources */,
|
||||
92FF02CA23AA0B5A00AC97F6 /* MNNUnPackC4.S in Sources */,
|
||||
952298B22B4D39050043978B /* MetalLoop.mm in Sources */,
|
||||
|
@ -3683,11 +3700,13 @@
|
|||
92FF02FF23AA0B5A00AC97F6 /* MNNFloat2Int8.S in Sources */,
|
||||
4D9A937926255BDA00F9B43C /* CoreMLRaster.cpp in Sources */,
|
||||
48417FF224D13BF50056D9A7 /* GeometrySelect.cpp in Sources */,
|
||||
CECF8C84299CAD9400D3875B /* lz4.c in Sources */,
|
||||
489D7A7E2550FDC900AD896A /* MNNMetalContext.mm in Sources */,
|
||||
92FF033423AA0B5A00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S in Sources */,
|
||||
92FF036B23AA0B5A00AC97F6 /* CPUResize.cpp in Sources */,
|
||||
92FF02C723AA0B5A00AC97F6 /* MNNCopyC4WithStride.S in Sources */,
|
||||
92FF030923AA0B5A00AC97F6 /* MNNNV21ToBGRUnit.S in Sources */,
|
||||
CECF8C79299CAD9400D3875B /* hmac-sha.cpp in Sources */,
|
||||
92FF04C023AA0BFB00AC97F6 /* Tensor.cpp in Sources */,
|
||||
CEE9B95B2A3AA4D4006438F2 /* MNNBilinearLineC8.S in Sources */,
|
||||
92FF045D23AA0B7100AC97F6 /* ShapeCast.cpp in Sources */,
|
||||
|
@ -4016,7 +4035,7 @@
|
|||
CODE_SIGN_STYLE = Automatic;
|
||||
DEAD_CODE_STRIPPING = YES;
|
||||
DEFINES_MODULE = YES;
|
||||
DEVELOPMENT_TEAM = 6G7464HHUS;
|
||||
DEVELOPMENT_TEAM = Q48UX93J22;
|
||||
DYLIB_COMPATIBILITY_VERSION = 1;
|
||||
DYLIB_CURRENT_VERSION = 1;
|
||||
DYLIB_INSTALL_NAME_BASE = "@rpath";
|
||||
|
@ -4081,7 +4100,7 @@
|
|||
CODE_SIGN_STYLE = Automatic;
|
||||
DEAD_CODE_STRIPPING = YES;
|
||||
DEFINES_MODULE = YES;
|
||||
DEVELOPMENT_TEAM = 6G7464HHUS;
|
||||
DEVELOPMENT_TEAM = Q48UX93J22;
|
||||
DYLIB_COMPATIBILITY_VERSION = 1;
|
||||
DYLIB_CURRENT_VERSION = 1;
|
||||
DYLIB_INSTALL_NAME_BASE = "@rpath";
|
||||
|
@ -4144,7 +4163,7 @@
|
|||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
DEVELOPMENT_TEAM = 6G7464HHUS;
|
||||
DEVELOPMENT_TEAM = Q48UX93J22;
|
||||
GCC_ENABLE_CPP_EXCEPTIONS = NO;
|
||||
GCC_ENABLE_CPP_RTTI = NO;
|
||||
GCC_PREPROCESSOR_DEFINITIONS = "MNN_REDUCE_SIZE=1";
|
||||
|
@ -4160,8 +4179,7 @@
|
|||
IPHONEOS_DEPLOYMENT_TARGET = 9.0;
|
||||
LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
|
||||
OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcdeve;
|
||||
"PRODUCT_BUNDLE_IDENTIFIER[sdk=iphoneos*]" = com.taobao.mnn.abcdes;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcdeve00;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
|
@ -4173,7 +4191,7 @@
|
|||
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
|
||||
ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage;
|
||||
CODE_SIGN_STYLE = Automatic;
|
||||
DEVELOPMENT_TEAM = 6G7464HHUS;
|
||||
DEVELOPMENT_TEAM = Q48UX93J22;
|
||||
GCC_ENABLE_CPP_EXCEPTIONS = NO;
|
||||
GCC_ENABLE_CPP_RTTI = NO;
|
||||
GCC_PREPROCESSOR_DEFINITIONS = "MNN_REDUCE_SIZE=1";
|
||||
|
@ -4189,8 +4207,7 @@
|
|||
IPHONEOS_DEPLOYMENT_TARGET = 9.0;
|
||||
LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
|
||||
OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcdeve;
|
||||
"PRODUCT_BUNDLE_IDENTIFIER[sdk=iphoneos*]" = com.taobao.mnn.abcdes;
|
||||
PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcdeve00;
|
||||
PRODUCT_NAME = "$(TARGET_NAME)";
|
||||
TARGETED_DEVICE_FAMILY = "1,2";
|
||||
};
|
||||
|
@ -4304,3 +4321,4 @@
|
|||
};
|
||||
rootObject = 0F1465AE1FA18D1000F9860A /* Project object */;
|
||||
}
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ option(PYMNN_TRAIN_API "MNN train API be exposed" OFF)
|
|||
option(PYMNN_INTERNAL_SERVING "Internal use only." OFF)
|
||||
option(PYMNN_OPENCV_API "MNN OpenCV API be exposed" ON)
|
||||
option(PYMNN_IMGCODECS "MNN IMGCODECS API be exposed" OFF)
|
||||
option(PYMNN_AUDIO_API "MNN Audio API be exposed" ON)
|
||||
option(PYMNN_AUDIO_API "MNN Audio API be exposed" OFF)
|
||||
option(PYMNN_OHOS_INTERNAL "compile for harmony internal." OFF)
|
||||
|
||||
if (PYMNN_OHOS_INTERNAL)
|
||||
|
@ -202,7 +202,11 @@ else()
|
|||
endif()
|
||||
export_headers(DIR ${CMAKE_SOURCE_DIR}/pip_package/MNN)
|
||||
else()
|
||||
target_link_libraries(mnnpybridge PRIVATE log MNN MNN_Express MNNOpenCV MNNAudio)
|
||||
if (PYMNN_AUDIO_API)
|
||||
target_link_libraries(mnnpybridge PRIVATE log MNN MNN_Express MNNOpenCV MNNAudio)
|
||||
else()
|
||||
target_link_libraries(mnnpybridge PRIVATE log MNN MNN_Express MNNOpenCV)
|
||||
endif()
|
||||
if(PYMNN_USE_ALINNPYTHON)
|
||||
target_link_libraries(mnnpybridge PRIVATE AliNNPython)
|
||||
endif()
|
||||
|
|
|
@ -88,7 +88,7 @@ def build_deps():
|
|||
if USE_OPENCL:
|
||||
extra_opts += ' -DMNN_OPENCL=ON'
|
||||
if USE_LLM:
|
||||
extra_opts += ' -DMNN_BUILD_LLM=ON -DMNN_LOW_MEMORY=ON -DMNN_SUPPORT_TRANSFORMER_FUSE=ON'
|
||||
extra_opts += ' -DMNN_BUILD_LLM=ON -DMNN_LOW_MEMORY=ON -DMNN_SUPPORT_TRANSFORMER_FUSE=ON -DLLM_SUPPORT_VISION=ON -DLLM_SUPPORT_AUDIO=ON'
|
||||
if USE_ARM82:
|
||||
extra_opts += ' -DMNN_ARM82=ON'
|
||||
extra_opts += ' -DMNN_USE_THREAD_POOL=OFF -DMNN_OPENMP=ON' if USE_OPENMP else ' -DMNN_USE_THREAD_POOL=ON -DMNN_OPENMP=OFF'
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#include <sstream>
|
||||
#include "llm/llm.hpp"
|
||||
#include "cpp/getLinearInput.hpp"
|
||||
|
||||
typedef struct {
|
||||
PyObject_HEAD
|
||||
|
@ -146,6 +147,61 @@ static PyObject* PyMNNLLM_reset(LLM *self, PyObject *args) {
|
|||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* PyMNNLLM_enable_collection_mode(LLM *self, PyObject *args) {
|
||||
if (self->is_embedding) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
int mode = 0;
|
||||
const char* output_file = NULL;
|
||||
float target_sparsity = 0.5;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "i|sf", &mode, &output_file, &target_sparsity)) {
|
||||
PyErr_SetString(PyExc_ValueError, "Invalid arguments. Usage: enable_collection_mode(mode, output_file=None, target_sparsity=0.5)");
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
std::string filename;
|
||||
|
||||
switch (mode) {
|
||||
case 1: {
|
||||
// Threshold mode
|
||||
if (output_file == NULL) {
|
||||
filename = "thresholds.json";
|
||||
} else {
|
||||
filename = std::string(output_file);
|
||||
}
|
||||
|
||||
MNN::LinearInput::initGetThreshold(filename, target_sparsity);
|
||||
MNN_PRINT("Enabled threshold collection mode. Output: %s, Sparsity: %.2f\n",
|
||||
filename.c_str(), target_sparsity);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case 2: {
|
||||
// MaxValue mode
|
||||
if (output_file == NULL) {
|
||||
filename = "max_values.json";
|
||||
} else {
|
||||
filename = std::string(output_file);
|
||||
}
|
||||
|
||||
MNN::LinearInput::initGetMaxValue(filename);
|
||||
MNN_PRINT("Enabled max value collection mode. Output: %s\n", filename.c_str());
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
default: {
|
||||
PyErr_SetString(PyExc_ValueError, "Invalid mode. Use 1 for threshold collection, 2 for max value collection");
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
}
|
||||
|
||||
return toPyObj(true);
|
||||
}
|
||||
|
||||
static PyMethodDef PyMNNLLM_methods[] = {
|
||||
{"load", (PyCFunction)PyMNNLLM_load, METH_VARARGS, "load model."},
|
||||
{"forward", (PyCFunction)PyMNNLLM_forward, METH_VARARGS, "forward `logits` by `input_ids`."},
|
||||
|
@ -159,6 +215,7 @@ static PyMethodDef PyMNNLLM_methods[] = {
|
|||
{"create_lora", (PyCFunction)PyMNNLLM_create_lora, METH_VARARGS, "create_lora."},
|
||||
{"set_config", (PyCFunction)PyMNNLLM_set_config, METH_VARARGS, "set_config."},
|
||||
{"reset", (PyCFunction)PyMNNLLM_reset, METH_VARARGS, "reset."},
|
||||
{"enable_collection_mode", (PyCFunction)PyMNNLLM_enable_collection_mode, METH_VARARGS, "Enable data collection mode."},
|
||||
{NULL} /* Sentinel */
|
||||
};
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ void Arm82MNNPackForMatMul_A(float* destOrigin, float const** sourceGroup, const
|
|||
// C(UP_DIV(h,8), e, h8) = B(UP_DIV(h,hP), l, hP) * A(l, eP), hP = 24
|
||||
// parameter: [aStride, l, h, cStride, bExtraStride]
|
||||
// aStride in parameter is deprecated (useless), but for code clean, just retain it
|
||||
void MNNPackedMatMulFP16(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias);
|
||||
void MNNPackedMatMulFP16(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
|
||||
|
||||
// C(UP_DIV(h,8), e, h8) = B(UP_DIV(h,hP), l, hP) * A(l, e), hP = 24, e >= 1
|
||||
// parameter: [aStride, l, h, cStride, bExtraStride]
|
||||
|
@ -52,10 +52,16 @@ void MNNDynamicQuantFP16_Pack8(const float* src, int8_t* dst, const float* scale
|
|||
void MNNDynamicQuantFP16_Pack4(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, const float* bias, size_t pack);
|
||||
void MNNGeneralIm2col_Arm82(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
|
||||
void MNNGeneralIm2col_Arm86(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
|
||||
#ifdef MNN_SME2
|
||||
void MNNGeneralIm2col_Fp16Sme2(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
|
||||
#endif
|
||||
void MNNLocalMinMaxFP16_Pack4(float* dstMin, float* dstMax, const float* source, size_t blockNum, size_t blockLU, size_t EP, size_t LP, size_t loadDstBuffer);
|
||||
void MNNLocalMinMaxFP16_Pack8(float* dstMin, float* dstMax, const float* source, size_t blockNum, size_t blockLU, size_t EP, size_t LP, size_t loadDstBuffer);
|
||||
#endif // MNN_LOW_MEMORY
|
||||
void CountMinMaxValue_FP16(float* source, float* minVal, float* maxVal, size_t sizeQuad);
|
||||
#ifdef MNN_SME2
|
||||
void MNNPackedMatMulRemainFP16_SME2(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__aarch64__)
|
||||
|
@ -72,6 +78,12 @@ void MNNConvRunForLineDepthwiseFP16(float* dst, const float* src, const float* w
|
|||
|
||||
namespace MNN {
|
||||
|
||||
static void Sme2MNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
|
||||
*hP = 64;
|
||||
*eP = 16;
|
||||
*lP = 2;
|
||||
}
|
||||
|
||||
static void MNNMatrixAddFP16(FLOAT16* C, const FLOAT16* A, const FLOAT16* B, size_t widthC8, size_t cStride, size_t aStride, size_t bStride, size_t height) {
|
||||
for (int y = 0; y < height; ++y) {
|
||||
auto a = A + aStride * y, b = B + bStride * y;
|
||||
|
@ -103,18 +115,23 @@ static void ARM82CountMinMaxValue(float* source, float* minVal, float* maxVal, s
|
|||
max_ = ((__fp16*)maxVal)[0];
|
||||
min_ = ((__fp16*)minVal)[0];
|
||||
}
|
||||
if (remain > 0) {
|
||||
int16_t tmp[8] = {0};
|
||||
auto srcRemain = reinterpret_cast<uint8_t*>(source) + 8 * (size / 8) * 2;
|
||||
::memcpy(tmp, srcRemain, remain * 2);
|
||||
CountMinMaxValue_FP16((float*)tmp, (float*)tmp, (float*)((uint8_t*)tmp + 2), 1);
|
||||
max_ = ALIMAX(((__fp16*)tmp)[1], max_);
|
||||
min_ = ALIMIN(((__fp16*)tmp)[0], min_);
|
||||
auto srcPtr = reinterpret_cast<__fp16*>(source);
|
||||
while (remain) {
|
||||
max_ = ALIMAX(srcPtr[0], max_);
|
||||
min_ = ALIMIN(srcPtr[0], min_);
|
||||
srcPtr += 1;
|
||||
remain--;
|
||||
}
|
||||
reinterpret_cast<__fp16*>(minVal)[0] = min_;
|
||||
reinterpret_cast<__fp16*>(maxVal)[0] = max_;
|
||||
}
|
||||
}
|
||||
#ifdef MNN_SME2
|
||||
//(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias)
|
||||
static void MNNPackedMatMulFP16_SME2(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) {
|
||||
MNNPackedMatMulRemainFP16_SME2(C, A, B, 16, parameter, postParameters, bias, k, b);
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
static void ARM82CountMinMaxValue(float* source, float* minVal, float* maxVal, size_t size) {
|
||||
auto srcPtr = (FLOAT16*)source;
|
||||
|
@ -134,7 +151,8 @@ static void ARM82CountMinMaxValue(float* source, float* minVal, float* maxVal, s
|
|||
}
|
||||
#endif
|
||||
|
||||
static void Arm82MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h, size_t l, bool transpose) {
|
||||
static void Arm82MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
||||
auto l = kernelsize * ic;
|
||||
auto dest = (int16_t*)destC;
|
||||
auto source = (int16_t*)sourceC;
|
||||
int ePack, lPack, hPack;
|
||||
|
@ -169,6 +187,36 @@ static void Arm82MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h
|
|||
}
|
||||
}
|
||||
|
||||
static void Sme2MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h, size_t kernelsize, size_t ic ,bool transpose) {
|
||||
auto dest = (int16_t*)destC;
|
||||
auto source = (int16_t*)sourceC;
|
||||
int LP = 2;
|
||||
int HP = 64;
|
||||
auto l = kernelsize * ic;
|
||||
memset(dest, 0, ROUND_UP(h, HP) * ROUND_UP(ic, LP) * kernelsize * sizeof(FLOAT16));
|
||||
auto stride0 = ROUND_UP(ic, LP) * kernelsize * HP;
|
||||
auto stride1 = HP * ROUND_UP(ic, LP);
|
||||
auto stride2 = HP * LP;
|
||||
|
||||
size_t srcStride0 = l; // [h,k2,ic]->[hu,k2,ic/lp,hp,lp]
|
||||
size_t srcStride1 = 1;
|
||||
if (!transpose) { // [k2,ic,h]->[hu,k2,ic/lp,hp,lp]
|
||||
srcStride0 = 1;
|
||||
srcStride1 = h;
|
||||
}
|
||||
for (int y = 0; y < h; ++y) {
|
||||
auto yHu = y / HP;
|
||||
auto yHp = y % HP;
|
||||
for (int k = 0; k < kernelsize; ++k) {
|
||||
for (int x = 0; x < ic; ++x) {
|
||||
auto xLu = x / LP;
|
||||
auto xLp = x % LP;
|
||||
dest[yHu * stride0 + k * stride1 + xLu * stride2 + yHp * LP + xLp] = source[y * srcStride0 + (x + k * ic) * srcStride1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void MNNScaleAndAddBiasFP16(FLOAT16* dst, const FLOAT16* src, const FLOAT16* bias, const FLOAT16* alpha, size_t planeNumber,
|
||||
size_t biasNumber) {
|
||||
for (int z = 0; z < biasNumber; ++z) {
|
||||
|
@ -221,7 +269,7 @@ static void MNNGridSampleComputeCordFP16(FLOAT16* dst, const FLOAT16* src, size_
|
|||
::memcpy(dst, tempDst, areaRemain * 2 * sizeof(int16_t));
|
||||
}
|
||||
|
||||
static void MNNGridSampleComputeCord3DFp16(FLOAT* dst, const FLOAT* src, size_t inD, size_t inH, size_t inW, size_t outD, size_t outH, size_t outW, size_t strideD, size_t strideH, bool alignCorners) {
|
||||
static void MNNGridSampleComputeCord3DFp16(FLOAT* dst, const FLOAT* src, size_t inD, size_t inH, size_t inW, size_t outD, size_t outH, size_t outW, bool alignCorners) {
|
||||
float16x8_t zero = vdupq_n_f16(0);
|
||||
float16x8_t one = vdupq_n_f16(1);
|
||||
float16x8_t half = vdupq_n_f16(0.5f);
|
||||
|
@ -882,6 +930,499 @@ static void _ArmBasicMNNPackC4ForMatMul_A_L8(int8_t* destOrigin, int8_t const**
|
|||
}
|
||||
}
|
||||
|
||||
static void Sme2MNNPackForMatMul_A_FP16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) {
|
||||
int LP = 2;
|
||||
int pack = 8;
|
||||
int eDest = 16;
|
||||
// LP >= pack
|
||||
int number = info[0];
|
||||
int eReal = info[1];
|
||||
int offset = info[3];
|
||||
for (int n=0; n<number; ++n) {
|
||||
int eWork = el[4 * n + 0];
|
||||
int lWork = el[4 * n + 1];
|
||||
int eOffset = el[4 * n + 2];
|
||||
int lOffset = el[4 * n + 3];
|
||||
auto sourceN = (FLOAT16*)(sourceGroup[n]);
|
||||
auto destN = (FLOAT16*)destOrigin + lOffset * eDest + eOffset * LP;
|
||||
|
||||
auto srcStride0 = pack * offset;
|
||||
auto dstStride0 = eDest * LP;
|
||||
auto l = lWork;
|
||||
while (l > 7) {
|
||||
auto source = sourceN;
|
||||
auto dest = destN;
|
||||
l -= 8;
|
||||
auto e = eWork;
|
||||
if (e == eDest) {
|
||||
auto s0 = vld1q_f32((float*)(source)); // 00112233
|
||||
auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233
|
||||
auto s2 = vld1q_f32((float*)(source + 2 * srcStride0));
|
||||
auto s3 = vld1q_f32((float*)(source + 3 * srcStride0));
|
||||
|
||||
auto s4 = vld1q_f32((float*)(source + 4 * srcStride0));
|
||||
auto s5 = vld1q_f32((float*)(source + 5 * srcStride0));
|
||||
auto s6 = vld1q_f32((float*)(source + 6 * srcStride0));
|
||||
auto s7 = vld1q_f32((float*)(source + 7 * srcStride0));
|
||||
|
||||
auto s8 = vld1q_f32((float*)(source + 8 * srcStride0));
|
||||
auto s9 = vld1q_f32((float*)(source + 9 * srcStride0));
|
||||
auto s10 = vld1q_f32((float*)(source + 10 * srcStride0));
|
||||
auto s11 = vld1q_f32((float*)(source + 11 * srcStride0));
|
||||
|
||||
auto s12 = vld1q_f32((float*)(source + 12 * srcStride0));
|
||||
auto s13 = vld1q_f32((float*)(source + 13 * srcStride0));
|
||||
auto s14 = vld1q_f32((float*)(source + 14 * srcStride0));
|
||||
auto s15 = vld1q_f32((float*)(source + 15 * srcStride0));
|
||||
|
||||
auto zip1s01 = vzip1q_f32(s0, s1); // 00001111
|
||||
auto zip1s23 = vzip1q_f32(s2, s3); // 00001111
|
||||
auto zip1s45 = vzip1q_f32(s4, s5); // 00001111
|
||||
auto zip1s67 = vzip1q_f32(s6, s7); // 00001111
|
||||
auto zip1s89 = vzip1q_f32(s8, s9); // 00001111
|
||||
auto zip1s1011 = vzip1q_f32(s10, s11); // 00001111
|
||||
auto zip1s1213 = vzip1q_f32(s12, s13); // 00001111
|
||||
auto zip1s1415 = vzip1q_f32(s14, s15); // 00001111
|
||||
|
||||
auto zip2s01 = vzip2q_f32(s0, s1); // 22223333
|
||||
auto zip2s23 = vzip2q_f32(s2, s3); // 22223333
|
||||
auto zip2s45 = vzip2q_f32(s4, s5); // 22223333
|
||||
auto zip2s67 = vzip2q_f32(s6, s7); // 22223333
|
||||
auto zip2s89 = vzip2q_f32(s8, s9); // 22223333
|
||||
auto zip2s1011 = vzip2q_f32(s10, s11); // 22223333
|
||||
auto zip2s1213 = vzip2q_f32(s12, s13); // 22223333
|
||||
auto zip2s1415 = vzip2q_f32(s14, s15); // 22223333
|
||||
|
||||
auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000
|
||||
auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
|
||||
auto zip1s891011_01 = vzip1q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011);
|
||||
auto zip1s12131415_01 = vzip1q_f64((float64x2_t)zip1s1213, (float64x2_t)zip1s1415);
|
||||
|
||||
auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111
|
||||
auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
|
||||
auto zip2s891011_01 = vzip2q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011);
|
||||
auto zip2s12131415_01 = vzip2q_f64((float64x2_t)zip1s1213, (float64x2_t)zip1s1415);
|
||||
|
||||
auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222
|
||||
auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
|
||||
auto zip1s891011_23 = vzip1q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011);
|
||||
auto zip1s12131415_23 = vzip1q_f64((float64x2_t)zip2s1213, (float64x2_t)zip2s1415);
|
||||
|
||||
auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333
|
||||
auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
|
||||
auto zip2s891011_23 = vzip2q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011);
|
||||
auto zip2s12131415_23 = vzip2q_f64((float64x2_t)zip2s1213, (float64x2_t)zip2s1415);
|
||||
|
||||
vst1q_f64((float64_t*)dest, zip1s0123_01);
|
||||
vst1q_f64((float64_t*)(dest + 8), zip1s4567_01);
|
||||
vst1q_f64((float64_t*)(dest + 16), zip1s891011_01);
|
||||
vst1q_f64((float64_t*)(dest + 24), zip1s12131415_01);
|
||||
|
||||
vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01);
|
||||
vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01);
|
||||
vst1q_f64((float64_t*)(dest + dstStride0 + 16), zip2s891011_01);
|
||||
vst1q_f64((float64_t*)(dest + dstStride0 + 24), zip2s12131415_01);
|
||||
|
||||
vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23);
|
||||
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23);
|
||||
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 16), zip1s891011_23);
|
||||
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 24), zip1s12131415_23);
|
||||
|
||||
vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23);
|
||||
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23);
|
||||
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 16), zip2s891011_23);
|
||||
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 24), zip2s12131415_23);
|
||||
|
||||
// dest += (4 * dstStride0);
|
||||
// e -= eDest;
|
||||
sourceN += (eReal * pack);
|
||||
destN += (4 * dstStride0);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (e > 11) {
|
||||
auto s0 = vld1q_f32((float*)(source)); // 00112233
|
||||
auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233
|
||||
auto s2 = vld1q_f32((float*)(source + 2 * srcStride0));
|
||||
auto s3 = vld1q_f32((float*)(source + 3 * srcStride0));
|
||||
|
||||
auto s4 = vld1q_f32((float*)(source + 4 * srcStride0));
|
||||
auto s5 = vld1q_f32((float*)(source + 5 * srcStride0));
|
||||
auto s6 = vld1q_f32((float*)(source + 6 * srcStride0));
|
||||
auto s7 = vld1q_f32((float*)(source + 7 * srcStride0));
|
||||
|
||||
auto s8 = vld1q_f32((float*)(source + 8 * srcStride0));
|
||||
auto s9 = vld1q_f32((float*)(source + 9 * srcStride0));
|
||||
auto s10 = vld1q_f32((float*)(source + 10 * srcStride0));
|
||||
auto s11 = vld1q_f32((float*)(source + 11 * srcStride0));
|
||||
|
||||
auto zip1s01 = vzip1q_f32(s0, s1); // 00001111
|
||||
auto zip1s23 = vzip1q_f32(s2, s3); // 00001111
|
||||
auto zip1s45 = vzip1q_f32(s4, s5); // 00001111
|
||||
auto zip1s67 = vzip1q_f32(s6, s7); // 00001111
|
||||
auto zip1s89 = vzip1q_f32(s8, s9); // 00001111
|
||||
auto zip1s1011 = vzip1q_f32(s10, s11); // 00001111
|
||||
|
||||
auto zip2s01 = vzip2q_f32(s0, s1); // 22223333
|
||||
auto zip2s23 = vzip2q_f32(s2, s3); // 22223333
|
||||
auto zip2s45 = vzip2q_f32(s4, s5); // 22223333
|
||||
auto zip2s67 = vzip2q_f32(s6, s7); // 22223333
|
||||
auto zip2s89 = vzip2q_f32(s8, s9); // 22223333
|
||||
auto zip2s1011 = vzip2q_f32(s10, s11); // 22223333
|
||||
|
||||
auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000
|
||||
auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
|
||||
auto zip1s891011_01 = vzip1q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011);
|
||||
|
||||
auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111
|
||||
auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
|
||||
auto zip2s891011_01 = vzip2q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011);
|
||||
|
||||
auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222
|
||||
auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
|
||||
auto zip1s891011_23 = vzip1q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011);
|
||||
|
||||
auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333
|
||||
auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
|
||||
auto zip2s891011_23 = vzip2q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011);
|
||||
|
||||
vst1q_f64((float64_t*)dest, zip1s0123_01);
|
||||
vst1q_f64((float64_t*)(dest + 8), zip1s4567_01);
|
||||
vst1q_f64((float64_t*)(dest + 16), zip1s891011_01);
|
||||
|
||||
vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01);
|
||||
vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01);
|
||||
vst1q_f64((float64_t*)(dest + dstStride0 + 16), zip2s891011_01);
|
||||
|
||||
vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23);
|
||||
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23);
|
||||
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 16), zip1s891011_23);
|
||||
|
||||
vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23);
|
||||
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23);
|
||||
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 16), zip2s891011_23);
|
||||
|
||||
dest += 24;
|
||||
e -= 12;
|
||||
source += (12 * srcStride0);
|
||||
}
|
||||
|
||||
if (e > 7) {
|
||||
auto s0 = vld1q_f32((float*)(source)); // 00112233
|
||||
auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233
|
||||
auto s2 = vld1q_f32((float*)(source + 2 * srcStride0));
|
||||
auto s3 = vld1q_f32((float*)(source + 3 * srcStride0));
|
||||
|
||||
auto s4 = vld1q_f32((float*)(source + 4 * srcStride0));
|
||||
auto s5 = vld1q_f32((float*)(source + 5 * srcStride0));
|
||||
auto s6 = vld1q_f32((float*)(source + 6 * srcStride0));
|
||||
auto s7 = vld1q_f32((float*)(source + 7 * srcStride0));
|
||||
|
||||
auto zip1s01 = vzip1q_f32(s0, s1); // 00001111
|
||||
auto zip1s23 = vzip1q_f32(s2, s3); // 00001111
|
||||
auto zip1s45 = vzip1q_f32(s4, s5); // 00001111
|
||||
auto zip1s67 = vzip1q_f32(s6, s7); // 00001111
|
||||
|
||||
auto zip2s01 = vzip2q_f32(s0, s1); // 22223333
|
||||
auto zip2s23 = vzip2q_f32(s2, s3); // 22223333
|
||||
auto zip2s45 = vzip2q_f32(s4, s5); // 22223333
|
||||
auto zip2s67 = vzip2q_f32(s6, s7); // 22223333
|
||||
|
||||
auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000
|
||||
auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
|
||||
|
||||
auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111
|
||||
auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
|
||||
|
||||
auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222
|
||||
auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
|
||||
|
||||
auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333
|
||||
auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
|
||||
|
||||
vst1q_f64((float64_t*)dest, zip1s0123_01);
|
||||
vst1q_f64((float64_t*)(dest + 8), zip1s4567_01);
|
||||
|
||||
vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01);
|
||||
vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01);
|
||||
|
||||
vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23);
|
||||
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23);
|
||||
|
||||
vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23);
|
||||
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23);
|
||||
|
||||
dest += 16;
|
||||
e -= 8;
|
||||
source += (8 * srcStride0);
|
||||
}
|
||||
|
||||
if (e > 3) {
|
||||
auto s0 = vld1q_f32((float*)(source)); // 00112233
|
||||
auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233
|
||||
auto s2 = vld1q_f32((float*)(source + 2 * srcStride0));
|
||||
auto s3 = vld1q_f32((float*)(source + 3 * srcStride0));
|
||||
|
||||
auto zip1s01 = vzip1q_f32(s0, s1); // 00001111
|
||||
auto zip1s23 = vzip1q_f32(s2, s3); // 00001111
|
||||
|
||||
auto zip2s01 = vzip2q_f32(s0, s1); // 22223333
|
||||
auto zip2s23 = vzip2q_f32(s2, s3); // 22223333
|
||||
|
||||
auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000
|
||||
|
||||
auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111
|
||||
|
||||
auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222
|
||||
|
||||
auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333
|
||||
|
||||
vst1q_f64((float64_t*)dest, zip1s0123_01);
|
||||
vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01);
|
||||
vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23);
|
||||
vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23);
|
||||
|
||||
dest += 8;
|
||||
e -= 4;
|
||||
source += (4 * srcStride0);
|
||||
}
|
||||
while (e > 0) {
|
||||
auto s0 = vld1q_f32((float*)(source)); // 00112233
|
||||
|
||||
((float*)dest)[0] = s0[0];
|
||||
((float*)(dest + dstStride0))[0] = s0[1];
|
||||
((float*)(dest + 2 * dstStride0))[0] = s0[2];
|
||||
((float*)(dest + 3 * dstStride0))[0] = s0[3];
|
||||
|
||||
dest += 2;
|
||||
e -= 1;
|
||||
source += srcStride0;
|
||||
}
|
||||
sourceN += (eReal * pack);
|
||||
destN += (4 * dstStride0);
|
||||
} // l>7
|
||||
|
||||
if (l > 3) {
|
||||
auto source = sourceN;
|
||||
auto dest = destN;
|
||||
l -= 4;
|
||||
auto e = eWork;
|
||||
if (e == eDest) {
|
||||
auto s0 = vld1_f32((float*)(source)); // 0011
|
||||
auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
|
||||
auto s2 = vld1_f32((float*)(source + 2 * srcStride0));
|
||||
auto s3 = vld1_f32((float*)(source + 3 * srcStride0));
|
||||
|
||||
auto s4 = vld1_f32((float*)(source + 4 * srcStride0));
|
||||
auto s5 = vld1_f32((float*)(source + 5 * srcStride0));
|
||||
auto s6 = vld1_f32((float*)(source + 6 * srcStride0));
|
||||
auto s7 = vld1_f32((float*)(source + 7 * srcStride0));
|
||||
|
||||
auto s8 = vld1_f32((float*)(source + 8 * srcStride0));
|
||||
auto s9 = vld1_f32((float*)(source + 9 * srcStride0));
|
||||
auto s10 = vld1_f32((float*)(source + 10 * srcStride0));
|
||||
auto s11 = vld1_f32((float*)(source + 11 * srcStride0));
|
||||
|
||||
auto s12 = vld1_f32((float*)(source + 12 * srcStride0));
|
||||
auto s13 = vld1_f32((float*)(source + 13 * srcStride0));
|
||||
auto s14 = vld1_f32((float*)(source + 14 * srcStride0));
|
||||
auto s15 = vld1_f32((float*)(source + 15 * srcStride0));
|
||||
|
||||
auto zip1s01 = vzip1_f32(s0, s1); // 0000
|
||||
auto zip1s23 = vzip1_f32(s2, s3); // 0000
|
||||
auto zip1s45 = vzip1_f32(s4, s5); // 0000
|
||||
auto zip1s67 = vzip1_f32(s6, s7); // 0000
|
||||
auto zip1s89 = vzip1_f32(s8, s9); // 0000
|
||||
auto zip1s1011 = vzip1_f32(s10, s11); // 0000
|
||||
auto zip1s1213 = vzip1_f32(s12, s13); // 0000
|
||||
auto zip1s1415 = vzip1_f32(s14, s15); // 0000
|
||||
|
||||
auto zip2s01 = vzip2_f32(s0, s1); // 1111
|
||||
auto zip2s23 = vzip2_f32(s2, s3); // 1111
|
||||
auto zip2s45 = vzip2_f32(s4, s5); // 1111
|
||||
auto zip2s67 = vzip2_f32(s6, s7); // 1111
|
||||
auto zip2s89 = vzip2_f32(s8, s9); // 1111
|
||||
auto zip2s1011 = vzip2_f32(s10, s11); // 1111
|
||||
auto zip2s1213 = vzip2_f32(s12, s13); // 1111
|
||||
auto zip2s1415 = vzip2_f32(s14, s15); // 1111
|
||||
|
||||
vst1_f32((float32_t*)dest, zip1s01);
|
||||
vst1_f32((float32_t*)(dest + 4), zip1s23);
|
||||
vst1_f32((float32_t*)(dest + 8), zip1s45);
|
||||
vst1_f32((float32_t*)(dest + 12), zip1s67);
|
||||
vst1_f32((float32_t*)(dest + 16), zip1s89);
|
||||
vst1_f32((float32_t*)(dest + 20), zip1s1011);
|
||||
vst1_f32((float32_t*)(dest + 24), zip1s1213);
|
||||
vst1_f32((float32_t*)(dest + 28), zip1s1415);
|
||||
|
||||
vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
|
||||
vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23);
|
||||
vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45);
|
||||
vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67);
|
||||
vst1_f32((float32_t*)(dest + dstStride0 + 16), zip2s89);
|
||||
vst1_f32((float32_t*)(dest + dstStride0 + 20), zip2s1011);
|
||||
vst1_f32((float32_t*)(dest + dstStride0 + 24), zip2s1213);
|
||||
vst1_f32((float32_t*)(dest + dstStride0 + 28), zip2s1415);
|
||||
|
||||
|
||||
dest += 32;
|
||||
e -= eDest;
|
||||
}
|
||||
|
||||
if (e > 11) {
|
||||
auto s0 = vld1_f32((float*)(source)); // 0011
|
||||
auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
|
||||
auto s2 = vld1_f32((float*)(source + 2 * srcStride0));
|
||||
auto s3 = vld1_f32((float*)(source + 3 * srcStride0));
|
||||
|
||||
auto s4 = vld1_f32((float*)(source + 4 * srcStride0));
|
||||
auto s5 = vld1_f32((float*)(source + 5 * srcStride0));
|
||||
auto s6 = vld1_f32((float*)(source + 6 * srcStride0));
|
||||
auto s7 = vld1_f32((float*)(source + 7 * srcStride0));
|
||||
|
||||
auto s8 = vld1_f32((float*)(source + 8 * srcStride0));
|
||||
auto s9 = vld1_f32((float*)(source + 9 * srcStride0));
|
||||
auto s10 = vld1_f32((float*)(source + 10 * srcStride0));
|
||||
auto s11 = vld1_f32((float*)(source + 11 * srcStride0));
|
||||
|
||||
auto zip1s01 = vzip1_f32(s0, s1); // 0000
|
||||
auto zip1s23 = vzip1_f32(s2, s3); // 0000
|
||||
auto zip1s45 = vzip1_f32(s4, s5); // 0000
|
||||
auto zip1s67 = vzip1_f32(s6, s7); // 0000
|
||||
auto zip1s89 = vzip1_f32(s8, s9); // 0000
|
||||
auto zip1s1011 = vzip1_f32(s10, s11); // 0000
|
||||
|
||||
auto zip2s01 = vzip2_f32(s0, s1); // 1111
|
||||
auto zip2s23 = vzip2_f32(s2, s3); // 1111
|
||||
auto zip2s45 = vzip2_f32(s4, s5); // 1111
|
||||
auto zip2s67 = vzip2_f32(s6, s7); // 1111
|
||||
auto zip2s89 = vzip2_f32(s8, s9); // 1111
|
||||
auto zip2s1011 = vzip2_f32(s10, s11); // 1111
|
||||
|
||||
vst1_f32((float32_t*)dest, zip1s01);
|
||||
vst1_f32((float32_t*)(dest + 4), zip1s23);
|
||||
vst1_f32((float32_t*)(dest + 8), zip1s45);
|
||||
vst1_f32((float32_t*)(dest + 12), zip1s67);
|
||||
vst1_f32((float32_t*)(dest + 16), zip1s89);
|
||||
vst1_f32((float32_t*)(dest + 20), zip1s1011);
|
||||
|
||||
vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
|
||||
vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23);
|
||||
vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45);
|
||||
vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67);
|
||||
vst1_f32((float32_t*)(dest + dstStride0 + 16), zip2s89);
|
||||
vst1_f32((float32_t*)(dest + dstStride0 + 20), zip2s1011);
|
||||
|
||||
dest += 24;
|
||||
e -= 12;
|
||||
source += (12 * srcStride0);
|
||||
}
|
||||
|
||||
if (e > 7) {
|
||||
auto s0 = vld1_f32((float*)(source)); // 0011
|
||||
auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
|
||||
auto s2 = vld1_f32((float*)(source + 2 * srcStride0));
|
||||
auto s3 = vld1_f32((float*)(source + 3 * srcStride0));
|
||||
|
||||
auto s4 = vld1_f32((float*)(source + 4 * srcStride0));
|
||||
auto s5 = vld1_f32((float*)(source + 5 * srcStride0));
|
||||
auto s6 = vld1_f32((float*)(source + 6 * srcStride0));
|
||||
auto s7 = vld1_f32((float*)(source + 7 * srcStride0));
|
||||
|
||||
auto zip1s01 = vzip1_f32(s0, s1); // 0000
|
||||
auto zip1s23 = vzip1_f32(s2, s3); // 0000
|
||||
auto zip1s45 = vzip1_f32(s4, s5); // 0000
|
||||
auto zip1s67 = vzip1_f32(s6, s7); // 0000
|
||||
|
||||
auto zip2s01 = vzip2_f32(s0, s1); // 1111
|
||||
auto zip2s23 = vzip2_f32(s2, s3); // 1111
|
||||
auto zip2s45 = vzip2_f32(s4, s5); // 1111
|
||||
auto zip2s67 = vzip2_f32(s6, s7); // 1111
|
||||
|
||||
vst1_f32((float32_t*)dest, zip1s01);
|
||||
vst1_f32((float32_t*)(dest + 4), zip1s23);
|
||||
vst1_f32((float32_t*)(dest + 8), zip1s45);
|
||||
vst1_f32((float32_t*)(dest + 12), zip1s67);
|
||||
|
||||
vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
|
||||
vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23);
|
||||
vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45);
|
||||
vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67);
|
||||
|
||||
dest += 16;
|
||||
e -= 8;
|
||||
source += (8 * srcStride0);
|
||||
}
|
||||
|
||||
if (e > 3) {
|
||||
auto s0 = vld1_f32((float*)(source)); // 0011
|
||||
auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
|
||||
auto s2 = vld1_f32((float*)(source + 2 * srcStride0));
|
||||
auto s3 = vld1_f32((float*)(source + 3 * srcStride0));
|
||||
|
||||
auto zip1s01 = vzip1_f32(s0, s1); // 0000
|
||||
auto zip1s23 = vzip1_f32(s2, s3); // 0000
|
||||
|
||||
auto zip2s01 = vzip2_f32(s0, s1); // 1111
|
||||
auto zip2s23 = vzip2_f32(s2, s3); // 1111
|
||||
|
||||
vst1_f32((float32_t*)dest, zip1s01);
|
||||
vst1_f32((float32_t*)(dest + 4), zip1s23);
|
||||
|
||||
vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
|
||||
vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23);
|
||||
|
||||
dest += 8;
|
||||
e -= 4;
|
||||
source += (4 * srcStride0);
|
||||
}
|
||||
if (e > 1) {
|
||||
auto s0 = vld1_f32((float*)(source)); // 0011
|
||||
auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
|
||||
|
||||
auto zip1s01 = vzip1_f32(s0, s1); // 0000
|
||||
|
||||
auto zip2s01 = vzip2_f32(s0, s1); // 1111
|
||||
|
||||
vst1_f32((float32_t*)dest, zip1s01);
|
||||
|
||||
vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
|
||||
|
||||
dest += 4;
|
||||
e -= 2;
|
||||
source += (2 * srcStride0);
|
||||
}
|
||||
if (e > 0) {
|
||||
auto s0 = vld1_f32((float*)(source)); // 0011
|
||||
|
||||
((float*)dest)[0] = s0[0];
|
||||
((float*)(dest + dstStride0))[0] = s0[1];
|
||||
}
|
||||
sourceN += 4;
|
||||
destN += (2 * dstStride0);
|
||||
}
|
||||
|
||||
auto source = (FLOAT16*)(sourceGroup[n]);
|
||||
auto dest = (FLOAT16*)destOrigin + lOffset * eDest + eOffset * LP;
|
||||
if (l > 0) {
|
||||
auto e = eWork;
|
||||
auto lRemain = lWork - l;
|
||||
// if e < eDest, packed A -> [LU, eDest, LP] eDest=eP
|
||||
for (int y=0; y<e; ++y) {
|
||||
auto yR = y % eDest;
|
||||
for (int x=lRemain; x<lWork; ++x) {
|
||||
auto xR = x % pack;
|
||||
auto xC = x / pack;
|
||||
auto xOut = x / LP;
|
||||
auto xIn = x % LP;
|
||||
dest[xOut * eDest * LP + yR * LP + xIn] = source[xC * eReal * pack + y * pack * offset + xR];
|
||||
}
|
||||
}
|
||||
l--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef MNN_LOW_MEMORY
|
||||
void MNNAbsMaxFP16(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) {
|
||||
if (pack == 4) {
|
||||
|
@ -1024,7 +1565,7 @@ static void MNNAsyQuantInfo_FP16(float* scale, float* bias, float* qscale, float
|
|||
}
|
||||
|
||||
#ifdef __aarch64__
|
||||
if (DST_XUNIT == 12) { // Arm82, fp16: core->pack=8, SRC_UNIT=4
|
||||
if (DST_XUNIT == 12 || DST_XUNIT == 16) { // Arm82/SME2, fp16: core->pack=8, SRC_UNIT=4
|
||||
// max,min shape: [blockNum, EP]
|
||||
if (innerSide == 4) {
|
||||
for (int i = 0; i < kernelsize; ++i) {
|
||||
|
@ -1037,12 +1578,22 @@ static void MNNAsyQuantInfo_FP16(float* scale, float* bias, float* qscale, float
|
|||
}
|
||||
}
|
||||
// scale, bias
|
||||
auto success = MNNAsyLocalQuantInfo_EP12_FP16(scale, bias, qscale, qbias, dstMin, dstMax, info);
|
||||
if (!success) {
|
||||
MNN_ERROR("Call error: MNNAsyLocalQuantInfo_EP12_FP16\n");
|
||||
if (DST_XUNIT == 12) {
|
||||
auto success = MNNAsyLocalQuantInfo_EP12_FP16(scale, bias, qscale, qbias, dstMin, dstMax, info);
|
||||
if (!success) {
|
||||
MNN_ERROR("Call error: MNNAsyLocalQuantInfo_EP12_FP16\n");
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (DST_XUNIT == 16) {
|
||||
auto success = MNNAsyLocalQuantInfo_EP16_FP16(scale, bias, qscale, qbias, dstMin, dstMax, info);
|
||||
if (!success) {
|
||||
MNN_ERROR("Call error: MNNAsyLocalQuantInfo_EP16_FP16\n");
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (DST_XUNIT == 10 && innerSide == 8) { // Arm86, fp16: core->pack=8, SRC_UNIT=8
|
||||
// max,min shape: [blockNum, plane]
|
||||
|
@ -1120,12 +1671,16 @@ bool Arm82Functions::init() {
|
|||
gInstance = new CoreFunctions;
|
||||
gArm82CoreInt8Functions = new CoreInt8Functions;
|
||||
*gArm82CoreInt8Functions = *MNNGetInt8CoreFunctions();
|
||||
gInstance->backendMatmulRelatedFunctions = origin->backendMatmulRelatedFunctions;
|
||||
gInstance->sme2MatmulRelatedFuncions = origin->sme2MatmulRelatedFuncions;
|
||||
{
|
||||
if (origin->supportSDot) {
|
||||
gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<12, 4>;
|
||||
gInstance->supportSDot = true;
|
||||
}
|
||||
if (origin->supportI8mm) {
|
||||
gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A_L8<10, 8>;
|
||||
gInstance->supportI8mm = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1165,12 +1720,16 @@ bool Arm82Functions::init() {
|
|||
FUNC_PTR_ASSIGN(gInstance->MNNAddC4WithStride, MNNAddC8WithStrideFP16);
|
||||
|
||||
// MatMul
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNGetMatMulPackMode, Arm82MNNGetMatMulPackMode);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul, MNNPackedMatMulFP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain, MNNPackedMatMulRemainFP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Arm82MNNPackForMatMul_A);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackForMatMul_B, Arm82MNNPackForMatMul_B);
|
||||
#if defined(__aarch64__)
|
||||
gInstance->supportFp16arith = origin->supportFp16arith;
|
||||
gInstance->supportSDot = origin->supportSDot;
|
||||
gInstance->supportI8mm = origin->supportI8mm;
|
||||
gInstance->supportSME2 = origin->supportSME2;
|
||||
#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM
|
||||
// Weight Dequant Gemm Kernels
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul_int8, MNNPackedMatMulFP16_int8);
|
||||
|
@ -1184,22 +1743,18 @@ bool Arm82Functions::init() {
|
|||
FUNC_PTR_ASSIGN(gInstance->MNNAsyQuantFunc, MNNAsyQuantFunc_Arm82);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNAsyQuantInfo, MNNAsyQuantInfo_FP16); // return 'plane' min&max
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNDynamicUpdateConvBiasScale, origin->MNNDynamicUpdateConvBiasScale);
|
||||
#ifdef __aarch64__
|
||||
|
||||
if (origin->supportSDot) {
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNGeneralIm2Col, MNNGeneralIm2col_Arm82);
|
||||
}
|
||||
if (origin->supportI8mm) {
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNGeneralIm2Col, MNNGeneralIm2col_Arm86);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
#endif // MNN_LOW_MEMORY
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNCountMaxMinValue, ARM82CountMinMaxValue); // return one min&max
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNSumByAxisLForMatmul_A, origin->MNNSumByAxisLForMatmul_A);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNDepthwiseConvFastKernel, MNNDepthwiseConvFastKernelFP16);
|
||||
#endif
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Arm82MNNPackForMatMul_A);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNGetMatMulPackMode, Arm82MNNGetMatMulPackMode);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackForMatMul_B, Arm82MNNPackForMatMul_B);
|
||||
gInstance->MNNComputeMatMulForH_1 = _MNNComputeMatMulForH_1_FP16;
|
||||
gInstance->MNNComputeMatMulForE_1 = _MNNComputeMatMulForE_1_FP16;
|
||||
|
||||
|
@ -1219,6 +1774,43 @@ bool Arm82Functions::init() {
|
|||
|
||||
gInstance->MNNPoolingMax = (decltype(gInstance->MNNPoolingMax))(poolingMax<float16_t, Vec, 8, -65535>);
|
||||
gInstance->MNNPoolingAvg = (decltype(gInstance->MNNPoolingAvg))(poolingAvg<float16_t, Vec, 8>);
|
||||
|
||||
{
|
||||
gInstance->backendMatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A;
|
||||
gInstance->backendMatmulRelatedFunctions.MNNGeneralIm2Col = gInstance->MNNGeneralIm2Col;
|
||||
|
||||
gInstance->backendMatmulRelatedFunctions.MNNGetMatMulPackMode = gInstance->MNNGetMatMulPackMode;
|
||||
gInstance->backendMatmulRelatedFunctions.MNNPackedMatMul = gInstance->MNNPackedMatMul;
|
||||
gInstance->backendMatmulRelatedFunctions.MNNPackedMatMulRemain = gInstance->MNNPackedMatMulRemain;
|
||||
gInstance->backendMatmulRelatedFunctions.MNNPackC4ForMatMul_A = gInstance->MNNPackC4ForMatMul_A;
|
||||
gInstance->backendMatmulRelatedFunctions.MNNPackForMatMul_B = gInstance->MNNPackForMatMul_B;
|
||||
}
|
||||
#ifdef __aarch64__
|
||||
#ifdef MNN_SME2
|
||||
if (origin->supportSME2) {
|
||||
gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<16, 4>;
|
||||
gInstance->sme2MatmulRelatedFuncions.MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<16, 4>;
|
||||
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul, MNNPackedMatMulFP16_SME2);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain, MNNPackedMatMulRemainFP16_SME2);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNGetMatMulPackMode, Sme2MNNGetMatMulPackMode);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Sme2MNNPackForMatMul_A_FP16);
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNPackForMatMul_B, Sme2MNNPackForMatMul_B);
|
||||
|
||||
gInstance->sme2MatmulRelatedFuncions.MNNPackedMatMul = MNNPackedMatMulFP16_SME2;
|
||||
gInstance->sme2MatmulRelatedFuncions.MNNPackedMatMulRemain = MNNPackedMatMulRemainFP16_SME2;
|
||||
gInstance->sme2MatmulRelatedFuncions.MNNGetMatMulPackMode = Sme2MNNGetMatMulPackMode;
|
||||
gInstance->sme2MatmulRelatedFuncions.MNNPackC4ForMatMul_A = Sme2MNNPackForMatMul_A_FP16;
|
||||
gInstance->sme2MatmulRelatedFuncions.MNNPackForMatMul_B = Sme2MNNPackForMatMul_B;
|
||||
|
||||
#ifdef MNN_LOW_MEMORY
|
||||
FUNC_PTR_ASSIGN(gInstance->MNNGeneralIm2Col, MNNGeneralIm2col_Fp16Sme2);
|
||||
gInstance->sme2MatmulRelatedFuncions.MNNGeneralIm2Col = MNNGeneralIm2col_Fp16Sme2;
|
||||
#endif
|
||||
}
|
||||
#endif // MNN_SME2
|
||||
#endif // __aarch64__
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -249,6 +249,268 @@ void MNNNC8HW8TONHWC(float* dest, const FLOAT16* src, size_t plane, size_t chann
|
|||
#ifdef __aarch64__
|
||||
#ifdef MNN_LOW_MEMORY
|
||||
|
||||
bool MNNAsyLocalQuantInfo_EP16_FP16(float* scale, float* bias, float* qscale, float* qbias, const float* srcMin, const float* srcMax, const size_t* info) {
|
||||
// dequant scale/bias : [EU, blockNum, step]
|
||||
// quant scale/bias: [blockNum, EP]
|
||||
|
||||
auto blockNum = info[0];
|
||||
auto EP = info[1];
|
||||
auto DST_XUNIT = info[3];
|
||||
if (DST_XUNIT != 16) {
|
||||
MNN_ERROR("Call error: MNNAsyLocalQuantInfo_EP16_FP16\n");
|
||||
return false;
|
||||
}
|
||||
auto stride = EP * blockNum;
|
||||
|
||||
auto minfloat = vdupq_n_f32(1e-6);
|
||||
auto _255f = vdupq_n_f32(255.f);
|
||||
auto _128f = vdupq_n_f32(128.f);
|
||||
auto _0f = vdupq_n_f32(0.f);
|
||||
|
||||
auto minPtr = (FLOAT16*)srcMin;
|
||||
auto maxPtr = (FLOAT16*)srcMax;
|
||||
for (int k = 0; k < blockNum; ++k) {
|
||||
auto qind = k * EP;
|
||||
auto realDstCount = EP;
|
||||
auto scalePtr = scale + k * ALIMIN(EP, DST_XUNIT);
|
||||
auto biasPtr = bias + k * ALIMIN(EP, DST_XUNIT);
|
||||
while (realDstCount > DST_XUNIT - 1) {
|
||||
auto max0_fp16 = vld1_f16(maxPtr + qind);
|
||||
auto max1_fp16 = vld1_f16(maxPtr + qind + 4);
|
||||
auto max2_fp16 = vld1_f16(maxPtr + qind + 8);
|
||||
auto max3_fp16 = vld1_f16(maxPtr + qind + 12);
|
||||
auto min0_fp16 = vld1_f16(minPtr + qind);
|
||||
auto min1_fp16 = vld1_f16(minPtr + qind + 4);
|
||||
auto min2_fp16 = vld1_f16(minPtr + qind + 8);
|
||||
auto min3_fp16 = vld1_f16(minPtr + qind + 12);
|
||||
|
||||
// float16 -> float32
|
||||
auto max0 = vcvt_f32_f16(max0_fp16);
|
||||
auto max1 = vcvt_f32_f16(max1_fp16);
|
||||
auto max2 = vcvt_f32_f16(max2_fp16);
|
||||
auto max3 = vcvt_f32_f16(max3_fp16);
|
||||
|
||||
auto min0 = vcvt_f32_f16(min0_fp16);
|
||||
auto min1 = vcvt_f32_f16(min1_fp16);
|
||||
auto min2 = vcvt_f32_f16(min2_fp16);
|
||||
auto min3 = vcvt_f32_f16(min3_fp16);
|
||||
// diff
|
||||
auto diff0 = vsubq_f32(max0, min0);
|
||||
auto diff1 = vsubq_f32(max1, min1);
|
||||
auto diff2 = vsubq_f32(max2, min2);
|
||||
auto diff3 = vsubq_f32(max3, min3);
|
||||
|
||||
auto qscaleV0 = vdivq_f32(_255f, diff0);
|
||||
auto qscaleV1 = vdivq_f32(_255f, diff1);
|
||||
auto qscaleV2 = vdivq_f32(_255f, diff2);
|
||||
auto qscaleV3 = vdivq_f32(_255f, diff3);
|
||||
auto scaleV0 = vdivq_f32(diff0, _255f);
|
||||
auto scaleV1 = vdivq_f32(diff1, _255f);
|
||||
auto scaleV2 = vdivq_f32(diff2, _255f);
|
||||
auto scaleV3 = vdivq_f32(diff3, _255f);
|
||||
|
||||
auto qbiasV0 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min0), diff0), _128f));
|
||||
auto qbiasV1 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min1), diff1), _128f));
|
||||
auto qbiasV2 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min2), diff2), _128f));
|
||||
auto qbiasV3 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min3), diff3), _128f));
|
||||
auto biasV0 = vaddq_f32(vdivq_f32(vmulq_f32(diff0, _128f), _255f), min0);
|
||||
auto biasV1 = vaddq_f32(vdivq_f32(vmulq_f32(diff1, _128f), _255f), min1);
|
||||
auto biasV2 = vaddq_f32(vdivq_f32(vmulq_f32(diff2, _128f), _255f), min2);
|
||||
auto biasV3 = vaddq_f32(vdivq_f32(vmulq_f32(diff3, _128f), _255f), min3);
|
||||
|
||||
auto _0bic = vclezq_f32(diff0);
|
||||
auto _1bic = vclezq_f32(diff1);
|
||||
auto _2bic = vclezq_f32(diff2);
|
||||
auto _3bic = vclezq_f32(diff3);
|
||||
|
||||
qscaleV0 = vbslq_f32(_0bic, _0f, qscaleV0);
|
||||
qscaleV1 = vbslq_f32(_1bic, _0f, qscaleV1);
|
||||
qscaleV2 = vbslq_f32(_2bic, _0f, qscaleV2);
|
||||
qscaleV3 = vbslq_f32(_3bic, _0f, qscaleV3);
|
||||
|
||||
qbiasV0 = vrndaq_f32(vbslq_f32(_0bic, _0f, qbiasV0));
|
||||
qbiasV1 = vrndaq_f32(vbslq_f32(_1bic, _0f, qbiasV1));
|
||||
qbiasV2 = vrndaq_f32(vbslq_f32(_2bic, _0f, qbiasV2));
|
||||
qbiasV3 = vrndaq_f32(vbslq_f32(_3bic, _0f, qbiasV3));
|
||||
|
||||
scaleV0 = vbslq_f32(_0bic, _0f, scaleV0);
|
||||
scaleV1 = vbslq_f32(_1bic, _0f, scaleV1);
|
||||
scaleV2 = vbslq_f32(_2bic, _0f, scaleV2);
|
||||
scaleV3 = vbslq_f32(_3bic, _0f, scaleV3);
|
||||
|
||||
biasV0 = vbslq_f32(_0bic, max0, biasV0);
|
||||
biasV1 = vbslq_f32(_1bic, max1, biasV1);
|
||||
biasV2 = vbslq_f32(_2bic, max2, biasV2);
|
||||
biasV3 = vbslq_f32(_3bic, max3, biasV3);
|
||||
|
||||
vst1q_f32(qscale + qind, qscaleV0);
|
||||
vst1q_f32(qscale + qind + 4, qscaleV1);
|
||||
vst1q_f32(qscale + qind + 8, qscaleV2);
|
||||
vst1q_f32(qscale + qind + 12, qscaleV3);
|
||||
|
||||
vst1q_f32(qbias + qind, qbiasV0);
|
||||
vst1q_f32(qbias + qind + 4, qbiasV1);
|
||||
vst1q_f32(qbias + qind + 8, qbiasV2);
|
||||
vst1q_f32(qbias + qind + 12, qbiasV3);
|
||||
|
||||
vst1q_f32(scalePtr, scaleV0);
|
||||
vst1q_f32(scalePtr + 4, scaleV1);
|
||||
vst1q_f32(scalePtr + 8, scaleV2);
|
||||
vst1q_f32(scalePtr + 12, scaleV3);
|
||||
|
||||
vst1q_f32(biasPtr, biasV0);
|
||||
vst1q_f32(biasPtr + 4, biasV1);
|
||||
vst1q_f32(biasPtr + 8, biasV2);
|
||||
vst1q_f32(biasPtr + 12, biasV3);
|
||||
|
||||
realDstCount -= DST_XUNIT;
|
||||
qind += DST_XUNIT;
|
||||
scalePtr += (blockNum * DST_XUNIT);
|
||||
biasPtr += (blockNum * DST_XUNIT);
|
||||
}
|
||||
if (realDstCount == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto remainE = realDstCount;
|
||||
auto stride0 = remainE * blockNum;
|
||||
scalePtr = scale + (EP / DST_XUNIT) * blockNum * DST_XUNIT + k * remainE;
|
||||
biasPtr = bias + (EP / DST_XUNIT) * blockNum * DST_XUNIT + k * remainE;
|
||||
if (realDstCount > 7) {
|
||||
auto max0_fp16 = vld1_f16(maxPtr + qind);
|
||||
auto max1_fp16 = vld1_f16(maxPtr + qind + 4);
|
||||
auto min0_fp16 = vld1_f16(minPtr + qind);
|
||||
auto min1_fp16 = vld1_f16(minPtr + qind + 4);
|
||||
|
||||
// float16 -> float32
|
||||
auto max0 = vcvt_f32_f16(max0_fp16);
|
||||
auto max1 = vcvt_f32_f16(max1_fp16);
|
||||
|
||||
auto min0 = vcvt_f32_f16(min0_fp16);
|
||||
auto min1 = vcvt_f32_f16(min1_fp16);
|
||||
// diff
|
||||
auto diff0 = vsubq_f32(max0, min0);
|
||||
auto diff1 = vsubq_f32(max1, min1);
|
||||
|
||||
auto qscaleV0 = vdivq_f32(_255f, diff0);
|
||||
auto qscaleV1 = vdivq_f32(_255f, diff1);
|
||||
auto scaleV0 = vdivq_f32(diff0, _255f);
|
||||
auto scaleV1 = vdivq_f32(diff1, _255f);
|
||||
|
||||
auto qbiasV0 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min0), diff0), _128f));
|
||||
auto qbiasV1 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min1), diff1), _128f));
|
||||
auto biasV0 = vaddq_f32(vdivq_f32(vmulq_f32(diff0, _128f), _255f), min0);
|
||||
auto biasV1 = vaddq_f32(vdivq_f32(vmulq_f32(diff1, _128f), _255f), min1);
|
||||
|
||||
auto _0bic = vclezq_f32(diff0);
|
||||
auto _1bic = vclezq_f32(diff1);
|
||||
|
||||
qscaleV0 = vbslq_f32(_0bic, _0f, qscaleV0);
|
||||
qscaleV1 = vbslq_f32(_1bic, _0f, qscaleV1);
|
||||
|
||||
qbiasV0 = vrndaq_f32(vbslq_f32(_0bic, _0f, qbiasV0));
|
||||
qbiasV1 = vrndaq_f32(vbslq_f32(_1bic, _0f, qbiasV1));
|
||||
|
||||
scaleV0 = vbslq_f32(_0bic, _0f, scaleV0);
|
||||
scaleV1 = vbslq_f32(_1bic, _0f, scaleV1);
|
||||
|
||||
biasV0 = vbslq_f32(_0bic, max0, biasV0);
|
||||
biasV1 = vbslq_f32(_1bic, max1, biasV1);
|
||||
|
||||
vst1q_f32(qscale + qind, qscaleV0);
|
||||
vst1q_f32(qscale + qind + 4, qscaleV1);
|
||||
|
||||
vst1q_f32(qbias + qind, qbiasV0);
|
||||
vst1q_f32(qbias + qind + 4, qbiasV1);
|
||||
|
||||
vst1q_f32(scalePtr, scaleV0);
|
||||
vst1q_f32(scalePtr + 4, scaleV1);
|
||||
|
||||
vst1q_f32(biasPtr, biasV0);
|
||||
vst1q_f32(biasPtr + 4, biasV1);
|
||||
realDstCount -= 8;
|
||||
qind += 8;
|
||||
scalePtr += 8;
|
||||
biasPtr += 8;
|
||||
}
|
||||
if (realDstCount > 3) {
|
||||
auto max0_fp16 = vld1_f16(maxPtr + qind);
|
||||
auto min0_fp16 = vld1_f16(minPtr + qind);
|
||||
|
||||
// float16 -> float32
|
||||
auto max0 = vcvt_f32_f16(max0_fp16);
|
||||
auto min0 = vcvt_f32_f16(min0_fp16);
|
||||
// diff
|
||||
auto diff0 = vsubq_f32(max0, min0);
|
||||
|
||||
auto qscaleV0 = vdivq_f32(_255f, diff0);
|
||||
auto scaleV0 = vdivq_f32(diff0, _255f);
|
||||
|
||||
auto qbiasV0 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min0), diff0), _128f));
|
||||
auto biasV0 = vaddq_f32(vdivq_f32(vmulq_f32(diff0, _128f), _255f), min0);
|
||||
|
||||
auto _0bic = vclezq_f32(diff0);
|
||||
|
||||
qscaleV0 = vbslq_f32(_0bic, _0f, qscaleV0);
|
||||
qbiasV0 = vrndaq_f32(vbslq_f32(_0bic, _0f, qbiasV0));
|
||||
scaleV0 = vbslq_f32(_0bic, _0f, scaleV0);
|
||||
biasV0 = vbslq_f32(_0bic, max0, biasV0);
|
||||
|
||||
vst1q_f32(qscale + qind, qscaleV0);
|
||||
|
||||
vst1q_f32(qbias + qind, qbiasV0);
|
||||
|
||||
vst1q_f32(scalePtr, scaleV0);
|
||||
|
||||
vst1q_f32(biasPtr, biasV0);
|
||||
|
||||
realDstCount -= 4;
|
||||
qind += 4;
|
||||
scalePtr += 4;
|
||||
biasPtr += 4;
|
||||
}
|
||||
while (realDstCount > 0) {
|
||||
auto max0_fp16 = vld1_dup_f16(maxPtr + qind);
|
||||
auto min0_fp16 = vld1_dup_f16(minPtr + qind);
|
||||
|
||||
// float16->float32
|
||||
auto max0 = vcvt_f32_f16(max0_fp16);
|
||||
auto min0 = vcvt_f32_f16(min0_fp16);
|
||||
auto diff0 = vsubq_f32(max0, min0);
|
||||
|
||||
auto qscaleV0 = vdivq_f32(_255f, diff0);
|
||||
auto scaleV0 = vdivq_f32(diff0, _255f);
|
||||
|
||||
auto qbiasV0 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min0), diff0), _128f));
|
||||
auto biasV0 = vaddq_f32(vdivq_f32(vmulq_f32(diff0, _128f), _255f), min0);
|
||||
|
||||
auto _0bic = vclezq_f32(diff0);
|
||||
|
||||
qscaleV0 = vbslq_f32(_0bic, _0f, qscaleV0);
|
||||
|
||||
qbiasV0 = vrndaq_f32(vbslq_f32(_0bic, _0f, qbiasV0));
|
||||
|
||||
scaleV0 = vbslq_f32(_0bic, _0f, scaleV0);
|
||||
|
||||
biasV0 = vbslq_f32(_0bic, max0, biasV0);
|
||||
|
||||
vst1q_lane_f32(qscale + qind, qscaleV0, 0);
|
||||
|
||||
vst1q_lane_f32(qbias + qind, qbiasV0, 0);
|
||||
|
||||
vst1q_lane_f32(scalePtr, scaleV0, 0);
|
||||
|
||||
vst1q_lane_f32(biasPtr, biasV0, 0);
|
||||
|
||||
realDstCount -= 1;
|
||||
qind += 1;
|
||||
scalePtr += 1;
|
||||
biasPtr += 1;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MNNAsyLocalQuantInfo_EP12_FP16(float* scale, float* bias, float* qscale, float* qbias, const float* srcMin, const float* srcMax, const size_t* info) {
|
||||
// dequant scale/bias : [EU, blockNum, step]
|
||||
// quant scale/bias: [blockNum, EP]
|
||||
|
|
|
@ -36,6 +36,7 @@ void MNNSlowCopy(T* dst, const U* src, size_t size) {
|
|||
#ifdef __aarch64__
|
||||
bool MNNAsyLocalQuantInfo_EP12_FP16(float* scale, float* bias, float* qscale, float* qbias, const float* srcMin, const float* srcMax, const size_t* info);
|
||||
bool MNNAsyLocalQuantInfo_EP10_FP16(float* scale, float* bias, float* qscale, float* qbias, const float* srcMin, const float* srcMax, const size_t* info);
|
||||
bool MNNAsyLocalQuantInfo_EP16_FP16(float* scale, float* bias, float* qscale, float* qbias, const float* srcMin, const float* srcMax, const size_t* info);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
|
|
@ -13,7 +13,15 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR AR
|
|||
if (MNN_CPU_WEIGHT_DEQUANT_GEMM)
|
||||
file(GLOB MNN_ARM82_SRCS_ASM ${MNN_ARM82_SRCS_ASM} ${CMAKE_CURRENT_LIST_DIR}/asm/arm64/normal_memory/*)
|
||||
endif()
|
||||
add_library(MNN_Arm82 OBJECT ${MNN_ARM82_SRCS} ${MNN_ARM82_SRCS_ASM})
|
||||
if (MNN_SME2)
|
||||
file(GLOB MNN_SME2_SRCS_ASM_FP16 ${MNN_SME2_SRCS_ASM_FP16} ${CMAKE_CURRENT_LIST_DIR}/asm/arm64/sme2_asm/*)
|
||||
#set_source_files_properties(${MNN_SME2_SRCS_ASM_FP16} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.6-a+sve+sve2+sme+sme2+fp16")
|
||||
set_source_files_properties(${MNN_SME2_SRCS_ASM_FP16} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+fp16")
|
||||
endif()
|
||||
set_source_files_properties(${MNN_ARM82_SRCS} PROPERTIES COMPILE_OPTIONS "-march=armv8.2-a+fp16")
|
||||
set_source_files_properties(${MNN_ARM82_SRCS_ASM} PROPERTIES COMPILE_OPTIONS "-march=armv8.2-a+fp16")
|
||||
|
||||
add_library(MNN_Arm82 OBJECT ${MNN_ARM82_SRCS} ${MNN_ARM82_SRCS_ASM} ${MNN_SME2_SRCS_ASM_FP16})
|
||||
if (MNN_LOW_MEMORY)
|
||||
target_compile_options(MNN_Arm82 PRIVATE -DMNN_LOW_MEMORY)
|
||||
endif()
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -56,13 +56,19 @@ void CPUAttention::pack_query(Tensor* query, char* pack_q, char* sum_q, int seq_
|
|||
}
|
||||
}
|
||||
else {
|
||||
// target: [seq_len/eP, mHeadDim/lP, eP, lP]
|
||||
T * query_src = query->host<T>();
|
||||
T * query_dst = reinterpret_cast<T*>(pack_q);
|
||||
auto stride0 = ROUND_UP(mHeadDim, lP) * eP;
|
||||
auto stride1 = eP * lP;
|
||||
if (mHeadDim % lP) {
|
||||
memset(query_dst, 0, ROUND_UP(mHeadDim, lP) * bytes * ROUND_UP(seq_len, eP));
|
||||
}
|
||||
for (int i = 0; i < seq_len; i++) {
|
||||
int out_index = i / eP;
|
||||
int in_index = i % eP;
|
||||
for (int j = 0; j < mHeadDim; j++) {
|
||||
query_dst[out_index * mHeadDim * eP + j * eP + in_index] = query_src[i * mNumHead * mHeadDim + h * mHeadDim + j] * q_scale;
|
||||
query_dst[out_index * stride0 + (j / lP) * stride1 + in_index * lP + (j % lP)] = query_src[i * mNumHead * mHeadDim + h * mHeadDim + j] * q_scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -83,15 +89,20 @@ void CPUAttention::unpack_QK(float * unpack_qk_dst, char * pack_qk_src, int seq_
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
static void pack_QK(char * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_len, int eP) {
|
||||
static void pack_QK(char * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_len, int eP, int lP, int bytes) {
|
||||
T * dst = reinterpret_cast<T*>(pack_qk_dst);
|
||||
float * src = reinterpret_cast<float*>(qk_src);
|
||||
// [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len, eP]
|
||||
// [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len/lP, eP, lP]
|
||||
auto stride0 = ROUND_UP(kv_seq_len, lP) * eP;
|
||||
auto stride1 = eP * lP;
|
||||
if (kv_seq_len % lP) {
|
||||
memset(dst, 0, ROUND_UP(kv_seq_len, lP) * ROUND_UP(seq_len, eP) * bytes);
|
||||
}
|
||||
for (int i = 0; i < seq_len; i++) {
|
||||
int out_index = i / eP;
|
||||
int in_index = i % eP;
|
||||
for (int j = 0; j < kv_seq_len; j++) {
|
||||
dst[out_index * kv_seq_len * eP + j * eP + in_index] = src[i * kv_seq_len + j];
|
||||
dst[out_index * stride0 + (j / lP) * stride1 + in_index * lP + (j % lP)] = src[i * kv_seq_len + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -181,7 +192,7 @@ ErrorCode CPUAttention::onResize(const std::vector<Tensor*>& inputs, const std::
|
|||
mQueryScale.resize(mNumHead);
|
||||
mQueryZeroPoint.resize(mNumHead);
|
||||
} else {
|
||||
mPackQ.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(seq_len, eP), mHeadDim, eP}));
|
||||
mPackQ.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(mHeadDim, lP), eP}));
|
||||
mPackQKV.reset(Tensor::createDevice<float>({mThreadNum, UP_DIV(mHeadDim, unit), seq_len, unit}));
|
||||
backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC);
|
||||
backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC);
|
||||
|
@ -242,7 +253,7 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std:
|
|||
std::shared_ptr<Tensor> packQK(Tensor::createDevice<float>({mThreadNum, UP_DIV(kv_seq_len, unit), seq_len, unit}));
|
||||
std::shared_ptr<Tensor> unpackQK(Tensor::createDevice<int32_t>({mThreadNum, seq_len, kv_seq_len}));
|
||||
std::shared_ptr<Tensor> softmMaxQ(Tensor::createDevice<int32_t>({mThreadNum, seq_len, kv_seq_len}));
|
||||
std::shared_ptr<Tensor> newPackQK(Tensor::createDevice<float>({mThreadNum, UP_DIV(seq_len, eP), kv_seq_len, eP}));
|
||||
std::shared_ptr<Tensor> newPackQK(Tensor::createDevice<float>({mThreadNum, UP_DIV(seq_len, eP), ROUND_UP(kv_seq_len, lP), eP}));
|
||||
std::shared_ptr<Tensor> dequantV(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), kv_seq_len, hP}));
|
||||
backend()->onAcquireBuffer(packQK.get(), Backend::STATIC);
|
||||
backend()->onAcquireBuffer(unpackQK.get(), Backend::STATIC);
|
||||
|
@ -254,12 +265,12 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std:
|
|||
}
|
||||
|
||||
std::function<void(int)> mCompute = [=](int tId) {
|
||||
auto pack_q = mPackQ->host<char>() + tId * UP_DIV(seq_len, eP) * mHeadDim * eP * bytes;
|
||||
auto pack_q = mPackQ->host<char>() + tId * UP_DIV(seq_len, eP) * ROUND_UP(mHeadDim, lP) * eP * bytes;
|
||||
auto pack_qk = packQK->host<char>() + tId * UP_DIV(kv_seq_len, unit) * seq_len * unit * bytes;
|
||||
char * sum_q = nullptr;
|
||||
auto unpack_qk = unpackQK->host<float>() + tId * seq_len * kv_seq_len;
|
||||
auto softmax_qk = softmMaxQ->host<float>() + tId * seq_len * kv_seq_len;
|
||||
auto new_pack_qk = newPackQK->host<char>() + tId * UP_DIV(seq_len, eP) * kv_seq_len * eP * bytes;
|
||||
auto new_pack_qk = newPackQK->host<char>() + tId * UP_DIV(seq_len, eP) * ROUND_UP(kv_seq_len, lP) * eP * bytes;
|
||||
auto pack_qkv = mPackQKV->host<char>() + tId * UP_DIV(mHeadDim, unit) * seq_len * unit * bytes;
|
||||
auto QxK = quant_key ? core->MNNPackedMatMul_int8 : core->MNNPackedMatMul;
|
||||
auto QxK_remain = quant_key ? core->MNNPackedMatMulRemain_int8 : core->MNNPackedMatMulRemain;
|
||||
|
@ -274,7 +285,7 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std:
|
|||
char * scale_addr = mKVCacheManager->addrOfScale(kv_h);
|
||||
char * zero_point_addr = mKVCacheManager->addrOfZeroPoint(kv_h);
|
||||
char * key_sum_addr = mKVCacheManager->addrOfKeySum(kv_h);
|
||||
char * value_addr = quant_value ? (dequantV->host<char>() + kv_h * UP_DIV(mHeadDim, hP) * kv_seq_len * hP * bytes) : mKVCacheManager->addrOfValue(kv_h);
|
||||
char * value_addr = quant_value ? (dequantV->host<char>() + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(kv_seq_len, lP) * hP * bytes) : mKVCacheManager->addrOfValue(kv_h);
|
||||
if (bytes == 2) {
|
||||
pack_query<FLOAT16_T>(query, pack_q, sum_q, seq_len, h, q_scale);
|
||||
} else {
|
||||
|
@ -327,33 +338,36 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std:
|
|||
else {
|
||||
int loop_e = seq_len / eP;
|
||||
int remain = seq_len % eP;
|
||||
size_t shapeParameters[7] = {(size_t)eP * bytes, (size_t)mHeadDim, (size_t)kv_seq_len, (size_t)seq_len * unit * bytes, 0, 0, 0};
|
||||
auto qStride0 = ROUND_UP(mHeadDim, lP) * eP * bytes;
|
||||
size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)mHeadDim, lP), (size_t)kv_seq_len, (size_t)seq_len * unit * bytes, 0, 0, 0};
|
||||
for (int i = 0 ; i < loop_e; i++) {
|
||||
QxK((float*)(pack_qk + (i * eP * unit) * bytes), (float*)(pack_q + (i * mHeadDim * eP) * bytes), (float*)key_addr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
|
||||
QxK((float*)(pack_qk + (i * eP * unit) * bytes), (float*)(pack_q + i * qStride0), (float*)key_addr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
|
||||
}
|
||||
QxK_remain((float*)(pack_qk + (loop_e * eP * unit) * bytes), (float*)(pack_q + (loop_e * mHeadDim * eP) * bytes), (float*)key_addr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
|
||||
QxK_remain((float*)(pack_qk + (loop_e * eP * unit) * bytes), (float*)(pack_q + loop_e * qStride0), (float*)key_addr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr);
|
||||
}
|
||||
// qk: [kv_seq_len/unit, seq_len, unit] -> [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len, eP]
|
||||
if(bytes == 2) {
|
||||
unpack_QK<FLOAT16_T>(unpack_qk, pack_qk, seq_len, kv_seq_len);
|
||||
mask_QK<FLOAT16_T>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask);
|
||||
softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len);
|
||||
pack_QK<FLOAT16_T>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP);
|
||||
pack_QK<FLOAT16_T>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes);
|
||||
} else {
|
||||
unpack_QK<float>(unpack_qk, pack_qk, seq_len, kv_seq_len);
|
||||
mask_QK<float>(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits<float>::lowest(), mask);
|
||||
softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len);
|
||||
pack_QK<float>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP);
|
||||
pack_QK<float>(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP, lP, bytes);
|
||||
}
|
||||
// qk @ v
|
||||
size_t shapeParameters[7] = {(size_t)eP * bytes, (size_t)kv_seq_len, (size_t)mHeadDim, (size_t)seq_len * unit * bytes, 0, 0, 0};
|
||||
shapeParameters[5] = quant_value ? 0 : (max_len - kv_seq_len) * hP * bytes;
|
||||
size_t shapeParameters[7] = {(size_t)eP * bytes, ROUND_UP((size_t)kv_seq_len, lP), (size_t)mHeadDim, (size_t)seq_len * unit * bytes, 0, 0, 0};
|
||||
size_t bExtraStride = (UP_DIV(max_len, lP) - UP_DIV(kv_seq_len, lP)) * hP * lP * bytes;
|
||||
shapeParameters[5] = quant_value ? 0 : bExtraStride;
|
||||
int loop_e = seq_len / eP;
|
||||
int remain = seq_len % eP;
|
||||
auto qkStride0 = ROUND_UP(kv_seq_len, lP) * eP * bytes;
|
||||
for (int i = 0 ; i < loop_e; i++) {
|
||||
core->MNNPackedMatMul((float*)(pack_qkv + (i * eP * unit) * bytes), (float*)(new_pack_qk + (i * kv_seq_len * eP) * bytes), (float*)value_addr, shapeParameters, nullptr, nullptr, nullptr, nullptr);
|
||||
core->MNNPackedMatMul((float*)(pack_qkv + (i * eP * unit) * bytes), (float*)(new_pack_qk + i * qkStride0), (float*)value_addr, shapeParameters, nullptr, nullptr, nullptr, nullptr);
|
||||
}
|
||||
core->MNNPackedMatMulRemain((float*)(pack_qkv + (loop_e * eP * unit) * bytes), (float*)(new_pack_qk + (loop_e * kv_seq_len * eP) * bytes), (float*)value_addr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr);
|
||||
core->MNNPackedMatMulRemain((float*)(pack_qkv + (loop_e * eP * unit) * bytes), (float*)(new_pack_qk + loop_e * qkStride0), (float*)value_addr, remain, shapeParameters, nullptr, nullptr, nullptr, nullptr);
|
||||
// unpack: [head_dim/unit, seq_len, unit] -> [seq_len, num_head, head_dim]
|
||||
auto dst_ptr = outputs[0]->host<char>() + h * mHeadDim * bytes;
|
||||
if (bytes == 2) {
|
||||
|
@ -393,7 +407,7 @@ bool CPUAttention::onClone(Backend* bn, const Op* op, Execution** dst) {
|
|||
}
|
||||
|
||||
CPUAttention::CPUAttention(Backend *backend, bool kv_cache) : Execution(backend), mKVCache(kv_cache) {
|
||||
mMeta = (KVMeta*)(backend->getRuntime()->pMeta);
|
||||
mMeta = (KVMeta*)(backend->getMetaPtr());
|
||||
mPackQ.reset(Tensor::createDevice<float>({1, 1, 1, 1}));
|
||||
mPackQKV.reset(Tensor::createDevice<float>({1, 1, 1, 1}));
|
||||
MNN::KVCacheManager::KVCacheConfig kvconfig;
|
||||
|
|
|
@ -102,7 +102,7 @@ void CPURuntime::_bindCPUCore() const {
|
|||
}
|
||||
#endif
|
||||
#ifdef MNN_USE_THREAD_POOL
|
||||
if(mThreadPool) {
|
||||
if (nullptr != mThreadPool) {
|
||||
mThreadPool->active();
|
||||
mThreadPool->enqueue(std::make_pair([&](int i) {
|
||||
MNNSetSchedAffinity(lockCPUIndexes[i].first, lockCPUIndexes[i].second);
|
||||
|
@ -113,24 +113,12 @@ void CPURuntime::_bindCPUCore() const {
|
|||
#endif
|
||||
}
|
||||
|
||||
void CPURuntime::_resetThreadPool() const{
|
||||
void CPURuntime::_resetThreadPool() const {
|
||||
mThreadNumber = std::max(1, mThreadNumber);
|
||||
mThreadNumber = std::min(mThreadNumber, MAX_THREAD_NUMBER);
|
||||
#ifdef MNN_USE_THREAD_POOL
|
||||
if (mThreadPool) {
|
||||
mThreadPool->releaseWorkIndex(mTaskIndex);
|
||||
}
|
||||
if (mThreadNumber > 1) {
|
||||
mThreadNumber = ALIMIN(ThreadPool::init(mThreadNumber, mCpuMask, mThreadPool), mThreadNumber);
|
||||
if (mThreadPool) {
|
||||
mTaskIndex = mThreadPool->acquireWorkIndex();
|
||||
}
|
||||
if (-1 == mTaskIndex) {
|
||||
MNN_ERROR("The ThreadPool has been used to MNN_THREAD_POOL_MAX_TASKS, can't use thread pool\n");
|
||||
mThreadNumber = 1;
|
||||
}
|
||||
} else {
|
||||
mTaskIndex = -1;
|
||||
}
|
||||
#endif
|
||||
// Reset tid to rebind cpu if necessary
|
||||
|
@ -256,11 +244,7 @@ CPURuntime::CPURuntime(const Backend::Info& info) {
|
|||
}
|
||||
|
||||
CPURuntime:: ~ CPURuntime() {
|
||||
#ifdef MNN_USE_THREAD_POOL
|
||||
if(mThreadPool) {
|
||||
mThreadPool->releaseWorkIndex(mTaskIndex);
|
||||
}
|
||||
#endif
|
||||
// Do nothing
|
||||
}
|
||||
float CPURuntime::onGetMemoryInMB() {
|
||||
auto staticMemoryInMB = mStaticAllocator->totalSize() / 1024.0f / 1024.0f;
|
||||
|
@ -336,11 +320,17 @@ Backend* CPURuntime::onCreate(const BackendConfig* config, Backend* origin) cons
|
|||
MNN_PRINT("cpu backend was created by runtime:%p\n", this);
|
||||
#endif
|
||||
CPUBackend* res = nullptr;
|
||||
auto initThreadNumber = hint().initThreadNumber;
|
||||
do {
|
||||
#ifdef MNN_USE_ARMV82
|
||||
auto core = MNNGetCoreFunctions();
|
||||
if (core->supportFp16arith && precision == BackendConfig::Precision_Low) {
|
||||
res = new Arm82Backend(this, memory);
|
||||
if (hint().useArmSme2Cores && res->threadNumber() <= 2 && core->supportSME2 && res->functions()->sme2MatmulRelatedFuncions.Int8GemmKernel) {
|
||||
res->mRelatedFunctions = &(res->functions()->sme2MatmulRelatedFuncions);
|
||||
} else {
|
||||
res->mRelatedFunctions = &(res->functions()->backendMatmulRelatedFunctions);
|
||||
}
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
|
@ -353,7 +343,7 @@ Backend* CPURuntime::onCreate(const BackendConfig* config, Backend* origin) cons
|
|||
#endif
|
||||
if (flags == MNN_CPU_USE_DEFAULT_BACKEND) {
|
||||
// Default don't use multi-thread init
|
||||
res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU, 0);
|
||||
res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU);
|
||||
break;
|
||||
}
|
||||
#ifdef MNN_USE_SSE
|
||||
|
@ -365,6 +355,7 @@ Backend* CPURuntime::onCreate(const BackendConfig* config, Backend* origin) cons
|
|||
res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU, flags);
|
||||
} while (false);
|
||||
mSharedDmaInfo = nullptr;
|
||||
res->setMetaPtr(pMeta);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -402,9 +393,13 @@ void CPURuntime::onGabageCollect(int level) {
|
|||
|
||||
void CPURuntime::onConcurrencyBegin() const {
|
||||
#ifdef MNN_USE_THREAD_POOL
|
||||
if (mTaskIndex < 0 && nullptr != mThreadPool) {
|
||||
mTaskIndex = mThreadPool->acquireWorkIndex();
|
||||
}
|
||||
if (mTaskIndex >= 0) {
|
||||
if (mThreadOpen == 0 && mThreadPool) {
|
||||
// mThreadOpen 0 -> 1, open ThreadPool
|
||||
// mThreadOpen 0 -> 1, active ThreadPool
|
||||
// For next onConcurrencyBegin, will only add mThreadOpen
|
||||
if (0 == mThreadOpen) {
|
||||
mThreadPool->active();
|
||||
}
|
||||
mThreadOpen++;
|
||||
|
@ -421,11 +416,12 @@ void CPURuntime::onConcurrencyBegin() const {
|
|||
void CPURuntime::onConcurrencyEnd() const {
|
||||
#ifdef MNN_USE_THREAD_POOL
|
||||
if (mTaskIndex >= 0) {
|
||||
MNN_ASSERT(mThreadOpen > 0);
|
||||
mThreadOpen--;
|
||||
mThreadOpen = mThreadOpen < 0 ? 0 : mThreadOpen;
|
||||
if (0 == mThreadOpen && mThreadPool) {
|
||||
if (0 == mThreadOpen) {
|
||||
mThreadPool->releaseWorkIndex(mTaskIndex);
|
||||
mThreadPool->deactive();
|
||||
mTaskIndex = -1;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
@ -460,10 +456,14 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p
|
|||
#endif
|
||||
mMemory = memory;
|
||||
mRuntime = const_cast<CPURuntime*>(runtime);
|
||||
auto core = MNNGetCoreFunctions();
|
||||
mThreadNumber = mRuntime->mThreadNumber;
|
||||
#ifdef MNN_USE_THREAD_POOL
|
||||
mThreadPool = mRuntime->mThreadPool;
|
||||
#endif
|
||||
if (mRuntime->hint().useArmSme2Cores && core->supportSME2 && core->sme2MatmulRelatedFuncions.Int8GemmKernel) {
|
||||
mThreadNumber = 2;
|
||||
mRelatedFunctions = &core->sme2MatmulRelatedFuncions;
|
||||
} else {
|
||||
mRelatedFunctions = &core->backendMatmulRelatedFunctions;
|
||||
}
|
||||
// Compute Group Rate
|
||||
do {
|
||||
if (mThreadNumber <= 1 || mRuntime->mPower == BackendConfig::Power_Low) {
|
||||
|
|
|
@ -60,10 +60,12 @@ private:
|
|||
mutable int mThreadNumber;
|
||||
mutable std::vector<int> mCpuIds;
|
||||
mutable unsigned long mCpuMask;
|
||||
#ifdef MNN_USE_THREAD_POOL
|
||||
mutable ThreadPool* mThreadPool = nullptr;
|
||||
#endif
|
||||
#ifdef MNN_USE_THREAD_POOL
|
||||
mutable int mTaskIndex = -1;
|
||||
mutable int mThreadOpen = 0;
|
||||
mutable ThreadPool* mThreadPool = nullptr;
|
||||
#endif
|
||||
BackendConfig::MemoryMode mMemory;
|
||||
BackendConfig::PowerMode mPower;
|
||||
|
@ -82,6 +84,7 @@ private:
|
|||
};
|
||||
struct CoreFunctions;
|
||||
struct CoreInt8Functions;
|
||||
struct MatmulRelatedFunctions;
|
||||
|
||||
class CPUResizeCache;
|
||||
class CPUMemObj : public Backend::MemObj {
|
||||
|
@ -134,6 +137,10 @@ public:
|
|||
const CoreFunctions* functions() const {
|
||||
return mCoreFunctions;
|
||||
}
|
||||
|
||||
const MatmulRelatedFunctions* int8GemmFunctions() const {
|
||||
return mRelatedFunctions;
|
||||
}
|
||||
// Return element size for Tensor, conside pack
|
||||
size_t getTensorSize(const Tensor* tensor, bool multiBytes = false) const;
|
||||
const CoreInt8Functions* int8Functions() const {
|
||||
|
@ -183,12 +190,10 @@ protected:
|
|||
MemObj* allocBuffer(size_t size, Tensor* dest, StorageType storageType);
|
||||
CoreFunctions* mCoreFunctions;
|
||||
CoreInt8Functions* mInt8CoreFunctions;
|
||||
const MatmulRelatedFunctions* mRelatedFunctions;
|
||||
private:
|
||||
mutable std::shared_ptr<WorkerThread> mInitWorkQueue;
|
||||
int mThreadNumber;
|
||||
#ifdef MNN_USE_THREAD_POOL
|
||||
ThreadPool* mThreadPool = nullptr;
|
||||
#endif
|
||||
mutable int mThreadNumber = 1;
|
||||
std::vector<std::pair<float, int>> mGroupWithComputeRate;
|
||||
float mComputeI = 0.f;
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ static void _transformWeight(const uint8_t* tempWeight, uint8_t* dest, int outpu
|
|||
core->MNNPackCUnit((float*)dst, (const float*)src, fw*fh, outputCount, offset);
|
||||
}
|
||||
//printf("%d - %d - %d - %d\n", outputCount, srcCount, fh, fw);
|
||||
core->MNNPackForMatMul_B((float*)dest, (const float*)cache, outputC4 * fw * fh * core->pack, srcCount, false);
|
||||
core->MNNPackForMatMul_B((float*)dest, (const float*)cache, outputC4 * fw * fh * core->pack, 1, srcCount, false);
|
||||
}
|
||||
std::shared_ptr<DeconvolutionResource> CPUDeconvolution::makeResource(int srcCount, const Op *convOp, Backend* backend, bool dynamic) {
|
||||
auto core = static_cast<CPUBackend*>(backend)->functions();
|
||||
|
@ -253,7 +253,7 @@ ErrorCode CPUDeconvolutionOrigin::onResize(const std::vector<Tensor*>& inputs, c
|
|||
}
|
||||
::memset(tempOutPtr, 0, outputSize);
|
||||
|
||||
int l = mSrcCount;
|
||||
int l = ROUND_UP(mSrcCount, lP);
|
||||
int h = kernelCount * core->pack;
|
||||
auto weightPtr = weightTensor->host<uint8_t>();
|
||||
for (int index=tId; index < tileCount; index+=threadNumber) {
|
||||
|
|
|
@ -101,7 +101,7 @@ ErrorCode CPUMatMul::onResize(const std::vector<Tensor*>& inputs, const std::vec
|
|||
}
|
||||
|
||||
mPreFunctions.emplace_back(std::make_pair([BTPtrAlloc, l, h, this, core] (int tId, const float* APtr, const float* BPtr, const float* Bias, float* C) {
|
||||
core->MNNPackForMatMul_B((float*)BTPtrAlloc.ptr(), BPtr, h, l, mTransposeB);
|
||||
core->MNNPackForMatMul_B((float*)BTPtrAlloc.ptr(), BPtr, h, 1, l, mTransposeB);
|
||||
} , 1));
|
||||
bool useBias = false;
|
||||
MemChunk bdestAlloc;
|
||||
|
@ -194,7 +194,7 @@ void CPUMatMul::execute(const float* APtr, const float* BPtr, float* CPtr, const
|
|||
auto TC = mTempC.ptr() + tId * eP * hC4 * core->pack * core->bytes;
|
||||
size_t parameters[6];
|
||||
parameters[0] = eP * core->bytes;
|
||||
parameters[1] = mL;
|
||||
parameters[1] = lAlign;
|
||||
parameters[2] = mH;
|
||||
parameters[3] = eP * core->pack * core->bytes;
|
||||
parameters[4] = 0;
|
||||
|
@ -251,7 +251,7 @@ void CPUMatMul::execute(const float* APtr, const float* BPtr, float* CPtr, const
|
|||
int yy = lC;
|
||||
for (int x=0; x<xC; ++x) {
|
||||
::memset(TA + (yy * eP * lP + x * lP) * core->bytes, 0, lP * core->bytes);
|
||||
::memcpy(TA + (yy * eP * lP + x * lP) * core->bytes, (uint8_t*)APtr + ((x+xStart)*mL+yy*lP)*core->bytes, xC * core->bytes);
|
||||
::memcpy(TA + (yy * eP * lP + x * lP) * core->bytes, (uint8_t*)APtr + ((x+xStart)*mL+yy*lP)*core->bytes, lR * core->bytes);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -68,7 +68,7 @@ void CPURNNSequenceGRU::runRNNStep(const uint8_t* input, const int inputLength,
|
|||
mulFunction(resetGatePtr, rtPtr, hiddenStatePtr, numUnits, -1);
|
||||
// deal with recurrent bias and linear_before_reset parameter
|
||||
auto recurrentBiasAddedPtr = inputAndStatePtr + (inputLength + numUnits) * bytes;
|
||||
auto recurrentHiddenBiasPtr = recurrentBias->host<float>() + 2 * numUnits * bytes;
|
||||
auto recurrentHiddenBiasPtr = (float*)(recurrentBias->host<uint8_t>() + 2 * numUnits * bytes);
|
||||
addFunction(recurrentBiasAddedPtr, recurrentHiddenBiasPtr, candidateBias->host<float>(), numUnits, -1);
|
||||
mMatMulI2U->execute(inputAndState->host<float>(), candidateWeight->host<float>(), resetHt->host<float>(), nullptr);
|
||||
// reuse r_t memory as h_t'
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
|
||||
using Vec4 = MNN::Math::Vec<float, 4>;
|
||||
namespace MNN {
|
||||
typedef void (*FP16ToFP32)(const int16_t* src, float* dst, size_t size);
|
||||
typedef void (*FP32ToFP16)(const float* src, int16_t* dst, size_t size);
|
||||
struct ReduceInfo {
|
||||
int reduceMask[3] = {0, 0, 0};
|
||||
int reduceNum = 0;
|
||||
|
@ -434,7 +436,7 @@ static void _zero(const Tensor::InsideDescribe::Region& slice, int bytes, uint8_
|
|||
}
|
||||
}
|
||||
}
|
||||
static bool _reduceblit(const Tensor::InsideDescribe::Region& slice, int bytes, const uint8_t* srcPtr, uint8_t* dstPtr) {
|
||||
static bool _reduceblit(const Tensor::InsideDescribe::Region& slice, int bytes, const uint8_t* srcPtr, uint8_t* dstPtr, FP16ToFP32 funcFp16ToFp32 = nullptr, FP32ToFP16 funcFp32ToFp16 = nullptr) {
|
||||
ReduceInfo reduceInfo;
|
||||
reduceInfo.compute(slice);
|
||||
auto normalIndex = reduceInfo.normalIndex;
|
||||
|
@ -443,21 +445,31 @@ static bool _reduceblit(const Tensor::InsideDescribe::Region& slice, int bytes,
|
|||
case 3:
|
||||
{
|
||||
float summer = 0.0f;
|
||||
float fp32Buffer[1];
|
||||
for (int z=0; z<slice.size[0]; ++z) {
|
||||
auto srcZ = srcPtr + z * slice.src.stride[0] * bytes;
|
||||
for (int y=0; y<slice.size[1]; ++y) {
|
||||
auto srcY = srcZ + y * slice.src.stride[1] * bytes;
|
||||
auto S = (float*)srcY;
|
||||
for (int x=0; x<slice.size[2]; ++x) {
|
||||
auto S = (float*)srcY;
|
||||
if (bytes == 2) {
|
||||
funcFp16ToFp32((int16_t*)srcY, fp32Buffer, 1);
|
||||
S = fp32Buffer;
|
||||
}
|
||||
summer += S[slice.src.stride[2] * x];
|
||||
}
|
||||
}
|
||||
}
|
||||
((float*)dstPtr)[0] = summer;
|
||||
if (bytes == 4) {
|
||||
((float*)dstPtr)[0] = summer;
|
||||
} else {
|
||||
funcFp32ToFp16(&summer, (int16_t*)dstPtr, 1);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case 2:
|
||||
{
|
||||
float fp32Buffer[1];
|
||||
int sizeZ = slice.size[normalIndex[0]];
|
||||
int srcStrideZ = slice.src.stride[normalIndex[0]];
|
||||
int dstStrideZ = slice.dst.stride[normalIndex[0]];
|
||||
|
@ -473,12 +485,20 @@ static bool _reduceblit(const Tensor::InsideDescribe::Region& slice, int bytes,
|
|||
auto dstZ = dstPtr + z * dstStrideZ * bytes;
|
||||
for (int y=0; y<sizeY; ++y) {
|
||||
auto srcY = srcZ + y * srcStrideY * bytes;
|
||||
auto S = (float*)srcY;
|
||||
for (int x=0; x<sizeX; ++x) {
|
||||
auto S = (float*)srcY;
|
||||
if (bytes == 2) {
|
||||
funcFp16ToFp32((int16_t*)srcY, fp32Buffer, 1);
|
||||
S = fp32Buffer;
|
||||
}
|
||||
summer += S[srcStrideX * x];
|
||||
}
|
||||
}
|
||||
((float*)dstZ)[0] = summer;
|
||||
if (bytes == 4) {
|
||||
((float*)dstZ)[0] = summer;
|
||||
} else {
|
||||
funcFp32ToFp16(&summer, (int16_t*)dstZ, 1);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -493,6 +513,7 @@ static bool _reduceblit(const Tensor::InsideDescribe::Region& slice, int bytes,
|
|||
int sizeX = slice.size[reduceIndex[0]];
|
||||
int srcStrideX = slice.src.stride[reduceIndex[0]];
|
||||
int dstStrideX = slice.dst.stride[reduceIndex[0]];
|
||||
std::vector<float> fp32Buffer(sizeX);
|
||||
for (int z=0; z<sizeZ; ++z) {
|
||||
auto srcZ = srcPtr + z * srcStrideZ * bytes;
|
||||
auto dstZ = dstPtr + z * dstStrideZ * bytes;
|
||||
|
@ -500,11 +521,20 @@ static bool _reduceblit(const Tensor::InsideDescribe::Region& slice, int bytes,
|
|||
float summer = 0.0f;
|
||||
auto srcY = srcZ + y * srcStrideY * bytes;
|
||||
auto dstY = dstZ + y * dstStrideY * bytes;
|
||||
auto S = (float*)srcY;
|
||||
float* S = (float*)srcY;
|
||||
float* D = (float*)dstY;
|
||||
if (bytes == 2) {
|
||||
funcFp16ToFp32((int16_t*)srcY, fp32Buffer.data(), sizeX);
|
||||
S = fp32Buffer.data();
|
||||
}
|
||||
for (int x=0; x<sizeX; ++x) {
|
||||
summer += S[srcStrideX * x];
|
||||
}
|
||||
((float*)dstY)[0] = summer;
|
||||
if (bytes == 2) {
|
||||
funcFp32ToFp16(&summer, (int16_t*)dstY, 1);
|
||||
} else {
|
||||
D[0] = summer;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
@ -515,10 +545,10 @@ static bool _reduceblit(const Tensor::InsideDescribe::Region& slice, int bytes,
|
|||
return false;
|
||||
}
|
||||
|
||||
static void _blit(const Tensor::InsideDescribe::Region& slice, int bytes, const uint8_t* srcPtr, uint8_t* dstPtr, bool hasReduce) {
|
||||
static void _blit(const Tensor::InsideDescribe::Region& slice, int bytes, const uint8_t* srcPtr, uint8_t* dstPtr, bool hasReduce, FP16ToFP32 funcFp16ToFp32 = nullptr, FP32ToFP16 funcFp32ToFp16 = nullptr) {
|
||||
auto proc = _selectUnitProc(bytes, slice.src.stride[2], slice.dst.stride[2]);
|
||||
if (hasReduce) {
|
||||
if (_reduceblit(slice, bytes, srcPtr, dstPtr)) {
|
||||
if (_reduceblit(slice, bytes, srcPtr, dstPtr, funcFp16ToFp32, funcFp32ToFp16)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -607,10 +637,10 @@ ErrorCode CPURaster::onExecute(const std::vector<Tensor *> &____inputs, const st
|
|||
auto sourceFormat = TensorUtils::getDescribe(realInput)->dimensionFormat;
|
||||
auto destFormat = TensorUtils::getDescribe(output)->dimensionFormat;
|
||||
auto channelC4 = UP_DIV(srcChannel, core->pack);
|
||||
int batchStrideC4 = channelC4 * core->pack * srcArea * bytes;
|
||||
int batchStride = srcChannel * srcArea * bytes;
|
||||
int inputBatchStride = batchStride;
|
||||
int outputBatchStride = batchStride;
|
||||
auto batchStrideC4 = channelC4 * core->pack * srcArea * bytes;
|
||||
auto batchStride = srcChannel * srcArea * bytes;
|
||||
auto inputBatchStride = batchStride;
|
||||
auto outputBatchStride = batchStride;
|
||||
if (MNN_DATA_FORMAT_NC4HW4 == sourceFormat) {
|
||||
if (realInput->dimensions() <= 1) {
|
||||
::memcpy(output->host<uint8_t>(), realInput->host<uint8_t>(), realInput->elementSize() * bytes);
|
||||
|
@ -670,7 +700,7 @@ ErrorCode CPURaster::onExecute(const std::vector<Tensor *> &____inputs, const st
|
|||
}
|
||||
auto srcPtr = iter.first->host<uint8_t>() + slice.src.offset * bytes;
|
||||
auto dstPtr = (uint8_t*)mOutputPtr + slice.dst.offset * bytes;
|
||||
_blit(slice, bytes, srcPtr, dstPtr, mHasReduce);
|
||||
_blit(slice, bytes, srcPtr, dstPtr, mHasReduce, core->MNNLowpToFp32, core->MNNFp32ToLowp);
|
||||
}
|
||||
}
|
||||
MNN_CONCURRENCY_END();
|
||||
|
@ -1088,11 +1118,11 @@ public:
|
|||
break;
|
||||
case BinaryOpOperation_MUL:
|
||||
for (int z=0; z<sizeZ; ++z) {
|
||||
auto srcZ = srcF + z * dstStride[0];
|
||||
auto dstZ = dstF + z * outputStride[0];
|
||||
auto srcZ = srcF + z * outputStride[0];
|
||||
auto dstZ = dstF + z * dstStride[0];
|
||||
for (int y=0; y<sizeY; ++y) {
|
||||
auto srcY = srcZ + z * dstStride[1];
|
||||
auto dstY = dstZ + z * outputStride[1];
|
||||
auto srcY = srcZ + y * outputStride[1];
|
||||
auto dstY = dstZ + y * dstStride[1];
|
||||
for (int x=0; x<sizeX; ++x) {
|
||||
auto dstOffset = x * dstStride[2];
|
||||
dstY[dstOffset] = dstY[dstOffset] * srcY[x];
|
||||
|
@ -1102,16 +1132,14 @@ public:
|
|||
break;
|
||||
case BinaryOpOperation_SUB:
|
||||
for (int z=0; z<sizeZ; ++z) {
|
||||
auto srcZ = srcF + z * dstStride[0];
|
||||
auto dstZ = dstF + z * outputStride[0];
|
||||
auto srcZ = srcF + z * outputStride[0];
|
||||
auto dstZ = dstF + z * dstStride[0];
|
||||
for (int y=0; y<sizeY; ++y) {
|
||||
auto srcY = srcZ + z * dstStride[1];
|
||||
auto dstY = dstZ + z * outputStride[1];
|
||||
auto srcY = srcZ + y * outputStride[1];
|
||||
auto dstY = dstZ + y * dstStride[1];
|
||||
for (int x=0; x<sizeX; ++x) {
|
||||
auto dstOffset = x * dstStride[2];
|
||||
auto D = dstY[dstOffset];
|
||||
auto S = srcY[x];
|
||||
dstY[dstOffset] = D - S;
|
||||
dstY[dstOffset] = dstY[dstOffset] - srcY[x];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -119,54 +119,60 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) {
|
|||
mPastKey.reset(new_key);
|
||||
}
|
||||
else if (mConfig.mQuantKey) {
|
||||
auto new_key = Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP});
|
||||
auto new_key = Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP});
|
||||
mBackend->onAcquireBuffer(new_key, Backend::STATIC);
|
||||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(
|
||||
new_key->host<char>() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP,
|
||||
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP,
|
||||
UP_DIV(oldMaxLength, hP) * mHeadDim * hP
|
||||
new_key->host<char>() + h * new_key->stride(0),
|
||||
mPastKey->host<char>() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP),
|
||||
ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP)
|
||||
);
|
||||
}
|
||||
mPastKey.reset(new_key);
|
||||
}
|
||||
else {
|
||||
auto new_key = Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP});
|
||||
auto new_key = Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP});
|
||||
mBackend->onAcquireBuffer(new_key, Backend::STATIC);
|
||||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(
|
||||
new_key->host<char>() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes,
|
||||
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes,
|
||||
UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes
|
||||
new_key->host<char>() + h * new_key->stride(0) * mBytes,
|
||||
mPastKey->host<char>() + h * ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes,
|
||||
ROUND_UP(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes
|
||||
);
|
||||
if ((new_key->stride(0) - mPastKey->stride(0)) > 0) {
|
||||
memset(new_key->host<char>() + h * new_key->stride(0) * mBytes + mPastKey->stride(0) * mBytes, 0, (new_key->stride(0) - mPastKey->stride(0)) * mBytes);
|
||||
}
|
||||
}
|
||||
mPastKey.reset(new_key);
|
||||
}
|
||||
/*=================================== Value ===================================*/
|
||||
if (mConfig.mQuantValue) {
|
||||
auto new_value = Tensor::createDevice<fp8_t>({mKvNumHead, UP_DIV(mHeadDim, hP), mMaxLength, hP});
|
||||
auto new_value = Tensor::createDevice<fp8_t>({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(mMaxLength, lP), hP, lP});
|
||||
mBackend->onAcquireBuffer(new_value, Backend::STATIC);
|
||||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
||||
memcpy(
|
||||
new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP,
|
||||
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP,
|
||||
oldMaxLength * hP
|
||||
new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP,
|
||||
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
|
||||
ROUND_UP(oldMaxLength, lP) * hP
|
||||
);
|
||||
}
|
||||
}
|
||||
mPastValue.reset(new_value);
|
||||
}
|
||||
else {
|
||||
auto new_value = Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), mMaxLength, hP});
|
||||
auto new_value = Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(mMaxLength, lP), hP, lP});
|
||||
mBackend->onAcquireBuffer(new_value, Backend::STATIC);
|
||||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
||||
memcpy(
|
||||
new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes,
|
||||
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes,
|
||||
oldMaxLength * hP * mBytes
|
||||
new_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes,
|
||||
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
|
||||
ROUND_UP(oldMaxLength, lP) * hP * mBytes
|
||||
);
|
||||
if ((new_value->stride(1) - mPastValue->stride(1)) > 0) {
|
||||
memset(new_value->host<char>() + (h * new_value->stride(0) + i * new_value->stride(1)) * mBytes + mPastValue->stride(1) * mBytes, 0, (new_value->stride(1) - mPastValue->stride(1)) * mBytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
mPastValue.reset(new_value);
|
||||
|
@ -193,20 +199,23 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
|
|||
if (mConfig.mQuantKey) {
|
||||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(
|
||||
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP,
|
||||
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP,
|
||||
UP_DIV(oldMaxLength, hP) * mHeadDim * hP
|
||||
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
|
||||
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
|
||||
UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP
|
||||
);
|
||||
}
|
||||
mBackend->onReleaseBuffer(mPastKey.get(), Backend::STATIC);
|
||||
mPastKey.reset();
|
||||
}
|
||||
else {
|
||||
if (mHeadDim % lP) {
|
||||
memset(mMapKeyAddr, 0, mKvNumHead * ROUND_UP(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes );
|
||||
}
|
||||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(
|
||||
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes,
|
||||
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes,
|
||||
UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes
|
||||
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
|
||||
mPastKey->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
|
||||
UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes
|
||||
);
|
||||
}
|
||||
mBackend->onReleaseBuffer(mPastKey.get(), Backend::STATIC);
|
||||
|
@ -217,9 +226,9 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
|
|||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
||||
memcpy(
|
||||
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP,
|
||||
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP,
|
||||
oldMaxLength * hP
|
||||
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP,
|
||||
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
|
||||
ROUND_UP(oldMaxLength, lP) * hP
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -227,12 +236,15 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) {
|
|||
mPastValue.reset();
|
||||
}
|
||||
else {
|
||||
if (lP > 1) {
|
||||
memset(mMapValueAddr, 0, mKvNumHead * ROUND_UP(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * mBytes);
|
||||
}
|
||||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
||||
memcpy(
|
||||
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes,
|
||||
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes,
|
||||
oldMaxLength * hP * mBytes
|
||||
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes,
|
||||
mPastValue->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
|
||||
ROUND_UP(oldMaxLength, lP) * hP * mBytes
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -250,17 +262,25 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
|
|||
if (mConfig.mUseInt8Kernel) {
|
||||
old_key.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(oldMaxLength, hP8), UP_DIV(mHeadDim, lP8), hP8 * lP8}));
|
||||
} else if (mConfig.mQuantKey) {
|
||||
old_key.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(oldMaxLength, hP), mHeadDim, hP}));
|
||||
old_key.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(oldMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP}));
|
||||
} else {
|
||||
old_key.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(oldMaxLength, hP), mHeadDim, hP}));
|
||||
old_key.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(oldMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP}));
|
||||
}
|
||||
if (mConfig.mQuantValue) {
|
||||
old_value.reset(Tensor::createDevice<fp8_t>({mKvNumHead, UP_DIV(mHeadDim, hP), oldMaxLength, hP}));
|
||||
old_value.reset(Tensor::createDevice<fp8_t>({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(oldMaxLength, lP), hP, lP}));
|
||||
} else {
|
||||
old_value.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), oldMaxLength, hP}));
|
||||
old_value.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(oldMaxLength, lP), hP, lP}));
|
||||
}
|
||||
mBackend->onAcquireBuffer(old_key.get(), Backend::STATIC);
|
||||
mBackend->onAcquireBuffer(old_value.get(), Backend::STATIC);
|
||||
if (mHeadDim % lP) {
|
||||
memset(old_key->host<uint8_t>(), 0, old_key->length(0) * old_key->stride(0) * mBytes);
|
||||
}
|
||||
if (lP > 1) {
|
||||
// can't be mMaxLenth % lP, since mMaxLength may be larger than seq_len for prefilling, we should ensure the (mMaxLength - seq_len)'s buffer is 0.
|
||||
// computing L is seq_len
|
||||
memset(old_value->host<uint8_t>(), 0, old_value->length(0) * old_value->stride(0) * mBytes);
|
||||
}
|
||||
mmapKVCache(oldKeySize, oldValueSize);
|
||||
memcpy(old_key->host<char>(), mMapKeyAddr, oldKeySize);
|
||||
memcpy(old_value->host<char>(), mMapValueAddr, oldValueSize);
|
||||
|
@ -280,17 +300,17 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
|
|||
} else if (mConfig.mQuantKey) {
|
||||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(
|
||||
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP,
|
||||
old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP,
|
||||
UP_DIV(oldMaxLength, hP) * mHeadDim * hP
|
||||
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
|
||||
old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP,
|
||||
UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP
|
||||
);
|
||||
}
|
||||
} else {
|
||||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
memcpy(
|
||||
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes,
|
||||
old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes,
|
||||
UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes
|
||||
mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
|
||||
old_key->host<char>() + h * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes,
|
||||
UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -298,9 +318,9 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
|
|||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
||||
memcpy(
|
||||
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP,
|
||||
old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP,
|
||||
oldMaxLength * hP
|
||||
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP,
|
||||
old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP,
|
||||
ROUND_UP(oldMaxLength, lP) * hP
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -308,9 +328,9 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int o
|
|||
for (int h = 0; h < mKvNumHead; h++) {
|
||||
for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) {
|
||||
memcpy(
|
||||
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes,
|
||||
old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes,
|
||||
oldMaxLength * hP * mBytes
|
||||
mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(mMaxLength, lP) * hP * mBytes,
|
||||
old_value->host<char>() + (h * UP_DIV(mHeadDim, hP) + i) * ROUND_UP(oldMaxLength, lP) * hP * mBytes,
|
||||
ROUND_UP(oldMaxLength, lP) * hP * mBytes
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -341,11 +361,11 @@ void KVCacheManager::onAlloc(int kv_seq_len) {
|
|||
if (mConfig.mUseInt8Kernel) {
|
||||
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
|
||||
} else if (mConfig.mQuantKey) {
|
||||
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP;
|
||||
keySize = (size_t)mKvNumHead * ROUND_UP(mMaxLength, hP) * ROUND_UP(mHeadDim, lP);
|
||||
} else {
|
||||
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes;
|
||||
keySize = (size_t)mKvNumHead * ROUND_UP(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * mBytes;
|
||||
}
|
||||
valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
||||
valueSize = (size_t)mKvNumHead * ROUND_UP(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * (mConfig.mQuantValue ? 1 : mBytes);
|
||||
/*============== Put the kvcache in disk ===========*/
|
||||
if (mConfig.mKVCacheSizeLimit != -1 && keySize + valueSize > mConfig.mKVCacheSizeLimit) {
|
||||
createKVCacheFile();
|
||||
|
@ -358,17 +378,23 @@ void KVCacheManager::onAlloc(int kv_seq_len) {
|
|||
if (mConfig.mUseInt8Kernel) {
|
||||
mPastKey.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mMaxLength, hP8), UP_DIV(mHeadDim, lP8), hP8 * lP8}));
|
||||
} else if (mConfig.mQuantKey) {
|
||||
mPastKey.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP}));
|
||||
mPastKey.reset(Tensor::createDevice<int8_t>({mKvNumHead, UP_DIV(mMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP}));
|
||||
} else {
|
||||
mPastKey.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP}));
|
||||
mPastKey.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), UP_DIV(mHeadDim, lP), hP, lP}));
|
||||
}
|
||||
if (mConfig.mQuantValue) {
|
||||
mPastValue.reset(Tensor::createDevice<fp8_t>({mKvNumHead, UP_DIV(mHeadDim, hP), mMaxLength, hP}));
|
||||
mPastValue.reset(Tensor::createDevice<fp8_t>({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(mMaxLength, lP), hP, lP}));
|
||||
} else {
|
||||
mPastValue.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), mMaxLength, hP}));
|
||||
mPastValue.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mHeadDim, hP), UP_DIV(mMaxLength, lP), hP, lP}));
|
||||
}
|
||||
mBackend->onAcquireBuffer(mPastKey.get(), Backend::STATIC);
|
||||
mBackend->onAcquireBuffer(mPastValue.get(), Backend::STATIC);
|
||||
if (mHeadDim % lP) {
|
||||
memset(mPastKey->host<int8_t>(), 0, mPastKey->length(0) * mPastKey->stride(0) * mBytes);
|
||||
}
|
||||
if (lP > 1) { // can't be mMaxLenth % lP, since mMaxLength may be larger than seq_len for prefilling, we should ensure the (mMaxLength - seq_len)'s buffer is 0.
|
||||
memset(mPastValue->host<int8_t>(), 0, mPastValue->length(0) * mPastValue->stride(0) * mBytes);
|
||||
}
|
||||
}
|
||||
// scale, zero point and sum of key for quantization
|
||||
if (mConfig.mUseInt8Kernel) {
|
||||
|
@ -397,14 +423,14 @@ void KVCacheManager::onRealloc(const KVMeta* meta) {
|
|||
oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
|
||||
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
|
||||
} else if (mConfig.mQuantKey) {
|
||||
oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * mHeadDim * hP;
|
||||
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP;
|
||||
oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP;
|
||||
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP;
|
||||
} else {
|
||||
oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes;
|
||||
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes;
|
||||
oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes;
|
||||
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes;
|
||||
}
|
||||
oldValueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * oldMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
||||
valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
||||
oldValueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * ROUND_UP(oldMaxLength, lP) * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
||||
valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
||||
/*==== No limit for kvcache ====*/
|
||||
if (mConfig.mKVCacheSizeLimit == -1) {
|
||||
expandKVCacheInMem(oldMaxLength);
|
||||
|
@ -485,14 +511,14 @@ void KVCacheManager::onRealloc(const KVMeta* meta) {
|
|||
// mPastKey.reset(Tensor::createDevice<float>({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP}));
|
||||
|
||||
// Move K
|
||||
auto keyStride = UP_DIV(mMaxLength, align) * align * mHeadDim;
|
||||
auto dstKAddr = keyAddr() + dstStartAlign * mHeadDim * mBytes;
|
||||
auto srcKAddr = keyAddr() + startAlign * mHeadDim * mBytes;
|
||||
auto keyStride = UP_DIV(mMaxLength, align) * align * ROUND_UP(mHeadDim, lP);
|
||||
auto dstKAddr = keyAddr() + dstStartAlign * ROUND_UP(mHeadDim, lP) * mBytes;
|
||||
auto srcKAddr = keyAddr() + startAlign * ROUND_UP(mHeadDim, lP) * mBytes;
|
||||
for (int i=0; i<mKvNumHead; ++i) {
|
||||
auto dst = dstKAddr + i * keyStride * mBytes;
|
||||
auto src = srcKAddr + i * keyStride * mBytes;
|
||||
for (int j=0; j<sizeUnit; ++j) {
|
||||
::memcpy(dst + j * align * mHeadDim * mBytes, src + j * align * mHeadDim * mBytes, align * mHeadDim * mBytes);
|
||||
::memcpy(dst + j * align * ROUND_UP(mHeadDim, lP) * mBytes, src + j * align * ROUND_UP(mHeadDim, lP) * mBytes, align * ROUND_UP(mHeadDim, lP) * mBytes);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -503,8 +529,8 @@ void KVCacheManager::onRealloc(const KVMeta* meta) {
|
|||
auto srcVAddr = valudAddr() + startAlign * align * mBytes;
|
||||
auto number = mKvNumHead * UP_DIV(mHeadDim, align);
|
||||
for (int i=0; i<number; ++i) {
|
||||
auto dst = dstVAddr + i * mMaxLength * align * mBytes;
|
||||
auto src = srcVAddr + i * mMaxLength * align * mBytes;
|
||||
auto dst = dstVAddr + i * ROUND_UP(mMaxLength, lP) * align * mBytes;
|
||||
auto src = srcVAddr + i * ROUND_UP(mMaxLength, lP) * align * mBytes;
|
||||
for (int j=0; j<sizeUnit; ++j) {
|
||||
::memcpy(dst + j * align * align * mBytes, src + j * align * align * mBytes, align * align * mBytes);
|
||||
}
|
||||
|
@ -521,11 +547,11 @@ void KVCacheManager::onClear() {
|
|||
if (mConfig.mUseInt8Kernel) {
|
||||
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
|
||||
} else if (mConfig.mQuantKey) {
|
||||
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP;
|
||||
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP;
|
||||
} else {
|
||||
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes;
|
||||
keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes;
|
||||
}
|
||||
valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
||||
valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP * (mConfig.mQuantValue ? 1 : mBytes);
|
||||
unmapKVCache(keySize, valueSize);
|
||||
removeKVCacheFile();
|
||||
mKVCacheInDisk = false;
|
||||
|
@ -584,14 +610,16 @@ void KVCacheManager::pack_key(const Tensor* key, int seq_len, int kv_h) {
|
|||
}
|
||||
}
|
||||
}
|
||||
else { // [maxlen/hP, headdim, hP]
|
||||
else { // target: [maxlen/hP, headdim/lP, hP, lP]
|
||||
T * key_dst = reinterpret_cast<T*>(addrOfKey(kv_h));
|
||||
auto stride0 = ROUND_UP(mHeadDim, lP) * hP;
|
||||
auto stride1 = hP * lP;
|
||||
for (int i = 0; i < seq_len; i++) {
|
||||
T * key_src = key->host<T>() + i * mKvNumHead * mHeadDim + kv_h * mHeadDim;
|
||||
int out_index = (mPastLength + i) / hP;
|
||||
int in_index = (mPastLength + i) % hP;
|
||||
for (int j = 0; j < mHeadDim; j++) {
|
||||
key_dst[out_index * mHeadDim * hP + j * hP + in_index] = key_src[j];
|
||||
key_dst[out_index * stride0 + (j / lP) * stride1 + in_index * lP + (j % lP)] = key_src[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -618,13 +646,18 @@ void KVCacheManager::pack_value(const Tensor* value, int seq_len, int kv_h) { //
|
|||
MNNMemoryFreeAlign(buf);
|
||||
}
|
||||
else {
|
||||
// [mHeadDim/hP, mMaxLength/lP, hP, lP]
|
||||
auto stride0 = ROUND_UP(mMaxLength, lP) * hP;
|
||||
auto stride1 = hP * lP;
|
||||
T * value_dst = reinterpret_cast<T*>(addrOfValue(kv_h));
|
||||
for (int i = 0; i < seq_len; i++) {
|
||||
T * value_src = value->host<T>() + i * mKvNumHead * mHeadDim + kv_h * mHeadDim;
|
||||
int seqLenOut = (mPastLength + i) / lP;
|
||||
int seqLenIn = (mPastLength + i) % lP;
|
||||
for (int j = 0; j < mHeadDim; j++) {
|
||||
int out_index = j / hP;
|
||||
int in_index = j % hP;
|
||||
value_dst[out_index * mMaxLength * hP + (mPastLength + i) * hP + in_index] = value_src[j];
|
||||
value_dst[out_index * stride0 + seqLenOut * stride1 + in_index * lP + seqLenIn] = value_src[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -116,17 +116,17 @@ public:
|
|||
if (mConfig.mUseInt8Kernel) {
|
||||
return baseAddr + kv_h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8;
|
||||
} else if (mConfig.mQuantKey) {
|
||||
return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * mHeadDim * hP;
|
||||
return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP;
|
||||
} else {
|
||||
return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes;
|
||||
return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * ROUND_UP(mHeadDim, lP) * hP * mBytes;
|
||||
}
|
||||
}
|
||||
char * addrOfValue(int kv_h) {
|
||||
char * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host<char>();
|
||||
if (mConfig.mQuantValue) {
|
||||
return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * mMaxLength * hP;
|
||||
return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP;
|
||||
} else {
|
||||
return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * mMaxLength * hP * mBytes;
|
||||
return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * ROUND_UP(mMaxLength, lP) * hP * mBytes;
|
||||
}
|
||||
}
|
||||
char * addrOfScale(int kv_h) {
|
||||
|
|
|
@ -103,7 +103,7 @@ if (MNN_KLEIDIAI)
|
|||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.c
|
||||
)
|
||||
|
||||
set_source_files_properties(${MNN_SOURCES_KLEIDIAI} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+i8mm+dotprod+sve+sve2+fp16")
|
||||
set_source_files_properties(${MNN_SOURCES_KLEIDIAI} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm+dotprod+sve+sve2+fp16)
|
||||
set_source_files_properties(${KLEIDIAI_FILES_SME2} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+sve+sve2")
|
||||
|
||||
endif()
|
||||
|
@ -121,7 +121,13 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv7" OR ARCHS MATCHES "^armv7(;armv7s)?")
|
|||
endif()
|
||||
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64" OR ARCHS STREQUAL "ARM64")
|
||||
message(STATUS "Enabling AArch64 Assemblies")
|
||||
add_library(MNNARM64 OBJECT ${MNN_AArch64_SRC} ${MNN_NEON_SRC} ${MNN_SOURCES_KLEIDIAI} ${KLEIDIAI_FILES_SME2})
|
||||
if (MNN_SME2)
|
||||
add_definitions(-DMNN_SME2)
|
||||
FILE(GLOB MNN_SME2_AArch64_SRC ${MNN_SME2_AArch64_SRC} ${CMAKE_CURRENT_LIST_DIR}/arm64/sme2_asm/*.[sS])
|
||||
#set_source_files_properties(${MNN_SME2_AArch64_SRC} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.6-a+sve+sve2+sme+sme2+fp16")
|
||||
set_source_files_properties(${MNN_SME2_SRCS_ASM_FP16} PROPERTIES COMPILE_OPTIONS "-fno-tree-vectorize;-march=armv8.2-a+fp16")
|
||||
endif()
|
||||
add_library(MNNARM64 OBJECT ${MNN_AArch64_SRC} ${MNN_NEON_SRC} ${MNN_SOURCES_KLEIDIAI} ${KLEIDIAI_FILES_SME2} ${MNN_SME2_AArch64_SRC})
|
||||
target_include_directories(MNNARM64 PRIVATE ${CMAKE_CURRENT_LIST_DIR}/)
|
||||
list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNNARM64>)
|
||||
list(APPEND MNN_TARGETS MNNARM64)
|
||||
|
|
|
@ -907,9 +907,10 @@ void MNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
|
|||
|
||||
// input shape is (l, h) when transpose=false, else input shape is (h, l)
|
||||
// output shape is (UP_DIV(h, 8), l, 8)
|
||||
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose) {
|
||||
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
||||
auto hP = (int)h / 8;
|
||||
auto hR = (int)hP * 8;
|
||||
auto l = kernelsize * ic;
|
||||
if (hR != h) {
|
||||
::memset(dest, 0, UP_DIV(h, 8)*8*l*sizeof(float));
|
||||
}
|
||||
|
@ -952,7 +953,8 @@ void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bo
|
|||
}
|
||||
}
|
||||
#else
|
||||
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose) {
|
||||
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
||||
auto l = kernelsize * ic;
|
||||
if (!transpose) {
|
||||
auto hP = h / 4;
|
||||
auto hR = hP * 4;
|
||||
|
|
|
@ -28,7 +28,7 @@ void NEON_MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup
|
|||
const int32_t* el);
|
||||
|
||||
|
||||
void NEON_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose);
|
||||
void NEON_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
|
||||
|
||||
void NEON_MNNPackedMatMul_BF16(float* C, const float* A, const float* B, const size_t* parameter,
|
||||
const float* postParameters, const float* bias, const float* k, const float* b);
|
||||
|
@ -44,7 +44,7 @@ void MNNPackC4_BF16(float* dest, const float* source, size_t area, size_t depth,
|
|||
#ifdef __aarch64__
|
||||
void MNNPackC8_BF16(float* dest, const float* source, size_t l, size_t h);
|
||||
void ARMV86_MNNGetMatMulPackMode_BF16(int* eP, int* lP, int* hP);
|
||||
void ARMV86_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose);
|
||||
void ARMV86_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
|
||||
void ARMV86_MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);
|
||||
void ARMV86_MNNPackedMatMul_BF16(float* C, const float* A, const float* B, const size_t* parameter,
|
||||
const float* postParameters, const float* bias, const float* k, const float* b);
|
||||
|
|
|
@ -90,8 +90,12 @@ sub \rg0, \rg0, \rg1
|
|||
.endm
|
||||
|
||||
.macro REVERT_WEIGHT_KERNEL_SUM rg0, rg1, rg2, rg3
|
||||
mul \rg1, \rg2, \rg3 // blocknum * ocUp4 * sizeof(float)
|
||||
sub \rg0, \rg0, \rg1, LSL #4 // revert weight kernel sum
|
||||
// y=UP_DIV(ocDiv4,(hp/pack))
|
||||
add \rg1, \rg3, #1
|
||||
lsr \rg1, \rg1, #1
|
||||
// blockNum * y * (hp * sizeof(float))
|
||||
mul \rg1, \rg2, \rg1
|
||||
sub \rg0, \rg0, \rg1, LSL #5 // revert weight kernel sum
|
||||
.endm
|
||||
asm_function MNNGemmInt8AddBiasScale_ARMV82_Unit
|
||||
/*
|
||||
|
@ -435,9 +439,11 @@ L4_TILE12_BLOCKNUM:
|
|||
|
||||
|
||||
L4_Tile12Quan:
|
||||
ld1 {v0.4s}, [x2], #16 // scale
|
||||
ld1 {v0.4s}, [x2] // scale
|
||||
add x2, x2, #32
|
||||
ld1 {v2.4s, v3.4s, v4.4s}, [x8], #48 // x kernel sum
|
||||
ld1 {v5.4s}, [x2], #16 // weight quan zeropoint
|
||||
ld1 {v5.4s}, [x2] // weight quan zeropoint
|
||||
add x2, x2, #32
|
||||
Int32ToFloat v8, v9, v10, v11
|
||||
Int32ToFloat v12, v13, v14, v15
|
||||
Int32ToFloat v16, v17, v18, v19
|
||||
|
@ -470,7 +476,7 @@ L4_TILE12_BLOCKNUM:
|
|||
|
||||
cbz x27, L4_TILE12_ADD_DSTV
|
||||
ld1 {v0.4s, v1.4s, v2.4s}, [x27], #48 // input dequant bias
|
||||
ld1 {v3.4s}, [x28], #16 // weight kernel sum
|
||||
ld1 {v3.4s}, [x28] // weight kernel sum
|
||||
MLA_WEIGHTZERO v8, v0, v3, 0
|
||||
MLA_WEIGHTZERO v9, v0, v3, 1
|
||||
MLA_WEIGHTZERO v10, v0, v3, 2
|
||||
|
@ -483,6 +489,7 @@ L4_TILE12_BLOCKNUM:
|
|||
MLA_WEIGHTZERO v17, v2, v3, 1
|
||||
MLA_WEIGHTZERO v18, v2, v3, 2
|
||||
MLA_WEIGHTZERO v19, v2, v3, 3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE12_ADD_DSTV:
|
||||
cbz x19, L4_TILE12_ACCUM_BUFFER // x19=0: first block, do not add previous block result
|
||||
|
@ -782,9 +789,11 @@ L4_TILE8_BLOCKNUM:
|
|||
bne L4LoopSz_TILE_8
|
||||
|
||||
L4Tile8Quan:
|
||||
ld1 {v0.4s}, [x12], #16 // scale
|
||||
ld1 {v0.4s}, [x12] // scale
|
||||
add x12, x12, #32
|
||||
ld1 {v2.4s, v3.4s}, [x8], x22 // x kernel sum
|
||||
ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
|
||||
ld1 {v24.4s}, [x12] // weight quan zeropoint
|
||||
add x12, x12, #32
|
||||
Int32ToFloat v8, v9, v10, v11
|
||||
Int32ToFloat v12, v13, v14, v15
|
||||
MUL_SCALE v0, v8, v9, v10, v11
|
||||
|
@ -808,7 +817,7 @@ L4_TILE8_BLOCKNUM:
|
|||
|
||||
cbz x27, L4_TILE8_ADD_DSTV
|
||||
ld1 {v2.4s, v3.4s}, [x27], x25
|
||||
ld1 {v24.4s}, [x28], #16
|
||||
ld1 {v24.4s}, [x28]
|
||||
MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3
|
||||
MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3
|
||||
MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3
|
||||
|
@ -817,6 +826,7 @@ L4_TILE8_BLOCKNUM:
|
|||
MLA_WEIGHTZERO v13, v3, v24, 1 // tile:5, oc:0-3
|
||||
MLA_WEIGHTZERO v14, v3, v24, 2 // tile:6, oc:0-3
|
||||
MLA_WEIGHTZERO v15, v3, v24, 3 // tile:7, oc:0-3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE8_ADD_DSTV:
|
||||
cbz x19, TILE8_L4_ACCUM_BUFFER
|
||||
|
@ -1055,9 +1065,11 @@ L4_TILE4_BLOCKNUM:
|
|||
bne L4LoopSz_TILE_4
|
||||
|
||||
L4Tile4Quan:
|
||||
ld1 {v0.4s}, [x12], #16 // scale
|
||||
ld1 {v0.4s}, [x12] // scale
|
||||
add x12, x12, #32
|
||||
ld1 {v2.4s}, [x8], x22 // x kernel sum
|
||||
ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
|
||||
ld1 {v24.4s}, [x12] // weight quan zeropoint
|
||||
add x12, x12, #32
|
||||
Int32ToFloat v8, v9, v10, v11
|
||||
MUL_SCALE v0, v8, v9, v10, v11
|
||||
|
||||
|
@ -1074,11 +1086,12 @@ L4_TILE4_BLOCKNUM:
|
|||
|
||||
cbz x27, L4_TILE4_ADD_DSTV
|
||||
ld1 {v2.4s}, [x27], x25
|
||||
ld1 {v24.4s}, [x28], #16
|
||||
ld1 {v24.4s}, [x28]
|
||||
MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3
|
||||
MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3
|
||||
MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3
|
||||
MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE4_ADD_DSTV:
|
||||
cbz x19, TILE4_L4_ACCUM_BUFFER
|
||||
|
@ -1281,9 +1294,11 @@ L4_TILE1_BLOCKNUM:
|
|||
bne L4LoopSz_TILE_1
|
||||
|
||||
L4Tile1Quan:
|
||||
ld1 {v0.4s}, [x12], #16 // scale
|
||||
ld1 {v0.4s}, [x12] // scale
|
||||
add x12, x12, #32
|
||||
ld1 {v2.s}[0], [x8], x22 // x kernel sum
|
||||
ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
|
||||
ld1 {v24.4s}, [x12] // weight quan zeropoint
|
||||
add x12, x12, #32
|
||||
scvtf v8.4s, v8.4s
|
||||
fmul v8.4s, v8.4s, v0.4s
|
||||
|
||||
|
@ -1297,8 +1312,9 @@ L4_TILE1_BLOCKNUM:
|
|||
|
||||
cbz x27, L4_TILE1_ADD_DSTV
|
||||
ld1 {v2.s}[0], [x27], x25
|
||||
ld1 {v24.4s}, [x28], #16
|
||||
ld1 {v24.4s}, [x28]
|
||||
MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE1_ADD_DSTV:
|
||||
cbz x19, L4_TILE1_L8_ACCUM_BUFFER
|
||||
|
|
|
@ -101,8 +101,12 @@ sub \rg0, \rg0, \rg1
|
|||
.endm
|
||||
|
||||
.macro REVERT_WEIGHT_KERNEL_SUM rg0, rg1, rg2, rg3
|
||||
mul \rg1, \rg2, \rg3 // blocknum * ocUp4 * sizeof(float)
|
||||
sub \rg0, \rg0, \rg1, LSL #4 // revert weight kernel sum
|
||||
// y=UP_DIV(ocDiv4,(hp/pack))
|
||||
add \rg1, \rg3, #1
|
||||
lsr \rg1, \rg1, #1
|
||||
// blockNum * y * (hp * sizeof(float))
|
||||
mul \rg1, \rg2, \rg1
|
||||
sub \rg0, \rg0, \rg1, LSL #5 // revert weight kernel sum
|
||||
.endm
|
||||
asm_function MNNGemmInt8AddBiasScale_ARMV86_Unit
|
||||
/*
|
||||
|
@ -500,10 +504,12 @@ L4_LoopSz_TILE_10:
|
|||
scvtf v9.4s, v9.4s
|
||||
|
||||
L4_Tile10Quan:
|
||||
ld1 {v20.4s}, [x2], #16 // weight scale
|
||||
ld1 {v20.4s}, [x2] // weight scale
|
||||
add x2, x2, #32
|
||||
ld1 {v22.4s, v23.4s}, [x8], #32 // x kernel sum
|
||||
ld1 {v24.d}[0], [x8], #8
|
||||
ld1 {v25.4s}, [x2], #16 // weight quan zeropoint
|
||||
ld1 {v25.4s}, [x2] // weight quan zeropoint
|
||||
add x2, x2, #32
|
||||
MUL_SCALE v20, v0, v1, v2, v3
|
||||
MUL_SCALE v20, v4, v5, v6, v7
|
||||
fmul v8.4s, v8.4s, v20.4s
|
||||
|
@ -534,7 +540,7 @@ L4_Tile10Quan:
|
|||
cbz x27, L4_TILE10_ADD_DSTV
|
||||
ld1 {v22.4s, v23.4s}, [x27], #32 // input dequant bias
|
||||
ld1 {v24.2s}, [x27], #8
|
||||
ld1 {v25.4s}, [x28], #16 // weight kernel sum
|
||||
ld1 {v25.4s}, [x28] // weight kernel sum
|
||||
MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3
|
||||
MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3
|
||||
MLA_WEIGHTZERO v2, v22, v25, 2 // tile:2, oc:0-3
|
||||
|
@ -545,6 +551,7 @@ L4_Tile10Quan:
|
|||
MLA_WEIGHTZERO v7, v23, v25, 3 // tile:7, oc:0-3
|
||||
MLA_WEIGHTZERO v8, v24, v25, 0 // tile:8, oc:0-3
|
||||
MLA_WEIGHTZERO v9, v24, v25, 1 // tile:9, oc:0-3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE10_ADD_DSTV:
|
||||
cbz x19, L4_TILE10_TEMP_BUFFER
|
||||
|
@ -902,9 +909,11 @@ LoopSz4End_TILE_8:
|
|||
Int32ToFloat v4, v5, v6, v7
|
||||
|
||||
L4_Tile8Quan:
|
||||
ld1 {v20.4s}, [x12], #16 // scale
|
||||
ld1 {v20.4s}, [x12] // scale
|
||||
add x12, x12, #32
|
||||
ld1 {v22.4s, v23.4s}, [x8] // x kernel sum
|
||||
ld1 {v25.4s}, [x12], #16 // weight quan zeropoint
|
||||
ld1 {v25.4s}, [x12] // weight quan zeropoint
|
||||
add x12, x12, #32
|
||||
add x8, x8, x22, LSR #1
|
||||
MUL_SCALE v20, v0, v1, v2, v3
|
||||
MUL_SCALE v20, v4, v5, v6, v7
|
||||
|
@ -928,7 +937,7 @@ L4_Tile8Quan:
|
|||
|
||||
cbz x27, L4_TILE8_ADD_DSTV
|
||||
ld1 {v22.4s, v23.4s}, [x27], x25 // input dequant bias
|
||||
ld1 {v25.4s}, [x28], #16 // weight kernel sum
|
||||
ld1 {v25.4s}, [x28] // weight kernel sum
|
||||
MLA_WEIGHTZERO v0, v22, v25, 0
|
||||
MLA_WEIGHTZERO v1, v22, v25, 1
|
||||
MLA_WEIGHTZERO v2, v22, v25, 2
|
||||
|
@ -937,6 +946,7 @@ L4_Tile8Quan:
|
|||
MLA_WEIGHTZERO v5, v23, v25, 1
|
||||
MLA_WEIGHTZERO v6, v23, v25, 2
|
||||
MLA_WEIGHTZERO v7, v23, v25, 3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE8_ADD_DSTV:
|
||||
cbz x19, L4_TILE8_TEMP_BUFFER
|
||||
|
@ -1182,9 +1192,11 @@ L4_LoopSz_TILE_4:
|
|||
Int32ToFloat v0, v1, v2, v3
|
||||
|
||||
L4_Tile4Quan:
|
||||
ld1 {v20.4s}, [x12], #16 // scale
|
||||
ld1 {v20.4s}, [x12] // scale
|
||||
add x12, x12, #32
|
||||
ld1 {v22.4s}, [x8] // x kernel sum
|
||||
ld1 {v25.4s}, [x12], #16 // weight quan zeropoint
|
||||
ld1 {v25.4s}, [x12] // weight quan zeropoint
|
||||
add x12, x12, #32
|
||||
MUL_SCALE v20, v0, v1, v2, v3
|
||||
add x8, x8, x22, LSR #1
|
||||
|
||||
|
@ -1202,11 +1214,12 @@ L4_Tile4Quan:
|
|||
|
||||
cbz x27, L4_TILE4_ADD_DSTV
|
||||
ld1 {v22.4s}, [x27], x25 // input dequant bias
|
||||
ld1 {v25.4s}, [x28], #16 // weight kernel sum
|
||||
ld1 {v25.4s}, [x28] // weight kernel sum
|
||||
MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3
|
||||
MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3
|
||||
MLA_WEIGHTZERO v2, v22, v25, 2 // tile:2, oc:0-3
|
||||
MLA_WEIGHTZERO v3, v22, v25, 3 // tile:3, oc:0-3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE4_ADD_DSTV:
|
||||
cbz x19, L4_TILE4_ACCUM_BUFFER
|
||||
|
@ -1414,9 +1427,11 @@ L4_LoopSz_TILE_2:
|
|||
scvtf v1.4s, v1.4s
|
||||
|
||||
L4_Tile2Quan:
|
||||
ld1 {v20.4s}, [x12], #16 // scale
|
||||
ld1 {v20.4s}, [x12] // scale
|
||||
add x12, x12, #32
|
||||
ld1 {v22.d}[0], [x8] // x kernel sum
|
||||
ld1 {v25.4s}, [x12], #16 // weight quan zeropoint
|
||||
ld1 {v25.4s}, [x12] // weight quan zeropoint
|
||||
add x12, x12, #32
|
||||
fmul v0.4s, v0.4s, v20.4s
|
||||
fmul v1.4s, v1.4s, v20.4s
|
||||
add x8, x8, x22, LSR #1
|
||||
|
@ -1434,9 +1449,10 @@ L4_Tile2Quan:
|
|||
|
||||
cbz x27, L4_TILE2_ADD_DSTV
|
||||
ld1 {v22.2s}, [x27], x25 // input dequant bias
|
||||
ld1 {v25.4s}, [x28], #16 // weight kernel sum
|
||||
ld1 {v25.4s}, [x28] // weight kernel sum
|
||||
MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3
|
||||
MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3
|
||||
MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3]
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE2_ADD_DSTV:
|
||||
cbz x19, L4_TILE2_ACCUM_BUFFER
|
||||
|
@ -1636,9 +1652,11 @@ L4_LoopSzEnd_TILE_1:
|
|||
scvtf v25.4s, v25.4s
|
||||
|
||||
L4_Tile1Quan:
|
||||
ld1 {v0.4s}, [x12], #16 // scale
|
||||
ld1 {v0.4s}, [x12] // scale
|
||||
add x12, x12, #32
|
||||
ld1 {v6.s}[0], [x8] // x kernel sum
|
||||
ld1 {v8.4s}, [x12], #16 // weight quan zeropoint
|
||||
ld1 {v8.4s}, [x12] // weight quan zeropoint
|
||||
add x12, x12, #32
|
||||
fmul v25.4s, v25.4s, v0.4s
|
||||
add x8, x8, x22, LSR #1
|
||||
cbz x21, L4_TILE1_MLA
|
||||
|
@ -1652,8 +1670,9 @@ L4_Tile1Quan:
|
|||
|
||||
cbz x27, L4_TILE1_ADD_DSTV
|
||||
ld1 {v6.s}[0], [x27], x25 // input dequant bias
|
||||
ld1 {v8.4s}, [x28], #16 // weight kernel sum
|
||||
ld1 {v8.4s}, [x28] // weight kernel sum
|
||||
MLA_WEIGHTZERO v25, v6, v8, 0 // tile:0, oc:0-3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE1_ADD_DSTV:
|
||||
cbz x19, L4_TILE1_ACCUM_BUFFER
|
||||
|
|
|
@ -0,0 +1,337 @@
|
|||
#ifdef __aarch64__
|
||||
|
||||
#include "MNNAsmGlobal.h"
|
||||
|
||||
.text
|
||||
.align 5
|
||||
|
||||
.macro SET_0 s0, s1, s2, s3
|
||||
movi \s0\().4s, #0
|
||||
movi \s1\().4s, #0
|
||||
movi \s2\().4s, #0
|
||||
movi \s3\().4s, #0
|
||||
.endm
|
||||
|
||||
/*
|
||||
struct SumByAxisParams {
|
||||
ssize_t kernelCountUnitDouble;
|
||||
ssize_t col_buffer_unit_size;
|
||||
ssize_t DST_XUNIT;
|
||||
ssize_t SRC_UNIT;
|
||||
ssize_t blockNum;
|
||||
ssize_t oneScale;
|
||||
};
|
||||
*/
|
||||
|
||||
asm_function MNNSumByAxisLForMatmul_A_SME2
|
||||
// MNNSumByAxisLForMatmul_A_SME2(float_t* dest, int8_t* source, float* dequantScale, ssize_t realDstCount,
|
||||
// ssize_t kernelCountUnitDouble, ssize_t col_buffer_unit_size, ssize_t EP, ssize_t LP, ssize_t blockNum, ssize_t oneScale);
|
||||
// x0: dest, x1: source, x2: dequantScale, x3: realDstCount, x4: sumParams
|
||||
// x5: oneScale
|
||||
// Load from sp: x8: blockNum
|
||||
// EP=16, LP=4, HP=16
|
||||
|
||||
ldr x12, [x4, #48] // Valid
|
||||
ldr x8, [x4, #32] // blockNum
|
||||
ldr x5, [x4, #40] // oneScale
|
||||
ldr x14, [x4, #56] // kx*ky
|
||||
ldr x15, [x4, #72] // input block quant, 0:no, 1:yes
|
||||
ldr x4, [x4, #64] // LU
|
||||
|
||||
stp d14, d15, [sp, #(-16 * 5)]!
|
||||
stp d12, d13, [sp, #(16 * 1)]
|
||||
stp d10, d11, [sp, #(16 * 2)]
|
||||
stp d8, d9, [sp, #(16 * 3)]
|
||||
stp x20, x21, [sp, #(16 * 4)]
|
||||
|
||||
movi v31.16b, #1
|
||||
mov v29.16b, v31.16b
|
||||
ld1r {v30.4s}, [x2] // Dequant scale
|
||||
sdiv x4, x4, x8 // src_depth_quad per block
|
||||
cbz x12, Start
|
||||
mov x13, #0xFFFFFFFF
|
||||
lsl x12, x12, #3
|
||||
lsl x13, x13, x12
|
||||
dup v28.4s, w13
|
||||
bic v29.16b, v31.16b, v28.16b
|
||||
|
||||
Start:
|
||||
mov x13, x15 // input block quant, 0:no, 1:yes
|
||||
|
||||
TILE_16:
|
||||
cmp x3, #16
|
||||
blt Remain
|
||||
|
||||
mov x9, x8 // blockNum
|
||||
cbnz x13, TILE16_BLOCK_NUM
|
||||
ld1 {v13.4s, v14.4s, v15.4s, v16.4s}, [x2], #64 // batch quant scale
|
||||
|
||||
TILE16_BLOCK_NUM:
|
||||
mov x15, x14 // kx*ky
|
||||
movi v9.4s, #0
|
||||
movi v10.4s, #0
|
||||
movi v11.4s, #0
|
||||
movi v12.4s, #0
|
||||
|
||||
/* for range(kx*ky)...for range(ic/pack) */
|
||||
TILE16_BLOCK_INNER:
|
||||
sub x12, x4, #1 // icDiv4
|
||||
cbz x12, TILE16_LAST_QUAD
|
||||
|
||||
TILE16_PRE_QUAD:
|
||||
ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64 // E: 0,1,2,3,...,15
|
||||
.inst 0x4e8097e9 // sdot v9.4s, v31.16b, v0.16b
|
||||
.inst 0x4e8197ea // sdot v10.4s, v31.16b, v1.16b
|
||||
.inst 0x4e8297eb // sdot v11.4s, v31.16b, v2.16b
|
||||
.inst 0x4e8397ec // sdot v12.4s, v31.16b, v3.16b
|
||||
subs x12, x12, #1 // icDiv4--
|
||||
bne TILE16_PRE_QUAD
|
||||
|
||||
TILE16_LAST_QUAD:
|
||||
ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64 // E: 0,1,2,3,...,11
|
||||
.inst 0x4e8097a9 // sdot v9.4s, v29.16b, v0.16b
|
||||
.inst 0x4e8197aa // sdot v10.4s, v29.16b, v1.16b
|
||||
.inst 0x4e8297ab // sdot v11.4s, v29.16b, v2.16b
|
||||
.inst 0x4e8397ac // sdot v12.4s, v29.16b, v3.16b
|
||||
|
||||
subs x15, x15, #1
|
||||
bne TILE16_BLOCK_INNER
|
||||
|
||||
TILE16_BLOCK_INNER_END:
|
||||
subs x9, x9, #1 // blockNum--
|
||||
|
||||
scvtf v9.4s, v9.4s
|
||||
scvtf v10.4s, v10.4s
|
||||
scvtf v11.4s, v11.4s
|
||||
scvtf v12.4s, v12.4s
|
||||
|
||||
cbnz x5, TILE16_MUL_ONE_SCALE
|
||||
cbz x13, TILE16_MUL_BLOCK_SCALE
|
||||
ld1 {v13.4s, v14.4s, v15.4s, v16.4s}, [x2], #64 // batch quant scale, input block quant
|
||||
TILE16_MUL_BLOCK_SCALE:
|
||||
fmul v9.4s, v9.4s, v13.4s
|
||||
fmul v10.4s, v10.4s, v14.4s
|
||||
fmul v11.4s, v11.4s, v15.4s
|
||||
fmul v12.4s, v12.4s, v16.4s
|
||||
b TILE16_STORE
|
||||
|
||||
TILE16_MUL_ONE_SCALE:
|
||||
fmul v9.4s, v9.4s, v30.4s
|
||||
fmul v10.4s, v10.4s, v30.4s
|
||||
fmul v11.4s, v11.4s, v30.4s
|
||||
fmul v12.4s, v12.4s, v30.4s
|
||||
|
||||
TILE16_STORE:
|
||||
st1 {v9.4s, v10.4s, v11.4s, v12.4s}, [x0], #64
|
||||
bne TILE16_BLOCK_NUM
|
||||
|
||||
TILE16_END:
|
||||
subs x3, x3, #16 // realDstCount-=16
|
||||
bne TILE_16
|
||||
|
||||
|
||||
Remain: // remain realDstCount < EP
|
||||
cbz x3, End
|
||||
/* x11: Remain dstCount step for each block */
|
||||
lsl x11, x3, #2
|
||||
lsl x6, x3, #2 // x6=eDest * LP
|
||||
mov x20, x2
|
||||
|
||||
TILE_12:
|
||||
cmp x3, #12
|
||||
blt TILE_2
|
||||
|
||||
mov x7, x1
|
||||
mov x9, x8 // blockNum
|
||||
mov x10, x0 // tag dst address
|
||||
|
||||
mov x9, x8 // blockNum
|
||||
cbnz x13, TILE12_BLOCK_NUM
|
||||
ld1 {v13.4s, v14.4s, v15.4s}, [x2], #48 // batch quant scale
|
||||
|
||||
TILE12_BLOCK_NUM:
|
||||
mov x15, x14 // kx*ky
|
||||
movi v10.4s, #0
|
||||
movi v11.4s, #0
|
||||
movi v12.4s, #0
|
||||
/* for range(kx*ky)...for range(ic/pack) */
|
||||
TILE12_BLOCK_INNER:
|
||||
sub x12, x4, #1 // icDiv4
|
||||
cbz x12, TILE12_LAST_QUAD
|
||||
|
||||
TILE12_PRE_QUAD:
|
||||
ld1 {v0.16b, v1.16b, v2.16b}, [x7], x6 // E: 0,1,2,3,...,11
|
||||
.inst 0x4e8097ea // sdot v10.4s, v31.16b, v0.16b // sum LP axis for E0, E1, E2, E3
|
||||
.inst 0x4e8197eb // sdot v11.4s, v31.16b, v1.16b
|
||||
.inst 0x4e8297ec // sdot v12.4s, v31.16b, v2.16b
|
||||
subs x12, x12, #1 // icDiv4--
|
||||
bne TILE12_PRE_QUAD
|
||||
|
||||
TILE12_LAST_QUAD:
|
||||
ld1 {v0.16b, v1.16b, v2.16b}, [x7], x6 // E: 0,1,2,3,...,11
|
||||
.inst 0x4e8097aa // sdot v10.4s, v29.16b, v0.16b // sum LP axis for E0, E1, E2, E3
|
||||
.inst 0x4e8197ab // sdot v11.4s, v29.16b, v1.16b
|
||||
.inst 0x4e8297ac // sdot v12.4s, v29.16b, v2.16b
|
||||
|
||||
subs x15, x15, #1
|
||||
bne TILE12_BLOCK_INNER
|
||||
|
||||
TILE12_BLOCK_INNER_END:
|
||||
subs x9, x9, #1 // blockNum--
|
||||
|
||||
scvtf v10.4s, v10.4s
|
||||
scvtf v11.4s, v11.4s
|
||||
scvtf v12.4s, v12.4s
|
||||
|
||||
cbnz x5, TILE12_MUL_ONE_SCALE
|
||||
cbz x13, TILE12_MUL_BLOCK_SCALE
|
||||
ld1 {v13.4s, v14.4s, v15.4s}, [x2], x6 // batch quant scale, input block quant
|
||||
TILE12_MUL_BLOCK_SCALE:
|
||||
fmul v10.4s, v10.4s, v13.4s
|
||||
fmul v11.4s, v11.4s, v14.4s
|
||||
fmul v12.4s, v12.4s, v15.4s
|
||||
b TILE12_STORE
|
||||
|
||||
TILE12_MUL_ONE_SCALE:
|
||||
fmul v10.4s, v10.4s, v30.4s
|
||||
fmul v11.4s, v11.4s, v30.4s
|
||||
fmul v12.4s, v12.4s, v30.4s
|
||||
|
||||
TILE12_STORE:
|
||||
st1 {v10.4s, v11.4s, v12.4s}, [x10], x11
|
||||
bne TILE12_BLOCK_NUM
|
||||
|
||||
TILE12_END:
|
||||
subs x3, x3, #12 // realDstCount-=12
|
||||
add x1, x1, #48 // LP * 12 * sizeof(int8_t)
|
||||
add x0, x0, #48 // finish 12*sizeof(float)
|
||||
add x2, x20, #48 // x20 + 12 * sizeof(float)
|
||||
mov x20, x2
|
||||
|
||||
TILE_2: // realDstCount >= 1
|
||||
cmp x3, #2
|
||||
blt TILE_1
|
||||
|
||||
mov x7, x1
|
||||
mov x9, x8 // blockNum
|
||||
mov x10, x0 // tag dst address
|
||||
|
||||
cbnz x13, TILE2_BLOCK_NUM
|
||||
ld1 {v13.d}[0], [x2], #8 // batch quant scale
|
||||
|
||||
TILE2_BLOCK_NUM:
|
||||
mov x15, x14 // kx*ky
|
||||
movi v10.4s, #0
|
||||
|
||||
TILE2_BLOCK_INNER: // range(kxky)
|
||||
sub x12, x4, #1 // icDiv4
|
||||
cbz x12, TILE2_LAST_QUAD
|
||||
|
||||
TILE2_PRE_QUAD: // range(icDiv4)
|
||||
ld1 {v0.d}[0], [x7], x6 // E: 0,1
|
||||
subs x12, x12, #1
|
||||
.inst 0x4e8097ea // sdot v10.4s, v31.16b, v0.16b // sum LP axis for E0
|
||||
bne TILE2_PRE_QUAD
|
||||
|
||||
TILE2_LAST_QUAD:
|
||||
ld1 {v0.d}[0], [x7], x6 // E: 0,1
|
||||
.inst 0x4e8097aa // sdot v10.4s, v29.16b, v0.16b
|
||||
|
||||
subs x15, x15, #1 // kxky--
|
||||
bne TILE2_BLOCK_INNER
|
||||
|
||||
TILE2_BLOCK_INNER_END:
|
||||
scvtf v10.4s, v10.4s
|
||||
|
||||
cbnz x5, TILE2_MUL_ONE_SCALE
|
||||
cbz x13, TILE2_MUL_BLOCK_SCALE
|
||||
ld1 {v13.d}[0], [x2], x6 // batch quant scale
|
||||
TILE2_MUL_BLOCK_SCALE:
|
||||
fmul v10.4s, v10.4s, v13.4s
|
||||
b TILE2_STORE
|
||||
|
||||
TILE2_MUL_ONE_SCALE:
|
||||
fmul v10.4s, v10.4s, v30.4s
|
||||
|
||||
TILE2_STORE:
|
||||
subs x9, x9, #1 // blockNum--
|
||||
st1 {v10.d}[0], [x10], x11
|
||||
bne TILE2_BLOCK_NUM
|
||||
|
||||
TILE2_END:
|
||||
sub x3, x3, #2 // realDstCount-=2
|
||||
add x1, x1, #8 // LP * 2
|
||||
add x0, x0, #8 // finish remain 2
|
||||
add x2, x20, #8 // x20 + 2 * sizeof(float)
|
||||
mov x20, x2
|
||||
b TILE_2
|
||||
|
||||
|
||||
TILE_1: // realDstCount >= 1
|
||||
cmp x3, #1
|
||||
blt End
|
||||
|
||||
mov x7, x1
|
||||
mov x9, x8 // blockNum
|
||||
mov x10, x0
|
||||
|
||||
cbnz x13, TILE1_BLOCK_NUM
|
||||
ld1 {v13.s}[0], [x2], #4 // batch quant scale
|
||||
|
||||
TILE1_BLOCK_NUM:
|
||||
mov x15, x14 // kx*ky
|
||||
movi v10.4s, #0
|
||||
|
||||
TILE1_BLOCK_INNER:
|
||||
sub x12, x4, #1
|
||||
cbz x12, TILE1_LAST_QUAD
|
||||
|
||||
TILE1_PRE_QUAD:
|
||||
ld1 {v0.s}[0], [x7] // E: 0
|
||||
add x7, x7, x6
|
||||
.inst 0x4e8097ea // sdot v10.4s, v31.16b, v0.16b // sum LP axis for E0
|
||||
subs x12, x12, #1 // icDiv4--
|
||||
bne TILE1_PRE_QUAD
|
||||
|
||||
TILE1_LAST_QUAD:
|
||||
ld1 {v0.s}[0], [x7], x6 // E: 0
|
||||
.inst 0x4e8097aa // sdot v10.4s, v29.16b, v0.16b
|
||||
|
||||
subs x15, x15, #1 // kxky--
|
||||
bne TILE1_BLOCK_INNER
|
||||
|
||||
TILE1_BLOCK_INNER_END:
|
||||
scvtf v10.4s, v10.4s
|
||||
|
||||
cbnz x5, TILE1_MUL_ONE_SCALE
|
||||
cbz x13, TILE1_MUL_BLOCK_SCALE
|
||||
ld1 {v13.s}[0], [x2], x6 // batch quant scale
|
||||
TILE1_MUL_BLOCK_SCALE:
|
||||
fmul v10.4s, v10.4s, v13.4s
|
||||
b TILE1_STORE
|
||||
|
||||
TILE1_MUL_ONE_SCALE:
|
||||
fmul v10.4s, v10.4s, v30.4s
|
||||
|
||||
TILE1_STORE:
|
||||
subs x9, x9, #1 // blockNum--
|
||||
st1 {v10.s}[0], [x10], x11
|
||||
bne TILE1_BLOCK_NUM
|
||||
|
||||
TILE1_END:
|
||||
sub x3, x3, #1 // realDstCount-=1
|
||||
add x1, x1, #4 // LP * 1
|
||||
add x0, x0, #4 // finish remain 1
|
||||
add x2, x20, #4 // x20 + 1 * sizeof(float)
|
||||
mov x20, x2
|
||||
|
||||
b TILE_1
|
||||
|
||||
End:
|
||||
ldp x20, x21, [sp, #(16 * 4)]
|
||||
ldp d8, d9, [sp, #(16 * 3)]
|
||||
ldp d10, d11, [sp, #(16 * 2)]
|
||||
ldp d12, d13, [sp, #(16 * 1)]
|
||||
ldp d14, d15, [sp], #(16 * 5)
|
||||
ret
|
||||
#endif
|
|
@ -0,0 +1,151 @@
|
|||
#ifdef __aarch64__
|
||||
#include "MNNAsmGlobal.h"
|
||||
|
||||
.text
|
||||
.align 5
|
||||
|
||||
.macro SET_0 s0, s1, s2, s3
|
||||
movi \s0\().4s, #0
|
||||
movi \s1\().4s, #0
|
||||
movi \s2\().4s, #0
|
||||
movi \s3\().4s, #0
|
||||
.endm
|
||||
|
||||
.macro Int32_To_Float32 s0, s1, s2, s3
|
||||
scvtf \s0\().4s, \s0\().4s
|
||||
scvtf \s1\().4s, \s1\().4s
|
||||
scvtf \s2\().4s, \s2\().4s
|
||||
scvtf \s3\().4s, \s3\().4s
|
||||
.endm
|
||||
|
||||
asm_function MNNPermuteSumWeightInt4Sme2
|
||||
// void MNNPermuteSumWeightInt4Sme2(uint8_t* dest, uint8_t* source, size_t outside, size_t inside, float* kernelSum);
|
||||
// auto load: x0: dest, x1: source, x2: outside, x3: inside, x4: kernelSum
|
||||
|
||||
// inside = lu
|
||||
// outside = blocknum*hu
|
||||
// kernelSum shape: [hu, blockNum, hp]
|
||||
|
||||
|
||||
stp d14, d15, [sp, #-64]!
|
||||
stp d12, d13, [sp, #16]
|
||||
stp d10, d11, [sp, #32]
|
||||
stp d8, d9, [sp, #48]
|
||||
|
||||
movi v31.16b, #15
|
||||
movi v30.16b, #4
|
||||
movi v29.16b, #1
|
||||
|
||||
Loop: // blocknum*hu
|
||||
mov x6, x3 // lu
|
||||
|
||||
SET_0 v4, v5, v6, v7
|
||||
SET_0 v20, v21, v22, v23
|
||||
cmp x6, #2
|
||||
blt LoopLU
|
||||
|
||||
LoopLU2:
|
||||
ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64
|
||||
ushr v8.16b, v0.16b, #4 // v8: 0 2 ... 30
|
||||
and v9.16b, v0.16b, v31.16b // v9: 1 3 ... 31
|
||||
ushr v10.16b, v1.16b, #4
|
||||
and v11.16b, v1.16b, v31.16b
|
||||
|
||||
ushr v27.16b, v2.16b, #4
|
||||
and v28.16b, v2.16b, v31.16b
|
||||
ushr v24.16b, v3.16b, #4
|
||||
and v25.16b, v3.16b, v31.16b
|
||||
|
||||
zip1 v12.16b, v8.16b, v9.16b // v12: 0 1 2 3 ... 14 15
|
||||
zip2 v13.16b, v8.16b, v9.16b // v13: 16 17 18 19 ... 30 31
|
||||
zip1 v14.16b, v10.16b, v11.16b // v14: 32 33 34 35 ... 46 47
|
||||
zip2 v15.16b, v10.16b, v11.16b // v15: 48 49 50 51 ... 62 63
|
||||
|
||||
zip1 v16.16b, v27.16b, v28.16b
|
||||
zip2 v17.16b, v27.16b, v28.16b
|
||||
zip1 v18.16b, v24.16b, v25.16b
|
||||
zip2 v19.16b, v24.16b, v25.16b
|
||||
|
||||
// weight kernel sum
|
||||
.inst 0x6e8c97a4 // udot v4.4s, v29.16b, v12.16b
|
||||
.inst 0x6e8d97a5 // udot v5.4s, v29.16b, v13.16b
|
||||
.inst 0x6e8e97a6 // udot v6.4s, v29.16b, v14.16b
|
||||
.inst 0x6e8f97a7 // udot v7.4s, v29.16b, v15.16b
|
||||
|
||||
.inst 0x6e9097b4 // udot v20.4s, v29.16b, v16.16b
|
||||
.inst 0x6e9197b5 // udot v21.4s, v29.16b, v17.16b
|
||||
.inst 0x6e9297b6 // udot v22.4s, v29.16b, v18.16b
|
||||
.inst 0x6e9397b7 // udot v23.4s, v29.16b, v19.16b
|
||||
|
||||
sub x6, x6, #2
|
||||
// transpose
|
||||
ushl v9.16b, v9.16b, v30.16b
|
||||
ushl v11.16b, v11.16b, v30.16b
|
||||
ushl v28.16b, v28.16b, v30.16b
|
||||
ushl v25.16b, v25.16b, v30.16b
|
||||
|
||||
orr v0.16b, v8.16b, v9.16b
|
||||
orr v1.16b, v10.16b, v11.16b
|
||||
orr v2.16b, v27.16b, v28.16b
|
||||
orr v3.16b, v24.16b, v25.16b
|
||||
|
||||
st1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64
|
||||
|
||||
cmp x6, #2
|
||||
bge LoopLU2
|
||||
cbz x6, LUEnd
|
||||
|
||||
LoopLU:
|
||||
cbz x6, LUEnd
|
||||
|
||||
ld1 {v0.16b, v1.16b}, [x1], #32
|
||||
ushr v8.16b, v0.16b, #4 // v8: 0 2 ... 30
|
||||
and v9.16b, v0.16b, v31.16b // v9: 1 3 ... 31
|
||||
ushr v10.16b, v1.16b, #4
|
||||
and v11.16b, v1.16b, v31.16b
|
||||
|
||||
zip1 v12.16b, v8.16b, v9.16b // v12: 0 1 2 3 ... 14 15
|
||||
zip2 v13.16b, v8.16b, v9.16b // v13: 16 17 18 19 ... 30 31
|
||||
zip1 v14.16b, v10.16b, v11.16b // v14: 32 33 34 35 ... 46 47
|
||||
zip2 v15.16b, v10.16b, v11.16b // v15: 48 49 50 51 ... 62 63
|
||||
|
||||
|
||||
// weight kernel sum
|
||||
.inst 0x6e8c97a4 // udot v4.4s, v29.16b, v12.16b
|
||||
.inst 0x6e8d97a5 // udot v5.4s, v29.16b, v13.16b
|
||||
.inst 0x6e8e97a6 // udot v6.4s, v29.16b, v14.16b
|
||||
.inst 0x6e8f97a7 // udot v7.4s, v29.16b, v15.16b
|
||||
|
||||
// <<4
|
||||
ushl v9.16b, v9.16b, v30.16b
|
||||
ushl v11.16b, v11.16b, v30.16b
|
||||
|
||||
orr v0.16b, v8.16b, v9.16b
|
||||
orr v1.16b, v10.16b, v11.16b
|
||||
|
||||
st1 {v0.16b, v1.16b}, [x0], #32
|
||||
|
||||
LUEnd:
|
||||
|
||||
add v4.4s, v4.4s, v20.4s
|
||||
add v5.4s, v5.4s, v21.4s
|
||||
add v6.4s, v6.4s, v22.4s
|
||||
add v7.4s, v7.4s, v23.4s
|
||||
scvtf v4.4s, v4.4s
|
||||
scvtf v5.4s, v5.4s
|
||||
scvtf v6.4s, v6.4s
|
||||
scvtf v7.4s, v7.4s
|
||||
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x4], #64
|
||||
|
||||
subs x2, x2, #1 // outside--
|
||||
bne Loop
|
||||
|
||||
|
||||
End:
|
||||
ldp d8, d9, [sp, #48]
|
||||
ldp d10, d11, [sp, #32]
|
||||
ldp d12, d13, [sp, #16]
|
||||
ldp d14, d15, [sp], #64
|
||||
ret
|
||||
|
||||
#endif
|
|
@ -0,0 +1,127 @@
|
|||
#ifdef __aarch64__
|
||||
#include "MNNAsmGlobal.h"
|
||||
|
||||
.text
|
||||
.align 5
|
||||
|
||||
.macro SET_0 s0, s1, s2, s3
|
||||
movi \s0\().4s, #0
|
||||
movi \s1\().4s, #0
|
||||
movi \s2\().4s, #0
|
||||
movi \s3\().4s, #0
|
||||
.endm
|
||||
|
||||
.macro Int32_To_Float32 s0, s1, s2, s3
|
||||
scvtf \s0\().4s, \s0\().4s
|
||||
scvtf \s1\().4s, \s1\().4s
|
||||
scvtf \s2\().4s, \s2\().4s
|
||||
scvtf \s3\().4s, \s3\().4s
|
||||
.endm
|
||||
|
||||
asm_function MNNSumWeightInt8Sme2
|
||||
// void MNNSumWeightInt8Sme2(float* kernlesum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP)
|
||||
// auto load: x0: dest, x1: source, x2: outside, x3: reduceAxis, x4: hP, x5: lP
|
||||
|
||||
// weight shape: [outside, reduceAxis, hP, lP]
|
||||
// outside = blocknum * hU
|
||||
// reduceAxis = kernelCount * lU
|
||||
|
||||
stp d14, d15, [sp, #-64]!
|
||||
stp d12, d13, [sp, #16]
|
||||
stp d10, d11, [sp, #32]
|
||||
stp d8, d9, [sp, #48]
|
||||
|
||||
movi v31.16b, #1
|
||||
|
||||
Loop:
|
||||
mov x5, x3
|
||||
SET_0 v16, v17, v18, v19
|
||||
SET_0 v20, v21, v22, v23
|
||||
SET_0 v24, v25, v26, v27
|
||||
|
||||
LU3:
|
||||
cmp x5, #3
|
||||
blt LU2
|
||||
ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64
|
||||
ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64
|
||||
ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x1], #64
|
||||
|
||||
// kernel sum
|
||||
.inst 0x4e8097f0 // sdot v16.4s, v31.16b, v0.16b
|
||||
.inst 0x4e8197f1 // sdot v17.4s, v31.16b, v1.16b
|
||||
.inst 0x4e8297f2 // sdot v18.4s, v31.16b, v2.16b
|
||||
.inst 0x4e8397f3 // sdot v19.4s, v31.16b, v3.16b
|
||||
|
||||
.inst 0x4e8497f4 // sdot v20.4s, v31.16b, v4.16b
|
||||
.inst 0x4e8597f5 // sdot v21.4s, v31.16b, v5.16b
|
||||
.inst 0x4e8697f6 // sdot v22.4s, v31.16b, v6.16b
|
||||
.inst 0x4e8797f7 // sdot v23.4s, v31.16b, v7.16b
|
||||
|
||||
.inst 0x4e8897f8 // sdot v24.4s, v31.16b, v8.16b
|
||||
.inst 0x4e8997f9 // sdot v25.4s, v31.16b, v9.16b
|
||||
.inst 0x4e8a97fa // sdot v26.4s, v31.16b, v10.16b
|
||||
.inst 0x4e8b97fb // sdot v27.4s, v31.16b, v11.16b
|
||||
|
||||
|
||||
subs x5, x5, #3
|
||||
beq LUEnd
|
||||
b LU3
|
||||
|
||||
|
||||
LU2:
|
||||
cmp x5, #2
|
||||
blt LU1
|
||||
ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64
|
||||
ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64
|
||||
|
||||
// kernel sum
|
||||
.inst 0x4e8097f0 // sdot v16.4s, v31.16b, v0.16b
|
||||
.inst 0x4e8197f1 // sdot v17.4s, v31.16b, v1.16b
|
||||
.inst 0x4e8297f2 // sdot v18.4s, v31.16b, v2.16b
|
||||
.inst 0x4e8397f3 // sdot v19.4s, v31.16b, v3.16b
|
||||
|
||||
.inst 0x4e8497f4 // sdot v20.4s, v31.16b, v4.16b
|
||||
.inst 0x4e8597f5 // sdot v21.4s, v31.16b, v5.16b
|
||||
.inst 0x4e8697f6 // sdot v22.4s, v31.16b, v6.16b
|
||||
.inst 0x4e8797f7 // sdot v23.4s, v31.16b, v7.16b
|
||||
|
||||
subs x5, x5, #2
|
||||
beq LUEnd
|
||||
b LU2
|
||||
|
||||
LU1: // outside
|
||||
cbz x5, LUEnd
|
||||
ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64
|
||||
// kernel sum
|
||||
.inst 0x4e8097f0 // sdot v16.4s, v31.16b, v0.16b
|
||||
.inst 0x4e8197f1 // sdot v17.4s, v31.16b, v1.16b
|
||||
.inst 0x4e8297f2 // sdot v18.4s, v31.16b, v2.16b
|
||||
.inst 0x4e8397f3 // sdot v19.4s, v31.16b, v3.16b
|
||||
|
||||
LUEnd:
|
||||
add v16.4s, v16.4s, v20.4s
|
||||
add v17.4s, v17.4s, v21.4s
|
||||
add v18.4s, v18.4s, v22.4s
|
||||
add v19.4s, v19.4s, v23.4s
|
||||
add v16.4s, v16.4s, v24.4s
|
||||
add v17.4s, v17.4s, v25.4s
|
||||
add v18.4s, v18.4s, v26.4s
|
||||
add v19.4s, v19.4s, v27.4s
|
||||
scvtf v16.4s, v16.4s
|
||||
scvtf v17.4s, v17.4s
|
||||
scvtf v18.4s, v18.4s
|
||||
scvtf v19.4s, v19.4s
|
||||
st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64
|
||||
|
||||
subs x2, x2, #1 // outside--
|
||||
bne Loop
|
||||
|
||||
|
||||
End:
|
||||
ldp d8, d9, [sp, #48]
|
||||
ldp d10, d11, [sp, #32]
|
||||
ldp d12, d13, [sp, #16]
|
||||
ldp d14, d15, [sp], #64
|
||||
ret
|
||||
|
||||
#endif
|
|
@ -448,6 +448,239 @@ static bool MNNAsyLocalQuantInfo_EP12_FP32(float* scale, float* bias, float* qsc
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool MNNAsyLocalQuantInfo_EP16_FP32(float* scale, float* bias, float* qscale, float* qbias, const float* srcMin, const float* srcMax, const size_t* info) {
|
||||
auto blockNum = info[0];
|
||||
auto EP = info[1];
|
||||
auto DST_XUNIT = info[3];
|
||||
if (DST_XUNIT != 16) {
|
||||
MNN_ERROR("Function called error\n");
|
||||
return false;
|
||||
}
|
||||
auto stride = EP * blockNum;
|
||||
// dequant scale/bias : [EU, blockNum, step]
|
||||
// quant scale/bias: [blockNum, EP]
|
||||
auto minfloat = vdupq_n_f32(1e-6);
|
||||
auto _255f = vdupq_n_f32(255.f);
|
||||
auto _128f = vdupq_n_f32(128.f);
|
||||
auto _0f = vdupq_n_f32(0.f);
|
||||
for (int k = 0; k < blockNum; ++k) {
|
||||
auto qind = k * EP;
|
||||
auto realDstCount = EP;
|
||||
auto scalePtr = scale + k * ALIMIN(EP, DST_XUNIT);
|
||||
auto biasPtr = bias + k * ALIMIN(EP, DST_XUNIT);
|
||||
while (realDstCount > DST_XUNIT - 1) {
|
||||
auto max0 = vld1q_f32(srcMax + qind);
|
||||
auto max1 = vld1q_f32(srcMax + qind + 4);
|
||||
auto max2 = vld1q_f32(srcMax + qind + 8);
|
||||
auto max3 = vld1q_f32(srcMax + qind + 12);
|
||||
auto min0 = vld1q_f32(srcMin + qind);
|
||||
auto min1 = vld1q_f32(srcMin + qind + 4);
|
||||
auto min2 = vld1q_f32(srcMin + qind + 8);
|
||||
auto min3 = vld1q_f32(srcMin + qind + 12);
|
||||
|
||||
auto diff0 = vsubq_f32(max0, min0);
|
||||
auto diff1 = vsubq_f32(max1, min1);
|
||||
auto diff2 = vsubq_f32(max2, min2);
|
||||
auto diff3 = vsubq_f32(max3, min3);
|
||||
|
||||
auto qscaleV0 = vdivq_f32(_255f, diff0);
|
||||
auto qscaleV1 = vdivq_f32(_255f, diff1);
|
||||
auto qscaleV2 = vdivq_f32(_255f, diff2);
|
||||
auto qscaleV3 = vdivq_f32(_255f, diff3);
|
||||
auto scaleV0 = vdivq_f32(diff0, _255f);
|
||||
auto scaleV1 = vdivq_f32(diff1, _255f);
|
||||
auto scaleV2 = vdivq_f32(diff2, _255f);
|
||||
auto scaleV3 = vdivq_f32(diff3, _255f);
|
||||
|
||||
auto qbiasV0 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min0), diff0), _128f));
|
||||
auto qbiasV1 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min1), diff1), _128f));
|
||||
auto qbiasV2 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min2), diff2), _128f));
|
||||
auto qbiasV3 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min3), diff3), _128f));
|
||||
auto biasV0 = vaddq_f32(vdivq_f32(vmulq_f32(diff0, _128f), _255f), min0);
|
||||
auto biasV1 = vaddq_f32(vdivq_f32(vmulq_f32(diff1, _128f), _255f), min1);
|
||||
auto biasV2 = vaddq_f32(vdivq_f32(vmulq_f32(diff2, _128f), _255f), min2);
|
||||
auto biasV3 = vaddq_f32(vdivq_f32(vmulq_f32(diff3, _128f), _255f), min3);
|
||||
|
||||
auto _0bic = vclezq_f32(diff0);
|
||||
auto _1bic = vclezq_f32(diff1);
|
||||
auto _2bic = vclezq_f32(diff2);
|
||||
auto _3bic = vclezq_f32(diff3);
|
||||
|
||||
qscaleV0 = vbslq_f32(_0bic, _0f, qscaleV0);
|
||||
qscaleV1 = vbslq_f32(_1bic, _0f, qscaleV1);
|
||||
qscaleV2 = vbslq_f32(_2bic, _0f, qscaleV2);
|
||||
qscaleV3 = vbslq_f32(_3bic, _0f, qscaleV3);
|
||||
|
||||
qbiasV0 = vrndaq_f32(vbslq_f32(_0bic, _0f, qbiasV0));
|
||||
qbiasV1 = vrndaq_f32(vbslq_f32(_1bic, _0f, qbiasV1));
|
||||
qbiasV2 = vrndaq_f32(vbslq_f32(_2bic, _0f, qbiasV2));
|
||||
qbiasV3 = vrndaq_f32(vbslq_f32(_3bic, _0f, qbiasV3));
|
||||
|
||||
scaleV0 = vbslq_f32(_0bic, _0f, scaleV0);
|
||||
scaleV1 = vbslq_f32(_1bic, _0f, scaleV1);
|
||||
scaleV2 = vbslq_f32(_2bic, _0f, scaleV2);
|
||||
scaleV3 = vbslq_f32(_3bic, _0f, scaleV3);
|
||||
|
||||
biasV0 = vbslq_f32(_0bic, max0, biasV0);
|
||||
biasV1 = vbslq_f32(_1bic, max1, biasV1);
|
||||
biasV2 = vbslq_f32(_2bic, max2, biasV2);
|
||||
biasV3 = vbslq_f32(_3bic, max3, biasV3);
|
||||
|
||||
vst1q_f32(qscale + qind, qscaleV0);
|
||||
vst1q_f32(qscale + qind + 4, qscaleV1);
|
||||
vst1q_f32(qscale + qind + 8, qscaleV2);
|
||||
vst1q_f32(qscale + qind + 12, qscaleV3);
|
||||
|
||||
vst1q_f32(qbias + qind, qbiasV0);
|
||||
vst1q_f32(qbias + qind + 4, qbiasV1);
|
||||
vst1q_f32(qbias + qind + 8, qbiasV2);
|
||||
vst1q_f32(qbias + qind + 12, qbiasV3);
|
||||
|
||||
vst1q_f32(scalePtr, scaleV0);
|
||||
vst1q_f32(scalePtr + 4, scaleV1);
|
||||
vst1q_f32(scalePtr + 8, scaleV2);
|
||||
vst1q_f32(scalePtr + 12, scaleV3);
|
||||
|
||||
vst1q_f32(biasPtr, biasV0);
|
||||
vst1q_f32(biasPtr + 4, biasV1);
|
||||
vst1q_f32(biasPtr + 8, biasV2);
|
||||
vst1q_f32(biasPtr + 12, biasV3);
|
||||
|
||||
realDstCount -= DST_XUNIT;
|
||||
qind += DST_XUNIT;
|
||||
scalePtr += (blockNum * DST_XUNIT);
|
||||
biasPtr += (blockNum * DST_XUNIT);
|
||||
}
|
||||
if (realDstCount == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto remainE = realDstCount;
|
||||
auto stride0 = remainE * blockNum;
|
||||
scalePtr = scale + (EP / DST_XUNIT) * blockNum * DST_XUNIT + k * remainE;
|
||||
biasPtr = bias + (EP / DST_XUNIT) * blockNum * DST_XUNIT + k * remainE;
|
||||
if (realDstCount > 7) {
|
||||
auto max0 = vld1q_f32(srcMax + qind);
|
||||
auto max1 = vld1q_f32(srcMax + qind + 4);
|
||||
auto min0 = vld1q_f32(srcMin + qind);
|
||||
auto min1 = vld1q_f32(srcMin + qind + 4);
|
||||
auto diff0 = vsubq_f32(max0, min0);
|
||||
auto diff1 = vsubq_f32(max1, min1);
|
||||
|
||||
auto qscaleV0 = vdivq_f32(_255f, diff0);
|
||||
auto qscaleV1 = vdivq_f32(_255f, diff1);
|
||||
auto scaleV0 = vdivq_f32(diff0, _255f);
|
||||
auto scaleV1 = vdivq_f32(diff1, _255f);
|
||||
|
||||
auto qbiasV0 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min0), diff0), _128f));
|
||||
auto qbiasV1 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min1), diff1), _128f));
|
||||
auto biasV0 = vaddq_f32(vdivq_f32(vmulq_f32(diff0, _128f), _255f), min0);
|
||||
auto biasV1 = vaddq_f32(vdivq_f32(vmulq_f32(diff1, _128f), _255f), min1);
|
||||
|
||||
auto _0bic = vclezq_f32(diff0);
|
||||
auto _1bic = vclezq_f32(diff1);
|
||||
|
||||
qscaleV0 = vbslq_f32(_0bic, _0f, qscaleV0);
|
||||
qscaleV1 = vbslq_f32(_1bic, _0f, qscaleV1);
|
||||
|
||||
qbiasV0 = vrndaq_f32(vbslq_f32(_0bic, _0f, qbiasV0));
|
||||
qbiasV1 = vrndaq_f32(vbslq_f32(_1bic, _0f, qbiasV1));
|
||||
|
||||
scaleV0 = vbslq_f32(_0bic, _0f, scaleV0);
|
||||
scaleV1 = vbslq_f32(_1bic, _0f, scaleV1);
|
||||
|
||||
biasV0 = vbslq_f32(_0bic, max0, biasV0);
|
||||
biasV1 = vbslq_f32(_1bic, max1, biasV1);
|
||||
|
||||
vst1q_f32(qscale + qind, qscaleV0);
|
||||
vst1q_f32(qscale + qind + 4, qscaleV1);
|
||||
|
||||
vst1q_f32(qbias + qind, qbiasV0);
|
||||
vst1q_f32(qbias + qind + 4, qbiasV1);
|
||||
|
||||
vst1q_f32(scalePtr, scaleV0);
|
||||
vst1q_f32(scalePtr + 4, scaleV1);
|
||||
|
||||
vst1q_f32(biasPtr, biasV0);
|
||||
vst1q_f32(biasPtr + 4, biasV1);
|
||||
|
||||
realDstCount -= 8;
|
||||
qind += 8;
|
||||
scalePtr += 8;
|
||||
biasPtr += 8;
|
||||
}
|
||||
if (realDstCount > 3) {
|
||||
auto max0 = vld1q_f32(srcMax + qind);
|
||||
auto min0 = vld1q_f32(srcMin + qind);
|
||||
auto diff0 = vsubq_f32(max0, min0);
|
||||
|
||||
auto qscaleV0 = vdivq_f32(_255f, diff0);
|
||||
auto scaleV0 = vdivq_f32(diff0, _255f);
|
||||
|
||||
auto qbiasV0 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min0), diff0), _128f));
|
||||
auto biasV0 = vaddq_f32(vdivq_f32(vmulq_f32(diff0, _128f), _255f), min0);
|
||||
|
||||
auto _0bic = vclezq_f32(diff0);
|
||||
|
||||
qscaleV0 = vbslq_f32(_0bic, _0f, qscaleV0);
|
||||
|
||||
qbiasV0 = vrndaq_f32(vbslq_f32(_0bic, _0f, qbiasV0));
|
||||
|
||||
scaleV0 = vbslq_f32(_0bic, _0f, scaleV0);
|
||||
|
||||
biasV0 = vbslq_f32(_0bic, max0, biasV0);
|
||||
|
||||
vst1q_f32(qscale + qind, qscaleV0);
|
||||
|
||||
vst1q_f32(qbias + qind, qbiasV0);
|
||||
|
||||
vst1q_f32(scalePtr, scaleV0);
|
||||
|
||||
vst1q_f32(biasPtr, biasV0);
|
||||
|
||||
realDstCount -= 4;
|
||||
qind += 4;
|
||||
scalePtr += 4;
|
||||
biasPtr += 4;
|
||||
}
|
||||
while (realDstCount > 0) {
|
||||
auto max0 = vld1q_dup_f32(srcMax + qind);
|
||||
auto min0 = vld1q_dup_f32(srcMin + qind);
|
||||
auto diff0 = vsubq_f32(max0, min0);
|
||||
|
||||
auto qscaleV0 = vdivq_f32(_255f, diff0);
|
||||
auto scaleV0 = vdivq_f32(diff0, _255f);
|
||||
|
||||
auto qbiasV0 = vnegq_f32(vaddq_f32(vdivq_f32(vmulq_f32(_255f, min0), diff0), _128f));
|
||||
auto biasV0 = vaddq_f32(vdivq_f32(vmulq_f32(diff0, _128f), _255f), min0);
|
||||
|
||||
auto _0bic = vclezq_f32(diff0);
|
||||
|
||||
qscaleV0 = vbslq_f32(_0bic, _0f, qscaleV0);
|
||||
|
||||
qbiasV0 = vrndaq_f32(vbslq_f32(_0bic, _0f, qbiasV0));
|
||||
|
||||
scaleV0 = vbslq_f32(_0bic, _0f, scaleV0);
|
||||
|
||||
biasV0 = vbslq_f32(_0bic, max0, biasV0);
|
||||
|
||||
vst1q_lane_f32(qscale + qind, qscaleV0, 0);
|
||||
|
||||
vst1q_lane_f32(qbias + qind, qbiasV0, 0);
|
||||
|
||||
vst1q_lane_f32(scalePtr, scaleV0, 0);
|
||||
|
||||
vst1q_lane_f32(biasPtr, biasV0, 0);
|
||||
|
||||
realDstCount -= 1;
|
||||
qind += 1;
|
||||
scalePtr += 1;
|
||||
biasPtr += 1;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
|
|
|
@ -90,8 +90,12 @@ sub \rg0, \rg0, \rg1
|
|||
.endm
|
||||
|
||||
.macro REVERT_WEIGHT_KERNEL_SUM rg0, rg1, rg2, rg3
|
||||
mul \rg1, \rg2, \rg3 // blocknum * ocUp4 * sizeof(float)
|
||||
sub \rg0, \rg0, \rg1, LSL #4 // revert weight kernel sum
|
||||
// y=UP_DIV(ocDiv4,(hp/pack))
|
||||
add \rg1, \rg3, #1
|
||||
lsr \rg1, \rg1, #1
|
||||
// blockNum * y * (hp * sizeof(float))
|
||||
mul \rg1, \rg2, \rg1
|
||||
sub \rg0, \rg0, \rg1, LSL #5 // revert weight kernel sum
|
||||
.endm
|
||||
|
||||
asm_function MNNGemmInt8AddBiasScale_ARMV82_w4_Unit
|
||||
|
@ -398,9 +402,11 @@ L4_TILE12_BLOCKNUM:
|
|||
|
||||
|
||||
L4_Tile12Quan:
|
||||
ld1 {v0.4s}, [x2], #16 // scale
|
||||
ld1 {v0.4s}, [x2] // scale
|
||||
add x2, x2, #32
|
||||
ld1 {v2.4s, v3.4s, v4.4s}, [x8], #48 // x kernel sum
|
||||
ld1 {v5.4s}, [x2], #16 // weight quan zeropoint
|
||||
ld1 {v5.4s}, [x2] // weight quan zeropoint
|
||||
add x2, x2, #32
|
||||
Int32ToFloat v8, v9, v10, v11
|
||||
Int32ToFloat v12, v13, v14, v15
|
||||
Int32ToFloat v16, v17, v18, v19
|
||||
|
@ -430,7 +436,7 @@ L4_TILE12_BLOCKNUM:
|
|||
|
||||
cbz x27, L4_TILE12_ADD_DSTV
|
||||
ld1 {v0.4s, v1.4s, v2.4s}, [x27], #48 // input dequant bias
|
||||
ld1 {v3.4s}, [x28], #16 // weight kernel sum
|
||||
ld1 {v3.4s}, [x28] // weight kernel sum
|
||||
MLA_WEIGHTZERO v8, v0, v3, 0
|
||||
MLA_WEIGHTZERO v9, v0, v3, 1
|
||||
MLA_WEIGHTZERO v10, v0, v3, 2
|
||||
|
@ -443,6 +449,7 @@ L4_TILE12_BLOCKNUM:
|
|||
MLA_WEIGHTZERO v17, v2, v3, 1
|
||||
MLA_WEIGHTZERO v18, v2, v3, 2
|
||||
MLA_WEIGHTZERO v19, v2, v3, 3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE12_ADD_DSTV:
|
||||
cbz x19, L4_TILE12_ACCUM_BUFFER // x19=0: first block, do not add previous block result
|
||||
|
@ -685,9 +692,11 @@ L4_TILE8_BLOCKNUM:
|
|||
bne L4LoopSz_TILE_8
|
||||
|
||||
L4Tile8Quan:
|
||||
ld1 {v0.4s}, [x12], #16 // scale
|
||||
ld1 {v0.4s}, [x12] // scale
|
||||
add x12, x12, #32
|
||||
ld1 {v2.4s, v3.4s}, [x8], x22 // x kernel sum
|
||||
ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
|
||||
ld1 {v24.4s}, [x12] // weight quan zeropoint
|
||||
add x12, x12, #32
|
||||
Int32ToFloat v8, v9, v10, v11
|
||||
Int32ToFloat v12, v13, v14, v15
|
||||
MUL_SCALE v0, v8, v9, v10, v11
|
||||
|
@ -709,7 +718,7 @@ L4_TILE8_BLOCKNUM:
|
|||
|
||||
cbz x27, L4_TILE8_ADD_DSTV
|
||||
ld1 {v2.4s, v3.4s}, [x27], x25
|
||||
ld1 {v24.4s}, [x28], #16
|
||||
ld1 {v24.4s}, [x28]
|
||||
MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3
|
||||
MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3
|
||||
MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3
|
||||
|
@ -718,6 +727,7 @@ L4_TILE8_BLOCKNUM:
|
|||
MLA_WEIGHTZERO v13, v3, v24, 1 // tile:5, oc:0-3
|
||||
MLA_WEIGHTZERO v14, v3, v24, 2 // tile:6, oc:0-3
|
||||
MLA_WEIGHTZERO v15, v3, v24, 3 // tile:7, oc:0-3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE8_ADD_DSTV:
|
||||
cbz x19, TILE8_L4_ACCUM_BUFFER
|
||||
|
@ -905,9 +915,11 @@ L4_TILE4_BLOCKNUM:
|
|||
bne L4LoopSz_TILE_4
|
||||
|
||||
L4Tile4Quan:
|
||||
ld1 {v0.4s}, [x12], #16 // scale
|
||||
ld1 {v0.4s}, [x12] // scale
|
||||
add x12, x12, #32
|
||||
ld1 {v2.4s}, [x8], x22 // x kernel sum
|
||||
ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
|
||||
ld1 {v24.4s}, [x12] // weight quan zeropoint
|
||||
add x12, x12, #32
|
||||
Int32ToFloat v8, v9, v10, v11
|
||||
MUL_SCALE v0, v8, v9, v10, v11
|
||||
|
||||
|
@ -922,11 +934,12 @@ L4_TILE4_BLOCKNUM:
|
|||
|
||||
cbz x27, L4_TILE4_ADD_DSTV
|
||||
ld1 {v2.4s}, [x27], x25
|
||||
ld1 {v24.4s}, [x28], #16
|
||||
ld1 {v24.4s}, [x28]
|
||||
MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3
|
||||
MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3
|
||||
MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3
|
||||
MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE4_ADD_DSTV:
|
||||
cbz x19, TILE4_L4_ACCUM_BUFFER
|
||||
|
@ -1115,9 +1128,11 @@ L4_TILE1_BLOCKNUM:
|
|||
bne L4LoopSz_TILE_1_lu1
|
||||
|
||||
L4Tile1Quan:
|
||||
ld1 {v0.4s}, [x12], #16 // scale
|
||||
ld1 {v0.4s}, [x12] // scale
|
||||
add x12, x12, #32
|
||||
ld1 {v2.s}[0], [x8], x22 // x kernel sum
|
||||
ld1 {v24.4s}, [x12], #16 // weight quan zeropoint
|
||||
ld1 {v24.4s}, [x12] // weight quan zeropoint
|
||||
add x12, x12, #32
|
||||
scvtf v8.4s, v8.4s
|
||||
fmul v8.4s, v8.4s, v0.4s
|
||||
|
||||
|
@ -1129,8 +1144,9 @@ L4_TILE1_BLOCKNUM:
|
|||
|
||||
cbz x27, L4_TILE1_ADD_DSTV
|
||||
ld1 {v2.s}[0], [x27], x25
|
||||
ld1 {v24.4s}, [x28], #16
|
||||
ld1 {v24.4s}, [x28]
|
||||
MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE1_ADD_DSTV:
|
||||
cbz x19, L4_TILE1_L8_ACCUM_BUFFER
|
||||
|
|
|
@ -81,8 +81,12 @@ sub \rg0, \rg0, \rg1
|
|||
.endm
|
||||
|
||||
.macro REVERT_WEIGHT_KERNEL_SUM rg0, rg1, rg2, rg3
|
||||
mul \rg1, \rg2, \rg3 // blocknum * ocUp4 * sizeof(float)
|
||||
sub \rg0, \rg0, \rg1, LSL #4 // revert weight kernel sum
|
||||
// y=UP_DIV(ocDiv4,(hp/pack))
|
||||
add \rg1, \rg3, #1
|
||||
lsr \rg1, \rg1, #1
|
||||
// blockNum * y * (hp * sizeof(float))
|
||||
mul \rg1, \rg2, \rg1
|
||||
sub \rg0, \rg0, \rg1, LSL #5 // revert weight kernel sum
|
||||
.endm
|
||||
|
||||
asm_function MNNGemmInt8AddBiasScale_ARMV86_w4_Unit
|
||||
|
@ -424,14 +428,16 @@ L4_LoopSz_TILE_10:
|
|||
scvtf v9.4s, v9.4s
|
||||
|
||||
L4_Tile10Quan:
|
||||
ld1 {v20.4s}, [x2], #16 // weight scale
|
||||
ld1 {v20.4s}, [x2] // weight scale
|
||||
add x2, x2, #32
|
||||
ld1 {v22.4s, v23.4s}, [x8], #32 // x kernel sum
|
||||
ld1 {v24.d}[0], [x8], #8
|
||||
ld1 {v25.4s}, [x2], #16 // weight quan zeropoint
|
||||
ld1 {v25.4s}, [x2] // weight quan zeropoint
|
||||
MUL_SCALE v20, v0, v1, v2, v3
|
||||
MUL_SCALE v20, v4, v5, v6, v7
|
||||
fmul v8.4s, v8.4s, v20.4s
|
||||
fmul v9.4s, v9.4s, v20.4s
|
||||
add x2, x2, #32
|
||||
|
||||
ld1 {v27.4s, v28.4s}, [x23], #32 // input dequant scale
|
||||
ld1 {v29.d}[0], [x23], x24
|
||||
|
@ -455,7 +461,7 @@ L4_Tile10Quan:
|
|||
cbz x27, L4_TILE10_ADD_DSTV
|
||||
ld1 {v22.4s, v23.4s}, [x27], #32 // input dequant bias
|
||||
ld1 {v24.2s}, [x27], #8
|
||||
ld1 {v25.4s}, [x28], #16 // weight kernel sum
|
||||
ld1 {v25.4s}, [x28] // weight kernel sum
|
||||
MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3
|
||||
MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3
|
||||
MLA_WEIGHTZERO v2, v22, v25, 2 // tile:2, oc:0-3
|
||||
|
@ -466,6 +472,7 @@ L4_Tile10Quan:
|
|||
MLA_WEIGHTZERO v7, v23, v25, 3 // tile:7, oc:0-3
|
||||
MLA_WEIGHTZERO v8, v24, v25, 0 // tile:8, oc:0-3
|
||||
MLA_WEIGHTZERO v9, v24, v25, 1 // tile:9, oc:0-3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE10_ADD_DSTV:
|
||||
cbz x19, L4_TILE10_TEMP_BUFFER
|
||||
|
@ -756,9 +763,11 @@ L4_LoopSz_TILE_8:
|
|||
Int32ToFloat v4, v5, v6, v7
|
||||
|
||||
L4_Tile8Quan:
|
||||
ld1 {v20.4s}, [x12], #16 // scale
|
||||
ld1 {v20.4s}, [x12] // scale
|
||||
add x12, x12, #32
|
||||
ld1 {v22.4s, v23.4s}, [x8] // x kernel sum
|
||||
ld1 {v25.4s}, [x12], #16 // weight quan zeropoint
|
||||
ld1 {v25.4s}, [x12] // weight quan zeropoint
|
||||
add x12, x12, #32
|
||||
add x8, x8, x22, LSR #1
|
||||
MUL_SCALE v20, v0, v1, v2, v3
|
||||
MUL_SCALE v20, v4, v5, v6, v7
|
||||
|
@ -779,7 +788,7 @@ L4_Tile8Quan:
|
|||
|
||||
cbz x27, L4_TILE8_ADD_DSTV
|
||||
ld1 {v22.4s, v23.4s}, [x27], x25 // input dequant bias
|
||||
ld1 {v25.4s}, [x28], #16 // weight kernel sum
|
||||
ld1 {v25.4s}, [x28] // weight kernel sum
|
||||
MLA_WEIGHTZERO v0, v22, v25, 0
|
||||
MLA_WEIGHTZERO v1, v22, v25, 1
|
||||
MLA_WEIGHTZERO v2, v22, v25, 2
|
||||
|
@ -788,6 +797,7 @@ L4_Tile8Quan:
|
|||
MLA_WEIGHTZERO v5, v23, v25, 1
|
||||
MLA_WEIGHTZERO v6, v23, v25, 2
|
||||
MLA_WEIGHTZERO v7, v23, v25, 3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE8_ADD_DSTV:
|
||||
cbz x19, L4_TILE8_TEMP_BUFFER
|
||||
|
@ -991,9 +1001,11 @@ L4_LoopSz_TILE_4:
|
|||
Int32ToFloat v0, v1, v2, v3
|
||||
|
||||
L4_Tile4Quan:
|
||||
ld1 {v20.4s}, [x12], #16 // scale
|
||||
ld1 {v20.4s}, [x12] // scale
|
||||
add x12, x12, #32
|
||||
ld1 {v22.4s}, [x8] // x kernel sum
|
||||
ld1 {v25.4s}, [x12], #16 // weight quan zeropoint
|
||||
ld1 {v25.4s}, [x12] // weight quan zeropoint
|
||||
add x12, x12, #32
|
||||
MUL_SCALE v20, v0, v1, v2, v3
|
||||
add x8, x8, x22, LSR #1
|
||||
|
||||
|
@ -1008,11 +1020,12 @@ L4_Tile4Quan:
|
|||
|
||||
cbz x27, L4_TILE4_ADD_DSTV
|
||||
ld1 {v22.4s}, [x27], x25 // input dequant bias
|
||||
ld1 {v25.4s}, [x28], #16 // weight kernel sum
|
||||
ld1 {v25.4s}, [x28] // weight kernel sum
|
||||
MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3
|
||||
MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3
|
||||
MLA_WEIGHTZERO v2, v22, v25, 2 // tile:2, oc:0-3
|
||||
MLA_WEIGHTZERO v3, v22, v25, 3 // tile:3, oc:0-3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE4_ADD_DSTV:
|
||||
cbz x19, L4_TILE4_ACCUM_BUFFER
|
||||
|
@ -1180,15 +1193,17 @@ L4_LoopSz_TILE_2:
|
|||
scvtf v1.4s, v1.4s
|
||||
|
||||
L4_Tile2Quan:
|
||||
ld1 {v20.4s}, [x12], #16 // scale
|
||||
ld1 {v20.4s}, [x12] // scale
|
||||
ld1 {v22.d}[0], [x8] // x kernel sum
|
||||
ld1 {v25.4s}, [x12], #16 // weight quan zeropoint
|
||||
add x12, x12, #32
|
||||
ld1 {v25.4s}, [x12] // weight quan zeropoint
|
||||
fmul v0.4s, v0.4s, v20.4s
|
||||
fmul v1.4s, v1.4s, v20.4s
|
||||
add x8, x8, x22, LSR #1
|
||||
ld1 {v27.d}[0], [x23], x25
|
||||
fmul v0.4s, v0.4s, v27.s[0]
|
||||
fmul v1.4s, v1.4s, v27.s[1]
|
||||
add x12, x12, #32
|
||||
|
||||
L4_TILE2_MLA:
|
||||
MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3
|
||||
|
@ -1196,9 +1211,10 @@ L4_Tile2Quan:
|
|||
|
||||
cbz x27, L4_TILE2_ADD_DSTV
|
||||
ld1 {v22.2s}, [x27], x25 // input dequant bias
|
||||
ld1 {v25.4s}, [x28], #16 // weight kernel sum
|
||||
ld1 {v25.4s}, [x28] // weight kernel sum
|
||||
MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3
|
||||
MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE2_ADD_DSTV:
|
||||
cbz x19, L4_TILE2_ACCUM_BUFFER
|
||||
|
@ -1460,11 +1476,13 @@ L4_LoopSzEnd_TILE_1:
|
|||
scvtf v25.4s, v25.4s
|
||||
|
||||
L4_Tile1Quan:
|
||||
ld1 {v0.4s}, [x12], #16 // scale
|
||||
ld1 {v0.4s}, [x12] // scale
|
||||
add x12, x12, #32 // hp*sizeof(float)
|
||||
ld1 {v6.s}[0], [x8] // x kernel sum
|
||||
ld1 {v8.4s}, [x12], #16 // weight quan zeropoint
|
||||
ld1 {v8.4s}, [x12] // weight quan zeropoint
|
||||
fmul v25.4s, v25.4s, v0.4s
|
||||
add x8, x8, x22, LSR #1
|
||||
add x12, x12, #32
|
||||
|
||||
ld1 {v10.s}[0], [x23], x25
|
||||
fmul v25.4s, v25.4s, v10.s[0]
|
||||
|
@ -1474,8 +1492,9 @@ L4_Tile1Quan:
|
|||
|
||||
cbz x27, L4_TILE1_ADD_DSTV
|
||||
ld1 {v6.s}[0], [x27], x25 // input dequant bias
|
||||
ld1 {v8.4s}, [x28], #16 // weight kernel sum
|
||||
ld1 {v8.4s}, [x28] // weight kernel sum
|
||||
MLA_WEIGHTZERO v25, v6, v8, 0 // tile:0, oc:0-3
|
||||
add x28, x28, #32
|
||||
|
||||
L4_TILE1_ADD_DSTV:
|
||||
cbz x19, L4_TILE1_ACCUM_BUFFER
|
||||
|
|
|
@ -0,0 +1,339 @@
|
|||
//
|
||||
// MNNGeneralIm2col_Fp16Sme2.S
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2024/12/25.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifdef __aarch64__
|
||||
|
||||
#include "MNNAsmGlobal.h"
|
||||
.text
|
||||
.align 5
|
||||
|
||||
//void MNNGeneralIm2col_Fp16Sme2(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack)
|
||||
asm_function MNNGeneralIm2col_Fp16Sme2
|
||||
|
||||
// x0:destOrigin, x1:sourceGroup, x2:info, x3:el, x4:LP, x5:pack
|
||||
stp d14, d15, [sp, #(-16 * 5)]!
|
||||
stp d12, d13, [sp, #(16 * 1)]
|
||||
stp d10, d11, [sp, #(16 * 2)]
|
||||
stp d8, d9, [sp, #(16 * 3)]
|
||||
stp x19, x20, [sp, #(16 * 4)]
|
||||
|
||||
// load el info
|
||||
ldr w6, [x2, #0] // number
|
||||
ldr w7, [x2, #4] // eReal
|
||||
ldr w15, [x2, #8] // eDest (< EP)
|
||||
ldr w9, [x2, #12] // offset (stride)
|
||||
ldr x14, [x1, #0] // src start
|
||||
lsl x9, x9, #4 // pack*offset*sizeof(float16_t)
|
||||
// stride
|
||||
lsl x19, x15, #3 // eDest*LP*sizeof(float16_t)
|
||||
lsl x7, x7, #4 // eReal*pack*sizeof(float16_t)
|
||||
mov x20, #3 // Sme2,LP=4
|
||||
|
||||
LoopNum:
|
||||
|
||||
ldr w10, [x3], #4 // e
|
||||
ldr w11, [x3], #4 // l
|
||||
ldr w12, [x3], #4 // eOffset
|
||||
ldr w13, [x3], #4 // lOffset
|
||||
// dst address: x2
|
||||
and x2, x13, x20 // lR
|
||||
sub x13, x13, x2 // lOffset-lR
|
||||
mul x13, x13, x15 // (lOffset-lR)*(eDest)
|
||||
add x13, x13, x2 // (lOffset-lR)*(eDest)+lR
|
||||
add x13, x13, x12, LSL #2 // + eoffset*lp
|
||||
add x2, x0, x13, LSL #1 // *sizeof(float16_t)
|
||||
|
||||
lsl x8, x19, #1 // 2*(eDest*LP*sizeof(float16_t))
|
||||
cmp x11, #8
|
||||
blt LoopL4
|
||||
|
||||
LoopL8:
|
||||
mov x5, x2
|
||||
mov x4, x14
|
||||
mov x13, x10
|
||||
add x12, x2, x19 // eDest*LP*sizeof(float16_t)
|
||||
|
||||
cmp x13, #16
|
||||
blt LoopL8E12
|
||||
|
||||
|
||||
sub x8, x8, #64
|
||||
LoopL8E16:
|
||||
sub x13, x13, #16
|
||||
ld1 {v0.8h}, [x14], x9
|
||||
ld1 {v1.8h}, [x14], x9
|
||||
ld1 {v2.8h}, [x14], x9
|
||||
ld1 {v3.8h}, [x14], x9
|
||||
ld1 {v4.8h}, [x14], x9
|
||||
ld1 {v5.8h}, [x14], x9
|
||||
ld1 {v6.8h}, [x14], x9
|
||||
ld1 {v7.8h}, [x14], x9
|
||||
ld1 {v8.8h}, [x14], x9
|
||||
ld1 {v9.8h}, [x14], x9
|
||||
ld1 {v10.8h}, [x14], x9
|
||||
ld1 {v11.8h}, [x14], x9
|
||||
ld1 {v12.8h}, [x14], x9
|
||||
ld1 {v13.8h}, [x14], x9
|
||||
ld1 {v14.8h}, [x14], x9
|
||||
ld1 {v15.8h}, [x14], x9
|
||||
zip1 v16.2d, v0.2d, v1.2d
|
||||
zip1 v17.2d, v2.2d, v3.2d
|
||||
zip1 v18.2d, v4.2d, v5.2d
|
||||
zip1 v19.2d, v6.2d, v7.2d
|
||||
zip1 v20.2d, v8.2d, v9.2d
|
||||
zip1 v21.2d, v10.2d, v11.2d
|
||||
zip1 v22.2d, v12.2d, v13.2d
|
||||
zip1 v23.2d, v14.2d, v15.2d
|
||||
|
||||
|
||||
zip2 v24.2d, v0.2d, v1.2d
|
||||
zip2 v25.2d, v2.2d, v3.2d
|
||||
zip2 v26.2d, v4.2d, v5.2d
|
||||
zip2 v27.2d, v6.2d, v7.2d
|
||||
zip2 v28.2d, v8.2d, v9.2d
|
||||
zip2 v29.2d, v10.2d, v11.2d
|
||||
zip2 v30.2d, v12.2d, v13.2d
|
||||
zip2 v31.2d, v14.2d, v15.2d
|
||||
|
||||
st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x2], #64
|
||||
st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x2], x8
|
||||
st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x12], #64
|
||||
st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x12], x8
|
||||
cmp x13, #16
|
||||
bge LoopL8E16
|
||||
add x8, x8, #64
|
||||
|
||||
LoopL8E12:
|
||||
cmp x13, #12
|
||||
blt LoopL8E8
|
||||
ld1 {v0.8h}, [x14], x9
|
||||
ld1 {v1.8h}, [x14], x9
|
||||
ld1 {v2.8h}, [x14], x9
|
||||
ld1 {v3.8h}, [x14], x9
|
||||
ld1 {v4.8h}, [x14], x9
|
||||
ld1 {v5.8h}, [x14], x9
|
||||
ld1 {v6.8h}, [x14], x9
|
||||
ld1 {v7.8h}, [x14], x9
|
||||
ld1 {v8.8h}, [x14], x9
|
||||
ld1 {v9.8h}, [x14], x9
|
||||
ld1 {v10.8h}, [x14], x9
|
||||
ld1 {v11.8h}, [x14], x9
|
||||
zip1 v12.2d, v0.2d, v1.2d
|
||||
zip1 v13.2d, v2.2d, v3.2d
|
||||
zip1 v14.2d, v4.2d, v5.2d
|
||||
zip1 v15.2d, v6.2d, v7.2d
|
||||
zip1 v16.2d, v8.2d, v9.2d
|
||||
zip1 v17.2d, v10.2d, v11.2d
|
||||
zip2 v18.2d, v0.2d, v1.2d
|
||||
zip2 v19.2d, v2.2d, v3.2d
|
||||
zip2 v20.2d, v4.2d, v5.2d
|
||||
zip2 v21.2d, v6.2d, v7.2d
|
||||
zip2 v22.2d, v8.2d, v9.2d
|
||||
zip2 v23.2d, v10.2d, v11.2d
|
||||
|
||||
st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x2], #64
|
||||
st1 {v16.8h, v17.8h}, [x2], #32
|
||||
st1 {v18.8h, v19.8h, v20.8h, v21.8h}, [x12], #64
|
||||
st1 {v22.8h, v23.8h}, [x12], #32
|
||||
sub x13, x13, #12
|
||||
|
||||
LoopL8E8:
|
||||
cmp x13, #8
|
||||
blt LoopL8E4
|
||||
ld1 {v0.8h}, [x14], x9
|
||||
ld1 {v1.8h}, [x14], x9
|
||||
ld1 {v2.8h}, [x14], x9
|
||||
ld1 {v3.8h}, [x14], x9
|
||||
ld1 {v4.8h}, [x14], x9
|
||||
ld1 {v5.8h}, [x14], x9
|
||||
ld1 {v6.8h}, [x14], x9
|
||||
ld1 {v7.8h}, [x14], x9
|
||||
zip1 v12.2d, v0.2d, v1.2d
|
||||
zip1 v13.2d, v2.2d, v3.2d
|
||||
zip1 v14.2d, v4.2d, v5.2d
|
||||
zip1 v15.2d, v6.2d, v7.2d
|
||||
zip2 v18.2d, v0.2d, v1.2d
|
||||
zip2 v19.2d, v2.2d, v3.2d
|
||||
zip2 v20.2d, v4.2d, v5.2d
|
||||
zip2 v21.2d, v6.2d, v7.2d
|
||||
st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x2], #64
|
||||
st1 {v18.8h, v19.8h, v20.8h, v21.8h}, [x12], #64
|
||||
sub x13, x13, #8
|
||||
|
||||
LoopL8E4:
|
||||
cmp x13, #4
|
||||
blt LoopL8E2
|
||||
ld1 {v0.8h}, [x14], x9
|
||||
ld1 {v1.8h}, [x14], x9
|
||||
ld1 {v2.8h}, [x14], x9
|
||||
ld1 {v3.8h}, [x14], x9
|
||||
zip1 v12.2d, v0.2d, v1.2d
|
||||
zip1 v13.2d, v2.2d, v3.2d
|
||||
zip2 v18.2d, v0.2d, v1.2d
|
||||
zip2 v19.2d, v2.2d, v3.2d
|
||||
st1 {v12.8h, v13.8h}, [x2], #32
|
||||
st1 {v18.8h, v19.8h}, [x12], #32
|
||||
sub x13, x13, #4
|
||||
|
||||
LoopL8E2:
|
||||
cmp x13, #2
|
||||
blt LoopL8E1
|
||||
ld1 {v0.8h}, [x14], x9
|
||||
ld1 {v1.8h}, [x14], x9
|
||||
zip1 v12.2d, v0.2d, v1.2d
|
||||
zip2 v18.2d, v0.2d, v1.2d
|
||||
st1 {v12.8h}, [x2], #16
|
||||
st1 {v18.8h}, [x12], #16
|
||||
sub x13, x13, #2
|
||||
|
||||
LoopL8E1:
|
||||
cmp x13, #1
|
||||
blt EndL8LoopE
|
||||
ld1 {v0.8h}, [x14], x9
|
||||
st1 {v0.d}[0], [x2], #8
|
||||
st1 {v0.d}[1], [x12], #8
|
||||
|
||||
EndL8LoopE:
|
||||
sub x11, x11, #8
|
||||
cmp x11, #8
|
||||
add x2, x5, x8 // eDest*LP*2*sizeof(float16_t)
|
||||
add x14, x4, x7
|
||||
bge LoopL8
|
||||
cbz x11, EndLoopL
|
||||
|
||||
LoopL4:
|
||||
mov x5, x2
|
||||
mov x4, x14
|
||||
mov x13, x10
|
||||
|
||||
sub x8, x19, #64
|
||||
|
||||
cmp x13, #16
|
||||
blt LoopL4E12
|
||||
|
||||
LoopL4E16:
|
||||
sub x13, x13, #16
|
||||
ld1 {v0.8h}, [x14], x9
|
||||
ld1 {v1.8h}, [x14], x9
|
||||
ld1 {v2.8h}, [x14], x9
|
||||
ld1 {v3.8h}, [x14], x9
|
||||
ld1 {v4.8h}, [x14], x9
|
||||
ld1 {v5.8h}, [x14], x9
|
||||
ld1 {v6.8h}, [x14], x9
|
||||
ld1 {v7.8h}, [x14], x9
|
||||
ld1 {v8.8h}, [x14], x9
|
||||
ld1 {v9.8h}, [x14], x9
|
||||
ld1 {v10.8h}, [x14], x9
|
||||
ld1 {v11.8h}, [x14], x9
|
||||
ld1 {v12.8h}, [x14], x9
|
||||
ld1 {v13.8h}, [x14], x9
|
||||
ld1 {v14.8h}, [x14], x9
|
||||
ld1 {v15.8h}, [x14], x9
|
||||
|
||||
zip1 v16.2d, v0.2d, v1.2d
|
||||
zip1 v17.2d, v2.2d, v3.2d
|
||||
zip1 v18.2d, v4.2d, v5.2d
|
||||
zip1 v19.2d, v6.2d, v7.2d
|
||||
zip1 v20.2d, v8.2d, v9.2d
|
||||
zip1 v21.2d, v10.2d, v11.2d
|
||||
zip1 v22.2d, v12.2d, v13.2d
|
||||
zip1 v23.2d, v14.2d, v15.2d
|
||||
|
||||
st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x2], #64
|
||||
st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x2], x8
|
||||
cmp x13, #16
|
||||
bge LoopL4E16
|
||||
|
||||
LoopL4E12:
|
||||
cmp x13, #12
|
||||
blt LoopL4E8
|
||||
sub x13, x13, #12
|
||||
ld1 {v0.8h}, [x14], x9
|
||||
ld1 {v1.8h}, [x14], x9
|
||||
ld1 {v2.8h}, [x14], x9
|
||||
ld1 {v3.8h}, [x14], x9
|
||||
ld1 {v4.8h}, [x14], x9
|
||||
ld1 {v5.8h}, [x14], x9
|
||||
ld1 {v6.8h}, [x14], x9
|
||||
ld1 {v7.8h}, [x14], x9
|
||||
ld1 {v8.8h}, [x14], x9
|
||||
ld1 {v9.8h}, [x14], x9
|
||||
ld1 {v10.8h}, [x14], x9
|
||||
ld1 {v11.8h}, [x14], x9
|
||||
zip1 v12.2d, v0.2d, v1.2d
|
||||
zip1 v13.2d, v2.2d, v3.2d
|
||||
zip1 v14.2d, v4.2d, v5.2d
|
||||
zip1 v15.2d, v6.2d, v7.2d
|
||||
zip1 v16.2d, v8.2d, v9.2d
|
||||
zip1 v17.2d, v10.2d, v11.2d
|
||||
|
||||
st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x2], #64
|
||||
st1 {v16.8h, v17.8h}, [x2], #32
|
||||
|
||||
LoopL4E8:
|
||||
cmp x13, #8
|
||||
blt LoopL4E4
|
||||
ld1 {v0.8h}, [x14], x9
|
||||
ld1 {v1.8h}, [x14], x9
|
||||
ld1 {v2.8h}, [x14], x9
|
||||
ld1 {v3.8h}, [x14], x9
|
||||
ld1 {v4.8h}, [x14], x9
|
||||
ld1 {v5.8h}, [x14], x9
|
||||
ld1 {v6.8h}, [x14], x9
|
||||
ld1 {v7.8h}, [x14], x9
|
||||
zip1 v12.2d, v0.2d, v1.2d
|
||||
zip1 v13.2d, v2.2d, v3.2d
|
||||
zip1 v14.2d, v4.2d, v5.2d
|
||||
zip1 v15.2d, v6.2d, v7.2d
|
||||
st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x2], #64
|
||||
sub x13, x13, #8
|
||||
|
||||
LoopL4E4:
|
||||
cmp x13, #4
|
||||
blt LoopL4E2
|
||||
ld1 {v0.8h}, [x14], x9
|
||||
ld1 {v1.8h}, [x14], x9
|
||||
ld1 {v2.8h}, [x14], x9
|
||||
ld1 {v3.8h}, [x14], x9
|
||||
zip1 v12.2d, v0.2d, v1.2d
|
||||
zip1 v13.2d, v2.2d, v3.2d
|
||||
st1 {v12.8h, v13.8h}, [x2], #32
|
||||
sub x13, x13, #4
|
||||
|
||||
LoopL4E2:
|
||||
cmp x13, #2
|
||||
blt LoopL4E1
|
||||
ld1 {v0.8h}, [x14], x9
|
||||
ld1 {v1.8h}, [x14], x9
|
||||
zip1 v12.2d, v0.2d, v1.2d
|
||||
st1 {v12.8h}, [x2], #16
|
||||
sub x13, x13, #2
|
||||
|
||||
LoopL4E1:
|
||||
cmp x13, #1
|
||||
blt EndLoopL
|
||||
ld1 {v0.8h}, [x14], x9
|
||||
st1 {v0.d}[0], [x2], #8
|
||||
|
||||
EndLoopL:
|
||||
subs x6, x6, #1
|
||||
add x1, x1, #8
|
||||
ldr x14, [x1, #0]
|
||||
bne LoopNum
|
||||
|
||||
|
||||
End:
|
||||
ldp x19, x20, [sp, #(16 * 4)]
|
||||
ldp d8, d9, [sp, #(16 * 3)]
|
||||
ldp d10, d11, [sp, #(16 * 2)]
|
||||
ldp d12, d13, [sp, #(16 * 1)]
|
||||
ldp d14, d15, [sp], #(16 * 5)
|
||||
ret
|
||||
|
||||
#endif
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
//
|
||||
// MNNGeneralIm2col_Fp32Sme2.S
|
||||
// MNN
|
||||
//
|
||||
// Created by MNN on 2025/01/13.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#ifdef __aarch64__
|
||||
|
||||
#include "MNNAsmGlobal.h"
|
||||
.text
|
||||
.align 5
|
||||
|
||||
//void MNNGeneralIm2col_Fp32Sme2(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack)
|
||||
asm_function MNNGeneralIm2col_Fp32Sme2
|
||||
|
||||
// x0:destOrigin, x1:sourceGroup, x2:info, x3:el, x4:LP, x5:pack
|
||||
stp d14, d15, [sp, #(-16 * 5)]!
|
||||
stp d12, d13, [sp, #(16 * 1)]
|
||||
stp d10, d11, [sp, #(16 * 2)]
|
||||
stp d8, d9, [sp, #(16 * 3)]
|
||||
stp x19, x20, [sp, #(16 * 4)]
|
||||
|
||||
// load el info
|
||||
ldr w6, [x2, #0] // number
|
||||
ldr w7, [x2, #4] // eReal
|
||||
ldr w15, [x2, #8] // eDest (< EP)
|
||||
ldr w9, [x2, #12] // offset (stride)
|
||||
ldr x14, [x1, #0] // src start
|
||||
lsl x9, x9, #4 // pack*offset*sizeof(float32_t)
|
||||
// stride
|
||||
lsl x19, x15, #4 // eDest*LP*sizeof(float32_t)
|
||||
lsl x7, x7, #4 // eReal*pack*sizeof(float32_t)
|
||||
mov x20, #3 // Arm82,LP=4
|
||||
|
||||
LoopNum:
|
||||
|
||||
ldr w10, [x3], #4 // e
|
||||
ldr w11, [x3], #4 // l
|
||||
ldr w12, [x3], #4 // eOffset
|
||||
ldr w13, [x3], #4 // lOffset
|
||||
// dst address: x2
|
||||
and x2, x13, x20 // lR
|
||||
sub x13, x13, x2 // lOffset-lR
|
||||
mul x13, x13, x15 // (lOffset-lR)*(eDest)
|
||||
add x13, x13, x2 // (lOffset-lR)*(eDest)+lR
|
||||
add x13, x13, x12, LSL #2 // + eoffset*lp
|
||||
add x2, x0, x13, LSL #2 // *sizeof(float32_t)
|
||||
|
||||
LoopL4:
|
||||
mov x5, x2
|
||||
mov x4, x14
|
||||
mov x13, x10
|
||||
|
||||
cmp x13, #16
|
||||
blt LoopL4E12
|
||||
|
||||
LoopL4E16:
|
||||
sub x13, x13, #16
|
||||
ld1 {v0.4s}, [x14], x9
|
||||
ld1 {v1.4s}, [x14], x9
|
||||
ld1 {v2.4s}, [x14], x9
|
||||
ld1 {v3.4s}, [x14], x9
|
||||
ld1 {v4.4s}, [x14], x9
|
||||
ld1 {v5.4s}, [x14], x9
|
||||
ld1 {v6.4s}, [x14], x9
|
||||
ld1 {v7.4s}, [x14], x9
|
||||
ld1 {v8.4s}, [x14], x9
|
||||
ld1 {v9.4s}, [x14], x9
|
||||
ld1 {v10.4s}, [x14], x9
|
||||
ld1 {v11.4s}, [x14], x9
|
||||
ld1 {v12.4s}, [x14], x9
|
||||
ld1 {v13.4s}, [x14], x9
|
||||
ld1 {v14.4s}, [x14], x9
|
||||
ld1 {v15.4s}, [x14], x9
|
||||
|
||||
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64
|
||||
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x2], #64
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64
|
||||
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64
|
||||
cmp x13, #16
|
||||
bge LoopL4E16
|
||||
|
||||
LoopL4E12:
|
||||
cmp x13, #12
|
||||
blt LoopL4E8
|
||||
ld1 {v0.4s}, [x14], x9
|
||||
ld1 {v1.4s}, [x14], x9
|
||||
ld1 {v2.4s}, [x14], x9
|
||||
ld1 {v3.4s}, [x14], x9
|
||||
ld1 {v4.4s}, [x14], x9
|
||||
ld1 {v5.4s}, [x14], x9
|
||||
ld1 {v6.4s}, [x14], x9
|
||||
ld1 {v7.4s}, [x14], x9
|
||||
ld1 {v8.4s}, [x14], x9
|
||||
ld1 {v9.4s}, [x14], x9
|
||||
ld1 {v10.4s}, [x14], x9
|
||||
ld1 {v11.4s}, [x14], x9
|
||||
sub x13, x13, #12
|
||||
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64
|
||||
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x2], #64
|
||||
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64
|
||||
|
||||
LoopL4E8:
|
||||
cmp x13, #8
|
||||
blt LoopL4E4
|
||||
ld1 {v0.4s}, [x14], x9
|
||||
ld1 {v1.4s}, [x14], x9
|
||||
ld1 {v2.4s}, [x14], x9
|
||||
ld1 {v3.4s}, [x14], x9
|
||||
ld1 {v4.4s}, [x14], x9
|
||||
ld1 {v5.4s}, [x14], x9
|
||||
ld1 {v6.4s}, [x14], x9
|
||||
ld1 {v7.4s}, [x14], x9
|
||||
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64
|
||||
st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x2], #64
|
||||
sub x13, x13, #8
|
||||
|
||||
LoopL4E4:
|
||||
cmp x13, #4
|
||||
blt LoopL4E2
|
||||
ld1 {v0.4s}, [x14], x9
|
||||
ld1 {v1.4s}, [x14], x9
|
||||
ld1 {v2.4s}, [x14], x9
|
||||
ld1 {v3.4s}, [x14], x9
|
||||
st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64
|
||||
sub x13, x13, #4
|
||||
|
||||
LoopL4E2:
|
||||
cmp x13, #2
|
||||
blt LoopL4E1
|
||||
ld1 {v0.4s}, [x14], x9
|
||||
ld1 {v1.4s}, [x14], x9
|
||||
st1 {v0.4s, v1.4s}, [x2], #32
|
||||
sub x13, x13, #2
|
||||
|
||||
LoopL4E1:
|
||||
cmp x13, #1
|
||||
blt EndL4LoopE
|
||||
ld1 {v0.4s}, [x14], x9
|
||||
st1 {v0.4s}, [x2], #16
|
||||
|
||||
EndL4LoopE:
|
||||
add x2, x5, x19 // eDest*LP*sizeof(float32_t)
|
||||
add x14, x4, x7
|
||||
subs x11, x11, #4
|
||||
bne LoopL4
|
||||
|
||||
EndLoopL:
|
||||
subs x6, x6, #1
|
||||
add x1, x1, #8
|
||||
ldr x14, [x1, #0]
|
||||
bne LoopNum
|
||||
|
||||
|
||||
End:
|
||||
ldp x19, x20, [sp, #(16 * 4)]
|
||||
ldp d8, d9, [sp, #(16 * 3)]
|
||||
ldp d10, d11, [sp, #(16 * 2)]
|
||||
ldp d12, d13, [sp, #(16 * 1)]
|
||||
ldp d14, d15, [sp], #(16 * 5)
|
||||
ret
|
||||
|
||||
#endif
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -94,10 +94,11 @@ void ARMV86_MNNGetMatMulPackMode_BF16(int* eP, int* lP, int* hP) {
|
|||
*hP = HP;
|
||||
*lP = LP;
|
||||
}
|
||||
void ARMV86_MNNPackForMatMul_B_BF16(float* destF, const float* sourceF, size_t h, size_t l, bool transpose) {
|
||||
void ARMV86_MNNPackForMatMul_B_BF16(float* destF, const float* sourceF, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
||||
// [l, h] -> [h/hp, l/lp, hp, lp]
|
||||
auto dest = (int16_t*)destF;
|
||||
auto source = (const int32_t*)sourceF;
|
||||
auto l = kernelsize * ic;
|
||||
auto lCP = UP_DIV(l, LP);
|
||||
auto hCP = UP_DIV(h, HP);
|
||||
int sYstride = 1;
|
||||
|
@ -173,10 +174,11 @@ void ARMV86_MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGro
|
|||
#undef EP
|
||||
#undef HP
|
||||
#undef LP
|
||||
void NEON_MNNPackForMatMul_B_BF16(float* destFloat, const float* sourceFloat, size_t h, size_t l, bool transpose) {
|
||||
void NEON_MNNPackForMatMul_B_BF16(float* destFloat, const float* sourceFloat, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
||||
auto hP = (int)h / 8;
|
||||
auto hR = (int)hP * 8;
|
||||
int16_t* dest = (int16_t*)destFloat;
|
||||
auto l = kernelsize * ic;
|
||||
const float* source = sourceFloat;
|
||||
if (!transpose) {
|
||||
for (int y = 0; y < hP; ++y) {
|
||||
|
@ -246,9 +248,10 @@ void NEON_MNNPackForMatMul_B_BF16(float* destFloat, const float* sourceFloat, si
|
|||
}
|
||||
|
||||
#else
|
||||
void NEON_MNNPackForMatMul_B_BF16(float* destFloat, const float* sourceFloat, size_t h, size_t l, bool transpose) {
|
||||
void NEON_MNNPackForMatMul_B_BF16(float* destFloat, const float* sourceFloat, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
||||
int16_t* dest = (int16_t*)destFloat;
|
||||
const float* source = sourceFloat;
|
||||
auto l = kernelsize * ic;
|
||||
if (!transpose) {
|
||||
auto hP = h / 4;
|
||||
auto hR = hP * 4;
|
||||
|
|
|
@ -228,6 +228,7 @@ void MNNDynamicUpdateConvBiasScale(float* newbias, float* oldbias, float* weight
|
|||
#endif // not __aarch64__
|
||||
|
||||
static void MNNCountMaxMinValue(const float* source, float* minVal, float* maxVal, size_t size) {
|
||||
#ifndef MNN_USE_NEON
|
||||
int pack = 4;
|
||||
float max_ = source[0], min_ = source[0];
|
||||
for (int i = 1; i < size; ++i) {
|
||||
|
@ -240,6 +241,162 @@ static void MNNCountMaxMinValue(const float* source, float* minVal, float* maxVa
|
|||
}
|
||||
*minVal = min_;
|
||||
*maxVal = max_;
|
||||
#else
|
||||
auto sizeDiv4 = size / 4;
|
||||
auto remain = size - 4 * sizeDiv4;
|
||||
auto srcPtr = source;
|
||||
auto max0 = vdupq_n_f32(srcPtr[0]);
|
||||
auto min0 = vdupq_n_f32(srcPtr[0]);
|
||||
while (sizeDiv4 > 15) {
|
||||
sizeDiv4 -= 16;
|
||||
auto data0 = vld1q_f32(srcPtr);
|
||||
auto data1 = vld1q_f32(srcPtr + 4);
|
||||
auto data2 = vld1q_f32(srcPtr + 8);
|
||||
auto data3 = vld1q_f32(srcPtr + 12);
|
||||
auto data4 = vld1q_f32(srcPtr + 16);
|
||||
auto data5 = vld1q_f32(srcPtr + 20);
|
||||
auto data6 = vld1q_f32(srcPtr + 24);
|
||||
auto data7 = vld1q_f32(srcPtr + 28);
|
||||
auto data8 = vld1q_f32(srcPtr + 32);
|
||||
auto data9 = vld1q_f32(srcPtr + 36);
|
||||
auto data10 = vld1q_f32(srcPtr + 40);
|
||||
auto data11 = vld1q_f32(srcPtr + 44);
|
||||
auto data12 = vld1q_f32(srcPtr + 48);
|
||||
auto data13 = vld1q_f32(srcPtr + 52);
|
||||
auto data14 = vld1q_f32(srcPtr + 56);
|
||||
auto data15 = vld1q_f32(srcPtr + 60);
|
||||
|
||||
auto lmin0 = vminq_f32(data0, data1);
|
||||
auto lmin2 = vminq_f32(data2, data3);
|
||||
auto lmin4 = vminq_f32(data4, data5);
|
||||
auto lmin6 = vminq_f32(data6, data7);
|
||||
auto lmin8 = vminq_f32(data8, data9);
|
||||
auto lmin10 = vminq_f32(data10, data11);
|
||||
auto lmin12 = vminq_f32(data12, data13);
|
||||
auto lmin14 = vminq_f32(data14, data15);
|
||||
|
||||
auto lmax0 = vmaxq_f32(data0, data1);
|
||||
auto lmax2 = vmaxq_f32(data2, data3);
|
||||
auto lmax4 = vmaxq_f32(data4, data5);
|
||||
auto lmax6 = vmaxq_f32(data6, data7);
|
||||
auto lmax8 = vmaxq_f32(data8, data9);
|
||||
auto lmax10 = vmaxq_f32(data10, data11);
|
||||
auto lmax12 = vmaxq_f32(data12, data13);
|
||||
auto lmax14 = vmaxq_f32(data14, data15);
|
||||
|
||||
lmin0 = vminq_f32(lmin0, lmin2);
|
||||
lmin4 = vminq_f32(lmin4, lmin6);
|
||||
lmin8 = vminq_f32(lmin8, lmin10);
|
||||
lmin12 = vminq_f32(lmin12, lmin14);
|
||||
|
||||
lmax0 = vmaxq_f32(lmax0, lmax2);
|
||||
lmax4 = vmaxq_f32(lmax4, lmax6);
|
||||
lmax8 = vmaxq_f32(lmax8, lmax10);
|
||||
lmax12 = vmaxq_f32(lmax12, lmax14);
|
||||
|
||||
lmin0 = vminq_f32(lmin0, lmin8);
|
||||
lmin4 = vminq_f32(lmin4, lmin12);
|
||||
lmax0 = vmaxq_f32(lmax0, lmax8);
|
||||
lmax4 = vmaxq_f32(lmax4, lmax12);
|
||||
lmin0 = vminq_f32(lmin0, lmin4);
|
||||
lmax0 = vmaxq_f32(lmax0, lmax4);
|
||||
|
||||
max0 = vmaxq_f32(max0, lmax0);
|
||||
min0 = vminq_f32(min0, lmin0);
|
||||
srcPtr += 64;
|
||||
}
|
||||
if (sizeDiv4 > 7) {
|
||||
sizeDiv4 -= 8;
|
||||
auto data0 = vld1q_f32(srcPtr);
|
||||
auto data1 = vld1q_f32(srcPtr + 4);
|
||||
auto data2 = vld1q_f32(srcPtr + 8);
|
||||
auto data3 = vld1q_f32(srcPtr + 12);
|
||||
auto data4 = vld1q_f32(srcPtr + 16);
|
||||
auto data5 = vld1q_f32(srcPtr + 20);
|
||||
auto data6 = vld1q_f32(srcPtr + 24);
|
||||
auto data7 = vld1q_f32(srcPtr + 28);
|
||||
|
||||
auto lmin0 = vminq_f32(data0, data1);
|
||||
auto lmin2 = vminq_f32(data2, data3);
|
||||
auto lmin4 = vminq_f32(data4, data5);
|
||||
auto lmin6 = vminq_f32(data6, data7);
|
||||
|
||||
auto lmax0 = vmaxq_f32(data0, data1);
|
||||
auto lmax2 = vmaxq_f32(data2, data3);
|
||||
auto lmax4 = vmaxq_f32(data4, data5);
|
||||
auto lmax6 = vmaxq_f32(data6, data7);
|
||||
|
||||
lmin0 = vminq_f32(lmin0, lmin2);
|
||||
lmin4 = vminq_f32(lmin4, lmin6);
|
||||
|
||||
lmax0 = vmaxq_f32(lmax0, lmax2);
|
||||
lmax4 = vmaxq_f32(lmax4, lmax6);
|
||||
|
||||
lmin0 = vminq_f32(lmin0, lmin4);
|
||||
lmax0 = vmaxq_f32(lmax0, lmax4);
|
||||
|
||||
max0 = vmaxq_f32(max0, lmax0);
|
||||
min0 = vminq_f32(min0, lmin0);
|
||||
srcPtr += 32;
|
||||
}
|
||||
if (sizeDiv4 > 3) {
|
||||
sizeDiv4 -= 4;
|
||||
auto data0 = vld1q_f32(srcPtr);
|
||||
auto data1 = vld1q_f32(srcPtr + 4);
|
||||
auto data2 = vld1q_f32(srcPtr + 8);
|
||||
auto data3 = vld1q_f32(srcPtr + 12);
|
||||
|
||||
auto lmin0 = vminq_f32(data0, data1);
|
||||
auto lmin2 = vminq_f32(data2, data3);
|
||||
|
||||
auto lmax0 = vmaxq_f32(data0, data1);
|
||||
auto lmax2 = vmaxq_f32(data2, data3);
|
||||
|
||||
lmin0 = vminq_f32(lmin0, lmin2);
|
||||
lmax0 = vmaxq_f32(lmax0, lmax2);
|
||||
|
||||
max0 = vmaxq_f32(max0, lmax0);
|
||||
min0 = vminq_f32(min0, lmin0);
|
||||
srcPtr += 16;
|
||||
}
|
||||
if (sizeDiv4 > 1) {
|
||||
sizeDiv4 -= 2;
|
||||
auto data0 = vld1q_f32(srcPtr);
|
||||
auto data1 = vld1q_f32(srcPtr + 4);
|
||||
|
||||
auto lmin0 = vminq_f32(data0, data1);
|
||||
auto lmax0 = vmaxq_f32(data0, data1);
|
||||
|
||||
max0 = vmaxq_f32(max0, lmax0);
|
||||
min0 = vminq_f32(min0, lmin0);
|
||||
srcPtr += 8;
|
||||
}
|
||||
if (sizeDiv4 > 0) {
|
||||
sizeDiv4--;
|
||||
auto data0 = vld1q_f32(srcPtr);
|
||||
max0 = vmaxq_f32(max0, data0);
|
||||
min0 = vminq_f32(min0, data0);
|
||||
srcPtr += 4;
|
||||
}
|
||||
float temp0[4];
|
||||
float temp1[4];
|
||||
vst1q_f32(temp0, max0);
|
||||
vst1q_f32(temp1, min0);
|
||||
auto maxval = temp0[0];
|
||||
auto minval = temp1[0];
|
||||
for (int i = 1; i < 4; ++i) {
|
||||
maxval = ALIMAX(maxval, temp0[i]);
|
||||
minval = ALIMIN(minval, temp1[i]);
|
||||
}
|
||||
while (remain > 0) {
|
||||
maxval = ALIMAX(maxval, srcPtr[0]);
|
||||
minval = ALIMIN(minval, srcPtr[0]);
|
||||
remain--;
|
||||
srcPtr += 1;
|
||||
}
|
||||
minVal[0] = minval;
|
||||
maxVal[0] = maxval;
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef MNN_LOW_MEMORY
|
||||
|
@ -389,18 +546,28 @@ static void MNNAsyQuantInfo_FP32(float* scale, float* bias, float* qscale, float
|
|||
// dequant scale/bias : [EU, blockNum, step], step=ALIMIN(step, EP), EU=UP_DIV(plane, EP)
|
||||
// quant scale/bias : [blockNum, plane]
|
||||
#ifdef __aarch64__
|
||||
if (DST_XUNIT == 12 && innerSide == 4) { // Arm82,fp32: SRC_UNIT=4, core->pack=4
|
||||
if ((DST_XUNIT == 12 || DST_XUNIT == 16) && innerSide == 4) { // Arm82,fp32: SRC_UNIT=4, core->pack=4
|
||||
// max,min shape: [blockNum, EP]
|
||||
for (int i = 0; i < kernelsize; ++i) {
|
||||
MNNLocalMinMaxFP32_Pack4(dstMin, dstMax, src + i * stride0, blockNum, blockLU, plane, innerSide, i);
|
||||
}
|
||||
// scale, bias
|
||||
bool success = MNNAsyLocalQuantInfo_EP12_FP32(scale, bias, qscale, qbias, dstMin, dstMax, info);
|
||||
if (!success) {
|
||||
MNN_ERROR("Call error for:MNNAsyLocalQuantInfo_EP12\n");
|
||||
if (DST_XUNIT == 12) {
|
||||
bool success = MNNAsyLocalQuantInfo_EP12_FP32(scale, bias, qscale, qbias, dstMin, dstMax, info);
|
||||
if (!success) {
|
||||
MNN_ERROR("Call error for:MNNAsyLocalQuantInfo_EP12\n");
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (DST_XUNIT == 16) {
|
||||
bool success = MNNAsyLocalQuantInfo_EP16_FP32(scale, bias, qscale, qbias, dstMin, dstMax, info);
|
||||
if (!success) {
|
||||
MNN_ERROR("Call error for:MNNAsyLocalQuantInfo_EP16_FP32\n");
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (DST_XUNIT == 10) { // Arm86,fp32: SRC_UNIT=8,core->pack=4
|
||||
// max,min shape: [blockNum, EP]
|
||||
|
@ -653,6 +820,73 @@ static void MNNReorderWeightInt4Arm82(uint8_t* dest, const uint8_t* source, int3
|
|||
}
|
||||
MNNPermuteSumWeightInt4Arm82(dest, dest, blocknum * hu, lu, kernelsum);
|
||||
}
|
||||
|
||||
static void MNNReorderWeightInt4Sme2(uint8_t* dest, const uint8_t* source, int32_t* shape, size_t size, float* kernelsum) {
|
||||
MNN_ASSERT(size > 4);
|
||||
// dst shape: [hu, blocknum, kernelCount, lu, hp, lp], kernelCount=1 in this case
|
||||
auto blocknum = shape[0];
|
||||
auto hu = shape[1];
|
||||
auto lu = shape[2];
|
||||
auto hp = shape[3];
|
||||
auto lp = shape[4];
|
||||
auto ic = blocknum *lu * lp;
|
||||
auto stride0 = blocknum * hp * lu * lp;
|
||||
auto stride1 = lu * hp * lp;
|
||||
auto stride2 = hp * lp;
|
||||
auto dstPtr = (int16_t*)dest;
|
||||
auto srcPtr = (int16_t*)source;
|
||||
int unitpacksize = sizeof(int16_t) / sizeof(uint8_t);
|
||||
for (int i = 0; i < hu; ++i) {
|
||||
for (int k = 0; k < hp; ++k) {
|
||||
for (int bl = 0; bl < blocknum; ++bl) {
|
||||
int j = 0;
|
||||
while (j + 7 < lu) {
|
||||
auto srcindex = ((i * hp + k) * ic + bl * (lu * lp) + j * lp) / unitpacksize;
|
||||
auto dstindex0 = (bl * stride1 + i * stride0 + j * stride2 + k * lp) / unitpacksize;
|
||||
auto dstindex1 = (bl * stride1 + i * stride0 + (j + 1) * stride2 + k * lp) / unitpacksize;
|
||||
auto dstindex2 = (bl * stride1 + i * stride0 + (j + 2) * stride2 + k * lp) / unitpacksize;
|
||||
auto dstindex3 = (bl * stride1 + i * stride0 + (j + 3) * stride2 + k * lp) / unitpacksize;
|
||||
auto dstindex4 = (bl * stride1 + i * stride0 + (j + 4) * stride2 + k * lp) / unitpacksize;
|
||||
auto dstindex5 = (bl * stride1 + i * stride0 + (j + 5) * stride2 + k * lp) / unitpacksize;
|
||||
auto dstindex6 = (bl * stride1 + i * stride0 + (j + 6) * stride2 + k * lp) / unitpacksize;
|
||||
auto dstindex7 = (bl * stride1 + i * stride0 + (j + 7) * stride2 + k * lp) / unitpacksize;
|
||||
j += 8;
|
||||
auto srcdata = vld1q_s16(srcPtr + srcindex);
|
||||
vst1q_lane_s16(dstPtr + dstindex0, srcdata, 0);
|
||||
vst1q_lane_s16(dstPtr + dstindex1, srcdata, 1);
|
||||
vst1q_lane_s16(dstPtr + dstindex2, srcdata, 2);
|
||||
vst1q_lane_s16(dstPtr + dstindex3, srcdata, 3);
|
||||
vst1q_lane_s16(dstPtr + dstindex4, srcdata, 4);
|
||||
vst1q_lane_s16(dstPtr + dstindex5, srcdata, 5);
|
||||
vst1q_lane_s16(dstPtr + dstindex6, srcdata, 6);
|
||||
vst1q_lane_s16(dstPtr + dstindex7, srcdata, 7);
|
||||
}
|
||||
while (j + 3 < lu) {
|
||||
auto srcindex = ((i * hp + k) * ic + bl * (lu * lp) + j * lp) / unitpacksize;
|
||||
auto dstindex0 = (bl * stride1 + i * stride0 + j * stride2 + k * lp) / unitpacksize;
|
||||
auto dstindex1 = (bl * stride1 + i * stride0 + (j + 1) * stride2 + k * lp) / unitpacksize;
|
||||
auto dstindex2 = (bl * stride1 + i * stride0 + (j + 2) * stride2 + k * lp) / unitpacksize;
|
||||
auto dstindex3 = (bl * stride1 + i * stride0 + (j + 3) * stride2 + k * lp) / unitpacksize;
|
||||
j += 4;
|
||||
auto srcdata = vld1_s16(srcPtr + srcindex);
|
||||
vst1_lane_s16(dstPtr + dstindex0, srcdata, 0);
|
||||
vst1_lane_s16(dstPtr + dstindex1, srcdata, 1);
|
||||
vst1_lane_s16(dstPtr + dstindex2, srcdata, 2);
|
||||
vst1_lane_s16(dstPtr + dstindex3, srcdata, 3);
|
||||
|
||||
}
|
||||
while (j < lu)
|
||||
{
|
||||
auto srcindex = ((i * hp + k) * ic + bl * (lu * lp) + j * lp) / 2;
|
||||
auto dstindex = (bl * stride1 + i * stride0 + j * stride2 + k * lp) / 2;
|
||||
dstPtr[dstindex] = srcPtr[srcindex];
|
||||
j++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
MNNPermuteSumWeightInt4Sme2(dest, dest, blocknum * hu, lu, kernelsum);
|
||||
}
|
||||
#endif // __aarch64__
|
||||
|
||||
static void MNNSumWeightInt8(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP) {
|
||||
|
@ -1147,9 +1381,11 @@ void MNNGetSparseMatMulPackMode(int* eP, int *lP, int* hP) {
|
|||
return;
|
||||
}
|
||||
|
||||
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose) {
|
||||
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
||||
// src: [h, kernelsize, ic]
|
||||
auto hP = h / 4;
|
||||
auto hR = hP * 4;
|
||||
auto l = kernelsize * ic;
|
||||
if (hR != h) {
|
||||
::memset(dest, 0, UP_DIV(h, 4)*4*l*sizeof(float));
|
||||
}
|
||||
|
@ -3632,6 +3868,54 @@ static void generalIm2col(float* destOrigin, float const** sourceGroup, const in
|
|||
}
|
||||
#endif // MNN_LOW_MEMORY
|
||||
|
||||
#ifdef MNN_SME2
|
||||
static void SME2MNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
|
||||
*eP = 16;
|
||||
*lP = 1;
|
||||
*hP = 64;
|
||||
}
|
||||
static void MNNPackedMatMulFP32_SME2(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) {
|
||||
MNNPackedMatMulRemainFP32_SME2(C, A, B, 16, parameter, postParameters, bias, k, b);
|
||||
return;
|
||||
}
|
||||
static void Sme2MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
||||
// src: [h, kernelsize, ic]
|
||||
// dst: [h/hp, kernelsize, ic/lp, hp, lp]
|
||||
auto dest = (int32_t*)destC;
|
||||
auto source = (int32_t*)sourceC;
|
||||
int LP = 1;
|
||||
int HP = 64;
|
||||
auto l = kernelsize * ic;
|
||||
memset(dest, 0, ROUND_UP(h, HP) * ROUND_UP(ic, LP) * kernelsize * 4);
|
||||
auto stride0 = kernelsize * ROUND_UP(ic, LP) * HP;
|
||||
auto stride1 = ROUND_UP(ic, LP) * HP;
|
||||
auto stride2 = HP * LP;
|
||||
|
||||
auto srcStride0 = l; // [h,l]->[hu,lu,hp,lp]
|
||||
auto srcStride1 = 1;
|
||||
if (!transpose) { // [l,h]->[hu,lu,hp,lp]
|
||||
srcStride0 = 1;
|
||||
srcStride1 = h;
|
||||
}
|
||||
for (int y = 0; y < h; ++y) {
|
||||
auto yHu = y / HP;
|
||||
auto yHp = y % HP;
|
||||
for (int k = 0; k < kernelsize; ++k) {
|
||||
for (int x = 0; x < ic; ++x) {
|
||||
auto xLu = x / LP;
|
||||
auto xLp = x % LP;
|
||||
dest[yHu * stride0 + k * stride1 + xLu * stride2 + yHp * LP + xLp] = source[y * srcStride0 + (x + k * ic) * srcStride1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
static void Sme2MNNPackForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) {
|
||||
const int32_t infosme2[4] = {info[0], info[1], 16, info[3]};
|
||||
MNNPackC4ForMatMul_A(destOrigin, sourceGroup, infosme2, el);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace MNN {
|
||||
|
||||
static CoreFunctions* gCoreFunction = nullptr;
|
||||
|
@ -3742,19 +4026,19 @@ void MNNCoreFunctionInit() {
|
|||
gCoreFunction->supportFp16arith = gCPUInfo.fp16arith;
|
||||
gCoreFunction->supportSDot = gCPUInfo.dot;
|
||||
gCoreFunction->supportI8mm = gCPUInfo.i8mm;
|
||||
gCoreFunction->supportSME2 = gCPUInfo.sme2;
|
||||
gCoreFunction->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A;
|
||||
gCoreFunction->MNNReorderWeightInt4 = MNNReorderWeightInt4;
|
||||
gCoreFunction->MNNSumWeightInt8 = MNNSumWeightInt8;
|
||||
#ifdef __aarch64__
|
||||
if (gCoreFunction->supportSDot) {
|
||||
gCoreFunction->MNNReorderWeightInt4 = MNNReorderWeightInt4Arm82;
|
||||
gCoreFunction->MNNSumWeightInt8 = MNNSumWeightInt8Arm82;
|
||||
}
|
||||
if (gCoreFunction->supportI8mm) {
|
||||
gCoreFunction->MNNReorderWeightInt4 = MNNReorderWeightInt4Arm86;
|
||||
gCoreFunction->MNNSumWeightInt8 = MNNSumWeightInt8Arm86;
|
||||
|
||||
}
|
||||
if (gCoreFunction->supportSDot) {
|
||||
gCoreFunction->MNNReorderWeightInt4 = MNNReorderWeightInt4Arm82;
|
||||
gCoreFunction->MNNSumWeightInt8 = MNNSumWeightInt8Arm82;
|
||||
}
|
||||
if (gCoreFunction->supportI8mm) {
|
||||
gCoreFunction->MNNReorderWeightInt4 = MNNReorderWeightInt4Arm86;
|
||||
gCoreFunction->MNNSumWeightInt8 = MNNSumWeightInt8Arm86;
|
||||
}
|
||||
#endif
|
||||
#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM
|
||||
// Weight Dequant Gemm Kernels
|
||||
|
@ -3778,6 +4062,42 @@ void MNNCoreFunctionInit() {
|
|||
}
|
||||
#endif
|
||||
#endif
|
||||
{ // backendMatmulRelatedFunctions
|
||||
gCoreFunction->backendMatmulRelatedFunctions.MNNReorderWeightInt4 = gCoreFunction->MNNReorderWeightInt4;
|
||||
gCoreFunction->backendMatmulRelatedFunctions.MNNSumWeightInt8 = gCoreFunction->MNNSumWeightInt8;
|
||||
gCoreFunction->backendMatmulRelatedFunctions.MNNGeneralIm2Col = gCoreFunction->MNNGeneralIm2Col;
|
||||
|
||||
gCoreFunction->backendMatmulRelatedFunctions.MNNPackedMatMul = gCoreFunction->MNNPackedMatMul;
|
||||
gCoreFunction->backendMatmulRelatedFunctions.MNNPackedMatMulRemain = gCoreFunction->MNNPackedMatMulRemain;
|
||||
gCoreFunction->backendMatmulRelatedFunctions.MNNGetMatMulPackMode = gCoreFunction->MNNGetMatMulPackMode;
|
||||
gCoreFunction->backendMatmulRelatedFunctions.MNNPackC4ForMatMul_A = gCoreFunction->MNNPackC4ForMatMul_A;
|
||||
gCoreFunction->backendMatmulRelatedFunctions.MNNPackForMatMul_B = gCoreFunction->MNNPackForMatMul_B;
|
||||
}
|
||||
#ifdef __aarch64__
|
||||
#ifdef MNN_SME2
|
||||
if (gCoreFunction->supportSME2) {
|
||||
gCoreFunction->MNNSumWeightInt8 = MNNSumWeightInt8Sme2;
|
||||
gCoreFunction->MNNReorderWeightInt4 = MNNReorderWeightInt4Sme2;
|
||||
gCoreFunction->MNNPackedMatMul = MNNPackedMatMulFP32_SME2;
|
||||
gCoreFunction->MNNPackedMatMulRemain = MNNPackedMatMulRemainFP32_SME2;
|
||||
gCoreFunction->MNNGetMatMulPackMode = SME2MNNGetMatMulPackMode;
|
||||
gCoreFunction->MNNPackC4ForMatMul_A = Sme2MNNPackForMatMul_A;
|
||||
gCoreFunction->MNNPackForMatMul_B = Sme2MNNPackForMatMul_B;
|
||||
|
||||
gCoreFunction->sme2MatmulRelatedFuncions.MNNSumWeightInt8 = MNNSumWeightInt8Sme2;
|
||||
gCoreFunction->sme2MatmulRelatedFuncions.MNNReorderWeightInt4 = MNNReorderWeightInt4Sme2;
|
||||
gCoreFunction->sme2MatmulRelatedFuncions.MNNPackedMatMul = MNNPackedMatMulFP32_SME2;
|
||||
gCoreFunction->sme2MatmulRelatedFuncions.MNNPackedMatMulRemain = MNNPackedMatMulRemainFP32_SME2;
|
||||
gCoreFunction->sme2MatmulRelatedFuncions.MNNGetMatMulPackMode = SME2MNNGetMatMulPackMode;
|
||||
gCoreFunction->sme2MatmulRelatedFuncions.MNNPackC4ForMatMul_A = Sme2MNNPackForMatMul_A;
|
||||
gCoreFunction->sme2MatmulRelatedFuncions.MNNPackForMatMul_B = Sme2MNNPackForMatMul_B;
|
||||
#ifdef MNN_LOW_MEMORY
|
||||
gCoreFunction->MNNGeneralIm2Col = MNNGeneralIm2col_Fp32Sme2;
|
||||
gCoreFunction->sme2MatmulRelatedFuncions.MNNGeneralIm2Col = MNNGeneralIm2col_Fp32Sme2;
|
||||
#endif // MNN_LOW_MEMORY
|
||||
}
|
||||
#endif // MNN_SME2
|
||||
#endif // __aarch64__
|
||||
MNNCoreInt8FunctionInit();
|
||||
MNNFunctionInit();
|
||||
}
|
||||
|
|
|
@ -19,10 +19,11 @@
|
|||
#include "backend/cpu/compute/Int8FunctionsOpt.h"
|
||||
|
||||
extern "C" {
|
||||
#ifdef MNN_LOW_MEMORY
|
||||
#ifdef __aarch64__
|
||||
#ifdef MNN_LOW_MEMORY
|
||||
void MNNGeneralIm2col_Fp32Arm82(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
|
||||
void MNNGeneralIm2col_Fp32Arm86(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
|
||||
void MNNGeneralIm2col_Fp32Sme2(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
|
||||
void MNNLocalMinMaxFP32_Pack4(float* dstMin, float* dstMax, const float* source, size_t blockNum, size_t blockLU, size_t EP, size_t LP, size_t loadDstBuffer);
|
||||
void MNNLocalMinMaxFP32_Pack8(float* dstMin, float* dstMax, const float* source, size_t blockNum, size_t blockLU, size_t EP, size_t LP, size_t loadDstBuffer);
|
||||
void MNNDynamicQuantFP32_Pack4(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, const float* bias, size_t pack);
|
||||
|
@ -31,6 +32,10 @@ void MNNAbsMaxFP32_Pack4(const float* source, float* absmax, size_t src_depth_qu
|
|||
void MNNAbsMaxFP32_Pack8(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack);
|
||||
void MNNQuantScaleFP32(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch);
|
||||
void MNNDynamicUpdateConvBiasScale(float* newbias, float* oldbias, float* weightKernelSum, float* inputZero, size_t ocQuad);
|
||||
#endif
|
||||
#ifdef MNN_SME2
|
||||
void MNNPackedMatMulRemainFP32_SME2(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
|
||||
|
||||
#endif
|
||||
#endif
|
||||
void MNNFp32ToFp8(uint8_t* dst, const float* src, size_t size);
|
||||
|
@ -134,7 +139,7 @@ el: number * 4
|
|||
*/
|
||||
void MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);
|
||||
|
||||
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose);
|
||||
void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
|
||||
|
||||
// parameters: e, l, h, CStride, AStride, BStride
|
||||
void MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
|
||||
|
@ -197,8 +202,10 @@ struct SumByAxisParams {
|
|||
#ifdef __aarch64__
|
||||
void MNNPermuteSumWeightInt4Arm86(uint8_t* dest, uint8_t* source, size_t outside, size_t inside, float* kernlesum);
|
||||
void MNNPermuteSumWeightInt4Arm82(uint8_t* dest, uint8_t* source, size_t outside, size_t inside, float* kernlesum);
|
||||
void MNNPermuteSumWeightInt4Sme2(uint8_t* dest, uint8_t* source, size_t outside, size_t inside, float* kernlesum);
|
||||
void MNNSumWeightInt8Arm86(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP);
|
||||
void MNNSumWeightInt8Arm82(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP);
|
||||
void MNNSumWeightInt8Sme2(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -211,6 +218,28 @@ typedef void(*MNNBinaryExecInt8)(int8_t* outputRaw, const int8_t* inputRaw0, con
|
|||
constexpr int InputTileMax = 14; // same value from DynamicGemm.h, cannot include from different backend code.
|
||||
|
||||
namespace MNN {
|
||||
struct MatmulRelatedFunctions {
|
||||
// coreFunctions
|
||||
void (*MNNSumWeightInt8)(float* kernelsum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP) = nullptr;
|
||||
void (*MNNReorderWeightInt4)(uint8_t* dest, const uint8_t* source, int32_t* shape, size_t size, float* kernelsum) = nullptr;
|
||||
void (*MNNPackedMatMul)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) = nullptr;
|
||||
void (*MNNPackedMatMulRemain)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) = nullptr;
|
||||
void (*MNNGetMatMulPackMode)(int* eP, int* lP, int* hP) = nullptr;
|
||||
void (*MNNPackC4ForMatMul_A)(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) = nullptr;
|
||||
void (*MNNPackForMatMul_B)(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) = nullptr;
|
||||
void(*MNNGeneralIm2Col)(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack) = nullptr;
|
||||
|
||||
// int8CoreFunctions
|
||||
void(*Int8GemmKernel)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) = nullptr;
|
||||
void(*Int8GemmKernelFast)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) = nullptr;
|
||||
void(*MNNGetGemmUnit)(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) = nullptr;
|
||||
void(*MNNPackC4Int8ForMatMul_A)(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el) = nullptr;
|
||||
void(*MNNGemmInt8AddBiasScale_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr;
|
||||
void(*MNNGemmInt8AddBiasScale_w4_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr;
|
||||
void(*Int8GemmKernel_W4)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount) = nullptr;
|
||||
void(*MNNSumByAxisLForMatmul_A)(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams) = nullptr;
|
||||
};
|
||||
|
||||
struct CoreFunctions {
|
||||
// fp8
|
||||
void (*MNNFp32ToFp8)(uint8_t* dst, const float* src, size_t size);
|
||||
|
@ -222,11 +251,12 @@ struct CoreFunctions {
|
|||
bool supportFp16arith = false;
|
||||
bool supportSDot = false;
|
||||
bool supportI8mm = false;
|
||||
bool supportSME2 = false;
|
||||
/**MatMul Pack and Functions*/
|
||||
void(*MNNGetMatMulPackMode)(int* eP, int *lP, int* hP);
|
||||
void(*MNNGetSparseMatMulPackMode)(int* eP, int *lP, int* hP);
|
||||
void(*MNNPackC4ForMatMul_A)(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);
|
||||
void(*MNNPackForMatMul_B)(float* dest, const float* source, size_t h, size_t l, bool transpose);
|
||||
void(*MNNPackForMatMul_B)(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
|
||||
void(*MNNGeneralIm2Col)(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
|
||||
// parameters: e, l, h, CStride, AStride, BStride
|
||||
void(*MNNPackedMatMul)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
|
||||
|
@ -365,6 +395,9 @@ struct CoreFunctions {
|
|||
void(*MNNSumByAxisLForMatmul_A)(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams);
|
||||
void(*MNNReorderWeightInt4)(uint8_t* dest, const uint8_t* source, int32_t* shape, size_t size, float* kernelsum);
|
||||
void(*MNNSumWeightInt8)(float* kernlesum, int8_t* source, size_t outside, size_t reduceAxis, size_t hP, size_t lP);
|
||||
|
||||
MatmulRelatedFunctions backendMatmulRelatedFunctions;
|
||||
MatmulRelatedFunctions sme2MatmulRelatedFuncions;
|
||||
};
|
||||
void MNNCoreFunctionInit();
|
||||
CoreFunctions* MNNGetCoreFunctions();
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "core/Concurrency.h"
|
||||
#include "core/TensorUtils.hpp"
|
||||
|
||||
|
||||
#define QUANT_INFO_BYTES 4
|
||||
namespace MNN {
|
||||
|
||||
|
@ -175,10 +176,10 @@ void ConvInt8TiledExecutor::packWeightAndQuantInfo(int8_t* dstbuffer, const int8
|
|||
auto huPtr = dstbuffer + hU * blockNum * (stride1 + 2 * UNIT * infoBytes);
|
||||
int scaleCount = ALIMIN(ocUp4 - hU * UNIT, UNIT);
|
||||
for (int bl = 0; bl < blockNum; ++bl) {
|
||||
auto blockPtr = huPtr + bl * (stride1 + 2 * scaleCount * infoBytes);
|
||||
auto blockPtr = huPtr + bl * (stride1 + 2 * UNIT * infoBytes);
|
||||
memcpy(blockPtr, src0 + bl * stride1 + hU * stride0, stride1);
|
||||
memcpy(blockPtr + stride1, src1 + (bl * ocUp4 + hU * UNIT) * infoBytes, scaleCount * infoBytes);
|
||||
memcpy(blockPtr + stride1 + scaleCount * infoBytes, src2 + (bl * ocUp4 + hU * UNIT) * infoBytes, scaleCount * infoBytes);
|
||||
memcpy(blockPtr + stride1 + UNIT * infoBytes, src2 + (bl * ocUp4 + hU * UNIT) * infoBytes, scaleCount * infoBytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -209,15 +210,14 @@ static void _computeReorderQuantInfo(std::shared_ptr<CPUConvolution::ResourceInt
|
|||
auto weightKernelSum = resource->mWeightKernelSum->host<float>();
|
||||
::memset(weightKernelSum, 0, resource->mWeightKernelSum->size());
|
||||
|
||||
bool blockQuantInput = (resource->mWeightKernelSum->length(0) / QUANT_INFO_BYTES == ocUp4) ? false : true;
|
||||
bool blockQuantInput = (resource->mWeightKernelSum->length(0) / QUANT_INFO_BYTES == ocUpHp) ? false : true;
|
||||
int ocDiv4 = UP_DIV(outputCount, pack);
|
||||
// dstkernelsum: [hU,blocknum,min(hP, pack)]
|
||||
// resource->mWeightKernelSum: [hU,blocknum,hP]
|
||||
if (quantCommon->asymmetric) {
|
||||
for (int i = 0; i < outputCount; ++i) {
|
||||
float accum = 0.f;
|
||||
auto ocOutside = i / HP;
|
||||
auto ocInside = i % HP;
|
||||
int remain = ALIMIN(HP, ocUp4 - ocOutside * HP);
|
||||
for (int j = 0; j < blockNum; ++j) {
|
||||
int index = i * blockNum + j;
|
||||
int srcSumIndex = ocOutside * blockNum * HP + j * HP + ocInside; // ikernelsum: [hU,blocknum,hP]
|
||||
|
@ -229,7 +229,7 @@ static void _computeReorderQuantInfo(std::shared_ptr<CPUConvolution::ResourceInt
|
|||
accum += ((ikernelSum[srcSumIndex] - blockSize * 8)* quanInfoPtr[2 * index + 1] + blockSize * quanInfoPtr[2 * index]);
|
||||
}
|
||||
if (blockQuantInput) {
|
||||
int dstSumIndex = ocOutside * blockNum * HP + j * remain + ocInside;
|
||||
int dstSumIndex = ocOutside * blockNum * HP + j * HP + ocInside;
|
||||
weightKernelSum[dstSumIndex] = accum;
|
||||
accum = 0;
|
||||
}
|
||||
|
@ -243,7 +243,6 @@ static void _computeReorderQuantInfo(std::shared_ptr<CPUConvolution::ResourceInt
|
|||
float accum = 0.f;
|
||||
auto ocOutside = i / HP;
|
||||
auto ocInside = i % HP;
|
||||
int remain = ALIMIN(HP, ocUp4 - ocOutside * HP);
|
||||
for (int j = 0; j < blockNum; ++j) {
|
||||
int index = i * blockNum + j;
|
||||
int srcSumIndex = ocOutside * blockNum * HP + j * HP + ocInside; // ikernelsum: [hU,blocknum,hP]
|
||||
|
@ -255,7 +254,7 @@ static void _computeReorderQuantInfo(std::shared_ptr<CPUConvolution::ResourceInt
|
|||
accum += ((ikernelSum[srcSumIndex] - blockSize * 8) * quanInfoPtr[index]);
|
||||
}
|
||||
if (blockQuantInput) {
|
||||
int dstSumIndex = ocOutside * blockNum * HP + j * remain + ocInside;
|
||||
int dstSumIndex = ocOutside * blockNum * HP + j * HP + ocInside;
|
||||
weightKernelSum[dstSumIndex] = accum;
|
||||
accum = 0;
|
||||
}
|
||||
|
@ -286,12 +285,17 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
|
|||
// backend info
|
||||
auto core = static_cast<CPUBackend*>(backend)->int8Functions();
|
||||
auto gcore = static_cast<CPUBackend*>(backend)->functions();
|
||||
const int threads = static_cast<CPUBackend*>(backend)->threadNumber();
|
||||
|
||||
mRelatedFunctions = *(static_cast<CPUBackend*>(backend)->int8GemmFunctions());
|
||||
|
||||
int UNIT, SRC_UNIT, DST_XUNIT;
|
||||
core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
|
||||
mRelatedFunctions.MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
|
||||
int pack = gcore->pack;
|
||||
|
||||
// compute info
|
||||
int ocUp4 = ROUND_UP(oc, pack);
|
||||
int ocUpHp = ROUND_UP(oc, ALIMAX(UNIT, pack));
|
||||
int lU = UP_DIV(ic / blockNum, SRC_UNIT) * kernelCount;
|
||||
int scaleSize = ocUp4 * blockNum;
|
||||
std::vector<int> shape = {blockNum, UP_DIV(oc, UNIT), lU, UNIT, SRC_UNIT};
|
||||
|
@ -316,15 +320,15 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
|
|||
}
|
||||
}
|
||||
// buffer allocate
|
||||
auto quantlen = 2 * blockNum * ROUND_UP(oc, pack) * QUANT_INFO_BYTES;
|
||||
auto quantlen = 2 * blockNum * ROUND_UP(oc, UNIT) * QUANT_INFO_BYTES;
|
||||
auto weightlen = shape[0] * shape[1] * shape[2] * shape[3] * shape[4];
|
||||
mResourceInt8->mWeightInt8.reset(Tensor::createDevice<uint8_t>({weightlen + quantlen}));
|
||||
mResourceInt8->mOriginBias.reset(Tensor::createDevice<int32_t>({ocUp4})); // float
|
||||
mResourceInt8->mOriginBias.reset(Tensor::createDevice<int32_t>({ocUpHp})); // float
|
||||
auto dynamicOption = static_cast<CPUBackend*>(backend)->getRuntime()->hint().dynamicQuantOption; // input ic block quant.
|
||||
if (dynamicOption != 2) {
|
||||
mResourceInt8->mWeightKernelSum.reset(Tensor::createDevice<uint8_t>({QUANT_INFO_BYTES * ocUp4}));
|
||||
mResourceInt8->mWeightKernelSum.reset(Tensor::createDevice<uint8_t>({QUANT_INFO_BYTES * ocUpHp}));
|
||||
} else {
|
||||
mResourceInt8->mWeightKernelSum.reset(Tensor::createDevice<uint8_t>({blockNum * QUANT_INFO_BYTES * ocUp4}));
|
||||
mResourceInt8->mWeightKernelSum.reset(Tensor::createDevice<uint8_t>({blockNum * QUANT_INFO_BYTES * ocUpHp}));
|
||||
}
|
||||
|
||||
auto res = backend->onAcquireBuffer(mResourceInt8->mOriginBias.get(), Backend::STATIC);
|
||||
|
@ -341,7 +345,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
|
|||
}
|
||||
|
||||
// read weight, weight's scale&bias, convolution bias
|
||||
::memset(mResourceInt8->mOriginBias->host<float>(), 0, ocUp4 * sizeof(float));
|
||||
::memset(mResourceInt8->mOriginBias->host<float>(), 0, ocUpHp * sizeof(float));
|
||||
if (!isDynamicQuant) {
|
||||
mResourceInt8->mDynamicQuant = false;
|
||||
|
||||
|
@ -375,7 +379,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
|
|||
// reorder weight&scale&bias
|
||||
int bytes = 4;
|
||||
auto resourceInt8 = mResourceInt8;
|
||||
auto reorderFunction = [weightNotPacked, weightlen, oc, ic, kernelCount, UNIT, SRC_UNIT, shape, resourceInt8, bytes, L, ocUp4, quantlen, scaleAndBias]()
|
||||
auto reorderFunction = [weightNotPacked, weightlen, oc, ic, kernelCount, UNIT, SRC_UNIT, shape, resourceInt8, bytes, L, ocUp4, scaleAndBias, pack]()
|
||||
{
|
||||
AutoStorage<uint8_t> weightReordered(weightlen);
|
||||
if (!weightReordered.get()) {
|
||||
|
@ -384,22 +388,22 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
|
|||
}
|
||||
int32_t info[6] = {1, oc, ic, kernelCount, UNIT, SRC_UNIT};
|
||||
ConvInt8TiledExecutor::reorderWeight((uint8_t*)weightReordered.get(), (uint8_t*)weightNotPacked, info, 0);
|
||||
int32_t params[6] = {shape[0], shape[1], shape[2], shape[3], shape[4], quantlen/ (2 * QUANT_INFO_BYTES * 1)};
|
||||
int32_t params[6] = {shape[0], shape[1], shape[2], shape[3], shape[4], ROUND_UP(oc, pack)};
|
||||
ConvInt8TiledExecutor::packWeightAndQuantInfo(resourceInt8->mWeightInt8->host<int8_t>(), (int8_t*)weightReordered.get(), (int8_t*)scaleAndBias.get(), params, QUANT_INFO_BYTES);
|
||||
return 0;
|
||||
};
|
||||
static_cast<CPUBackend*>(backend)->enqueueTask(reorderFunction);
|
||||
|
||||
// gemmInt8 kernel
|
||||
mGemmKernel = core->Int8GemmKernel;
|
||||
mGemmKernel = mRelatedFunctions.Int8GemmKernel;
|
||||
#ifdef MNN_USE_SSE
|
||||
int actBits = convOp->symmetricQuan()->nbits();
|
||||
if (actBits <= 7) {
|
||||
mGemmKernel = core->Int8GemmKernelFast;
|
||||
mGemmKernel = mRelatedFunctions.Int8GemmKernelFast;
|
||||
}
|
||||
#else
|
||||
if(convOp->symmetricQuan()->method() == QuantizeAlgo_OVERFLOW_AWARE){
|
||||
mGemmKernel = core->Int8GemmKernelFast;
|
||||
mGemmKernel = mRelatedFunctions.Int8GemmKernelFast;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -409,11 +413,12 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
|
|||
// dynamic quant
|
||||
bool directReadInt4weight = (kernelCount == 1 && ROUND_UP(oc, UNIT) == oc && ROUND_UP(ic, SRC_UNIT) == ic);
|
||||
auto target = mResourceInt8;
|
||||
auto funcs = mRelatedFunctions;
|
||||
// Save bias
|
||||
if (convOp->bias()) {
|
||||
::memcpy(mResourceInt8->mOriginBias->host<float>(), convOp->bias()->data(), oc * sizeof(float));
|
||||
}
|
||||
auto function = [shape, UNIT, SRC_UNIT, quanCommon, weightlen, scaleSize, directReadInt4weight, blockNum, ic, oc, quantlen, kernelCount, pack, convOp, gcore, target]() -> int {
|
||||
auto function = [funcs, shape, UNIT, SRC_UNIT, DST_XUNIT, quanCommon, weightlen, scaleSize, directReadInt4weight, blockNum, ic, oc, kernelCount, pack, convOp, gcore, target]() -> int {
|
||||
auto sh = shape;
|
||||
AutoStorage<int8_t> weightReordered(weightlen);
|
||||
AutoStorage<int8_t> reorderedQuantInfo(2 * scaleSize * QUANT_INFO_BYTES);
|
||||
|
@ -427,7 +432,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
|
|||
auto srcPtr = (uint8_t*)quanCommon->weight.get();
|
||||
auto dstPtr = (uint8_t*)weightReordered.get();
|
||||
::memset(dstPtr, 0, weightlen);
|
||||
gcore->MNNReorderWeightInt4(dstPtr, srcPtr, sh.data(), sh.size(), (float*)kernelsum.get());
|
||||
funcs.MNNReorderWeightInt4(dstPtr, srcPtr, sh.data(), sh.size(), (float*)kernelsum.get());
|
||||
} else { // int4 weight but oc/ic not packed
|
||||
auto weightLength = quanCommon->weight.size();
|
||||
int blocksize = ic * kernelCount / blockNum;
|
||||
|
@ -452,7 +457,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
|
|||
MNN_ERROR("Weight reorder memory not enough!\n");
|
||||
return -1;
|
||||
}
|
||||
reorderWeight(packedInt8weight.get(), (uint8_t*)tmpWeight.data(), info, 0, (float*)kernelsum.get(), gcore->MNNSumWeightInt8);
|
||||
reorderWeight(packedInt8weight.get(), (uint8_t*)tmpWeight.data(), info, 0, (float*)kernelsum.get(), funcs.MNNSumWeightInt8);
|
||||
// pack two int4 to int8
|
||||
int leng = weightlen * 2;
|
||||
auto srcint4Ptr = (uint8_t*)packedInt8weight.get();
|
||||
|
@ -463,14 +468,21 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
|
|||
auto src0 = srcint4Ptr + i * permuteUnit;
|
||||
auto dst0 = dstint4Ptr + i * halfPermuteStride;
|
||||
for (int j = 0; j < halfPermuteStride; ++j) {
|
||||
int s0 = src0[j];
|
||||
int s1 = src0[j + halfPermuteStride];
|
||||
int d = (s0) * 16 + (s1);
|
||||
int s0, s1, d;
|
||||
if (UNIT == 16 && SRC_UNIT == 4 && DST_XUNIT == 16) { // SME2
|
||||
s0 = src0[2 * j + 0];
|
||||
s1 = src0[2 * j + 1];
|
||||
d = s0 + (s1) * 16;
|
||||
} else {
|
||||
s0 = src0[j];
|
||||
s1 = src0[j + halfPermuteStride];
|
||||
d = (s0) * 16 + (s1);
|
||||
}
|
||||
dst0[j] = d;
|
||||
}
|
||||
}
|
||||
} else { // int8 weight
|
||||
reorderWeight((uint8_t*)weightReordered.get(), (uint8_t*)quanCommon->weight.get(), info, 0, (float*)kernelsum.get(), gcore->MNNSumWeightInt8);
|
||||
reorderWeight((uint8_t*)weightReordered.get(), (uint8_t*)quanCommon->weight.get(), info, 0, (float*)kernelsum.get(), funcs.MNNSumWeightInt8);
|
||||
}
|
||||
}
|
||||
/* 2. compute and order dequant scale&bias */
|
||||
|
@ -480,7 +492,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O
|
|||
}
|
||||
_computeReorderQuantInfo(target, quanCommon, oc, kernelCount * ic, pack, reorderedQuantInfo, (float*)kernelsum.get(), UNIT, notConvertInt4ToInt8);
|
||||
/* 3. put weight and quantInfo together */
|
||||
int32_t params[6] = {shape[0], shape[1], shape[2], shape[3], shape[4], quantlen / (2 * QUANT_INFO_BYTES * blockNum)};
|
||||
int32_t params[6] = {shape[0], shape[1], shape[2], shape[3], shape[4], ROUND_UP(oc, pack)};
|
||||
ConvInt8TiledExecutor::packWeightAndQuantInfo(target->mWeightInt8->host<int8_t>(), (int8_t*)weightReordered.get(), reorderedQuantInfo.get(), params, QUANT_INFO_BYTES);
|
||||
|
||||
return 0;
|
||||
|
@ -508,16 +520,11 @@ bool DenseConvInt8TiledExecutor::onClone(Backend* bn, const Op* op, Execution**
|
|||
return true;
|
||||
}
|
||||
|
||||
void DenseConvInt8TiledExecutor::getPackParameter(int* Unit, int* srcUnit, int* DestUnit, const CoreInt8Functions* core) {
|
||||
core->MNNGetGemmUnit(Unit, srcUnit, DestUnit);
|
||||
}
|
||||
|
||||
|
||||
ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
||||
// Initialize.
|
||||
mUseBatchQuan = false;
|
||||
mIm2ColBasedInt8 = true;
|
||||
|
||||
auto option = static_cast<CPUBackend*>(backend())->getRuntime()->hint().dynamicQuantOption;
|
||||
int batch = inputs[0]->batch();
|
||||
int inC = inputs[0]->channel();
|
||||
|
@ -526,8 +533,12 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector<Tensor*>& input
|
|||
auto planeSize = output->width() * output->height() * output->batch();
|
||||
auto core = static_cast<CPUBackend*>(backend())->int8Functions();
|
||||
auto gcore =static_cast<CPUBackend*>(backend())->functions();
|
||||
const int threads = static_cast<CPUBackend*>(backend())->threadNumber();
|
||||
|
||||
mRelatedFunctions = *(static_cast<CPUBackend*>(backend())->int8GemmFunctions());
|
||||
|
||||
int UNIT, SRC_UNIT, DST_XUNIT;
|
||||
core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
|
||||
mRelatedFunctions.MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
|
||||
int kernelCount = mCommon->kernelY() * mCommon->kernelX();
|
||||
bool fastway = (kernelCount == 1) && (output->width() == inputs[0]->width()) && (output->height() == inputs[0]->height()) && (mCommon->strideX() * mCommon->strideY()) == 1;
|
||||
if (inputPlane > 1) {
|
||||
|
@ -563,9 +574,9 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector<Tensor*>& input
|
|||
mIm2ColBasedInt8 = true;
|
||||
mUseBatchQuan = false;
|
||||
}
|
||||
ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParamter, mCommon, inputs[0], outputs[0], mPadX, mPadY, gcore, core);
|
||||
int matmulUnits[3] = {UNIT, SRC_UNIT, DST_XUNIT};
|
||||
ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParamter, mCommon, inputs[0], outputs[0], mPadX, mPadY, gcore, core, gcore->pack, matmulUnits);
|
||||
// input scale buffer
|
||||
const int threads = static_cast<CPUBackend*>(backend())->threadNumber();
|
||||
|
||||
// Im2col info
|
||||
int im2colBytes = 1;
|
||||
|
@ -668,22 +679,22 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector<Tensor*>& input
|
|||
|
||||
#ifdef MNN_LOW_MEMORY
|
||||
{ // Dynamic Quant kernels
|
||||
mGemmKernel = core->Int8GemmKernel;
|
||||
mGemmKernel = mRelatedFunctions.Int8GemmKernel;
|
||||
if (mResourceInt8->mActBits == 4) {
|
||||
mGemmKernel = core->Int8GemmKernel_W4;
|
||||
mGemmKernel = mRelatedFunctions.Int8GemmKernel_W4;
|
||||
}
|
||||
mQuantFunc = core->MNNFloat2Int8;
|
||||
if (gcore->bytes == 2 && gcore->pack == 8) {
|
||||
mGemmKernel = core->MNNGemmInt8AddBiasScale_Unit_FP16;
|
||||
mGemmKernel = mRelatedFunctions.MNNGemmInt8AddBiasScale_Unit_FP16;
|
||||
if (mResourceInt8->mActBits == 4) {
|
||||
mGemmKernel = core->MNNGemmInt8AddBiasScale_w4_Unit_FP16;
|
||||
mGemmKernel = mRelatedFunctions.MNNGemmInt8AddBiasScale_w4_Unit_FP16;
|
||||
}
|
||||
mQuantFunc = core->DynamicQuanInput_ARM82;
|
||||
mQuantAndReorderFunc = core->DynamicQuanInputAndReorder_ARM82;
|
||||
|
||||
}
|
||||
// A axisSum kernel
|
||||
mSumByAxisLFunc = gcore->MNNSumByAxisLForMatmul_A;
|
||||
mSumByAxisLFunc = mRelatedFunctions.MNNSumByAxisLForMatmul_A;
|
||||
}
|
||||
|
||||
mInputBlockNum = (option == 2) ? mBlockNum : 1;
|
||||
|
@ -804,8 +815,8 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
|
|||
auto dynamicOption = static_cast<CPUBackend*>(backend())->getRuntime()->hint().dynamicQuantOption;
|
||||
|
||||
int UNIT, SRC_UNIT, DST_XUNIT;
|
||||
core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
|
||||
auto blitProc = core->MNNPackC4Int8ForMatMul_A;
|
||||
mRelatedFunctions.MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
|
||||
auto blitProc = mRelatedFunctions.MNNPackC4Int8ForMatMul_A;
|
||||
const int plane = output->batch() * mIm2ColParamter.oh * mIm2ColParamter.ow;
|
||||
const int batch = input->batch();
|
||||
const int PackUnit = gcore->pack;
|
||||
|
@ -960,6 +971,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
|
|||
sumParams.kernelxy = kxky;
|
||||
sumParams.LU = UP_DIV(inputchannel, SRC_UNIT);
|
||||
sumParams.inputBlock = (mInputBlockNum > 1) ? 1 : 0;
|
||||
std::vector<float> fakeInputScales(DST_XUNIT, 1.f);
|
||||
|
||||
auto tileSplitFunction = [&](int tId, int eStartIndex, int eEndIndex, int estep) {
|
||||
auto ocDivThread = ocDiv4;
|
||||
|
@ -979,20 +991,30 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
|
|||
inputBias = inputScale + mBatchQuantInfo->stride(0) / 2;
|
||||
ptrInputBias = inputBias;
|
||||
}
|
||||
} else {
|
||||
inputScale = (uint8_t*)fakeInputScales.data();
|
||||
ptrInputScale = inputScale;
|
||||
}
|
||||
if (mBlockNum > 1) {
|
||||
accumbuff = reinterpret_cast<float*>(mAccumBuffer->host<int8_t>() + tId * mAccumBuffer->stride(0) * sizeof(int32_t));
|
||||
}
|
||||
float* ptrY = nullptr;
|
||||
if ((dstBytes != 1)) {
|
||||
if (dstBytes != 1) {
|
||||
ptrY = mResourceInt8->mWeightKernelSum->host<float>();
|
||||
}
|
||||
QuanPostTreatParameters quanParam;
|
||||
quanParam.blockNum = mBlockNum;
|
||||
quanParam.weightKernelSum = ptrY;
|
||||
int32_t indices[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
|
||||
quanParam.indices = indices;
|
||||
if (dstBytes != 1) {
|
||||
quanParam.useInt8 = 0;
|
||||
quanParam.fp32minmax = reluPtr;
|
||||
#ifdef MNN_USE_SSE
|
||||
if (!mBatchQuantInfo.get()) {
|
||||
quanParam.weightKernelSum = nullptr;
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
quanParam.maxValue = mMutableResource->mClampMax;
|
||||
if (mResourceInt8->mRelu) {
|
||||
|
@ -1045,7 +1067,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
|
|||
memset(im2colDst, 0, mTempIm2ColBuffer->stride(0));
|
||||
}
|
||||
info[2] = realDstCount;
|
||||
gcore->MNNGeneralIm2Col((float*)im2colDst, (float const**)srcPtr, info, el, SRC_UNIT, PackUnit); // im2colDst: [lu, realDstCount, lp]
|
||||
mRelatedFunctions.MNNGeneralIm2Col((float*)im2colDst, (float const**)srcPtr, info, el, SRC_UNIT, PackUnit); // im2colDst: [lu, realDstCount, lp]
|
||||
}
|
||||
ptrInputScale = mBatchQuantInfo->host<uint8_t>() + tId * mBatchQuantInfo->stride(0);
|
||||
if (dynamicOption == 2) {
|
||||
|
@ -1079,7 +1101,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
|
|||
#endif
|
||||
if (mResourceInt8->mWeightAsymmetricQuant) {
|
||||
MNN_ASSERT(mBatchQuantInfo.get() && mBatchQuantInfo->host<float>());
|
||||
gcore->MNNSumByAxisLForMatmul_A(xKernelSumPtrTid, im2colDst, (float*)ptrInputScale, realDstCount, sumParams);
|
||||
mRelatedFunctions.MNNSumByAxisLForMatmul_A(xKernelSumPtrTid, im2colDst, (float*)ptrInputScale, realDstCount, sumParams);
|
||||
} else {
|
||||
memset(xKernelSumPtrTid, 0, mBlockNum * DST_XUNIT * mIm2ColCount * QUANT_INFO_BYTES);
|
||||
}
|
||||
|
@ -1094,7 +1116,6 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
|
|||
}
|
||||
quanParam.srcKernelSum = ptrX;
|
||||
mGemmKernel(outputInTilePtr, im2colDst, weightPtrTid, blockL, dstZStep * dstBytes, ocDivThread, &quanParam, step);
|
||||
|
||||
ptrX += (step * mBlockNum);
|
||||
realDstCount-=step;
|
||||
outputInTilePtr += DST_XUNIT * PackUnit * dstBytes;
|
||||
|
@ -1155,7 +1176,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
|
|||
info[0] = number;
|
||||
info[2] = work;
|
||||
if (number > 0) { // im2col
|
||||
gcore->MNNGeneralIm2Col((float*)im2colDstTmp, (float const**)srcPtr, info, el, SRC_UNIT, PackUnit); // im2colDst: [lu, realDstCount, lp]
|
||||
mRelatedFunctions.MNNGeneralIm2Col((float*)im2colDstTmp, (float const**)srcPtr, info, el, SRC_UNIT, PackUnit); // im2colDst: [lu, realDstCount, lp]
|
||||
}
|
||||
if (mUseBatchQuan || dynamicOption == 2) {
|
||||
if (dynamicOption == 2) {
|
||||
|
@ -1192,7 +1213,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
|
|||
#endif
|
||||
if (mResourceInt8->mWeightAsymmetricQuant) {
|
||||
MNN_ASSERT(mBatchQuantInfo.get() && mBatchQuantInfo->host<float>());
|
||||
gcore->MNNSumByAxisLForMatmul_A(xKernelSumPtr, im2colDst, mBatchQuantInfo->host<float>(), plane, sumParams);
|
||||
mRelatedFunctions.MNNSumByAxisLForMatmul_A(xKernelSumPtr, im2colDst, mBatchQuantInfo->host<float>(), plane, sumParams);
|
||||
} else {
|
||||
memset(xKernelSumPtr, 0, mTileCount * mBlockNum * DST_XUNIT * mIm2ColCount * QUANT_INFO_BYTES);
|
||||
}
|
||||
|
@ -1211,9 +1232,16 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
|
|||
quanParam.blockNum = mBlockNum;
|
||||
quanParam.weightKernelSum = ptrY;
|
||||
quanParam.biasFloat = reinterpret_cast<float*>(biasPtr + ocIndex * 4);
|
||||
int32_t indices[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
|
||||
quanParam.indices = indices;
|
||||
if (dstBytes != 1) {
|
||||
quanParam.useInt8 = 0;
|
||||
quanParam.fp32minmax = reluPtr;
|
||||
#ifdef MNN_USE_SSE
|
||||
if (!mBatchQuantInfo.get()) {
|
||||
quanParam.weightKernelSum = nullptr;
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
quanParam.maxValue = mMutableResource->mClampMax;
|
||||
if (mResourceInt8->mRelu) {
|
||||
|
@ -1230,6 +1258,8 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector<Tensor*>& inpu
|
|||
if (dynamicOption == 2) {
|
||||
inputBias = inputScale + mInputBlockNum * plane * QUANT_INFO_BYTES;
|
||||
}
|
||||
} else {
|
||||
inputScale = (uint8_t*)fakeInputScales.data();
|
||||
}
|
||||
if (mBlockNum > 1) {
|
||||
accumbuff = reinterpret_cast<float*>(mAccumBuffer->host<int8_t>() + tId * mAccumBuffer->stride(0) * sizeof(int32_t));
|
||||
|
|
|
@ -23,7 +23,6 @@ public:
|
|||
virtual ~ConvInt8TiledExecutor();
|
||||
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
|
||||
virtual void getPackParameter(int* Unit, int* SrcUnit, int* DestUnit, const CoreInt8Functions* core) = 0;
|
||||
static void packWeightAndQuantInfo(int8_t* dstbuffer, const int8_t* weight, const int8_t* quantInfo, int32_t* info, int infoBytes = 4);
|
||||
static void reorderWeight(uint8_t* dst, const uint8_t* src, int32_t* info, int32_t initval = 0, float* kernelsum = nullptr, weightSummerFuncion summerFunc = nullptr);
|
||||
static void initializeConvInt8QuantInfo(std::shared_ptr<CPUConvolution::ResourceInt8>& resourceInt8, const Convolution2D* conv2D);
|
||||
|
@ -56,7 +55,6 @@ public:
|
|||
virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
|
||||
void getPackParameter(int* Unit, int* SrcUnit, int* DestUnit, const CoreInt8Functions* core) override;
|
||||
private:
|
||||
DenseConvInt8TiledExecutor(Backend* backend, const Op* op, const DenseConvInt8TiledExecutor& exe);
|
||||
|
||||
|
@ -84,6 +82,7 @@ private:
|
|||
bool mIm2ColBasedInt8;
|
||||
int mSizeInputBlockQuant;
|
||||
bool mToFuseInputbias2Bias;
|
||||
MatmulRelatedFunctions mRelatedFunctions;
|
||||
};
|
||||
|
||||
} // namespace MNN
|
||||
|
|
|
@ -519,6 +519,7 @@ ErrorCode ConvInt8Winograd::WinoExecution::onExecute(const std::vector<Tensor *>
|
|||
auto weight = mWinoResource->weight->host<int8_t>();
|
||||
std::vector<float> xkernelSum(DST_XUNIT, 0);
|
||||
std::vector<float> wKernelSum(dc_4 * pack, 0);
|
||||
std::vector<float> fakeInputScale(DST_XUNIT, 1.f);
|
||||
std::vector<float> reluThred = {-std::numeric_limits<float>().max(), std::numeric_limits<float>().max()};
|
||||
|
||||
auto tFunction = [&](int tId) {
|
||||
|
@ -560,7 +561,7 @@ ErrorCode ConvInt8Winograd::WinoExecution::onExecute(const std::vector<Tensor *>
|
|||
|
||||
quanParam.biasFloat = (mWinoResource->offsets->host<float>() + i * mWinoResource->offsets->stride(0));
|
||||
quanParam.scale = mWinoResource->scales->host<float>() + i * dc_4 * pack;
|
||||
quanParam.inputScale = nullptr;
|
||||
quanParam.inputScale = fakeInputScale.data();
|
||||
quanParam.bias = nullptr;
|
||||
quanParam.blockNum = 1;
|
||||
gemmFunc((int8_t*)_dstFloatPtr, _srcInt8Ptr, _weightInt8Ptr, mTempInputBuffer->length(2), xC * pack * sizeof(float), dc_4, &quanParam, DST_XUNIT);
|
||||
|
|
|
@ -52,10 +52,10 @@ Convolution1x1Strassen::Convolution1x1Strassen(const Convolution2DCommon *common
|
|||
return;
|
||||
}
|
||||
core->MNNFp32ToLowp(originWeight, tempTensor->host<int16_t>(), outputCount * mSrcCount);
|
||||
core->MNNPackForMatMul_B(mResource->mWeight->host<float>(), tempTensor->host<float>(), outputCount, mSrcCount, true);
|
||||
core->MNNPackForMatMul_B(mResource->mWeight->host<float>(), tempTensor->host<float>(), outputCount, 1, mSrcCount, true);
|
||||
b->onReleaseBuffer(tempTensor.get(), Backend::STATIC);
|
||||
} else {
|
||||
core->MNNPackForMatMul_B(mResource->mWeight->host<float>(), originWeight, outputCount, mSrcCount, true);
|
||||
core->MNNPackForMatMul_B(mResource->mWeight->host<float>(), originWeight, outputCount, 1, mSrcCount, true);
|
||||
}
|
||||
}
|
||||
Convolution1x1Strassen::Convolution1x1Strassen(std::shared_ptr<CPUConvolution::Resource> resource, const Convolution2DCommon *common, Backend* b) : CPUConvolution(common, b) {
|
||||
|
|
|
@ -114,7 +114,7 @@ ErrorCode ConvolutionPackFreeWinograd::onExecute(const std::vector<Tensor *> &in
|
|||
|
||||
std::vector<size_t> parameters(7);
|
||||
parameters[0] = eRemain * bytes;
|
||||
parameters[1] = input->channel();
|
||||
parameters[1] = ROUND_UP(input->channel(), lPack);
|
||||
parameters[2] = output->channel();
|
||||
parameters[3] = ePack * pack * bytes;
|
||||
parameters[4] = 0;
|
||||
|
|
|
@ -281,7 +281,7 @@ ErrorCode ConvolutionPackWinograd::onResize(const std::vector<Tensor *> &inputs,
|
|||
}
|
||||
int eRemain = (tFin-tSta) % ePack;
|
||||
std::vector<size_t> parameters(6);
|
||||
parameters[1] = input->channel();
|
||||
parameters[1] = ROUND_UP(input->channel(), lPack);
|
||||
parameters[2] = output->channel();
|
||||
parameters[4] = 0;
|
||||
parameters[5] = 0;
|
||||
|
|
|
@ -96,7 +96,8 @@ void ConvolutionTiledExecutor:: setIm2ColParameter(ConvolutionCommon::Im2ColPara
|
|||
if (pack == 0) {
|
||||
pack = floatCore->pack;
|
||||
}
|
||||
|
||||
int EP, LP, HP;
|
||||
floatCore->MNNGetMatMulPackMode(&EP, &LP, &HP);
|
||||
const auto kernelCount = convCommon->kernelX() * convCommon->kernelY();
|
||||
|
||||
dstIm2ColParamter.dilateX = convCommon->dilateX();
|
||||
|
@ -116,8 +117,8 @@ void ConvolutionTiledExecutor:: setIm2ColParameter(ConvolutionCommon::Im2ColPara
|
|||
dstIm2ColParamter.srcZStep = input->stride(1) * pack * input->batch();
|
||||
dstIm2ColParamter.srcYStep = input->stride(2) * pack;
|
||||
dstIm2ColParamter.packCUnit = pack;
|
||||
dstIm2ColParamter.ic = input->channel();
|
||||
dstIm2ColParamter.icup4 = input->channel(); // for float im2col.
|
||||
dstIm2ColParamter.ic = ROUND_UP(input->channel(), LP);
|
||||
dstIm2ColParamter.icup4 = ROUND_UP(input->channel(), ALIMIN(LP, pack)); // for float im2col.
|
||||
if (nullptr != int8Core) {
|
||||
// Compute Int8 Info and align ic
|
||||
int UNIT, SRC_UNIT, DynamicDestUnit;
|
||||
|
|
|
@ -24,7 +24,7 @@ namespace MNN {
|
|||
|
||||
void DenseConvolutionTiledExecutor::initWeight(float *dest, const float *source, float* cache, int depth, int outputCount, int kernelSize, const CoreFunctions* function) {
|
||||
ConvolutionTiledExecutor::initWeight(source, cache, depth, outputCount, kernelSize, function);
|
||||
function->MNNPackForMatMul_B(dest, cache, outputCount, kernelSize * depth, true);
|
||||
function->MNNPackForMatMul_B(dest, cache, outputCount, kernelSize, depth, true);
|
||||
|
||||
}
|
||||
bool DenseConvolutionTiledExecutor::initQuantizeResource(std::shared_ptr<ConvolutionCommon::Int8Common> int8Info, std::shared_ptr<CPUConvolution::Resource> resource, int hU, int hP, int lU, int lP, int outputCount, int srcChannel, int kernelSize, int bytes) {
|
||||
|
@ -177,7 +177,7 @@ DenseConvolutionTiledExecutor::DenseConvolutionTiledExecutor(const Convolution2D
|
|||
auto srcCount = (int)originWeightSize / outputCount / common->kernelX() / common->kernelY();
|
||||
auto lSize = srcCount * common->kernelX() * common->kernelY();
|
||||
auto hU = UP_DIV(outputCount, hP);
|
||||
auto lU = UP_DIV(lSize, lP);
|
||||
auto lU = UP_DIV(srcCount, lP) * common->kernelX() * common->kernelY();
|
||||
if (useInt8Weight) {
|
||||
// Quantize weight to int8
|
||||
auto allocSuccess = DenseConvolutionTiledExecutor::initQuantizeResource(int8Info, mResource, hU, hP, lU, lP, outputCount, srcCount, common->kernelX() * common->kernelY(), bytes);
|
||||
|
@ -280,7 +280,7 @@ ErrorCode ConvolutionTiledExecutorMultiInput::onExecute(const std::vector<Tensor
|
|||
MNNTranspose32Bit((int32_t*)dO, (const int32_t*)sO, &dims[0]);
|
||||
}
|
||||
}
|
||||
function->MNNPackForMatMul_B(mTempWeight->host<float>(), mTempWeightCache->host<float>(), outputCount, kernelSize * depth, true);
|
||||
function->MNNPackForMatMul_B(mTempWeight->host<float>(), mTempWeightCache->host<float>(), outputCount, kernelSize, depth, true);
|
||||
return mProxy->onExecute(mInputs, outputs);
|
||||
}
|
||||
ErrorCode ConvolutionTiledExecutorMultiInput::onResize(const std::vector<Tensor*>& inputs,
|
||||
|
@ -292,7 +292,7 @@ ErrorCode ConvolutionTiledExecutorMultiInput::onResize(const std::vector<Tensor*
|
|||
function->MNNGetMatMulPackMode(&eP, &lP, &hP);
|
||||
auto kernelSize = depth * inputs[1]->stride(1);
|
||||
mTempWeight.reset(Tensor::createDevice<float>(
|
||||
{UP_DIV(outputCount, hP), UP_DIV(kernelSize, lP), lP * hP}));
|
||||
{UP_DIV(outputCount, hP), UP_DIV(depth, lP) * inputs[1]->stride(1), lP * hP}));
|
||||
if (function->bytes < 4) {
|
||||
mTempWeightCache.reset(Tensor::createDevice<int32_t>({2, outputCount * kernelSize}));
|
||||
} else {
|
||||
|
@ -304,10 +304,11 @@ ErrorCode ConvolutionTiledExecutorMultiInput::onResize(const std::vector<Tensor*
|
|||
if (!res) {
|
||||
return OUT_OF_MEMORY;
|
||||
}
|
||||
if (inputs.size() > 2 && inputs[2]->elementSize() % function->pack == 0) {
|
||||
if (inputs.size() > 2 && inputs[2]->elementSize() % hP == 0) {
|
||||
mInputs = {inputs[0], mTempWeight.get(), inputs[2]};
|
||||
} else {
|
||||
mTempBias.reset(Tensor::createDevice<float>({UP_DIV(outputCount, function->pack) * function->pack}));
|
||||
auto hPackedSize = ALIMAX(hP, function->pack);
|
||||
mTempBias.reset(Tensor::createDevice<float>({UP_DIV(outputCount, hPackedSize) * hPackedSize}));
|
||||
backend()->onAcquireBuffer(mTempBias.get(), Backend::DYNAMIC);
|
||||
mInputs = {inputs[0], mTempWeight.get(), mTempBias.get()};
|
||||
}
|
||||
|
@ -445,7 +446,7 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector<Tensor*>& inputs
|
|||
const uint8_t* dequantBias = nullptr;
|
||||
auto ic = input->channel();
|
||||
auto icC4 = UP_DIV(ic, unit);
|
||||
auto L = ic * mCommon->kernelY() * mCommon->kernelX();
|
||||
auto L = ROUND_UP(ic, lP) * mCommon->kernelY() * mCommon->kernelX();
|
||||
auto tileC = std::max(unit, hP);
|
||||
int blockSize = L;
|
||||
int blockNum = 1;
|
||||
|
@ -677,7 +678,14 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector<Tensor*>& inputs
|
|||
if (number > 0) {
|
||||
packA((float *)gemmBuffer, srcPtr, info, el);
|
||||
}
|
||||
|
||||
/*
|
||||
for (int kk=0; kk < mIm2ColParameters.kernelX * mIm2ColParameters.kernelY; ++kk) {
|
||||
for (int xx=0; xx < ROUND_UP(input->channel(), lP) * eP; ++xx) {
|
||||
printf("%f ", ((__fp16*)gemmBuffer)[kk * ROUND_UP(input->channel(), lP) * eP + xx]);
|
||||
if (xx % (eP * lP) == (eP * lP -1)) printf("\n");
|
||||
}
|
||||
}
|
||||
*/
|
||||
int finishedL = 0;
|
||||
int wquantStride = 0;
|
||||
int8_t* _weightPtr = reinterpret_cast<int8_t*>(weightPtr);
|
||||
|
|
|
@ -40,6 +40,7 @@ IdstConvolutionInt8::IdstConvolutionInt8(const Convolution2DCommon* convOp, Back
|
|||
core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT);
|
||||
int PackUnit = static_cast<CPUBackend*>(b)->functions()->pack;
|
||||
int ocUp4 = ROUND_UP(biasSize, PackUnit);
|
||||
int ocUpHp = ROUND_UP(biasSize, UNIT);
|
||||
mBias.reset(ocUp4);
|
||||
mBias.clear();
|
||||
auto biasDest = mBias.get();
|
||||
|
@ -66,9 +67,9 @@ IdstConvolutionInt8::IdstConvolutionInt8(const Convolution2DCommon* convOp, Back
|
|||
auto srcCount = mSrcCount;
|
||||
std::vector<int> shape;
|
||||
shape = {1, UP_DIV(outputCount, UNIT), UP_DIV(srcCount, SRC_UNIT) * kernelCount, UNIT, SRC_UNIT};
|
||||
mFakeBias.reset(Tensor::createDevice<float>({ocUp4}));
|
||||
mFakeBias.reset(Tensor::createDevice<float>({ocUpHp}));
|
||||
int weightlen = shape[0] * shape[1] * shape[2] * shape[3] * shape[4];
|
||||
int quantlen = 2 * ocUp4 * QUANT_INFO_BYTES;
|
||||
int quantlen = 2 * ocUpHp * QUANT_INFO_BYTES;
|
||||
mWeight.reset(Tensor::createDevice<int8_t>({weightlen + quantlen}));
|
||||
mValid = b->onAcquireBuffer(mWeight.get(), Backend::STATIC);
|
||||
mValid &= b->onAcquireBuffer(mFakeBias.get(), Backend::STATIC);
|
||||
|
@ -199,7 +200,7 @@ ErrorCode IdstConvolutionInt8::onExecute(const std::vector<Tensor*>& inputs, con
|
|||
quanParam.srcKernelSum = fakeSrcKernleSum.data();
|
||||
std::vector<float> fakeInputScale(DST_XUNIT, 1.f);
|
||||
quanParam.inputScale = fakeInputScale.data();
|
||||
std::vector<float> fakeWeightKernelsSum(ocC4 * PackUnit, 0.f);
|
||||
std::vector<float> fakeWeightKernelsSum(ROUND_UP(output->channel(), UNIT__), 0.f);
|
||||
quanParam.weightKernelSum = fakeWeightKernelsSum.data();
|
||||
quanParam.inputBias = nullptr;
|
||||
quanParam.blockNum = 1;
|
||||
|
|
|
@ -39,6 +39,7 @@ void MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3(int8_t* dst, const int8_t*
|
|||
size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder=nullptr);
|
||||
void MNNSumByAxisLForMatmul_A_ARM86(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams);
|
||||
void MNNSumByAxisLForMatmul_A_ARM82(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams);
|
||||
void MNNSumByAxisLForMatmul_A_SME2(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams);
|
||||
#if defined(MNN_LOW_MEMORY)
|
||||
// int4 weight gemmInt8 kernel
|
||||
void MNNGemmInt8AddBiasScale_ARMV82_w4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad,
|
||||
|
@ -62,8 +63,16 @@ void MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16(int8_t* dst, const int8_t* src,
|
|||
const QuanPostTreatParameters* post, size_t realDstCount);
|
||||
void DynamicQuanInputAndReorder_ARM82(const float* src, int8_t* dst, size_t planeSize, const float* scale, ssize_t aMin,
|
||||
ssize_t aMax, const float* zeroPoint, size_t ocQuad, size_t offset);
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef MNN_SME2
|
||||
void MNNGemmInt8AddBiasScale_SME2_w4_Fp32(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount);
|
||||
void MNNGemmInt8AddBiasScale_SME2_w8_Fp32(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount);
|
||||
void MNNGemmInt8AddBiasScale_SME2_w4_Fp16(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount);
|
||||
void MNNGemmInt8AddBiasScale_SME2_w8_Fp16(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount);
|
||||
#endif
|
||||
#endif // __aarch64__
|
||||
}
|
||||
#endif // MNN_USE_NEON
|
||||
|
@ -2224,6 +2233,12 @@ static void MNNGetGemmUnitI8mm(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) {
|
|||
*DST_XUNIT = 10;
|
||||
}
|
||||
|
||||
static void MNNGetGemmUnitSme2(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) {
|
||||
*UNIT = 16;
|
||||
*SRC_UNIT = 4;
|
||||
*DST_XUNIT = 16;
|
||||
}
|
||||
|
||||
template<int EP, int HP>
|
||||
static void _ArmBasicMNNPackC4ForMatMul_A_L4(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el) {
|
||||
int number = info[0];
|
||||
|
@ -2341,6 +2356,7 @@ static CoreInt8Functions* gCoreFunc = nullptr;
|
|||
void MNNCoreInt8FunctionInit() {
|
||||
/* CoreInt8Functions without sdot */
|
||||
gCoreFunc = new CoreInt8Functions;
|
||||
auto core = MNNGetCoreFunctions();
|
||||
|
||||
// MatMul
|
||||
gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_16x4_Unit;
|
||||
|
@ -2374,7 +2390,6 @@ void MNNCoreInt8FunctionInit() {
|
|||
#endif
|
||||
|
||||
#if defined(__aarch64__)
|
||||
auto core = MNNGetCoreFunctions();
|
||||
if (core->supportSDot) {
|
||||
// MatMul
|
||||
gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_ARMV82_Unit;
|
||||
|
@ -2401,8 +2416,10 @@ void MNNCoreInt8FunctionInit() {
|
|||
gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_ARMV86_Unit;
|
||||
gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnitI8mm;
|
||||
core->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A_ARM86;
|
||||
|
||||
#if defined(MNN_LOW_MEMORY)
|
||||
gCoreFunc->Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale_ARMV86_w4_Unit;
|
||||
|
||||
#ifdef MNN_USE_ARMV82
|
||||
gCoreFunc->MNNGemmInt8AddBiasScale_Unit_FP16 = MNNGemmInt8AddBiasScale_ARMV86_Unit_FP16;
|
||||
gCoreFunc->MNNGemmInt8AddBiasScale_w4_Unit_FP16 = MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16;
|
||||
|
@ -2411,6 +2428,42 @@ void MNNCoreInt8FunctionInit() {
|
|||
// Im2Col
|
||||
gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A<10, 8, 8>;
|
||||
}
|
||||
#endif // __aarch64__
|
||||
{
|
||||
core->backendMatmulRelatedFunctions.Int8GemmKernel = gCoreFunc->Int8GemmKernel;
|
||||
core->backendMatmulRelatedFunctions.Int8GemmKernelFast = gCoreFunc->Int8GemmKernelFast;
|
||||
core->backendMatmulRelatedFunctions.Int8GemmKernel_W4 = gCoreFunc->Int8GemmKernel_W4;
|
||||
core->backendMatmulRelatedFunctions.MNNGemmInt8AddBiasScale_Unit_FP16 = gCoreFunc->MNNGemmInt8AddBiasScale_Unit_FP16;
|
||||
core->backendMatmulRelatedFunctions.MNNGemmInt8AddBiasScale_w4_Unit_FP16 = gCoreFunc->MNNGemmInt8AddBiasScale_w4_Unit_FP16;
|
||||
core->backendMatmulRelatedFunctions.MNNGetGemmUnit = gCoreFunc->MNNGetGemmUnit;
|
||||
core->backendMatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = gCoreFunc->MNNPackC4Int8ForMatMul_A;
|
||||
|
||||
core->backendMatmulRelatedFunctions.MNNSumByAxisLForMatmul_A = core->MNNSumByAxisLForMatmul_A;
|
||||
}
|
||||
|
||||
#ifdef __aarch64__
|
||||
|
||||
#ifdef MNN_SME2
|
||||
if (core->supportSME2) {
|
||||
gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnitSme2;
|
||||
gCoreFunc->Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale_SME2_w4_Fp32;
|
||||
gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_SME2_w8_Fp32;
|
||||
gCoreFunc->MNNGemmInt8AddBiasScale_w4_Unit_FP16 = MNNGemmInt8AddBiasScale_SME2_w4_Fp16;
|
||||
gCoreFunc->MNNGemmInt8AddBiasScale_Unit_FP16 = MNNGemmInt8AddBiasScale_SME2_w8_Fp16;
|
||||
core->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A_SME2;
|
||||
gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A<16, 4, 16>;
|
||||
gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_SME2_w8_Fp32;
|
||||
|
||||
core->sme2MatmulRelatedFuncions.MNNGetGemmUnit = MNNGetGemmUnitSme2;
|
||||
core->sme2MatmulRelatedFuncions.Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale_SME2_w4_Fp32;
|
||||
core->sme2MatmulRelatedFuncions.Int8GemmKernel = MNNGemmInt8AddBiasScale_SME2_w8_Fp32;
|
||||
core->sme2MatmulRelatedFuncions.MNNGemmInt8AddBiasScale_w4_Unit_FP16 = MNNGemmInt8AddBiasScale_SME2_w4_Fp16;
|
||||
core->sme2MatmulRelatedFuncions.MNNGemmInt8AddBiasScale_Unit_FP16 = MNNGemmInt8AddBiasScale_SME2_w8_Fp16;
|
||||
core->sme2MatmulRelatedFuncions.MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A_SME2;
|
||||
core->sme2MatmulRelatedFuncions.MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A<16, 4, 16>;
|
||||
core->sme2MatmulRelatedFuncions.Int8GemmKernelFast = MNNGemmInt8AddBiasScale_SME2_w8_Fp32;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
MNNInt8FunctionInit();
|
||||
}
|
||||
|
|
|
@ -52,6 +52,7 @@ struct QuanPostTreatParameters {
|
|||
const float* inputScale = nullptr;
|
||||
const float* inputBias = nullptr;
|
||||
float* accumBuffer = nullptr;
|
||||
int32_t* indices = nullptr;
|
||||
};
|
||||
struct QuanPrePostParameters{
|
||||
float* inputScale;
|
||||
|
|
|
@ -37,7 +37,7 @@ public:
|
|||
virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override;
|
||||
virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override;
|
||||
|
||||
void getPackParameter(int* Unit, int* SrcUnit, int* DestUnit, const CoreInt8Functions* core) override;
|
||||
void getPackParameter(int* Unit, int* SrcUnit, int* DestUnit, const CoreInt8Functions* core);
|
||||
bool reorderWeight(Backend* b, const Convolution2DCommon* common, const std::shared_ptr<Tensor>& weightOrigin,
|
||||
std::shared_ptr<Tensor>& weight, const SparseCommon* sparseCommon);
|
||||
|
||||
|
|
|
@ -78,11 +78,11 @@ ErrorCode StrassenMatrixComputor::_generateTrivalMatMul(int e, int l, int h, con
|
|||
auto matmulUnit = core->MNNPackedMatMul;
|
||||
auto matmulRemain = core->MNNPackedMatMulRemain;
|
||||
mFunctions.emplace_back(
|
||||
std::make_pair([cStride, l, h, xCount, AT, BT, CT, COT, tileBufferBasic, unitNumber, bExtraStride, numberThread, eReal, eP, active, matmulUnit, matmulRemain, this](int tId) {
|
||||
std::make_pair([cStride, l, h, xCount, AT, BT, CT, COT, tileBufferBasic, unitNumber, bExtraStride, numberThread, eReal, eP, lP, active, matmulUnit, matmulRemain, this](int tId) {
|
||||
auto core = static_cast<CPUBackend*>(backend())->functions();
|
||||
size_t parameters[7];
|
||||
parameters[0] = xCount * core->bytes;
|
||||
parameters[1] = l;
|
||||
parameters[1] = ROUND_UP(l, lP);
|
||||
parameters[2] = h;
|
||||
parameters[3] = cStride;
|
||||
parameters[4] = 0;
|
||||
|
|
|
@ -31,6 +31,7 @@ bool AVX2Backend::isValid() {
|
|||
AVX2Backend::AVX2Backend(const CPURuntime* runtime, BackendConfig::MemoryMode memory, size_t flags) : CPUBackend(runtime, BackendConfig::Precision_Low, memory, MNN_FORWARD_CPU_EXTENSION, flags) {
|
||||
mCoreFunctions = AVX2Functions::get();
|
||||
mInt8CoreFunctions = AVX2Functions::getInt8();
|
||||
mRelatedFunctions = &(mCoreFunctions->backendMatmulRelatedFunctions);
|
||||
}
|
||||
|
||||
AVX2Backend::~AVX2Backend() {
|
||||
|
|
|
@ -105,6 +105,18 @@ bool AVX2Functions::init(int cpuFlags) {
|
|||
sizeof(MNN::CoreFunctions::MNNPackedMatMulKernel) * AVX512_INPUT_TILE_MAX);
|
||||
}
|
||||
#endif
|
||||
{
|
||||
coreFunction->backendMatmulRelatedFunctions.MNNGetMatMulPackMode = coreFunction->MNNGetMatMulPackMode;
|
||||
coreFunction->backendMatmulRelatedFunctions.MNNPackC4ForMatMul_A = coreFunction->MNNPackC4ForMatMul_A;
|
||||
coreFunction->backendMatmulRelatedFunctions.MNNPackForMatMul_B = coreFunction->MNNPackForMatMul_B;
|
||||
coreFunction->backendMatmulRelatedFunctions.MNNPackedMatMul = coreFunction->MNNPackedMatMul;
|
||||
coreFunction->backendMatmulRelatedFunctions.MNNPackedMatMulRemain = coreFunction->MNNPackedMatMulRemain;
|
||||
coreFunction->backendMatmulRelatedFunctions.Int8GemmKernel = gAVX2CoreInt8Functions->Int8GemmKernel;
|
||||
coreFunction->backendMatmulRelatedFunctions.Int8GemmKernelFast = gAVX2CoreInt8Functions->Int8GemmKernelFast;
|
||||
coreFunction->backendMatmulRelatedFunctions.Int8GemmKernel_W4 = gAVX2CoreInt8Functions->Int8GemmKernel_W4;
|
||||
coreFunction->backendMatmulRelatedFunctions.MNNGetGemmUnit = gAVX2CoreInt8Functions->MNNGetGemmUnit;
|
||||
coreFunction->backendMatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = gAVX2CoreInt8Functions->MNNPackC4Int8ForMatMul_A;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -80,6 +80,13 @@ void MNNFunctionInit() {
|
|||
}
|
||||
#endif
|
||||
_SSE_ImageProcessInit(coreFunction, cpuFlags);
|
||||
{
|
||||
coreFunction->backendMatmulRelatedFunctions.MNNGetMatMulPackMode = coreFunction->MNNGetMatMulPackMode;
|
||||
coreFunction->backendMatmulRelatedFunctions.MNNPackC4ForMatMul_A = coreFunction->MNNPackC4ForMatMul_A;
|
||||
coreFunction->backendMatmulRelatedFunctions.MNNPackForMatMul_B = coreFunction->MNNPackForMatMul_B;
|
||||
coreFunction->backendMatmulRelatedFunctions.MNNPackedMatMul = coreFunction->MNNPackedMatMul;
|
||||
coreFunction->backendMatmulRelatedFunctions.MNNPackedMatMulRemain = coreFunction->MNNPackedMatMulRemain;
|
||||
}
|
||||
}
|
||||
|
||||
void MNNAvgPoolUint8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx, ssize_t paddingx, ssize_t factor) {
|
||||
|
@ -131,6 +138,7 @@ void MNNMaxPoolInt8_(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputW
|
|||
void MNNInt8FunctionInit() {
|
||||
auto cpuFlags = libyuv::InitCpuFlags();
|
||||
auto core = MNN::MNNGetInt8CoreFunctions();
|
||||
auto gcore = MNN::MNNGetCoreFunctions();
|
||||
core->MNNAvgPoolInt8 = MNNAvgPoolUint8;
|
||||
core->MNNMaxPoolInt8 = MNNMaxPoolInt8_;
|
||||
if (cpuFlags & libyuv::kCpuHasSSE41) {
|
||||
|
@ -143,6 +151,13 @@ void MNNInt8FunctionInit() {
|
|||
core->Int8GemmKernel_W4 = _SSE_MNNGemmInt8AddBiasScale_16x4_w4;
|
||||
#endif
|
||||
}
|
||||
{
|
||||
gcore->backendMatmulRelatedFunctions.Int8GemmKernel = core->Int8GemmKernel;
|
||||
gcore->backendMatmulRelatedFunctions.Int8GemmKernelFast = core->Int8GemmKernelFast;
|
||||
gcore->backendMatmulRelatedFunctions.Int8GemmKernel_W4 = core->Int8GemmKernel_W4;
|
||||
gcore->backendMatmulRelatedFunctions.MNNGetGemmUnit = core->MNNGetGemmUnit;
|
||||
gcore->backendMatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = core->MNNPackC4Int8ForMatMul_A;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ void _AVX_MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup
|
|||
void _AVX_MNNCountMinMaxValue(const float* source, float* minVal, float* maxVal, size_t size);
|
||||
|
||||
void _AVX_MNNGetMatMulPackMode_BF16(int* eP, int *lP, int* hP);
|
||||
void _AVX_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose);
|
||||
void _AVX_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
|
||||
void _AVX_MNNPackedSparseMatMul(float* C, const float* A, const float* B, unsigned int* NNZMap, int* dataOffsetMap, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias);
|
||||
void _AVX_MNNComputeMatMulForH_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId);
|
||||
|
||||
|
@ -72,7 +72,7 @@ void _AVX_MNNPackCUnit(float* dst, const float* src, size_t area, size_t depth,
|
|||
void _AVX_MNNUnpackCUnit(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
|
||||
void _AVX_MNNPackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
|
||||
void _AVX_MNNUnpackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
|
||||
void _AVX_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose);
|
||||
void _AVX_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
|
||||
|
||||
void _AVX_ExtraInit(void* functions);
|
||||
void _AVX_WinogradInit(void* functions);
|
||||
|
|
|
@ -336,7 +336,8 @@ _mm256_storeu_ps(temp2 + 7 * 8, t7);\
|
|||
}
|
||||
}
|
||||
|
||||
void _AVX_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose) {
|
||||
void _AVX_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
||||
auto l = kernelsize * ic;
|
||||
int offset[2] = {
|
||||
(int)l,
|
||||
(int)l
|
||||
|
@ -348,10 +349,11 @@ void _AVX_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t
|
|||
MNNPackC4(dest, source, l, h, offset);
|
||||
}
|
||||
|
||||
void _AVX_MNNPackForMatMul_B_EShort(float* dest, const float* source, size_t h, size_t l, bool transpose) {
|
||||
void _AVX_MNNPackForMatMul_B_EShort(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
||||
const int unit = 16;
|
||||
auto hP = h / unit;
|
||||
auto hR = hP * unit;
|
||||
auto l = kernelsize * ic;
|
||||
if (hR != h) {
|
||||
::memset(dest, 0, UP_DIV(h, unit)*unit*l*sizeof(float));
|
||||
}
|
||||
|
|
|
@ -10,7 +10,8 @@
|
|||
#include "FunctionSummary.hpp"
|
||||
#include "core/Macro.h"
|
||||
|
||||
void _AVX_MNNPackForMatMul_B_BF16(float* destF, const float* sourceF, size_t h, size_t l, bool transpose) {
|
||||
void _AVX_MNNPackForMatMul_B_BF16(float* destF, const float* sourceF, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
||||
auto l = kernelsize * ic;
|
||||
auto dest = (int16_t*)destF;
|
||||
auto source = (const int16_t*)sourceF;
|
||||
auto lC8 = UP_DIV(l, 8);
|
||||
|
|
|
@ -30,7 +30,7 @@ do { \
|
|||
|
||||
// ========= CommonOptFunction.cpp ===========
|
||||
extern "C" {
|
||||
void _AVX512_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose);
|
||||
void _AVX512_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
|
||||
void _AVX512_MNNPackC8ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);
|
||||
|
||||
void _AVX512_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
|
||||
|
|
|
@ -327,7 +327,8 @@ void _AVX_MNNPackCUnitTranspose(float* dst, const float* src, size_t area, size_
|
|||
void _AVX_MNNUnpackCUnitTranspose(float* dst, const float* src, size_t area, size_t depth, int* areaOffset);
|
||||
}
|
||||
|
||||
void _AVX512_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose) {
|
||||
void _AVX512_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
||||
auto l = kernelsize * ic;
|
||||
int offset[2] = {
|
||||
(int)l,
|
||||
(int)l
|
||||
|
|
|
@ -286,6 +286,7 @@ void _AVX512_MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* sca
|
|||
auto sizeC4 = sizeQuad / 2;
|
||||
auto sizeRemain = sizeQuad % 2;
|
||||
auto zero = _mm256_set1_epi32(0);
|
||||
auto offset = _mm256_set1_ps(128.f);
|
||||
|
||||
auto scaleValue0 = _mm256_set1_ps(scale[0]);
|
||||
auto scaleValue1 = scaleValue0;
|
||||
|
@ -293,11 +294,11 @@ void _AVX512_MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* sca
|
|||
scaleValue0 = _mm256_loadu_ps(scale);
|
||||
scaleValue1 = _mm256_loadu_ps(scale + 8);
|
||||
}
|
||||
auto zeroPointValue0 = _mm256_set1_ps(zeroPoint[0]) + _mm256_set1_ps(128.f);
|
||||
auto zeroPointValue0 = _mm256_add_ps(_mm256_set1_ps(zeroPoint[0]), offset);
|
||||
auto zeroPointValue1 = zeroPointValue0;
|
||||
if (quanParamVec >> 1) {
|
||||
zeroPointValue0 = _mm256_loadu_ps(zeroPoint) + _mm256_set1_ps(128.f);
|
||||
zeroPointValue1 = _mm256_loadu_ps(zeroPoint + 8) + _mm256_set1_ps(128.f);
|
||||
zeroPointValue0 = _mm256_add_ps(_mm256_loadu_ps(zeroPoint), offset);
|
||||
zeroPointValue1 = _mm256_add_ps(_mm256_loadu_ps(zeroPoint + 8), offset);
|
||||
}
|
||||
|
||||
for (int i = 0; i < sizeC4; ++i) {
|
||||
|
|
|
@ -153,7 +153,6 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
|
|||
int weight_step_Z = static_cast<int32_t>(src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H)) + (2 * sizeof(float) * GEMMINT8_AVX512_H);
|
||||
int weight_step_Y = static_cast<int32_t>(GEMMINT8_AVX512_L * GEMMINT8_AVX512_H);
|
||||
int weightPackStride = GEMMINT8_AVX512_L * PACK_UNIT;
|
||||
int weight_step_Z_remain = static_cast<int32_t>(src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H)) + (2 * sizeof(float) * dzR * PACK_UNIT);
|
||||
int source_step = realDst * PACK_UNIT;
|
||||
|
||||
if (realDst == GEMMINT8_AVX512_E) {
|
||||
|
@ -518,9 +517,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
|
|||
__m512i D1 = _mm512_set1_epi32(0);
|
||||
__m512i D2 = _mm512_set1_epi32(0);
|
||||
__m512i D3 = _mm512_set1_epi32(0);
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + dzR * PACK_UNIT;
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z + i * weightPackStride;
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
|
||||
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
|
||||
|
||||
for (int sz = 0; sz < src_depth_quad; ++sz) {
|
||||
|
@ -569,7 +568,7 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
|
|||
|
||||
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
|
||||
if (post->inputBias) {
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
|
||||
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
|
||||
bias00 = _mm512_mul_ps(inputbias0, wsum0);
|
||||
bias01 = _mm512_mul_ps(inputbias1, wsum0);
|
||||
|
@ -941,9 +940,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
|
|||
__m512i D0 = _mm512_set1_epi32(0);
|
||||
__m512i D1 = _mm512_set1_epi32(0);
|
||||
__m512i D2 = _mm512_set1_epi32(0);
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + dzR * PACK_UNIT;
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z + i * weightPackStride;
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
|
||||
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
|
||||
|
||||
for (int sz = 0; sz < src_depth_quad; ++sz) {
|
||||
|
@ -984,7 +983,7 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
|
|||
f2 = _mm512_mul_ps(f2, inputscale2);
|
||||
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
|
||||
if (post->inputBias) {
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
|
||||
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
|
||||
bias00 = _mm512_mul_ps(inputbias0, wsum0);
|
||||
bias01 = _mm512_mul_ps(inputbias1, wsum0);
|
||||
|
@ -1284,9 +1283,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
|
|||
for (int bk = 0; bk < blockNum; ++bk) {
|
||||
__m512i D0 = _mm512_set1_epi32(0);
|
||||
__m512i D1 = _mm512_set1_epi32(0);
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + dzR * PACK_UNIT;
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z + i * weightPackStride;
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
|
||||
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
|
||||
|
||||
for (int sz = 0; sz < src_depth_quad; ++sz) {
|
||||
|
@ -1320,7 +1319,7 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
|
|||
f1 = _mm512_mul_ps(f1, inputscale1);
|
||||
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
|
||||
if (post->inputBias) {
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
|
||||
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
|
||||
bias00 = _mm512_mul_ps(inputbias0, wsum0);
|
||||
bias01 = _mm512_mul_ps(inputbias1, wsum0);
|
||||
|
@ -1537,9 +1536,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
|
|||
for (int i=0; i<dzR; ++i) {
|
||||
auto accum_x = accumbuff;
|
||||
for (int bk = 0; bk < blockNum; ++bk) {
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + dzR * PACK_UNIT;
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z + i * weightPackStride;
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
|
||||
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
|
||||
|
||||
__m512i D0 = _mm512_set1_epi32(0);
|
||||
|
@ -1566,7 +1565,7 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s
|
|||
f0 = _mm512_mul_ps(f0, inputscale0);
|
||||
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
|
||||
if (post->inputBias) {
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
|
||||
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
|
||||
bias00 = _mm512_mul_ps(inputbias0, wsum0);
|
||||
} else if (bk == 0) { // if input not block quant, only accum once!
|
||||
|
@ -1655,7 +1654,6 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
|
|||
int weight_step_Y = GEMMINT8_AVX512_L * GEMMINT8_AVX512_H / 2;
|
||||
int weight_step_Z = src_depth_quad * weight_step_Y + (2 * 4 * GEMMINT8_AVX512_H);
|
||||
int weightPackStride = GEMMINT8_AVX512_L / 2 * PACK_UNIT;
|
||||
int weight_step_Z_remain = src_depth_quad * weight_step_Y + (2 * 4 * dzR * PACK_UNIT);
|
||||
int source_step = realDst * PACK_UNIT;
|
||||
if (realDst == GEMMINT8_AVX512_E) {
|
||||
for (int dz = 0; dz < dzU; ++dz) {
|
||||
|
@ -1975,9 +1973,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
|
|||
__m512i D1 = _mm512_set1_epi32(0);
|
||||
__m512i D2 = _mm512_set1_epi32(0);
|
||||
__m512i D3 = _mm512_set1_epi32(0);
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + weightPackStride * suborder[i];
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + dzR * PACK_UNIT;
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z + weightPackStride * suborder[i];
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
|
||||
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
|
||||
|
||||
for (int sz = 0; sz < src_depth_quad; ++sz) {
|
||||
|
@ -2027,7 +2025,7 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
|
|||
f3 = _mm512_mul_ps(f3, inputscale3);
|
||||
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
|
||||
if (post->inputBias) {
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
|
||||
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
|
||||
bias00 = _mm512_mul_ps(inputbias0, wsum0);
|
||||
bias01 = _mm512_mul_ps(inputbias1, wsum0);
|
||||
|
@ -2349,9 +2347,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
|
|||
__m512i D0 = _mm512_set1_epi32(0);
|
||||
__m512i D1 = _mm512_set1_epi32(0);
|
||||
__m512i D2 = _mm512_set1_epi32(0);
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + weightPackStride * suborder[i];
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + dzR * PACK_UNIT;
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z + weightPackStride * suborder[i];
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
|
||||
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
|
||||
|
||||
for (int sz = 0; sz < src_depth_quad; ++sz) {
|
||||
|
@ -2394,7 +2392,7 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
|
|||
f2 = _mm512_mul_ps(f2, inputscale2);
|
||||
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
|
||||
if (post->inputBias) {
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
|
||||
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
|
||||
bias00 = _mm512_mul_ps(inputbias0, wsum0);
|
||||
bias01 = _mm512_mul_ps(inputbias1, wsum0);
|
||||
|
@ -2655,9 +2653,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
|
|||
for (int bk = 0; bk < blockNum; ++bk) {
|
||||
__m512i D0 = _mm512_set1_epi32(0);
|
||||
__m512i D1 = _mm512_set1_epi32(0);
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + weightPackStride * suborder[i];
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + dzR * PACK_UNIT;
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z + weightPackStride * suborder[i];
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
|
||||
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
|
||||
|
||||
for (int sz = 0; sz < src_depth_quad; ++sz) {
|
||||
|
@ -2693,7 +2691,7 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
|
|||
f1 = _mm512_mul_ps(f1, inputscale1);
|
||||
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
|
||||
if (post->inputBias) {
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
|
||||
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
|
||||
bias00 = _mm512_mul_ps(inputbias0, wsum0);
|
||||
bias01 = _mm512_mul_ps(inputbias1, wsum0);
|
||||
|
@ -2878,9 +2876,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
|
|||
for (int i=0; i<dzR; ++i) {
|
||||
auto accum_x = accumbuff;
|
||||
for (int bk = 0; bk < blockNum; ++bk) {
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + weightPackStride * suborder[i];
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + dzR * PACK_UNIT;
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z + weightPackStride * suborder[i];
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
|
||||
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
|
||||
|
||||
__m512i D0 = _mm512_set1_epi32(0);
|
||||
|
@ -2908,7 +2906,7 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t
|
|||
f0 = _mm512_mul_ps(f0, inputscale0);
|
||||
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
|
||||
if (post->inputBias) {
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
|
||||
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
|
||||
bias00 = _mm512_mul_ps(inputbias0, wsum0);
|
||||
} else if (bk == 0) { // if input not block quant, only accum once!
|
||||
|
|
|
@ -139,7 +139,6 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
|
|||
int weight_step_Z = static_cast<int32_t>(src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H)) + (2 * sizeof(float) * GEMMINT8_AVX512_H);
|
||||
int weight_step_Y = static_cast<int32_t>(GEMMINT8_AVX512_L * GEMMINT8_AVX512_H);
|
||||
int weightPackStride = GEMMINT8_AVX512_L * PACK_UNIT;
|
||||
int weight_step_Z_remain = static_cast<int32_t>(src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H)) + (2 * sizeof(float) * dzR * PACK_UNIT);
|
||||
int source_step = realDst * PACK_UNIT;
|
||||
|
||||
if (realDst == GEMMINT8_AVX512_E) {
|
||||
|
@ -504,9 +503,9 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
|
|||
__m512i D1 = _mm512_set1_epi32(0);
|
||||
__m512i D2 = _mm512_set1_epi32(0);
|
||||
__m512i D3 = _mm512_set1_epi32(0);
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + dzR * PACK_UNIT;
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z + i * weightPackStride;
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
|
||||
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
|
||||
|
||||
for (int sz = 0; sz < src_depth_quad; ++sz) {
|
||||
|
@ -555,7 +554,7 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
|
|||
|
||||
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
|
||||
if (post->inputBias) {
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
|
||||
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
|
||||
bias00 = _mm512_mul_ps(inputbias0, wsum0);
|
||||
bias01 = _mm512_mul_ps(inputbias1, wsum0);
|
||||
|
@ -927,9 +926,9 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
|
|||
__m512i D0 = _mm512_set1_epi32(0);
|
||||
__m512i D1 = _mm512_set1_epi32(0);
|
||||
__m512i D2 = _mm512_set1_epi32(0);
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + dzR * PACK_UNIT;
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z + i * weightPackStride;
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
|
||||
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
|
||||
|
||||
for (int sz = 0; sz < src_depth_quad; ++sz) {
|
||||
|
@ -970,7 +969,7 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
|
|||
f2 = _mm512_mul_ps(f2, inputscale2);
|
||||
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
|
||||
if (post->inputBias) {
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
|
||||
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
|
||||
bias00 = _mm512_mul_ps(inputbias0, wsum0);
|
||||
bias01 = _mm512_mul_ps(inputbias1, wsum0);
|
||||
|
@ -1270,9 +1269,9 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
|
|||
for (int bk = 0; bk < blockNum; ++bk) {
|
||||
__m512i D0 = _mm512_set1_epi32(0);
|
||||
__m512i D1 = _mm512_set1_epi32(0);
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + dzR * PACK_UNIT;
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z + i * weightPackStride;
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
|
||||
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
|
||||
|
||||
for (int sz = 0; sz < src_depth_quad; ++sz) {
|
||||
|
@ -1306,7 +1305,7 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
|
|||
f1 = _mm512_mul_ps(f1, inputscale1);
|
||||
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
|
||||
if (post->inputBias) {
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
|
||||
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
|
||||
bias00 = _mm512_mul_ps(inputbias0, wsum0);
|
||||
bias01 = _mm512_mul_ps(inputbias1, wsum0);
|
||||
|
@ -1523,9 +1522,9 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
|
|||
for (int i=0; i<dzR; ++i) {
|
||||
auto accum_x = accumbuff;
|
||||
for (int bk = 0; bk < blockNum; ++bk) {
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + i * weightPackStride;
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + dzR * PACK_UNIT;
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z + i * weightPackStride;
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
|
||||
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
|
||||
|
||||
__m512i D0 = _mm512_set1_epi32(0);
|
||||
|
@ -1552,7 +1551,7 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s
|
|||
f0 = _mm512_mul_ps(f0, inputscale0);
|
||||
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
|
||||
if (post->inputBias) {
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
|
||||
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
|
||||
bias00 = _mm512_mul_ps(inputbias0, wsum0);
|
||||
} else if (bk == 0) { // if input not block quant, only accum once!
|
||||
|
@ -1640,7 +1639,6 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
|
|||
int weight_step_Y = GEMMINT8_AVX512_L * GEMMINT8_AVX512_H / 2;
|
||||
int weight_step_Z = src_depth_quad * weight_step_Y + (2 * 4 * GEMMINT8_AVX512_H);
|
||||
int weightPackStride = GEMMINT8_AVX512_L / 2 * PACK_UNIT;
|
||||
int weight_step_Z_remain = src_depth_quad * weight_step_Y + (2 * 4 * dzR * PACK_UNIT);
|
||||
int source_step = realDst * PACK_UNIT;
|
||||
if (realDst == GEMMINT8_AVX512_E) {
|
||||
for (int dz = 0; dz < dzU; ++dz) {
|
||||
|
@ -1960,9 +1958,9 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
|
|||
__m512i D1 = _mm512_set1_epi32(0);
|
||||
__m512i D2 = _mm512_set1_epi32(0);
|
||||
__m512i D3 = _mm512_set1_epi32(0);
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + weightPackStride * suborder[i];
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + dzR * PACK_UNIT;
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z + weightPackStride * suborder[i];
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
|
||||
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
|
||||
|
||||
for (int sz = 0; sz < src_depth_quad; ++sz) {
|
||||
|
@ -2012,7 +2010,7 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
|
|||
f3 = _mm512_mul_ps(f3, inputscale3);
|
||||
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
|
||||
if (post->inputBias) {
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
|
||||
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
|
||||
bias00 = _mm512_mul_ps(inputbias0, wsum0);
|
||||
bias01 = _mm512_mul_ps(inputbias1, wsum0);
|
||||
|
@ -2334,9 +2332,9 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
|
|||
__m512i D0 = _mm512_set1_epi32(0);
|
||||
__m512i D1 = _mm512_set1_epi32(0);
|
||||
__m512i D2 = _mm512_set1_epi32(0);
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + weightPackStride * suborder[i];
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + dzR * PACK_UNIT;
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z + weightPackStride * suborder[i];
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
|
||||
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
|
||||
|
||||
for (int sz = 0; sz < src_depth_quad; ++sz) {
|
||||
|
@ -2379,7 +2377,7 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
|
|||
f2 = _mm512_mul_ps(f2, inputscale2);
|
||||
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
|
||||
if (post->inputBias) {
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
|
||||
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
|
||||
bias00 = _mm512_mul_ps(inputbias0, wsum0);
|
||||
bias01 = _mm512_mul_ps(inputbias1, wsum0);
|
||||
|
@ -2640,9 +2638,9 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
|
|||
for (int bk = 0; bk < blockNum; ++bk) {
|
||||
__m512i D0 = _mm512_set1_epi32(0);
|
||||
__m512i D1 = _mm512_set1_epi32(0);
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + weightPackStride * suborder[i];
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + dzR * PACK_UNIT;
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z + weightPackStride * suborder[i];
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
|
||||
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
|
||||
|
||||
for (int sz = 0; sz < src_depth_quad; ++sz) {
|
||||
|
@ -2678,7 +2676,7 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
|
|||
f1 = _mm512_mul_ps(f1, inputscale1);
|
||||
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
|
||||
if (post->inputBias) {
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
|
||||
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
|
||||
bias00 = _mm512_mul_ps(inputbias0, wsum0);
|
||||
bias01 = _mm512_mul_ps(inputbias1, wsum0);
|
||||
|
@ -2863,9 +2861,9 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
|
|||
for (int i=0; i<dzR; ++i) {
|
||||
auto accum_x = accumbuff;
|
||||
for (int bk = 0; bk < blockNum; ++bk) {
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z_remain + weightPackStride * suborder[i];
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z_remain + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + dzR * PACK_UNIT;
|
||||
auto weightDzSub = weight_dz + bk * weight_step_Z + weightPackStride * suborder[i];
|
||||
auto scaleDz = (float*)(weight_dz + bk * weight_step_Z + src_depth_quad * weight_step_Y);
|
||||
auto biasDz = scaleDz + GEMMINT8_AVX512_H;
|
||||
const auto src_x = src + bk * src_depth_quad * GEMMINT8_AVX512_L * realDst;
|
||||
|
||||
__m512i D0 = _mm512_set1_epi32(0);
|
||||
|
@ -2893,7 +2891,7 @@ void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight
|
|||
f0 = _mm512_mul_ps(f0, inputscale0);
|
||||
if ((post->useInt8 == 0) && post->weightKernelSum && (post->inputBias || (bk == 0))) {
|
||||
if (post->inputBias) {
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * dzR * PACK_UNIT;
|
||||
weightKernelSum_dz = post->weightKernelSum + dzU * blockNum * GEMMINT8_AVX512_H + bk * GEMMINT8_AVX512_H;
|
||||
auto wsum0 = _mm512_loadu_ps(weightKernelSum_dz + i * PACK_UNIT);
|
||||
bias00 = _mm512_mul_ps(inputbias0, wsum0);
|
||||
} else if (bk == 0) { // if input not block quant, only accum once!
|
||||
|
|
|
@ -72,14 +72,14 @@ void _SSE_MNNConvRunForLineDepthwise(float* dst, const float* src, const float*
|
|||
void _SSE_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step,
|
||||
size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst);
|
||||
void _SSE_MNNExpC8(float* dest, const float* source, float* offset, const float* parameters, size_t countC8);
|
||||
void _SSE_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose);
|
||||
void _SSE_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
|
||||
void _SSE_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, ssize_t maxValue, const float* zeroPoint, ssize_t quanParamVec);
|
||||
|
||||
void _SSE_MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t size, const float* zeroPoint, ssize_t quanParamVec);
|
||||
void _SSE_MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder=nullptr);
|
||||
void _SSE_MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count);
|
||||
|
||||
void _SSE_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose);
|
||||
void _SSE_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose);
|
||||
void _SSE_MNNReluInt8(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint);
|
||||
void _SSE_MNNSoftmax(float* dest, const float* source, size_t size);
|
||||
void _SSE_ExtraInit(void* functions);
|
||||
|
|
|
@ -153,7 +153,8 @@ void _SSE_GemmPostTreat(float* C, size_t eSize, const size_t* parameter, const f
|
|||
}
|
||||
}
|
||||
|
||||
void _SSE_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose) {
|
||||
void _SSE_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
||||
auto l = kernelsize * ic;
|
||||
int offset[2] = {
|
||||
(int)l,
|
||||
(int)l
|
||||
|
@ -165,7 +166,8 @@ void _SSE_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t
|
|||
MNNPackC4(dest, source, l, h, offset);
|
||||
}
|
||||
|
||||
void _SSE_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose) {
|
||||
void _SSE_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
||||
auto l = kernelsize * ic;
|
||||
int offset[] = {
|
||||
(int)l,
|
||||
(int)l
|
||||
|
|
|
@ -117,7 +117,7 @@ ENDIF()
|
|||
include(FetchContent)
|
||||
|
||||
if(MNN_SUPPORT_TRANSFORMER_FUSE)
|
||||
set(CUTLASS_COMMIT_HASH "5c6bca04414e06ce74458ab0a2018e2b8272701c")
|
||||
set(CUTLASS_COMMIT_HASH "b995f933179c22d3fe0d871c3a53d11e4681950f")
|
||||
set(CUTLASS_VERSION_NAME "v4.0.0")
|
||||
else()
|
||||
set(CUTLASS_COMMIT_HASH "319a389f42b776fae5701afcb943fc03be5b5c25")
|
||||
|
|
|
@ -100,7 +100,9 @@ Backend* CUDARuntimeWrapper::onCreate(const BackendConfig* config, Backend* orig
|
|||
precision = 1;
|
||||
}
|
||||
|
||||
return new CUDABackend(this, mBufferPool, mCUDARuntime, precision, memory_mode);
|
||||
auto backend = new CUDABackend(this, mBufferPool, mCUDARuntime, precision, memory_mode);
|
||||
backend->setMetaPtr(pMeta);
|
||||
return backend;
|
||||
}
|
||||
|
||||
void CUDARuntimeWrapper::onGabageCollect(int level) {
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#include "AttentionExecution.hpp"
|
||||
#include "core/TensorUtils.hpp"
|
||||
#include "SoftmaxExecution.hpp"
|
||||
|
||||
namespace MNN {
|
||||
namespace CUDA {
|
||||
|
@ -43,8 +44,8 @@ __global__ void compact_kv_cache_kernel(
|
|||
int copy_len = reserve_info[reserve_pair_idx * 2 + 1];
|
||||
int copy_dst_begin_offset = reserve_offsets[reserve_pair_idx]; // 在 reserve 区域内的偏移
|
||||
|
||||
long long src_offset = past_kv_len_after_remove + copy_src_begin;
|
||||
long long dst_offset = past_kv_len_after_remove + copy_dst_begin_offset;
|
||||
int src_offset = past_kv_len_after_remove + copy_src_begin;
|
||||
int dst_offset = past_kv_len_after_remove + copy_dst_begin_offset;
|
||||
|
||||
uint8_t* src_k_ptr = (uint8_t*)src_key_cache;
|
||||
uint8_t* dst_k_ptr = (uint8_t*)dst_key_cache;
|
||||
|
@ -54,13 +55,13 @@ __global__ void compact_kv_cache_kernel(
|
|||
// Key Cache: [L, B, H_kv, D] -> 拷贝整个 (B, H_kv, D) 下的 L 片段
|
||||
for (int l = 0; l < copy_len; ++l) {
|
||||
// Key: [L, B, H, D]
|
||||
long long k_src_idx = ((src_offset + l) * b + b_idx) * h_kv * d + h_kv_idx * d + d_idx;
|
||||
long long k_dst_idx = ((dst_offset + l) * b + b_idx) * h_kv * d + h_kv_idx * d + d_idx;
|
||||
int k_src_idx = ((src_offset + l) * b + b_idx) * h_kv * d + h_kv_idx * d + d_idx;
|
||||
int k_dst_idx = ((dst_offset + l) * b + b_idx) * h_kv * d + h_kv_idx * d + d_idx;
|
||||
memcpy(dst_k_ptr + k_dst_idx * element_size, src_k_ptr + k_src_idx * element_size, element_size);
|
||||
|
||||
// Value: [B, H, D, L]
|
||||
long long v_src_idx = (((long long)b_idx * h_kv + h_kv_idx) * d + d_idx) * kv_cache_max_len + (src_offset + l);
|
||||
long long v_dst_idx = (((long long)b_idx * h_kv + h_kv_idx) * d + d_idx) * kv_cache_max_len + (dst_offset + l);
|
||||
int v_src_idx = ((b_idx * h_kv + h_kv_idx) * d + d_idx) * kv_cache_max_len + (src_offset + l);
|
||||
int v_dst_idx = ((b_idx * h_kv + h_kv_idx) * d + d_idx) * kv_cache_max_len + (dst_offset + l);
|
||||
memcpy(dst_v_ptr + v_dst_idx * element_size, src_v_ptr + v_src_idx * element_size, element_size);
|
||||
}
|
||||
}
|
||||
|
@ -92,9 +93,9 @@ __global__ void copy_kv_to_cache_kernel(
|
|||
int h_kv_idx = bh_kv_idx % kv_num_head;
|
||||
|
||||
// 输入 K 和 V 的源索引 (假设输入布局为 [B, L_k_new, H_kv, D])
|
||||
long long input_offset = (long long)b_idx * new_kv_seq_len * kv_num_head * head_dim +
|
||||
(long long)l_idx_new * kv_num_head * head_dim +
|
||||
(long long)h_kv_idx * head_dim;
|
||||
int input_offset = b_idx * new_kv_seq_len * kv_num_head * head_dim +
|
||||
l_idx_new * kv_num_head * head_dim +
|
||||
h_kv_idx * head_dim;
|
||||
|
||||
T val_to_copy_k = key_input[input_offset + d_idx];
|
||||
T val_to_copy_v = value_input[input_offset + d_idx]; // 从相同偏移读取,因为K和V输入形状相同
|
||||
|
@ -104,16 +105,16 @@ __global__ void copy_kv_to_cache_kernel(
|
|||
if (dest_seq_idx_cache >= allocated_kv_len) return; // 边界检查
|
||||
|
||||
// Key Cache 输出: [L_kv_alloc, B, H_kv, D]
|
||||
long long key_cache_idx = (long long)dest_seq_idx_cache * batch_size * kv_num_head * head_dim +
|
||||
(long long)b_idx * kv_num_head * head_dim +
|
||||
(long long)h_kv_idx * head_dim +
|
||||
int key_cache_idx = dest_seq_idx_cache * batch_size * kv_num_head * head_dim +
|
||||
b_idx * kv_num_head * head_dim +
|
||||
h_kv_idx * head_dim +
|
||||
d_idx;
|
||||
key_cache_output[key_cache_idx] = val_to_copy_k;
|
||||
|
||||
// Value Cache 输出: [B, H_kv, D, L_kv_alloc]
|
||||
long long value_cache_idx = (long long)b_idx * kv_num_head * head_dim * allocated_kv_len +
|
||||
(long long)h_kv_idx * head_dim * allocated_kv_len +
|
||||
(long long)d_idx * allocated_kv_len +
|
||||
int value_cache_idx = b_idx * kv_num_head * head_dim * allocated_kv_len +
|
||||
h_kv_idx * head_dim * allocated_kv_len +
|
||||
d_idx * allocated_kv_len +
|
||||
dest_seq_idx_cache;
|
||||
value_cache_output[value_cache_idx] = val_to_copy_v;
|
||||
}
|
||||
|
@ -148,15 +149,15 @@ __global__ void qk_kernel(
|
|||
int h_kv_idx = h_q_idx / param->group; // 对应的 KV 头索引
|
||||
|
||||
// Query 元素指针基址 Q[b_idx, current_full_q_idx, h_q_idx, :]
|
||||
const T* q_ptr = query_input + (long long)b_idx * param->query_seq_len * param->head_num * param->head_dim +
|
||||
(long long)current_full_q_idx * param->head_num * param->head_dim +
|
||||
(long long)h_q_idx * param->head_dim;
|
||||
const T* q_ptr = query_input + b_idx * param->query_seq_len * param->head_num * param->head_dim +
|
||||
current_full_q_idx * param->head_num * param->head_dim +
|
||||
h_q_idx * param->head_dim;
|
||||
|
||||
// Key 元素指针基址 K_cache[k_idx, b_idx, h_kv_idx, :]
|
||||
// Key Cache: [L_kv_alloc, B, H_kv, D]
|
||||
const T* k_ptr = key_cache + (long long)k_idx * param->batch * param->kv_head_num * param->head_dim +
|
||||
(long long)b_idx * param->kv_head_num * param->head_dim +
|
||||
(long long)h_kv_idx * param->head_dim;
|
||||
const T* k_ptr = key_cache + k_idx * param->batch * param->kv_head_num * param->head_dim +
|
||||
b_idx * param->kv_head_num * param->head_dim +
|
||||
h_kv_idx * param->head_dim;
|
||||
|
||||
for (int d = 0; d < param->head_dim; ++d) {
|
||||
score_sum += static_cast<AccT>(q_ptr[d]) * static_cast<AccT>(k_ptr[d]);
|
||||
|
@ -168,26 +169,38 @@ __global__ void qk_kernel(
|
|||
// current_full_q_idx 是在完整查询序列中的索引 (行索引)
|
||||
// k_idx 是在当前有效Key序列中的索引 (列索引)
|
||||
// 浮点 Mask 布局为 L_q * L_q (param->query_seq_len * param->query_seq_len),整数 Mask 布局为 L_q * L_k
|
||||
long long mask_idx = (long long)current_full_q_idx * (is_add_mask_flag ? param->query_seq_len : param->key_seq_len) + k_idx - param->key_seq_len + param->query_seq_len;
|
||||
|
||||
if (is_add_mask_flag) {
|
||||
// 加性Mask通常是float类型
|
||||
if (k_idx >= param->key_seq_len - param->query_seq_len)
|
||||
score_sum += k_idx >= param->key_seq_len - param->query_seq_len ?
|
||||
static_cast<const AccT*>(mask_tensor_data)[mask_idx]
|
||||
: 0; // 前 L_k - L_q 个 Mask 均视为 0
|
||||
} else {
|
||||
// 设置Mask通常是int类型, 0表示mask掉
|
||||
if (is_add_mask_flag) { // 加性Mask是float类型
|
||||
int mask_idx = current_full_q_idx * param->query_seq_len + k_idx - param->key_seq_len + param->query_seq_len;
|
||||
if (k_idx >= param->key_seq_len - param->query_seq_len) { // 前 L_k - L_q 个 Mask 均视为 0
|
||||
if (sizeof(T) == sizeof(__half)) {
|
||||
score_sum += __half2float(((const __half*)mask_tensor_data)[mask_idx]);
|
||||
} else {
|
||||
score_sum += static_cast<const AccT*>(mask_tensor_data)[mask_idx];
|
||||
}
|
||||
}
|
||||
} else { // 设置Mask是int类型, 0表示mask掉
|
||||
int mask_idx = current_full_q_idx * param->key_seq_len + k_idx;
|
||||
if (static_cast<const int*>(mask_tensor_data)[mask_idx] == 0) {
|
||||
score_sum = AccT(-1e9f);
|
||||
if (sizeof(T) == sizeof(__half)) {
|
||||
score_sum = AccT(-65504.0f);
|
||||
} else {
|
||||
score_sum = AccT(-1e9f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (sizeof(T) == sizeof(__half)) {
|
||||
const AccT max_half_val = AccT(65504.0f);
|
||||
if (score_sum > max_half_val) score_sum = max_half_val;
|
||||
if (score_sum < -max_half_val) score_sum = -max_half_val;
|
||||
}
|
||||
|
||||
// 输出: qk_scores_output[b_idx, h_q_idx, q_idx_in_piece, k_idx]
|
||||
long long out_idx = (long long)b_idx * param->head_num * param->q_seq_piece_len * param->key_seq_len +
|
||||
(long long)h_q_idx * param->q_seq_piece_len * param->key_seq_len +
|
||||
(long long)q_idx_in_piece * param->key_seq_len +
|
||||
int out_idx = b_idx * param->head_num * param->q_seq_piece_len * param->key_seq_len +
|
||||
h_q_idx * param->q_seq_piece_len * param->key_seq_len +
|
||||
q_idx_in_piece * param->key_seq_len +
|
||||
k_idx;
|
||||
qk_scores_output[out_idx] = static_cast<T>(score_sum);
|
||||
}
|
||||
|
@ -209,8 +222,8 @@ __global__ void softmax_kernel(
|
|||
return;
|
||||
}
|
||||
|
||||
const T* current_row_scores = qk_scores + (long long)row_global_idx * param->key_seq_len;
|
||||
T* current_row_result = softmax_result + (long long)row_global_idx * param->key_seq_len;
|
||||
const T* current_row_scores = qk_scores + row_global_idx * param->key_seq_len;
|
||||
T* current_row_result = softmax_result + row_global_idx * param->key_seq_len;
|
||||
|
||||
// 1. 找到行中的最大值
|
||||
AccT max_val = AccT(-1e9f); // 或者 -FLT_MAX
|
||||
|
@ -229,7 +242,7 @@ __global__ void softmax_kernel(
|
|||
sum_exp += expf(static_cast<AccT>(current_row_scores[i]) - max_val);
|
||||
}
|
||||
// 避免除以零
|
||||
AccT inv_sum_exp = (sum_exp == 0.0f) ? AccT(1e-10f) : (AccT(1.0f) / sum_exp);
|
||||
AccT inv_sum_exp = (sum_exp == 0.0f) ? AccT(1e-6f) : (AccT(1.0f) / sum_exp);
|
||||
|
||||
// 3. 计算 Softmax
|
||||
for (int i = 0; i < param->key_seq_len; ++i) {
|
||||
|
@ -264,24 +277,24 @@ __global__ void qkv_kernel(
|
|||
int h_kv_idx = h_q_idx / param->group; // 对应的 KV 头索引
|
||||
|
||||
// Softmax 概率指针 S[b_idx, h_q_idx, q_idx_in_piece, :]
|
||||
const T* prob_ptr = softmax_probs + (long long)b_idx * param->head_num * param->q_seq_piece_len * param->key_seq_len +
|
||||
(long long)h_q_idx * param->q_seq_piece_len * param->key_seq_len +
|
||||
(long long)q_idx_in_piece * param->key_seq_len;
|
||||
const T* prob_ptr = softmax_probs + b_idx * param->head_num * param->q_seq_piece_len * param->key_seq_len +
|
||||
h_q_idx * param->q_seq_piece_len * param->key_seq_len +
|
||||
q_idx_in_piece * param->key_seq_len;
|
||||
|
||||
// Value Cache 指针基址 V[b_idx, h_kv_idx, d_idx, :]
|
||||
// Value Cache 布局: [B, H_kv, D, L_kv_alloc_max]
|
||||
const T* val_ptr_base = value_cache + (long long)b_idx * param->kv_head_num * param->head_dim * param->max_kv_len +
|
||||
(long long)h_kv_idx * param->head_dim * param->max_kv_len +
|
||||
(long long)d_idx * param->max_kv_len;
|
||||
const T* val_ptr_base = value_cache + b_idx * param->kv_head_num * param->head_dim * param->max_kv_len +
|
||||
h_kv_idx * param->head_dim * param->max_kv_len +
|
||||
d_idx * param->max_kv_len;
|
||||
|
||||
for (int k_s = 0; k_s < param->key_seq_len; ++k_s) { // 沿 L_k_total (param->key_seq_len) 求和
|
||||
weighted_sum += static_cast<AccT>(prob_ptr[k_s]) * static_cast<AccT>(val_ptr_base[k_s]);
|
||||
}
|
||||
|
||||
// 最终输出: O[b_idx, current_full_q_idx, h_q_idx, d_idx]
|
||||
long long out_idx = (long long)b_idx * param->query_seq_len * param->head_num * param->head_dim +
|
||||
(long long)current_full_q_idx * param->head_num * param->head_dim +
|
||||
(long long)h_q_idx * param->head_dim +
|
||||
int out_idx = b_idx * param->query_seq_len * param->head_num * param->head_dim +
|
||||
current_full_q_idx * param->head_num * param->head_dim +
|
||||
h_q_idx * param->head_dim +
|
||||
d_idx;
|
||||
attention_output[out_idx] = static_cast<T>(weighted_sum);
|
||||
}
|
||||
|
@ -293,11 +306,10 @@ AttentionExecution::AttentionExecution(Backend* backend, bool kv_cache_op_param)
|
|||
: Execution(backend), mIsKVCacheEnabled(kv_cache_op_param), mCudaBackend(static_cast<CUDABackend*>(backend)),
|
||||
mBatch(0), mQuerySeqLen(0), mNumHead(0), mHeadDim(0), mKvNumHead(0), mNewKvSeqLen(0),
|
||||
mQseqSplitNum(1), mHasMask(false), mIsAddMask(false), mParam_gpu(nullptr), mScale(1.0f) {
|
||||
mPrecision = halide_type_of<float>(); // 默认精度,可在 onResize 中更改
|
||||
mPrecision = 4; // 默认精度,可在 onResize 中更改
|
||||
if (mIsKVCacheEnabled) {
|
||||
mCache.reset(new SharedCache());
|
||||
mMeta = (KVMeta*)(mCudaBackend->getRuntime()->pMeta);
|
||||
|
||||
mMeta = (KVMeta*)(mCudaBackend->getMetaPtr());
|
||||
// printf("[Attention Constructor] Reading runtime. Address is: %p\n", (void*)(mCudaBackend->getRuntime()));
|
||||
// printf("[Attention Constructor] Reading pMeta as address. %p\n", (void*)(mCudaBackend->getRuntime()->pMeta));
|
||||
}
|
||||
|
@ -321,8 +333,12 @@ ErrorCode AttentionExecution::init_cache_tensors() {
|
|||
mCache->mMaxLength = 0;
|
||||
|
||||
// 创建占位Tensor,实际分配在 reallocKVCache_gpu 中进行
|
||||
mCache->mPastKey.reset(Tensor::createDevice({1, 1, 1, 1}, mPrecision, Tensor::CAFFE));
|
||||
mCache->mPastValue.reset(Tensor::createDevice({1, 1, 1, 1}, mPrecision, Tensor::CAFFE));
|
||||
mCache->mPastKey.reset(mPrecision == 4
|
||||
? Tensor::createDevice<float>({1, 1, 1, 1})
|
||||
: Tensor::createDevice<uint16_t>({1, 1, 1, 1}));
|
||||
mCache->mPastValue.reset(mPrecision == 4
|
||||
? Tensor::createDevice<float>({1, 1, 1, 1})
|
||||
: Tensor::createDevice<uint16_t>({1, 1, 1, 1}));
|
||||
if (!mCache->mPastKey || !mCache->mPastValue) return MNN::OUT_OF_MEMORY;
|
||||
|
||||
bool res = mCudaBackend->onAcquireBuffer(mCache->mPastKey.get(), Backend::STATIC);
|
||||
|
@ -349,9 +365,13 @@ ErrorCode AttentionExecution::reallocKVCache_gpu(int required_total_kv_len, int
|
|||
}
|
||||
|
||||
// Key Cache: [L_kv_max, B, H_kv, D]
|
||||
std::shared_ptr<Tensor> new_past_key_tensor(Tensor::createDevice({new_allocated_max_len, batch_size, kv_num_head, head_dim}, mPrecision, Tensor::CAFFE));
|
||||
std::shared_ptr<Tensor> new_past_key_tensor(mPrecision == 4
|
||||
? Tensor::createDevice<float>({new_allocated_max_len, batch_size, kv_num_head, head_dim})
|
||||
: Tensor::createDevice<uint16_t>({new_allocated_max_len, batch_size, kv_num_head, head_dim}));
|
||||
// Value Cache: [B, H_kv, D, L_kv_max]
|
||||
std::shared_ptr<Tensor> new_past_value_tensor(Tensor::createDevice({batch_size, kv_num_head, head_dim, new_allocated_max_len}, mPrecision, Tensor::CAFFE));
|
||||
std::shared_ptr<Tensor> new_past_value_tensor(mPrecision == 4
|
||||
? Tensor::createDevice<float>({batch_size, kv_num_head, head_dim, new_allocated_max_len})
|
||||
: Tensor::createDevice<uint16_t>({batch_size, kv_num_head, head_dim, new_allocated_max_len}));
|
||||
|
||||
if (!new_past_key_tensor || !new_past_value_tensor) return MNN::OUT_OF_MEMORY;
|
||||
|
||||
|
@ -360,7 +380,7 @@ ErrorCode AttentionExecution::reallocKVCache_gpu(int required_total_kv_len, int
|
|||
if(!resK || !resV) return MNN::OUT_OF_MEMORY;
|
||||
|
||||
if (needs_copy) {
|
||||
size_t element_size_bytes = mPrecision.bytes();
|
||||
size_t element_size_bytes = mPrecision;
|
||||
// 拷贝 Key Cache: 从旧的拷贝 [old_past_len, B, H_kv, D] 片段到新的
|
||||
size_t key_bytes_to_copy = (size_t)old_past_len * batch_size * kv_num_head * head_dim * element_size_bytes;
|
||||
if (key_bytes_to_copy > 0) {
|
||||
|
@ -376,9 +396,9 @@ ErrorCode AttentionExecution::reallocKVCache_gpu(int required_total_kv_len, int
|
|||
for (int h_kv = 0; h_kv < kv_num_head; ++h_kv) {
|
||||
for (int d_s = 0; d_s < head_dim; ++d_s) {
|
||||
uint8_t* dst_ptr = getTensorDevicePtr<uint8_t>(new_past_value_tensor.get()) +
|
||||
((((long long)b * kv_num_head + h_kv) * head_dim + d_s) * new_allocated_max_len) * element_size_bytes;
|
||||
(((b * kv_num_head + h_kv) * head_dim + d_s) * new_allocated_max_len) * element_size_bytes;
|
||||
uint8_t* src_ptr = getTensorDevicePtr<uint8_t>(mCache->mPastValue.get()) +
|
||||
((((long long)b * kv_num_head + h_kv) * head_dim + d_s) * old_max_len) * element_size_bytes;
|
||||
(((b * kv_num_head + h_kv) * head_dim + d_s) * old_max_len) * element_size_bytes;
|
||||
if ((size_t)old_past_len * element_size_bytes > 0) {
|
||||
cudaMemcpyAsync(dst_ptr, src_ptr, (size_t)old_past_len * element_size_bytes, cudaMemcpyDeviceToDevice, stream);
|
||||
checkKernelErrors;
|
||||
|
@ -409,7 +429,7 @@ ErrorCode AttentionExecution::reallocKVCache_gpu(int required_total_kv_len, cons
|
|||
return MNN::INVALID_VALUE;
|
||||
}
|
||||
|
||||
size_t element_size = mPrecision.bytes();
|
||||
size_t element_size = mPrecision;
|
||||
bool needs_realloc = required_total_kv_len > mCache->mMaxLength;
|
||||
|
||||
int past_len_after_remove = mCache->mPastLength - meta->remove;
|
||||
|
@ -437,14 +457,12 @@ ErrorCode AttentionExecution::reallocKVCache_gpu(int required_total_kv_len, cons
|
|||
if (needs_realloc) {
|
||||
// 优化路径:如果需要扩容,则直接将旧Cache的有效数据紧凑化拷贝到新Cache中
|
||||
int new_allocated_max_len = ROUND_UP(required_total_kv_len, mExpandChunk);
|
||||
std::shared_ptr<Tensor> new_past_key_tensor(
|
||||
Tensor::createDevice({new_allocated_max_len, mBatch, mKvNumHead, mHeadDim},
|
||||
mPrecision, Tensor::CAFFE
|
||||
));
|
||||
std::shared_ptr<Tensor> new_past_value_tensor(
|
||||
Tensor::createDevice({mBatch, mKvNumHead, mHeadDim, new_allocated_max_len},
|
||||
mPrecision, Tensor::CAFFE
|
||||
));
|
||||
std::shared_ptr<Tensor> new_past_key_tensor(mPrecision == 4
|
||||
? Tensor::createDevice<float>({new_allocated_max_len, mBatch, mKvNumHead, mHeadDim})
|
||||
: Tensor::createDevice<uint16_t>({new_allocated_max_len, mBatch, mKvNumHead, mHeadDim}));
|
||||
std::shared_ptr<Tensor> new_past_value_tensor(mPrecision == 4
|
||||
? Tensor::createDevice<float>({mBatch, mKvNumHead, mHeadDim, new_allocated_max_len})
|
||||
: Tensor::createDevice<uint16_t>({mBatch, mKvNumHead, mHeadDim, new_allocated_max_len}));
|
||||
if(!mCudaBackend->onAcquireBuffer(new_past_key_tensor.get(), Backend::STATIC)
|
||||
|| !mCudaBackend->onAcquireBuffer(new_past_value_tensor.get(), Backend::STATIC)) {
|
||||
return MNN::OUT_OF_MEMORY;
|
||||
|
@ -466,8 +484,12 @@ ErrorCode AttentionExecution::reallocKVCache_gpu(int required_total_kv_len, cons
|
|||
mCache->mMaxLength = new_allocated_max_len;
|
||||
} else {
|
||||
// 如果不需要扩容,则需要一个临时缓冲区来执行原地紧凑化
|
||||
std::shared_ptr<Tensor> temp_key(Tensor::createDevice(mCache->mPastKey->shape(), mPrecision, Tensor::CAFFE));
|
||||
std::shared_ptr<Tensor> temp_value(Tensor::createDevice(mCache->mPastValue->shape(), mPrecision, Tensor::CAFFE));
|
||||
std::shared_ptr<Tensor> temp_key(mPrecision == 4
|
||||
? Tensor::createDevice<float>(mCache->mPastKey->shape())
|
||||
: Tensor::createDevice<uint16_t>(mCache->mPastKey->shape()));
|
||||
std::shared_ptr<Tensor> temp_value(mPrecision == 4
|
||||
? Tensor::createDevice<float>(mCache->mPastValue->shape())
|
||||
: Tensor::createDevice<uint16_t>(mCache->mPastValue->shape()));
|
||||
if(!mCudaBackend->onAcquireBuffer(temp_key.get(), Backend::STATIC)
|
||||
|| !mCudaBackend->onAcquireBuffer(temp_value.get(), Backend::STATIC)) {
|
||||
return MNN::OUT_OF_MEMORY;
|
||||
|
@ -508,14 +530,12 @@ ErrorCode AttentionExecution::reallocKVCache_gpu(int required_total_kv_len, cons
|
|||
int old_past_len_to_copy = mCache->mPastLength;
|
||||
int new_allocated_max_len = ROUND_UP(required_total_kv_len, mExpandChunk);
|
||||
|
||||
std::shared_ptr<Tensor> new_past_key_tensor(
|
||||
Tensor::createDevice({new_allocated_max_len, mBatch, mKvNumHead, mHeadDim},
|
||||
mPrecision, Tensor::CAFFE
|
||||
));
|
||||
std::shared_ptr<Tensor> new_past_value_tensor(
|
||||
Tensor::createDevice({mBatch, mKvNumHead, mHeadDim, new_allocated_max_len},
|
||||
mPrecision, Tensor::CAFFE
|
||||
));
|
||||
std::shared_ptr<Tensor> new_past_key_tensor(mPrecision == 4
|
||||
? Tensor::createDevice<float>({new_allocated_max_len, mBatch, mKvNumHead, mHeadDim})
|
||||
: Tensor::createDevice<uint16_t>({new_allocated_max_len, mBatch, mKvNumHead, mHeadDim}));
|
||||
std::shared_ptr<Tensor> new_past_value_tensor(mPrecision == 4
|
||||
? Tensor::createDevice<float>({mBatch, mKvNumHead, mHeadDim, new_allocated_max_len})
|
||||
: Tensor::createDevice<uint16_t>({mBatch, mKvNumHead, mHeadDim, new_allocated_max_len}));
|
||||
if(!mCudaBackend->onAcquireBuffer(new_past_key_tensor.get(), Backend::STATIC)
|
||||
|| !mCudaBackend->onAcquireBuffer(new_past_value_tensor.get(), Backend::STATIC)) {
|
||||
return MNN::OUT_OF_MEMORY;
|
||||
|
@ -554,7 +574,9 @@ ErrorCode AttentionExecution::ensureTempBuffers_gpu(int batch, int num_head, int
|
|||
bool qk_realloc = !mTempQK || mTempQK->shape() != qk_shape;
|
||||
if (qk_realloc) {
|
||||
if(mTempQK && mTempQK->deviceId() != 0) mCudaBackend->onReleaseBuffer(mTempQK.get(), Backend::STATIC);
|
||||
mTempQK.reset(Tensor::createDevice(qk_shape, mPrecision, Tensor::CAFFE));
|
||||
mTempQK.reset(mPrecision == 4
|
||||
? Tensor::createDevice<float>(qk_shape)
|
||||
: Tensor::createDevice<uint16_t>(qk_shape));
|
||||
if(!mTempQK || !mCudaBackend->onAcquireBuffer(mTempQK.get(), Backend::STATIC)) return MNN::OUT_OF_MEMORY;
|
||||
}
|
||||
|
||||
|
@ -562,7 +584,9 @@ ErrorCode AttentionExecution::ensureTempBuffers_gpu(int batch, int num_head, int
|
|||
bool softmax_realloc = !mTempSoftmax || mTempSoftmax->shape() != qk_shape;
|
||||
if (softmax_realloc) {
|
||||
if(mTempSoftmax && mTempSoftmax->deviceId() != 0) mCudaBackend->onReleaseBuffer(mTempSoftmax.get(), Backend::STATIC);
|
||||
mTempSoftmax.reset(Tensor::createDevice(qk_shape, mPrecision, Tensor::CAFFE));
|
||||
mTempSoftmax.reset(mPrecision == 4
|
||||
? Tensor::createDevice<float>(qk_shape)
|
||||
: Tensor::createDevice<uint16_t>(qk_shape));
|
||||
if(!mTempSoftmax || !mCudaBackend->onAcquireBuffer(mTempSoftmax.get(), Backend::STATIC)) return MNN::OUT_OF_MEMORY;
|
||||
}
|
||||
|
||||
|
@ -573,7 +597,9 @@ ErrorCode AttentionExecution::ensureTempBuffers_gpu(int batch, int num_head, int
|
|||
bool temp_k_realloc = !mTempK_current_step || mTempK_current_step->shape() != temp_k_shape;
|
||||
if(temp_k_realloc){
|
||||
if(mTempK_current_step && mTempK_current_step->deviceId() !=0) mCudaBackend->onReleaseBuffer(mTempK_current_step.get(), Backend::STATIC);
|
||||
mTempK_current_step.reset(Tensor::createDevice(temp_k_shape, mPrecision, Tensor::CAFFE));
|
||||
mTempK_current_step.reset(mPrecision == 4
|
||||
? Tensor::createDevice<float>(temp_k_shape)
|
||||
: Tensor::createDevice<uint16_t>(temp_k_shape) );
|
||||
if(!mTempK_current_step || !mCudaBackend->onAcquireBuffer(mTempK_current_step.get(), Backend::STATIC)) return MNN::OUT_OF_MEMORY;
|
||||
}
|
||||
|
||||
|
@ -582,7 +608,9 @@ ErrorCode AttentionExecution::ensureTempBuffers_gpu(int batch, int num_head, int
|
|||
bool temp_v_realloc = !mTempV_current_step || mTempV_current_step->shape() != temp_v_shape;
|
||||
if(temp_v_realloc){
|
||||
if(mTempV_current_step && mTempV_current_step->deviceId() !=0) mCudaBackend->onReleaseBuffer(mTempV_current_step.get(), Backend::STATIC);
|
||||
mTempV_current_step.reset(Tensor::createDevice(temp_v_shape, mPrecision, Tensor::CAFFE));
|
||||
mTempV_current_step.reset(mPrecision == 4
|
||||
? Tensor::createDevice<float>(temp_v_shape)
|
||||
: Tensor::createDevice<uint16_t>(temp_v_shape));
|
||||
if(!mTempV_current_step || !mCudaBackend->onAcquireBuffer(mTempV_current_step.get(), Backend::STATIC)) return MNN::OUT_OF_MEMORY;
|
||||
}
|
||||
}
|
||||
|
@ -602,11 +630,6 @@ bool AttentionExecution::onClone(Backend* bn, const Op* op, Execution** dst) {
|
|||
|
||||
// #define DEBUG_ATTENTION_VERBOSE
|
||||
|
||||
static inline float half_to_float_debug(const __half& h_val) {
|
||||
return __half2float(h_val);
|
||||
}
|
||||
|
||||
// 打印GPU张量指定切片的辅助函数
|
||||
void print_gpu_tensor_debug(
|
||||
const MNN::Tensor* target_tensor, // 要打印的张量 (可以是GPU或CPU上的)
|
||||
const char* name // 张量的名称,用于日志
|
||||
|
@ -620,12 +643,77 @@ void print_gpu_tensor_debug(
|
|||
target_tensor->print();
|
||||
}
|
||||
|
||||
// 打印GPU张量指定切片的辅助函数
|
||||
void print_gpu_tensor_debug(
|
||||
int mPrecision,
|
||||
const MNN::Tensor* target_tensor,
|
||||
const char* name
|
||||
) {
|
||||
if (!target_tensor) {
|
||||
printf("\n--- Tensor [%s] is null. ---\n", name);
|
||||
return;
|
||||
}
|
||||
|
||||
printf("\n--- Tensor [%s] ---\n", name);
|
||||
|
||||
const MNN::Tensor* tensor_to_print = nullptr;
|
||||
std::unique_ptr<MNN::Tensor, decltype(&MNN::Tensor::destroy)> host_tensor_holder(nullptr, &MNN::Tensor::destroy);
|
||||
|
||||
if (target_tensor->deviceId() != 0) {
|
||||
host_tensor_holder.reset(MNN::Tensor::createHostTensorFromDevice(target_tensor, true));
|
||||
if (!host_tensor_holder) {
|
||||
printf("Error: Failed to copy device tensor to host for printing.\n");
|
||||
return;
|
||||
}
|
||||
tensor_to_print = host_tensor_holder.get();
|
||||
} else {
|
||||
tensor_to_print = target_tensor;
|
||||
}
|
||||
|
||||
printf("Shape: [ ");
|
||||
const auto& shape = tensor_to_print->shape();
|
||||
for (int val : shape) printf("%d ", val);
|
||||
printf("], ");
|
||||
|
||||
const int total_elements = tensor_to_print->elementSize() / mPrecision;
|
||||
if (total_elements == 0) {
|
||||
printf("Data: (empty)\n");
|
||||
return;
|
||||
}
|
||||
printf("Data (all %d elements):\n", total_elements);
|
||||
|
||||
int elements_per_line = total_elements; // 默认情况下,对于一维或零维张量,不换行
|
||||
const int dims = tensor_to_print->dimensions();
|
||||
if (dims > 1) elements_per_line = tensor_to_print->length(dims - 1);
|
||||
|
||||
if (mPrecision == 2) {
|
||||
const __half* data_ptr = tensor_to_print->host<__half>();
|
||||
for (int i = 0; i < total_elements; ++i) {
|
||||
printf("%8.4f ", __half2float(data_ptr[i]));
|
||||
if ((i + 1) % elements_per_line == 0 && (i + 1) < total_elements) printf("\n");
|
||||
}
|
||||
} else if (mPrecision == 4) {
|
||||
const float* data_ptr = tensor_to_print->host<float>();
|
||||
for (int i = 0; i < total_elements; ++i) {
|
||||
printf("%8.4f ", data_ptr[i]);
|
||||
if ((i + 1) % elements_per_line == 0 && (i + 1) < total_elements) printf("\n");
|
||||
}
|
||||
} else {
|
||||
target_tensor->print();
|
||||
}
|
||||
printf("\n--- End of Tensor [%s] ---\n", name);
|
||||
}
|
||||
|
||||
// 初始化所有参数
|
||||
ErrorCode AttentionExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
||||
const auto* query_tensor = inputs[0]; // 形状: [B, L_q, H_q, D]
|
||||
const auto* key_tensor = inputs[1]; // 形状: [B, L_k_new, H_kv, D]
|
||||
|
||||
mPrecision = query_tensor->getType(); // 获取Tensor的halide_type_t
|
||||
if (mCudaBackend->useFp16()) {
|
||||
mPrecision = 2;
|
||||
} else {
|
||||
mPrecision = 4;
|
||||
}
|
||||
|
||||
mBatch = query_tensor->length(0);
|
||||
mQuerySeqLen = query_tensor->length(1);
|
||||
|
@ -671,7 +759,7 @@ ErrorCode AttentionExecution::onExecute(const std::vector<Tensor *> &inputs, con
|
|||
const Tensor* mask_input_tensor = mHasMask ? inputs[3] : nullptr;
|
||||
auto final_output_tensor = outputs[0];
|
||||
|
||||
// qk_kernel 默认 Mask 为 [1, 1, L_q, L_k],但模型会出现 [1, 1, 1, 1] 值为 -0.00 的 Mask 代表无 Mask
|
||||
// qk_kernel 默认 Mask 为 [L_q, L_k],但模型会出现 [1, 1, 1, 1] 值为 -0.00 的 Mask 代表无 Mask
|
||||
if (mIsKVCacheEnabled && mHasMask && mask_input_tensor && mask_input_tensor->elementSize() == 1) {
|
||||
mHasMask = false;
|
||||
}
|
||||
|
@ -758,15 +846,15 @@ ErrorCode AttentionExecution::onExecute(const std::vector<Tensor *> &inputs, con
|
|||
|
||||
// 拷贝新的 K, V 到 Cache
|
||||
dim3 copy_blockDim(32, 8, 1); // 暂定 Block大小: 256 线程
|
||||
dim3 copy_gridDim(UP_DIV(mHeadDim, (int)copy_blockDim.x),
|
||||
UP_DIV(mNewKvSeqLen, (int)copy_blockDim.y),
|
||||
UP_DIV(mBatch * mKvNumHead, (int)copy_blockDim.z));
|
||||
if (mPrecision.bytes() == 4) { // float32
|
||||
dim3 copy_gridDim(UP_DIV(mHeadDim, copy_blockDim.x),
|
||||
UP_DIV(mNewKvSeqLen, copy_blockDim.y),
|
||||
UP_DIV(mBatch * mKvNumHead, copy_blockDim.z));
|
||||
if (mPrecision == 4) { // float32
|
||||
copy_kv_to_cache_kernel<float><<<copy_gridDim, copy_blockDim, 0, stream>>>(
|
||||
getTensorDevicePtr<float>(key_input_tensor), getTensorDevicePtr<float>(value_input_tensor),
|
||||
getTensorDevicePtr<float>(mCache->mPastKey.get()), getTensorDevicePtr<float>(mCache->mPastValue.get()),
|
||||
mBatch, mNewKvSeqLen, mKvNumHead, mHeadDim, param_cpu.past_kv_len, mCache->mMaxLength);
|
||||
} else if (mPrecision.bytes() == 2) { // float16
|
||||
} else if (mPrecision == 2) { // float16
|
||||
copy_kv_to_cache_kernel<__half><<<copy_gridDim, copy_blockDim, 0, stream>>>(
|
||||
getTensorDevicePtr<__half>(key_input_tensor), getTensorDevicePtr<__half>(value_input_tensor),
|
||||
getTensorDevicePtr<__half>(mCache->mPastKey.get()), getTensorDevicePtr<__half>(mCache->mPastValue.get()),
|
||||
|
@ -778,10 +866,10 @@ ErrorCode AttentionExecution::onExecute(const std::vector<Tensor *> &inputs, con
|
|||
effective_value_cache_ptr = getTensorDevicePtr(mCache->mPastValue.get());
|
||||
|
||||
#ifdef DEBUG_ATTENTION_VERBOSE
|
||||
if (current_total_kv_len_for_qk >= 0) {
|
||||
if (current_total_kv_len_for_qk >= 20) {
|
||||
// cudaStreamSynchronize(stream); // 确保 copy_kv_to_cache_kernel 完成
|
||||
// if (mCache && mCache->mPastKey) print_gpu_tensor_debug(mCache->mPastKey.get(), "Key Cache (After copy_kv_to_cache)"); // L_kv_alloc, B, H_kv, D
|
||||
// if (mCache && mCache->mPastValue) print_gpu_tensor_debug(mCache->mPastValue.get(), "Value Cache (After copy_kv_to_cache)"); // B, H_kv, D, L_kv_alloc
|
||||
// if (mCache && mCache->mPastKey) print_gpu_tensor_debug(mPrecision, mCache->mPastKey.get(), "Key Cache (After copy_kv_to_cache)"); // L_kv_alloc, B, H_kv, D
|
||||
// if (mCache && mCache->mPastValue) print_gpu_tensor_debug(mPrecision, mCache->mPastValue.get(), "Value Cache (After copy_kv_to_cache)"); // B, H_kv, D, L_kv_alloc
|
||||
}
|
||||
#endif
|
||||
} else { // 没有 KV Cache, 使用当前步骤的临时Tensor存储 K/V
|
||||
|
@ -796,15 +884,15 @@ ErrorCode AttentionExecution::onExecute(const std::vector<Tensor *> &inputs, con
|
|||
if(temp_err != MNN::NO_ERROR) return temp_err;
|
||||
|
||||
dim3 copy_blockDim(32, 8, 1);
|
||||
dim3 copy_gridDim(UP_DIV(mHeadDim, (int)copy_blockDim.x),
|
||||
UP_DIV(mNewKvSeqLen, (int)copy_blockDim.y),
|
||||
UP_DIV(mBatch * mKvNumHead, (int)copy_blockDim.z));
|
||||
if (mPrecision.bytes() == 4) {
|
||||
dim3 copy_gridDim(UP_DIV(mHeadDim, copy_blockDim.x),
|
||||
UP_DIV(mNewKvSeqLen, copy_blockDim.y),
|
||||
UP_DIV(mBatch * mKvNumHead, copy_blockDim.z));
|
||||
if (mPrecision == 4) {
|
||||
copy_kv_to_cache_kernel<float><<<copy_gridDim, copy_blockDim, 0, stream>>>(
|
||||
getTensorDevicePtr<float>(key_input_tensor), getTensorDevicePtr<float>(value_input_tensor),
|
||||
getTensorDevicePtr<float>(mTempK_current_step.get()), getTensorDevicePtr<float>(mTempV_current_step.get()),
|
||||
mBatch, mNewKvSeqLen, mKvNumHead, mHeadDim, 0, allocated_kv_len_for_value_stride);
|
||||
} else if (mPrecision.bytes() == 2) {
|
||||
} else if (mPrecision == 2) {
|
||||
copy_kv_to_cache_kernel<__half><<<copy_gridDim, copy_blockDim, 0, stream>>>(
|
||||
getTensorDevicePtr<__half>(key_input_tensor), getTensorDevicePtr<__half>(value_input_tensor),
|
||||
getTensorDevicePtr<__half>(mTempK_current_step.get()), getTensorDevicePtr<__half>(mTempV_current_step.get()),
|
||||
|
@ -817,8 +905,8 @@ ErrorCode AttentionExecution::onExecute(const std::vector<Tensor *> &inputs, con
|
|||
|
||||
#ifdef DEBUG_ATTENTION_VERBOSE
|
||||
// cudaStreamSynchronize(stream); // 确保 copy_kv_to_cache_kernel 完成
|
||||
// if (mTempK_current_step) print_gpu_tensor_debug(mTempK_current_step.get(), "Temp K Current Step (Copied from Input K)");
|
||||
// if (mTempV_current_step) print_gpu_tensor_debug(mTempV_current_step.get(), "Temp V Current Step (Copied from Input V)");
|
||||
// if (mTempK_current_step) print_gpu_tensor_debug(mPrecision, mTempK_current_step.get(), "Temp K Current Step (Copied from Input K)");
|
||||
// if (mTempV_current_step) print_gpu_tensor_debug(mPrecision, mTempV_current_step.get(), "Temp V Current Step (Copied from Input V)");
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -844,16 +932,16 @@ ErrorCode AttentionExecution::onExecute(const std::vector<Tensor *> &inputs, con
|
|||
|
||||
// QK Kernel
|
||||
dim3 qk_blockDim(16, 16, 1); // 暂用 256 线程
|
||||
dim3 qk_gridDim(UP_DIV(current_total_kv_len_for_qk, (int)qk_blockDim.x),
|
||||
UP_DIV(current_piece_actual_len, (int)qk_blockDim.y),
|
||||
UP_DIV(mBatch * mNumHead, (int)qk_blockDim.z));
|
||||
dim3 qk_gridDim(UP_DIV(current_total_kv_len_for_qk, qk_blockDim.x),
|
||||
UP_DIV(current_piece_actual_len, qk_blockDim.y),
|
||||
UP_DIV(mBatch * mNumHead, qk_blockDim.z));
|
||||
|
||||
if (mPrecision.bytes() == 4) {
|
||||
if (mPrecision == 4) {
|
||||
qk_kernel<float><<<qk_gridDim, qk_blockDim, 0, stream>>>(
|
||||
getTensorDevicePtr<float>(query_input_tensor), static_cast<const float*>(effective_key_cache_ptr),
|
||||
getTensorDevicePtr<float>(mTempQK.get()), mask_ptr_device,
|
||||
mParam_gpu, q_seq_offset, mHasMask, mIsAddMask);
|
||||
} else if (mPrecision.bytes() == 2) {
|
||||
} else if (mPrecision == 2) {
|
||||
qk_kernel<__half><<<qk_gridDim, qk_blockDim, 0, stream>>>(
|
||||
getTensorDevicePtr<__half>(query_input_tensor), static_cast<const __half*>(effective_key_cache_ptr),
|
||||
getTensorDevicePtr<__half>(mTempQK.get()), mask_ptr_device,
|
||||
|
@ -862,28 +950,65 @@ ErrorCode AttentionExecution::onExecute(const std::vector<Tensor *> &inputs, con
|
|||
checkKernelErrors;
|
||||
|
||||
// Softmax Kernel
|
||||
int softmax_total_rows_for_piece = mBatch * mNumHead * current_piece_actual_len;
|
||||
dim3 softmax_blockDim(256, 1, 1); // 每个线程处理一行
|
||||
dim3 softmax_gridDim(UP_DIV(softmax_total_rows_for_piece, (int)softmax_blockDim.x), 1, 1);
|
||||
if (mPrecision.bytes() == 4) {
|
||||
softmax_kernel<float><<<softmax_gridDim, softmax_blockDim, 0, stream>>>(
|
||||
getTensorDevicePtr<float>(mTempQK.get()), getTensorDevicePtr<float>(mTempSoftmax.get()), mParam_gpu, current_piece_actual_len);
|
||||
} else if (mPrecision.bytes() == 2) {
|
||||
softmax_kernel<__half><<<softmax_gridDim, softmax_blockDim, 0, stream>>>(
|
||||
getTensorDevicePtr<__half>(mTempQK.get()), getTensorDevicePtr<__half>(mTempSoftmax.get()), mParam_gpu, current_piece_actual_len);
|
||||
} else { return MNN::NOT_SUPPORT; }
|
||||
|
||||
// int softmax_total_rows_for_piece = mBatch * mNumHead * current_piece_actual_len;
|
||||
// dim3 softmax_blockDim(256, 1, 1); // 每个线程处理一行
|
||||
// dim3 softmax_gridDim(UP_DIV(softmax_total_rows_for_piece, (int)softmax_blockDim.x), 1, 1);
|
||||
// if (mPrecision == 4) {
|
||||
// softmax_kernel<float><<<softmax_gridDim, softmax_blockDim, 0, stream>>>(
|
||||
// getTensorDevicePtr<float>(mTempQK.get()), getTensorDevicePtr<float>(mTempSoftmax.get()), mParam_gpu, current_piece_actual_len);
|
||||
// } else if (mPrecision == 2) {
|
||||
// softmax_kernel<__half><<<softmax_gridDim, softmax_blockDim, 0, stream>>>(
|
||||
// getTensorDevicePtr<__half>(mTempQK.get()), getTensorDevicePtr<__half>(mTempSoftmax.get()), mParam_gpu, current_piece_actual_len);
|
||||
// } else { return MNN::NOT_SUPPORT; }
|
||||
// checkKernelErrors;
|
||||
|
||||
const int axis = current_total_kv_len_for_qk;
|
||||
const int inside = 1;
|
||||
const int outside = mBatch * mNumHead * current_piece_actual_len;
|
||||
const int count = outside * inside;
|
||||
|
||||
const void* qk_scores_ptr = getTensorDevicePtr(mTempQK.get());
|
||||
void* softmax_result_ptr = getTensorDevicePtr(mTempSoftmax.get());
|
||||
|
||||
// 根据 axis 长度和精度选择最优核函数并启动
|
||||
// 启动配置:每个线程块(Block)处理一行(row)Softmax
|
||||
// Grid维度 = 总行数(count),Block维度 = 用于并行计算一行的线程数
|
||||
if (mPrecision == 4) { // FP32
|
||||
const auto* input_ptr = static_cast<const float*>(qk_scores_ptr);
|
||||
auto* output_ptr = static_cast<float*>(softmax_result_ptr);
|
||||
if (axis <= 32) {
|
||||
// 对于短序列,使用Warp级别原语进行极致优化
|
||||
SOFTMAX_WARP_32<float><<<count, 32, 0, stream>>>(input_ptr, output_ptr, inside, axis, outside, count);
|
||||
} else {
|
||||
// 对于长序列,使用基于共享内存的块内并行归约
|
||||
constexpr int threads_per_block = 256;
|
||||
const int calc_multi_num = UP_DIV(axis, threads_per_block);
|
||||
SOFTMAX_AXIS_REDUCE<float><<<count, threads_per_block, 0, stream>>>(input_ptr, output_ptr, inside, axis, threads_per_block, calc_multi_num, outside, count);
|
||||
}
|
||||
} else { // FP16
|
||||
const auto* input_ptr = static_cast<const __half*>(qk_scores_ptr);
|
||||
auto* output_ptr = static_cast<__half*>(softmax_result_ptr);
|
||||
if (axis <= 32) {
|
||||
SOFTMAX_WARP_32<__half><<<count, 32, 0, stream>>>(input_ptr, output_ptr, inside, axis, outside, count);
|
||||
} else {
|
||||
constexpr int threads_per_block = 256;
|
||||
const int calc_multi_num = UP_DIV(axis, threads_per_block);
|
||||
SOFTMAX_AXIS_REDUCE<__half><<<count, threads_per_block, 0, stream>>>(input_ptr, output_ptr, inside, axis, threads_per_block, calc_multi_num, outside, count);
|
||||
}
|
||||
}
|
||||
checkKernelErrors;
|
||||
|
||||
// QKV Kernel
|
||||
dim3 qkv_blockDim(32, 8, 1); // 暂用 256 线程
|
||||
dim3 qkv_gridDim(UP_DIV(mHeadDim, (int)qkv_blockDim.x),
|
||||
UP_DIV(current_piece_actual_len, (int)qkv_blockDim.y),
|
||||
UP_DIV(mBatch * mNumHead, (int)qkv_blockDim.z));
|
||||
if (mPrecision.bytes() == 4) {
|
||||
dim3 qkv_gridDim(UP_DIV(mHeadDim, qkv_blockDim.x),
|
||||
UP_DIV(current_piece_actual_len, qkv_blockDim.y),
|
||||
UP_DIV(mBatch * mNumHead, qkv_blockDim.z));
|
||||
if (mPrecision == 4) {
|
||||
qkv_kernel<float><<<qkv_gridDim, qkv_blockDim, 0, stream>>>(
|
||||
getTensorDevicePtr<float>(mTempSoftmax.get()), static_cast<const float*>(effective_value_cache_ptr),
|
||||
getTensorDevicePtr<float>(final_output_tensor), mParam_gpu, q_seq_offset);
|
||||
} else if (mPrecision.bytes() == 2) {
|
||||
} else if (mPrecision == 2) {
|
||||
qkv_kernel<__half><<<qkv_gridDim, qkv_blockDim, 0, stream>>>(
|
||||
getTensorDevicePtr<__half>(mTempSoftmax.get()), static_cast<const __half*>(effective_value_cache_ptr),
|
||||
getTensorDevicePtr<__half>(final_output_tensor), mParam_gpu, q_seq_offset);
|
||||
|
@ -916,23 +1041,22 @@ ErrorCode AttentionExecution::onExecute(const std::vector<Tensor *> &inputs, con
|
|||
printf("--------------------------------------------------------------------\n");
|
||||
|
||||
if (current_total_kv_len_for_qk >= 20) {
|
||||
// 打印 Mask (如果存在)
|
||||
// Mask 通常是 [1, 1, Lq_full, Lk_total]
|
||||
// 打印 Mask
|
||||
if (mHasMask && mask_input_tensor) {
|
||||
// print_gpu_tensor_debug(mask_input_tensor, "Mask Input Tensor (Relevant Slice view)");
|
||||
// print_gpu_tensor_debug(mask_input_tensor, "Mask Input Tensor");
|
||||
}
|
||||
|
||||
// 打印 QK^T Scores, mTempQK shape: [B, H_q, L_q_piece, L_k_total]
|
||||
// print_gpu_tensor_debug(mTempQK.get(), "QK^T Scores (mTempQK for current piece)");
|
||||
// print_gpu_tensor_debug(mPrecision, mTempQK.get(), "QK^T Scores (mTempQK for current piece)");
|
||||
|
||||
// 打印 Softmax Probabilities, mTempSoftmax shape: [B, H_q, L_q_piece, L_k_total]
|
||||
// print_gpu_tensor_debug(mTempSoftmax.get(), "Softmax Probs (mTempSoftmax for current piece)");
|
||||
// print_gpu_tensor_debug(mPrecision, mTempSoftmax.get(), "Softmax Probs (mTempSoftmax for current piece)");
|
||||
|
||||
// 打印 Attention Output, final_output_tensor shape: [B, L_q_full, H_q, D] or [1, 1, N]
|
||||
if (final_output_tensor->dimensions()==3) { // 三维如 [1, 1, N]
|
||||
// print_gpu_tensor_debug(final_output_tensor, "Attention Output Slice (Final)");
|
||||
// print_gpu_tensor_debug(mPrecision, final_output_tensor, "Attention Output Slice (Final)");
|
||||
} else { // 假设是 [B, Lq_full, Hq, D]
|
||||
// print_gpu_tensor_debug(final_output_tensor, "Attention Output Slice (Final)");
|
||||
// print_gpu_tensor_debug(mPrecision, final_output_tensor, "Attention Output Slice (Final)");
|
||||
}
|
||||
|
||||
printf("========================= END PIECE DEBUG DUMP =======================\n");
|
||||
|
|
|
@ -80,7 +80,7 @@ private:
|
|||
KVMeta* mMeta = nullptr;
|
||||
|
||||
const int mExpandChunk = 64; // KV Cache重分配时的扩展块大小
|
||||
halide_type_t mPrecision; // 精度 (float或half)
|
||||
int mPrecision; // 精度 (float或half)
|
||||
};
|
||||
#endif // MNN_SUPPORT_TRANSFORMER_FUSE
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ public:
|
|||
|
||||
#ifdef MNN_LOW_MEMORY
|
||||
auto conv2dParams = op->main_as_Convolution2D();
|
||||
bool isMemoryLowWeightOnlyQuant = (conv2dParams->quanParameter() != nullptr && conv2dParams->quanParameter()->buffer() != nullptr);
|
||||
bool isMemoryLowWeightOnlyQuant = conv2dParams->quanParameter() && (conv2dParams->external() || conv2dParams->quanParameter()->buffer());
|
||||
isMemoryLowWeightOnlyQuant = isMemoryLowWeightOnlyQuant && (static_cast<CUDABackend*>(backend)->getMemoryMode() == BackendConfig::Memory_Low);
|
||||
isMemoryLowWeightOnlyQuant = isMemoryLowWeightOnlyQuant && ConvFpAIntBExecution::isValid(op->main_as_Convolution2D(), backend);
|
||||
if (isMemoryLowWeightOnlyQuant) {
|
||||
|
|
|
@ -124,20 +124,19 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con
|
|||
mFast = false;
|
||||
break;
|
||||
}
|
||||
if(!_directBlitC4(slice0, slice, output)) {
|
||||
mFast = false;
|
||||
if(!_directBlitC4(slice0, slice, output)) {
|
||||
mFast = false;
|
||||
break;
|
||||
}
|
||||
if (!OpCommonUtils::canBlitFast(slice, output, pack, false, true)) {
|
||||
mFast = false;
|
||||
mFast = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
//MNN_PRINT("raster fast:%d regionNum:%d\n\n\n", mFast, des->regions.size());
|
||||
if (mFast) {
|
||||
int dstStep = 1;
|
||||
for (int i=0; i< des->regions.size(); ++i) {
|
||||
int srcStep = 1;
|
||||
int dstStep = 1;
|
||||
auto& slice = des->regions[i];
|
||||
if(slice.dst.offset / (slice.size[2] * slice.size[1]) >= 1) {
|
||||
int batchChannel = slice.dst.offset / (slice.size[1] * slice.size[2]) + 1;
|
||||
|
@ -147,6 +146,11 @@ ErrorCode RasterExecution::onResize(const std::vector<Tensor *> &____inputs, con
|
|||
int tmp = slice.dst.stride[0] / slice.src.stride[0];
|
||||
dstStep = dstStep > tmp ? dstStep : tmp;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i=0; i< des->regions.size(); ++i) {
|
||||
int srcStep = 1;
|
||||
auto& slice = des->regions[i];
|
||||
if(slice.src.offset / (slice.size[2] * slice.size[1]) >= 1) {
|
||||
int batchChannel = slice.src.offset / (slice.size[1] * slice.size[2]) + 1;
|
||||
srcStep = srcStep > batchChannel ? srcStep : batchChannel;
|
||||
|
|
|
@ -18,6 +18,13 @@
|
|||
namespace MNN {
|
||||
namespace CUDA {
|
||||
|
||||
template <typename T>
|
||||
__global__ void SOFTMAX(const T *input, T *output, const int inside, const int axis, const int outside, const int count);
|
||||
template <typename T>
|
||||
__global__ void SOFTMAX_WARP_32(const T *input, T *output, const int inside, const int axis, const int outside, const int count);
|
||||
template <typename T>
|
||||
__global__ void SOFTMAX_AXIS_REDUCE(const T *input, T *output, const int inside, const int axis, const int per_block_size, const int calc_multi_num, const int outside, const int count);
|
||||
|
||||
class SoftmaxExecution : public Execution {
|
||||
public:
|
||||
SoftmaxExecution(int axis, Backend *backend);
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -10,7 +10,6 @@
|
|||
|
||||
#include "backend/cuda/core/CUDABackend.hpp"
|
||||
#include "core/Execution.hpp"
|
||||
#include "../CutlassGemmParam.hpp"
|
||||
#include "../MNNCUDADefine.hpp"
|
||||
#include "../MNNCUDAFunction.cuh"
|
||||
#include "../cutlass_common/CutlassConvCommonExecution.hpp"
|
||||
|
@ -27,12 +26,16 @@ public:
|
|||
void* mScale;
|
||||
void* mOffset;
|
||||
void* mBias;
|
||||
int mQuanC;
|
||||
std::shared_ptr<Tensor> weightTensor;
|
||||
std::shared_ptr<Tensor> scaleTensor;
|
||||
std::shared_ptr<Tensor> offsetTensor;
|
||||
std::shared_ptr<Tensor> biasTensor;
|
||||
Backend* mBackend = nullptr;
|
||||
bool mIsWeightInt4 = false;
|
||||
|
||||
std::shared_ptr<Tensor> mSumBQTensor;
|
||||
void* mSumBQ = nullptr;
|
||||
};
|
||||
static bool isValid(const Convolution2D* conv, Backend* backend);
|
||||
ConvFpAIntBExecution(Backend* backend, const MNN::Op* op, std::shared_ptr<Resource> res);
|
||||
|
@ -43,6 +46,9 @@ public:
|
|||
|
||||
private:
|
||||
std::shared_ptr<Resource> mResource;
|
||||
|
||||
std::shared_ptr<Tensor> mDequantFilterTensor;
|
||||
void* mDequantFilter = nullptr;
|
||||
};
|
||||
|
||||
} // namespace CUDA
|
||||
|
|
|
@ -100,7 +100,7 @@ void AttentionBufExecution::_init() {
|
|||
mCache.reset(new SharedCache);
|
||||
auto mtbn = static_cast<MetalBackend *>(backend());
|
||||
auto context = (__bridge MNNMetalContext *)mtbn->context();
|
||||
mMeta = (KVMeta*)(mtbn->getRuntime()->pMeta);
|
||||
mMeta = (KVMeta*)(mtbn->getMetaPtr());
|
||||
|
||||
mParamQKV = [context newDeviceBuffer:sizeof(Param) access:CPUWriteOnly];
|
||||
mParamSoftmax = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly];
|
||||
|
|
|
@ -269,6 +269,7 @@ private:
|
|||
bool mUseFloatAsFp16;
|
||||
bool mIsIphone = false;
|
||||
BufferAllocator* mCurrentAllocator = nullptr;
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -1219,7 +1219,9 @@ Backend* MetalRuntime::onCreate(const BackendConfig* config, Backend* origin) co
|
|||
memory = config->memory;
|
||||
}
|
||||
bool useFp16AsFp32 = precision != BackendConfig::Precision_High;
|
||||
return new MetalBackend(mStatic, this, useFp16AsFp32, memory);
|
||||
auto backend = new MetalBackend(mStatic, this, useFp16AsFp32, memory);
|
||||
backend->setMetaPtr(pMeta);
|
||||
return backend;
|
||||
}
|
||||
|
||||
void MetalRuntime::onGabageCollect(int level) {
|
||||
|
|
|
@ -574,6 +574,7 @@ bool convertBufferToBuffer(Tensor *input, Tensor *output, OpenCLRuntime *runtime
|
|||
return true;
|
||||
}
|
||||
|
||||
#ifdef __ANDROID__
|
||||
bool convertBetweenAHDandCLmem(const Tensor *input, const Tensor *output, OpenCLRuntime *runtime, int precision, int memType, bool toDevice, bool toHost) {
|
||||
std::set<std::string> buildOptions;
|
||||
auto srcDimensionFormat = TensorUtils::getDescribe(input)->dimensionFormat;
|
||||
|
@ -584,25 +585,46 @@ bool convertBetweenAHDandCLmem(const Tensor *input, const Tensor *output, OpenCL
|
|||
|
||||
buildOptions.emplace("-DINPUT_FORMAT=" + std::to_string(srcDimensionFormat));
|
||||
buildOptions.emplace("-DOUTPUT_FORMAT=" + std::to_string(dstDimensionFormat));
|
||||
std::vector<int> outputShape;
|
||||
std::shared_ptr<KernelWrap> kernelW;
|
||||
if(toDevice){
|
||||
buildOptions.emplace("-DSHARED_TO_CL");
|
||||
kernelW = runtime->buildKernelWithCache("glmem_convert", "gl_to_cl", buildOptions, precision, nullptr, output);
|
||||
outputShape = tensorShapeFormat(output);
|
||||
} else if(toHost){
|
||||
buildOptions.emplace("-DCL_TO_SHARED");
|
||||
kernelW = runtime->buildKernelWithCache("glmem_convert", "cl_to_gl", buildOptions, precision, input, nullptr);
|
||||
outputShape = tensorShapeFormat(input);
|
||||
}else{
|
||||
MNN_PRINT("convertGLMemBetweenCLmem only support toDevice or toHost!\n");
|
||||
return false;
|
||||
}
|
||||
std::vector<int> outputShape = toDevice ? tensorShapeFormat(output): tensorShapeFormat(input);
|
||||
|
||||
int shape[4] = {outputShape[0], outputShape[3], outputShape[1], outputShape[2]};//N C H W
|
||||
uint32_t gws[3] = {static_cast<uint32_t>(UP_DIV(shape[3], 4)),
|
||||
static_cast<uint32_t>(UP_DIV(shape[1], 4)),
|
||||
static_cast<uint32_t>(shape[0] * shape[2])};
|
||||
std::shared_ptr<KernelWrap> kernelW;
|
||||
int format = AHARDWAREBUFFER_FORMAT_R8G8B8A8_UNORM;
|
||||
int stride = shape[3];
|
||||
AHardwareBuffer_Desc Desc = {};
|
||||
if(OpenCLSymbolsOperator::getOpenclSymbolsPtr()->isSupportAhardwareBufferFunc()){
|
||||
if(toDevice){
|
||||
MNNAHardwareBuffer_describe((AHardwareBuffer*)(((CLSharedMemReleaseBuffer*)TensorUtils::getSharedMem(input))->getSharedId()), &Desc);
|
||||
}else{
|
||||
MNNAHardwareBuffer_describe((AHardwareBuffer*)(((CLSharedMemReleaseBuffer*)TensorUtils::getSharedMem(output))->getSharedId()), &Desc);
|
||||
}
|
||||
format = Desc.format;
|
||||
stride = Desc.stride;
|
||||
}
|
||||
if(format == AHARDWAREBUFFER_FORMAT_R8G8B8A8_UNORM){
|
||||
if(toDevice){
|
||||
buildOptions.emplace("-DSHARED_TO_CL");
|
||||
kernelW = runtime->buildKernelWithCache("glmem_convert", "gl_to_cl", buildOptions, precision, nullptr, output);
|
||||
} else if(toHost){
|
||||
buildOptions.emplace("-DCL_TO_SHARED");
|
||||
kernelW = runtime->buildKernelWithCache("glmem_convert", "cl_to_gl", buildOptions, precision, input, nullptr);
|
||||
}
|
||||
}else if(format == AHARDWAREBUFFER_FORMAT_Y8Cb8Cr8_420){
|
||||
if(toDevice){
|
||||
buildOptions.emplace("-DSHARED_TO_CL");
|
||||
kernelW = runtime->buildKernelWithCache("glmem_convert", "yuv_to_cl", buildOptions, precision, nullptr, output);
|
||||
} else if(toHost){
|
||||
buildOptions.emplace("-DCL_TO_SHARED");
|
||||
kernelW = runtime->buildKernelWithCache("glmem_convert", "cl_to_yuv", buildOptions, precision, input, nullptr);
|
||||
}
|
||||
}else{
|
||||
MNN_PRINT("convertGLMemBetweenCLmem only support AHARDWAREBUFFER_FORMAT_R8G8B8A8_UNORM or AHARDWAREBUFFER_FORMAT_Y8Cb8Cr8_420!\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto Kernel = kernelW->get();
|
||||
uint32_t idx = 0;
|
||||
cl_int ret = CL_SUCCESS;
|
||||
|
@ -629,6 +651,7 @@ bool convertBetweenAHDandCLmem(const Tensor *input, const Tensor *output, OpenCL
|
|||
}
|
||||
}
|
||||
ret |= Kernel.setArg(idx++, sizeof(shape), shape);
|
||||
ret |= Kernel.setArg(idx++, stride);
|
||||
MNN_CHECK_CL_SUCCESS(ret, "setArg glmem_convert");
|
||||
|
||||
const uint32_t maxWorkGroupSize = static_cast<uint32_t>(runtime->getMaxWorkGroupSize(kernelW));
|
||||
|
@ -647,6 +670,7 @@ bool convertBetweenAHDandCLmem(const Tensor *input, const Tensor *output, OpenCL
|
|||
MNN_CHECK_CL_SUCCESS(res, "glmem_convert");
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace OpenCL
|
||||
} // namespace MNN
|
||||
|
|
|
@ -34,7 +34,9 @@ bool convertNC4HW4BufferBetweenNC16HW16Buffer(const Tensor *input, Tensor *outpu
|
|||
#endif
|
||||
|
||||
bool convertBufferToBuffer(Tensor *input, Tensor *output, OpenCLRuntime *runtime, int input_precision, int output_precision, int backend_precison, bool toDevice, bool toHost, bool needWait = false, bool svmFlag = false);
|
||||
#ifdef __ANDROID__
|
||||
bool convertBetweenAHDandCLmem(const Tensor *input, const Tensor *output, OpenCLRuntime *runtime, int precision, int memType, bool toDevice, bool toHost);
|
||||
#endif
|
||||
|
||||
class BufferConvertor {
|
||||
public:
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue