mindspore.ops.ApplyMomentum

查看源文件
class mindspore.ops.ApplyMomentum(use_nesterov=False, use_locking=False, gradient_scale=1.0)[源代码]

使用动量算法的优化器。

更多详细信息,请参阅论文 On the importance of initialization and momentum in deep learning

输入的 variableaccumulationgradient 的输入遵循隐式类型转换规则,使数据类型一致。如果它们具有不同的数据类型,则低精度数据类型将转换为相对最高精度的数据类型。

有关公式和用法的更多详细信息,请参阅 mindspore.nn.Momentum

参数:
  • use_locking (bool) - 是否对参数更新加锁保护。默认值: False

  • use_nesterov (bool) - 是否使用nesterov动量。默认值: False

  • gradient_scale (float) - 梯度的缩放比例。默认值: 1.0

输入:
  • variable (Union[Parameter, Tensor]) - 要更新的权重。数据类型必须为float。

  • accumulation (Union[Parameter, Tensor]) - 按动量权重计算的累加梯度值,数据类型与 variable 相同。

  • learning_rate (Union[Number, Tensor]) - 学习率,必须是float或为float数据类型的Scalar的Tensor。

  • gradient (Tensor) - 梯度,数据类型与 variable 相同。

  • momentum (Union[Number, Tensor]) - 动量,必须是float或为float数据类型的Scalar的Tensor。

输出:

Tensor,更新后的参数。

异常:
  • TypeError - 如果 use_lockinguse_nesterov 不是bool,或 gradient_scale 不是float。

  • TypeError - 如果 varaccumgrad 不支持数据类型转换。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, nn, ops, Parameter
>>> class Net(nn.Cell):
...    def __init__(self):
...        super(Net, self).__init__()
...        self.apply_momentum = ops.ApplyMomentum()
...        self.variable = Parameter(Tensor(np.array([[0.6, 0.4],
...                                            [0.1, 0.5]]).astype(np.float32)), name="variable")
...        self.accumulate = Parameter(Tensor(np.array([[0.6, 0.5],
...                                            [0.2, 0.6]]).astype(np.float32)), name="accumulate")
...    def construct(self, lr, grad, moment):
...        out = self.apply_momentum(self.variable, self.accumulate, lr, grad, moment)
...        return out
>>> net = Net()
>>> lr = Tensor(0.1, mindspore.float32)
>>> moment = Tensor(0.9, mindspore.float32)
>>> grad = Tensor(np.array([[0.3, 0.7], [0.1, 0.8]]).astype(np.float32))
>>> output = net(lr, grad, moment)
>>> print(output)
[[0.51600003 0.285     ]
[0.072      0.366     ]]