# 梯度求导 [](https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/source_zh_cn/migration_guide/model_development/gradient.md) ## 自动微分对比 MindSpore 和 PyTorch 都提供了自动微分功能,让我们在定义了正向网络后,可以通过简单的接口调用实现自动反向传播以及梯度更新。但需要注意的是,MindSpore 和 PyTorch 构建反向图的逻辑是不同的,这个差异也会带来 API 设计上的不同。
PyTorch的自动微分 | MindSpore的自动微分 |
```python # torch.autograd: # backward是累计的,更新完之后需清空optimizer import torch from torch.autograd import Variable x = Variable(torch.ones(2, 2), requires_grad=True) x = x * 2 y = x - 1 y.backward(x) ``` |
```python # ms.grad: # 使用grad接口,输入正向图,输出反向图 import mindspore as ms from mindspore import nn class GradNetWrtX(nn.Cell): def __init__(self, net): super(GradNetWrtX, self).__init__() self.net = net def construct(self, x, y): gradient_function = ms.grad(self.net) return gradient_function(x, y) ``` |
PyTorch | MindSpore |
```python # 在调用backward函数之前,x.grad和y.grad函数为空 # backward计算过后,x.grad和y.grad分别代表导数计算后的值 import torch print("=== tensor.backward ===") x = torch.tensor(1.0, requires_grad=True) y = torch.tensor(2.0, requires_grad=True) z = x**2+y print("x.grad before backward", x.grad) print("y.grad before backward", y.grad) z.backward() print("z", z) print("x.grad", x.grad) print("y.grad", y.grad) print("=== torch.autograd.backward ===") x = torch.tensor(1.0, requires_grad=True) y = torch.tensor(2.0, requires_grad=True) z = x**2+y torch.autograd.backward(z) print("z", z) print("x.grad", x.grad) print("y.grad", y.grad) ``` |
```python import mindspore print("=== mindspore.grad ===") x = mindspore.Tensor(1.0) y = mindspore.Tensor(2.0) def net(x, y): return x**2+y out = mindspore.grad(net, grad_position=0)(x, y) print("out", out) out1 = mindspore.grad(net, grad_position=1)(x, y) print("out1", out1) ``` |
运行结果: ```text === tensor.backward === x.grad before backward None y.grad before backward None z tensor(3., grad_fn= |
运行结果: ```text === mindspore.grad === out 2.0 out1 1.0 ``` |
PyTorch | MindSpore |
```python # 不支持多个输出 import torch print("=== torch.autograd.backward 不支持多个output ===") x = torch.tensor(1.0, requires_grad=True) y = torch.tensor(2.0, requires_grad=True) z = x**2+y torch.autograd.backward(z) print("z", z) print("x.grad", x.grad) print("y.grad", y.grad) ``` |
```python # 支持多个输出 import mindspore print("=== mindspore.grad 多个output ===") x = mindspore.Tensor(1.0) y = mindspore.Tensor(2.0) def net(x, y): return x**2+y, x out = mindspore.grad(net, grad_position=0)(x, y) print("out", out) out1 = mindspore.grad(net, grad_position=1)(x, y) print("out1", out) ``` |
运行结果: ```text === torch.autograd.backward 不支持多个output === z tensor(3., grad_fn= |
运行结果: ```text === mindspore.grad 多个output === out 3.0 out1 3.0 ``` |