mindspore.vjp

查看源文件
mindspore.vjp(fn, inputs, weights=None, has_aux=False)[源代码]

计算给定网络的向量雅可比积(vector-jacobian-product, VJP)。VJP对应 反向模式自动微分

参数:
  • fn (Union[Function, Cell]) - 待求导的函数或网络。以Tensor为入参,返回Tensor或Tensor数组。

  • inputs (Union[Tensor, tuple[Tensor], list[Tensor]]) - 输入网络 fn 的入参。

  • weights (Union[ParameterTuple, Parameter, list[Parameter]]) - 训练网络中需要返回梯度的网络变量。一般可通过 weights = net.trainable_params() 获取。默认值: None

  • has_aux (bool) - 若 has_auxTrue ,只有 fn 的第一个输出参与 fn 的求导,其他输出将直接返回。此时, fn 的输出数量必须超过一个。默认值: False

返回:
  • net_output (Union[Tensor, tuple[Tensor]]) - 输入网络的正向计算结果。

  • vjp_fn (Function) - 用于求解向量雅可比积的函数。接收shape和type与 net_out 一致的输入。

  • aux_value (Union[Tensor, tuple[Tensor]], 可选) - 若 has_aux 为True,才返回 aux_valueaux_valuefn(inputs) 的第一个除外的其他输出,且不参与 fn 的求导。

异常:
  • TypeError - inputsv 类型不符合要求。

支持平台:

Ascend GPU CPU

样例:

>>> import numpy as np
>>> import mindspore.nn as nn
>>> from mindspore import vjp
>>> from mindspore import Tensor
>>> class Net(nn.Cell):
...     def construct(self, x, y):
...         return x**3 + y
>>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
>>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
>>> outputs, vjp_fn = vjp(Net(), x, y)
>>> print(outputs)
[[ 2. 10.]
 [30. 68.]]
>>> gradient = vjp_fn(v)
>>> print(gradient)
(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 3.00000000e+00,  1.20000000e+01],
 [ 2.70000000e+01,  4.80000000e+01]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 1.00000000e+00,  1.00000000e+00],
 [ 1.00000000e+00,  1.00000000e+00]]))
>>> def fn(x, y):
...     return 2 * x + y, y ** 3
>>> outputs, vjp_fn, aux = vjp(fn, x, y, has_aux=True)
>>> gradient = vjp_fn(v)
>>> print(outputs)
[[ 3.  6.]
 [ 9. 12.]]
>>> print(aux)
[[ 1.  8.]
 [27. 64.]]
>>> print(gradient)
(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 2.00000000e+00,  2.00000000e+00],
 [ 2.00000000e+00,  2.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 1.00000000e+00,  1.00000000e+00],
 [ 1.00000000e+00,  1.00000000e+00]]))