MNN/source/backend/cuda/execution/CutlassGemmParam.hpp

353 lines
20 KiB
C++
Raw Normal View History

2022-08-12 10:30:48 +08:00
//
// ConvCutlassExecution.hpp
// MNN
//
// Created by MNN on 2020/08/22.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef CutlassGemmParam_hpp
#define CutlassGemmParam_hpp
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_relu6.h"
#include "cutlass/gemm/device/gemm.h"
2022-09-30 10:02:52 +08:00
#include "cutlass/gemm/device/gemm_array.h"
#include "cutlass/gemm/device/gemm_batched.h"
2022-08-12 10:30:48 +08:00
namespace MNN {
namespace CUDA {
2022-09-30 10:02:52 +08:00
struct CutlassGemmInfo{
int elh[3];
int elhPad[3];
};
2022-08-12 10:30:48 +08:00
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations
using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A
using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B
using ElementOutput_F16 = cutlass::half_t; // <- data type of elements in output matrix D
using ElementOutput_F32 = float; // <- data type of elements in output matrix D
// The code section below describes matrix layout of input and output matrices. Column Major for
// Matrix A, Row Major for Matrix B and Row Major for Matrix C
using LayoutInputA = cutlass::layout::RowMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::RowMajor;
// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM
using MMAOp = cutlass::arch::OpClassTensorOp;
// This code section describes CUDA SM architecture number
using SmArch70 = cutlass::arch::Sm70;
using SmArch75 = cutlass::arch::Sm75;
// This code section describes the tile size a thread block will compute
using ShapeMMAThreadBlock =
2022-09-30 10:02:52 +08:00
cutlass::gemm::GemmShape<64, 64, 64>; // <- threadblock tile M = 128, N = 256, K = 64
2022-08-12 10:30:48 +08:00
// This code section describes tile size a warp will compute
2022-09-30 10:02:52 +08:00
using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 32, 64>; // <- warp tile M = 64, N = 64, K = 64
2022-08-12 10:30:48 +08:00
// This code section describes the size of MMA op
using ShapeMMAOp1688 = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 8, N = 8, K = 16
using ShapeMMAOp884 = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 16
// This code section describes how threadblocks are scheduled on GPU
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ??
// This code section describes the epilogue part of the kernel
using EpilogueOp_F16_Linear = cutlass::epilogue::thread::LinearCombination<
ElementOutput_F16, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput_F16>::value, // <- the number of elements per vectorized
// memory access. For a byte, it's 16
// elements. This becomes the vector width of
// math instructions in the epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
using EpilogueOp_F32_Linear = cutlass::epilogue::thread::LinearCombination<
ElementOutput_F32, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput_F32>::value, // <- the number of elements per vectorized
// memory access. For a byte, it's 16
// elements. This becomes the vector width of
// math instructions in the epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function
using EpilogueOp_F16_Relu = cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput_F16, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput_F16>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This becomes
// the vector width of math instructions in
// epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha in linear combination function
using EpilogueOp_F32_Relu = cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput_F32, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput_F32>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This becomes
// the vector width of math instructions in
// epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha in linear combination function
using EpilogueOp_F16_Relu6 = cutlass::epilogue::thread::LinearCombinationRelu6<
ElementOutput_F16, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput_F16>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This becomes
// the vector width of math instructions in
// epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha in linear combination function
using EpilogueOp_F32_Relu6 = cutlass::epilogue::thread::LinearCombinationRelu6<
ElementOutput_F32, // <- data type of output matrix
128 / cutlass::sizeof_bits<ElementOutput_F32>::value, // <- this is the number of elements per
// vectorized memory access. For half
// precision, it's 8 elements. This becomes
// the vector width of math instructions in
// epilogue too
ElementAccumulator, // <- data type of accumulator
ElementComputeEpilogue>; // <- data type for alpha in linear combination function
// Number of pipelines you want to use
constexpr int NumStages = 2;
using Gemm_F16_Linear_Sm70 = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput_F16,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch70,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp884,
EpilogueOp_F16_Linear,
SwizzleThreadBlock,
NumStages>;
using Gemm_F16_Linear_Sm75 = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput_F16,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch75,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp1688,
EpilogueOp_F16_Linear,
SwizzleThreadBlock,
NumStages>;
using Gemm_F32_Linear_Sm70 = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput_F32,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch70,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp884,
EpilogueOp_F32_Linear,
SwizzleThreadBlock,
NumStages>;
using Gemm_F32_Linear_Sm75 = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput_F32,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch75,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp1688,
EpilogueOp_F32_Linear,
SwizzleThreadBlock,
NumStages>;
using Gemm_F16_Relu_Sm70 = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput_F16,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch70,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp884,
EpilogueOp_F16_Relu,
SwizzleThreadBlock,
NumStages>;
using Gemm_F16_Relu_Sm75 = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput_F16,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch75,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp1688,
EpilogueOp_F16_Relu,
SwizzleThreadBlock,
NumStages>;
using Gemm_F32_Relu_Sm70 = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput_F32,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch70,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp884,
EpilogueOp_F32_Relu,
SwizzleThreadBlock,
NumStages>;
using Gemm_F32_Relu_Sm75 = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput_F32,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch75,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp1688,
EpilogueOp_F32_Relu,
SwizzleThreadBlock,
NumStages>;
using Gemm_F16_Relu6_Sm70 = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput_F16,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch70,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp884,
EpilogueOp_F16_Relu6,
SwizzleThreadBlock,
NumStages>;
using Gemm_F16_Relu6_Sm75 = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput_F16,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch75,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp1688,
EpilogueOp_F16_Relu6,
SwizzleThreadBlock,
NumStages>;
using Gemm_F32_Relu6_Sm70 = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput_F32,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch70,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp884,
EpilogueOp_F32_Relu6,
SwizzleThreadBlock,
NumStages>;
using Gemm_F32_Relu6_Sm75 = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput_F32,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch75,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp1688,
EpilogueOp_F32_Relu6,
SwizzleThreadBlock,
NumStages>;
2022-09-30 10:02:52 +08:00
// This code section describes how threadblocks are scheduled on GPU
using BatchedSwizzleThreadBlock = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; // <- ??
using ShapeBatchMMAThreadBlock =
cutlass::gemm::GemmShape<64, 64, 64>; // <- threadblock tile M = 128, N = 256, K = 64
// This code section describes tile size a warp will compute
using ShapeBatchMMAWarp = cutlass::gemm::GemmShape<16, 64, 64>; // <- warp tile M = 64, N = 64, K = 64
using GemmBatched_F16_Linear_Sm75 = cutlass::gemm::device::GemmBatched<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput_F16,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch75,
ShapeBatchMMAThreadBlock,
ShapeBatchMMAWarp,
ShapeMMAOp1688,
EpilogueOp_F16_Linear,
BatchedSwizzleThreadBlock,
NumStages>;
using GemmBatched_F32_Linear_Sm75 = cutlass::gemm::device::GemmBatched<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput_F32,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch75,
ShapeBatchMMAThreadBlock,
ShapeBatchMMAWarp,
ShapeMMAOp1688,
EpilogueOp_F32_Linear,
BatchedSwizzleThreadBlock,
NumStages>;
2022-08-12 10:30:48 +08:00
} // namespace CUDA
} // namespace MNN
#endif /* CutlassGemmParam_hpp */