2020-11-05 16:41:56 +08:00
|
|
|
//
|
|
|
|
// WhileModule.cpp
|
|
|
|
// MNN
|
|
|
|
//
|
|
|
|
// Created by MNN on b'2020/09/10'.
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
//
|
|
|
|
|
|
|
|
#include "WhileModule.hpp"
|
2021-01-06 16:29:37 +08:00
|
|
|
#include "StaticModule.hpp"
|
2020-11-05 16:41:56 +08:00
|
|
|
#include <MNN/expr/ExprCreator.hpp>
|
|
|
|
#include "MNN_generated.h"
|
|
|
|
//#define MNN_OPEN_TIME_TRACE
|
|
|
|
#include <MNN/AutoTime.hpp>
|
|
|
|
namespace MNN {
|
|
|
|
namespace Express {
|
|
|
|
static int _findPos(const std::vector<std::string>& names, const std::string& key) {
|
|
|
|
for (int i=0; i<names.size(); ++i) {
|
|
|
|
if (names[i] == key) {
|
|
|
|
return i;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return -1;
|
|
|
|
}
|
|
|
|
WhileModule* WhileModule::create(const Op* op, const std::map<std::string, SubGraph>& subGraph) {
|
|
|
|
auto module = new WhileModule;
|
2021-01-06 16:29:37 +08:00
|
|
|
module->setType("WhileModule");
|
2021-02-07 10:45:07 +08:00
|
|
|
std::shared_ptr<WhileModule::Info> info(new WhileModule::Info);
|
|
|
|
module->mInfo = info;
|
|
|
|
if (nullptr != op->name()) {
|
|
|
|
module->setName(op->name()->str());
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
auto whileParam = op->main_as_WhileParam();
|
|
|
|
auto& body = subGraph.find(whileParam->body_graph()->str())->second;
|
|
|
|
auto& cond = subGraph.find(whileParam->cond_graph()->str())->second;
|
|
|
|
module->mBody = body.m;
|
|
|
|
module->mCond = cond.m;
|
|
|
|
/** Compute map index
|
|
|
|
int mCondInputNumber;
|
|
|
|
int mBodyInputNumber;
|
|
|
|
|
|
|
|
// First mCondInputs' index, Second: inputs's index
|
|
|
|
std::vector<std::pair<int, int>> mInputForCond;
|
|
|
|
|
|
|
|
// First mBodyInputs' index, Second: inputs's index
|
|
|
|
std::vector<std::pair<int, int>> mInputForBody;
|
|
|
|
std::vector<int> mOutputFromBody;
|
|
|
|
std::vector<std::pair<int, int>> mUpdateForCond;
|
|
|
|
std::vector<std::pair<int, int>> mUpdateForBody;
|
|
|
|
std::vector<std::pair<int, int>> mCondUpdateForCond;
|
|
|
|
std::vector<std::pair<int, int>> mCondUpdateForBody;
|
|
|
|
*/
|
|
|
|
// Map Inputs
|
2021-02-07 10:45:07 +08:00
|
|
|
info->mBodyInputNumber = body.inputs.size();
|
|
|
|
info->mCondInputNumber = cond.inputs.size();
|
2020-11-05 16:41:56 +08:00
|
|
|
for (int i=0; i<whileParam->aliases_inputs()->size(); ++i) {
|
|
|
|
auto index = i;
|
|
|
|
auto data = whileParam->aliases_inputs()->GetAs<StringVec>(i);
|
|
|
|
for (int s=0; s<data->data()->size(); ++s) {
|
|
|
|
auto name = data->data()->GetAsString(s)->str();
|
|
|
|
auto bodyInputPos = _findPos(body.inputs, name);
|
|
|
|
if (bodyInputPos >= 0) {
|
2021-02-07 10:45:07 +08:00
|
|
|
info->mInputForBody.emplace_back(std::make_pair(bodyInputPos, i));
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
auto condInputPos = _findPos(cond.inputs, name);
|
|
|
|
if (condInputPos >= 0) {
|
2021-02-07 10:45:07 +08:00
|
|
|
info->mInputForCond.emplace_back(std::make_pair(condInputPos, i));
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
2021-02-07 10:45:07 +08:00
|
|
|
// if (bodyInputPos < 0 && condInputPos < 0) {
|
|
|
|
// MNN_ASSERT(false);
|
|
|
|
// }
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
// Map update
|
|
|
|
auto update = whileParam->aliases_updates();
|
2021-01-06 16:29:37 +08:00
|
|
|
std::set<int> reusedTensors;
|
2020-11-05 16:41:56 +08:00
|
|
|
std::map<int, int> replaceOutputs;
|
|
|
|
for (int i=0; i<update->size(); ++i) {
|
|
|
|
auto data = update->GetAs<StringVec>(i);
|
|
|
|
int bodyInputPos = -1;
|
|
|
|
int condInputPos = -1;
|
|
|
|
int bodyOutputPos = -1;
|
|
|
|
int condOutputPos = -1;
|
|
|
|
MNN_ASSERT(2 == data->data()->size());
|
|
|
|
auto outputName = data->data()->GetAsString(0)->str();
|
|
|
|
auto inputName = data->data()->GetAsString(1)->str();
|
|
|
|
bodyInputPos = _findPos(body.inputs, inputName);
|
|
|
|
condInputPos = _findPos(cond.inputs, inputName);
|
|
|
|
bodyOutputPos = _findPos(body.outputs, outputName);
|
|
|
|
condOutputPos = _findPos(cond.outputs, outputName);
|
|
|
|
|
|
|
|
auto updateBodyOutputPos = _findPos(body.outputs, inputName);
|
|
|
|
|
|
|
|
MNN_ASSERT(bodyOutputPos == -1 || condOutputPos == -1);
|
|
|
|
if (condOutputPos >= 0) {
|
|
|
|
if (bodyInputPos >= 0) {
|
2021-02-07 10:45:07 +08:00
|
|
|
info->mCondUpdateForBody.emplace_back(std::make_pair(bodyInputPos, condOutputPos));
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
if (condInputPos >= 0) {
|
2021-02-07 10:45:07 +08:00
|
|
|
info->mCondUpdateForCond.emplace_back(std::make_pair(condInputPos, condOutputPos));
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
if (bodyOutputPos >= 0) {
|
|
|
|
if (bodyInputPos >= 0) {
|
2021-01-06 16:29:37 +08:00
|
|
|
reusedTensors.insert(bodyOutputPos);
|
2021-02-07 10:45:07 +08:00
|
|
|
info->mUpdateForBody.emplace_back(std::make_pair(bodyInputPos, bodyOutputPos));
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
if (condInputPos >= 0) {
|
2021-02-07 10:45:07 +08:00
|
|
|
info->mUpdateForCond.emplace_back(std::make_pair(condInputPos, bodyOutputPos));
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
if (updateBodyOutputPos >= 0) {
|
|
|
|
replaceOutputs.insert(std::make_pair(updateBodyOutputPos, bodyOutputPos));
|
|
|
|
}
|
2021-02-07 10:45:07 +08:00
|
|
|
MNN_ASSERT(condInputPos >= 0 || bodyInputPos >= 0);
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
2021-02-07 10:45:07 +08:00
|
|
|
//MNN_ASSERT(bodyOutputPos >= 0 || condOutputPos >= 0);
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
2021-02-07 10:45:07 +08:00
|
|
|
MNN_ASSERT(!info->mUpdateForCond.empty());
|
2020-11-05 16:41:56 +08:00
|
|
|
// Map outputs
|
|
|
|
auto output = whileParam->aliases_outputs();
|
2021-02-07 10:45:07 +08:00
|
|
|
info->mOutputNumber = output->size();
|
2020-11-05 16:41:56 +08:00
|
|
|
for (int i=0; i<output->size(); ++i) {
|
|
|
|
auto data = output->GetAsString(i);
|
|
|
|
auto pos = _findPos(body.outputs, data->str());
|
2021-02-07 10:45:07 +08:00
|
|
|
auto posInput = _findPos(body.inputs, data->str());
|
|
|
|
//MNN_ASSERT(pos >= 0 || posInput >= 0);
|
2020-11-05 16:41:56 +08:00
|
|
|
if (replaceOutputs.find(pos) != replaceOutputs.end()) {
|
|
|
|
pos = replaceOutputs[pos];
|
|
|
|
}
|
2021-02-07 10:45:07 +08:00
|
|
|
if (pos >= 0) {
|
|
|
|
info->mOutputFromBody.emplace_back(std::make_pair(i, pos));
|
|
|
|
}
|
|
|
|
if (posInput >= 0) {
|
|
|
|
info->mOutputFromBodyInput.emplace_back(std::make_pair(i, posInput));
|
|
|
|
}
|
|
|
|
for (int j=0; j<whileParam->aliases_inputs()->size(); ++j) {
|
|
|
|
auto inputStrVec = whileParam->aliases_inputs()->GetAs<StringVec>(j);
|
|
|
|
bool find = false;
|
|
|
|
for (int k=0; k<inputStrVec->data()->size(); ++k) {
|
|
|
|
auto name = inputStrVec->data()->GetAsString(k)->str();
|
|
|
|
if (name == data->str()) {
|
|
|
|
find = true;
|
|
|
|
info->mOutputFromInput.emplace_back(j);
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (find) {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
2021-01-06 16:29:37 +08:00
|
|
|
if (module->mBody->type() == "StaticModule") {
|
|
|
|
static_cast<StaticModule*>(module->mBody.get())->setReusedTensors(reusedTensors);
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
return module;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<Express::VARP> WhileModule::onForward(const std::vector<Express::VARP>& inputsI) {
|
2021-02-07 10:45:07 +08:00
|
|
|
std::vector<Express::VARP> condInputs(mInfo->mCondInputNumber);
|
|
|
|
std::vector<Express::VARP> bodyInputs(mInfo->mBodyInputNumber);
|
2020-11-05 16:41:56 +08:00
|
|
|
auto& inputs = inputsI;
|
2021-02-07 10:45:07 +08:00
|
|
|
for (auto& p : mInfo->mInputForCond) {
|
2020-11-05 16:41:56 +08:00
|
|
|
condInputs[p.first] = inputs[p.second];
|
|
|
|
}
|
2021-02-07 10:45:07 +08:00
|
|
|
for (auto& p : mInfo->mInputForBody) {
|
2020-11-05 16:41:56 +08:00
|
|
|
bodyInputs[p.first] = inputs[p.second];
|
|
|
|
}
|
|
|
|
|
2021-02-07 10:45:07 +08:00
|
|
|
std::vector<Express::VARP> outputs(mInfo->mOutputNumber);
|
|
|
|
for (int i = 0; i < mInfo->mOutputFromInput.size(); ++i) {
|
|
|
|
outputs[i] = inputs[mInfo->mOutputFromInput[i]];
|
|
|
|
}
|
|
|
|
int step = 0;
|
2020-11-05 16:41:56 +08:00
|
|
|
while (true) {
|
|
|
|
auto res = mCond->onForward(condInputs)[0];
|
|
|
|
auto resPtr = res->readMap<int>();
|
|
|
|
if (resPtr[0] <= 0) {
|
|
|
|
break;
|
|
|
|
}
|
2021-02-07 10:45:07 +08:00
|
|
|
step++;
|
|
|
|
//MNN_PRINT("%s - %d\n", name().c_str(), step);
|
2020-11-05 16:41:56 +08:00
|
|
|
auto bodyOutputs = mBody->onForward(bodyInputs);
|
|
|
|
Express::Variable::prepareCompute(bodyOutputs);
|
|
|
|
for (int i=0; i<bodyOutputs.size(); ++i) {
|
|
|
|
auto p = bodyOutputs[i];
|
2021-02-07 10:45:07 +08:00
|
|
|
if (p.get() == nullptr) {
|
|
|
|
continue;
|
|
|
|
}
|
2020-11-05 16:41:56 +08:00
|
|
|
if (p->expr().first->get() != nullptr) {
|
|
|
|
auto ptr = p->readMap<void>();
|
|
|
|
auto info = p->getInfo();
|
|
|
|
auto newV = Express::_Input(info->dim, info->order, info->type);
|
|
|
|
if (nullptr != ptr) {
|
|
|
|
::memcpy(newV->writeMap<void>(), ptr, info->type.bytes() * info->size);
|
|
|
|
}
|
|
|
|
bodyOutputs[i] = newV;
|
|
|
|
}
|
|
|
|
}
|
2021-02-07 10:45:07 +08:00
|
|
|
for (auto& p : mInfo->mUpdateForCond) {
|
2020-11-05 16:41:56 +08:00
|
|
|
condInputs[p.first] = bodyOutputs[p.second];
|
|
|
|
}
|
2021-02-07 10:45:07 +08:00
|
|
|
for (auto& p : mInfo->mUpdateForBody) {
|
2020-11-05 16:41:56 +08:00
|
|
|
bodyInputs[p.first] = bodyOutputs[p.second];
|
|
|
|
}
|
2021-02-07 10:45:07 +08:00
|
|
|
for (auto& p : mInfo->mCondUpdateForCond) {
|
2020-11-05 16:41:56 +08:00
|
|
|
condInputs[p.first] = res;
|
|
|
|
}
|
2021-02-07 10:45:07 +08:00
|
|
|
for (auto& p : mInfo->mCondUpdateForBody) {
|
2020-11-05 16:41:56 +08:00
|
|
|
bodyInputs[p.first] = res;
|
|
|
|
}
|
2021-02-07 10:45:07 +08:00
|
|
|
for (int i=0; i<mInfo->mOutputFromBody.size(); ++i) {
|
|
|
|
outputs[mInfo->mOutputFromBody[i].first] = bodyOutputs[mInfo->mOutputFromBody[i].second];
|
|
|
|
}
|
|
|
|
for (int i=0; i<mInfo->mOutputFromBodyInput.size(); ++i) {
|
|
|
|
outputs[mInfo->mOutputFromBodyInput[i].first] = bodyInputs[mInfo->mOutputFromBodyInput[i].second];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto o : outputs) {
|
|
|
|
MNN_ASSERT(nullptr != o);
|
2020-11-05 16:41:56 +08:00
|
|
|
}
|
|
|
|
return outputs;
|
|
|
|
}
|
|
|
|
|
|
|
|
Module* WhileModule::clone(CloneContext* ctx) const {
|
|
|
|
WhileModule* module(new WhileModule);
|
2021-02-07 10:45:07 +08:00
|
|
|
module->mInfo = mInfo;
|
2020-11-05 16:41:56 +08:00
|
|
|
module->mCond.reset(mCond->clone(ctx));
|
|
|
|
module->mBody.reset(mBody->clone(ctx));
|
|
|
|
return this->cloneBaseTo(ctx, module);
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
};
|