mindspore.ops.SparseApplyAdagradV2
- class mindspore.ops.SparseApplyAdagradV2(lr, epsilon, use_locking=False, update_slots=True)[source]
Updates relevant entries according to the adagrad scheme, one more epsilon attribute than SparseApplyAdagrad.
\[\begin{split}\begin{array}{ll} \\ accum += grad * grad \\ var -= lr * grad * \frac{1}{\sqrt{accum} + \epsilon} \end{array}\end{split}\]where \(\epsilon\) represents epsilon.
Inputs of var, accum and grad comply with the implicit type conversion rules to make the data types consistent. If they have different data types, the lower priority data type will be converted to the relatively highest priority data type.
- Parameters
- Inputs:
var (Parameter) - Variable to be updated. The data type must be float16 or float32. The shape is \((N, *)\) where \(*\) means, any number of additional dimensions.
accum (Parameter) - Accumulation to be updated. The shape and data type must be the same as var.
grad (Tensor) - Gradients has the same data type as var and \(grad.shape[1:] = var.shape[1:]\) if var.shape > 1.
indices (Tensor) - A vector of indices into the first dimension of var and accum. The type must be int32 and \(indices.shape[0] = grad.shape[0]\).
- Outputs:
Tuple of 2 tensors, the updated parameters.
var (Tensor) - The same shape and data type as var.
accum (Tensor) - The same shape and data type as accum.
- Raises
TypeError – If neither lr nor epsilon is a float.
TypeError – If neither update_slots nor use_locking is a bool.
TypeError – If dtype of var, accum or grad is neither float16 nor float32.
TypeError – If dtype of indices is not int32.
RuntimeError – If the data type of var, accum and grad conversion of Parameter is not supported.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.sparse_apply_adagrad_v2 = ops.SparseApplyAdagradV2(lr=1e-8, epsilon=1e-6) ... self.var = Parameter(Tensor(np.array([[0.2]]).astype(np.float32)), name="var") ... self.accum = Parameter(Tensor(np.array([[0.1]]).astype(np.float32)), name="accum") ... ... def construct(self, grad, indices): ... out = self.sparse_apply_adagrad_v2(self.var, self.accum, grad, indices) ... return out ... >>> net = Net() >>> grad = Tensor(np.array([[0.7]]).astype(np.float32)) >>> indices = Tensor(np.array([0]), mindspore.int32) >>> output = net(grad, indices) >>> print(output) (Tensor(shape=[1, 1], dtype=Float32, value= [[ 1.99999988e-01]]), Tensor(shape=[1, 1], dtype=Float32, value= [[ 5.89999974e-01]]))