2020-11-05 16:41:56 +08:00
|
|
|
//
|
|
|
|
// DeconvSingleInputExecution.cpp
|
|
|
|
// MNN
|
|
|
|
//
|
2022-05-06 19:51:20 +08:00
|
|
|
// Created by MNN on 2022/03/04.
|
2020-11-05 16:41:56 +08:00
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
//
|
|
|
|
|
|
|
|
#include "DeconvSingleInputExecution.hpp"
|
2023-04-27 15:11:05 +08:00
|
|
|
#include "MultiInputDeconvExecution.hpp"
|
2023-04-11 11:12:00 +08:00
|
|
|
#include "ConvBaseKernel.cuh"
|
2023-04-27 15:11:05 +08:00
|
|
|
#include "DeconvBaseKernel.cuh"
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
namespace MNN {
|
|
|
|
namespace CUDA {
|
2022-11-18 22:35:31 +08:00
|
|
|
|
2022-02-18 11:30:27 +08:00
|
|
|
DeconvSingleInputExecution::Resource::Resource(Backend* bn, const MNN::Op* op) {
|
|
|
|
mBackend = bn;
|
|
|
|
auto runtime = static_cast<CUDABackend*>(bn)->getCUDARuntime();
|
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
auto conv = op->main_as_Convolution2D();
|
|
|
|
auto common = conv->common();
|
|
|
|
mKernelInfo.kernelX = common->kernelX();
|
|
|
|
mKernelInfo.kernelY = common->kernelY();
|
2022-02-18 11:30:27 +08:00
|
|
|
mKernelInfo.groups = common->group();
|
2020-11-05 16:41:56 +08:00
|
|
|
mKernelInfo.strideX = common->strideX();
|
|
|
|
mKernelInfo.strideY = common->strideY();
|
|
|
|
mKernelInfo.dilateX = common->dilateX();
|
|
|
|
mKernelInfo.dilateY = common->dilateY();
|
|
|
|
mKernelInfo.activationType = common->relu() ? 1 : (common->relu6() ? 2 : 0);
|
|
|
|
|
|
|
|
//weight host->device
|
|
|
|
const float* filterDataPtr = nullptr;
|
|
|
|
int weightSize = 0;
|
|
|
|
std::shared_ptr<ConvolutionCommon::Int8Common> quanCommon;
|
|
|
|
ConvolutionCommon::getConvParameters(&quanCommon, conv, &filterDataPtr, &weightSize);
|
2022-02-18 11:30:27 +08:00
|
|
|
mKernelInfo.kernelN = common->outputCount();
|
|
|
|
mKernelInfo.kernelC = weightSize / mKernelInfo.kernelN / mKernelInfo.kernelX / mKernelInfo.kernelY;
|
|
|
|
|
2022-11-08 17:05:14 +08:00
|
|
|
CutlassGemmInfo param;
|
2022-02-18 11:30:27 +08:00
|
|
|
int e = mKernelInfo.kernelN * mKernelInfo.kernelX * mKernelInfo.kernelY;
|
|
|
|
int l = mKernelInfo.kernelC;
|
|
|
|
param.elh[0] = e;
|
|
|
|
param.elh[1] = l;
|
2022-11-08 17:05:14 +08:00
|
|
|
param.elhPad[0] = UP_DIV(e, PACK_NUMBER) * PACK_NUMBER;
|
|
|
|
param.elhPad[1] = UP_DIV(l, PACK_NUMBER) * PACK_NUMBER;
|
|
|
|
|
2022-02-18 11:30:27 +08:00
|
|
|
auto tempCacheBuffer = static_cast<CUDABackend*>(bn)->getStaticBufferPool()->alloc(weightSize * sizeof(float));
|
|
|
|
float* cacheWeight = (float*)((uint8_t*)tempCacheBuffer.first + tempCacheBuffer.second);
|
|
|
|
runtime->memcpy(cacheWeight, filterDataPtr, weightSize * sizeof(float), MNNMemcpyHostToDevice);
|
|
|
|
|
|
|
|
// Reorder weight
|
2022-11-18 22:35:31 +08:00
|
|
|
if(static_cast<CUDABackend*>(bn)->getPrecision() == 1) {
|
|
|
|
weightTensor.reset(Tensor::createDevice<int32_t>({param.elh[0] * param.elhPad[1]}));
|
|
|
|
} else {
|
|
|
|
weightTensor.reset(Tensor::createDevice<int16_t>({param.elh[0] * param.elhPad[1]}));
|
|
|
|
}
|
2022-02-18 11:30:27 +08:00
|
|
|
bn->onAcquireBuffer(weightTensor.get(), Backend::STATIC);
|
2022-05-06 19:51:20 +08:00
|
|
|
mFilter = (void *)weightTensor.get()->buffer().device;
|
|
|
|
|
2023-04-27 15:11:05 +08:00
|
|
|
callWeightReorder((const void *)cacheWeight, (void *)mFilter, mKernelInfo, param.elhPad[1], (int)(static_cast<CUDABackend*>(bn)->getPrecision() == 1), runtime);
|
|
|
|
|
2022-02-18 11:30:27 +08:00
|
|
|
static_cast<CUDABackend*>(bn)->getStaticBufferPool()->free(tempCacheBuffer);
|
|
|
|
|
|
|
|
// Copy Bias
|
|
|
|
int biasSize = conv->bias()->size();
|
2023-04-27 15:11:05 +08:00
|
|
|
if(static_cast<CUDABackend*>(bn)->getPrecision() == 2) {
|
|
|
|
// Pack for flaoot22half2 memory protect
|
|
|
|
int biasPackSize = UP_DIV(biasSize, PACK_NUMBER) * PACK_NUMBER;
|
|
|
|
|
|
|
|
biasTensor.reset(Tensor::createDevice<float>({biasPackSize}));
|
|
|
|
bn->onAcquireBuffer(biasTensor.get(), Backend::STATIC);
|
|
|
|
mBias = (void *)biasTensor.get()->buffer().device;
|
|
|
|
|
|
|
|
auto tempBiasBuffer = static_cast<CUDABackend*>(bn)->getStaticBufferPool()->alloc(biasPackSize * sizeof(float));
|
|
|
|
float* cacheBias = (float*)((uint8_t*)tempBiasBuffer.first + tempBiasBuffer.second);
|
|
|
|
cuda_check(cudaMemcpy(cacheBias, conv->bias()->data(), conv->bias()->size()*sizeof(float), cudaMemcpyHostToDevice));
|
|
|
|
|
|
|
|
callFloat2Half((const void*)cacheBias, (void*)mBias, biasPackSize, runtime);
|
|
|
|
} else {
|
|
|
|
biasTensor.reset(Tensor::createDevice<float>({biasSize}));
|
|
|
|
bn->onAcquireBuffer(biasTensor.get(), Backend::STATIC);
|
|
|
|
mBias = (void *)biasTensor.get()->buffer().device;
|
|
|
|
cuda_check(cudaMemcpy(mBias, conv->bias()->data(), conv->bias()->size()*sizeof(float), cudaMemcpyHostToDevice));
|
|
|
|
}
|
2022-02-18 11:30:27 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
DeconvSingleInputExecution::Resource::~Resource() {
|
|
|
|
// Do nothing
|
|
|
|
}
|
2023-04-27 15:11:05 +08:00
|
|
|
DeconvSingleInputExecution::DeconvSingleInputExecution(Backend* backend, const MNN::Op* op, std::shared_ptr<Resource> res) : CutlassDeconvCommonExecution(backend) {
|
2022-02-18 11:30:27 +08:00
|
|
|
mResource = res;
|
2023-04-27 15:11:05 +08:00
|
|
|
mOp = op;
|
|
|
|
mPrecisonLevel = static_cast<CUDABackend*>(backend)->getPrecision();
|
|
|
|
mFp16Infer = (mPrecisonLevel == 2);
|
|
|
|
mFp32Infer = (mPrecisonLevel == 1);
|
|
|
|
mFp16Fp32MixInfer = (mPrecisonLevel == 0);
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
DeconvSingleInputExecution::~DeconvSingleInputExecution() {
|
2022-11-08 17:05:14 +08:00
|
|
|
// Do nothing
|
2022-02-18 11:30:27 +08:00
|
|
|
}
|
|
|
|
bool DeconvSingleInputExecution::onClone(Backend* bn, const Op* op, Execution** dst) {
|
|
|
|
if (!mValid) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
if (nullptr == dst) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
auto dstExe = new DeconvSingleInputExecution(bn, op, mResource);
|
|
|
|
*dst = dstExe;
|
|
|
|
return true;
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
|
2022-02-18 11:30:27 +08:00
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
ErrorCode DeconvSingleInputExecution::onResize(const std::vector<Tensor*> &inputs, const std::vector<Tensor*> &outputs) {
|
2022-02-18 11:30:27 +08:00
|
|
|
auto runtime = static_cast<CUDABackend*>(backend())->getCUDARuntime();
|
2020-11-05 16:41:56 +08:00
|
|
|
auto input = inputs[0], output = outputs[0];
|
2022-11-08 17:05:14 +08:00
|
|
|
auto bytes = static_cast<CUDABackend*>(backend())->getBytes(inputs[0]);
|
2022-02-18 11:30:27 +08:00
|
|
|
auto convCommon = mOp->main_as_Convolution2D()->common();
|
|
|
|
|
|
|
|
// Col2Im Param
|
|
|
|
auto pad = ConvolutionCommon::convolutionTransposePad(input, output, mOp->main_as_Convolution2D()->common());
|
|
|
|
mCol2ImParamter.dilateX = convCommon->dilateX();
|
|
|
|
mCol2ImParamter.dilateY = convCommon->dilateY();
|
|
|
|
mCol2ImParamter.strideX = convCommon->strideX();
|
|
|
|
mCol2ImParamter.strideY = convCommon->strideY();
|
|
|
|
mCol2ImParamter.ic = input->channel();
|
|
|
|
mCol2ImParamter.oc = output->channel();
|
|
|
|
mCol2ImParamter.kernelX = convCommon->kernelX();
|
|
|
|
mCol2ImParamter.kernelY = convCommon->kernelY();
|
|
|
|
mCol2ImParamter.padX = pad.first;
|
|
|
|
mCol2ImParamter.padY = pad.second;
|
|
|
|
|
|
|
|
mCol2ImParamter.ih = input->height();
|
|
|
|
mCol2ImParamter.iw = input->width();
|
|
|
|
mCol2ImParamter.oh = output->height();
|
|
|
|
mCol2ImParamter.ow = output->width();
|
|
|
|
mCol2ImParamter.ob = output->batch();
|
|
|
|
|
2022-11-08 17:05:14 +08:00
|
|
|
mActivationType = convCommon->relu() ? 1 : convCommon->relu6() ? 2 : 0;
|
2022-02-18 11:30:27 +08:00
|
|
|
|
|
|
|
// Matmul Param
|
|
|
|
int e = output->channel() * mCol2ImParamter.kernelX * mCol2ImParamter.kernelY;
|
|
|
|
int l = input->channel();
|
|
|
|
int h = input->height() * input->width() * output->batch();
|
|
|
|
|
2022-11-08 17:05:14 +08:00
|
|
|
mGemmInfo.elh[0] = e;
|
|
|
|
mGemmInfo.elh[1] = l;
|
|
|
|
mGemmInfo.elh[2] = h;
|
|
|
|
mGemmInfo.elhPad[0] = UP_DIV(e, PACK_NUMBER) * PACK_NUMBER;
|
|
|
|
mGemmInfo.elhPad[1] = UP_DIV(l, PACK_NUMBER) * PACK_NUMBER;
|
|
|
|
mGemmInfo.elhPad[2] = UP_DIV(h, PACK_NUMBER) * PACK_NUMBER;
|
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
|
2022-02-18 11:30:27 +08:00
|
|
|
// Alloc temp cuda memory
|
|
|
|
auto pool = static_cast<CUDABackend*>(backend())->getBufferPool();
|
2023-09-04 10:42:11 +08:00
|
|
|
MemChunk buffer_input, buffer_im2col;
|
2022-11-18 22:35:31 +08:00
|
|
|
if(mFp16Fp32MixInfer) {
|
2022-11-08 17:05:14 +08:00
|
|
|
buffer_input = pool->alloc(sizeof(__half) * mGemmInfo.elhPad[1] * mGemmInfo.elh[2]);
|
2022-11-18 22:35:31 +08:00
|
|
|
mInputBuffer = (void*)((uint8_t*)buffer_input.first + buffer_input.second);
|
2022-11-08 17:05:14 +08:00
|
|
|
} else {
|
2022-11-18 22:35:31 +08:00
|
|
|
mInputBuffer = (void*)input->deviceId();
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
buffer_im2col = pool->alloc(bytes * mGemmInfo.elh[0] * mGemmInfo.elhPad[2]);
|
|
|
|
mIm2ColBuffer = (__half*)((uint8_t*)buffer_im2col.first + buffer_im2col.second);
|
2022-11-18 22:35:31 +08:00
|
|
|
if(mFp16Fp32MixInfer) {
|
2022-11-08 17:05:14 +08:00
|
|
|
pool->free(buffer_input);
|
|
|
|
}
|
|
|
|
pool->free(buffer_im2col);
|
|
|
|
|
2022-11-18 22:35:31 +08:00
|
|
|
if(mFp16Fp32MixInfer || mFp32Infer) {
|
2022-11-08 17:05:14 +08:00
|
|
|
mZeroTensor.reset(Tensor::createDevice<uint32_t>({mGemmInfo.elhPad[2]}));
|
|
|
|
} else {
|
|
|
|
mZeroTensor.reset(Tensor::createDevice<uint16_t>({mGemmInfo.elhPad[2]}));
|
|
|
|
}
|
|
|
|
static_cast<CUDABackend*>(backend())->onAcquireBuffer(mZeroTensor.get(), Backend::STATIC);
|
|
|
|
|
|
|
|
mZeroPtr = (void *)mZeroTensor.get()->buffer().device;
|
|
|
|
cuda_check(cudaMemset(mZeroPtr, 0, mGemmInfo.elhPad[2]*bytes));
|
|
|
|
|
2022-11-18 22:35:31 +08:00
|
|
|
|
2023-04-27 15:11:05 +08:00
|
|
|
mFilterAddr = mResource->mFilter;
|
|
|
|
mBiasAddr = mResource->mBias;
|
|
|
|
mBackendPtr = mResource->mBackend;
|
|
|
|
|
|
|
|
// Call from different function
|
|
|
|
if(mFp32Infer){
|
|
|
|
return callCutlassGemmCudaCoreFloat32(inputs, outputs);
|
|
|
|
}
|
|
|
|
|
2022-11-08 17:05:14 +08:00
|
|
|
mGpuComputeCap = runtime->compute_capability();
|
2023-04-27 15:11:05 +08:00
|
|
|
//MNN_PRINT("Gpu smArch is sm_%d\n", mGpuComputeCap);
|
2022-11-08 17:05:14 +08:00
|
|
|
if(mGpuComputeCap < 75) {
|
2023-04-27 15:11:05 +08:00
|
|
|
return callCutlassGemmCudaCoreFloat16(inputs, outputs);
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
2023-04-27 15:11:05 +08:00
|
|
|
return callCutlassGemmTensorCore(inputs, outputs);
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
ErrorCode DeconvSingleInputExecution::onExecute(const std::vector<Tensor*> &inputs, const std::vector<Tensor*> &outputs) {
|
2022-02-18 11:30:27 +08:00
|
|
|
//MNN_PRINT("cuda convSingleInput onExecute in, inputsize:%d %d\n", (int)inputs.size(), workspace_size_);
|
2020-11-05 16:41:56 +08:00
|
|
|
MNN_ASSERT(inputs.size() == 1);
|
|
|
|
MNN_ASSERT(outputs.size() == 1);
|
|
|
|
|
|
|
|
auto runtime = static_cast<CUDABackend*>(backend())->getCUDARuntime();
|
|
|
|
const void *input_addr = (const void*)inputs[0]->deviceId();
|
2022-02-18 11:30:27 +08:00
|
|
|
const void *filter_addr = mResource->mFilter;
|
|
|
|
const void *bias_addr = mResource->mBias;
|
2020-11-05 16:41:56 +08:00
|
|
|
void *output_addr = (void*)outputs[0]->deviceId();
|
|
|
|
|
2022-11-08 17:05:14 +08:00
|
|
|
// Do input Rerange Pack
|
2023-04-27 15:11:05 +08:00
|
|
|
if(mFp16Fp32MixInfer) {
|
|
|
|
size_t maxCount = mGemmInfo.elhPad[1] * mGemmInfo.elh[2];
|
|
|
|
callFloat2Half((const void*)input_addr, (void*)mInputBuffer, maxCount, runtime);
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
|
2023-04-27 15:11:05 +08:00
|
|
|
// Run cutlass gemm forward
|
|
|
|
runCutlassGemmFunc();
|
2022-02-18 11:30:27 +08:00
|
|
|
|
2023-04-27 15:11:05 +08:00
|
|
|
// Run Col2Im
|
|
|
|
int convert_flag = mPrecisonLevel;
|
|
|
|
if(convert_flag == 0) {
|
|
|
|
convert_flag = 1;
|
2022-08-12 10:30:48 +08:00
|
|
|
}
|
2023-04-27 15:11:05 +08:00
|
|
|
callCol2ImFunc((const void*)mIm2ColBuffer, (const void*)bias_addr, (void *)output_addr, &mCol2ImParamter, convert_flag, runtime);
|
2020-11-05 16:41:56 +08:00
|
|
|
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
|
|
|
class CUDADeconvolutionCreator : public CUDABackend::Creator {
|
|
|
|
public:
|
|
|
|
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
|
|
|
const MNN::Op* op, Backend* backend) const override {
|
|
|
|
if (nullptr != op->main_as_Convolution2D()->quanParameter()) {
|
|
|
|
auto quan = op->main_as_Convolution2D()->quanParameter();
|
|
|
|
if (1 == quan->type() || 2 == quan->type()) {
|
|
|
|
MNN_PRINT("cuda Deconv quant type 1 or 2 not support\n");
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-04-27 15:11:05 +08:00
|
|
|
if(inputs.size() == 2 || inputs.size() == 3) {
|
|
|
|
return new MultiInputDeconvExecution(op, backend);
|
2020-11-05 16:41:56 +08:00
|
|
|
} else if(inputs.size() == 1) {
|
2022-02-18 11:30:27 +08:00
|
|
|
std::shared_ptr<DeconvSingleInputExecution::Resource> resource(new DeconvSingleInputExecution::Resource(backend, op));
|
|
|
|
return new DeconvSingleInputExecution(backend, op, resource);
|
2020-11-05 16:41:56 +08:00
|
|
|
} else {
|
|
|
|
MNN_PRINT("Deconv inputs size:%d not support", (int)inputs.size());
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
2022-05-06 19:51:20 +08:00
|
|
|
CUDACreatorRegister<CUDADeconvolutionCreator> __DeConvExecution(OpType_Deconvolution);
|
2020-11-05 16:41:56 +08:00
|
|
|
|
|
|
|
}// namespace CUDA
|
2021-11-30 10:10:53 +08:00
|
|
|
}// namespace MNN
|