MNN/source/backend/cpu/bf16/BF16Unary.cpp

217 lines
7.0 KiB
C++

#include <vector>
#include "BF16Unary.hpp"
#include "VecHalf.hpp"
#include "math/Vec.hpp"
#include "backend/cpu/UnaryUtils.hpp"
#include "BF16Backend.hpp"
namespace MNN {
using Vec4Half = MNN::Math::VecHalf<4>;
using Vec4 = MNN::Math::Vec<float, 4>;
struct Vec4Square {
Vec4Half operator()(Vec4Half &x) const {
return x * x;
}
};
struct Vec4Neg {
Vec4Half operator()(Vec4Half &x) const {
return -x;
}
};
struct Vec4Abs {
Vec4Half operator()(Vec4Half &x) const {
float v[4];
v[0] = fabs(x.value[0]);
v[1] = fabs(x.value[1]);
v[2] = fabs(x.value[2]);
v[3] = fabs(x.value[3]);
auto c = Vec4::load(v);
Vec4Half value;
value.value = std::move(c.value);
return value;
}
};
template<typename Compute>
void BF16VecUnary(void *dstRaw, const void *src0Raw, int elementSize) {
Compute Func;
auto dst = (int16_t*)dstRaw;
auto src0 = (int16_t*)src0Raw;
const int sizeDivUnit = elementSize / 4;
const int remainCount = elementSize - sizeDivUnit * 4;
if (sizeDivUnit > 0) {
for (int i = 0; i < sizeDivUnit; ++i) {
Vec4Half a = Vec4Half::load(src0);
Vec4Half::save(dst, Func(a));
src0 += 4;
dst += 4;
}
}
if (remainCount > 0) {
int16_t tempSrc0[4];
int16_t tempDst[4];
::memcpy(tempSrc0, src0, remainCount * sizeof(int16_t));
Vec4Half a = Vec4Half::load(tempSrc0);
Vec4Half::save(tempDst, Func(a));
::memcpy(dst, tempDst, remainCount * sizeof(int16_t));
}
}
#define BLOCK_SIZE 16
template<typename Compute>
static void _Wrap(void* outRaw, const void* inpRaw, int realSize) {
Compute execute;
float out[BLOCK_SIZE];
float inp[BLOCK_SIZE];
int b = realSize / BLOCK_SIZE;
int remain = realSize % BLOCK_SIZE;
auto bf16F = BF16Functions::get();
auto outR = (int16_t*)outRaw;
auto inpR = (const int16_t*)inpRaw;
for (int i=0; i<b; ++i) {
bf16F->MNNLowpToFp32(inpR, inp, BLOCK_SIZE);
execute(out, inp, BLOCK_SIZE);
bf16F->MNNFp32ToLowp(out, outR, BLOCK_SIZE);
outR += BLOCK_SIZE;
inpR += BLOCK_SIZE;
}
if (remain > 0) {
bf16F->MNNLowpToFp32(inpR, inp, remain);
execute(out, inp, remain);
bf16F->MNNFp32ToLowp(out, outR, remain);
}
}
struct _Exp {
void operator()(void* outRaw, const void* inpRaw, int realSize) const {
auto out = (float*)outRaw;
auto inp = (const float*)inpRaw;
MNNScaleAndAddBiasScalar(out, inp, 0.0f, -1.0f, realSize);
MNNExp(out, out, realSize);
}
};
struct _ExpM1 {
void operator()(void* outRaw, const void* inpRaw, int realSize) const {
auto out = (float*)outRaw;
auto inp = (const float*)inpRaw;
MNNScaleAndAddBiasScalar(out, inp, 0.0f, -1.0f, realSize);
MNNExp(out, out, realSize);
for (int i=0; i<realSize; ++i) {
out[i] = out[i] - 1.0f;
}
}
};
struct _Tanh {
void operator()(void* outRaw, const void* inpRaw, int realSize) const {
auto out = (float*)outRaw;
auto inp = (const float*)inpRaw;
MNNTanh(out, inp, realSize);
}
};
struct _Sigmoid {
void operator()(void* outRaw, const void* inpRaw, int realSize) const {
auto out = (float*)outRaw;
auto inp = (const float*)inpRaw;
MNNSigmoidLowp(out, inp, realSize);
}
};
struct _HardSwish {
void operator()(void* outRaw, const void* inpRaw, int realSize) const {
auto out = (float*)outRaw;
auto inp = (const float*)inpRaw;
MNNHardSwishCommon(out, inp, realSize);
}
};
template <typename Func, typename T>
struct _Unary {
void operator()(void* outputPtr, const void* inputPtr, int elementSize) const {
Func f;
const T *inputData = (T*)inputPtr;
T *outputData = (T *)outputPtr;
for (int i=0; i<elementSize; ++i) {
outputData[i] = f(inputData[i]);
}
}
};
MNNUnaryExecute BF16UnaryFloatSelect(int type, int precision) {
switch (type) {
case UnaryOpOperation_ABS:
return BF16VecUnary<Vec4Abs>;
case UnaryOpOperation_SQUARE:
return BF16VecUnary<Vec4Square>;
case UnaryOpOperation_NEG:
return BF16VecUnary<Vec4Neg>;
case UnaryOpOperation_RSQRT:
return _Wrap<_Unary<UnaryRsqrt<float>, float>>;
case UnaryOpOperation_EXP:
return _Wrap<_Exp>;
case UnaryOpOperation_COS:
return _Wrap<_Unary<UnaryCos<float>, float>>;
case UnaryOpOperation_SIN:
return _Wrap<_Unary<UnarySin<float>, float>>;
case UnaryOpOperation_SIGMOID:
return _Wrap<_Sigmoid>;
case UnaryOpOperation_TANH:
return _Wrap<_Tanh>;
case UnaryOpOperation_TAN:
return _Wrap<_Unary<UnaryTan<float>, float>>;
case UnaryOpOperation_ATAN:
return _Wrap<_Unary<UnaryATan<float>, float>>;
case UnaryOpOperation_SQRT:
return _Wrap<_Unary<UnarySqrt<float>, float>>;
case UnaryOpOperation_CEIL:
return _Wrap<_Unary<UnaryCeil<float>, float>>;
case UnaryOpOperation_RECIPROCAL:
return _Wrap<_Unary<UnaryRecipocal<float>, float>>;
case UnaryOpOperation_LOG1P:
return _Wrap<_Unary<UnaryLog1p<float>, float>>;
case UnaryOpOperation_LOG:
return _Wrap<_Unary<UnaryLog<float>, float>>;
case UnaryOpOperation_FLOOR:
return _Wrap<_Unary<UnaryFloor<float>, float>>;
case UnaryOpOperation_BNLL:
return _Wrap<_Unary<UnaryBNLL<float>, float>>;
case UnaryOpOperation_ACOSH:
return _Wrap<_Unary<UnaryAcosh<float>, float>>;
case UnaryOpOperation_SINH:
return _Wrap<_Unary<UnarySinh<float>, float>>;
case UnaryOpOperation_ASINH:
return _Wrap<_Unary<UnaryAsinh<float>, float>>;
case UnaryOpOperation_ATANH:
return _Wrap<_Unary<UnaryAtanh<float>, float>>;
case UnaryOpOperation_SIGN:
return _Wrap<_Unary<UnarySign<float>, float>>;
case UnaryOpOperation_ROUND:
return _Wrap<_Unary<UnaryRound<float>, float>>;
case UnaryOpOperation_COSH:
return _Wrap<_Unary<UnaryCosh<float>, float>>;
case UnaryOpOperation_ERF:
return _Wrap<_Unary<UnaryErf<float>, float>>;
case UnaryOpOperation_ERFC:
return _Wrap<_Unary<UnaryErfc<float>, float>>;
case UnaryOpOperation_ERFINV:
return _Wrap<_Unary<UnaryErfinv<float>, float>>;
case UnaryOpOperation_EXPM1:
return _Wrap<_ExpM1>;
case UnaryOpOperation_ASIN:
return _Wrap<_Unary<UnaryAsin<float>, float>>;
case UnaryOpOperation_ACOS:
return _Wrap<_Unary<UnaryAcos<float>, float>>;
case UnaryOpOperation_HARDSWISH:
return _Wrap<_HardSwish>;
default:
MNN_ASSERT(false);
break;
}
return nullptr;
}
};