mindspore.train.LearningRateScheduler

class mindspore.train.LearningRateScheduler(learning_rate_function)[源代码]

用于在训练期间更改学习率。

参数:
  • learning_rate_function (Function) - 在训练期间更改学习率的函数。

样例:

>>> import numpy as np
>>> from mindspore import nn
>>> from mindspore.train import Model, LearningRateScheduler
>>> from mindspore import dataset as ds
...
>>> def learning_rate_function(lr, cur_step_num):
...     if cur_step_num%1000 == 0:
...         lr = lr*0.1
...     return lr
...
>>> lr = 0.1
>>> momentum = 0.9
>>> net = nn.Dense(10, 5)
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=lr, momentum=momentum)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
...
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
>>> dataset = ds.NumpySlicesDataset(data=data).batch(32)
>>> model.train(1, dataset, callbacks=[LearningRateScheduler(learning_rate_function)],
...             dataset_sink_mode=False)
step_end(run_context)[源代码]

在step结束时更改学习率。

参数: