mindspore.LearningRateScheduler
- class mindspore.LearningRateScheduler(learning_rate_function)[源代码]
用于在训练期间更改学习率。
参数:
learning_rate_function (Function) - 在训练期间更改学习率的函数。
样例:
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import nn >>> from mindspore import LearningRateScheduler >>> import mindspore.nn as nn >>> from mindspore import dataset as ds ... >>> def learning_rate_function(lr, cur_step_num): ... if cur_step_num%1000 == 0: ... lr = lr*0.1 ... return lr ... >>> lr = 0.1 >>> momentum = 0.9 >>> net = nn.Dense(10, 5) >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') >>> optim = nn.Momentum(net.trainable_params(), learning_rate=lr, momentum=momentum) >>> model = ms.Model(net, loss_fn=loss, optimizer=optim) ... >>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))} >>> dataset = ds.NumpySlicesDataset(data=data).batch(32) >>> model.train(1, dataset, callbacks=[LearningRateScheduler(learning_rate_function)], ... dataset_sink_mode=False)