mirror of https://github.com/alibaba/MNN.git
369 lines
12 KiB
C++
369 lines
12 KiB
C++
#ifdef MNN_USE_SSE
|
|
#include "../x86_x64/sse/FunctionSummary.hpp"
|
|
#include "../x86_x64/avx/FunctionSummary.hpp"
|
|
#include "../x86_x64/avxfma/FunctionSummary.hpp"
|
|
#include "../x86_x64/avx512/FunctionSummary.hpp"
|
|
#endif
|
|
#include "core/Macro.h"
|
|
#if defined(MNN_USE_NEON)
|
|
#include "../arm/FunctionSummary.hpp"
|
|
#endif
|
|
|
|
#include "BF16Functions.hpp"
|
|
#include "../compute/CommonOptFunction.h"
|
|
#include "../CPURuntime.hpp"
|
|
#include "VecHalf.hpp"
|
|
#include "math/Vec.hpp"
|
|
using BFVec4 = MNN::Math::VecHalf<4>;
|
|
using Vec4 = MNN::Math::Vec<float, 4>;
|
|
namespace MNN {
|
|
// The Function Will be Called in init
|
|
void registerBF16Backend() {
|
|
BF16Functions::init();
|
|
}
|
|
// just for reference BF16 converting of c++ code, not for arm or sse.
|
|
inline int16_t MNNFP32ToBF16(float fp32Value) {
|
|
int32_t* s32Value = (int32_t*)(&fp32Value);
|
|
return (int16_t)((*s32Value) >> 16);
|
|
}
|
|
inline float MNNLowpToFp32(int16_t s16Value) {
|
|
int32_t s32Value = ((int32_t)s16Value) << 16;
|
|
float* fp32Value = (float*)(&s32Value);
|
|
return *fp32Value;
|
|
}
|
|
|
|
static void _MNNFp32ToLowp(const float* src, int16_t* dst, size_t size) {
|
|
int sizeC4 = size / 4;
|
|
for (int i = 0; i < sizeC4; ++i) {
|
|
auto srcV = Vec4::load(src);
|
|
auto dstV = BFVec4(std::move(srcV.value));
|
|
BFVec4::save(dst, dstV);
|
|
src+=4;
|
|
dst+=4;
|
|
}
|
|
int sizeRemain = size % 4;
|
|
if (sizeRemain > 0) {
|
|
float srcTemp[4];
|
|
int64_t dstTemp[1];
|
|
::memcpy(srcTemp, src, sizeRemain * sizeof(float));
|
|
auto srcV = Vec4::load(srcTemp);
|
|
auto dstV = BFVec4(std::move(srcV.value));
|
|
BFVec4::save((int16_t*)dstTemp, dstV);
|
|
::memcpy(dst, dstTemp, sizeRemain * sizeof(int16_t));
|
|
}
|
|
}
|
|
static void _MNNLowpToFp32(const int16_t* src, float* dst, size_t size) {
|
|
int sizeC4 = size / 4;
|
|
for (int i = 0; i < sizeC4; ++i) {
|
|
auto srcV = BFVec4::load(src);
|
|
auto dstV = Vec4(std::move(srcV.value));
|
|
Vec4::save(dst, dstV);
|
|
src+=4;
|
|
dst+=4;
|
|
}
|
|
int sizeRemain = size % 4;
|
|
if (sizeRemain > 0) {
|
|
int64_t srcTemp[2];
|
|
float dstTemp[4];
|
|
::memcpy(srcTemp, src, sizeRemain * sizeof(int16_t));
|
|
auto srcV = BFVec4::load((int16_t*)srcTemp);
|
|
auto dstV = Vec4(std::move(srcV.value));
|
|
Vec4::save(dstTemp, dstV);
|
|
::memcpy(dst, dstTemp, sizeRemain * sizeof(float));
|
|
}
|
|
}
|
|
|
|
#if defined(MNN_USE_NEON)
|
|
// todo: search for proper value for bf16
|
|
void NEON_MNNGetMatMulPackMode_BF16(int* eP, int* lP, int* hP) {
|
|
*eP = 12;
|
|
*lP = 1;
|
|
#ifdef __aarch64__
|
|
*hP = 8;
|
|
#else
|
|
*hP = 4;
|
|
#endif
|
|
}
|
|
|
|
#ifdef __aarch64__
|
|
#define EP 12
|
|
#define HP 8
|
|
#define LP 4
|
|
void ARMV86_MNNGetMatMulPackMode_BF16(int* eP, int* lP, int* hP) {
|
|
*eP = EP;
|
|
*hP = HP;
|
|
*lP = LP;
|
|
}
|
|
void ARMV86_MNNPackForMatMul_B_BF16(float* destF, const float* sourceF, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
|
// [l, h] -> [h/hp, l/lp, hp, lp]
|
|
auto dest = (int16_t*)destF;
|
|
auto source = (const int32_t*)sourceF;
|
|
auto l = kernelsize * ic;
|
|
auto lCP = UP_DIV(l, LP);
|
|
auto hCP = UP_DIV(h, HP);
|
|
int sYstride = 1;
|
|
int sXstride = h;
|
|
if (transpose) {
|
|
sYstride = l;
|
|
sXstride = 1;
|
|
}
|
|
::memset(dest, 0, lCP * hCP * sizeof(int16_t) * HP * LP);
|
|
for (int y = 0; y < h; ++y) {
|
|
int yC = y / HP;
|
|
int yR = y % HP;
|
|
for (int x = 0; x < l; ++x) {
|
|
int xC = x / LP;
|
|
int xR = x % LP;
|
|
dest[xR + yR * LP + xC * HP * LP + yC * HP * LP * lCP] = source[sXstride * x + sYstride * y] >> 16;
|
|
}
|
|
}
|
|
}
|
|
void ARMV86_MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) {
|
|
// [l/4, e, 4] -> [l/4, ep, 4]
|
|
int number = info[0];
|
|
int eReal = info[1];
|
|
int eDest = info[2];
|
|
int offset = info[3];
|
|
if (1 == number) {
|
|
int l = el[1];
|
|
if (l % 8 != 0) {
|
|
auto lAigin = UP_DIV(l, LP) * LP;
|
|
::memset(destOrigin, 0, eDest * lAigin * sizeof(int16_t));
|
|
}
|
|
}
|
|
|
|
for (int n=0; n<number; ++n) {
|
|
int e = el[4 * n + 0];
|
|
int l = el[4 * n + 1];
|
|
int eOffset = el[4 * n + 2];
|
|
int lOffset = el[4 * n + 3];
|
|
auto lC4 = l / 4;
|
|
auto lDiv = UP_DIV(l, 4);
|
|
auto lOC = lOffset / LP;
|
|
auto lOR = lOffset % LP;
|
|
auto source = (const float*)(sourceGroup[n]);
|
|
auto sourceInt = (const int32_t*)(source);
|
|
auto dest = ((int16_t*)destOrigin) + eOffset * LP + lOC * eDest * LP;
|
|
if (lOR == 0) {
|
|
// [l/4, e, 4] -> [l/4, ep, 4]
|
|
for (int x = 0; x < lDiv; ++x) {
|
|
auto destX = dest + x * eDest * 4;
|
|
auto srcX = source + x * eReal * 4;
|
|
for (int y = 0; y < e; ++y) {
|
|
auto srcV = Vec4::load(srcX + y * offset * 4);
|
|
auto dstV = BFVec4(std::move(srcV.value));
|
|
BFVec4::save((int16_t*)(destX + 4*y), dstV);
|
|
}
|
|
}
|
|
continue;
|
|
}
|
|
for (int x = 0; x < l; ++x) {
|
|
auto dl = lOR + x;
|
|
auto dlC = dl / LP;
|
|
auto dlR = dl % LP;
|
|
auto xC = x / LP;
|
|
auto xR = x % LP;
|
|
auto destX = dest + dlC * eDest * LP + dlR;
|
|
auto srcX = sourceInt + xC * eReal * LP + xR;
|
|
for (int y = 0; y < e; ++y) {
|
|
destX[y * 4] = srcX[y * 4 * offset] >> 16;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
#undef EP
|
|
#undef HP
|
|
#undef LP
|
|
void NEON_MNNPackForMatMul_B_BF16(float* destFloat, const float* sourceFloat, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
|
auto hP = (int)h / 8;
|
|
auto hR = (int)hP * 8;
|
|
int16_t* dest = (int16_t*)destFloat;
|
|
auto l = kernelsize * ic;
|
|
const float* source = sourceFloat;
|
|
if (!transpose) {
|
|
for (int y = 0; y < hP; ++y) {
|
|
auto destY = dest + y * 8 * l;
|
|
auto sourceY = source + y * 8;
|
|
for (int x = 0; x < l; ++x) {
|
|
auto s0 = Vec4::load(sourceY + x * h + 0);
|
|
auto s1 = Vec4::load(sourceY + x * h + 4);
|
|
auto d0 = BFVec4(std::move(s0.value));
|
|
auto d1 = BFVec4(std::move(s1.value));
|
|
BFVec4::save(destY + 8 * x + 0, d0);
|
|
BFVec4::save(destY + 8 * x + 4, d1);
|
|
}
|
|
}
|
|
auto hRemain = h - hR;
|
|
if (hRemain > 0) {
|
|
auto destY = dest + hP * 8 * l;
|
|
auto sourceY = source + hP * 8;
|
|
float sTmp[8];
|
|
::memset(sTmp, 0, sizeof(sTmp));
|
|
for (int x = 0; x < l; ++x) {
|
|
::memcpy(sTmp, sourceY + x * h, hRemain * sizeof(float));
|
|
auto s0 = Vec4::load(sTmp + 0);
|
|
auto s1 = Vec4::load(sTmp + 4);
|
|
auto d0 = BFVec4(std::move(s0.value));
|
|
auto d1 = BFVec4(std::move(s1.value));
|
|
BFVec4::save(destY + 8 * x + 0, d0);
|
|
BFVec4::save(destY + 8 * x + 4, d1);
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
if (hR != h) {
|
|
::memset(dest, 0, UP_DIV(h, 8) * 8 * l * sizeof(int16_t));
|
|
}
|
|
auto sourceInt32 = (const int32_t*)source;
|
|
#if 0
|
|
// Origin C++ code
|
|
for (int y = 0; y < h; ++y) {
|
|
auto yR = y % 8;
|
|
auto yC = y / 8;
|
|
for (int x = 0; x < l; ++x) {
|
|
dest[x * 8 + yR + yC * 8 * l] = sourceInt32[x + y * l] >> 16;
|
|
}
|
|
}
|
|
return;
|
|
#endif
|
|
int lC8 = (int)l / 8;
|
|
auto lR = lC8 * 8;
|
|
if (hP > 0 && lC8 > 0) {
|
|
MNNPackC8_BF16(destFloat, sourceFloat, l, h);
|
|
}
|
|
for (int y = hR; y < h; ++y) {
|
|
auto yR = y % 8;
|
|
auto yC = hP;
|
|
for (int x = 0; x < l; ++x) {
|
|
dest[x * 8 + yR + yC * 8 * l] = sourceInt32[x + y * l] >> 16;
|
|
}
|
|
}
|
|
for (int y = 0; y < hR; ++y) {
|
|
auto yR = y % 8;
|
|
auto yC = y / 8;
|
|
for (int x = lR; x < l; ++x) {
|
|
dest[x * 8 + yR + yC * 8 * l] = sourceInt32[x + y * l] >> 16;
|
|
}
|
|
}
|
|
}
|
|
|
|
#else
|
|
void NEON_MNNPackForMatMul_B_BF16(float* destFloat, const float* sourceFloat, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
|
int16_t* dest = (int16_t*)destFloat;
|
|
const float* source = sourceFloat;
|
|
auto l = kernelsize * ic;
|
|
if (!transpose) {
|
|
auto hP = h / 4;
|
|
auto hR = hP * 4;
|
|
if (hR != h) {
|
|
::memset(dest, 0, UP_DIV(h, 4) * 4 * l * sizeof(int16_t));
|
|
}
|
|
for (int y = 0; y < hP; ++y) {
|
|
auto destY = dest + y * 4 * l;
|
|
auto sourceY = source + y * 4;
|
|
for (int x = 0; x < l; ++x) {
|
|
auto s0 = Vec4::load(sourceY + x * h + 0);
|
|
auto d0 = BFVec4(std::move(s0.value));
|
|
BFVec4::save(destY + 4 * x + 0, d0);
|
|
}
|
|
}
|
|
auto hRemain = h - hR;
|
|
if (hRemain > 0) {
|
|
auto destY = dest + hP * 4 * l;
|
|
auto sourceY = source + hP * 4;
|
|
for (int x = 0; x < l; ++x) {
|
|
auto s0 = Vec4::load(sourceY + x * h + 0);
|
|
auto d0 = BFVec4(std::move(s0.value));
|
|
BFVec4::save(destY + 4 * x + 0, d0);
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
#if 0
|
|
auto sourceInt32 = (const int32_t*)source;
|
|
// Origin C++ code
|
|
::memset(dest, 0, UP_DIV(h, 4) * 4 * l * sizeof(int16_t));
|
|
|
|
for (int y = 0; y < h; ++y) {
|
|
auto yR = y % 4;
|
|
auto yC = y / 4;
|
|
for (int x = 0; x < l; ++x) {
|
|
dest[x * 4 + yR + yC * 4 * l] = sourceInt32[x + y * l] >> 16;
|
|
}
|
|
}
|
|
return;
|
|
#endif
|
|
int offset[2] = {
|
|
(int)l,
|
|
(int)l,
|
|
};
|
|
MNNPackC4_BF16(destFloat, sourceFloat, l, h, offset);
|
|
}
|
|
#endif // __aarch64__
|
|
#endif
|
|
|
|
#if 0
|
|
void MNNPackC4ForMatMul_ABF16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) {
|
|
int number = info[0];
|
|
int eReal = info[1];
|
|
int eDest = info[2];
|
|
int offset = info[3];
|
|
for (int n=0; n<number; ++n) {
|
|
int e = el[4 * n + 0];
|
|
int l = el[4 * n + 1];
|
|
int eOffset = el[4 * n + 2];
|
|
int lOffset = el[4 * n + 3];
|
|
auto dest = (int16_t*)destOrigin + lOffset * eDest + eOffset;
|
|
auto source = (int32_t*)sourceGroup[n];
|
|
|
|
for (int y=0; y<e; ++y) {
|
|
auto yR = y % eDest;
|
|
for (int x=0; x<l; ++x) {
|
|
auto xR = x % 4;
|
|
auto xC = x / 4;
|
|
dest[(x) * eDest + yR] = source[xC * eReal * 4 + y * 4 * offset + xR] >> 16;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
#endif
|
|
|
|
static CoreFunctions* gInstance = nullptr;
|
|
bool BF16Functions::init() {
|
|
#if !defined(MNN_USE_NEON)
|
|
return false;
|
|
#else
|
|
gInstance = new CoreFunctions;
|
|
*gInstance = *MNNGetCoreFunctions();
|
|
gInstance->MNNFp32ToLowp = _MNNFp32ToLowp;
|
|
gInstance->MNNLowpToFp32 = _MNNLowpToFp32;
|
|
gInstance->matmulBytes = 2;
|
|
|
|
gInstance->MNNPackForMatMul_B = NEON_MNNPackForMatMul_B_BF16;
|
|
gInstance->MNNGetMatMulPackMode = NEON_MNNGetMatMulPackMode_BF16;
|
|
gInstance->MNNPackC4ForMatMul_A = NEON_MNNPackC4ForMatMul_A_BF16;
|
|
gInstance->MNNPackedMatMul = NEON_MNNPackedMatMul_BF16;
|
|
gInstance->MNNPackedMatMulRemain = NEON_MNNPackedMatMulRemain_BF16;
|
|
#ifdef __aarch64__
|
|
const MNNCPUInfo& gCPUInfo = *MNNGetCPUInfo();
|
|
gInstance->supportFp16arith = gCPUInfo.fp16arith;
|
|
gInstance->supportSDot = gCPUInfo.dot;
|
|
gInstance->supportI8mm = gCPUInfo.i8mm;
|
|
if (gInstance->supportI8mm) {
|
|
gInstance->MNNPackForMatMul_B = ARMV86_MNNPackForMatMul_B_BF16;
|
|
gInstance->MNNPackC4ForMatMul_A = ARMV86_MNNPackC4ForMatMul_A_BF16;
|
|
gInstance->MNNGetMatMulPackMode = ARMV86_MNNGetMatMulPackMode_BF16;
|
|
gInstance->MNNPackedMatMul = ARMV86_MNNPackedMatMul_BF16;
|
|
gInstance->MNNPackedMatMulRemain = ARMV86_MNNPackedMatMulRemain_BF16;
|
|
}
|
|
#endif
|
|
// TODO: raw cpu version of bf16
|
|
return true;
|
|
#endif
|
|
}
|
|
|
|
CoreFunctions* BF16Functions::get() {
|
|
return gInstance;
|
|
}
|
|
};
|