2020-11-05 16:41:56 +08:00
|
|
|
#include "MatMulExecution.hpp"
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
namespace MNN {
|
|
|
|
namespace CUDA {
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2022-11-18 22:35:31 +08:00
|
|
|
template<typename T0, typename T1>
|
2022-11-08 17:05:14 +08:00
|
|
|
__global__ void PackPadFill(
|
2022-11-18 22:35:31 +08:00
|
|
|
const T0* A, const T0* B,
|
2022-11-08 17:05:14 +08:00
|
|
|
bool transA, bool transB,
|
2023-03-17 17:04:38 +08:00
|
|
|
T1* tempA, T1* tempB, const int batchA, const int batchB,
|
2022-11-08 17:05:14 +08:00
|
|
|
const int e, const int l, const int h,
|
|
|
|
const int ep, const int lp, const int hp,
|
|
|
|
DivModFast d_e, DivModFast d_l, DivModFast d_h,
|
|
|
|
DivModFast d_lp, DivModFast d_lp2
|
|
|
|
) {
|
2022-11-18 22:35:31 +08:00
|
|
|
T1 zero = (T1)0.0;
|
2022-11-08 17:05:14 +08:00
|
|
|
|
|
|
|
if((char *)A != (char *)tempA) {
|
|
|
|
if(transA) { // l * e , just transpose to e * lp
|
2023-03-17 17:04:38 +08:00
|
|
|
const int maxCount = batchA * e * lp;
|
2022-11-08 17:05:14 +08:00
|
|
|
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
|
int bIndex, lpIndex, eIndex, tmp;
|
|
|
|
d_lp.divmod(index, tmp, lpIndex);
|
|
|
|
d_e.divmod(tmp, bIndex, eIndex);
|
|
|
|
|
|
|
|
if(lpIndex >= l) {
|
2022-11-18 22:35:31 +08:00
|
|
|
tempA[index] = zero;
|
2022-11-08 17:05:14 +08:00
|
|
|
continue;
|
|
|
|
}
|
|
|
|
tempA[index] = A[bIndex * e * l + lpIndex * e + eIndex];
|
|
|
|
}
|
|
|
|
} else { // e * l, just pack for l
|
|
|
|
if (l & 1 == 0) {
|
2023-03-17 17:04:38 +08:00
|
|
|
const int maxCount = batchA * e * (lp >> 1);
|
2022-11-08 17:05:14 +08:00
|
|
|
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
|
int lp2Index, eIndex, bIndex, tmp;
|
|
|
|
d_lp2.divmod(index, tmp, lp2Index);
|
|
|
|
d_e.divmod(tmp, bIndex, eIndex);
|
|
|
|
|
|
|
|
if(lp2Index + lp2Index >= l) {
|
2022-11-18 22:35:31 +08:00
|
|
|
tempA[index+index] = zero;
|
|
|
|
tempA[index+index+1] = zero;
|
2022-11-08 17:05:14 +08:00
|
|
|
continue;
|
|
|
|
}
|
2022-11-18 22:35:31 +08:00
|
|
|
tempA[index+index] = A[bIndex * e * l + eIndex * l + lp2Index + lp2Index];
|
|
|
|
tempA[index+index+1] = A[bIndex * e * l + eIndex * l + lp2Index + lp2Index + 1];
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
} else {
|
2023-03-17 17:04:38 +08:00
|
|
|
const int maxCount = batchA * e * lp;
|
2022-11-08 17:05:14 +08:00
|
|
|
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
|
int lpIndex, eIndex, bIndex, tmp;
|
|
|
|
d_lp.divmod(index, tmp, lpIndex);
|
|
|
|
d_e.divmod(tmp, bIndex, eIndex);
|
|
|
|
if(lpIndex >= l || eIndex >= e) {
|
2022-11-18 22:35:31 +08:00
|
|
|
tempA[index] = zero;
|
2022-11-08 17:05:14 +08:00
|
|
|
continue;
|
|
|
|
}
|
|
|
|
tempA[index] = A[bIndex * e * l + eIndex * l + lpIndex];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if((char *)B != (char *)tempB) {
|
|
|
|
if(!transB) { // l * h
|
2023-03-17 17:04:38 +08:00
|
|
|
const int maxCount = batchB * lp * h;
|
2022-11-08 17:05:14 +08:00
|
|
|
if(h == hp) { // and h already packed, just pack for l -> lp * h
|
|
|
|
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
|
int lpIndex, hpIndex, bIndex, tmp;
|
|
|
|
d_h.divmod(index, tmp, hpIndex);
|
|
|
|
d_lp.divmod(tmp, bIndex, lpIndex);
|
|
|
|
|
|
|
|
if(lpIndex >= l || hpIndex >= h) {
|
2022-11-18 22:35:31 +08:00
|
|
|
tempB[index] = zero;
|
2022-11-08 17:05:14 +08:00
|
|
|
continue;
|
|
|
|
}
|
|
|
|
tempB[index] = B[bIndex * h * l + lpIndex * h + hpIndex];
|
|
|
|
}
|
|
|
|
} else { // and h not packed, just transpose and pack for l -> h * lp
|
|
|
|
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
|
int lpIndex, hIndex, bIndex, tmp;
|
|
|
|
d_lp.divmod(index, tmp, lpIndex);
|
|
|
|
d_h.divmod(tmp, bIndex, hIndex);
|
|
|
|
|
|
|
|
if(lpIndex >= l || hIndex >= h) {
|
2022-11-18 22:35:31 +08:00
|
|
|
tempB[index] = zero;
|
2022-11-08 17:05:14 +08:00
|
|
|
continue;
|
|
|
|
}
|
|
|
|
tempB[index] = B[bIndex * h * l + lpIndex * h + hIndex];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else { // h * l, just pack for l
|
|
|
|
if(l & 1 == 0) {
|
2023-03-17 17:04:38 +08:00
|
|
|
const int maxCount = batchB * h * (lp >> 1);
|
2022-11-08 17:05:14 +08:00
|
|
|
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
|
int lp2Index, hIndex, bIndex, tmp;
|
|
|
|
d_lp2.divmod(index, tmp, lp2Index);
|
|
|
|
d_h.divmod(tmp, bIndex, hIndex);
|
|
|
|
|
|
|
|
if(lp2Index + lp2Index >= l) {
|
2022-11-18 22:35:31 +08:00
|
|
|
tempB[index+index] = zero;
|
|
|
|
tempB[index+index+1] = zero;
|
2022-11-08 17:05:14 +08:00
|
|
|
continue;
|
|
|
|
}
|
2022-11-18 22:35:31 +08:00
|
|
|
tempB[index+index] = B[bIndex * h * l + hIndex * l + lp2Index + lp2Index];
|
|
|
|
tempB[index+index+1] = B[bIndex * h * l + hIndex * l + lp2Index + lp2Index + 1];
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
} else {
|
2023-03-17 17:04:38 +08:00
|
|
|
const int maxCount = batchB * h * lp;
|
2022-11-08 17:05:14 +08:00
|
|
|
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < maxCount; index += blockDim.x * gridDim.x) {
|
|
|
|
int lpIndex, hIndex, bIndex, tmp;
|
|
|
|
d_lp.divmod(index, tmp, lpIndex);
|
|
|
|
d_h.divmod(tmp, bIndex, hIndex);
|
|
|
|
|
|
|
|
if(lpIndex >= l || hIndex >= h) {
|
2022-11-18 22:35:31 +08:00
|
|
|
tempB[index] = zero;
|
2022-11-08 17:05:14 +08:00
|
|
|
continue;
|
|
|
|
}
|
|
|
|
tempB[index] = B[bIndex * h * l + hIndex * l + lpIndex];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
|
2023-03-17 17:04:38 +08:00
|
|
|
MatMulExecution::MatMulExecution(bool transposeA, bool transposeB, Backend *backend, int aS, int bS, int cS) : Execution(backend) {
|
2020-11-05 16:41:56 +08:00
|
|
|
mTransposeA = transposeA;
|
|
|
|
mTransposeB = transposeB;
|
2022-12-24 09:42:39 +08:00
|
|
|
mBackend = backend;
|
2022-11-18 22:35:31 +08:00
|
|
|
int precisonLevel = static_cast<CUDABackend*>(backend)->getPrecision();
|
|
|
|
mFp16Infer = (precisonLevel == 2);
|
|
|
|
mFp32Infer = (precisonLevel == 1);
|
|
|
|
mFp16Fp32MixInfer = (precisonLevel == 0);
|
2023-03-17 17:04:38 +08:00
|
|
|
mAs = aS;
|
|
|
|
mBs = bS;
|
|
|
|
mCs = cS;
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
MatMulExecution::~ MatMulExecution() {
|
2022-11-08 17:05:14 +08:00
|
|
|
// do nothing
|
|
|
|
}
|
|
|
|
|
|
|
|
void MatMulExecution::setArguments(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
|
|
|
auto runtime = static_cast<CUDABackend*>(backend())->getCUDARuntime();
|
|
|
|
auto bytes = static_cast<CUDABackend*>(backend())->getBytes(inputs[0]);
|
|
|
|
auto pool = static_cast<CUDABackend*>(backend())->getBufferPool();
|
|
|
|
|
|
|
|
const Tensor* A = inputs[0];
|
|
|
|
const Tensor* B = inputs[1];
|
|
|
|
auto C = outputs[0];
|
|
|
|
bool hAlignment = (mGemmInfo.elhPad[2] == mGemmInfo.elh[2]);
|
|
|
|
|
|
|
|
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
|
2022-12-24 09:42:39 +08:00
|
|
|
ElementComputeEpilogue beta = ElementComputeEpilogue(0);
|
2022-11-08 17:05:14 +08:00
|
|
|
|
|
|
|
// Split K dimension into 1 partitions
|
|
|
|
cutlass::gemm::GemmCoord problem_size(mGemmInfo.elh[0], mGemmInfo.elh[2], mGemmInfo.elhPad[1]);// m n k
|
|
|
|
|
|
|
|
if (inputs.size() > 2) {
|
|
|
|
mBiasPtr = (void*)inputs[2]->deviceId();
|
2022-12-24 09:42:39 +08:00
|
|
|
beta = ElementComputeEpilogue(1);
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
2022-11-18 22:35:31 +08:00
|
|
|
if(mFp32Infer) {
|
|
|
|
if(mUseRRLayout) {
|
|
|
|
typename GemmBatchedCuda_F32_F32_Linear_AlignCuda_Row_Row::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F32 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
2022-11-18 22:35:31 +08:00
|
|
|
{(ElementInput_F32 *)mTempMatB, mGemmInfo.elhPad[2]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elhPad[2]* mBs), // batch_stride_B
|
2022-11-18 22:35:31 +08:00
|
|
|
{(ElementOutput_F32 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F32 *)C->deviceId(), mGemmInfo.elhPad[2]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[2]), // batch_stride_C
|
2022-11-18 22:35:31 +08:00
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
|
|
|
|
|
|
|
size_t workspace_size = GemmBatchedCuda_F32_F32_Linear_AlignCuda_Row_Row::get_workspace_size(arguments);
|
|
|
|
|
|
|
|
if(workspace_size != 0) {
|
2022-12-24 09:42:39 +08:00
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
2022-11-18 22:35:31 +08:00
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
|
|
|
cutlass::Status status = mGemmBatchedCudaF32F32LnAlign1RR.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmBatchedCudaF32F32LnAlign1RR.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
typename GemmBatchedCuda_F32_F32_Linear_AlignCuda_Row_Column::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F32 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
2022-11-18 22:35:31 +08:00
|
|
|
{(ElementInput_F32 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elh[2]* mBs), // batch_stride_B
|
2022-11-18 22:35:31 +08:00
|
|
|
{(ElementOutput_F32 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F32 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elh[2]), // batch_stride_C
|
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
|
|
|
|
|
|
|
size_t workspace_size = GemmBatchedCuda_F32_F32_Linear_AlignCuda_Row_Column::get_workspace_size(arguments);
|
|
|
|
|
|
|
|
if(workspace_size != 0) {
|
2022-12-24 09:42:39 +08:00
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
2022-11-18 22:35:31 +08:00
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
|
|
|
cutlass::Status status = mGemmBatchedCudaF32F32LnAlign1RC.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmBatchedCudaF32F32LnAlign1RC.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2022-11-08 17:05:14 +08:00
|
|
|
mGpuComputeCap = runtime->compute_capability();
|
|
|
|
//MNN_PRINT("Gpu smArch is sm_%d\n", mGpuComputeCap);
|
|
|
|
|
|
|
|
if(mGpuComputeCap < 75) {
|
2022-11-18 22:35:31 +08:00
|
|
|
if(mFp16Infer) {
|
2022-11-08 17:05:14 +08:00
|
|
|
if(mUseRRLayout) {
|
|
|
|
typename GemmBatchedCuda_F16_F16_Linear_AlignCuda_Row_Row::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[2]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elhPad[2]* mBs), // batch_stride_B
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elhPad[2]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[2]), // batch_stride_C
|
2022-11-08 17:05:14 +08:00
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
|
|
|
|
|
|
|
size_t workspace_size = GemmBatchedCuda_F16_F16_Linear_AlignCuda_Row_Row::get_workspace_size(arguments);
|
|
|
|
if(workspace_size != 0) {
|
2022-12-24 09:42:39 +08:00
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
2022-11-18 22:35:31 +08:00
|
|
|
cutlass::Status status = mGemmBatchedCudaF16F16LnAlign1RR.can_implement(arguments);
|
2022-11-08 17:05:14 +08:00
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
2022-11-18 22:35:31 +08:00
|
|
|
status = mGemmBatchedCudaF16F16LnAlign1RR.initialize(arguments, (uint8_t *)mWorkspace);
|
2022-11-08 17:05:14 +08:00
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
typename GemmBatchedCuda_F16_F16_Linear_AlignCuda_Row_Column::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elh[2]* mBs), // batch_stride_B
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elh[2]), // batch_stride_C
|
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
|
|
|
|
|
|
|
size_t workspace_size = GemmBatchedCuda_F16_F16_Linear_AlignCuda_Row_Column::get_workspace_size(arguments);
|
|
|
|
|
|
|
|
if(workspace_size != 0) {
|
2022-12-24 09:42:39 +08:00
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
2022-11-18 22:35:31 +08:00
|
|
|
cutlass::Status status = mGemmBatchedCudaF16F16LnAlign1RC.can_implement(arguments);
|
2022-11-08 17:05:14 +08:00
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
2022-11-18 22:35:31 +08:00
|
|
|
status = mGemmBatchedCudaF16F16LnAlign1RC.initialize(arguments, (uint8_t *)mWorkspace);
|
2022-11-08 17:05:14 +08:00
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
if(mUseRRLayout) {
|
|
|
|
if(mNeedConvertMatAB) {
|
|
|
|
typename GemmBatchedCuda_F16_F32_Linear_AlignCuda_Row_Row::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[2]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elhPad[2]* mBs), // batch_stride_B
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementOutput_F32 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F32 *)C->deviceId(), mGemmInfo.elhPad[2]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[2]), // batch_stride_C
|
2022-11-08 17:05:14 +08:00
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
|
|
|
|
|
|
|
size_t workspace_size = GemmBatchedCuda_F16_F32_Linear_AlignCuda_Row_Row::get_workspace_size(arguments);
|
|
|
|
|
|
|
|
if(workspace_size != 0) {
|
2022-12-24 09:42:39 +08:00
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
|
|
|
cutlass::Status status = mGemmBatchedCudaF16F32LnAlign1RR.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmBatchedCudaF16F32LnAlign1RR.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
typename GemmBatchedCuda_F32_F32_Linear_AlignCuda_Row_Row::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F32 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementInput_F32 *)mTempMatB, mGemmInfo.elhPad[2]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elhPad[2]* mBs), // batch_stride_B
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementOutput_F32 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F32 *)C->deviceId(), mGemmInfo.elhPad[2]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[2]), // batch_stride_C
|
2022-11-08 17:05:14 +08:00
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
|
|
|
|
|
|
|
size_t workspace_size = GemmBatchedCuda_F32_F32_Linear_AlignCuda_Row_Row::get_workspace_size(arguments);
|
|
|
|
|
|
|
|
if(workspace_size != 0) {
|
2022-12-24 09:42:39 +08:00
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
|
|
|
cutlass::Status status = mGemmBatchedCudaF32F32LnAlign1RR.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmBatchedCudaF32F32LnAlign1RR.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
if(mNeedConvertMatAB) {
|
|
|
|
typename GemmBatchedCuda_F16_F32_Linear_AlignCuda_Row_Column::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elh[2]* mBs), // batch_stride_B
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementOutput_F32 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F32 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elh[2]), // batch_stride_C
|
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
|
|
|
|
|
|
|
size_t workspace_size = GemmBatchedCuda_F16_F32_Linear_AlignCuda_Row_Column::get_workspace_size(arguments);
|
|
|
|
|
|
|
|
if(workspace_size != 0) {
|
2022-12-24 09:42:39 +08:00
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
|
|
|
cutlass::Status status = mGemmBatchedCudaF16F32LnAlign1RC.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmBatchedCudaF16F32LnAlign1RC.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
typename GemmBatchedCuda_F32_F32_Linear_AlignCuda_Row_Column::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F32 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementInput_F32 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elh[2]* mBs), // batch_stride_B
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementOutput_F32 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F32 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elh[2]), // batch_stride_C
|
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
|
|
|
|
|
|
|
size_t workspace_size = GemmBatchedCuda_F32_F32_Linear_AlignCuda_Row_Column::get_workspace_size(arguments);
|
|
|
|
|
|
|
|
if(workspace_size != 0) {
|
2022-12-24 09:42:39 +08:00
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
|
|
|
cutlass::Status status = mGemmBatchedCudaF32F32LnAlign1RC.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmBatchedCudaF32F32LnAlign1RC.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2022-11-18 22:35:31 +08:00
|
|
|
if(mFp16Infer) {
|
2022-11-08 17:05:14 +08:00
|
|
|
if(mUseRRLayout) {
|
|
|
|
typename GemmBatchedTensor_F16_F16_Linear_AlignTensor_Row_Row_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[2]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elhPad[2]* mBs), // batch_stride_B
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elhPad[2]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[2]), // batch_stride_C
|
2022-11-08 17:05:14 +08:00
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
|
|
|
|
|
|
|
size_t workspace_size = GemmBatchedTensor_F16_F16_Linear_AlignTensor_Row_Row_Sm75::get_workspace_size(arguments);
|
|
|
|
if(workspace_size != 0) {
|
2022-12-24 09:42:39 +08:00
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
2022-11-18 22:35:31 +08:00
|
|
|
cutlass::Status status = mGemmBatchedF16F16LnAlign8RRSm75.can_implement(arguments);
|
2022-11-08 17:05:14 +08:00
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
2022-11-18 22:35:31 +08:00
|
|
|
status = mGemmBatchedF16F16LnAlign8RRSm75.initialize(arguments, (uint8_t *)mWorkspace);
|
2022-11-08 17:05:14 +08:00
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
if(hAlignment) {
|
2023-05-18 19:11:50 +08:00
|
|
|
if(mConvertGemmSplitK) {
|
|
|
|
int split_k_slices = 16;
|
|
|
|
typename GemmTensor_F16_F16_Linear_AlignTensor_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
|
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
split_k_slices}; // <- k-dimension split factor
|
|
|
|
size_t workspace_size = GemmTensor_F16_F16_Linear_AlignTensor_Sm75::get_workspace_size(arguments);
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2023-05-18 19:11:50 +08:00
|
|
|
if(workspace_size != 0) {
|
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
|
|
|
}
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2023-05-18 19:11:50 +08:00
|
|
|
cutlass::Status status = mGemmF16F16LnAlign8Sm75.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2023-05-18 19:11:50 +08:00
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmF16F16LnAlign8Sm75.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
typename GemmBatchedTensor_F16_F16_Linear_AlignTensor_Row_Column_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
|
|
|
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elh[2]* mBs), // batch_stride_B
|
|
|
|
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elh[2]), // batch_stride_C
|
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
|
|
|
|
|
|
|
size_t workspace_size = GemmBatchedTensor_F16_F16_Linear_AlignTensor_Row_Column_Sm75::get_workspace_size(arguments);
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2023-05-18 19:11:50 +08:00
|
|
|
if(workspace_size != 0) {
|
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
|
|
|
cutlass::Status status = mGemmBatchedF16F16LnAlign8RCSm75.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2023-05-18 19:11:50 +08:00
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmBatchedF16F16LnAlign8RCSm75.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
2023-05-18 19:11:50 +08:00
|
|
|
} else {
|
|
|
|
if(mConvertGemmSplitK) {
|
|
|
|
int split_k_slices = 16;
|
|
|
|
typename GemmTensor_F16_F16_Linear_AlignCuda_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
|
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
split_k_slices}; // <- k-dimension split factor
|
|
|
|
size_t workspace_size = GemmTensor_F16_F16_Linear_AlignCuda_Sm75::get_workspace_size(arguments);
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2023-05-18 19:11:50 +08:00
|
|
|
if(workspace_size != 0) {
|
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
|
|
|
}
|
|
|
|
|
|
|
|
cutlass::Status status = mGemmF16F16LnAlign1Sm75.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmF16F16LnAlign1Sm75.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
typename GemmBatchedTensor_F16_F16_Linear_AlignCuda_Row_Column_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
|
|
|
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elh[2]* mBs), // batch_stride_B
|
|
|
|
{(ElementOutput_F16 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F16 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elh[2]), // batch_stride_C
|
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
|
|
|
|
|
|
|
size_t workspace_size = GemmBatchedTensor_F16_F16_Linear_AlignCuda_Row_Column_Sm75::get_workspace_size(arguments);
|
|
|
|
|
|
|
|
if(workspace_size != 0) {
|
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
|
|
|
cutlass::Status status = mGemmBatchedF16F16LnAlign1RCSm75.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmBatchedF16F16LnAlign1RCSm75.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
if(mUseRRLayout) {
|
|
|
|
if(mNeedConvertMatAB) {
|
|
|
|
typename GemmBatchedTensor_F16_F32_Linear_AlignTensor_Row_Row_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[2]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elhPad[2]* mBs), // batch_stride_B
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementOutput_F32 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F32 *)C->deviceId(), mGemmInfo.elhPad[2]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[2]), // batch_stride_C
|
2022-11-08 17:05:14 +08:00
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
|
|
|
|
|
|
|
size_t workspace_size = GemmBatchedTensor_F16_F32_Linear_AlignTensor_Row_Row_Sm75::get_workspace_size(arguments);
|
|
|
|
|
|
|
|
if(workspace_size != 0) {
|
2022-12-24 09:42:39 +08:00
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
|
|
|
cutlass::Status status = mGemmBatchedF16F32LnAlign8RRSm75.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmBatchedF16F32LnAlign8RRSm75.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
typename GemmBatchedTensor_F32_F32_Linear_AlignTensor_Row_Row_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F32 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementInput_F32 *)mTempMatB, mGemmInfo.elhPad[2]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elhPad[2]* mBs), // batch_stride_B
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementOutput_F32 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F32 *)C->deviceId(), mGemmInfo.elhPad[2]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[2]), // batch_stride_C
|
2022-11-08 17:05:14 +08:00
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
|
|
|
|
|
|
|
size_t workspace_size = GemmBatchedTensor_F32_F32_Linear_AlignTensor_Row_Row_Sm75::get_workspace_size(arguments);
|
|
|
|
|
|
|
|
if(workspace_size != 0) {
|
2022-12-24 09:42:39 +08:00
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
|
|
|
cutlass::Status status = mGemmBatchedF32F32LnAlign8RRSm75.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmBatchedF32F32LnAlign8RRSm75.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
if(hAlignment) {
|
|
|
|
if(mNeedConvertMatAB) {
|
2023-05-18 19:11:50 +08:00
|
|
|
if(mConvertGemmSplitK) {
|
|
|
|
int split_k_slices = 16;
|
|
|
|
typename GemmTensor_F16_F32_Linear_AlignTensor_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
{(ElementOutput_F32 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
{(ElementOutput_F32 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
|
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
split_k_slices}; // <- k-dimension split factor
|
|
|
|
size_t workspace_size = GemmTensor_F16_F32_Linear_AlignTensor_Sm75::get_workspace_size(arguments);
|
|
|
|
|
|
|
|
if(workspace_size != 0) {
|
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
|
|
|
}
|
|
|
|
|
|
|
|
cutlass::Status status = mGemmF16F32LnAlign8Sm75.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmF16F32LnAlign8Sm75.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
typename GemmBatchedTensor_F16_F32_Linear_AlignTensor_Row_Column_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
|
|
|
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elh[2]* mBs), // batch_stride_B
|
|
|
|
{(ElementOutput_F32 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F32 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elh[2]), // batch_stride_C
|
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2023-05-18 19:11:50 +08:00
|
|
|
size_t workspace_size = GemmBatchedTensor_F16_F32_Linear_AlignTensor_Row_Column_Sm75::get_workspace_size(arguments);
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2023-05-18 19:11:50 +08:00
|
|
|
if(workspace_size != 0) {
|
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
|
|
|
cutlass::Status status = mGemmBatchedF16F32LnAlign8RCSm75.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2023-05-18 19:11:50 +08:00
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmBatchedF16F32LnAlign8RCSm75.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
2023-05-18 19:11:50 +08:00
|
|
|
} else {
|
|
|
|
if(mConvertGemmSplitK) {
|
|
|
|
int split_k_slices = 16;
|
|
|
|
typename GemmTensor_F32_F32_Linear_AlignTensor_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F32 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
{(ElementInput_F32 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
{(ElementOutput_F32 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
{(ElementOutput_F32 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
|
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
split_k_slices}; // <- k-dimension split factor
|
|
|
|
size_t workspace_size = GemmTensor_F32_F32_Linear_AlignTensor_Sm75::get_workspace_size(arguments);
|
|
|
|
|
|
|
|
if(workspace_size != 0) {
|
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
|
|
|
}
|
|
|
|
|
|
|
|
cutlass::Status status = mGemmF32F32LnAlign8Sm75.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmF32F32LnAlign8Sm75.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
typename GemmBatchedTensor_F32_F32_Linear_AlignTensor_Row_Column_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F32 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
2023-05-18 19:11:50 +08:00
|
|
|
{(ElementInput_F32 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
2023-03-23 19:12:07 +08:00
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elh[2]* mBs), // batch_stride_B
|
2022-11-08 17:05:14 +08:00
|
|
|
{(ElementOutput_F32 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F32 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elh[2]), // batch_stride_C
|
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
|
|
|
|
2023-05-18 19:11:50 +08:00
|
|
|
size_t workspace_size = GemmBatchedTensor_F32_F32_Linear_AlignTensor_Row_Column_Sm75::get_workspace_size(arguments);
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2023-05-18 19:11:50 +08:00
|
|
|
if(workspace_size != 0) {
|
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
|
|
|
cutlass::Status status = mGemmBatchedF32F32LnAlign8RCSm75.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmBatchedF32F32LnAlign8RCSm75.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
2023-05-18 19:11:50 +08:00
|
|
|
}
|
|
|
|
} else {
|
|
|
|
if(mNeedConvertMatAB) {
|
|
|
|
if(mConvertGemmSplitK) {
|
|
|
|
int split_k_slices = 16;
|
|
|
|
typename GemmTensor_F16_F32_Linear_AlignCuda_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
{(ElementOutput_F32 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
{(ElementOutput_F32 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
|
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
split_k_slices}; // <- k-dimension split factor
|
|
|
|
size_t workspace_size = GemmTensor_F16_F32_Linear_AlignCuda_Sm75::get_workspace_size(arguments);
|
|
|
|
|
|
|
|
if(workspace_size != 0) {
|
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
|
|
|
}
|
|
|
|
|
|
|
|
cutlass::Status status = mGemmF16F32LnAlign1Sm75.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmF16F32LnAlign1Sm75.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
typename GemmBatchedTensor_F16_F32_Linear_AlignCuda_Row_Column_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F16 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
|
|
|
{(ElementInput_F16 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elh[2]* mBs), // batch_stride_B
|
|
|
|
{(ElementOutput_F32 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F32 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elh[2]), // batch_stride_C
|
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2023-05-18 19:11:50 +08:00
|
|
|
size_t workspace_size = GemmBatchedTensor_F16_F32_Linear_AlignCuda_Row_Column_Sm75::get_workspace_size(arguments);
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2023-05-18 19:11:50 +08:00
|
|
|
if(workspace_size != 0) {
|
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
|
|
|
cutlass::Status status = mGemmBatchedF16F32LnAlign1RCSm75.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2023-05-18 19:11:50 +08:00
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmBatchedF16F32LnAlign1RCSm75.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
if(mConvertGemmSplitK) {
|
|
|
|
int split_k_slices = 16;
|
|
|
|
typename GemmTensor_F32_F32_Linear_AlignCuda_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F32 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
{(ElementInput_F32 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
{(ElementOutput_F32 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
{(ElementOutput_F32 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
|
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
split_k_slices}; // <- k-dimension split factor
|
|
|
|
size_t workspace_size = GemmTensor_F32_F32_Linear_AlignCuda_Sm75::get_workspace_size(arguments);
|
|
|
|
|
|
|
|
if(workspace_size != 0) {
|
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
|
|
|
}
|
|
|
|
|
|
|
|
cutlass::Status status = mGemmF32F32LnAlign1Sm75.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmF32F32LnAlign1Sm75.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
typename GemmBatchedTensor_F32_F32_Linear_AlignCuda_Row_Column_Sm75::Arguments arguments{problem_size, // <- problem size of matrix multiplication
|
|
|
|
{(ElementInput_F32 *)mTempMatA, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elhPad[1]* mAs), // batch_stride_A
|
|
|
|
{(ElementInput_F32 *)mTempMatB, mGemmInfo.elhPad[1]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elhPad[1] * mGemmInfo.elh[2]* mBs), // batch_stride_B
|
|
|
|
{(ElementOutput_F32 *)mBiasPtr, 0}, // Ptr + ldm if ldm = 0, vector,
|
|
|
|
(int64_t)(0), // batch_stride_bias
|
|
|
|
{(ElementOutput_F32 *)C->deviceId(), mGemmInfo.elh[2]}, // Ptr + ldm
|
|
|
|
(int64_t)(mGemmInfo.elh[0] * mGemmInfo.elh[2]), // batch_stride_C
|
|
|
|
{alpha, beta}, // <- tuple of alpha and beta
|
|
|
|
mBatch}; // batch_count
|
|
|
|
|
|
|
|
size_t workspace_size = GemmBatchedTensor_F32_F32_Linear_AlignCuda_Row_Column_Sm75::get_workspace_size(arguments);
|
|
|
|
|
|
|
|
if(workspace_size != 0) {
|
|
|
|
workspaceTensor.reset(Tensor::createDevice<int8_t>({(int)workspace_size}));
|
|
|
|
mBackend->onAcquireBuffer(workspaceTensor.get(), Backend::STATIC);
|
|
|
|
mWorkspace = (void *)workspaceTensor.get()->buffer().device;
|
|
|
|
}
|
|
|
|
// Check the problem size is supported or not
|
|
|
|
cutlass::Status status = mGemmBatchedF32F32LnAlign1RCSm75.can_implement(arguments);
|
|
|
|
cutlass_check(status);
|
|
|
|
|
|
|
|
// Initialize CUTLASS kernel with arguments and workspace pointer
|
|
|
|
status = mGemmBatchedF32F32LnAlign1RCSm75.initialize(arguments, (uint8_t *)mWorkspace);
|
|
|
|
cutlass_check(status);
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
ErrorCode MatMulExecution::onResize(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
2022-02-18 11:30:27 +08:00
|
|
|
auto runtime = static_cast<CUDABackend*>(backend())->getCUDARuntime();
|
2022-11-08 17:05:14 +08:00
|
|
|
auto bytes = static_cast<CUDABackend*>(backend())->getBytes(inputs[0]);
|
|
|
|
|
|
|
|
const Tensor* A = inputs[0];
|
|
|
|
const Tensor* B = inputs[1];
|
2020-11-05 16:41:56 +08:00
|
|
|
auto C = outputs[0];
|
2022-02-18 11:30:27 +08:00
|
|
|
auto dimensions = C->dimensions();
|
2022-11-08 17:05:14 +08:00
|
|
|
mBatch = 1;
|
2022-02-18 11:30:27 +08:00
|
|
|
for (int i = 0; i < dimensions - 2; ++i) {
|
2022-11-08 17:05:14 +08:00
|
|
|
mBatch *= C->length(i);
|
2022-02-18 11:30:27 +08:00
|
|
|
}
|
|
|
|
auto e = C->length(dimensions-2);
|
|
|
|
auto h = C->length(dimensions-1);
|
|
|
|
auto w0 = inputs[0]->length(dimensions-1);
|
|
|
|
auto h0 = inputs[0]->length(dimensions-2);
|
2022-01-04 10:50:40 +08:00
|
|
|
|
|
|
|
auto l = w0;
|
|
|
|
if (mTransposeA) {
|
|
|
|
l = h0;
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
2022-11-08 17:05:14 +08:00
|
|
|
|
|
|
|
mGemmInfo.elh[0] = e;
|
|
|
|
mGemmInfo.elh[1] = l;
|
|
|
|
mGemmInfo.elh[2] = h;
|
|
|
|
mGemmInfo.elhPad[0] = UP_DIV(e, PACK_NUMBER) * PACK_NUMBER;
|
|
|
|
mGemmInfo.elhPad[1] = UP_DIV(l, PACK_NUMBER) * PACK_NUMBER;
|
|
|
|
mGemmInfo.elhPad[2] = UP_DIV(h, PACK_NUMBER) * PACK_NUMBER;
|
|
|
|
|
|
|
|
bool lAlignment = (mGemmInfo.elhPad[1] == mGemmInfo.elh[1]);
|
|
|
|
bool hAlignment = (mGemmInfo.elhPad[2] == mGemmInfo.elh[2]);
|
|
|
|
bool needBTranspose = (!mTransposeB && !hAlignment);
|
|
|
|
|
|
|
|
mUseRRLayout = (!mTransposeB && hAlignment);
|
2023-04-27 15:11:05 +08:00
|
|
|
|
2023-06-16 09:42:45 +08:00
|
|
|
mNeedATempBuffer = (mTransposeA || !lAlignment);
|
|
|
|
mNeedBTempBuffer = (needBTranspose || !lAlignment);
|
2022-11-08 17:05:14 +08:00
|
|
|
mNeedConvertMatAB = (mNeedATempBuffer || mNeedBTempBuffer);
|
|
|
|
|
2023-05-18 19:11:50 +08:00
|
|
|
// MNN_PRINT("trAtrB:%d-%d, tmpAB:%d-%d inps:%d, bwlh:%d-%d-%d-%d\n", mTransposeA, mTransposeB, mNeedATempBuffer, mNeedBTempBuffer, inputs.size(), mBatch, mGemmInfo.elh[0], mGemmInfo.elh[1], mGemmInfo.elh[2]);
|
2022-11-08 17:05:14 +08:00
|
|
|
|
|
|
|
auto pool = static_cast<CUDABackend*>(backend())->getBufferPool();
|
2023-09-04 10:42:11 +08:00
|
|
|
MemChunk bufferAData, bufferBData;
|
2023-04-27 15:11:05 +08:00
|
|
|
size_t convertBytes = 2;
|
|
|
|
if(mFp32Infer) {
|
|
|
|
convertBytes = 4;
|
|
|
|
}
|
2023-06-16 09:42:45 +08:00
|
|
|
if((mNeedConvertMatAB && mFp16Fp32MixInfer) || mNeedATempBuffer) {
|
2023-03-17 17:04:38 +08:00
|
|
|
bufferAData = pool->alloc(convertBytes * mBatch * mAs * mGemmInfo.elh[0] * mGemmInfo.elhPad[1]);
|
2023-09-04 10:42:11 +08:00
|
|
|
mTempMatA = (void*)bufferAData.ptr();
|
2023-04-27 15:11:05 +08:00
|
|
|
} else {
|
|
|
|
mTempMatA = (void *)A->deviceId();
|
|
|
|
}
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2023-06-16 09:42:45 +08:00
|
|
|
if((mNeedConvertMatAB && mFp16Fp32MixInfer) || mNeedBTempBuffer) {
|
2023-03-17 17:04:38 +08:00
|
|
|
bufferBData = pool->alloc(convertBytes * mBatch * mBs * mGemmInfo.elh[2] * mGemmInfo.elhPad[1]);
|
2023-09-04 10:42:11 +08:00
|
|
|
mTempMatB = (void*)bufferBData.ptr();
|
2023-04-27 15:11:05 +08:00
|
|
|
} else {
|
|
|
|
mTempMatB = (void *)B->deviceId();
|
|
|
|
}
|
2022-11-08 17:05:14 +08:00
|
|
|
|
2023-04-27 15:11:05 +08:00
|
|
|
if(bufferAData.first != nullptr) {
|
2022-11-08 17:05:14 +08:00
|
|
|
pool->free(bufferAData);
|
2023-04-27 15:11:05 +08:00
|
|
|
}
|
|
|
|
if(bufferBData.first != nullptr) {
|
2022-11-08 17:05:14 +08:00
|
|
|
pool->free(bufferBData);
|
2022-01-04 10:50:40 +08:00
|
|
|
}
|
2023-04-27 15:11:05 +08:00
|
|
|
|
2022-12-24 09:42:39 +08:00
|
|
|
// inputSize only two, No need Bias, Fake address for mBiasPtr is ok because beta is zero.
|
|
|
|
if(inputs.size() == 2) {
|
|
|
|
mBiasPtr = (void*)B->deviceId();
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
//printf("MatMulAB:%p-%p-%p-%p\n", A->host<void*>(), A->deviceId(), B->host<void*>(), B->deviceId());
|
|
|
|
|
2023-05-18 19:11:50 +08:00
|
|
|
mConvertGemmSplitK = ((mBatch == 1) && (mGemmInfo.elhPad[1] >= 16384));
|
2022-11-08 17:05:14 +08:00
|
|
|
// Set Cutlass Param Arguments
|
|
|
|
mResizeSetArgument = (mTempMatA != nullptr && mTempMatB != nullptr && C->deviceId() != 0);
|
|
|
|
if(mResizeSetArgument) {
|
|
|
|
setArguments(inputs, outputs);
|
2022-01-04 10:50:40 +08:00
|
|
|
}
|
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
|
|
|
ErrorCode MatMulExecution::onExecute(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs) {
|
2022-02-18 11:30:27 +08:00
|
|
|
auto bytes = static_cast<CUDABackend*>(backend())->getBytes(inputs[0]);
|
2022-11-08 17:05:14 +08:00
|
|
|
auto runtime = static_cast<CUDABackend*>(backend())->getCUDARuntime();
|
|
|
|
bool hAlignment = (mGemmInfo.elhPad[2] == mGemmInfo.elh[2]);
|
2022-01-04 10:50:40 +08:00
|
|
|
|
2022-11-08 17:05:14 +08:00
|
|
|
// PreProcess for Alignment
|
|
|
|
if(mNeedConvertMatAB) {
|
2023-03-17 18:05:40 +08:00
|
|
|
int aBatch = mBatch;
|
|
|
|
int bBatch = mBatch;
|
|
|
|
if (mAs == 0) {
|
|
|
|
aBatch = 1;
|
|
|
|
}
|
|
|
|
if (mBs == 0) {
|
|
|
|
bBatch = 1;
|
|
|
|
}
|
2022-11-08 17:05:14 +08:00
|
|
|
DivModFast eD(mGemmInfo.elh[0]);
|
|
|
|
DivModFast lD(mGemmInfo.elh[1]);
|
|
|
|
DivModFast hD(mGemmInfo.elh[2]);
|
|
|
|
DivModFast lpD((mGemmInfo.elhPad[1]));
|
|
|
|
DivModFast lp2D((mGemmInfo.elhPad[1]/2));
|
|
|
|
|
|
|
|
auto& prop = runtime->prop();
|
|
|
|
int block_num = prop.multiProcessorCount;
|
|
|
|
int block_size = prop.maxThreadsPerBlock;
|
2022-11-18 22:35:31 +08:00
|
|
|
if(mFp32Infer) {
|
|
|
|
PackPadFill<<<block_num, block_size>>>((const float*)inputs[0]->deviceId(), (const float*)inputs[1]->deviceId(), \
|
|
|
|
mTransposeA, mTransposeB, (float*)mTempMatA, (float*)mTempMatB,
|
2023-03-17 18:05:40 +08:00
|
|
|
aBatch, bBatch, mGemmInfo.elh[0], mGemmInfo.elh[1], mGemmInfo.elh[2], \
|
2022-11-18 22:35:31 +08:00
|
|
|
mGemmInfo.elhPad[0], mGemmInfo.elhPad[1], mGemmInfo.elhPad[2], \
|
|
|
|
eD, lD, hD, lpD, lp2D);
|
|
|
|
checkKernelErrors;
|
|
|
|
} else if(mFp16Fp32MixInfer) {
|
2022-11-08 17:05:14 +08:00
|
|
|
PackPadFill<<<block_num, block_size>>>((const float*)inputs[0]->deviceId(), (const float*)inputs[1]->deviceId(), \
|
|
|
|
mTransposeA, mTransposeB, (half*)mTempMatA, (half*)mTempMatB,
|
2023-03-17 18:05:40 +08:00
|
|
|
aBatch, bBatch, mGemmInfo.elh[0], mGemmInfo.elh[1], mGemmInfo.elh[2], \
|
2022-11-08 17:05:14 +08:00
|
|
|
mGemmInfo.elhPad[0], mGemmInfo.elhPad[1], mGemmInfo.elhPad[2], \
|
|
|
|
eD, lD, hD, lpD, lp2D);
|
|
|
|
checkKernelErrors;
|
|
|
|
} else {
|
|
|
|
PackPadFill<<<block_num, block_size>>>((const half*)inputs[0]->deviceId(), (const half*)inputs[1]->deviceId(), \
|
|
|
|
mTransposeA, mTransposeB, (half*)mTempMatA, (half*)mTempMatB,
|
2023-03-17 18:05:40 +08:00
|
|
|
aBatch, bBatch, mGemmInfo.elh[0], mGemmInfo.elh[1], mGemmInfo.elh[2], \
|
2022-11-08 17:05:14 +08:00
|
|
|
mGemmInfo.elhPad[0], mGemmInfo.elhPad[1], mGemmInfo.elhPad[2], \
|
|
|
|
eD, lD, hD, lpD, lp2D);
|
|
|
|
checkKernelErrors;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if(!mResizeSetArgument) {
|
|
|
|
// Repeat set cutlass argments if possible
|
|
|
|
//printf("argment onexecute set\n");
|
|
|
|
|
|
|
|
if(!mNeedConvertMatAB) {
|
|
|
|
mTempMatA = (void *)inputs[0]->deviceId();
|
|
|
|
mTempMatB = (void *)inputs[1]->deviceId();
|
|
|
|
}
|
|
|
|
setArguments(inputs, outputs);
|
|
|
|
}
|
|
|
|
|
2022-11-18 22:35:31 +08:00
|
|
|
|
|
|
|
if(mFp32Infer) {
|
|
|
|
if(mUseRRLayout) {
|
|
|
|
cutlass::Status status = mGemmBatchedCudaF32F32LnAlign1RR();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmBatchedCudaF32F32LnAlign1RC();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
2022-11-08 17:05:14 +08:00
|
|
|
if(mGpuComputeCap < 75) {
|
2022-11-18 22:35:31 +08:00
|
|
|
if (mFp16Fp32MixInfer) {
|
2022-11-08 17:05:14 +08:00
|
|
|
if(mUseRRLayout) {
|
|
|
|
if(mNeedConvertMatAB) {
|
|
|
|
cutlass::Status status = mGemmBatchedCudaF16F32LnAlign1RR();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmBatchedCudaF32F32LnAlign1RR();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
if(mNeedConvertMatAB) {
|
|
|
|
cutlass::Status status = mGemmBatchedCudaF16F32LnAlign1RC();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmBatchedCudaF32F32LnAlign1RC();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
if(mUseRRLayout) {
|
2022-11-18 22:35:31 +08:00
|
|
|
cutlass::Status status = mGemmBatchedCudaF16F16LnAlign1RR();
|
2022-11-08 17:05:14 +08:00
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
2022-11-18 22:35:31 +08:00
|
|
|
cutlass::Status status = mGemmBatchedCudaF16F16LnAlign1RC();
|
2022-11-08 17:05:14 +08:00
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
2022-11-18 22:35:31 +08:00
|
|
|
if (mFp16Fp32MixInfer) {
|
2022-11-08 17:05:14 +08:00
|
|
|
if(mUseRRLayout) {
|
|
|
|
if(mNeedConvertMatAB) {
|
|
|
|
cutlass::Status status = mGemmBatchedF16F32LnAlign8RRSm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmBatchedF32F32LnAlign8RRSm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
if(hAlignment) {
|
|
|
|
if(mNeedConvertMatAB) {
|
2023-05-18 19:11:50 +08:00
|
|
|
if(mConvertGemmSplitK) {
|
|
|
|
cutlass::Status status = mGemmF16F32LnAlign8Sm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmBatchedF16F32LnAlign8RCSm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
2022-11-08 17:05:14 +08:00
|
|
|
} else {
|
2023-05-18 19:11:50 +08:00
|
|
|
if(mConvertGemmSplitK) {
|
|
|
|
cutlass::Status status = mGemmF32F32LnAlign8Sm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmBatchedF32F32LnAlign8RCSm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
} else {
|
|
|
|
if(mNeedConvertMatAB) {
|
2023-05-18 19:11:50 +08:00
|
|
|
if(mConvertGemmSplitK) {
|
|
|
|
cutlass::Status status = mGemmF16F32LnAlign1Sm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmBatchedF16F32LnAlign1RCSm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
2022-11-08 17:05:14 +08:00
|
|
|
} else {
|
2023-05-18 19:11:50 +08:00
|
|
|
if(mConvertGemmSplitK) {
|
|
|
|
cutlass::Status status = mGemmF32F32LnAlign1Sm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmBatchedF32F32LnAlign1RCSm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
if(mUseRRLayout) {
|
2022-11-18 22:35:31 +08:00
|
|
|
cutlass::Status status = mGemmBatchedF16F16LnAlign8RRSm75();
|
2022-11-08 17:05:14 +08:00
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
if(hAlignment) {
|
2023-05-18 19:11:50 +08:00
|
|
|
if(mConvertGemmSplitK) {
|
|
|
|
cutlass::Status status = mGemmF16F16LnAlign8Sm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmBatchedF16F16LnAlign8RCSm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
2022-11-08 17:05:14 +08:00
|
|
|
} else {
|
2023-05-18 19:11:50 +08:00
|
|
|
if(mConvertGemmSplitK) {
|
|
|
|
cutlass::Status status = mGemmF16F16LnAlign1Sm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
} else {
|
|
|
|
cutlass::Status status = mGemmBatchedF16F16LnAlign1RCSm75();
|
|
|
|
cutlass_check(status);
|
|
|
|
}
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|
|
|
|
}
|
2022-01-04 10:50:40 +08:00
|
|
|
}
|
2023-05-18 19:11:50 +08:00
|
|
|
// printf("normal:%d rrlayout:%d convertab:%d halign:%d\n", mFp16Fp32MixInfer, mUseRRLayout, mNeedConvertMatAB, hAlignment);
|
2020-11-05 16:41:56 +08:00
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
|
|
|
class MatMulCreator : public CUDABackend::Creator {
|
|
|
|
public:
|
|
|
|
virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
|
|
|
const MNN::Op* op, Backend* backend) const override {
|
|
|
|
auto param = op->main_as_MatMul();
|
|
|
|
return new MatMulExecution(param->transposeA(), param->transposeB(), backend);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
static CUDACreatorRegister<MatMulCreator> __init(OpType_MatMul);
|
|
|
|
|
|
|
|
}
|
2022-11-08 17:05:14 +08:00
|
|
|
}
|