MNN/tools/train/source/demo/linearRegression.cpp

59 lines
1.7 KiB
C++
Raw Normal View History

2019-12-27 22:16:57 +08:00
//
// linearRegression.cpp
// MNN
//
// Created by MNN on 2019/11/22.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include <MNN/expr/ExprCreator.hpp>
#include <random>
#include "DemoUnit.hpp"
#include "SGD.hpp"
using namespace MNN::Express;
using namespace MNN::Train;
std::random_device gRandom;
class LinearRegress : public DemoUnit {
public:
virtual int run(int argc, const char* argv[]) override {
2020-02-26 09:57:17 +08:00
VARP w = _TrainableParam(0.3f, {}, NHWC);
VARP b = _TrainableParam(0.1f, {}, NHWC);
2019-12-27 22:16:57 +08:00
std::shared_ptr<SGD> opt(new SGD);
opt->append({w, b});
opt->setLearningRate(0.1f);
const int number = 10;
const int limit = 300;
for (int i = 0; i < limit; ++i) {
VARP x = _Input({number}, NHWC);
// Fill x
auto xPtr = x->writeMap<float>();
for (int v = 0; v < number; ++v) {
xPtr[v] = (gRandom() % 10000) / 10000.0f;
}
VARP label = _Input({number}, NHWC);
// Fill label
auto ptr = label->writeMap<float>();
for (int v = 0; v < number; ++v) {
ptr[v] = xPtr[v] * 0.8f + 0.7f;
}
VARP y = x * w + b;
VARP diff = y - label;
VARP loss = (diff * diff).mean({});
if (i == limit - 1) {
MNN_PRINT("w = %f, b = %f, Target w = 0.8f, Target b = 0.7f\n", w->readMap<float>()[0],
b->readMap<float>()[0]);
Variable::save({y}, "linear.mnn");
} else {
opt->step(loss);
}
}
return 0;
}
};
DemoUnitSetRegister(LinearRegress, "LinearRegress");