MNN/source/backend/cuda/execution/ConvSingleInputExecution.cu

307 lines
14 KiB
Plaintext
Raw Normal View History

2020-11-05 16:41:56 +08:00
//
// ConvSingleInputExecution.cpp
// MNN
//
// Created by MNN on 2020/08/22.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "ConvSingleInputExecution.hpp"
2022-02-18 11:30:27 +08:00
#include "Raster.cuh"
#include "MNNCUDADefine.hpp"
#include "MNNCUDAFunction.cuh"
2020-11-05 16:41:56 +08:00
2022-02-18 11:30:27 +08:00
// 16 / sizeof(int4)
2020-11-05 16:41:56 +08:00
namespace MNN {
namespace CUDA {
2022-02-18 11:30:27 +08:00
__global__ void KernelReorder(const float* B, half* BP, int kw, int kh, int ic, int oc, int ocPack) {
int icC4 = UP_DIV(ic, PACK_NUMBER);
int kernelCount = kw * kh;
int l = icC4 * kernelCount * PACK_NUMBER;
int h = oc;
int lDiv = UP_DIV(l, MATMULPACK);
int lAlign = lDiv * MATMULPACK;
int hAlign = UP_DIV(h, ocPack) * ocPack;
int maxCount = hAlign * lAlign;
for (size_t indexO = blockIdx.x * blockDim.x + threadIdx.x; indexO < maxCount; indexO += blockDim.x * gridDim.x) {
int lR = indexO % MATMULPACK;
int tmp = indexO / MATMULPACK;
int hR = tmp % ocPack;
int tmp2 = tmp / ocPack;
int lC = tmp2 % lDiv;
int hC = tmp2 / lDiv;
half* dst = BP + indexO;
int sH = hC * ocPack + hR;
int sL = lC * MATMULPACK + lR;
if (sH >= oc) {
*dst = 0.0;
2022-01-04 10:50:40 +08:00
continue;
2020-11-05 16:41:56 +08:00
}
2022-02-18 11:30:27 +08:00
int sLR = sL % PACK_NUMBER;
int sLC = sL / PACK_NUMBER;
int iLC = sLC / (kernelCount);
int ik = sLC % kernelCount;
int iz = iLC * PACK_NUMBER + sLR;
if (iz >= ic) {
*dst = 0.0;
continue;
2022-01-04 10:50:40 +08:00
}
2022-02-18 11:30:27 +08:00
const float* src = B + sH * kernelCount * ic + ik + iz * kernelCount;
*dst = *src;
2020-11-05 16:41:56 +08:00
}
}
ConvSingleInputExecution::Resource::Resource(Backend* bn, const MNN::Op* op) {
mBackend = bn;
2022-01-04 10:50:40 +08:00
auto runtime = static_cast<CUDABackend*>(bn)->getCUDARuntime();
2020-11-05 16:41:56 +08:00
auto conv = op->main_as_Convolution2D();
auto common = conv->common();
mKernelInfo.kernelX = common->kernelX();
mKernelInfo.kernelY = common->kernelY();
mKernelInfo.groups = common->group();
2020-11-05 16:41:56 +08:00
mKernelInfo.strideX = common->strideX();
mKernelInfo.strideY = common->strideY();
mKernelInfo.dilateX = common->dilateX();
mKernelInfo.dilateY = common->dilateY();
mKernelInfo.activationType = common->relu() ? 1 : (common->relu6() ? 2 : 0);
//weight host->device
const float* filterDataPtr = nullptr;
int weightSize = 0;
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
ConvolutionCommon::getConvParameters(&quanCommon, conv, &filterDataPtr, &weightSize);
mKernelInfo.kernelN = common->outputCount();
mKernelInfo.kernelC = weightSize / mKernelInfo.kernelN / mKernelInfo.kernelX / mKernelInfo.kernelY;
2022-02-18 11:30:27 +08:00
int icDiv = UP_DIV(mKernelInfo.kernelC, PACK_NUMBER);
2022-01-04 10:50:40 +08:00
MatMulParam param;
int e = 0;
2022-02-18 11:30:27 +08:00
int l = mKernelInfo.kernelX * mKernelInfo.kernelY * icDiv * MATMULPACK;
2022-01-04 10:50:40 +08:00
int h = mKernelInfo.kernelN;
param.elh[0] = e;
param.elh[1] = l;
param.elh[2] = h;
2022-02-18 11:30:27 +08:00
param.elhPack[0] = UP_DIV(e, MATMULPACK);
param.elhPack[1] = UP_DIV(l, MATMULPACK);
param.elhPack[2] = UP_DIV(h, MATMULPACK);
2022-01-04 10:50:40 +08:00
param.bStride[0] = 0;
param.bStride[1] = 1;
param.bStride[2] = l;
2022-02-18 11:30:27 +08:00
FuseRegion reg;
int maxOffsetNumber = 8;
std::vector<int> offset(maxOffsetNumber);
auto regionStorage = static_cast<CUDABackend*>(bn)->getStaticBufferPool()->alloc(sizeof(FuseRegion));
auto offsetGpuStorage = static_cast<CUDABackend*>(bn)->getStaticBufferPool()->alloc(sizeof(int) * maxOffsetNumber);
auto offsetGpu = (uint8_t*)offsetGpuStorage.first + offsetGpuStorage.second;
2022-01-04 10:50:40 +08:00
// Reorder weight
2022-02-18 11:30:27 +08:00
{
auto tempCacheBuffer = static_cast<CUDABackend*>(bn)->getStaticBufferPool()->alloc(weightSize * sizeof(float));
float* cacheWeight = (float*)((uint8_t*)tempCacheBuffer.first + tempCacheBuffer.second);
runtime->memcpy(cacheWeight, filterDataPtr, weightSize * sizeof(float), MNNMemcpyHostToDevice);
weightTensor.reset(Tensor::createDevice<int16_t>({param.elhPack[1] * param.elhPack[2] * (MATMULPACK * MATMULPACK)}));
bn->onAcquireBuffer(weightTensor.get(), Backend::STATIC);
mFilter = (void *)weightTensor.get()->buffer().device;
auto& prop = runtime->prop();
int cores = prop.multiProcessorCount;
int threadNumbers = prop.maxThreadsPerBlock;
if (param.elhPack[2] % 2 == 0) {
KernelReorder<<<cores, threadNumbers>>>((float*)cacheWeight, (half*)mFilter,
mKernelInfo.kernelX, mKernelInfo.kernelY, mKernelInfo.kernelC, mKernelInfo.kernelN, 32);
mUsePack = true;
} else {
KernelReorder<<<cores, threadNumbers>>>((float*)cacheWeight, (half*)mFilter,
mKernelInfo.kernelX, mKernelInfo.kernelY, mKernelInfo.kernelC, mKernelInfo.kernelN, MATMULPACK);
}
static_cast<CUDABackend*>(bn)->getStaticBufferPool()->free(tempCacheBuffer);
}
2020-11-05 16:41:56 +08:00
2022-01-04 10:50:40 +08:00
// Copy Bias
int biasSize = conv->bias()->size();
biasTensor.reset(Tensor::createDevice<float>({biasSize}));
bn->onAcquireBuffer(biasTensor.get(), Backend::STATIC);
2022-02-18 11:30:27 +08:00
auto tempBiasStorage = static_cast<CUDABackend*>(bn)->getStaticBufferPool()->alloc(conv->bias()->size()*sizeof(float));
auto biasTemp = (float*)((uint8_t*)tempBiasStorage.first + tempBiasStorage.second);
cuda_check(cudaMemcpy(biasTemp, conv->bias()->data(), conv->bias()->size()*sizeof(float), cudaMemcpyHostToDevice));
// FP32 -> FP16
mBias = (void *)biasTensor.get()->buffer().device;
2022-02-18 11:30:27 +08:00
int alignSize = UP_DIV(conv->bias()->size(), PACK_NUMBER) * PACK_NUMBER;
reg.size[0] = 1;
reg.size[1] = 1;
reg.size[2] = alignSize;
reg.srcStride[0] = 0;
reg.srcStride[1] = 0;
reg.srcStride[2] = 1;
reg.dstStride[0] = 0;
reg.dstStride[1] = 0;
reg.dstStride[2] = 1;
offset[0] = 1;
offset[1] = 1;
offset[2] = conv->bias()->size();
offset[3] = 0;
offset[4] = 1;
offset[5] = 1;
offset[6] = reg.size[2];
offset[7] = 0;
reg.fuseNumber = 1;
runtime->memcpy((uint8_t*)regionStorage.first + regionStorage.second, &reg, sizeof(FuseRegion), MNNMemcpyHostToDevice, true);
runtime->memcpy(offsetGpu, offset.data(), 8 * sizeof(int), MNNMemcpyHostToDevice, true);
if (static_cast<CUDABackend*>(bn)->useFp16()) {
FuseRasterBlitFloatToHalf((uint8_t*)mBias, (uint8_t*)biasTemp, (FuseRegion*)((uint8_t*)regionStorage.first + regionStorage.second), offsetGpu, runtime);
} else {
FuseRasterBlitCommon((uint8_t*)mBias, (uint8_t*)biasTemp, (FuseRegion*)((uint8_t*)regionStorage.first + regionStorage.second), offsetGpu, runtime, 4);
}
static_cast<CUDABackend*>(bn)->getStaticBufferPool()->free(regionStorage);
static_cast<CUDABackend*>(bn)->getStaticBufferPool()->free(offsetGpuStorage);
static_cast<CUDABackend*>(bn)->getStaticBufferPool()->free(tempBiasStorage);
2020-11-05 16:41:56 +08:00
}
ConvSingleInputExecution::Resource::~Resource() {
2022-01-04 10:50:40 +08:00
// Do nothing
}
ConvSingleInputExecution::ConvSingleInputExecution(Backend* backend, const MNN::Op* op, std::shared_ptr<Resource> res) : Execution(backend), mOp(op) {
mResource = res;
auto runtime = static_cast<CUDABackend*>(backend)->getCUDARuntime();
2022-01-04 10:50:40 +08:00
auto staticPool = static_cast<CUDABackend*>(backend)->getStaticBufferPool();
mGpuMatMulParam = staticPool->alloc(sizeof(MatMulParam));
mGpuIm2ColParam = staticPool->alloc(sizeof(ConvolutionCommon::Im2ColParameter));
}
ConvSingleInputExecution::~ConvSingleInputExecution() {
2022-01-04 10:50:40 +08:00
auto staticPool = static_cast<CUDABackend*>(backend())->getStaticBufferPool();
staticPool->free(mGpuMatMulParam);
staticPool->free(mGpuIm2ColParam);
}
bool ConvSingleInputExecution::onClone(Backend* bn, const Op* op, Execution** dst) {
if (!mValid) {
return false;
2020-11-05 16:41:56 +08:00
}
if (nullptr == dst) {
return true;
2020-11-05 16:41:56 +08:00
}
auto dstExe = new ConvSingleInputExecution(bn, op, mResource);
*dst = dstExe;
return true;
2020-11-05 16:41:56 +08:00
}
2022-01-04 10:50:40 +08:00
2020-11-05 16:41:56 +08:00
ErrorCode ConvSingleInputExecution::onResize(const std::vector<Tensor*> &inputs, const std::vector<Tensor*> &outputs) {
2022-01-04 10:50:40 +08:00
auto runtime = static_cast<CUDABackend*>(backend())->getCUDARuntime();
2020-11-05 16:41:56 +08:00
auto input = inputs[0], output = outputs[0];
2022-02-18 11:30:27 +08:00
const int UNIT = PACK_NUMBER;
2022-01-04 10:50:40 +08:00
auto convCommon = mOp->main_as_Convolution2D()->common();
auto pads = ConvolutionCommon::convolutionPadFull(input, output, mOp->main_as_Convolution2D()->common());
2022-02-18 11:30:27 +08:00
int ic = input->channel();
int icDiv = UP_DIV(ic, PACK_NUMBER);
2022-01-04 10:50:40 +08:00
mIm2ColParamter.dilateX = convCommon->dilateX();
mIm2ColParamter.dilateY = convCommon->dilateY();
mIm2ColParamter.strideX = convCommon->strideX();
mIm2ColParamter.strideY = convCommon->strideY();
2022-02-18 11:30:27 +08:00
mIm2ColParamter.icDiv4 = icDiv;
2022-01-04 10:50:40 +08:00
mIm2ColParamter.kernelX = convCommon->kernelX();
mIm2ColParamter.kernelY = convCommon->kernelY();
mIm2ColParamter.padX = std::get<0>(pads);
mIm2ColParamter.padY = std::get<1>(pads);
mIm2ColParamter.ih = input->height();
mIm2ColParamter.iw = input->width();
mIm2ColParamter.oh = output->height();
mIm2ColParamter.ow = output->width();
mIm2ColParamter.srcZStep = input->height() * input->width() * UNIT * input->batch();
mIm2ColParamter.srcYStep = input->width() * UNIT;
mIm2ColParamter.packCUnit = UNIT;
runtime->memcpy((uint8_t*)mGpuIm2ColParam.first + mGpuIm2ColParam.second, &mIm2ColParamter, sizeof(ConvolutionCommon::Im2ColParameter), MNNMemcpyHostToDevice);
2022-02-18 11:30:27 +08:00
//MNN_PRINT("conv size:%d-%d-%d, %d-%d-%d\n", input->height(), input->width(), input->channel(), output->height(), output->width(), output->channel());
2022-01-04 10:50:40 +08:00
int e = output->height() * output->width() * output->batch();
2022-02-18 11:30:27 +08:00
int l = icDiv * mIm2ColParamter.kernelX * mIm2ColParamter.kernelY * MATMULPACK;
2022-01-04 10:50:40 +08:00
int h = output->channel();
mMatMulParam.elh[0] = e;
mMatMulParam.elh[1] = l;
mMatMulParam.elh[2] = h;
2022-02-18 11:30:27 +08:00
mMatMulParam.elhPack[0] = UP_DIV(e, MATMULPACK);
mMatMulParam.elhPack[1] = UP_DIV(l, MATMULPACK);
mMatMulParam.elhPack[2] = UP_DIV(h, MATMULPACK);
2022-01-04 10:50:40 +08:00
mMatMulParam.cStride[0] = mIm2ColParamter.ow * mIm2ColParamter.oh * h;
mMatMulParam.cStride[1] = 1;
mMatMulParam.cStride[2] = mIm2ColParamter.ow * mIm2ColParamter.oh;
2022-02-18 11:30:27 +08:00
mMatMulParam.minValue = -FLT_MAX;
mMatMulParam.maxValue = FLT_MAX;
2022-01-04 10:50:40 +08:00
if (convCommon->relu()) {
mMatMulParam.minValue = 0.0f;
}
if (convCommon->relu6()) {
mMatMulParam.minValue = 0.0f;
mMatMulParam.maxValue = 6.0f;
}
2022-02-18 11:30:27 +08:00
//MNN_PRINT("Im2Col temp size:%d!!!\n\n", mMatMulParam.elhPack[0] * mMatMulParam.elhPack[1] * MATMULPACK * MATMULPACK);
2022-01-04 10:50:40 +08:00
runtime->memcpy((uint8_t*)mGpuMatMulParam.first + mGpuMatMulParam.second, &mMatMulParam, sizeof(MatMulParam), MNNMemcpyHostToDevice);
auto pool = static_cast<CUDABackend*>(backend())->getBufferPool();
2022-02-18 11:30:27 +08:00
auto buffer = pool->alloc((size_t)sizeof(__half) * (size_t)mMatMulParam.elhPack[0] * (size_t)mMatMulParam.elhPack[1] * (size_t)MATMULPACK * (size_t)MATMULPACK);
2022-01-04 10:50:40 +08:00
mIm2ColBuffer = (__half*)((uint8_t*)buffer.first + buffer.second);
pool->free(buffer);
2022-02-18 11:30:27 +08:00
2020-11-05 16:41:56 +08:00
return NO_ERROR;
}
ErrorCode ConvSingleInputExecution::onExecute(const std::vector<Tensor*> &inputs, const std::vector<Tensor*> &outputs) {
//MNN_PRINT("cuda convSingleInput onExecute in, inputsize:%d %d\n", (int)inputs.size(), workspace_size_);
MNN_ASSERT(inputs.size() == 1);
MNN_ASSERT(outputs.size() == 1);
2022-02-18 11:30:27 +08:00
auto input = inputs[0];
auto output = outputs[0];
2020-11-05 16:41:56 +08:00
auto runtime = static_cast<CUDABackend*>(backend())->getCUDARuntime();
2022-02-18 11:30:27 +08:00
auto bytes = static_cast<CUDABackend*>(backend())->getBytes(input);
2020-11-05 16:41:56 +08:00
const void *input_addr = (const void*)inputs[0]->deviceId();
const void *filter_addr = mResource->mFilter;
const void *bias_addr = mResource->mBias;
2022-02-18 11:30:27 +08:00
auto bn = backend();
2020-11-05 16:41:56 +08:00
void *output_addr = (void*)outputs[0]->deviceId();
2022-02-18 11:30:27 +08:00
2022-01-04 10:50:40 +08:00
auto gpuIm2Col = (const ConvolutionCommon::Im2ColParameter*)((uint8_t*)mGpuIm2ColParam.first + mGpuIm2ColParam.second);
auto gpuMatMul = (const MatMulParam*)((uint8_t*)mGpuMatMulParam.first + mGpuMatMulParam.second);
2022-02-18 11:30:27 +08:00
// Im2Col func
Im2ColMain(runtime, &mMatMulParam, gpuMatMul, &mIm2ColParamter, gpuIm2Col, (const float*)input_addr, mIm2ColBuffer, bytes);
if (mResource->mUsePack) {
GemmPacked16x32(runtime, &mMatMulParam, gpuMatMul, (float*)output_addr, (const __half*)mIm2ColBuffer, (const __half*)filter_addr, (const half*)bias_addr, bytes);
} else {
//printf("NotPack:%d-%d-%d-%d-%d, %d-%d-%d\n", mIm2ColParamter.icDiv4, mIm2ColParamter.ih, mIm2ColParamter.iw, mIm2ColParamter.oh, mIm2ColParamter.ow, mMatMulParam.elhPack[0], mMatMulParam.elhPack[1], mMatMulParam.elhPack[2]);
GemmPackedFullMain(runtime, &mMatMulParam, gpuMatMul, (float*)output_addr, (const __half*)mIm2ColBuffer, (const __half*)filter_addr, (const half*)bias_addr, bytes);
}
2020-11-05 16:41:56 +08:00
return NO_ERROR;
}
class CUDAConvolutionCreator : public CUDABackend::Creator {
public:
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
const MNN::Op* op, Backend* backend) const override {
if (nullptr != op->main_as_Convolution2D()->quanParameter()) {
auto quan = op->main_as_Convolution2D()->quanParameter();
if (1 == quan->type() || 2 == quan->type()) {
if (quan->has_scaleInt()) {
// Don't support IDST-int8 because of error
return nullptr;
}
2020-11-05 16:41:56 +08:00
}
}
2022-01-04 10:50:40 +08:00
std::shared_ptr<ConvSingleInputExecution::Resource> resource(new ConvSingleInputExecution::Resource(backend, op));
return new ConvSingleInputExecution(backend, op, resource);
2020-11-05 16:41:56 +08:00
}
};
CUDACreatorRegister<CUDAConvolutionCreator> __ConvExecution(OpType_Convolution);
}// namespace CUDA
}// namespace MNN