mirror of https://github.com/alibaba/MNN.git
				
				
				
			[PATCH 31/78] [Metal:Speed] Optimized MetalPooling by move operate into onResize
This commit is contained in:
		
							parent
							
								
									80963bd19a
								
							
						
					
					
						commit
						10fa9fc483
					
				|  | @ -121,6 +121,7 @@ typedef struct { | |||
|                 threads:(MTLSize)threads | ||||
|         threadsPerGroup:(MTLSize)threadsPerGroup | ||||
|               bandwidth:(MNN::MetalBandwidth)bandwidth; | ||||
| - (id<MTLComputePipelineState>)pipelineWithName:(NSString *)name; | ||||
| 
 | ||||
| #if MNN_METAL_DEBUG | ||||
| /**
 | ||||
|  | @ -147,6 +148,8 @@ typedef struct { | |||
|  * @brief print encoder | ||||
|  */ | ||||
| - (void)printEncoder:(id<MTLCommandEncoder>)encoder; | ||||
| 
 | ||||
| 
 | ||||
| #endif | ||||
| @end | ||||
| 
 | ||||
|  |  | |||
|  | @ -33,6 +33,9 @@ private: | |||
|     int mPadX; | ||||
|     int mPadY; | ||||
|     id<MTLBuffer> mConstBuffer; | ||||
|     MTLSize mGroup; | ||||
|     MTLSize mLocal; | ||||
|     id<MTLComputePipelineState> mPipeline; | ||||
| }; | ||||
| 
 | ||||
| } // namespace MNN
 | ||||
|  |  | |||
|  | @ -59,7 +59,10 @@ ErrorCode MetalPooling::onResize(const std::vector<Tensor *> &inputs, const std: | |||
|     ((int *)mConstBuffer.contents)[8]  = strideHeight; | ||||
|     ((int *)mConstBuffer.contents)[9]  = padWidth; | ||||
|     ((int *)mConstBuffer.contents)[10] = padHeight; | ||||
| 
 | ||||
|     auto ow = output->width(), oh = output->height(), slice = UP_DIV(output->channel(), 4) * output->batch(); | ||||
|     mLocal = MTLSizeMake(8, 8, 4); | ||||
|     mGroup = MTLSizeMake(UP_DIV(ow, 8), (NSUInteger)UP_DIV(oh, 8), (NSUInteger)UP_DIV(slice, 4)); | ||||
|     mPipeline = [context pipelineWithName:(mPoolType == PoolType_MAXPOOL) ? @"pooling_max" : @"pooling_avg"]; | ||||
|     return NO_ERROR; | ||||
| } | ||||
| 
 | ||||
|  | @ -67,17 +70,12 @@ ErrorCode MetalPooling::onExecute(const std::vector<Tensor *> &inputs, const std | |||
|     auto backend = static_cast<MetalBackend *>(this->backend()); | ||||
|     auto context = (__bridge MNNMetalContext *)backend->context(); | ||||
|     auto input = inputs[0], output = outputs[0]; | ||||
|     auto ow = output->width(), oh = output->height(), slice = UP_DIV(output->channel(), 4) * output->batch(); | ||||
| 
 | ||||
|     auto encoder   = [context encoder]; | ||||
|     auto bandwidth = [context load:(mPoolType == PoolType_MAXPOOL) ? @"pooling_max" : @"pooling_avg" encoder:encoder]; | ||||
|     bandwidth.zAxisProtected = YES; | ||||
|     [encoder setComputePipelineState:mPipeline]; | ||||
|     [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)input->deviceId() offset:0 atIndex:0]; | ||||
|     [encoder setBuffer:(__bridge id<MTLBuffer>)(void *)output->deviceId() offset:0 atIndex:1]; | ||||
|     [encoder setBuffer:mConstBuffer offset:0 atIndex:2]; | ||||
|     [context dispatchEncoder:encoder | ||||
|                      threads:{ (NSUInteger) ow, (NSUInteger)oh, (NSUInteger)slice } | ||||
|                    bandwidth:bandwidth]; | ||||
|     [encoder dispatchThreadgroups:mGroup threadsPerThreadgroup:mLocal]; | ||||
|     [encoder endEncoding]; | ||||
|     MNN_PRINT_ENCODER(context, encoder); | ||||
|     return NO_ERROR; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue