mindspore.amp.build_train_network

查看源文件
mindspore.amp.build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs)[源代码]

构建混合精度训练网络。

说明

参数:
  • network (Cell) - 定义网络结构。

  • optimizer (mindspore.nn.Optimizer) - 定义优化器,用于更新权重参数。

  • loss_fn (Union[None, Cell]) - 定义损失函数。如果为None, network 中应该包含损失函数。默认值: None

  • level (str) - 支持['O0', 'O1', 'O2', 'O3', 'auto']。默认值: 'O0'

    level 的详细配置信息可参考 mindspore.amp.auto_mixed_precision()

    level 配置的 keep_batchnorm_fp32cast_model_typeloss_scale_manager 可能会被 kwargs 里的配置覆盖。

  • boost_level (str) - mindspore.boost 中参数 level 的选项,设置boost的训练模式级别。支持['O0', 'O1', 'O2']。默认值: 'O0'

    • 'O0' - 不变化。

    • 'O1' - 开启boost模式,性能提升20%左右,准确率与原始准确率相同。

    • 'O2' - 开启boost模式,性能提升30%左右,准确率降低小于3%。如果设置了'O1'或'O2'模式,boost相关库将自动生效。

  • cast_model_type (mindspore.dtype) - 支持float16,float32。如果设置了该参数,网络将被转化为设置的数据类型,而不会根据设置的level进行转换。

  • keep_batchnorm_fp32 (bool) - 当网络被设置为float16时,配置为True,则BatchNorm将保持在float32下运行。设置level不会影响该属性。

  • loss_scale_manager (Union[None, LossScaleManager]) - 如果不为None,必须是 mindspore.amp.LossScaleManager 的子类,用于缩放损失系数(loss scale)。设置level不会影响该属性。

异常:

样例:

>>> from mindspore import amp, nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> network = LeNet5()
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
>>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
>>> amp_level="O3"
>>> net = amp.build_train_network(network, net_opt, net_loss, amp_level)