mindspore.experimental.optim.lr_scheduler.StepLR

View Source On Gitee
class mindspore.experimental.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=- 1)[source]

Decays the learning rate of each parameter group by gamma every step_size epochs. Notice that such decay can happen simultaneously with other changes to the learning rate from outside this scheduler.

Warning

This is an experimental lr scheduler module that is subject to change. This module must be used with optimizers in Experimental Optimizer .

Parameters
  • optimizer (mindspore.experimental.optim.Optimizer) – Wrapped optimizer.

  • step_size (int) – Period of learning rate decay.

  • gamma (float, optional) – Multiplicative factor of learning rate decay. Default: 0.5.

  • last_epoch (int, optional) – The index of the last epoch. Default: -1.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> from mindspore import nn
>>> from mindspore.experimental import optim
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
>>> optimizer = optim.Adam(net.trainable_params(), lr=0.05)
>>> # Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05     if epoch < 2
>>> # lr = 0.005    if 2 <= epoch < 4
>>> # lr = 0.0005   if 4 <= epoch < 6
>>> scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
>>> def forward_fn(data, label):
...     logits = net(data)
...     loss = loss_fn(logits, label)
...     return loss, logits
>>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
>>> def train_step(data, label):
...     (loss, _), grads = grad_fn(data, label)
...     optimizer(grads)
...     return loss
>>> for epoch in range(6):
...     # Create the dataset taking MNIST as an example. Refer to
...     # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/mnist.py
...     for data, label in create_dataset():
...         train_step(data, label)
...     scheduler.step()
...     current_lr = scheduler.get_last_lr()