MNN/tools/train/source/demo/dataLoaderDemo.cpp

118 lines
4.1 KiB
C++
Raw Normal View History

2019-12-27 22:16:57 +08:00
//
// dataLoaderDemo.cpp
// MNN
//
// Created by MNN on 2019/11/20.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include <iostream>
#include "DataLoader.hpp"
#include "DemoUnit.hpp"
#include "MNN_generated.h"
#include "MnistDataset.hpp"
2020-02-26 09:57:17 +08:00
#include "LambdaTransform.hpp"
#include "RandomSampler.hpp"
#include "Sampler.hpp"
#include "StackTransform.hpp"
#include "Transform.hpp"
#include "TransformDataset.hpp"
2019-12-27 22:16:57 +08:00
#ifdef MNN_USE_OPENCV
#include <opencv2/opencv.hpp> // use opencv to show pictures
using namespace cv;
#endif
using namespace std;
2020-02-26 09:57:17 +08:00
using namespace MNN;
using namespace MNN::Train;
2019-12-27 22:16:57 +08:00
/*
* this is an demo for how to use the DataLoader
*/
class DataLoaderDemo : public DemoUnit {
public:
// this function is an example to use the lambda transform
// here we use lambda transform to normalize data from 0~255 to 0~1
static Example func(Example example) {
// // an easier way to do this
2020-02-26 09:57:17 +08:00
auto cast = _Cast(example.first[0], halide_type_of<float>());
example.first[0] = _Multiply(cast, _Const(1.0f / 255.0f));
2019-12-27 22:16:57 +08:00
return example;
}
virtual int run(int argc, const char* argv[]) override {
if (argc != 2) {
cout << "usage: ./runTrainDemo.out DataLoaderDemo /path/to/unzipped/mnist/data/" << endl;
return 0;
}
std::string root = argv[1];
// train data loader
const size_t trainDatasetSize = 60000;
2020-02-26 09:57:17 +08:00
auto trainDatasetOrigin = MnistDataset::create(root, MnistDataset::Mode::TRAIN);
auto trainDataset = trainDatasetOrigin.mDataset;
2019-12-27 22:16:57 +08:00
// the lambda transform for one example, we also can do it in batch
auto trainTransform = std::make_shared<LambdaTransform>(func);
// // the stack transform, stack [1, 28, 28] to [n, 1, 28, 28]
// auto trainTransform = std::make_shared<StackTransform>();
const int trainBatchSize = 7;
const int trainNumWorkers = 4;
auto trainDataLoader =
2020-02-26 09:57:17 +08:00
std::shared_ptr<DataLoader>(DataLoader::makeDataLoader(trainDataset, {trainTransform}, trainBatchSize, true, trainNumWorkers));
2019-12-27 22:16:57 +08:00
// test data loader
const size_t testDatasetSize = 10000;
2020-02-26 09:57:17 +08:00
auto testDatasetOrigin = MnistDataset::create(root, MnistDataset::Mode::TEST);
auto testDataset = testDatasetOrigin.mDataset;
2019-12-27 22:16:57 +08:00
// the lambda transform for one example, we also can do it in batch
auto testTransform = std::make_shared<LambdaTransform>(func);
// // the stack transform, stack [1, 28, 28] to [n, 1, 28, 28]
// auto testTransform = std::make_shared<StackTransform>();
const int testBatchSize = 3;
const int testNumWorkers = 4;
auto testDataLoader =
2020-02-26 09:57:17 +08:00
std::shared_ptr<DataLoader>(DataLoader::makeDataLoader(testDataset, {testTransform}, testBatchSize, false, testNumWorkers));
2019-12-27 22:16:57 +08:00
const size_t iterations = testDatasetSize / testBatchSize;
for (int i = 0; i < iterations; i++) {
auto trainData = trainDataLoader->next();
auto testData = testDataLoader->next();
2020-02-26 09:57:17 +08:00
auto data = trainData[0].first[0]->readMap<float>();
auto label = trainData[0].second[0]->readMap<uint8_t>();
2019-12-27 22:16:57 +08:00
cout << "index: " << i << " train label: " << int(label[0]) << endl;
#ifdef MNN_USE_OPENCV
// only show the first picture in the batch
imshow("train", Mat(28, 28, CV_32FC1, (void*)data));
#endif
2020-02-26 09:57:17 +08:00
data = testData[0].first[0]->readMap<float>();
label = testData[0].second[0]->readMap<uint8_t>();
2019-12-27 22:16:57 +08:00
cout << "index: " << i << " test label: " << int(label[0]) << endl;
#ifdef MNN_USE_OPENCV
// only show the first picture in the batch
imshow("test", Mat(28, 28, CV_32FC1, (void*)data));
waitKey(-1);
#endif
2019-12-27 22:16:57 +08:00
}
// this will reset the sampler's internal state, not necessary here
trainDataLoader->reset();
// this will reset the sampler's internal state, necessary here, because the test dataset is exhausted
testDataLoader->reset();
return 0;
}
};
DemoUnitSetRegister(DataLoaderDemo, "DataLoaderDemo");