mindspore.experimental.optim.lr_scheduler.CyclicLR
- class mindspore.experimental.optim.lr_scheduler.CyclicLR(optimizer, base_lr, max_lr, step_size_up=2000, step_size_down=None, mode='triangular', gamma=1., scale_fn=None, scale_mode='cycle', last_epoch=- 1)[source]
Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR). The policy cycles the learning rate between two boundaries with a constant frequency, as detailed in the paper Cyclical Learning Rates for Training Neural Networks. The distance between the two boundaries can be scaled on a per-iteration or per-cycle basis.
This class has three built-in policies, as put forth in the paper:
"triangular": A basic triangular cycle without amplitude scaling.
"triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
"exp_range": A cycle that scales initial amplitude by
at each cycle iteration.
Warning
This is an experimental lr scheduler module that is subject to change. This module must be used with optimizers in Experimental Optimizer .
- Parameters
optimizer (
mindspore.experimental.optim.Optimizer
) – Wrapped optimizer.base_lr (Union(float, list)) – Initial learning rate which is the lower boundary in the cycle for each parameter group.
max_lr (Union(float, list)) – Upper learning rate boundaries in the cycle for each parameter group. Functionally, it defines the cycle amplitude (max_lr - base_lr). The lr at any cycle is the sum of base_lr and some scaling of the amplitude.
step_size_up (int, optional) – Number of training iterations in the increasing half of a cycle. Default:
2000
.step_size_down (int, optional) – Number of training iterations in the decreasing half of a cycle. If step_size_down is None, it is set to step_size_up. Default:
None
.mode (str, optional) – One of {triangular, triangular2, exp_range}. Values correspond to policies detailed above. If scale_fn is not None, this argument is ignored. Default:
'triangular'
.gamma (float, optional) – Constant in 'exp_range' scaling function: gamma**(cycle iterations). Default:
1.0
.scale_fn (function, optional) – Custom scaling policy defined by a single argument lambda function, where 0 <= scale_fn(x) <= 1 for all x >= 0. If specified, then 'mode' is ignored. Default:
None
.scale_mode (str, optional) – {'cycle', 'iterations'}. Defines whether scale_fn is evaluated on cycle number or cycle iterations (training iterations since start of cycle). Illegal inputs will use
'iterations'
by defaults. Default:'cycle'
.last_epoch (int, optional) – The index of the last epoch. Default:
-1
.
- Raises
ValueError – When base_lr is list or tuple, the length of it is not equal to the number of param groups.
ValueError – When max_lr is list or tuple, the length of it is not equal to the number of param groups.
ValueError – mode is not in [
'triangular'
,'triangular2'
,'exp_range'
] and scale_fn is None.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> from mindspore.experimental import optim >>> from mindspore import nn >>> net = nn.Dense(3, 2) >>> optimizer = optim.SGD(net.trainable_params(), lr=0.1, momentum=0.9) >>> scheduler = optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1) >>> expect_list = [[0.010045], [0.01009], [0.010135], [0.01018], [0.010225]] >>> >>> for i in range(5): ... scheduler.step() ... current_lr = scheduler.get_last_lr() ... print(current_lr) [Tensor(shape=[], dtype=Float32, value= 0.010045)] [Tensor(shape=[], dtype=Float32, value= 0.01009)] [Tensor(shape=[], dtype=Float32, value= 0.010135)] [Tensor(shape=[], dtype=Float32, value= 0.01018)] [Tensor(shape=[], dtype=Float32, value= 0.010225)]