mindspore.amp.auto_mixed_precision

查看源文件
mindspore.amp.auto_mixed_precision(network, amp_level='O0', dtype=mstype.float16)[源代码]

返回一个经过自动混合精度处理的网络。

该接口会对输入网络进行自动混合精度处理,处理后的网络里的Cell和算子增加了精度转换操作,以低精度进行计算,如 mstype.float16mstype.bfloat16 。 Cell和算子的输入和参数被转换成低精度浮点数,计算结果被转换回全精度浮点数,即 mstype.float32

amp_level 及其对应名单决定了哪些Cell和算子需要进行精度转换。

amp_level 配置为 O0 时,不对Cell和算子进行精度转换。

amp_level 配置为 O1 时,白名单内的Cell和算子会被转换为低精度运算。白名单的具体内容可参考 mindspore.amp.get_white_list()

amp_level 配置为 O2 时,黑名单内的Cell保持全精度运算,名单外的Cell会被转换为低精度运算。黑名单的具体内容可参考 mindspore.amp.get_black_list()

amp_level 配置为 O3 时,所有Cell和算子都转换为低精度运算。

amp_level 配置为 auto 时, auto_whitelist 名单里的算子会被转换为低精度运算, auto_blacklist 名单里的算子会被转换为全精度运算, promote_list 名单里的算子会被转换为算子输入中最高精度的浮点类型,名单外的算子使用输入的类型进行计算。

auto_whitelist 名单里的算子包括:

Conv2DConv3DConv2DTransposeConv3DTransposeConvolutionMatMulMatMulExtBatchMatMulBatchMatMulExtPReLUEinsumDenseAddmm

auto_blacklist 名单里的算子包括:

PowACosAsinCoshErfinvExpExpm1LogLog1pReciprocalRsqrtSinhTanSoftplusSoftplusExtLayerNormLayerNormExtBatchNormGroupNormKLDivLossSmoothL1LossMultilabelMarginLossSoftMarginLossTripletMarginLossMultiMarginLossBCEWithLogitsLossPdistCdistRenormReduceProdSoftmaxLogSoftmaxCumProdCumSumCumsumExtProdExtSumExtNorm

promote_list 名单里的算子包括:

AddcdivAddcmulCross_PyboostCrossPrimDotGridSampler2DGridSampler3DBiasAdd

关于自动混合精度的详细介绍,请参考 自动混合精度

说明

  • 重复调用混合精度接口,如 custom_mixed_precisionauto_mixed_precision ,可能导致网络层数增大,性能降低。

  • 如果使用 mindspore.train.Modelmindspore.amp.build_train_network() 等接口来训练经 过 custom_mixed_precisionauto_mixed_precision 等混合精度接口转换后的网络,则需要将 amp_level 配置 为 O0 以避免重复的精度转换。

  • amp_level 配置为 auto 时,网络输出的类型可能是低精度类型,此时可能需要手动转换类型以避免loss函数出现类型不一致的报错。

  • amp_level 配置为 auto ,而网络里的Cell配置了 to_float 时, to_float 指定的精度优先生效。

警告

auto 等级的 amp_level 是实验性API,后续可能修改或删除。

参数:
  • network (Union[Cell, function]) - 定义网络结构。仅当 amp_level 配置为 auto 时支持Function类型。

  • amp_level (str) - 支持["O0", "O1", "O2", "O3", "auto"]。默认值: "O0"

    • "O0" - 不变化。

    • "O1" - 仅将白名单内的Cell和算子转换为低精度运算,其余部分保持全精度运算。

    • "O2" - 黑名单内的Cell和算子保持全精度运算,其余部分都转换为低精度运算。

    • "O3" - 将网络全部转为低精度运算。

    • "auto" - 将 auto_whitelist 名单内的算子转换为低精度运算, auto_blacklist 名单内的算子转换为全精度运算, promote_list 名单内的算子转换为算子输入中最高精度的浮点类型,名单外的算子使用输入的类型进行计算。

  • dtype (Type) - 低精度计算时使用的数据类型,可以是 mstype.float16mstype.bfloat16 。默认值: mstype.float16

异常:
  • TypeError - network 不是Cell或函数。

  • ValueError - amp_level 不在支持范围内。

  • ValueError - dtype 既不是 mstype.float16 也不是 mstype.bfloat16

样例:

>>> from mindspore import amp
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> network = LeNet5()
>>> amp_level = "O1"
>>> net = amp.auto_mixed_precision(network, amp_level)