2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								//
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								//  MnistUtils.cpp
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								//  MNN
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								//
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								//  Created by MNN on 2020/01/08.
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								//  Copyright © 2018, Alibaba Group Holding Limited
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								//
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include "MnistUtils.hpp"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include <MNN/expr/Executor.hpp>
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include <MNN/expr/Optimizer.hpp>
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include <cmath>
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include <iostream>
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include <vector>
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include "DataLoader.hpp"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include "DemoUnit.hpp"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include "MnistDataset.hpp"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include "NN.hpp"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include "SGD.hpp"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#define MNN_OPEN_TIME_TRACE
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include <MNN/AutoTime.hpp>
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include "ADAM.hpp"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include "LearningRateScheduler.hpp"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include "Loss.hpp"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include "RandomGenerator.hpp"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								#include "Transformer.hpp"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								using namespace MNN;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								using namespace MNN::Express;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								using namespace MNN::Train;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								void MnistUtils::train(std::shared_ptr<Module> model, std::string root) {
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        // Load snapshot
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        auto para = Variable::load("mnist.snapshot.mnn");
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        model->loadParameters(para);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    auto exe = Executor::getGlobalExecutor();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    BackendConfig config;
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    exe->setGlobalExecutorConfig(MNN_FORWARD_CPU, config, 4);
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    std::shared_ptr<SGD> sgd(new SGD);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    sgd->append(model->parameters());
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    sgd->setMomentum(0.9f);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    // sgd->setMomentum2(0.99f);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    sgd->setWeightDecay(0.0005f);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    auto dataset = MnistDataset::create(root, MnistDataset::Mode::TRAIN);
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    // the stack transform, stack [1, 28, 28] to [n, 1, 28, 28]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    const size_t batchSize  = 64;
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    const size_t numWorkers = 0;
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    bool shuffle            = true;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    auto dataLoader = std::shared_ptr<DataLoader>(dataset.createLoader(batchSize, true, shuffle, numWorkers));
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    size_t iterations = dataLoader->iterNumber();
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    auto testDataset            = MnistDataset::create(root, MnistDataset::Mode::TEST);
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    const size_t testBatchSize  = 20;
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    const size_t testNumWorkers = 0;
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    shuffle                     = false;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    auto testDataLoader = std::shared_ptr<DataLoader>(testDataset.createLoader(testBatchSize, true, shuffle, testNumWorkers));
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    size_t testIterations = testDataLoader->iterNumber();
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    for (int epoch = 0; epoch < 50; ++epoch) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        model->clearCache();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        exe->gc(Executor::FULL);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        exe->resetProfile();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            AUTOTIME;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            dataLoader->reset();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            model->setIsTraining(true);
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            Timer _100Time;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            int lastIndex = 0;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            int moveBatchSize = 0;
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            for (int i = 0; i < iterations; i++) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                // AUTOTIME;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                auto trainData  = dataLoader->next();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                auto example    = trainData[0];
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                auto cast       = _Cast<float>(example.first[0]);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                example.first[0] = cast * _Const(1.0f / 255.0f);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                moveBatchSize += example.first[0]->getInfo()->dim[0];
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                // Compute One-Hot
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                auto newTarget = _OneHot(_Cast<int32_t>(example.second[0]), _Scalar<int>(10), _Scalar<float>(1.0f),
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                                         _Scalar<float>(0.0f));
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                auto predict = model->forward(example.first[0]);
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                auto loss    = _CrossEntropy(predict, newTarget);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                float rate   = LrScheduler::inv(0.01, epoch * iterations + i, 0.0001, 0.75);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                sgd->setLearningRate(rate);
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                if (moveBatchSize % (10 * batchSize) == 0 || i == iterations - 1) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    std::cout << "epoch: " << (epoch);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    std::cout << "  " << moveBatchSize << " / " << dataLoader->size();
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    std::cout << " loss: " << loss->readMap<float>()[0];
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                    std::cout << " lr: " << rate;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    std::cout << " time: " << (float)_100Time.durationInUs() / 1000.0f << " ms / " << (i - lastIndex) <<  " iter"  << std::endl;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    std::cout.flush();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    _100Time.reset();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    lastIndex = i;
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                sgd->step(loss);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        }
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        Variable::save(model->parameters(), "mnist.snapshot.mnn");
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            model->setIsTraining(false);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            auto forwardInput = _Input({1, 1, 28, 28}, NC4HW4);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            forwardInput->setName("data");
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            auto predict = model->forward(forwardInput);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            predict->setName("prob");
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            Transformer::turnModelToInfer()->onExecute({predict});
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            Variable::save({predict}, "temp.mnist.mnn");
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        }
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        int correct = 0;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        testDataLoader->reset();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        model->setIsTraining(false);
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        int moveBatchSize = 0;
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        for (int i = 0; i < testIterations; i++) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            auto data       = testDataLoader->next();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            auto example    = data[0];
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            moveBatchSize += example.first[0]->getInfo()->dim[0];
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            if ((i + 1) % 100 == 0) {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                std::cout << "test: " << moveBatchSize << " / " << testDataLoader->size() << std::endl;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            auto cast       = _Cast<float>(example.first[0]);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            example.first[0] = cast * _Const(1.0f / 255.0f);
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            auto predict    = model->forward(example.first[0]);
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            predict         = _ArgMax(predict, 1);
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            auto accu       = _Cast<int32_t>(_Equal(predict, _Cast<int32_t>(example.second[0]))).sum({});
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            correct += accu->readMap<int32_t>()[0];
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        }
							 | 
						
					
						
							
								
									
										
										
										
											2020-02-26 09:57:17 +08:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        auto accu = (float)correct / (float)testDataLoader->size();
							 | 
						
					
						
							
								
									
										
										
										
											2020-01-15 13:33:47 +08:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        std::cout << "epoch: " << epoch << "  accuracy: " << accu << std::endl;
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        exe->dumpProfile();
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								}
							 |