mirror of https://github.com/alibaba/MNN.git
2550 lines
113 KiB
C++
2550 lines
113 KiB
C++
#if defined(__ANDROID__) || defined(__aarch64__)
|
|
#include <math.h>
|
|
#include <float.h>
|
|
#include "Arm82Functions.hpp"
|
|
#include "Arm82OptFunc.hpp"
|
|
#include "Arm82WinogradOptFunc.hpp"
|
|
#include "Arm82Vec.hpp"
|
|
#include "Arm82Binary.hpp"
|
|
#include "Arm82Unary.hpp"
|
|
#include "Arm82Relu.hpp"
|
|
#include "backend/cpu/compute/CommonOptFunction.h"
|
|
#include "backend/cpu/CPUPool.hpp"
|
|
#include "backend/cpu/CPURuntime.hpp"
|
|
|
|
#define FLOAT FLOAT16
|
|
#define PACK 8
|
|
using Vec = MNN::Math::Vec<FLOAT16, 8>;
|
|
|
|
#include "backend/cpu/GridSampler.hpp"
|
|
|
|
#if defined(MNN_USE_NEON)
|
|
#include <arm_neon.h>
|
|
#endif
|
|
|
|
extern "C" {
|
|
// (UP_DIV(l,8), e, 8) -> (UP_DIV(e,eP), l, eP)
|
|
void Arm82MNNPackForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el);
|
|
// void MNNPackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, size_t depth, int32_t* areaOffset);
|
|
|
|
// C(UP_DIV(h,8), e, h8) = B(UP_DIV(h,hP), l, hP) * A(l, eP), hP = 24
|
|
// parameter: [aStride, l, h, cStride, bExtraStride]
|
|
// aStride in parameter is deprecated (useless), but for code clean, just retain it
|
|
void MNNPackedMatMulFP16(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
|
|
|
|
// C(UP_DIV(h,8), e, h8) = B(UP_DIV(h,hP), l, hP) * A(l, e), hP = 24, e >= 1
|
|
// parameter: [aStride, l, h, cStride, bExtraStride]
|
|
void MNNPackedMatMulRemainFP16(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
|
|
|
|
#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM
|
|
void MNNPackedMatMulFP16_int4(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
|
|
void MNNPackedMatMulRemainFP16_int4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
|
|
void MNNPackedMatMulFP16_int8(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
|
|
void MNNPackedMatMulRemainFP16_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
|
|
#endif
|
|
|
|
#ifdef __aarch64__
|
|
#ifdef MNN_LOW_MEMORY
|
|
void MNNAbsMaxFP16_Pack8(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack);
|
|
void MNNAbsMaxFP16_Pack4(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack);
|
|
void MNNQuantScaleFP16(float* sum, float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch);
|
|
void MNNDynamicQuantFP16_Pack8(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, const float* bias, size_t pack);
|
|
void MNNDynamicQuantFP16_Pack4(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, const float* bias, size_t pack);
|
|
void MNNGeneralIm2col_Arm82(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
|
|
void MNNGeneralIm2col_Arm86(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
|
|
#ifdef MNN_SME2
|
|
void MNNGeneralIm2col_Fp16Sme2(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el, int32_t LP, int32_t pack);
|
|
#endif
|
|
void MNNLocalMinMaxFP16_Pack4(float* dstMin, float* dstMax, const float* source, size_t blockNum, size_t blockLU, size_t EP, size_t LP, size_t loadDstBuffer);
|
|
void MNNLocalMinMaxFP16_Pack8(float* dstMin, float* dstMax, const float* source, size_t blockNum, size_t blockLU, size_t EP, size_t LP, size_t loadDstBuffer);
|
|
#endif // MNN_LOW_MEMORY
|
|
void CountMinMaxValue_FP16(float* source, float* minVal, float* maxVal, size_t sizeQuad);
|
|
#ifdef MNN_SME2
|
|
void MNNPackedMatMulRemainFP16_SME2(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b);
|
|
#endif
|
|
#endif
|
|
|
|
#if defined(__aarch64__)
|
|
|
|
void MNNDepthwiseConvFastKernelFP16(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup,
|
|
size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height,
|
|
size_t srcHStep, size_t dstHStep, const float* bias, const float* parameters);
|
|
#endif
|
|
|
|
void MNNConvRunForLineDepthwiseFP16(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup,
|
|
size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height, size_t srcHStep, size_t dstHStep);
|
|
}
|
|
|
|
|
|
namespace MNN {
|
|
|
|
static void Sme2MNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
|
|
*hP = 64;
|
|
*eP = 16;
|
|
*lP = 2;
|
|
}
|
|
|
|
static void MNNMatrixAddFP16(FLOAT16* C, const FLOAT16* A, const FLOAT16* B, size_t widthC8, size_t cStride, size_t aStride, size_t bStride, size_t height) {
|
|
for (int y = 0; y < height; ++y) {
|
|
auto a = A + aStride * y, b = B + bStride * y;
|
|
auto c = C + cStride * y;
|
|
for (int x = 0; x < widthC8; ++x) {
|
|
vst1q_f16(c + x * 8, vaddq_f16(vld1q_f16(a + x * 8), vld1q_f16(b + x * 8)));
|
|
}
|
|
}
|
|
}
|
|
static void MNNMatrixSubFP16(FLOAT16* C, const FLOAT16* A, const FLOAT16* B, size_t widthC8, size_t cStride, size_t aStride, size_t bStride, size_t height) {
|
|
for (int y = 0; y < height; ++y) {
|
|
auto a = A + aStride * y, b = B + bStride * y;
|
|
auto c = C + cStride * y;
|
|
for (int x = 0; x < widthC8; ++x) {
|
|
vst1q_f16(c + x * 8, vsubq_f16(vld1q_f16(a + x * 8), vld1q_f16(b + x * 8)));
|
|
}
|
|
}
|
|
}
|
|
#if defined(__aarch64__)
|
|
static void ARM82CountMinMaxValue(float* source, float* minVal, float* maxVal, size_t size) {
|
|
if (size % 8 == 0) {
|
|
CountMinMaxValue_FP16(source, minVal, maxVal, size / 8);
|
|
} else {
|
|
auto remain = size - 8 * (size / 8);
|
|
auto max_ = ((__fp16*)source)[0];
|
|
auto min_ = max_;
|
|
if (size >= 8) {
|
|
CountMinMaxValue_FP16(source, minVal, maxVal, size / 8);
|
|
max_ = ((__fp16*)maxVal)[0];
|
|
min_ = ((__fp16*)minVal)[0];
|
|
}
|
|
auto srcPtr = reinterpret_cast<__fp16*>(source);
|
|
while (remain) {
|
|
max_ = ALIMAX(srcPtr[0], max_);
|
|
min_ = ALIMIN(srcPtr[0], min_);
|
|
srcPtr += 1;
|
|
remain--;
|
|
}
|
|
reinterpret_cast<__fp16*>(minVal)[0] = min_;
|
|
reinterpret_cast<__fp16*>(maxVal)[0] = max_;
|
|
}
|
|
}
|
|
#ifdef MNN_SME2
|
|
//(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias)
|
|
static void MNNPackedMatMulFP16_SME2(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) {
|
|
MNNPackedMatMulRemainFP16_SME2(C, A, B, 16, parameter, postParameters, bias, k, b);
|
|
}
|
|
#endif
|
|
#else
|
|
static void ARM82CountMinMaxValue(float* source, float* minVal, float* maxVal, size_t size) {
|
|
auto srcPtr = (FLOAT16*)source;
|
|
auto minPtr = (FLOAT16*)minVal;
|
|
auto maxPtr = (FLOAT16*)maxVal;
|
|
auto max_ = srcPtr[0], min_ = srcPtr[0];
|
|
for (int i = 1; i < size; ++i) {
|
|
if (max_ < srcPtr[i]) {
|
|
max_ = srcPtr[i];
|
|
}
|
|
if (min_ > srcPtr[i]) {
|
|
min_ = srcPtr[i];
|
|
}
|
|
}
|
|
minPtr[0] = min_;
|
|
maxPtr[0] = max_;
|
|
}
|
|
#endif
|
|
|
|
static void Arm82MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h, size_t kernelsize, size_t ic, bool transpose) {
|
|
auto l = kernelsize * ic;
|
|
auto dest = (int16_t*)destC;
|
|
auto source = (int16_t*)sourceC;
|
|
int ePack, lPack, hPack;
|
|
Arm82MNNGetMatMulPackMode(&ePack, &lPack, &hPack);
|
|
auto hP = (int)h / hPack;
|
|
auto hR = (int)hP * hPack;
|
|
if (hR != h) {
|
|
::memset(dest, 0, UP_DIV(h, hPack) * hPack * l * sizeof(FLOAT16));
|
|
}
|
|
if (!transpose) {
|
|
for (int y = 0; y < hP; ++y) {
|
|
auto destY = dest + y * hPack * l;
|
|
auto sourceY = source + y * hPack;
|
|
for (int x = 0; x < l; ++x) {
|
|
::memcpy(destY + hPack * x, sourceY + x * h, hPack * sizeof(FLOAT16));
|
|
}
|
|
}
|
|
auto hRemain = h - hR;
|
|
if (hRemain > 0) {
|
|
auto destY = dest + hP * hPack * l;
|
|
auto sourceY = source + hP * hPack;
|
|
for (int x = 0; x < l; ++x) {
|
|
::memcpy(destY + hPack * x, sourceY + x * h, hRemain * sizeof(FLOAT16));
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
for (int y = 0; y < h; ++y) {
|
|
for (int x = 0; x < l; ++x) {
|
|
dest[(y / hPack * l + x) * hPack + y % hPack] = source[y * l + x];
|
|
}
|
|
}
|
|
}
|
|
|
|
static void Sme2MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h, size_t kernelsize, size_t ic ,bool transpose) {
|
|
auto dest = (int16_t*)destC;
|
|
auto source = (int16_t*)sourceC;
|
|
int LP = 2;
|
|
int HP = 64;
|
|
auto l = kernelsize * ic;
|
|
memset(dest, 0, ROUND_UP(h, HP) * ROUND_UP(ic, LP) * kernelsize * sizeof(FLOAT16));
|
|
auto stride0 = ROUND_UP(ic, LP) * kernelsize * HP;
|
|
auto stride1 = HP * ROUND_UP(ic, LP);
|
|
auto stride2 = HP * LP;
|
|
|
|
size_t srcStride0 = l; // [h,k2,ic]->[hu,k2,ic/lp,hp,lp]
|
|
size_t srcStride1 = 1;
|
|
if (!transpose) { // [k2,ic,h]->[hu,k2,ic/lp,hp,lp]
|
|
srcStride0 = 1;
|
|
srcStride1 = h;
|
|
}
|
|
for (int y = 0; y < h; ++y) {
|
|
auto yHu = y / HP;
|
|
auto yHp = y % HP;
|
|
for (int k = 0; k < kernelsize; ++k) {
|
|
for (int x = 0; x < ic; ++x) {
|
|
auto xLu = x / LP;
|
|
auto xLp = x % LP;
|
|
dest[yHu * stride0 + k * stride1 + xLu * stride2 + yHp * LP + xLp] = source[y * srcStride0 + (x + k * ic) * srcStride1];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static void MNNScaleAndAddBiasFP16(FLOAT16* dst, const FLOAT16* src, const FLOAT16* bias, const FLOAT16* alpha, size_t planeNumber,
|
|
size_t biasNumber) {
|
|
for (int z = 0; z < biasNumber; ++z) {
|
|
FLOAT16* dstZ = dst + planeNumber * 8 * z;
|
|
const FLOAT16* srcZ = src + planeNumber * 8 * z;
|
|
auto biasZ = vld1q_f16(bias + 8 * z), alphaZ = vld1q_f16(alpha + 8 * z);
|
|
for (int p = 0; p < planeNumber; ++p) {
|
|
FLOAT16* dstX = dstZ + 8 * p;
|
|
const FLOAT16* srcX = srcZ + 8 * p;
|
|
auto res = vaddq_f16(vmulq_f16(vld1q_f16(srcX), alphaZ), biasZ);
|
|
vst1q_f16(dstX, res);
|
|
}
|
|
}
|
|
}
|
|
|
|
static void MNNGridSampleComputeCordFP16(FLOAT16* dst, const FLOAT16* src, size_t inH, size_t inW, size_t outH, size_t outW, size_t stride, bool alignCorners) {
|
|
float16x8_t zero = vdupq_n_f16(0);
|
|
float16x8_t one = vdupq_n_f16(1);
|
|
float16x8_t half = vdupq_n_f16(0.5f);
|
|
float16x8_t a = alignCorners ? one : zero;
|
|
float16x8_t b = alignCorners ? zero : one;
|
|
float16x8_t inW_sub_a = vsubq_f16(vdupq_n_f16(inW), a);
|
|
float16x8_t inH_sub_a = vsubq_f16(vdupq_n_f16(inH), a);
|
|
|
|
int area = outH * outW;
|
|
int areaC8 = area / 8;
|
|
int areaRemain = area - areaC8 * 8;
|
|
for (int i = 0; i < areaC8; ++i) {
|
|
auto cordH = vld2q_f16(src);
|
|
// float16x8_t x = cordH.val[0];
|
|
// float16x8_t y = cordH.val[1];
|
|
cordH.val[0] = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, cordH.val[0]), inW_sub_a), b));
|
|
cordH.val[1] = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, cordH.val[1]), inH_sub_a), b));
|
|
vst2q_f16(dst, cordH);
|
|
|
|
src += 16;
|
|
dst += 16;
|
|
}
|
|
if (areaRemain == 0) {
|
|
return;
|
|
}
|
|
|
|
// areaRemain
|
|
FLOAT16 tempDst[16];
|
|
::memcpy(tempDst, src, areaRemain * 2 * sizeof(int16_t));
|
|
auto cordH = vld2q_f16(tempDst);
|
|
cordH.val[0] = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, cordH.val[0]), inW_sub_a), b));
|
|
cordH.val[1] = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, cordH.val[1]), inH_sub_a), b));
|
|
vst2q_f16(tempDst, cordH);
|
|
::memcpy(dst, tempDst, areaRemain * 2 * sizeof(int16_t));
|
|
}
|
|
|
|
static void MNNGridSampleComputeCord3DFp16(FLOAT* dst, const FLOAT* src, size_t inD, size_t inH, size_t inW, size_t outD, size_t outH, size_t outW, bool alignCorners) {
|
|
float16x8_t zero = vdupq_n_f16(0);
|
|
float16x8_t one = vdupq_n_f16(1);
|
|
float16x8_t half = vdupq_n_f16(0.5f);
|
|
float16x8_t a = alignCorners ? one : zero;
|
|
float16x8_t b = alignCorners ? zero : one;
|
|
float16x8_t inW_sub_a = vsubq_f16(vdupq_n_f16(inW), a);
|
|
float16x8_t inH_sub_a = vsubq_f16(vdupq_n_f16(inH), a);
|
|
float16x8_t inD_sub_a = vsubq_f16(vdupq_n_f16(inD), a);
|
|
size_t area = outH * outW * outD;
|
|
size_t areaC8 = area / 8;
|
|
size_t areaRemain = area - areaC8 * 8;
|
|
|
|
for (int i = 0; i < areaC8; ++i) {
|
|
auto cordH = vld3q_f16(src);
|
|
// float16x8_t x = cordH.val[0];
|
|
// float16x8_t y = cordH.val[1];
|
|
cordH.val[0] = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, cordH.val[0]), inW_sub_a), b));
|
|
cordH.val[1] = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, cordH.val[1]), inH_sub_a), b));
|
|
cordH.val[2] = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, cordH.val[2]), inD_sub_a), b));
|
|
vst3q_f16(dst, cordH);
|
|
src += 24;
|
|
dst += 24;
|
|
}
|
|
if (areaRemain == 0) {
|
|
return;
|
|
}
|
|
|
|
// areaRemain
|
|
FLOAT16 tempDst[24];
|
|
::memcpy(tempDst, src, areaRemain * 3 * sizeof(int16_t));
|
|
auto cordH = vld3q_f16(tempDst);
|
|
cordH.val[0] = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, cordH.val[0]), inW_sub_a), b));
|
|
cordH.val[1] = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, cordH.val[1]), inH_sub_a), b));
|
|
cordH.val[2] = vmulq_f16(half, vsubq_f16(vmulq_f16(vaddq_f16(one, cordH.val[2]), inD_sub_a), b));
|
|
vst3q_f16(tempDst, cordH);
|
|
::memcpy(dst, tempDst, areaRemain * 3 * sizeof(int16_t));
|
|
}
|
|
static void MNNRoiPoolingMaxFP16(FLOAT16* dst, const FLOAT16* src, int hLen, int wLen, int iw) {
|
|
Vec max = Vec(-65504.0f);
|
|
for (int h = 0; h < hLen; h++, src += iw * 8) {
|
|
for (int w = 0; w < wLen; w++) {
|
|
Vec in = Vec::load(src + w * 8);
|
|
max = Vec::max(max, in);
|
|
}
|
|
}
|
|
Vec::save(dst, max);
|
|
}
|
|
|
|
static void MNNRoiAlignMaxFP16(FLOAT16* dst, const FLOAT16* src, const std::vector<std::vector<int>> &vecPos, const std::vector<std::vector<float>> &vecArea, int samplingRatioArea, int pooledHeight, int pooledWidth) {
|
|
for (int h = 0; h < pooledHeight; ++h, dst += pooledWidth * 8) {
|
|
int preCalcIdx = h * pooledWidth * samplingRatioArea;
|
|
for (int w = 0; w < pooledWidth; ++w) {
|
|
Vec res = Vec(-65504.0f);
|
|
for (int i = 0; i < samplingRatioArea; ++i) {
|
|
const std::vector<int>& pos = vecPos[preCalcIdx];
|
|
const std::vector<float>& area = vecArea[preCalcIdx];
|
|
|
|
Vec val0 = Vec::load(src + pos[0] * 8);
|
|
Vec val1 = Vec::load(src + pos[1] * 8);
|
|
Vec val2 = Vec::load(src + pos[2] * 8);
|
|
Vec val3 = Vec::load(src + pos[3] * 8);
|
|
Vec mla = val0 * area[0];
|
|
mla = Vec::fma(mla, val1, area[1]);
|
|
mla = Vec::fma(mla, val2, area[2]);
|
|
mla = Vec::fma(mla, val3, area[3]);
|
|
res = Vec::max(res, mla);
|
|
preCalcIdx++;
|
|
}
|
|
Vec::save(dst + w * 8, res);
|
|
}
|
|
}
|
|
}
|
|
|
|
static void MNNRoiAlignAvgFP16(FLOAT16* dst, const FLOAT16* src, const std::vector<std::vector<int>> &vecPos, const std::vector<std::vector<float>> &vecArea, int samplingRatioArea, int pooledHeight, int pooledWidth) {
|
|
float invSamplingCnt = 1.f / samplingRatioArea;
|
|
for (int h = 0; h < pooledHeight; ++h, dst += pooledWidth * 8) {
|
|
int preCalcIdx = h * pooledWidth * samplingRatioArea;
|
|
for (int w = 0; w < pooledWidth; ++w) {
|
|
Vec res = Vec(0.f);
|
|
for (int i = 0; i < samplingRatioArea; ++i) {
|
|
const std::vector<int>& pos = vecPos[preCalcIdx];
|
|
const std::vector<float>& area = vecArea[preCalcIdx];
|
|
|
|
Vec val0 = Vec::load(src + pos[0] * 8);
|
|
Vec val1 = Vec::load(src + pos[1] * 8);
|
|
Vec val2 = Vec::load(src + pos[2] * 8);
|
|
Vec val3 = Vec::load(src + pos[3] * 8);
|
|
Vec mla = val0 * area[0];
|
|
mla = Vec::fma(mla, val1, area[1]);
|
|
mla = Vec::fma(mla, val2, area[2]);
|
|
mla = Vec::fma(mla, val3, area[3]);
|
|
res += mla;
|
|
preCalcIdx++;
|
|
}
|
|
res = res * invSamplingCnt;
|
|
Vec::save(dst + w * 8, res);
|
|
}
|
|
}
|
|
}
|
|
|
|
static void MNNCopyC8WithStrideFP16(const FLOAT16* source, FLOAT16* dest, size_t srcStride, size_t dstStride, size_t count) {
|
|
using Vec = MNN::Math::Vec<FLOAT16, 8>;
|
|
for (int i = 0; i < count; ++i) {
|
|
auto srcPtr = source + i * srcStride;
|
|
auto dstPtr = dest + i * dstStride;
|
|
Vec::save(dstPtr, Vec::load(srcPtr));
|
|
}
|
|
}
|
|
|
|
static void MNNAddC8WithStrideFP16(const FLOAT16* source, FLOAT16* dest, size_t srcStride, size_t dstStride, size_t count) {
|
|
for (int i = 0; i < count; ++i) {
|
|
auto srcPtr = source + i * srcStride;
|
|
auto dstPtr = dest + i * dstStride;
|
|
auto value = Vec::load(dstPtr) + Vec::load(srcPtr);
|
|
Vec::save(dstPtr, value);
|
|
}
|
|
}
|
|
|
|
static void MNNAxByClampBroadcastC8FP16(float* CF, const float* AF, const float* BF, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) {
|
|
auto C = (FLOAT16*)CF;
|
|
auto A = (FLOAT16*)AF;
|
|
auto B = (FLOAT16*)BF;
|
|
using Vec = MNN::Math::Vec<FLOAT16, 8>;
|
|
auto minF = Vec(parameters[2]);
|
|
auto maxF = Vec(parameters[3]);
|
|
auto beta = Vec(parameters[1]);
|
|
for (int y = 0; y < height; ++y) {
|
|
auto a = A + aStride * y;
|
|
auto b = B + 8 * y;
|
|
auto bv = Vec::load(b);
|
|
auto c = C + cStride * y;
|
|
for (int x = 0; x < width; ++x) {
|
|
auto av = Vec::load(a + 8 * x);
|
|
auto cv = av + bv * beta;
|
|
cv = Vec::min(cv, maxF);
|
|
cv = Vec::max(cv, minF);
|
|
Vec::save(c + 8 * x, cv);
|
|
}
|
|
}
|
|
}
|
|
|
|
void ARM82StrassenMerge(FLOAT16* c11, FLOAT16* c12, FLOAT16* c21, FLOAT16* c22, FLOAT16* xAddr,
|
|
size_t cStride, size_t eSub, size_t hSub) {
|
|
const int pack = 8;
|
|
for (int y = 0; y < hSub; ++y) {
|
|
auto c11Y = c11 + y * cStride;
|
|
auto c12Y = c12 + y * cStride;
|
|
auto c22Y = c22 + y * cStride;
|
|
auto c21Y = c21 + y * cStride;
|
|
auto xY = xAddr + y * eSub * pack;
|
|
for (int x = 0; x < eSub; ++x) {
|
|
auto xv = vld1q_f16(xY + x * pack);
|
|
auto c21v = vld1q_f16(c21Y + x * pack);
|
|
auto c11v = vld1q_f16(c11Y + x * pack);
|
|
auto c22v = vld1q_f16(c22Y + x * pack);
|
|
auto c12v = vld1q_f16(c12Y + x * pack);
|
|
c12v = c12v + xv;
|
|
c21v = c12v + c21v;
|
|
c12v = c22v + c12v;
|
|
c22v = c22v + c21v;
|
|
c12v = c11v + c12v;
|
|
vst1q_f16(c12Y + x * pack, c12v);
|
|
vst1q_f16(c22Y + x * pack, c22v);
|
|
vst1q_f16(c21Y + x * pack, c21v);
|
|
}
|
|
}
|
|
}
|
|
|
|
void MNNUnpackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, size_t depth, int32_t* areaOffset) {
|
|
int srcAreaOffset = areaOffset[0];
|
|
int c = (int)depth;
|
|
int cDiv4 = c / 8;
|
|
int cAlign = cDiv4 * 8;
|
|
int areaDiv4 = area / 4;
|
|
int areaAlign = areaDiv4 * 4;
|
|
|
|
if (areaAlign > 0) {
|
|
for (int ci = 0; ci < cDiv4; ++ci) {
|
|
auto srcH = src + ci * 8 * srcAreaOffset;
|
|
auto dstH = dst + ci * 8;
|
|
for (int hi = 0; hi < areaAlign; hi+=4) {
|
|
auto src0 = srcH + hi * 8;
|
|
auto src1 = srcH + hi * 8 + 8;
|
|
auto src2 = srcH + hi * 8 + 16;
|
|
auto src3 = srcH + hi * 8 + 24;
|
|
|
|
auto dst0 = dstH + hi * c;
|
|
auto dst1 = dstH + hi * c + c;
|
|
auto dst2 = dstH + hi * c + 2 * c;
|
|
auto dst3 = dstH + hi * c + 3 * c;
|
|
vst1q_s16(dst0, vld1q_s16(src0));
|
|
vst1q_s16(dst1, vld1q_s16(src1));
|
|
vst1q_s16(dst2, vld1q_s16(src2));
|
|
vst1q_s16(dst3, vld1q_s16(src3));
|
|
}
|
|
}
|
|
}
|
|
if (areaAlign < area) {
|
|
for (int ci = 0; ci < cDiv4; ++ci) {
|
|
auto srcH = src + 8 * ci * srcAreaOffset;
|
|
auto dstH = dst + ci * 8;
|
|
for (int hi = areaAlign; hi < area; ++hi) {
|
|
auto src0 = srcH + hi * 8;
|
|
auto dst0 = dstH + hi * c;
|
|
vst1q_s16(dst0, vld1q_s16(src0));
|
|
}
|
|
}
|
|
}
|
|
if (c == cAlign) {
|
|
return;
|
|
}
|
|
|
|
int cReamin = c - cAlign;
|
|
auto srcAlign = src + srcAreaOffset * cAlign;
|
|
auto dstAlign = dst + cAlign;
|
|
|
|
for (int hi = 0; hi < area; ++hi) {
|
|
auto srcHeight = srcAlign + hi * 8;
|
|
auto dstHeight = dstAlign + hi * c;
|
|
|
|
for (int ci = 0; ci < cReamin; ++ci) {
|
|
dstHeight[ci] = srcHeight[ci];
|
|
}
|
|
}
|
|
}
|
|
|
|
void MNNPackTransposeInt16C8(int16_t* dst, const int16_t* src, size_t area, size_t depth, int32_t* areaOffset) {
|
|
if (depth == 8) {
|
|
::memcpy(dst, src, area * depth * sizeof(int16_t));
|
|
return;
|
|
}
|
|
int dstAreaOffset = areaOffset[1];
|
|
int c = (int)depth;
|
|
int cDiv4 = c / 8;
|
|
int cAlign = cDiv4 * 8;
|
|
int areaDiv4 = area / 4;
|
|
int areaAlign = areaDiv4 * 4;
|
|
if (areaAlign > 0) {
|
|
for (int ci = 0; ci < cDiv4; ++ci) {
|
|
auto srcH = src + ci * 8;
|
|
auto dstH = dst + ci * dstAreaOffset * 8;
|
|
for (int hi = 0; hi < areaAlign; hi+=4) {
|
|
auto src0 = srcH + hi * c;
|
|
auto src1 = srcH + hi * c + c;
|
|
auto src2 = srcH + hi * c + 2 * c;
|
|
auto src3 = srcH + hi * c + 3 * c;
|
|
auto dst0 = dstH + hi * 8;
|
|
auto dst1 = dstH + hi * 8 + 8;
|
|
auto dst2 = dstH + hi * 8 + 16;
|
|
auto dst3 = dstH + hi * 8 + 24;
|
|
vst1q_s16(dst0, vld1q_s16(src0));
|
|
vst1q_s16(dst1, vld1q_s16(src1));
|
|
vst1q_s16(dst2, vld1q_s16(src2));
|
|
vst1q_s16(dst3, vld1q_s16(src3));
|
|
}
|
|
}
|
|
}
|
|
if (areaAlign < area) {
|
|
for (int ci = 0; ci < cDiv4; ++ci) {
|
|
auto srcH = src + ci * 8;
|
|
auto dstH = dst + ci * dstAreaOffset * 8;
|
|
for (int hi = areaAlign; hi < area; ++hi) {
|
|
auto src0 = srcH + hi * c;
|
|
auto dst0 = dstH + hi * 8;
|
|
vst1q_s16(dst0, vld1q_s16(src0));
|
|
}
|
|
}
|
|
}
|
|
|
|
if (cAlign == c) {
|
|
return;
|
|
}
|
|
|
|
int cReamin = c - cAlign;
|
|
auto srcAlign = src + cAlign;
|
|
auto dstAlign = dst + dstAreaOffset * cAlign;
|
|
|
|
for (int hi = 0; hi < area; ++hi) {
|
|
auto srcHeight = srcAlign + hi * c;
|
|
auto dstHeight = dstAlign + hi * 8;
|
|
for (int i = 0; i < 8; ++i) {
|
|
dstHeight[i] = 0;
|
|
}
|
|
for (int ci = 0; ci < cReamin; ++ci) {
|
|
dstHeight[ci] = srcHeight[ci];
|
|
}
|
|
}
|
|
}
|
|
|
|
static void _MNNDeconvRunForUnitDepthWise(const FLOAT16* dst, FLOAT16* src, const FLOAT16* weight, size_t fw, size_t fh,
|
|
size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) {
|
|
int fx, fy;
|
|
auto src_z = src;
|
|
auto weight_z = weight;
|
|
Vec dstV = Vec::load(dst);
|
|
for (fy = 0; fy < fh; ++fy) {
|
|
auto src_y = src_z + fy * dilateY_step;
|
|
auto weight_y = weight_z + fy * weight_y_step;
|
|
for (fx = 0; fx < fw; ++fx) {
|
|
Vec weight_x = Vec::load(weight_y + 8 * fx);
|
|
Vec src_x = Vec::load(src_y + fx * dilateX_step);
|
|
Vec::save(src_y + fx * dilateX_step, src_x + weight_x * dstV);
|
|
}
|
|
}
|
|
}
|
|
static void _MNNDeconvRunForLineDepthwise(const FLOAT16* dst, FLOAT16* src, const FLOAT16* weight, size_t width, size_t src_w_setup,
|
|
size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step) {
|
|
int dx;
|
|
for (dx = 0; dx < width; ++dx) {
|
|
auto dst_x = dst + dx * 8;
|
|
auto src_dx = src + src_w_setup * dx;
|
|
_MNNDeconvRunForUnitDepthWise(dst_x, src_dx, weight, fw, fh, fw * 8, dilateX_step, dilateY_step);
|
|
}
|
|
}
|
|
|
|
static void _MNNComputeMatMulForH_1_FP16(const float* AF, const float* BF, float* CF, const float* biasPtrF, const MatMulParam* param, size_t tId) {
|
|
auto A = (const FLOAT16*)AF;
|
|
auto B = (const FLOAT16*)BF;
|
|
auto C = (FLOAT16*)CF;
|
|
auto biasPtr = (const FLOAT16*)biasPtrF;
|
|
int e = param->e;
|
|
int l = param->l;
|
|
int numberThread = param->numberThread;
|
|
float biasValue = 0.0f;
|
|
if (nullptr != biasPtr) {
|
|
biasValue = biasPtr[0];
|
|
}
|
|
if (param->ATranspose) {
|
|
auto eC4 = e / 8;
|
|
auto eR = e % 8;
|
|
for (int y=tId; y<eC4; y+=numberThread) {
|
|
Vec sumValue = Vec(biasValue);
|
|
auto srcY = A + y * 8;
|
|
for (int x=0; x<l; ++x) {
|
|
sumValue = sumValue + Vec::load(srcY + x * e) * Vec(B[x]);
|
|
}
|
|
Vec::save(C + 8 * y, sumValue);
|
|
}
|
|
if (0 == tId && eR > 0) {
|
|
Vec sumValue = Vec(biasValue);
|
|
auto srcY = A + eC4 * 8;
|
|
FLOAT16 AR[8];
|
|
for (int x=0; x<l; ++x) {
|
|
::memcpy(AR, srcY + x * e, eR * sizeof(int16_t));
|
|
sumValue = sumValue + Vec::load(AR) * Vec(B[x]);
|
|
}
|
|
FLOAT16 CR[8];
|
|
Vec::save(CR, sumValue);
|
|
::memcpy(C + 8 * eC4, CR, eR * sizeof(int16_t));
|
|
}
|
|
return;
|
|
}
|
|
auto lC4 = l / 8;
|
|
auto lR = l % 8;
|
|
for (int y=tId; y<e; y+=numberThread) {
|
|
Vec sumValue = Vec(biasValue);
|
|
auto srcY = A + y * l;
|
|
for (int x=0; x<lC4; ++x) {
|
|
sumValue = sumValue + Vec::load(srcY + 8 * x) * Vec::load(B + 8 * x);
|
|
}
|
|
if (lR > 0) {
|
|
FLOAT16 AR[8] = {0, 0, 0, 0, 0, 0, 0, 0};
|
|
FLOAT16 BR[8] = {0, 0, 0, 0, 0, 0, 0, 0};
|
|
::memcpy(AR, srcY + lC4 * 8, lR * sizeof(int16_t));
|
|
::memcpy(BR, B + 8 * lC4, lR * sizeof(int16_t));
|
|
sumValue = sumValue + Vec::load(AR) * Vec::load(BR);
|
|
}
|
|
float sumSingle = sumValue[0] + sumValue[1] + sumValue[2] + sumValue[3] + sumValue[4] + sumValue[5] + sumValue[6] + sumValue[7];
|
|
C[y] = sumSingle;
|
|
}
|
|
}
|
|
|
|
static void _MNNComputeMatMulForE_1_FP16(const float* AF, const float* BF, float* CF, const float* biasPtrF, const MatMulParam* param, size_t tId) {
|
|
auto l = param->l;
|
|
auto h = param->h;
|
|
auto numberThread = param->numberThread;
|
|
auto lC4 = l / 8;
|
|
auto lR = l % 8;
|
|
auto A = (const FLOAT16*)AF;
|
|
auto B = (const FLOAT16*)BF;
|
|
auto C = (FLOAT16*)CF;
|
|
auto biasPtr = (const FLOAT16*)biasPtrF;
|
|
if (param->BTranspose) {
|
|
for (int y=tId; y<h; y+=numberThread) {
|
|
Vec sumValue = Vec(0.0f);
|
|
auto by = B + y * l;
|
|
for (int x=0; x<lC4; ++x) {
|
|
sumValue = Vec::fma(sumValue, Vec::load(A + x * 8), Vec::load(by + x * 8));
|
|
}
|
|
if (lR > 0) {
|
|
FLOAT16 AR[8] = {0, 0, 0, 0, 0, 0, 0, 0};
|
|
FLOAT16 BR[8] = {0, 0, 0, 0, 0, 0, 0, 0};
|
|
::memcpy(AR, A + lC4 * 8, lR * sizeof(int16_t));
|
|
::memcpy(BR, by + 8 * lC4, lR * sizeof(int16_t));
|
|
sumValue = sumValue + Vec::load(AR) * Vec::load(BR);
|
|
}
|
|
float sumRemain = sumValue[0] + sumValue[1] + sumValue[2] + sumValue[3] + sumValue[4] + sumValue[5] + sumValue[6] + sumValue[7];
|
|
if (nullptr != biasPtr) {
|
|
sumRemain += biasPtr[y];
|
|
}
|
|
C[y] = sumRemain;
|
|
}
|
|
} else {
|
|
auto hC4 = h / 8;
|
|
auto hR = h % 8;
|
|
auto hC16 = hC4 / 4;
|
|
auto hC4R = hC4 % 4;
|
|
for (int y=tId; y<hC16; y+=numberThread) {
|
|
auto biasP = biasPtr + 8 * 4 * y;
|
|
auto bs = B + 8 * 4 * y;
|
|
Vec s0 = Vec(0.0f);
|
|
Vec s1 = Vec(0.0f);
|
|
Vec s2 = Vec(0.0f);
|
|
Vec s3 = Vec(0.0f);
|
|
if (biasPtr != nullptr) {
|
|
s0 = Vec::load(biasP + 8 * 0);
|
|
s1 = Vec::load(biasP + 8 * 1);
|
|
s2 = Vec::load(biasP + 8 * 2);
|
|
s3 = Vec::load(biasP + 8 * 3);
|
|
}
|
|
auto srcY = A + y * l * 8 * 4;
|
|
for (int x=0; x<l; ++x) {
|
|
auto a = Vec(A[x]);
|
|
s0 = Vec::fma(s0, a, Vec::load(bs + h * x + 0 * 8));
|
|
s1 = Vec::fma(s1, a, Vec::load(bs + h * x + 1 * 8));
|
|
s2 = Vec::fma(s2, a, Vec::load(bs + h * x + 2 * 8));
|
|
s3 = Vec::fma(s3, a, Vec::load(bs + h * x + 3 * 8));
|
|
}
|
|
Vec::save(C + 4 * 8 * y + 8 * 0, s0);
|
|
Vec::save(C + 4 * 8 * y + 8 * 1, s1);
|
|
Vec::save(C + 4 * 8 * y + 8 * 2, s2);
|
|
Vec::save(C + 4 * 8 * y + 8 * 3, s3);
|
|
}
|
|
|
|
for (int y=hC16*4+tId; y<hC4; y+=numberThread) {
|
|
auto bs = B + 8 * y;
|
|
Vec sumValue = Vec(0.0f);
|
|
if (biasPtr != nullptr) {
|
|
sumValue = Vec::load(biasPtr + 8 * y);
|
|
}
|
|
auto srcY = A + y * l * 8;
|
|
for (int x=0; x<l; ++x) {
|
|
sumValue = Vec::fma(sumValue, Vec(A[x]), Vec::load(bs + h * x));
|
|
}
|
|
Vec::save(C + 8 * y, sumValue);
|
|
}
|
|
if (tId == 0 && hR > 0) {
|
|
auto bs = B + 8 * hC4;
|
|
Vec sumValue = Vec(0.0f);
|
|
if (biasPtr != nullptr) {
|
|
FLOAT16 biasTemp[8];
|
|
::memcpy(biasTemp, biasPtr + 8 * hC4, hR * sizeof(int16_t));
|
|
sumValue = Vec::load(biasTemp);
|
|
}
|
|
auto srcY = A + 8 * hC4 * l;
|
|
FLOAT16 bTemp[8];
|
|
for (int x=0; x<l; ++x) {
|
|
::memcpy(bTemp, bs + h * x, hR * sizeof(int16_t));
|
|
sumValue = sumValue + Vec(A[x]) * Vec::load(bTemp);
|
|
}
|
|
FLOAT16 cTemp[8];
|
|
Vec::save(cTemp, sumValue);
|
|
::memcpy(C + 8 * hC4, cTemp, hR * sizeof(int16_t));
|
|
}
|
|
}
|
|
}
|
|
|
|
template<int EP, int LP>
|
|
static void _Arm82MNNPackC4ForMatMul_A(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el) {
|
|
const int pack = 8;
|
|
int number = info[0];
|
|
int eReal = info[1];
|
|
int xStride = info[3];
|
|
int xS4 = xStride * pack / sizeof(int32_t);
|
|
int PUNIT = pack / LP;
|
|
int FLOATPACK = pack / sizeof(int32_t);
|
|
int eOutsideStride = info[2] / sizeof(int32_t);
|
|
int eDest = EP;
|
|
int realDstCount = info[4];
|
|
for (int n=0; n<number; ++n) {
|
|
int e = el[4 * n + 0];
|
|
int l = el[4 * n + 1];
|
|
int eOffset = el[4 * n + 2];
|
|
int lOffset = el[4 * n + 3];
|
|
int eC = eOffset / EP;
|
|
int eR = eOffset % EP;
|
|
int eS = eDest - eR;
|
|
bool lastBag = false;
|
|
int eOutsideStride4LastBag = eOutsideStride;
|
|
if (realDstCount % EP > 0) {
|
|
int jobsE = realDstCount - eOffset - e;
|
|
if (jobsE == 0 || (jobsE < (realDstCount % EP))) {
|
|
lastBag = true;
|
|
}
|
|
}
|
|
auto source = (int32_t*)sourceGroup[n];
|
|
auto dest = (int32_t*)(destOrigin + eC * info[2] + eR * LP + lOffset * EP);
|
|
//printf("e=%d, l=%d, eOffset=%d, lOffset=%d, eDest=%d\n", e, l, eOffset, lOffset, eDest);
|
|
l = l / 4; // Use float instead of int8 * 4
|
|
if (lastBag && e + eR < EP) {
|
|
int elast = ALIMAX(eR + e, realDstCount % EP);
|
|
dest = (int32_t*)(destOrigin + lOffset * elast + eC * info[2] + eR * LP);
|
|
}
|
|
int offsetLC = lOffset / 4;
|
|
for (int x = 0; x < l; ++x) {
|
|
int eRemain = e;
|
|
auto xR = x % PUNIT;
|
|
auto xC = x / PUNIT;
|
|
auto d = dest;
|
|
auto s = source + xC * eReal * FLOATPACK + xR;
|
|
if (eR > 0) {
|
|
int eStep = ALIMIN(eRemain, eS);
|
|
for (int yi=0; yi<eStep; ++yi) {
|
|
d[yi] = s[yi * xS4];
|
|
}
|
|
eRemain-=eStep;
|
|
if (!lastBag ||eRemain >= EP) {
|
|
d += (eOutsideStride - eR);
|
|
} else {
|
|
int eFill = ALIMAX(eRemain, realDstCount % EP); // maybe padding>0
|
|
eOutsideStride4LastBag = eOutsideStride - (EP * 4 * offsetLC / sizeof(float));
|
|
d += (eOutsideStride4LastBag - eR + offsetLC * eFill);
|
|
}
|
|
s += eS * xS4;
|
|
}
|
|
while (eRemain > 0) {
|
|
int eStep = ALIMIN(eDest, eRemain);
|
|
for (int yi=0; yi<eStep; ++yi) {
|
|
d[yi] = s[yi * xS4];
|
|
}
|
|
eRemain-=eStep;
|
|
if (!lastBag || eRemain >= EP) {
|
|
d+= eOutsideStride;
|
|
} else {
|
|
int eFill = ALIMAX(eRemain, realDstCount % EP); // maybe padding>0
|
|
eOutsideStride4LastBag = eOutsideStride - (EP * 4 * offsetLC / sizeof(float));
|
|
d+= (eOutsideStride4LastBag + offsetLC * eFill);
|
|
}
|
|
s+= eStep * xS4;
|
|
}
|
|
if (lastBag && e + eR < EP) {
|
|
int efill = ALIMAX(e + eR, realDstCount % EP);
|
|
dest += efill;
|
|
} else {
|
|
dest += eDest;
|
|
}
|
|
offsetLC++;
|
|
}
|
|
}
|
|
}
|
|
|
|
template<int EP, int HP>
|
|
static void _ArmBasicMNNPackC4ForMatMul_A_L8(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el) {
|
|
int number = info[0];
|
|
int eReal = info[1];
|
|
int eDest = EP;
|
|
int offset = info[3];
|
|
const int LP = 8;
|
|
int eOutsideStride = info[2] / sizeof(int64_t);
|
|
int realDstCount = info[4];
|
|
for (int n=0; n<number; ++n) {
|
|
int e = el[4 * n + 0];
|
|
int l = el[4 * n + 1];
|
|
int eOffset = el[4 * n + 2];
|
|
int lOffset = el[4 * n + 3];
|
|
int eC = eOffset / EP;
|
|
int eR = eOffset % EP;
|
|
int eS = eDest - eR;
|
|
bool lastBag = false;
|
|
int eOutsideStride4LastBag = eOutsideStride;
|
|
int eres = realDstCount - eOffset;
|
|
if (realDstCount % EP > 0) {
|
|
int jobsE = realDstCount - eOffset - e;
|
|
if (jobsE == 0 || (jobsE < (realDstCount % EP))) {
|
|
lastBag = true;
|
|
}
|
|
}
|
|
auto dest = (int64_t*)(destOrigin + lOffset * eDest + eC * info[2] + eR * LP);
|
|
auto source = (int64_t*)sourceGroup[n];
|
|
int lRemain = l / LP;
|
|
if (lastBag && e + eR < EP) {
|
|
int elast = ALIMIN(ALIMAX(eR + e, realDstCount % EP), EP);
|
|
dest = (int64_t*)(destOrigin + lOffset * elast + eC * info[2] + eR * LP);
|
|
}
|
|
int offsetLC = lOffset / LP;
|
|
for (int x = 0; x < lRemain; ++x) {
|
|
int eRemain = e;
|
|
auto d = dest;
|
|
auto s = source;
|
|
if (1 == offset) {
|
|
if (eR > 0) {
|
|
int eStep = ALIMIN(eRemain, eS);
|
|
::memcpy(d, s, eStep * sizeof(int64_t));
|
|
eRemain-=eStep;
|
|
if (!lastBag ||eRemain >= EP) {
|
|
d += (eOutsideStride - eR);
|
|
} else {
|
|
int eFill = ALIMAX(eRemain, realDstCount % EP); // maybe padding>0
|
|
eOutsideStride4LastBag = eOutsideStride - (EP * offsetLC);
|
|
d += (eOutsideStride4LastBag - eR + offsetLC * eFill);
|
|
}
|
|
s += (eS * offset);
|
|
}
|
|
while (eRemain > 0) {
|
|
int eStep = ALIMIN(eDest, eRemain);
|
|
::memcpy(d, s, eStep * sizeof(int64_t));
|
|
eRemain-=eStep;
|
|
if (!lastBag || eRemain >= EP) {
|
|
d+= eOutsideStride;
|
|
} else {
|
|
int eFill = ALIMAX(eRemain, realDstCount % EP); // maybe padding>0
|
|
eOutsideStride4LastBag = eOutsideStride - (EP * offsetLC);
|
|
d+= (eOutsideStride4LastBag + offsetLC * eFill);
|
|
}
|
|
s+= (eStep * offset);
|
|
}
|
|
} else {
|
|
if (eR > 0) {
|
|
int eStep = ALIMIN(eRemain, eS);
|
|
for (int yi=0; yi<eStep; ++yi) {
|
|
d[yi] = s[yi * offset];
|
|
}
|
|
eRemain-=eStep;
|
|
if (!lastBag ||eRemain >= EP) {
|
|
d += (eOutsideStride - eR);
|
|
} else {
|
|
int eFill = ALIMAX(eRemain, realDstCount % EP); // maybe padding>0
|
|
eOutsideStride4LastBag = eOutsideStride - (EP * offsetLC);
|
|
d += (eOutsideStride4LastBag - eR + offsetLC * eFill);
|
|
}
|
|
s += eS * offset;
|
|
}
|
|
while (eRemain > 0) {
|
|
int eStep = ALIMIN(eDest, eRemain);
|
|
for (int yi=0; yi<eStep; ++yi) {
|
|
d[yi] = s[yi * offset];
|
|
}
|
|
eRemain-=eStep;
|
|
if (!lastBag || eRemain >= EP) {
|
|
d+= eOutsideStride;
|
|
} else {
|
|
int eFill = ALIMAX(eRemain, realDstCount % EP); // maybe padding>0
|
|
eOutsideStride4LastBag = eOutsideStride - (EP * offsetLC);
|
|
d+= (eOutsideStride4LastBag + offsetLC * eFill);
|
|
}
|
|
s+= eStep * offset;
|
|
}
|
|
}
|
|
source += eReal;
|
|
if (lastBag && e + eR < EP ) { // eR=0;eR>0
|
|
int efill = ALIMAX(e + eR, realDstCount % EP);
|
|
dest += efill;
|
|
} else {
|
|
dest += eDest;
|
|
}
|
|
offsetLC++;
|
|
}
|
|
}
|
|
}
|
|
|
|
static void Sme2MNNPackForMatMul_A_FP16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) {
|
|
int LP = 2;
|
|
int pack = 8;
|
|
int eDest = 16;
|
|
// LP >= pack
|
|
int number = info[0];
|
|
int eReal = info[1];
|
|
int offset = info[3];
|
|
for (int n=0; n<number; ++n) {
|
|
int eWork = el[4 * n + 0];
|
|
int lWork = el[4 * n + 1];
|
|
int eOffset = el[4 * n + 2];
|
|
int lOffset = el[4 * n + 3];
|
|
auto sourceN = (FLOAT16*)(sourceGroup[n]);
|
|
auto destN = (FLOAT16*)destOrigin + lOffset * eDest + eOffset * LP;
|
|
|
|
auto srcStride0 = pack * offset;
|
|
auto dstStride0 = eDest * LP;
|
|
auto l = lWork;
|
|
while (l > 7) {
|
|
auto source = sourceN;
|
|
auto dest = destN;
|
|
l -= 8;
|
|
auto e = eWork;
|
|
if (e == eDest) {
|
|
auto s0 = vld1q_f32((float*)(source)); // 00112233
|
|
auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233
|
|
auto s2 = vld1q_f32((float*)(source + 2 * srcStride0));
|
|
auto s3 = vld1q_f32((float*)(source + 3 * srcStride0));
|
|
|
|
auto s4 = vld1q_f32((float*)(source + 4 * srcStride0));
|
|
auto s5 = vld1q_f32((float*)(source + 5 * srcStride0));
|
|
auto s6 = vld1q_f32((float*)(source + 6 * srcStride0));
|
|
auto s7 = vld1q_f32((float*)(source + 7 * srcStride0));
|
|
|
|
auto s8 = vld1q_f32((float*)(source + 8 * srcStride0));
|
|
auto s9 = vld1q_f32((float*)(source + 9 * srcStride0));
|
|
auto s10 = vld1q_f32((float*)(source + 10 * srcStride0));
|
|
auto s11 = vld1q_f32((float*)(source + 11 * srcStride0));
|
|
|
|
auto s12 = vld1q_f32((float*)(source + 12 * srcStride0));
|
|
auto s13 = vld1q_f32((float*)(source + 13 * srcStride0));
|
|
auto s14 = vld1q_f32((float*)(source + 14 * srcStride0));
|
|
auto s15 = vld1q_f32((float*)(source + 15 * srcStride0));
|
|
|
|
auto zip1s01 = vzip1q_f32(s0, s1); // 00001111
|
|
auto zip1s23 = vzip1q_f32(s2, s3); // 00001111
|
|
auto zip1s45 = vzip1q_f32(s4, s5); // 00001111
|
|
auto zip1s67 = vzip1q_f32(s6, s7); // 00001111
|
|
auto zip1s89 = vzip1q_f32(s8, s9); // 00001111
|
|
auto zip1s1011 = vzip1q_f32(s10, s11); // 00001111
|
|
auto zip1s1213 = vzip1q_f32(s12, s13); // 00001111
|
|
auto zip1s1415 = vzip1q_f32(s14, s15); // 00001111
|
|
|
|
auto zip2s01 = vzip2q_f32(s0, s1); // 22223333
|
|
auto zip2s23 = vzip2q_f32(s2, s3); // 22223333
|
|
auto zip2s45 = vzip2q_f32(s4, s5); // 22223333
|
|
auto zip2s67 = vzip2q_f32(s6, s7); // 22223333
|
|
auto zip2s89 = vzip2q_f32(s8, s9); // 22223333
|
|
auto zip2s1011 = vzip2q_f32(s10, s11); // 22223333
|
|
auto zip2s1213 = vzip2q_f32(s12, s13); // 22223333
|
|
auto zip2s1415 = vzip2q_f32(s14, s15); // 22223333
|
|
|
|
auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000
|
|
auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
|
|
auto zip1s891011_01 = vzip1q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011);
|
|
auto zip1s12131415_01 = vzip1q_f64((float64x2_t)zip1s1213, (float64x2_t)zip1s1415);
|
|
|
|
auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111
|
|
auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
|
|
auto zip2s891011_01 = vzip2q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011);
|
|
auto zip2s12131415_01 = vzip2q_f64((float64x2_t)zip1s1213, (float64x2_t)zip1s1415);
|
|
|
|
auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222
|
|
auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
|
|
auto zip1s891011_23 = vzip1q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011);
|
|
auto zip1s12131415_23 = vzip1q_f64((float64x2_t)zip2s1213, (float64x2_t)zip2s1415);
|
|
|
|
auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333
|
|
auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
|
|
auto zip2s891011_23 = vzip2q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011);
|
|
auto zip2s12131415_23 = vzip2q_f64((float64x2_t)zip2s1213, (float64x2_t)zip2s1415);
|
|
|
|
vst1q_f64((float64_t*)dest, zip1s0123_01);
|
|
vst1q_f64((float64_t*)(dest + 8), zip1s4567_01);
|
|
vst1q_f64((float64_t*)(dest + 16), zip1s891011_01);
|
|
vst1q_f64((float64_t*)(dest + 24), zip1s12131415_01);
|
|
|
|
vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01);
|
|
vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01);
|
|
vst1q_f64((float64_t*)(dest + dstStride0 + 16), zip2s891011_01);
|
|
vst1q_f64((float64_t*)(dest + dstStride0 + 24), zip2s12131415_01);
|
|
|
|
vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23);
|
|
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23);
|
|
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 16), zip1s891011_23);
|
|
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 24), zip1s12131415_23);
|
|
|
|
vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23);
|
|
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23);
|
|
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 16), zip2s891011_23);
|
|
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 24), zip2s12131415_23);
|
|
|
|
// dest += (4 * dstStride0);
|
|
// e -= eDest;
|
|
sourceN += (eReal * pack);
|
|
destN += (4 * dstStride0);
|
|
continue;
|
|
}
|
|
|
|
if (e > 11) {
|
|
auto s0 = vld1q_f32((float*)(source)); // 00112233
|
|
auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233
|
|
auto s2 = vld1q_f32((float*)(source + 2 * srcStride0));
|
|
auto s3 = vld1q_f32((float*)(source + 3 * srcStride0));
|
|
|
|
auto s4 = vld1q_f32((float*)(source + 4 * srcStride0));
|
|
auto s5 = vld1q_f32((float*)(source + 5 * srcStride0));
|
|
auto s6 = vld1q_f32((float*)(source + 6 * srcStride0));
|
|
auto s7 = vld1q_f32((float*)(source + 7 * srcStride0));
|
|
|
|
auto s8 = vld1q_f32((float*)(source + 8 * srcStride0));
|
|
auto s9 = vld1q_f32((float*)(source + 9 * srcStride0));
|
|
auto s10 = vld1q_f32((float*)(source + 10 * srcStride0));
|
|
auto s11 = vld1q_f32((float*)(source + 11 * srcStride0));
|
|
|
|
auto zip1s01 = vzip1q_f32(s0, s1); // 00001111
|
|
auto zip1s23 = vzip1q_f32(s2, s3); // 00001111
|
|
auto zip1s45 = vzip1q_f32(s4, s5); // 00001111
|
|
auto zip1s67 = vzip1q_f32(s6, s7); // 00001111
|
|
auto zip1s89 = vzip1q_f32(s8, s9); // 00001111
|
|
auto zip1s1011 = vzip1q_f32(s10, s11); // 00001111
|
|
|
|
auto zip2s01 = vzip2q_f32(s0, s1); // 22223333
|
|
auto zip2s23 = vzip2q_f32(s2, s3); // 22223333
|
|
auto zip2s45 = vzip2q_f32(s4, s5); // 22223333
|
|
auto zip2s67 = vzip2q_f32(s6, s7); // 22223333
|
|
auto zip2s89 = vzip2q_f32(s8, s9); // 22223333
|
|
auto zip2s1011 = vzip2q_f32(s10, s11); // 22223333
|
|
|
|
auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000
|
|
auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
|
|
auto zip1s891011_01 = vzip1q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011);
|
|
|
|
auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111
|
|
auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
|
|
auto zip2s891011_01 = vzip2q_f64((float64x2_t)zip1s89, (float64x2_t)zip1s1011);
|
|
|
|
auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222
|
|
auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
|
|
auto zip1s891011_23 = vzip1q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011);
|
|
|
|
auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333
|
|
auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
|
|
auto zip2s891011_23 = vzip2q_f64((float64x2_t)zip2s89, (float64x2_t)zip2s1011);
|
|
|
|
vst1q_f64((float64_t*)dest, zip1s0123_01);
|
|
vst1q_f64((float64_t*)(dest + 8), zip1s4567_01);
|
|
vst1q_f64((float64_t*)(dest + 16), zip1s891011_01);
|
|
|
|
vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01);
|
|
vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01);
|
|
vst1q_f64((float64_t*)(dest + dstStride0 + 16), zip2s891011_01);
|
|
|
|
vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23);
|
|
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23);
|
|
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 16), zip1s891011_23);
|
|
|
|
vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23);
|
|
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23);
|
|
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 16), zip2s891011_23);
|
|
|
|
dest += 24;
|
|
e -= 12;
|
|
source += (12 * srcStride0);
|
|
}
|
|
|
|
if (e > 7) {
|
|
auto s0 = vld1q_f32((float*)(source)); // 00112233
|
|
auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233
|
|
auto s2 = vld1q_f32((float*)(source + 2 * srcStride0));
|
|
auto s3 = vld1q_f32((float*)(source + 3 * srcStride0));
|
|
|
|
auto s4 = vld1q_f32((float*)(source + 4 * srcStride0));
|
|
auto s5 = vld1q_f32((float*)(source + 5 * srcStride0));
|
|
auto s6 = vld1q_f32((float*)(source + 6 * srcStride0));
|
|
auto s7 = vld1q_f32((float*)(source + 7 * srcStride0));
|
|
|
|
auto zip1s01 = vzip1q_f32(s0, s1); // 00001111
|
|
auto zip1s23 = vzip1q_f32(s2, s3); // 00001111
|
|
auto zip1s45 = vzip1q_f32(s4, s5); // 00001111
|
|
auto zip1s67 = vzip1q_f32(s6, s7); // 00001111
|
|
|
|
auto zip2s01 = vzip2q_f32(s0, s1); // 22223333
|
|
auto zip2s23 = vzip2q_f32(s2, s3); // 22223333
|
|
auto zip2s45 = vzip2q_f32(s4, s5); // 22223333
|
|
auto zip2s67 = vzip2q_f32(s6, s7); // 22223333
|
|
|
|
auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000
|
|
auto zip1s4567_01 = vzip1q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
|
|
|
|
auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111
|
|
auto zip2s4567_01 = vzip2q_f64((float64x2_t)zip1s45, (float64x2_t)zip1s67);
|
|
|
|
auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222
|
|
auto zip1s4567_23 = vzip1q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
|
|
|
|
auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333
|
|
auto zip2s4567_23 = vzip2q_f64((float64x2_t)zip2s45, (float64x2_t)zip2s67);
|
|
|
|
vst1q_f64((float64_t*)dest, zip1s0123_01);
|
|
vst1q_f64((float64_t*)(dest + 8), zip1s4567_01);
|
|
|
|
vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01);
|
|
vst1q_f64((float64_t*)(dest + dstStride0 + 8), zip2s4567_01);
|
|
|
|
vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23);
|
|
vst1q_f64((float64_t*)(dest + 2 * dstStride0 + 8), zip1s4567_23);
|
|
|
|
vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23);
|
|
vst1q_f64((float64_t*)(dest + 3 * dstStride0 + 8), zip2s4567_23);
|
|
|
|
dest += 16;
|
|
e -= 8;
|
|
source += (8 * srcStride0);
|
|
}
|
|
|
|
if (e > 3) {
|
|
auto s0 = vld1q_f32((float*)(source)); // 00112233
|
|
auto s1 = vld1q_f32((float*)(source + srcStride0));// 00112233
|
|
auto s2 = vld1q_f32((float*)(source + 2 * srcStride0));
|
|
auto s3 = vld1q_f32((float*)(source + 3 * srcStride0));
|
|
|
|
auto zip1s01 = vzip1q_f32(s0, s1); // 00001111
|
|
auto zip1s23 = vzip1q_f32(s2, s3); // 00001111
|
|
|
|
auto zip2s01 = vzip2q_f32(s0, s1); // 22223333
|
|
auto zip2s23 = vzip2q_f32(s2, s3); // 22223333
|
|
|
|
auto zip1s0123_01 = vzip1q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 00000000
|
|
|
|
auto zip2s0123_01 = vzip2q_f64((float64x2_t)zip1s01, (float64x2_t)zip1s23); // 11111111
|
|
|
|
auto zip1s0123_23 = vzip1q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 22222222
|
|
|
|
auto zip2s0123_23 = vzip2q_f64((float64x2_t)zip2s01, (float64x2_t)zip2s23); // 33333333
|
|
|
|
vst1q_f64((float64_t*)dest, zip1s0123_01);
|
|
vst1q_f64((float64_t*)(dest + dstStride0), zip2s0123_01);
|
|
vst1q_f64((float64_t*)(dest + 2 * dstStride0), zip1s0123_23);
|
|
vst1q_f64((float64_t*)(dest + 3 * dstStride0), zip2s0123_23);
|
|
|
|
dest += 8;
|
|
e -= 4;
|
|
source += (4 * srcStride0);
|
|
}
|
|
while (e > 0) {
|
|
auto s0 = vld1q_f32((float*)(source)); // 00112233
|
|
|
|
((float*)dest)[0] = s0[0];
|
|
((float*)(dest + dstStride0))[0] = s0[1];
|
|
((float*)(dest + 2 * dstStride0))[0] = s0[2];
|
|
((float*)(dest + 3 * dstStride0))[0] = s0[3];
|
|
|
|
dest += 2;
|
|
e -= 1;
|
|
source += srcStride0;
|
|
}
|
|
sourceN += (eReal * pack);
|
|
destN += (4 * dstStride0);
|
|
} // l>7
|
|
|
|
if (l > 3) {
|
|
auto source = sourceN;
|
|
auto dest = destN;
|
|
l -= 4;
|
|
auto e = eWork;
|
|
if (e == eDest) {
|
|
auto s0 = vld1_f32((float*)(source)); // 0011
|
|
auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
|
|
auto s2 = vld1_f32((float*)(source + 2 * srcStride0));
|
|
auto s3 = vld1_f32((float*)(source + 3 * srcStride0));
|
|
|
|
auto s4 = vld1_f32((float*)(source + 4 * srcStride0));
|
|
auto s5 = vld1_f32((float*)(source + 5 * srcStride0));
|
|
auto s6 = vld1_f32((float*)(source + 6 * srcStride0));
|
|
auto s7 = vld1_f32((float*)(source + 7 * srcStride0));
|
|
|
|
auto s8 = vld1_f32((float*)(source + 8 * srcStride0));
|
|
auto s9 = vld1_f32((float*)(source + 9 * srcStride0));
|
|
auto s10 = vld1_f32((float*)(source + 10 * srcStride0));
|
|
auto s11 = vld1_f32((float*)(source + 11 * srcStride0));
|
|
|
|
auto s12 = vld1_f32((float*)(source + 12 * srcStride0));
|
|
auto s13 = vld1_f32((float*)(source + 13 * srcStride0));
|
|
auto s14 = vld1_f32((float*)(source + 14 * srcStride0));
|
|
auto s15 = vld1_f32((float*)(source + 15 * srcStride0));
|
|
|
|
auto zip1s01 = vzip1_f32(s0, s1); // 0000
|
|
auto zip1s23 = vzip1_f32(s2, s3); // 0000
|
|
auto zip1s45 = vzip1_f32(s4, s5); // 0000
|
|
auto zip1s67 = vzip1_f32(s6, s7); // 0000
|
|
auto zip1s89 = vzip1_f32(s8, s9); // 0000
|
|
auto zip1s1011 = vzip1_f32(s10, s11); // 0000
|
|
auto zip1s1213 = vzip1_f32(s12, s13); // 0000
|
|
auto zip1s1415 = vzip1_f32(s14, s15); // 0000
|
|
|
|
auto zip2s01 = vzip2_f32(s0, s1); // 1111
|
|
auto zip2s23 = vzip2_f32(s2, s3); // 1111
|
|
auto zip2s45 = vzip2_f32(s4, s5); // 1111
|
|
auto zip2s67 = vzip2_f32(s6, s7); // 1111
|
|
auto zip2s89 = vzip2_f32(s8, s9); // 1111
|
|
auto zip2s1011 = vzip2_f32(s10, s11); // 1111
|
|
auto zip2s1213 = vzip2_f32(s12, s13); // 1111
|
|
auto zip2s1415 = vzip2_f32(s14, s15); // 1111
|
|
|
|
vst1_f32((float32_t*)dest, zip1s01);
|
|
vst1_f32((float32_t*)(dest + 4), zip1s23);
|
|
vst1_f32((float32_t*)(dest + 8), zip1s45);
|
|
vst1_f32((float32_t*)(dest + 12), zip1s67);
|
|
vst1_f32((float32_t*)(dest + 16), zip1s89);
|
|
vst1_f32((float32_t*)(dest + 20), zip1s1011);
|
|
vst1_f32((float32_t*)(dest + 24), zip1s1213);
|
|
vst1_f32((float32_t*)(dest + 28), zip1s1415);
|
|
|
|
vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
|
|
vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23);
|
|
vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45);
|
|
vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67);
|
|
vst1_f32((float32_t*)(dest + dstStride0 + 16), zip2s89);
|
|
vst1_f32((float32_t*)(dest + dstStride0 + 20), zip2s1011);
|
|
vst1_f32((float32_t*)(dest + dstStride0 + 24), zip2s1213);
|
|
vst1_f32((float32_t*)(dest + dstStride0 + 28), zip2s1415);
|
|
|
|
|
|
dest += 32;
|
|
e -= eDest;
|
|
}
|
|
|
|
if (e > 11) {
|
|
auto s0 = vld1_f32((float*)(source)); // 0011
|
|
auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
|
|
auto s2 = vld1_f32((float*)(source + 2 * srcStride0));
|
|
auto s3 = vld1_f32((float*)(source + 3 * srcStride0));
|
|
|
|
auto s4 = vld1_f32((float*)(source + 4 * srcStride0));
|
|
auto s5 = vld1_f32((float*)(source + 5 * srcStride0));
|
|
auto s6 = vld1_f32((float*)(source + 6 * srcStride0));
|
|
auto s7 = vld1_f32((float*)(source + 7 * srcStride0));
|
|
|
|
auto s8 = vld1_f32((float*)(source + 8 * srcStride0));
|
|
auto s9 = vld1_f32((float*)(source + 9 * srcStride0));
|
|
auto s10 = vld1_f32((float*)(source + 10 * srcStride0));
|
|
auto s11 = vld1_f32((float*)(source + 11 * srcStride0));
|
|
|
|
auto zip1s01 = vzip1_f32(s0, s1); // 0000
|
|
auto zip1s23 = vzip1_f32(s2, s3); // 0000
|
|
auto zip1s45 = vzip1_f32(s4, s5); // 0000
|
|
auto zip1s67 = vzip1_f32(s6, s7); // 0000
|
|
auto zip1s89 = vzip1_f32(s8, s9); // 0000
|
|
auto zip1s1011 = vzip1_f32(s10, s11); // 0000
|
|
|
|
auto zip2s01 = vzip2_f32(s0, s1); // 1111
|
|
auto zip2s23 = vzip2_f32(s2, s3); // 1111
|
|
auto zip2s45 = vzip2_f32(s4, s5); // 1111
|
|
auto zip2s67 = vzip2_f32(s6, s7); // 1111
|
|
auto zip2s89 = vzip2_f32(s8, s9); // 1111
|
|
auto zip2s1011 = vzip2_f32(s10, s11); // 1111
|
|
|
|
vst1_f32((float32_t*)dest, zip1s01);
|
|
vst1_f32((float32_t*)(dest + 4), zip1s23);
|
|
vst1_f32((float32_t*)(dest + 8), zip1s45);
|
|
vst1_f32((float32_t*)(dest + 12), zip1s67);
|
|
vst1_f32((float32_t*)(dest + 16), zip1s89);
|
|
vst1_f32((float32_t*)(dest + 20), zip1s1011);
|
|
|
|
vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
|
|
vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23);
|
|
vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45);
|
|
vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67);
|
|
vst1_f32((float32_t*)(dest + dstStride0 + 16), zip2s89);
|
|
vst1_f32((float32_t*)(dest + dstStride0 + 20), zip2s1011);
|
|
|
|
dest += 24;
|
|
e -= 12;
|
|
source += (12 * srcStride0);
|
|
}
|
|
|
|
if (e > 7) {
|
|
auto s0 = vld1_f32((float*)(source)); // 0011
|
|
auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
|
|
auto s2 = vld1_f32((float*)(source + 2 * srcStride0));
|
|
auto s3 = vld1_f32((float*)(source + 3 * srcStride0));
|
|
|
|
auto s4 = vld1_f32((float*)(source + 4 * srcStride0));
|
|
auto s5 = vld1_f32((float*)(source + 5 * srcStride0));
|
|
auto s6 = vld1_f32((float*)(source + 6 * srcStride0));
|
|
auto s7 = vld1_f32((float*)(source + 7 * srcStride0));
|
|
|
|
auto zip1s01 = vzip1_f32(s0, s1); // 0000
|
|
auto zip1s23 = vzip1_f32(s2, s3); // 0000
|
|
auto zip1s45 = vzip1_f32(s4, s5); // 0000
|
|
auto zip1s67 = vzip1_f32(s6, s7); // 0000
|
|
|
|
auto zip2s01 = vzip2_f32(s0, s1); // 1111
|
|
auto zip2s23 = vzip2_f32(s2, s3); // 1111
|
|
auto zip2s45 = vzip2_f32(s4, s5); // 1111
|
|
auto zip2s67 = vzip2_f32(s6, s7); // 1111
|
|
|
|
vst1_f32((float32_t*)dest, zip1s01);
|
|
vst1_f32((float32_t*)(dest + 4), zip1s23);
|
|
vst1_f32((float32_t*)(dest + 8), zip1s45);
|
|
vst1_f32((float32_t*)(dest + 12), zip1s67);
|
|
|
|
vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
|
|
vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23);
|
|
vst1_f32((float32_t*)(dest + dstStride0 + 8), zip2s45);
|
|
vst1_f32((float32_t*)(dest + dstStride0 + 12), zip2s67);
|
|
|
|
dest += 16;
|
|
e -= 8;
|
|
source += (8 * srcStride0);
|
|
}
|
|
|
|
if (e > 3) {
|
|
auto s0 = vld1_f32((float*)(source)); // 0011
|
|
auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
|
|
auto s2 = vld1_f32((float*)(source + 2 * srcStride0));
|
|
auto s3 = vld1_f32((float*)(source + 3 * srcStride0));
|
|
|
|
auto zip1s01 = vzip1_f32(s0, s1); // 0000
|
|
auto zip1s23 = vzip1_f32(s2, s3); // 0000
|
|
|
|
auto zip2s01 = vzip2_f32(s0, s1); // 1111
|
|
auto zip2s23 = vzip2_f32(s2, s3); // 1111
|
|
|
|
vst1_f32((float32_t*)dest, zip1s01);
|
|
vst1_f32((float32_t*)(dest + 4), zip1s23);
|
|
|
|
vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
|
|
vst1_f32((float32_t*)(dest + dstStride0 + 4), zip2s23);
|
|
|
|
dest += 8;
|
|
e -= 4;
|
|
source += (4 * srcStride0);
|
|
}
|
|
if (e > 1) {
|
|
auto s0 = vld1_f32((float*)(source)); // 0011
|
|
auto s1 = vld1_f32((float*)(source + srcStride0));// 0011
|
|
|
|
auto zip1s01 = vzip1_f32(s0, s1); // 0000
|
|
|
|
auto zip2s01 = vzip2_f32(s0, s1); // 1111
|
|
|
|
vst1_f32((float32_t*)dest, zip1s01);
|
|
|
|
vst1_f32((float32_t*)(dest + dstStride0), zip2s01);
|
|
|
|
dest += 4;
|
|
e -= 2;
|
|
source += (2 * srcStride0);
|
|
}
|
|
if (e > 0) {
|
|
auto s0 = vld1_f32((float*)(source)); // 0011
|
|
|
|
((float*)dest)[0] = s0[0];
|
|
((float*)(dest + dstStride0))[0] = s0[1];
|
|
}
|
|
sourceN += 4;
|
|
destN += (2 * dstStride0);
|
|
}
|
|
|
|
auto source = (FLOAT16*)(sourceGroup[n]);
|
|
auto dest = (FLOAT16*)destOrigin + lOffset * eDest + eOffset * LP;
|
|
if (l > 0) {
|
|
auto e = eWork;
|
|
auto lRemain = lWork - l;
|
|
// if e < eDest, packed A -> [LU, eDest, LP] eDest=eP
|
|
for (int y=0; y<e; ++y) {
|
|
auto yR = y % eDest;
|
|
for (int x=lRemain; x<lWork; ++x) {
|
|
auto xR = x % pack;
|
|
auto xC = x / pack;
|
|
auto xOut = x / LP;
|
|
auto xIn = x % LP;
|
|
dest[xOut * eDest * LP + yR * LP + xIn] = source[xC * eReal * pack + y * pack * offset + xR];
|
|
}
|
|
}
|
|
l--;
|
|
}
|
|
}
|
|
}
|
|
|
|
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
|
static void MNNAttenPackAndScaleSingleHead(float* dst, const float* srcHeadBase, size_t srcRowStride, const float* scale, const int32_t* units, size_t seqLen, size_t headDim) {
|
|
const int32_t eP = units[0];
|
|
const int32_t lP = units[1];
|
|
|
|
if (lP != 1 && lP != 2) {
|
|
MNN_ERROR("This function only supports lP=1 or 2\n");
|
|
return;
|
|
}
|
|
|
|
const float scaleVal = scale[0];
|
|
const float16x8_t vScale = vdupq_n_f16(scaleVal);
|
|
|
|
const size_t packedHeadDim = UP_DIV(headDim, lP);
|
|
const size_t dstStrideDOuter = (size_t)eP * lP;
|
|
const size_t dstStrideSOuter = packedHeadDim * dstStrideDOuter;
|
|
|
|
for (int s = 0; s < seqLen; ++s) {
|
|
const int sOuter = s / eP;
|
|
const int sInner = s % eP;
|
|
const FLOAT16* srcRowPtr = (FLOAT16*)srcHeadBase + s * srcRowStride;
|
|
FLOAT16* dstBasePtr = (FLOAT16*)dst + sOuter * dstStrideSOuter + sInner * lP;
|
|
|
|
if (lP == 1) {
|
|
size_t d = 0;
|
|
for (; d + 7 < headDim; d += 8) {
|
|
float16x8_t sVec = vld1q_f16(srcRowPtr + d);
|
|
sVec = vmulq_f16(sVec, vScale);
|
|
|
|
dstBasePtr[(d + 0) * dstStrideDOuter] = sVec[0];
|
|
dstBasePtr[(d + 1) * dstStrideDOuter] = sVec[1];
|
|
dstBasePtr[(d + 2) * dstStrideDOuter] = sVec[2];
|
|
dstBasePtr[(d + 3) * dstStrideDOuter] = sVec[3];
|
|
dstBasePtr[(d + 4) * dstStrideDOuter] = sVec[4];
|
|
dstBasePtr[(d + 5) * dstStrideDOuter] = sVec[5];
|
|
dstBasePtr[(d + 6) * dstStrideDOuter] = sVec[6];
|
|
dstBasePtr[(d + 7) * dstStrideDOuter] = sVec[7];
|
|
}
|
|
for (; d < headDim; ++d) {
|
|
dstBasePtr[d * dstStrideDOuter] = srcRowPtr[d] * scaleVal;
|
|
}
|
|
} else { // lP == 2
|
|
const FLOAT16* srcDPtr = srcRowPtr;
|
|
FLOAT16* dstDPtr = dstBasePtr;
|
|
size_t dRealSize = headDim;
|
|
|
|
while (dRealSize >= 16) {
|
|
float16x8_t s0 = vld1q_f16(srcDPtr);
|
|
float16x8_t s1 = vld1q_f16(srcDPtr + 8);
|
|
s0 = vmulq_f16(s0, vScale);
|
|
s1 = vmulq_f16(s1, vScale);
|
|
|
|
float16x4_t lowS0_f16 = vget_low_f16(s0); // {s0, s1, s2, s3}
|
|
float16x4_t highS0_f16 = vget_high_f16(s0); // {s4, s5, s6, s7}
|
|
uint32x2_t lowS0_u32 = vreinterpret_u32_f16(lowS0_f16);
|
|
uint32x2_t highS0_u32 = vreinterpret_u32_f16(highS0_f16);
|
|
|
|
*((uint32_t*)(dstDPtr + 0 * dstStrideDOuter)) = vget_lane_u32(lowS0_u32, 0); // Store pair {s0, s1}
|
|
*((uint32_t*)(dstDPtr + 1 * dstStrideDOuter)) = vget_lane_u32(lowS0_u32, 1); // Store pair {s2, s3}
|
|
*((uint32_t*)(dstDPtr + 2 * dstStrideDOuter)) = vget_lane_u32(highS0_u32, 0); // Store pair {s4, s5}
|
|
*((uint32_t*)(dstDPtr + 3 * dstStrideDOuter)) = vget_lane_u32(highS0_u32, 1); // Store pair {s6, s7}
|
|
|
|
float16x4_t lowS1_f16 = vget_low_f16(s1); // {s8, s9, s10, s11}
|
|
float16x4_t highS1_f16 = vget_high_f16(s1); // {s12, s13, s14, s15}
|
|
uint32x2_t lowS1_u32 = vreinterpret_u32_f16(lowS1_f16);
|
|
uint32x2_t highS1_u32 = vreinterpret_u32_f16(highS1_f16);
|
|
|
|
*((uint32_t*)(dstDPtr + 4 * dstStrideDOuter)) = vget_lane_u32(lowS1_u32, 0);
|
|
*((uint32_t*)(dstDPtr + 5 * dstStrideDOuter)) = vget_lane_u32(lowS1_u32, 1);
|
|
*((uint32_t*)(dstDPtr + 6 * dstStrideDOuter)) = vget_lane_u32(highS1_u32, 0);
|
|
*((uint32_t*)(dstDPtr + 7 * dstStrideDOuter)) = vget_lane_u32(highS1_u32, 1);
|
|
|
|
dRealSize -= 16;
|
|
srcDPtr += 16;
|
|
dstDPtr += 8 * dstStrideDOuter;
|
|
}
|
|
// Remainder loop with padding
|
|
while (dRealSize > 0) {
|
|
if (dRealSize >= 2) {
|
|
dstDPtr[0] = srcDPtr[0] * scaleVal;
|
|
dstDPtr[1] = srcDPtr[1] * scaleVal;
|
|
|
|
dRealSize -= 2;
|
|
srcDPtr += 2;
|
|
dstDPtr += dstStrideDOuter;
|
|
} else { // dRealSize == 1
|
|
dstDPtr[0] = srcDPtr[0] * scaleVal;
|
|
dstDPtr[1] = (FLOAT16)0.0f; // Pad with zero
|
|
dRealSize = 0;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static void MNNFlashAttentionUpdateBlockOutput( float* dst, float* src, const float* scale, const float* normalizeScale, int depthQuad, int plane, int pack, int idx, int kvBlocks, int size, int bytes) {
|
|
auto dstPtr = (float16_t*)dst;
|
|
auto srcPtr = (float16_t*)src;
|
|
const auto stride0 = plane * pack;
|
|
|
|
if (idx == 0) {
|
|
memcpy(dst, src, size * bytes);
|
|
} else {
|
|
for (int j = 0; j < depthQuad; ++j) {
|
|
const auto baseOffset = j * stride0;
|
|
int i = 0;
|
|
const int plane4 = plane - (plane % 4);
|
|
for (; i < plane4; i += 4) {
|
|
|
|
auto pdst0 = dstPtr + baseOffset + (i + 0) * pack;
|
|
auto psrc0 = srcPtr + baseOffset + (i + 0) * pack;
|
|
auto pdst1 = dstPtr + baseOffset + (i + 1) * pack;
|
|
auto psrc1 = srcPtr + baseOffset + (i + 1) * pack;
|
|
auto pdst2 = dstPtr + baseOffset + (i + 2) * pack;
|
|
auto psrc2 = srcPtr + baseOffset + (i + 2) * pack;
|
|
auto pdst3 = dstPtr + baseOffset + (i + 3) * pack;
|
|
auto psrc3 = srcPtr + baseOffset + (i + 3) * pack;
|
|
|
|
float16x8_t src0 = vld1q_f16(psrc0);
|
|
float16x8_t dst0 = vld1q_f16(pdst0);
|
|
float16x8_t src1 = vld1q_f16(psrc1);
|
|
float16x8_t dst1 = vld1q_f16(pdst1);
|
|
float16x8_t src2 = vld1q_f16(psrc2);
|
|
float16x8_t dst2 = vld1q_f16(pdst2);
|
|
float16x8_t src3 = vld1q_f16(psrc3);
|
|
float16x8_t dst3 = vld1q_f16(pdst3);
|
|
|
|
float32x4_t svec0 = vdupq_n_f32(scale[i + 0]);
|
|
float32x4_t svec1 = vdupq_n_f32(scale[i + 1]);
|
|
float32x4_t svec2 = vdupq_n_f32(scale[i + 2]);
|
|
float32x4_t svec3 = vdupq_n_f32(scale[i + 3]);
|
|
|
|
|
|
float32x4_t res00 = vfmaq_f32(vcvt_f32_f16(vget_low_f16(src0)), vcvt_f32_f16(vget_low_f16(dst0)), svec0);
|
|
float32x4_t res10 = vfmaq_f32(vcvt_f32_f16(vget_high_f16(src0)), vcvt_f32_f16(vget_high_f16(dst0)), svec0);
|
|
|
|
float32x4_t res01 = vfmaq_f32(vcvt_f32_f16(vget_low_f16(src1)), vcvt_f32_f16(vget_low_f16(dst1)), svec1);
|
|
float32x4_t res11 = vfmaq_f32(vcvt_f32_f16(vget_high_f16(src1)), vcvt_f32_f16(vget_high_f16(dst1)), svec1);
|
|
|
|
float32x4_t res02 = vfmaq_f32(vcvt_f32_f16(vget_low_f16(src2)), vcvt_f32_f16(vget_low_f16(dst2)), svec2);
|
|
float32x4_t res12 = vfmaq_f32(vcvt_f32_f16(vget_high_f16(src2)), vcvt_f32_f16(vget_high_f16(dst2)), svec2);
|
|
|
|
float32x4_t res03 = vfmaq_f32(vcvt_f32_f16(vget_low_f16(src3)), vcvt_f32_f16(vget_low_f16(dst3)), svec3);
|
|
float32x4_t res13 = vfmaq_f32(vcvt_f32_f16(vget_high_f16(src3)), vcvt_f32_f16(vget_high_f16(dst3)), svec3);
|
|
|
|
vst1q_f16(pdst0, vcombine_f16(vcvt_f16_f32(res00), vcvt_f16_f32(res10)));
|
|
vst1q_f16(pdst1, vcombine_f16(vcvt_f16_f32(res01), vcvt_f16_f32(res11)));
|
|
vst1q_f16(pdst2, vcombine_f16(vcvt_f16_f32(res02), vcvt_f16_f32(res12)));
|
|
vst1q_f16(pdst3, vcombine_f16(vcvt_f16_f32(res03), vcvt_f16_f32(res13)));
|
|
}
|
|
|
|
for (; i < plane; ++i) {
|
|
auto pdst = dstPtr + baseOffset + i * pack;
|
|
auto psrc = srcPtr + baseOffset + i * pack;
|
|
|
|
float16x8_t srcF16 = vld1q_f16(psrc);
|
|
float16x8_t dstF16 = vld1q_f16(pdst);
|
|
float32x4_t svec = vdupq_n_f32(scale[i]);
|
|
|
|
float32x4_t s0 = vcvt_f32_f16(vget_low_f16(srcF16));
|
|
float32x4_t s1 = vcvt_f32_f16(vget_high_f16(srcF16));
|
|
float32x4_t d0 = vcvt_f32_f16(vget_low_f16(dstF16));
|
|
float32x4_t d1 = vcvt_f32_f16(vget_high_f16(dstF16));
|
|
|
|
float32x4_t res0 = vfmaq_f32(s0, d0, svec);
|
|
float32x4_t res1 = vfmaq_f32(s1, d1, svec);
|
|
|
|
vst1q_f16(pdst, vcombine_f16(vcvt_f16_f32(res0), vcvt_f16_f32(res1)));
|
|
}
|
|
}
|
|
}
|
|
|
|
if (idx == kvBlocks - 1) {
|
|
for (int j = 0; j < depthQuad; ++j) {
|
|
const auto baseOffset = j * stride0;
|
|
int i = 0;
|
|
const int plane4 = plane - (plane % 4);
|
|
for (; i < plane4; i += 4) {
|
|
auto pdst0 = dstPtr + baseOffset + (i + 0) * pack;
|
|
auto pdst1 = dstPtr + baseOffset + (i + 1) * pack;
|
|
auto pdst2 = dstPtr + baseOffset + (i + 2) * pack;
|
|
auto pdst3 = dstPtr + baseOffset + (i + 3) * pack;
|
|
|
|
float16x8_t dst0 = vld1q_f16(pdst0);
|
|
float16x8_t dst1 = vld1q_f16(pdst1);
|
|
float16x8_t dst2 = vld1q_f16(pdst2);
|
|
float16x8_t dst3 = vld1q_f16(pdst3);
|
|
|
|
float32x4_t ns0 = vdupq_n_f32(1.0f / normalizeScale[i + 0]);
|
|
float32x4_t ns1 = vdupq_n_f32(1.0f / normalizeScale[i + 1]);
|
|
float32x4_t ns2 = vdupq_n_f32(1.0f / normalizeScale[i + 2]);
|
|
float32x4_t ns3 = vdupq_n_f32(1.0f / normalizeScale[i + 3]);
|
|
|
|
float32x4_t d00 = vmulq_f32(vcvt_f32_f16(vget_low_f16(dst0)), ns0);
|
|
float32x4_t d10 = vmulq_f32(vcvt_f32_f16(vget_high_f16(dst0)), ns0);
|
|
float32x4_t d01 = vmulq_f32(vcvt_f32_f16(vget_low_f16(dst1)), ns1);
|
|
float32x4_t d11 = vmulq_f32(vcvt_f32_f16(vget_high_f16(dst1)), ns1);
|
|
float32x4_t d02 = vmulq_f32(vcvt_f32_f16(vget_low_f16(dst2)), ns2);
|
|
float32x4_t d12 = vmulq_f32(vcvt_f32_f16(vget_high_f16(dst2)), ns2);
|
|
float32x4_t d03 = vmulq_f32(vcvt_f32_f16(vget_low_f16(dst3)), ns3);
|
|
float32x4_t d13 = vmulq_f32(vcvt_f32_f16(vget_high_f16(dst3)), ns3);
|
|
|
|
vst1q_f16(pdst0, vcombine_f16(vcvt_f16_f32(d00), vcvt_f16_f32(d10)));
|
|
vst1q_f16(pdst1, vcombine_f16(vcvt_f16_f32(d01), vcvt_f16_f32(d11)));
|
|
vst1q_f16(pdst2, vcombine_f16(vcvt_f16_f32(d02), vcvt_f16_f32(d12)));
|
|
vst1q_f16(pdst3, vcombine_f16(vcvt_f16_f32(d03), vcvt_f16_f32(d13)));
|
|
}
|
|
|
|
for (; i < plane; ++i) {
|
|
auto pdst = dstPtr + baseOffset + i * pack;
|
|
float32x4_t nsvec = vdupq_n_f32(1.0f / normalizeScale[i]);
|
|
|
|
float16x8_t dstF16 = vld1q_f16(pdst);
|
|
float32x4_t d0 = vcvt_f32_f16(vget_low_f16(dstF16));
|
|
float32x4_t d1 = vcvt_f32_f16(vget_high_f16(dstF16));
|
|
|
|
d0 = vmulq_f32(d0, nsvec);
|
|
d1 = vmulq_f32(d1, nsvec);
|
|
|
|
vst1q_f16(pdst, vcombine_f16(vcvt_f16_f32(d0), vcvt_f16_f32(d1)));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static void MNNAttenUnpackAndConvertFp16(float* dst, float* src, size_t depth, size_t planesize, int pack) {
|
|
// src: (UP_DIV(depth, pack), planesize, pack), float16
|
|
// dst: (planesize, depth), float32
|
|
// pack=8
|
|
|
|
if (planesize == 1) {
|
|
MNNDequantizeFP16((int16_t*)src, dst, depth);
|
|
return; // no need to convert
|
|
}
|
|
const auto depthDiv8 = UP_DIV(depth, pack);
|
|
const auto srcStep = pack * planesize;
|
|
const auto dstStep = depth;
|
|
|
|
auto remainDepth = depth % pack;
|
|
auto depthQuad = depthDiv8;
|
|
if (remainDepth > 0) {
|
|
depthQuad -= 1; // last quad is not full
|
|
}
|
|
|
|
for (int i = 0; i < depthQuad; ++i) {
|
|
auto realsize = planesize;
|
|
auto srcPtr = (FLOAT16*)src + i * srcStep;
|
|
auto dstPtr = (float*)dst + i * pack;
|
|
while (realsize >= 8) {
|
|
float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack);
|
|
float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack);
|
|
float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack);
|
|
float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack);
|
|
float16x8_t s4_f16 = vld1q_f16(srcPtr + 4 * pack);
|
|
float16x8_t s5_f16 = vld1q_f16(srcPtr + 5 * pack);
|
|
float16x8_t s6_f16 = vld1q_f16(srcPtr + 6 * pack);
|
|
float16x8_t s7_f16 = vld1q_f16(srcPtr + 7 * pack);
|
|
|
|
float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16));
|
|
float32x4_t d01_f32 = vcvt_f32_f16(vget_high_f16(s0_f16));
|
|
float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16));
|
|
float32x4_t d11_f32 = vcvt_f32_f16(vget_high_f16(s1_f16));
|
|
float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16));
|
|
float32x4_t d21_f32 = vcvt_f32_f16(vget_high_f16(s2_f16));
|
|
float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16));
|
|
float32x4_t d31_f32 = vcvt_f32_f16(vget_high_f16(s3_f16));
|
|
float32x4_t d40_f32 = vcvt_f32_f16(vget_low_f16(s4_f16));
|
|
float32x4_t d41_f32 = vcvt_f32_f16(vget_high_f16(s4_f16));
|
|
float32x4_t d50_f32 = vcvt_f32_f16(vget_low_f16(s5_f16));
|
|
float32x4_t d51_f32 = vcvt_f32_f16(vget_high_f16(s5_f16));
|
|
float32x4_t d60_f32 = vcvt_f32_f16(vget_low_f16(s6_f16));
|
|
float32x4_t d61_f32 = vcvt_f32_f16(vget_high_f16(s6_f16));
|
|
float32x4_t d70_f32 = vcvt_f32_f16(vget_low_f16(s7_f16));
|
|
float32x4_t d71_f32 = vcvt_f32_f16(vget_high_f16(s7_f16));
|
|
|
|
vst1q_f32(dstPtr + 0 * dstStep, d00_f32); vst1q_f32(dstPtr + 0 * dstStep + 4, d01_f32);
|
|
vst1q_f32(dstPtr + 1 * dstStep, d10_f32); vst1q_f32(dstPtr + 1 * dstStep + 4, d11_f32);
|
|
vst1q_f32(dstPtr + 2 * dstStep, d20_f32); vst1q_f32(dstPtr + 2 * dstStep + 4, d21_f32);
|
|
vst1q_f32(dstPtr + 3 * dstStep, d30_f32); vst1q_f32(dstPtr + 3 * dstStep + 4, d31_f32);
|
|
vst1q_f32(dstPtr + 4 * dstStep, d40_f32); vst1q_f32(dstPtr + 4 * dstStep + 4, d41_f32);
|
|
vst1q_f32(dstPtr + 5 * dstStep, d50_f32); vst1q_f32(dstPtr + 5 * dstStep + 4, d51_f32);
|
|
vst1q_f32(dstPtr + 6 * dstStep, d60_f32); vst1q_f32(dstPtr + 6 * dstStep + 4, d61_f32);
|
|
vst1q_f32(dstPtr + 7 * dstStep, d70_f32); vst1q_f32(dstPtr + 7 * dstStep + 4, d71_f32);
|
|
|
|
srcPtr += 8 * pack;
|
|
dstPtr += 8 * dstStep;
|
|
realsize -= 8;
|
|
}
|
|
if (realsize >= 4) {
|
|
float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack);
|
|
float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack);
|
|
float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack);
|
|
float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack);
|
|
|
|
float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16));
|
|
float32x4_t d01_f32 = vcvt_f32_f16(vget_high_f16(s0_f16));
|
|
float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16));
|
|
float32x4_t d11_f32 = vcvt_f32_f16(vget_high_f16(s1_f16));
|
|
float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16));
|
|
float32x4_t d21_f32 = vcvt_f32_f16(vget_high_f16(s2_f16));
|
|
float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16));
|
|
float32x4_t d31_f32 = vcvt_f32_f16(vget_high_f16(s3_f16));
|
|
|
|
vst1q_f32(dstPtr + 0 * dstStep, d00_f32); vst1q_f32(dstPtr + 0 * dstStep + 4, d01_f32);
|
|
vst1q_f32(dstPtr + 1 * dstStep, d10_f32); vst1q_f32(dstPtr + 1 * dstStep + 4, d11_f32);
|
|
vst1q_f32(dstPtr + 2 * dstStep, d20_f32); vst1q_f32(dstPtr + 2 * dstStep + 4, d21_f32);
|
|
vst1q_f32(dstPtr + 3 * dstStep, d30_f32); vst1q_f32(dstPtr + 3 * dstStep + 4, d31_f32);
|
|
|
|
srcPtr += 4 * pack;
|
|
dstPtr += 4 * dstStep;
|
|
realsize -= 4;
|
|
}
|
|
while (realsize > 0) {
|
|
auto s0_fp16 = vld1q_f16(srcPtr);
|
|
auto s00_fp32 = vcvt_f32_f16(vget_low_f16(s0_fp16));
|
|
auto s01_fp32 = vcvt_f32_f16(vget_high_f16(s0_fp16));
|
|
vst1q_f32(dstPtr, s00_fp32);
|
|
vst1q_f32(dstPtr + 4, s01_fp32);
|
|
srcPtr += pack;
|
|
dstPtr += dstStep;
|
|
realsize--;
|
|
}
|
|
}
|
|
|
|
// process remain depth < 8
|
|
if (remainDepth >= 4) {
|
|
auto realsize = planesize;
|
|
auto srcPtr = (FLOAT16*)src + (depthDiv8 - 1) * srcStep;
|
|
auto dstPtr = (float*)dst + (depthDiv8 - 1) * pack;
|
|
auto extraDepth = remainDepth - 4;
|
|
|
|
float tmp0[4];
|
|
float tmp1[4];
|
|
float tmp2[4];
|
|
float tmp3[4];
|
|
float tmp4[4];
|
|
float tmp5[4];
|
|
float tmp6[4];
|
|
float tmp7[4];
|
|
|
|
while (realsize >= 8) {
|
|
float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack);
|
|
float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack);
|
|
float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack);
|
|
float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack);
|
|
float16x8_t s4_f16 = vld1q_f16(srcPtr + 4 * pack);
|
|
float16x8_t s5_f16 = vld1q_f16(srcPtr + 5 * pack);
|
|
float16x8_t s6_f16 = vld1q_f16(srcPtr + 6 * pack);
|
|
float16x8_t s7_f16 = vld1q_f16(srcPtr + 7 * pack);
|
|
|
|
|
|
float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16));
|
|
float32x4_t d01_f32 = vcvt_f32_f16(vget_high_f16(s0_f16));
|
|
float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16));
|
|
float32x4_t d11_f32 = vcvt_f32_f16(vget_high_f16(s1_f16));
|
|
float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16));
|
|
float32x4_t d21_f32 = vcvt_f32_f16(vget_high_f16(s2_f16));
|
|
float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16));
|
|
float32x4_t d31_f32 = vcvt_f32_f16(vget_high_f16(s3_f16));
|
|
float32x4_t d40_f32 = vcvt_f32_f16(vget_low_f16(s4_f16));
|
|
float32x4_t d41_f32 = vcvt_f32_f16(vget_high_f16(s4_f16));
|
|
float32x4_t d50_f32 = vcvt_f32_f16(vget_low_f16(s5_f16));
|
|
float32x4_t d51_f32 = vcvt_f32_f16(vget_high_f16(s5_f16));
|
|
float32x4_t d60_f32 = vcvt_f32_f16(vget_low_f16(s6_f16));
|
|
float32x4_t d61_f32 = vcvt_f32_f16(vget_high_f16(s6_f16));
|
|
float32x4_t d70_f32 = vcvt_f32_f16(vget_low_f16(s7_f16));
|
|
float32x4_t d71_f32 = vcvt_f32_f16(vget_high_f16(s7_f16));
|
|
|
|
vst1q_f32(dstPtr + 0 * dstStep, d00_f32); vst1q_f32(tmp0, d01_f32);
|
|
vst1q_f32(dstPtr + 1 * dstStep, d10_f32); vst1q_f32(tmp1, d11_f32);
|
|
vst1q_f32(dstPtr + 2 * dstStep, d20_f32); vst1q_f32(tmp2, d21_f32);
|
|
vst1q_f32(dstPtr + 3 * dstStep, d30_f32); vst1q_f32(tmp3, d31_f32);
|
|
vst1q_f32(dstPtr + 4 * dstStep, d40_f32); vst1q_f32(tmp4, d41_f32);
|
|
vst1q_f32(dstPtr + 5 * dstStep, d50_f32); vst1q_f32(tmp5, d51_f32);
|
|
vst1q_f32(dstPtr + 6 * dstStep, d60_f32); vst1q_f32(tmp6, d61_f32);
|
|
vst1q_f32(dstPtr + 7 * dstStep, d70_f32); vst1q_f32(tmp7, d71_f32);
|
|
|
|
memcpy(dstPtr + 0 * dstStep + 4, tmp0, sizeof(float) * extraDepth);
|
|
memcpy(dstPtr + 1 * dstStep + 4, tmp1, sizeof(float) * extraDepth);
|
|
memcpy(dstPtr + 2 * dstStep + 4, tmp2, sizeof(float) * extraDepth);
|
|
memcpy(dstPtr + 3 * dstStep + 4, tmp3, sizeof(float) * extraDepth);
|
|
memcpy(dstPtr + 4 * dstStep + 4, tmp4, sizeof(float) * extraDepth);
|
|
memcpy(dstPtr + 5 * dstStep + 4, tmp5, sizeof(float) * extraDepth);
|
|
memcpy(dstPtr + 6 * dstStep + 4, tmp6, sizeof(float) * extraDepth);
|
|
memcpy(dstPtr + 7 * dstStep + 4, tmp7, sizeof(float) * extraDepth);
|
|
|
|
srcPtr += 8 * pack;
|
|
dstPtr += 8 * dstStep;
|
|
realsize -= 8;
|
|
}
|
|
if (realsize >= 4) {
|
|
float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack);
|
|
float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack);
|
|
float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack);
|
|
float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack);
|
|
|
|
float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16));
|
|
float32x4_t d01_f32 = vcvt_f32_f16(vget_high_f16(s0_f16));
|
|
float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16));
|
|
float32x4_t d11_f32 = vcvt_f32_f16(vget_high_f16(s1_f16));
|
|
float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16));
|
|
float32x4_t d21_f32 = vcvt_f32_f16(vget_high_f16(s2_f16));
|
|
float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16));
|
|
float32x4_t d31_f32 = vcvt_f32_f16(vget_high_f16(s3_f16));
|
|
|
|
vst1q_f32(dstPtr + 0 * dstStep, d00_f32); vst1q_f32(tmp0, d01_f32);
|
|
vst1q_f32(dstPtr + 1 * dstStep, d10_f32); vst1q_f32(tmp1, d11_f32);
|
|
vst1q_f32(dstPtr + 2 * dstStep, d20_f32); vst1q_f32(tmp2, d21_f32);
|
|
vst1q_f32(dstPtr + 3 * dstStep, d30_f32); vst1q_f32(tmp3, d31_f32);
|
|
|
|
memcpy(dstPtr + 0 * dstStep + 4, tmp0, sizeof(float) * extraDepth);
|
|
memcpy(dstPtr + 1 * dstStep + 4, tmp1, sizeof(float) * extraDepth);
|
|
memcpy(dstPtr + 2 * dstStep + 4, tmp2, sizeof(float) * extraDepth);
|
|
memcpy(dstPtr + 3 * dstStep + 4, tmp3, sizeof(float) * extraDepth);
|
|
|
|
srcPtr += 4 * pack;
|
|
dstPtr += 4 * dstStep;
|
|
realsize -= 4;
|
|
}
|
|
while (realsize > 0) {
|
|
auto s0_fp16 = vld1q_f16(srcPtr);
|
|
auto d00_fp32 = vcvt_f32_f16(vget_low_f16(s0_fp16));
|
|
auto d01_fp32 = vcvt_f32_f16(vget_high_f16(s0_fp16));
|
|
vst1q_f32(dstPtr, d00_fp32);
|
|
vst1q_f32(tmp0, d01_fp32);
|
|
memcpy(dstPtr + 4, tmp0, sizeof(float) * extraDepth);
|
|
srcPtr += pack;
|
|
dstPtr += dstStep;
|
|
realsize--;
|
|
}
|
|
}
|
|
|
|
if (remainDepth > 0 && remainDepth < 4) {
|
|
auto realsize = planesize;
|
|
auto srcPtr = (FLOAT16*)src + (depthDiv8 - 1) * srcStep;
|
|
auto dstPtr = (float*)dst + (depthDiv8 - 1) * pack;
|
|
|
|
float tmp0[4];
|
|
float tmp1[4];
|
|
float tmp2[4];
|
|
float tmp3[4];
|
|
float tmp4[4];
|
|
float tmp5[4];
|
|
float tmp6[4];
|
|
float tmp7[4];
|
|
|
|
while (realsize >= 8) {
|
|
float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack);
|
|
float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack);
|
|
float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack);
|
|
float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack);
|
|
float16x8_t s4_f16 = vld1q_f16(srcPtr + 4 * pack);
|
|
float16x8_t s5_f16 = vld1q_f16(srcPtr + 5 * pack);
|
|
float16x8_t s6_f16 = vld1q_f16(srcPtr + 6 * pack);
|
|
float16x8_t s7_f16 = vld1q_f16(srcPtr + 7 * pack);
|
|
|
|
float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16));
|
|
float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16));
|
|
float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16));
|
|
float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16));
|
|
float32x4_t d40_f32 = vcvt_f32_f16(vget_low_f16(s4_f16));
|
|
float32x4_t d50_f32 = vcvt_f32_f16(vget_low_f16(s5_f16));
|
|
float32x4_t d60_f32 = vcvt_f32_f16(vget_low_f16(s6_f16));
|
|
float32x4_t d70_f32 = vcvt_f32_f16(vget_low_f16(s7_f16));
|
|
|
|
vst1q_f32(tmp0, d00_f32);
|
|
vst1q_f32(tmp1, d10_f32);
|
|
vst1q_f32(tmp2, d20_f32);
|
|
vst1q_f32(tmp3, d30_f32);
|
|
vst1q_f32(tmp4, d40_f32);
|
|
vst1q_f32(tmp5, d50_f32);
|
|
vst1q_f32(tmp6, d60_f32);
|
|
vst1q_f32(tmp7, d70_f32);
|
|
|
|
memcpy(dstPtr + 0 * dstStep, tmp0, sizeof(float) * remainDepth);
|
|
memcpy(dstPtr + 1 * dstStep, tmp1, sizeof(float) * remainDepth);
|
|
memcpy(dstPtr + 2 * dstStep, tmp2, sizeof(float) * remainDepth);
|
|
memcpy(dstPtr + 3 * dstStep, tmp3, sizeof(float) * remainDepth);
|
|
memcpy(dstPtr + 4 * dstStep, tmp4, sizeof(float) * remainDepth);
|
|
memcpy(dstPtr + 5 * dstStep, tmp5, sizeof(float) * remainDepth);
|
|
memcpy(dstPtr + 6 * dstStep, tmp6, sizeof(float) * remainDepth);
|
|
memcpy(dstPtr + 7 * dstStep, tmp7, sizeof(float) * remainDepth);
|
|
|
|
srcPtr += 8 * pack;
|
|
dstPtr += 8 * dstStep;
|
|
realsize -= 8;
|
|
}
|
|
if (realsize >= 4) {
|
|
float16x8_t s0_f16 = vld1q_f16(srcPtr + 0 * pack);
|
|
float16x8_t s1_f16 = vld1q_f16(srcPtr + 1 * pack);
|
|
float16x8_t s2_f16 = vld1q_f16(srcPtr + 2 * pack);
|
|
float16x8_t s3_f16 = vld1q_f16(srcPtr + 3 * pack);
|
|
|
|
float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16));
|
|
float32x4_t d10_f32 = vcvt_f32_f16(vget_low_f16(s1_f16));
|
|
float32x4_t d20_f32 = vcvt_f32_f16(vget_low_f16(s2_f16));
|
|
float32x4_t d30_f32 = vcvt_f32_f16(vget_low_f16(s3_f16));
|
|
|
|
vst1q_f32(tmp0, d00_f32);
|
|
vst1q_f32(tmp1, d10_f32);
|
|
vst1q_f32(tmp2, d20_f32);
|
|
vst1q_f32(tmp3, d30_f32);
|
|
|
|
memcpy(dstPtr + 0 * dstStep, tmp0, sizeof(float) * remainDepth);
|
|
memcpy(dstPtr + 1 * dstStep, tmp1, sizeof(float) * remainDepth);
|
|
memcpy(dstPtr + 2 * dstStep, tmp2, sizeof(float) * remainDepth);
|
|
memcpy(dstPtr + 3 * dstStep, tmp3, sizeof(float) * remainDepth);
|
|
|
|
srcPtr += 4 * pack;
|
|
dstPtr += 4 * dstStep;
|
|
realsize -= 4;
|
|
}
|
|
while (realsize > 0) {
|
|
auto s0_f16 = vld1q_f16(srcPtr);
|
|
float32x4_t d00_f32 = vcvt_f32_f16(vget_low_f16(s0_f16));
|
|
vst1q_f32(tmp0, d00_f32);
|
|
memcpy(dstPtr + 0 * dstStep, tmp0, sizeof(float) * remainDepth);
|
|
srcPtr += pack;
|
|
dstPtr += dstStep;
|
|
realsize--;
|
|
}
|
|
}
|
|
}
|
|
|
|
static void MNNAttenPackAndConvertFp32LP1(float* dst, const float* src, const int32_t* units, size_t depth, size_t planesize) {
|
|
int32_t eP = units[0];
|
|
int32_t lP = units[1];
|
|
|
|
if (lP != 1) {
|
|
MNN_ERROR("This function only supports lP=1\n");
|
|
return;
|
|
}
|
|
|
|
auto dstStride1 = eP;
|
|
auto dstStride0 = planesize * dstStride1;
|
|
|
|
for (int i = 0; i < depth; ++i) {
|
|
size_t realsize = planesize;
|
|
const float* srcPtr = src + i * planesize;
|
|
FLOAT16* dstPtr = (FLOAT16*)dst + (i % eP) + (i / eP) * dstStride0;
|
|
|
|
while (realsize >= 16) {
|
|
float32x4_t s0_f32 = vld1q_f32(srcPtr);
|
|
float32x4_t s1_f32 = vld1q_f32(srcPtr + 4);
|
|
float32x4_t s2_f32 = vld1q_f32(srcPtr + 8);
|
|
float32x4_t s3_f32 = vld1q_f32(srcPtr + 12);
|
|
|
|
float16x4_t d0_f16 = vcvt_f16_f32(s0_f32);
|
|
float16x4_t d1_f16 = vcvt_f16_f32(s1_f32);
|
|
float16x4_t d2_f16 = vcvt_f16_f32(s2_f32);
|
|
float16x4_t d3_f16 = vcvt_f16_f32(s3_f32);
|
|
|
|
vst1_lane_f16(dstPtr, d0_f16, 0);
|
|
vst1_lane_f16(dstPtr + dstStride1, d0_f16, 1);
|
|
vst1_lane_f16(dstPtr + 2 * dstStride1, d0_f16, 2);
|
|
vst1_lane_f16(dstPtr + 3 * dstStride1, d0_f16, 3);
|
|
|
|
vst1_lane_f16(dstPtr + 4 * dstStride1, d1_f16, 0);
|
|
vst1_lane_f16(dstPtr + 5 * dstStride1, d1_f16, 1);
|
|
vst1_lane_f16(dstPtr + 6 * dstStride1, d1_f16, 2);
|
|
vst1_lane_f16(dstPtr + 7 * dstStride1, d1_f16, 3);
|
|
|
|
vst1_lane_f16(dstPtr + 8 * dstStride1, d2_f16, 0);
|
|
vst1_lane_f16(dstPtr + 9 * dstStride1, d2_f16, 1);
|
|
vst1_lane_f16(dstPtr + 10 * dstStride1, d2_f16, 2);
|
|
vst1_lane_f16(dstPtr + 11 * dstStride1, d2_f16, 3);
|
|
|
|
vst1_lane_f16(dstPtr + 12 * dstStride1, d3_f16, 0);
|
|
vst1_lane_f16(dstPtr + 13 * dstStride1, d3_f16, 1);
|
|
vst1_lane_f16(dstPtr + 14 * dstStride1, d3_f16, 2);
|
|
vst1_lane_f16(dstPtr + 15 * dstStride1, d3_f16, 3);
|
|
|
|
srcPtr += 16;
|
|
dstPtr += 16 * dstStride1;
|
|
realsize -= 16;
|
|
}
|
|
|
|
if (realsize >= 8) {
|
|
float32x4_t s0_f32 = vld1q_f32(srcPtr);
|
|
float32x4_t s1_f32 = vld1q_f32(srcPtr + 4);
|
|
|
|
float16x4_t d0_f16 = vcvt_f16_f32(s0_f32);
|
|
float16x4_t d1_f16 = vcvt_f16_f32(s1_f32);
|
|
|
|
vst1_lane_f16(dstPtr, d0_f16, 0);
|
|
vst1_lane_f16(dstPtr + dstStride1, d0_f16, 1);
|
|
vst1_lane_f16(dstPtr + 2 * dstStride1, d0_f16, 2);
|
|
vst1_lane_f16(dstPtr + 3 * dstStride1, d0_f16, 3);
|
|
|
|
vst1_lane_f16(dstPtr + 4 * dstStride1, d1_f16, 0);
|
|
vst1_lane_f16(dstPtr + 5 * dstStride1, d1_f16, 1);
|
|
vst1_lane_f16(dstPtr + 6 * dstStride1, d1_f16, 2);
|
|
vst1_lane_f16(dstPtr + 7 * dstStride1, d1_f16, 3);
|
|
|
|
srcPtr += 8;
|
|
dstPtr += 8 * dstStride1;
|
|
realsize -= 8;
|
|
}
|
|
|
|
if (realsize >= 4) {
|
|
float32x4_t s0_f32 = vld1q_f32(srcPtr);
|
|
float16x4_t d0_f16 = vcvt_f16_f32(s0_f32);
|
|
|
|
vst1_lane_f16(dstPtr, d0_f16, 0);
|
|
vst1_lane_f16(dstPtr + dstStride1, d0_f16, 1);
|
|
vst1_lane_f16(dstPtr + 2 * dstStride1, d0_f16, 2);
|
|
vst1_lane_f16(dstPtr + 3 * dstStride1, d0_f16, 3);
|
|
|
|
srcPtr += 4;
|
|
dstPtr += 4 * dstStride1;
|
|
realsize -= 4;
|
|
}
|
|
|
|
for (; realsize > 0; --realsize) {
|
|
*dstPtr = (FLOAT16)(*srcPtr);
|
|
srcPtr++;
|
|
dstPtr += dstStride1;
|
|
}
|
|
}
|
|
}
|
|
|
|
static void MNNAttenPackAndConvertFp32(float* dst, float* src, const int32_t* units, size_t depth, size_t planesize) {
|
|
int32_t eP = units[0];
|
|
int32_t lP = units[1]; // Now lP=1 or 2
|
|
|
|
if (lP != 1 && lP != 2) {
|
|
MNN_ERROR("This function only supports lP=1 or 2\n");
|
|
return;
|
|
}
|
|
|
|
// src [depth, planesize] (float32)
|
|
// dst [depth/eP, planesize/lP, eP, lP] (float16)
|
|
|
|
if (lP == 1) {
|
|
MNNAttenPackAndConvertFp32LP1(dst, src, units, depth, planesize);
|
|
return;
|
|
}
|
|
|
|
auto dstStride1 = eP * lP;
|
|
auto dstStride0 = UP_DIV(planesize, lP) * dstStride1;
|
|
|
|
for (int i = 0; i < depth; ++i) {
|
|
size_t realsize = planesize;
|
|
const float* srcPtr = src + i * planesize;
|
|
FLOAT16* dstPtr = (FLOAT16*)dst + (i % eP) * lP + (i / eP) * dstStride0;
|
|
|
|
while (realsize >= 16) {
|
|
float32x4_t s0 = vld1q_f32(srcPtr);
|
|
float32x4_t s1 = vld1q_f32(srcPtr + 4);
|
|
float32x4_t s2 = vld1q_f32(srcPtr + 8);
|
|
float32x4_t s3 = vld1q_f32(srcPtr + 12);
|
|
|
|
float16x4_t h0 = vcvt_f16_f32(s0);
|
|
float16x4_t h1 = vcvt_f16_f32(s1);
|
|
float16x4_t h2 = vcvt_f16_f32(s2);
|
|
float16x4_t h3 = vcvt_f16_f32(s3);
|
|
|
|
vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0);
|
|
vst1_lane_u32((uint32_t*)(dstPtr + dstStride1), vreinterpret_u32_f16(h0), 1);
|
|
|
|
vst1_lane_u32((uint32_t*)(dstPtr + 2 * dstStride1), vreinterpret_u32_f16(h1), 0);
|
|
vst1_lane_u32((uint32_t*)(dstPtr + 3 * dstStride1), vreinterpret_u32_f16(h1), 1);
|
|
|
|
vst1_lane_u32((uint32_t*)(dstPtr + 4 * dstStride1), vreinterpret_u32_f16(h2), 0);
|
|
vst1_lane_u32((uint32_t*)(dstPtr + 5 * dstStride1), vreinterpret_u32_f16(h2), 1);
|
|
|
|
vst1_lane_u32((uint32_t*)(dstPtr + 6 * dstStride1), vreinterpret_u32_f16(h3), 0);
|
|
vst1_lane_u32((uint32_t*)(dstPtr + 7 * dstStride1), vreinterpret_u32_f16(h3), 1);
|
|
|
|
realsize -= 16;
|
|
srcPtr += 16;
|
|
dstPtr += 8 * dstStride1;
|
|
}
|
|
|
|
if (realsize >= 8) {
|
|
float32x4_t s0 = vld1q_f32(srcPtr);
|
|
float32x4_t s1 = vld1q_f32(srcPtr + 4);
|
|
|
|
float16x4_t h0 = vcvt_f16_f32(s0);
|
|
float16x4_t h1 = vcvt_f16_f32(s1);
|
|
|
|
vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0);
|
|
vst1_lane_u32((uint32_t*)(dstPtr + dstStride1), vreinterpret_u32_f16(h0), 1);
|
|
|
|
vst1_lane_u32((uint32_t*)(dstPtr + 2 * dstStride1), vreinterpret_u32_f16(h1), 0);
|
|
vst1_lane_u32((uint32_t*)(dstPtr + 3 * dstStride1), vreinterpret_u32_f16(h1), 1);
|
|
|
|
realsize -= 8;
|
|
srcPtr += 8;
|
|
dstPtr += 4 * dstStride1;
|
|
}
|
|
|
|
if (realsize >= 4) {
|
|
float32x4_t s0 = vld1q_f32(srcPtr);
|
|
float16x4_t h0 = vcvt_f16_f32(s0);
|
|
|
|
vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0);
|
|
vst1_lane_u32((uint32_t*)(dstPtr + dstStride1), vreinterpret_u32_f16(h0), 1);
|
|
|
|
realsize -= 4;
|
|
srcPtr += 4;
|
|
dstPtr += 2 * dstStride1;
|
|
}
|
|
|
|
if (realsize >= 2) {
|
|
float32x2_t s0 = vld1_f32(srcPtr);
|
|
float16x4_t h0 = vcvt_f16_f32(vcombine_f32(s0, s0));
|
|
|
|
vst1_lane_u32((uint32_t*)dstPtr, vreinterpret_u32_f16(h0), 0);
|
|
|
|
realsize -= 2;
|
|
srcPtr += 2;
|
|
dstPtr += dstStride1;
|
|
}
|
|
|
|
if (realsize > 0) {
|
|
dstPtr[0] = (FLOAT16)srcPtr[0];
|
|
dstPtr[1] = (FLOAT16)0.0f;
|
|
}
|
|
}
|
|
}
|
|
|
|
#endif // MNN_SUPPORT_TRANSFORMER_FUSE
|
|
|
|
#ifdef MNN_LOW_MEMORY
|
|
void MNNAbsMaxFP16(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) {
|
|
if (pack == 4) {
|
|
MNNAbsMaxFP16_Pack4(source, absmax, src_depth_quad, realSize, pack);
|
|
return;
|
|
}
|
|
if (pack == 8) {
|
|
MNNAbsMaxFP16_Pack8(source, absmax, src_depth_quad, realSize, pack);
|
|
return;
|
|
}
|
|
// source: (src_depth_quad, realSize, pack)
|
|
auto srcStep = pack * realSize;
|
|
auto srcPtr = (FLOAT16*)source;
|
|
auto dstPtr = (FLOAT16*)absmax;
|
|
for (int i = 0; i < realSize; ++i) {
|
|
FLOAT16 absmaxVal = 0; // absmaxVal>=0
|
|
for (int c = 0; c < src_depth_quad; ++c) {
|
|
auto src = srcPtr + c * srcStep + i * pack;
|
|
for (int k = 0; k < pack; ++k) {
|
|
if (std::abs(src[k]) > absmaxVal) {
|
|
absmaxVal = std::abs(src[k]);
|
|
}
|
|
}
|
|
}
|
|
dstPtr[i] = absmaxVal;
|
|
}
|
|
return;
|
|
}
|
|
|
|
static void MNNDynamicQuantFP16(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack, const float* bias = nullptr) {
|
|
if (pack == 8) {
|
|
MNNDynamicQuantFP16_Pack8(src, dst, scale, src_depth_quad,realSize, nullptr, pack);
|
|
return;
|
|
}
|
|
if (pack == 4) {
|
|
MNNDynamicQuantFP16_Pack4(src, dst, scale, src_depth_quad,realSize, nullptr, pack);
|
|
return;
|
|
}
|
|
int8_t* dstPtr = dst;
|
|
auto srcPtr = (FLOAT16*)src;
|
|
|
|
for (int i = 0; i < realSize; ++i) {
|
|
auto scaleVal = static_cast<FLOAT16>(scale[i]);
|
|
for (int c = 0; c < src_depth_quad; ++c) {
|
|
auto srcZ = srcPtr + c * pack * realSize + i * pack;
|
|
auto dstZ = dstPtr + c * pack * realSize + i * pack;
|
|
for (int k = 0; k < pack; ++k) {
|
|
int val = (int)roundf(srcZ[k] * scaleVal);
|
|
dstZ[k] = val;
|
|
}
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
|
|
static void MNNAsyQuantFunc_Arm82(int8_t* dst, const float* src, float* qscale, float* qbias, const size_t* info) {
|
|
// input shape: [kernelsize, blockNum, blockLU, EP, LP]
|
|
// qscale&qbias [blockNum, EP]
|
|
auto blockNum = info[0];
|
|
auto EP = info[1]; // real area for data
|
|
auto LP = info[2]; // Innermost data layout, may come from backend's pack or gemmint8 units' SRC_UNIT
|
|
auto DST_XUNIT = info[3]; // backend gemmint8 units
|
|
auto SRC_UNIT = info[4];
|
|
auto kernelsize = info[5];
|
|
auto blockLU = info[6];
|
|
auto stride0 = blockNum * blockLU * EP * LP;
|
|
auto stride1 = blockLU * EP * LP;
|
|
auto srcPtr = (FLOAT16*)src;
|
|
#ifdef __aarch64__
|
|
if (LP == 4 || LP == 8) {
|
|
for (int k = 0; k < kernelsize; ++k) {
|
|
for (int i = 0; i < blockNum; ++i) {
|
|
if (LP == 4) {
|
|
MNNDynamicQuantFP16_Pack4((float*)(srcPtr + k * stride0 + i * stride1), dst + k * stride0 + i * stride1, qscale + i * EP, blockLU, EP, qbias + i * EP, LP);
|
|
}
|
|
if (LP == 8) {
|
|
MNNDynamicQuantFP16_Pack8((float*)(srcPtr + k * stride0 + i * stride1), dst + k * stride0 + i * stride1, qscale + i * EP, blockLU, EP, qbias + i * EP, LP);
|
|
}
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
#endif
|
|
for (int i = 0; i < EP; ++i) {
|
|
for (int bk = 0; bk < blockNum; ++bk) {
|
|
float quant_scale = qscale[i + bk * EP];
|
|
float quant_bias = qbias[i + bk * EP];
|
|
for (int n = 0; n < kernelsize; ++n) {
|
|
for (int k = 0; k < blockLU; ++k) {
|
|
for (int j = 0; j < LP; ++j) {
|
|
int dataIndx = n * stride0 + bk * stride1 + k * EP * LP + i * LP + j;
|
|
auto data_ = static_cast<float>(srcPtr[dataIndx]);
|
|
int qval = static_cast<int32_t>(roundf(data_ * quant_scale + quant_bias));
|
|
dst[dataIndx] = qval;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static void MNNAsyQuantInfo_FP16(float* scale, float* bias, float* qscale, float* qbias, float* dstMin, float* dstMax, float* src, const size_t* info) {
|
|
auto blockNum = info[0];
|
|
auto plane = info[1]; // real area for data
|
|
auto innerSide = info[2]; // Innermost data layout, may come from backend's pack or gemmint8 units' SRC_UNIT
|
|
auto DST_XUNIT = info[3];
|
|
auto kernelsize = info[5];
|
|
auto blockLU = info[6];
|
|
auto stride0 = blockNum * blockLU * plane * innerSide;
|
|
auto stride1 = blockLU * plane * innerSide;
|
|
auto srcPtr = (FLOAT16*)src;
|
|
|
|
// input shape: [kernelsize,blocknum,blocklu,DST_XUNIT,SRC_UNIT] or [ic/core->pack, plane, core->pack]
|
|
// dequant scale/bias : [EU, blockNum, step]
|
|
// quant scale/bias: [blockNum, plane]
|
|
if (info[7] == 1) { // scale&bias:[1]
|
|
ARM82CountMinMaxValue(src, dstMin, dstMax, kernelsize * stride0);
|
|
float maxval = *(FLOAT16*)dstMax;
|
|
float minval = *(FLOAT16*)dstMin;
|
|
if (info[8] == 1 && (maxval - minval) > 1e-7) {
|
|
if (minval > 0.f) {
|
|
minval = 0.f;
|
|
} else if (maxval < 0.f){
|
|
maxval = 0.f;
|
|
}
|
|
}
|
|
auto range = maxval - minval;
|
|
if (range <= 1e-7) {
|
|
scale[0] = 0.f;
|
|
qscale[0] = 0.f;
|
|
qbias[0] = 0.f;
|
|
bias[0] = maxval;
|
|
} else {
|
|
qscale[0] = 255.f / range;
|
|
scale[0] = range / 255.f;
|
|
qbias[0] = roundf(-minval * 255.f / range)- 128.f;
|
|
bias[0] = -qbias[0] * scale[0];
|
|
}
|
|
return;
|
|
}
|
|
|
|
#ifdef __aarch64__
|
|
if (DST_XUNIT == 12 || DST_XUNIT == 16) { // Arm82/SME2, fp16: core->pack=8, SRC_UNIT=4
|
|
// max,min shape: [blockNum, EP]
|
|
if (innerSide == 4) {
|
|
for (int i = 0; i < kernelsize; ++i) {
|
|
MNNLocalMinMaxFP16_Pack4(dstMin, dstMax, (float*)(srcPtr + i * stride0), blockNum, blockLU, plane, innerSide, i);
|
|
}
|
|
}
|
|
if (innerSide == 8) {
|
|
for (int i = 0; i < kernelsize; ++i) {
|
|
MNNLocalMinMaxFP16_Pack8(dstMin, dstMax, (float*)(srcPtr + i * stride0), blockNum, blockLU, plane, innerSide, i);
|
|
}
|
|
}
|
|
// scale, bias
|
|
if (DST_XUNIT == 12) {
|
|
auto success = MNNAsyLocalQuantInfo_EP12_FP16(scale, bias, qscale, qbias, dstMin, dstMax, info);
|
|
if (!success) {
|
|
MNN_ERROR("Call error: MNNAsyLocalQuantInfo_EP12_FP16\n");
|
|
return;
|
|
}
|
|
return;
|
|
}
|
|
if (DST_XUNIT == 16) {
|
|
auto success = MNNAsyLocalQuantInfo_EP16_FP16(scale, bias, qscale, qbias, dstMin, dstMax, info);
|
|
if (!success) {
|
|
MNN_ERROR("Call error: MNNAsyLocalQuantInfo_EP16_FP16\n");
|
|
return;
|
|
}
|
|
return;
|
|
}
|
|
}
|
|
if (DST_XUNIT == 10 && innerSide == 8) { // Arm86, fp16: core->pack=8, SRC_UNIT=8
|
|
// max,min shape: [blockNum, plane]
|
|
for (int i = 0; i < kernelsize; ++i) {
|
|
MNNLocalMinMaxFP16_Pack8(dstMin, dstMax, (float*)(srcPtr + i * stride0), blockNum, blockLU, plane, innerSide, i);
|
|
}
|
|
// scale, bias
|
|
auto success = MNNAsyLocalQuantInfo_EP10_FP16(scale, bias, qscale, qbias, dstMin, dstMax, info);
|
|
if (!success) {
|
|
MNN_ERROR("Call error: MNNAsyLocalQuantInfo_EP10_FP16\n");
|
|
return;
|
|
}
|
|
return;
|
|
}
|
|
#else
|
|
// aarch32
|
|
// max,min shape: [blockNum, plane]
|
|
auto minPtr = (FLOAT16*)dstMin;
|
|
auto maxPtr = (FLOAT16*)dstMax;
|
|
for (int i = 0; i < plane; ++i) {
|
|
for (int bk = 0; bk < blockNum; ++bk) {
|
|
auto idx0 = i * innerSide + bk * stride1;
|
|
auto max_ = srcPtr[idx0];
|
|
auto min_ = max_;
|
|
for (int n = 0; n < kernelsize; ++n) {
|
|
for (int k = 0; k < blockLU; ++k) {
|
|
for (int j = 0; j < innerSide; ++j) {
|
|
auto dataIndx = idx0 + n * stride0 + k * (plane * innerSide) + j;
|
|
auto data_ = srcPtr[dataIndx];
|
|
max_ = ALIMAX(max_, data_);
|
|
min_ = ALIMIN(min_, data_);
|
|
}
|
|
}
|
|
}
|
|
auto sindx = i + bk * plane;
|
|
minPtr[sindx] = min_;
|
|
maxPtr[sindx] = max_;
|
|
}
|
|
}
|
|
// scale, bias
|
|
for (int i = 0; i < plane; ++i) {
|
|
auto step = ALIMIN(DST_XUNIT, plane - (i / DST_XUNIT) * DST_XUNIT);
|
|
auto sind0 = (i / DST_XUNIT) * DST_XUNIT * blockNum + (i % DST_XUNIT);
|
|
for (int k = 0; k < blockNum; ++k) {
|
|
auto sind = sind0 + k * step;
|
|
auto qind = i + k * plane;
|
|
auto max_ = (float)maxPtr[qind];
|
|
auto min_ = (float)minPtr[qind];
|
|
auto range = max_ - min_;
|
|
if (fabs(range) < 1e-7) {
|
|
qscale[qind] = 0.f;
|
|
qbias[qind] = 0.f;
|
|
scale[sind] = 0.f;
|
|
bias[sind] = max_;
|
|
} else {
|
|
qscale[qind] = 255.f / range;
|
|
qbias[qind] = -min_ * 255.f / range - 128.0f;
|
|
scale[sind] = range / 255.f;
|
|
bias[sind] = min_ + (128.f / 255.f) * range;
|
|
|
|
}
|
|
}
|
|
}
|
|
#endif
|
|
}
|
|
|
|
#endif // MNN_LOW_MEMORY
|
|
|
|
static CoreFunctions* gInstance = nullptr;
|
|
static CoreInt8Functions* gArm82CoreInt8Functions = nullptr;
|
|
|
|
bool Arm82Functions::init() {
|
|
using Vec = MNN::Math::Vec<FLOAT16, 8>;
|
|
auto origin = MNNGetCoreFunctions();
|
|
#define FUNC_PTR_ASSIGN(dst, src) dst = (decltype(dst))(src)
|
|
gInstance = new CoreFunctions;
|
|
gArm82CoreInt8Functions = new CoreInt8Functions;
|
|
*gArm82CoreInt8Functions = *MNNGetInt8CoreFunctions();
|
|
gInstance->int8MatmulRelatedFunctions = origin->int8MatmulRelatedFunctions;
|
|
gInstance->sme2Int8MatmulRelatedFuncionsHp32 = origin->sme2Int8MatmulRelatedFuncionsHp32;
|
|
{
|
|
if (origin->supportSDot) {
|
|
gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<12, 4>;
|
|
gInstance->supportSDot = true;
|
|
}
|
|
if (origin->supportI8mm) {
|
|
gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A_L8<10, 8>;
|
|
gInstance->supportI8mm = true;
|
|
}
|
|
}
|
|
|
|
FUNC_PTR_ASSIGN(gInstance->MNNFp32ToFp8, MNNFp32ToFp8);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNFp16ToFp8, MNNFp16ToFp8);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNFp8ToFp32, MNNFp8ToFp32);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNFp8ToFp16, MNNFp8ToFp16);
|
|
|
|
FUNC_PTR_ASSIGN(gInstance->MNNFp32ToLowp, MNNQuantizeFP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNLowpToFp32, MNNDequantizeFP16);
|
|
gInstance->bytes = 2;
|
|
|
|
// Packed
|
|
gInstance->pack = 8;
|
|
FUNC_PTR_ASSIGN(gInstance->MNNPackCUnit, MNNPackC8FP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNUnpackCUnit, MNNUnPackC8FP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNPackCUnitTranspose, MNNPackTransposeInt16C8);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNUnpackCUnitTranspose, MNNUnpackTransposeInt16C8);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNConvRunForLineDepthwise, MNNConvRunForLineDepthwiseFP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNAxByClampBroadcastUnit, MNNAxByClampBroadcastC8FP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNMatrixSub, MNNMatrixSubFP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNMatrixAdd, MNNMatrixAddFP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNStrassenMergeCFunction, ARM82StrassenMerge);
|
|
gInstance->MNNReorderWeightInt4 = origin->MNNReorderWeightInt4;
|
|
gInstance->MNNSumWeightInt8 = origin->MNNSumWeightInt8;
|
|
gInstance->MNNSumWeightInt8SmeHp64 = origin->MNNSumWeightInt8SmeHp64;
|
|
gInstance->penalty = 2.0f;
|
|
FUNC_PTR_ASSIGN(gInstance->MNNScaleAndAddBias, MNNScaleAndAddBiasFP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNGridSampleComputeCord, MNNGridSampleComputeCordFP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNGridSampleInterp, MNNGridSampleInterp);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNGridSampleInterpGrad, MNNGridSampleInterpGrad);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNGridSampleComputeCord3D, MNNGridSampleComputeCord3DFp16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNGridSampleInterp3D, MNNGridSampleInterp3D);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNRoiPoolingMax, MNNRoiPoolingMaxFP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNRoiAlignMax, MNNRoiAlignMaxFP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNRoiAlignAvg, MNNRoiAlignAvgFP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNCopyC4WithStride, MNNCopyC8WithStrideFP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNAddC4WithStride, MNNAddC8WithStrideFP16);
|
|
|
|
// MatMul
|
|
FUNC_PTR_ASSIGN(gInstance->MNNGetMatMulPackMode, Arm82MNNGetMatMulPackMode);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul, MNNPackedMatMulFP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain, MNNPackedMatMulRemainFP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Arm82MNNPackForMatMul_A);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNPackForMatMul_B, Arm82MNNPackForMatMul_B);
|
|
|
|
FUNC_PTR_ASSIGN(gInstance->MNNSoftmax, origin->MNNSoftmax);
|
|
#if defined(__aarch64__)
|
|
gInstance->supportFp16arith = origin->supportFp16arith;
|
|
gInstance->supportSDot = origin->supportSDot;
|
|
gInstance->supportI8mm = origin->supportI8mm;
|
|
gInstance->supportSME2 = origin->supportSME2;
|
|
#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM
|
|
// Weight Dequant Gemm Kernels
|
|
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul_int8, MNNPackedMatMulFP16_int8);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain_int8, MNNPackedMatMulRemainFP16_int8);
|
|
#endif
|
|
#ifdef MNN_LOW_MEMORY
|
|
// Dynamic Qaunt Helper Functions
|
|
FUNC_PTR_ASSIGN(gInstance->MNNAbsMax, MNNAbsMaxFP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNQuantScale, MNNQuantScaleFP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNDynamicQuant, MNNDynamicQuantFP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNAsyQuantFunc, MNNAsyQuantFunc_Arm82);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNAsyQuantInfo, MNNAsyQuantInfo_FP16); // return 'plane' min&max
|
|
FUNC_PTR_ASSIGN(gInstance->MNNDynamicUpdateConvBiasScale, origin->MNNDynamicUpdateConvBiasScale);
|
|
|
|
if (origin->supportSDot) {
|
|
FUNC_PTR_ASSIGN(gInstance->MNNGeneralIm2Col, MNNGeneralIm2col_Arm82);
|
|
}
|
|
if (origin->supportI8mm) {
|
|
FUNC_PTR_ASSIGN(gInstance->MNNGeneralIm2Col, MNNGeneralIm2col_Arm86);
|
|
}
|
|
#endif // MNN_LOW_MEMORY
|
|
FUNC_PTR_ASSIGN(gInstance->MNNCountMaxMinValue, ARM82CountMinMaxValue); // return one min&max
|
|
FUNC_PTR_ASSIGN(gInstance->MNNSumByAxisLForMatmul_A, origin->MNNSumByAxisLForMatmul_A);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNDepthwiseConvFastKernel, MNNDepthwiseConvFastKernelFP16);
|
|
#endif // __aarch64__
|
|
|
|
#ifdef MNN_SUPPORT_TRANSFORMER_FUSE
|
|
// Attention
|
|
FUNC_PTR_ASSIGN(gInstance->MNNAttenUnpackAndConvertFp16, MNNAttenUnpackAndConvertFp16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNAttenPackAndConvertFp32, MNNAttenPackAndConvertFp32);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNAttenPackAndScaleSingleHead, MNNAttenPackAndScaleSingleHead);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNFlashAttentionUpdateBlockOutput, MNNFlashAttentionUpdateBlockOutput);
|
|
#endif // MNN_SUPPORT_TRANSFORMER_FUSE
|
|
|
|
gInstance->MNNComputeMatMulForH_1 = _MNNComputeMatMulForH_1_FP16;
|
|
gInstance->MNNComputeMatMulForE_1 = _MNNComputeMatMulForE_1_FP16;
|
|
|
|
FUNC_PTR_ASSIGN(gInstance->chooseWinoSourceTransformPack, Arm82WinogradFunction::chooseWinoSourceTransformPack);
|
|
FUNC_PTR_ASSIGN(gInstance->chooseWinoSourceUnrollTransform, Arm82WinogradFunction::chooseSourceUnrollTransform);
|
|
FUNC_PTR_ASSIGN(gInstance->chooseWinoDestUnrollTransform, Arm82WinogradFunction::chooseWinoDestUnrollTransform);
|
|
|
|
gInstance->MNNDeconvRunForLineDepthwise = (decltype(gInstance->MNNDeconvRunForLineDepthwise))_MNNDeconvRunForLineDepthwise;
|
|
gInstance->MNNDeconvRunForUnitDepthWise = (decltype(gInstance->MNNDeconvRunForUnitDepthWise))_MNNDeconvRunForUnitDepthWise;
|
|
|
|
// Binary and Unary
|
|
gInstance->MNNSelectBinaryFunctionForFloat = Arm82BinaryFloat::select;
|
|
gInstance->MNNSelectUnaryFunctionForFloat = Arm82Unary::select;
|
|
|
|
// Relu with slope
|
|
gInstance->MNNReluWithSlopeChannel = Arm82Relu::reluWithSlopeChannel;
|
|
|
|
gInstance->MNNPoolingMax = (decltype(gInstance->MNNPoolingMax))(poolingMax<float16_t, Vec, 8, -65535>);
|
|
gInstance->MNNPoolingAvg = (decltype(gInstance->MNNPoolingAvg))(poolingAvg<float16_t, Vec, 8>);
|
|
|
|
{
|
|
gInstance->int8MatmulRelatedFunctions.MNNPackC4Int8ForMatMul_A = gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A;
|
|
gInstance->int8MatmulRelatedFunctions.MNNGeneralIm2Col = gInstance->MNNGeneralIm2Col;
|
|
}
|
|
#ifdef __aarch64__
|
|
#ifdef MNN_SME2
|
|
if (origin->supportSME2) {
|
|
gArm82CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<16, 4>;
|
|
gInstance->sme2Int8MatmulRelatedFuncionsHp32.MNNPackC4Int8ForMatMul_A = _Arm82MNNPackC4ForMatMul_A<16, 4>;
|
|
|
|
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul, MNNPackedMatMulFP16_SME2);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain, MNNPackedMatMulRemainFP16_SME2);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNGetMatMulPackMode, Sme2MNNGetMatMulPackMode);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Sme2MNNPackForMatMul_A_FP16);
|
|
FUNC_PTR_ASSIGN(gInstance->MNNPackForMatMul_B, Sme2MNNPackForMatMul_B);
|
|
|
|
#ifdef MNN_LOW_MEMORY
|
|
FUNC_PTR_ASSIGN(gInstance->MNNGeneralIm2Col, MNNGeneralIm2col_Fp16Sme2);
|
|
gInstance->sme2Int8MatmulRelatedFuncionsHp32.MNNGeneralIm2Col = MNNGeneralIm2col_Fp16Sme2;
|
|
#endif
|
|
}
|
|
#endif // MNN_SME2
|
|
#endif // __aarch64__
|
|
|
|
return true;
|
|
}
|
|
|
|
CoreFunctions* Arm82Functions::get() {
|
|
return gInstance;
|
|
}
|
|
CoreInt8Functions* Arm82Functions::getInt8() {
|
|
return gArm82CoreInt8Functions;
|
|
}
|
|
};
|
|
#endif
|