mirror of https://github.com/alibaba/MNN.git
feat(convert): add ROIPoolingOnnx Convert(the onnx model file export from torchvision.ops.roi_pool)
This commit is contained in:
parent
19c2df11f5
commit
cca07fdf98
|
|
@ -0,0 +1,40 @@
|
|||
//
|
||||
// ROIPoolingOnnx.cpp
|
||||
// MNNConverter
|
||||
//
|
||||
// Created by MNN on 2021/10/27.
|
||||
// Copyright © 2018, Alibaba Group Holding Limited
|
||||
//
|
||||
|
||||
#include "onnxOpConverter.hpp"
|
||||
|
||||
DECLARE_OP_CONVERTER(ROIPoolingOnnx);
|
||||
|
||||
MNN::OpType ROIPoolingOnnx::opType() { return MNN::OpType_ROIPooling; }
|
||||
|
||||
MNN::OpParameter ROIPoolingOnnx::type() { return MNN::OpParameter_RoiPooling; }
|
||||
|
||||
void ROIPoolingOnnx::run(MNN::OpT *dstOp, const onnx::NodeProto *onnxNode, OnnxScope *scope) {
|
||||
auto roiPool = new MNN::RoiPoolingT;
|
||||
|
||||
const auto attrSize = onnxNode->attribute_size();
|
||||
for (int i = 0; i < attrSize; ++i) {
|
||||
const auto &attributeProto = onnxNode->attribute(i);
|
||||
const auto &attributeName = attributeProto.name();
|
||||
if (attributeName == "output_size") {
|
||||
DCHECK(attributeProto.type() == ::onnx::AttributeProto_AttributeType_INTS) << "Node Attribute ERROR";
|
||||
DCHECK(attributeProto.ints_size() == 2) << "Node Attribute ERROR";
|
||||
roiPool->pooledHeight = attributeProto.ints(0);
|
||||
roiPool->pooledWidth = attributeProto.ints(1);
|
||||
} else if (attributeName == "spatial_scale") {
|
||||
DCHECK(attributeProto.type() == ::onnx::AttributeProto_AttributeType_FLOAT) << "Node Attribute ERROR";
|
||||
roiPool->spatialScale = attributeProto.f();
|
||||
} else {
|
||||
DLOG(ERROR) << "TODO!";
|
||||
}
|
||||
}
|
||||
|
||||
dstOp->main.value = roiPool;
|
||||
};
|
||||
|
||||
REGISTER_CONVERTER(ROIPoolingOnnx, ROIPooling);
|
||||
Loading…
Reference in New Issue