mirror of https://github.com/alibaba/MNN.git
190 lines
8.0 KiB
Plaintext
190 lines
8.0 KiB
Plaintext
//
|
|
// MetalConvolutionGEMM.mm
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2019/01/31.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#import "backend/metal/MetalConvolutionGEMM.hpp"
|
|
#import "core/Macro.h"
|
|
#import "core/Macro.h"
|
|
#import "backend/metal/MetalBackend.hpp"
|
|
#import "backend/metal/MetalConvolution.hpp"
|
|
|
|
#if MNN_METAL_ENABLED
|
|
namespace MNN {
|
|
|
|
bool MetalConvolutionGEMM::isValid(const Convolution2D *conv, const Tensor *input) {
|
|
auto common = conv->common();
|
|
auto kx = common->kernelX(), ky = common->kernelY();
|
|
if (kx == 1 || ky == 1 || common->group() != 1) {
|
|
return false;
|
|
}
|
|
auto oc = common->outputCount();
|
|
if (oc <= 64) {
|
|
return false;
|
|
}
|
|
auto iw = input->width(), ih = input->height(), ic = input->channel();
|
|
if (iw * ih * ic <= 16384) {
|
|
return false;
|
|
}
|
|
auto sx = common->strideX(), ow = (iw - kx + 1) / sx;
|
|
auto sy = common->strideY(), oh = (ih - ky + 1) / sy;
|
|
if ((iw * ih * ic) / (ow * oh * oc) > 4) {
|
|
return false;
|
|
}
|
|
|
|
auto unit = conv->quanParameter() != nullptr ? sizeof(float) : sizeof(metal_float);
|
|
auto iz = UP_DIV(ic, 4), oz = UP_DIV(oc, 4), batch = input->batch();
|
|
return UP_DIV(ow * oh * batch, 4) * kx * ky * iz * 16 * sizeof(metal_float) < (2 << 20) && // tmp input
|
|
UP_DIV(ow * oh * batch, 4) * oz * 16 * unit < (2 << 20); // tmp output
|
|
}
|
|
|
|
MetalConvolutionGEMM::MetalConvolutionGEMM(Backend *backend, const Tensor *input, const MNN::Op *op)
|
|
: MetalConvolutionCommon(backend, op) {
|
|
loadWeight(op->main_as_Convolution2D());
|
|
}
|
|
|
|
ErrorCode MetalConvolutionGEMM::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
|
// prepare
|
|
auto backend = static_cast<MetalBackend *>(this->backend());
|
|
auto context = (__bridge MNNMetalContext *)backend->context();
|
|
auto input = inputs[0], output = outputs[0];
|
|
auto iw = input->width(), ih = input->height(), iz = UP_DIV(input->channel(), 4);
|
|
auto ow = output->width(), oh = output->height(), oz = UP_DIV(output->channel(), 4), ob = output->batch();
|
|
|
|
// pad mode support
|
|
int padX = mPadX, padY = mPadY;
|
|
if (mPadMode == PadMode_SAME) {
|
|
int kernelWidthSize = (mKernelX - 1) * mDilateX + 1;
|
|
int kernelHeightSize = (mKernelY - 1) * mDilateY + 1;
|
|
int pw = (ow - 1) * mStrideX + kernelWidthSize - iw;
|
|
int ph = (oh - 1) * mStrideY + kernelHeightSize - ih;
|
|
padX = pw / 2;
|
|
padY = ph / 2;
|
|
}
|
|
|
|
// create const buffer
|
|
int constants[] = {iw,
|
|
ih,
|
|
iw * ih,
|
|
iz,
|
|
ow,
|
|
oh,
|
|
ow * oh,
|
|
oz,
|
|
input->batch(),
|
|
|
|
mKernelX,
|
|
mKernelY,
|
|
mKernelX * mKernelY,
|
|
mStrideX,
|
|
mStrideY,
|
|
padX,
|
|
padY,
|
|
mDilateX,
|
|
mDilateY,
|
|
mActivationType};
|
|
mConstBuffer.reset(sizeof(constants));
|
|
::memcpy(mConstBuffer.buffer().contents, constants, sizeof(constants));
|
|
|
|
// create mat mul const buffer
|
|
int shapes[] = {UP_DIV(ow * oh * ob, 4), oz, mKernelX * mKernelY * iz, 1};
|
|
mShapeBuffer = [context newDeviceBuffer:sizeof(shapes) bytes:shapes access:CPUWriteOnly];
|
|
|
|
// accquire space for source & dst
|
|
int is = UP_DIV(ow * oh * ob, 4) * mKernelX * mKernelY * iz * 16 * sizeof(metal_float) / sizeof(uint8_t);
|
|
int os = UP_DIV(ow * oh * ob, 4) * oz * 16 * sizeof(metal_float) / sizeof(uint8_t);
|
|
mTempInput.reset(Tensor::createDevice<uint8_t>(std::vector<int>{is}));
|
|
mTempOutput.reset(Tensor::createDevice<uint8_t>(std::vector<int>{os}));
|
|
|
|
if (!backend->onAcquireBuffer(mTempInput.get(), Backend::DYNAMIC) ||
|
|
!backend->onAcquireBuffer(mTempOutput.get(), Backend::DYNAMIC)) {
|
|
return OUT_OF_MEMORY;
|
|
}
|
|
backend->onReleaseBuffer(mTempInput.get(), Backend::DYNAMIC);
|
|
backend->onReleaseBuffer(mTempOutput.get(), Backend::DYNAMIC);
|
|
mPipelineGEMM = [context pipelineWithName:@"matmul4x4"];
|
|
mPipelineIm2Col = [context pipelineWithName:@"conv_im2col"];
|
|
mPipelineCol2Im = [context pipelineWithName:@"conv_col2im"];
|
|
NSUInteger gw = UP_DIV(output->width() * output->height() * output->batch(), 4);
|
|
NSUInteger gh = UP_DIV(output->channel(), 4);
|
|
mGemm = [context computeBestGroupAndLocal:mPipelineGEMM threads:{gw, gh, 1}];
|
|
mIm2Col = [context computeBestGroupAndLocal:mPipelineIm2Col threads:{(NSUInteger)ow, (NSUInteger)oh, (NSUInteger)iz*ob}];
|
|
mCol2Im = [context computeBestGroupAndLocal:mPipelineCol2Im threads:{(NSUInteger)ow, (NSUInteger)oh, (NSUInteger)oz*ob}];
|
|
return NO_ERROR;
|
|
}
|
|
|
|
ErrorCode MetalConvolutionGEMM::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
|
return onFloat(inputs[0], outputs[0]);
|
|
}
|
|
|
|
ErrorCode MetalConvolutionGEMM::onFloat(const Tensor *input, const Tensor *output) {
|
|
auto backend = static_cast<MetalBackend *>(this->backend());
|
|
auto encoder = backend->encoder();
|
|
|
|
{ // im2col
|
|
[encoder setComputePipelineState:mPipelineIm2Col];
|
|
[encoder setBuffer:(__bridge id<MTLBuffer>)(void *)input->deviceId() offset:0 atIndex:0];
|
|
[encoder setBuffer:(__bridge id<MTLBuffer>)(void *)mTempInput->deviceId() offset:0 atIndex:1];
|
|
[encoder setBuffer:mConstBuffer.buffer() offset:0 atIndex:2];
|
|
[encoder dispatchThreadgroups:mIm2Col.first threadsPerThreadgroup:mIm2Col.second];
|
|
}
|
|
{ // gemm
|
|
[encoder setComputePipelineState:mPipelineGEMM];
|
|
[encoder setBuffer:(__bridge id<MTLBuffer>)(void *)mTempInput->deviceId() offset:0 atIndex:0];
|
|
[encoder setBuffer:(__bridge id<MTLBuffer>)(void *)mTempOutput->deviceId() offset:0 atIndex:1];
|
|
[encoder setBuffer:mWeight offset:0 atIndex:2];
|
|
[encoder setBuffer:mShapeBuffer offset:0 atIndex:3];
|
|
[encoder dispatchThreadgroups:mGemm.first threadsPerThreadgroup:mGemm.second];
|
|
}
|
|
{ // col2im
|
|
[encoder setComputePipelineState:mPipelineCol2Im];
|
|
[encoder setBuffer:(__bridge id<MTLBuffer>)(void *)mTempOutput->deviceId() offset:0 atIndex:0];
|
|
[encoder setBuffer:(__bridge id<MTLBuffer>)(void *)output->deviceId() offset:0 atIndex:1];
|
|
[encoder setBuffer:mBias offset:0 atIndex:2];
|
|
[encoder setBuffer:mConstBuffer.buffer() offset:0 atIndex:3];
|
|
[encoder dispatchThreadgroups:mCol2Im.first threadsPerThreadgroup:mCol2Im.second];
|
|
}
|
|
return NO_ERROR;
|
|
}
|
|
|
|
template <typename FType, typename TType>
|
|
static id<MTLBuffer> weightInBlock(MNNMetalContext *context, int group, int oc, int ic, int kh, int kw,
|
|
const FType *src) {
|
|
auto oz = UP_DIV(oc, 4);
|
|
auto iz = UP_DIV(ic, 4);
|
|
auto buffer = [context newDeviceBuffer:oz * iz * kw * kh * 16 * sizeof(TType) access:CPUWriteOnly];
|
|
auto dst = (TType *)buffer.contents;
|
|
|
|
for (int o = 0; o < oc; o++) {
|
|
auto zo = o / 4, ro = o % 4;
|
|
auto o_dst = dst + zo * iz * kh * kw * 16 + ro; // o/4 x 4
|
|
#pragma clang loop vectorize(enable)
|
|
for (int i = 0; i < ic; i++) {
|
|
auto zi = i / 4, ri = i % 4;
|
|
auto i_dst = o_dst + zi * kh * kw * 16 + ri * 4; // i/4 x 4
|
|
#pragma clang loop vectorize(enable)
|
|
for (int h = 0; h < kh; h++) {
|
|
#pragma clang loop vectorize(enable) unroll(enable)
|
|
for (int w = 0; w < kw; w++) {
|
|
// to [g][o/4][i/4][h][w][16]
|
|
// from [g][o][i][h][w]
|
|
i_dst[(h * kw + w) * 16] = *src++;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return buffer;
|
|
}
|
|
|
|
id<MTLBuffer> MetalConvolutionGEMM::weightForFloat(int group, int oc, int ic, int kh, int kw, const float *src) {
|
|
auto backend = static_cast<MetalBackend *>(this->backend());
|
|
auto context = (__bridge MNNMetalContext *)static_cast<MetalBackend *>(backend)->context();
|
|
return weightInBlock<float, metal_float>(context, group, oc, ic, kh, kw, src);
|
|
}
|
|
|
|
} // namespace MNN
|
|
#endif /* MNN_METAL_ENABLED */
|