mindspore.vjp
- mindspore.vjp(fn, inputs, weights=None, has_aux=False)[源代码]
计算给定网络的向量雅可比积(vector-jacobian-product, VJP)。VJP对应 反向模式自动微分。
- 参数:
fn (Union[Function, Cell]) - 待求导的函数或网络。以Tensor为入参,返回Tensor或Tensor数组。
inputs (Union[Tensor, tuple[Tensor], list[Tensor]]) - 输入网络 fn 的入参。
weights (Union[ParameterTuple, Parameter, list[Parameter]]) - 训练网络中需要返回梯度的网络变量。一般可通过 weights = net.trainable_params() 获取。默认值:
None
。has_aux (bool) - 若 has_aux 为
True
,只有 fn 的第一个输出参与 fn 的求导,其他输出将直接返回。此时, fn 的输出数量必须超过一个。默认值:False
。
- 返回:
正向输出和计算 vjp 的功能。
net_output (Union[Tensor, tuple[Tensor]]) - fn(inputs) 的输出。特别是当 has_aux 设置为
True
时, net_output 是 fn(inputs) 的第一个输出。vjp_fn (Function) - 用于求解向量雅可比积的函数。接收shape和type与 net_output 一致的输入。
aux_value (Union[Tensor, tuple[Tensor]], 可选) - 若 has_aux 为True,才返回 aux_value 。 aux_value 是 fn(inputs) 的第一个除外的其他输出,且不参与 fn 的求导。
- 异常:
TypeError - inputs 或 v 类型不符合要求。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import numpy as np >>> import mindspore.nn as nn >>> from mindspore import vjp >>> from mindspore import Tensor >>> class Net(nn.Cell): ... def construct(self, x, y): ... return x**3 + y >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) >>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32)) >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32)) >>> outputs, vjp_fn = vjp(Net(), x, y) >>> print(outputs) [[ 2. 10.] [30. 68.]] >>> gradient = vjp_fn(v) >>> print(gradient) (Tensor(shape=[2, 2], dtype=Float32, value= [[ 3.00000000e+00, 1.20000000e+01], [ 2.70000000e+01, 4.80000000e+01]]), Tensor(shape=[2, 2], dtype=Float32, value= [[ 1.00000000e+00, 1.00000000e+00], [ 1.00000000e+00, 1.00000000e+00]])) >>> def fn(x, y): ... return 2 * x + y, y ** 3 >>> outputs, vjp_fn, aux = vjp(fn, x, y, has_aux=True) >>> gradient = vjp_fn(v) >>> print(outputs) [[ 3. 6.] [ 9. 12.]] >>> print(aux) [[ 1. 8.] [27. 64.]] >>> print(gradient) (Tensor(shape=[2, 2], dtype=Float32, value= [[ 2.00000000e+00, 2.00000000e+00], [ 2.00000000e+00, 2.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value= [[ 1.00000000e+00, 1.00000000e+00], [ 1.00000000e+00, 1.00000000e+00]]))