| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  Session.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2018/07/30.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | #include "core/Session.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | #include <string.h>
 | 
					
						
							| 
									
										
										
										
											2020-03-22 20:16:29 +08:00
										 |  |  | #include <MNN/AutoTime.hpp>
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | #include <map>
 | 
					
						
							|  |  |  | #include <set>
 | 
					
						
							| 
									
										
										
										
											2020-03-22 20:16:29 +08:00
										 |  |  | #include "MNN_generated.h"
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | #include "core/AutoStorage.h"
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | #include "core/RuntimeFactory.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  | #include "core/TensorUtils.hpp"
 | 
					
						
							|  |  |  | #include "core/WrapExecution.hpp"
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | using namespace std; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace MNN { | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | Session::Session(Schedule::ScheduleInfo&& info, Interpreter::SessionMode callBackMode, | 
					
						
							|  |  |  |                  Interpreter::SessionMode inputMode, RuntimeInfo&& runtime) { | 
					
						
							|  |  |  |     mRuntime = std::move(runtime); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     if (info.pipelineInfo.empty()) { | 
					
						
							|  |  |  |         mValid = false; | 
					
						
							|  |  |  |         return; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     Backend::Info defaultInfo; | 
					
						
							|  |  |  |     defaultInfo.type      = MNN_FORWARD_CPU; | 
					
						
							|  |  |  |     defaultInfo.numThread = 1; | 
					
						
							|  |  |  |     mTensors              = std::move(info.allTensors); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     for (auto& iter : info.pipelineInfo) { | 
					
						
							| 
									
										
										
										
											2020-11-18 10:48:38 +08:00
										 |  |  |         auto rt    = mRuntime.first.find(iter.first.type)->second.get(); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         auto cpuRuntime = mRuntime.second; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         std::shared_ptr<Backend> first(rt->onCreate(iter.first.user)); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         std::shared_ptr<Backend> second; | 
					
						
							|  |  |  |         if (first->type() == MNN_FORWARD_CPU) { | 
					
						
							|  |  |  |             second = first; | 
					
						
							|  |  |  |         } else { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |             BackendConfig defaultConfig; | 
					
						
							|  |  |  |             second.reset(cpuRuntime->onCreate(&defaultConfig)); | 
					
						
							| 
									
										
										
										
											2020-03-22 20:16:29 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2020-11-18 10:48:38 +08:00
										 |  |  |         std::shared_ptr<Pipeline> newPipeline(new Pipeline(std::move(iter.second), first, second, inputMode == Interpreter::Session_Input_Inside, rt->onGetCompilerType() == Runtime::Compiler_Geometry)); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         mPipelines.emplace_back(std::move(newPipeline)); | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     mInputs       = std::move(info.inputTensors); | 
					
						
							|  |  |  |     mOutputs      = std::move(info.outputTensor); | 
					
						
							|  |  |  |     mCallBackMode = callBackMode; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Session::~Session() { | 
					
						
							|  |  |  |     for (auto& t : mTensors) { | 
					
						
							|  |  |  |         TensorUtils::clearHandleData(t.second.get()); | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2019-09-01 19:25:26 +08:00
										 |  |  |     mPipelines.clear(); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     mRuntime.first.clear(); | 
					
						
							| 
									
										
										
										
											2019-09-01 19:25:26 +08:00
										 |  |  |     mTensors.clear(); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     mRuntime.second = nullptr; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | bool Session::loadCache(const void* buffer, size_t size) { | 
					
						
							|  |  |  |     for (auto iter : mRuntime.first) { | 
					
						
							|  |  |  |         auto res = iter.second->onSetCache(buffer, size); | 
					
						
							|  |  |  |         if (res) { | 
					
						
							|  |  |  |             return true; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return false; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | std::pair<const void*, size_t> Session::getCache() { | 
					
						
							|  |  |  |     for (auto iter : mRuntime.first) { | 
					
						
							|  |  |  |         auto res = iter.second->onGetCache(); | 
					
						
							|  |  |  |         if (res.first != nullptr) { | 
					
						
							|  |  |  |             return res; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return std::make_pair(nullptr, 0); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2021-01-06 16:29:37 +08:00
										 |  |  | void Session::cloneExecution(const std::map<const Op*, std::shared_ptr<Execution>>& cache, int pipelineIndex) { | 
					
						
							|  |  |  |     mPipelines[pipelineIndex]->cloneExecution(cache); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | const std::map<const Op*, std::shared_ptr<Execution>>& Session::getExecution(int pipelineIndex) { | 
					
						
							|  |  |  |     return mPipelines[pipelineIndex]->getCache(); | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | ErrorCode Session::run() const { | 
					
						
							| 
									
										
										
										
											2019-08-08 14:42:14 +08:00
										 |  |  |     if (mNeedResize) { | 
					
						
							| 
									
										
										
										
											2020-03-22 20:16:29 +08:00
										 |  |  |         MNN_ERROR("Can't run session because not resized\n"); | 
					
						
							| 
									
										
										
										
											2019-08-08 14:42:14 +08:00
										 |  |  |         return COMPUTE_SIZE_ERROR; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     for (auto& iter : mPipelines) { | 
					
						
							|  |  |  |         auto error = iter->execute(); | 
					
						
							|  |  |  |         if (NO_ERROR != error) { | 
					
						
							|  |  |  |             return error; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return NO_ERROR; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ErrorCode Session::runWithCallBack(const TensorCallBackWithInfo& before, const TensorCallBackWithInfo& end, | 
					
						
							|  |  |  |                                    bool sync) const { | 
					
						
							| 
									
										
										
										
											2019-08-08 14:42:14 +08:00
										 |  |  |     if (mNeedResize) { | 
					
						
							| 
									
										
										
										
											2020-03-22 20:16:29 +08:00
										 |  |  |         MNN_ERROR("Can't run session because not resized\n"); | 
					
						
							| 
									
										
										
										
											2019-08-08 14:42:14 +08:00
										 |  |  |         return COMPUTE_SIZE_ERROR; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     for (auto& iter : mPipelines) { | 
					
						
							|  |  |  |         auto error = iter->executeCallBack(before, end); | 
					
						
							|  |  |  |         if (NO_ERROR != error) { | 
					
						
							|  |  |  |             return error; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return NO_ERROR; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void Session::_clearCache() { | 
					
						
							|  |  |  |     for (auto& t : mTensors) { | 
					
						
							|  |  |  |         auto describe = TensorUtils::getDescribe(t.second.get()); | 
					
						
							|  |  |  |         TensorUtils::clearHandleData(t.second.get()); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         describe->useCount = 0; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         describe->backend  = nullptr; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         describe->regions.clear(); | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | ErrorCode Session::resize(bool isStatic) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |     if (mNeedResize) { | 
					
						
							|  |  |  |         if (!isStatic) { | 
					
						
							|  |  |  |             _clearCache(); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         bool debug = mCallBackMode == Interpreter::Session_Debug; | 
					
						
							|  |  |  |         for (auto& iter : mPipelines) { | 
					
						
							|  |  |  |             auto error = iter->encode(isStatic, debug); | 
					
						
							|  |  |  |             if (NO_ERROR != error) { | 
					
						
							|  |  |  |                 return error; | 
					
						
							|  |  |  |             } | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         mNeedResize = false; | 
					
						
							|  |  |  |         mNeedMalloc = true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (mNeedMalloc) { | 
					
						
							|  |  |  |         // Set needResize = true for easy for judge in runSession when error
 | 
					
						
							|  |  |  |         mNeedResize = true; | 
					
						
							|  |  |  |         // Turn Pipeline to Command Buffer and Malloc resource
 | 
					
						
							|  |  |  |         // TODO: Seperate Schedule and Malloc
 | 
					
						
							|  |  |  |         for (auto& iter : mPipelines) { | 
					
						
							|  |  |  |             auto error = iter->allocMemory(); | 
					
						
							|  |  |  |             if (NO_ERROR != error) { | 
					
						
							|  |  |  |                 return error; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (auto& iter : mRuntime.first) { | 
					
						
							|  |  |  |             iter.second->onGabageCollect(0); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         mNeedMalloc = false; | 
					
						
							|  |  |  |         mNeedResize = false; | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |     return NO_ERROR; | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | bool Session::getInfo(Interpreter::SessionInfoCode code, void* ptr) const { | 
					
						
							|  |  |  |     switch (code) { | 
					
						
							|  |  |  |         case Interpreter::MEMORY: { | 
					
						
							|  |  |  |             auto dst     = (float*)ptr; | 
					
						
							|  |  |  |             float summer = mRuntime.second->onGetMemoryInMB(); | 
					
						
							|  |  |  |             for (auto& r : mRuntime.first) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |                 if (r.second.get() != mRuntime.second.get()) { | 
					
						
							|  |  |  |                     summer += r.second->onGetMemoryInMB(); | 
					
						
							|  |  |  |                 } | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |             } | 
					
						
							|  |  |  |             *dst = summer; | 
					
						
							|  |  |  |             return true; | 
					
						
							|  |  |  |         } break; | 
					
						
							|  |  |  |         // TODO: Support other debug info
 | 
					
						
							|  |  |  |         default: | 
					
						
							|  |  |  |             break; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return false; | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | const Backend* Session::getBackEnd(const Tensor* tensor) const { | 
					
						
							|  |  |  |     return TensorUtils::getDescribe(tensor)->backend; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Tensor* Session::getInput(const char* name) const { | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |     //MNN_ASSERT(!mInputs.empty());
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  |     if (nullptr == name) { | 
					
						
							|  |  |  |         return mInputs.begin()->second; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     auto iter = mInputs.find(name); | 
					
						
							|  |  |  |     if (iter == mInputs.end()) { | 
					
						
							|  |  |  |         MNN_PRINT("Error: can't find input: %s\n", name); | 
					
						
							|  |  |  |         return nullptr; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return iter->second; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Tensor* Session::getOutput(const char* name) const { | 
					
						
							|  |  |  |     MNN_ASSERT(!mOutputs.empty()); | 
					
						
							|  |  |  |     if (nullptr == name) { | 
					
						
							|  |  |  |         return mOutputs.begin()->second; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     auto iter = mOutputs.find(name); | 
					
						
							|  |  |  |     if (iter == mOutputs.end()) { | 
					
						
							|  |  |  |         MNN_PRINT("Error: can't find output: %s\n", name); | 
					
						
							|  |  |  |         return nullptr; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return iter->second; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | const std::map<std::string, Tensor*>& Session::getInputAll() const { | 
					
						
							|  |  |  |     return mInputs; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | const std::map<std::string, Tensor*>& Session::getOutputAll() const { | 
					
						
							|  |  |  |     return mOutputs; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  | ErrorCode Session::updateToModel(Net* net) const { | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |     if (mNeedResize) { | 
					
						
							|  |  |  |         return NOT_SUPPORT; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |     int opSize = net->oplists()->size(); | 
					
						
							|  |  |  |     for (int i = 0; i < opSize; ++i) { | 
					
						
							|  |  |  |         auto op = net->oplists()->GetAs<Op>(i); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         if ((net->usage() == Usage_INFERENCE || net->usage() == Usage_INFERENCE_STATIC) && op->type() != OpType_Const) { | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (net->usage() == Usage_TRAIN && op->type() != OpType_TrainableParam) { | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (!op->outputIndexes() || op->outputIndexes()->size() != 1) { | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         auto index = op->outputIndexes()->data()[0]; | 
					
						
							|  |  |  |         auto blob  = op->main_as_Blob(); | 
					
						
							|  |  |  |         if (blob->dataType() != DataType_DT_FLOAT) { | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2019-12-27 22:16:57 +08:00
										 |  |  |         std::shared_ptr<Tensor> tensor = mTensors[index].second; | 
					
						
							|  |  |  |         if (tensor->host<void>() == nullptr && tensor->deviceId() != 0) { | 
					
						
							|  |  |  |             tensor.reset(Tensor::createHostTensorFromDevice(tensor.get(), true)); | 
					
						
							|  |  |  |             if (tensor.get() == nullptr) { | 
					
						
							|  |  |  |                 MNN_ERROR("failed to copy trained param from device to host\n"); | 
					
						
							|  |  |  |                 return INVALID_VALUE; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         ::memcpy((void*)blob->float32s()->data(), tensor->host<float>(), tensor->size()); | 
					
						
							| 
									
										
										
										
											2019-06-17 20:10:35 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return NO_ERROR; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-04-17 10:49:11 +08:00
										 |  |  | } // namespace MNN
 |