mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			605 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			605 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			C++
		
	
	
	
//
 | 
						||
//  CPUBinary.cpp
 | 
						||
//  MNN
 | 
						||
//
 | 
						||
//  Created by MNN on 2018/08/02.
 | 
						||
//  Copyright © 2018, Alibaba Group Holding Limited
 | 
						||
//
 | 
						||
 | 
						||
#include "CPUBinary.hpp"
 | 
						||
#include <math.h>
 | 
						||
#include <algorithm>
 | 
						||
#include "CPUBackend.hpp"
 | 
						||
#include "compute/CommonOptFunction.h"
 | 
						||
#include "compute/ConvOpt.h"
 | 
						||
#include "core/Macro.h"
 | 
						||
#include "core/Concurrency.h"
 | 
						||
#include "core/OpCommonUtils.hpp"
 | 
						||
namespace MNN {
 | 
						||
#define MAX_DIM 6
 | 
						||
CPUBinaryInt::CPUBinaryInt(Backend* b, int32_t type) : MNN::Execution(b), mType(type) {
 | 
						||
    // nothing to do
 | 
						||
}
 | 
						||
CPUBinaryFloat::CPUBinaryFloat(Backend* b, int32_t type) : MNN::Execution(b), mType(type) {
 | 
						||
    // nothing to do
 | 
						||
}
 | 
						||
 | 
						||
ErrorCode CPUBinaryFloat::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
 | 
						||
    MNN_ASSERT(1 == outputs.size());
 | 
						||
    const int input0DataCount = inputs[0]->elementSize();
 | 
						||
    const int input1DataCount = inputs[1]->elementSize();
 | 
						||
    const int outputDataCount = outputs[0]->elementSize();
 | 
						||
    int maxCount = input0DataCount > input1DataCount ?  input0DataCount : input1DataCount;
 | 
						||
    mElementProc = nullptr;
 | 
						||
    mSupportScale = false;
 | 
						||
    if (outputs[0]->getType().code != halide_type_float || maxCount < 4 || (outputDataCount > input0DataCount && outputDataCount > input1DataCount)) {
 | 
						||
        // Can't optimize
 | 
						||
        return NO_ERROR;
 | 
						||
    }
 | 
						||
    auto eleProc = mElementProc;// Set nullptr for begin
 | 
						||
    switch (mType) {
 | 
						||
        case BinaryOpOperation_MUL:
 | 
						||
            eleProc = MNNMatrixProdCommon;
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_ADD:
 | 
						||
            eleProc = MNNMatrixAddCommon;
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_MAXIMUM:
 | 
						||
            eleProc = MNNMatrixMaxCommon;
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_SUB:
 | 
						||
            eleProc = MNNMatrixSubCommon;
 | 
						||
            break;
 | 
						||
        default:
 | 
						||
            break;
 | 
						||
    }
 | 
						||
    if (input1DataCount == input0DataCount) {
 | 
						||
        mOutside = 1;
 | 
						||
        mInside = input0DataCount;
 | 
						||
        mElementProc = eleProc;
 | 
						||
        return NO_ERROR;
 | 
						||
    }
 | 
						||
    if (input1DataCount == 1 || input0DataCount == 1) {
 | 
						||
        mAxis = 1;
 | 
						||
        mOutside = 1;
 | 
						||
        switch (mType) {
 | 
						||
            case BinaryOpOperation_MUL:
 | 
						||
            case BinaryOpOperation_ADD:
 | 
						||
            case BinaryOpOperation_SUB:
 | 
						||
                mSupportScale = true;
 | 
						||
                break;
 | 
						||
            default:
 | 
						||
                break;
 | 
						||
        }
 | 
						||
        return NO_ERROR;
 | 
						||
    }
 | 
						||
    if (nullptr == eleProc) {
 | 
						||
        return NO_ERROR;
 | 
						||
    }
 | 
						||
    // For AddBias / Mul Sqrt
 | 
						||
    int dims[MAX_DIM];
 | 
						||
    int stride[MAX_DIM];
 | 
						||
    int iStride0[MAX_DIM];
 | 
						||
    int iStride1[MAX_DIM];
 | 
						||
    const Tensor* input0 = inputs[0];
 | 
						||
    const Tensor* input1 = inputs[1];
 | 
						||
    const Tensor* output = outputs[0];
 | 
						||
    if (input0DataCount < input1DataCount) {
 | 
						||
        input0 = inputs[1];
 | 
						||
        input1 = inputs[0];
 | 
						||
    }
 | 
						||
    OpCommonUtils::broastCastComputeDim(dims, stride, iStride0, iStride1, input0, input1, output);
 | 
						||
    int breakPos = -1;
 | 
						||
    for (int i=0; i<MAX_DIM; ++i) {
 | 
						||
        if (iStride1[i] > 0) {
 | 
						||
            if (breakPos >= 0) {
 | 
						||
                // Failed to optmize
 | 
						||
                return NO_ERROR;
 | 
						||
            }
 | 
						||
            breakPos = i;
 | 
						||
        }
 | 
						||
    }
 | 
						||
    MNN_ASSERT(breakPos >= 0);
 | 
						||
    //FUNC_PRINT(breakPos);
 | 
						||
    mOutside = 1;
 | 
						||
    mInside = 1;
 | 
						||
    for (int i=0; i<breakPos; ++i) {
 | 
						||
        mOutside *= dims[i];
 | 
						||
    }
 | 
						||
    mAxis = dims[breakPos];
 | 
						||
    for (int i=breakPos+1; i<MAX_DIM; ++i) {
 | 
						||
        mInside *= dims[i];
 | 
						||
    }
 | 
						||
    // Serveral Machine need memory 4 * sizeof(float) align
 | 
						||
    if (1 == mInside && mAxis >= 4) {
 | 
						||
        mElementProc = eleProc;
 | 
						||
        //MNN_PRINT("Open Optimize\n");
 | 
						||
    } else if (BinaryOpOperation_MAXIMUM != mType && mInside >= 4) {
 | 
						||
        mSupportScale = true;
 | 
						||
    }
 | 
						||
    //MNN_PRINT("%d, %d, %d\n", mInside, mAxis, mOutside);
 | 
						||
    return NO_ERROR;
 | 
						||
}
 | 
						||
 | 
						||
