mindspore.nn.AdamOffload

class mindspore.nn.AdamOffload(params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, use_nesterov=False, weight_decay=0.0, loss_scale=1.0)[source]

This optimizer will offload Adam optimizer to host CPU and keep parameters being updated on the device, to minimize the memory cost. Although that would bring about an increase of performance overhead, the optimizer could be used to run a larger model.

The Adam algorithm is proposed in Adam: A Method for Stochastic Optimization.

The updating formulas are as follows,

\[\begin{split}\begin{array}{ll} \\ m_{t+1} = \beta_1 * m_{t} + (1 - \beta_1) * g \\ v_{t+1} = \beta_2 * v_{t} + (1 - \beta_2) * g * g \\ l = \alpha * \frac{\sqrt{1-\beta_2^t}}{1-\beta_1^t} \\ w_{t+1} = w_{t} - l * \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon} \end{array}\end{split}\]

\(m\) represents the 1st moment vector moment1, \(v\) represents the 2nd moment vector moment2, \(g\) represents gradients, \(l\) represents scaling factor, \(\beta_1, \beta_2\) represent beta1 and beta2, \(t\) represents updating step while \(beta_1^t\) and \(beta_2^t\) represent beta1_power and beta2_power, \(\alpha\) represents learning_rate, \(w\) represents params, \(\epsilon\) represents eps.

Note

This optimizer only supports GRAPH_MODE currently.

When separating parameter groups, the weight decay in each group will be applied on the parameters if the weight decay is positive. When not separating parameter groups, the weight_decay in the API will be applied on the parameters without ‘beta’ or ‘gamma’ in their names if weight_decay is positive.

To improve parameter groups performance, the customized order of parameters is supported.

Parameters
  • params (Union[list[Parameter], list[dict]]) –

    When the params is a list of Parameter which will be updated, the element in params must be class Parameter. When the params is a list of dict, the “params”, “lr”, “weight_decay” and “order_params” are the keys can be parsed.

    • params: Required. The value must be a list of Parameter.

    • lr: Optional. If “lr” is in the keys, the value of the corresponding learning rate will be used. If not, the learning_rate in the API will be used.

    • weight_decay: Optional. If “weight_decay” is in the keys, the value of the corresponding weight decay will be used. If not, the weight_decay in the API will be used.

    • order_params: Optional. If “order_params” is in the keys, the value must be the order of parameters and the order will be followed in the optimizer. There are no other keys in the dict and the parameters which in the ‘order_params’ must be in one of group parameters.

  • learning_rate (Union[float, Tensor, Iterable, LearningRateSchedule]) – A value or a graph for the learning rate. When the learning_rate is an Iterable or a Tensor in a 1D dimension, use the dynamic learning rate, then the i-th step will take the i-th value as the learning rate. When the learning_rate is LearningRateSchedule, use dynamic learning rate, the i-th learning rate will be calculated during the process of training according to the formula of LearningRateSchedule. When the learning_rate is a float or a Tensor in a zero dimension, use fixed learning rate. Other cases are not supported. The float learning rate must be equal to or greater than 0. If the type of learning_rate is int, it will be converted to float. Default: 1e-3.

  • beta1 (float) – The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). Default: 0.9.

  • beta2 (float) – The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). Default: 0.999.

  • eps (float) – Term added to the denominator to improve numerical stability. Should be greater than 0. Default: 1e-8.

  • use_locking (bool) – Whether to enable a lock to protect variable tensors from being updated. If true, updates of the var, m, and v tensors will be protected by a lock. If false, the result is unpredictable. Default: False.

  • use_nesterov (bool) – Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients. If true, update the gradients using NAG. If false, update the gradients without using NAG. Default: False.

  • weight_decay (float) – Weight decay (L2 penalty). It must be equal to or greater than 0. Default: 0.0.

  • loss_scale (float) – A floating point value for the loss scale. Should be greater than 0. In general, use the default value. Only when FixedLossScaleManager is used for training and the drop_overflow_update in FixedLossScaleManager is set to False, then this value needs to be the same as the loss_scale in FixedLossScaleManager. Refer to class mindspore.FixedLossScaleManager for more details. Default: 1.0.

Inputs:
  • gradients (tuple[Tensor]) - The gradients of params, the shape is the same as params.

Outputs:

Tensor[bool], the value is True.

Raises
  • TypeError – If learning_rate is not one of int, float, Tensor, Iterable, LearningRateSchedule.

  • TypeError – If element of parameters is neither Parameter nor dict.

  • TypeError – If beta1, beta2, eps or loss_scale is not a float.

  • TypeError – If weight_decay is neither float nor int.

  • TypeError – If use_locking or use_nesterov is not a bool.

  • ValueError – If loss_scale or eps is less than or equal to 0.

  • ValueError – If beta1, beta2 is not in range (0.0, 1.0).

  • ValueError – If weight_decay is less than 0.

Supported Platforms:

Ascend GPU CPU

Examples

>>> net = Net()
>>> #1) All parameters use the same learning rate and weight decay
>>> optim = nn.AdamOffload(params=net.trainable_params())
>>>
>>> #2) Use parameter groups and set different values
>>> conv_params = list(filter(lambda x: 'conv' in x.name, net.trainable_params()))
>>> no_conv_params = list(filter(lambda x: 'conv' not in x.name, net.trainable_params()))
>>> group_params = [{'params': conv_params, 'weight_decay': 0.01},
...                 {'params': no_conv_params, 'lr': 0.01},
...                 {'order_params': net.trainable_params()}]
>>> optim = nn.AdamOffload(group_params, learning_rate=0.1, weight_decay=0.0)
>>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01.
>>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0.
>>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'.
>>>
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim)