Gradient Operation

Ascend GPU CPU Model Development

View Source On Gitee

Overview

GradOperation is used to generate the gradient of the input function. The get_all, get_by_list, and sens_param parameters are used to control the gradient calculation method. For details, see mindspore API.

The following is an example of using GradOperation.

First-order Derivation

The first-order derivative method of MindSpore is mindspore.ops.GradOperation (get_all=False, get_by_list=False, sens_param=False). When get_all is set to False, the first input derivative is computed. When get_all is set to True, all input derivatives are computed. When get_by_list is set to False, weight derivation is not performed. When get_by_list is set to True, weight derivation is performed. sens_param scales the output value of the network to change the final gradient. Therefore, its dimension is consistent with the output dimension. The following uses the first-order derivation of the MatMul operator for in-depth analysis.

For details about the complete sample code, see First-order Derivation Sample Code.

Input Derivation

The input derivation code is as follows:

import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import ParameterTuple, Parameter
from mindspore import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.matmul = ops.MatMul()
        self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
    def construct(self, x, y):
        x = x * self.z
        out = self.matmul(x, y)
        return out

class GradNetWrtX(nn.Cell):
    def __init__(self, net):
        super(GradNetWrtX, self).__init__()
        self.net = net
        self.grad_op = ops.GradOperation()
    def construct(self, x, y):
        gradient_function = self.grad_op(self.net)
        return gradient_function(x, y)

x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
output = GradNetWrtX(Net())(x, y)
print(output)

The output is as follows:

[[4.5099998 2.7 3.6000001]
 [4.5099998 2.7 3.6000001]]

To facilitate analysis, inputs x, y, and z can be expressed as follows:

x = Tensor([[x1, x2, x3], [x4, x5, x6]])  
y = Tensor([[y1, y2, y3], [y4, y5, y6], [y7, y8, y9]])
z = Tensor([z])

The following forward result can be obtained based on the definition of the MatMul operator:

\(output = [[(x1 \cdot y1 + x2 \cdot y4 + x3 \cdot y7) \cdot z, (x1 \cdot y2 + x2 \cdot y5 + x3 \cdot y8) \cdot z, (x1 \cdot y3 + x2 \cdot y6 + x3 \cdot y9) \cdot z]\),

\([(x4 \cdot y1 + x5 \cdot y4 + x6 \cdot y7) \cdot z, (x4 \cdot y2 + x5 \cdot y5 + x6 \cdot y8) \cdot z, (x4 \cdot y3 + x5 \cdot y6 + x6 \cdot y9) \cdot z]]\)

MindSpore uses the Reverse[3] automatic differentiation mechanism during gradient computation. The output result is summed and then the derivative of the input x is computed.

(1) Summation formula:

\(\sum{output} = [(x1 \cdot y1 + x2 \cdot y4 + x3 \cdot y7) + (x1 \cdot y2 + x2 \cdot y5 + x3 \cdot y8) + (x1 \cdot y3 + x2 \cdot y6 + x3 \cdot y9) +\)

\((x4 \cdot y1 + x5 \cdot y4 + x6 \cdot y7) + (x4 \cdot y2 + x5 \cdot y5 + x6 \cdot y8) + (x4 \cdot y3 + x5 \cdot y6 + x6 \cdot y9)] \cdot z\)

(2) Derivation formula:

\(\frac{\mathrm{d}(\sum{output})}{\mathrm{d}x} = [[(y1 + y2 + y3) \cdot z, (y4 + y5 + y6) \cdot z, (y7 + y8 + y9) \cdot z], [(y1 + y2 + y3) \cdot z, (y4 + y5 + y6) \cdot z, (y7 + y8 + y9) \cdot z]]\)

(3) Computation result:

\(\frac{\mathrm{d}(\sum{output})}{\mathrm{d}x} = [[4.5099998 \quad 2.7 \quad 3.6000001] [4.5099998 \quad 2.7 \quad 3.6000001]]\)

If the derivatives of the x and y inputs are considered, you only need to set self.grad_op = GradOperation(get_all=True) in GradNetWrtX.

Weight Derivation

If the derivation of weights is considered, change GradNetWrtX to the following:

class GradNetWrtX(nn.Cell):
    def __init__(self, net):
        super(GradNetWrtX, self).__init__()
        self.net = net
        self.params = ParameterTuple(net.trainable_params())
        self.grad_op = ops.GradOperation(get_by_list=True)
    def construct(self, x, y):
        gradient_function = self.grad_op(self.net, self.params)
        return gradient_function(x, y)
