2019-04-17 10:49:11 +08:00
|
|
|
//
|
|
|
|
// Session.cpp
|
|
|
|
// MNN
|
|
|
|
//
|
|
|
|
// Created by MNN on 2018/07/30.
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
//
|
|
|
|
|
2019-12-27 22:16:57 +08:00
|
|
|
#include "core/Session.hpp"
|
2023-04-11 11:12:00 +08:00
|
|
|
#include "core/WrapExecution.hpp"
|
2019-04-17 10:49:11 +08:00
|
|
|
#include <string.h>
|
2020-03-22 20:16:29 +08:00
|
|
|
#include <MNN/AutoTime.hpp>
|
2019-04-17 10:49:11 +08:00
|
|
|
#include <map>
|
|
|
|
#include <set>
|
2020-03-22 20:16:29 +08:00
|
|
|
#include "MNN_generated.h"
|
2019-12-27 22:16:57 +08:00
|
|
|
#include "core/AutoStorage.h"
|
2020-11-05 16:41:56 +08:00
|
|
|
#include "core/RuntimeFactory.hpp"
|
2019-12-27 22:16:57 +08:00
|
|
|
#include "core/TensorUtils.hpp"
|
2022-12-30 15:18:58 +08:00
|
|
|
#include "utils/InitNet.hpp"
|
2019-04-17 10:49:11 +08:00
|
|
|
|
|
|
|
namespace MNN {
|
2024-09-12 12:57:57 +08:00
|
|
|
void Session::createPipelineBackend(Schedule::PipelineInfo& iter, RuntimeInfo& runtime) {
|
2022-12-30 15:18:58 +08:00
|
|
|
if (iter.first.cache.first != nullptr) {
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
auto rt = runtime.first.find(iter.first.info.type)->second.get();
|
|
|
|
auto cpuRuntime = runtime.second;
|
|
|
|
bool specialUsage = false;
|
|
|
|
if (iter.first.info.user != nullptr) {
|
|
|
|
specialUsage = iter.first.info.user->flags > 0;
|
|
|
|
}
|
|
|
|
iter.first.cache.first.reset(rt->onCreate(iter.first.info.user));
|
|
|
|
std::shared_ptr<Backend> second;
|
|
|
|
if (iter.first.cache.first->type() == MNN_FORWARD_CPU && (!specialUsage)) {
|
|
|
|
iter.first.cache.second = iter.first.cache.first;
|
|
|
|
} else {
|
|
|
|
// Const Backend shouldn't be used as default backend
|
|
|
|
// The session may be schedule multi-thread but const backend is the same
|
|
|
|
// We need create a new backend to do size compute / not support op compute
|
|
|
|
BackendConfig defaultConfig;
|
|
|
|
defaultConfig.flags = 4;
|
2024-09-12 12:57:57 +08:00
|
|
|
if (iter.first.info.user != nullptr) {
|
|
|
|
// Don't change default Precision
|
|
|
|
defaultConfig.memory = iter.first.info.user->memory;
|
|
|
|
defaultConfig.power = iter.first.info.user->power;
|
|
|
|
}
|
|
|
|
Backend* origin = nullptr;
|
|
|
|
if (cpuRuntime.get() == rt) {
|
|
|
|
origin = iter.first.cache.first.get();
|
|
|
|
}
|
|
|
|
iter.first.cache.second.reset(cpuRuntime->onCreate(&defaultConfig, origin));
|
2022-12-30 15:18:58 +08:00
|
|
|
}
|
|
|
|
}
|
2024-06-15 15:39:59 +08:00
|
|
|
void Session::ModeGroup::setMode(Interpreter::SessionMode mode) {
|
|
|
|
if (mode == Interpreter::Session_Input_Inside || mode == Interpreter::Session_Input_User) {
|
|
|
|
inputMode = mode;
|
|
|
|
} else if (mode == Interpreter::Session_Output_User || mode == Interpreter::Session_Output_Inside) {
|
|
|
|
outputMode = mode;
|
|
|
|
} else if (mode == Interpreter::Session_Backend_Auto || mode == Interpreter::Session_Backend_Fix) {
|
|
|
|
backendMode = mode;
|
|
|
|
} else if (mode == Interpreter::Session_Debug || mode == Interpreter::Session_Release) {
|
|
|
|
callBackMode = mode;
|
|
|
|
} else if (mode == Interpreter::Session_Resize_Direct || mode == Interpreter::Session_Resize_Defer) {
|
|
|
|
resizeMode = mode;
|
|
|
|
} else if(mode == Interpreter::Session_Memory_Collect || mode == Interpreter::Session_Memory_Cache) {
|
|
|
|
memoryUsageMode = mode;
|
|
|
|
} else if(mode == Interpreter::Session_Codegen_Disable || mode == Interpreter::Session_Codegen_Enable) {
|
|
|
|
codegenMode = mode;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
void Session::ModeGroup::setHint(Interpreter::HintMode mode, int hint) {
|
|
|
|
switch (mode) {
|
|
|
|
case Interpreter::MAX_TUNING_NUMBER:
|
|
|
|
maxTuningNumber = hint;
|
|
|
|
break;
|
|
|
|
case Interpreter::MEM_ALLOCATOR_TYPE:
|
2024-07-22 19:51:53 +08:00
|
|
|
runtimeHint.memoryAllocatorType = hint;
|
2024-06-15 15:39:59 +08:00
|
|
|
break;
|
|
|
|
case Interpreter::WINOGRAD_MEMORY_LEVEL:
|
2024-07-22 19:51:53 +08:00
|
|
|
runtimeHint.winogradMemoryUsed = hint;
|
|
|
|
break;
|
|
|
|
case Interpreter::CPU_LITTLECORE_DECREASE_RATE:
|
|
|
|
runtimeHint.cpuDecreaseRate = hint;
|
2024-06-15 15:39:59 +08:00
|
|
|
break;
|
|
|
|
case Interpreter::GEOMETRY_COMPUTE_MASK:
|
|
|
|
geometryMask = hint;
|
|
|
|
break;
|
|
|
|
case Interpreter::STRICT_CHECK_MODEL:
|
|
|
|
checkNetBuffer = hint > 0;
|
|
|
|
break;
|
2024-07-22 19:51:53 +08:00
|
|
|
case Interpreter::DYNAMIC_QUANT_OPTIONS:
|
|
|
|
runtimeHint.dynamicQuantOption = hint;
|
|
|
|
break;
|
2024-09-12 12:57:57 +08:00
|
|
|
case Interpreter::QKV_QUANT_OPTIONS:
|
|
|
|
runtimeHint.qkvQuantOption = hint;
|
2024-07-22 19:51:53 +08:00
|
|
|
break;
|
2024-08-24 15:46:21 +08:00
|
|
|
case Interpreter::KVCACHE_SIZE_LIMIT:
|
|
|
|
runtimeHint.kvcacheSizeLimit = hint;
|
|
|
|
break;
|
2024-10-14 19:26:28 +08:00
|
|
|
case Interpreter::OP_ENCODER_NUMBER_FOR_COMMIT:
|
|
|
|
runtimeHint.encorderNumForCommit = hint;
|
2025-01-22 14:47:50 +08:00
|
|
|
break;
|
2024-12-31 15:34:08 +08:00
|
|
|
case Interpreter::MMAP_FILE_SIZE:
|
|
|
|
runtimeHint.mmapFileSize = hint;
|
2025-01-22 14:47:50 +08:00
|
|
|
break;
|
2024-12-31 15:34:08 +08:00
|
|
|
case Interpreter::USE_CACHED_MMAP:
|
|
|
|
runtimeHint.useCachedMmap = hint;
|
2025-01-22 14:47:50 +08:00
|
|
|
break;
|
2025-03-12 11:35:16 +08:00
|
|
|
case Interpreter::INIT_THREAD_NUMBER:
|
|
|
|
runtimeHint.initThreadNumber = hint;
|
|
|
|
break;
|
2024-06-15 15:39:59 +08:00
|
|
|
default:
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
2024-08-24 15:46:21 +08:00
|
|
|
|
|
|
|
void Session::ModeGroup::setExternalPath(std::string path, int type) {
|
|
|
|
switch (type) {
|
|
|
|
case MNN::Interpreter::EXTERNAL_PATH_KVCACHE_DIR:
|
|
|
|
runtimeHint.kvcacheDirPath = path;
|
|
|
|
break;
|
2024-09-12 12:57:57 +08:00
|
|
|
case MNN::Interpreter::EXTERNAL_FEATUREMAP_DIR:
|
|
|
|
runtimeHint.midMemoryPath = path;
|
|
|
|
break;
|
|
|
|
case MNN::Interpreter::EXTERNAL_WEIGHT_DIR:
|
|
|
|
runtimeHint.weightMemoryPath = path;
|
|
|
|
break;
|
2024-08-24 15:46:21 +08:00
|
|
|
default:
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2022-01-04 10:50:40 +08:00
|
|
|
Session::Session(Schedule::ScheduleInfo&& info, const ModeGroup& mode, RuntimeInfo&& runtime) {
|
2022-12-30 15:18:58 +08:00
|
|
|
mMode = mode;
|
2020-11-05 16:41:56 +08:00
|
|
|
mRuntime = std::move(runtime);
|
2019-04-17 10:49:11 +08:00
|
|
|
if (info.pipelineInfo.empty()) {
|
|
|
|
mValid = false;
|
|
|
|
return;
|
|
|
|
}
|
2022-12-30 15:18:58 +08:00
|
|
|
mInfo = std::move(info);
|
|
|
|
for (auto& iter : mInfo.pipelineInfo) {
|
2024-09-12 12:57:57 +08:00
|
|
|
createPipelineBackend(iter, mRuntime);
|
2022-01-04 10:50:40 +08:00
|
|
|
Pipeline::TuningAttr attr;
|
|
|
|
attr.maxTuningNumber = mode.maxTuningNumber;
|
|
|
|
attr.autoSetOpType = mode.backendMode == Interpreter::Session_Backend_Auto;
|
2022-12-30 15:18:58 +08:00
|
|
|
auto rt = mRuntime.first.find(iter.first.info.type)->second.get();
|
|
|
|
auto cpuRuntime = mRuntime.second;
|
2024-11-18 14:37:45 +08:00
|
|
|
auto geoMask = mMode.geometryMask;
|
|
|
|
if (rt->onGetCompilerType() != Runtime::Compiler_Loop) {
|
|
|
|
geoMask = 0;
|
|
|
|
}
|
|
|
|
std::shared_ptr<Pipeline> newPipeline(new Pipeline( mInfo.externalWeightPath, std::move(iter), mode.inputMode == Interpreter::Session_Input_Inside, mode.outputMode == Interpreter::Session_Output_User, attr, rt, cpuRuntime.get(), geoMask));
|
2019-04-17 10:49:11 +08:00
|
|
|
mPipelines.emplace_back(std::move(newPipeline));
|
|
|
|
}
|
2022-01-04 10:50:40 +08:00
|
|
|
mCallBackMode = mode.callBackMode;
|
2023-04-27 15:11:05 +08:00
|
|
|
mMemoryUsageMode = mode.memoryUsageMode;
|
2023-07-18 09:36:26 +08:00
|
|
|
mCodegenMode = mode.codegenMode;
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
Session::~Session() {
|
2023-01-11 15:08:58 +08:00
|
|
|
for (auto& iter : mRuntime.first) {
|
|
|
|
iter.second->mCancelled = true;
|
|
|
|
}
|
2022-01-04 10:50:40 +08:00
|
|
|
waitAsyncResize();
|
2022-12-30 15:18:58 +08:00
|
|
|
mInfo.allTensors.clear();
|
2019-09-01 19:25:26 +08:00
|
|
|
mPipelines.clear();
|
2020-11-05 16:41:56 +08:00
|
|
|
mRuntime.first.clear();
|
|
|
|
mRuntime.second = nullptr;
|
|
|
|
}
|
2022-12-30 15:18:58 +08:00
|
|
|
Schedule::PipelineInfo& Session::getPipelineInfo(int index) const {
|
|
|
|
MNN_ASSERT(index >= 0);
|
|
|
|
MNN_ASSERT(index < mPipelines.size());
|
|
|
|
return mPipelines[index]->getPipelineInfo();
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
|
|
|
|
bool Session::loadCache(const void* buffer, size_t size) {
|
|
|
|
for (auto iter : mRuntime.first) {
|
|
|
|
auto res = iter.second->onSetCache(buffer, size);
|
|
|
|
if (res) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
2022-01-04 10:50:40 +08:00
|
|
|
void Session::waitAsyncResize() {
|
|
|
|
for (auto& iter : mRuntime.first) {
|
|
|
|
iter.second->waitAsyncWork();
|
|
|
|
}
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
|
2023-01-11 15:08:58 +08:00
|
|
|
bool Session::hasAsyncWork() {
|
|
|
|
for (auto& iter : mRuntime.first) {
|
|
|
|
auto res = iter.second->hasAsyncWork();
|
|
|
|
if (res) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
std::pair<const void*, size_t> Session::getCache() {
|
2023-01-11 15:08:58 +08:00
|
|
|
// Set cancelled for quickly ending
|
|
|
|
for (auto& iter : mRuntime.first) {
|
|
|
|
iter.second->mCancelled = true;
|
|
|
|
}
|
2022-01-04 10:50:40 +08:00
|
|
|
waitAsyncResize();
|
2023-01-11 15:08:58 +08:00
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
for (auto iter : mRuntime.first) {
|
|
|
|
auto res = iter.second->onGetCache();
|
|
|
|
if (res.first != nullptr) {
|
|
|
|
return res;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return std::make_pair(nullptr, 0);
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
2021-11-30 10:10:53 +08:00
|
|
|
|
2019-04-17 10:49:11 +08:00
|
|
|
ErrorCode Session::run() const {
|
2019-08-08 14:42:14 +08:00
|
|
|
if (mNeedResize) {
|
2020-03-22 20:16:29 +08:00
|
|
|
MNN_ERROR("Can't run session because not resized\n");
|
2019-08-08 14:42:14 +08:00
|
|
|
return COMPUTE_SIZE_ERROR;
|
|
|
|
}
|
2019-04-17 10:49:11 +08:00
|
|
|
for (auto& iter : mPipelines) {
|
|
|
|
auto error = iter->execute();
|
|
|
|
if (NO_ERROR != error) {
|
|
|
|
return error;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
|
|
|
ErrorCode Session::runWithCallBack(const TensorCallBackWithInfo& before, const TensorCallBackWithInfo& end,
|
|
|
|
bool sync) const {
|
2019-08-08 14:42:14 +08:00
|
|
|
if (mNeedResize) {
|
2020-03-22 20:16:29 +08:00
|
|
|
MNN_ERROR("Can't run session because not resized\n");
|
2019-08-08 14:42:14 +08:00
|
|
|
return COMPUTE_SIZE_ERROR;
|
|
|
|
}
|
2019-04-17 10:49:11 +08:00
|
|
|
for (auto& iter : mPipelines) {
|
|
|
|
auto error = iter->executeCallBack(before, end);
|
|
|
|
if (NO_ERROR != error) {
|
|
|
|
return error;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2022-12-30 15:18:58 +08:00
|
|
|
ErrorCode Session::resize() {
|
2022-09-30 10:02:52 +08:00
|
|
|
#ifdef LOG_VERBOSE
|
2024-02-05 15:09:01 +08:00
|
|
|
for (auto& iter : mInfo.inputTensors) {
|
2022-09-30 10:02:52 +08:00
|
|
|
auto& inputTensor = iter.second;
|
|
|
|
MNN_PRINT("before resize, input name:%s, ptr:%p, hostPtr:%p, shape:", iter.first.c_str(), inputTensor, inputTensor->host<void>());
|
|
|
|
inputTensor->printShape();
|
|
|
|
MNN_PRINT("\n");
|
|
|
|
}
|
|
|
|
#endif
|
2023-07-18 09:36:26 +08:00
|
|
|
bool permitCodegen = mCodegenMode == Interpreter::Session_Codegen_Enable;
|
|
|
|
|
2021-11-30 10:10:53 +08:00
|
|
|
bool firstMalloc = false;
|
2021-04-08 15:34:23 +08:00
|
|
|
if (mNeedResize) {
|
|
|
|
bool debug = mCallBackMode == Interpreter::Session_Debug;
|
|
|
|
for (auto& iter : mPipelines) {
|
2023-07-18 09:36:26 +08:00
|
|
|
auto error = iter->encode(debug, permitCodegen);
|
2021-04-08 15:34:23 +08:00
|
|
|
if (NO_ERROR != error) {
|
|
|
|
return error;
|
|
|
|
}
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
2021-04-08 15:34:23 +08:00
|
|
|
mNeedResize = false;
|
|
|
|
mNeedMalloc = true;
|
2021-11-30 10:10:53 +08:00
|
|
|
firstMalloc = true;
|
2021-04-08 15:34:23 +08:00
|
|
|
}
|
|
|
|
if (mNeedMalloc) {
|
|
|
|
// Set needResize = true for easy for judge in runSession when error
|
|
|
|
mNeedResize = true;
|
|
|
|
// Turn Pipeline to Command Buffer and Malloc resource
|
2022-05-27 23:46:44 +08:00
|
|
|
// TODO: Separate Schedule and Malloc
|
2023-09-04 10:42:11 +08:00
|
|
|
bool forbidReplace = permitCodegen;
|
|
|
|
if (mInfo.constReplaceBackend != nullptr) {
|
|
|
|
forbidReplace = true;
|
|
|
|
}
|
2021-04-08 15:34:23 +08:00
|
|
|
for (auto& iter : mPipelines) {
|
2023-09-04 10:42:11 +08:00
|
|
|
auto error = iter->allocMemory(firstMalloc, forbidReplace);
|
2021-04-08 15:34:23 +08:00
|
|
|
if (NO_ERROR != error) {
|
|
|
|
return error;
|
|
|
|
}
|
|
|
|
}
|
2023-04-27 15:11:05 +08:00
|
|
|
if(mMemoryUsageMode == Interpreter::Session_Memory_Collect) {
|
2024-08-24 15:46:21 +08:00
|
|
|
mRuntime.second->onGabageCollect(0);
|
2023-04-27 15:11:05 +08:00
|
|
|
for (auto& iter : mRuntime.first) {
|
|
|
|
iter.second->onGabageCollect(0);
|
|
|
|
}
|
2021-04-08 15:34:23 +08:00
|
|
|
}
|
|
|
|
mNeedMalloc = false;
|
|
|
|
mNeedResize = false;
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
2022-09-30 10:02:52 +08:00
|
|
|
|
|
|
|
#ifdef LOG_VERBOSE
|
|
|
|
MNN_PRINT("session after resize\n");
|
2024-02-05 15:09:01 +08:00
|
|
|
for (auto& iter : mInfo.outputTensor) {
|
2022-09-30 10:02:52 +08:00
|
|
|
auto& outputTensor = iter.second;
|
|
|
|
MNN_PRINT("output name:%s, ptr:%p,shape:", iter.first.c_str(), outputTensor);
|
|
|
|
outputTensor->printShape();
|
|
|
|
MNN_PRINT("\n");
|
|
|
|
}
|
|
|
|
#endif
|
2019-04-17 10:49:11 +08:00
|
|
|
return NO_ERROR;
|
|
|
|
}
|
2024-04-19 11:58:21 +08:00
|
|
|
void Session::openResizeCheck() {
|
|
|
|
for (auto& iter : mPipelines) {
|
|
|
|
iter->openResizeCheck();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
ErrorCode Session::fixResizeCache() {
|
|
|
|
for (auto& iter : mPipelines) {
|
|
|
|
auto code = iter->fixResizeCache();
|
|
|
|
if (NO_ERROR != code) {
|
|
|
|
return code;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
2020-11-05 16:41:56 +08:00
|
|
|
bool Session::getInfo(Interpreter::SessionInfoCode code, void* ptr) const {
|
|
|
|
switch (code) {
|
|
|
|
case Interpreter::MEMORY: {
|
|
|
|
auto dst = (float*)ptr;
|
|
|
|
float summer = mRuntime.second->onGetMemoryInMB();
|
|
|
|
for (auto& r : mRuntime.first) {
|
2021-04-08 15:34:23 +08:00
|
|
|
if (r.second.get() != mRuntime.second.get()) {
|
|
|
|
summer += r.second->onGetMemoryInMB();
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
*dst = summer;
|
|
|
|
return true;
|
|
|
|
} break;
|
2021-06-11 17:17:13 +08:00
|
|
|
case Interpreter::BACKENDS: {
|
|
|
|
int pos = 0;
|
|
|
|
auto res = (int32_t*)ptr;
|
2022-01-04 10:50:40 +08:00
|
|
|
for (auto& r : mPipelines) {
|
|
|
|
auto type = r->getMainForwardType();
|
|
|
|
res[pos++] = type;
|
2021-06-11 17:17:13 +08:00
|
|
|
}
|
|
|
|
return true;
|
|
|
|
} break;
|
|
|
|
case Interpreter::FLOPS: {
|
|
|
|
float flo = 0.0f;
|
|
|
|
for (auto& iter : mPipelines) {
|
|
|
|
flo += iter->flops();
|
|
|
|
}
|
|
|
|
auto dst = (float*)ptr;
|
|
|
|
*dst = flo;
|
|
|
|
return true;
|
|
|
|
} break;
|
2022-06-10 10:39:50 +08:00
|
|
|
case Interpreter::RESIZE_STATUS: {
|
|
|
|
auto dst = (int*)ptr;
|
|
|
|
if (mNeedResize) {
|
|
|
|
*dst = 2;
|
|
|
|
} else if (mNeedMalloc) {
|
|
|
|
*dst = 1;
|
|
|
|
} else {
|
|
|
|
*dst = 0;
|
|
|
|
}
|
2023-05-18 19:11:50 +08:00
|
|
|
return true;
|
2022-06-10 10:39:50 +08:00
|
|
|
} break;
|
2023-05-18 19:11:50 +08:00
|
|
|
case Interpreter::THREAD_NUMBER: {
|
|
|
|
auto dst = (int*)ptr;
|
|
|
|
if (mPipelines.empty()) {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
*dst = mPipelines[0]->getPipelineInfo().first.info.numThread;
|
|
|
|
return true;
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
// TODO: Support other debug info
|
|
|
|
default:
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
2019-04-17 10:49:11 +08:00
|
|
|
|
|
|
|
const Backend* Session::getBackEnd(const Tensor* tensor) const {
|
2024-04-19 11:58:21 +08:00
|
|
|
return TensorUtils::getDescribeOrigin(tensor)->getBackend();
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
Tensor* Session::getInput(const char* name) const {
|
2021-02-07 10:45:07 +08:00
|
|
|
//MNN_ASSERT(!mInputs.empty());
|
2019-04-17 10:49:11 +08:00
|
|
|
if (nullptr == name) {
|
2022-12-30 15:18:58 +08:00
|
|
|
return mInfo.inputTensors.begin()->second;
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
2022-12-30 15:18:58 +08:00
|
|
|
auto iter = mInfo.inputTensors.find(name);
|
|
|
|
if (iter == mInfo.inputTensors.end()) {
|
2019-04-17 10:49:11 +08:00
|
|
|
MNN_PRINT("Error: can't find input: %s\n", name);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return iter->second;
|
|
|
|
}
|
2022-12-30 15:18:58 +08:00
|
|
|
Tensor* Session::getTensor(int index) const {
|
|
|
|
return mInfo.allTensors[index].get();
|
|
|
|
}
|
2019-04-17 10:49:11 +08:00
|
|
|
|
|
|
|
Tensor* Session::getOutput(const char* name) const {
|
2022-12-30 15:18:58 +08:00
|
|
|
MNN_ASSERT(!mInfo.outputTensor.empty());
|
2019-04-17 10:49:11 +08:00
|
|
|
if (nullptr == name) {
|
2022-12-30 15:18:58 +08:00
|
|
|
return mInfo.outputTensor.begin()->second;
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
|
|
|
|
2022-12-30 15:18:58 +08:00
|
|
|
auto iter = mInfo.outputTensor.find(name);
|
|
|
|
if (iter == mInfo.outputTensor.end()) {
|
2019-04-17 10:49:11 +08:00
|
|
|
MNN_PRINT("Error: can't find output: %s\n", name);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
return iter->second;
|
|
|
|
}
|
|
|
|
|
|
|
|
const std::map<std::string, Tensor*>& Session::getInputAll() const {
|
2022-12-30 15:18:58 +08:00
|
|
|
return mInfo.inputTensors;
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
const std::map<std::string, Tensor*>& Session::getOutputAll() const {
|
2022-12-30 15:18:58 +08:00
|
|
|
return mInfo.outputTensor;
|
2019-04-17 10:49:11 +08:00
|
|
|
}
|
|
|
|
|
2019-06-17 20:10:35 +08:00
|
|
|
ErrorCode Session::updateToModel(Net* net) const {
|
2021-02-07 10:45:07 +08:00
|
|
|
if (mNeedResize) {
|
|
|
|
return NOT_SUPPORT;
|
|
|
|
}
|
2019-06-17 20:10:35 +08:00
|
|
|
int opSize = net->oplists()->size();
|
|
|
|
for (int i = 0; i < opSize; ++i) {
|
|
|
|
auto op = net->oplists()->GetAs<Op>(i);
|
2023-12-04 11:12:20 +08:00
|
|
|
if (op->type() != OpType_Const && op->type() != OpType_TrainableParam) {
|
2019-06-17 20:10:35 +08:00
|
|
|
continue;
|
|
|
|
}
|
|
|
|
if (!op->outputIndexes() || op->outputIndexes()->size() != 1) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
auto index = op->outputIndexes()->data()[0];
|
|
|
|
auto blob = op->main_as_Blob();
|
|
|
|
if (blob->dataType() != DataType_DT_FLOAT) {
|
|
|
|
continue;
|
|
|
|
}
|
2022-12-30 15:18:58 +08:00
|
|
|
std::shared_ptr<Tensor> tensor = mInfo.allTensors[index];
|
2023-04-11 11:12:00 +08:00
|
|
|
if (WrapExecution::needWrap(tensor.get(), nullptr)) {
|
2019-12-27 22:16:57 +08:00
|
|
|
tensor.reset(Tensor::createHostTensorFromDevice(tensor.get(), true));
|
|
|
|
if (tensor.get() == nullptr) {
|
|
|
|
MNN_ERROR("failed to copy trained param from device to host\n");
|
|
|
|
return INVALID_VALUE;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
::memcpy((void*)blob->float32s()->data(), tensor->host<float>(), tensor->size());
|
2019-06-17 20:10:35 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return NO_ERROR;
|
|
|
|
}
|
|
|
|
|
2022-12-30 15:18:58 +08:00
|
|
|
static void initTensors(std::vector<std::shared_ptr<Tensor>>& tensors, const std::vector<std::shared_ptr<Tensor>>& tensorSrc) {
|
|
|
|
for (int i=0; i<tensors.size(); ++i) {
|
2024-08-24 15:46:21 +08:00
|
|
|
if (tensorSrc[i].get() == nullptr) {
|
|
|
|
continue;
|
|
|
|
}
|
2022-12-30 15:18:58 +08:00
|
|
|
// Init all tensor except for const
|
|
|
|
if (tensors[i].get() == nullptr) {
|
|
|
|
tensors[i].reset(new Tensor);
|
|
|
|
TensorUtils::getDescribe(tensors[i].get())->index = i;
|
|
|
|
}
|
|
|
|
auto srcDes = TensorUtils::getDescribe(tensorSrc[i].get());
|
|
|
|
if (srcDes->quantAttr != nullptr) {
|
|
|
|
TensorUtils::getDescribe(tensors[i].get())->quantAttr.reset(new QuantAttr);
|
|
|
|
*TensorUtils::getDescribe(tensors[i].get())->quantAttr = *srcDes->quantAttr;
|
|
|
|
}
|
2023-09-04 10:42:11 +08:00
|
|
|
if (TensorUtils::getDescribe(tensors[i].get())->usage != Tensor::InsideDescribe::CONSTANT) {
|
|
|
|
TensorUtils::copyShape(tensorSrc[i].get(), tensors[i].get(), true);
|
|
|
|
}
|
2022-12-30 15:18:58 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
Session* Session::clone(RuntimeInfo&& runtime, std::shared_ptr<Schedule::ScheduleInfo> sharedConst) {
|
|
|
|
// TODO: Currently only valid for Module api's onClone
|
|
|
|
Schedule::ScheduleInfo scheduleInfo;
|
|
|
|
scheduleInfo.defaultBackend = mInfo.defaultBackend;
|
|
|
|
scheduleInfo.pipelineInfo.resize(1);
|
2024-04-19 11:58:21 +08:00
|
|
|
scheduleInfo.externalWeightPath = mInfo.externalWeightPath;
|
2022-12-30 15:18:58 +08:00
|
|
|
Session::ModeGroup modes;
|
|
|
|
scheduleInfo.defaultBackend = sharedConst->defaultBackend;
|
2023-09-04 10:42:11 +08:00
|
|
|
scheduleInfo.constReplaceBackend = sharedConst->constReplaceBackend;
|
2022-12-30 15:18:58 +08:00
|
|
|
scheduleInfo.allTensors = sharedConst->allTensors;
|
|
|
|
initTensors(scheduleInfo.allTensors, mInfo.allTensors);
|
|
|
|
MNN_ASSERT(1 == mPipelines.size());
|
|
|
|
auto& srcPipelineInfo = mPipelines[0]->getPipelineInfo();
|
|
|
|
auto& opCaches = srcPipelineInfo.second;
|
|
|
|
auto& pipelineInfo = scheduleInfo.pipelineInfo[0];
|
|
|
|
pipelineInfo.first.info = srcPipelineInfo.first.info;
|
|
|
|
pipelineInfo.first.config = srcPipelineInfo.first.config;
|
|
|
|
pipelineInfo.first.info.user = &pipelineInfo.first.config;
|
|
|
|
auto& oplists = pipelineInfo.second;
|
|
|
|
oplists.resize(opCaches.size());
|
2024-09-12 12:57:57 +08:00
|
|
|
createPipelineBackend(pipelineInfo, runtime);
|
2022-12-30 15:18:58 +08:00
|
|
|
auto first = pipelineInfo.first.cache.first;
|
|
|
|
auto second = pipelineInfo.first.cache.second;
|
|
|
|
for (int i=0; i<opCaches.size(); ++i) {
|
|
|
|
auto& srcOpInfo = opCaches[i];
|
|
|
|
auto& opInfo = oplists[i];
|
|
|
|
opInfo.op = opCaches[i].op;
|
|
|
|
opInfo.type = srcOpInfo.type;
|
2024-04-23 13:54:38 +08:00
|
|
|
opInfo.computeCache.copyImmutable(srcOpInfo.computeCache);
|
2022-12-30 15:18:58 +08:00
|
|
|
auto op = opInfo.op;
|
|
|
|
if (nullptr != op->outputIndexes()) {
|
|
|
|
auto data = op->outputIndexes()->data();
|
|
|
|
for (int j = 0; j < op->outputIndexes()->size(); ++j) {
|
|
|
|
opInfo.outputs.push_back(scheduleInfo.allTensors[data[j]].get());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (nullptr != op->inputIndexes()) {
|
|
|
|
auto data = op->inputIndexes()->data();
|
|
|
|
for (int j = 0; j < op->inputIndexes()->size(); ++j) {
|
|
|
|
opInfo.inputs.push_back(scheduleInfo.allTensors[data[j]].get());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (int j=0; j<opInfo.inputs.size(); ++j) {
|
2023-09-04 10:42:11 +08:00
|
|
|
if (TensorUtils::getDescribe(opInfo.inputs[j])->usage != Tensor::InsideDescribe::CONSTANT) {
|
|
|
|
TensorUtils::getDescribe(opInfo.inputs[j])->usage = TensorUtils::getDescribe(srcOpInfo.inputs[j])->usage;
|
|
|
|
}
|
2022-12-30 15:18:58 +08:00
|
|
|
}
|
|
|
|
for (int j=0; j<opInfo.outputs.size(); ++j) {
|
|
|
|
TensorUtils::getDescribe(opInfo.outputs[j])->usage = TensorUtils::getDescribe(srcOpInfo.outputs[j])->usage;
|
|
|
|
}
|
|
|
|
// Clone cache
|
|
|
|
for (auto& iter : srcOpInfo.executionCache) {
|
|
|
|
Execution* copyExecution = nullptr;
|
|
|
|
bool valid = false;
|
|
|
|
if (first->type() == iter.second->backend()->type()) {
|
|
|
|
valid = iter.second->onClone(first.get(), iter.first, ©Execution);
|
|
|
|
} else {
|
|
|
|
valid = iter.second->onClone(second.get(), iter.first, ©Execution);
|
|
|
|
}
|
|
|
|
if (valid) {
|
|
|
|
std::shared_ptr<Execution> copyExeWrap(copyExecution);
|
|
|
|
opInfo.executionCache.insert(std::make_pair(iter.first, copyExeWrap));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto dst = new Session(std::move(scheduleInfo), mMode, std::move(runtime));
|
|
|
|
return dst;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2019-04-17 10:49:11 +08:00
|
|
|
} // namespace MNN
|