template <typename Tin, typename Tout, typename Func>
 | 
						||
static ErrorCode _binaryOp(Tensor* input0, Tensor* input1, Tensor* output) {
 | 
						||
    Func f;
 | 
						||
    const int input0DataCount = input0->elementSize();
 | 
						||
    const int input1DataCount = input1->elementSize();
 | 
						||
    const Tin* input0Data = input0->host<Tin>();
 | 
						||
    const Tin* input1Data = input1->host<Tin>();
 | 
						||
    Tout* outputData      = output->host<Tout>();
 | 
						||
 | 
						||
    if (input0DataCount == 1) { // data count == 1, not only mean scalar input, maybe of shape (1, 1, 1, ...,1)
 | 
						||
        for (int i = 0; i < input1DataCount; i++) {
 | 
						||
            outputData[i] = static_cast<Tout>(f(input0Data[0], input1Data[i]));
 | 
						||
        }
 | 
						||
    } else if (input1DataCount == 1) {
 | 
						||
        for (int i = 0; i < input0DataCount; i++) {
 | 
						||
            outputData[i] = static_cast<Tout>(f(input0Data[i], input1Data[0]));
 | 
						||
        }
 | 
						||
    } else { // both input contains more than one element,which means no scalar input
 | 
						||
        bool sameShape = true;
 | 
						||
        {
 | 
						||
            if (input0->dimensions() == input1->dimensions()) {
 | 
						||
                for (int i = 0; i < input0->buffer().dimensions; i++) {
 | 
						||
                    if (input0->buffer().dim[i].extent != input1->buffer().dim[i].extent) {
 | 
						||
                        sameShape = false;
 | 
						||
                        break;
 | 
						||
                    }
 | 
						||
                }
 | 
						||
            }
 | 
						||
            else {
 | 
						||
                sameShape = false;
 | 
						||
            }
 | 
						||
        }
 | 
						||
        if (sameShape) { // two inputs have the same shape, apply element-wise operation
 | 
						||
            for (int i = 0; i < input0DataCount; i++) {
 | 
						||
                outputData[i] = static_cast<Tout>(f(input0Data[i], input1Data[i]));
 | 
						||
            }
 | 
						||
        } else { // not the same shape, use broadcast
 | 
						||
            MNN_ASSERT(output->dimensions() <= MAX_DIM);
 | 
						||
            int dims[MAX_DIM];
 | 
						||
            int stride[MAX_DIM];
 | 
						||
            int iStride0[MAX_DIM];
 | 
						||
            int iStride1[MAX_DIM];
 | 
						||
            OpCommonUtils::broastCastComputeDim(dims, stride, iStride0, iStride1, input0, input1, output);
 | 
						||
            for (int w = 0; w < dims[5]; ++w) {
 | 
						||
                auto ow  = outputData + w * stride[5];
 | 
						||
                auto i0w = input0Data + w * iStride0[5];
 | 
						||
                auto i1w = input1Data + w * iStride1[5];
 | 
						||
#define PTR(x, y, i)                      \
 | 
						||
    auto o##x  = o##y + x * stride[i];    \
 | 
						||
    auto i0##x = i0##y + x * iStride0[i]; \
 | 
						||
    auto i1##x = i1##y + x * iStride1[i]
 | 
						||
 | 
						||
                for (int v = 0; v < dims[4]; ++v) {
 | 
						||
                    PTR(v, w, 4);
 | 
						||
                    for (int u = 0; u < dims[3]; ++u) {
 | 
						||
                        PTR(u, v, 3);
 | 
						||
                        for (int z = 0; z < dims[2]; ++z) {
 | 
						||
                            PTR(z, u, 2);
 | 
						||
                            for (int y = 0; y < dims[1]; ++y) {
 | 
						||
                                PTR(y, z, 1);
 | 
						||
                                for (int x = 0; x < dims[0]; ++x) {
 | 
						||
                                    PTR(x, y, 0);
 | 
						||
                                    *ox = static_cast<Tout>(f(*i0x, *i1x));
 | 
						||
                                }
 | 
						||
                            }
 | 
						||
                        }
 | 
						||
                    }
 | 
						||
                }
 | 
						||
            }
 | 
						||
#undef MAX_DIM
 | 
						||
#undef PTR
 | 
						||
        }
 | 
						||
        // broadcast-capable check is done in compute size
 | 
						||
    }
 | 
						||
 | 
						||
    return NO_ERROR;
 | 
						||
}
 | 
						||
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryMax : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return std::max(x, y);
 | 
						||
    }
 | 
						||
};
 | 
						||
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryMin : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return std::min(x, y);
 | 
						||
    }
 | 
						||
};
 | 
						||
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryMul : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return x * y;
 | 
						||
    }
 | 
						||
};
 | 
						||
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryAdd : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return x + y;
 | 
						||
    }
 | 
						||
};
 | 
						||
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinarySub : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return x - y;
 | 
						||
    }
 | 
						||
};
 | 
						||
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryRealDiv : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return x / y;
 | 
						||
    }
 | 
						||
};
 | 
						||
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryMod : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return x - x / y;
 | 
						||
    }
 | 
						||
};
 | 
						||
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryGreater : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return (_ErrorCode)((x > y) ? 1 : 0);
 | 
						||
    }
 | 
						||
};
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryLess : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return (_ErrorCode)((x < y) ? 1 : 0);
 | 
						||
    }
 | 
						||
};
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryGreaterEqual : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return (_ErrorCode)((x >= y) ? 1 : 0);
 | 
						||
    }
 | 
						||
};
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryLessEqual : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return (_ErrorCode)((x <= y) ? 1 : 0);
 | 
						||
    }
 | 
						||
};
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryEqual : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return (_ErrorCode)((x == y) ? 1 : 0);
 | 
						||
    }
 | 
						||
};
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryFloorDiv : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return floor(x / y);
 | 
						||
    }
 | 
						||
};
 | 
						||
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryFloorMod : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return x - floor(x / y) * y;
 | 
						||
    }
 | 
						||
};
 | 
						||
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinarySquaredDifference : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return (x - y) * (x - y);
 | 
						||
    }
 | 
						||
};
 | 
						||
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryPow : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return pow(x, y);
 | 
						||
    }
 | 
						||
};
 | 
						||
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryAtan2 : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return atan(x / y);
 | 
						||
    }
 | 
						||
};
 | 
						||
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryLogicalOr : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return (_ErrorCode)((x || y) ? 1 : 0);
 | 
						||
    }
 | 
						||
};
 | 
						||
 | 
						||
