MNN/source/backend/cpu/BinaryUtils.hpp

443 lines
14 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include <math.h>
#include <algorithm>
#include "compute/CommonOptFunction.h"
#include "MNN_generated.h"
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryMax {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return std::max(x, y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryMin {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return std::min(x, y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryMul {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return x * y;
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryAdd {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return x + y;
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinarySub {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return x - y;
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryRealDiv {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return x / y;
}
};
/**
Ref from onnxruntime/onnxruntime/core/providers/cpu/math/element_wise_ops.cc :: Modulus
*/
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryModInt {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
auto res = x % y;
if ((res < 0 && y > 0) || (res > 0 && y < 0)) {
res += y;
}
return (_ErrorCode)res;
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryMod {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return fmodf(x, y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryGreater {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)((x > y) ? 1 : 0);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryLess {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)((x < y) ? 1 : 0);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryGreaterEqual {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)((x >= y) ? 1 : 0);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryLessEqual {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)((x <= y) ? 1 : 0);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryEqual {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)((x == y) ? 1 : 0);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryFloorDiv {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return floor(static_cast<double>(x) / y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryFloorMod {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return x - floor(x / y) * y;
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinarySquaredDifference {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (x - y) * (x - y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryPow {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return pow(x, y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryAtan2 {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return atan2(x, y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryLogicalOr {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)((x || y) ? 1 : 0);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryLogicalXor {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)((x ^ y) ? 1 : 0);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryNotEqual {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)((x != y) ? 1 : 0);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryLeftShift {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)(x << y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryBitwiseAnd {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)(x & y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryRightShift {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)(x >> y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryBitwiseOr {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)(x | y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
struct BinaryBitwiseXor {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)(x ^ y);
}
};
template<typename Func, typename V, int pack, typename U, typename Tout>
void executeVec(void* outputRaw, const void* inputRaw0, const void* inputRaw1, int elementSize, int needBroadcastIndex) {
Func compute;
const int sizeDivUnit = elementSize / pack;
const int remainCount = elementSize - sizeDivUnit * pack;
auto src0 = (const U*)(inputRaw0);
auto src1 = (const U*)(inputRaw1);
auto dst = (Tout*)outputRaw;
if (-1 == needBroadcastIndex) {
if (sizeDivUnit > 0) {
for (int i = 0; i < sizeDivUnit; ++i) {
V a = V::load(src0);
V b = V::load(src1);
V::save(dst, compute(a, b));
src0 += pack;
src1 += pack;
dst += pack;
}
}
if (remainCount > 0) {
U tempSrc0[pack];
U tempSrc1[pack];
Tout tempDst[pack];
::memcpy(tempSrc0, src0, remainCount * sizeof(U));
::memcpy(tempSrc1, src1, remainCount * sizeof(U));
V a = V::load(tempSrc0);
V b = V::load(tempSrc1);
V::save(tempDst, compute(a, b));
::memcpy(dst, tempDst, remainCount * sizeof(U));
}
} else if (0 == needBroadcastIndex) {
const U srcValue0 = src0[0];
V a = V(srcValue0);
if (sizeDivUnit > 0) {
for (int i = 0; i < sizeDivUnit; ++i) {
const auto src1Ptr = src1;
auto dstPtr = dst;
V b = V::load(src1Ptr);
V::save(dstPtr, compute(a, b));
src1 += pack;
dst += pack;
}
}
if (remainCount > 0) {
U tempSrc1[pack];
Tout tempDst[pack];
::memcpy(tempSrc1, src1, remainCount * sizeof(U));
V b = V::load(tempSrc1);
V::save(tempDst, compute(a, b));
::memcpy(dst, tempDst, remainCount * sizeof(U));
}
} else {
const auto srcValue1 = static_cast<U>(src1[0]);
V b = V(srcValue1);
if (sizeDivUnit > 0) {
for (int i = 0; i < sizeDivUnit; ++i) {
const auto src0Ptr = src0;
auto dstPtr = dst;
V a = V::load(src0Ptr);
V::save(dstPtr, compute(a, b));
src0 += pack;
dst += pack;
}
}
if (remainCount > 0) {
U tempSrc0[pack];
Tout tempDst[pack];
::memcpy(tempSrc0, src0, remainCount * sizeof(U));
V a = V::load(tempSrc0);
V::save(tempDst, compute(a, b));
::memcpy(dst, tempDst, remainCount * sizeof(U));
}
}
}
template<typename Vec>
struct VecBinaryAdd {
Vec operator()(Vec& x, Vec& y) const {
return x + y;
}
};
template<typename Vec>
struct VecBinarySub {
Vec operator()(Vec& x, Vec& y) const {
return x - y;
}
};
template<typename Vec>
struct VecBinaryMul {
Vec operator()(Vec& x, Vec& y) const {
return x * y;
}
};
template<typename Vec>
struct VecBinaryMin {
Vec operator()(Vec& x, Vec& y) const {
return Vec::min(x, y);
}
};
template<typename Vec>
struct VecBinaryMax {
Vec operator()(Vec& x, Vec& y) const {
return Vec::max(x, y);
}
};
template<typename Vec>
struct VecBinarySqd {
Vec operator()(Vec& x, Vec& y) const {
return (x-y)*(x-y);
}
};
template<typename Vec>
struct VecBinaryLess {
Vec operator()(Vec& x, Vec& y) const {
return x < y;
}
};
template<typename Vec>
struct VecBinaryGreater {
Vec operator()(Vec& x, Vec& y) const {
return x > y;
}
};
template<typename Vec>
struct VecBinaryLessEqual {
Vec operator()(Vec& x, Vec& y) const {
return x <= y;
}
};
template<typename Vec>
struct VecBinaryGreaterEqual {
Vec operator()(Vec& x, Vec& y) const {
return x >= y;
}
};
template<typename Vec>
struct VecBinaryEqual {
Vec operator()(Vec& x, Vec& y) const {
return x == y;
}
};
namespace MNN {
template<typename Tin, typename Tout, typename Func>
void execute(void* outputRaw, const void* inputRaw0, const void* inputRaw1, int elementSize, int broadcastIndex) {
Func f;
const int input0DataCount = elementSize;
const int input1DataCount = elementSize;
const Tin* input0Data = (const Tin*)inputRaw0;
const Tin* input1Data = (const Tin*)inputRaw1;
Tout* outputData = (Tout*)outputRaw;
if (broadcastIndex == 0) { // data count == 1, not only mean scalar input, maybe of shape (1, 1, 1, ...,1)
for (int i = 0; i < input1DataCount; i++) {
outputData[i] = (Tout)(f(input0Data[0], input1Data[i]));
}
} else if (broadcastIndex == 1) {
for (int i = 0; i < input0DataCount; i++) {
outputData[i] = (Tout)(f(input0Data[i], input1Data[0]));
}
} else { // both input contains more than one elementwhich means no scalar input
for (int i = 0; i < input0DataCount; i++) {
outputData[i] = (Tout)(f(input0Data[i], input1Data[i]));
}
}
}
template<typename Tin, typename Tout, typename Func>
void executeInt8 (int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const QuanPrePostParameters* params, size_t elementSize, size_t needBroadcast) {
Func f;
int size = static_cast<int>(elementSize);
#ifdef MNN_USE_NEON
size *= 4;
#endif
float inp0 = 0, inp1 = 0, output = 0;
#ifdef MNN_USE_SSE
const int offset = 128;
const uint8_t* inputData0 = (uint8_t*)inputRaw0;
const uint8_t* inputData1 = (uint8_t*)inputRaw1;
uint8_t* outputData = (uint8_t*)outputRaw;
#else
const int offset = 0;
const int8_t* inputData0 = (int8_t*)inputRaw0;
const int8_t* inputData1 = (int8_t*)inputRaw1;
int8_t* outputData = (int8_t*)outputRaw;
#endif
const int maxValue = static_cast<int32_t>(params->maxValue) + offset;
const int minValue = static_cast<int32_t>(params->minValue) + offset;
for (int i = 0; i < size; ++i) {
if (needBroadcast == 0) {
inp0 = (inputData0[0]- offset - params->inputZeroPoint[0]) * inputScalesFp32[0];
inp1 = (inputData1[i]- offset - params->inputZeroPoint[1]) * inputScalesFp32[1];
output = f(inp0, inp1);
} else if (needBroadcast == 1) {
inp0 = (inputData0[i] - offset - params->inputZeroPoint[0]) * inputScalesFp32[0];
inp1 = (inputData1[0] - offset - params->inputZeroPoint[1]) * inputScalesFp32[1];
output = f(inp0, inp1);
} else {
inp0 = (inputData0[i] - offset - params->inputZeroPoint[0]) * inputScalesFp32[0];
inp1 = (inputData1[i] - offset - params->inputZeroPoint[1]) * inputScalesFp32[1];
output = f(inp0, inp1);
}
int value = (int)roundf(output * inputScalesFp32[2]) + offset + static_cast<int32_t>(params->outputZeroPoint[0]);
if (value > maxValue) {
value = maxValue;
}
if (value < minValue) {
value = minValue;
}
outputData[i] = value;
}
}
template<typename V, int pack, typename U>
MNNBinaryExecute selectVector(int type) {
switch (type) {
case BinaryOpOperation_ADD:
return executeVec<VecBinaryAdd<V>, V, pack, U, U>;
case BinaryOpOperation_SUB:
return executeVec<VecBinarySub<V>, V, pack, U, U>;
case BinaryOpOperation_MUL:
return executeVec<VecBinaryMul<V>, V, pack, U, U>;
case BinaryOpOperation_MINIMUM:
return executeVec<VecBinaryMin<V>, V, pack, U, U>;
case BinaryOpOperation_MAXIMUM:
return executeVec<VecBinaryMax<V>, V, pack, U, U>;
case BinaryOpOperation_SquaredDifference:
return executeVec<VecBinarySqd<V>, V, pack, U, U>;
case BinaryOpOperation_LESS:
return executeVec<VecBinaryLess<V>, V, pack, U, int32_t>;
case BinaryOpOperation_LESS_EQUAL:
return executeVec<VecBinaryLessEqual<V>, V, pack, U, int32_t>;
case BinaryOpOperation_GREATER:
return executeVec<VecBinaryGreater<V>, V, pack, U, int32_t>;
case BinaryOpOperation_GREATER_EQUAL:
return executeVec<VecBinaryGreaterEqual<V>, V, pack, U, int32_t>;
case BinaryOpOperation_EQUAL:
return executeVec<VecBinaryEqual<V>, V, pack, U, int32_t>;
}
return nullptr;
}
};