mindspore.experimental.optim.Optimizer

查看源文件
class mindspore.experimental.optim.Optimizer(params, defaults)[源代码]

用于参数更新的优化器基类。

警告

这是一个实验性的优化器模块,需要和 LRScheduler 下的动态学习率接口配合使用。

参数:
  • params (Union[list(Parameter), list(dict)]) - 网络参数的列表或指定了参数组的列表。

  • defaults (dict) - 一个包含了优化器参数默认值的字典(当参数组未指定参数值时使用此默认值)。

异常:
  • TypeError - learning_rate 不是int、float、Tensor。

  • TypeError - parameters 的元素不是Parameter或字典。

  • TypeError - weight_decay 不是float或int。

  • ValueError - weight_decay 小于0。

  • ValueError - learning_rate 是一个Tensor,但是其shape大于1。

支持平台:

Ascend GPU CPU

样例:

>>> import numpy as np
>>> import mindspore
>>> from mindspore import nn, Tensor, Parameter
>>> from mindspore import ops
>>> from mindspore.experimental import optim
>>>
>>> class MySGD(optim.Optimizer):
...    def __init__(self, params, lr):
...        defaults = dict(lr=lr)
...        super(MySGD, self).__init__(params, defaults)
...
...    def construct(self, gradients):
...         for group_id, group in enumerate(self.param_groups):
...            id = self.group_start_id[group_id]
...            for i, param in enumerate(group["params"]):
...                next_param = param + gradients[id+i] * group["lr"]
...                ops.assign(param, next_param)
>>>
>>> net = nn.Dense(8, 2)
>>> data = Tensor(np.random.rand(20, 8).astype(np.float32))
>>> label = Tensor(np.random.rand(20, 2).astype(np.float32))
>>>
>>> optimizer = MySGD(net.trainable_params(), 0.01)
>>> optimizer.add_param_group({"params": Parameter([0.01, 0.02])})
>>>
>>> criterion = nn.MAELoss(reduction="mean")
>>>
>>> def forward_fn(data, label):
...    logits = net(data)
...    loss = criterion(logits, label)
...    return loss, logits
>>>
>>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
>>>
>>> def train_step(data, label):
...    (loss, _), grads = grad_fn(data, label)
...    optimizer(grads)
...    print(loss)
>>>
>>> train_step(data, label)
add_param_group(param_group)[源代码]

Optimizer.param_groups 属性添加一个参数组。

参数:
  • param_group (dict) - 指定了当前网络参数组的特定的优化器配置。