mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
	
	
		
			55 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			C++
		
	
	
	
		
		
			
		
	
	
			55 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			C++
		
	
	
	
|  | //  ShapePlugin.cpp
 | ||
|  | //  MNN
 | ||
|  | //
 | ||
|  | //  Created by MNN on 2020/04/05.
 | ||
|  | //  Copyright © 2018, Alibaba Group Holding Limited
 | ||
|  | //
 | ||
|  | 
 | ||
|  | #ifdef MNN_WITH_PLUGIN
 | ||
|  | 
 | ||
|  | #include "core/SizeComputer.hpp"
 | ||
|  | #include "core/TensorUtils.hpp"
 | ||
|  | 
 | ||
|  | #include "MNN/plugin/PluginShapeInference.hpp"
 | ||
|  | 
 | ||
|  | namespace MNN { | ||
|  | 
 | ||
|  | static std::shared_ptr<plugin::InferShapeKernel> getInferShapeKernel( // NOLINT
 | ||
|  |     const std::string& name) {                                        // NOLINT
 | ||
|  |     return std::shared_ptr<plugin::InferShapeKernel>(                 // NOLINT
 | ||
|  |         plugin::InferShapeKernelRegister::get(name)); | ||
|  | } | ||
|  | 
 | ||
|  | class PluginSizeComputer : public SizeComputer { | ||
|  | public: | ||
|  |     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, | ||
|  |                                const std::vector<Tensor*>& outputs) const override { | ||
|  |         // Plugin op should has inputs or outputs, or both.
 | ||
|  |         MNN_CHECK(inputs.size() > 0 || outputs.size() > 0, // NOLINT
 | ||
|  |                   "Plugin op should has inputs or outputs, or both of them."); | ||
|  | 
 | ||
|  |         const Plugin* plugin_param = op->main_as<Plugin>(); | ||
|  | 
 | ||
|  |         std::shared_ptr<plugin::InferShapeKernel> kernel = // NOLINT
 | ||
|  |             getInferShapeKernel(plugin_param->type()->str()); | ||
|  |         MNN_CHECK(nullptr != kernel.get(), // NOLINT
 | ||
|  |                   "Shape inference kernel has not been registered for plugin op."); | ||
|  | 
 | ||
|  |         plugin::InferShapeContext ctx(inputs, outputs); | ||
|  |         for (const Attribute* attr : *(plugin_param->attr())) { | ||
|  |             ctx.setAttr(attr->key()->str(), attr); | ||
|  |         } | ||
|  |         bool status = kernel->compute(&ctx); | ||
|  |         if (!status) { | ||
|  |             MNN_ERROR("Plugin op infer shape failed with false returned."); | ||
|  |         } | ||
|  |         return status; | ||
|  |     } | ||
|  | }; | ||
|  | 
 | ||
|  | REGISTER_SHAPE(PluginSizeComputer, OpType_Plugin); | ||
|  | 
 | ||
|  | } // namespace MNN
 | ||
|  | 
 | ||
|  | #endif // MNN_WITH_PLUGIN
 |