mindspore.jvp

mindspore.jvp(fn, inputs, v, has_aux=False)[source]

Compute the jacobian-vector-product of the given network. The calculation procedure of JVP can be found in forward-mode differentiation.

Parameters
  • fn (Union[Function, Cell]) – The function or net that takes Tensor inputs and returns single Tensor or tuple of Tensors.

  • inputs (Union[Tensor, tuple[Tensor], list[Tensor]]) – The inputs to fn .

  • v (Union[Tensor, tuple[Tensor], list[Tensor]]) – The vector in jacobian-vector-product. The shape and type of v should be the same as inputs .

  • has_aux (bool) – If True , only the first output of fn contributes the gradient of fn, while the other outputs will be returned straightly. It means the fn must return more than one outputs in this case. Default: False .

Returns

  • net_output (Union[Tensor, tuple[Tensor]]) - The output of fn(inputs) . Specially, when has_aux is set True , netout is the first output of fn(inputs) .

  • jvp (Union[Tensor, tuple[Tensor]]) - The result of jacobian-vector-product.

  • aux_value (Union[Tensor, tuple[Tensor]], optional) - When has_aux is True , aux_value will be returned. It means the second to last outputs of fn(inputs) . Specially, aux_value does not contribute to gradient.

Raises

TypeErrorinputs or v does not belong to required types.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import numpy as np
>>> from mindspore import jvp
>>> from mindspore import Tensor
>>> import mindspore.nn as nn
>>> class Net(nn.Cell):
...     def construct(self, x, y):
...         return x**3 + y
>>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
>>> output = jvp(Net(), (x, y), (v, v))
>>> print(output[0])
[[ 2. 10.]
 [30. 68.]]
>>> print(output[1])
[[ 4. 13.]
 [28. 49.]]
>>>
>>> def fn(x, y):
...     return x ** 3 + y, y
>>> output, jvp_out, aux = jvp(fn, (x, y), (v, v), has_aux=True)
>>> print(output)
[[ 2. 10.]
 [30. 68.]]
>>> print(jvp_out)
[[ 4. 13.]
 [28. 49.]]
>>> print(aux)
[[ 1. 2.]
 [3. 4.]]