mirror of https://github.com/alibaba/MNN.git
Compare commits
3 Commits
b95d87c004
...
10fb71261b
| Author | SHA1 | Date |
|---|---|---|
|
|
10fb71261b | |
|
|
9a085992ea | |
|
|
8e7a63d622 |
|
|
@ -97,12 +97,12 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend
|
|||
#else
|
||||
if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) {
|
||||
#ifdef MNN_KLEIDIAI_ENABLED
|
||||
if (MNNGetCPUInfo()->sme2 && !weigthQauntInfo && cpuBackend->functions()->bytes == 4) {
|
||||
if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) {
|
||||
return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
|
||||
}
|
||||
#else
|
||||
return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
|
||||
#endif
|
||||
|
||||
return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -122,7 +122,7 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend
|
|||
#endif
|
||||
|
||||
#ifdef MNN_KLEIDIAI_ENABLED
|
||||
if (MNNGetCPUInfo()->sme2 && !weightQuantInfo && cpuBackend->functions()->bytes == 4) {
|
||||
if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) {
|
||||
return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -9,8 +9,11 @@
|
|||
#include "backend/cpu/CPUTensorConvert.hpp"
|
||||
#include "core/Macro.h"
|
||||
#include "core/TensorUtils.hpp"
|
||||
#include "kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h"
|
||||
#include "kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h"
|
||||
#include "kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h"
|
||||
#include "kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h"
|
||||
#include "kai_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme.h"
|
||||
#include "kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h"
|
||||
|
||||
namespace MNN {
|
||||
|
|
@ -26,8 +29,11 @@ static void initWeight(const T* weight, const T* bias, T* cache, T* output, cons
|
|||
if (bytes == 4) {
|
||||
kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(outputCount, kh * kw, srcCount, outputCount * sizeof(T),
|
||||
cache, bias, output);
|
||||
} else if (bytes == 2) {
|
||||
kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(outputCount, kh * kw, srcCount, outputCount * sizeof(T),
|
||||
cache, bias, output);
|
||||
} else {
|
||||
MNN_ERROR("Not fp32, should not be called here\n");
|
||||
MNN_ERROR("Not fp32 and fp16, should not be called here\n");
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
|
@ -49,8 +55,11 @@ KleidiAIDenseConvolution::KleidiAIDenseConvolution(const Convolution2DCommon* co
|
|||
if (core->bytes == 4) {
|
||||
kai_rhs_packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(
|
||||
outputCount, common->kernelY() * common->kernelX(), srcCount);
|
||||
} else if (core->bytes == 2) {
|
||||
kai_rhs_packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(
|
||||
outputCount, common->kernelY() * common->kernelX(), srcCount);
|
||||
} else {
|
||||
MNN_ERROR("Not fp32, should not be called here\n");
|
||||
MNN_ERROR("Not fp32 and fp16, should not be called here\n");
|
||||
abort();
|
||||
}
|
||||
mResource->mWeight.reset(Tensor::createDevice<uint8_t>({kai_rhs_packed_size}));
|
||||
|
|
@ -76,8 +85,17 @@ KleidiAIDenseConvolution::KleidiAIDenseConvolution(const Convolution2DCommon* co
|
|||
if (core->bytes == 4) {
|
||||
MNN::initWeight(originWeight, bias, cache->host<float>(), mResource->mWeight->host<float>(), oihwShape,
|
||||
core->bytes);
|
||||
} else if (core->bytes == 2) {
|
||||
for (int i = 0; i < outputCount; i++) {
|
||||
mResource->mBias->host<__fp16>()[i] = (__fp16)(bias[i]);
|
||||
}
|
||||
ConvertOIHWToHWIO(cache->host<__fp16>(), originWeight,
|
||||
{outputCount, srcCount, common->kernelY(), common->kernelX()});
|
||||
kai_run_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(
|
||||
outputCount, common->kernelY() * common->kernelX(), srcCount, outputCount * sizeof(__fp16),
|
||||
cache->host<__fp16>(), mResource->mBias->host<__fp16>(), mResource->mWeight->host<__fp16>());
|
||||
} else {
|
||||
MNN_ERROR("Not fp32, should not be called here\n");
|
||||
MNN_ERROR("Not fp32 and fp16, should not be called here\n");
|
||||
abort();
|
||||
}
|
||||
|
||||
|
|
@ -135,8 +153,11 @@ ErrorCode KleidiAIDenseConvolutionMultiInput::onExecute(const std::vector<Tensor
|
|||
if (function->bytes == 4) {
|
||||
initWeight(source, mInputs[2]->host<float>(), cache, mTempWeight->host<float>(), inputs[1]->shape(),
|
||||
function->bytes);
|
||||
} else if (function->bytes == 2) {
|
||||
initWeight(reinterpret_cast<const __fp16*>(source), mInputs[2]->host<__fp16>(),
|
||||
reinterpret_cast<__fp16*>(cache), mTempWeight->host<__fp16>(), inputs[1]->shape(), function->bytes);
|
||||
} else {
|
||||
MNN_ERROR("Not fp32, should not be called here\n");
|
||||
MNN_ERROR("Not fp32 and fp16, should not be called here\n");
|
||||
abort();
|
||||
}
|
||||
return mProxy->onExecute(mInputs, outputs);
|
||||
|
|
@ -150,8 +171,12 @@ ErrorCode KleidiAIDenseConvolutionMultiInput::onResize(const std::vector<Tensor*
|
|||
int kai_rhs_packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(
|
||||
outputCount, inputs[1]->stride(1), depth);
|
||||
mTempWeight.reset(Tensor::createDevice<uint8_t>({kai_rhs_packed_size}));
|
||||
} else if (function->bytes == 2) {
|
||||
int kai_rhs_packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x16p2vlx2b_x16_x16_sme(
|
||||
outputCount, inputs[1]->stride(1), depth);
|
||||
mTempWeight.reset(Tensor::createDevice<uint8_t>({kai_rhs_packed_size}));
|
||||
} else {
|
||||
MNN_ERROR("Not fp32, should not be called here\n");
|
||||
MNN_ERROR("Not fp32 and fp16, should not be called here\n");
|
||||
abort();
|
||||
}
|
||||
mTempWeightCache.reset(Tensor::createDevice<float>(
|
||||
|
|
@ -206,8 +231,11 @@ ErrorCode KleidiAIDenseConvolutionImpl::onResize(const std::vector<Tensor*>& inp
|
|||
if (core->bytes == 4) {
|
||||
mTempBufferTranspose.buffer().dim[0].extent =
|
||||
kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme(outputNhwSize, kernelSize, ic);
|
||||
} else if (core->bytes == 2) {
|
||||
mTempBufferTranspose.buffer().dim[0].extent =
|
||||
kai_get_lhs_packed_size_lhs_imatmul_pack_x16p2vlx2_x16p_sme(outputNhwSize, kernelSize, ic);
|
||||
} else {
|
||||
MNN_ERROR("Not fp32, should not be called here\n");
|
||||
MNN_ERROR("Not fp32 and fp16, should not be called here\n");
|
||||
abort();
|
||||
}
|
||||
TensorUtils::setLinearLayout(&mTempBufferTranspose);
|
||||
|
|
@ -289,8 +317,16 @@ ErrorCode KleidiAIDenseConvolutionImpl::onResize(const std::vector<Tensor*>& inp
|
|||
kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme(outputNhwSize, kernelSize, ic, table.data.data(), 0,
|
||||
mPadBuffer.host<uint8_t>(),
|
||||
mTempBufferTranspose.host<uint8_t>());
|
||||
} else if (bytes == 2) {
|
||||
int blockSize = kai_get_m_step_lhs_imatmul_pack_x16p2vlx2_x16p_sme();
|
||||
::memset(mPadBuffer.host<__fp16>(), 0, params.inputChannel * sizeof(__fp16));
|
||||
auto table = IndirectionTable<__fp16>(mInputNHWC.shape(), params, mInputNHWC.host<__fp16>(),
|
||||
mPadBuffer.host<__fp16>(), blockSize);
|
||||
kai_run_lhs_imatmul_pack_x16p2vlx2_x16p_sme(outputNhwSize, kernelSize, ic, table.data.data(), 0,
|
||||
mPadBuffer.host<uint8_t>(),
|
||||
mTempBufferTranspose.host<uint8_t>());
|
||||
} else {
|
||||
MNN_ERROR("Not fp32, should not be called here\n");
|
||||
MNN_ERROR("Not fp32 and fp16, should not be called here\n");
|
||||
abort();
|
||||
}
|
||||
|
||||
|
|
@ -300,8 +336,14 @@ ErrorCode KleidiAIDenseConvolutionImpl::onResize(const std::vector<Tensor*>& inp
|
|||
outputNhwSize, outputChannel, kernelSize, ic, mTempBufferTranspose.host<uint8_t>(),
|
||||
weight->host<uint8_t>(), mOutputNHWC.host<uint8_t>(), outputChannel * sizeof(float), postParameters[2],
|
||||
postParameters[3]);
|
||||
} else if (bytes == 2) {
|
||||
float max = postParameters[3] > 65504.f ? 65504.f : postParameters[3];
|
||||
kai_run_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa(
|
||||
outputNhwSize, outputChannel, kernelSize, ic, mTempBufferTranspose.host<uint8_t>(),
|
||||
weight->host<uint8_t>(), mOutputNHWC.host<uint8_t>(), outputChannel * sizeof(__fp16), postParameters[2],
|
||||
max);
|
||||
} else {
|
||||
MNN_ERROR("Not fp32, should not be called here\n");
|
||||
MNN_ERROR("Not fp32 and fp16, should not be called here\n");
|
||||
abort();
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue