| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  CPUDeconvolution.hpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2018/07/20.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #ifndef CPUDeconvolution_hpp
 | 
					
						
							|  |  |  | #define CPUDeconvolution_hpp
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  | #include "CPUConvolution.hpp"
 | 
					
						
							|  |  |  | #include "compute/StrassenMatmulComputor.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | namespace MNN { | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  | class CPUDeconvolutionBasic : public CPUConvolution { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | public: | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |     CPUDeconvolutionBasic(const Tensor *input, const Op *convOp, Backend *b); | 
					
						
							|  |  |  |     virtual ~CPUDeconvolutionBasic() = default; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | protected: | 
					
						
							|  |  |  |     int mSrcCount; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |     std::vector<float> mPostParameters; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  | class CPUDeconvolutionCommon : public CPUDeconvolutionBasic { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | public: | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |     CPUDeconvolutionCommon(const Tensor *input, const Op *convOp, Backend *b); | 
					
						
							|  |  |  |     virtual ~CPUDeconvolutionCommon(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | protected: | 
					
						
							|  |  |  |     std::shared_ptr<Tensor> mBias; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class CPUDeconvolutionOrigin : public CPUDeconvolutionBasic { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     CPUDeconvolutionOrigin(const Tensor *input, const Op *convOp, Backend *b) | 
					
						
							|  |  |  |         : CPUDeconvolutionBasic(input, convOp, b) { | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  |         // Do nothing
 | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |     virtual ~CPUDeconvolutionOrigin() = default; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | 
					
						
							|  |  |  |     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | private: | 
					
						
							| 
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 |  |  |     std::shared_ptr<StrassenMatrixComputor> mMatMul; | 
					
						
							| 
									
										
										
										
											2020-07-04 01:21:30 +08:00
										 |  |  |     std::vector<std::pair<std::function<void(float*, int)>, int>> mPostFunctions; | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class CPUDeconvolution : public CPUDeconvolutionCommon { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     CPUDeconvolution(const Tensor *input, const Op *convOp, Backend *b); | 
					
						
							|  |  |  |     virtual ~CPUDeconvolution(); | 
					
						
							|  |  |  |     virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override { | 
					
						
							|  |  |  |         mOrigin->onExecute(mTempInputs, outputs); | 
					
						
							|  |  |  |         return NO_ERROR; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override { | 
					
						
							|  |  |  |         mTempInputs = {inputs[0], mWeight.get(), mBias.get()}; | 
					
						
							|  |  |  |         return mOrigin->onResize(mTempInputs, outputs); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | private: | 
					
						
							|  |  |  |     std::shared_ptr<Tensor> mWeight; | 
					
						
							|  |  |  |     std::vector<Tensor *> mTempInputs; | 
					
						
							|  |  |  |     std::shared_ptr<CPUDeconvolutionOrigin> mOrigin; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | } // namespace MNN
 | 
					
						
							|  |  |  | #endif /* CPUDeconvolution_hpp */
 |