mindspore.nn.TrainOneStepCell

class mindspore.nn.TrainOneStepCell(network, optimizer, sens=None, return_grad=False)[source]

Network training package class.

Wraps the network with the optimizer. The resulting Cell is trained with input ‘*inputs’. The backward graph will be created in the construct function to update the parameter. Different parallel modes are available for training.

Parameters
  • network (Cell) – The training network. The network only supports single output.

  • optimizer (Union[Cell]) – Optimizer for updating the network parameters.

  • sens (numbers.Number) – The scaling number to be filled as the input of backpropagation. Default value is None , which is 1.0 .

  • return_grad (bool) – Whether to return gradient. If True, it will return the gradient in the form of a dict while returning loss. The key of the dict is the parameter name corresponding to the gradient, and value is the gradient value. Default value is False .

Inputs:
  • *inputs (Tuple(Tensor)) - Tuple of input tensors with shape \((N, \ldots)\).

Outputs:

Tensor, a tensor means the loss value, the shape of which is usually \(()\).

Raises

TypeError – If sens is not a numbers.Number.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> #1) Using the WithLossCell provided by MindSpore
>>> loss_net = nn.WithLossCell(net, loss_fn)
>>> train_net = nn.TrainOneStepCell(loss_net, optim)
>>>
>>> #2) Using user-defined WithLossCell
>>> class MyWithLossCell(nn.Cell):
...    def __init__(self, backbone, loss_fn):
...        super(MyWithLossCell, self).__init__(auto_prefix=False)
...        self._backbone = backbone
...        self._loss_fn = loss_fn
...
...    def construct(self, x, y, label):
...        out = self._backbone(x, y)
...        return self._loss_fn(out, label)
...
...    @property
...    def backbone_network(self):
...        return self._backbone
...
>>> loss_net = MyWithLossCell(net, loss_fn)
>>> train_net = nn.TrainOneStepCell(loss_net, optim)