比较与torch.autograd.backward和torch.autograd.grad的功能差异

查看源文件

torch.autograd.backward

torch.autograd.backward(
  tensors,
  grad_tensors=None,
  retain_graph=None,
  create_graph=False,
  grad_variables=None
)

更多内容详见torch.autograd.backward

torch.autograd.grad

torch.autograd.grad(
  outputs,
  inputs,
  grad_outputs=None,
  retain_graph=None,
  create_graph=False,
  only_inputs=True,
  allow_unused=False
)

更多内容详见torch.autograd.grad

mindspore.ops.GradOperation

class mindspore.ops.GradOperation(
  get_all=False,
  get_by_list=False,
  sens_param=False
)

更多内容详见mindspore.ops.GradOperation

使用方式

PyTorch:使用torch.autograd.backward计算给定Tensor关于叶子节点的梯度总和,反向传播计算Tensor的梯度时,只计算requires_grad=True的叶子节点的梯度。使用torch.autograd.grad计算并返回输出关于输入的梯度总和,如果only_inputs为True,仅返回与指定输入相关的梯度列表。

MindSpore:计算梯度,其中get_all为False时,只会对第一个输入求导,为True时,会对所有输入求导;get_by_list为False时,不会对权重求导,为True时,会对权重求导;sens_param对网络的输出值做缩放以改变最终梯度。

代码示例

import numpy as np
import mindspore.nn as nn
import mindspore as ms
from mindspore import ops

# In MindSpore:
class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.matmul = ops.MatMul()
        self.z = ms.Parameter(ms.Tensor(np.array([1.0], np.float32)), name='z')
    def construct(self, x, y):
        x = x * self.z
        out = self.matmul(x, y)
        return out

class GradNetWrtX(nn.Cell):
    def __init__(self, net):
        super(GradNetWrtX, self).__init__()
        self.net = net
        self.grad_op = ops.GradOperation()
    def construct(self, x, y):
        gradient_function = self.grad_op(self.net)
        return gradient_function(x, y)

x = ms.Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=ms.float32)
y = ms.Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=ms.float32)
output = GradNetWrtX(Net())(x, y)
print(output)
# Out:
# [[1.4100001 1.5999999 6.6      ]
#  [1.4100001 1.5999999 6.6      ]]

# In torch:
import torch
x = torch.tensor(2., requires_grad=True)
y = torch.tensor(3., requires_grad=True)
z = x * x * y
z.backward()
print(x.grad, y.grad)
# Out:
# tensor(12.) tensor(4.)

x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()
z = x * x * y
grad_x = torch.autograd.grad(outputs=z, inputs=x)
print(grad_x[0])
# Out:
# tensor(12.)