mindspore.nn.TrainOneStepWithLossScaleCell

class mindspore.nn.TrainOneStepWithLossScaleCell(network, optimizer, scale_sense)[source]

Network training with loss scaling.

This is a training step with loss scaling. It takes a network, an optimizer and possibly a scale update Cell as args. The loss scale value can be updated in both host side or device side. The TrainOneStepWithLossScaleCell will be compiled to be graph which takes *inputs as input data. The Tensor type of scale_sense is acting as loss scaling value. If you want to update it on host side, the value must be provided. If the Tensor type of scale_sense is not given, the loss scale update logic must be provied by Cell type of scale_sense.

Parameters
  • network (Cell) – The training network. The network only supports single output.

  • optimizer (Cell) – Optimizer for updating the weights.

  • scale_sense (Union[Tensor, Cell]) – If this value is Cell type, the loss scaling update logic cell.If this value is Tensor type, Tensor with shape \(()\) or \((1,)\).

Inputs:
  • (*inputs) (Tuple(Tensor)) - Tuple of input tensors with shape \((N, \ldots)\).

Outputs:

Tuple of 3 Tensor, the loss, overflow flag and current loss scaling value.

  • loss (Tensor) - Tensor with shape \(()\).

  • overflow (Tensor) - Tensor with shape \(()\), type is bool.

  • loss scaling value (Tensor) - Tensor with shape \(()\)

Raises
  • TypeError – If scale_sense is neither Cell nor Tensor.

  • ValueError – If shape of scale_sense is neither (1,) nor ().

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> from mindspore import Tensor, Parameter, nn
>>> from mindspore.ops import operations as P
>>> from mindspore.nn.wrap.cell_wrapper import WithLossCell
>>> from mindspore.common import dtype as mstype
>>>
>>> class Net(nn.Cell):
...     def __init__(self, in_features, out_features):
...         super(Net, self).__init__()
...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
...                                 name='weight')
...         self.matmul = P.MatMul()
...
...     def construct(self, x):
...         output = self.matmul(x, self.weight)
...         return output
...
>>> size, in_features, out_features = 16, 16, 10
>>> #1) when the type of scale_sense is Cell:
>>> net = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = WithLossCell(net, loss)
>>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
>>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
>>> labels = Tensor(np.ones([out_features,]), mindspore.float32)
>>> output = train_network(input, labels)
>>>
>>> #2) when the type of scale_sense is Tensor:
>>> net = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = WithLossCell(net, loss)
>>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
>>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
>>> scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens)
>>> output = train_network(inputs, label)
get_overflow_status(status, compute_output)[source]

Get floating-point overflow status.

Get overflow results after executing the target process for overflow detection.

Parameters
  • status (object) – A status instance used to detect the overflow.

  • compute_output – Overflow detection should be performed on a certain computation. Set compute_output as the output of the computation, to ensure overflow status is acquired before executing the computation.

Returns

bool, whether the overflow occurs or not.

process_loss_scale(overflow)[source]

Calculate loss scale according to the overflow.

Parameters

overflow (bool) – Whether the overflow occurs or not.

Returns

bool, overflow value.

set_sense_scale(sens)[source]

If the user has set the sens in the training process and wants to reassign the value, he can call this function again to make modification, and sens needs to be of type Tensor.

start_overflow_check(pre_cond, compute_input)[source]

Start floating-point overflow detection. Create and clear the overflow detection state.

Specify the argument ‘pre_cond’ and ‘compute_input’ to make sure overflow status is cleared at the right time. Taking this situation as an example, we need to execute state clearing after loss calculation and then detect overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss function, and compute_input should be the input of gradients-computing function.

Parameters
  • pre_cond (object) – A precondition for starting overflow detection. It determines the executing order of overflow state clearing and prior processions. It makes sure that the function ‘start_overflow’ clears status after finishing the process of precondition.

  • compute_input (object) – The input of subsequent process. Overflow detection should be performed on a certain computation. Set compute_input as the input of the computation, to ensure overflow status is cleared before executing the computation.

Returns

Tuple[object, object], the first value is False for GPU backend, while it is a instance of NPUAllocFloatStatus for other backend. The status is used to detect overflow during overflow detection. The second value is the same as the input of compute_input, but contains some information about the execution order.