mindspore.amp.auto_mixed_precision

View Source On Gitee
mindspore.amp.auto_mixed_precision(network, amp_level='O0', dtype=mstype.float16)[source]

Returns a network processed with auto mixed precision.

This interface will automatically perform mixed-precision processing on the input network, and the cells and operators in the processed network will add precision conversion operations to calculate with lower precision: mstype.float16 or mstype.bfloat16 . Inputs and parameters of cells and operators are converted to lower precision float, and calculation results are converted back to full precision float, i.e. mstype.float32 .

The framework has a set of built-in blacklists and whitelists, and the amp_level determines which cells and operators are specifically converted.

The current built-in whitelist contents are:

[mindspore.nn.Conv1d, mindspore.nn.Conv2d, mindspore.nn.Conv3d, mindspore.nn.Conv1dTranspose, mindspore.nn.Conv2dTranspose, mindspore.nn.Conv3dTranspose, mindspore.nn.Dense, mindspore.nn.LSTMCell, mindspore.nn.RNNCell, mindspore.nn.GRUCell, mindspore.ops.Conv2D, mindspore.ops.Conv3D, mindspore.ops.Conv2DTranspose, mindspore.ops.Conv3DTranspose, mindspore.ops.MatMul, mindspore.ops.BatchMatMul, mindspore.ops.PReLU, mindspore.ops.ReLU, mindspore.ops.Ger]

The current built-in blacklist contents are:

[mindspore.nn.BatchNorm1d, mindspore.nn.BatchNorm2d, mindspore.nn.BatchNorm3d, mindspore.nn.LayerNorm]

For details on automatic mixed precision, refer to Automatic Mix Precision .

Note

  • Repeatedly calling mixed-precision interfaces, such as custom_mixed_precision and auto_mixed_precision, can result in a larger network hierarchy and slower performance.

  • If interfaces like Model and build_train_network is used to train the network which is converted by mixed-precision interfaces such as custom_mixed_precision and auto_mixed_precision, amp_level need to be configured to O0 to avoid the duplicated accuracy conversion.

Parameters
  • network (Cell) – Definition of the network.

  • amp_level (str) –

    Supports [“O0”, “O1”, “O2”, “O3”]. Default: "O0" .

    • ”O0”: Do not change.

    • ”O1”: Convert cells and operators in whitelist to lower precision operations, and keep full precision operations for the rest.

    • ”O2”: Keep full precision operations for cells and operators in blacklist, and convert the rest to lower precision operations.

    • ”O3”: Cast network to lower precision.

  • dtype (Type) – The type used in lower precision calculations, can be mstype.float16 or mstype.bfloat16 , default: mstype.float16 .

Raises
  • TypeError – If network is not a Cell.

  • ValueError – If dtype is not one of mstype.float16 , mstype.bfloat16 .

  • ValueError – If amp_level is not within the supported range.

Examples

>>> from mindspore import amp
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> network = LeNet5()
>>> amp_level = "O1"
>>> net = amp.auto_mixed_precision(network, amp_level)