mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			217 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			217 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			C++
		
	
	
	
#include <vector>
 | 
						|
#include "BF16Unary.hpp"
 | 
						|
#include "VecHalf.hpp"
 | 
						|
#include "math/Vec.hpp"
 | 
						|
#include "backend/cpu/UnaryUtils.hpp"
 | 
						|
#include "BF16Backend.hpp"
 | 
						|
namespace MNN {
 | 
						|
 | 
						|
using Vec4Half = MNN::Math::VecHalf<4>;
 | 
						|
using Vec4 = MNN::Math::Vec<float, 4>;
 | 
						|
 | 
						|
struct Vec4Square {
 | 
						|
    Vec4Half operator()(Vec4Half &x) const {
 | 
						|
        return x * x;
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
struct Vec4Neg {
 | 
						|
    Vec4Half operator()(Vec4Half &x) const {
 | 
						|
        return -x;
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
struct Vec4Abs {
 | 
						|
    Vec4Half operator()(Vec4Half &x) const {
 | 
						|
        float v[4];
 | 
						|
        v[0] = fabs(x.value[0]);
 | 
						|
        v[1] = fabs(x.value[1]);
 | 
						|
        v[2] = fabs(x.value[2]);
 | 
						|
        v[3] = fabs(x.value[3]);
 | 
						|
        auto c = Vec4::load(v);
 | 
						|
        Vec4Half value;
 | 
						|
        value.value = std::move(c.value);
 | 
						|
        return value;
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
template<typename Compute>
 | 
						|
void BF16VecUnary(void *dstRaw, const void *src0Raw, int elementSize) {
 | 
						|
    Compute Func;
 | 
						|
    auto dst = (int16_t*)dstRaw;
 | 
						|
    auto src0 = (int16_t*)src0Raw;
 | 
						|
    const int sizeDivUnit = elementSize / 4;
 | 
						|
    const int remainCount = elementSize - sizeDivUnit * 4;
 | 
						|
 | 
						|
    if (sizeDivUnit > 0) {
 | 
						|
        for (int i = 0; i < sizeDivUnit; ++i) {
 | 
						|
            Vec4Half a = Vec4Half::load(src0);
 | 
						|
            Vec4Half::save(dst, Func(a));
 | 
						|
            src0 += 4;
 | 
						|
            dst += 4;
 | 
						|
        }
 | 
						|
    }
 | 
						|
    if (remainCount > 0) {
 | 
						|
        int16_t tempSrc0[4];
 | 
						|
        int16_t tempDst[4];
 | 
						|
        ::memcpy(tempSrc0, src0, remainCount * sizeof(int16_t));
 | 
						|
        Vec4Half a = Vec4Half::load(tempSrc0);
 | 
						|
        Vec4Half::save(tempDst, Func(a));
 | 
						|
        ::memcpy(dst, tempDst, remainCount * sizeof(int16_t));
 | 
						|
    }
 | 
						|
}
 | 
						|
#define BLOCK_SIZE 16
 | 
						|
template<typename Compute>
 | 
						|
static void _Wrap(void* outRaw, const void* inpRaw, int realSize) {
 | 
						|
    Compute execute;
 | 
						|
    float out[BLOCK_SIZE];
 | 
						|
    float inp[BLOCK_SIZE];
 | 
						|
    int b = realSize / BLOCK_SIZE;
 | 
						|
    int remain = realSize % BLOCK_SIZE;
 | 
						|
    auto bf16F = BF16Functions::get();
 | 
						|
    auto outR = (int16_t*)outRaw;
 | 
						|
    auto inpR = (const int16_t*)inpRaw;
 | 
						|
    for (int i=0; i<b; ++i) {
 | 
						|
        bf16F->MNNLowpToFp32(inpR, inp, BLOCK_SIZE);
 | 
						|
        execute(out, inp, BLOCK_SIZE);
 | 
						|
        bf16F->MNNFp32ToLowp(out, outR, BLOCK_SIZE);
 | 
						|
        outR += BLOCK_SIZE;
 | 
						|
        inpR += BLOCK_SIZE;
 | 
						|
    }
 | 
						|
    if (remain > 0) {
 | 
						|
        bf16F->MNNLowpToFp32(inpR, inp, remain);
 | 
						|
        execute(out, inp, remain);
 | 
						|
        bf16F->MNNFp32ToLowp(out, outR, remain);
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
struct _Exp {
 | 
						|
    void operator()(void* outRaw, const void* inpRaw, int realSize) const {
 | 
						|
        auto out = (float*)outRaw;
 | 
						|
        auto inp = (const float*)inpRaw;
 | 
						|
        MNNScaleAndAddBiasScalar(out, inp, 0.0f, -1.0f, realSize);
 | 
						|
        MNNExp(out, out, realSize);
 | 
						|
    }
 | 
						|
};
 | 
						|
struct _ExpM1 {
 | 
						|
    void operator()(void* outRaw, const void* inpRaw, int realSize) const {
 | 
						|
        auto out = (float*)outRaw;
 | 
						|
        auto inp = (const float*)inpRaw;
 | 
						|
        MNNScaleAndAddBiasScalar(out, inp, 0.0f, -1.0f, realSize);
 | 
						|
        MNNExp(out, out, realSize);
 | 
						|
        for (int i=0; i<realSize; ++i) {
 | 
						|
            out[i] = out[i] - 1.0f;
 | 
						|
        }
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
struct _Tanh {
 | 
						|
    void operator()(void* outRaw, const void* inpRaw, int realSize) const {
 | 
						|
        auto out = (float*)outRaw;
 | 
						|
        auto inp = (const float*)inpRaw;
 | 
						|
        MNNTanh(out, inp, realSize);
 | 
						|
    }
 | 
						|
};
 | 
						|
struct _Sigmoid {
 | 
						|
    void operator()(void* outRaw, const void* inpRaw, int realSize) const {
 | 
						|
        auto out = (float*)outRaw;
 | 
						|
        auto inp = (const float*)inpRaw;
 | 
						|
        MNNSigmoidLowp(out, inp, realSize);
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
struct _HardSwish {
 | 
						|
    void operator()(void* outRaw, const void* inpRaw, int realSize) const {
 | 
						|
        auto out = (float*)outRaw;
 | 
						|
        auto inp = (const float*)inpRaw;
 | 
						|
        MNNHardSwishCommon(out, inp, realSize);
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
template <typename Func, typename T>
 | 
						|
struct _Unary {
 | 
						|
    void operator()(void* outputPtr, const void* inputPtr, int elementSize) const {
 | 
						|
        Func f;
 | 
						|
        const T *inputData = (T*)inputPtr;
 | 
						|
        T *outputData      = (T *)outputPtr;
 | 
						|
        for (int i=0; i<elementSize; ++i) {
 | 
						|
            outputData[i] = f(inputData[i]);
 | 
						|
        }
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
MNNUnaryExecute BF16UnaryFloatSelect(int type, int precision) {
 | 
						|
    switch (type) {
 | 
						|
        case UnaryOpOperation_ABS:
 | 
						|
            return BF16VecUnary<Vec4Abs>;
 | 
						|
        case UnaryOpOperation_SQUARE:
 | 
						|
            return BF16VecUnary<Vec4Square>;
 | 
						|
        case UnaryOpOperation_NEG:
 | 
						|
            return BF16VecUnary<Vec4Neg>;
 | 
						|
        case UnaryOpOperation_RSQRT:
 | 
						|
            return _Wrap<_Unary<UnaryRsqrt<float>, float>>;
 | 
						|
        case UnaryOpOperation_EXP:
 | 
						|
            return _Wrap<_Exp>;
 | 
						|
        case UnaryOpOperation_COS:
 | 
						|
            return _Wrap<_Unary<UnaryCos<float>, float>>;
 | 
						|
        case UnaryOpOperation_SIN:
 | 
						|
            return _Wrap<_Unary<UnarySin<float>, float>>;
 | 
						|
        case UnaryOpOperation_SIGMOID:
 | 
						|
            return _Wrap<_Sigmoid>;
 | 
						|
        case UnaryOpOperation_TANH:
 | 
						|
            return _Wrap<_Tanh>;
 | 
						|
        case UnaryOpOperation_TAN:
 | 
						|
            return _Wrap<_Unary<UnaryTan<float>, float>>;
 | 
						|
        case UnaryOpOperation_ATAN:
 | 
						|
            return _Wrap<_Unary<UnaryATan<float>, float>>;
 | 
						|
        case UnaryOpOperation_SQRT:
 | 
						|
            return _Wrap<_Unary<UnarySqrt<float>, float>>;
 | 
						|
        case UnaryOpOperation_CEIL:
 | 
						|
            return _Wrap<_Unary<UnaryCeil<float>, float>>;
 | 
						|
        case UnaryOpOperation_RECIPROCAL:
 | 
						|
            return _Wrap<_Unary<UnaryRecipocal<float>, float>>;
 | 
						|
        case UnaryOpOperation_LOG1P:
 | 
						|
            return _Wrap<_Unary<UnaryLog1p<float>, float>>;
 | 
						|
        case UnaryOpOperation_LOG:
 | 
						|
            return _Wrap<_Unary<UnaryLog<float>, float>>;
 | 
						|
        case UnaryOpOperation_FLOOR:
 | 
						|
            return _Wrap<_Unary<UnaryFloor<float>, float>>;
 | 
						|
        case UnaryOpOperation_BNLL:
 | 
						|
            return _Wrap<_Unary<UnaryBNLL<float>, float>>;
 | 
						|
        case UnaryOpOperation_ACOSH:
 | 
						|
            return _Wrap<_Unary<UnaryAcosh<float>, float>>;
 | 
						|
        case UnaryOpOperation_SINH:
 | 
						|
            return _Wrap<_Unary<UnarySinh<float>, float>>;
 | 
						|
        case UnaryOpOperation_ASINH:
 | 
						|
            return _Wrap<_Unary<UnaryAsinh<float>, float>>;
 | 
						|
        case UnaryOpOperation_ATANH:
 | 
						|
            return _Wrap<_Unary<UnaryAtanh<float>, float>>;
 | 
						|
        case UnaryOpOperation_SIGN:
 | 
						|
            return _Wrap<_Unary<UnarySign<float>, float>>;
 | 
						|
        case UnaryOpOperation_ROUND:
 | 
						|
            return _Wrap<_Unary<UnaryRound<float>, float>>;
 | 
						|
        case UnaryOpOperation_COSH:
 | 
						|
            return _Wrap<_Unary<UnaryCosh<float>, float>>;
 | 
						|
        case UnaryOpOperation_ERF:
 | 
						|
            return _Wrap<_Unary<UnaryErf<float>, float>>;
 | 
						|
        case UnaryOpOperation_ERFC:
 | 
						|
            return _Wrap<_Unary<UnaryErfc<float>, float>>;
 | 
						|
        case UnaryOpOperation_ERFINV:
 | 
						|
            return _Wrap<_Unary<UnaryErfinv<float>, float>>;
 | 
						|
        case UnaryOpOperation_EXPM1:
 | 
						|
            return _Wrap<_ExpM1>;
 | 
						|
        case UnaryOpOperation_ASIN:
 | 
						|
            return _Wrap<_Unary<UnaryAsin<float>, float>>;
 | 
						|
        case UnaryOpOperation_ACOS:
 | 
						|
            return _Wrap<_Unary<UnaryAcos<float>, float>>;
 | 
						|
        case UnaryOpOperation_HARDSWISH:
 | 
						|
            return _Wrap<_HardSwish>;
 | 
						|
        default:
 | 
						|
            MNN_ASSERT(false);
 | 
						|
            break;
 | 
						|
    }
 | 
						|
    return nullptr;
 | 
						|
}
 | 
						|
 | 
						|
};
 |