mindspore.train.LearningRateScheduler
- class mindspore.train.LearningRateScheduler(learning_rate_function)[源代码]
用于在训练期间更改学习率。
- 参数:
learning_rate_function (Function) - 在训练期间更改学习率的函数。
样例:
>>> import numpy as np >>> from mindspore import nn >>> from mindspore.train import Model, LearningRateScheduler >>> from mindspore import dataset as ds ... >>> def learning_rate_function(lr, cur_step_num): ... if cur_step_num%1000 == 0: ... lr = lr*0.1 ... return lr ... >>> lr = 0.1 >>> momentum = 0.9 >>> net = nn.Dense(10, 5) >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') >>> optim = nn.Momentum(net.trainable_params(), learning_rate=lr, momentum=momentum) >>> model = Model(net, loss_fn=loss, optimizer=optim) ... >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))} >>> dataset = ds.NumpySlicesDataset(data=data).batch(32) >>> model.train(1, dataset, callbacks=[LearningRateScheduler(learning_rate_function)], ... dataset_sink_mode=False)
- step_end(run_context)[源代码]
在step结束时更改学习率。
- 参数:
run_context (RunContext) - 包含模型的一些基本信息。详情请参考
mindspore.train.RunContext
。