2021-01-07 17:08:34 +08:00
|
|
|
|
#include <math.h>
|
|
|
|
|
#include <algorithm>
|
2021-06-11 17:17:13 +08:00
|
|
|
|
#include "compute/CommonOptFunction.h"
|
|
|
|
|
#include "MNN_generated.h"
|
2021-01-07 17:08:34 +08:00
|
|
|
|
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryMax {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return std::max(x, y);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryMin {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return std::min(x, y);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryMul {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return x * y;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryAdd {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return x + y;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinarySub {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return x - y;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryRealDiv {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return x / y;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryMod {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return x - x / y;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryGreater {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return (_ErrorCode)((x > y) ? 1 : 0);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryLess {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return (_ErrorCode)((x < y) ? 1 : 0);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryGreaterEqual {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return (_ErrorCode)((x >= y) ? 1 : 0);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryLessEqual {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return (_ErrorCode)((x <= y) ? 1 : 0);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryEqual {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return (_ErrorCode)((x == y) ? 1 : 0);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryFloorDiv {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
2021-02-07 10:45:07 +08:00
|
|
|
|
return floor(static_cast<float>(x) / y);
|
2021-01-07 17:08:34 +08:00
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryFloorMod {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return x - floor(x / y) * y;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinarySquaredDifference {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return (x - y) * (x - y);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryPow {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return pow(x, y);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryAtan2 {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return atan(x / y);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryLogicalOr {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return (_ErrorCode)((x || y) ? 1 : 0);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
|
2021-04-08 15:34:23 +08:00
|
|
|
|
struct BinaryNotEqual {
|
2021-01-07 17:08:34 +08:00
|
|
|
|
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
|
|
|
|
|
return (_ErrorCode)((x != y) ? 1 : 0);
|
|
|
|
|
}
|
|
|
|
|
};
|
2021-06-11 17:17:13 +08:00
|
|
|
|
|
|
|
|
|
template<typename Func, typename V, int pack>
|
|
|
|
|
void executeVec(void* outputRaw, const void* inputRaw0, const void* inputRaw1, int elementSize, int needBroadcastIndex) {
|
|
|
|
|
Func compute;
|
|
|
|
|
const int sizeDivUnit = elementSize / pack;
|
|
|
|
|
const int remainCount = elementSize - sizeDivUnit * pack;
|
|
|
|
|
auto src0 = (const float*)(inputRaw0);
|
|
|
|
|
auto src1 = (const float*)(inputRaw1);
|
|
|
|
|
auto dst = (float*)outputRaw;
|
|
|
|
|
|
|
|
|
|
if (-1 == needBroadcastIndex) {
|
|
|
|
|
if (sizeDivUnit > 0) {
|
|
|
|
|
for (int i = 0; i < sizeDivUnit; ++i) {
|
|
|
|
|
V a = V::load(src0);
|
|
|
|
|
V b = V::load(src1);
|
|
|
|
|
V::save(dst, compute(a, b));
|
|
|
|
|
src0 += pack;
|
|
|
|
|
src1 += pack;
|
|
|
|
|
dst += pack;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (remainCount > 0) {
|
|
|
|
|
float tempSrc0[pack];
|
|
|
|
|
float tempSrc1[pack];
|
|
|
|
|
float tempDst[pack];
|
|
|
|
|
::memcpy(tempSrc0, src0, remainCount * sizeof(float));
|
|
|
|
|
::memcpy(tempSrc1, src1, remainCount * sizeof(float));
|
|
|
|
|
V a = V::load(tempSrc0);
|
|
|
|
|
V b = V::load(tempSrc1);
|
|
|
|
|
V::save(tempDst, compute(a, b));
|
|
|
|
|
::memcpy(dst, tempDst, remainCount * sizeof(float));
|
|
|
|
|
}
|
|
|
|
|
} else if (0 == needBroadcastIndex) {
|
|
|
|
|
const float srcValue0 = src0[0];
|
|
|
|
|
V a = V(srcValue0);
|
|
|
|
|
if (sizeDivUnit > 0) {
|
|
|
|
|
for (int i = 0; i < sizeDivUnit; ++i) {
|
|
|
|
|
const auto src1Ptr = src1;
|
|
|
|
|
auto dstPtr = dst;
|
|
|
|
|
V b = V::load(src1Ptr);
|
|
|
|
|
V::save(dstPtr, compute(a, b));
|
|
|
|
|
src1 += pack;
|
|
|
|
|
dst += pack;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (remainCount > 0) {
|
|
|
|
|
float tempSrc1[pack];
|
|
|
|
|
float tempDst[pack];
|
|
|
|
|
::memcpy(tempSrc1, src1, remainCount * sizeof(float));
|
|
|
|
|
V b = V::load(tempSrc1);
|
|
|
|
|
V::save(tempDst, compute(a, b));
|
|
|
|
|
::memcpy(dst, tempDst, remainCount * sizeof(float));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
const float srcValue1 = src1[0];
|
|
|
|
|
V b = V(srcValue1);
|
|
|
|
|
if (sizeDivUnit > 0) {
|
|
|
|
|
for (int i = 0; i < sizeDivUnit; ++i) {
|
|
|
|
|
const auto src0Ptr = src0;
|
|
|
|
|
auto dstPtr = dst;
|
|
|
|
|
V a = V::load(src0Ptr);
|
|
|
|
|
V::save(dstPtr, compute(a, b));
|
|
|
|
|
src0 += pack;
|
|
|
|
|
dst += pack;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (remainCount > 0) {
|
|
|
|
|
float tempSrc0[pack];
|
|
|
|
|
float tempDst[pack];
|
|
|
|
|
::memcpy(tempSrc0, src0, remainCount * sizeof(float));
|
|
|
|
|
V a = V::load(tempSrc0);
|
|
|
|
|
V::save(tempDst, compute(a, b));
|
|
|
|
|
::memcpy(dst, tempDst, remainCount * sizeof(float));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename Vec>
|
|
|
|
|
struct VecBinaryAdd {
|
|
|
|
|
Vec operator()(Vec& x, Vec& y) const {
|
|
|
|
|
return x + y;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template<typename Vec>
|
|
|
|
|
struct VecBinarySub {
|
|
|
|
|
Vec operator()(Vec& x, Vec& y) const {
|
|
|
|
|
return x - y;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template<typename Vec>
|
|
|
|
|
struct VecBinaryMul {
|
|
|
|
|
Vec operator()(Vec& x, Vec& y) const {
|
|
|
|
|
return x * y;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template<typename Vec>
|
|
|
|
|
struct VecBinaryMin {
|
|
|
|
|
Vec operator()(Vec& x, Vec& y) const {
|
|
|
|
|
return Vec::min(x, y);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template<typename Vec>
|
|
|
|
|
struct VecBinaryMax {
|
|
|
|
|
Vec operator()(Vec& x, Vec& y) const {
|
|
|
|
|
return Vec::max(x, y);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template<typename Vec>
|
|
|
|
|
struct VecBinarySqd {
|
|
|
|
|
Vec operator()(Vec& x, Vec& y) const {
|
|
|
|
|
return (x-y)*(x-y);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
namespace MNN {
|
|
|
|
|
template<typename Tin, typename Tout, typename Func>
|
|
|
|
|
void execute(void* outputRaw, const void* inputRaw0, const void* inputRaw1, int elementSize, int broadcastIndex) {
|
|
|
|
|
Func f;
|
|
|
|
|
const int input0DataCount = elementSize;
|
|
|
|
|
const int input1DataCount = elementSize;
|
|
|
|
|
const Tin* input0Data = (const Tin*)inputRaw0;
|
|
|
|
|
const Tin* input1Data = (const Tin*)inputRaw1;
|
|
|
|
|
Tout* outputData = (Tout*)outputRaw;
|
|
|
|
|
|
|
|
|
|
if (broadcastIndex == 0) { // data count == 1, not only mean scalar input, maybe of shape (1, 1, 1, ...,1)
|
|
|
|
|
for (int i = 0; i < input1DataCount; i++) {
|
|
|
|
|
outputData[i] = (Tout)(f(input0Data[0], input1Data[i]));
|
|
|
|
|
}
|
|
|
|
|
} else if (broadcastIndex == 1) {
|
|
|
|
|
for (int i = 0; i < input0DataCount; i++) {
|
|
|
|
|
outputData[i] = (Tout)(f(input0Data[i], input1Data[0]));
|
|
|
|
|
}
|
|
|
|
|
} else { // both input contains more than one element,which means no scalar input
|
|
|
|
|
for (int i = 0; i < input0DataCount; i++) {
|
|
|
|
|
outputData[i] = (Tout)(f(input0Data[i], input1Data[i]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename V, int pack>
|
|
|
|
|
MNNBinaryExecute selectVector(int type) {
|
|
|
|
|
switch (type) {
|
|
|
|
|
case BinaryOpOperation_ADD:
|
|
|
|
|
return executeVec<VecBinaryAdd<V>, V, pack>;
|
|
|
|
|
case BinaryOpOperation_SUB:
|
|
|
|
|
return executeVec<VecBinarySub<V>, V, pack>;
|
|
|
|
|
case BinaryOpOperation_MUL:
|
|
|
|
|
return executeVec<VecBinaryMul<V>, V, pack>;
|
|
|
|
|
case BinaryOpOperation_MINIMUM:
|
|
|
|
|
return executeVec<VecBinaryMin<V>, V, pack>;
|
|
|
|
|
case BinaryOpOperation_MAXIMUM:
|
|
|
|
|
return executeVec<VecBinaryMax<V>, V, pack>;
|
|
|
|
|
case BinaryOpOperation_SquaredDifference:
|
|
|
|
|
return executeVec<VecBinarySqd<V>, V, pack>;
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
};
|