mindspore.amp.LossScaler

查看源文件
class mindspore.amp.LossScaler[源代码]

使用混合精度时,用于管理损失缩放系数(loss scaler)的抽象类。

派生类需要实现该类的所有方法。训练过程中,scaleunscale 用于对损失值或梯度进行放大或缩小,以避免数据溢出;adjust 用于调整损失缩放系数 scale_value 的值。

关于使用 LossScaler 进行损失缩放,请查看 教程

警告

这是一个实验性API,后续可能修改或删除。

样例:

>>> from mindspore.amp import LossScaler, _grad_scale_map, _grad_unscale_map
>>> from mindspore import ops, Parameter, Tensor
>>> from mindspore.common import dtype as mstype
>>>
>>> class MyLossScaler(LossScaler):
...     def __init__(self, scale_value):
...         self.scale_value = Parameter(Tensor(scale_value, dtype=mstype.float32), name="scale_value")
...
...     def scale(self, inputs):
...         inputs = mutable(inputs)
...         return _grad_scale_map(self.scale_value, inputs)
...
...     def unscale(self, inputs):
...         inputs = mutable(inputs)
...         return _grad_unscale_map(self.scale_value, inputs)
...
...     def adjust(self, grads_finite):
...         scale_mul_factor = self.scale_value * self.scale_factor
...         scale_value = ops.select(grads_finite, scale_mul_factor, self.scale_value)
...         ops.assign(self.scale_value, scale_value)
...         return True
>>>
>>> loss_scaler = MyLossScaler(1024)
abstract adjust(grads_finite)[源代码]

根据梯度是否为有效值(无溢出)对 scale_value 进行调整。

参数:
  • grads_finite (Tensor) - bool类型的标量Tensor,表示梯度是否为有效值(无溢出)。

abstract scale(inputs)[源代码]

对inputs进行scale,inputs *= scale_value

参数:
  • inputs (Union(Tensor, tuple(Tensor))) - 损失值或梯度。

abstract unscale(inputs)[源代码]

对inputs进行unscale,inputs /= scale_value

参数:
  • inputs (Union(Tensor, tuple(Tensor))) - 损失值或梯度。