MNN/source/backend/metal/MetalBackend.hpp

158 lines
4.9 KiB
C++
Raw Normal View History

2019-04-17 10:49:11 +08:00
//
// MetalBackend.hpp
// MNN
//
// Created by MNN on 2019/01/30.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef MetalBackend_hpp
#define MetalBackend_hpp
2019-12-27 22:16:57 +08:00
#include "core/Backend.hpp"
2019-04-17 10:49:11 +08:00
#include "MNN_generated.h"
#include "MetalDefine.h"
2020-11-05 16:41:56 +08:00
#include <vector>
2019-04-17 10:49:11 +08:00
#if MNN_METAL_ENABLED
namespace MNN {
2020-11-05 16:41:56 +08:00
/** MetalRuntime */
class MetalRuntime : public Runtime {
public:
friend class MetalBackend;
class BufferAllocator {
public:
BufferAllocator(void* context);
~ BufferAllocator();
id<MTLBuffer> alloc(size_t size, bool seperate = false);
void release(id<MTLBuffer> buffer);
void clear();
private:
std::map<id<MTLBuffer>, size_t> mAllocated;
std::multimap<size_t, id<MTLBuffer>> mReusableBuffers;
void* mContext = nullptr;
};
MetalRuntime();
virtual ~ MetalRuntime();
virtual Backend* onCreate() const override;
virtual void onGabageCollect(int level) override;
void *context() const {
return mContext;
}
id<MTLBuffer> getHostBuffer(size_t size) const;
private:
void* mContext = nullptr;
std::shared_ptr<BufferAllocator> mStatic;
std::shared_ptr<BufferAllocator> mDynamic;
mutable id<MTLBuffer> mHostBuffer = nullptr;
};
2019-04-17 10:49:11 +08:00
/** Metal backend */
class MetalBackend final : public Backend {
public:
/** Metal execution creator */
class Creator {
public:
/**
* @brief create execution for given input, op on metal backend.
* @param inputs given input tensors.
* @param op given op.
* @param backend metal backend.
* @return created execution if supported, NULL otherwise.
*/
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const MNN::Op *op, Backend *backend) const = 0;
2019-04-17 10:49:11 +08:00
};
/**
* @brief register creator for given op type.
* @param type given op type.
* @param creator registering creator.
*/
static void addCreator(OpType type, Creator *creator);
2020-11-05 16:41:56 +08:00
class AutoBuffer {
public:
AutoBuffer(const MetalRuntime* runtime) {
mRuntime = runtime;
}
~ AutoBuffer();
void reset(size_t length);
id<MTLBuffer> buffer() const {
return mBuffer;
}
private:
const MetalRuntime* mRuntime = nullptr;
id<MTLBuffer> mBuffer = nil;
};
const MetalRuntime* runtime() const {
return mRuntime;
}
2019-04-17 10:49:11 +08:00
public:
2020-11-05 16:41:56 +08:00
MetalBackend(const MetalRuntime* runtime);
2019-04-17 10:49:11 +08:00
virtual ~MetalBackend();
virtual bool onAcquireBuffer(const Tensor *Tensor, StorageType storageType) override;
virtual bool onReleaseBuffer(const Tensor *Tensor, StorageType storageType) override;
virtual bool onClearBuffer() override;
virtual void onCopyBuffer(const Tensor *srcTensor, const Tensor *dstTensor) const override;
virtual Execution *onCreate(const std::vector<Tensor *> &inputs, const std::vector<Tensor *> &outputs,
const MNN::Op *op) override;
virtual void onExecuteBegin() const override;
virtual void onExecuteEnd() const override;
virtual std::pair<float, bool> onMeasure(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
const MNN::Op* op) override;
2019-04-17 10:49:11 +08:00
public:
/**
* @brief get metal context object
* @return metal context object pointer
*/
2020-11-05 16:41:56 +08:00
void *context() const;
2019-04-17 10:49:11 +08:00
/**
* @brief copy buffer content to dest tensor
* @param srcTensor source tensor
* @param dstTensor destined tensor
* @param encoder command encoder
*/
void onCopyBuffer(const Tensor *srcTensor, const Tensor *dstTensor,
id<MTLComputeCommandEncoder> encoder, id<MTLBuffer> shape) const;
2019-04-17 10:49:11 +08:00
void flushEncoder() const;
id<MTLComputeCommandEncoder> encoder() const;
2019-04-17 10:49:11 +08:00
private:
2020-11-05 16:41:56 +08:00
const MetalRuntime* mRuntime;
std::vector<id<MTLBuffer>> mHoldBuffers;
mutable id<MTLComputeCommandEncoder> mComputeEncoder = nil;
2019-04-17 10:49:11 +08:00
private:
id<MTLBuffer> getHostBuffer(size_t size) const;
void onCopyHostToDevice(const Tensor *src, const Tensor *dst) const;
void onCopyDeviceToHost(const Tensor *src, const Tensor *dst) const;
void onCopyDeviceToDevice(const Tensor *src, const Tensor *dst, id<MTLComputeCommandEncoder> encoder, id<MTLBuffer> shape) const;
2019-04-17 10:49:11 +08:00
};
2020-11-05 16:41:56 +08:00
2019-04-17 10:49:11 +08:00
/** Metal creator register */
template <class T>
class MetalCreatorRegister {
public:
/**
* @brief initializer. register T creator for given op type.
* @param type given op type.
*/
MetalCreatorRegister(OpType type) {
T *test = new T;
MetalBackend::addCreator(type, test);
}
};
} // namespace MNN
#define REGISTER_METAL_OP_CREATOR(name, opType) \
void ___##name##__##opType##__() { \
MetalBackend::addCreator(opType, new name); \
}
2019-04-17 10:49:11 +08:00
#endif /* MNN_METAL_ENABLED */
#endif /* MetalBackend_hpp */