mindspore.vjp

View Source On Gitee
mindspore.vjp(fn, *inputs, weights=None, has_aux=False)[source]

Compute the vector-jacobian-product of the given network. vjp matches reverse-mode differentiation.

Parameters
  • fn (Union[Function, Cell]) – The function or net that takes Tensor inputs and returns single Tensor or tuple of Tensors.

  • inputs (Union[Tensor, tuple[Tensor], list[Tensor]]) – The inputs to fn .

  • weights (Union[ParameterTuple, Parameter, list[Parameter]]) – The parameters of the training network that need to calculate the gradient. weights can be got through weights = net.trainable_params() . Default: None .

  • has_aux (bool) – If True, only the first output of fn contributes the gradient of fn, while the other outputs will be returned straightly. It means the fn must return more than one outputs in this case. Default: False.

Returns

Forward outputs and function to calculate vjp.

  • net_output (Union[Tensor, tuple[Tensor]]) - The output of fn(inputs). Specially, when has_aux is set to True, net_output is the first output of fn(inputs).

  • vjp_fn (Function) - To calculate vector-jacobian-product. Its inputs are the vectors whose shape and type should be the same as net_output .

  • aux_value (Union[Tensor, tuple[Tensor]], optional) - When has_aux is True, aux_value will be returned. It means the second to last outputs of fn(inputs). Specially, aux_value does not contribute to gradient.

Raises

TypeErrorinputs or v does not belong to required types.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import vjp
>>> from mindspore import Tensor
>>> class Net(nn.Cell):
...     def construct(self, x, y):
...         return x**3 + y
>>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
>>> outputs, vjp_fn = vjp(Net(), x, y)
>>> print(outputs)
[[ 2. 10.]
 [30. 68.]]
>>> gradient = vjp_fn(v)
>>> print(gradient)
(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 3.00000000e+00,  1.20000000e+01],
 [ 2.70000000e+01,  4.80000000e+01]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 1.00000000e+00,  1.00000000e+00],
 [ 1.00000000e+00,  1.00000000e+00]]))
>>> def fn(x, y):
...     return 2 * x + y, y ** 3
>>> outputs, vjp_fn, aux = vjp(fn, x, y, has_aux=True)
>>> gradient = vjp_fn(v)
>>> print(outputs)
[[ 3.  6.]
 [ 9. 12.]]
>>> print(aux)
[[ 1.  8.]
 [27. 64.]]
>>> print(gradient)
(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 2.00000000e+00,  2.00000000e+00],
 [ 2.00000000e+00,  2.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 1.00000000e+00,  1.00000000e+00],
 [ 1.00000000e+00,  1.00000000e+00]]))