MNN/docs/pymnn/optim.md

1.7 KiB
Raw Permalink Blame History

optim

module optim

optim时优化器模块提供了一个优化器基类Optimizer,并提供了SGDADAM优化器实现;主要用于训练阶段迭代优化


optim Types


optim.Regularization_Method

优化器的正则化方法提供了L1和L2正则化方法

  • 类型:Enum
  • 枚举值:
    • L1
    • L2
    • L1L2

SGD(module, lr, momentum, weight_decay, regularization_method)

创建一个SGD优化器

参数:

  • module:_Module 模型实例
  • lr:float 学习率
  • momentum:float 动量默认为0.9
  • weight_decay:float 权重衰减默认为0.0
  • regularization_method:RegularizationMethod 正则化方法默认为L2正则化

返回SGD优化器实例

返回类型:Optimizer

示例:

model = Net()
sgd = optim.SGD(model, 0.001, 0.9, 0.0005, optim.Regularization_Method.L2)
# feed some date to the model, then get the loss
loss = ...
sgd.step(loss) # backward and update parameters in the model

ADAM(module, lr, momentum, momentum2, weight_decay, eps, regularization_method)

创建一个ADAM优化器

参数:

  • module:_Module 模型实例
  • lr:float 学习率
  • momentum:float 动量默认为0.9
  • momentum2:float 动量2默认为0.999
  • weight_decay:float 权重衰减默认为0.0
  • eps:float 正则化阈值默认为1e-8
  • regularization_method:RegularizationMethod 正则化方法默认为L2正则化

返回ADAM优化器实例

返回类型:Optimizer

示例:

model = Net()
sgd = optim.ADAM(model, 0.001)
# feed some date to the model, then get the loss
loss = ...
sgd.step(loss) # backward and update parameters in the model