feat(convert): add ROIPoolingOnnx Convert(the onnx model file export from torchvision.ops.roi_pool)

This commit is contained in:
insta360 2021-10-27 15:06:45 +08:00 committed by wuhao
parent 19c2df11f5
commit cca07fdf98
1 changed files with 40 additions and 0 deletions

View File

@ -0,0 +1,40 @@
//
// ROIPoolingOnnx.cpp
// MNNConverter
//
// Created by MNN on 2021/10/27.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "onnxOpConverter.hpp"
DECLARE_OP_CONVERTER(ROIPoolingOnnx);
MNN::OpType ROIPoolingOnnx::opType() { return MNN::OpType_ROIPooling; }
MNN::OpParameter ROIPoolingOnnx::type() { return MNN::OpParameter_RoiPooling; }
void ROIPoolingOnnx::run(MNN::OpT *dstOp, const onnx::NodeProto *onnxNode, OnnxScope *scope) {
auto roiPool = new MNN::RoiPoolingT;
const auto attrSize = onnxNode->attribute_size();
for (int i = 0; i < attrSize; ++i) {
const auto &attributeProto = onnxNode->attribute(i);
const auto &attributeName = attributeProto.name();
if (attributeName == "output_size") {
DCHECK(attributeProto.type() == ::onnx::AttributeProto_AttributeType_INTS) << "Node Attribute ERROR";
DCHECK(attributeProto.ints_size() == 2) << "Node Attribute ERROR";
roiPool->pooledHeight = attributeProto.ints(0);
roiPool->pooledWidth = attributeProto.ints(1);
} else if (attributeName == "spatial_scale") {
DCHECK(attributeProto.type() == ::onnx::AttributeProto_AttributeType_FLOAT) << "Node Attribute ERROR";
roiPool->spatialScale = attributeProto.f();
} else {
DLOG(ERROR) << "TODO!";
}
}
dstOp->main.value = roiPool;
};
REGISTER_CONVERTER(ROIPoolingOnnx, ROIPooling);