MNN/source/core/WrapExecution.cpp

189 lines
7.6 KiB
C++
Raw Normal View History

2019-04-17 10:49:11 +08:00
//
// WrapExecution.cpp
// MNN
//
// Created by MNN on 2018/09/03.
// Copyright © 2018, Alibaba Group Holding Limited
//
2019-12-27 22:16:57 +08:00
#include "core/WrapExecution.hpp"
#include "core/TensorUtils.hpp"
#include "backend/cpu/CPUBackend.hpp"
#include "backend/cpu/compute/CommonOptFunction.h"
2019-04-17 10:49:11 +08:00
namespace MNN {
bool WrapExecution::needWrap(const Tensor* input, Backend* curBackend) {
if (curBackend->type() == MNN_FORWARD_NN) {
return false;
}
auto des = TensorUtils::getDescribe(input);
auto bn = des->backend;
MNNForwardType type = MNN_FORWARD_CPU;
int pack = 4;
int bytes = 4;
if (nullptr != bn) {
type = bn->type();
if (type == MNN_FORWARD_CPU_EXTENSION) {
auto core = static_cast<CPUBackend*>(bn)->functions();
pack = core->pack;
bytes = core->bytes;
}
}
if (type == curBackend->type()) {
return false;;
}
bool srcCpu = (type == MNN_FORWARD_CPU_EXTENSION || type == MNN_FORWARD_CPU);
bool dstCpu = ((curBackend->type() == MNN_FORWARD_CPU_EXTENSION) || (curBackend->type() == MNN_FORWARD_CPU));
if (srcCpu && dstCpu) {
auto dstCore = static_cast<CPUBackend*>(curBackend)->functions();
if (dstCore->bytes == bytes) {
if (dstCore->pack == pack || des->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) {
return false;
}
}
}
return true;
}
2019-04-17 10:49:11 +08:00
2020-11-05 16:41:56 +08:00
WrapExecution::WrapExecution(Backend* CPUBackend, std::shared_ptr<Execution> execution, bool isStatic)
2019-04-17 10:49:11 +08:00
: Execution(execution->backend()), mCPUBackend(CPUBackend), mExecution(execution) {
2020-11-05 16:41:56 +08:00
mValid = execution->valid();
mStatic = isStatic;
}
Tensor* WrapExecution::_getCopyTensor(Tensor* inputTensor) {
auto dstBackend = mExecution->backend();
auto inputDes = TensorUtils::getDescribe(inputTensor);
auto srcBackend = inputDes->backend;
if (nullptr == srcBackend) {
srcBackend = mCPUBackend;
}
// CPU -> CPU or XPU -> XPU
2021-01-06 16:29:37 +08:00
//if (srcBackend == dstBackend) {
if (srcBackend->type() == dstBackend->type()) {
2020-11-05 16:41:56 +08:00
return inputTensor;
}
auto iter = mInputMaps.find(inputTensor);
if (iter != mInputMaps.end()) {
return std::get<2>(iter->second).get();
}
// CPU -> XPU
2021-01-06 16:29:37 +08:00
if (srcBackend->type() == mCPUBackend->type()) {
2020-11-05 16:41:56 +08:00
std::shared_ptr<Tensor> wrapTensor(new Tensor);
TensorUtils::copyShape(inputTensor, wrapTensor.get(), true);
2021-04-08 15:34:23 +08:00
TensorUtils::adjustTensorForCompability(wrapTensor.get());
2020-11-05 16:41:56 +08:00
wrapTensor->buffer().type = inputTensor->buffer().type;
TensorUtils::getDescribe(wrapTensor.get())->quantAttr = TensorUtils::getDescribe(inputTensor)->quantAttr;
2020-11-05 16:41:56 +08:00
mInputMaps.insert(std::make_pair(inputTensor, std::make_tuple(dstBackend, dstBackend, wrapTensor)));
return wrapTensor.get();
}
// XPU -> CPU
2021-01-06 16:29:37 +08:00
if (dstBackend->type() == mCPUBackend->type()) {
2020-11-05 16:41:56 +08:00
std::shared_ptr<Tensor> wrapTensor(new Tensor);
TensorUtils::copyShape(inputTensor, wrapTensor.get(), true);
wrapTensor->buffer().type = inputTensor->buffer().type;
2021-04-08 15:34:23 +08:00
TensorUtils::adjustTensorForCompability(wrapTensor.get());
TensorUtils::getDescribe(wrapTensor.get())->quantAttr = TensorUtils::getDescribe(inputTensor)->quantAttr;
2020-11-05 16:41:56 +08:00
mInputMaps.insert(std::make_pair(inputTensor, std::make_tuple(mCPUBackend, srcBackend, wrapTensor)));
return wrapTensor.get();
}
// XPU -> CPU -> XPU'
std::shared_ptr<Tensor> midTensor(new Tensor);
std::shared_ptr<Tensor> wrapTensor(new Tensor);
TensorUtils::copyShape(inputTensor, midTensor.get(), true);
TensorUtils::copyShape(inputTensor, wrapTensor.get(), true);
2021-04-08 15:34:23 +08:00
TensorUtils::adjustTensorForCompability(wrapTensor.get());
TensorUtils::adjustTensorForCompability(midTensor.get());
2020-11-05 16:41:56 +08:00
TensorUtils::getDescribe(midTensor.get())->usage = TensorUtils::getDescribe(inputTensor)->usage;
TensorUtils::getDescribe(midTensor.get())->quantAttr = TensorUtils::getDescribe(inputTensor)->quantAttr;
2020-11-05 16:41:56 +08:00
midTensor->buffer().type = inputTensor->buffer().type;
wrapTensor->buffer().type = inputTensor->buffer().type;
mInputMaps.insert(std::make_pair(inputTensor, std::make_tuple(mCPUBackend, srcBackend, midTensor)));
mInputMaps.insert(std::make_pair(midTensor.get(), std::make_tuple(dstBackend, dstBackend, wrapTensor)));
return wrapTensor.get();
2019-04-17 10:49:11 +08:00
}
ErrorCode WrapExecution::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
2020-11-05 16:41:56 +08:00
mWrapInputTensors.resize(inputs.size());
2019-04-17 10:49:11 +08:00
mInputMaps.clear();
auto dstBackend = mExecution->backend();
for (int i = 0; i < inputs.size(); ++i) {
auto inputTensor = inputs[i];
2020-11-05 16:41:56 +08:00
auto des = TensorUtils::getDescribe(inputTensor);
if (des->memoryType == Tensor::InsideDescribe::MEMORY_VIRTUAL) {
MNN_ASSERT(inputs.size() == 1);
mWrapForRaster.reset(new Tensor);
TensorUtils::copyShape(inputTensor, mWrapForRaster.get(), true);
mWrapForRaster->buffer().type = inputTensor->buffer().type;
auto wrapDes = TensorUtils::getDescribe(mWrapForRaster.get());
wrapDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
wrapDes->regions = des->regions;
for (auto& r : wrapDes->regions) {
r.origin = _getCopyTensor(r.origin);
}
mWrapInputTensors[i] = mWrapForRaster.get();
} else {
mWrapInputTensors[i] = _getCopyTensor(inputTensor);
2019-04-17 10:49:11 +08:00
}
}
for (int i = 0; i < outputs.size(); ++i) {
MNN_ASSERT(TensorUtils::getDescribe(outputs[i])->backend == dstBackend);
}
bool memoryAllocSuccess = true;
2019-04-17 10:49:11 +08:00
// acquire memory, copy const tensors
for (auto& iter : mInputMaps) {
2020-11-05 16:41:56 +08:00
auto backend = std::get<0>(iter.second);
auto converter = std::get<1>(iter.second);
auto src = iter.first;
auto dst = std::get<2>(iter.second).get();
2019-04-17 10:49:11 +08:00
2020-11-05 16:41:56 +08:00
if (TensorUtils::getDescribe(src)->usage == TensorUsage::CONSTANT && mStatic) {
memoryAllocSuccess = backend->onAcquireBuffer(dst, Backend::DYNAMIC_SEPERATE);
if (memoryAllocSuccess) {
converter->onCopyBuffer(src, dst);
TensorUtils::getDescribe(dst)->usage = TensorUtils::getDescribe(src)->usage;
}
2019-04-17 10:49:11 +08:00
} else {
memoryAllocSuccess = backend->onAcquireBuffer(dst, Backend::DYNAMIC);
2019-04-17 10:49:11 +08:00
}
}
if (!memoryAllocSuccess) {
return OUT_OF_MEMORY;
}
2019-04-17 10:49:11 +08:00
// do resize
auto result = mExecution->onResize(mWrapInputTensors, outputs);
// release memory
for (auto& iter : mInputMaps) {
2020-11-05 16:41:56 +08:00
auto backend = std::get<0>(iter.second);
auto dst = std::get<2>(iter.second).get();
2019-04-17 10:49:11 +08:00
2020-11-05 16:41:56 +08:00
if (TensorUtils::getDescribe(dst)->usage == TensorUsage::CONSTANT && mStatic) {
2019-04-17 10:49:11 +08:00
backend->onReleaseBuffer(dst, Backend::DYNAMIC_SEPERATE);
} else {
backend->onReleaseBuffer(dst, Backend::DYNAMIC);
}
}
return result;
}
ErrorCode WrapExecution::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
MNN_ASSERT(mWrapInputTensors.size() == inputs.size());
// copy variant tensors
for (auto& iter : mInputMaps) {
2020-11-05 16:41:56 +08:00
auto converter = std::get<1>(iter.second);
auto src = iter.first;
auto dst = std::get<2>(iter.second).get();
if (TensorUtils::getDescribe(src)->usage != TensorUsage::CONSTANT || (!mStatic)) {
2019-04-17 10:49:11 +08:00
converter->onCopyBuffer(src, dst);
}
}
auto code = mExecution->onExecute(mWrapInputTensors, outputs);
return code;
2019-04-17 10:49:11 +08:00
}
} // namespace MNN