比较与torch.diag的差异

查看源文件

以下映射关系均可参考本文。

PyTorch APIs

MindSpore APIs

torch.diag

mindspore.ops.diag

torch.Tensor.diag

mindspore.Tensor.diag

torch.diag

torch.diag(input, diagonal=0, *, out=None) -> Tensor

更多内容详见torch.diag

mindspore.ops.diag

mindspore.ops.diag(input) -> Tensor

更多内容详见mindspore.ops.diag

差异对比

MindSpore此API功能与PyTorch不一致。

PyTorch:若输入为一维张量,用输入的对角线值构成的一维张量来构造对角线张量;若输入为矩阵,则返回由输入的对角线元素组成的一维张量。

MindSpore:MindSpore此API,若输入为一维张量,则实现与PyTorch相同的功能;若输入为矩阵,则不能实现与PyTorch相同的功能,且没有diagonal参数控制对角线的位置。

分类

子类

PyTorch

MindSpore

差异

参数

参数1

input

input

-

参数2

diagonal

-

PyTorch中diagonal的取值用于控制对角线的位置,MindSpore无此参数

参数3

out

-

详见通用差异参数表

代码示例1

PyTorch的此API参数x支持多维张量和一维张量,且存在diagonal参数用于控制对角线的位置,而MindSpore此API不存在diagonal参数;当输入参数x为一维张量且diagonal为0时,两API实现相同的功能。

# PyTorch
import torch
x = torch.tensor([1,2,3,4],dtype=int)
out = torch.diag(x)
out = out.detach().numpy()
print(out)
# [[1 0 0 0]
#  [0 2 0 0]
#  [0 0 3 0]
#  [0 0 0 4]]

# MindSpore
from mindspore import Tensor
import mindspore.ops as ops
input_x = Tensor([1, 2, 3, 4]).astype('int32')
output = ops.diag(input_x)
print(output)
# [[1 0 0 0]
#  [0 2 0 0]
#  [0 0 3 0]
#  [0 0 0 4]]

代码示例2

当输入参数x为一维张量且diagonal不为0时,PyTorch的此API可控制对角线的位置,而MindSpore的此API没有diagonal参数,可以将此API得到的输出使用mindspore.ops.pad进行处理,从而实现相同功能。

# PyTorch
import torch
x = torch.tensor([1,2,3,4],dtype=int)
# diagonal大于0时的结果
out = torch.diag(x, diagonal=1)
out = out.detach().numpy()
print(out)
# [[0 1 0 0 0]
#  [0 0 2 0 0]
#  [0 0 0 3 0]
#  [0 0 0 0 4]
#  [0 0 0 0 0]]

# diagonal小于0时的结果
out = torch.diag(x, diagonal=-1)
out = out.detach().numpy()
print(out)
# [[0 0 0 0 0]
#  [1 0 0 0 0]
#  [0 2 0 0 0]
#  [0 0 3 0 0]
#  [0 0 0 4 0]]

# MindSpore
from mindspore import Tensor
import mindspore.ops as ops
input_x = Tensor([1, 2, 3, 4]).astype('int32')
output = ops.diag(input_x)
# MindSpore对应于diagonal大于0时的此API功能实现
padding = ((1, 0, 0, 1))
a = ops.pad(output, padding)
print(a)
# [[0 1 0 0 0]
#  [0 0 2 0 0]
#  [0 0 0 3 0]
#  [0 0 0 0 4]
#  [0 0 0 0 0]]

# MindSpore对应于diagonal小于0时的此API功能实现
padding = ((0, 1, 1, 0))
a = ops.pad(output, padding)
print(a)
# [[0 0 0 0 0]
#  [1 0 0 0 0]
#  [0 2 0 0 0]
#  [0 0 3 0 0]
#  [0 0 0 4 0]]

代码示例3

PyTorch的此API输入为矩阵且使用diagonal时用于提取对角线组成的一维张量,MindSpore此API不支持此功能,使用mindspore.numpy.diag算子可实现此功能。

# PyTorch
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]],dtype=int)
# diagonal大于0时的结果
out = torch.diag(x, diagonal=1)
out = out.detach().numpy()
print(out)
# [2 6]

# diagonal为默认值0时的结果
out = torch.diag(x)
out = out.detach().numpy()
print(out)
# [1 5 9]

# diagonal小于0时的结果
out = torch.diag(x, diagonal=-1)
out = out.detach().numpy()
print(out)
# [4 8]

# MindSpore
from mindspore import Tensor
import mindspore.ops as ops
import mindspore.numpy as np
input_x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype('int32')
# 对应于diagonal大于0时的mindspore.numpy.diag的此功能实现
output = np.diag(input_x, k=1)
print(output)
# [2 6]

# 对应于diagonal默认为0时的mindspore.numpy.diag的此功能实现
output = np.diag(input_x)
print(output)
# [1 5 9]

# 对应于diagonal小于0时的mindspore.numpy.diag的此功能实现
output = np.diag(input_x, k=-1)
print(output)
# [4 8]