mirror of https://github.com/alibaba/MNN.git
223 lines
8.2 KiB
C++
223 lines
8.2 KiB
C++
//
|
|
// distillTrainQuant.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2020/02/19.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include <MNN/expr/Executor.hpp>
|
|
#include <cmath>
|
|
#include <sstream>
|
|
#include <fstream>
|
|
#include <iostream>
|
|
#include <vector>
|
|
#include <string>
|
|
#include "DemoUnit.hpp"
|
|
#include "NN.hpp"
|
|
#include "SGD.hpp"
|
|
#include "PipelineModule.hpp"
|
|
#define MNN_OPEN_TIME_TRACE
|
|
#include <MNN/AutoTime.hpp>
|
|
#include <functional>
|
|
#include "RandomGenerator.hpp"
|
|
#include "LearningRateScheduler.hpp"
|
|
#include "Loss.hpp"
|
|
#include "Transformer.hpp"
|
|
#include "DataLoader.hpp"
|
|
#include "ImageDataset.hpp"
|
|
|
|
using namespace MNN;
|
|
using namespace MNN::Express;
|
|
using namespace MNN::Train;
|
|
using namespace MNN::CV;
|
|
|
|
std::string gTrainImagePath;
|
|
std::string gTrainTxt;
|
|
std::string gTestImagePath;
|
|
std::string gTestTxt;
|
|
|
|
void _test(std::shared_ptr<Module> optmized, const ImageDataset::ImageConfig* config) {
|
|
bool readAllImagesToMemory = false;
|
|
DatasetPtr dataset = ImageDataset::create(gTestImagePath, gTestTxt, config, readAllImagesToMemory);
|
|
|
|
const int batchSize = 10;
|
|
const int numWorkers = 0;
|
|
std::shared_ptr<DataLoader> dataLoader(dataset.createLoader(batchSize, true, false, numWorkers));
|
|
|
|
const int iterations = dataLoader->iterNumber();
|
|
|
|
// const int usedSize = 1000;
|
|
// const int iterations = usedSize / batchSize;
|
|
|
|
int correct = 0;
|
|
dataLoader->reset();
|
|
optmized->setIsTraining(false);
|
|
|
|
AUTOTIME;
|
|
for (int i = 0; i < iterations; i++) {
|
|
if ((i + 1) % 10 == 0) {
|
|
std::cout << "test iteration: " << (i + 1) << std::endl;
|
|
}
|
|
auto data = dataLoader->next();
|
|
auto example = data[0];
|
|
auto predict = optmized->forward(_Convert(example.first[0], NC4HW4));
|
|
predict = _Softmax(predict);
|
|
predict = _ArgMax(predict, 1); // (N, numClasses) --> (N)
|
|
const int addToLabel = 1;
|
|
auto label = example.second[0] + _Scalar<int32_t>(addToLabel);
|
|
auto accu = _Cast<int32_t>(_Equal(predict, label).sum({}));
|
|
|
|
correct += accu->readMap<int32_t>()[0];
|
|
}
|
|
auto accu = (float)correct / dataLoader->size();
|
|
// auto accu = (float)correct / usedSize;
|
|
std::cout << "accuracy: " << accu << std::endl;
|
|
}
|
|
|
|
void _train(std::shared_ptr<Module> origin, std::shared_ptr<Module> optmized, std::string inputName, std::string outputName) {
|
|
std::shared_ptr<SGD> sgd(new SGD);
|
|
sgd->append(optmized->parameters());
|
|
sgd->setMomentum(0.9f);
|
|
sgd->setWeightDecay(0.00004f);
|
|
|
|
auto converImagesToFormat = CV::RGB;
|
|
int resizeHeight = 224;
|
|
int resizeWidth = 224;
|
|
std::vector<float> means = {127.5, 127.5, 127.5};
|
|
std::vector<float> scales = {1/127.5, 1/127.5, 1/127.5};
|
|
std::vector<float> cropFraction = {0.875, 0.875}; // center crop fraction for height and width
|
|
bool centerOrRandomCrop = false; // true for random crop
|
|
std::shared_ptr<ImageDataset::ImageConfig> datasetConfig(ImageDataset::ImageConfig::create(converImagesToFormat, resizeHeight, resizeWidth, scales, means, cropFraction, centerOrRandomCrop));
|
|
bool readAllImagesToMemory = false;
|
|
DatasetPtr dataset = ImageDataset::create(gTrainImagePath, gTrainTxt, datasetConfig.get(), readAllImagesToMemory);
|
|
|
|
const int batchSize = 32;
|
|
const int numWorkers = 4;
|
|
auto dataLoader = dataset.createLoader(batchSize, true, true, numWorkers);
|
|
|
|
const int iterations = dataLoader->iterNumber();
|
|
|
|
for (int epoch = 0; epoch < 5; ++epoch) {
|
|
AUTOTIME;
|
|
dataLoader->reset();
|
|
optmized->setIsTraining(true);
|
|
origin->setIsTraining(false);
|
|
Timer _100Time;
|
|
int lastIndex = 0;
|
|
int moveBatchSize = 0;
|
|
for (int i = 0; i < iterations; i++) {
|
|
// AUTOTIME;
|
|
auto trainData = dataLoader->next();
|
|
auto example = trainData[0].first[0];
|
|
moveBatchSize += example->getInfo()->dim[0];
|
|
auto nc4hw4example = _Convert(example, NC4HW4);
|
|
auto teacherLogits = origin->forward(nc4hw4example);
|
|
auto studentLogits = optmized->forward(nc4hw4example);
|
|
|
|
// Compute One-Hot
|
|
auto labels = trainData[0].second[0];
|
|
const int addToLabel = 1;
|
|
auto newTarget = _OneHot(_Cast<int32_t>(_Squeeze(labels + _Scalar<int32_t>(addToLabel), {})),
|
|
_Scalar<int>(1001), _Scalar<float>(1.0f),
|
|
_Scalar<float>(0.0f));
|
|
|
|
VARP loss = _DistillLoss(studentLogits, teacherLogits, newTarget, 20, 0.9);
|
|
|
|
// float rate = LrScheduler::inv(basicRate, epoch * iterations + i, 0.0001, 0.75);
|
|
float rate = 1e-5;
|
|
sgd->setLearningRate(rate);
|
|
if (moveBatchSize % (10 * batchSize) == 0 || i == iterations - 1) {
|
|
std::cout << "epoch: " << (epoch);
|
|
std::cout << " " << moveBatchSize << " / " << dataLoader->size();
|
|
std::cout << " loss: " << loss->readMap<float>()[0];
|
|
std::cout << " lr: " << rate;
|
|
std::cout << " time: " << (float)_100Time.durationInUs() / 1000.0f << " ms / " << (i - lastIndex) << " iter" << std::endl;
|
|
std::cout.flush();
|
|
_100Time.reset();
|
|
lastIndex = i;
|
|
}
|
|
sgd->step(loss);
|
|
}
|
|
|
|
{
|
|
AUTOTIME;
|
|
dataLoader->reset();
|
|
optmized->setIsTraining(false);
|
|
{
|
|
auto forwardInput = _Input({1, 3, 224, 224}, NC4HW4);
|
|
forwardInput->setName(inputName);
|
|
auto predict = optmized->forward(forwardInput);
|
|
auto output = _Softmax(predict);
|
|
output->setName(outputName);
|
|
Transformer::turnModelToInfer()->onExecute({output});
|
|
Variable::save({output}, "temp.quan.mnn");
|
|
}
|
|
}
|
|
|
|
_test(optmized, datasetConfig.get());
|
|
}
|
|
}
|
|
|
|
class DistillTrainQuant : public DemoUnit {
|
|
public:
|
|
virtual int run(int argc, const char* argv[]) override {
|
|
if (argc < 6) {
|
|
MNN_PRINT("usage: ./runTrainDemo.out DistillTrainQuant /path/to/mobilenetV2Model path/to/train/images/ path/to/train/image/txt path/to/test/images/ path/to/test/image/txt [bits]\n");
|
|
return 0;
|
|
}
|
|
|
|
gTrainImagePath = argv[2];
|
|
gTrainTxt = argv[3];
|
|
gTestImagePath = argv[4];
|
|
gTestTxt = argv[5];
|
|
|
|
auto varMap = Variable::loadMap(argv[1]);
|
|
if (varMap.empty()) {
|
|
MNN_ERROR("Can not load model %s\n", argv[1]);
|
|
return 0;
|
|
}
|
|
int bits = 8;
|
|
if (argc > 6) {
|
|
std::istringstream is(argv[6]);
|
|
is >> bits;
|
|
}
|
|
if (1 > bits || bits > 8) {
|
|
MNN_ERROR("bits must be 2-8, use 8 default\n");
|
|
bits = 8;
|
|
}
|
|
FUNC_PRINT(bits);
|
|
|
|
auto inputOutputs = Variable::getInputAndOutput(varMap);
|
|
auto inputs = Variable::mapToSequence(inputOutputs.first);
|
|
MNN_ASSERT(inputs.size() == 1);
|
|
auto input = inputs[0];
|
|
std::string inputName = input->name();
|
|
auto inputInfo = input->getInfo();
|
|
MNN_ASSERT(nullptr != inputInfo && inputInfo->order == NC4HW4);
|
|
|
|
auto outputs = Variable::mapToSequence(inputOutputs.second);
|
|
std::string originOutputName = outputs[0]->name();
|
|
|
|
std::string nodeBeforeSoftmax = "MobilenetV2/Predictions/Reshape";
|
|
auto lastVar = varMap[nodeBeforeSoftmax];
|
|
std::map<std::string, VARP> outputVarPair;
|
|
outputVarPair[nodeBeforeSoftmax] = lastVar;
|
|
|
|
auto logitsOutput = Variable::mapToSequence(outputVarPair);
|
|
{
|
|
auto exe = Executor::getGlobalExecutor();
|
|
BackendConfig config;
|
|
exe->setGlobalExecutorConfig(MNN_FORWARD_CPU, config, 4);
|
|
}
|
|
std::shared_ptr<Module> model(PipelineModule::extract(inputs, logitsOutput, true));
|
|
PipelineModule::turnQuantize(model.get(), bits);
|
|
((PipelineModule*)model.get())->toTrainQuant(bits);
|
|
std::shared_ptr<Module> originModel(PipelineModule::extract(inputs, logitsOutput, false));
|
|
_train(originModel, model, inputName, originOutputName);
|
|
return 0;
|
|
}
|
|
};
|
|
|
|
DemoUnitSetRegister(DistillTrainQuant, "DistillTrainQuant");
|