mirror of https://github.com/alibaba/MNN.git
				
				
				
			feat(convert): add ROIPoolingOnnx Convert(the onnx model file export from torchvision.ops.roi_pool)
This commit is contained in:
		
							parent
							
								
									19c2df11f5
								
							
						
					
					
						commit
						cca07fdf98
					
				| 
						 | 
				
			
			@ -0,0 +1,40 @@
 | 
			
		|||
//
 | 
			
		||||
//  ROIPoolingOnnx.cpp
 | 
			
		||||
//  MNNConverter
 | 
			
		||||
//
 | 
			
		||||
//  Created by MNN on 2021/10/27.
 | 
			
		||||
//  Copyright © 2018, Alibaba Group Holding Limited
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
#include "onnxOpConverter.hpp"
 | 
			
		||||
 | 
			
		||||
DECLARE_OP_CONVERTER(ROIPoolingOnnx);
 | 
			
		||||
 | 
			
		||||
MNN::OpType ROIPoolingOnnx::opType() { return MNN::OpType_ROIPooling; }
 | 
			
		||||
 | 
			
		||||
MNN::OpParameter ROIPoolingOnnx::type() { return MNN::OpParameter_RoiPooling; }
 | 
			
		||||
 | 
			
		||||
void ROIPoolingOnnx::run(MNN::OpT *dstOp, const onnx::NodeProto *onnxNode, OnnxScope *scope) {
 | 
			
		||||
    auto roiPool = new MNN::RoiPoolingT;
 | 
			
		||||
 | 
			
		||||
    const auto attrSize = onnxNode->attribute_size();
 | 
			
		||||
    for (int i = 0; i < attrSize; ++i) {
 | 
			
		||||
        const auto &attributeProto = onnxNode->attribute(i);
 | 
			
		||||
        const auto &attributeName  = attributeProto.name();
 | 
			
		||||
        if (attributeName == "output_size") {
 | 
			
		||||
            DCHECK(attributeProto.type() == ::onnx::AttributeProto_AttributeType_INTS) << "Node Attribute ERROR";
 | 
			
		||||
            DCHECK(attributeProto.ints_size() == 2) << "Node Attribute ERROR";
 | 
			
		||||
            roiPool->pooledHeight = attributeProto.ints(0);
 | 
			
		||||
            roiPool->pooledWidth  = attributeProto.ints(1);
 | 
			
		||||
        } else if (attributeName == "spatial_scale") {
 | 
			
		||||
            DCHECK(attributeProto.type() == ::onnx::AttributeProto_AttributeType_FLOAT) << "Node Attribute ERROR";
 | 
			
		||||
            roiPool->spatialScale = attributeProto.f();
 | 
			
		||||
        } else {
 | 
			
		||||
            DLOG(ERROR) << "TODO!";
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    dstOp->main.value = roiPool;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
REGISTER_CONVERTER(ROIPoolingOnnx, ROIPooling);
 | 
			
		||||
		Loading…
	
		Reference in New Issue