output = GradNetWrtX(Net())(x, y)
print(output)

The output is as follows:

(Tensor(shape=[1], dtype=Float32, value= [ 2.15359993e+01]),)

The derivation formula is changed to:

\(\frac{\mathrm{d}(\sum{output})}{\mathrm{d}z} = (x1 \cdot y1 + x2 \cdot y4 + x3 \cdot y7) + (x1 \cdot y2 + x2 \cdot y5 + x3 \cdot y8) + (x1 \cdot y3 + x2 \cdot y6 + x3 \cdot y9) + \)

\((x4 \cdot y1 + x5 \cdot y4 + x6 \cdot y7) + (x4 \cdot y2 + x5 \cdot y5 + x6 \cdot y8) + (x4 \cdot y3 + x5 \cdot y6 + x6 \cdot y9)\)

Computation result

\(\frac{\mathrm{d}(\sum{output})}{\mathrm{d}z} = [2.15359993e+01]\)

Gradient Value Scaling

You can use the sens_param parameter to control the scaling of the gradient value.

class GradNetWrtX(nn.Cell):
    def __init__(self, net):
        super(GradNetWrtX, self).__init__()
        self.net = net
        self.grad_op = ops.GradOperation(sens_param=True)
        self.grad_wrt_output = Tensor([[0.1, 0.6, 0.2], [0.8, 1.3, 1.1]], dtype=mstype.float32)
    def construct(self, x, y):
        gradient_function = self.grad_op(self.net)
        return gradient_function(x, y, self.grad_wrt_output)
output = GradNetWrtX(Net())(x, y)  
print(output)

The output is as follows:

[[2.211 0.51 1.49 ]
 [5.588 2.68 4.07 ]]

self.grad_wrt_output may be denoted as the following form:

self.grad_wrt_output = Tensor([[s1, s2, s3], [s4, s5, s6]])

The output value after scaling is the product of the original output value and the element corresponding to self.grad_wrt_output.

