mindspore.nn.thor

mindspore.nn.thor(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32, use_nesterov=False, decay_filter=lambda x: ..., split_indices=None, enable_clip_grad=False, frequency=100)[source]

Updates gradients by second-order algorithm–THOR.

Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation (THOR) algorithm is proposed in:

THOR: Trace-based Hardware-driven layer-ORiented Natural Gradient Descent Computation

The updating formulas are as follows,

\[\begin{split}\begin{array}{ll} \\ A_i = a_i{a_i}^T \\ G_i = D_{s_i}{ D_{s_i}}^T \\ m_i = \beta * m_i + ({G_i^{(k)}}+\lambda I)^{-1}) g_i ({\overline A_{i-1}^{(k)}}+\lambda I)^{-1} \\ w_i = w_i - \alpha * m_i \\ \end{array}\end{split}\]

\(D_{s_i}\) represents the derivative of the loss function of the output of the i-th layer, \(a_{i-1}\) represents the input of i-th layer,and which is the activations of previous layer, \(\beta\) represents momentum, \(I\) represents the identity matrix, \(\overline A\) represents the transpose of matrix A, \(\lambda\) represents ‘damping’, \(g_i\) represents gradients of the i-th layer, \(\otimes\) represents Kronecker product, \(\alpha\) represents ‘learning rate’

Parameters
  • net (Cell) – The training network.

  • learning_rate (Tensor) – A value for the learning rate.

  • damping (Tensor) – A value for the damping.

  • momentum (float) – Hyper-parameter of type float, means momentum for the moving average. It must be at least 0.0.

  • weight_decay (int, float) – Weight decay (L2 penalty). It must be equal to or greater than 0.0. Default: 0.0.

  • loss_scale (float) – A value for the loss scale. It must be greater than 0.0. In general, use the default value. Default: 1.0.

  • batch_size (int) – The size of a batch. Default: 32

  • use_nesterov (bool) – Enable Nesterov momentum. Default: False.

  • decay_filter (function) – A function to determine which layers the weight decay applied to. And it only works when the weight_decay > 0. Default: lambda x: x.name not in []

  • split_indices (list) – Set allreduce fusion strategy by A/G layer indices . Only works when distributed computing. ResNet50 as an example, there are 54 layers of A/G respectively, when split_indices is set to [26, 53], it means A/G is divided into two groups to allreduce, one is 0~26 layer, and the other is 27~53. Default: None

  • enable_clip_grad (bool) – Whether to clip the gradients. Default: False

  • frequency (int) – The update interval of A/G and $A^{-1}/G^{-1}$. When frequency equals N (N is greater than 1), A/G and $A^{-1}/G^{-1}$ will be updated every N steps, and other steps will use the stale A/G and $A^{-1}/G^{-1}$ to update weights. Default: 100.

Inputs:
  • gradients (tuple[Tensor]) - The gradients of params, the shape is the same as params.

Outputs:

tuple[bool], all elements are True.

Raises
  • TypeError – If learning_rate is not Tensor.

  • TypeError – If loss_scale, momentum or frequency is not a float.

  • TypeError – If weight_decay is neither float nor int.

  • TypeError – If use_nesterov is not a bool.

  • ValueError – If loss_scale is less than or equal to 0.

  • ValueError – If weight_decay or momentum is less than 0.

  • ValueError – If frequency is not int.

  • ValueError – If frequency is less than 2.

Supported Platforms:

Ascend GPU

Examples

>>> from mindspore.nn import thor
>>> from mindspore import Model
>>> from mindspore import FixedLossScaleManager
>>> from mindspore.train.callback import LossMonitor
>>> from mindspore.train.train_thor import ConvertModelUtils
>>> from mindspore import nn
>>> from mindspore import Tensor
>>>
>>> net = Net()
>>> dataset = create_dataset()
>>> temp = Tensor([4e-4, 1e-4, 1e-5, 1e-5], mstype.float32)
>>> optim = thor(net, learning_rate=temp, damping=temp, momentum=0.9, loss_scale=128, frequency=4)
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> loss_scale = FixedLossScaleManager(128, drop_overflow_update=False)
>>> model = Model(net, loss_fn=loss, optimizer=optim, loss_scale_manager=loss_scale, metrics={'acc'},
...               amp_level="O2", keep_batchnorm_fp32=False)
>>> model = ConvertModelUtils.convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=optim,
...                                                 loss_scale_manager=loss_scale, metrics={'acc'},
...                                                 amp_level="O2", keep_batchnorm_fp32=False)
>>> loss_cb = LossMonitor()
>>> model.train(1, dataset, callbacks=loss_cb, sink_size=4, dataset_sink_mode=True)