mindspore.nn.LARS

class mindspore.nn.LARS(optimizer, epsilon=1e-05, coefficient=0.001, use_clip=False, lars_filter=lambda x: ...)[source]

Implements the LARS algorithm.

LARS is an optimization algorithm employing a large batch optimization technique. Refer to paper LARGE BATCH TRAINING OF CONVOLUTIONAL NETWORKS.

The updating formulas are as follows,

Parameters:base learning rate γ0, momentum m, weight decay λ, LARS coefficient η, number of steps TInit: t=0, v=0, init weight w0l for each layer lwhile t<T for each layer l dogtlL(wtl)γtγ0(1tT)2γlηwtlgtl+λwtl(compute the local LR γl)vt+1lmvtl+γt+1γl(gtl+λwtl)wt+1lwtlvt+1l end while 

w represents the network's params, g represents gradients, t represents the current step, λ represents weight_decay in optimizer, γ represents learning_rate in optimizer, η represents coefficient.

Parameters
  • optimizer (mindspore.nn.Optimizer) – MindSpore optimizer for which to wrap and modify gradients.

  • epsilon (float) – Term added to the denominator to improve numerical stability. Default: 1e-05 .

  • coefficient (float) – Trust coefficient for calculating the local learning rate. Default: 0.001 .

  • use_clip (bool) – Whether to use clip operation for calculating the local learning rate. Default: False .

  • lars_filter (Function) – A function to determine which of the network parameters to use LARS algorithm. Default: lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name.

Inputs:
  • gradients (tuple[Tensor]) - The gradients of params in the optimizer, the shape is the as same as the params in the optimizer.

Supported Platforms:

Ascend

Examples

>>> import mindspore as ms
>>> from mindspore import nn
>>>
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.5.0/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> opt = nn.Momentum(net.trainable_params(), 0.1, 0.9)
>>> opt_lars = nn.LARS(opt, epsilon=1e-08, coefficient=0.02)
>>> model = ms.train.Model(net, loss_fn=loss, optimizer=opt_lars, metrics=None)