mindspore.vjp
- mindspore.vjp(fn, *inputs, weights=None, has_aux=False)[source]
Compute the vector-jacobian-product of the given network. vjp matches reverse-mode differentiation.
- Parameters
fn (Union[Function, Cell]) – The function or net that takes Tensor inputs and returns single Tensor or tuple of Tensors.
inputs (Union[Tensor, tuple[Tensor], list[Tensor]]) – The inputs to fn .
weights (Union[ParameterTuple, Parameter, list[Parameter]]) – The parameters of the training network that need to calculate the gradient. weights can be got through weights = net.trainable_params() . Default:
None
.has_aux (bool) – If True, only the first output of fn contributes the gradient of fn, while the other outputs will be returned straightly. It means the fn must return more than one outputs in this case. Default:
False
.
- Returns
Forward outputs and function to calculate vjp.
net_output (Union[Tensor, tuple[Tensor]]) - The output of fn(inputs). Specially, when has_aux is set True, netout is the first output of fn(inputs).
vjp_fn (Function) - To calculate vector-jacobian-product. Its inputs are the vectors whose shape and type should be the same as netout .
aux_value (Union[Tensor, tuple[Tensor]], optional) - When has_aux is True, aux_value will be returned. It means the second to last outputs of fn(inputs). Specially, aux_value does not contribute to gradient.
- Raises
TypeError – inputs or v does not belong to required types.
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import numpy as np >>> import mindspore.nn as nn >>> from mindspore import vjp >>> from mindspore import Tensor >>> class Net(nn.Cell): ... def construct(self, x, y): ... return x**3 + y >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) >>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) >>> outputs, vjp_fn = vjp(Net(), x, y) >>> print(outputs) [[ 2. 10.] [30. 68.]] >>> gradient = vjp_fn(v) >>> print(gradient) (Tensor(shape=[2, 2], dtype=Float32, value= [[ 3.00000000e+00, 1.20000000e+01], [ 2.70000000e+01, 4.80000000e+01]]), Tensor(shape=[2, 2], dtype=Float32, value= [[ 1.00000000e+00, 1.00000000e+00], [ 1.00000000e+00, 1.00000000e+00]])) >>> def fn(x, y): ... return 2 * x + y, y ** 3 >>> outputs, vjp_fn, aux = vjp(fn, x, y, has_aux=True) >>> gradient = vjp_fn(v) >>> print(outputs) [[ 3. 6.] [ 9. 12.]] >>> print(aux) [[ 1. 8.] [27. 64.]] >>> print(gradient) (Tensor(shape=[2, 2], dtype=Float32, value= [[ 2.00000000e+00, 2.00000000e+00], [ 2.00000000e+00, 2.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value= [[ 1.00000000e+00, 1.00000000e+00], [ 1.00000000e+00, 1.00000000e+00]]))