mindspore.amp.LossScaleManager

class mindspore.amp.LossScaleManager[源代码]

使用混合精度时,用于管理损失缩放系数(loss scale)的抽象类。

派生类需要实现该类的所有方法。 get_loss_scale 用于获取当前的梯度放大系数。 update_loss_scale 用于更新梯度放大系数,该方法将在训练过程中被调用。 get_update_cell 用于获取更新梯度放大系数的 mindspore.nn.Cell 实例,该实例将在训练过程中被调用。当前多使用 get_update_cell 方式。

例如:mindspore.amp.FixedLossScaleManagermindspore.amp.DynamicLossScaleManager

get_loss_scale()[源代码]

获取梯度放大系数(loss scale)的值。

get_update_cell()[源代码]

获取用于更新梯度放大系数的 mindspore.nn.Cell 实例。

update_loss_scale(overflow)[源代码]

根据 overflow 状态更新梯度放大系数(loss scale)。

参数:
  • overflow (bool) - 表示训练过程是否溢出。