mindspore.amp.auto_mixed_precision

View Source On Gitee
mindspore.amp.auto_mixed_precision(network, amp_level='O0', dtype=mstype.float16)[source]

Returns a network processed with auto mixed precision.

This interface will automatically perform mixed-precision processing on the input network, and the cells and operators in the processed network will add precision conversion operations to calculate with lower precision: mstype.float16 or mstype.bfloat16 . Inputs and parameters of cells and operators are converted to lower precision float, and calculation results are converted back to full precision float, i.e. mstype.float32 .

The amp_level and its corresponding lists determine which cells and operators are converted.

When amp_level is set to O0, no cells and operators are converted.

When amp_level is set to O1, cells and operators in whitelist will be converted to lower precision operations. For details on whitelist, refer to mindspore.amp.get_white_list().

When amp_level is set to O2, cells in blacklist will maintain full precision, and cells outside the list will be converted to low precision. For details on blacklist, refer to mindspore.amp.get_black_list().

When amp_level is set to O3, all cells will be converted to low precision.

When amp_level is set to auto, operators in auto_whitelist will be converted to lower precision operations, operators in auto_blacklist will be converted to full precision operations, operators in promote_list will be converted to the higher accuracy float type of the operator inputs, and operators not listed will run in the type defined by their inputs.

Operators in auto_whitelist are:

Conv2D, Conv3D, Conv2DTranspose, Conv3DTranspose, Convolution, MatMul, MatMulExt, BatchMatMul, BatchMatMulExt, PReLU, Einsum, Dense, Addmm

Operators in auto_blacklist are:

Pow, ACos, Asin, Cosh, Erfinv, Exp, Expm1, Log, Log1p, Reciprocal, Rsqrt, Sinh, Tan, Softplus, SoftplusExt, LayerNorm, LayerNormExt, BatchNorm, GroupNorm, KLDivLoss, SmoothL1Loss, MultilabelMarginLoss, SoftMarginLoss, TripletMarginLoss, MultiMarginLoss, BCEWithLogitsLoss, Pdist, Cdist, Renorm, ReduceProd, Softmax, LogSoftmax, CumProd, CumSum, CumsumExt, ProdExt, SumExt, Norm

Operators in promote_list are:

Addcdiv, Addcmul, Cross, _PyboostCrossPrim, Dot, GridSampler2D, GridSampler3D, BiasAdd

For details on automatic mixed precision, refer to Automatic Mix Precision .

Note

  • Repeatedly calling mixed-precision interfaces, such as custom_mixed_precision and auto_mixed_precision, can result in a larger network hierarchy and slower performance.

  • If interfaces like Model and build_train_network is used to train the network which is converted by mixed-precision interfaces such as custom_mixed_precision and auto_mixed_precision, amp_level need to be configured to O0 to avoid the duplicated accuracy conversion.

  • When amp_level is set to auto, the output of the network may be lower precision. In this case, you may need to manually convert the type to avoid type inconsistency errors of the loss function.

  • When amp_level is set to auto, and cells in the network are configured with to_float, the accuracy specified by to_float takes effect first.

Warning

auto level of amp_level is an experimental API that is subject to change or deletion.

Parameters
  • network (Union[Cell, function]) – Definition of the network. Function type is supported only when amp_level is set to auto .

  • amp_level (str) –

    Supports ["O0", "O1", "O2", "O3", "auto"]. Default: "O0" .

    • "O0": Do not change.

    • "O1": Convert cells and operators in whitelist to lower precision operations, and keep full precision operations for the rest.

    • "O2": Keep full precision operations for cells and operators in blacklist, and convert the rest to lower precision operations.

    • "O3": Cast network to lower precision.

    • "auto": Operators in auto_whitelist will be converted to lower precision operations, operators in auto_blacklist will be converted to full precision, operators in promote_list will be converted to the higher accuracy float type of the operator inputs, and operators not listed will run in the type defined by their inputs.

  • dtype (Type) – The type used in lower precision calculations, can be mstype.float16 or mstype.bfloat16 , default: mstype.float16 .

Raises
  • TypeError – If network is not a Cell or a function.

  • ValueError – If dtype is not one of mstype.float16 , mstype.bfloat16 .

  • ValueError – If amp_level is not within the supported range.

Examples

>>> from mindspore import amp
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> network = LeNet5()
>>> amp_level = "O1"
>>> net = amp.auto_mixed_precision(network, amp_level)