diff --git a/source/backend/metal/MNNMetalContext.h b/source/backend/metal/MNNMetalContext.h index 7417a02a..2dc55e27 100644 --- a/source/backend/metal/MNNMetalContext.h +++ b/source/backend/metal/MNNMetalContext.h @@ -121,6 +121,7 @@ typedef struct { threads:(MTLSize)threads threadsPerGroup:(MTLSize)threadsPerGroup bandwidth:(MNN::MetalBandwidth)bandwidth; +- (id)pipelineWithName:(NSString *)name; #if MNN_METAL_DEBUG /** @@ -147,6 +148,8 @@ typedef struct { * @brief print encoder */ - (void)printEncoder:(id)encoder; + + #endif @end diff --git a/source/backend/metal/MetalPooling.hpp b/source/backend/metal/MetalPooling.hpp index cc1697f6..04f009da 100644 --- a/source/backend/metal/MetalPooling.hpp +++ b/source/backend/metal/MetalPooling.hpp @@ -33,6 +33,9 @@ private: int mPadX; int mPadY; id mConstBuffer; + MTLSize mGroup; + MTLSize mLocal; + id mPipeline; }; } // namespace MNN diff --git a/source/backend/metal/MetalPooling.mm b/source/backend/metal/MetalPooling.mm index 47cba0f9..29dd08d2 100755 --- a/source/backend/metal/MetalPooling.mm +++ b/source/backend/metal/MetalPooling.mm @@ -59,7 +59,10 @@ ErrorCode MetalPooling::onResize(const std::vector &inputs, const std: ((int *)mConstBuffer.contents)[8] = strideHeight; ((int *)mConstBuffer.contents)[9] = padWidth; ((int *)mConstBuffer.contents)[10] = padHeight; - + auto ow = output->width(), oh = output->height(), slice = UP_DIV(output->channel(), 4) * output->batch(); + mLocal = MTLSizeMake(8, 8, 4); + mGroup = MTLSizeMake(UP_DIV(ow, 8), (NSUInteger)UP_DIV(oh, 8), (NSUInteger)UP_DIV(slice, 4)); + mPipeline = [context pipelineWithName:(mPoolType == PoolType_MAXPOOL) ? @"pooling_max" : @"pooling_avg"]; return NO_ERROR; } @@ -67,17 +70,12 @@ ErrorCode MetalPooling::onExecute(const std::vector &inputs, const std auto backend = static_cast(this->backend()); auto context = (__bridge MNNMetalContext *)backend->context(); auto input = inputs[0], output = outputs[0]; - auto ow = output->width(), oh = output->height(), slice = UP_DIV(output->channel(), 4) * output->batch(); - auto encoder = [context encoder]; - auto bandwidth = [context load:(mPoolType == PoolType_MAXPOOL) ? @"pooling_max" : @"pooling_avg" encoder:encoder]; - bandwidth.zAxisProtected = YES; + [encoder setComputePipelineState:mPipeline]; [encoder setBuffer:(__bridge id)(void *)input->deviceId() offset:0 atIndex:0]; [encoder setBuffer:(__bridge id)(void *)output->deviceId() offset:0 atIndex:1]; [encoder setBuffer:mConstBuffer offset:0 atIndex:2]; - [context dispatchEncoder:encoder - threads:{ (NSUInteger) ow, (NSUInteger)oh, (NSUInteger)slice } - bandwidth:bandwidth]; + [encoder dispatchThreadgroups:mGroup threadsPerThreadgroup:mLocal]; [encoder endEncoding]; MNN_PRINT_ENCODER(context, encoder); return NO_ERROR;