MNN/source/backend/cuda/execution/cutlass_common/CutlassConvCommonExecution.cu

149 lines
4.4 KiB
Plaintext
Raw Normal View History

2023-04-11 11:12:00 +08:00
//
2023-04-27 15:11:05 +08:00
// CutlassConvCommonExecution.cu
2023-04-11 11:12:00 +08:00
// MNN
//
// Created by MNN on 2023/03/22.
// Copyright © 2018, Alibaba Group Holding Limited
//
2023-04-27 15:11:05 +08:00
#include "CutlassConvCommonExecution.hpp"
2023-04-11 11:12:00 +08:00
namespace MNN {
namespace CUDA {
2023-10-18 10:31:02 +08:00
CutlassConvCommonExecution::CutlassConvCommonExecution(Backend *backend) :
#ifdef ENABLE_CUDA_TUNE_PARAM
CutlassGemmTuneCommonExecution(backend)
#else
Execution(backend)
#endif
{
2023-04-27 15:11:05 +08:00
mBackendPtr = backend;
2023-04-11 11:12:00 +08:00
}
2023-04-27 15:11:05 +08:00
ErrorCode CutlassConvCommonExecution::runCutlassGemmFunc() {
2023-07-05 11:44:25 +08:00
#ifdef ENABLE_CUDA_BF16
2023-06-16 09:42:45 +08:00
if(mBf16Infer) {
if(mActivationType == 1) {
cutlass::Status status = mGemmBF16BF16ReluSm80();
cutlass_check(status);
} else if(mActivationType == 2) {
cutlass::Status status = mGemmBF16BF16Relu6Sm80();
cutlass_check(status);
} else {
cutlass::Status status = mGemmBF16BF16LnSm80();
cutlass_check(status);
}
return NO_ERROR;
}
2023-07-05 11:44:25 +08:00
#endif
2023-04-11 11:12:00 +08:00
if(mFp32Infer) {
if(mActivationType == 1) {
cutlass::Status status = mGemmCudaF32F32Relu();
cutlass_check(status);
} else if(mActivationType == 2) {
cutlass::Status status = mGemmCudaF32F32Relu6();
cutlass_check(status);
} else {
cutlass::Status status = mGemmCudaF32F32Ln();
cutlass_check(status);
}
return NO_ERROR;
}
if(mGpuComputeCap < 70) {
if(mActivationType == 1) {
if(mFp16Fp32MixInfer) {
cutlass::Status status = mGemmCudaF16F32Relu();
cutlass_check(status);
} else {
cutlass::Status status = mGemmCudaF16F16Relu();
cutlass_check(status);
}
} else if(mActivationType == 2) {
if(mFp16Fp32MixInfer) {
cutlass::Status status = mGemmCudaF16F32Relu6();
cutlass_check(status);
} else {
cutlass::Status status = mGemmCudaF16F16Relu6();
cutlass_check(status);
}
} else {
if(mFp16Fp32MixInfer) {
cutlass::Status status = mGemmCudaF16F32Ln();
cutlass_check(status);
} else {
cutlass::Status status = mGemmCudaF16F16Ln();
cutlass_check(status);
}
}
return NO_ERROR;
} else if(mGpuComputeCap < 75) {
if(mActivationType == 1) {
if(mFp16Fp32MixInfer) {
cutlass::Status status = mGemmF16F32ReluSm70();
cutlass_check(status);
} else {
cutlass::Status status = mGemmF16F16ReluSm70();
cutlass_check(status);
}
} else if(mActivationType == 2) {
if(mFp16Fp32MixInfer) {
cutlass::Status status = mGemmF16F32Relu6Sm70();
cutlass_check(status);
} else {
cutlass::Status status = mGemmF16F16Relu6Sm70();
cutlass_check(status);
}
} else {
if(mFp16Fp32MixInfer) {
cutlass::Status status = mGemmF16F32LnSm70();
cutlass_check(status);
} else {
cutlass::Status status = mGemmF16F16LnSm70();
cutlass_check(status);
}
}
return NO_ERROR;
}
2023-12-27 17:26:44 +08:00
#ifdef ENABLE_CUDA_TUNE_PARAM
if(mIsTuned) {
runGemmTensorCoreFloat16Infer(&mInfo);
}
#endif
if(!mIsTuned) {
if(mActivationType == 1) {
if(mFp16Fp32MixInfer) {
cutlass::Status status = mGemmF16F32ReluSm75();
cutlass_check(status);
} else {
cutlass::Status status = mGemmF16F16ReluSm75();
cutlass_check(status);
}
} else if(mActivationType == 2) {
if(mFp16Fp32MixInfer) {
cutlass::Status status = mGemmF16F32Relu6Sm75();
cutlass_check(status);
} else {
cutlass::Status status = mGemmF16F16Relu6Sm75();
cutlass_check(status);
}
2023-04-11 11:12:00 +08:00
} else {
2023-12-27 17:26:44 +08:00
if(mFp16Fp32MixInfer) {
cutlass::Status status = mGemmF16F32LnSm75();
cutlass_check(status);
} else {
cutlass::Status status = mGemmF16F16LnSm75();
cutlass_check(status);
}
2023-04-11 11:12:00 +08:00
}
}
return NO_ERROR;
}
} // namespace CUDA
} // namespace MNN