mindspore.nn.TrainOneStepCell
- class mindspore.nn.TrainOneStepCell(network, optimizer, sens=None, return_grad=False)[源代码]
训练网络封装类。
封装 network 和 optimizer 。构建一个输入'*inputs'的用于训练的Cell。 执行函数 construct 中会构建反向图以更新网络参数。支持不同的并行训练模式。
- 参数:
network (Cell) - 训练网络。只支持单输出网络。
optimizer (Union[Cell]) - 用于更新网络参数的优化器。
sens (numbers.Number) - 反向传播的输入,缩放系数。默认值为
None
,取1.0
。return_grad (bool) - 是否返回梯度,若为
True
,则会在返回loss的同时以字典的形式返回梯度,字典的key为梯度对应的参数名,value为梯度值。默认值为False
。
- 输入:
*inputs (Tuple(Tensor)) - shape为 \((N, \ldots)\) 的Tensor组成的元组。
- 输出:
Tensor,损失函数值,其shape通常为 \(()\) 。
- 异常:
TypeError - sens 不是numbers.Number。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import mindspore.nn as nn >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.4.10/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> #1) Using the WithLossCell provided by MindSpore >>> loss_net = nn.WithLossCell(net, loss_fn) >>> train_net = nn.TrainOneStepCell(loss_net, optim) >>> >>> #2) Using user-defined WithLossCell >>> class MyWithLossCell(nn.Cell): ... def __init__(self, backbone, loss_fn): ... super(MyWithLossCell, self).__init__(auto_prefix=False) ... self._backbone = backbone ... self._loss_fn = loss_fn ... ... def construct(self, x, y, label): ... out = self._backbone(x, y) ... return self._loss_fn(out, label) ... ... @property ... def backbone_network(self): ... return self._backbone ... >>> loss_net = MyWithLossCell(net, loss_fn) >>> train_net = nn.TrainOneStepCell(loss_net, optim)