\(output = [[(x1 \cdot y1 + x2 \cdot y4 + x3 \cdot y7) \cdot z \cdot s1, (x1 \cdot y2 + x2 \cdot y5 + x3 \cdot y8) \cdot z \cdot s2, (x1 \cdot y3 + x2 \cdot y6 + x3 \cdot y9) \cdot z \cdot s3], \)

\([(x4 \cdot y1 + x5 \cdot y4 + x6 \cdot y7) \cdot z \cdot s4, (x4 \cdot y2 + x5 \cdot y5 + x6 \cdot y8) \cdot z \cdot s5, (x4 \cdot y3 + x5 \cdot y6 + x6 \cdot y9) \cdot z \cdot s6]\)

The derivation formula is changed to compute the derivative of the sum of the output values to each element of x.

\(\frac{\mathrm{d}(\sum{output})}{\mathrm{d}x} = [[(s1 \cdot y1 + s2 \cdot y2 + s3 \cdot y3) \cdot z, (s1 \cdot y4 + s2 \cdot y5 + s3 \cdot y6) \cdot z, (s1 \cdot y7 + s2 \cdot y8 + s3 \cdot y9) \cdot z], \)

\([(s4 \cdot y1 + s5 \cdot y2 + s6 \cdot y3) \cdot z, (s4 \cdot y4 + s5 \cdot y5 + s6 \cdot y6) \cdot z, (s4 \cdot y7 + s5 \cdot y8 + s6 \cdot y9) \cdot z]\)

To compute the derivative of a single output (for example, output[0][0]) to the input, set the scaling value of the corresponding position to 1, and set the scaling values of other positions to 0. You can also change the network structure as follows:

class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.matmul = ops.MatMul()
        self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
    def construct(self, x, y):
        x = x * self.z
        out = self.matmul(x, y)
        return out[0][0]
output = GradNetWrtX(Net())(x, y)  
print(output)

The output is as follows:

[[0.11 1.1 1.1]
 [0.   0.  0. ]]

Stop Gradient

We can use stop_gradient to disable calculation of gradient for certain operators. For example:

import numpy as np
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import ParameterTuple, Parameter
from mindspore import dtype as mstype
from mindspore.ops import stop_gradient

class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.matmul = ops.MatMul()

    def construct(self, x, y):
        out1 = self.matmul(x, y)
        out2 = self.matmul(x, y)
        out2 = stop_gradient(out2)
        out = out1 + out2
        return out

class GradNetWrtX(nn.Cell):
    def __init__(self, net):
        super(GradNetWrtX, self).__init__()
        self.net = net
        self.grad_op = ops.GradOperation()

    def construct(self, x, y):
        gradient_function = self.grad_op(self.net)
        return gradient_function(x, y)

x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
output = GradNetWrtX(Net())(x, y)
print(output)
    [[4.5, 2.7, 3.6],
     [4.5, 2.7, 3.6]]

Here, we set stop_gradient to out2, so this operator does not have any contribution to gradient. If we delete out2 = stop_gradient(out2), the result is:

    [[9.0, 5.4, 7.2],
     [9.0, 5.4, 7.2]]

After we do not set stop_gradient to out2, it will make the same contribution to gradient as out1. So we can see that each result has doubled.

High-order Derivation

High-order differentiation is used in domains such as AI-supported scientific computing and second-order optimization. For example, in the molecular dynamics simulation, when the potential energy is trained using the neural network[1], the derivative of the neural network output to the input needs to be computed in the loss function, and then the second-order cross derivative of the loss function to the input and the weight exists in backward propagation. In addition, the second-order derivatives of the output to the input exist in differential equations solved by AI (such as PINNs[2]). Another example is that in order to enable the neural network to converge quickly in the second-order optimization, the second-order derivative of the loss function to the weight needs to be computed using the Newton method.

MindSpore can support high-order derivatives by computing derivatives for multiple times. The following uses several examples to describe how to compute derivatives.

For details about the complete sample code, see High-order Derivation Sample Code.

Single-input Single-output High-order Derivative

For example, the second-order derivative (-Sin) of the Sin operator is implemented as follows:

import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.sin = ops.Sin()
    def construct(self, x):
        out = self.sin(x)
        return out

class Grad(nn.Cell):
    def __init__(self, network):
        super(Grad, self).__init__()
        self.grad = ops.GradOperation()
        self.network = network
    def construct(self, x):
        gout= self.grad(self.network)(x)
        return gout
class GradSec(nn.Cell):
    def __init__(self, network):
        super(GradSec, self).__init__()
        self.grad = ops.GradOperation()
        self.network = network
    def construct(self, x):
        gout= self.grad(self.network)(x)
        return gout

net=Net()
firstgrad = Grad(net) # first order
secondgrad = GradSec(firstgrad) # second order
x_train = Tensor(np.array([1.0], dtype=np.float32))
output = secondgrad(x_train)
print(output)

The output is as follows:

[-0.841471]

Single-input Multi-output High-order Derivative

For example, for a multiplication operation with multiple outputs, a high-order derivative of the multiplication operation is as follows:

import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.mul = ops.Mul()
    def construct(self, x):
        out = self.mul(x, x)
        return out

class Grad(nn.Cell):
    def __init__(self, network):
        super(Grad, self).__init__()
        self.grad = ops.GradOperation(sens_param=False)
        self.network = network
    def construct(self, x):
        gout = self.grad(self.network)(x)
        return gout
class GradSec(nn.Cell):
    def __init__(self, network):
        super(GradSec, self).__init__()
        self.grad = ops.GradOperation(sens_param=False)
        self.network = network
    def construct(self, x):
        gout = self.grad(self.network)(x)
        return gout

net=Net()
firstgrad = Grad(net) # first order
secondgrad = GradSec(firstgrad) # second order
x = Tensor([0.1, 0.2, 0.3], dtype=mstype.float32)
output = secondgrad(x)
print(output)

The output is as follows:

[2. 2. 2.]

Multiple-Input Multiple-Output High-Order Derivative

For example, if a neural network has multiple inputs x and y, second-order derivatives dxdx, dydy, dxdy, and dydx may be obtained by using a gradient scaling mechanism as follows:

import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")

class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.mul = ops.Mul()

    def construct(self, x, y):
        x_square = self.mul(x, x)
        x_square_y = self.mul(x_square, y)
        return x_square_y

class Grad(nn.Cell):
    def __init__(self, network):
        super(Grad, self).__init__()
        self.grad = ops.GradOperation(get_all=True, sens_param=False)
        self.network = network
    def construct(self, x, y):
        gout = self.grad(self.network)(x, y) # return dx, dy
        return gout

class GradSec(nn.Cell):
    def __init__(self, network):
        super(GradSec, self).__init__()
        self.grad = ops.GradOperation(get_all=True, sens_param=True)
        self.network = network
        self.sens1 = Tensor(np.array([1]).astype('float32'))
        self.sens2 = Tensor(np.array([0]).astype('float32'))
    def construct(self, x, y):
        dxdx, dxdy = self.grad(self.network)(x, y, (self.sens1,self.sens2))
        dydx, dydy = self.grad(self.network)(x, y, (self.sens2,self.sens1))
        return dxdx, dxdy, dydx, dydy

net = Net()
firstgrad = Grad(net) # first order
secondgrad = GradSec(firstgrad) # second order
x_train = Tensor(np.array([4],dtype=np.float32))
y_train = Tensor(np.array([5],dtype=np.float32))
dxdx, dxdy, dydx, dydy = secondgrad(x_train, y_train)
print(dxdx, dxdy, dydx, dydy)

The output is as follows:

[10] [8.] [8.] [0.]

Specifically, results of computing the first-order derivatives are dx and dy. If dxdx is computed, only the first-order derivative dx needs to be retained, and scaling values corresponding to x and y are set to 1 and 0 respectively, that is, self.grad(self.network)(x, y, (self.sens1,self.sens2)). Similarly, if dydy is computed, only the first-order derivative dy is retained, and sens_param corresponding to x and y is set to 0 and 1, respectively, that is, self.grad(self.network)(x, y, (self.sens2,self.sens1)).

Support for Second-order Differential Operators

CPU supports the following operators: Square, Exp, Neg, Mul, and MatMul.

GPU supports the following operators: Pow, Log, Square, Exp, Neg, Mul, Div, MatMul, Sin, Cos, Tan and Atanh.

Ascend supports the following operators: Pow, Log, Square, Exp, Neg, Mul, Div, MatMul, Sin, Cos, Tan, Sinh, Cosh and Atanh.

Jvp and Vjp Interface

Besides GradOperation interface which is based on backward auto differentiation, MindSpore also provides two new gradient interfaces: Jvp and Vjp. Jvp is for forward mode AD and Vjp is for backward mode AD.

Jvp

Jvp(Jacobian-vector-product), uses forward mode AD, it is more suitable for network with smaller input dimension compared to output dimension. Different from backward mode AD, forward mode AD can get the output of network and the gradient at the same time. So, compared to backward AD, forward mode AD requires less memory. More information about the difference between forward mode AD and backward mode AD can be found in MindSpore Automatic Differentiation.

The example code is as follow:

import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.sin = ops.Sin()
        self.cos = ops.Cos()

    def construct(self, x, y):
        a = self.sin(x)
        b = self.cos(y)
        out = a + b
        return out

class GradNet(nn.Cell):
    def __init__(self, net):
        super(GradNet, self).__init__()
        self.net = net
        self.grad_op = nn.Jvp(net)

    def construct(self, x, y, v):
        output = self.grad_op(x, y, (v, v))
        return output

x = Tensor([0.8, 0.6, 0.2], dtype=mstype.float32)
y = Tensor([0.7, 0.4, 0.3], dtype=mstype.float32)
v = Tensor([1, 1, 1], dtype=mstype.float32)
output = GradNet(Net())(x, y, v)
print(output)

The output is:

(Tensor(shape=[3], dtype=Float32, value= [ 1.48219836e+00, 1.48570347e+00, 1.15400589e+00]), Tensor(shape=[3], dtype=Float32, value= [ 5.24890423e-02,
4.35917288e-01, 6.84546351e-01]))

Vjp

Vjp(Vector-jacobian-product), uses backward mode AD. The output of Vjp will be the network output and forward mode gradient output. It is more suitable for network with greater input dimension compared to output dimension. More information about the difference between forward mode AD and backward mode AD can be found in MindSpore Automatic Differentiation.

The example code is as follow:

import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.sin = ops.Sin()
        self.cos = ops.Cos()

    def construct(self, x, y):
        a = self.sin(x)
        b = self.cos(y)
        out = a + b
        return out

class GradNet(nn.Cell):
    def __init__(self, net):
        super(GradNet, self).__init__()
        self.net = net
        self.grad_op = nn.Vjp(net)

    def construct(self, x, y, v):
        output = self.grad_op(x, y, v)
        return output

x = Tensor([0.8, 0.6, 0.2], dtype=mstype.float32)
y = Tensor([0.7, 0.4, 0.3], dtype=mstype.float32)
v = Tensor([1, 1, 1], dtype=mstype.float32)
output = GradNet(Net())(x, y, v)
print(output)

The output is:

(Tensor(shape=[3], dtype=Float32, value= [ 1.48219836e+00, 1.48570347e+00, 1.15400589e+00]), (Tensor(shape=[3], dtype=Float32, value= [ 6.96706712e-01,
8.25335622e-01, 9.80066597e-01]), Tensor(shape=[3], dtype=Float32, value= [-6.44217670e-01, -3.89418334e-01, -2.95520216e-01])))

Functional Interfaces grad, jvp and vjp

The automatic differentiation plays an important role in the field of scientific computing, and functional interfaces are generally used in this field. In order to improve the usability of the automatic differentiation function, MindSpore provides functional interfaces of GradOperation, Jvp and Vjp: grad, jvp and vjp. The functional interface does not need object initialization, which fits the user’s habits.

functional grad

grad is used to generate the gradient of the input function. The grad_position, and sens_param parameters are used to control the gradient calculation method. The default value of grad_position is 0, which means the derivative of first input will be computed. When grad_position is set to int or tuple type, the derivative of corresponding inputs indexed by grad_position will be computed. sens_param scales the output value of the network to change the final gradient. The default value of sens_param is False.

Example:

The grad_position parameter controls the derivation of specific inputs.

import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import grad
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
    def construct(self, x, y, z):
        return x*y*z

x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
net = Net()
output = grad(net, grad_position=(1, 2))(x, y, z)
print(output)

results:

(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 0.00000000e+00,  6.00000000e+00],
 [ 1.50000000e+01, -4.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[-2.00000000e+00,  6.00000000e+00],
 [-3.00000000e+00,  8.00000000e+00]]))

