MNN/source/core/Session.cpp

246 lines
7.3 KiB
C++
Raw Normal View History

2019-04-17 10:49:11 +08:00
//
// Session.cpp
// MNN
//
// Created by MNN on 2018/07/30.
// Copyright © 2018, Alibaba Group Holding Limited
//
2019-12-27 22:16:57 +08:00
#include "core/Session.hpp"
2019-04-17 10:49:11 +08:00
#include <string.h>
#include <MNN/AutoTime.hpp>
2019-04-17 10:49:11 +08:00
#include <map>
#include <set>
#include "MNN_generated.h"
2019-12-27 22:16:57 +08:00
#include "core/AutoStorage.h"
2020-11-05 16:41:56 +08:00
#include "core/RuntimeFactory.hpp"
2019-12-27 22:16:57 +08:00
#include "core/TensorUtils.hpp"
#include "core/WrapExecution.hpp"
2019-04-17 10:49:11 +08:00
using namespace std;
namespace MNN {
2020-11-05 16:41:56 +08:00
Session::Session(Schedule::ScheduleInfo&& info, Interpreter::SessionMode callBackMode,
Interpreter::SessionMode inputMode, RuntimeInfo&& runtime) {
mRuntime = std::move(runtime);
2019-04-17 10:49:11 +08:00
if (info.pipelineInfo.empty()) {
mValid = false;
return;
}
2020-11-05 16:41:56 +08:00
Backend::Info defaultInfo;
defaultInfo.type = MNN_FORWARD_CPU;
defaultInfo.numThread = 1;
mTensors = std::move(info.allTensors);
2019-04-17 10:49:11 +08:00
for (auto& iter : info.pipelineInfo) {
auto rt = mRuntime.first.find(iter.first.type)->second.get();
2020-11-05 16:41:56 +08:00
auto cpuRuntime = mRuntime.second;
std::shared_ptr<Backend> first(rt->onCreate());
2020-11-05 16:41:56 +08:00
std::shared_ptr<Backend> second;
if (first->type() == MNN_FORWARD_CPU) {
second = first;
} else {
second.reset(cpuRuntime->onCreate());
}
std::shared_ptr<Pipeline> newPipeline(new Pipeline(std::move(iter.second), first, second, inputMode == Interpreter::Session_Input_Inside, rt->onGetCompilerType() == Runtime::Compiler_Geometry));
2019-04-17 10:49:11 +08:00
mPipelines.emplace_back(std::move(newPipeline));
}
2020-11-05 16:41:56 +08:00
mInputs = std::move(info.inputTensors);
mOutputs = std::move(info.outputTensor);
mCallBackMode = callBackMode;
2019-04-17 10:49:11 +08:00
}
Session::~Session() {
for (auto& t : mTensors) {
TensorUtils::clearHandleData(t.second.get());
}
mPipelines.clear();
2020-11-05 16:41:56 +08:00
mRuntime.first.clear();
mTensors.clear();
2020-11-05 16:41:56 +08:00
mRuntime.second = nullptr;
}
bool Session::loadCache(const void* buffer, size_t size) {
for (auto iter : mRuntime.first) {
auto res = iter.second->onSetCache(buffer, size);
if (res) {
return true;
}
}
return false;
}
std::pair<const void*, size_t> Session::getCache() {
for (auto iter : mRuntime.first) {
auto res = iter.second->onGetCache();
if (res.first != nullptr) {
return res;
}
}
return std::make_pair(nullptr, 0);
2019-04-17 10:49:11 +08:00
}
2021-01-06 16:29:37 +08:00
void Session::cloneExecution(const std::map<const Op*, std::shared_ptr<Execution>>& cache, int pipelineIndex) {
mPipelines[pipelineIndex]->cloneExecution(cache);
}
const std::map<const Op*, std::shared_ptr<Execution>>& Session::getExecution(int pipelineIndex) {
return mPipelines[pipelineIndex]->getCache();
}
2019-04-17 10:49:11 +08:00
ErrorCode Session::run() const {
2019-08-08 14:42:14 +08:00
if (mNeedResize) {
MNN_ERROR("Can't run session because not resized\n");
2019-08-08 14:42:14 +08:00
return COMPUTE_SIZE_ERROR;
}
2019-04-17 10:49:11 +08:00
for (auto& iter : mPipelines) {
auto error = iter->execute();
if (NO_ERROR != error) {
return error;
}
}
return NO_ERROR;
}
ErrorCode Session::runWithCallBack(const TensorCallBackWithInfo& before, const TensorCallBackWithInfo& end,
bool sync) const {
2019-08-08 14:42:14 +08:00
if (mNeedResize) {
MNN_ERROR("Can't run session because not resized\n");
2019-08-08 14:42:14 +08:00
return COMPUTE_SIZE_ERROR;
}
2019-04-17 10:49:11 +08:00
for (auto& iter : mPipelines) {
auto error = iter->executeCallBack(before, end);
if (NO_ERROR != error) {
return error;
}
}
return NO_ERROR;
}
void Session::_clearCache() {
for (auto& t : mTensors) {
auto describe = TensorUtils::getDescribe(t.second.get());
TensorUtils::clearHandleData(t.second.get());
2020-11-05 16:41:56 +08:00
describe->useCount = 0;
2019-04-17 10:49:11 +08:00
describe->backend = nullptr;
2020-11-05 16:41:56 +08:00
describe->regions.clear();
2019-04-17 10:49:11 +08:00
}
}
2020-11-05 16:41:56 +08:00
ErrorCode Session::resize(bool isStatic) {
for (auto& iter : mRuntime.first) {
iter.second->onGabageCollect(100);
2019-04-17 10:49:11 +08:00
}
2020-11-05 16:41:56 +08:00
if (!isStatic) {
_clearCache();
}
bool debug = mCallBackMode == Interpreter::Session_Debug;
// Turn Pipeline to Command Buffer and Malloc resource
// TODO: Seperate Schedule and Malloc
2019-04-17 10:49:11 +08:00
for (auto& iter : mPipelines) {
2020-11-05 16:41:56 +08:00
auto error = iter->encode(isStatic);
if (NO_ERROR != error) {
return error;
}
error = iter->allocMemory(debug);
2019-04-17 10:49:11 +08:00
if (NO_ERROR != error) {
return error;
}
}
mNeedResize = false;
2020-11-05 16:41:56 +08:00
for (auto& iter : mRuntime.first) {
iter.second->onGabageCollect(0);
2019-04-17 10:49:11 +08:00
}
return NO_ERROR;
}
2020-11-05 16:41:56 +08:00
bool Session::getInfo(Interpreter::SessionInfoCode code, void* ptr) const {
switch (code) {
case Interpreter::MEMORY: {
auto dst = (float*)ptr;
float summer = mRuntime.second->onGetMemoryInMB();
for (auto& r : mRuntime.first) {
summer += r.second->onGetMemoryInMB();
}
*dst = summer;
return true;
} break;
// TODO: Support other debug info
default:
break;
}
return false;
}
2019-04-17 10:49:11 +08:00
const Backend* Session::getBackEnd(const Tensor* tensor) const {
return TensorUtils::getDescribe(tensor)->backend;
}
Tensor* Session::getInput(const char* name) const {
MNN_ASSERT(!mInputs.empty());
if (nullptr == name) {
return mInputs.begin()->second;
}
auto iter = mInputs.find(name);
if (iter == mInputs.end()) {
MNN_PRINT("Error: can't find input: %s\n", name);
return nullptr;
}
return iter->second;
}
Tensor* Session::getOutput(const char* name) const {
MNN_ASSERT(!mOutputs.empty());
if (nullptr == name) {
return mOutputs.begin()->second;
}
auto iter = mOutputs.find(name);
if (iter == mOutputs.end()) {
MNN_PRINT("Error: can't find output: %s\n", name);
return nullptr;
}
return iter->second;
}
const std::map<std::string, Tensor*>& Session::getInputAll() const {
return mInputs;
}
const std::map<std::string, Tensor*>& Session::getOutputAll() const {
return mOutputs;
}
ErrorCode Session::releaseCache() {
return NO_ERROR;
}
ErrorCode Session::updateToModel(Net* net) const {
int opSize = net->oplists()->size();
for (int i = 0; i < opSize; ++i) {
auto op = net->oplists()->GetAs<Op>(i);
2020-11-05 16:41:56 +08:00
if ((net->usage() == Usage_INFERENCE || net->usage() == Usage_INFERENCE_STATIC) && op->type() != OpType_Const) {
2019-12-27 22:16:57 +08:00
continue;
}
if (net->usage() == Usage_TRAIN && op->type() != OpType_TrainableParam) {
continue;
}
if (!op->outputIndexes() || op->outputIndexes()->size() != 1) {
continue;
}
auto index = op->outputIndexes()->data()[0];
auto blob = op->main_as_Blob();
if (blob->dataType() != DataType_DT_FLOAT) {
continue;
}
2019-12-27 22:16:57 +08:00
std::shared_ptr<Tensor> tensor = mTensors[index].second;
if (tensor->host<void>() == nullptr && tensor->deviceId() != 0) {
tensor.reset(Tensor::createHostTensorFromDevice(tensor.get(), true));
if (tensor.get() == nullptr) {
MNN_ERROR("failed to copy trained param from device to host\n");
return INVALID_VALUE;
}
}
::memcpy((void*)blob->float32s()->data(), tensor->host<float>(), tensor->size());
}
return NO_ERROR;
}
2019-04-17 10:49:11 +08:00
} // namespace MNN