mindspore.nn.TrainOneStepWithLossScaleCell
- class mindspore.nn.TrainOneStepWithLossScaleCell(network, optimizer, scale_sense)[source]
Network training with loss scaling.
This is a training step with loss scaling. It takes a network, an optimizer and a scale update Cell(or a Tensor) as args. The loss scale value can be updated in both host side or device side. If you want to update it on host side, using a value of Tensor type as scale_sense, otherwise, using a Cell instance for updating loss scale as scale_sense.
- Parameters
network (Cell) – The training network. The network only supports single output.
optimizer (Cell) – Optimizer for updating the network parameters.
scale_sense (Union[Tensor, Cell]) – If this value is a Cell, it will be called by TrainOneStepWithLossScaleCell to update loss scale. If this value is a Tensor, the loss scale can be modified by set_sense_scale, the shape should be \(()\) or \((1,)\).
- Inputs:
*inputs (Tuple(Tensor)) - Tuple of input tensors with shape \((N, \ldots)\).
- Outputs:
Tuple of 3 Tensor, the loss, overflow flag and current loss scale value.
loss (Tensor) - A scalar, the loss value.
overflow (Tensor) - A scalar, whether overflow occur or not, the type is bool.
loss scale (Tensor) - The loss scale value, the shape is \(()\) or \((1,)\).
- Raises
TypeError – If scale_sense is neither Cell nor Tensor.
ValueError – If shape of scale_sense is neither \((1,)\) nor \(()\).
- Supported Platforms:
Ascend
GPU
Examples
>>> import numpy as np >>> import mindspore >>> from mindspore import Tensor, Parameter, nn, ops >>> >>> class Net(nn.Cell): ... def __init__(self, in_features, out_features): ... super(Net, self).__init__() ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), ... name='weight') ... self.matmul = ops.MatMul() ... ... def construct(self, x): ... output = self.matmul(x, self.weight) ... return output ... >>> size, in_features, out_features = 16, 16, 10 >>> #1) when the type of scale_sense is Cell: >>> net = Net(in_features, out_features) >>> loss = nn.MSELoss() >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> net_with_loss = nn.WithLossCell(net, loss) >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000) >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32) >>> labels = Tensor(np.ones([out_features,]), mindspore.float32) >>> output = train_network(input, labels) >>> status, scaling_sens = train_network.start_overflow_check(loss, train_network.scaling_sens) >>> grads = train_network.grad(train_network.network, weights)(*inputs, scaling_sens_filled) >>> grads = train_network.grad_reducer(grads) >>> cond = train_network.get_overflow_status(status, grads) >>> overflow = train_network.process_loss_scale(cond) >>> >>> #2) when the type of scale_sense is Tensor: >>> net = Net(in_features, out_features) >>> loss = nn.MSELoss() >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> net_with_loss = nn.WithLossCell(net, loss) >>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32)) >>> label = Tensor(np.zeros([size, out_features]).astype(np.float32)) >>> scaling_sens = Tensor([1024], dtype=mindspore.float32) >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens) >>> scaling_sens = Tensor([1], dtype=mstype.float32) >>> train_network.set_sense_scale(scaling_sens) >>> output = train_network(inputs, label) >>> >>> # update scaling sens and train the network >>> scaling_sens = Tensor([1], dtype=mindspore.float32) >>> train_network.set_sense_scale(scaling_sens) >>> output = train_network(inputs, label)
- get_overflow_status(status, compute_output)[source]
Get floating-point overflow status.
Get overflow results after executing the target process for overflow detection. User-defined training network based on this class can also call this interface to process the overflow.
- Parameters
status (object) – To control the execution sequence with start_overflow_check, it should be set as the first output of start_overflow_check.
compute_output – Overflow detection should be performed in a certain computation process. Set compute_output as the output of the computation process.
- Returns
bool, whether the overflow occurs or not.
- process_loss_scale(overflow)[source]
Calculate loss scale according to the overflow.
User-defined training network based on this class can also call this interface to process the overflow.
- Parameters
overflow (bool) – Whether the overflow occurs or not.
- Returns
bool, the input overflow value.
- set_sense_scale(sens)[source]
If the user has set the scale_sense of Tensor type, he can call this function to reassign the value.
- Parameters
sens (Tensor) – The new sense whose shape and type are the same with original scale_sense.
- start_overflow_check(pre_cond, compute_input)[source]
Start floating-point overflow detection. Create and clear the overflow detection state.
Specify the argument ‘pre_cond’ and ‘compute_input’ to make sure overflow status is cleared at the right time. Taking this situation as an example, we need to execute state clearing after loss calculation and then detect overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss function, and compute_input should be the input of gradients-computing function. User-defined training network based on this class can also call this interface to process the overflow.
- Parameters
pre_cond (Tensor) – A precondition for starting overflow detection. It determines the executing order of overflow state clearing and prior processions. It makes sure that the function ‘start_overflow’ clears status after finishing the process of precondition.
compute_input (object) – The input of subsequent process. Overflow detection should be performed on a certain computation. Set compute_input as the input of the computation, to ensure overflow status is cleared before executing the computation.
- Returns
Tuple[object, object], the first output is used to control the execution sequence. To ensure that the start_overflow_check is executed before get_overflow_status after compilation optimization is performed. This value should be used as the first input of get_overflow_status. The second output is the same as the input of compute_input, used to control the execution sequence, and make ensure that the overflow flag is cleaned up when the function returns.