template <typename _Arg1, typename _Arg2, typename _ErrorCode>
 | 
						||
struct BinaryNotEqual : std::binary_function<_Arg1, _Arg2, _ErrorCode> {
 | 
						||
    _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const {
 | 
						||
        return (_ErrorCode)((x != y) ? 1 : 0);
 | 
						||
    }
 | 
						||
};
 | 
						||
 | 
						||
static void callEleFunc(void(*proc)(float* C, const float* A, const float* B, size_t width, size_t cStride, size_t aStride, size_t bStride, size_t height),
 | 
						||
                        float* C, const float* A, const float* B, size_t size, bool swap) {
 | 
						||
    if (swap) {
 | 
						||
        proc(C, B, A, size, 0, 0, 0, 1);
 | 
						||
    } else {
 | 
						||
        proc(C, A, B, size, 0, 0, 0, 1);
 | 
						||
    }
 | 
						||
}
 | 
						||
 | 
						||
ErrorCode CPUBinaryFloat::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
 | 
						||
    auto input  = inputs[0];
 | 
						||
    auto input1 = inputs[1];
 | 
						||
    auto output = outputs[0];
 | 
						||
    
 | 
						||
    if (nullptr != mElementProc || mSupportScale) {
 | 
						||
        auto numberThread = ((CPUBackend*)backend())->threadNumber();
 | 
						||
        auto i1Size = input->elementSize();
 | 
						||
        auto i2Size = input1->elementSize();
 | 
						||
        bool swap = false;
 | 
						||
        if (i1Size < i2Size) {
 | 
						||
            auto temp = i2Size;
 | 
						||
            i2Size = i1Size;
 | 
						||
            i1Size = temp;
 | 
						||
            input = inputs[1];
 | 
						||
            input1 = inputs[0];
 | 
						||
            swap = true;
 | 
						||
        }
 | 
						||
        auto size = i1Size;
 | 
						||
        auto schedule = ((CPUBackend*)backend())->multiThreadDivide(size);
 | 
						||
        int sizeDivide = schedule.first;
 | 
						||
        int scheduleNumber = schedule.second;
 | 
						||
        if (nullptr != mElementProc) {
 | 
						||
            if (mOutside == 1) {
 | 
						||
                MNN_CONCURRENCY_BEGIN(tId, scheduleNumber) {
 | 
						||
                    int start = sizeDivide * (int)tId;
 | 
						||
                    int realSize = sizeDivide;
 | 
						||
                    if (tId == scheduleNumber -1 ) {
 | 
						||
                        realSize = size - start;
 | 
						||
                    }
 | 
						||
                    if (realSize > 0) {
 | 
						||
                        mElementProc(output->host<float>() + start, input->host<float>() + start, input1->host<float>() + start, realSize, 0, 0, 0, 1);
 | 
						||
                    }
 | 
						||
                }
 | 
						||
                MNN_CONCURRENCY_END();
 | 
						||
            } else {
 | 
						||
                MNN_CONCURRENCY_BEGIN(tId, numberThread) {
 | 
						||
                    for (int y = tId; y < mOutside; y+=numberThread) {
 | 
						||
                        callEleFunc(mElementProc, output->host<float>() + y * mAxis, input->host<float>() + y * mAxis, input1->host<float>(), mAxis, swap);
 | 
						||
                    }
 | 
						||
                }
 | 
						||
                MNN_CONCURRENCY_END();
 | 
						||
            }
 | 
						||
        } else {
 | 
						||
            if (mOutside == 1 && mAxis == 1) {
 | 
						||
                float* inputPtr = input->host<float>();
 | 
						||
                float scalar = input1->host<float>()[0];
 | 
						||
                float scale = scalar;
 | 
						||
                float bias = 0.0f;
 | 
						||
                switch (mType) {
 | 
						||
                    case BinaryOpOperation_ADD:
 | 
						||
                        scale = 1.0f;
 | 
						||
                        bias = scalar;
 | 
						||
                        break;
 | 
						||
                    case BinaryOpOperation_SUB:
 | 
						||
                        if (!swap) {
 | 
						||
                            scale = 1.0f;
 | 
						||
                            bias = -scalar;
 | 
						||
                        } else {
 | 
						||
                            scale = -1.0f;
 | 
						||
                            bias = scalar;
 | 
						||
                        }
 | 
						||
                        break;
 | 
						||
                    default:
 | 
						||
                        break;
 | 
						||
                }
 | 
						||
 | 
						||
                MNN_CONCURRENCY_BEGIN(tId, scheduleNumber) {
 | 
						||
                    int start = sizeDivide * (int)tId;
 | 
						||
                    int realSize = sizeDivide;
 | 
						||
                    if (tId == scheduleNumber -1 ) {
 | 
						||
                        realSize = size - start;
 | 
						||
                    }
 | 
						||
                    if (realSize > 0) {
 | 
						||
                        MNNScaleAndAddBiasScalar(output->host<float>() + start, inputPtr + start, bias, scale, realSize);
 | 
						||
                    }
 | 
						||
                }
 | 
						||
                MNN_CONCURRENCY_END();
 | 
						||
            } else {
 | 
						||
                float* inputPtr = input->host<float>();
 | 
						||
                float* input1Ptr = input1->host<float>();
 | 
						||
                auto total = mOutside * mAxis;
 | 
						||
                MNN_CONCURRENCY_BEGIN(tId, numberThread) {
 | 
						||
                    for (int index = tId; index < total; index += numberThread) {
 | 
						||
                        auto axis = index % mAxis;
 | 
						||
                        float scalar = input1Ptr[axis];
 | 
						||
                        float scale = scalar;
 | 
						||
                        float bias = 0.0f;
 | 
						||
                        switch (mType) {
 | 
						||
                            case BinaryOpOperation_ADD:
 | 
						||
                                scale = 1.0f;
 | 
						||
                                bias = scalar;
 | 
						||
                                break;
 | 
						||
                            case BinaryOpOperation_SUB:
 | 
						||
                                if (!swap) {
 | 
						||
                                    scale = 1.0f;
 | 
						||
                                    bias = -scalar;
 | 
						||
                                } else {
 | 
						||
                                    scale = -1.0f;
 | 
						||
                                    bias = scalar;
 | 
						||
                                }
 | 
						||
                                break;
 | 
						||
                            default:
 | 
						||
                                break;
 | 
						||
                        }
 | 
						||
                        MNNScaleAndAddBiasScalar(output->host<float>() + mInside * index, inputPtr +  mInside * index, bias, scale, mInside);
 | 
						||
                    }
 | 
						||
                }
 | 
						||
                MNN_CONCURRENCY_END();
 | 
						||
            }
 | 
						||
 | 
						||
        }
 | 
						||
        return NO_ERROR;
 | 
						||
    }
 | 
						||
 | 
						||
    switch (mType) {
 | 
						||
        case BinaryOpOperation_MUL:
 | 
						||
            _binaryOp<float, float, BinaryMul<float, float, float>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_ADD:
 | 
						||
            _binaryOp<float, float, BinaryAdd<float, float, float>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_SUB:
 | 
						||
            _binaryOp<float, float, BinarySub<float, float, float>>(input, input1, output);
 | 
						||
            break;
 | 
						||
 | 
						||
        case BinaryOpOperation_REALDIV:
 | 
						||
            _binaryOp<float, float, BinaryRealDiv<float, float, float>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_MINIMUM:
 | 
						||
            _binaryOp<float, float, BinaryMin<float, float, float>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_MAXIMUM:
 | 
						||
            _binaryOp<float, float, BinaryMax<float, float, float>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_GREATER:
 | 
						||
            _binaryOp<float, int32_t, BinaryGreater<float, float, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_LESS:
 | 
						||
            _binaryOp<float, int32_t, BinaryLess<float, float, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_LESS_EQUAL:
 | 
						||
            _binaryOp<float, int32_t, BinaryLessEqual<float, float, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_GREATER_EQUAL:
 | 
						||
            _binaryOp<float, int32_t, BinaryGreaterEqual<float, float, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_EQUAL:
 | 
						||
            _binaryOp<float, int32_t, BinaryEqual<float, float, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_FLOORDIV:
 | 
						||
            _binaryOp<float, float, BinaryFloorDiv<float, float, float>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_FLOORMOD:
 | 
						||
            _binaryOp<float, float, BinaryFloorMod<float, float, float>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_POW:
 | 
						||
            _binaryOp<float, float, BinaryPow<float, float, float>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_SquaredDifference:
 | 
						||
            _binaryOp<float, float, BinarySquaredDifference<float, float, float>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_ATAN2:
 | 
						||
            _binaryOp<float, float, BinaryAtan2<float, float, float>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_NOTEQUAL:
 | 
						||
            _binaryOp<float, int32_t, BinaryNotEqual<float, float, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_MOD:
 | 
						||
            _binaryOp<float, float, BinaryMod<float, float, float>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        default:
 | 
						||
            MNN_ASSERT(false);
 | 
						||
            break;
 | 
						||
    }
 | 
						||
    return NO_ERROR;
 | 
						||
}
 | 
						||
 | 
						||
