mindformers.core.LinearWithWarmUpLR

View Source On Gitee
class mindformers.core.LinearWithWarmUpLR(learning_rate: float, total_steps: int, warmup_steps: int = None, warmup_lr_init: float = 0., warmup_ratio: float = None, **kwargs)[source]

Linear with Warm Up Learning Rate.

The LinearWithWarmUpLR scheduler uses a linear warm-up strategy to gradually increase the learning rate for each parameter group, followed by a linear adjustment of the learning rate after the warm-up phase ends.

During the warm-up phase, the learning rate increases linearly from a smaller initial value to the base learning rate, as described by the following formula:

\[\eta_t = \eta_{\text{warmup}} + t \times \frac{\eta_{\text{base}} - \eta_{\text{warmup}}}{\text{warmup_steps}}\]

where \(\eta_{\text{warmup}}\) is the initial learning rate during the warm-up phase, and \(\eta_{\text{base}}\) is the base learning rate after the warm-up phase.

After the warm-up phase, the learning rate is adjusted according to the following linear schedule:

\[\eta_t = \eta_{\text{base}} - t \times \frac{\eta_{\text{base}} - \eta_{\text{end}}}{\text{total_steps} - \text{warmup_steps}}\]

where \(\eta_{\text{end}}\) is the minimum learning rate at the end of training, \(\text{total_steps}\) is the total number of training steps, and \(\text{warmup_steps}\) is the number of steps in the warm-up phase.

This method allows for a smooth increase in learning rate through linear warm-up, followed by a gradual decrease during the remainder of the training, enhancing the stability and effectiveness of the training process.

Parameters
  • learning_rate (float) – Initial value of learning rate.

  • total_steps (int) – The number of total steps.

  • warmup_steps (int) – The number of warm up steps. Default: None.

  • warmup_lr_init (float) – Initial learning rate in warm up steps. Default: 0.

  • warmup_ratio (float) – Ratio of total training steps used for warmup. Default: None.

Inputs:
  • global_step (int) - The global step.

Outputs:

Learning rate.

Examples

>>> import mindspore as ms
>>> from mindformers.core import LinearWithWarmUpLR
>>>
>>> ms.set_context(mode=ms.GRAPH_MODE)
>>> total_steps = 20
>>> warmup_steps = 10
>>> learning_rate = 0.005
>>>
>>> linear_warmup = LinearWithWarmUpLR(learning_rate=learning_rate,
...                                    warmup_steps=warmup_steps,
...                                    total_steps=total_steps)
>>> print(linear_warmup(1))
0.0005
>>> print(linear_warmup(15))
0.0025