mindspore.amp.LossScaler
- class mindspore.amp.LossScaler[source]
Loss scaler abstract class when using mixed precision.
Derived class needs to implement all of its methods. During training, scale and unscale is used to scale and unscale the loss value and gradients to avoid overflow, adjust is used to update the loss scale value.
For more information, refer to the tutorials.
Warning
This is an experimental API that is subject to change or deletion.
Examples
>>> from mindspore.amp import LossScaler, _grad_scale_map, _grad_unscale_map >>> from mindspore import ops, Parameter, Tensor >>> from mindspore.common import dtype as mstype >>> >>> class MyLossScaler(LossScaler): ... def __init__(self, scale_value): ... self.scale_value = Parameter(Tensor(scale_value, dtype=mstype.float32), name="scale_value") ... ... def scale(self, inputs): ... inputs = mutable(inputs) ... return _grad_scale_map(self.scale_value, inputs) ... ... def unscale(self, inputs): ... inputs = mutable(inputs) ... return _grad_unscale_map(self.scale_value, inputs) ... ... def adjust(self, grads_finite): ... scale_mul_factor = self.scale_value * self.scale_factor ... scale_value = ops.select(grads_finite, scale_mul_factor, self.scale_value) ... ops.assign(self.scale_value, scale_value) ... return True >>> >>> loss_scaler = MyLossScaler(1024)
- abstract adjust(grads_finite)[source]
Adjust the scale_value dependent on whether grads are finite.
- Parameters
grads_finite (Tensor) – a scalar bool Tensor indicating whether the grads are finite.