Example:

The sens_param parameter decides whether to scale the output value of the network to change the final gradient.

import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import grad
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
    def construct(self, x, y, z):
        return x**2 + y**2 + z**2, x*y*z

x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
v = Tensor(np.array([[-1, 3], [2, 1]]).astype(np.float32))
net = Net()
output = grad(net, grad_position=(1, 2), sens_param=True)(x, y, z, (v, v))
print(output)

result:

(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 4.00000000e+00,  3.60000000e+01],
 [ 2.60000000e+01,  0.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 2.00000000e+00,  3.60000000e+01],
 [ 1.40000000e+01,  6.00000000e+00]]))

functional jvp

jvp corresponds to the automatic differentiation of the forward mode, and returns the result of the network and the differentiation of the network. The first element of tuple output is the result of the network and the second is the forward mode gradient output.

Example:

import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops import jvp
from mindspore import Tensor
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
    def construct(self, x, y):
        return x**3 + y

x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
output = jvp(Net(), (x, y), (v, v))
print(output)

results:

(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 2.00000000e+00, 1.00000000e+01],
[ 3.00000000e+01, 6.80000000e+01]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 4.00000000e+00, 1.30000000e+01],
[ 2.80000000e+01, 4.90000000e+01]]))

functional vjp

vjp corresponds to the automatic differentiation of the reverse mode, and returns the result of the network and the differentiation of the network. The first element of tuple output is the result of the network and the second is the backward mode gradient output.

