| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  | #include "GenerateSubGraph.hpp"
 | 
					
						
							|  |  |  | #include "PostTreatUtils.hpp"
 | 
					
						
							|  |  |  | #include <MNN/MNNDefine.h>
 | 
					
						
							|  |  |  | #include "Program.hpp"
 | 
					
						
							|  |  |  | #include <MNN/expr/ExprCreator.hpp>
 | 
					
						
							|  |  |  | #include <sstream>
 | 
					
						
							|  |  |  | namespace MNN { | 
					
						
							|  |  |  | using NodeVector = std::vector<std::unique_ptr<OpT>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | struct ClusterNode { | 
					
						
							|  |  |  |     std::string name; | 
					
						
							|  |  |  |     NodeVector nodes; | 
					
						
							|  |  |  |     bool hasLoop = false; | 
					
						
							|  |  |  |     bool hasSwitch = false; | 
					
						
							|  |  |  |     bool hasMerge = false; | 
					
						
							|  |  |  |     std::vector<std::shared_ptr<ClusterNode>> children; | 
					
						
							|  |  |  |     ClusterNode* parent = nullptr; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | static inline std::vector<std::string> RSplitString(const std::string& name, | 
					
						
							|  |  |  |                                                     const std::string& sp) { | 
					
						
							|  |  |  |     std::vector<std::string> splits; | 
					
						
							|  |  |  |     size_t pos = name.rfind(sp); | 
					
						
							|  |  |  |     if (pos != std::string::npos) { | 
					
						
							|  |  |  |         splits.push_back(name.substr(0, pos)); | 
					
						
							|  |  |  |         splits.push_back(name.substr(pos + 1)); | 
					
						
							|  |  |  |     } else { | 
					
						
							|  |  |  |         splits.push_back(name); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return std::move(splits); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | static void _makeClusterNode(const std::string& name, std::map<std::string, std::shared_ptr<ClusterNode>>& clusters, std::vector<std::shared_ptr<ClusterNode>>& rootClusters) { | 
					
						
							|  |  |  |     if (clusters.find(name) != clusters.end()) { | 
					
						
							|  |  |  |         return; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     std::shared_ptr<ClusterNode> newNode(new ClusterNode); | 
					
						
							|  |  |  |     newNode->name = name; | 
					
						
							|  |  |  |     clusters.emplace(name, newNode); | 
					
						
							|  |  |  |     auto parent = RSplitString(name, "/").at(0); | 
					
						
							|  |  |  |     if (parent == name) { | 
					
						
							|  |  |  |         rootClusters.emplace_back(newNode); | 
					
						
							|  |  |  |         return; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     _makeClusterNode(parent, clusters, rootClusters); | 
					
						
							|  |  |  |     newNode->parent = clusters[parent].get(); | 
					
						
							|  |  |  |     clusters[parent]->children.emplace_back(newNode); | 
					
						
							|  |  |  |     return; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | static void _mergeSubGraph(std::shared_ptr<ClusterNode> node) { | 
					
						
							|  |  |  |     for (auto c : node->children) { | 
					
						
							|  |  |  |         _mergeSubGraph(c); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     bool merge = false; | 
					
						
							|  |  |  |     auto children = std::move(node->children); | 
					
						
							|  |  |  |     node->children.clear(); | 
					
						
							|  |  |  |     for (auto c : children) { | 
					
						
							|  |  |  |         if (c->hasLoop || c->hasMerge) { | 
					
						
							|  |  |  |             // Can't merge
 | 
					
						
							|  |  |  |             node->children.emplace_back(c); | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (auto& o : c->nodes) { | 
					
						
							|  |  |  |             node->nodes.emplace_back(std::move(o)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         node->children.insert(node->children.end(), c->children.begin(), c->children.end()); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | static void _printSubGraph(std::shared_ptr<ClusterNode> node, int indent = 0) { | 
					
						
							|  |  |  |     for (int v=0; v<indent; ++v) { | 
					
						
							|  |  |  |         MNN_PRINT(" "); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     MNN_PRINT("%s\n", node->name.c_str()); | 
					
						
							|  |  |  |     for (auto c : node->children) { | 
					
						
							|  |  |  |         _printSubGraph(c, indent+4); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | static bool _isControlOp(const OpT* op) { | 
					
						
							|  |  |  |     std::set<std::string> controlOps{"Merge", "Switch", "LoopCond", "Enter", "Exit", "NextIteration"}; | 
					
						
							|  |  |  |     return op->type == OpType_Extra && controlOps.find(op->main.AsExtra()->type) != controlOps.end(); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | std::vector<std::unique_ptr<OpT>> _makeCond(std::shared_ptr<ClusterNode> cNode, MNN::NetT* netT, const std::map<std::string, int>& originTensorIndexes) { | 
					
						
							|  |  |  |     std::vector<std::unique_ptr<OpT>> res; | 
					
						
							|  |  |  |     std::unique_ptr<OpT> condOp(new OpT); | 
					
						
							|  |  |  |     condOp->type = OpType_If; | 
					
						
							|  |  |  |     condOp->main.type = OpParameter_IfParam; | 
					
						
							|  |  |  |     condOp->main.value = new IfParamT; | 
					
						
							|  |  |  |     condOp->name = cNode->name; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Find cond tensor
 | 
					
						
							|  |  |  |     std::set<int> condTensorIndexes; | 
					
						
							|  |  |  |     for (int i=0; i<cNode->nodes.size(); ++i) { | 
					
						
							|  |  |  |         auto& op = cNode->nodes[i]; | 
					
						
							|  |  |  |         if (op->type == OpType_Extra && op->main.AsExtra()->type == "Switch") { | 
					
						
							|  |  |  |             // Find outside condIndex
 | 
					
						
							|  |  |  |             auto originIndex = op->inputIndexes[1]; | 
					
						
							|  |  |  |             bool find = false; | 
					
						
							|  |  |  |             do { | 
					
						
							|  |  |  |                 for (auto& subop : cNode->nodes) { | 
					
						
							|  |  |  |                     for (auto out : subop->outputIndexes) { | 
					
						
							|  |  |  |                         if (out == originIndex) { | 
					
						
							|  |  |  |                             find = true; | 
					
						
							|  |  |  |                             break; | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                     if (find) { | 
					
						
							|  |  |  |                         break; | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } while (false); | 
					
						
							|  |  |  |             if (!find) { | 
					
						
							|  |  |  |                 condTensorIndexes.insert(originIndex); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     MNN_ASSERT(condTensorIndexes.size() > 0); | 
					
						
							|  |  |  |     int condTensorIndex = *condTensorIndexes.begin(); | 
					
						
							|  |  |  |     // Find dependency for condTensors
 | 
					
						
							|  |  |  |     if (condTensorIndexes.size() > 1) { | 
					
						
							|  |  |  |         MNN_ASSERT(cNode->parent != nullptr); | 
					
						
							|  |  |  |         for (auto index : condTensorIndexes) { | 
					
						
							|  |  |  |             bool valid = true; | 
					
						
							|  |  |  |             for (auto& op : cNode->parent->nodes) { | 
					
						
							|  |  |  |                 if (op->inputIndexes.size() > 1 && op->inputIndexes[1] == index) { | 
					
						
							|  |  |  |                     valid = false; | 
					
						
							|  |  |  |                     break; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             if (valid) { | 
					
						
							|  |  |  |                 condTensorIndex = index; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         // Remove Switch For Parent Switch
 | 
					
						
							|  |  |  |         bool needCheck = true; | 
					
						
							|  |  |  |         std::map<int, int> replaceTensor; | 
					
						
							|  |  |  |         needCheck = true; | 
					
						
							|  |  |  |         while (needCheck) { | 
					
						
							|  |  |  |             needCheck = false; | 
					
						
							|  |  |  |             auto nodes = std::move(cNode->nodes); | 
					
						
							|  |  |  |             for (int i = 0; i < nodes.size(); ++i) { | 
					
						
							|  |  |  |                 if ((!needCheck) && nodes[i]->type == OpType_Extra && nodes[i]->main.AsExtra()->type == "Switch") { | 
					
						
							|  |  |  |                     if (nodes[i]->inputIndexes[1] != condTensorIndex) { | 
					
						
							|  |  |  |                         // Once Time remove only one switch
 | 
					
						
							|  |  |  |                         for (auto output : nodes[i]->outputIndexes) { | 
					
						
							|  |  |  |                             replaceTensor.insert(std::make_pair(output, nodes[i]->inputIndexes[0])); | 
					
						
							|  |  |  |                         } | 
					
						
							|  |  |  |                         needCheck = true; | 
					
						
							|  |  |  |                         continue; | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 cNode->nodes.emplace_back(std::move(nodes[i])); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             for (auto& op : cNode->nodes) { | 
					
						
							|  |  |  |                 for (int i = 0; i < op->inputIndexes.size(); ++i) { | 
					
						
							|  |  |  |                     if (replaceTensor.find(op->inputIndexes[i]) != replaceTensor.end()) { | 
					
						
							|  |  |  |                         op->inputIndexes[i] = replaceTensor[op->inputIndexes[i]]; | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     //0: no use, 1: left, 2: right, -1: switch, -2: merge
 | 
					
						
							|  |  |  |     std::vector<int> opMask(cNode->nodes.size(), 0); | 
					
						
							|  |  |  |     std::vector<int> tensorMask(netT->tensorName.size(), 0); | 
					
						
							|  |  |  |     for (int i=0; i<cNode->nodes.size(); ++i) { | 
					
						
							|  |  |  |         if (opMask[i] != 0) { | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         auto& op = cNode->nodes[i]; | 
					
						
							|  |  |  |         if (op->type == OpType_Extra && op->main.AsExtra()->type == "Switch") { | 
					
						
							|  |  |  |             tensorMask[op->outputIndexes[0]] = 2; | 
					
						
							|  |  |  |             if (op->outputIndexes.size() > 1) { | 
					
						
							|  |  |  |                 tensorMask[op->outputIndexes[1]] = 1; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             opMask[i] = -1; | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (op->type == OpType_Extra && op->main.AsExtra()->type == "Merge") { | 
					
						
							|  |  |  |             tensorMask[op->outputIndexes[0]] = -2; | 
					
						
							|  |  |  |             opMask[i] = -2; | 
					
						
							|  |  |  |             condOp->outputIndexes.emplace_back(op->outputIndexes[0]); | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         bool valid = false; | 
					
						
							|  |  |  |         for (auto index : op->inputIndexes) { | 
					
						
							|  |  |  |             if (tensorMask[index] > 0) { | 
					
						
							|  |  |  |                 opMask[i] = tensorMask[index]; | 
					
						
							|  |  |  |                 valid = true; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (auto index : op->outputIndexes) { | 
					
						
							|  |  |  |             if (tensorMask[index] > 0) { | 
					
						
							|  |  |  |                 MNN_ASSERT(opMask[i] <= 0 || opMask[i] == tensorMask[index]); | 
					
						
							|  |  |  |                 opMask[i] = tensorMask[index]; | 
					
						
							|  |  |  |                 valid = true; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (valid) { | 
					
						
							|  |  |  |             for (auto index : op->inputIndexes) { | 
					
						
							|  |  |  |                 tensorMask[index] = opMask[i]; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             for (auto index : op->outputIndexes) { | 
					
						
							|  |  |  |                 tensorMask[index] = opMask[i]; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     // Remove Switch
 | 
					
						
							|  |  |  |     bool needCheck = true; | 
					
						
							|  |  |  |     std::map<int, int> replaceTensor; | 
					
						
							|  |  |  |     while (needCheck) { | 
					
						
							|  |  |  |         needCheck = false; | 
					
						
							|  |  |  |         auto nodes = std::move(cNode->nodes); | 
					
						
							|  |  |  |         for (int i = 0; i < nodes.size(); ++i) { | 
					
						
							|  |  |  |             if (nodes[i]->type == OpType_Extra && nodes[i]->main.AsExtra()->type == "Switch" && (!needCheck)) { | 
					
						
							|  |  |  |                 // Once Time remove only one switch
 | 
					
						
							|  |  |  |                 for (auto output : nodes[i]->outputIndexes) { | 
					
						
							|  |  |  |                     replaceTensor.insert(std::make_pair(output, nodes[i]->inputIndexes[0])); | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 needCheck = true; | 
					
						
							|  |  |  |                 continue; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             cNode->nodes.emplace_back(std::move(nodes[i])); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (auto& op : cNode->nodes) { | 
					
						
							|  |  |  |             for (int i = 0; i < op->inputIndexes.size(); ++i) { | 
					
						
							|  |  |  |                 if (replaceTensor.find(op->inputIndexes[i]) != replaceTensor.end()) { | 
					
						
							|  |  |  |                     op->inputIndexes[i] = replaceTensor[op->inputIndexes[i]]; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     std::map<int, Express::VARP> varMap; | 
					
						
							|  |  |  |     std::set<OpT*> invalidSet; | 
					
						
							|  |  |  |     std::vector<int> inputIndexes; | 
					
						
							|  |  |  |     std::set<int> extraInputIndexes; | 
					
						
							|  |  |  |     std::vector<int> leftOutputs; | 
					
						
							|  |  |  |     std::vector<int> rightOutputs; | 
					
						
							|  |  |  |     std::vector<std::string> mergeNames; | 
					
						
							|  |  |  |     for (auto& node : cNode->nodes) { | 
					
						
							|  |  |  |         if (node->type == OpType_Extra && node->main.AsExtra()->type == "Merge") { | 
					
						
							|  |  |  |             mergeNames.emplace_back(node->name); | 
					
						
							|  |  |  |             if (tensorMask[node->inputIndexes[0]] == 1) { | 
					
						
							|  |  |  |                 leftOutputs.emplace_back(node->inputIndexes[0]); | 
					
						
							|  |  |  |                 rightOutputs.emplace_back(node->inputIndexes[1]); | 
					
						
							|  |  |  |             } else { | 
					
						
							|  |  |  |                 leftOutputs.emplace_back(node->inputIndexes[1]); | 
					
						
							|  |  |  |                 rightOutputs.emplace_back(node->inputIndexes[0]); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         Express::Program::createUnit(varMap, inputIndexes, cNode->nodes, node.get(), netT, invalidSet, extraInputIndexes); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     auto makeSubGraph = [&](const std::vector<int>& index) { | 
					
						
							|  |  |  |         std::vector<Express::VARP> out; | 
					
						
							|  |  |  |         for (auto l : index) { | 
					
						
							|  |  |  |             auto iter = varMap.find(l); | 
					
						
							|  |  |  |             if (iter != varMap.end()) { | 
					
						
							|  |  |  |                 out.emplace_back(iter->second); | 
					
						
							|  |  |  |             } else { | 
					
						
							|  |  |  |                 auto tempInput = Express::_Input(); | 
					
						
							|  |  |  |                 tempInput->setName(netT->tensorName[l]); | 
					
						
							|  |  |  |                 out.emplace_back(tempInput); | 
					
						
							|  |  |  |                 extraInputIndexes.insert(l); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         std::unique_ptr<NetT> newT(new NetT); | 
					
						
							|  |  |  |         Express::Variable::save(out, newT.get()); | 
					
						
							|  |  |  |         std::unique_ptr<SubGraphProtoT> subGraph(new SubGraphProtoT); | 
					
						
							|  |  |  |         subGraph->tensors = std::move(newT->tensorName); | 
					
						
							|  |  |  |         subGraph->nodes = std::move(newT->oplists); | 
					
						
							|  |  |  |         for (int i = 0; i < subGraph->nodes.size(); ++i) { | 
					
						
							|  |  |  |             if (subGraph->nodes[i]->type == OpType_Input) { | 
					
						
							|  |  |  |                 subGraph->inputs.emplace_back(i); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (auto l : index) { | 
					
						
							|  |  |  |             auto& outputName = netT->tensorName[l]; | 
					
						
							|  |  |  |             for (int i = 0; i < subGraph->tensors.size(); ++i) { | 
					
						
							|  |  |  |                 if (subGraph->tensors[i] == outputName) { | 
					
						
							|  |  |  |                     subGraph->outputs.emplace_back(i); | 
					
						
							|  |  |  |                     break; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return subGraph; | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |         auto leftGraph = makeSubGraph(leftOutputs); | 
					
						
							|  |  |  |         leftGraph->name = cNode->name + "/then"; | 
					
						
							|  |  |  |         condOp->main.AsIfParam()->then_graph = leftGraph->name; | 
					
						
							|  |  |  |         netT->subgraphs.emplace_back(std::move(leftGraph)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto rightGraph = makeSubGraph(rightOutputs); | 
					
						
							|  |  |  |         rightGraph->name = cNode->name + "/else"; | 
					
						
							|  |  |  |         condOp->main.AsIfParam()->else_graph = rightGraph->name; | 
					
						
							|  |  |  |         netT->subgraphs.emplace_back(std::move(rightGraph)); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     condOp->inputIndexes.emplace_back(condTensorIndex); | 
					
						
							|  |  |  |     std::unique_ptr<StringVecT> inputT(new StringVecT); | 
					
						
							|  |  |  |     inputT->data.emplace_back(netT->tensorName[condTensorIndex]); | 
					
						
							|  |  |  |     condOp->main.AsIfParam()->aliases_inputs.emplace_back(std::move(inputT)); | 
					
						
							|  |  |  |     extraInputIndexes.erase(condTensorIndex); | 
					
						
							|  |  |  |     for (auto index : extraInputIndexes) { | 
					
						
							|  |  |  |         condOp->inputIndexes.emplace_back(index); | 
					
						
							|  |  |  |         std::unique_ptr<StringVecT> inputT(new StringVecT); | 
					
						
							|  |  |  |         inputT->data.emplace_back(netT->tensorName[index]); | 
					
						
							|  |  |  |         condOp->main.AsIfParam()->aliases_inputs.emplace_back(std::move(inputT)); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     for (int i = 0; i < leftOutputs.size(); ++i) { | 
					
						
							|  |  |  |         std::unique_ptr<StringVecT> outputPari(new StringVecT); | 
					
						
							|  |  |  |         outputPari->data.emplace_back(netT->tensorName[leftOutputs[i]]); | 
					
						
							|  |  |  |         outputPari->data.emplace_back(netT->tensorName[rightOutputs[i]]); | 
					
						
							|  |  |  |         condOp->main.AsIfParam()->aliases_outputs.emplace_back(std::move(outputPari)); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     // Compability for old usage
 | 
					
						
							|  |  |  |     for (int i = 0; i < condOp->outputIndexes.size(); ++i) { | 
					
						
							|  |  |  |         std::ostringstream newName; | 
					
						
							|  |  |  |         newName << condOp->name << ":" << i; | 
					
						
							|  |  |  |         netT->tensorName[condOp->outputIndexes[i]] = newName.str(); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     res.emplace_back(std::move(condOp)); | 
					
						
							|  |  |  |     cNode->nodes.clear(); | 
					
						
							|  |  |  |     return res; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | std::vector<std::unique_ptr<OpT>> _makeWhile(std::shared_ptr<ClusterNode> cNode, MNN::NetT* netT, const std::map<std::string, int>& originTensorIndexes) { | 
					
						
							|  |  |  |     std::vector<std::unique_ptr<OpT>> res; | 
					
						
							|  |  |  |     // Remove switch and find LoopCond
 | 
					
						
							|  |  |  |     int loopCond = -1; | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |         std::map<int, int> replaceTensor; | 
					
						
							|  |  |  |         auto childs = std::move(cNode->nodes); | 
					
						
							|  |  |  |         for (auto& op : childs) { | 
					
						
							|  |  |  |             if (op->type == OpType_Extra && op->main.AsExtra()->type == "Switch") { | 
					
						
							|  |  |  |                 for (auto o : op->outputIndexes) { | 
					
						
							|  |  |  |                     replaceTensor.insert(std::make_pair(o, op->inputIndexes[0])); | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 continue; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             if (op->type == OpType_Extra && op->main.AsExtra()->type == "LoopCond") { | 
					
						
							|  |  |  |                 loopCond = op->outputIndexes[0]; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             cNode->nodes.emplace_back(std::move(op)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (auto& op : cNode->nodes) { | 
					
						
							|  |  |  |             for (int i = 0; i < op->inputIndexes.size(); ++i) { | 
					
						
							|  |  |  |                 if (replaceTensor.find(op->inputIndexes[i]) != replaceTensor.end()) { | 
					
						
							|  |  |  |                     op->inputIndexes[i] = replaceTensor[op->inputIndexes[i]]; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     MNN_ASSERT(loopCond != -1); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Generate Condition Graph
 | 
					
						
							|  |  |  |     std::map<int, Express::VARP> varMap; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // While Op
 | 
					
						
							|  |  |  |     std::unique_ptr<SubGraphProtoT> condGraph(new SubGraphProtoT); | 
					
						
							|  |  |  |     condGraph->name = cNode->name + "/cond"; | 
					
						
							|  |  |  |     std::unique_ptr<SubGraphProtoT> bodyGraph(new SubGraphProtoT); | 
					
						
							|  |  |  |     bodyGraph->name = cNode->name + "/body"; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     std::unique_ptr<OpT> whileOpU(new OpT); | 
					
						
							|  |  |  |     auto whileOp = whileOpU.get();// For easy to debug
 | 
					
						
							|  |  |  |     whileOp->type = OpType_While; | 
					
						
							|  |  |  |     whileOp->main.type = OpParameter_WhileParam; | 
					
						
							|  |  |  |     whileOp->main.value = new WhileParamT; | 
					
						
							|  |  |  |     whileOp->name = cNode->name; | 
					
						
							|  |  |  |     auto whileParam = whileOp->main.AsWhileParam(); | 
					
						
							|  |  |  |     whileParam->cond_graph = condGraph->name; | 
					
						
							|  |  |  |     whileParam->body_graph = bodyGraph->name; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |     std::set<int> extraInputIndexes; | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |     // Remove Merge and find body
 | 
					
						
							|  |  |  |     std::vector<int> bodyUpdate; | 
					
						
							|  |  |  |     std::set<std::string> bodyOutputNames; | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |         std::vector<std::pair<int, int>> updateIndexes; | 
					
						
							|  |  |  |         auto childs = std::move(cNode->nodes); | 
					
						
							|  |  |  |         std::map<int, int> replaceTensor; | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |         std::set<int> updateToTensors; | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |         for (auto& op : childs) { | 
					
						
							|  |  |  |             if (op->type == OpType_Extra && op->main.AsExtra()->type == "Merge") { | 
					
						
							| 
									
										
										
										
											2021-04-08 15:34:23 +08:00
										 |  |  |                 int updateFromIdx = op->inputIndexes[1], updateToIdx = op->inputIndexes[0]; | 
					
						
							|  |  |  |                 // if tensor_x is at outside of loop and used by two op, and these two op
 | 
					
						
							|  |  |  |                 // are all update data, so need copy tensor_x to tensor_x_copy.
 | 
					
						
							|  |  |  |                 if (updateToTensors.find(updateToIdx) != updateToTensors.end()) { | 
					
						
							|  |  |  |                     std::unique_ptr<OpT> copyOp(new OpT); | 
					
						
							|  |  |  |                     copyOp->type = OpType_Concat; | 
					
						
							|  |  |  |                     copyOp->inputIndexes.push_back(updateToIdx); | 
					
						
							|  |  |  |                     auto opName = netT->tensorName[updateToIdx] + "_copy"; | 
					
						
							|  |  |  |                     updateToIdx = netT->tensorName.size(); | 
					
						
							|  |  |  |                     copyOp->outputIndexes.push_back(updateToIdx); | 
					
						
							|  |  |  |                     netT->tensorName.push_back(opName); | 
					
						
							|  |  |  |                     netT->tensorNumber++; | 
					
						
							|  |  |  |                     res.emplace_back(std::move(copyOp)); | 
					
						
							|  |  |  |                     extraInputIndexes.insert(updateToIdx); | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |                 updateToTensors.insert(updateToIdx); | 
					
						
							|  |  |  |                 updateIndexes.emplace_back(std::make_pair(updateFromIdx, updateToIdx)); | 
					
						
							|  |  |  |                 replaceTensor.insert(std::make_pair(op->outputIndexes[0], updateToIdx)); | 
					
						
							|  |  |  |                 bodyUpdate.emplace_back(updateFromIdx); | 
					
						
							|  |  |  |                 bodyOutputNames.insert(netT->tensorName[updateFromIdx]); | 
					
						
							| 
									
										
										
										
											2021-02-07 10:45:07 +08:00
										 |  |  |                 continue; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             cNode->nodes.emplace_back(std::move(op)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (auto& op : cNode->nodes) { | 
					
						
							|  |  |  |             for (int i = 0; i < op->inputIndexes.size(); ++i) { | 
					
						
							|  |  |  |                 if (replaceTensor.find(op->inputIndexes[i]) != replaceTensor.end()) { | 
					
						
							|  |  |  |                     op->inputIndexes[i] = replaceTensor[op->inputIndexes[i]]; | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (auto& p : updateIndexes) { | 
					
						
							|  |  |  |             if (replaceTensor.find(p.first) != replaceTensor.end()) { | 
					
						
							|  |  |  |                 p.first = replaceTensor[p.first]; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             if (replaceTensor.find(p.second) != replaceTensor.end()) { | 
					
						
							|  |  |  |                 p.second = replaceTensor[p.second]; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (auto& p : updateIndexes) { | 
					
						
							|  |  |  |             std::unique_ptr<StringVecT> updateName(new StringVecT); | 
					
						
							|  |  |  |             updateName->data.emplace_back(netT->tensorName[p.first]); | 
					
						
							|  |  |  |             updateName->data.emplace_back(netT->tensorName[p.second]); | 
					
						
							|  |  |  |             whileParam->aliases_updates.emplace_back(std::move(updateName)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Get output
 | 
					
						
							|  |  |  |     for (auto& op : cNode->nodes) { | 
					
						
							|  |  |  |         if (op->type != OpType_Extra) { | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (op->main.AsExtra()->type == "Exit") { | 
					
						
							|  |  |  |             whileOp->outputIndexes.emplace_back(op->outputIndexes[0]); | 
					
						
							|  |  |  |             whileParam->aliases_outputs.emplace_back(netT->tensorName[op->inputIndexes[0]]); | 
					
						
							|  |  |  |             bodyOutputNames.insert(netT->tensorName[op->inputIndexes[0]]); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Create Loop Cond
 | 
					
						
							|  |  |  |     std::set<OpT*> invalidSet; | 
					
						
							|  |  |  |     std::vector<int> inputIndexes; | 
					
						
							|  |  |  |     for (auto& node : cNode->nodes) { | 
					
						
							|  |  |  |         Express::Program::createUnit(varMap, inputIndexes, cNode->nodes, node.get(), netT, invalidSet, extraInputIndexes); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     for (auto index : extraInputIndexes) { | 
					
						
							|  |  |  |         std::unique_ptr<StringVecT> inputNames(new StringVecT); | 
					
						
							|  |  |  |         inputNames->data.emplace_back(netT->tensorName[index]); | 
					
						
							|  |  |  |         whileParam->aliases_inputs.emplace_back(std::move(inputNames)); | 
					
						
							|  |  |  |         whileOp->inputIndexes.emplace_back(index); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |         std::unique_ptr<NetT> condNet(new NetT); | 
					
						
							|  |  |  |         Express::Variable::save({varMap[loopCond]}, condNet.get()); | 
					
						
							|  |  |  |         for (auto& op : condNet->oplists) { | 
					
						
							|  |  |  |             if (op->type == OpType_Extra && op->main.AsExtra()->type == "LoopCond") { | 
					
						
							|  |  |  |                 condGraph->outputs.emplace_back(op->inputIndexes[0]); | 
					
						
							|  |  |  |                 continue; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             if (op->type == OpType_Input) { | 
					
						
							|  |  |  |                 condGraph->inputs.emplace_back(op->outputIndexes[0]); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             condGraph->nodes.emplace_back(std::move(op)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         condGraph->tensors = std::move(condNet->tensorName); | 
					
						
							|  |  |  |         MNN_ASSERT(condGraph->outputs.size() > 0); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |         std::unique_ptr<NetT> bodyNet(new NetT); | 
					
						
							|  |  |  |         std::vector<Express::VARP> bodyOutputs; | 
					
						
							|  |  |  |         for (auto b : bodyUpdate) { | 
					
						
							|  |  |  |             if (varMap.find(b) != varMap.end()) { | 
					
						
							|  |  |  |                 bodyOutputs.emplace_back(varMap[b]); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         Express::Variable::save(bodyOutputs, bodyNet.get()); | 
					
						
							|  |  |  |         for (auto& op : bodyNet->oplists) { | 
					
						
							|  |  |  |             if (op->type == OpType_Input) { | 
					
						
							|  |  |  |                 bodyGraph->inputs.emplace_back(op->outputIndexes[0]); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             for (auto o : op->outputIndexes) { | 
					
						
							|  |  |  |                 if (bodyOutputNames.find(bodyNet->tensorName[o]) != bodyOutputNames.end()) { | 
					
						
							|  |  |  |                     bodyGraph->outputs.emplace_back(o); | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             bodyGraph->nodes.emplace_back(std::move(op)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         bodyGraph->tensors = std::move(bodyNet->tensorName); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |         // Const op needed update turn to Input
 | 
					
						
							|  |  |  |         auto turnConst = [&](SubGraphProtoT* subGraph) { | 
					
						
							|  |  |  |             for (auto& s : whileParam->aliases_updates) { | 
					
						
							|  |  |  |                 auto& second = s->data[1]; | 
					
						
							|  |  |  |                 for (int i = 0; i < subGraph->nodes.size(); ++i) { | 
					
						
							|  |  |  |                     auto& op = subGraph->nodes[i]; | 
					
						
							|  |  |  |                     if (OpType_Const != op->type) { | 
					
						
							|  |  |  |                         continue; | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                     if (subGraph->tensors[op->outputIndexes[0]] == second) { | 
					
						
							|  |  |  |                         // Const move outside
 | 
					
						
							|  |  |  |                         auto opPtr = op.get(); | 
					
						
							|  |  |  |                         res.emplace_back(std::move(op)); | 
					
						
							|  |  |  |                         subGraph->nodes[i].reset(new OpT); | 
					
						
							|  |  |  |                         subGraph->nodes[i]->type = OpType_Input; | 
					
						
							|  |  |  |                         subGraph->nodes[i]->main.type = OpParameter_Input; | 
					
						
							|  |  |  |                         subGraph->nodes[i]->main.value = new InputT; | 
					
						
							|  |  |  |                         subGraph->nodes[i]->main.AsInput()->dims = opPtr->main.AsBlob()->dims; | 
					
						
							|  |  |  |                         subGraph->nodes[i]->main.AsInput()->dtype = opPtr->main.AsBlob()->dataType; | 
					
						
							|  |  |  |                         subGraph->nodes[i]->main.AsInput()->dformat = opPtr->main.AsBlob()->dataFormat; | 
					
						
							|  |  |  |                         subGraph->nodes[i]->outputIndexes = opPtr->outputIndexes; | 
					
						
							|  |  |  |                         opPtr->outputIndexes[0] = originTensorIndexes.find(second)->second; | 
					
						
							|  |  |  |                         std::unique_ptr<StringVecT> newVecT(new StringVecT); | 
					
						
							|  |  |  |                         newVecT->data.emplace_back(second); | 
					
						
							|  |  |  |                         whileParam->aliases_inputs.emplace_back(std::move(newVecT)); | 
					
						
							|  |  |  |                         whileOp->inputIndexes.emplace_back(opPtr->outputIndexes[0]); | 
					
						
							|  |  |  |                     } | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |         turnConst(condGraph.get()); | 
					
						
							|  |  |  |         turnConst(bodyGraph.get()); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     //FUNC_PRINT_ALL(whileOp->name.c_str(), s);
 | 
					
						
							|  |  |  |     netT->subgraphs.emplace_back(std::move(condGraph)); | 
					
						
							|  |  |  |     netT->subgraphs.emplace_back(std::move(bodyGraph)); | 
					
						
							|  |  |  |     res.emplace_back(std::move(whileOpU)); | 
					
						
							|  |  |  |     cNode->nodes.clear(); | 
					
						
							|  |  |  |     return res; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | static std::vector<std::unique_ptr<OpT>> _makeSubGraph(std::shared_ptr<ClusterNode> cNode, MNN::NetT* netT, const std::map<std::string, int>& t) { | 
					
						
							|  |  |  |     // Make Subgraph In order, first make children, second make parent
 | 
					
						
							|  |  |  |     for (auto c : cNode->children) { | 
					
						
							|  |  |  |         auto opList = std::move(_makeSubGraph(c, netT, t)); | 
					
						
							|  |  |  |         for (auto&& op : opList) { | 
					
						
							|  |  |  |             cNode->nodes.emplace_back(std::move(op)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (cNode->hasLoop) { | 
					
						
							|  |  |  |         return _makeWhile(cNode, netT, t); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (cNode->hasMerge) { | 
					
						
							|  |  |  |         return _makeCond(cNode, netT, t); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return {}; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | int GenerateSubGraph(std::unique_ptr<MNN::NetT>& netT) { | 
					
						
							|  |  |  |     // Remove unuseful op before cluster
 | 
					
						
							|  |  |  |     std::vector<std::string> passes = { | 
					
						
							|  |  |  |         "RemoveUnusefulOp", | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  |     for (auto pass : passes) { | 
					
						
							|  |  |  |         auto convert = PostConverter::get(pass); | 
					
						
							|  |  |  |         if (nullptr == convert) { | 
					
						
							|  |  |  |             continue; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         convert->onExecute(netT); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     bool hasControlFlow = false; | 
					
						
							|  |  |  |     for (auto& op : netT->oplists) { | 
					
						
							|  |  |  |         if (_isControlOp(op.get())) { | 
					
						
							|  |  |  |             hasControlFlow = true; | 
					
						
							|  |  |  |             break; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (!hasControlFlow) { | 
					
						
							|  |  |  |         return 0; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     MNN_PRINT("The modle has control flow, please use MNN::Module to run it\n"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // We broadly divided all nodes into clusters by the prefix of the node
 | 
					
						
							|  |  |  |     // name, and each cluster belongs to one of the tree categories,
 | 
					
						
							|  |  |  |     // Normal, Condition or WhileLoop.
 | 
					
						
							|  |  |  |     // The nodes which have the same name prefix maybe belong to the same
 | 
					
						
							|  |  |  |     // cluster. The nodes that type is `Condition` maybe belong to a condition
 | 
					
						
							|  |  |  |     // subgraph. The nodes that type is `WhileLoop` maybe belong to a while loop
 | 
					
						
							|  |  |  |     // subgraph.
 | 
					
						
							|  |  |  |     std::map<std::string, std::shared_ptr<ClusterNode>> clusters; | 
					
						
							|  |  |  |     std::vector<std::shared_ptr<ClusterNode>> rootClusters; | 
					
						
							|  |  |  |     bool hasControlflow = false; | 
					
						
							|  |  |  |     for (auto& node : netT->oplists) { | 
					
						
							|  |  |  |         std::string name = RSplitString(node->name, "/").at(0); | 
					
						
							|  |  |  |         _makeClusterNode(name, clusters, rootClusters); | 
					
						
							|  |  |  |         auto it = clusters.find(name); | 
					
						
							|  |  |  |         if (node->type == OpType_Extra) { | 
					
						
							|  |  |  |             auto type = node->main.AsExtra()->type; | 
					
						
							|  |  |  |             if (type == "LoopCond") { | 
					
						
							|  |  |  |                 hasControlflow = true; | 
					
						
							|  |  |  |                 it->second->hasLoop = true; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             else if (type == "Switch") { | 
					
						
							|  |  |  |                 hasControlflow = true; | 
					
						
							|  |  |  |                 it->second->hasSwitch = true; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             else if (type == "Merge") { | 
					
						
							|  |  |  |                 hasControlflow = true; | 
					
						
							|  |  |  |                 it->second->hasMerge = true; | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         it->second->nodes.emplace_back(std::move(node)); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     netT->oplists.clear(); | 
					
						
							|  |  |  |     std::map<std::string, int> tensorNameMap; | 
					
						
							|  |  |  |     for (int i=0; i<netT->tensorName.size(); ++i) { | 
					
						
							|  |  |  |         tensorNameMap[netT->tensorName[i]] = i; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     for (auto n : rootClusters) { | 
					
						
							|  |  |  |         _mergeSubGraph(n); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | #ifdef MNN_PRINT_SUBGRAPH
 | 
					
						
							|  |  |  |     for (auto n : rootClusters) { | 
					
						
							|  |  |  |         _printSubGraph(n); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  |     for (auto n : rootClusters) { | 
					
						
							|  |  |  |         auto controlOp = _makeSubGraph(n, netT.get(), tensorNameMap); | 
					
						
							|  |  |  |         for (auto& c : n->nodes) { | 
					
						
							|  |  |  |             netT->oplists.emplace_back(std::move(c)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         for (auto& op : controlOp) { | 
					
						
							|  |  |  |             netT->oplists.emplace_back(std::move(op)); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return 0; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | } |