mindspore.amp.auto_mixed_precision
- mindspore.amp.auto_mixed_precision(network, amp_level='O0')[source]
Returns a network processed with auto mixed precision.
This interface will automatically perform mixed-precision processing on the input network, and the cells and operators in the processed network will add precision conversion operations to calculate with float16 accuracy. Inputs and parameters of cells and operators are converted to float16 type, and calculation results are converted back to float32 type.
The framework has a set of built-in blacklists and whitelists, and the amp_level determines which cells and operators are specifically converted:
When amp_level=”O0” , no precision conversion is performed.
When amp_level=”O1” , only the cells and operators in the whitelist will be converted.
When amp_level=”O2” , all cells and operators except those in the blacklist will be converted.
When amp_level=”O3” , all cells and operators in the network are converted.
The current built-in whitelist contents are:
[
mindspore.nn.Conv1d
,mindspore.nn.Conv2d
,mindspore.nn.Conv3d
,mindspore.nn.Conv1dTranspose
,mindspore.nn.Conv2dTranspose
,mindspore.nn.Conv3dTranspose
,mindspore.nn.Dense
,mindspore.nn.LSTMCell
,mindspore.nn.RNNCell
,mindspore.nn.GRUCell
,mindspore.ops.Conv2D
,mindspore.ops.Conv3D
,mindspore.ops.Conv2DTranspose
,mindspore.ops.Conv3DTranspose
,mindspore.ops.MatMul
,mindspore.ops.BatchMatMul
,mindspore.ops.PReLU
,mindspore.ops.ReLU
,mindspore.ops.Ger
]The current built-in blacklist contents are:
[
mindspore.nn.BatchNorm1d
,mindspore.nn.BatchNorm2d
,mindspore.nn.BatchNorm3d
,mindspore.nn.LayerNorm
]For details on automatic mixed precision, refer to Automatic Mix Precision .
- Parameters
network (Cell) – Definition of the network.
amp_level (str) –
Supports [“O0”, “O1”, “O2”, “O3”]. Default:
"O0"
.”O0”: Do not change.
”O1”: Convert cells and operators in whitelist to float16 precision operations, and keep float32 precision operations for the rest.
”O2”: Keep float32 precision operations for cells and operators in blacklist, and convert the rest to float16 precision operations.
”O3”: Cast network to float16.
- Raises
ValueError – If amp level is not supported.
Examples
>>> from mindspore import amp >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py >>> network = LeNet5() >>> amp_level = "O1" >>> net = amp.auto_mixed_precision(network, amp_level)