mindspore.nn.DynamicLossScaleUpdateCell
- class mindspore.nn.DynamicLossScaleUpdateCell(loss_scale_value, scale_factor, scale_window)[source]
Dynamic Loss scale update cell.
For loss scaling training, the initial loss scaling value will be set to be loss_scale_value. In each training step, the loss scaling value will be updated by loss scaling value/scale_factor when there is an overflow. And it will be increased by loss scaling value * scale_factor if there is no overflow for a continuous scale_window steps. This cell is used for Graph mode training in which all logic will be executed on device side(Another training mode is normal(non-sink) mode in which some logic will be executed on host).
- Parameters
- Inputs:
loss_scale (Tensor) - The loss scale value during training with shape \(()\).
overflow (bool) - Whether the overflow occurs or not.
- Outputs:
bool, the input overflow.
- Raises
TypeError – If dtype of inputs or label is neither float16 nor float32.
- Supported Platforms:
Ascend
GPU
Examples
>>> import numpy as np >>> from mindspore import Tensor, Parameter, nn >>> import mindspore.ops as ops >>> >>> class Net(nn.Cell): ... def __init__(self, in_features, out_features): ... super(Net, self).__init__() ... self.weight = Parameter(Tensor(np.ones([in_features, out_features]).astype(np.float32)), ... name='weight') ... self.matmul = ops.MatMul() ... ... def construct(self, x): ... output = self.matmul(x, self.weight) ... return output ... >>> in_features, out_features = 16, 10 >>> net = Net(in_features, out_features) >>> loss = nn.MSELoss() >>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) >>> net_with_loss = nn.WithLossCell(net, loss) >>> manager = nn.DynamicLossScaleUpdateCell(loss_scale_value=2**12, scale_factor=2, scale_window=1000) >>> train_network = nn.TrainOneStepWithLossScaleCell(net_with_loss, optimizer, scale_sense=manager) >>> input = Tensor(np.ones([out_features, in_features]), mindspore.float32) >>> labels = Tensor(np.ones([out_features,]), mindspore.float32) >>> output = train_network(input, labels)