Replace the KleidiAI macro with the hint config

Signed-off-by: yanzhang <yanzhang.wang@arm.com>
Change-Id: I9dcc4f68e1e67a11266b66589006650b197cde1e
This commit is contained in:
yanzhang 2025-08-01 17:19:36 +08:00
parent 5952e33570
commit f699ca2842
27 changed files with 193 additions and 109 deletions

View File

@ -117,7 +117,7 @@ static inline uint64_t getTimeInUs() {
}
std::vector<float> doBench(Model& model, int loop, int warmup = 10, int forward = MNN_FORWARD_CPU, bool only_inference = true,
int numberThread = 4, int precision = 2, float sparsity = 0.0f, int sparseBlockOC = 1, bool testQuantModel=false) {
int numberThread = 4, int precision = 2, float sparsity = 0.0f, int sparseBlockOC = 1, bool testQuantModel=false, bool enableKleidiAI=false) {
auto revertor = std::unique_ptr<Revert>(new Revert(model.model_file.c_str()));
if (testQuantModel) {
revertor->initialize(0, sparseBlockOC, false, true);
@ -130,6 +130,7 @@ std::vector<float> doBench(Model& model, int loop, int warmup = 10, int forward
auto net = std::shared_ptr<MNN::Interpreter>(MNN::Interpreter::createFromBuffer(modelBuffer, bufferSize), MNN::Interpreter::destroy);
revertor.reset();
net->setSessionMode(MNN::Interpreter::Session_Release);
net->setSessionHint(MNN::Interpreter::HintMode::CPU_ENABLE_KLEIDIAI, enableKleidiAI);
MNN::ScheduleConfig config;
config.numThread = numberThread;
config.type = static_cast<MNNForwardType>(forward);
@ -392,8 +393,9 @@ int main(int argc, const char* argv[]) {
int precision = 2;
float sparsity = 0.0f;
int sparseBlockOC = 1;
bool enableKleidiAI = false;
if (argc <= 2) {
std::cout << "Usage: " << argv[0] << " models_folder [loop_count] [warmup] [forwardtype] [numberThread] [precision] [weightSparsity] [testQuantizedModel]" << std::endl;
std::cout << "Usage: " << argv[0] << " models_folder [loop_count] [warmup] [forwardtype] [numberThread] [precision] [weightSparsity] [testQuantizedModel] [enableKleidiAI]" << std::endl;
return 1;
}
if (argc >= 3) {
@ -420,8 +422,11 @@ int main(int argc, const char* argv[]) {
if(argc >= 10) {
testQuantizedModel = atoi(argv[9]);
}
if (argc >= 11) {
enableKleidiAI = atoi(argv[10]) > 0 ? true : false;
}
std::cout << "Forward type: " << forwardType(forward) << " thread=" << numberThread << " precision=" <<precision << " sparsity=" <<sparsity << " sparseBlockOC=" << sparseBlockOC << " testQuantizedModel=" << testQuantizedModel << std::endl;
std::cout << "Forward type: " << forwardType(forward) << " thread=" << numberThread << " precision=" <<precision << " sparsity=" <<sparsity << " sparseBlockOC=" << sparseBlockOC << " testQuantizedModel=" << testQuantizedModel << " enableKleidiAI=" << enableKleidiAI << std::endl;
std::vector<Model> models = findModelFiles(argv[1]);
std::cout << "--------> Benchmarking... loop = " << argv[2] << ", warmup = " << warmup << std::endl;
@ -441,10 +446,10 @@ int main(int argc, const char* argv[]) {
}
for (auto& m : models) {
std::vector<float> costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, false);
std::vector<float> costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, false, enableKleidiAI);
displayStats(m.name.c_str(), costs, false);
if (testQuantizedModel) {
costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, true);
costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, true, enableKleidiAI);
displayStats(m.name, costs, 1);
}
}

View File

@ -252,7 +252,10 @@ public:
CPU_CORE_IDS = 14,
// set CPU threads to use when supports Arm sme2
CPU_SME2_INSTRUCTIONS = 15
CPU_SME2_INSTRUCTIONS = 15,
// Enable KleidiAI
CPU_ENABLE_KLEIDIAI = 16
};
enum ExternalPathType {

View File

@ -21,10 +21,6 @@
#include "ThreadPool.hpp"
#endif
#ifdef MNN_KLEIDIAI_ENABLED
#include "arm/mnn_kleidiai.h"
#endif
namespace MNN {
class WorkerThread;
class CPURuntime : public Runtime {

View File

@ -20,7 +20,6 @@ if (MNN_CPU_WEIGHT_DEQUANT_GEMM)
endif()
if (MNN_KLEIDIAI)
add_definitions(-DMNN_KLEIDIAI_ENABLED=1)
# Disable the KleidiAI tests
set(KLEIDIAI_BUILD_TESTS OFF)
# Fetch KleidiAI sources:

View File

@ -23,14 +23,98 @@
#include "core/OpCommonUtils.hpp"
#include "backend/cpu/OneDNNConvolution.hpp"
#include "backend/cpu/compute/ConvInt8TiledExecutor.hpp"
#ifdef MNN_KLEIDIAI_ENABLED
// For KleidiAI relevant
#include "backend/cpu/compute/KleidiAIConvInt8.hpp"
#include "backend/cpu/compute/KleidiAIConvolution.hpp"
#include "backend/cpu/compute/KleidiAIDenseConvolution.hpp"
#endif //MNN_KLEIDIAI_ENABLED
namespace MNN {
static Execution* _createKleidiAIUnit(const Tensor* input, const Tensor* output, Backend* backend, const Op* op,
const float* originWeight, size_t originWeightSize, const float* bias,
size_t biasSize, std::shared_ptr<ConvolutionCommon::Int8Common> weightQuantInfo,
bool supportSparse, bool lowMemory) {
auto cpuBackend = (CPUBackend*)backend;
auto conv2d = op->main_as_Convolution2D();
auto common = conv2d->common();
bool fastWay = common->kernelY() == 1 && common->kernelX() == 1 && output->width() == input->width() &&
output->height() == input->height() && common->strideX() == 1 && common->strideY() == 1;
#ifdef MNN_LOW_MEMORY
if (lowMemory && nullptr != weightQuantInfo.get() && originWeightSize == 0) {
if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) {
do {
if (!weightQuantInfo->canUseInt4) {
break;
}
auto convOp = op->main_as_Convolution2D();
auto core = static_cast<CPUBackend*>(backend)->functions();
int oc = convOp->common()->outputCount();
int ic = convOp->common()->inputCount();
int blockNum = 1;
int dequantCnt = weightQuantInfo->alphaSize;
if (weightQuantInfo->asymmetric) {
dequantCnt /= 2;
}
blockNum = dequantCnt / oc;
bool bAsym = weightQuantInfo->asymmetric;
size_t blkSize = blockNum == 1 ? 0 : ic / blockNum;
KleidiAI::AccelType accelType = KleidiAI::getQIntAccelType(4, bAsym, blkSize, core->bytes);
KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo());
if (!kai.canAccelerate(accelType, convOp->common())) {
break;
}
if (!kai.isLoaded(accelType)) {
kai.setLoaded(accelType);
kai.printInfo(accelType);
}
return new KleidiAIConvInt8(backend, op, weightQuantInfo, true, kai, accelType, blockNum);
} while (0);
}
// Have not supported the quantized weight.
return nullptr;
}
#else
if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) {
if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) {
return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize,
weightQuantInfo);
}
// Do nothing and fallback.
return nullptr;
}
#endif
// This is different with original impl. It's a corresponding impl for strassen,
// which is called when built without MNN_REDUCE_SIZE. But for KleidiAI,
// need not to care about this.
if (fastWay && cpuBackend->functions()->matmulBytes == 0) {
auto bytes = cpuBackend->functions()->bytes;
auto accelType = (bytes == 2) ? KleidiAI::AccelType::FP16 : KleidiAI::AccelType::FP32;
KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo());
if (kai.canAccelerate(accelType)) {
return new KleidiAIConvolution(common, backend, originWeight, originWeightSize, bias, biasSize);
}
}
if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) {
return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize,
weightQuantInfo);
}
return nullptr;
}
static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend* backend,
const Op* op, const float* originWeight, size_t originWeightSize, const float* bias, size_t biasSize, std::shared_ptr<ConvolutionCommon::Int8Common> weightQuantInfo, bool supportSparse, bool lowMemory) {
auto cpuBackend = (CPUBackend*)backend;
@ -48,47 +132,22 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend
}
}
#endif
if (cpuBackend->getRuntime()->hint().enableKleidiAI) {
auto execution = _createKleidiAIUnit(input, output, backend, op, originWeight, originWeightSize, bias, biasSize,
weightQuantInfo, supportSparse, lowMemory);
if (execution) {
return execution;
}
}
bool fastWay = common->kernelY() == 1 && common->kernelX() == 1
&& output->width() == input->width() && output->height() == input->height()
&& common->strideX() == 1 && common->strideY() == 1;
#ifdef MNN_LOW_MEMORY
if (lowMemory && nullptr != weightQuantInfo.get() && originWeightSize == 0) {
if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) {
#ifdef MNN_KLEIDIAI_ENABLED
do {
if (!weightQuantInfo->canUseInt4) {
break;
}
auto convOp = op->main_as_Convolution2D();
auto core = static_cast<CPUBackend*>(backend)->functions();
int oc = convOp->common()->outputCount();
int ic = convOp->common()->inputCount();
int blockNum = 1;
int dequantCnt = weightQuantInfo->alphaSize;
if (weightQuantInfo->asymmetric) {
dequantCnt /= 2;
}
blockNum = dequantCnt / oc;
bool bAsym = weightQuantInfo->asymmetric;
size_t blkSize = blockNum == 1 ? 0 : ic / blockNum;
KleidiAI::AccelType accelType = KleidiAI::getQIntAccelType(4, bAsym, blkSize, core->bytes);
KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo());
if(!kai.isLoaded(accelType)) {
kai.setLoaded(accelType);
kai.printInfo(accelType);
}
if(!kai.canAccelerate(accelType, convOp->common())){
break;
}
return new KleidiAIConvInt8(backend, op, weightQuantInfo, true, kai, accelType, blockNum);
} while (0);
#endif
return new DenseConvInt8TiledExecutor(backend, op, weightQuantInfo, true);
} else {
return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
@ -96,37 +155,16 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend
}
#else
if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) {
#ifdef MNN_KLEIDIAI_ENABLED
if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) {
return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
}
#endif
return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
}
#endif
#ifndef MNN_REDUCE_SIZE
if (fastWay && cpuBackend->functions()->matmulBytes == 0) {
#ifdef MNN_KLEIDIAI_ENABLED
auto bytes = cpuBackend->functions()->bytes;
auto accelType = (bytes==2) ? KleidiAI::AccelType::FP16 : KleidiAI::AccelType::FP32;
KleidiAI& kai = KleidiAI::getInstance(*MNNGetCPUInfo());
if (kai.canAccelerate(accelType)){
return new KleidiAIConvolution(common, backend, originWeight, originWeightSize, bias, biasSize);
}
#endif //MNN_KLEIDIAI_ENABLED
return new Convolution1x1Strassen(common, backend, originWeight, originWeightSize, bias, biasSize);
}
#endif
#ifdef MNN_KLEIDIAI_ENABLED
if (MNNGetCPUInfo()->sme2 && !weightQuantInfo) {
return new KleidiAIDenseConvolution(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo);
}
#endif
if (cpuBackend->getRuntime()->hint().winogradMemoryUsed == 0 || (!ConvolutionWinogradBridge::canUseWinograd(common))) {
return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, nullptr);
}

