MNN/tools/train/source/grad/OpGrad.hpp

43 lines
1014 B
C++
Raw Normal View History

2019-12-27 22:16:57 +08:00
//
// OpGrad.hpp
// MNN
//
// Created by MNN on 2019/05/05.
// Copyright © 2018, Alibaba Group Holding Limited
//
#ifndef OpGrad_hpp
#define OpGrad_hpp
#include <MNN/expr/Expr.hpp>
#include <MNN/expr/ExprCreator.hpp>
#include <MNN/expr/Optimizer.hpp>
#include <map>
#include <vector>
#include "MNN_generated.h"
namespace MNN {
class MNN_PUBLIC OpGrad {
public:
enum Type { LINEAR, SEMI_LINEAR, NO_LINEAR };
OpGrad() = default;
virtual ~OpGrad() = default;
Type type() const {
return mType;
}
virtual std::vector<Express::VARP> onGrad(Express::EXPRP expr,
2019-12-27 22:16:57 +08:00
const std::vector<Express::VARP>& backwardOutput) = 0;
static OpGrad* get(int type);
static void insert(int type, OpGrad* creator);
2020-02-26 09:57:17 +08:00
static std::map<Express::VARP, Express::VARP> grad(Express::VARP loss, const std::set<Express::VARP>& parameters, const std::string& blockExpr = "");
2019-12-27 22:16:57 +08:00
protected:
Type mType = LINEAR;
};
} // namespace MNN
#endif