Example:

import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops import vjp
from mindspore import Tensor
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
    def construct(self, x, y):
        return x**3 + y

x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
output = vjp(Net(), (x, y), v)
print(output)

results:

(Tensor(shape=[2, 2], dtype=Float32, value=
[[ 2.00000000e+00, 1.00000000e+01],
[ 3.00000000e+01, 6.80000000e+01]]), (Tensor(shape=[2, 2], dtype=Float32, value=
[[ 3.00000000e+00, 1.20000000e+01],
[ 2.70000000e+01, 4.80000000e+01]]), Tensor(shape=[2, 2], dtype=Float32, value=
[[ 1.00000000e+00, 1.00000000e+00],
[ 1.00000000e+00, 1.00000000e+00]])))

References

[1] Zhang L, Han J, Wang H, et al. Deep potential molecular dynamics: a scalable model with the accuracy of quantum mechanics[J]. Physical review letters, 2018, 120(14): 143001.

[2] Raissi M, Perdikaris P, Karniadakis G E. Physics informed deep learning (part i): Data-driven solutions of nonlinear partial differential equations[J]. arXiv preprint arXiv:1711.10561, 2017.

[3] Baydin A G, Pearlmutter B A, Radul A A, et al. Automatic differentiation in machine learning: a survey[J]. The Journal of Machine Learning Research, 2017, 18(1): 5595-5637.