mirror of https://github.com/alibaba/MNN.git
193 lines
7.1 KiB
C++
193 lines
7.1 KiB
C++
//
|
|
// cli.cpp
|
|
// MNNConverter
|
|
//
|
|
// Created by MNN on 2019/01/31.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include "cli.hpp"
|
|
#if defined(_MSC_VER)
|
|
#include <Windows.h>
|
|
#undef min
|
|
#undef max
|
|
#else
|
|
#include <unistd.h>
|
|
#endif
|
|
#include <MNN/VCS.h>
|
|
#include "config.hpp"
|
|
#include "logkit.h"
|
|
|
|
/**
|
|
* Print Command Line Banner
|
|
*/
|
|
void Cli::printProjectBanner() {
|
|
// print project detail
|
|
// auto config = ProjectConfig::obtainSingletonInstance();
|
|
|
|
std::cout << "\nMNNConverter Version: " << ProjectConfig::version << " - MNN @ 2018\n\n" << std::endl;
|
|
}
|
|
|
|
cxxopts::Options Cli::initializeMNNConvertArgs(modelConfig &modelPath, int argc, char **argv) {
|
|
cxxopts::Options options("MNNConvert");
|
|
|
|
options.positional_help("[optional args]").show_positional_help();
|
|
|
|
options.allow_unrecognised_options().add_options()("h, help", "Convert Other Model Format To MNN Model\n")(
|
|
"v, version", "show current version")("f, framework", "model type, ex: [TF,CAFFE,ONNX,TFLITE,MNN]",
|
|
cxxopts::value<std::string>())(
|
|
"modelFile", "tensorflow Pb or caffeModel, ex: *.pb,*caffemodel", cxxopts::value<std::string>())(
|
|
"prototxt", "only used for caffe, ex: *.prototxt", cxxopts::value<std::string>())(
|
|
"MNNModel", "MNN model, ex: *.mnn", cxxopts::value<std::string>())(
|
|
"fp16", "save Conv's weight/bias in half_float data type")(
|
|
"benchmarkModel",
|
|
"Do NOT save big size data, such as Conv's weight,BN's gamma,beta,mean and variance etc. Only used to test "
|
|
"the cost of the model")("bizCode", "MNN Model Flag, ex: MNN", cxxopts::value<std::string>())(
|
|
"debug", "Enable debugging mode.")(
|
|
"forTraining", "whether or not to save training ops BN and Dropout, default: false", cxxopts::value<bool>())(
|
|
"weightQuantBits", "save conv/matmul/LSTM float weights to int8 type, only optimize for model size, 2-8 bits, default: 0, which means no weight quant", cxxopts::value<int>())(
|
|
"weightQuantAsymmetric", "the default weight-quant uses SYMMETRIC quant method, which is compatible with old MNN versions. "
|
|
"you can try set --weightQuantAsymmetric to use asymmetric quant method to improve accuracy of the weight-quant model in some cases, "
|
|
"but asymmetric quant model cannot run on old MNN versions. You will need to upgrade MNN to new version to solve this problem. default: false", cxxopts::value<bool>())(
|
|
"compressionParamsFile",
|
|
"The path of the compression parameters that stores activation, "
|
|
"weight scales and zero points for quantization or information "
|
|
"for sparsity.", cxxopts::value<std::string>())(
|
|
"saveStaticModel", "save static model with fix shape, default: false", cxxopts::value<bool>())(
|
|
"inputConfigFile", "set input config file for static model, ex: ~/config.txt", cxxopts::value<std::string>());
|
|
|
|
auto result = options.parse(argc, argv);
|
|
|
|
if (result.count("help")) {
|
|
std::cout << options.help({""}) << std::endl;
|
|
exit(EXIT_SUCCESS);
|
|
}
|
|
|
|
if (result.count("version")) {
|
|
std::cout << ProjectConfig::version << std::endl;
|
|
exit(EXIT_SUCCESS);
|
|
}
|
|
|
|
modelPath.model = modelPath.MAX_SOURCE;
|
|
// model source
|
|
if (result.count("framework")) {
|
|
const std::string frameWork = result["framework"].as<std::string>();
|
|
if ("TF" == frameWork) {
|
|
modelPath.model = modelConfig::TENSORFLOW;
|
|
} else if ("CAFFE" == frameWork) {
|
|
modelPath.model = modelConfig::CAFFE;
|
|
} else if ("ONNX" == frameWork) {
|
|
modelPath.model = modelConfig::ONNX;
|
|
} else if ("MNN" == frameWork) {
|
|
modelPath.model = modelConfig::MNN;
|
|
} else if ("TFLITE" == frameWork) {
|
|
modelPath.model = modelConfig::TFLITE;
|
|
} else {
|
|
std::cout << "Framework Input ERROR or Not Support This Model Type Now!" << std::endl;
|
|
std::cout << options.help({""}) << std::endl;
|
|
exit(EXIT_FAILURE);
|
|
}
|
|
} else {
|
|
std::cout << options.help({""}) << std::endl;
|
|
exit(EXIT_FAILURE);
|
|
}
|
|
|
|
// model file path
|
|
if (result.count("modelFile")) {
|
|
const std::string modelFile = result["modelFile"].as<std::string>();
|
|
if (CommonKit::FileIsExist(modelFile)) {
|
|
modelPath.modelFile = modelFile;
|
|
} else {
|
|
DLOG(INFO) << "Model File Does Not Exist! ==> " << modelFile;
|
|
exit(EXIT_FAILURE);
|
|
}
|
|
} else {
|
|
std::cout << options.help({""}) << std::endl;
|
|
exit(EXIT_FAILURE);
|
|
}
|
|
|
|
// prototxt file path
|
|
if (result.count("prototxt")) {
|
|
const std::string prototxt = result["prototxt"].as<std::string>();
|
|
if (CommonKit::FileIsExist(prototxt)) {
|
|
modelPath.prototxtFile = prototxt;
|
|
} else {
|
|
DLOG(INFO) << "Model File Does Not Exist!";
|
|
exit(EXIT_FAILURE);
|
|
}
|
|
} else {
|
|
// caffe model must have this option
|
|
if (modelPath.model == modelPath.CAFFE) {
|
|
std::cout << options.help({""}) << std::endl;
|
|
exit(EXIT_FAILURE);
|
|
}
|
|
}
|
|
|
|
// MNN model output path
|
|
if (result.count("MNNModel")) {
|
|
const std::string MNNModelPath = result["MNNModel"].as<std::string>();
|
|
modelPath.MNNModel = MNNModelPath;
|
|
} else {
|
|
std::cout << options.help({""}) << std::endl;
|
|
exit(EXIT_FAILURE);
|
|
}
|
|
|
|
// add MNN bizCode
|
|
if (result.count("bizCode")) {
|
|
const std::string bizCode = result["bizCode"].as<std::string>();
|
|
modelPath.bizCode = bizCode;
|
|
} else {
|
|
std::cout << options.help({""}) << std::endl;
|
|
exit(EXIT_FAILURE);
|
|
}
|
|
|
|
// input config file path
|
|
if (result.count("inputConfigFile")) {
|
|
const std::string inputConfigFile = result["inputConfigFile"].as<std::string>();
|
|
modelPath.inputConfigFile = inputConfigFile;
|
|
}
|
|
|
|
// benchmarkModel
|
|
if (result.count("benchmarkModel")) {
|
|
modelPath.benchmarkModel = true;
|
|
modelPath.bizCode = "benchmark";
|
|
}
|
|
// half float
|
|
if (result.count("fp16")) {
|
|
modelPath.saveHalfFloat = true;
|
|
}
|
|
if (result.count("forTraining")) {
|
|
modelPath.forTraining = true;
|
|
}
|
|
if (result.count("weightQuantBits")) {
|
|
modelPath.weightQuantBits = result["weightQuantBits"].as<int>();
|
|
}
|
|
if (result.count("weightQuantAsymmetric")) {
|
|
modelPath.weightQuantAsymmetric = true;
|
|
}
|
|
if (result.count("saveStaticModel")) {
|
|
modelPath.saveStaticModel = true;
|
|
}
|
|
|
|
// Int8 calibration table path.
|
|
if (result.count("compressionParamsFile")) {
|
|
modelPath.compressionParamsFile =
|
|
result["compressionParamsFile"].as<std::string>();
|
|
}
|
|
|
|
return options;
|
|
}
|
|
|
|
bool CommonKit::FileIsExist(string path) {
|
|
#if defined(_MSC_VER)
|
|
if (INVALID_FILE_ATTRIBUTES != GetFileAttributes(path.c_str()) && GetLastError() != ERROR_FILE_NOT_FOUND) {
|
|
return true;
|
|
}
|
|
#else
|
|
if ((access(path.c_str(), F_OK)) != -1) {
|
|
return true;
|
|
}
|
|
#endif
|
|
return false;
|
|
}
|