MNN/tools/converter/source/optimizer/postconvert/RemoveTestNoUseOps.cpp

150 lines
5.3 KiB
C++

#include "RemoveTestNoUseOps.hpp"
bool RemoveTestNoUseOps::onExecute(std::unique_ptr<MNN::NetT>& net) const {
const MNN::NetT* const netPtr = net.get();
std::set<std::string> netOutputNames;
for (auto& t : net->outputName) {
netOutputNames.insert(t);
}
for (auto iter = net->oplists.begin(); iter != net->oplists.end(); iter++) {
auto& op = *iter;
if (op->type == OpType_Input) {
for (auto o : op->outputIndexes) {
netOutputNames.insert(net->tensorName[o]);
}
}
}
std::unordered_set<int> removedInputs;
for (auto iter = net->oplists.begin(); iter != net->oplists.end();) {
auto& op = *iter;
bool shouldDelete = shouldDeleteJudge(op.get(), netPtr);
if (!shouldDelete) {
iter++;
continue;
}
bool hasOutputName = false;
for (auto o : op->outputIndexes) {
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
hasOutputName = true;
break;
}
}
bool hasOutputFromInput = false;
for (auto o : op->inputIndexes) {
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
hasOutputFromInput = true;
break;
}
}
if (hasOutputFromInput && hasOutputName) {
iter++;
continue;
}
bool deleteOutput = shouldDeleteOutput(op.get());
// Find the next op
if (op->outputIndexes.empty() || op->inputIndexes.empty()) {
iter = net->oplists.erase(iter);
continue;
}
auto originInput = op->inputIndexes[0];
auto originOutputs = op->outputIndexes;
if ((!deleteOutput) && hasOutputName) {
bool valid = true;
for (auto o : originOutputs) {
if (netOutputNames.find(net->tensorName[o]) != netOutputNames.end()) {
if (netOutputNames.find(net->tensorName[originInput]) != netOutputNames.end()) {
valid = false;
break;
}
net->tensorName[originInput] = net->tensorName[o];
}
}
if (!valid) {
continue;
}
}
for (auto subIter = net->oplists.begin(); subIter != net->oplists.end(); subIter++) {
auto& subOp = *subIter;
if (deleteOutput) {
for (auto iter=subOp->inputIndexes.begin(); iter != subOp->inputIndexes.end();) {
if (std::find(originOutputs.begin(), originOutputs.end(), *iter) != originOutputs.end()) {
iter = subOp->inputIndexes.erase(iter);
continue;
}
iter++;
}
} else {
for (int v = 0; v < subOp->inputIndexes.size(); ++v) {
if (std::find(originOutputs.begin(), originOutputs.end(), subOp->inputIndexes[v]) != originOutputs.end()) {
subOp->inputIndexes[v] = originInput;
}
}
}
}
bool removeUselessInput = shouldRemoveUnusefulInputs(op.get());
if (removeUselessInput) {
for (int input : op->inputIndexes) {
removedInputs.emplace(input);
}
}
iter = net->oplists.erase(iter);
}
// Remove the op only if the reference counts of it's all outputs
// are reduced to be zero.
std::unordered_map<int, int/*reference count*/> uselessIndex;
for (const auto& op : net->oplists) {
for (int input : op->inputIndexes) {
auto it = uselessIndex.find(input);
if (it == uselessIndex.end()) {
uselessIndex.emplace(input, 1);
} else {
++it->second;
}
}
}
// Set reference count 1 for all net outputs.
for (const auto& op : net->oplists) {
for (int output : op->outputIndexes) {
auto it = uselessIndex.find(output);
if (it == uselessIndex.end()) {
if (removedInputs.count(output)) {
uselessIndex.emplace(output, 0);
} else {
uselessIndex.emplace(output, 1);
}
}
}
}
bool needIteration = false;
do {
needIteration = false;
for (auto iter = net->oplists.begin(); iter != net->oplists.end();) {
auto& op = *iter;
bool useless = true;
for (auto index : op->outputIndexes) {
if (uselessIndex.at(index) > 0) {
useless = false;
break;
}
}
if (!useless) {
iter++;
continue;
}
if (!op->inputIndexes.empty()) {
for (auto index : op->inputIndexes) {
auto it = uselessIndex.find(index);
MNN_ASSERT(it != uselessIndex.end());
--it->second;
}
needIteration = true;
}
iter = net->oplists.erase(iter);
}
} while (needIteration);
return true;
}