mindspore.amp.FixedLossScaleManager

查看源文件
class mindspore.amp.FixedLossScaleManager(loss_scale=128.0, drop_overflow_update=True)[源代码]

损失缩放系数不变的管理器,继承自 mindspore.amp.LossScaleManager

参数:
  • loss_scale (float) - 梯度放大系数。注:如果将 drop_overflow_update 设为 False ,则定义优化器时需要将优化器的 loss_scale 设为相同的值。默认值: 128.0

  • drop_overflow_update (bool) - 出现溢出时,是否执行优化器。如果值为 True ,则出现溢出时不会执行优化器。默认值: True

样例:

>>> import mindspore as ms
>>> from mindspore import amp, nn
>>>
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> loss_scale = 1024.0
>>> loss_scale_manager = amp.FixedLossScaleManager(loss_scale, False)
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=loss_scale)
>>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
get_drop_overflow_update()[源代码]

返回 drop_overflow_update ,该值表示是否在发生溢出时放弃本轮参数更新。

返回:

bool, drop_overflow_update 的值。

get_loss_scale()[源代码]

获取loss scale值。

返回:

bool,loss_scale 的值。

get_update_cell()[源代码]

返回用于更新 loss_scale 值的 mindspore.nn.Cell 实例, mindspore.nn.TrainOneStepWithLossScaleCell 会调用该实例。该类使用固定的梯度放大系数,因此该实例不执行任何操作。

返回:

None或 Cell 。当 drop_overflow_update 为True时,返回 mindspore.nn.FixedLossScaleUpdateCell 实例,当 drop_overflow_update 为False时,返回None。

update_loss_scale(overflow)[源代码]

更新loss scale值。类 mindspore.amp.FixedLossScaleManager 中,该方法不执行任何操作。

参数:
  • overflow (bool) - 表示是否溢出。