| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  CPUBinary.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2018/08/02.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 |  |  | #include "CPUBinary.hpp"
 | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  | #include "CPUBinaryInt8.hpp"
 | 
					
						
							| 
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 |  |  | #include "CPUBackend.hpp"
 | 
					
						
							|  |  |  | #include "compute/CommonOptFunction.h"
 | 
					
						
							|  |  |  | #include "compute/ConvOpt.h"
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | #include "core/Macro.h"
 | 
					
						
							| 
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 |  |  | #include "core/Concurrency.h"
 | 
					
						
							| 
									
										
										
										
											2020-03-08 10:20:18 +08:00
										 |  |  | #include "core/OpCommonUtils.hpp"
 | 
					
						
							| 
									
										
										
										
											2021-01-07 17:08:34 +08:00
										 |  |  | #include "BinaryUtils.hpp"
 | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  | #include "math/Vec.hpp"
 | 
					
						
							|  |  |  | using Vec4 = MNN::Math::Vec<float, 4>; | 
					
						
							| 
									
										
										
										
											2023-10-18 10:31:02 +08:00
										 |  |  | using Vec4Int = MNN::Math::Vec<int32_t, 4>; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | namespace MNN { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  | ErrorCode CPUBinary::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) { | 
					
						
							| 
									
										
										
										
											2023-08-21 14:51:54 +08:00
										 |  |  |     auto input0DataCount = TensorUtils::getRawSize(inputs[0]); | 
					
						
							|  |  |  |     auto input1DataCount = TensorUtils::getRawSize(inputs[1]); | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  |     if (input1DataCount == input0DataCount) { | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |         mNeedBroadcastIndex = -1; | 
					
						
							|  |  |  |     } else if (input0DataCount == 1) { | 
					
						
							|  |  |  |         mNeedBroadcastIndex = 0; | 
					
						
							|  |  |  |     } else { | 
					
						
							|  |  |  |         mNeedBroadcastIndex = 1; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2023-08-21 14:51:54 +08:00
										 |  |  |     mTotalSize = ((CPUBackend*)backend())->getTensorSize(outputs[0]); | 
					
						
							| 
									
										
										
										
											2022-09-30 10:02:52 +08:00
										 |  |  |      | 
					
						
							|  |  |  |     if(mActivationType == 1 && outputs[0]->getType().code == halide_type_float) { | 
					
						
							|  |  |  |         mActivationExe.reset(new CPURelu(backend(), 0.0)); | 
					
						
							|  |  |  |         mActivationExe->onResize(outputs, outputs); | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     return NO_ERROR; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  | ErrorCode CPUBinary::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     auto input  = inputs[0]; | 
					
						
							|  |  |  |     auto input1 = inputs[1]; | 
					
						
							|  |  |  |     auto output = outputs[0]; | 
					
						
							| 
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 |  |  |      | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     auto schedule = ((CPUBackend*)backend())->multiThreadDivide(mTotalSize); | 
					
						
							|  |  |  |     auto input0Ptr = input->host<uint8_t>(); | 
					
						
							|  |  |  |     auto input1Ptr = input1->host<uint8_t>(); | 
					
						
							| 
									
										
										
										
											2022-09-30 10:02:52 +08:00
										 |  |  |      | 
					
						
							|  |  |  |     auto outputPtr = outputs[0]->host<uint8_t>(); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     int inpBytes = input->getType().bytes(); | 
					
						
							|  |  |  |     int outBytes = output->getType().bytes(); | 
					
						
							|  |  |  |     if (halide_type_float == input->getType().code) { | 
					
						
							|  |  |  |         inpBytes = static_cast<CPUBackend*>(backend())->functions()->bytes; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (halide_type_float == output->getType().code) { | 
					
						
							|  |  |  |         outBytes = static_cast<CPUBackend*>(backend())->functions()->bytes; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     auto precision = static_cast<CPUBackend*>(backend())->precisionMode(); | 
					
						
							|  |  |  |     MNN_CONCURRENCY_BEGIN(tId, schedule.second) { | 
					
						
							|  |  |  |         int start = schedule.first * (int)tId; | 
					
						
							|  |  |  |         int realSize = schedule.first; | 
					
						
							|  |  |  |         if (tId == schedule.second -1 ) { | 
					
						
							|  |  |  |             realSize = mTotalSize - start; | 
					
						
							| 
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |         if (realSize > 0) { | 
					
						
							|  |  |  |             auto inp0 = input0Ptr + start * inpBytes; | 
					
						
							|  |  |  |             auto inp1 = input1Ptr + start * inpBytes; | 
					
						
							|  |  |  |             if (mNeedBroadcastIndex == 0) { | 
					
						
							|  |  |  |                 inp0 = input0Ptr; | 
					
						
							|  |  |  |             } else if (mNeedBroadcastIndex == 1) { | 
					
						
							|  |  |  |                 inp1 = input1Ptr; | 
					
						
							| 
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 |  |  |             } | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             auto out = outputPtr + start * outBytes; | 
					
						
							|  |  |  |             mProc(out, inp0, inp1, realSize, mNeedBroadcastIndex); | 
					
						
							| 
									
										
										
										
											2022-09-30 10:02:52 +08:00
										 |  |  |             if(mActivationType == 1 && output->getType().code == halide_type_int) { | 
					
						
							|  |  |  |                 for(int i=0; i<realSize; i++) { | 
					
						
							|  |  |  |                     auto val = ((int32_t *)out)[i]; | 
					
						
							|  |  |  |                     auto res = val > 0 ? val : 0; | 
					
						
							|  |  |  |                     ((int32_t *)out)[i] = res; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							| 
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     MNN_CONCURRENCY_END(); | 
					
						
							| 
									
										
										
										
											2022-09-30 10:02:52 +08:00
										 |  |  |      | 
					
						
							|  |  |  |     if(mActivationType == 1 && output->getType().code == halide_type_float) { | 
					
						
							| 
									
										
										
										
											2024-02-29 16:21:40 +08:00
										 |  |  |         mActivationExe->onExecute(outputs, outputs); | 
					
						
							| 
									
										
										
										
											2022-09-30 10:02:52 +08:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     return NO_ERROR; | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  | MNNBinaryExecute CPUBinary::selectForFloat(int type) { | 
					
						
							| 
									
										
										
										
											2023-10-18 10:31:02 +08:00
										 |  |  |     auto vecFunction = selectVector<Vec4, 4, float>(type); | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     if (nullptr != vecFunction) { | 
					
						
							|  |  |  |         return vecFunction; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     switch (type) { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         case BinaryOpOperation_REALDIV: | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             return execute<float, float, BinaryRealDiv<float, float, float>>; | 
					
						
							|  |  |  |         case BinaryOpOperation_FLOORDIV: | 
					
						
							|  |  |  |             return execute<float, float, BinaryFloorDiv<float, float, float>>; | 
					
						
							|  |  |  |         case BinaryOpOperation_FLOORMOD: | 
					
						
							|  |  |  |             return execute<float, float, BinaryFloorMod<float, float, float>>; | 
					
						
							| 
									
										
										
										
											2024-04-19 12:54:03 +08:00
										 |  |  |         case BinaryOpOperation_NOTEQUAL: | 
					
						
							|  |  |  |             return execute<float, int32_t, BinaryNotEqual<float, float, int32_t>>; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |         case BinaryOpOperation_POW: | 
					
						
							|  |  |  |             return execute<float, float, BinaryPow<float, float, float>>; | 
					
						
							|  |  |  |         case BinaryOpOperation_ATAN2: | 
					
						
							|  |  |  |             return execute<float, float, BinaryAtan2<float, float, float>>; | 
					
						
							|  |  |  |         case BinaryOpOperation_MOD: | 
					
						
							|  |  |  |             return execute<float, float, BinaryMod<float, float, float>>; | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  |         default: | 
					
						
							|  |  |  |             MNN_ASSERT(false); | 
					
						
							|  |  |  |             break; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     return nullptr; | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  | MNNBinaryExecute CPUBinary::selectForInt(int type) { | 
					
						
							| 
									
										
										
										
											2023-10-18 10:31:02 +08:00
										 |  |  |     auto vecFunction = selectVector<Vec4Int, 4, int32_t>(type); | 
					
						
							|  |  |  |     if (nullptr != vecFunction) { | 
					
						
							|  |  |  |         return vecFunction; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     switch (type) { | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  |         case BinaryOpOperation_MUL: | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             return execute<int32_t, int32_t, BinaryMul<int32_t, int32_t, int32_t>>; | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  |         case BinaryOpOperation_REALDIV: | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             return execute<int32_t, int32_t, BinaryRealDiv<int32_t, int32_t, int32_t>>; | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  |         case BinaryOpOperation_FLOORDIV: | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             return execute<int32_t, int32_t, BinaryFloorDiv<int32_t, int32_t, int32_t>>; | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  |             break; | 
					
						
							|  |  |  |         case BinaryOpOperation_FLOORMOD: | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             return execute<int32_t, int32_t, BinaryFloorMod<int32_t, int32_t, int32_t>>; | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  |             break; | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |         case BinaryOpOperation_LOGICALOR: | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             return execute<int32_t, int32_t, BinaryLogicalOr<int32_t, int32_t, int32_t>>; | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |             break; | 
					
						
							|  |  |  |         case BinaryOpOperation_NOTEQUAL: | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |             return execute<int32_t, int32_t, BinaryNotEqual<int32_t, int32_t, int32_t>>; | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |             break; | 
					
						
							|  |  |  |         case BinaryOpOperation_MOD: | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |             return execute<int32_t, int32_t, BinaryModInt<int32_t, int32_t, int32_t>>; | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |             break; | 
					
						
							| 
									
										
										
										
											2022-01-04 10:50:40 +08:00
										 |  |  |         case BinaryOpOperation_LOGICALXOR: | 
					
						
							|  |  |  |             return execute<int32_t, int32_t, BinaryLogicalXor<int32_t, int32_t, int32_t>>; | 
					
						
							|  |  |  |             break; | 
					
						
							|  |  |  |         case BinaryOpOperation_LEFTSHIFT: | 
					
						
							|  |  |  |             return execute<int32_t, int32_t, BinaryLeftShift<int32_t, int32_t, int32_t>>; | 
					
						
							|  |  |  |             break; | 
					
						
							|  |  |  |         case BinaryOpOperation_RIGHTSHIFT: | 
					
						
							|  |  |  |             return execute<int32_t, int32_t, BinaryRightShift<int32_t, int32_t, int32_t>>; | 
					
						
							|  |  |  |             break; | 
					
						
							|  |  |  |         case BinaryOpOperation_BITWISE_AND: | 
					
						
							|  |  |  |             return execute<int32_t, int32_t, BinaryBitwiseAnd<int32_t, int32_t, int32_t>>; | 
					
						
							|  |  |  |             break; | 
					
						
							|  |  |  |         case BinaryOpOperation_BITWISE_OR: | 
					
						
							|  |  |  |             return execute<int32_t, int32_t, BinaryBitwiseOr<int32_t, int32_t, int32_t>>; | 
					
						
							|  |  |  |             break; | 
					
						
							|  |  |  |         case BinaryOpOperation_BITWISE_XOR: | 
					
						
							|  |  |  |             return execute<int32_t, int32_t, BinaryBitwiseXor<int32_t, int32_t, int32_t>>; | 
					
						
							|  |  |  |             break; | 
					
						
							| 
									
										
										
										
											2022-12-24 09:42:39 +08:00
										 |  |  |         case BinaryOpOperation_POW: | 
					
						
							|  |  |  |             return execute<int32_t, int32_t, BinaryPow<int32_t, int32_t, int32_t>>; | 
					
						
							|  |  |  |             break; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         default: | 
					
						
							| 
									
										
										
										
											2022-12-24 09:42:39 +08:00
										 |  |  |             MNN_ERROR("Don't support binary - int compute for type %d\n", type); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |             MNN_ASSERT(false); | 
					
						
							|  |  |  |             break; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     return nullptr; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class CPUBinaryCreator : public CPUBackend::Creator { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, | 
					
						
							|  |  |  |                                 const MNN::Op* op, Backend* backend) const override { | 
					
						
							|  |  |  |         int32_t type = op->main_as_BinaryOp()->opType(); | 
					
						
							| 
									
										
										
										
											2020-05-15 20:32:30 +08:00
										 |  |  |         auto dataType = inputs[0]->getType(); | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |         auto core = static_cast<CPUBackend*>(backend)->functions(); | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  |         auto input0Ptr = inputs[0]->host<uint8_t>(); | 
					
						
							|  |  |  |         if (CPUBackend::getDataType(inputs[0]) == DataType_DT_INT8 || inputs[0]->getType().bytes() == 1) { | 
					
						
							| 
									
										
										
										
											2023-06-16 09:42:45 +08:00
										 |  |  |             if (CPUBackend::getDataType(inputs[1]) == DataType_DT_INT8 || inputs[1]->getType().bytes() == 1) { | 
					
						
							|  |  |  |                 if (CPUBackend::getDataType(outputs[0]) == DataType_DT_INT8 || outputs[0]->getType().bytes() == 1) { | 
					
						
							|  |  |  |                     auto func = CPUBinaryInt8::selectForInt8(type); | 
					
						
							|  |  |  |                     if (nullptr == func) { | 
					
						
							|  |  |  |                         return nullptr; | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                     return new CPUBinaryInt8(backend, func, op->main_as_BinaryOp()->activationType()); | 
					
						
							|  |  |  |                 } | 
					
						
							| 
									
										
										
										
											2023-04-27 15:11:05 +08:00
										 |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |         if (dataType.bits == 32) { | 
					
						
							|  |  |  |             if (dataType.code == halide_type_int) { | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |                 auto func = CPUBinary::selectForInt(type); | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |                 if (nullptr == func) { | 
					
						
							|  |  |  |                     return nullptr; | 
					
						
							|  |  |  |                 } | 
					
						
							| 
									
										
										
										
											2022-09-30 10:02:52 +08:00
										 |  |  |                 return new CPUBinary(backend, func, op->main_as_BinaryOp()->activationType()); | 
					
						
							| 
									
										
										
										
											2020-05-15 20:32:30 +08:00
										 |  |  |             } else if (dataType.code == halide_type_float) { | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |                 auto func = core->MNNSelectBinaryFunctionForFloat(type); | 
					
						
							|  |  |  |                 if (nullptr == func) { | 
					
						
							|  |  |  |                     return nullptr; | 
					
						
							|  |  |  |                 } | 
					
						
							| 
									
										
										
										
											2022-09-30 10:02:52 +08:00
										 |  |  |                 return new CPUBinary(backend, func, op->main_as_BinaryOp()->activationType()); | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |             } | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2020-05-15 20:32:30 +08:00
										 |  |  |         MNN_ERROR("CpuBinary: unsupported data type (bits: %d, code: %d)\n", | 
					
						
							|  |  |  |                   dataType.bits, dataType.code); | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |         return nullptr; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | REGISTER_CPU_OP_CREATOR(CPUBinaryCreator, OpType_BinaryOp); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // namespace MNN
 |