mindspore.experimental.optim.lr_scheduler.StepLR
- class mindspore.experimental.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=- 1)[source]
Decays the learning rate of each parameter group by gamma every step_size epochs. Notice that such decay can happen simultaneously with other changes to the learning rate from outside this scheduler.
Warning
This is an experimental lr scheduler module that is subject to change. This module must be used with optimizers in Experimental Optimizer .
- Parameters
optimizer (
mindspore.experimental.optim.Optimizer
) – Wrapped optimizer.step_size (int) – Period of learning rate decay.
gamma (float, optional) – Multiplicative factor of learning rate decay. Default:
0.5
.last_epoch (int, optional) – The index of the last epoch. Default:
-1
.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore >>> from mindspore import nn >>> from mindspore.experimental import optim >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True) >>> optimizer = optim.Adam(net.trainable_params(), lr=0.05) >>> # Assuming optimizer uses lr = 0.05 for all groups >>> # lr = 0.05 if epoch < 2 >>> # lr = 0.005 if 2 <= epoch < 4 >>> # lr = 0.0005 if 4 <= epoch < 6 >>> scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1) >>> def forward_fn(data, label): ... logits = net(data) ... loss = loss_fn(logits, label) ... return loss, logits >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True) >>> def train_step(data, label): ... (loss, _), grads = grad_fn(data, label) ... optimizer(grads) ... return loss >>> for epoch in range(6): ... # Create the dataset taking MNIST as an example. Refer to ... # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py ... for data, label in create_dataset(): ... train_step(data, label) ... scheduler.step() ... current_lr = scheduler.get_last_lr()