MNN/source/core/Backend.cpp

166 lines
4.6 KiB
C++
Raw Normal View History

2019-04-17 10:49:11 +08:00
//
// Backend.cpp
// MNN
//
// Created by MNN on 2018/07/06.
// Copyright © 2018, Alibaba Group Holding Limited
//
2020-11-05 16:41:56 +08:00
#include "core/Backend.hpp"
2019-04-17 10:49:11 +08:00
#include <stdio.h>
#include <mutex>
#include "MNN_generated.h"
2020-11-05 16:41:56 +08:00
#include "backend/cpu/CPUTensorConvert.hpp"
2019-12-27 22:16:57 +08:00
#include "core/Macro.h"
#include "core/TensorUtils.hpp"
2024-02-29 16:21:40 +08:00
#include "geometry/GeometryComputer.hpp"
#include "shape/SizeComputer.hpp"
#ifdef MNN_INTERNAL_ENABLED
#include "internal/logging/Log.hpp"
#endif
2019-04-17 10:49:11 +08:00
namespace MNN {
2020-11-05 16:41:56 +08:00
static std::map<MNNForwardType, std::pair<const RuntimeCreator*, bool>>& GetExtraCreator() {
static std::once_flag gInitFlag;
2020-11-05 16:41:56 +08:00
static std::map<MNNForwardType, std::pair<const RuntimeCreator*, bool>>* gExtraCreator;
std::call_once(gInitFlag,
[&]() { gExtraCreator = new std::map<MNNForwardType, std::pair<const RuntimeCreator*, bool>>; });
2019-04-17 10:49:11 +08:00
return *gExtraCreator;
}
2024-02-29 16:21:40 +08:00
extern void registerCPURuntimeCreator();
#if MNN_METAL_ENABLED
extern void registerMetalRuntimeCreator();
#endif
#if MNN_OPENCL_ENABLED
namespace OpenCL {
extern void registerOpenCLRuntimeCreator();
}
#endif
#if MNN_COREML_ENABLED
extern void registerCoreMLRuntimeCreator();
#endif
#if MNN_NNAPI_ENABLED
extern void registerNNAPIRuntimeCreator();
#endif
static std::once_flag s_flag;
void registerBackend() {
std::call_once(s_flag, [&]() {
#ifdef MNN_INTERNAL_ENABLED
LogInit();
#endif
registerCPURuntimeCreator();
#ifndef MNN_BUILD_MINI
SizeComputerSuite::init();
GeometryComputer::init();
#endif
#if MNN_COREML_ENABLED
registerCoreMLRuntimeCreator();
#endif
#ifdef MNN_NNAPI_ENABLED
registerNNAPIRuntimeCreator();
#endif
#if MNN_OPENCL_ENABLED
OpenCL::registerOpenCLRuntimeCreator();
#endif
#if MNN_METAL_ENABLED
registerMetalRuntimeCreator();
#endif
auto& gExtraCreator = GetExtraCreator();
for(auto iter = gExtraCreator.begin(); iter != gExtraCreator.end();){
if(!iter->second.second){
iter++;
}else{
Backend::Info info;
info.type = iter->first;
std::shared_ptr<Runtime> bn(iter->second.first->onCreate(info));
if (nullptr == bn.get()) {
iter = gExtraCreator.erase(iter);
MNN_ERROR("Error to use creator of %d, delete it\n", info.type);
}else{
iter++;
}
}
}
});
}
2020-11-05 16:41:56 +08:00
const RuntimeCreator* MNNGetExtraRuntimeCreator(MNNForwardType type) {
registerBackend();
2019-04-17 10:49:11 +08:00
auto& gExtraCreator = GetExtraCreator();
auto iter = gExtraCreator.find(type);
if (iter == gExtraCreator.end()) {
return nullptr;
}
if (!iter->second.second) {
return iter->second.first;
}
Backend::Info info;
info.type = type;
2020-11-05 16:41:56 +08:00
std::shared_ptr<Runtime> bn(iter->second.first->onCreate(info));
2019-04-17 10:49:11 +08:00
if (nullptr != bn.get()) {
return iter->second.first;
}
return nullptr;
}
2020-11-05 16:41:56 +08:00
bool MNNInsertExtraRuntimeCreator(MNNForwardType type, const RuntimeCreator* creator, bool needCheck) {
2019-04-17 10:49:11 +08:00
auto& gExtraCreator = GetExtraCreator();
if (gExtraCreator.find(type) != gExtraCreator.end()) {
MNN_ASSERT(false && "duplicate type");
return false;
}
gExtraCreator.insert(std::make_pair(type, std::make_pair(creator, needCheck)));
return true;
}
2020-11-05 16:41:56 +08:00
bool MNNCPUCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor) {
auto& srcBuffer = srcTensor->buffer();
auto& dstBuffer = dstTensor->buffer();
MNN_ASSERT(srcBuffer.dimensions == dstBuffer.dimensions);
MNN_ASSERT(srcBuffer.type == dstBuffer.type);
if (nullptr == srcBuffer.host || nullptr == dstBuffer.host) {
return false;
}
auto code = CPUTensorConverter::convert(srcTensor, dstTensor);
if (NO_ERROR != code) {
MNN_ERROR("Error in CPUBackend::onCopyBuffer\n");
}
return true;
}
bool Backend::onAcquireBuffer(const Tensor* tensor, StorageType storageType) {
auto mem = this->onAcquire(tensor, storageType);
if (nullptr == mem) {
return false;
}
2024-04-19 11:58:21 +08:00
if (mem == TensorUtils::getDescribeOrigin(tensor)->mem.get()) {
return true;
}
2024-04-19 11:58:21 +08:00
TensorUtils::getDescribeOrigin(tensor)->mem = mem;
return true;
}
bool Backend::onReleaseBuffer(const Tensor* tensor, StorageType storageType) {
2024-04-19 11:58:21 +08:00
TensorUtils::getDescribeOrigin(tensor)->mem = nullptr;
return true;
}
2022-12-30 15:18:58 +08:00
2022-01-04 10:50:40 +08:00
bool Runtime::hasAsyncWork() const {
return mFuture.valid();
}
void Runtime::setAsyncWork(std::future<int>&& future) {
mFuture = std::move(future);
}
void Runtime::waitAsyncWork() {
if (mFuture.valid()) {
mFuture.wait();
}
}
2019-04-17 10:49:11 +08:00
} // namespace MNN