MNN/source/geometry/GeometryComputer.cpp

221 lines
7.4 KiB
C++
Raw Normal View History

2020-11-05 16:41:56 +08:00
//
// GeometryComputer.cpp
// MNN
//
// Created by MNN on 2020/04/01.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include <mutex>
#include "geometry/GeometryComputer.hpp"
#include "core/Backend.hpp"
#include "core/OpCommonUtils.hpp"
#include "shape/SizeComputer.hpp"
#include "core/TensorUtils.hpp"
namespace MNN {
2022-12-30 15:18:58 +08:00
2020-11-05 16:41:56 +08:00
GeometryComputer::Context::~Context() {
// Do nothing
2020-11-05 16:41:56 +08:00
}
2022-12-30 15:18:58 +08:00
GeometryComputer::Context::Context(std::shared_ptr<Backend> allocBackend, MNNForwardType type) {
2020-11-05 16:41:56 +08:00
mBackend = allocBackend;
flatbuffers::FlatBufferBuilder builder(32);
2020-11-05 16:41:56 +08:00
OpBuilder opBuilder(builder);
opBuilder.add_type(OpType_Raster);
auto lastOffset = opBuilder.Finish();
builder.Finish(lastOffset);
mRasterOp.reset(new BufferStorage);
mRasterOp->storage = builder.ReleaseRaw(mRasterOp->allocated_size, mRasterOp->offset);
mForwardType = type;
2020-11-05 16:41:56 +08:00
}
void GeometryComputer::Context::pushCache(const CommandBuffer& buffer) {
for (auto cmd : buffer.command) {
if (cmd->op->type() == OpType_Raster) {
mRasterCmdCache.emplace_back(cmd);
}
}
}
void GeometryComputer::Context::clear() {
mTempConstTensors.clear();
2020-11-05 16:41:56 +08:00
}
const std::vector<std::shared_ptr<Tensor>>& GeometryComputer::Context::searchConst(const Op* op) {
2020-11-05 16:41:56 +08:00
auto iter = mConstTensors.find(op);
if (iter == mConstTensors.end()) {
mConstTensors.insert(std::make_pair(op, std::vector<std::shared_ptr<Tensor>>{}));
2020-11-05 16:41:56 +08:00
return mEmpty;
}
return iter->second;
}
std::shared_ptr<Tensor> GeometryComputer::Context::allocConst(const Op* key, const std::vector<int>& shape,
halide_type_t type, Tensor::DimensionType dimType) {
std::shared_ptr<Tensor> tensor(Tensor::createDevice(shape, type, dimType));
TensorUtils::getDescribe(tensor.get())->usage = Tensor::InsideDescribe::CONSTANT;
auto res = mBackend->onAcquireBuffer(tensor.get(), Backend::STATIC);
if (!res) {
return nullptr;
}
2023-06-16 09:42:45 +08:00
TensorUtils::getDescribe(tensor.get())->setBackend(mBackend.get());
auto iter = mConstTensors.find(key);
if (iter != mConstTensors.end()) {
iter->second.emplace_back(tensor);
} else {
mTempConstTensors.emplace_back(tensor);
}
2020-11-05 16:41:56 +08:00
return tensor;
}
2021-04-08 15:34:23 +08:00
bool GeometryComputer::Context::allocTensor(Tensor* tensor) {
auto res = mBackend->onAcquireBuffer(tensor, Backend::STATIC);
if (!res) {
return false;
}
TensorUtils::getDescribe(tensor)->usage = Tensor::InsideDescribe::CONSTANT;
2023-06-16 09:42:45 +08:00
TensorUtils::getDescribe(tensor)->setBackend(mBackend.get());
2021-04-08 15:34:23 +08:00
return true;
}
2022-06-10 10:39:50 +08:00
inline bool _hasZeroDim(const Tensor* t) {
for (int i = 0; i < t->dimensions(); ++i) {
if (t->length(i) <= 0) {
return true;
}
}
return false;
}
void GeometryComputer::Context::getRasterCacheCreateRecursive(Tensor* src, CommandBuffer& cmd) {
2020-11-05 16:41:56 +08:00
auto srcDes = TensorUtils::getDescribe(src);
if (srcDes->memoryType != Tensor::InsideDescribe::MEMORY_VIRTUAL) {
2021-04-08 15:34:23 +08:00
return;
2020-11-05 16:41:56 +08:00
}
2022-06-10 10:39:50 +08:00
if (_hasZeroDim(src)) {
return;
}
2020-11-05 16:41:56 +08:00
for (auto& input : srcDes->regions) {
MNN_ASSERT(input.origin != src);
auto inputDes = TensorUtils::getDescribe(input.origin);
while (inputDes->memoryType == Tensor::InsideDescribe::MEMORY_VIRTUAL) {
if (1 != inputDes->regions.size()) {
break;
}
bool merge = TensorUtils::fuseRegion(inputDes->regions[0], input);
if (!merge) {
break;
}
inputDes = TensorUtils::getDescribe(input.origin);
}
2022-06-10 10:39:50 +08:00
getRasterCacheCreateRecursive(input.origin, cmd);
2020-11-05 16:41:56 +08:00
}
2021-04-08 15:34:23 +08:00
getRasterCacheCreate(src, cmd);
2020-11-05 16:41:56 +08:00
}
2021-04-08 15:34:23 +08:00
void GeometryComputer::Context::getRasterCacheCreate(Tensor* src, CommandBuffer& cmdBuffer) {
2020-11-05 16:41:56 +08:00
auto srcDes = TensorUtils::getDescribe(src);
if (srcDes->memoryType != Tensor::InsideDescribe::MEMORY_VIRTUAL) {
2021-04-08 15:34:23 +08:00
return;
2020-11-05 16:41:56 +08:00
}
srcDes->memoryType = Tensor::InsideDescribe::MEMORY_BACKEND;
if (mRasterCmdCache.empty()) {
SharedPtr<Command> cmdP(new Command);
auto& cmd = *cmdP;
cmd.op = flatbuffers::GetRoot<Op>(mRasterOp->buffer());
cmd.buffer = mRasterOp;
cmd.outputs = {src};
2022-12-30 15:18:58 +08:00
TensorUtils::setRasterInputs(cmdP.get());
cmdBuffer.command.emplace_back(std::move(cmdP));
return;
}
auto iter = mRasterCmdCache.begin() + ((int)mRasterCmdCache.size() - 1);
auto cmdP = *iter;
mRasterCmdCache.erase(iter);
cmdP->outputs[0] = src;
2022-12-30 15:18:58 +08:00
TensorUtils::setRasterInputs(cmdP.get());
cmdBuffer.command.emplace_back(std::move(cmdP));
2020-11-05 16:41:56 +08:00
}
bool DefaultGeometryComputer::onRecompute(const Op* op, const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
Context& context, CommandBuffer& cmd) const {
if (1 != cmd.command.size()) {
return false;
2020-11-05 16:41:56 +08:00
}
return true;
2020-11-05 16:41:56 +08:00
}
bool DefaultGeometryComputer::onCompute(const Op* op, const std::vector<Tensor*>& originInputs,
const std::vector<Tensor*>& outputs, GeometryComputer::Context& context,
CommandBuffer& res) const {
auto inputs = originInputs;
// Last Command
SharedPtr<Command> cmdP(new Command);
auto& cmd = *cmdP;
2020-11-05 16:41:56 +08:00
cmd.op = op;
cmd.inputs = std::move(inputs);
cmd.outputs = std::move(outputs);
res.command.emplace_back(std::move(cmdP));
2020-11-05 16:41:56 +08:00
return true;
}
class GeometryComputerManager {
public:
GeometryComputer* search(int type, Runtime::CompilerType compType) {
if (Runtime::Compiler_Origin == compType) {
return &mDefault;
}
if (Runtime::Compiler_Loop == compType) {
auto iter = mLoopTable[type].get();
if (iter != nullptr) {
return iter;
}
}
// Geometry
auto iter = mTable[type].get();
if (iter != nullptr) {
2020-11-05 16:41:56 +08:00
// FUNC_PRINT(type);
return iter;
2020-11-05 16:41:56 +08:00
}
return &mDefault;
}
static void init() {
2021-02-07 10:45:07 +08:00
gInstance = new GeometryComputerManager;
gInstance->mTable.resize(OpType_MAX + 1);
gInstance->mLoopTable.resize(OpType_MAX + 1);
2020-11-05 16:41:56 +08:00
}
static GeometryComputerManager* get() {
2021-02-07 10:45:07 +08:00
return gInstance;
2020-11-05 16:41:56 +08:00
}
void insert(std::shared_ptr<GeometryComputer> c, int type, Runtime::CompilerType compType) {
if (Runtime::Compiler_Geometry == compType) {
mTable[type] = c;
} else if (Runtime::Compiler_Loop == compType) {
mLoopTable[type] = c;
}
2020-11-05 16:41:56 +08:00
}
private:
std::vector<std::shared_ptr<GeometryComputer>> mTable;
std::vector<std::shared_ptr<GeometryComputer>> mLoopTable;
2021-02-07 10:45:07 +08:00
static GeometryComputerManager* gInstance;
2020-11-05 16:41:56 +08:00
DefaultGeometryComputer mDefault;
};
2021-02-07 10:45:07 +08:00
GeometryComputerManager* GeometryComputerManager::gInstance;
void GeometryComputer::registerGeometryComputer(std::shared_ptr<GeometryComputer> comp, std::vector<int> type, Runtime::CompilerType compType) {
2020-11-05 16:41:56 +08:00
auto ins = GeometryComputerManager::get();
for (auto t : type) {
ins->insert(comp, t, compType);
2020-11-05 16:41:56 +08:00
}
}
void GeometryComputer::init() {
if (nullptr == GeometryComputerManager::get()) {
GeometryComputerManager::init();
registerGeometryOps();
}
}
const GeometryComputer* GeometryComputer::search(int type, Runtime::CompilerType compType) {
return GeometryComputerManager::get()->search(type, compType);
2020-11-05 16:41:56 +08:00
}
} // namespace MNN