2023-04-11 11:12:00 +08:00
|
|
|
//
|
2023-04-27 15:11:05 +08:00
|
|
|
// CutlassConvCommonExecution.cu
|
2023-04-11 11:12:00 +08:00
|
|
|
// MNN
|
|
|
|
//
|
|
|
|
// Created by MNN on 2023/03/22.
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
//
|
|
|
|
|
2023-04-27 15:11:05 +08:00
|
|
|
#include "CutlassConvCommonExecution.hpp"
|
2023-04-11 11:12:00 +08:00
|
|
|
|
|
|
|
namespace MNN {
|
|
|
|
namespace CUDA {
|
|
|
|
|
2023-10-18 10:31:02 +08:00
|
|
|
CutlassConvCommonExecution::CutlassConvCommonExecution(Backend *backend) :
|
|
|
|
#ifdef ENABLE_CUDA_TUNE_PARAM
|
|
|
|
CutlassGemmTuneCommonExecution(backend)
|
|
|
|
#else
|
|
|
|
Execution(backend)
|
|
|
|
#endif
|
|
|
|
{
|
2023-04-27 15:11:05 +08:00
|
|
|
mBackendPtr = backend;
|
2023-04-11 11:12:00 +08:00
|
|
|
}
|
|
|
|
|
2023-04-27 15:11:05 +08:00
|
|
|
ErrorCode CutlassConvCommonExecution::runCutlassGemmFunc() {
|
2023-07-05 11:44:25 +08:00
|
|
|
#ifdef ENABLE_CUDA_BF16
|
2023-06-16 09:42:45 +08:00
|
|
|
if(mBf16Infer) {
|
|
|
|
if(mActivationType == 1) {
|
|
|
|
cutlass::Status status = mGemmBF16BF16ReluSm80();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else if(mActivationType == 2) {
|
|
|
|
cutlass::Status status = mGemmBF16BF16Relu6Sm80();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmBF16BF16LnSm80();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
2023-07-05 11:44:25 +08:00
|
|
|
#endif
|
2023-04-11 11:12:00 +08:00
|
|
|
if(mFp32Infer) {
|
|
|
|
if(mActivationType == 1) {
|
|
|
|
cutlass::Status status = mGemmCudaF32F32Relu();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else if(mActivationType == 2) {
|
|
|
|
cutlass::Status status = mGemmCudaF32F32Relu6();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmCudaF32F32Ln();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
|
|
|
if(mGpuComputeCap < 70) {
|
|
|
|
if(mActivationType == 1) {
|
|
|
|
if(mFp16Fp32MixInfer) {
|
|
|
|
cutlass::Status status = mGemmCudaF16F32Relu();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmCudaF16F16Relu();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
} else if(mActivationType == 2) {
|
|
|
|
if(mFp16Fp32MixInfer) {
|
|
|
|
cutlass::Status status = mGemmCudaF16F32Relu6();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmCudaF16F16Relu6();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
if(mFp16Fp32MixInfer) {
|
|
|
|
cutlass::Status status = mGemmCudaF16F32Ln();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmCudaF16F16Ln();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return NO_ERROR;
|
|
|
|
} else if(mGpuComputeCap < 75) {
|
|
|
|
if(mActivationType == 1) {
|
|
|
|
if(mFp16Fp32MixInfer) {
|
|
|
|
cutlass::Status status = mGemmF16F32ReluSm70();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmF16F16ReluSm70();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
} else if(mActivationType == 2) {
|
|
|
|
if(mFp16Fp32MixInfer) {
|
|
|
|
cutlass::Status status = mGemmF16F32Relu6Sm70();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmF16F16Relu6Sm70();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
if(mFp16Fp32MixInfer) {
|
|
|
|
cutlass::Status status = mGemmF16F32LnSm70();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmF16F16LnSm70();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
2023-12-27 17:26:44 +08:00
|
|
|
#ifdef ENABLE_CUDA_TUNE_PARAM
|
|
|
|
if(mIsTuned) {
|
|
|
|
runGemmTensorCoreFloat16Infer(&mInfo);
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
if(!mIsTuned) {
|
|
|
|
if(mActivationType == 1) {
|
|
|
|
if(mFp16Fp32MixInfer) {
|
|
|
|
cutlass::Status status = mGemmF16F32ReluSm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmF16F16ReluSm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
} else if(mActivationType == 2) {
|
|
|
|
if(mFp16Fp32MixInfer) {
|
|
|
|
cutlass::Status status = mGemmF16F32Relu6Sm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmF16F16Relu6Sm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
2023-04-11 11:12:00 +08:00
|
|
|
} else {
|
2023-12-27 17:26:44 +08:00
|
|
|
if(mFp16Fp32MixInfer) {
|
|
|
|
cutlass::Status status = mGemmF16F32LnSm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmF16F16LnSm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
2023-04-11 11:12:00 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace CUDA
|
|
|
|
} // namespace MNN
|