mindspore.value_and_grad

mindspore.value_and_grad(fn, grad_position=0, weights=None, has_aux=False)[源代码]

生成求导函数,用于计算给定函数的正向计算结果和梯度。

函数求导包含以下三种场景:

  1. 对输入求导,此时 grad_position 非None,而 weights 是None;

  2. 对网络变量求导,此时 grad_position 是None,而 weights 非None;

  3. 同时对输入和网络变量求导,此时 grad_positionweights 都非None。

参数:
  • fn (Union[Cell, Function]) - 待求导的函数或网络。

  • grad_position (Union[NoneType, int, tuple[int]]) - 指定求导输入位置的索引。若为int类型,表示对单个输入求导;若为tuple类型,表示对tuple内索引的位置求导,其中索引从0开始;若是None,表示不对输入求导,这种场景下, weights 非None。默认值:0。

  • weights (Union[ParameterTuple, Parameter, list[Parameter]]) - 训练网络中需要返回梯度的网络变量。一般可通过 weights = net.trainable_params() 获取。默认值:None。

  • has_aux (bool) - 是否返回辅助参数的标志。若为True, fn 输出数量必须超过一个,其中只有 fn 第一个输出参与求导,其他输出值将直接返回。默认值:False。

返回:

Function,用于计算给定函数的梯度的求导函数。例如 out1, out2 = fn(*args) ,梯度函数将返回 ((out1, out2), gradient) 形式的结果, 若 has_aux 为True,那么 out2 不参与求导。

异常:
  • ValueError - 入参 grad_positionweights 同时为None。

  • TypeError - 入参类型不符合要求。

支持平台:

Ascend GPU CPU

样例:

>>> import numpy as np
>>> import mindspore
>>> from mindspore import Tensor, ops, nn
>>> from mindspore import value_and_grad
>>>
>>> # Cell object to be differentiated
>>> class Net(nn.Cell):
...     def construct(self, x, y, z):
...         return x * y * z
>>> x = Tensor([1, 2], mindspore.float32)
>>> y = Tensor([-2, 3], mindspore.float32)
>>> z = Tensor([0, 3], mindspore.float32)
>>> net = Net()
>>> grad_fn = value_and_grad(net, grad_position=1)
>>> output, inputs_gradient = grad_fn(x, y, z)
>>> print(output)
[-0. 18.]
>>> print(inputs_gradient)
[0. 6.]
>>>
>>> # Function object to be differentiated
>>> def fn(x, y, z):
...     res = x * ops.exp(y) * ops.pow(z, 2)
...     return res, z
>>> x = Tensor(np.array([3, 3]).astype(np.float32))
>>> y = Tensor(np.array([0, 0]).astype(np.float32))
>>> z = Tensor(np.array([5, 5]).astype(np.float32))
>>> output, inputs_gradient = value_and_grad(fn, grad_position=(1, 2), weights=None, has_aux=True)(x, y, z)
>>> print(output)
(Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01,  7.50000000e+01]),
 Tensor(shape=[2], dtype=Float32, value= [ 5.00000000e+00,  5.00000000e+00]))
>>> print(inputs_gradient)
(Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01,  7.50000000e+01]),
 Tensor(shape=[2], dtype=Float32, value= [ 3.00000000e+01,  3.00000000e+01]))
>>>
>>> # For given network to be differentiated with both inputs and weights, there are 3 cases.
>>> net = nn.Dense(10, 1)
>>> loss_fn = nn.MSELoss()
>>> def forward(inputs, labels):
...     logits = net(inputs)
...     loss = loss_fn(logits, labels)
...     return loss, logits
>>> inputs = Tensor(np.random.randn(16, 10).astype(np.float32))
>>> labels = Tensor(np.random.randn(16, 1).astype(np.float32))
>>> weights = net.trainable_params()
>>>
>>> # Case 1: gradient with respect to inputs.
>>> # For has_aux is set True, only loss contributes to the gradient.
>>> grad_fn = value_and_grad(forward, grad_position=0, weights=None, has_aux=True)
>>> (loss, logits), inputs_gradient = grad_fn(inputs, labels)
>>> print(logits.shape)
(16, 1)
>>> print(inputs.shape, inputs_gradient.shape)
(16, 10) (16, 10)
>>>
>>> # Case 2: gradient with respect to weights.
>>> # For has_aux is set True, only loss contributes to the gradient.
>>> grad_fn = value_and_grad(forward, grad_position=None, weights=weights, has_aux=True)
>>> (loss, logits), params_gradient = grad_fn(inputs, labels)
>>> print(logits.shape)
(16, 1)
>>> print(len(weights), len(params_gradient))
2 2
>>>
>>> # Case 3: gradient with respect to inputs and weights.
>>> # For has_aux is set False, both loss and logits contribute to the gradient.
>>> grad_fn = value_and_grad(forward, grad_position=0, weights=weights, has_aux=False)
>>> (loss, logits), (inputs_gradient, params_gradient) = grad_fn(inputs, labels)
>>> print(logits.shape)
(16, 1)
>>> print(inputs.shape, inputs_gradient.shape)
(16, 10) (16, 10)
>>> print(len(weights), len(params_gradient))
2 2