比较与torch.diag的功能差异
torch.diag
torch.diag(
input,
diagonal=0,
out=None
)
更多内容详见torch.diag。
mindspore.nn.MatrixDiag
class mindspore.nn.MatrixDiag()(x)
更多内容详见mindspore.nn.MatrixDiag。
使用方式
PyTorch:仅支持1D和2D,如果输入是1D,则将返回一个2D的对角矩阵,除对角线外,均置0。如果输入是2D,则返回该矩阵的对角线上的值。同时,它支持通过参数diagonal
指定对角线偏移量。
MindSpore:根据给定的值返回一个对角矩阵,对于k维的输入,将返回k+1维的对角矩阵。
代码示例
import mindspore
from mindspore import Tensor, nn
import torch
import numpy as np
x1 = np.random.randn(2)
x2 = np.random.randn(2, 3)
x3 = np.random.randn(2, 3, 4)
# In MindSpore, for the given k-dimension input, a k+1 dimension diagonal matrix will be returned.
matrix_diag = nn.MatrixDiag()
for n, x in enumerate([x1, x2, x3]):
try:
input_x = Tensor(x, mindspore.float32)
output = matrix_diag(input_x)
print('input shape: {}; output size: {}'.format(
str(n + 1), str(output.shape)
))
except Exception as e:
print('ERROR: ' + str(e))
# Out:
# input shape: 1; output size: (2, 2)
# input shape: 2; output size: (2, 3, 3)
# input shape: 3; output size: (2, 3, 4, 4)
# In torch, output for 1-dimension and 2-dimension input will be returned based on different rules.
# If the dimension of the input is greater than 2, it will raise error.
for n, x in enumerate([x1, x2, x3]):
try:
input_x = torch.tensor(x)
output = torch.diag(input_x)
print('input shape: {}; output size: {}'.format(
str(n + 1), str(output.shape)
))
except Exception as e:
print('ERROR: ' + str(e))
# Out:
# input shape: 1; output size: torch.Size([2, 2])
# input shape: 2; output size: torch.Size([2])
# ERROR: matrix or a vector expected