mindspore.jvp
- mindspore.jvp(fn, inputs, v, has_aux=False)[source]
Compute the jacobian-vector-product of the given network. The calculation procedure of JVP can be found in forward-mode differentiation.
- Parameters
fn (Union[Function, Cell]) – The function or net that takes Tensor inputs and returns single Tensor or tuple of Tensors.
inputs (Union[Tensor, tuple[Tensor], list[Tensor]]) – The inputs to fn .
v (Union[Tensor, tuple[Tensor], list[Tensor]]) – The vector in jacobian-vector-product. The shape and type of v should be the same as inputs .
has_aux (bool) – If
True
, only the first output of fn contributes the gradient of fn, while the other outputs will be returned straightly. It means the fn must return more than one outputs in this case. Default:False
.
- Returns
net_output (Union[Tensor, tuple[Tensor]]) - The output of fn(inputs) . Specially, when has_aux is set
True
, netout is the first output of fn(inputs) .jvp (Union[Tensor, tuple[Tensor]]) - The result of jacobian-vector-product.
aux_value (Union[Tensor, tuple[Tensor]], optional) - When has_aux is
True
, aux_value will be returned. It means the second to last outputs of fn(inputs) . Specially, aux_value does not contribute to gradient.
- Raises
TypeError – inputs or v does not belong to required types.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import numpy as np >>> from mindspore import jvp >>> from mindspore import Tensor >>> import mindspore.nn as nn >>> class Net(nn.Cell): ... def construct(self, x, y): ... return x**3 + y >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) >>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) >>> output = jvp(Net(), (x, y), (v, v)) >>> print(output[0]) [[ 2. 10.] [30. 68.]] >>> print(output[1]) [[ 4. 13.] [28. 49.]] >>> >>> def fn(x, y): ... return x ** 3 + y, y >>> output, jvp_out, aux = jvp(fn, (x, y), (v, v), has_aux=True) >>> print(output) [[ 2. 10.] [30. 68.]] >>> print(jvp_out) [[ 4. 13.] [28. 49.]] >>> print(aux) [[ 1. 2.] [3. 4.]]