mindspore.nn.LARS

View Source On Gitee
class mindspore.nn.LARS(optimizer, epsilon=1e-05, coefficient=0.001, use_clip=False, lars_filter=lambda x: ...)[source]

Implements the LARS algorithm.

LARS is an optimization algorithm employing a large batch optimization technique. Refer to paper LARGE BATCH TRAINING OF CONVOLUTIONAL NETWORKS.

The updating formulas are as follows,

\[\begin{split}\begin{array}{ll} \\ &\newline &\hline \\ &\textbf{Parameters}: \text{base learning rate } \gamma_{0} , \text{ momentum m}, \text{ weight decay } \lambda , \\ &\hspace{5mm}\text{ LARS coefficient } \eta , \text{ number of steps } T \\ &\textbf{Init}: \text{ t=0, v=0, init weight } w_{0}^{l} \text{ for each layer } l \\[-1.ex] &\newline &\hline \\ &\textbf{while} \text{ t<T for each layer } l \textbf{ do} \\ &\hspace{5mm}g_{t}^{l} \leftarrow \nabla L\left(w_{t}^{l}\right) \\ &\hspace{5mm}\gamma_{t} \leftarrow \gamma_{0} *\left(1-\frac{t}{T}\right)^{2} \\ &\hspace{5mm}\gamma^{l} \leftarrow \eta *\frac{\left\|w_{t}^{l}\right\|}{\left\|g_{t}^{l}\right\|+ \lambda\left\|w_{t}^{l}\right\|} \text{(compute the local LR } \gamma^{ l)} \\ &\hspace{5mm}v_{t+1}^{l} \leftarrow m v_{t}^{l}+\gamma_{t+1} * \gamma^{l} *\left(g_{t}^{l}+\lambda w_{t}^{l}\right) \\ &\hspace{5mm}w_{t+1}^{l} \leftarrow w_{t}^{l}-v_{t+1}^{l} \\ &\textbf{ end while } \\[-1.ex] &\newline &\hline \\[-1.ex] \end{array}\end{split}\]

\(w\) represents the network parameters, \(g\) represents gradients, \(t\) represents the current step, \(\lambda\) represents weight_decay in optimizer, \(\gamma\) represents learning_rate in optimizer, \(\eta\) represents coefficient.

Parameters
  • optimizer (mindspore.nn.Optimizer) – MindSpore optimizer for which to wrap and modify gradients.

  • epsilon (float) – Term added to the denominator to improve numerical stability. Default: 1e-05 .

  • coefficient (float) – Trust coefficient for calculating the local learning rate. Default: 0.001 .

  • use_clip (bool) – Whether to use clip operation for calculating the local learning rate. Default: False .

  • lars_filter (Function) – A function to determine which of the network parameters to use LARS algorithm. Default: lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.

Inputs:
  • gradients (tuple[Tensor]) - The gradients of params in the optimizer, the shape is the as same as the params in the optimizer.

Outputs:

Union[Tensor[bool], tuple[Parameter]], it depends on the output of optimizer.

Supported Platforms:

Ascend

Examples

>>> import mindspore as ms
>>> from mindspore import nn
>>>
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = nn.Momentum(net.trainable_params(), 0.1, 0.9)
>>> opt_lars = nn.LARS(opt, epsilon=1e-08, coefficient=0.02)
>>> model = ms.train.Model(net, loss_fn=loss, optimizer=opt_lars, metrics=None)