mindspore.ops.grad

mindspore.ops.grad(fn, grad_position=0, weights=None, has_aux=False)[source]

A wrapper function to generate the gradient function for the input function.

As for gradient, three typical cases are included:

  1. gradient with respect to inputs. In this case, grad_position is not None while weights is None.

  2. gradient with respect to weights. In this case, grad_position is None while weights is not None.

  3. gradient with respect to inputs and weights. In this case, grad_position and weights are not None.

Parameters
  • fn (Union(Cell, function)) – Function to do GradOperation.

  • grad_position (Union(NoneType, int, tuple[int])) – If int, get the gradient with respect to single input. If tuple, get the gradients with respect to selected inputs. ‘grad_position’ begins with 0. If None, none derivative of any input will be figured out, and in this case, weights is required. Default: 0.

  • weights (Union(ParameterTuple, Parameter, list(Parameter))) – The parameters of the training network that need to calculate the gradient. weights can be got through weights = net.trainable_params(). Default: None.

  • has_aux (bool) – If True, only the first output of fn contributes the gradient of fn, while the other outputs will be returned straightly. It means the fn must return more than one outputs in this case. Specially, this is an experimental feature and is subjected to change. Default: False.

Returns

Function, the gradient function to calculate gradient for the input function or cell.

For example, as for out1, out2 = fn(*args), gradient function will return outputs like (gradient, out2) when has_aux is set True, otherwise gradient.

Raises
  • ValueError – If both grad_position and weights are None.

  • TypeError – If type of Args does not belong to required ones.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, ops
>>> from mindspore.ops import grad
>>>
>>> # Cell object to be differentiated
>>> class Net(nn.Cell):
...     def construct(self, x, y, z):
...         return x * y * z
>>> x = Tensor([1, 2], mindspore.float32)
>>> y = Tensor([-2, 3], mindspore.float32)
>>> z = Tensor([0, 3], mindspore.float32)
>>> net = Net()
>>> output = grad(net, grad_position=(1, 2))(x, y, z)
>>> print(output)
(Tensor(shape=[2], dtype=Float32, value=[ 0.00000000e+00,  6.00000000e+00]),
 Tensor(shape=[2], dtype=Float32, value=[-2.00000000e+00,  6.00000000e+00]))
>>>
>>> # Function object to be differentiated
>>> def fn(x, y, z):
...     res = x * ops.exp(y) * ops.pow(z, 2)
...     return res, z
>>> x = Tensor([3, 3], mindspore.float32)
>>> y = Tensor([0, 0], mindspore.float32)
>>> z = Tensor([5, 5], mindspore.float32)
>>> gradient, aux = grad(fn, (1, 2), None, True)(x, y, z)
>>> print(gradient)
(Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01,  7.50000000e+01]),
 Tensor(shape=[2], dtype=Float32, value= [ 3.00000000e+01,  3.00000000e+01]))
>>> print(aux)
(Tensor(shape=[2], dtype=Float32, value= [ 5.00000000e+00,  5.00000000e+00]),)
>>>
>>> # For given network to be differentiated with both inputs and weights, there are 3 cases.
>>> net = nn.Dense(10, 1)
>>> loss_fn = nn.MSELoss()
>>> def forward(inputs, labels):
...     logits = net(inputs)
...     loss = loss_fn(logits, labels)
...     return loss, logits
>>> inputs = Tensor(np.random.randn(16, 10).astype(np.float32))
>>> labels = Tensor(np.random.randn(16, 1).astype(np.float32))
>>> weights = net.trainable_params()
>>> # Case 1: gradient with respect to inputs.
>>> grad_fn = grad(forward, grad_position=0, weights=None, has_aux=True)
>>> inputs_gradient, (aux_logits,) = grad_fn(inputs, labels)
>>> print(aux_logits.shape)
(16, 1)
>>>
>>> # Case 2: gradient with respect to weights.
>>> grad_fn = grad(forward, grad_position=None, weights=weights, has_aux=True)
>>> params_gradient, (aux_logits,) = grad_fn(inputs, labels)
>>> print(aux_logits.shape)
(16, 1)
>>> print(len(weights), len(params_gradient))
2 2
>>>
>>> # Case 3: gradient with respect to inputs and weights.
>>> grad_fn = grad(forward, grad_position=0, weights=weights, has_aux=False)
>>> inputs_gradient, params_gradient = grad_fn(inputs, labels)
>>> print(len(weights), len(params_gradient))
2 2