mirror of https://github.com/alibaba/MNN.git
75 lines
2.8 KiB
C++
75 lines
2.8 KiB
C++
|
//
|
||
|
// TRTActivation.cpp
|
||
|
// MNN
|
||
|
//
|
||
|
// Created by MNN on 2019/09/11.
|
||
|
// Copyright © 2018, Alibaba Group Holding Limited
|
||
|
//
|
||
|
|
||
|
#include "TRTActivation.hpp"
|
||
|
#include <core/TensorUtils.hpp>
|
||
|
#include "TRTBackend.hpp"
|
||
|
#include "schema/current/MNNPlugin_generated.h"
|
||
|
|
||
|
using namespace std;
|
||
|
|
||
|
namespace MNN {
|
||
|
static std::shared_ptr<MNNTRTPlugin::PluginT> createPluginWithOutput(const std::vector<Tensor *> &outputs) {
|
||
|
std::shared_ptr<MNNTRTPlugin::PluginT> plu(new MNNTRTPlugin::PluginT);
|
||
|
plu->outputs.resize(outputs.size());
|
||
|
for (int i = 0; i < outputs.size(); ++i) {
|
||
|
auto shape = outputs[0]->shape();
|
||
|
plu->outputs[i].reset(new MNNTRTPlugin::ShapeT);
|
||
|
plu->outputs[i]->dim = shape;
|
||
|
plu->outputs[i]->bytes = outputs[i]->getType().bytes();
|
||
|
plu->outputs[i]->type = outputs[i]->getType().code;
|
||
|
}
|
||
|
return plu;
|
||
|
}
|
||
|
|
||
|
TRTActivation::TRTActivation(Backend *b, const Op *op, const std::vector<Tensor *> &inputs,
|
||
|
const std::vector<Tensor *> &outputs)
|
||
|
: MNN::TRTCommonExecution(b, op) {
|
||
|
}
|
||
|
|
||
|
std::vector<ITensor *> TRTActivation::onEncode(const std::vector<ITensor *> &xOp) {
|
||
|
IActivationLayer *activationLayer{nullptr};
|
||
|
if (mOp->type() == OpType_ReLU) {
|
||
|
float slope = 0.0f;
|
||
|
if (nullptr != mOp->main_as_Relu()) {
|
||
|
slope = mOp->main_as_Relu()->slope();
|
||
|
}
|
||
|
if (slope == 0.0f) {
|
||
|
activationLayer = mTrtBackend->getNetwork()->addActivation(*(xOp[0]), nvinfer1::ActivationType::kRELU);
|
||
|
} else {
|
||
|
// Use Prelu plugin
|
||
|
auto plu = createPluginWithOutput(mOutputs);
|
||
|
auto preluPlugin = (nvinfer1::IPluginExt *)MNNTRTCreatePlugion(mOp, plu.get());
|
||
|
nvinfer1::IPluginLayer *plugin =
|
||
|
mTrtBackend->getNetwork()->addPluginExt(&xOp[0], 1, *((nvinfer1::IPluginExt *)preluPlugin));
|
||
|
if (plugin == nullptr) {
|
||
|
printf("plugin == nullptr !!!");
|
||
|
}
|
||
|
mTrtBackend->pushReleaseLayer(preluPlugin);
|
||
|
return {plugin->getOutput(0)};
|
||
|
}
|
||
|
} else if (mOp->type() == OpType_Sigmoid) {
|
||
|
activationLayer = mTrtBackend->getNetwork()->addActivation(*(xOp[0]), nvinfer1::ActivationType::kSIGMOID);
|
||
|
} else if (mOp->type() == OpType_ReLU6) {
|
||
|
activationLayer = mTrtBackend->getNetwork()->addActivation(*(xOp[0]), ActivationType::kCLIP);
|
||
|
activationLayer->setAlpha(0.);
|
||
|
activationLayer->setBeta(6.);
|
||
|
} else {
|
||
|
MNN_PRINT("activation not support this type : %d \n", mOp->type());
|
||
|
}
|
||
|
|
||
|
// activationLayer->setName(mOp->name()->str().c_str());
|
||
|
return {activationLayer->getOutput(0)};
|
||
|
}
|
||
|
|
||
|
TRTCreatorRegister<TypedCreator<TRTActivation>> __relu_op(OpType_ReLU);
|
||
|
TRTCreatorRegister<TypedCreator<TRTActivation>> __sigmoid_op(OpType_Sigmoid);
|
||
|
TRTCreatorRegister<TypedCreator<TRTActivation>> __relu6_op(OpType_ReLU6);
|
||
|
|
||
|
} // namespace MNN
|