mirror of https://github.com/alibaba/MNN.git
118 lines
4.1 KiB
C++
118 lines
4.1 KiB
C++
//
|
|
// dataLoaderDemo.cpp
|
|
// MNN
|
|
//
|
|
// Created by MNN on 2019/11/20.
|
|
// Copyright © 2018, Alibaba Group Holding Limited
|
|
//
|
|
|
|
#include <iostream>
|
|
#include "DataLoader.hpp"
|
|
#include "DemoUnit.hpp"
|
|
#include "MNN_generated.h"
|
|
#include "MnistDataset.hpp"
|
|
#include "LambdaTransform.hpp"
|
|
#include "RandomSampler.hpp"
|
|
#include "Sampler.hpp"
|
|
#include "StackTransform.hpp"
|
|
#include "Transform.hpp"
|
|
#include "TransformDataset.hpp"
|
|
|
|
#ifdef MNN_USE_OPENCV
|
|
#include <opencv2/opencv.hpp> // use opencv to show pictures
|
|
using namespace cv;
|
|
#endif
|
|
|
|
using namespace std;
|
|
using namespace MNN;
|
|
using namespace MNN::Train;
|
|
/*
|
|
* this is an demo for how to use the DataLoader
|
|
*/
|
|
|
|
class DataLoaderDemo : public DemoUnit {
|
|
public:
|
|
// this function is an example to use the lambda transform
|
|
// here we use lambda transform to normalize data from 0~255 to 0~1
|
|
static Example func(Example example) {
|
|
// // an easier way to do this
|
|
auto cast = _Cast(example.first[0], halide_type_of<float>());
|
|
example.first[0] = _Multiply(cast, _Const(1.0f / 255.0f));
|
|
return example;
|
|
}
|
|
virtual int run(int argc, const char* argv[]) override {
|
|
if (argc != 2) {
|
|
cout << "usage: ./runTrainDemo.out DataLoaderDemo /path/to/unzipped/mnist/data/" << endl;
|
|
return 0;
|
|
}
|
|
|
|
std::string root = argv[1];
|
|
|
|
// train data loader
|
|
const size_t trainDatasetSize = 60000;
|
|
auto trainDatasetOrigin = MnistDataset::create(root, MnistDataset::Mode::TRAIN);
|
|
auto trainDataset = trainDatasetOrigin.mDataset;
|
|
|
|
// the lambda transform for one example, we also can do it in batch
|
|
auto trainTransform = std::make_shared<LambdaTransform>(func);
|
|
|
|
// // the stack transform, stack [1, 28, 28] to [n, 1, 28, 28]
|
|
// auto trainTransform = std::make_shared<StackTransform>();
|
|
|
|
const int trainBatchSize = 7;
|
|
const int trainNumWorkers = 4;
|
|
|
|
auto trainDataLoader =
|
|
std::shared_ptr<DataLoader>(DataLoader::makeDataLoader(trainDataset, {trainTransform}, trainBatchSize, true, trainNumWorkers));
|
|
|
|
// test data loader
|
|
const size_t testDatasetSize = 10000;
|
|
auto testDatasetOrigin = MnistDataset::create(root, MnistDataset::Mode::TEST);
|
|
auto testDataset = testDatasetOrigin.mDataset;
|
|
|
|
// the lambda transform for one example, we also can do it in batch
|
|
auto testTransform = std::make_shared<LambdaTransform>(func);
|
|
|
|
// // the stack transform, stack [1, 28, 28] to [n, 1, 28, 28]
|
|
// auto testTransform = std::make_shared<StackTransform>();
|
|
|
|
const int testBatchSize = 3;
|
|
const int testNumWorkers = 4;
|
|
|
|
auto testDataLoader =
|
|
std::shared_ptr<DataLoader>(DataLoader::makeDataLoader(testDataset, {testTransform}, testBatchSize, false, testNumWorkers));
|
|
|
|
const size_t iterations = testDatasetSize / testBatchSize;
|
|
|
|
for (int i = 0; i < iterations; i++) {
|
|
auto trainData = trainDataLoader->next();
|
|
auto testData = testDataLoader->next();
|
|
|
|
auto data = trainData[0].first[0]->readMap<float>();
|
|
auto label = trainData[0].second[0]->readMap<uint8_t>();
|
|
|
|
cout << "index: " << i << " train label: " << int(label[0]) << endl;
|
|
#ifdef MNN_USE_OPENCV
|
|
// only show the first picture in the batch
|
|
imshow("train", Mat(28, 28, CV_32FC1, (void*)data));
|
|
#endif
|
|
data = testData[0].first[0]->readMap<float>();
|
|
label = testData[0].second[0]->readMap<uint8_t>();
|
|
|
|
cout << "index: " << i << " test label: " << int(label[0]) << endl;
|
|
#ifdef MNN_USE_OPENCV
|
|
// only show the first picture in the batch
|
|
imshow("test", Mat(28, 28, CV_32FC1, (void*)data));
|
|
waitKey(-1);
|
|
#endif
|
|
}
|
|
// this will reset the sampler's internal state, not necessary here
|
|
trainDataLoader->reset();
|
|
|
|
// this will reset the sampler's internal state, necessary here, because the test dataset is exhausted
|
|
testDataLoader->reset();
|
|
return 0;
|
|
}
|
|
};
|
|
DemoUnitSetRegister(DataLoaderDemo, "DataLoaderDemo");
|