mirror of https://github.com/alibaba/MNN.git
547 lines
20 KiB
C++
547 lines
20 KiB
C++
//
|
|
// Module.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2019/11/25.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include <MNN/expr/Module.hpp>
|
|
#include <MNN/expr/ExprCreator.hpp>
|
|
#include <MNN/expr/ExecutorScope.hpp>
|
|
#include "core/OpCommonUtils.hpp"
|
|
#include "PipelineModule.hpp"
|
|
#include "core/FileLoader.hpp"
|
|
#include "backend/cpu/CPUBackend.hpp"
|
|
#include "MNN_generated.h"
|
|
#include "Utils.hpp"
|
|
#include "RuntimeAttr.hpp"
|
|
#include "ModuleInside.hpp"
|
|
#include <MNN/AutoTime.hpp>
|
|
#ifdef MNN_INTERNAL_ENABLED
|
|
#include "internal/auth/ModelAuth.hpp"
|
|
#include "internal/logging/Log.hpp"
|
|
#include "internal/logging/LogHelper.hpp"
|
|
#endif // MNN_INTERNAL_ENABLED
|
|
|
|
namespace MNN {
|
|
namespace Express {
|
|
static MNN::Express::Executor::RuntimeManager* _createDefaultRuntimeManager(const Module::Config* config) {
|
|
ScheduleConfig sche_config;
|
|
if(nullptr != config && config->backend != nullptr) {
|
|
sche_config.type = config->backend->type;
|
|
sche_config.backendConfig = config->backend->config;
|
|
} else {
|
|
auto exe = ExecutorScope::Current();
|
|
auto attr = exe->getAttr();
|
|
sche_config.type = attr->firstType;
|
|
sche_config.numThread = attr->numThread;
|
|
sche_config.backendConfig = &attr->config;
|
|
}
|
|
return Executor::RuntimeManager::createRuntimeManager(sche_config);
|
|
}
|
|
|
|
static Module* loadInternal(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> _rtMgr, const Module::Config* config);
|
|
|
|
class EmptyModule : public Module {
|
|
public:
|
|
EmptyModule(const std::vector<Express::VARP>& parameters) {
|
|
for (auto p : parameters) {
|
|
addParameter(p);
|
|
}
|
|
}
|
|
virtual ~EmptyModule() {
|
|
// Do nothing
|
|
}
|
|
virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override {
|
|
return {};
|
|
}
|
|
|
|
protected:
|
|
EmptyModule() = default;
|
|
|
|
Module* clone(Module::CloneContext* ctx) const override {
|
|
EmptyModule* module(new EmptyModule);
|
|
return this->cloneBaseTo(ctx, module);
|
|
}
|
|
};
|
|
void Module::destroy(Module* m) {
|
|
if (nullptr != m) {
|
|
delete m;
|
|
}
|
|
}
|
|
|
|
Module* Module::createEmpty(const std::vector<Express::VARP>& parameters) {
|
|
return new EmptyModule(parameters);
|
|
}
|
|
|
|
Express::VARP Module::forward(Express::VARP input) {
|
|
return this->onForward({input})[0];
|
|
}
|
|
std::vector<Express::VARP> Module::parameters() const {
|
|
std::vector<Express::VARP> result;
|
|
_collectParameters(result);
|
|
return result;
|
|
}
|
|
bool Module::loadParameters(const std::vector<Express::VARP>& parameters) {
|
|
std::vector<Express::VARP> result;
|
|
_collectParameters(result);
|
|
if (parameters.empty() || parameters.size() != result.size()) {
|
|
MNN_ERROR("Error parameters, empty or parameter size not match \n");
|
|
return false;
|
|
}
|
|
for (int i=0; i<parameters.size(); ++i) {
|
|
if (nullptr != result[i].get()) {
|
|
// Check Origin parameter's size
|
|
auto dstInfo = result[i]->getInfo();
|
|
auto srcInfo = parameters[i]->getInfo();
|
|
if (dstInfo->dim.size() != srcInfo->dim.size() || dstInfo->order != srcInfo->order) {
|
|
MNN_ERROR("Error parameters %d, dim size or order not match \n", i);
|
|
return false;
|
|
}
|
|
if (dstInfo->size != srcInfo->size || dstInfo->type != srcInfo->type) {
|
|
MNN_ERROR("Error parameters %d, size or type not match \n", i);
|
|
return false;
|
|
}
|
|
}
|
|
Variable::replace(result[i], parameters[i]);
|
|
}
|
|
return true;
|
|
}
|
|
void Module::setIsTraining(const bool isTraining) {
|
|
mIsTraining = isTraining;
|
|
for (auto c : mChildren) {
|
|
c->setIsTraining(isTraining);
|
|
}
|
|
}
|
|
|
|
bool Module::getIsTraining() {
|
|
return mIsTraining;
|
|
}
|
|
|
|
void Module::registerModel(const std::vector<std::shared_ptr<Module>>& children) {
|
|
mChildren.insert(mChildren.begin(), children.begin(), children.end());
|
|
}
|
|
int Module::addParameter(VARP parameter) {
|
|
auto res = mParameters.size();
|
|
mParameters.emplace_back(parameter);
|
|
return (int)res;
|
|
}
|
|
|
|
void Module::setParameter(Express::VARP parameter, int index) {
|
|
if (index < 0 || index >= mParameters.size()) {
|
|
MNN_ERROR("Module error: index out of range: %d - %d:\n", index, (int)mParameters.size());
|
|
return;
|
|
}
|
|
mParameters[index] = parameter;
|
|
}
|
|
|
|
void Module::_collectParameters(std::vector<Express::VARP>& result) const {
|
|
for (auto p : mParameters) {
|
|
result.push_back(p);
|
|
}
|
|
for (auto c : mChildren) {
|
|
c->_collectParameters(result);
|
|
}
|
|
}
|
|
void Module::clearCache() {
|
|
for (auto c : mChildren) {
|
|
c->clearCache();
|
|
}
|
|
this->onClearCache();
|
|
}
|
|
|
|
Module* Module::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const char* fileName, const Module::Config* config) {
|
|
return load(inputs, outputs, fileName, nullptr, config);
|
|
}
|
|
|
|
Module* Module::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const Module::Config* config) {
|
|
return load(inputs, outputs, buffer, length, nullptr, config);
|
|
}
|
|
|
|
class NetModule : public Module {
|
|
public:
|
|
NetModule(std::shared_ptr<Module> m, std::shared_ptr<Module::Info> info, const MNN::Net* net, size_t size, float costTime) {
|
|
mChildren = {m};
|
|
auto mModule = mChildren[0];
|
|
mInfo = info;
|
|
setType("Net");
|
|
#ifdef MNN_INTERNAL_ENABLED
|
|
if (nullptr != net) {
|
|
mLogInfo = logBasicInfo();
|
|
std::string uuid = std::string(net->mnn_uuid() ? net->mnn_uuid()->c_str() : "");
|
|
mLogInfo.emplace("UUID", uuid);
|
|
mLogInfo.emplace("ModelVersion", info->version);
|
|
int backend = MNN_FORWARD_CPU;
|
|
int precision = BackendConfig::Precision_Normal;
|
|
int mode = 1;
|
|
if (info->runTimeManager.get() != nullptr) {
|
|
auto attr = info->runTimeManager->getInside();
|
|
mode = attr->mContent->mNumberThread;
|
|
int backendTypes[MNN_FORWARD_ALL];
|
|
info->runTimeManager->getInfo(Interpreter::BACKENDS, &backendTypes);
|
|
backend = backendTypes[0];
|
|
auto config = info->runTimeManager->getBnConfig();
|
|
if (nullptr != config) {
|
|
precision = config->precision;
|
|
}
|
|
}
|
|
mLogInfo.emplace("Backend", std::to_string(backend));
|
|
mLogInfo.emplace("Mode", std::to_string(mode));
|
|
mLogInfo.emplace("Precision", std::to_string(precision));
|
|
if (shouldLog(FREQ_HIGH)) {
|
|
std::map<std::string, std::string> metrics = mLogInfo;
|
|
metrics.emplace("Time", std::to_string(costTime));
|
|
auto sizeInMB = (float)size / 1024.0f / 1024.0f;
|
|
metrics.emplace("ModelSize", std::to_string(sizeInMB));
|
|
metrics.emplace("API", "Express::Module::NetModule");
|
|
logAsync(metrics);
|
|
}
|
|
}
|
|
#endif // MNN_INTERNAL_ENABLED
|
|
}
|
|
virtual ~ NetModule(){
|
|
mChildren.clear();
|
|
mInfo.reset();
|
|
auto exe = ExecutorScope::Current();
|
|
exe->gc(Executor::FULL);
|
|
}
|
|
|
|
virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override {
|
|
auto mModule = mChildren[0];
|
|
// Reset resize staus
|
|
mInfo->runTimeManager->getInside()->mResizeStatus = 0;
|
|
#ifdef MNN_INTERNAL_ENABLED
|
|
Timer _time;
|
|
auto glo = ExecutorScope::Current();
|
|
glo->getDebugTools()->flops = 0.0f;
|
|
#endif
|
|
std::vector<VARP> outputs;
|
|
{
|
|
Executor::RuntimeExecuteWrap wrap(mInfo->runTimeManager->getInside()->mRuntime);
|
|
outputs = mModule->onForward(inputs);
|
|
}
|
|
#ifdef MNN_INTERNAL_ENABLED
|
|
do {
|
|
if (outputs.empty()) {
|
|
break;
|
|
}
|
|
if (!shouldLog(FREQ_LOW)) {
|
|
break;
|
|
}
|
|
for (auto& v : outputs) {
|
|
auto t = Utils::getTensor(v);
|
|
t->wait(Tensor::MAP_TENSOR_READ, true);
|
|
}
|
|
auto metrics = mLogInfo;
|
|
metrics.emplace("Time", std::to_string((float)_time.durationInUs() / 1000.0f));
|
|
metrics.emplace("API", "NetModule::onForward");
|
|
if (mInfo->runTimeManager.get() != nullptr) {
|
|
float memory = 0.0f;
|
|
mInfo->runTimeManager->getInfo(Interpreter::MEMORY, &memory);
|
|
metrics.emplace("Flops", std::to_string(glo->getDebugTools()->flops));
|
|
metrics.emplace("Memory", std::to_string(memory));
|
|
}
|
|
logAsync(metrics);
|
|
MNN_PRINT("Cost time with log: %f\n", (float)_time.durationInUs() / 1000.0f);
|
|
} while(false);
|
|
#endif
|
|
|
|
mModule->clearCache();
|
|
return outputs;
|
|
}
|
|
virtual Module* clone(CloneContext* ctx) const override {
|
|
auto mModule = mChildren[0];
|
|
auto origin = mInfo->runTimeManager->getInside();
|
|
ScheduleConfig config;
|
|
config.type = origin->mRuntime.first.begin()->first;
|
|
config.numThread = origin->mContent->mNumberThread;
|
|
std::shared_ptr<Executor::RuntimeManager> newRt (Executor::RuntimeManager::createRuntimeManager(config));
|
|
const_cast<RuntimeAttr*>(newRt->getInside())->mContent = origin->mContent;
|
|
std::shared_ptr<Module::Info> newInfo(new Module::Info);
|
|
*newInfo = *mInfo;
|
|
ctx->pRuntimeManager = newRt;
|
|
newInfo->runTimeManager = newRt;
|
|
std::shared_ptr<Module> submodule(mModule->clone(ctx));
|
|
NetModule* module(new NetModule(submodule, newInfo, nullptr, 0, 0.0f));
|
|
#ifdef MNN_INTERNAL_ENABLED
|
|
module->mLogInfo = mLogInfo;
|
|
#endif
|
|
return this->cloneBaseTo(ctx, module);
|
|
}
|
|
const Module::Info* info() const {
|
|
return mInfo.get();
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<Module::Info> mInfo;
|
|
#ifdef MNN_INTERNAL_ENABLED
|
|
std::map<std::string, std::string> mLogInfo;
|
|
#endif
|
|
};
|
|
|
|
const Module::Info* Module::getInfo() const {
|
|
if (mType != "Net") {
|
|
MNN_ERROR("The Module is not load from buffer, can't get info\n");
|
|
return nullptr;
|
|
}
|
|
return ((NetModule*)(this))->info();
|
|
}
|
|
|
|
static void _loadInputs(Module::Info* info, const std::vector<std::string>& inputs, const Net* net) {
|
|
auto type = net->sourceType();
|
|
if (type == NetSource_TENSORFLOW || type == NetSource_TFLITE) {
|
|
info->defaultFormat = NHWC;
|
|
} else {
|
|
info->defaultFormat = NCHW;
|
|
}
|
|
info->inputs.resize(inputs.size());
|
|
std::map<std::string, Variable::Info> allInputs;
|
|
for (int i=0; i<net->oplists()->size(); ++i) {
|
|
auto op = net->oplists()->GetAs<Op>(i);
|
|
if (op->type() == OpType_Input && op->main_as_Input() != nullptr) {
|
|
auto name = net->tensorName()->GetAsString(op->outputIndexes()->data()[0])->str();
|
|
auto inputInfo = op->main_as_Input();
|
|
std::vector<int> dims;
|
|
if (nullptr != inputInfo->dims()) {
|
|
dims.resize(inputInfo->dims()->size());
|
|
for (int v=0; v<dims.size(); ++v) {
|
|
dims[v] = inputInfo->dims()->data()[v];
|
|
}
|
|
}
|
|
auto dtype = Utils::revertDataType(inputInfo->dtype());
|
|
Variable::Info vinfo;
|
|
vinfo.dim = std::move(dims);
|
|
vinfo.order = Utils::revertFormat(inputInfo->dformat());
|
|
vinfo.type = dtype;
|
|
vinfo.syncSize();
|
|
allInputs.insert(std::make_pair(name, std::move(vinfo)));
|
|
}
|
|
}
|
|
for (int i=0; i<inputs.size(); ++i) {
|
|
auto iter = allInputs.find(inputs[i]);
|
|
if (iter != allInputs.end()) {
|
|
info->inputs[i] = iter->second;
|
|
}
|
|
}
|
|
}
|
|
|
|
Module* Module::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const char* fileName, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> _rtMgr, const Module::Config* config) {
|
|
AutoStorage<uint8_t> buffer;
|
|
{
|
|
FileLoader loader(fileName, true);
|
|
if (!loader.valid()) {
|
|
MNN_ERROR("Error for open %s\n", fileName);
|
|
return nullptr;
|
|
}
|
|
loader.read();
|
|
if (!loader.valid()) {
|
|
return nullptr;
|
|
}
|
|
loader.merge(buffer);
|
|
if (buffer.get() == nullptr) {
|
|
return nullptr;
|
|
}
|
|
}
|
|
auto rtMgr = _rtMgr;
|
|
if (nullptr == rtMgr.get()) {
|
|
rtMgr.reset(_createDefaultRuntimeManager(config));
|
|
}
|
|
bool needReset = false;
|
|
if (rtMgr->getInside()->mContent->mExternalFile.empty()) {
|
|
// Set Default externalFile
|
|
rtMgr->setExternalFile(std::string(fileName) + ".weight");
|
|
needReset = true;
|
|
}
|
|
auto res = loadInternal(inputs, outputs, buffer.get(), buffer.size(), rtMgr, config);
|
|
if (needReset) {
|
|
rtMgr->setExternalFile("");
|
|
}
|
|
return res;
|
|
}
|
|
|
|
Module* Module::load(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> _rtMgr, const Module::Config* config) {
|
|
auto rtmgr = _rtMgr;
|
|
if (nullptr == rtmgr) {
|
|
rtmgr.reset(_createDefaultRuntimeManager(config));
|
|
}
|
|
return loadInternal(inputs, outputs, buffer, length, rtmgr, config);
|
|
}
|
|
|
|
static Module* loadInternal(const std::vector<std::string>& inputs, const std::vector<std::string>& outputs, const uint8_t* buffer, size_t length, const std::shared_ptr<MNN::Express::Executor::RuntimeManager> _rtMgr, const Module::Config* config) {
|
|
// Check if runtime is valid
|
|
if (nullptr == _rtMgr || _rtMgr->getInside()->mRuntime.first.empty()) {
|
|
MNN_ERROR("Invalid runtime\n");
|
|
return nullptr;
|
|
}
|
|
bool checkMNNBuffer = true;
|
|
if (nullptr != _rtMgr) {
|
|
checkMNNBuffer = _rtMgr->getInside()->mContent->modes.checkNetBuffer;
|
|
}
|
|
bool valid = true;
|
|
if (checkMNNBuffer) {
|
|
valid = OpCommonUtils::checkNet(buffer, length);
|
|
}
|
|
if (!valid) {
|
|
return nullptr;
|
|
}
|
|
auto net = GetNet(buffer);
|
|
Timer _time;
|
|
std::shared_ptr<Module::Info> info(new Module::Info);
|
|
if (net->mnn_uuid()) {
|
|
info->uuid = net->mnn_uuid()->str();
|
|
}
|
|
if (net->extraInfo()) {
|
|
if (net->extraInfo()->version()) {
|
|
info->version = net->extraInfo()->version()->str();
|
|
}
|
|
// Get Meta
|
|
if (net->extraInfo()->buffer()) {
|
|
auto extra = flatbuffers::GetRoot<Extra>(net->extraInfo()->buffer()->data());
|
|
if (nullptr != extra->attr()) {
|
|
for (int i=0; i<extra->attr()->size(); ++i) {
|
|
auto attr = extra->attr()->GetAs<Attribute>(i);
|
|
if (nullptr != attr->key() && nullptr != attr->s()) {
|
|
// The model may be incomplete, avoid crash
|
|
info->metaData.insert(std::make_pair(attr->key()->str(), attr->s()->str()));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (net->bizCode()) {
|
|
info->bizCode = net->bizCode()->str();
|
|
}
|
|
auto rtMgr = _rtMgr;
|
|
Module::Config defaultConfig;
|
|
if (nullptr == config) {
|
|
config = &defaultConfig;
|
|
}
|
|
info->inputNames = inputs;
|
|
info->outputNames = outputs;
|
|
if ((!inputs.empty()) && (!outputs.empty())) {
|
|
_loadInputs(info.get(), inputs, net);
|
|
info->runTimeManager = rtMgr;
|
|
std::shared_ptr<Module> m(PipelineModule::load(inputs, outputs, buffer, length, rtMgr, config));
|
|
if (nullptr == m) {
|
|
return nullptr;
|
|
}
|
|
return new NetModule(m, info, net, length, (float)_time.durationInUs() / 1000.0f);
|
|
}
|
|
std::set<int> inputIdx, outputIdx, realOutput;
|
|
std::vector<int> realInput;
|
|
for (int i=0; i< net->oplists()->size(); ++i) {
|
|
auto op = net->oplists()->GetAs<Op>(i);
|
|
if (nullptr != op->inputIndexes()) {
|
|
auto data = op->inputIndexes()->data();
|
|
auto size = op->inputIndexes()->size();
|
|
for (int j=0; j<size; ++j) {
|
|
inputIdx.insert(data[j]);
|
|
}
|
|
}
|
|
if (nullptr != op->outputIndexes()) {
|
|
auto data = op->outputIndexes()->data();
|
|
auto size = op->outputIndexes()->size();
|
|
for (int j=0; j<size; ++j) {
|
|
outputIdx.insert(data[j]);
|
|
if (op->type() == OpType_Input) {
|
|
realInput.emplace_back(data[j]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if (info->inputNames.empty()) {
|
|
for (auto index : realInput) {
|
|
info->inputNames.emplace_back(net->tensorName()->GetAsString(index)->str());
|
|
}
|
|
}
|
|
if (info->outputNames.empty()) {
|
|
if (nullptr != net->outputName()) {
|
|
for (int i=0; i<net->outputName()->size(); ++i) {
|
|
info->outputNames.emplace_back(net->outputName()->GetAsString(i)->str());
|
|
}
|
|
} else {
|
|
std::set_difference(outputIdx.begin(), outputIdx.end(), inputIdx.begin(), inputIdx.end(), std::inserter(realOutput, realOutput.begin()));
|
|
for (auto index : realOutput) {
|
|
info->outputNames.emplace_back(net->tensorName()->GetAsString(index)->str());
|
|
}
|
|
}
|
|
}
|
|
std::shared_ptr<Module> m(PipelineModule::load(info->inputNames, info->outputNames, buffer, length, rtMgr, config));
|
|
_loadInputs(info.get(), info->inputNames, net);
|
|
info->runTimeManager = rtMgr;
|
|
if (nullptr == m) {
|
|
return nullptr;
|
|
}
|
|
return new NetModule(m, info, net, length, (float)_time.durationInUs() / 1000.0f);
|
|
}
|
|
|
|
EXPRP Module::CloneContext::getOrClone(EXPRP expr) {
|
|
auto it = mExprMap.find(expr.get());
|
|
if (it == mExprMap.end()) {
|
|
EXPRP replica;
|
|
if (expr->get() == nullptr) {
|
|
VARP var = Variable::create(expr);
|
|
Variable::Info info(*var->getInfo());
|
|
replica = Expr::create(std::move(info), var->readMap<void>(), expr->inputType(),
|
|
(expr->inputType() != VARP::CONSTANT) ? Expr::COPY : Expr::REF);
|
|
} else {
|
|
std::vector<VARP> inputs;
|
|
for (auto& input: expr->inputs()) {
|
|
inputs.emplace_back(getOrClone(input));
|
|
}
|
|
replica = Expr::create(expr->extra(), std::move(inputs), expr->outputSize());
|
|
}
|
|
replica->setName(expr->name());
|
|
it = mExprMap.emplace(expr.get(), replica).first;
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
VARP Module::CloneContext::getOrClone(VARP var) {
|
|
auto it = mVarMap.find(var.get());
|
|
if (it == mVarMap.end()) {
|
|
auto expr = var->expr();
|
|
VARP replica = Variable::create(getOrClone(expr.first), expr.second);
|
|
it = mVarMap.emplace(var.get(), replica).first;
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
Module* Module::clone(const Module* module, const bool shareParams) {
|
|
CloneContext context(shareParams);
|
|
return module->clone(&context);
|
|
}
|
|
|
|
Module* Module::cloneBaseTo(CloneContext* ctx, Module* module) const {
|
|
for (const Express::VARP& var : mParameters) {
|
|
module->mParameters.push_back(ctx->getOrClone(var));
|
|
}
|
|
module->mIsTraining = mIsTraining;
|
|
module->mName = mName;
|
|
module->mType = mType;
|
|
return module;
|
|
}
|
|
|
|
Module* Module::extract(std::vector<Express::VARP> inputs, std::vector<Express::VARP> outputs, bool fortrain, const std::map<std::string, SubGraph>& subGraph) {
|
|
return new PipelineModule(inputs, outputs);
|
|
}
|
|
int Module::traceOrOptimize(Interpreter::SessionMode stage) {
|
|
auto code = this->onOptimize(stage);
|
|
if (code != 0) {
|
|
// Has Error
|
|
return code;
|
|
}
|
|
for (auto& m : mChildren) {
|
|
code = m->traceOrOptimize(stage);
|
|
if (code != 0) {
|
|
return code;
|
|
}
|
|
}
|
|
return code;
|
|
}
|
|
|
|
|
|
} // namespace Express
|
|
} // namespace MNN
|