View File

@ -4,7 +4,6 @@
// SPDX-License-Identifier: Apache-2.0
//
#ifdef MNN_KLEIDIAI_ENABLED
#include "KleidiAIConvInt8.hpp"
#include "core/Macro.h"
#include "core/BufferAllocator.hpp"
@ -303,4 +302,3 @@ ErrorCode KleidiAIConvInt8::onExecute(const std::vector<Tensor*>& inputs, const
}
} // namespace MNN
#endif //MNN_KLEIDIAI_ENABLED

View File

@ -6,10 +6,8 @@
#ifndef KleidiAIConvInt8_hpp
#define KleidiAIConvInt8_hpp
#ifdef MNN_KLEIDIAI_ENABLED
#include "backend/cpu/CPUConvolution.hpp"
#include "Int8FunctionsOpt.h"
#include "CommonOptFunction.h"
#include "backend/cpu/arm/mnn_kleidiai.h"
namespace MNN {
class KleidiAIConvInt8 : public CPUConvolution {
@ -31,5 +29,4 @@ private:
};
} // namespace MNN
#endif // MNN_KLEIDIAI_ENABLED
#endif /* KleidiAIConvInt8_hpp */
#endif /* KleidiAIConvInt8_hpp */

View File

@ -4,18 +4,14 @@
// SPDX-License-Identifier: Apache-2.0
//
#ifdef MNN_KLEIDIAI_ENABLED
#include "KleidiAIConvolution.hpp"
#include <string.h>
#include "core/BufferAllocator.hpp"
#include "backend/cpu/CPUBackend.hpp"
#include "core/Concurrency.h"
#include "core/TensorUtils.hpp"
#include "backend/cpu/CPUTensorConvert.hpp"
namespace MNN {
#ifndef MNN_REDUCE_SIZE
KleidiAIConvolution::KleidiAIConvolution(const Convolution2DCommon *common, Backend *b, const float *originWeight,
size_t originWeightSize, const float *bias, size_t biasSize)
: CPUConvolution(common, b) {
@ -228,7 +224,4 @@ ErrorCode KleidiAIConvolution::onExecute(const std::vector<Tensor *> &inputs, co
return NO_ERROR;
}
#endif
} // namespace MNN
#endif //MNN_KLEIDIAI_ENABLED

View File

@ -6,12 +6,9 @@
#ifndef KleidiAIConvolution_hpp
#define KleidiAIConvolution_hpp
#ifdef MNN_KLEIDIAI_ENABLED
#include <functional>
#include "backend/cpu/CPUConvolution.hpp"
#include "backend/cpu/arm/mnn_kleidiai.h"
namespace MNN {
#ifndef MNN_REDUCE_SIZE
class KleidiAIConvolution : public CPUConvolution{
public:
KleidiAIConvolution(const Convolution2DCommon *common, Backend *b, const float *originWeight, size_t originWeightSize, const float *bias, size_t biasSize);
@ -30,8 +27,5 @@ class KleidiAIConvolution : public CPUConvolution{
KleidiAI::AccelType mAccelType = KleidiAI::AccelType::ACC_TYPE_NUMBER;
std::vector<float> mPostParameters;
};
#endif //MNN_KLEIDIAI_ENABLED
} // namespace MNN
#endif
#endif /* KleidiAIConvolution_hpp */

View File

@ -1,4 +1,3 @@
#if MNN_KLEIDIAI_ENABLED
#include "KleidiAIDenseConvolution.hpp"
#include <numeric>
@ -9,6 +8,7 @@
#include "backend/cpu/CPUTensorConvert.hpp"
#include "core/Macro.h"
#include "core/TensorUtils.hpp"
#include "core/Concurrency.h"
#include "kai_imatmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa.h"
#include "kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h"
#include "kai_lhs_imatmul_pack_x16p2vlx2_x16p_sme.h"
@ -304,10 +304,15 @@ ErrorCode KleidiAIDenseConvolutionImpl::onResize(const std::vector<Tensor*>& inp
.dilatedWidth = mCommon->dilateX(),
};
mFunction.second = [=](int tid) {
int threadNum = static_cast<CPUBackend*>(backend())->threadNumber();
mFunction.second = [=](int tId) {
// Convert NC4HW4 to NHWC
auto inputShape = input->shape(); // TODO check for NC4HW4, should be the NCHW
CPUTensorConverter::convert(input, &mInputNHWC, core);
// CPUTensorConverter::convert(input, &mInputNHWC, core);
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
CPUTensorConverter::convert(input, &mInputNHWC, core, tId, threadNum);
};
MNN_CONCURRENCY_END();
// Lhs packing
if (bytes == 4) {
int blockSize = kai_get_m_step_lhs_imatmul_pack_x32p2vlx1_x32p_sme();
@ -348,7 +353,11 @@ ErrorCode KleidiAIDenseConvolutionImpl::onResize(const std::vector<Tensor*>& inp
}
// Convert NHWC to NC4HW4
CPUTensorConverter::convert(&mOutputNHWC, output, core);
// CPUTensorConverter::convert(&mOutputNHWC, output, core);
MNN_CONCURRENCY_BEGIN(tId, threadNum) {
CPUTensorConverter::convert(&mOutputNHWC, output, core, tId, threadNum);
};
MNN_CONCURRENCY_END();
};
return NO_ERROR;
}
@ -359,4 +368,4 @@ ErrorCode KleidiAIDenseConvolutionImpl::onExecute(const std::vector<Tensor*>& in
return NO_ERROR;
}
} // namespace MNN
#endif

View File

@ -1,5 +1,3 @@
#if MNN_KLEIDIAI_ENABLED
#ifndef KleidiAIDenseConvolution_hpp
#define KleidiAIDenseConvolution_hpp
@ -242,4 +240,3 @@ private:
} // namespace MNN
#endif /* KleidiAIDenseConvolution_hpp */
#endif

