MNN/source/backend/cuda/core/CUDABackend.hpp

104 lines
3.3 KiB
C++
Raw Normal View History

2020-11-05 16:41:56 +08:00
//
// CUDABackend.hpp
// MNN
//
// Created by MNN on 2019/01/31.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef CUDABackend_hpp
#define CUDABackend_hpp
#include <set>
#include <vector>
#include "MNN_generated.h"
#include "backend/cuda/core/runtime/CUDARuntime.hpp"
#include "core/Backend.hpp"
#include "core/Macro.h"
#include "core/ConvolutionCommon.hpp"
2020-12-15 14:12:35 +08:00
#include "core/BufferAllocator.hpp"
2020-11-05 16:41:56 +08:00
namespace MNN {
namespace CUDA {
class MNN_PUBLIC CUDARuntimeWrapper : public Runtime {
public:
CUDARuntimeWrapper(BackendConfig::PrecisionMode precision, BackendConfig::PowerMode power);
virtual ~CUDARuntimeWrapper();
virtual Backend *onCreate() const override;
virtual void onGabageCollect(int level) override;
bool isCreateError() const {
return mIsCreateError;
}
private:
2020-12-15 14:12:35 +08:00
std::shared_ptr<BufferAllocator> mBufferPool;
2020-11-05 16:41:56 +08:00
std::shared_ptr<CUDARuntime> mCUDARuntime;
bool mIsCreateError{false};
};
class CUDABackend final : public Backend {
public:
2020-12-15 14:12:35 +08:00
CUDABackend(std::shared_ptr<BufferAllocator> st, std::shared_ptr<CUDARuntime> rt);
2020-11-05 16:41:56 +08:00
~CUDABackend();
CUDARuntime *getCUDARuntime();
virtual bool onAcquireBuffer(const Tensor *nativeTensor, StorageType storageType) override;
virtual bool onReleaseBuffer(const Tensor *nativeTensor, StorageType storageType) override;
virtual bool onClearBuffer() override;
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const MNN::Op *op) override;
virtual void onExecuteBegin() const override;
virtual void onExecuteEnd() const override;
virtual void onCopyBuffer(const Tensor *srcTensor, const Tensor *dstTensor) const override;
class Creator {
public:
virtual ~Creator() = default;
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &output,
const MNN::Op *op, Backend *backend) const = 0;
};
static bool addCreator(OpType t, Creator *c);
2020-12-15 14:12:35 +08:00
BufferAllocator *getBufferPool() const {
2020-11-05 16:41:56 +08:00
return mBufferPool.get();
}
2020-12-15 14:12:35 +08:00
BufferAllocator *getStaticBufferPool() const {
2020-11-05 16:41:56 +08:00
return mStaticBufferPool.get();
}
virtual std::pair<float, bool> onMeasure(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const MNN::Op *op) override;
static size_t realSize(const Tensor *tensor);
private:
2020-12-15 14:12:35 +08:00
std::shared_ptr<BufferAllocator> mBufferPool;
std::shared_ptr<BufferAllocator> mStaticBufferPool;
2020-11-05 16:41:56 +08:00
std::shared_ptr<CUDARuntime> mCUDARuntime;
};
template <class T>
class CUDACreatorRegister {
public:
CUDACreatorRegister(OpType type) {
T *t = new T;
CUDABackend::addCreator(type, t);
}
~CUDACreatorRegister() = default;
};
template <typename T>
class TypedCreator : public CUDABackend::Creator {
public:
virtual ~TypedCreator() = default;
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const MNN::Op *op, Backend *backend) const override {
return new T(inputs, op, backend);
}
};
} // namespace CUDA
} // namespace MNN
#endif /* CUDABackend_hpp */