mindspore.mint.optim.SGD
- class mindspore.mint.optim.SGD(params, lr, momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize=False)[源代码]
随机梯度下降算法。
如果nesterov为True:
如果nesterov为False:
需要注意的是,对于训练的第一步
。其中,p、v和u分别表示 parameters、accum 和 momentum。警告
这是一个实验性的优化器接口,后续可能修改或删除。需要和 LRScheduler 下的动态学习率接口配合使用。
- 参数:
params (Union[list(Parameter), list(dict)]) - 网络参数的列表或指定了参数组的列表。
lr (Union[bool, int, float, Tensor]) - 学习率。
momentum (Union[bool, int, float], 可选) - 动量值。默认值:
0
。weight_decay (Union[bool, int, float], 可选) - 权重衰减(L2 penalty),必须大于等于0。默认值:
0.
。dampening (Union[bool, int, float], 可选) - 动量的阻尼值。默认值:
0
。nesterov (bool, 可选) - 启用Nesterov动量。如果使用Nesterov,动量必须为正,阻尼必须等于0。默认值:
False
。
- 关键字参数:
maximize (bool, 可选) - 是否根据目标函数最大化网络参数。默认值:
False
。
- 输入:
gradients (tuple[Tensor]) - 网络权重的梯度。
- 异常:
ValueError - 学习率不是bool、int、float或Tensor。
ValueError - 学习率小于0。
ValueError - momentum 和 weight_decay 值小于0。
ValueError - momentum、 dampening 和 weight_decay 不是bool、int或float。
ValueError - nesterov 和 maximize 不是bool类型。
ValueError - nesterov 为True时, momentum 不为正或 dampening 不为0。
- 支持平台:
Ascend
样例:
>>> import mindspore >>> from mindspore import mint >>> from mindspore.mint import optim >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True) >>> optimizer = optim.SGD(net.trainable_params(), lr=0.1) >>> def forward_fn(data, label): ... logits = net(data) ... loss = loss_fn(logits, label) ... return loss, logits >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) >>> def train_step(data, label): ... (loss, _), grads = grad_fn(data, label) ... optimizer(grads) ... return loss