| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | //  Module.cpp
 | 
					
						
							|  |  |  | //  MNN
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  Created by MNN on 2019/11/25.
 | 
					
						
							|  |  |  | //  Copyright © 2018, Alibaba Group Holding Limited
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <MNN/expr/Module.hpp>
 | 
					
						
							|  |  |  | #include <MNN/expr/ExprCreator.hpp>
 | 
					
						
							| 
									
										
										
										
											2022-07-22 09:59:30 +08:00
										 |  |  | #include <MNN/expr/ExecutorScope.hpp>
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | #include "PipelineModule.hpp"
 | 
					
						
							|  |  |  | #include "core/FileLoader.hpp"
 | 
					
						
							| 
									
										
										
										
											2022-09-30 10:02:52 +08:00
										 |  |  | #include "backend/cpu/CPUBackend.hpp"
 | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  | #include "MNN_generated.h"
 | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  | #include "Utils.hpp"
 | 
					
						
							| 
									
										
										
										
											2022-06-24 18:30:05 +08:00
										 |  |  | #include "RuntimeAttr.hpp"
 | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-07-22 09:59:30 +08:00
										 |  |  | #include <MNN/AutoTime.hpp>
 | 
					
						
							| 
									
										
										
										
											2022-01-04 10:50:40 +08:00
										 |  |  | #ifdef MNN_INTERNAL_ENABLED
 | 
					
						
							|  |  |  | #include "internal/auth/ModelAuth.hpp"
 | 
					
						
							|  |  |  | #include "internal/logging/Log.hpp"
 | 
					
						
							|  |  |  | #include "internal/logging/LogHelper.hpp"
 | 
					
						
							|  |  |  | #endif // MNN_INTERNAL_ENABLED
 | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | namespace MNN { | 
					
						
							|  |  |  | namespace Express { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-05-06 19:51:20 +08:00
										 |  |  | static Module* loadInternal(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> _rtMgr, const Module::Config* config, bool enforceAuth); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | class EmptyModule : public Module { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |     EmptyModule(const std::vector<Express::VARP>& parameters) { | 
					
						
							|  |  |  |         for (auto p : parameters) { | 
					
						
							|  |  |  |             addParameter(p); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     virtual ~EmptyModule() { | 
					
						
							|  |  |  |         // Do nothing
 | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override { | 
					
						
							|  |  |  |         return {}; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | protected: | 
					
						
							|  |  |  |     EmptyModule() = default; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Module* clone(Module::CloneContext* ctx) const override { | 
					
						
							|  |  |  |         EmptyModule* module(new EmptyModule); | 
					
						
							|  |  |  |         return this->cloneBaseTo(ctx, module); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2022-08-12 10:30:48 +08:00
										 |  |  | void Module::destroy(Module* m) { | 
					
						
							|  |  |  |     if (nullptr != m) { | 
					
						
							|  |  |  |         delete m; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | Module* Module::createEmpty(const std::vector<Express::VARP>& parameters) { | 
					
						
							|  |  |  |     return new EmptyModule(parameters); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Express::VARP Module::forward(Express::VARP input) { | 
					
						
							|  |  |  |     return this->onForward({input})[0]; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | std::vector<Express::VARP> Module::parameters() const { | 
					
						
							|  |  |  |     std::vector<Express::VARP> result; | 
					
						
							|  |  |  |     _collectParameters(result); | 
					
						
							|  |  |  |     return result; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | bool Module::loadParameters(const std::vector<Express::VARP>& parameters) { | 
					
						
							|  |  |  |     std::vector<Express::VARP> result; | 
					
						
							|  |  |  |     _collectParameters(result); | 
					
						
							|  |  |  |     if (parameters.empty() || parameters.size() != result.size()) { | 
					
						
							|  |  |  |         MNN_ERROR("Error parameters, empty or parameter size not match \n"); | 
					
						
							|  |  |  |         return false; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     for (int i=0; i<parameters.size(); ++i) { | 
					
						
							|  |  |  |         if (nullptr != result[i].get()) { | 
					
						
							|  |  |  |             // Check Origin parameter's size
 | 
					
						
							|  |  |  |             auto dstInfo = result[i]->getInfo(); | 
					
						
							|  |  |  |             auto srcInfo = parameters[i]->getInfo(); | 
					
						
							|  |  |  |             if (dstInfo->dim.size() != srcInfo->dim.size() || dstInfo->order != srcInfo->order) { | 
					
						
							|  |  |  |                 MNN_ERROR("Error parameters %d, dim size or order not match \n", i); | 
					
						
							|  |  |  |                 return false; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             if (dstInfo->size != srcInfo->size || dstInfo->type != srcInfo->type) { | 
					
						
							|  |  |  |                 MNN_ERROR("Error parameters %d, size or type not match \n", i); | 
					
						
							|  |  |  |                 return false; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         Variable::replace(result[i], parameters[i]); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return true; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | void Module::setIsTraining(const bool isTraining) { | 
					
						
							|  |  |  |     mIsTraining = isTraining; | 
					
						
							|  |  |  |     for (auto c : mChildren) { | 
					
						
							|  |  |  |         c->setIsTraining(isTraining); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | bool Module::getIsTraining() { | 
					
						
							|  |  |  |     return mIsTraining; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void Module::registerModel(const std::vector<std::shared_ptr<Module>>& children) { | 
					
						
							|  |  |  |     mChildren.insert(mChildren.begin(), children.begin(), children.end()); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | int Module::addParameter(VARP parameter) { | 
					
						
							|  |  |  |     auto res = mParameters.size(); | 
					
						
							|  |  |  |     mParameters.emplace_back(parameter); | 
					
						
							|  |  |  |     return (int)res; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void Module::setParameter(Express::VARP parameter, int index) { | 
					
						
							|  |  |  |     if (index < 0 || index >= mParameters.size()) { | 
					
						
							|  |  |  |         MNN_ERROR("Module error: index out of range: %d - %d:\n", index, (int)mParameters.size()); | 
					
						
							|  |  |  |         return; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     mParameters[index] = parameter; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void Module::_collectParameters(std::vector<Express::VARP>& result) const { | 
					
						
							|  |  |  |     for (auto p : mParameters) { | 
					
						
							|  |  |  |         result.push_back(p); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     for (auto c : mChildren) { | 
					
						
							|  |  |  |         c->_collectParameters(result); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | void Module::clearCache() { | 
					
						
							|  |  |  |     for (auto c : mChildren) { | 
					
						
							|  |  |  |         c->clearCache(); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     this->onClearCache(); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-15 14:12:35 +08:00
										 |  |  | Module* Module::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const char* fileName, const Module::Config* config) { | 
					
						
							| 
									
										
										
										
											2021-09-18 15:52:30 +08:00
										 |  |  |     return load(inputs, outputs, fileName, nullptr, config); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Module* Module::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const Module::Config* config) { | 
					
						
							|  |  |  |     return load(inputs, outputs, buffer, length, nullptr, config); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Module* Module::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const char* fileName, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> rtMgr, const Module::Config* config) { | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |     AutoStorage<uint8_t> buffer; | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |         FileLoader loader(fileName); | 
					
						
							|  |  |  |         if (!loader.valid()) { | 
					
						
							|  |  |  |             MNN_ERROR("Error for open %s\n", fileName); | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |             return nullptr; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |         loader.read(); | 
					
						
							|  |  |  |         if (!loader.valid()) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |             return nullptr; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |         loader.merge(buffer); | 
					
						
							|  |  |  |         if (buffer.get() == nullptr) { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |             return nullptr; | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2021-09-18 15:52:30 +08:00
										 |  |  |     return load(inputs, outputs, buffer.get(), buffer.size(), rtMgr, config); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  | class NetModule : public Module { | 
					
						
							|  |  |  | public: | 
					
						
							| 
									
										
										
										
											2022-07-22 09:59:30 +08:00
										 |  |  |     NetModule(std::shared_ptr<Module> m, std::shared_ptr<Module::Info> info, const MNN::Net* net, size_t size, float costTime) { | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         mModule = m; | 
					
						
							|  |  |  |         mInfo = info; | 
					
						
							|  |  |  |         setType("Net"); | 
					
						
							| 
									
										
										
										
											2022-07-22 09:59:30 +08:00
										 |  |  | #ifdef MNN_INTERNAL_ENABLED
 | 
					
						
							|  |  |  |         if (nullptr != net) { | 
					
						
							|  |  |  |             mLogInfo = getBasicLoggingData(); | 
					
						
							| 
									
										
										
										
											2022-08-12 10:30:48 +08:00
										 |  |  |             std::string uuid = std::string(net->mnn_uuid() ? net->mnn_uuid()->c_str() : ""); | 
					
						
							|  |  |  |             mLogInfo.emplace("UUID", uuid); | 
					
						
							|  |  |  |             mLogInfo.emplace("ModelVersion", info->version); | 
					
						
							|  |  |  |             int backend = MNN_FORWARD_CPU; | 
					
						
							|  |  |  |             int precision = BackendConfig::Precision_Normal; | 
					
						
							|  |  |  |             int mode = 1; | 
					
						
							|  |  |  |             if (info->runTimeManager.get() != nullptr) { | 
					
						
							|  |  |  |                 auto attr = info->runTimeManager->getInside(); | 
					
						
							|  |  |  |                 mode = attr->mNumberThread; | 
					
						
							|  |  |  |                 int backendTypes[MNN_FORWARD_ALL]; | 
					
						
							|  |  |  |                 info->runTimeManager->getInfo(Interpreter::BACKENDS, &backendTypes); | 
					
						
							|  |  |  |                 backend = backendTypes[0]; | 
					
						
							|  |  |  |                 auto config = info->runTimeManager->getBnConfig(); | 
					
						
							|  |  |  |                 if (nullptr != config) { | 
					
						
							|  |  |  |                     precision = config->precision; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             mLogInfo.emplace("Backend",  std::to_string(backend)); | 
					
						
							|  |  |  |             mLogInfo.emplace("Mode",  std::to_string(mode)); | 
					
						
							|  |  |  |             mLogInfo.emplace("Precision", std::to_string(precision)); | 
					
						
							| 
									
										
										
										
											2022-07-22 09:59:30 +08:00
										 |  |  |             if (shouldLog(FREQ_HIGH)) { | 
					
						
							|  |  |  |                 std::map<std::string, std::string> metrics = mLogInfo; | 
					
						
							|  |  |  |                 metrics.emplace("Time", std::to_string(costTime)); | 
					
						
							|  |  |  |                 auto sizeInMB = (float)size / 1024.0f / 1024.0f; | 
					
						
							|  |  |  |                 metrics.emplace("ModelSize",  std::to_string(sizeInMB)); | 
					
						
							|  |  |  |                 metrics.emplace("API", "Express::Module::NetModule"); | 
					
						
							|  |  |  |                 logAsync(metrics); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | #endif // MNN_INTERNAL_ENABLED
 | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2023-01-11 15:08:58 +08:00
										 |  |  |     virtual ~ NetModule(){ | 
					
						
							| 
									
										
										
										
											2023-03-23 21:01:39 +08:00
										 |  |  |         mModule.reset(); | 
					
						
							|  |  |  |         mInfo.reset(); | 
					
						
							| 
									
										
										
										
											2023-01-11 15:08:58 +08:00
										 |  |  |         auto exe = ExecutorScope::Current(); | 
					
						
							|  |  |  |         exe->gc(Executor::FULL); | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override { | 
					
						
							| 
									
										
										
										
											2022-09-30 10:02:52 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-07-22 09:59:30 +08:00
										 |  |  | #ifdef MNN_INTERNAL_ENABLED
 | 
					
						
							|  |  |  |         auto glo = ExecutorScope::Current(); | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |         Timer _time; | 
					
						
							| 
									
										
										
										
											2022-07-22 09:59:30 +08:00
										 |  |  |         glo->getDebugTools()->flops = 0.0f; | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  |         auto outputs = mModule->onForward(inputs); | 
					
						
							|  |  |  | #ifdef MNN_INTERNAL_ENABLED
 | 
					
						
							|  |  |  |         do { | 
					
						
							|  |  |  |             if (outputs.empty()) { | 
					
						
							|  |  |  |                 break; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             if (!shouldLog(FREQ_LOW)) { | 
					
						
							|  |  |  |                 break; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             for (auto& v : outputs) { | 
					
						
							|  |  |  |                 auto t = Utils::getTensor(v); | 
					
						
							|  |  |  |                 t->wait(Tensor::MAP_TENSOR_READ, true); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             auto metrics = mLogInfo; | 
					
						
							|  |  |  |             metrics.emplace("Time", std::to_string((float)_time.durationInUs() / 1000.0f)); | 
					
						
							|  |  |  |             metrics.emplace("API", "NetModule::onForward"); | 
					
						
							|  |  |  |             if (mInfo->runTimeManager.get() != nullptr) { | 
					
						
							|  |  |  |                 float memory = 0.0f; | 
					
						
							|  |  |  |                 mInfo->runTimeManager->getInfo(Interpreter::MEMORY, &memory); | 
					
						
							|  |  |  |                 metrics.emplace("Flops", std::to_string(glo->getDebugTools()->flops)); | 
					
						
							|  |  |  |                 metrics.emplace("Memory", std::to_string(memory)); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             logAsync(metrics); | 
					
						
							|  |  |  |         } while(false); | 
					
						
							|  |  |  | #endif
 | 
					
						
							| 
									
										
										
										
											2022-12-30 15:18:58 +08:00
										 |  |  |         mModule->clearCache(); | 
					
						
							| 
									
										
										
										
											2022-07-22 09:59:30 +08:00
										 |  |  |         return outputs; | 
					
						
							| 
									
										
										
										
											2022-09-30 10:02:52 +08:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |     virtual Module* clone(CloneContext* ctx) const override { | 
					
						
							| 
									
										
										
										
											2022-02-18 11:30:27 +08:00
										 |  |  |         std::shared_ptr<Module> submodule(mModule->clone(ctx)); | 
					
						
							| 
									
										
										
										
											2022-09-30 10:02:52 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-07-22 09:59:30 +08:00
										 |  |  |         NetModule* module(new NetModule(submodule, mInfo, nullptr, 0, 0.0f)); | 
					
						
							|  |  |  | #ifdef MNN_INTERNAL_ENABLED
 | 
					
						
							|  |  |  |         module->mLogInfo = mLogInfo; | 
					
						
							|  |  |  | #endif
 | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         return this->cloneBaseTo(ctx, module); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     const Module::Info* info() const { | 
					
						
							|  |  |  |         return mInfo.get(); | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2022-09-30 10:02:52 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  | private: | 
					
						
							|  |  |  |     std::shared_ptr<Module> mModule; | 
					
						
							|  |  |  |     std::shared_ptr<Module::Info> mInfo; | 
					
						
							| 
									
										
										
										
											2022-07-22 09:59:30 +08:00
										 |  |  | #ifdef MNN_INTERNAL_ENABLED
 | 
					
						
							|  |  |  |     std::map<std::string, std::string> mLogInfo; | 
					
						
							|  |  |  | #endif
 | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | const Module::Info* Module::getInfo() const { | 
					
						
							|  |  |  |     if (mType != "Net") { | 
					
						
							|  |  |  |         MNN_ERROR("The Module is not load from buffer, can't get info\n"); | 
					
						
							|  |  |  |         return nullptr; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return ((NetModule*)(this))->info(); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | static void _loadInputs(Module::Info* info, const std::vector<std::string>& inputs, const Net* net) { | 
					
						
							|  |  |  |     auto type = net->sourceType(); | 
					
						
							|  |  |  |     if (type == NetSource_TENSORFLOW || type == NetSource_TFLITE) { | 
					
						
							|  |  |  |         info->defaultFormat = NHWC; | 
					
						
							|  |  |  |     } else { | 
					
						
							|  |  |  |         info->defaultFormat = NCHW; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     info->inputs.resize(inputs.size()); | 
					
						
							|  |  |  |     std::map<std::string, Variable::Info> allInputs; | 
					
						
							|  |  |  |     for (int i=0; i<net->oplists()->size(); ++i) { | 
					
						
							|  |  |  |         auto op = net->oplists()->GetAs<Op>(i); | 
					
						
							|  |  |  |         if (op->type() == OpType_Input && op->main_as_Input() != nullptr) { | 
					
						
							|  |  |  |             auto name = net->tensorName()->GetAsString(op->outputIndexes()->data()[0])->str(); | 
					
						
							|  |  |  |             auto inputInfo = op->main_as_Input(); | 
					
						
							|  |  |  |             std::vector<int> dims; | 
					
						
							|  |  |  |             if (nullptr != inputInfo->dims()) { | 
					
						
							|  |  |  |                 dims.resize(inputInfo->dims()->size()); | 
					
						
							|  |  |  |                 for (int v=0; v<dims.size(); ++v) { | 
					
						
							|  |  |  |                     dims[v] = inputInfo->dims()->data()[v]; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             auto dtype = Utils::revertDataType(inputInfo->dtype()); | 
					
						
							|  |  |  |             Variable::Info vinfo; | 
					
						
							|  |  |  |             vinfo.dim = std::move(dims); | 
					
						
							|  |  |  |             vinfo.order = Utils::revertFormat(inputInfo->dformat()); | 
					
						
							|  |  |  |             vinfo.type = dtype; | 
					
						
							|  |  |  |             vinfo.syncSize(); | 
					
						
							|  |  |  |             allInputs.insert(std::make_pair(name, std::move(vinfo))); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     for (int i=0; i<inputs.size(); ++i) { | 
					
						
							|  |  |  |         auto iter = allInputs.find(inputs[i]); | 
					
						
							|  |  |  |         if (iter != allInputs.end()) { | 
					
						
							|  |  |  |             info->inputs[i] = iter->second; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-18 11:30:27 +08:00
										 |  |  | Module* Module::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> _rtMgr, const Module::Config* config) { | 
					
						
							| 
									
										
										
										
											2022-05-06 19:51:20 +08:00
										 |  |  |     return loadInternal(inputs, outputs, buffer, length, _rtMgr, config, true); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | static Module* loadInternal(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> _rtMgr, const Module::Config* config, bool enforceAuth) { | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |     // Check if runtime is valid
 | 
					
						
							| 
									
										
										
										
											2022-06-24 18:30:05 +08:00
										 |  |  |     if (nullptr != _rtMgr && _rtMgr->getInside()->mRuntime.first.empty()) { | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         MNN_ERROR("Invalid runtime\n"); | 
					
						
							|  |  |  |         return nullptr; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2023-02-15 10:30:27 +08:00
										 |  |  |     bool checkMNNBuffer = true; | 
					
						
							|  |  |  |     if (nullptr != _rtMgr) { | 
					
						
							|  |  |  |         checkMNNBuffer = _rtMgr->getInside()->checkNetBuffer; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (checkMNNBuffer) { | 
					
						
							|  |  |  |         flatbuffers::Verifier verify(buffer, length); | 
					
						
							|  |  |  |         if (false == VerifyNetBuffer(verify)) { | 
					
						
							|  |  |  |             MNN_PRINT("Invalidate buffer to create MNN Module\n"); | 
					
						
							|  |  |  |             return nullptr; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     // Check Auto Inputs and Outputs
 | 
					
						
							|  |  |  |     auto net = GetNet(buffer); | 
					
						
							|  |  |  |     if (nullptr == net->oplists() || nullptr == net->tensorName()) { | 
					
						
							|  |  |  |         MNN_ERROR("Invalid net, for null oplist or tensorName\n"); | 
					
						
							|  |  |  |         return nullptr; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2022-07-22 09:59:30 +08:00
										 |  |  |     Timer _time; | 
					
						
							| 
									
										
										
										
											2022-05-06 19:51:20 +08:00
										 |  |  |     std::shared_ptr<Module::Info> info(new Module::Info); | 
					
						
							| 
									
										
										
										
											2022-06-27 10:51:38 +08:00
										 |  |  |     if (net->extraInfo() && net->extraInfo()->version()) { | 
					
						
							|  |  |  |         info->version = net->extraInfo()->version()->str(); | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2022-02-18 11:30:27 +08:00
										 |  |  |     auto rtMgr = _rtMgr; | 
					
						
							|  |  |  |     Module::Config defaultConfig; | 
					
						
							|  |  |  |     if (nullptr == config) { | 
					
						
							|  |  |  |         config = &defaultConfig; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if(nullptr == rtMgr && config->backend != nullptr) { | 
					
						
							|  |  |  |         ScheduleConfig sche_config; | 
					
						
							|  |  |  |         sche_config.type = config->backend->type; | 
					
						
							|  |  |  |         sche_config.backendConfig = config->backend->config; | 
					
						
							|  |  |  |         rtMgr.reset(Executor::RuntimeManager::createRuntimeManager(sche_config)); | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2022-06-27 10:51:38 +08:00
										 |  |  |     info->inputNames = inputs; | 
					
						
							|  |  |  |     info->outputNames = outputs; | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     if ((!inputs.empty()) && (!outputs.empty())) { | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |         _loadInputs(info.get(), inputs, net); | 
					
						
							|  |  |  |         info->runTimeManager = rtMgr; | 
					
						
							|  |  |  |         std::shared_ptr<Module> m(PipelineModule::load(inputs, outputs, buffer, length, rtMgr, config)); | 
					
						
							| 
									
										
										
										
											2022-07-22 09:59:30 +08:00
										 |  |  |         return new NetModule(m, info, net, length, (float)_time.durationInUs() / 1000.0f); | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |     std::set<int> inputIdx, outputIdx, realInput, realOutput; | 
					
						
							|  |  |  |     for (int i=0; i< net->oplists()->size(); ++i) { | 
					
						
							|  |  |  |         auto op = net->oplists()->GetAs<Op>(i); | 
					
						
							|  |  |  |         if (nullptr != op->inputIndexes()) { | 
					
						
							|  |  |  |             auto data = op->inputIndexes()->data(); | 
					
						
							|  |  |  |             auto size = op->inputIndexes()->size(); | 
					
						
							|  |  |  |             for (int j=0; j<size; ++j) { | 
					
						
							|  |  |  |                 inputIdx.insert(data[j]); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (nullptr != op->outputIndexes()) { | 
					
						
							|  |  |  |             auto data = op->outputIndexes()->data(); | 
					
						
							|  |  |  |             auto size = op->outputIndexes()->size(); | 
					
						
							|  |  |  |             for (int j=0; j<size; ++j) { | 
					
						
							|  |  |  |                 outputIdx.insert(data[j]); | 
					
						
							|  |  |  |                 if (op->type() == OpType_Input) { | 
					
						
							|  |  |  |                     realInput.insert(data[j]); | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2022-06-27 10:51:38 +08:00
										 |  |  |     if (info->inputNames.empty()) { | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |         for (auto index : realInput) { | 
					
						
							| 
									
										
										
										
											2022-06-27 10:51:38 +08:00
										 |  |  |             info->inputNames.emplace_back(net->tensorName()->GetAsString(index)->str()); | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2022-06-27 10:51:38 +08:00
										 |  |  |     if (info->outputNames.empty()) { | 
					
						
							| 
									
										
										
										
											2022-12-24 09:42:39 +08:00
										 |  |  |         if (nullptr != net->outputName()) { | 
					
						
							|  |  |  |             for (int i=0; i<net->outputName()->size(); ++i) { | 
					
						
							|  |  |  |                 info->outputNames.emplace_back(net->outputName()->GetAsString(i)->str()); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |             std::set_difference(outputIdx.begin(), outputIdx.end(), inputIdx.begin(), inputIdx.end(), std::inserter(realOutput, realOutput.begin())); | 
					
						
							|  |  |  |             for (auto index : realOutput) { | 
					
						
							|  |  |  |                 info->outputNames.emplace_back(net->tensorName()->GetAsString(index)->str()); | 
					
						
							|  |  |  |             } | 
					
						
							| 
									
										
										
										
											2021-06-11 17:17:13 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2022-06-27 10:51:38 +08:00
										 |  |  |     std::shared_ptr<Module> m(PipelineModule::load(info->inputNames, info->outputNames, buffer, length, rtMgr, config)); | 
					
						
							|  |  |  |     _loadInputs(info.get(), info->inputNames, net); | 
					
						
							| 
									
										
										
										
											2021-11-30 10:10:53 +08:00
										 |  |  |     info->runTimeManager = rtMgr; | 
					
						
							| 
									
										
										
										
											2022-07-22 09:59:30 +08:00
										 |  |  |     return new NetModule(m, info, net, length, (float)_time.durationInUs() / 1000.0f); | 
					
						
							| 
									
										
										
										
											2020-12-15 14:12:35 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | EXPRP Module::CloneContext::getOrClone(EXPRP expr) { | 
					
						
							|  |  |  |     auto it = mExprMap.find(expr.get()); | 
					
						
							|  |  |  |     if (it == mExprMap.end()) { | 
					
						
							| 
									
										
										
										
											2020-12-17 16:14:25 +08:00
										 |  |  |         EXPRP replica; | 
					
						
							|  |  |  |         if (expr->get() == nullptr) { | 
					
						
							|  |  |  |             VARP var = Variable::create(expr); | 
					
						
							|  |  |  |             Variable::Info info(*var->getInfo()); | 
					
						
							|  |  |  |             replica = Expr::create(std::move(info), var->readMap<void>(), expr->inputType(), | 
					
						
							|  |  |  |                                    (expr->inputType() != VARP::CONSTANT) ? Expr::COPY : Expr::REF); | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |             std::vector<VARP> inputs; | 
					
						
							|  |  |  |             for (auto& input: expr->inputs()) { | 
					
						
							|  |  |  |                 inputs.emplace_back(getOrClone(input)); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             replica = Expr::create(expr->extra(), std::move(inputs), expr->outputSize()); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         replica->setName(expr->name()); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         it = mExprMap.emplace(expr.get(), replica).first; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return it->second; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | VARP Module::CloneContext::getOrClone(VARP var) { | 
					
						
							|  |  |  |     auto it = mVarMap.find(var.get()); | 
					
						
							| 
									
										
										
										
											2020-12-17 16:14:25 +08:00
										 |  |  |     if (it == mVarMap.end()) { | 
					
						
							|  |  |  |         auto expr = var->expr(); | 
					
						
							|  |  |  |         VARP replica = Variable::create(getOrClone(expr.first), expr.second); | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  |         it = mVarMap.emplace(var.get(), replica).first; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return it->second; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Module* Module::clone(const Module* module, const bool shareParams) { | 
					
						
							|  |  |  |     CloneContext context(shareParams); | 
					
						
							|  |  |  |     return module->clone(&context); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Module* Module::cloneBaseTo(CloneContext* ctx, Module* module) const { | 
					
						
							|  |  |  |     for (const Express::VARP& var : mParameters) { | 
					
						
							|  |  |  |         module->mParameters.push_back(ctx->getOrClone(var)); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     module->mIsTraining = mIsTraining; | 
					
						
							|  |  |  |     module->mName = mName; | 
					
						
							|  |  |  |     module->mType = mType; | 
					
						
							|  |  |  |     return module; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-14 17:21:30 +08:00
										 |  |  | Module* Module::extract(std::vector<Express::VARP> inputs, std::vector<Express::VARP> outputs, bool fortrain, const std::map<std::string, SubGraph>& subGraph) { | 
					
						
							|  |  |  |     return new PipelineModule(inputs, outputs); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-05 16:41:56 +08:00
										 |  |  | } // namespace Express
 | 
					
						
							|  |  |  | } // namespace MNN
 |