MNN/source/backend/nnapi/execution/NNAPIBinary.cpp

58 lines
2.0 KiB
C++
Raw Normal View History

2022-09-30 10:02:52 +08:00
//
// NNAPIBinary.cpp
// MNN
//
// Created by MNN on 2022/09/05.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "NNAPIBinary.hpp"
namespace MNN {
NNAPIBinary::NNAPIBinary(MNN::Backend *b, const MNN::Op *op, const std::vector<Tensor *> &inputs, const std::vector<MNN::Tensor *> &outputs) : NNAPICommonExecution(b, op) {
}
ErrorCode NNAPIBinary::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
MNN_ASSERT(inputs.size() == 2 && outputs.size() == 1);
std::map<BinaryOpOperation, int> binary_map {
{BinaryOpOperation_ADD, ANEURALNETWORKS_ADD},
{BinaryOpOperation_SUB, ANEURALNETWORKS_SUB},
{BinaryOpOperation_MUL, ANEURALNETWORKS_MUL},
{BinaryOpOperation_DIV, ANEURALNETWORKS_DIV}
};
2022-10-30 08:44:24 +08:00
BinaryOpOperation binaryType;
if (mOp->type() == OpType_BinaryOp) {
binaryType = static_cast<BinaryOpOperation>(mOp->main_as_BinaryOp()->opType());
} else if (mOp->type() == OpType_Eltwise) {
auto elemType = mOp->main_as_Eltwise()->type();
switch (elemType) {
case EltwiseType_PROD:
binaryType = BinaryOpOperation_MUL;
break;
case EltwiseType_SUM:
binaryType = BinaryOpOperation_ADD;
break;
case EltwiseType_SUB:
binaryType = BinaryOpOperation_SUB;
break;
case EltwiseType_MAXIMUM:
binaryType = BinaryOpOperation_MAXIMUM;
break;
}
}
auto iter = binary_map.find(binaryType);
2022-09-30 10:02:52 +08:00
if (iter == binary_map.end() || iter->second < 0) {
2022-10-30 08:44:24 +08:00
MNN_ERROR("[NNAPI] Binary not support %s\n", MNN::EnumNameBinaryOpOperation(binaryType));
2022-09-30 10:02:52 +08:00
return NOT_SUPPORT;
}
auto inputIdxs = getTensorIdxs(inputs);
inputIdxs.push_back(buildScalar(ANEURALNETWORKS_FUSED_NONE));
return buildOperation(iter->second, inputIdxs, getTensorIdxs(outputs));
}
REGISTER_NNAPI_OP_CREATOR(NNAPIBinary, OpType_BinaryOp)
2022-10-30 08:44:24 +08:00
REGISTER_NNAPI_OP_CREATOR(NNAPIBinary, OpType_Eltwise)
2022-09-30 10:02:52 +08:00
} // namespace MNN