mindspore.amp.auto_mixed_precision
- mindspore.amp.auto_mixed_precision(network, amp_level='O0', dtype=mstype.float16)[source]
Returns a network processed with auto mixed precision.
This interface will automatically perform mixed-precision processing on the input network, and the cells and operators in the processed network will add precision conversion operations to calculate with lower precision:
mstype.float16
ormstype.bfloat16
. Inputs and parameters of cells and operators are converted to lower precision float, and calculation results are converted back to full precision float, i.e.mstype.float32
.The amp_level and its corresponding lists determine which cells and operators are converted.
When amp_level is set to
O0
, no cells and operators are converted.When amp_level is set to
O1
, cells and operators in whitelist will be converted to lower precision operations. For details on whitelist, refer tomindspore.amp.get_white_list()
.When amp_level is set to
O2
, cells in blacklist will maintain full precision, and cells outside the list will be converted to low precision. For details on blacklist, refer tomindspore.amp.get_black_list()
.When amp_level is set to
O3
, all cells will be converted to low precision.When amp_level is set to
auto
, operators in auto_whitelist will be converted to lower precision operations, operators in auto_blacklist will be converted to full precision operations, operators in promote_list will be converted to the higher accuracy float type of the operator inputs, and operators not listed will run in the type defined by their inputs.Operators in auto_whitelist are:
Conv2D
,Conv3D
,Conv2DTranspose
,Conv3DTranspose
,Convolution
,MatMul
,MatMulExt
,BatchMatMul
,BatchMatMulExt
,PReLU
,Einsum
,Dense
,Addmm
Operators in auto_blacklist are:
Pow
,ACos
,Asin
,Cosh
,Erfinv
,Exp
,Expm1
,Log
,Log1p
,Reciprocal
,Rsqrt
,Sinh
,Tan
,Softplus
,SoftplusExt
,LayerNorm
,LayerNormExt
,BatchNorm
,GroupNorm
,KLDivLoss
,SmoothL1Loss
,MultilabelMarginLoss
,SoftMarginLoss
,TripletMarginLoss
,MultiMarginLoss
,BCEWithLogitsLoss
,Pdist
,Cdist
,Renorm
,ReduceProd
,Softmax
,LogSoftmax
,CumProd
,CumSum
,CumsumExt
,ProdExt
,SumExt
,Norm
Operators in promote_list are:
Addcdiv
,Addcmul
,Cross
,_PyboostCrossPrim
,Dot
,GridSampler2D
,GridSampler3D
,BiasAdd
For details on automatic mixed precision, refer to Automatic Mix Precision .
Note
Repeatedly calling mixed-precision interfaces, such as custom_mixed_precision and auto_mixed_precision, can result in a larger network hierarchy and slower performance.
If interfaces like Model and build_train_network is used to train the network which is converted by mixed-precision interfaces such as custom_mixed_precision and auto_mixed_precision, amp_level need to be configured to
O0
to avoid the duplicated accuracy conversion.When amp_level is set to
auto
, the output of the network may be lower precision. In this case, you may need to manually convert the type to avoid type inconsistency errors of the loss function.When amp_level is set to
auto
, and cells in the network are configured with to_float, the accuracy specified by to_float takes effect first.
Warning
auto
level of amp_level is an experimental API that is subject to change or deletion.- Parameters
network (Union[Cell, function]) – Definition of the network. Function type is supported only when amp_level is set to
auto
.amp_level (str) –
Supports ["O0", "O1", "O2", "O3", "auto"]. Default:
"O0"
."O0": Do not change.
"O1": Convert cells and operators in whitelist to lower precision operations, and keep full precision operations for the rest.
"O2": Keep full precision operations for cells and operators in blacklist, and convert the rest to lower precision operations.
"O3": Cast network to lower precision.
"auto": Operators in auto_whitelist will be converted to lower precision operations, operators in auto_blacklist will be converted to full precision, operators in promote_list will be converted to the higher accuracy float type of the operator inputs, and operators not listed will run in the type defined by their inputs.
dtype (Type) – The type used in lower precision calculations, can be
mstype.float16
ormstype.bfloat16
, default:mstype.float16
.
- Raises
TypeError – If network is not a Cell or a function.
ValueError – If dtype is not one of
mstype.float16
,mstype.bfloat16
.ValueError – If amp_level is not within the supported range.
Examples
>>> from mindspore import amp >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> network = LeNet5() >>> amp_level = "O1" >>> net = amp.auto_mixed_precision(network, amp_level)