ErrorCode CPUBinaryInt::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
 | 
						||
    auto input  = inputs[0];
 | 
						||
    auto input1 = inputs[1];
 | 
						||
    auto output = outputs[0];
 | 
						||
    switch (mType) {
 | 
						||
        case BinaryOpOperation_MUL:
 | 
						||
            _binaryOp<int32_t, int32_t, BinaryMul<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_ADD:
 | 
						||
            _binaryOp<int32_t, int32_t, BinaryAdd<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_SUB:
 | 
						||
            _binaryOp<int32_t, int32_t, BinarySub<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
 | 
						||
        case BinaryOpOperation_REALDIV:
 | 
						||
            _binaryOp<int32_t, int32_t, BinaryRealDiv<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_MINIMUM:
 | 
						||
            _binaryOp<int32_t, int32_t, BinaryMin<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_MAXIMUM:
 | 
						||
            _binaryOp<int32_t, int32_t, BinaryMax<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_GREATER:
 | 
						||
            _binaryOp<int32_t, int32_t, BinaryGreater<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_LESS:
 | 
						||
            _binaryOp<int32_t, int32_t, BinaryLess<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_LESS_EQUAL:
 | 
						||
            _binaryOp<int32_t, int32_t, BinaryLessEqual<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_GREATER_EQUAL:
 | 
						||
            _binaryOp<int32_t, int32_t, BinaryGreaterEqual<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_EQUAL:
 | 
						||
            _binaryOp<int32_t, int32_t, BinaryEqual<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_FLOORDIV:
 | 
						||
            _binaryOp<int32_t, int32_t, BinaryFloorDiv<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_FLOORMOD:
 | 
						||
            _binaryOp<int32_t, int32_t, BinaryFloorMod<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_SquaredDifference:
 | 
						||
            _binaryOp<int32_t, int32_t, BinarySquaredDifference<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_LOGICALOR:
 | 
						||
            _binaryOp<int32_t, int32_t, BinaryLogicalOr<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_NOTEQUAL:
 | 
						||
            _binaryOp<int32_t, int32_t, BinaryNotEqual<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        case BinaryOpOperation_MOD:
 | 
						||
            _binaryOp<int32_t, int32_t, BinaryMod<int32_t, int32_t, int32_t>>(input, input1, output);
 | 
						||
            break;
 | 
						||
        default:
 | 
						||
            MNN_ASSERT(false);
 | 
						||
            break;
 | 
						||
    }
 | 
						||
    return NO_ERROR;
 | 
						||
}
 | 
						||
 | 
						||
