mindspore.nn.AdaSumByDeltaWeightWrapCell
- class mindspore.nn.AdaSumByDeltaWeightWrapCell(optimizer)[source]
Enable the adasum in “auto_parallel/semi_auto_parallel” mode. The implementation of the Adaptive Summation (AdaSum) algorithm is calculated based on the difference of weights before and after the updating of optimizer. See the paper AdaSum: Scaling Distributed Training with Adaptive Summation.
\[\begin{split}\begin{array}{ll} w_{t+1}=w_{t} - \alpha \cdot Adasum(g_{1}, g_{2}) \\ w_{t+1}=w_{t} - \alpha \cdot [(1 - \frac{g_2^{T}\cdot g_1}{2\cdot \left \| g_1 \right \|^2 })\cdot g_1 + (1 - \frac{g_1^{T}\cdot g_2}{2\cdot \left \| g_2 \right \|^2 })\cdot g_2] \\ \end{array}\end{split}\]In this implementation, \(g\) represents the weight difference before and after the updating of optimizer, and the subscripts represent different devices in the data parallel dimension.
Note
When using AdaSum, the number of traning cards needs to be a power of 2 and at least 16 cards are required. Currently, the optimizer sharding and pipeline parallel is not supported when using AdaSum. It is recommended to using AdaSumByDeltaWeightWrapCell in semi auto parallel/auto parallel mode, and in data parallel mode, we recommend to using mindspore.boost to applying AdaSum.
- Parameters
optimizer (Union[Cell]) – Optimizer for updating the weights. The construct function of the optimizer requires only one input.
- Inputs:
grads (Tuple(Tensor)) - Tuple of gradients, same with the input of passed optimizer.
- Raises
RuntimeError – If parallel_mode uses stand_alone mode, AdaSum only supports use in distributed scenarios.
RuntimeError – If the optimizer parallel is used when using AdaSum.
RuntimeError – If the pipeline parallel is used when using AdaSum.
RuntimeError – If device_num is not a power of 2, or less than 16.
- Supported Platforms:
Ascend
GPU
Examples
>>> from mindspore import nn >>> from mindspore.nn import AdaSumByDeltaWeightWrapCell >>> net = Net() >>> optim = AdaSumByDeltaWeightWrapCell(nn.Momentum(params=net.trainable_params(), ... learning_rate=0.1, momentum=0.9)) >>> loss = nn.SoftmaxCrossEntropyWithLogits() >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)