mindspore.value_and_grad
- mindspore.value_and_grad(fn, grad_position=0, weights=None, has_aux=False, return_ids=False)[source]
A wrapper function to generate the function to calculate forward output and gradient for the input function.
As for gradient, three typical cases are included:
gradient with respect to inputs. In this case, grad_position is not None while weights is
None
.gradient with respect to weights. In this case, grad_position is None while weights is not
None
.gradient with respect to inputs and weights. In this case, grad_position and weights are not
None
.
- Parameters
fn (Union[Cell, Function]) – Function to do GradOperation.
grad_position (Union[NoneType, int, tuple[int]], optional) –
Index to specify which inputs to be differentiated. Default:
0
.If int, get the gradient with respect to single input.
If tuple, get the gradients with respect to selected inputs. grad_position begins with 0.
If None, none derivative of any input will be solved, and in this case, weights is required.
weights (Union[ParameterTuple, Parameter, list[Parameter]], optional) – The parameters of the training network that need to calculate the gradient. weights can be got through weights = net.trainable_params() . Default:
None
.has_aux (bool, optional) – If
True
, only the first output of fn contributes the gradient of fn, while the other outputs will be returned straightly. It means the fn must return more than one outputs in this case. Default:False
.return_ids (bool, optional) – Whether the returned derivation function contains grad_position or weights information. If
True
, all gradient values in the returned derivation function will be replaced with: [gradient, grad_position] or [gradient, weights]. Default:False
.
- Returns
Function, the derivative function used to compute the gradient of a given function. For example, as for out1, out2 = fn(*args) , gradient function will return outputs like ((out1, out2), gradient) . When has_aux is set to
True
, only out1 contributes to the differentiation. If return_ids isTrue
, all gradient values in the returned derivation function will be replaced with: [gradient, grad_position] or [gradient, weights].- Raises
ValueError – If both grad_position and weights are
None
.TypeError – If type of Args does not belong to required ones.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import numpy as np >>> import mindspore >>> from mindspore import Tensor, ops, nn >>> from mindspore import value_and_grad >>> >>> # Cell object to be differentiated >>> class Net(nn.Cell): ... def construct(self, x, y, z): ... return x * y * z >>> x = Tensor([1, 2], mindspore.float32) >>> y = Tensor([-2, 3], mindspore.float32) >>> z = Tensor([0, 3], mindspore.float32) >>> net = Net() >>> grad_fn = value_and_grad(net, grad_position=1) >>> output, inputs_gradient = grad_fn(x, y, z) >>> print(output) [-0. 18.] >>> print(inputs_gradient) [0. 6.] >>> >>> # Function object to be differentiated >>> def fn(x, y, z): ... res = x * ops.exp(y) * ops.pow(z, 2) ... return res, z >>> x = Tensor(np.array([3, 3]).astype(np.float32)) >>> y = Tensor(np.array([0, 0]).astype(np.float32)) >>> z = Tensor(np.array([5, 5]).astype(np.float32)) >>> output, inputs_gradient = value_and_grad(fn, grad_position=(1, 2), weights=None, has_aux=True)(x, y, z) >>> print(output) (Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01, 7.50000000e+01]), Tensor(shape=[2], dtype=Float32, value= [ 5.00000000e+00, 5.00000000e+00])) >>> print(inputs_gradient) (Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01, 7.50000000e+01]), Tensor(shape=[2], dtype=Float32, value= [ 3.00000000e+01, 3.00000000e+01])) >>> >>> # For given network to be differentiated with both inputs and weights, there are 3 cases. >>> net = nn.Dense(10, 1) >>> loss_fn = nn.MSELoss() >>> def forward(inputs, labels): ... logits = net(inputs) ... loss = loss_fn(logits, labels) ... return loss, logits >>> inputs = Tensor(np.random.randn(16, 10).astype(np.float32)) >>> labels = Tensor(np.random.randn(16, 1).astype(np.float32)) >>> weights = net.trainable_params() >>> >>> # Case 1: gradient with respect to inputs. >>> # For has_aux is set True, only loss contributes to the gradient. >>> grad_fn = value_and_grad(forward, grad_position=0, weights=None, has_aux=True) >>> (loss, logits), inputs_gradient = grad_fn(inputs, labels) >>> print(logits.shape) (16, 1) >>> print(inputs.shape, inputs_gradient.shape) (16, 10) (16, 10) >>> >>> # Case 2: gradient with respect to weights. >>> # For has_aux is set True, only loss contributes to the gradient. >>> grad_fn = value_and_grad(forward, grad_position=None, weights=weights, has_aux=True) >>> (loss, logits), params_gradient = grad_fn(inputs, labels) >>> print(logits.shape) (16, 1) >>> print(len(weights), len(params_gradient)) 2 2 >>> >>> # Case 3: gradient with respect to inputs and weights. >>> # For has_aux is set False, both loss and logits contribute to the gradient. >>> grad_fn = value_and_grad(forward, grad_position=0, weights=weights, has_aux=False) >>> (loss, logits), (inputs_gradient, params_gradient) = grad_fn(inputs, labels) >>> print(logits.shape) (16, 1) >>> print(inputs.shape, inputs_gradient.shape) (16, 10) (16, 10) >>> print(len(weights), len(params_gradient)) 2 2