MNN/tools/converter/source/optimizer/PostConverter.cpp

147 lines
4.5 KiB
C++
Raw Normal View History

- build: - unify schema building in core and converter; - add more build script for android; - add linux build script for python; - ops impl: - add floor mod support in binary; - use eltwise impl in add/max/sub/mul binary for optimization; - remove fake double support in cast; - fix 5d support for concat; - add adjX and adjY support for batch matmul; - optimize conv2d back prop filter; - add pad mode support for conv3d; - fix bug in conv2d & conv depthwise with very small feature map; - optimize binary without broacast; - add data types support for gather; - add gather ND support; - use uint8 data type in gather v2; - add transpose support for matmul; - add matrix band part; - add dim != 4 support for padding, reshape & tensor convert; - add pad type support for pool3d; - make ops based on TensorFlow Lite quantization optional; - add all & any support for reduction; - use type in parameter as output type in reduction; - add int support for unary; - add variable weight support for conv2d; - fix conv2d depthwise weights initialization; - fix type support for transpose; - fix grad outputs count for reduce grad and reshape grad; - fix priorbox & detection output; - fix metal softmax error; - python: - add runSessionWithCallBackInfo interface; - add max nodes limit (1400) for visualization tool; - fix save error in python3; - align default dim; - convert: - add extra design for optimization; - add more post converting optimizers; - add caffe v1 weights blob support; - add cast, unary, conv transpose support for onnx model; - optimize batchnorm, conv with variable weights, prelu, reshape, slice, upsample for onnx model; - add cos/sin/atan/tan support for unary for tensorflow model; - add any/all support for reduction for tensorflow model; - add elu, conv3d, pool3d support for tensorflow model; - optimize argmax, batchnorm, concat, batch to space, conv with variable weights, prelu, slice for tensorflow model; - others: - fix size computer lock; - fix thread pool deadlock; - add express & parameters in express; - rewrite blitter chooser without static map; - add tests for expr;
2019-10-29 13:37:26 +08:00
//
// PostConverter.cpp
// MNNConverter
//
// Created by MNN on 2019/01/31.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "PostConverter.hpp"
#include "PostTreatUtils.hpp"
#include "Program.hpp"
#include "TemplateMerge.hpp"
#include "Optimizer.hpp"
using namespace MNN::Express;
std::unique_ptr<MNN::NetT> optimizeNet(std::unique_ptr<MNN::NetT>& originNet) {
if (originNet->oplists.size() <= 0) {
return nullptr;
}
std::vector<std::string> postConvertPass;
postConvertPass = {
// Seperate Tensor for inplace op
"RemoveInplace",
// Remove Unuseful Op such as NoOp, Identity, Dropout, Seq2Out,
"RemoveUnusefulOp",
// Turn InnerProduct from Caffe / Onnx to Convolution
"TransformInnerProduct",
// Turn Im2Seq from Caffe to Reshape
"TransformIm2Seq",
// Turn Caffe's ShuffleChannel to compose op
"TransformShuffleChannel",
// Turn Onnx's Pad to Tensorflow's Pad
"TransformOnnxPad",
};
for (auto pass : postConvertPass) {
auto convert = PostConverter::get(pass);
if (nullptr == convert) {
LOG(INFO) << "Can't find pass of " << pass << "\n";
continue;
}
bool valid = convert->onExecute(originNet);
if (!valid) {
LOG(INFO) << "Run " << pass << "Error\n";
}
}
auto program = MNN::Express::Program::create(originNet.get(), true);
std::vector<std::string> optimizePass = {
"Merge",
};
switch (originNet->sourceType) {
case MNN::NetSource_TENSORFLOW:
optimizePass.insert(optimizePass.begin(), "TFExtra");
break;
case MNN::NetSource_ONNX:
optimizePass.insert(optimizePass.begin(), "OnnxExtra");
break;
default:
break;
}
for (auto pass : optimizePass) {
auto& merge = MNN::Express::TemplateMerge::getInstance(pass);
merge.onExecute(program->outputs());
}
if (program->needGenerateCode()) {
MNN_PRINT("The Model Has Control / Extra Op, Please Compile the Code of model.cpp\n");
std::ofstream code("model.cpp");
code << "#include \"Expr.hpp\"\n";
code << "#include \"ExprCreator.hpp\"\n";
code << "using namespace MNN::Express;\n";
code << "void extraCall(std::map<std::string, VARP>& varMap) {\n";
program->emit(code);
code << "}\n";
}
std::unique_ptr<MNN::NetT> newNet(new MNN::NetT);
{
auto outputs = program->outputs();
newNet->sourceType = originNet->sourceType;
newNet->bizCode = originNet->bizCode;
Variable::save(outputs, newNet.get());
}
std::vector<std::string> afterProgramConvert = {
// Turn BatchNormal to Scale When inference
"TransformBatchNormal",
// Merge Scale info Convolution
"MergeToConvolution",
// conert some binary op(add, mul, sub...) to element wise op(sum, sub) accroding to input condition
"ConvertBinaryToElementwise",
// Turn group convolution to Slice - Convolution - Concat
"TransformGroupConvolution",
// Add tensor dimension format convert for NC4HW4 - NHWC / NC4HW4 - NCHW
"AddTensorFormatConverter",
// Remove unuseful tensor
"ReIndexTensor",
};
for (auto pass : afterProgramConvert) {
auto convert = PostConverter::get(pass);
if (nullptr == convert) {
LOG(INFO) << "Can't find pass of " << pass << "\n";
continue;
}
bool valid = convert->onExecute(newNet);
if (!valid) {
LOG(INFO) << "Run " << pass << "Error\n";
}
}
std::set<int> inputSet;
for (auto& op : newNet->oplists) {
if (op->type == MNN::OpType_Input) {
LOG(INFO) << "Inputs: " << newNet->tensorName[op->outputIndexes[0]];
continue;
}
for (auto index : op->inputIndexes) {
inputSet.insert(index);
}
}
for (auto& op : newNet->oplists) {
bool hasInput = false;
for (auto index : op->outputIndexes) {
if (inputSet.find(index) != inputSet.end()) {
hasInput = true;
break;
}
}
if (!hasInput) {
for (auto index : op->outputIndexes) {
LOG(INFO) << "Outputs: " << newNet->tensorName[index]
<< ", Type = " << MNN::EnumNameOpType(op->type);
}
}
}
return newNet;
}