mindspore.nn.TrainOneStepCell
- class mindspore.nn.TrainOneStepCell(network, optimizer, sens=None, return_grad=False)[source]
Network training package class.
Wraps the network with the optimizer. The resulting Cell is trained with input ‘*inputs’. The backward graph will be created in the construct function to update the parameter. Different parallel modes are available for training.
- Parameters
network (Cell) – The training network. The network only supports single output.
optimizer (Union[Cell]) – Optimizer for updating the network parameters.
sens (numbers.Number) – The scaling number to be filled as the input of backpropagation. Default value is
None
, which is1.0
.return_grad (bool) – Whether to return gradient. If
True
, it will return the gradient in the form of a dict while returning loss. The key of the dict is the parameter name corresponding to the gradient, and value is the gradient value. Default value isFalse
.
- Inputs:
*inputs (Tuple(Tensor)) - Tuple of input tensors with shape \((N, \ldots)\).
- Outputs:
Tensor, a tensor means the loss value, the shape of which is usually \(()\).
- Raises
TypeError – If sens is not a numbers.Number.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore.nn as nn >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() >>> optim = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> #1) Using the WithLossCell provided by MindSpore >>> loss_net = nn.WithLossCell(net, loss_fn) >>> train_net = nn.TrainOneStepCell(loss_net, optim) >>> >>> #2) Using user-defined WithLossCell >>> class MyWithLossCell(nn.Cell): ... def __init__(self, backbone, loss_fn): ... super(MyWithLossCell, self).__init__(auto_prefix=False) ... self._backbone = backbone ... self._loss_fn = loss_fn ... ... def construct(self, x, y, label): ... out = self._backbone(x, y) ... return self._loss_fn(out, label) ... ... @property ... def backbone_network(self): ... return self._backbone ... >>> loss_net = MyWithLossCell(net, loss_fn) >>> train_net = nn.TrainOneStepCell(loss_net, optim)