| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  MetalConvolutionCommon.hpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2019/02/25.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #ifndef MetalConvolutionCommon_hpp
 | 
					
						
							|  |  |  | #define MetalConvolutionCommon_hpp
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-12 16:07:38 +08:00
										 |  |  | #import "core/ConvolutionCommon.hpp"
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | #import "MetalBackend.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | #import "MNNMetalContext.h"
 | 
					
						
							|  |  |  | #if MNN_METAL_ENABLED
 | 
					
						
							|  |  |  | namespace MNN { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class MetalConvolutionCommon : public Execution { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     MetalConvolutionCommon(Backend *backend, const MNN::Op *op); | 
					
						
							|  |  |  |     virtual ~MetalConvolutionCommon() = default; | 
					
						
							|  |  |  |     virtual ErrorCode onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | 
					
						
							|  |  |  |     virtual ErrorCode onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) override; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | protected: | 
					
						
							|  |  |  |     void loadWeight(const MNN::Convolution2D *conv); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     virtual ErrorCode onFloat(const Tensor *input, const Tensor *output)     = 0; | 
					
						
							|  |  |  |     virtual id<MTLBuffer> weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | private: | 
					
						
							| 
									
										
										
										
											2020-03-12 16:07:38 +08:00
										 |  |  |     id<MTLBuffer> weightForConv(const Convolution2D *, ConvolutionCommon::Int8Common *, bool); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | protected: | 
					
						
							|  |  |  |     bool mDepthwise     = false; | 
					
						
							|  |  |  |     int mGroups         = 0; | 
					
						
							|  |  |  |     int mKernelX        = 0; | 
					
						
							|  |  |  |     int mKernelY        = 0; | 
					
						
							|  |  |  |     PadMode mPadMode    = PadMode_CAFFE; | 
					
						
							|  |  |  |     int mPadX           = 0; | 
					
						
							|  |  |  |     int mPadY           = 0; | 
					
						
							|  |  |  |     int mStrideX        = 0; | 
					
						
							|  |  |  |     int mStrideY        = 0; | 
					
						
							|  |  |  |     int mDilateX        = 0; | 
					
						
							|  |  |  |     int mDilateY        = 0; | 
					
						
							|  |  |  |     int mActivationType = 0; | 
					
						
							| 
									
										
										
										
											2021-09-18 15:52:30 +08:00
										 |  |  |     const MNN::Op *mOp  = nullptr; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     id<MTLBuffer> mWeight      = nil; | 
					
						
							|  |  |  |     id<MTLBuffer> mBias        = nil; | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |     id<MTLBuffer> mConstBuffer = nil; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // namespace MNN
 | 
					
						
							|  |  |  | #endif /* MNN_METAL_ENABLED */
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #endif /* MetalConvolutionCommon_hpp */
 |