mindspore.amp.auto_mixed_precision
- mindspore.amp.auto_mixed_precision(network, amp_level='O0')[源代码]
返回一个经过自动混合精度处理的网络。
该接口会对输入网络进行自动混合精度处理,处理后的网络里的Cell和算子增加了精度转换操作,以float16精度进行计算。 Cell和算子的输入和参数被转换成float16类型,计算结果被转换回float32类型。
框架内置了一组黑名单和白名单, amp_level 决定了具体对哪些Cell和算子进行精度转换:
当 amp_level=”O0” 时,不进行精度转换。
当 amp_level=”O1” 时,仅将白名单内的Cell和算子进行精度转换。
当 amp_level=”O2” 时,将除了黑名单内的其他Cell和算子都进行精度转换。
当 amp_level=”O3” 时,将网络里的所有Cell和算子都进行精度转换。
当前的内置白名单内容为:
[
mindspore.nn.Conv1d
,mindspore.nn.Conv2d
,mindspore.nn.Conv3d
,mindspore.nn.Conv1dTranspose
,mindspore.nn.Conv2dTranspose
,mindspore.nn.Conv3dTranspose
,mindspore.nn.Dense
,mindspore.nn.LSTMCell
,mindspore.nn.RNNCell
,mindspore.nn.GRUCell
,mindspore.ops.Conv2D
,mindspore.ops.Conv3D
,mindspore.ops.Conv2DTranspose
,mindspore.ops.Conv3DTranspose
,mindspore.ops.MatMul
,mindspore.ops.BatchMatMul
,mindspore.ops.PReLU
,mindspore.ops.ReLU
,mindspore.ops.Ger
]当前的内置黑名单内容为:
[
mindspore.nn.BatchNorm1d
,mindspore.nn.BatchNorm2d
,mindspore.nn.BatchNorm3d
,mindspore.nn.LayerNorm
]关于自动混合精度的详细介绍,请参考 自动混合精度 。
- 参数:
network (Cell) - 定义网络结构。
amp_level (str) - 支持[“O0”, “O1”, “O2”, “O3”]。默认值:”O0”。
“O0” - 不变化。
“O1” - 将白名单内的Cell和算子转换为float16精度运算,其余部分保持float32精度运算。
“O2” - 将黑名单内的Cell和算子保持float32精度运算,其余部分转换为float16精度运算。
“O3” - 将网络全部转为float16精度。
- 异常:
ValueError - amp_level 不在支持范围内。
样例:
>>> from mindspore import amp, nn >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.0/docs/mindspore/code/lenet.py >>> network = LeNet5() >>> amp_level = "O1" >>> net = amp.auto_mixed_precision(network, amp_level)