2020-11-05 16:41:56 +08:00
|
|
|
#include "TRTInterp.hpp"
|
|
|
|
#include <core/TensorUtils.hpp>
|
|
|
|
#include "TRTBackend.hpp"
|
2021-02-07 10:45:07 +08:00
|
|
|
#include "schema/current/MNNPlugin_generated.h"
|
2020-11-05 16:41:56 +08:00
|
|
|
|
|
|
|
using namespace std;
|
|
|
|
|
|
|
|
namespace MNN {
|
|
|
|
|
2021-02-07 10:45:07 +08:00
|
|
|
static float resizeScale(int inputSize, int outputSize, bool isAlign) {
|
|
|
|
int corner = 0;
|
|
|
|
if (isAlign) {
|
|
|
|
corner = 1;
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
2021-02-07 10:45:07 +08:00
|
|
|
return (float)(inputSize - corner) / (float)(outputSize - corner);
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
|
2021-02-07 10:45:07 +08:00
|
|
|
TRTInterp::TRTInterp(Backend *b, const Op *op, const std::vector<Tensor *> &inputs,
|
|
|
|
const std::vector<Tensor *> &outputs)
|
|
|
|
: MNN::TRTCommonExecution(b, op) {
|
|
|
|
// Do nothing
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
|
2021-02-07 10:45:07 +08:00
|
|
|
std::vector<ITensor *> TRTInterp::onEncode(const std::vector<ITensor *> &xOp) {
|
|
|
|
#ifdef TRT_LOG
|
|
|
|
MNN_PRINT("\n\nTRTInterp in\n\n");
|
|
|
|
#endif
|
|
|
|
auto plu = createPluginWithOutput(mOutputs);
|
2020-11-05 16:41:56 +08:00
|
|
|
|
2021-02-07 10:45:07 +08:00
|
|
|
int inputChannel = mInputs[0]->channel();
|
|
|
|
int inputBatch = mInputs[0]->batch();
|
2020-11-05 16:41:56 +08:00
|
|
|
|
2021-02-07 10:45:07 +08:00
|
|
|
int inputHeight = mInputs[0]->height();
|
|
|
|
int inputWidth = mInputs[0]->width();
|
|
|
|
int outputHeight = mOutputs[0]->height();
|
|
|
|
int outputWidth = mOutputs[0]->width();
|
2020-11-05 16:41:56 +08:00
|
|
|
|
2021-02-07 10:45:07 +08:00
|
|
|
bool alignCorners = mOp->main_as_Interp()->alignCorners();
|
|
|
|
// TODO, not used now
|
|
|
|
bool halfPixelCenters = mOp->main_as_Interp()->halfPixelCenters();
|
|
|
|
int resizeType = mOp->main_as_Interp()->resizeType();
|
|
|
|
if(resizeType != 1 && resizeType != 2) {
|
|
|
|
MNN_PRINT("Interp Type not support!\n");
|
|
|
|
}
|
|
|
|
plu->main.type = MNNTRTPlugin::Parameter_InterpInfo;
|
|
|
|
plu->main.value = new MNNTRTPlugin::InterpInfoT;
|
|
|
|
auto interp = plu->main.AsInterpInfo();
|
2020-11-05 16:41:56 +08:00
|
|
|
|
2021-02-07 10:45:07 +08:00
|
|
|
interp->inputChannel = inputChannel;
|
|
|
|
interp->heightScale = resizeScale(inputHeight, outputHeight, alignCorners);
|
|
|
|
interp->widthScale = resizeScale(inputWidth, outputWidth, alignCorners);
|
|
|
|
interp->channelBlocks = inputChannel * inputBatch;
|
|
|
|
interp->outputWidth = outputWidth;
|
|
|
|
interp->outputH_N = outputHeight * inputBatch;
|
|
|
|
interp->inputHeight = inputHeight;
|
|
|
|
interp->inputWidth = inputWidth;
|
|
|
|
interp->outputHeight = outputHeight;
|
|
|
|
// MNN_PRINT("hs:%f, ws:%f, c:%d, h:%d, w:%d\n", interp->heightScale, interp->widthScale, interp->channelBlocks,
|
|
|
|
// interp->outputHeight, interp->outputWidth);
|
|
|
|
|
|
|
|
auto interpPlugin = (nvinfer1::IPluginExt *)MNNTRTCreatePlugion(mOp, plu.get());
|
|
|
|
nvinfer1::IPluginLayer *plugin =
|
|
|
|
mTrtBackend->getNetwork()->addPluginExt(&xOp[0], 1, *((nvinfer1::IPluginExt *)interpPlugin));
|
|
|
|
if (plugin == nullptr) {
|
|
|
|
MNN_PRINT("Interp plugin == nullptr !!!\n");
|
|
|
|
}
|
|
|
|
mTrtBackend->pushReleaseLayer(interpPlugin);
|
|
|
|
return {plugin->getOutput(0)};
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
|
2021-02-07 10:45:07 +08:00
|
|
|
TRTCreatorRegister<TypedCreator<TRTInterp>> __interp_op(OpType_Interp);
|
2020-11-05 16:41:56 +08:00
|
|
|
|
2021-02-07 10:45:07 +08:00
|
|
|
}
|