MNN/source/backend/cpu/BinaryUtils.hpp

131 lines
3.7 KiB
C++
Raw Normal View History

#include <math.h>
#include <algorithm>
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryMax {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return std::max(x, y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryMin {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return std::min(x, y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryMul {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return x * y;
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryAdd {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return x + y;
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinarySub {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return x - y;
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryRealDiv {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return x / y;
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryMod {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return x - x / y;
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryGreater {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)((x > y) ? 1 : 0);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryLess {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)((x < y) ? 1 : 0);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryGreaterEqual {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)((x >= y) ? 1 : 0);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryLessEqual {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)((x <= y) ? 1 : 0);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryEqual {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)((x == y) ? 1 : 0);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryFloorDiv {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
2021-02-07 10:45:07 +08:00
return floor(static_cast<float>(x) / y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryFloorMod {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return x - floor(x / y) * y;
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinarySquaredDifference {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (x - y) * (x - y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryPow {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return pow(x, y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryAtan2 {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return atan(x / y);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryLogicalOr {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)((x || y) ? 1 : 0);
}
};
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
2021-04-08 15:34:23 +08:00
struct BinaryNotEqual {
_ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
return (_ErrorCode)((x != y) ? 1 : 0);
}
};