mirror of https://github.com/alibaba/MNN.git
286 lines
13 KiB
Plaintext
286 lines
13 KiB
Plaintext
//
|
|
// DeconvBaseKernel.cu
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2023/04/24.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include "DeconvBaseKernel.cuh"
|
|
#include "MNNCUDAFunction.cuh"
|
|
#include "MNNCUDADefine.hpp"
|
|
|
|
namespace MNN {
|
|
namespace CUDA {
|
|
|
|
template<typename T0, typename T1>
|
|
__global__ void DeconvKernelReorder(const T0* B, T1* BP, int kw, int kh, int ic, int oc, int icPack, int ocPack) {
|
|
int kernelCount = kw * kh;
|
|
int maxCount = kernelCount * icPack * oc;
|
|
// l = Cip, h = Cop * KhKw
|
|
// [Ci, Co, KhKw] -> [KhKw, Co, Cip]
|
|
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
int l_idx = index % icPack;
|
|
int h_idx = index / icPack;
|
|
int oc_idx = h_idx % oc;
|
|
int khw_idx = h_idx / oc;
|
|
|
|
if(l_idx >= ic) {
|
|
BP[index] = (T1)0.0f;
|
|
continue;
|
|
}
|
|
BP[index] = (T1)(B[(l_idx * oc + oc_idx) * kernelCount + khw_idx]);
|
|
}
|
|
}
|
|
|
|
#define DATA_CONVERT_COPY(precision) \
|
|
if(precision == 1 || precision == 0) { *((float4*)((float*)data_im + dst_offset)) = val; }\
|
|
else if(precision == 2) { float2 tmp; tmp.x = val.x; tmp.y = val.y; *(half2 *)((half *)(data_im + dst_offset))= __float22half2_rn(tmp); \
|
|
tmp.x = val.z; tmp.y = val.w; *(half2 *)((half *)(data_im + dst_offset + 2))= __float22half2_rn(tmp);}
|
|
|
|
template <typename Stype, typename Dtype>
|
|
__global__ void Col2Im_Vec4(const int n, const Stype* data_col,
|
|
const int batch, const int height, const int width, const int channels,
|
|
const int kernel_h, const int kernel_w,
|
|
const int pad_h, const int pad_w,
|
|
const int stride_h, const int stride_w,
|
|
const int dilation_h, const int dilation_w,
|
|
const int height_col, const int width_col,
|
|
const Dtype* bias, Dtype* data_im,
|
|
const int precision,
|
|
DivModFast d_ocp,
|
|
DivModFast d_ow,
|
|
DivModFast d_oh,
|
|
DivModFast d_ob
|
|
) {
|
|
const int channel_pack = ((channels+7) / 8) * 8;
|
|
|
|
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < (n); index += blockDim.x * gridDim.x) {
|
|
float4 val;
|
|
val.x = 0.0;
|
|
val.y = 0.0;
|
|
val.z = 0.0;
|
|
val.w = 0.0;
|
|
int idx_ocp, idx_bhw, idx_bh, b_im, idx_w, idx_h;
|
|
d_ocp.divmod(index, idx_bhw, idx_ocp);
|
|
const int c_im = idx_ocp << 2;
|
|
|
|
if(c_im >= channels) {
|
|
continue;
|
|
}
|
|
|
|
d_ow.divmod(idx_bhw, idx_bh, idx_w);
|
|
d_oh.divmod(idx_bh, b_im, idx_h);
|
|
|
|
const int w_im = idx_w + pad_w;
|
|
const int h_im = idx_h + pad_h;
|
|
|
|
if(nullptr != bias) {
|
|
if(precision == 2) {
|
|
val.x += (float)bias[c_im];
|
|
val.y += (float)bias[c_im+1];
|
|
val.z += (float)bias[c_im+2];
|
|
val.w += (float)bias[c_im+3];
|
|
} else {
|
|
val = *((float4*)((float*)bias + c_im));
|
|
}
|
|
}
|
|
int kernel_extent_w = (kernel_w - 1) * dilation_w + 1;
|
|
int kernel_extent_h = (kernel_h - 1) * dilation_h + 1;
|
|
// compute the start and end of the output
|
|
const int w_col_start =
|
|
(w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1;
|
|
const int w_col_end = min(w_im / stride_w + 1, width_col);
|
|
const int h_col_start =
|
|
(h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1;
|
|
const int h_col_end = min(h_im / stride_h + 1, height_col);
|
|
const int src_offset_add = kernel_h * kernel_w * batch * height_col * width_col;
|
|
// TODO: use LCM of stride and dilation to avoid unnecessary loops
|
|
for (int h_col = h_col_start; h_col < h_col_end; h_col += 1) {
|
|
for (int w_col = w_col_start; w_col < w_col_end; w_col += 1) {
|
|
int h_k = (h_im - h_col * stride_h);
|
|
int w_k = (w_im - w_col * stride_w);
|
|
if (h_k % dilation_h == 0 && w_k % dilation_w == 0) {
|
|
h_k /= dilation_h;
|
|
w_k /= dilation_w;
|
|
|
|
const int data_col_index = ((((b_im * height_col + h_col) * width_col + w_col) * kernel_h + h_k) * kernel_w + w_k) * channel_pack + c_im;
|
|
val.x += (float)data_col[data_col_index];
|
|
val.y += (float)data_col[data_col_index + 1];
|
|
val.z += (float)data_col[data_col_index + 2];
|
|
val.w += (float)data_col[data_col_index + 3];
|
|
}
|
|
}
|
|
}
|
|
int dst_offset = index << 2;
|
|
DATA_CONVERT_COPY(precision);
|
|
}
|
|
}
|
|
|
|
template <typename Stype, typename Dtype>
|
|
__global__ void Col2Im(const int n, const Stype* data_col,
|
|
const int batch, const int height, const int width, const int channels,
|
|
const int kernel_h, const int kernel_w,
|
|
const int pad_h, const int pad_w,
|
|
const int stride_h, const int stride_w,
|
|
const int dilation_h, const int dilation_w,
|
|
const int height_col, const int width_col,
|
|
const Dtype* bias, Dtype* data_im,
|
|
const int precision,
|
|
DivModFast d_ocp,
|
|
DivModFast d_ow,
|
|
DivModFast d_oh,
|
|
DivModFast d_ob
|
|
) {
|
|
const int channel_pack = ((channels+7) / 8) * 8;
|
|
|
|
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < (n); index += blockDim.x * gridDim.x) {
|
|
Dtype val = 0;
|
|
int idx_ocp, idx_tmp, idx_bh, b_im, idx_w, idx_h;
|
|
d_ocp.divmod(index, idx_tmp, idx_ocp);
|
|
const int c_im = idx_ocp;
|
|
if(c_im >= channels) {
|
|
data_im[index] = val;
|
|
break;
|
|
}
|
|
|
|
d_ow.divmod(idx_tmp, idx_bh, idx_w);
|
|
d_oh.divmod(idx_bh, b_im, idx_h);
|
|
|
|
const int w_im = idx_w + pad_w;
|
|
const int h_im = idx_h + pad_h;
|
|
|
|
if(nullptr != bias) {
|
|
val += (Dtype)bias[c_im];
|
|
}
|
|
int kernel_extent_w = (kernel_w - 1) * dilation_w + 1;
|
|
int kernel_extent_h = (kernel_h - 1) * dilation_h + 1;
|
|
// compute the start and end of the output
|
|
const int w_col_start =
|
|
(w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1;
|
|
const int w_col_end = min(w_im / stride_w + 1, width_col);
|
|
const int h_col_start =
|
|
(h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1;
|
|
const int h_col_end = min(h_im / stride_h + 1, height_col);
|
|
// TODO: use LCM of stride and dilation to avoid unnecessary loops
|
|
for (int h_col = h_col_start; h_col < h_col_end; h_col += 1) {
|
|
for (int w_col = w_col_start; w_col < w_col_end; w_col += 1) {
|
|
int h_k = (h_im - h_col * stride_h);
|
|
int w_k = (w_im - w_col * stride_w);
|
|
if (h_k % dilation_h == 0 && w_k % dilation_w == 0) {
|
|
h_k /= dilation_h;
|
|
w_k /= dilation_w;
|
|
|
|
const int data_col_index = ((((b_im * height_col + h_col) * width_col + w_col) * kernel_h + h_k) * kernel_w + w_k) * channels + c_im;
|
|
|
|
val += (Dtype)data_col[data_col_index];
|
|
}
|
|
}
|
|
}
|
|
data_im[index] = val;
|
|
}
|
|
}
|
|
|
|
void callWeightReorder(const void* input, void* output, const KernelInfo kernel_info, const int icPack, const int precision, CUDARuntime* runtime) {
|
|
auto& prop = runtime->prop();
|
|
int cores = prop.multiProcessorCount;
|
|
int threadNumbers = prop.maxThreadsPerBlock;
|
|
|
|
int ocPack = UP_DIV(kernel_info.kernelN, PACK_NUMBER) * PACK_NUMBER;
|
|
if(precision == 1) {
|
|
DeconvKernelReorder<<<cores, threadNumbers>>>((const float*)input, (float*)output,
|
|
kernel_info.kernelX, kernel_info.kernelY, kernel_info.kernelC, kernel_info.kernelN, icPack, ocPack);
|
|
checkKernelErrors;
|
|
} else if(precision == 0) {
|
|
DeconvKernelReorder<<<cores, threadNumbers>>>((const float*)input, (half*)output,
|
|
kernel_info.kernelX, kernel_info.kernelY, kernel_info.kernelC, kernel_info.kernelN, icPack, ocPack);
|
|
checkKernelErrors;
|
|
} else {
|
|
DeconvKernelReorder<<<cores, threadNumbers>>>((const half*)input, (half*)output,
|
|
kernel_info.kernelX, kernel_info.kernelY, kernel_info.kernelC, kernel_info.kernelN, icPack, ocPack);
|
|
checkKernelErrors;
|
|
}
|
|
}
|
|
|
|
void callCol2ImFunc(const void* input, const void* bias, void* output, const Col2ImParameter* col2im_param, const int precision, CUDARuntime* runtime) {
|
|
// Do Col2Im trans
|
|
int height_col = col2im_param->ih;
|
|
int width_col = col2im_param->iw;
|
|
int num_kernels = col2im_param->ob * UP_DIV(col2im_param->oc, 8) * col2im_param->oh * col2im_param->ow * 8;
|
|
|
|
int col2im_block_num = runtime->blocks_num(num_kernels);
|
|
int col2im_thread_num = runtime->threads_num();
|
|
|
|
DivModFast ocpD((UP_DIV(col2im_param->oc, 8) * 8));
|
|
DivModFast ocpD_4((UP_DIV(col2im_param->oc, 8) * 2));
|
|
DivModFast owD(col2im_param->ow);
|
|
DivModFast ohD(col2im_param->oh);
|
|
DivModFast obD( col2im_param->ob);
|
|
|
|
// printf("col2im:%d, %d-%d-%d-%d-%d-%d\n %d-%d-%d-%d-%d-%d\n %d-%d\n", col2im_param->ob, col2im_param->oh, col2im_param->ow, col2im_param->oc, \
|
|
// col2im_param->ih, col2im_param->iw, col2im_param->ic, \
|
|
// col2im_param->padX, col2im_param->padY, col2im_param->kernelX, col2im_param->kernelY, col2im_param->strideX, col2im_param->strideY, \
|
|
// col2im_block_num, col2im_thread_num);
|
|
|
|
if(col2im_param->oc % 4 == 0) {
|
|
num_kernels /= 4;
|
|
col2im_block_num = runtime->blocks_num(num_kernels);
|
|
col2im_thread_num = runtime->threads_num();
|
|
if(precision == 1) {
|
|
Col2Im_Vec4<<<col2im_block_num, col2im_thread_num>>>(
|
|
num_kernels, (const float*)input, col2im_param->ob, col2im_param->oh, col2im_param->ow, col2im_param->oc,
|
|
col2im_param->kernelY, col2im_param->kernelX, col2im_param->padY, col2im_param->padX,
|
|
col2im_param->strideY, col2im_param->strideX, col2im_param->dilateY, col2im_param->dilateX,
|
|
height_col, width_col, (const float*)bias, (float *)output, precision, ocpD_4, owD, ohD, obD);
|
|
checkKernelErrors;
|
|
return;
|
|
} else if(precision == 0) {
|
|
Col2Im_Vec4<<<col2im_block_num, col2im_thread_num>>>(
|
|
num_kernels, (const half*)input, col2im_param->ob, col2im_param->oh, col2im_param->ow, col2im_param->oc,
|
|
col2im_param->kernelY, col2im_param->kernelX, col2im_param->padY, col2im_param->padX,
|
|
col2im_param->strideY, col2im_param->strideX, col2im_param->dilateY, col2im_param->dilateX,
|
|
height_col, width_col, (const float*)bias, (float *)output, precision, ocpD_4, owD, ohD, obD);
|
|
checkKernelErrors;
|
|
return;
|
|
} else {
|
|
Col2Im_Vec4<<<col2im_block_num, col2im_thread_num>>>(
|
|
num_kernels, (const half*)input, col2im_param->ob, col2im_param->oh, col2im_param->ow, col2im_param->oc,
|
|
col2im_param->kernelY, col2im_param->kernelX, col2im_param->padY, col2im_param->padX,
|
|
col2im_param->strideY, col2im_param->strideX, col2im_param->dilateY, col2im_param->dilateX,
|
|
height_col, width_col, (const half*)bias, (half *)output, precision, ocpD_4, owD, ohD, obD);
|
|
checkKernelErrors;
|
|
return;
|
|
}
|
|
}
|
|
|
|
if(precision == 1) {
|
|
Col2Im<<<col2im_block_num, col2im_thread_num>>>(
|
|
num_kernels, (const float*)input, col2im_param->ob, col2im_param->oh, col2im_param->ow, col2im_param->oc,
|
|
col2im_param->kernelY, col2im_param->kernelX, col2im_param->padY, col2im_param->padX,
|
|
col2im_param->strideY, col2im_param->strideX, col2im_param->dilateY, col2im_param->dilateX,
|
|
height_col, width_col, (const float*)bias, (float *)output, precision, ocpD, owD, ohD, obD);
|
|
checkKernelErrors;
|
|
return;
|
|
} else if(precision == 0) {
|
|
Col2Im<<<col2im_block_num, col2im_thread_num>>>(
|
|
num_kernels, (const half*)input, col2im_param->ob, col2im_param->oh, col2im_param->ow, col2im_param->oc,
|
|
col2im_param->kernelY, col2im_param->kernelX, col2im_param->padY, col2im_param->padX,
|
|
col2im_param->strideY, col2im_param->strideX, col2im_param->dilateY, col2im_param->dilateX,
|
|
height_col, width_col, (const float*)bias, (float *)output, precision, ocpD, owD, ohD, obD);
|
|
checkKernelErrors;
|
|
return;
|
|
} else {
|
|
Col2Im<<<col2im_block_num, col2im_thread_num>>>(
|
|
num_kernels, (const half*)input, col2im_param->ob, col2im_param->oh, col2im_param->ow, col2im_param->oc,
|
|
col2im_param->kernelY, col2im_param->kernelX, col2im_param->padY, col2im_param->padX,
|
|
col2im_param->strideY, col2im_param->strideX, col2im_param->dilateY, col2im_param->dilateX,
|
|
height_col, width_col, (const half*)bias, (half *)output, precision, ocpD, owD, ohD, obD);
|
|
checkKernelErrors;
|
|
return;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
} //namespace CUDA
|
|
} //namespace MNN |