mindformers.core.LearningRateWiseLayer
- class mindformers.core.LearningRateWiseLayer(base_lr, lr_scale)[source]
Learning Rate Wise Layer.
This approach allows each layer to adapt its learning rate according to its specific needs, leading to more efficient and effective training. The learning rate for each layer is determined by a base learning rate modulated by a scaling factor specific to that layer.
Initially, the learning rate for each layer is set based on a linear scaling strategy:
\[\eta_{t,l} = \eta_{\text{base}} \times \alpha_l\]where \(\eta_{t,l}\) is the learning rate for layer \(l\) at time \(t\) , \(\eta_{\text{base}}\) is the base learning rate, and \(\alpha_l\) is the scaling factor for layer \(l\) .
As training progresses, the learning rate for each layer is adjusted according to the following cosine annealing schedule:
\[\eta_{t,l} = \eta_{\text{end}} + \frac{1}{2}(\eta_{\text{base}} \times \alpha_l - \eta_{\text{end}})\left(1 + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)\]where \(T_{cur}\) is the number of epochs completed since the learning rate was last reset, and \(T_{max}\) is the total number of epochs before the next reset. \(\eta_{\text{end}}\) represents the minimum learning rate at the end of the training.
- Parameters
base_lr (mindspore.nn.learning_rate_schedule.LearningRateSchedule) – The base learning rate schedule.
lr_scale (float) – The value for learning rate scaling.
- Inputs:
global_step (int) - The global step.
- Outputs:
Learning rate.
Examples
>>> import mindspore as ms >>> from mindformers.core import LinearWithWarmUpLR >>> from mindformers.core import LearningRateWiseLayer >>> >>> ms.set_context(mode=ms.GRAPH_MODE) >>> total_steps = 20 >>> warmup_steps = 10 >>> learning_rate = 0.005 >>> >>> linear_warmup = LinearWithWarmUpLR(learning_rate=learning_rate, ... warmup_steps=warmup_steps, ... total_steps=total_steps) >>> learning_rate_wise_layer = LearningRateWiseLayer(linear_warmup, 0.5) >>> print(learning_rate_wise_layer(1)) 0.00025 >>> print(learning_rate_wise_layer(15)) 0.00125