mindspore.nn.LARS
- class mindspore.nn.LARS(optimizer, epsilon=1e-05, coefficient=0.001, use_clip=False, lars_filter=<lambda x: "LayerNorm" not in x.name and "bias" not in x.name>)[source]
Implements the LARS algorithm with LARSUpdate Operator.
LARS is an optimization algorithm employing a large batch optimization technique. Refer to paper LARGE BATCH TRAINING OF CONVOLUTIONAL NETWORKS.
The updating formulas are as follows,
\[\begin{split}\begin{array}{ll} \\ \lambda = \frac{\theta \text{ * } || \omega || } \\ {|| g_{t} || \text{ + } \delta \text{ * } || \omega || } \\ \lambda = \begin{cases} \min(\frac{\lambda}{\alpha }, 1) & \text{ if } clip = True \\ \lambda & \text{ otherwise } \end{cases}\\ g_{t+1} = \lambda * (g_{t} + \delta * \omega) \end{array}\end{split}\]\(\theta\) represents coefficient, \(\omega\) represents parameters, \(g\) represents gradients, \(t\) represents updating step, \(\delta\) represents weight_decay, \(\alpha\) represents learning_rate, \(clip\) represents use_clip.
- Parameters
optimizer (Optimizer) – MindSpore optimizer for which to wrap and modify gradients.
epsilon (float) – Term added to the denominator to improve numerical stability. Default: 1e-05.
coefficient (float) – Trust coefficient for calculating the local learning rate. Default: 0.001.
use_clip (bool) – Whether to use clip operation for calculating the local learning rate. Default: False.
lars_filter (Function) – A function to determine whether apply the LARS algorithm. Default: lambda x: ‘LayerNorm’ not in x.name and ‘bias’ not in x.name.
- Inputs:
gradients (tuple[Tensor]) - The gradients of params in the optimizer, the shape is the as same as the params in the optimizer.
- Outputs:
Union[Tensor[bool], tuple[Parameter]], it depends on the output of optimizer.
- Supported Platforms:
Ascend
Examples
>>> net = Net() >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> opt = nn.Momentum(net.trainable_params(), 0.1, 0.9) >>> opt_lars = nn.LARS(opt, epsilon=1e-08, coefficient=0.02) >>> model = Model(net, loss_fn=loss, optimizer=opt_lars, metrics=None)