2019-07-11 13:56:52 +08:00
|
|
|
//
|
|
|
|
|
// quantized.cpp
|
|
|
|
|
// MNN
|
|
|
|
|
//
|
|
|
|
|
// Created by MNN on 2019/07/01.
|
|
|
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
#include <fstream>
|
|
|
|
|
#include <sstream>
|
2021-04-28 18:02:10 +08:00
|
|
|
#include <string>
|
2019-07-11 13:56:52 +08:00
|
|
|
#include "calibration.hpp"
|
|
|
|
|
#include "logkit.h"
|
2021-01-06 19:54:08 +08:00
|
|
|
|
2019-07-11 13:56:52 +08:00
|
|
|
int main(int argc, const char* argv[]) {
|
|
|
|
|
if (argc < 4) {
|
2019-08-22 20:13:46 +08:00
|
|
|
DLOG(INFO) << "Usage: ./quantized.out src.mnn dst.mnn preTreatConfig.json\n";
|
2019-07-11 13:56:52 +08:00
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
const char* modelFile = argv[1];
|
|
|
|
|
const char* preTreatConfig = argv[3];
|
|
|
|
|
const char* dstFile = argv[2];
|
2019-08-22 20:13:46 +08:00
|
|
|
DLOG(INFO) << ">>> modelFile: " << modelFile;
|
|
|
|
|
DLOG(INFO) << ">>> preTreatConfig: " << preTreatConfig;
|
|
|
|
|
DLOG(INFO) << ">>> dstFile: " << dstFile;
|
2019-07-11 13:56:52 +08:00
|
|
|
std::unique_ptr<MNN::NetT> netT;
|
|
|
|
|
{
|
2022-03-01 14:33:13 +08:00
|
|
|
//std::ifstream input(modelFile);
|
|
|
|
|
std::ifstream input(modelFile, std::ifstream::in | std::ifstream::binary);
|
2019-07-11 13:56:52 +08:00
|
|
|
std::ostringstream outputOs;
|
|
|
|
|
outputOs << input.rdbuf();
|
|
|
|
|
netT = MNN::UnPackNet(outputOs.str().c_str());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// temp build net for inference
|
|
|
|
|
flatbuffers::FlatBufferBuilder builder(1024);
|
|
|
|
|
auto offset = MNN::Net::Pack(builder, netT.get());
|
|
|
|
|
builder.Finish(offset);
|
|
|
|
|
int size = builder.GetSize();
|
|
|
|
|
auto ocontent = builder.GetBufferPointer();
|
|
|
|
|
|
|
|
|
|
// model buffer for creating mnn Interpreter
|
|
|
|
|
std::unique_ptr<uint8_t> modelForInference(new uint8_t[size]);
|
|
|
|
|
memcpy(modelForInference.get(), ocontent, size);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<uint8_t> modelOriginal(new uint8_t[size]);
|
|
|
|
|
memcpy(modelOriginal.get(), ocontent, size);
|
|
|
|
|
|
|
|
|
|
netT.reset();
|
|
|
|
|
netT = MNN::UnPackNet(modelOriginal.get());
|
|
|
|
|
|
|
|
|
|
// quantize model's weight
|
2019-08-22 20:13:46 +08:00
|
|
|
DLOG(INFO) << "Calibrate the feature and quantize model...";
|
2019-07-11 13:56:52 +08:00
|
|
|
std::shared_ptr<Calibration> calibration(
|
2021-04-28 18:02:10 +08:00
|
|
|
new Calibration(netT.get(), modelForInference.get(), size, preTreatConfig, std::string(modelFile), std::string(dstFile)));
|
2019-07-11 13:56:52 +08:00
|
|
|
calibration->runQuantizeModel();
|
2021-02-03 10:04:41 +08:00
|
|
|
calibration->dumpTensorScales(dstFile);
|
2019-08-22 20:13:46 +08:00
|
|
|
DLOG(INFO) << "Quantize model done!";
|
2019-07-11 13:56:52 +08:00
|
|
|
|
2020-12-11 11:23:31 +08:00
|
|
|
return 0;
|
2019-07-11 13:56:52 +08:00
|
|
|
}
|