| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  CPUConvolutionDepthwise.hpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2018/07/20.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #ifndef CPUConvolutionDepthwise_hpp
 | 
					
						
							|  |  |  | #define CPUConvolutionDepthwise_hpp
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | #include "core/AutoStorage.h"
 | 
					
						
							|  |  |  | #include "backend/cpu/CPUConvolution.hpp"
 | 
					
						
							|  |  |  | #include "backend/cpu/compute/ConvolutionIntFactory.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | namespace MNN { | 
					
						
							|  |  |  | class CPUConvolutionDepthwise : public Execution { | 
					
						
							|  |  |  | public: | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |     class BasicFloatExecution : public CPUConvolution { | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     public: | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |         BasicFloatExecution(const Convolution2DCommon *common, Backend *b) : CPUConvolution(common, b) { | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         virtual ~BasicFloatExecution() = default; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | 
					
						
							|  |  |  |         virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     private: | 
					
						
							|  |  |  |         std::function<void(const float *, float *, int)> mExecutor; | 
					
						
							|  |  |  |         int mNumber = 1; | 
					
						
							|  |  |  |     }; | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |     class MultiInputFloatExecution : public BasicFloatExecution { | 
					
						
							|  |  |  |     public: | 
					
						
							|  |  |  |         MultiInputFloatExecution(const Convolution2DCommon *common, Backend *b) : BasicFloatExecution(common, b) { | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         virtual ~MultiInputFloatExecution() = default; | 
					
						
							|  |  |  |         virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | 
					
						
							|  |  |  |         virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     private: | 
					
						
							|  |  |  |         std::unique_ptr<Tensor> mWeight; | 
					
						
							|  |  |  |         std::unique_ptr<Tensor> mBias; | 
					
						
							|  |  |  |         std::vector<Tensor *> mTempInputs; | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  |     class FloatExecution : public CPUConvolution { | 
					
						
							|  |  |  |     public: | 
					
						
							|  |  |  |         FloatExecution(const Convolution2DCommon *common, Backend *b, const float *originWeight, | 
					
						
							|  |  |  |                        size_t originWeightSize, const float *bias, size_t biasSize); | 
					
						
							|  |  |  |         virtual ~FloatExecution(); | 
					
						
							|  |  |  |         virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, | 
					
						
							|  |  |  |                                     const std::vector<Tensor *> &outputs) override { | 
					
						
							|  |  |  |             return mOrigin->onExecute(mTempInputs, outputs); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override { | 
					
						
							|  |  |  |             mTempInputs = {inputs[0], mWeight.get(), mBias.get()}; | 
					
						
							|  |  |  |             return mOrigin->onResize(mTempInputs, outputs); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     private: | 
					
						
							|  |  |  |         std::unique_ptr<Tensor> mWeight; | 
					
						
							|  |  |  |         std::unique_ptr<Tensor> mBias; | 
					
						
							|  |  |  |         std::vector<Tensor *> mTempInputs; | 
					
						
							|  |  |  |         std::unique_ptr<BasicFloatExecution> mOrigin; | 
					
						
							|  |  |  |     }; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     class Int8Execution : public CPUConvolution { | 
					
						
							|  |  |  |     public: | 
					
						
							|  |  |  |         Int8Execution(const Convolution2DCommon *convOp, Backend *b, const ConvolutionIntFactory::Int8Common *common, | 
					
						
							|  |  |  |                       const float *bias, size_t biasSize); | 
					
						
							|  |  |  |         virtual ~Int8Execution() = default; | 
					
						
							|  |  |  |         virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | 
					
						
							|  |  |  |         virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     private: | 
					
						
							|  |  |  |         AutoStorage<int8_t> mWeight; | 
					
						
							|  |  |  |         AutoStorage<float> mBias; | 
					
						
							|  |  |  |         AutoStorage<float> mAlpha; | 
					
						
							| 
									
										
										
										
											2019-07-11 13:56:52 +08:00
										 |  |  |         float mQuanScale[4]; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         Tensor mInputTempBuffer; | 
					
						
							|  |  |  |         const IDSTQuan *mQuan; | 
					
						
							| 
									
										
										
										
											2019-09-01 19:25:26 +08:00
										 |  |  |         std::function<void()> mRun; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     CPUConvolutionDepthwise(const Op *convOp, Backend *b); | 
					
						
							|  |  |  |     virtual ~CPUConvolutionDepthwise() = default; | 
					
						
							|  |  |  |     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | 
					
						
							|  |  |  |     virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | private: | 
					
						
							|  |  |  |     std::unique_ptr<Execution> mSubExecution; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | } // namespace MNN
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #endif /* CPUConvolutionDepthwise_hpp */
 |