2019-04-17 10:49:11 +08:00
|
|
|
//
|
|
|
|
// CPUBinary.cpp
|
|
|
|
// MNN
|
|
|
|
//
|
|
|
|
// Created by MNN on 2018/08/02.
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
//
|
|
|
|
|
2020-01-15 13:33:47 +08:00
|
|
|
#include "CPUBinary.hpp"
|
2023-04-27 15:11:05 +08:00
|
|
|
#include "CPUBinaryInt8.hpp"
|
2020-01-15 13:33:47 +08:00
|
|
|
#include "CPUBackend.hpp"
|
|
|
|
#include "compute/CommonOptFunction.h"
|
|
|
|
#include "compute/ConvOpt.h"
|
2019-12-27 22:16:57 +08:00
|
|
|
#include "core/Macro.h"
|
2020-01-15 13:33:47 +08:00
|
|
|
#include "core/Concurrency.h"
|
2020-03-08 10:20:18 +08:00
|
|
|
#include "core/OpCommonUtils.hpp"
|
2021-01-07 17:08:34 +08:00
|
|
|
#include "BinaryUtils.hpp"
|
2021-06-11 17:17:13 +08:00
|
|
|
#include "math/Vec.hpp"
|
|
|
|
using Vec4 = MNN::Math::Vec<float, 4>;
|
2023-10-18 10:31:02 +08:00
|
|
|
using Vec4Int = MNN::Math::Vec<int32_t, 4>;
|
2021-06-11 17:17:13 +08:00
|
|
|
|
2019-04-17 10:49:11 +08:00
|
|
|
namespace MNN {
|
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
ErrorCode CPUBinary::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
2023-08-21 14:51:54 +08:00
|
|
|
auto input0DataCount = TensorUtils::getRawSize(inputs[0]);
|
|
|
|
auto input1DataCount = TensorUtils::getRawSize(inputs[1]);
|
2020-02-26 09:57:17 +08:00
|
|
|
if (input1DataCount == input0DataCount) {
|
2021-06-11 17:17:13 +08:00
|
|
|
mNeedBroadcastIndex = -1;
|
|
|
|
} else if (input0DataCount == 1) {
|
|
|
|
mNeedBroadcastIndex = 0;
|
|
|
|
} else {
|
|
|
|
mNeedBroadcastIndex = 1;
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
2023-08-21 14:51:54 +08:00
|
|
|
mTotalSize = ((CPUBackend*)backend())->getTensorSize(outputs[0]);
|
2022-09-30 10:02:52 +08:00
|
|
|
|
|
|
|
if(mActivationType == 1 && outputs[0]->getType().code == halide_type_float) {
|
|
|
|
mActivationExe.reset(new CPURelu(backend(), 0.0));
|
|
|
|
mActivationExe->onResize(outputs, outputs);
|
|
|
|
}
|
2024-12-02 10:12:08 +08:00
|
|
|
const int threads = static_cast<CPUBackend*>(backend())->threadNumber();
|
|
|
|
if (static_cast<CPUBackend*>(backend())->getTensorSize(outputs[0], false) < LAUNCH_MULTI_THREADS_WORKLOAD) {
|
|
|
|
mThreadNum = 1;
|
|
|
|
mWorkDiv = mTotalSize;
|
|
|
|
} else {
|
|
|
|
mThreadNum = threads;
|
|
|
|
mWorkDiv = UP_DIV(mTotalSize, threads);
|
|
|
|
}
|
2019-04-17 10:49:11 +08:00
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
ErrorCode CPUBinary::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
2019-04-17 10:49:11 +08:00
|
|
|
auto input = inputs[0];
|
|
|
|
auto input1 = inputs[1];
|
|
|
|
auto output = outputs[0];
|
2024-12-02 10:12:08 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
auto input0Ptr = input->host<uint8_t>();
|
|
|
|
auto input1Ptr = input1->host<uint8_t>();
|
2022-09-30 10:02:52 +08:00
|
|
|
|
|
|
|
auto outputPtr = outputs[0]->host<uint8_t>();
|
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
int inpBytes = input->getType().bytes();
|
|
|
|
int outBytes = output->getType().bytes();
|
|
|
|
if (halide_type_float == input->getType().code) {
|
|
|
|
inpBytes = static_cast<CPUBackend*>(backend())->functions()->bytes;
|
|
|
|
}
|
|
|
|
if (halide_type_float == output->getType().code) {
|
|
|
|
outBytes = static_cast<CPUBackend*>(backend())->functions()->bytes;
|
|
|
|
}
|
|
|
|
auto precision = static_cast<CPUBackend*>(backend())->precisionMode();
|
2024-12-02 10:12:08 +08:00
|
|
|
|
|
|
|
MNN_CONCURRENCY_BEGIN(tId, mThreadNum) {
|
|
|
|
int start = tId * mWorkDiv;
|
|
|
|
int realSize = ALIMIN(mWorkDiv, mTotalSize - start);
|
2021-06-11 17:17:13 +08:00
|
|
|
if (realSize > 0) {
|
|
|
|
auto inp0 = input0Ptr + start * inpBytes;
|
|
|
|
auto inp1 = input1Ptr + start * inpBytes;
|
|
|
|
if (mNeedBroadcastIndex == 0) {
|
|
|
|
inp0 = input0Ptr;
|
|
|
|
} else if (mNeedBroadcastIndex == 1) {
|
|
|
|
inp1 = input1Ptr;
|
2020-01-15 13:33:47 +08:00
|
|
|
}
|
2021-06-11 17:17:13 +08:00
|
|
|
auto out = outputPtr + start * outBytes;
|
|
|
|
mProc(out, inp0, inp1, realSize, mNeedBroadcastIndex);
|
2022-09-30 10:02:52 +08:00
|
|
|
if(mActivationType == 1 && output->getType().code == halide_type_int) {
|
|
|
|
for(int i=0; i<realSize; i++) {
|
|
|
|
auto val = ((int32_t *)out)[i];
|
|
|
|
auto res = val > 0 ? val : 0;
|
|
|
|
((int32_t *)out)[i] = res;
|
|
|
|
}
|
|
|
|
}
|
2020-01-15 13:33:47 +08:00
|
|
|
}
|
|
|
|
}
|
2021-06-11 17:17:13 +08:00
|
|
|
MNN_CONCURRENCY_END();
|
2022-09-30 10:02:52 +08:00
|
|
|
|
|
|
|
if(mActivationType == 1 && output->getType().code == halide_type_float) {
|
2024-02-29 16:21:40 +08:00
|
|
|
mActivationExe->onExecute(outputs, outputs);
|
2022-09-30 10:02:52 +08:00
|
|
|
}
|
2021-06-11 17:17:13 +08:00
|
|
|
return NO_ERROR;
|
|
|
|
}
|
2019-04-17 10:49:11 +08:00
|
|
|
|
2021-06-11 17:17:13 +08:00
|
|
|
MNNBinaryExecute CPUBinary::selectForFloat(int type) {
|
2023-10-18 10:31:02 +08:00
|
|
|
auto vecFunction = selectVector<Vec4, 4, float>(type);
|
2021-06-11 17:17:13 +08:00
|
|
|
if (nullptr != vecFunction) {
|
|
|
|
return vecFunction;
|
|
|
|
}
|
|
|
|
switch (type) {
|
2019-04-17 10:49:11 +08:00
|
|
|
case BinaryOpOperation_REALDIV:
|
2021-06-11 17:17:13 +08:00
|
|
|
return execute<float, float, BinaryRealDiv<float, float, float>>;
|
|
|
|
case BinaryOpOperation_FLOORDIV:
|
|
|
|
return execute<float, float, BinaryFloorDiv<float, float, float>>;
|
|
|
|
case BinaryOpOperation_FLOORMOD:
|
|
|
|
return execute<float, float, BinaryFloorMod<float, float, float>>;
|
2024-04-19 12:54:03 +08:00
|
|
|
case BinaryOpOperation_NOTEQUAL:
|
|
|
|
return execute<float, int32_t, BinaryNotEqual<float, float, int32_t>>;
|
2021-06-11 17:17:13 +08:00
|
|
|
case BinaryOpOperation_POW:
|
|
|
|
return execute<float, float, BinaryPow<float, float, float>>;
|
|
|
|
case BinaryOpOperation_ATAN2:
|
|
|
|
return execute<float, float, BinaryAtan2<float, float, float>>;
|
|
|
|
case BinaryOpOperation_MOD:
|
|
|
|
return execute<float, float, BinaryMod<float, float, float>>;
|
2020-02-26 09:57:17 +08:00
|
|
|
default:
|
|
|
|
MNN_ASSERT(false);
|
|
|
|
break;
|
|
|
|
}
|
2021-06-11 17:17:13 +08:00
|
|
|
return nullptr;
|
2020-02-26 09:57:17 +08:00
|
|
|
}
|
|
|
|
|
2022-12-30 15:18:58 +08:00
|
|
|
MNNBinaryExecute CPUBinary::selectForInt(int type) {
|
2023-10-18 10:31:02 +08:00
|
|
|
auto vecFunction = selectVector<Vec4Int, 4, int32_t>(type);
|
|
|
|
if (nullptr != vecFunction) {
|
|
|
|
return vecFunction;
|
|
|
|
}
|
2021-06-11 17:17:13 +08:00
|
|
|
switch (type) {
|
2020-02-26 09:57:17 +08:00
|
|
|
case BinaryOpOperation_MUL:
|
2021-06-11 17:17:13 +08:00
|
|
|
return execute<int32_t, int32_t, BinaryMul<int32_t, int32_t, int32_t>>;
|
2020-02-26 09:57:17 +08:00
|
|
|
case BinaryOpOperation_REALDIV:
|
2021-06-11 17:17:13 +08:00
|
|
|
return execute<int32_t, int32_t, BinaryRealDiv<int32_t, int32_t, int32_t>>;
|
2020-02-26 09:57:17 +08:00
|
|
|
case BinaryOpOperation_FLOORDIV:
|
2021-06-11 17:17:13 +08:00
|
|
|
return execute<int32_t, int32_t, BinaryFloorDiv<int32_t, int32_t, int32_t>>;
|
2020-02-26 09:57:17 +08:00
|
|
|
break;
|
|
|
|
case BinaryOpOperation_FLOORMOD:
|
2021-06-11 17:17:13 +08:00
|
|
|
return execute<int32_t, int32_t, BinaryFloorMod<int32_t, int32_t, int32_t>>;
|
2020-02-26 09:57:17 +08:00
|
|
|
break;
|
2019-12-27 22:16:57 +08:00
|
|
|
case BinaryOpOperation_LOGICALOR:
|
2021-06-11 17:17:13 +08:00
|
|
|
return execute<int32_t, int32_t, BinaryLogicalOr<int32_t, int32_t, int32_t>>;
|
2019-12-27 22:16:57 +08:00
|
|
|
break;
|
|
|
|
case BinaryOpOperation_NOTEQUAL:
|
2021-06-11 17:17:13 +08:00
|
|
|
return execute<int32_t, int32_t, BinaryNotEqual<int32_t, int32_t, int32_t>>;
|
2019-12-27 22:16:57 +08:00
|
|
|
break;
|
|
|
|
case BinaryOpOperation_MOD:
|
2021-11-30 10:10:53 +08:00
|
|
|
return execute<int32_t, int32_t, BinaryModInt<int32_t, int32_t, int32_t>>;
|
2019-12-27 22:16:57 +08:00
|
|
|
break;
|
2022-01-04 10:50:40 +08:00
|
|
|
case BinaryOpOperation_LOGICALXOR:
|
|
|
|
return execute<int32_t, int32_t, BinaryLogicalXor<int32_t, int32_t, int32_t>>;
|
|
|
|
break;
|
|
|
|
case BinaryOpOperation_LEFTSHIFT:
|
|
|
|
return execute<int32_t, int32_t, BinaryLeftShift<int32_t, int32_t, int32_t>>;
|
|
|
|
break;
|
|
|
|
case BinaryOpOperation_RIGHTSHIFT:
|
|
|
|
return execute<int32_t, int32_t, BinaryRightShift<int32_t, int32_t, int32_t>>;
|
|
|
|
break;
|
|
|
|
case BinaryOpOperation_BITWISE_AND:
|
|
|
|
return execute<int32_t, int32_t, BinaryBitwiseAnd<int32_t, int32_t, int32_t>>;
|
|
|
|
break;
|
|
|
|
case BinaryOpOperation_BITWISE_OR:
|
|
|
|
return execute<int32_t, int32_t, BinaryBitwiseOr<int32_t, int32_t, int32_t>>;
|
|
|
|
break;
|
|
|
|
case BinaryOpOperation_BITWISE_XOR:
|
|
|
|
return execute<int32_t, int32_t, BinaryBitwiseXor<int32_t, int32_t, int32_t>>;
|
|
|
|
break;
|
2022-12-24 09:42:39 +08:00
|
|
|
case BinaryOpOperation_POW:
|
|
|
|
return execute<int32_t, int32_t, BinaryPow<int32_t, int32_t, int32_t>>;
|
|
|
|
break;
|
2019-04-17 10:49:11 +08:00
|
|
|
default:
|
2022-12-24 09:42:39 +08:00
|
|
|
MNN_ERROR("Don't support binary - int compute for type %d\n", type);
|
2019-04-17 10:49:11 +08:00
|
|
|
MNN_ASSERT(false);
|
|
|
|
break;
|
|
|
|
}
|
2021-06-11 17:17:13 +08:00
|
|
|
return nullptr;
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
class CPUBinaryCreator : public CPUBackend::Creator {
|
|
|
|
public:
|
|
|
|
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
|
|
|
const MNN::Op* op, Backend* backend) const override {
|
|
|
|
int32_t type = op->main_as_BinaryOp()->opType();
|
2020-05-15 20:32:30 +08:00
|
|
|
auto dataType = inputs[0]->getType();
|
2021-06-11 17:17:13 +08:00
|
|
|
auto core = static_cast<CPUBackend*>(backend)->functions();
|
2025-05-23 15:21:41 +08:00
|
|
|
#ifdef MNN_SUPPORT_QUANT_EXTEND
|
2023-04-27 15:11:05 +08:00
|
|
|
if (CPUBackend::getDataType(inputs[0]) == DataType_DT_INT8 || inputs[0]->getType().bytes() == 1) {
|
2023-06-16 09:42:45 +08:00
|
|
|
if (CPUBackend::getDataType(inputs[1]) == DataType_DT_INT8 || inputs[1]->getType().bytes() == 1) {
|
|
|
|
if (CPUBackend::getDataType(outputs[0]) == DataType_DT_INT8 || outputs[0]->getType().bytes() == 1) {
|
|
|
|
auto func = CPUBinaryInt8::selectForInt8(type);
|
|
|
|
if (nullptr == func) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return new CPUBinaryInt8(backend, func, op->main_as_BinaryOp()->activationType());
|
|
|
|
}
|
2023-04-27 15:11:05 +08:00
|
|
|
}
|
|
|
|
}
|
2025-05-23 15:21:41 +08:00
|
|
|
#endif
|
2019-06-17 20:10:35 +08:00
|
|
|
if (dataType.bits == 32) {
|
|
|
|
if (dataType.code == halide_type_int) {
|
2022-12-30 15:18:58 +08:00
|
|
|
auto func = CPUBinary::selectForInt(type);
|
2021-06-11 17:17:13 +08:00
|
|
|
if (nullptr == func) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
2022-09-30 10:02:52 +08:00
|
|
|
return new CPUBinary(backend, func, op->main_as_BinaryOp()->activationType());
|
2020-05-15 20:32:30 +08:00
|
|
|
} else if (dataType.code == halide_type_float) {
|
2021-06-11 17:17:13 +08:00
|
|
|
auto func = core->MNNSelectBinaryFunctionForFloat(type);
|
|
|
|
if (nullptr == func) {
|
|
|
|
return nullptr;
|
|
|
|
}
|
2022-09-30 10:02:52 +08:00
|
|
|
return new CPUBinary(backend, func, op->main_as_BinaryOp()->activationType());
|
2019-06-17 20:10:35 +08:00
|
|
|
}
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
2020-05-15 20:32:30 +08:00
|
|
|
MNN_ERROR("CpuBinary: unsupported data type (bits: %d, code: %d)\n",
|
|
|
|
dataType.bits, dataType.code);
|
2019-06-17 20:10:35 +08:00
|
|
|
return nullptr;
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
REGISTER_CPU_OP_CREATOR(CPUBinaryCreator, OpType_BinaryOp);
|
|
|
|
|
|
|
|
} // namespace MNN
|