mindspore.amp.LossScaler
- class mindspore.amp.LossScaler[源代码]
使用混合精度时,用于管理损失缩放系数(loss scaler)的抽象类。
派生类需要实现该类的所有方法。训练过程中,scale 和 unscale 用于对损失值或梯度进行放大或缩小,以避免数据溢出;adjust 用于调整损失缩放系数 scale_value 的值。
关于使用 LossScaler 进行损失缩放,请查看 教程。
警告
这是一个实验性API,后续可能修改或删除。
样例:
>>> from mindspore.amp import LossScaler, _grad_scale_map, _grad_unscale_map >>> from mindspore import ops, Parameter, Tensor >>> from mindspore.common import dtype as mstype >>> >>> class MyLossScaler(LossScaler): ... def __init__(self, scale_value): ... self.scale_value = Parameter(Tensor(scale_value, dtype=mstype.float32), name="scale_value") ... ... def scale(self, inputs): ... inputs = mutable(inputs) ... return _grad_scale_map(self.scale_value, inputs) ... ... def unscale(self, inputs): ... inputs = mutable(inputs) ... return _grad_unscale_map(self.scale_value, inputs) ... ... def adjust(self, grads_finite): ... scale_mul_factor = self.scale_value * self.scale_factor ... scale_value = ops.select(grads_finite, scale_mul_factor, self.scale_value) ... ops.assign(self.scale_value, scale_value) ... return True >>> >>> loss_scaler = MyLossScaler(1024)
- abstract adjust(grads_finite)[源代码]
根据梯度是否为有效值(无溢出)对 scale_value 进行调整。
- 参数:
grads_finite (Tensor) - bool类型的标量Tensor,表示梯度是否为有效值(无溢出)。