mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			59 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			59 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			C++
		
	
	
	
| //
 | |
| //  linearRegression.cpp
 | |
| //  MNN
 | |
| //
 | |
| //  Created by MNN on 2019/11/22.
 | |
| //  Copyright © 2018, Alibaba Group Holding Limited
 | |
| //
 | |
| 
 | |
| #include <MNN/expr/ExprCreator.hpp>
 | |
| #include <random>
 | |
| #include "DemoUnit.hpp"
 | |
| #include "SGD.hpp"
 | |
| using namespace MNN::Express;
 | |
| using namespace MNN::Train;
 | |
| std::random_device gRandom;
 | |
| class LinearRegress : public DemoUnit {
 | |
| public:
 | |
|     virtual int run(int argc, const char* argv[]) override {
 | |
|         VARP w = _TrainableParam(0.3f, {}, NHWC);
 | |
|         VARP b = _TrainableParam(0.1f, {}, NHWC);
 | |
| 
 | |
|         std::shared_ptr<SGD> opt(new SGD);
 | |
|         opt->append({w, b});
 | |
|         opt->setLearningRate(0.1f);
 | |
| 
 | |
|         const int number = 10;
 | |
|         const int limit  = 300;
 | |
|         for (int i = 0; i < limit; ++i) {
 | |
|             VARP x = _Input({number}, NHWC);
 | |
|             // Fill x
 | |
|             auto xPtr = x->writeMap<float>();
 | |
|             for (int v = 0; v < number; ++v) {
 | |
|                 xPtr[v] = (gRandom() % 10000) / 10000.0f;
 | |
|             }
 | |
|             VARP label = _Input({number}, NHWC);
 | |
|             // Fill label
 | |
|             auto ptr = label->writeMap<float>();
 | |
|             for (int v = 0; v < number; ++v) {
 | |
|                 ptr[v] = xPtr[v] * 0.8f + 0.7f;
 | |
|             }
 | |
|             VARP y = x * w + b;
 | |
| 
 | |
|             VARP diff = y - label;
 | |
|             VARP loss = (diff * diff).mean({});
 | |
| 
 | |
|             if (i == limit - 1) {
 | |
|                 MNN_PRINT("w = %f, b = %f, Target w = 0.8f, Target b = 0.7f\n", w->readMap<float>()[0],
 | |
|                           b->readMap<float>()[0]);
 | |
|                 Variable::save({y}, "linear.mnn");
 | |
|             } else {
 | |
|                 opt->step(loss);
 | |
|             }
 | |
|         }
 | |
|         return 0;
 | |
|     }
 | |
| };
 | |
| 
 | |
| DemoUnitSetRegister(LinearRegress, "LinearRegress");
 |