mindspore.amp.auto_mixed_precision
- mindspore.amp.auto_mixed_precision(network, amp_level='O0', dtype=mstype.float16)[source]
Returns a network processed with auto mixed precision.
This interface will automatically perform mixed-precision processing on the input network, and the cells and operators in the processed network will add precision conversion operations to calculate with lower precision:
mstype.float16
ormstype.bfloat16
. Inputs and parameters of cells and operators are converted to lower precision float, and calculation results are converted back to full precision float, i.e.mstype.float32
.The framework has a set of built-in blacklists and whitelists, and the amp_level determines which cells and operators are specifically converted.
The current built-in whitelist contents are:
[
mindspore.nn.Conv1d
,mindspore.nn.Conv2d
,mindspore.nn.Conv3d
,mindspore.nn.Conv1dTranspose
,mindspore.nn.Conv2dTranspose
,mindspore.nn.Conv3dTranspose
,mindspore.nn.Dense
,mindspore.nn.LSTMCell
,mindspore.nn.RNNCell
,mindspore.nn.GRUCell
,mindspore.ops.Conv2D
,mindspore.ops.Conv3D
,mindspore.ops.Conv2DTranspose
,mindspore.ops.Conv3DTranspose
,mindspore.ops.MatMul
,mindspore.ops.BatchMatMul
,mindspore.ops.PReLU
,mindspore.ops.ReLU
,mindspore.ops.Ger
]The current built-in blacklist contents are:
[
mindspore.nn.BatchNorm1d
,mindspore.nn.BatchNorm2d
,mindspore.nn.BatchNorm3d
,mindspore.nn.LayerNorm
]For details on automatic mixed precision, refer to Automatic Mix Precision .
Note
Repeatedly calling mixed-precision interfaces, such as custom_mixed_precision and auto_mixed_precision, can result in a larger network hierarchy and slower performance.
If interfaces like Model and build_train_network is used to train the network which is converted by mixed-precision interfaces such as custom_mixed_precision and auto_mixed_precision, amp_level need to be configured to
O0
to avoid the duplicated accuracy conversion.
- Parameters
network (Cell) – Definition of the network.
amp_level (str) –
Supports ["O0", "O1", "O2", "O3"]. Default:
"O0"
."O0": Do not change.
"O1": Convert cells and operators in whitelist to lower precision operations, and keep full precision operations for the rest.
"O2": Keep full precision operations for cells and operators in blacklist, and convert the rest to lower precision operations.
"O3": Cast network to lower precision.
dtype (Type) – The type used in lower precision calculations, can be
mstype.float16
ormstype.bfloat16
, default:mstype.float16
.
- Raises
TypeError – If network is not a Cell.
ValueError – If dtype is not one of
mstype.float16
,mstype.bfloat16
.ValueError – If amp_level is not within the supported range.
Examples
>>> from mindspore import amp >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> network = LeNet5() >>> amp_level = "O1" >>> net = amp.auto_mixed_precision(network, amp_level)