mirror of https://github.com/alibaba/MNN.git
639 lines
25 KiB
C++
639 lines
25 KiB
C++
#include "GenerateSubGraph.hpp"
|
|
#include "PostTreatUtils.hpp"
|
|
#include <MNN/MNNDefine.h>
|
|
#include "Program.hpp"
|
|
#include <MNN/expr/ExprCreator.hpp>
|
|
#include <sstream>
|
|
namespace MNN {
|
|
using NodeVector = std::vector<std::unique_ptr<OpT>>;
|
|
|
|
struct ClusterNode {
|
|
std::string name;
|
|
NodeVector nodes;
|
|
bool hasLoop = false;
|
|
bool hasSwitch = false;
|
|
bool hasMerge = false;
|
|
std::vector<std::shared_ptr<ClusterNode>> children;
|
|
ClusterNode* parent = nullptr;
|
|
};
|
|
|
|
static inline std::vector<std::string> RSplitString(const std::string& name,
|
|
const std::string& sp) {
|
|
std::vector<std::string> splits;
|
|
size_t pos = name.rfind(sp);
|
|
if (pos != std::string::npos) {
|
|
splits.push_back(name.substr(0, pos));
|
|
splits.push_back(name.substr(pos + 1));
|
|
} else {
|
|
splits.push_back(name);
|
|
}
|
|
return std::move(splits);
|
|
}
|
|
|
|
static void _makeClusterNode(const std::string& name, std::map<std::string, std::shared_ptr<ClusterNode>>& clusters, std::vector<std::shared_ptr<ClusterNode>>& rootClusters) {
|
|
if (clusters.find(name) != clusters.end()) {
|
|
return;
|
|
}
|
|
std::shared_ptr<ClusterNode> newNode(new ClusterNode);
|
|
newNode->name = name;
|
|
clusters.emplace(name, newNode);
|
|
auto parent = RSplitString(name, "/").at(0);
|
|
if (parent == name) {
|
|
rootClusters.emplace_back(newNode);
|
|
return;
|
|
}
|
|
_makeClusterNode(parent, clusters, rootClusters);
|
|
newNode->parent = clusters[parent].get();
|
|
clusters[parent]->children.emplace_back(newNode);
|
|
return;
|
|
}
|
|
|
|
static void _mergeSubGraph(std::shared_ptr<ClusterNode> node) {
|
|
for (auto c : node->children) {
|
|
_mergeSubGraph(c);
|
|
}
|
|
bool merge = false;
|
|
auto children = std::move(node->children);
|
|
node->children.clear();
|
|
for (auto c : children) {
|
|
if (c->hasLoop || c->hasMerge) {
|
|
// Can't merge
|
|
node->children.emplace_back(c);
|
|
continue;
|
|
}
|
|
for (auto& o : c->nodes) {
|
|
node->nodes.emplace_back(std::move(o));
|
|
}
|
|
node->children.insert(node->children.end(), c->children.begin(), c->children.end());
|
|
}
|
|
}
|
|
|
|
static void _printSubGraph(std::shared_ptr<ClusterNode> node, int indent = 0) {
|
|
for (int v=0; v<indent; ++v) {
|
|
MNN_PRINT(" ");
|
|
}
|
|
MNN_PRINT("%s\n", node->name.c_str());
|
|
for (auto c : node->children) {
|
|
_printSubGraph(c, indent+4);
|
|
}
|
|
}
|
|
static bool _isControlOp(const OpT* op) {
|
|
std::set<std::string> controlOps{"Merge", "Switch", "LoopCond", "Enter", "Exit", "NextIteration"};
|
|
return op->type == OpType_Extra && controlOps.find(op->main.AsExtra()->type) != controlOps.end();
|
|
}
|
|
|
|
std::vector<std::unique_ptr<OpT>> _makeCond(std::shared_ptr<ClusterNode> cNode, MNN::NetT* netT, const std::map<std::string, int>& originTensorIndexes) {
|
|
std::vector<std::unique_ptr<OpT>> res;
|
|
std::unique_ptr<OpT> condOp(new OpT);
|
|
condOp->type = OpType_If;
|
|
condOp->main.type = OpParameter_IfParam;
|
|
condOp->main.value = new IfParamT;
|
|
condOp->name = cNode->name;
|
|
|
|
// Find cond tensor
|
|
std::set<int> condTensorIndexes;
|
|
for (int i=0; i<cNode->nodes.size(); ++i) {
|
|
auto& op = cNode->nodes[i];
|
|
if (op->type == OpType_Extra && op->main.AsExtra()->type == "Switch") {
|
|
// Find outside condIndex
|
|
auto originIndex = op->inputIndexes[1];
|
|
bool find = false;
|
|
do {
|
|
for (auto& subop : cNode->nodes) {
|
|
for (auto out : subop->outputIndexes) {
|
|
if (out == originIndex) {
|
|
find = true;
|
|
break;
|
|
}
|
|
}
|
|
if (find) {
|
|
break;
|
|
}
|
|
}
|
|
} while (false);
|
|
if (!find) {
|
|
condTensorIndexes.insert(originIndex);
|
|
}
|
|
}
|
|
}
|
|
MNN_ASSERT(condTensorIndexes.size() > 0);
|
|
int condTensorIndex = *condTensorIndexes.begin();
|
|
// Find dependency for condTensors
|
|
if (condTensorIndexes.size() > 1) {
|
|
MNN_ASSERT(cNode->parent != nullptr);
|
|
for (auto index : condTensorIndexes) {
|
|
bool valid = true;
|
|
for (auto& op : cNode->parent->nodes) {
|
|
if (op->inputIndexes.size() > 1 && op->inputIndexes[1] == index) {
|
|
valid = false;
|
|
break;
|
|
}
|
|
}
|
|
if (valid) {
|
|
condTensorIndex = index;
|
|
}
|
|
}
|
|
// Remove Switch For Parent Switch
|
|
bool needCheck = true;
|
|
std::map<int, int> replaceTensor;
|
|
needCheck = true;
|
|
while (needCheck) {
|
|
needCheck = false;
|
|
auto nodes = std::move(cNode->nodes);
|
|
for (int i = 0; i < nodes.size(); ++i) {
|
|
if ((!needCheck) && nodes[i]->type == OpType_Extra && nodes[i]->main.AsExtra()->type == "Switch") {
|
|
if (nodes[i]->inputIndexes[1] != condTensorIndex) {
|
|
// Once Time remove only one switch
|
|
for (auto output : nodes[i]->outputIndexes) {
|
|
replaceTensor.insert(std::make_pair(output, nodes[i]->inputIndexes[0]));
|
|
}
|
|
needCheck = true;
|
|
continue;
|
|
}
|
|
}
|
|
cNode->nodes.emplace_back(std::move(nodes[i]));
|
|
}
|
|
for (auto& op : cNode->nodes) {
|
|
for (int i = 0; i < op->inputIndexes.size(); ++i) {
|
|
if (replaceTensor.find(op->inputIndexes[i]) != replaceTensor.end()) {
|
|
op->inputIndexes[i] = replaceTensor[op->inputIndexes[i]];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
//0: no use, 1: left, 2: right, -1: switch, -2: merge
|
|
std::vector<int> opMask(cNode->nodes.size(), 0);
|
|
std::vector<int> tensorMask(netT->tensorName.size(), 0);
|
|
for (int i=0; i<cNode->nodes.size(); ++i) {
|
|
if (opMask[i] != 0) {
|
|
continue;
|
|
}
|
|
auto& op = cNode->nodes[i];
|
|
if (op->type == OpType_Extra && op->main.AsExtra()->type == "Switch") {
|
|
tensorMask[op->outputIndexes[0]] = 2;
|
|
if (op->outputIndexes.size() > 1) {
|
|
tensorMask[op->outputIndexes[1]] = 1;
|
|
}
|
|
opMask[i] = -1;
|
|
continue;
|
|
}
|
|
if (op->type == OpType_Extra && op->main.AsExtra()->type == "Merge") {
|
|
tensorMask[op->outputIndexes[0]] = -2;
|
|
opMask[i] = -2;
|
|
condOp->outputIndexes.emplace_back(op->outputIndexes[0]);
|
|
continue;
|
|
}
|
|
bool valid = false;
|
|
for (auto index : op->inputIndexes) {
|
|
if (tensorMask[index] > 0) {
|
|
opMask[i] = tensorMask[index];
|
|
valid = true;
|
|
}
|
|
}
|
|
for (auto index : op->outputIndexes) {
|
|
if (tensorMask[index] > 0) {
|
|
MNN_ASSERT(opMask[i] <= 0 || opMask[i] == tensorMask[index]);
|
|
opMask[i] = tensorMask[index];
|
|
valid = true;
|
|
}
|
|
}
|
|
if (valid) {
|
|
for (auto index : op->inputIndexes) {
|
|
tensorMask[index] = opMask[i];
|
|
}
|
|
for (auto index : op->outputIndexes) {
|
|
tensorMask[index] = opMask[i];
|
|
}
|
|
}
|
|
}
|
|
// Remove Switch
|
|
bool needCheck = true;
|
|
std::map<int, int> replaceTensor;
|
|
while (needCheck) {
|
|
needCheck = false;
|
|
auto nodes = std::move(cNode->nodes);
|
|
for (int i = 0; i < nodes.size(); ++i) {
|
|
if (nodes[i]->type == OpType_Extra && nodes[i]->main.AsExtra()->type == "Switch" && (!needCheck)) {
|
|
// Once Time remove only one switch
|
|
for (auto output : nodes[i]->outputIndexes) {
|
|
replaceTensor.insert(std::make_pair(output, nodes[i]->inputIndexes[0]));
|
|
}
|
|
needCheck = true;
|
|
continue;
|
|
}
|
|
cNode->nodes.emplace_back(std::move(nodes[i]));
|
|
}
|
|
for (auto& op : cNode->nodes) {
|
|
for (int i = 0; i < op->inputIndexes.size(); ++i) {
|
|
if (replaceTensor.find(op->inputIndexes[i]) != replaceTensor.end()) {
|
|
op->inputIndexes[i] = replaceTensor[op->inputIndexes[i]];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
std::map<int, Express::VARP> varMap;
|
|
std::set<OpT*> invalidSet;
|
|
std::vector<int> inputIndexes;
|
|
std::set<int> extraInputIndexes;
|
|
std::vector<int> leftOutputs;
|
|
std::vector<int> rightOutputs;
|
|
std::vector<std::string> mergeNames;
|
|
for (auto& node : cNode->nodes) {
|
|
if (node->type == OpType_Extra && node->main.AsExtra()->type == "Merge") {
|
|
mergeNames.emplace_back(node->name);
|
|
if (tensorMask[node->inputIndexes[0]] == 1) {
|
|
leftOutputs.emplace_back(node->inputIndexes[0]);
|
|
rightOutputs.emplace_back(node->inputIndexes[1]);
|
|
} else {
|
|
leftOutputs.emplace_back(node->inputIndexes[1]);
|
|
rightOutputs.emplace_back(node->inputIndexes[0]);
|
|
}
|
|
continue;
|
|
}
|
|
Express::Program::createUnit(varMap, inputIndexes, cNode->nodes, node.get(), netT, invalidSet, extraInputIndexes);
|
|
}
|
|
auto makeSubGraph = [&](const std::vector<int>& index) {
|
|
std::vector<Express::VARP> out;
|
|
for (auto l : index) {
|
|
auto iter = varMap.find(l);
|
|
if (iter != varMap.end()) {
|
|
out.emplace_back(iter->second);
|
|
} else {
|
|
auto tempInput = Express::_Input();
|
|
tempInput->setName(netT->tensorName[l]);
|
|
out.emplace_back(tempInput);
|
|
extraInputIndexes.insert(l);
|
|
}
|
|
}
|
|
std::unique_ptr<NetT> newT(new NetT);
|
|
Express::Variable::save(out, newT.get());
|
|
std::unique_ptr<SubGraphProtoT> subGraph(new SubGraphProtoT);
|
|
subGraph->tensors = std::move(newT->tensorName);
|
|
subGraph->nodes = std::move(newT->oplists);
|
|
for (int i = 0; i < subGraph->nodes.size(); ++i) {
|
|
if (subGraph->nodes[i]->type == OpType_Input) {
|
|
subGraph->inputs.emplace_back(i);
|
|
}
|
|
}
|
|
for (auto l : index) {
|
|
auto& outputName = netT->tensorName[l];
|
|
for (int i = 0; i < subGraph->tensors.size(); ++i) {
|
|
if (subGraph->tensors[i] == outputName) {
|
|
subGraph->outputs.emplace_back(i);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
return subGraph;
|
|
};
|
|
{
|
|
auto leftGraph = makeSubGraph(leftOutputs);
|
|
leftGraph->name = cNode->name + "/then";
|
|
condOp->main.AsIfParam()->then_graph = leftGraph->name;
|
|
netT->subgraphs.emplace_back(std::move(leftGraph));
|
|
|
|
auto rightGraph = makeSubGraph(rightOutputs);
|
|
rightGraph->name = cNode->name + "/else";
|
|
condOp->main.AsIfParam()->else_graph = rightGraph->name;
|
|
netT->subgraphs.emplace_back(std::move(rightGraph));
|
|
}
|
|
condOp->inputIndexes.emplace_back(condTensorIndex);
|
|
std::unique_ptr<StringVecT> inputT(new StringVecT);
|
|
inputT->data.emplace_back(netT->tensorName[condTensorIndex]);
|
|
condOp->main.AsIfParam()->aliases_inputs.emplace_back(std::move(inputT));
|
|
extraInputIndexes.erase(condTensorIndex);
|
|
for (auto index : extraInputIndexes) {
|
|
condOp->inputIndexes.emplace_back(index);
|
|
std::unique_ptr<StringVecT> inputT(new StringVecT);
|
|
inputT->data.emplace_back(netT->tensorName[index]);
|
|
condOp->main.AsIfParam()->aliases_inputs.emplace_back(std::move(inputT));
|
|
}
|
|
for (int i = 0; i < leftOutputs.size(); ++i) {
|
|
std::unique_ptr<StringVecT> outputPari(new StringVecT);
|
|
outputPari->data.emplace_back(netT->tensorName[leftOutputs[i]]);
|
|
outputPari->data.emplace_back(netT->tensorName[rightOutputs[i]]);
|
|
condOp->main.AsIfParam()->aliases_outputs.emplace_back(std::move(outputPari));
|
|
}
|
|
// Compability for old usage
|
|
for (int i = 0; i < condOp->outputIndexes.size(); ++i) {
|
|
std::ostringstream newName;
|
|
newName << condOp->name << ":" << i;
|
|
netT->tensorName[condOp->outputIndexes[i]] = newName.str();
|
|
}
|
|
res.emplace_back(std::move(condOp));
|
|
cNode->nodes.clear();
|
|
return res;
|
|
}
|
|
|
|
|
|
std::vector<std::unique_ptr<OpT>> _makeWhile(std::shared_ptr<ClusterNode> cNode, MNN::NetT* netT, const std::map<std::string, int>& originTensorIndexes) {
|
|
std::vector<std::unique_ptr<OpT>> res;
|
|
// Remove switch and find LoopCond
|
|
int loopCond = -1;
|
|
{
|
|
std::map<int, int> replaceTensor;
|
|
auto childs = std::move(cNode->nodes);
|
|
for (auto& op : childs) {
|
|
if (op->type == OpType_Extra && op->main.AsExtra()->type == "Switch") {
|
|
for (auto o : op->outputIndexes) {
|
|
replaceTensor.insert(std::make_pair(o, op->inputIndexes[0]));
|
|
}
|
|
continue;
|
|
}
|
|
if (op->type == OpType_Extra && op->main.AsExtra()->type == "LoopCond") {
|
|
loopCond = op->outputIndexes[0];
|
|
}
|
|
cNode->nodes.emplace_back(std::move(op));
|
|
}
|
|
for (auto& op : cNode->nodes) {
|
|
for (int i = 0; i < op->inputIndexes.size(); ++i) {
|
|
if (replaceTensor.find(op->inputIndexes[i]) != replaceTensor.end()) {
|
|
op->inputIndexes[i] = replaceTensor[op->inputIndexes[i]];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
MNN_ASSERT(loopCond != -1);
|
|
|
|
// Generate Condition Graph
|
|
std::map<int, Express::VARP> varMap;
|
|
|
|
// While Op
|
|
std::unique_ptr<SubGraphProtoT> condGraph(new SubGraphProtoT);
|
|
condGraph->name = cNode->name + "/cond";
|
|
std::unique_ptr<SubGraphProtoT> bodyGraph(new SubGraphProtoT);
|
|
bodyGraph->name = cNode->name + "/body";
|
|
|
|
std::unique_ptr<OpT> whileOpU(new OpT);
|
|
auto whileOp = whileOpU.get();// For easy to debug
|
|
whileOp->type = OpType_While;
|
|
whileOp->main.type = OpParameter_WhileParam;
|
|
whileOp->main.value = new WhileParamT;
|
|
whileOp->name = cNode->name;
|
|
auto whileParam = whileOp->main.AsWhileParam();
|
|
whileParam->cond_graph = condGraph->name;
|
|
whileParam->body_graph = bodyGraph->name;
|
|
|
|
std::set<int> extraInputIndexes;
|
|
// Remove Merge and find body
|
|
std::vector<int> bodyUpdate;
|
|
std::set<std::string> bodyOutputNames;
|
|
{
|
|
std::vector<std::pair<int, int>> updateIndexes;
|
|
auto childs = std::move(cNode->nodes);
|
|
std::map<int, int> replaceTensor;
|
|
std::set<int> updateToTensors;
|
|
for (auto& op : childs) {
|
|
if (op->type == OpType_Extra && op->main.AsExtra()->type == "Merge") {
|
|
int updateFromIdx = op->inputIndexes[1], updateToIdx = op->inputIndexes[0];
|
|
// if tensor_x is at outside of loop and used by two op, and these two op
|
|
// are all update data, so need copy tensor_x to tensor_x_copy.
|
|
if (updateToTensors.find(updateToIdx) != updateToTensors.end()) {
|
|
std::unique_ptr<OpT> copyOp(new OpT);
|
|
copyOp->type = OpType_Concat;
|
|
copyOp->inputIndexes.push_back(updateToIdx);
|
|
auto opName = netT->tensorName[updateToIdx] + "_copy";
|
|
updateToIdx = netT->tensorName.size();
|
|
copyOp->outputIndexes.push_back(updateToIdx);
|
|
netT->tensorName.push_back(opName);
|
|
netT->tensorNumber++;
|
|
res.emplace_back(std::move(copyOp));
|
|
extraInputIndexes.insert(updateToIdx);
|
|
}
|
|
updateToTensors.insert(updateToIdx);
|
|
updateIndexes.emplace_back(std::make_pair(updateFromIdx, updateToIdx));
|
|
replaceTensor.insert(std::make_pair(op->outputIndexes[0], updateToIdx));
|
|
bodyUpdate.emplace_back(updateFromIdx);
|
|
bodyOutputNames.insert(netT->tensorName[updateFromIdx]);
|
|
continue;
|
|
}
|
|
cNode->nodes.emplace_back(std::move(op));
|
|
}
|
|
for (auto& op : cNode->nodes) {
|
|
for (int i = 0; i < op->inputIndexes.size(); ++i) {
|
|
if (replaceTensor.find(op->inputIndexes[i]) != replaceTensor.end()) {
|
|
op->inputIndexes[i] = replaceTensor[op->inputIndexes[i]];
|
|
}
|
|
}
|
|
}
|
|
for (auto& p : updateIndexes) {
|
|
if (replaceTensor.find(p.first) != replaceTensor.end()) {
|
|
p.first = replaceTensor[p.first];
|
|
}
|
|
if (replaceTensor.find(p.second) != replaceTensor.end()) {
|
|
p.second = replaceTensor[p.second];
|
|
}
|
|
}
|
|
for (auto& p : updateIndexes) {
|
|
std::unique_ptr<StringVecT> updateName(new StringVecT);
|
|
updateName->data.emplace_back(netT->tensorName[p.first]);
|
|
updateName->data.emplace_back(netT->tensorName[p.second]);
|
|
whileParam->aliases_updates.emplace_back(std::move(updateName));
|
|
}
|
|
}
|
|
|
|
// Get output
|
|
for (auto& op : cNode->nodes) {
|
|
if (op->type != OpType_Extra) {
|
|
continue;
|
|
}
|
|
if (op->main.AsExtra()->type == "Exit") {
|
|
whileOp->outputIndexes.emplace_back(op->outputIndexes[0]);
|
|
whileParam->aliases_outputs.emplace_back(netT->tensorName[op->inputIndexes[0]]);
|
|
bodyOutputNames.insert(netT->tensorName[op->inputIndexes[0]]);
|
|
}
|
|
}
|
|
|
|
// Create Loop Cond
|
|
std::set<OpT*> invalidSet;
|
|
std::vector<int> inputIndexes;
|
|
for (auto& node : cNode->nodes) {
|
|
Express::Program::createUnit(varMap, inputIndexes, cNode->nodes, node.get(), netT, invalidSet, extraInputIndexes);
|
|
}
|
|
for (auto index : extraInputIndexes) {
|
|
std::unique_ptr<StringVecT> inputNames(new StringVecT);
|
|
inputNames->data.emplace_back(netT->tensorName[index]);
|
|
whileParam->aliases_inputs.emplace_back(std::move(inputNames));
|
|
whileOp->inputIndexes.emplace_back(index);
|
|
}
|
|
{
|
|
std::unique_ptr<NetT> condNet(new NetT);
|
|
Express::Variable::save({varMap[loopCond]}, condNet.get());
|
|
for (auto& op : condNet->oplists) {
|
|
if (op->type == OpType_Extra && op->main.AsExtra()->type == "LoopCond") {
|
|
condGraph->outputs.emplace_back(op->inputIndexes[0]);
|
|
continue;
|
|
}
|
|
if (op->type == OpType_Input) {
|
|
condGraph->inputs.emplace_back(op->outputIndexes[0]);
|
|
}
|
|
condGraph->nodes.emplace_back(std::move(op));
|
|
}
|
|
condGraph->tensors = std::move(condNet->tensorName);
|
|
MNN_ASSERT(condGraph->outputs.size() > 0);
|
|
}
|
|
{
|
|
std::unique_ptr<NetT> bodyNet(new NetT);
|
|
std::vector<Express::VARP> bodyOutputs;
|
|
for (auto b : bodyUpdate) {
|
|
if (varMap.find(b) != varMap.end()) {
|
|
bodyOutputs.emplace_back(varMap[b]);
|
|
}
|
|
}
|
|
Express::Variable::save(bodyOutputs, bodyNet.get());
|
|
for (auto& op : bodyNet->oplists) {
|
|
if (op->type == OpType_Input) {
|
|
bodyGraph->inputs.emplace_back(op->outputIndexes[0]);
|
|
}
|
|
for (auto o : op->outputIndexes) {
|
|
if (bodyOutputNames.find(bodyNet->tensorName[o]) != bodyOutputNames.end()) {
|
|
bodyGraph->outputs.emplace_back(o);
|
|
}
|
|
}
|
|
bodyGraph->nodes.emplace_back(std::move(op));
|
|
}
|
|
bodyGraph->tensors = std::move(bodyNet->tensorName);
|
|
}
|
|
{
|
|
// Const op needed update turn to Input
|
|
auto turnConst = [&](SubGraphProtoT* subGraph) {
|
|
for (auto& s : whileParam->aliases_updates) {
|
|
auto& second = s->data[1];
|
|
for (int i = 0; i < subGraph->nodes.size(); ++i) {
|
|
auto& op = subGraph->nodes[i];
|
|
if (OpType_Const != op->type) {
|
|
continue;
|
|
}
|
|
if (subGraph->tensors[op->outputIndexes[0]] == second) {
|
|
// Const move outside
|
|
auto opPtr = op.get();
|
|
res.emplace_back(std::move(op));
|
|
subGraph->nodes[i].reset(new OpT);
|
|
subGraph->nodes[i]->type = OpType_Input;
|
|
subGraph->nodes[i]->main.type = OpParameter_Input;
|
|
subGraph->nodes[i]->main.value = new InputT;
|
|
subGraph->nodes[i]->main.AsInput()->dims = opPtr->main.AsBlob()->dims;
|
|
subGraph->nodes[i]->main.AsInput()->dtype = opPtr->main.AsBlob()->dataType;
|
|
subGraph->nodes[i]->main.AsInput()->dformat = opPtr->main.AsBlob()->dataFormat;
|
|
subGraph->nodes[i]->outputIndexes = opPtr->outputIndexes;
|
|
opPtr->outputIndexes[0] = originTensorIndexes.find(second)->second;
|
|
std::unique_ptr<StringVecT> newVecT(new StringVecT);
|
|
newVecT->data.emplace_back(second);
|
|
whileParam->aliases_inputs.emplace_back(std::move(newVecT));
|
|
whileOp->inputIndexes.emplace_back(opPtr->outputIndexes[0]);
|
|
}
|
|
}
|
|
}
|
|
};
|
|
turnConst(condGraph.get());
|
|
turnConst(bodyGraph.get());
|
|
}
|
|
//FUNC_PRINT_ALL(whileOp->name.c_str(), s);
|
|
netT->subgraphs.emplace_back(std::move(condGraph));
|
|
netT->subgraphs.emplace_back(std::move(bodyGraph));
|
|
res.emplace_back(std::move(whileOpU));
|
|
cNode->nodes.clear();
|
|
return res;
|
|
}
|
|
|
|
static std::vector<std::unique_ptr<OpT>> _makeSubGraph(std::shared_ptr<ClusterNode> cNode, MNN::NetT* netT, const std::map<std::string, int>& t) {
|
|
// Make Subgraph In order, first make children, second make parent
|
|
for (auto c : cNode->children) {
|
|
auto opList = std::move(_makeSubGraph(c, netT, t));
|
|
for (auto&& op : opList) {
|
|
cNode->nodes.emplace_back(std::move(op));
|
|
}
|
|
}
|
|
if (cNode->hasLoop) {
|
|
return _makeWhile(cNode, netT, t);
|
|
}
|
|
if (cNode->hasMerge) {
|
|
return _makeCond(cNode, netT, t);
|
|
}
|
|
return {};
|
|
}
|
|
|
|
int GenerateSubGraph(std::unique_ptr<MNN::NetT>& netT) {
|
|
// Remove unuseful op before cluster
|
|
std::vector<std::string> passes = {
|
|
"RemoveUnusefulOp",
|
|
};
|
|
for (auto pass : passes) {
|
|
auto convert = PostConverter::get(pass);
|
|
if (nullptr == convert) {
|
|
continue;
|
|
}
|
|
convert->onExecute(netT);
|
|
}
|
|
bool hasControlFlow = false;
|
|
for (auto& op : netT->oplists) {
|
|
if (_isControlOp(op.get())) {
|
|
hasControlFlow = true;
|
|
break;
|
|
}
|
|
}
|
|
if (!hasControlFlow) {
|
|
return 0;
|
|
}
|
|
MNN_PRINT("The modle has control flow, please use MNN::Module to run it\n");
|
|
|
|
// We broadly divided all nodes into clusters by the prefix of the node
|
|
// name, and each cluster belongs to one of the tree categories,
|
|
// Normal, Condition or WhileLoop.
|
|
// The nodes which have the same name prefix maybe belong to the same
|
|
// cluster. The nodes that type is `Condition` maybe belong to a condition
|
|
// subgraph. The nodes that type is `WhileLoop` maybe belong to a while loop
|
|
// subgraph.
|
|
std::map<std::string, std::shared_ptr<ClusterNode>> clusters;
|
|
std::vector<std::shared_ptr<ClusterNode>> rootClusters;
|
|
bool hasControlflow = false;
|
|
for (auto& node : netT->oplists) {
|
|
std::string name = RSplitString(node->name, "/").at(0);
|
|
_makeClusterNode(name, clusters, rootClusters);
|
|
auto it = clusters.find(name);
|
|
if (node->type == OpType_Extra) {
|
|
auto type = node->main.AsExtra()->type;
|
|
if (type == "LoopCond") {
|
|
hasControlflow = true;
|
|
it->second->hasLoop = true;
|
|
}
|
|
else if (type == "Switch") {
|
|
hasControlflow = true;
|
|
it->second->hasSwitch = true;
|
|
}
|
|
else if (type == "Merge") {
|
|
hasControlflow = true;
|
|
it->second->hasMerge = true;
|
|
}
|
|
}
|
|
it->second->nodes.emplace_back(std::move(node));
|
|
}
|
|
netT->oplists.clear();
|
|
std::map<std::string, int> tensorNameMap;
|
|
for (int i=0; i<netT->tensorName.size(); ++i) {
|
|
tensorNameMap[netT->tensorName[i]] = i;
|
|
}
|
|
for (auto n : rootClusters) {
|
|
_mergeSubGraph(n);
|
|
}
|
|
#ifdef MNN_PRINT_SUBGRAPH
|
|
for (auto n : rootClusters) {
|
|
_printSubGraph(n);
|
|
}
|
|
#endif
|
|
for (auto n : rootClusters) {
|
|
auto controlOp = _makeSubGraph(n, netT.get(), tensorNameMap);
|
|
for (auto& c : n->nodes) {
|
|
netT->oplists.emplace_back(std::move(c));
|
|
}
|
|
for (auto& op : controlOp) {
|
|
netT->oplists.emplace_back(std::move(op));
|
|
}
|
|
}
|
|
return 0;
|
|
}
|
|
}
|