View File

@ -62,6 +62,8 @@ struct RuntimeHint {
// whether to use Arm sme2 cores when threads>1
bool useArmSme2Cores = true;
bool enableKleidiAI = false;
// Use CPU Ids
std::vector<int> cpuIds;
};

View File

@ -109,6 +109,9 @@ void Session::ModeGroup::setHint(Interpreter::HintMode hint, int value) {
case Interpreter::HintMode::INIT_THREAD_NUMBER:
runtimeHint.initThreadNumber = value;
break;
case Interpreter::HintMode::CPU_ENABLE_KLEIDIAI:
runtimeHint.enableKleidiAI = value > 0 ? true : false;
break;
default:
break;
}

View File

@ -19,10 +19,6 @@
#undef CONSTANT
#endif // CONSTANT
#ifdef MNN_KLEIDIAI_ENABLED
#include "../backend/cpu/arm/mnn_kleidiai.h"
#endif
namespace MNN {
struct TensorArrayAttr {
// array size is dynamic or not

View File

@ -92,3 +92,7 @@ MNNForwardType getCurrentType() {
return attr->firstType;
}
std::shared_ptr<MNN::Express::Executor> cloneCurrentExecutor() {
auto attr = MNN::Express::ExecutorScope::Current()->getAttr();
return MNN::Express::Executor::newExecutor(getCurrentType(), attr->config, attr->numThread);
}

View File

@ -104,6 +104,8 @@ inline float keepFP32Precision(float fp32Value) {
}
MNNForwardType getCurrentType();
std::shared_ptr<MNN::Express::Executor> cloneCurrentExecutor();
using ConvertFP32 = float(*)(float fp32Value);
const static std::vector<ConvertFP32> FP32Converter = {

View File

@ -49,6 +49,12 @@ public:
for (int i = 0; i < channel * height * width; ++i){
inputData[i] = (rand() % 10) * 0.1;
}
MNN::BackendConfig config;
config.precision = (MNN::BackendConfig::PrecisionMode)MNN::BackendConfig::Precision_Normal;
config.memory = (MNN::BackendConfig::MemoryMode)MNN::BackendConfig::Memory_Normal;
std::shared_ptr<Executor> executor(Executor::newExecutor(getCurrentType(), config, 4));
ExecutorScope scope(executor);
auto net = _createModel();
auto x = _Input({1, channel, height, width}, NCHW, halide_type_of<float>());
@ -65,11 +71,7 @@ public:
// clone model
MNN::BackendConfig config;
config.precision = (MNN::BackendConfig::PrecisionMode)MNN::BackendConfig::Precision_Normal;
config.memory = (MNN::BackendConfig::MemoryMode)MNN::BackendConfig::Memory_Normal;
std::shared_ptr<Executor> executor(Executor::newExecutor(getCurrentType(), config, 4));
ExecutorScope scope(executor);
std::unique_ptr<Module> tempModule(Module::clone(net.get()));
auto xClone = _Input({1, channel, height, width}, NCHW, halide_type_of<float>());

View File

@ -14,12 +14,16 @@
#include <MNN/expr/Module.hpp>
#include "MNNTestSuite.h"
#include "MNN_generated.h"
#include "TestUtils.h"
using namespace MNN;
using namespace MNN::Express;
class GatherExprTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
std::unique_ptr<MNN::OpT> gatherOp(new MNN::OpT);
gatherOp->type = MNN::OpType_GatherND;
auto parameter = _Input({2, 2}, NHWC, halide_type_of<int32_t>());
@ -224,7 +228,8 @@ public:
class GatherNdReComputeTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
const float inpudata[] = {-1.0, -2.0, 3.0, 4.0};
const int indices_data[] = {0, 0, 1, 1};
auto params = _Const(inpudata, {2, 2}, NHWC, halide_type_of<float>());

View File

@ -22,6 +22,8 @@ public:
MNN_ERROR("Currently don't test not cpu mmap\n");
return true;
}
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
auto x = _Input({1, 3, 224, 224}, NC4HW4, halide_type_of<float>());
x->setName("x");
auto y = _Conv(1.0f, 0.01f, x, {3, 16}, {5, 5});

View File

@ -11,6 +11,7 @@
#include <random>
#include "MNNTestSuite.h"
#include "MNN_generated.h"
#include "TestUtils.h"
#include <MNN/expr/Expr.hpp>
#include <MNN/expr/ExprCreator.hpp>
#include <MNN/expr/Module.hpp>
@ -66,6 +67,8 @@ static void _originMatMul(float* C, const float* A, const float* B, int e, int l
class MatMulTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
int e = 5, h = 4, l = 6;
if (true) {
// Test MatMul

View File

@ -2,6 +2,7 @@
#include <MNN/expr/ExprCreator.hpp>
#include <MNN/expr/Module.hpp>
#include "MNNTestSuite.h"
#include "TestUtils.h"
using namespace MNN;
using namespace MNN::Express;
@ -15,6 +16,8 @@ public:
return summer;
}
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
std::vector<VARP> empty;
// Make Net
auto x = _Input({1, 3, 2, 2}, NCHW, halide_type_of<float>());

View File

@ -177,6 +177,8 @@ MNNTestSuiteRegister(ModuleTest, "expr/ModuleTest");
class ModuleWrongInputTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
std::vector<int8_t> buffer;
// construct
{
@ -244,6 +246,8 @@ MNNTestSuiteRegister(ModuleWrongInputTest, "expr/ModuleWrongInputTest");
class RefTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
std::vector<int8_t> buffer;
// construct
{
@ -318,6 +322,8 @@ public:
}
}
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
std::vector<int8_t> buffer;
#ifdef MNN_REDUCE_SIZE
return true;
@ -1039,6 +1045,8 @@ MNNTestSuiteRegister(ConstMemoryReplaceTest, "expr/ConstMemoryReplaceTest");
class MutlThreadConstReplaceTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
auto func = [precision](VARP y, int thread) {
flatbuffers::FlatBufferBuilder builderOutput(1024);
{
@ -1499,6 +1507,8 @@ MNNTestSuiteRegister(ExecutorResetLoadModuleTest, "expr/ExecutorResetLoadModuleT
class SequenceForwardResizeTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
// Make Model include convolution in shape compute and content compute
auto x = _Input({1, 3, 24, 24}, NCHW, halide_type_of<float>());
x->setName("x");
@ -1606,6 +1616,8 @@ MNNTestSuiteRegister(SequenceForwardResizeTest, "expr/SequenceForwardResizeTest"
class InputModuleTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
auto y = _mobileNetV1Expr(nullptr, false);
std::unique_ptr<MNN::NetT> net(new NetT);
Variable::save({y}, net.get());

View File

@ -35,6 +35,8 @@ static std::shared_ptr<Module> _createModel() {
class RasterOutputTest : public MNNTestCase {
public:
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
auto net = _createModel();
auto x = _Input({1, 3, 224, 224}, NCHW, halide_type_of<int>());
auto y = _Transpose(x, {0, 1, 3, 2});

View File

@ -66,6 +66,11 @@ int main(int argc, char* argv[]) {
dynamicOption = atoi(argv[7]);
FUNC_PRINT(dynamicOption);
}
bool enableKleidiAI = false;
if (argc > 8) {
enableKleidiAI = atoi(argv[8]) > 0 ? true : false;
FUNC_PRINT(enableKleidiAI);
}
auto exe = MNN::Express::Executor::newExecutor(type, config, thread);
if (exe == nullptr) {
MNN_ERROR("Can't create executor with type:%d, exit!\n", type);
@ -76,6 +81,7 @@ int main(int argc, char* argv[]) {
// set hint
MNN::RuntimeHint hint;
hint.dynamicQuantOption = dynamicOption;
hint.enableKleidiAI = enableKleidiAI;
scope.Current()->getRuntime().second->setRuntimeHint(hint);
MNNTestSuite::get()->pStaus.memory = memory;
MNNTestSuite::get()->pStaus.precision = precision;

View File

@ -202,6 +202,8 @@ class BinaryBroadcastTest : public MNNTestCase {
virtual ~BinaryBroadcastTest() = default;
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
bool resultNCHW = testDimensionFormat(NCHW, precision);
bool resultNHWC = testDimensionFormat(NHWC, precision);

View File

@ -17,6 +17,8 @@ class ReverseTest : public MNNTestCase {
public:
virtual ~ReverseTest() = default;
virtual bool run(int precision) {
auto executor = cloneCurrentExecutor();
ExecutorScope scope(executor);
std::shared_ptr<MNN::Express::Module> net;
{
auto input = _Input({3, 2, 3}, NCHW);

View File

@ -109,7 +109,7 @@ static inline std::vector<int> parseIntList(const std::string& str, char delim)
int main(int argc, char *argv[]) {
if (argc < 3) {
MNN_PRINT("=======================================================================================================================================\n");
MNN_ERROR("Usage: ./ModuleBasic.out ${test.mnn} ${Dir} [runMask] [forwardType] [runLoops] [numberThread] [precision | memory] [cacheFile] [cpuIds]\n");
MNN_ERROR("Usage: ./ModuleBasic.out ${test.mnn} ${Dir} [runMask] [forwardType] [runLoops] [numberThread] [precision | memory] [cacheFile] [cpuIds] [enableKleidiAI]\n");
MNN_PRINT("=======================================================================================================================================\n");
return 0;
}
@ -247,11 +247,16 @@ int main(int argc, char *argv[]) {
for (auto id : cpuIds) {
MNN_PRINT("%d ", id);
}
bool enableKleidiAI = false;
if (argc > 10) {
enableKleidiAI = atoi(argv[10]) > 0 ? true : false;
}
MNN_PRINT("\n");
FUNC_PRINT(precision);
FUNC_PRINT(memory);
FUNC_PRINT(power);
FUNC_PRINT_ALL(cacheFileName, s);
FUNC_PRINT(enableKleidiAI);
// create session
MNN::ScheduleConfig config;
config.type = type;
@ -320,6 +325,10 @@ int main(int argc, char *argv[]) {
rtmgr->setHint(Interpreter::DYNAMIC_QUANT_OPTIONS, 2);
}
if (enableKleidiAI) {
rtmgr->setHint(Interpreter::CPU_ENABLE_KLEIDIAI, true);
}
// rtmgr->setHint(Interpreter::CPU_SME2_INSTRUCTIONS, false);
if (runMask & 2048) {