2024-10-21 14:32:47 +08:00
|
|
|
//
|
|
|
|
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
|
|
|
|
//
|
|
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
//
|
|
|
|
|
2024-08-16 09:05:35 +08:00
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include <MNN/ErrorCode.hpp>
|
2025-01-07 10:18:20 +08:00
|
|
|
#include "core/Backend.hpp"
|
|
|
|
#include "core/Execution.hpp"
|
|
|
|
#include "core/TensorUtils.hpp"
|
|
|
|
#include "core/ConvolutionCommon.hpp"
|
|
|
|
#include "backend/cpu/CPUBackend.hpp"
|
|
|
|
#include "backend/cpu/CPURuntime.hpp"
|
|
|
|
#include "backend/cpu/compute/CommonOptFunction.h"
|
|
|
|
|
|
|
|
#include "mnn_kleidiai_util.h"
|
2024-08-16 09:05:35 +08:00
|
|
|
|
2025-04-24 08:32:00 +08:00
|
|
|
#define FLT16_MAX 65504.0f
|
|
|
|
#define FLT16_MIN -65504.0f
|
|
|
|
|
2024-08-16 09:05:35 +08:00
|
|
|
namespace MNN {
|
|
|
|
class KleidiAI {
|
|
|
|
public:
|
2025-01-07 10:18:20 +08:00
|
|
|
// ===================================================================
|
|
|
|
// Enum definition
|
|
|
|
|
|
|
|
enum class AccelType {
|
|
|
|
/*
|
|
|
|
ASYM/SYM: Asymmetric/symmetric;
|
|
|
|
CHNLQT/BLKQT: channel wise/block wise;
|
|
|
|
*/
|
|
|
|
QINT = 0,
|
2025-07-08 16:22:23 +08:00
|
|
|
QI4_ASYM_CHNLQT_F32 = QINT,
|
|
|
|
QI4_ASYM_CHNLQT_F16,
|
|
|
|
QI4_ASYM_BLKQT_F32,
|
|
|
|
QI4_ASYM_BLKQT_F16,
|
|
|
|
QI4_SYM_CHNLQT_F32,
|
2025-01-07 10:18:20 +08:00
|
|
|
QI4_SYM_BLKQT,
|
|
|
|
QI8_ASYM_CHNLQT,
|
|
|
|
QI8_ASYM_BLKQT,
|
|
|
|
QI8_SYM_CHNLQT,
|
|
|
|
QI8_SYM_BLKQT,
|
|
|
|
QINT_END = QI8_SYM_BLKQT,
|
|
|
|
|
|
|
|
FLOAT,
|
|
|
|
FP16 = FLOAT,
|
|
|
|
FP32,
|
|
|
|
BF16,
|
|
|
|
FLOAT_END = BF16,
|
|
|
|
|
|
|
|
ACC_TYPE_NUMBER,
|
|
|
|
ACC_TYPE_ERROR = ACC_TYPE_NUMBER
|
|
|
|
};
|
|
|
|
|
|
|
|
// ===================================================================
|
|
|
|
// Some necessary data structures
|
|
|
|
typedef struct KernelParam {
|
|
|
|
size_t mKaiMstepGemv = 0;
|
|
|
|
size_t mKaiMstepGemm = 0;
|
|
|
|
size_t mKaiNStep = 0;
|
|
|
|
|
|
|
|
size_t mKaiMrGemv = 0;
|
|
|
|
size_t mKaiMrGemm = 0;
|
|
|
|
size_t mKaiNr = 0;
|
|
|
|
size_t mKaiKr = 0;
|
|
|
|
size_t mKaiSr = 0;
|
|
|
|
} KernelParam;
|
|
|
|
|
|
|
|
typedef struct KernelInfo {
|
|
|
|
bool mKernelSupport = false;
|
|
|
|
KernelParam mKernelParam;
|
|
|
|
} KernelInfo;
|
|
|
|
|
|
|
|
typedef struct StaticInfo {
|
|
|
|
bool mDot = false;
|
|
|
|
bool mI8mm = false;
|
|
|
|
bool mSme2 = false;
|
|
|
|
|
|
|
|
KernelInfo mKernelInfo[(size_t)AccelType::ACC_TYPE_NUMBER];
|
|
|
|
} StaticInfo;
|
|
|
|
|
|
|
|
|
|
|
|
typedef struct QIntInfo {
|
|
|
|
size_t mBits;
|
|
|
|
bool mAsymmetric; //Asymmetric quantized model.
|
|
|
|
size_t mBlockSize; //0: Per channel quant; others: Per block quant.
|
2025-07-08 16:22:23 +08:00
|
|
|
size_t mBytes; //4: float32; 2: float16.
|
2025-01-07 10:18:20 +08:00
|
|
|
|
2025-07-08 16:22:23 +08:00
|
|
|
QIntInfo(size_t bits = 4, bool asymmetric = false, size_t blockSize = 0, size_t bytes = 0) {
|
2025-01-07 10:18:20 +08:00
|
|
|
mBits = bits;
|
|
|
|
mAsymmetric = asymmetric;
|
|
|
|
mBlockSize = blockSize;
|
2025-07-08 16:22:23 +08:00
|
|
|
mBytes = bytes;
|
2024-08-16 09:05:35 +08:00
|
|
|
}
|
|
|
|
|
2025-01-07 10:18:20 +08:00
|
|
|
bool operator<(const QIntInfo& rhs) const {
|
|
|
|
if(mBits != rhs.mBits) {
|
|
|
|
return mBits < rhs.mBits;
|
|
|
|
}
|
|
|
|
|
|
|
|
if(mAsymmetric != rhs.mAsymmetric) {
|
|
|
|
return mAsymmetric < rhs.mAsymmetric;
|
|
|
|
}
|
|
|
|
|
2025-07-08 16:22:23 +08:00
|
|
|
if(mBytes != rhs.mBytes) {
|
|
|
|
return mBytes < rhs.mBytes;
|
|
|
|
}
|
|
|
|
|
2025-01-07 10:18:20 +08:00
|
|
|
bool lhsPerChannel = mBlockSize == 0 ? true : false;
|
|
|
|
bool rhsPerChannel = rhs.mBlockSize == 0 ? true : false;
|
|
|
|
return lhsPerChannel < rhsPerChannel;
|
2024-08-16 09:05:35 +08:00
|
|
|
}
|
2025-01-07 10:18:20 +08:00
|
|
|
} QIntInfo;
|
|
|
|
|
|
|
|
// ===================================================================
|
|
|
|
|
|
|
|
//Public static members.
|
|
|
|
static bool mKaiInitialized;
|
|
|
|
|
|
|
|
//Get instance.
|
2025-07-08 16:22:23 +08:00
|
|
|
static KleidiAI &getInstance(const MNNCPUInfo& gCPUInfo);
|
2025-01-07 10:18:20 +08:00
|
|
|
static KleidiAI &getInstance();
|
|
|
|
static void initKernelInfo();
|
2024-08-16 09:05:35 +08:00
|
|
|
|
|
|
|
~KleidiAI() {}
|
|
|
|
|
2025-01-07 10:18:20 +08:00
|
|
|
void printInfo(AccelType type);
|
|
|
|
|
|
|
|
//Check and set
|
|
|
|
bool canAccelerate();
|
|
|
|
bool canAccelerate(AccelType type);
|
2025-07-08 16:22:23 +08:00
|
|
|
bool canAccelerate(AccelType type, const Convolution2DCommon *common);
|
2025-01-07 10:18:20 +08:00
|
|
|
bool isLoaded(AccelType type);
|
|
|
|
void setLoaded(AccelType type) { mLoaded[(size_t)type] = true; }
|
2024-08-16 09:05:35 +08:00
|
|
|
|
|
|
|
//Get info
|
2025-07-08 16:22:23 +08:00
|
|
|
static AccelType getQIntAccelType(size_t bits, bool bAsymmetric, size_t blockSize, size_t bytes);
|
2025-01-07 10:18:20 +08:00
|
|
|
size_t getMr(AccelType type, size_t m = 1);
|
|
|
|
size_t getNr(AccelType type);
|
|
|
|
size_t getKr(AccelType type);
|
|
|
|
size_t getSr(AccelType type);
|
|
|
|
size_t getMStep(AccelType type, size_t m = 1);
|
|
|
|
size_t getNStep(AccelType type);
|
|
|
|
size_t getVecNumPerThread(size_t totalVec, size_t totalThread, size_t minStep);
|
|
|
|
//Get Static info
|
2025-02-11 14:23:54 +08:00
|
|
|
bool bSupportSme2() { return mStaticInfo.mSme2; }
|
2024-08-16 09:05:35 +08:00
|
|
|
|
|
|
|
//Lhs
|
2025-04-24 08:32:00 +08:00
|
|
|
size_t getLhsPackedSize(AccelType type, size_t m, size_t k);
|
2025-01-07 10:18:20 +08:00
|
|
|
size_t getLhsQuantedPackedSize(AccelType type, size_t m, size_t k, size_t bl);
|
|
|
|
size_t getLhsQuantedPackedOffset(AccelType type, size_t m, size_t mIdx, size_t k, size_t bl);
|
|
|
|
void runLhsPack(AccelType type, size_t m, size_t k, size_t mIdx, const void* lhs, size_t lhsStride, void* lhsPacked);
|
|
|
|
void runLhsQuantPack(AccelType type, size_t m, size_t k, size_t bl, size_t mr, const void* lhs, void* lhsQuantedPacked);
|
2024-08-16 09:05:35 +08:00
|
|
|
|
|
|
|
//Rhs
|
2025-01-07 10:18:20 +08:00
|
|
|
size_t getRhsPackedSize(AccelType type, size_t n, size_t k, size_t bl);
|
|
|
|
size_t getRhsPackedOffset(AccelType type, size_t nIdx, size_t k, size_t bl);
|
|
|
|
void runRhsPack(AccelType type, size_t numGroups, size_t n, size_t k, size_t bl, size_t rhsStride,
|
|
|
|
const void* rhs, const void* scale, const void* zeroPoint, const void* bias,
|
2025-05-19 16:33:41 +08:00
|
|
|
void* rhsPacked);
|
2025-01-09 09:58:26 +08:00
|
|
|
|
2024-08-16 09:05:35 +08:00
|
|
|
//Dst
|
2025-01-07 10:18:20 +08:00
|
|
|
size_t getDstOffset(size_t mIdx, size_t nIdx, size_t n, size_t elementSize) { return (nIdx * elementSize) + mIdx * (n * elementSize); }
|
2024-08-16 09:05:35 +08:00
|
|
|
|
|
|
|
//Matmul
|
2025-01-07 10:18:20 +08:00
|
|
|
void runMatmul(AccelType type, size_t m, size_t n, size_t k, size_t bl,
|
|
|
|
const void* lhsPacked, const void* rhsPacked, void* dst,
|
2025-01-09 09:58:26 +08:00
|
|
|
size_t dstStrideRow, size_t dstStrideCol,
|
|
|
|
const float scalarMax, const float scalarMin);
|
2024-08-16 09:05:35 +08:00
|
|
|
|
|
|
|
private:
|
2025-01-07 10:18:20 +08:00
|
|
|
KleidiAI() {}
|
|
|
|
|
|
|
|
static KleidiAI *mKaiInstance;
|
|
|
|
//Static info, never change after construct.
|
|
|
|
static StaticInfo mStaticInfo;
|
|
|
|
//Status, will change while pipeline is running.
|
|
|
|
bool mLoaded[(size_t)AccelType::ACC_TYPE_NUMBER] = { false };
|
|
|
|
bool mLinear = false; //All pipeline format has been set as NCHW.
|
|
|
|
};
|
|
|
|
|
|
|
|
// ===================================================================
|
|
|
|
// Inline functions
|
|
|
|
inline bool KleidiAI::canAccelerate() {
|
|
|
|
for(size_t type = 0; type < (size_t)AccelType::ACC_TYPE_NUMBER; type++) {
|
|
|
|
if(mStaticInfo.mKernelInfo[(size_t)type].mKernelSupport && isLoaded(static_cast<AccelType>(type))) {
|
|
|
|
return true;
|
2024-10-21 14:32:47 +08:00
|
|
|
}
|
2024-08-16 09:05:35 +08:00
|
|
|
}
|
2025-01-07 10:18:20 +08:00
|
|
|
return false;
|
|
|
|
}
|
2024-08-16 09:05:35 +08:00
|
|
|
|
2025-01-07 10:18:20 +08:00
|
|
|
inline bool KleidiAI::canAccelerate(AccelType type) {
|
|
|
|
if(type >= AccelType::ACC_TYPE_ERROR) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
return mStaticInfo.mKernelInfo[(size_t)type].mKernelSupport;
|
|
|
|
}
|
|
|
|
|
2025-07-08 16:22:23 +08:00
|
|
|
inline bool KleidiAI::canAccelerate(AccelType type, const Convolution2DCommon* common) {
|
|
|
|
if(type >= AccelType::ACC_TYPE_ERROR) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
if(common->group() != 1) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
if(type == AccelType::QI4_ASYM_CHNLQT_F32|| type == AccelType::QI4_ASYM_CHNLQT_F16 || type == AccelType::QI8_ASYM_CHNLQT) {
|
|
|
|
if(common->inputCount() % 32 != 0) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
2025-07-31 16:56:51 +08:00
|
|
|
if(type == AccelType::QI4_SYM_CHNLQT_F32){
|
|
|
|
if(common->inputCount() % 2 != 0) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
2025-07-08 16:22:23 +08:00
|
|
|
if(common->kernelX() == 1 && common->kernelY() == 1
|
|
|
|
&& common->padX() == 0 && common->padY() == 0
|
|
|
|
&& common->strideX() == 1 && common->strideY() == 1
|
|
|
|
&& common->dilateX() == 1 && common->dilateY() == 1) {
|
|
|
|
return mStaticInfo.mKernelInfo[(size_t)type].mKernelSupport;
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2025-01-07 10:18:20 +08:00
|
|
|
inline bool KleidiAI::isLoaded(AccelType type) {
|
|
|
|
MNN_ASSERT(type < AccelType::ACC_TYPE_NUMBER);
|
|
|
|
return mLoaded[(size_t)type];
|
|
|
|
}
|
|
|
|
|
|
|
|
inline size_t KleidiAI::getMr(AccelType type, size_t m) {
|
|
|
|
KernelParam *pParam = &mStaticInfo.mKernelInfo[(size_t)type].mKernelParam;
|
|
|
|
return (m == 1) ? pParam->mKaiMrGemv : pParam->mKaiMrGemm;
|
|
|
|
}
|
|
|
|
|
|
|
|
inline size_t KleidiAI::getNr(AccelType type) {
|
|
|
|
KernelParam *pParam = &mStaticInfo.mKernelInfo[(size_t)type].mKernelParam;
|
|
|
|
return pParam->mKaiNr;
|
|
|
|
}
|
|
|
|
|
|
|
|
inline size_t KleidiAI::getKr(AccelType type) {
|
|
|
|
KernelParam *pParam = &mStaticInfo.mKernelInfo[(size_t)type].mKernelParam;
|
|
|
|
return pParam->mKaiKr;
|
|
|
|
}
|
|
|
|
|
|
|
|
inline size_t KleidiAI::getSr(AccelType type) {
|
|
|
|
KernelParam *pParam = &mStaticInfo.mKernelInfo[(size_t)type].mKernelParam;
|
|
|
|
return pParam->mKaiSr;
|
|
|
|
}
|
|
|
|
|
|
|
|
inline size_t KleidiAI::getMStep(AccelType type, size_t m) {
|
|
|
|
KernelParam *pParam = &mStaticInfo.mKernelInfo[(size_t)type].mKernelParam;
|
|
|
|
return (m == 1) ? pParam->mKaiMstepGemv : pParam->mKaiMstepGemm;
|
|
|
|
}
|
|
|
|
|
|
|
|
inline size_t KleidiAI::getNStep(AccelType type) {
|
|
|
|
KernelParam *pParam = &mStaticInfo.mKernelInfo[(size_t)type].mKernelParam;
|
|
|
|
return pParam->mKaiNStep;
|
|
|
|
}
|
|
|
|
|
|
|
|
inline size_t KleidiAI::getVecNumPerThread(size_t totalVec, size_t totalThread, size_t minStep) {
|
|
|
|
return kai_roundup((totalVec + totalThread - 1) / totalThread, minStep);
|
|
|
|
}
|
2024-08-16 09:05:35 +08:00
|
|
|
}
|