# 比较与torch.diag的差异

[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.1/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/source_zh_cn/note/api_mapping/pytorch_diff/diag.md)

以下映射关系均可参考本文。

|     PyTorch APIs      |      MindSpore APIs       |
| :-------------------: | :-----------------------: |
|    torch.diag     |  mindspore.ops.diag   |
|   torch.Tensor.diag    |   mindspore.Tensor.diag    |

## torch.diag

```text
torch.diag(input, diagonal=0, *, out=None) -> Tensor
```

更多内容详见[torch.diag](https://pytorch.org/docs/1.8.1/generated/torch.diag.html)。

## mindspore.ops.diag

```text
mindspore.ops.diag(input) -> Tensor
```

更多内容详见[mindspore.ops.diag](https://www.mindspore.cn/docs/zh-CN/r2.1/api_python/ops/mindspore.ops.diag.html)。

## 差异对比

MindSpore此API功能与PyTorch不一致。

PyTorch:若输入为一维张量,用输入的对角线值构成的一维张量来构造对角线张量;若输入为矩阵,则返回由输入的对角线元素组成的一维张量。

MindSpore:MindSpore此API,若输入为一维张量,则实现与PyTorch相同的功能;若输入为矩阵,则不能实现与PyTorch相同的功能,且没有`diagonal`参数控制对角线的位置。

| 分类 | 子类 |PyTorch | MindSpore | 差异 |
| --- | --- | --- | --- |---|
|参数 | 参数1 | input | input |- |
| | 参数2 | diagonal | - | PyTorch中`diagonal`的取值用于控制对角线的位置,MindSpore无此参数 |
| | 参数3 | out | - | 详见[通用差异参数表](https://www.mindspore.cn/docs/zh-CN/r2.1/note/api_mapping/pytorch_api_mapping.html#通用差异参数表) |

### 代码示例1

> PyTorch的此API参数`x`支持多维张量和一维张量,且存在`diagonal`参数用于控制对角线的位置,而MindSpore此API不存在`diagonal`参数;当输入参数x为一维张量且`diagonal`为0时,两API实现相同的功能。

```python
# PyTorch
import torch
x = torch.tensor([1,2,3,4],dtype=int)
out = torch.diag(x)
out = out.detach().numpy()
print(out)
# [[1 0 0 0]
#  [0 2 0 0]
#  [0 0 3 0]
#  [0 0 0 4]]

# MindSpore
from mindspore import Tensor
import mindspore.ops as ops
input_x = Tensor([1, 2, 3, 4]).astype('int32')
output = ops.diag(input_x)
print(output)
# [[1 0 0 0]
#  [0 2 0 0]
#  [0 0 3 0]
#  [0 0 0 4]]

```

### 代码示例2

> 当输入参数`x`为一维张量且`diagonal`不为0时,PyTorch的此API可控制对角线的位置,而MindSpore的此API没有`diagonal`参数,可以将此API得到的输出使用mindspore.ops.pad进行处理,从而实现相同功能。

```python
# PyTorch
import torch
x = torch.tensor([1,2,3,4],dtype=int)
# diagonal大于0时的结果
out = torch.diag(x, diagonal=1)
out = out.detach().numpy()
print(out)
# [[0 1 0 0 0]
#  [0 0 2 0 0]
#  [0 0 0 3 0]
#  [0 0 0 0 4]
#  [0 0 0 0 0]]

# diagonal小于0时的结果
out = torch.diag(x, diagonal=-1)
out = out.detach().numpy()
print(out)
# [[0 0 0 0 0]
#  [1 0 0 0 0]
#  [0 2 0 0 0]
#  [0 0 3 0 0]
#  [0 0 0 4 0]]

# MindSpore
from mindspore import Tensor
import mindspore.ops as ops
input_x = Tensor([1, 2, 3, 4]).astype('int32')
output = ops.diag(input_x)
# MindSpore对应于diagonal大于0时的此API功能实现
padding = ((1, 0, 0, 1))
a = ops.pad(output, padding)
print(a)
# [[0 1 0 0 0]
#  [0 0 2 0 0]
#  [0 0 0 3 0]
#  [0 0 0 0 4]
#  [0 0 0 0 0]]

# MindSpore对应于diagonal小于0时的此API功能实现
padding = ((0, 1, 1, 0))
a = ops.pad(output, padding)
print(a)
# [[0 0 0 0 0]
#  [1 0 0 0 0]
#  [0 2 0 0 0]
#  [0 0 3 0 0]
#  [0 0 0 4 0]]
```

### 代码示例3

> PyTorch的此API输入为矩阵且使用`diagonal`时用于提取对角线组成的一维张量,MindSpore此API不支持此功能,使用mindspore.numpy.diag算子可实现此功能。

```python
# PyTorch
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]],dtype=int)
# diagonal大于0时的结果
out = torch.diag(x, diagonal=1)
out = out.detach().numpy()
print(out)
# [2 6]

# diagonal为默认值0时的结果
out = torch.diag(x)
out = out.detach().numpy()
print(out)
# [1 5 9]

# diagonal小于0时的结果
out = torch.diag(x, diagonal=-1)
out = out.detach().numpy()
print(out)
# [4 8]

# MindSpore
from mindspore import Tensor
import mindspore.ops as ops
import mindspore.numpy as np
input_x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype('int32')
# 对应于diagonal大于0时的mindspore.numpy.diag的此功能实现
output = np.diag(input_x, k=1)
print(output)
# [2 6]

# 对应于diagonal默认为0时的mindspore.numpy.diag的此功能实现
output = np.diag(input_x)
print(output)
# [1 5 9]

# 对应于diagonal小于0时的mindspore.numpy.diag的此功能实现
output = np.diag(input_x, k=-1)
print(output)
# [4 8]
```