mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			58 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			58 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			C++
		
	
	
	
//
 | 
						|
//  linearRegression.cpp
 | 
						|
//  MNN
 | 
						|
//
 | 
						|
//  Created by MNN on 2019/11/22.
 | 
						|
//  Copyright © 2018, Alibaba Group Holding Limited
 | 
						|
//
 | 
						|
 | 
						|
#include <MNN/expr/ExprCreator.hpp>
 | 
						|
#include <random>
 | 
						|
#include "DemoUnit.hpp"
 | 
						|
#include "SGD.hpp"
 | 
						|
using namespace MNN::Express;
 | 
						|
using namespace MNN::Train;
 | 
						|
std::random_device gRandom;
 | 
						|
class LinearRegress : public DemoUnit {
 | 
						|
public:
 | 
						|
    virtual int run(int argc, const char* argv[]) override {
 | 
						|
        VARP w = _TrainableParam(0.3f, {}, NHWC);
 | 
						|
        VARP b = _TrainableParam(0.1f, {}, NHWC);
 | 
						|
        std::shared_ptr<Module> _m(Module::createEmpty({w, b}));
 | 
						|
        std::shared_ptr<SGD> opt(new SGD(_m));
 | 
						|
        opt->setLearningRate(0.1f);
 | 
						|
 | 
						|
        const int number = 10;
 | 
						|
        const int limit  = 300;
 | 
						|
        for (int i = 0; i < limit; ++i) {
 | 
						|
            VARP x = _Input({number}, NHWC);
 | 
						|
            // Fill x
 | 
						|
            auto xPtr = x->writeMap<float>();
 | 
						|
            for (int v = 0; v < number; ++v) {
 | 
						|
                xPtr[v] = (gRandom() % 10000) / 10000.0f;
 | 
						|
            }
 | 
						|
            VARP label = _Input({number}, NHWC);
 | 
						|
            // Fill label
 | 
						|
            auto ptr = label->writeMap<float>();
 | 
						|
            for (int v = 0; v < number; ++v) {
 | 
						|
                ptr[v] = xPtr[v] * 0.8f + 0.7f;
 | 
						|
            }
 | 
						|
            VARP y = x * w + b;
 | 
						|
 | 
						|
            VARP diff = y - label;
 | 
						|
            VARP loss = (diff * diff).mean({});
 | 
						|
 | 
						|
            if (i == limit - 1) {
 | 
						|
                MNN_PRINT("w = %f, b = %f, Target w = 0.8f, Target b = 0.7f\n", w->readMap<float>()[0],
 | 
						|
                          b->readMap<float>()[0]);
 | 
						|
                Variable::save({y}, "linear.mnn");
 | 
						|
            } else {
 | 
						|
                opt->step(loss);
 | 
						|
            }
 | 
						|
        }
 | 
						|
        return 0;
 | 
						|
    }
 | 
						|
};
 | 
						|
 | 
						|
DemoUnitSetRegister(LinearRegress, "LinearRegress");
 |