Document feedback

Question document fragment

When a question document fragment contains a formula, it is displayed as a space.

Submission type
issue

It's a little complicated...

I'd like to ask someone.

PR

Just a small problem.

I can fix it online!

Please select the submission type

Problem type
Specifications and Common Mistakes

- Specifications and Common Mistakes:

- Misspellings or punctuation mistakes,incorrect formulas, abnormal display.

- Incorrect links, empty cells, or wrong formats.

- Chinese characters in English context.

- Minor inconsistencies between the UI and descriptions.

- Low writing fluency that does not affect understanding.

- Incorrect version numbers, including software package names and version numbers on the UI.

Usability

- Usability:

- Incorrect or missing key steps.

- Missing main function descriptions, keyword explanation, necessary prerequisites, or precautions.

- Ambiguous descriptions, unclear reference, or contradictory context.

- Unclear logic, such as missing classifications, items, and steps.

Correctness

- Correctness:

- Technical principles, function descriptions, supported platforms, parameter types, or exceptions inconsistent with that of software implementation.

- Incorrect schematic or architecture diagrams.

- Incorrect commands or command parameters.

- Incorrect code.

- Commands inconsistent with the functions.

- Wrong screenshots.

- Sample code running error, or running results inconsistent with the expectation.

Risk Warnings

- Risk Warnings:

- Lack of risk warnings for operations that may damage the system or important data.

Content Compliance

- Content Compliance:

- Contents that may violate applicable laws and regulations or geo-cultural context-sensitive words and expressions.

- Copyright infringement.

Please select the type of question

Problem description

Describe the bug so that we can quickly locate the problem.

mindspore.nn.TrainOneStepWithLossScaleCell

View Source On Gitee
class mindspore.nn.TrainOneStepWithLossScaleCell(network, optimizer, scale_sense)[source]

Network training with loss scaling.

This is a training step with loss scaling. It takes a network, an optimizer and a scale update Cell(or a Tensor) as args. The loss scale value can be updated in both host side or device side. If you want to update it on host side, using a value of Tensor type as scale_sense, otherwise, using a Cell instance for updating loss scale as scale_sense.

Parameters
  • network (Cell) – The training network. The network only supports single output.

  • optimizer (Cell) – Optimizer for updating the network parameters.

  • scale_sense (Union[Tensor, Cell]) – If this value is a Cell, it will be called by TrainOneStepWithLossScaleCell to update loss scale. If this value is a Tensor, the loss scale can be modified by set_sense_scale, the shape should be () or (1,).

Inputs:
  • *inputs (Tuple(Tensor)) - Tuple of input tensors with shape (N,).

Outputs:

Tuple of 3 Tensor, the loss, overflow flag and current loss scale value.

  • loss (Tensor) - A scalar, the loss value.

  • overflow (Tensor) - A scalar, whether overflow occur or not, the type is bool.

  • loss scale (Tensor) - The loss scale value, the shape is () or (1,).

Raises
  • TypeError – If scale_sense is neither Cell nor Tensor.

  • ValueError – If shape of scale_sense is neither (1,) nor ().

Supported Platforms:

Ascend GPU

Examples

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor, Parameter, nn, ops
>>>
>>> class Net(nn.Cell):
...     def __init__(self, in_features, out_features):
...         super(Net, self).__init__()
...         self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)),
...                                 name='weight')
...         self.matmul = ops.MatMul()
...
...     def construct(self, x):
...         output = self.matmul(x, self.weight)
...         return output
...
>>> size, in_features, out_features = 16, 16, 10
>>> #1) when the type of scale_sense is Cell:
>>> net = Net(in_features, out_features)
>>> loss_fn = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = nn.WithLossCell(net, loss_fn)
>>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32)
>>> labels = Tensor(np.ones([out_features,]), mindspore.float32)
>>> loss = net_with_loss(input, labels)
>>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager)
>>> status = Tensor([0] * 8, mindspore.int32)
>>> scaling_sens = train_network.scale_sense
>>> scaling_sens_filled = ops.ones_like(loss) * ops.cast(scaling_sens, ops.dtype(loss))
>>> grads = train_network.grad(train_network.network, train_network.weights)(input, labels, scaling_sens_filled)
>>> grads = train_network.grad_reducer(grads)
>>> cond = train_network.get_overflow_status(status, grads)
>>> overflow = train_network.process_loss_scale(cond)
>>>
>>> #2) when the type of scale_sense is Tensor:
>>> net = Net(in_features, out_features)
>>> loss = nn.MSELoss()
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> net_with_loss = nn.WithLossCell(net, loss)
>>> inputs = Tensor(np.ones([size, in_features]).astype(np.float32))
>>> label = Tensor(np.zeros([size, out_features]).astype(np.float32))
>>> scaling_sens = Tensor([1024], dtype=mindspore.float32)
>>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=scaling_sens)
>>> scaling_sens = Tensor([1], dtype=mstype.float32)
>>> train_network.set_sense_scale(scaling_sens)
>>> output = train_network(inputs, label)
>>>
>>> # update scaling sens and train the network
>>> scaling_sens = Tensor([1], dtype=mindspore.float32)
>>> train_network.set_sense_scale(scaling_sens)
>>> output = train_network(inputs, label)
get_overflow_status(status, compute_output)[source]

Get floating-point overflow status.

Get overflow results after executing the target process for overflow detection. User-defined training network based on this class can also call this interface to process the overflow.

Parameters
  • status (object) – To control the execution sequence with start_overflow_check, it should be set as the first output of start_overflow_check.

  • compute_output – Overflow detection should be performed in a certain computation process. Set compute_output as the output of the computation process.

Returns

bool, whether the overflow occurs or not.

process_loss_scale(overflow)[source]

Calculate loss scale according to the overflow.

User-defined training network based on this class can also call this interface to process the overflow.

Parameters

overflow (bool) – Whether the overflow occurs or not.

Returns

bool, the input overflow value.

set_sense_scale(sens)[source]

If the user has set the scale_sense of Tensor type, he can call this function to reassign the value.

Parameters

sens (Tensor) – The new sense whose shape and type are the same with original scale_sense.

start_overflow_check(pre_cond, compute_input)[source]

Start floating-point overflow detection. Create and clear the overflow detection state.

Specify the argument 'pre_cond' and 'compute_input' to make sure overflow status is cleared at the right time. Taking this situation as an example, we need to execute state clearing after loss calculation and then detect overflow in the process of gradient calculation. In this case, pre_cond should be the output of the loss function, and compute_input should be the input of gradients-computing function. User-defined training network based on this class can also call this interface to process the overflow.

Parameters
  • pre_cond (Tensor) – A precondition for starting overflow detection. It determines the executing order of overflow state clearing and prior processions. It makes sure that the function 'start_overflow' clears status after finishing the process of precondition.

  • compute_input (object) – The input of subsequent process. Overflow detection should be performed on a certain computation. Set compute_input as the input of the computation, to ensure overflow status is cleared before executing the computation.

Returns

Tuple[object, object], the first output is used to control the execution sequence. To ensure that the start_overflow_check is executed before get_overflow_status after compilation optimization is performed. This value should be used as the first input of get_overflow_status. The second output is the same as the input of compute_input, used to control the execution sequence, and make ensure that the overflow flag is cleaned up when the function returns.