mindspore.experimental
实验性模块。
实验性优化器
接口名 |
概述 |
支持平台 |
用于参数更新的优化器基类。 |
|
|
Adadelta算法的实现。 |
|
|
Adagrad算法的实现。 |
|
|
Adaptive Moment Estimation (Adam)算法的实现。 |
|
|
Adamax算法的实现(基于无穷范数的Adam算法)。 |
|
|
Adaptive Moment Estimation Weight Decay(AdamW)算法的实现。 |
|
|
Averaged Stochastic Gradient Descent 算法的实现。 |
|
|
NAdam算法的实现。 |
|
|
RAdam 算法的实现。 |
|
|
RMSprop 算法的实现。 |
|
|
Rprop 算法的实现。 |
|
|
随机梯度下降算法。 |
|
LRScheduler类
本模块中的动态学习率都是LRScheduler的子类,此模块仅与mindspore.experimental.optim下的优化器配合使用,使用时将优化器实例传递给LRScheduler类。在训练过程中,LRScheduler子类通过调用 step 方法进行学习率的动态改变。
import mindspore
from mindspore import nn
from mindspore.experimental import optim
# Define the network structure of LeNet5. Refer to
# https://gitee.com/mindspore/docs/blob/r2.3.0rc2/docs/mindspore/code/lenet.py
net = LeNet5()
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
optimizer = optim.Adam(net.trainable_params(), lr=0.05)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)
def forward_fn(data, label):
logits = net(data)
loss = loss_fn(logits, label)
return loss, logits
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
def train_step(data, label):
(loss, _), grads = grad_fn(data, label)
optimizer(grads)
return loss
for epoch in range(6):
# Create the dataset taking MNIST as an example. Refer to
# https://gitee.com/mindspore/docs/blob/r2.3.0rc2/docs/mindspore/code/mnist.py
for data, label in create_dataset(need_download=False):
train_step(data, label)
scheduler.step()
接口名 |
概述 |
支持平台 |
动态学习率的基类。 |
|
|
将每个参数组的学习率按照衰减因子 factor 进行衰减,直到 last_epoch 达到 total_iters。 |
|
|
使用余弦退火对优化器参数组的学习率进行改变。 |
|
|
|
使用余弦退火热重启对优化器参数组的学习率进行改变。 |
|
根据循环学习率策略(CLR)设置每个参数组的学习率。 |
|
|
每个epoch呈指数衰减的学习率,即乘以 gamma 。 |
|
|
将每个参数组的学习率设定为初始学习率乘以指定的 lr_lambda 函数。 |
|
|
线性减小学习率乘法因子 ,并将每个参数组的学习率按照此乘法因子进行衰减,直到 last_epoch 数达到 total_iters。 |
|
|
将每个参数组当前的学习率按照传入的 lr_lambda 函数乘以指定的乘法因子。 |
|
|
当epoch/step达到 milestones 时,将每个参数组的学习率按照乘法因子 gamma 进行变化。 |
|
|
每个epoch,学习率通过多项式拟合来调整。 |
|
|
当指标停止改进时降低学习率。 |
|
|
SequentialLR 接收一个将被顺序调用的学习率调度器列表 schedulers,及指定的间隔列表 milestone,milestone 设定了每个epoch哪个调度器被调用。 |
|
|
每 step_size 个epoch按 gamma 衰减每个参数组的学习率。 |
|