mindspore.experimental.optim.lr_scheduler.SequentialLR
- class mindspore.experimental.optim.lr_scheduler.SequentialLR(optimizer, schedulers, milestones, last_epoch=- 1)[source]
Concatenate multiple learning rate adjustment strategies in schedulers in sequence, switching to the next learning rate adjustment strategy at milestone.
Warning
This is an experimental lr scheduler module that is subject to change. This module must be used with optimizers in Experimental Optimizer .
- Parameters
optimizer (
mindspore.experimental.optim.Optimizer
) – Wrapped optimizer.schedulers (list[
mindspore.experimental.optim.lr_scheduler.LRScheduler
]) – List of learning rate schedulers.milestones (list) – List of integers of milestone points, sets which learning rate adjustment strategy is invoked for each epoch.
last_epoch (int, optional) – The number of times the step() method of the current learning rate adjustment strategy has been executed. Default:
-1
.
- Raises
ValueError – The optimizer in schedulers is different from the optimizer passed in.
ValueError – The optimizer in schedulers is different from the optimizer of schedulers[0].
ValueError – Length of milestones is not equal to length of schedulers minus 1.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> from mindspore.experimental import optim >>> from mindspore import nn >>> net = nn.Dense(3, 2) >>> optimizer = optim.Adam(net.trainable_params(), 0.1) >>> scheduler1 = optim.lr_scheduler.ConstantLR(optimizer, factor=0.1, total_iters=2) >>> scheduler2 = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) >>> scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[2]) >>> for i in range(6): ... scheduler.step() ... current_lr = scheduler.get_last_lr() ... print(current_lr) [Tensor(shape=[], dtype=Float32, value= 0.01)] [Tensor(shape=[], dtype=Float32, value= 0.1)] [Tensor(shape=[], dtype=Float32, value= 0.09)] [Tensor(shape=[], dtype=Float32, value= 0.081)] [Tensor(shape=[], dtype=Float32, value= 0.0729)] [Tensor(shape=[], dtype=Float32, value= 0.06561)]