mindspore.nn.TrainOneStepCell

查看源文件
class mindspore.nn.TrainOneStepCell(network, optimizer, sens=None, return_grad=False)[源代码]

训练网络封装类。

封装 networkoptimizer 。构建一个输入'*inputs'的用于训练的Cell。 执行函数 construct 中会构建反向图以更新网络参数。支持不同的并行训练模式。

参数:
  • network (Cell) - 训练网络。只支持单输出网络。

  • optimizer (Union[Cell]) - 用于更新网络参数的优化器。

  • sens (numbers.Number) - 反向传播的输入,缩放系数。默认值为 None ,取 1.0

  • return_grad (bool) - 是否返回梯度,若为 True ,则会在返回loss的同时以字典的形式返回梯度,字典的key为梯度对应的参数名,value为梯度值。默认值为 False

输入:
  • *inputs (Tuple(Tensor)) - shape为 \((N, \ldots)\) 的Tensor组成的元组。

输出:

Tensor,损失函数值,其shape通常为 \(()\)

异常:
  • TypeError - sens 不是numbers.Number。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> #1) Using the WithLossCell provided by MindSpore
>>> loss_net = nn.WithLossCell(net, loss_fn)
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
>>>
>>> #2) Using user-defined WithLossCell
>>> class MyWithLossCell(nn.Cell):
...    def __init__(self, backbone, loss_fn):
...        super(MyWithLossCell, self).__init__(auto_prefix=False)
...        self._backbone = backbone
...        self._loss_fn = loss_fn
...
...    def construct(self, x, y, label):
...        out = self._backbone(x, y)
...        return self._loss_fn(out, label)
...
...    @property
...    def backbone_network(self):
...        return self._backbone
...
>>> loss_net = MyWithLossCell(net, loss_fn)
>>> train_net = nn.TrainOneStepCell(loss_net, optim)