2020-11-05 16:41:56 +08:00
|
|
|
//
|
|
|
|
// CUDABackend.cpp
|
|
|
|
// MNN
|
|
|
|
//
|
|
|
|
// Created by MNN on 2019/02/28.
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
//
|
|
|
|
|
|
|
|
#include "backend/cuda/core/CUDABackend.hpp"
|
|
|
|
#include "MNN_generated.h"
|
|
|
|
|
|
|
|
#include <map>
|
|
|
|
#include <mutex>
|
|
|
|
#include "core/Macro.h"
|
|
|
|
#include "shape/SizeComputer.hpp"
|
|
|
|
#include "core/TensorUtils.hpp"
|
2022-02-18 11:30:27 +08:00
|
|
|
#include "execution/Raster.cuh"
|
|
|
|
#include "execution/Transpose.cuh"
|
|
|
|
#include "execution/MNNCUDADefine.hpp"
|
2023-06-16 09:42:45 +08:00
|
|
|
#include "execution/CastExecution.hpp"
|
2022-09-30 10:02:52 +08:00
|
|
|
#include "CUDATools.hpp"
|
2023-12-04 11:12:20 +08:00
|
|
|
#include "execution/FuseExecutionV2.hpp"
|
2022-02-18 11:30:27 +08:00
|
|
|
// #define MNN_CUDA_COPY_DEBUG
|
2020-11-05 16:41:56 +08:00
|
|
|
|
|
|
|
namespace MNN {
|
|
|
|
namespace CUDA {
|
|
|
|
|
|
|
|
std::map<OpType, CUDABackend::Creator*>* gCreator() {
|
|
|
|
static std::map<OpType, CUDABackend::Creator*>* creators = nullptr;
|
2021-06-11 17:17:13 +08:00
|
|
|
static std::once_flag gOnce;
|
2020-11-05 16:41:56 +08:00
|
|
|
std::call_once(gOnce, [&]() { creators = new std::map<OpType, CUDABackend::Creator*>; });
|
|
|
|
return creators;
|
|
|
|
};
|
2020-12-15 14:12:35 +08:00
|
|
|
class CUDARuntimeAllocator : public BufferAllocator::Allocator {
|
|
|
|
public:
|
|
|
|
CUDARuntimeAllocator(CUDARuntime* rt) : mRuntime(rt) {
|
|
|
|
// Do nothing
|
|
|
|
}
|
|
|
|
virtual ~ CUDARuntimeAllocator() = default;
|
2023-09-04 10:42:11 +08:00
|
|
|
virtual MemChunk onAlloc(size_t size, size_t align) override {
|
|
|
|
return MemChunk(mRuntime->alloc(size), 0);
|
2020-12-15 14:12:35 +08:00
|
|
|
}
|
2023-09-04 10:42:11 +08:00
|
|
|
virtual void onRelease(MemChunk ptr) override {
|
2020-12-15 14:12:35 +08:00
|
|
|
mRuntime->free(ptr.first);
|
|
|
|
}
|
|
|
|
private:
|
|
|
|
CUDARuntime* mRuntime;
|
|
|
|
};
|
2024-04-19 11:58:21 +08:00
|
|
|
CUDARuntimeWrapper::CUDARuntimeWrapper(BackendConfig::PrecisionMode precision, BackendConfig::PowerMode power, BackendConfig::MemoryMode memory, int deviceId) {
|
2022-02-18 11:30:27 +08:00
|
|
|
// TODO: Search CUDA Device info and use best one
|
2023-04-18 18:54:46 +08:00
|
|
|
mCUDARuntime.reset(new CUDARuntime(deviceId));
|
2022-09-30 10:02:52 +08:00
|
|
|
#ifdef LOG_VERBOSE
|
|
|
|
MNN_PRINT("create cuda runtime:%p\n", mCUDARuntime.get());
|
|
|
|
#endif
|
2020-11-05 16:41:56 +08:00
|
|
|
if (mCUDARuntime.get()) {
|
|
|
|
if (mCUDARuntime->isCreateError() == true) {
|
|
|
|
mIsCreateError = true;
|
|
|
|
return;
|
|
|
|
}
|
2020-12-15 14:12:35 +08:00
|
|
|
std::shared_ptr<BufferAllocator::Allocator> allocator(new CUDARuntimeAllocator(mCUDARuntime.get()));
|
2023-09-04 10:42:11 +08:00
|
|
|
mBufferPool.reset(new EagerBufferAllocator(allocator));
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
2022-02-18 11:30:27 +08:00
|
|
|
mDefaultPrecision = precision;
|
2024-04-19 11:58:21 +08:00
|
|
|
mDefaultMemory = memory;
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
CUDARuntimeWrapper::~CUDARuntimeWrapper() {
|
|
|
|
// Do nothing
|
|
|
|
}
|
2021-06-11 17:17:13 +08:00
|
|
|
float CUDARuntimeWrapper::onGetMemoryInMB() {
|
|
|
|
auto staticMemoryInMB = mBufferPool->totalSize() / 1024.0f / 1024.0f;
|
|
|
|
return staticMemoryInMB;
|
|
|
|
}
|
|
|
|
|
2023-10-18 10:31:02 +08:00
|
|
|
std::pair<const void*, size_t> CUDARuntimeWrapper::onGetCache() {//make Cache
|
|
|
|
return mCUDARuntime->makeCache();
|
|
|
|
}
|
|
|
|
|
|
|
|
bool CUDARuntimeWrapper::onSetCache(const void* buffer, size_t size) {//set Cache
|
|
|
|
return mCUDARuntime->setCache(std::make_pair(buffer, size));
|
|
|
|
}
|
|
|
|
|
2024-09-12 12:57:57 +08:00
|
|
|
Backend* CUDARuntimeWrapper::onCreate(const BackendConfig* config, Backend* origin) const {
|
2022-09-30 10:02:52 +08:00
|
|
|
#ifdef LOG_VERBOSE
|
|
|
|
MNN_PRINT("cudaruntime:%p, create CUDABackend\n", this);
|
|
|
|
#endif
|
2024-04-19 11:58:21 +08:00
|
|
|
auto precision_mode = mDefaultPrecision;
|
|
|
|
auto memory_mode = mDefaultMemory;
|
2022-02-18 11:30:27 +08:00
|
|
|
if (nullptr != config) {
|
2024-04-19 11:58:21 +08:00
|
|
|
precision_mode = config->precision;
|
|
|
|
memory_mode = config->memory;
|
2022-02-18 11:30:27 +08:00
|
|
|
}
|
2022-11-18 22:35:31 +08:00
|
|
|
int precision = 0;
|
2024-04-19 11:58:21 +08:00
|
|
|
if(precision_mode == BackendConfig::Precision_Low) {
|
2022-11-18 22:35:31 +08:00
|
|
|
precision = 2;
|
2024-04-19 11:58:21 +08:00
|
|
|
} else if(precision_mode == BackendConfig::Precision_Normal) {
|
2022-11-18 22:35:31 +08:00
|
|
|
precision = 0;
|
2024-04-19 11:58:21 +08:00
|
|
|
} else if(precision_mode == BackendConfig::Precision_Low_BF16) {
|
2023-06-16 09:42:45 +08:00
|
|
|
precision = 3;
|
2022-11-18 22:35:31 +08:00
|
|
|
} else {
|
|
|
|
precision = 1;
|
|
|
|
}
|
|
|
|
|
2025-07-23 14:10:58 +08:00
|
|
|
auto backend = new CUDABackend(this, mBufferPool, mCUDARuntime, precision, memory_mode);
|
|
|
|
backend->setMetaPtr(pMeta);
|
|
|
|
return backend;
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void CUDARuntimeWrapper::onGabageCollect(int level) {
|
2020-12-15 14:12:35 +08:00
|
|
|
mBufferPool->release(false);
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
|
2022-09-30 10:02:52 +08:00
|
|
|
|
2025-06-17 11:08:21 +08:00
|
|
|
CUDABackend::CUDABackend(const Runtime* runtime,
|
|
|
|
std::shared_ptr<BufferAllocator> st,
|
|
|
|
std::shared_ptr<CUDARuntime> rt,
|
2024-04-19 11:58:21 +08:00
|
|
|
int precision, BackendConfig::MemoryMode memory)
|
2020-11-05 16:41:56 +08:00
|
|
|
: Backend(MNN_FORWARD_CUDA) {
|
2022-09-30 10:02:52 +08:00
|
|
|
#ifdef LOG_VERBOSE
|
|
|
|
MNN_PRINT("cuda backend create\n");
|
|
|
|
#endif
|
2023-09-04 10:42:11 +08:00
|
|
|
mBufferPool.reset(new EagerBufferAllocator(BufferAllocator::Allocator::createRecurse(st.get())));
|
2025-06-17 11:08:21 +08:00
|
|
|
mRuntime = runtime;
|
2020-11-05 16:41:56 +08:00
|
|
|
mStaticBufferPool = st;
|
|
|
|
mCUDARuntime = rt;
|
2022-11-18 22:35:31 +08:00
|
|
|
mUseFp16AsFp32 = (precision == 2);
|
|
|
|
mPrecision = precision;
|
2024-04-19 11:58:21 +08:00
|
|
|
mMemory = memory;
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
CUDABackend::~CUDABackend() {
|
|
|
|
#ifdef LOG_VERBOSE
|
|
|
|
MNN_PRINT("enter CUDABackend::~CUDABackend \n");
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
CUDARuntime* CUDABackend::getCUDARuntime() {
|
|
|
|
MNN_ASSERT(nullptr != mCUDARuntime.get());
|
|
|
|
return mCUDARuntime.get();
|
|
|
|
}
|
2022-09-30 10:02:52 +08:00
|
|
|
const Runtime* CUDABackend::getRuntime() {
|
2025-06-17 11:08:21 +08:00
|
|
|
return (const Runtime*)mRuntime;
|
2022-09-30 10:02:52 +08:00
|
|
|
}
|
2022-02-18 11:30:27 +08:00
|
|
|
bool CUDABackend::useFp16() const {
|
|
|
|
return mUseFp16AsFp32;
|
|
|
|
}
|
2023-07-18 09:36:26 +08:00
|
|
|
|
|
|
|
#ifdef MNN_CODEGEN_CUDA
|
|
|
|
std::map<std::pair<std::string, std:: string>, CUmodule> CUDABackend::kernelCuModuleMap() {
|
|
|
|
return mKernelCuModuleMap;
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
2022-11-18 22:35:31 +08:00
|
|
|
int CUDABackend::getPrecision() const {
|
|
|
|
return mPrecision;
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
|
2024-04-19 11:58:21 +08:00
|
|
|
BackendConfig::MemoryMode CUDABackend::getMemoryMode() const {
|
|
|
|
return mMemory;
|
|
|
|
}
|
2021-11-30 10:10:53 +08:00
|
|
|
class CUDAMemObj : public Backend::MemObj {
|
|
|
|
public:
|
2023-09-04 10:42:11 +08:00
|
|
|
CUDAMemObj(BufferAllocator* allocator, MemChunk points) {
|
2021-11-30 10:10:53 +08:00
|
|
|
mPoint = std::move(points);
|
|
|
|
mAllocator = allocator;
|
|
|
|
}
|
|
|
|
virtual ~ CUDAMemObj() {
|
|
|
|
mAllocator->free(mPoint);
|
|
|
|
}
|
2023-09-04 10:42:11 +08:00
|
|
|
MemChunk chunk() override {
|
|
|
|
return mPoint;
|
|
|
|
}
|
2021-11-30 10:10:53 +08:00
|
|
|
private:
|
|
|
|
BufferAllocator* mAllocator;
|
2023-09-04 10:42:11 +08:00
|
|
|
MemChunk mPoint;
|
2021-11-30 10:10:53 +08:00
|
|
|
};
|
2025-05-30 18:38:23 +08:00
|
|
|
size_t CUDABackend::getBytes(const Tensor* tensor) const {
|
|
|
|
size_t bytes = tensor->getType().bytes();
|
2023-06-16 09:42:45 +08:00
|
|
|
if (mPrecision == 2 || mPrecision == 3) {// Fp16 or Bf16
|
2022-02-18 11:30:27 +08:00
|
|
|
if (halide_type_float == tensor->getType().code) {
|
|
|
|
bytes = 2;
|
|
|
|
}
|
|
|
|
}
|
2023-06-16 09:42:45 +08:00
|
|
|
auto quant = TensorUtils::getDescribe(tensor)->quantAttr.get();
|
|
|
|
if (nullptr != quant && TensorUtils::getDescribe(tensor)->type == DataType_DT_INT8) {
|
|
|
|
bytes = 1;
|
|
|
|
}
|
2022-02-18 11:30:27 +08:00
|
|
|
return bytes;
|
|
|
|
}
|
|
|
|
CPUResizeCache* CUDABackend::getCache() {
|
|
|
|
return &mCache;
|
|
|
|
}
|
|
|
|
|
2021-11-30 10:10:53 +08:00
|
|
|
Backend::MemObj* CUDABackend::onAcquire(const Tensor* nativeTensor, StorageType storageType) {
|
2022-09-30 10:02:52 +08:00
|
|
|
// MNN_PRINT("onAcquire CUDA memory for tensor:%p\n", nativeTensor);
|
2020-11-05 16:41:56 +08:00
|
|
|
#ifdef LOG_VERBOSE
|
|
|
|
MNN_PRINT("Start CUDABackend::onAcquireBuffer !\n");
|
|
|
|
#endif
|
2021-11-30 10:10:53 +08:00
|
|
|
BufferAllocator* allocator = nullptr;
|
2022-02-18 11:30:27 +08:00
|
|
|
auto bytes = getBytes(nativeTensor);
|
|
|
|
size_t mallocSize = realSize(nativeTensor) * bytes;
|
|
|
|
|
2023-09-04 10:42:11 +08:00
|
|
|
MemChunk buffer;
|
2020-11-05 16:41:56 +08:00
|
|
|
if (storageType == DYNAMIC_SEPERATE) {
|
2020-12-15 14:12:35 +08:00
|
|
|
buffer = mBufferPool->alloc(mallocSize, true);
|
2021-11-30 10:10:53 +08:00
|
|
|
allocator = mBufferPool.get();
|
2020-11-05 16:41:56 +08:00
|
|
|
} else if (storageType == DYNAMIC) {
|
2020-12-15 14:12:35 +08:00
|
|
|
buffer = mBufferPool->alloc(mallocSize, false);
|
2021-11-30 10:10:53 +08:00
|
|
|
allocator = mBufferPool.get();
|
2020-11-05 16:41:56 +08:00
|
|
|
} else {
|
|
|
|
MNN_ASSERT(storageType == STATIC);
|
2020-12-15 14:12:35 +08:00
|
|
|
buffer = mStaticBufferPool->alloc(mallocSize, false);
|
2021-11-30 10:10:53 +08:00
|
|
|
allocator = mStaticBufferPool.get();
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
2020-12-15 14:12:35 +08:00
|
|
|
if(nullptr == buffer.first) {
|
2021-11-30 10:10:53 +08:00
|
|
|
return nullptr;
|
2020-12-15 14:12:35 +08:00
|
|
|
};
|
2023-09-04 10:42:11 +08:00
|
|
|
auto host = buffer.ptr();
|
2020-12-15 14:12:35 +08:00
|
|
|
((Tensor*)nativeTensor)->buffer().device = (uint64_t)host;
|
|
|
|
auto des = TensorUtils::getDescribe(nativeTensor);
|
|
|
|
des->extra.offset = buffer.second;
|
2021-11-30 10:10:53 +08:00
|
|
|
return new CUDAMemObj(allocator, buffer);
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
bool CUDABackend::onClearBuffer() {
|
2022-02-18 11:30:27 +08:00
|
|
|
mCache.reset();
|
2020-12-15 14:12:35 +08:00
|
|
|
mBufferPool->release(true);
|
2020-11-05 16:41:56 +08:00
|
|
|
return true;
|
|
|
|
}
|
|
|
|
size_t CUDABackend::realSize(const Tensor* tensor) {
|
2022-02-18 11:30:27 +08:00
|
|
|
auto dim = TensorUtils::getDescribe(tensor)->dimensionFormat;
|
|
|
|
int pack = 1;
|
|
|
|
if (dim == MNN_DATA_FORMAT_NC4HW4) {
|
|
|
|
pack = PACK_NUMBER;
|
2023-06-16 09:42:45 +08:00
|
|
|
if (getDataType(tensor) == DataType_DT_INT8 || tensor->getType().bytes() == 1) {
|
2023-02-28 10:41:24 +08:00
|
|
|
pack = INT8_PACK_NUMBER;
|
|
|
|
}
|
2022-02-18 11:30:27 +08:00
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
size_t res = 1;
|
|
|
|
for (int i = 0; i < tensor->dimensions(); ++i) {
|
2022-02-18 11:30:27 +08:00
|
|
|
size_t l = tensor->length(i);
|
|
|
|
if (1 == i ) {
|
|
|
|
l = UP_DIV(l, pack) * pack;
|
|
|
|
}
|
|
|
|
res *= l;
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
|
2023-02-28 10:41:24 +08:00
|
|
|
static OpType _getRealOpType(OpType opType) {
|
|
|
|
switch (opType) {
|
|
|
|
case OpType_Convolution:
|
|
|
|
return OpType_ConvInt8;
|
|
|
|
case OpType_ConvolutionDepthwise:
|
|
|
|
return OpType_DepthwiseConvInt8;
|
2023-06-16 09:42:45 +08:00
|
|
|
case OpType_BinaryOp:
|
2023-02-28 10:41:24 +08:00
|
|
|
default:
|
|
|
|
return opType;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-04 11:12:20 +08:00
|
|
|
#ifdef MNN_CODEGEN_CUDA
|
|
|
|
void CUDABackend::compile(CUmodule* dst, std::pair<string, string> code, std::vector<const char*> compile_params) {
|
|
|
|
std::vector<const char *> param;
|
|
|
|
auto ptx_code =
|
|
|
|
CUDANVRTCCompile(code, param, mCUDARuntime->compute_capability(), false);
|
|
|
|
|
|
|
|
MNN_CUDA_SAFE_CALL(cuModuleLoadDataEx(dst, ptx_code.c_str(), 0, 0, 0));
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
Execution* CUDABackend::onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
|
|
|
|
const MNN::Op* op) {
|
2022-09-30 10:02:52 +08:00
|
|
|
// #ifdef LOG_VERBOSE
|
|
|
|
// MNN_PRINT("Start CUDABackend::onCreate useFp16:%d\n", useFp16());
|
|
|
|
// #endif
|
2023-02-28 10:41:24 +08:00
|
|
|
auto opType = op->type();
|
|
|
|
if (outputs.size() > 0) {
|
|
|
|
if (TensorUtils::getDescribe(outputs[0])->quantAttr != nullptr && TensorUtils::getDescribe(outputs[0])->type == DataType_DT_INT8) {
|
|
|
|
opType = _getRealOpType(opType);
|
|
|
|
}
|
|
|
|
}
|
2023-06-16 09:42:45 +08:00
|
|
|
// MNN_PRINT("CUDABackend support type %s\n", EnumNameOpType(opType));
|
2020-11-05 16:41:56 +08:00
|
|
|
auto creators = gCreator();
|
2023-02-28 10:41:24 +08:00
|
|
|
auto iter = creators->find(opType);
|
2020-11-05 16:41:56 +08:00
|
|
|
if (iter == creators->end()) {
|
|
|
|
if (nullptr != op->name()) {
|
2023-02-28 10:41:24 +08:00
|
|
|
MNN_PRINT("CUDABackend Don't support type %s, %s\n", EnumNameOpType(opType), op->name()->c_str());
|
2020-11-05 16:41:56 +08:00
|
|
|
} else {
|
2023-02-28 10:41:24 +08:00
|
|
|
MNN_PRINT("CUDABackend Don't support type %s\n", EnumNameOpType(opType));
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
return NULL;
|
|
|
|
}
|
|
|
|
|
2023-07-18 09:36:26 +08:00
|
|
|
#ifdef MNN_CODEGEN_CUDA
|
|
|
|
if(op->type() == OpType_Extra) {
|
2023-12-04 11:12:20 +08:00
|
|
|
if (!FuseExecutionV2::check(op)) {
|
|
|
|
auto extra = op->main_as_Extra();
|
|
|
|
std::string source(reinterpret_cast<const char*>(extra->info()->data()));
|
|
|
|
auto kernel_name = extra->type()->c_str();
|
|
|
|
std::string kernel_source = source;
|
2023-07-18 09:36:26 +08:00
|
|
|
|
2023-12-04 11:12:20 +08:00
|
|
|
std::pair<std::string, std::string> kernelInfo = std::make_pair<std::string, std::string>(kernel_name, kernel_source.c_str());
|
|
|
|
if(mKernelCuModuleMap.find(kernelInfo) == mKernelCuModuleMap.end()) {
|
|
|
|
// printf("\n%s\n\n%s !!!!\n", kernel_source.c_str(), kernel_name);
|
|
|
|
std::vector<const char *> param;
|
|
|
|
bool includeHeadFile = mUseFp16AsFp32;
|
|
|
|
auto ptx_code =
|
|
|
|
CUDANVRTCCompile(kernelInfo, param, mCUDARuntime->compute_capability(), includeHeadFile);
|
|
|
|
|
|
|
|
MNN_CUDA_SAFE_CALL(cuModuleLoadDataEx(&mCuModule, ptx_code.c_str(), 0, 0, 0));
|
|
|
|
mKernelCuModuleMap.insert(std::pair<std::pair<std::string, std:: string>, CUmodule>(kernelInfo, mCuModule));
|
|
|
|
}
|
2023-07-18 09:36:26 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
auto exe = iter->second->onCreate(inputs, outputs, op, this);
|
|
|
|
if (NULL == exe) {
|
|
|
|
if (nullptr != op->name()) {
|
2023-02-28 10:41:24 +08:00
|
|
|
MNN_PRINT("CUDABackend The Creator Don't support type %s, %s\n", EnumNameOpType(opType), op->name()->c_str());
|
2020-11-05 16:41:56 +08:00
|
|
|
} else {
|
2023-02-28 10:41:24 +08:00
|
|
|
MNN_PRINT("CUDABackend The Creator Don't support type %s\n", EnumNameOpType(opType));
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
return NULL;
|
|
|
|
}
|
|
|
|
#ifdef LOG_VERBOSE
|
2020-11-13 09:01:15 +08:00
|
|
|
MNN_PRINT("End CUDABackend::onCreate \n");
|
2020-11-05 16:41:56 +08:00
|
|
|
#endif
|
2023-02-28 10:41:24 +08:00
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
return exe;
|
|
|
|
}
|
|
|
|
|
2021-01-06 16:29:37 +08:00
|
|
|
void CUDABackend::onResizeBegin() {
|
|
|
|
}
|
|
|
|
|
2023-09-20 20:16:25 +08:00
|
|
|
ErrorCode CUDABackend::onResizeEnd() {
|
2023-09-21 21:29:53 +08:00
|
|
|
return NO_ERROR;
|
2021-01-06 16:29:37 +08:00
|
|
|
}
|
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
void CUDABackend::onExecuteBegin() const {
|
|
|
|
}
|
|
|
|
|
|
|
|
void CUDABackend::onExecuteEnd() const {
|
|
|
|
}
|
2022-09-30 10:02:52 +08:00
|
|
|
|
2022-02-18 11:30:27 +08:00
|
|
|
static void _computeStride(MNN_DATA_FORMAT srcDimensionFormat, int* srcStride, int batch, int plane, int channel, int srcPack) {
|
|
|
|
if (srcDimensionFormat == MNN_DATA_FORMAT_NC4HW4) {
|
|
|
|
srcStride[0] = plane * srcPack;
|
|
|
|
srcStride[1] = plane * batch * PACK_NUMBER;
|
|
|
|
srcStride[2] = srcPack;
|
|
|
|
} else if (srcDimensionFormat == MNN_DATA_FORMAT_NCHW) {
|
|
|
|
srcStride[0] = channel * plane;
|
|
|
|
srcStride[1] = plane * PACK_NUMBER;
|
|
|
|
srcStride[2] = 1;
|
|
|
|
} else {
|
|
|
|
srcStride[0] = channel * plane;
|
|
|
|
srcStride[1] = PACK_NUMBER;
|
|
|
|
srcStride[2] = channel;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
static void _computeBCA(int& batch, int& plane, int& channel, MNN_DATA_FORMAT srcDimensionFormat, const Tensor* srcTensor) {
|
2022-08-12 10:30:48 +08:00
|
|
|
if(srcTensor->dimensions() == 0) {
|
|
|
|
batch = 1;
|
|
|
|
plane = 1;
|
|
|
|
channel = 1;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2022-02-18 11:30:27 +08:00
|
|
|
if (srcDimensionFormat != MNN_DATA_FORMAT_NHWC) {
|
|
|
|
batch = srcTensor->length(0);
|
2023-04-27 15:11:05 +08:00
|
|
|
channel = 1;
|
|
|
|
if(srcTensor->dimensions() > 1) {
|
|
|
|
channel = srcTensor->length(1);
|
|
|
|
}
|
2022-02-18 11:30:27 +08:00
|
|
|
plane = 1;
|
|
|
|
for (int i=2; i<srcTensor->dimensions(); ++i) {
|
|
|
|
plane *= srcTensor->length(i);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
batch = srcTensor->length(0);
|
2022-08-12 10:30:48 +08:00
|
|
|
channel = 1;
|
|
|
|
if(srcTensor->dimensions() > 1) {
|
|
|
|
channel = srcTensor->length(srcTensor->dimensions()-1);
|
|
|
|
}
|
2022-02-18 11:30:27 +08:00
|
|
|
plane = 1;
|
|
|
|
for (int i=1; i<srcTensor->dimensions()-1; ++i) {
|
|
|
|
plane *= srcTensor->length(i);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
static PackInfo _computePackInfo(MNN_DATA_FORMAT srcDimensionFormat, int batch, int plane, int channel) {
|
|
|
|
PackInfo pack;
|
|
|
|
pack.inside = plane;
|
|
|
|
pack.axis = channel;
|
|
|
|
pack.unit = PACK_NUMBER;
|
|
|
|
pack.outside = batch;
|
|
|
|
if (srcDimensionFormat == MNN_DATA_FORMAT_NHWC) {
|
|
|
|
pack.axisStride = 1;
|
|
|
|
pack.insideStride = channel;
|
|
|
|
} else {
|
|
|
|
pack.axisStride = plane;
|
|
|
|
pack.insideStride = 1;
|
|
|
|
}
|
|
|
|
return pack;
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
|
|
|
|
void CUDABackend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor) const {
|
2022-09-30 10:02:52 +08:00
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
auto srcDimensionFormat = TensorUtils::getDescribe(srcTensor)->dimensionFormat;
|
|
|
|
auto dstDimensionFormat = TensorUtils::getDescribe(dstTensor)->dimensionFormat;
|
2022-09-30 10:02:52 +08:00
|
|
|
auto srcIndex = TensorUtils::getDescribe(srcTensor)->index;
|
|
|
|
auto dstIndex = TensorUtils::getDescribe(dstTensor)->index;
|
2023-04-11 11:12:00 +08:00
|
|
|
auto srcDevice = (srcTensor->deviceId() != 0 && srcTensor->deviceId() != 1);
|
|
|
|
auto dstDevice = (dstTensor->deviceId() != 0 && dstTensor->deviceId() != 1);
|
2022-02-18 11:30:27 +08:00
|
|
|
MNN_ASSERT(srcDevice || dstDevice);
|
|
|
|
uint8_t* srcPtr = nullptr;
|
2023-09-04 10:42:11 +08:00
|
|
|
MemChunk tempSrcStorage;
|
2022-02-18 11:30:27 +08:00
|
|
|
auto bytes = getBytes(srcTensor);
|
|
|
|
auto type = srcTensor->getType();
|
2022-09-30 10:02:52 +08:00
|
|
|
|
2023-06-16 09:42:45 +08:00
|
|
|
//MNN_PRINT("%d-%d\n", srcTensor->dimensions(), dstTensor->dimensions());
|
|
|
|
bool directCopy = ((srcDimensionFormat == dstDimensionFormat && dstDimensionFormat != MNN_DATA_FORMAT_NC4HW4) || srcTensor->dimensions() <= 1) && \
|
|
|
|
(getDataType(srcTensor) == getDataType(dstTensor));
|
|
|
|
if (mPrecision == 2 || mPrecision == 3) { // Fp16 or Bf16
|
2022-09-30 10:02:52 +08:00
|
|
|
if (((!srcDevice) || (!dstDevice))){
|
|
|
|
if (type.code == halide_type_float) {
|
|
|
|
directCopy = false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-02-18 11:30:27 +08:00
|
|
|
#ifdef MNN_CUDA_COPY_DEBUG
|
2022-08-12 10:30:48 +08:00
|
|
|
checkKernelErrors;
|
2022-09-30 10:02:52 +08:00
|
|
|
MNN_PRINT("CUDA Bn copy tensor ptr:%p -> ptr:%p deviceId:%d -> %d, hostPtr:%p -> %p, graphIndex: %d -> %d, format %d -> %d, directCopy: %d, dims: [",
|
|
|
|
srcTensor, dstTensor, srcTensor->deviceId(), dstTensor->deviceId(), srcTensor->host<void>(), dstTensor->host<void>(), srcIndex, dstIndex, srcDimensionFormat, dstDimensionFormat, directCopy);
|
|
|
|
|
2022-02-18 11:30:27 +08:00
|
|
|
for (int i=0; i<srcTensor->dimensions(); ++i) {
|
|
|
|
MNN_PRINT("%d ", srcTensor->length(i));
|
2022-08-12 10:30:48 +08:00
|
|
|
if(srcDevice && !dstDevice) {
|
2023-06-16 09:42:45 +08:00
|
|
|
MNN_PRINT("\n");
|
2022-08-12 10:30:48 +08:00
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
2022-08-12 10:30:48 +08:00
|
|
|
MNN_PRINT("], ");
|
|
|
|
MNN_PRINT("addr:%p %p\n", srcTensor->deviceId(), dstTensor->deviceId());
|
2022-02-18 11:30:27 +08:00
|
|
|
#endif
|
2022-08-12 10:30:48 +08:00
|
|
|
|
2024-02-29 16:21:40 +08:00
|
|
|
// printf("MNN srcDevice:%d %llu, dstDevice:%d %llu, directCopy:%d\n", srcDevice, srcTensor->deviceId(), dstDevice, dstTensor->deviceId(), directCopy);
|
2022-02-18 11:30:27 +08:00
|
|
|
if (directCopy) {
|
|
|
|
auto gpuSize = realSize(srcTensor) * getBytes(srcTensor);
|
|
|
|
if (srcDevice && dstDevice) {
|
2022-09-30 10:02:52 +08:00
|
|
|
NVTX_PUSH("DtoD");
|
2022-02-18 11:30:27 +08:00
|
|
|
mCUDARuntime->memcpy((void*)(dstTensor->deviceId()), (void*)(srcTensor->deviceId()), gpuSize,
|
|
|
|
MNNMemcpyDeviceToDevice, true);
|
2022-09-30 10:02:52 +08:00
|
|
|
NVTX_POP();
|
2022-02-18 11:30:27 +08:00
|
|
|
} else if (srcDevice && (!dstDevice)) {
|
2022-09-30 10:02:52 +08:00
|
|
|
NVTX_PUSH("DtoH");
|
2022-02-18 11:30:27 +08:00
|
|
|
mCUDARuntime->memcpy((void*)(dstTensor->host<void>()), (void*)(srcTensor->deviceId()), gpuSize,
|
|
|
|
MNNMemcpyDeviceToHost, true);
|
2022-09-30 10:02:52 +08:00
|
|
|
NVTX_POP();
|
2022-02-18 11:30:27 +08:00
|
|
|
} else if ((!srcDevice) && (dstDevice)) {
|
2022-09-30 10:02:52 +08:00
|
|
|
NVTX_PUSH("HtoD");
|
2022-02-18 11:30:27 +08:00
|
|
|
mCUDARuntime->memcpy((void*)(dstTensor->deviceId()), (void*)(srcTensor->host<void>()), gpuSize,
|
|
|
|
MNNMemcpyHostToDevice, true);
|
2022-09-30 10:02:52 +08:00
|
|
|
NVTX_POP();
|
2022-02-18 11:30:27 +08:00
|
|
|
}
|
|
|
|
return;
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
2022-02-18 11:30:27 +08:00
|
|
|
if (!srcDevice) {
|
|
|
|
auto cpuSize = srcTensor->size();
|
|
|
|
tempSrcStorage = mStaticBufferPool->alloc(cpuSize);
|
2023-09-04 10:42:11 +08:00
|
|
|
srcPtr = tempSrcStorage.ptr();
|
2022-02-18 11:30:27 +08:00
|
|
|
mCUDARuntime->memcpy(srcPtr, srcTensor->host<void>(), cpuSize, MNNMemcpyHostToDevice,
|
2021-04-08 15:34:23 +08:00
|
|
|
true);
|
2022-02-18 11:30:27 +08:00
|
|
|
} else {
|
|
|
|
srcPtr = (uint8_t*)srcTensor->deviceId();
|
|
|
|
}
|
|
|
|
uint8_t* dstPtr = nullptr;
|
2023-09-04 10:42:11 +08:00
|
|
|
MemChunk tempDstStorage;
|
2022-02-18 11:30:27 +08:00
|
|
|
if (!dstDevice) {
|
|
|
|
auto cpuSize = dstTensor->size();
|
|
|
|
tempDstStorage = mStaticBufferPool->alloc(cpuSize);
|
2023-09-04 10:42:11 +08:00
|
|
|
dstPtr = tempDstStorage.ptr();
|
2022-02-18 11:30:27 +08:00
|
|
|
} else {
|
|
|
|
dstPtr = (uint8_t*)dstTensor->deviceId();
|
|
|
|
}
|
|
|
|
|
2022-09-30 10:02:52 +08:00
|
|
|
NVTX_PUSH("copy convert");
|
2022-02-18 11:30:27 +08:00
|
|
|
// Format convert
|
2022-08-12 10:30:48 +08:00
|
|
|
int batch, plane, channel;
|
|
|
|
_computeBCA(batch, plane, channel, srcDimensionFormat, srcTensor);
|
|
|
|
|
|
|
|
// for (int i=0; i<srcTensor->dimensions(); ++i) {
|
|
|
|
// MNN_PRINT("%d ", srcTensor->length(i));
|
|
|
|
// }
|
|
|
|
// MNN_PRINT("\n, batch:%d, plane:%d, channel:%d, dims:%d\n", batch, plane, channel, srcTensor->dimensions());
|
2023-06-16 09:42:45 +08:00
|
|
|
// MNN_PRINT("oncopybuffer dateType:%d->%d format:%d->%d\n", getDataType(srcTensor), getDataType(dstTensor), srcDimensionFormat, dstDimensionFormat);
|
|
|
|
|
|
|
|
std::unique_ptr<Tensor> wrapTensor;
|
2023-09-04 10:42:11 +08:00
|
|
|
MemChunk wrapSrcStorage;
|
2023-06-16 09:42:45 +08:00
|
|
|
if (getDataType(srcTensor) != getDataType(dstTensor)) {
|
|
|
|
auto dimType = Tensor::CAFFE;
|
|
|
|
switch (TensorUtils::getDescribe(srcTensor)->dimensionFormat) {
|
|
|
|
case MNN_DATA_FORMAT_NCHW:
|
|
|
|
break;
|
|
|
|
case MNN_DATA_FORMAT_NC4HW4:
|
|
|
|
dimType = Tensor::CAFFE_C4;
|
|
|
|
break;
|
|
|
|
case MNN_DATA_FORMAT_NHWC:
|
|
|
|
dimType = Tensor::TENSORFLOW;
|
|
|
|
break;
|
|
|
|
default:
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto convertType = CastCreator::FlOAT_TO_INT8;
|
|
|
|
if (getDataType(srcTensor) == DataType_DT_INT8) {
|
|
|
|
convertType = CastCreator::INT8_TO_FlOAT;
|
|
|
|
}
|
|
|
|
|
|
|
|
wrapTensor.reset(Tensor::createDevice(srcTensor->shape(), dstTensor->getType(), dimType));
|
|
|
|
wrapSrcStorage = mStaticBufferPool->alloc(realSize(wrapTensor.get()) * getBytes(dstTensor));
|
|
|
|
// MNN_PRINT("warp:%d %d %d %d\n", realSize(wrapTensor.get()), getBytes(dstTensor), dstTensor->getType(), srcTensor->getDimensionType());
|
2023-09-04 10:42:11 +08:00
|
|
|
wrapTensor.get()->buffer().device = (uint64_t)(wrapSrcStorage.ptr());
|
2023-06-16 09:42:45 +08:00
|
|
|
|
|
|
|
auto dstType = getDataType(dstTensor);
|
|
|
|
if (dstType != DataType_DT_FLOAT) {
|
|
|
|
wrapTensor->setType(dstType);
|
|
|
|
}
|
|
|
|
|
|
|
|
#ifdef LOG_VERBOSE
|
|
|
|
MNN_PRINT("CPU backend copy tensor ptr:%p -> ptr:%p hostPtr:%p -> %p, format %d -> %d, dims: [",
|
|
|
|
srcTensor, dstTensor, srcTensor->host<void>(), dstTensor->host<void>(), TensorUtils::getDescribe(srcTensor)->dimensionFormat, TensorUtils::getDescribe(dstTensor)->dimensionFormat);
|
|
|
|
for (int i=0; i<srcTensor->dimensions(); ++i) {
|
|
|
|
MNN_PRINT("%d ", srcTensor->length(i));
|
|
|
|
}
|
|
|
|
MNN_PRINT("]\n");
|
|
|
|
#endif
|
|
|
|
|
|
|
|
auto code = CastCreator::cast(srcTensor, wrapTensor.get(), (Backend*)this, convertType);
|
|
|
|
if (NO_ERROR != code) {
|
|
|
|
MNN_ERROR("Error in CudaBackend::onCopyBuffer:cast\n");
|
|
|
|
}
|
|
|
|
srcTensor = wrapTensor.get();
|
|
|
|
srcPtr = (uint8_t*)srcTensor->deviceId();
|
|
|
|
}
|
2022-08-12 10:30:48 +08:00
|
|
|
|
|
|
|
FormatConvert((float *)dstPtr, (float *)srcPtr, srcDimensionFormat, dstDimensionFormat, mCUDARuntime.get(), \
|
|
|
|
plane, batch, channel, srcTensor, \
|
2023-06-16 09:42:45 +08:00
|
|
|
mPrecision, srcDevice, dstDevice);
|
2022-08-12 10:30:48 +08:00
|
|
|
|
2022-02-18 11:30:27 +08:00
|
|
|
if (!srcDevice) {
|
|
|
|
mStaticBufferPool->free(tempSrcStorage);
|
|
|
|
}
|
|
|
|
if (!dstDevice) {
|
|
|
|
auto cpuSize = dstTensor->size();
|
|
|
|
mCUDARuntime->memcpy(dstTensor->host<void>(), dstPtr, cpuSize, MNNMemcpyDeviceToHost,
|
2020-11-05 16:41:56 +08:00
|
|
|
true);
|
2022-09-30 10:02:52 +08:00
|
|
|
mStaticBufferPool->free(tempDstStorage);
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
2022-09-30 10:02:52 +08:00
|
|
|
NVTX_POP();
|
2020-11-05 16:41:56 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2024-02-29 16:21:40 +08:00
|
|
|
int CUDABackend::onSync(Tensor::MapType mtype, bool toCpu, const Tensor* dstTensor) {
|
|
|
|
if (toCpu) {
|
|
|
|
mCUDARuntime->device_sync();
|
|
|
|
}
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
2023-06-16 09:42:45 +08:00
|
|
|
DataType CUDABackend::getDataType(const Tensor* tensor) {
|
|
|
|
auto des = TensorUtils::getDescribe(tensor);
|
|
|
|
if (nullptr == des->quantAttr.get()) {
|
|
|
|
return DataType_DT_FLOAT;
|
|
|
|
}
|
|
|
|
return des->type;
|
|
|
|
}
|
|
|
|
|
|
|
|
ErrorCode CastWrapExecution::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
|
|
|
|
auto convertType = mRunType == DataType_DT_INT8 ? CastCreator::FlOAT_TO_INT8 : CastCreator::INT8_TO_FlOAT;
|
|
|
|
auto cudaBackend = ((CUDABackend*)backend());
|
|
|
|
CastCreator::cast(inputs[0], outputs[0], cudaBackend, convertType);
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
bool CUDABackend::addCreator(OpType t, Creator* c) {
|
|
|
|
auto map = gCreator();
|
|
|
|
if (map->find(t) != map->end()) {
|
|
|
|
MNN_PRINT("Error: %d type has be added\n", t);
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
map->insert(std::make_pair(t, c));
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
} // namespace CUDA
|
|
|
|
} // namespace MNN
|