class CPUBinaryCreator : public CPUBackend::Creator {
 | 
						||
public:
 | 
						||
    virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
 | 
						||
                                const MNN::Op* op, Backend* backend) const override {
 | 
						||
        // auto dataType   = outputs[0]->getType();
 | 
						||
        int32_t type = op->main_as_BinaryOp()->opType();
 | 
						||
        // auto dataType = op->main_as_BinaryOp()->T();
 | 
						||
        auto dataType = inputs[0]->getType();
 | 
						||
        if (dataType.bits == 32) {
 | 
						||
            if (dataType.code == halide_type_int) {
 | 
						||
                return new CPUBinaryInt(backend, type);
 | 
						||
            } else if (dataType.code == halide_type_float) {
 | 
						||
                return new CPUBinaryFloat(backend, type);
 | 
						||
            }
 | 
						||
        }
 | 
						||
        MNN_ERROR("CpuBinary: unsupported data type (bits: %d, code: %d)\n",
 | 
						||
                  dataType.bits, dataType.code);
 | 
						||
        return nullptr;
 | 
						||
    }
 | 
						||
};
 | 
						||
 | 
						||
REGISTER_CPU_OP_CREATOR(CPUBinaryCreator, OpType_BinaryOp);
 | 
						||
 | 
						||
} // namespace MNN
 |