mindspore.grad

mindspore.grad(fn, grad_position=0, weights=None, has_aux=False)[source]

A wrapper function to generate the gradient function for the input function.

As for gradient, three typical cases are included:

  1. gradient with respect to inputs. In this case, grad_position is not None while weights is None.

  2. gradient with respect to weights. In this case, grad_position is None while weights is not None.

  3. gradient with respect to inputs and weights. In this case, grad_position and weights are not None.

Parameters
  • fn (Union[Cell, Function]) – Function to do GradOperation.

  • grad_position (Union[NoneType, int, tuple[int]]) – Index to specify which inputs to be differentiated. If int, get the gradient with respect to single input. If tuple, get the gradients with respect to selected inputs. grad_position begins with 0. If None, none derivative of any input will be figured out, and in this case, weights is required. Default: 0.

  • weights (Union[ParameterTuple, Parameter, list[Parameter]]) – The parameters of the training network that need to calculate the gradient. weights can be got through weights = net.trainable_params() . Default: None.

  • has_aux (bool) – If True, only the first output of fn contributes the gradient of fn, while the other outputs will be returned straightly. It means the fn must return more than one outputs in this case. Default: False.

Returns

Function, the gradient function to calculate gradient for the input function or cell. For example, as for out1, out2 = fn(*args), when has_aux is set True, gradient function will return outputs like (gradient, out2) and out2 does not contribute to the differentiation, otherwise gradient.

Raises
  • ValueError – If both grad_position and weights are None.

  • TypeError – If type of Args does not belong to required ones.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore
>>> import mindspore.nn as nn
>>> from mindspore import Tensor, ops
>>> from mindspore import grad
>>>
>>> # Cell object to be differentiated
>>> class Net(nn.Cell):
...     def construct(self, x, y, z):
...         return x * y * z
>>> x = Tensor([1, 2], mindspore.float32)
>>> y = Tensor([-2, 3], mindspore.float32)
>>> z = Tensor([0, 3], mindspore.float32)
>>> net = Net()
>>> output = grad(net, grad_position=(1, 2))(x, y, z)
>>> print(output)
(Tensor(shape=[2], dtype=Float32, value=[ 0.00000000e+00,  6.00000000e+00]),
 Tensor(shape=[2], dtype=Float32, value=[-2.00000000e+00,  6.00000000e+00]))
>>>
>>> # Function object to be differentiated
>>> def fn(x, y, z):
...     res = x * ops.exp(y) * ops.pow(z, 2)
...     return res, z
>>> x = Tensor([3, 3], mindspore.float32)
>>> y = Tensor([0, 0], mindspore.float32)
>>> z = Tensor([5, 5], mindspore.float32)
>>> gradient, aux = grad(fn, (1, 2), None, True)(x, y, z)
>>> print(gradient)
(Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01,  7.50000000e+01]),
 Tensor(shape=[2], dtype=Float32, value= [ 3.00000000e+01,  3.00000000e+01]))
>>> print(aux)
(Tensor(shape=[2], dtype=Float32, value= [ 5.00000000e+00,  5.00000000e+00]),)
>>>
>>> # For given network to be differentiated with both inputs and weights, there are 3 cases.
>>> net = nn.Dense(10, 1)
>>> loss_fn = nn.MSELoss()
>>> def forward(inputs, labels):
...     logits = net(inputs)
...     loss = loss_fn(logits, labels)
...     return loss, logits
>>> inputs = Tensor(np.random.randn(16, 10).astype(np.float32))
>>> labels = Tensor(np.random.randn(16, 1).astype(np.float32))
>>> weights = net.trainable_params()
>>> # Case 1: gradient with respect to inputs.
>>> # Aux value does not contribute to the gradient.
>>> grad_fn = grad(forward, grad_position=(0, 1), weights=None, has_aux=True)
>>> inputs_gradient, (aux_logits,) = grad_fn(inputs, labels)
>>> print(len(inputs_gradient))
2
>>> print(aux_logits.shape)
(16, 1)
>>>
>>> # Case 2: gradient with respect to weights.
>>> grad_fn = grad(forward, grad_position=None, weights=weights, has_aux=True)
>>> params_gradient, (aux_logits,) = grad_fn(inputs, labels)
>>> print(len(weights), len(params_gradient))
2 2
>>> print(aux_logits.shape)
(16, 1)
>>>
>>> # Case 3: gradient with respect to inputs and weights.
>>> grad_fn = grad(forward, grad_position=0, weights=weights, has_aux=False)
>>> inputs_gradient, params_gradient = grad_fn(inputs, labels)
>>> print(len(weights), len(params_gradient))
2 2