mirror of https://github.com/alibaba/MNN.git
124 lines
4.8 KiB
C++
124 lines
4.8 KiB
C++
//
|
|
// EltwiseExecution.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2019/02/28.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include "EltwiseExecution.hpp"
|
|
|
|
#include <Macro.h>
|
|
#include <string.h>
|
|
#include "TensorUtils.hpp"
|
|
|
|
namespace MNN {
|
|
namespace OpenCL {
|
|
|
|
EltwiseExecution::EltwiseExecution(const std::vector<Tensor *> &inputs, const std::string &compute, Backend *backend)
|
|
: CommonExecution(backend) {
|
|
mBuildOptions.emplace("-DOPERATOR=" + compute);
|
|
}
|
|
|
|
ErrorCode EltwiseExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
|
MNN_ASSERT(inputs.size() >= 2);
|
|
mUnits.resize(inputs.size() - 1);
|
|
|
|
auto nhwc = tensorShapeFormat(outputs[0]);
|
|
int nhwcArray[] = {nhwc[0], nhwc[1], nhwc[2], UP_DIV(nhwc[3], 4)};
|
|
|
|
auto imageWidth = nhwcArray[2] * nhwcArray[3];
|
|
auto imageHeight = nhwcArray[0] * nhwcArray[1];
|
|
|
|
int wh[] = {nhwc[2], nhwc[1]};
|
|
int input1Stride[] = {1, 1, 1, 1};
|
|
cl::NDRange localSize = {16, 16};
|
|
cl::NDRange globalSize = {(uint32_t)UP_DIV(imageWidth, 16) * 16, (uint32_t)UP_DIV(imageHeight, 16) * 16};
|
|
|
|
auto runTime = ((OpenCLBackend *)backend())->getOpenCLRuntime();
|
|
mUnits[0].kernel = runTime->buildKernel("binary", "binary", mBuildOptions);
|
|
mUnits[0].kernel.setArg(0, openCLImage(inputs[0]));
|
|
mUnits[0].kernel.setArg(1, openCLImage(inputs[1]));
|
|
mUnits[0].kernel.setArg(2, openCLImage(outputs[0]));
|
|
mUnits[0].kernel.setArg(3, nhwcArray);
|
|
mUnits[0].kernel.setArg(4, wh);
|
|
mUnits[0].kernel.setArg(5, input1Stride);
|
|
mUnits[0].globalWorkSize = globalSize;
|
|
mUnits[0].localWorkSize = localSize;
|
|
for (int i = 2; i < inputs.size(); ++i) {
|
|
auto &unit = mUnits[i - 1];
|
|
unit.kernel = runTime->buildKernel("binary", "binary", mBuildOptions);
|
|
unit.kernel.setArg(0, openCLImage(inputs[i]));
|
|
unit.kernel.setArg(1, openCLImage(outputs[0]));
|
|
unit.kernel.setArg(2, openCLImage(outputs[0]));
|
|
unit.kernel.setArg(3, nhwcArray);
|
|
unit.kernel.setArg(4, wh);
|
|
unit.kernel.setArg(5, input1Stride);
|
|
unit.globalWorkSize = globalSize;
|
|
unit.localWorkSize = localSize;
|
|
}
|
|
return NO_ERROR;
|
|
}
|
|
class EltwiseCreator : public OpenCLBackend::Creator {
|
|
public:
|
|
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
|
|
const MNN::Op *op, Backend *backend) const override {
|
|
if (op->type() == OpType_Eltwise) {
|
|
switch (op->main_as_Eltwise()->type()) {
|
|
case EltwiseType_SUM:
|
|
return new EltwiseExecution(inputs, "in0+in1", backend);
|
|
case EltwiseType_PROD:
|
|
return new EltwiseExecution(inputs, "in0*in1", backend);
|
|
case EltwiseType_MAXIMUM:
|
|
return new EltwiseExecution(inputs, "fmax(in0, in1)", backend);
|
|
default:
|
|
break;
|
|
}
|
|
return nullptr;
|
|
}
|
|
if (op->type() == OpType_BinaryOp) {
|
|
MNN_ASSERT(inputs.size() > 1);
|
|
auto input0 = inputs[0];
|
|
// Don't support broatcast
|
|
for (int i = 1; i < inputs.size(); ++i) {
|
|
auto input = inputs[i];
|
|
if (input0->dimensions() != input->dimensions()) {
|
|
return nullptr;
|
|
}
|
|
auto dim = input0->dimensions();
|
|
for (int l = 0; l < dim; ++l) {
|
|
if (input0->length(l) != input->length(l)) {
|
|
return nullptr;
|
|
}
|
|
}
|
|
}
|
|
switch (op->main_as_BinaryOp()->opType()) {
|
|
case BinaryOpOperation_ADD:
|
|
return new EltwiseExecution(inputs, "in0+in1", backend);
|
|
case BinaryOpOperation_SUB:
|
|
return new EltwiseExecution(inputs, "in0-in1", backend);
|
|
case BinaryOpOperation_MUL:
|
|
return new EltwiseExecution(inputs, "in0*in1", backend);
|
|
case BinaryOpOperation_POW:
|
|
return new EltwiseExecution(inputs, "pow(in0, in1)", backend);
|
|
case BinaryOpOperation_DIV:
|
|
return new EltwiseExecution(inputs, "in0/in1", backend);
|
|
case BinaryOpOperation_MAXIMUM:
|
|
return new EltwiseExecution(inputs, "fmax(in0,in1)", backend);
|
|
case BinaryOpOperation_MINIMUM:
|
|
return new EltwiseExecution(inputs, "fmin(in0,in1)", backend);
|
|
default:
|
|
break;
|
|
}
|
|
return nullptr;
|
|
}
|
|
return nullptr;
|
|
}
|
|
};
|
|
|
|
OpenCLCreatorRegister<EltwiseCreator> __eltwise_op(OpType_Eltwise);
|
|
OpenCLCreatorRegister<EltwiseCreator> __binary_op(OpType_BinaryOp);
|
|
|
|
} // namespace OpenCL
|
|
} // namespace MNN
|