| 
									
										
										
										
											2020-04-14 21:43:02 +08:00
										 |  |  | //  ShapePlugin.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2020/04/05.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | #include "shape/SizeComputer.hpp"
 | 
					
						
							| 
									
										
										
										
											2020-04-14 21:43:02 +08:00
										 |  |  | #include "core/TensorUtils.hpp"
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-15 14:12:35 +08:00
										 |  |  | #ifdef MNN_WITH_PLUGIN
 | 
					
						
							| 
									
										
										
										
											2020-04-14 21:43:02 +08:00
										 |  |  | #include "MNN/plugin/PluginShapeInference.hpp"
 | 
					
						
							| 
									
										
										
										
											2020-12-15 14:12:35 +08:00
										 |  |  | #endif  // MNN_WITH_PLUGIN
 | 
					
						
							| 
									
										
										
										
											2020-04-14 21:43:02 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | namespace MNN { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-15 14:12:35 +08:00
										 |  |  | #ifdef MNN_WITH_PLUGIN
 | 
					
						
							| 
									
										
										
										
											2020-04-14 21:43:02 +08:00
										 |  |  | static std::shared_ptr<plugin::InferShapeKernel> getInferShapeKernel( // NOLINT
 | 
					
						
							|  |  |  |     const std::string& name) {                                        // NOLINT
 | 
					
						
							|  |  |  |     return std::shared_ptr<plugin::InferShapeKernel>(                 // NOLINT
 | 
					
						
							|  |  |  |         plugin::InferShapeKernelRegister::get(name)); | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2020-12-15 14:12:35 +08:00
										 |  |  | #endif  // MNN_WITH_PLUGIN
 | 
					
						
							| 
									
										
										
										
											2020-04-14 21:43:02 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | class PluginSizeComputer : public SizeComputer { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     virtual bool onComputeSize(const MNN::Op* op, const std::vector<Tensor*>& inputs, | 
					
						
							|  |  |  |                                const std::vector<Tensor*>& outputs) const override { | 
					
						
							|  |  |  |         // Plugin op should has inputs or outputs, or both.
 | 
					
						
							|  |  |  |         MNN_CHECK(inputs.size() > 0 || outputs.size() > 0, // NOLINT
 | 
					
						
							|  |  |  |                   "Plugin op should has inputs or outputs, or both of them."); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-15 14:12:35 +08:00
										 |  |  | #ifdef MNN_WITH_PLUGIN
 | 
					
						
							| 
									
										
										
										
											2020-04-14 21:43:02 +08:00
										 |  |  |         const Plugin* plugin_param = op->main_as<Plugin>(); | 
					
						
							|  |  |  |         std::shared_ptr<plugin::InferShapeKernel> kernel = // NOLINT
 | 
					
						
							|  |  |  |             getInferShapeKernel(plugin_param->type()->str()); | 
					
						
							|  |  |  |         MNN_CHECK(nullptr != kernel.get(), // NOLINT
 | 
					
						
							|  |  |  |                   "Shape inference kernel has not been registered for plugin op."); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         plugin::InferShapeContext ctx(inputs, outputs); | 
					
						
							|  |  |  |         for (const Attribute* attr : *(plugin_param->attr())) { | 
					
						
							|  |  |  |             ctx.setAttr(attr->key()->str(), attr); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         bool status = kernel->compute(&ctx); | 
					
						
							|  |  |  |         if (!status) { | 
					
						
							|  |  |  |             MNN_ERROR("Plugin op infer shape failed with false returned."); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return status; | 
					
						
							| 
									
										
										
										
											2020-12-15 14:12:35 +08:00
										 |  |  | #else
 | 
					
						
							|  |  |  |         MNN_ERROR("Plugin is not supported. Please recompile with `MNN_WITH_PLUGIN` enabled."); | 
					
						
							|  |  |  |         return false; | 
					
						
							|  |  |  | #endif  // MNN_WITH_PLUGIN
 | 
					
						
							| 
									
										
										
										
											2020-04-14 21:43:02 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | REGISTER_SHAPE(PluginSizeComputer, OpType_Plugin); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // namespace MNN
 |