mindspore.nn.Adagrad

View Source On Gitee
class mindspore.nn.Adagrad(params, accum=0.1, learning_rate=0.001, update_slots=True, loss_scale=1.0, weight_decay=0.0)[source]

Implements the Adagrad algorithm.

Adagrad is an online Learning and Stochastic Optimization. Refer to paper Efficient Learning using Forward-Backward Splitting. Adagrad can adaptively assign different learning rates to each parameter in response to the uneven number of samples for different parameters.

The updating Pseudo codes are as follows:

\[\begin{split}\begin{aligned} \\ &\newline &\hline \\ &\textbf{Parameters}: \text{learning rate } \gamma, \: \text{ params } w_0, \: \: \text{ weight decay } \lambda, \\ &\hspace{12mm} \text{ initial accumulator value } state\_sum\\ &\textbf{Init}: state\_sum_0 \leftarrow 0 \\[-1.ex] &\newline &\hline \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{5mm}g_t \leftarrow \nabla_{w} f_t (w_{t-1}) \\ &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\ &\hspace{10mm} g_t \leftarrow g_t + \lambda w_{t-1} \\ &\hspace{5mm}state\_sum_t \leftarrow state\_sum_{t-1} + g^2_t \\ &\hspace{5mm}w_t \leftarrow w_{t-1}- \gamma*\frac{g_t}{\sqrt{state\_sum_t} + \epsilon} \\ &\newline &\hline \\ &\bf{return} \: w_t \\[-1.ex] &\newline &\hline \\ \end{aligned}\end{split}\]

\(state\_sum\) stands for the accumulated squared sum of the gradients \(accum\). \(g\) stands for grads, \(\lambda\) stands for weight_decay. \(\gamma\) stands for learning_rate, \(w\) stands for params.

Note

If parameters are not grouped, the weight_decay in optimizer will be applied on the network parameters without 'beta' or 'gamma' in their names. Users can group parameters to change the strategy of decaying weight. When parameters are grouped, each group can set weight_decay. If not, the weight_decay in optimizer will be applied.

Parameters
  • params (Union[list[Parameter], list[dict]]) –

    Must be list of Parameter or list of dict. When the params is a list of dict, the string "params", "lr", "weight_decay", "grad_centralization" and "order_params" are the keys can be parsed.

    • params: Required. Parameters in current group. The value must be a list of Parameter.

    • lr: Optional. If "lr" in the keys, the value of corresponding learning rate will be used. If not, the learning_rate in optimizer will be used. Fixed and dynamic learning rate are supported.

    • weight_decay: Optional. If "weight_decay" in the keys, the value of corresponding weight decay will be used. If not, the weight_decay in the optimizer will be used. It should be noted that weight decay can be a constant value or a Cell. It is a Cell only when dynamic weight decay is applied. Dynamic weight decay is similar to dynamic learning rate, users need to customize a weight decay schedule only with global step as input, and during training, the optimizer calls the instance of WeightDecaySchedule to get the weight decay value of current step.

    • grad_centralization: Optional. Must be Boolean. If "grad_centralization" is in the keys, the set value will be used. If not, the grad_centralization is False by default. This configuration only works on the convolution layer.

    • order_params: Optional. When parameters is grouped, this usually is used to maintain the order of parameters that appeared in the network to improve performance. The value should be parameters whose order will be followed in optimizer. If order_params in the keys, other keys will be ignored and the element of 'order_params' must be in one group of params.

  • accum (float) – The starting value for \(h\), must be zero or positive values. Default: 0.1 .

  • learning_rate (Union[float, int, Tensor, Iterable, LearningRateSchedule]) –

    Default: 0.001 .

    • float: The fixed learning rate value. Must be equal to or greater than 0.

    • int: The fixed learning rate value. Must be equal to or greater than 0. It will be converted to float.

    • Tensor: Its value should be a scalar or a 1-D vector. For scalar, fixed learning rate will be applied. For vector, learning rate is dynamic, then the i-th step will take the i-th value as the learning rate.

    • Iterable: Learning rate is dynamic. The i-th step will take the i-th value as the learning rate.

    • LearningRateSchedule: Learning rate is dynamic. During training, the optimizer calls the instance of LearningRateSchedule with step as the input to get the learning rate of current step.

  • update_slots (bool) – Whether the \(h\) will be updated. Default: True .

  • loss_scale (float) – Value for the loss scale. It must be greater than 0.0. In general, use the default value. Only when FixedLossScaleManager is used for training and the drop_overflow_update in FixedLossScaleManager is set to False, then this value needs to be the same as the loss_scale in FixedLossScaleManager. Refer to class mindspore.amp.FixedLossScaleManager for more details. Default: 1.0 .

  • weight_decay (Union[float, int, Cell]) –

    Weight decay (L2 penalty). Default: 0.0 .

    • float: The fixed weight decay value. Must be equal to or greater than 0.

    • int: The fixed weight decay value. Must be equal to or greater than 0. It will be converted to float.

    • Cell: Weight decay is dynamic. During training, the optimizer calls the instance of the Cell with step as the input to get the weight decay value of current step.

Inputs:
  • grads (tuple[Tensor]) - The gradients of params in the optimizer, the shape is the same as the params in optimizer.

Outputs:

Tensor[bool], the value is True .

Raises
  • TypeError – If learning_rate is not one of int, float, Tensor, Iterable, LearningRateSchedule.

  • TypeError – If element of parameters is neither Parameter nor dict.

  • TypeError – If accum or loss_scale is not a float.

  • TypeError – If update_slots is not a bool.

  • TypeError – If weight_decay is neither float nor int.

  • ValueError – If loss_scale is less than or equal to 0.

  • ValueError – If accum or weight_decay is less than 0.

Supported Platforms:

Ascend GPU CPU

Examples

>>> from mindspore import train
>>> import mindspore.nn as nn
>>>
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.Adagrad(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01, 'grad_centralization':True},
...                 {'params': no_conv_params, 'lr': 0.01},
...                 {'order_params': net.trainable_params()}]
>>> optim = nn.Adagrad(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01 and grad
>>> # centralization of True.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0 and grad
>>> # centralization of False.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = train.Model(net, loss_fn=loss, optimizer=optim)