mindspore.nn.MatrixDiag

class mindspore.nn.MatrixDiag[源代码]

根据对角线值返回一批对角矩阵。

假设 x\(k\) 个维度 \([I, J, K, ..., N]\) ,则输出秩为 \(k+1\) 且维度为 \([I, J, K, ..., N, N]\) 的Tensor,其中: \(output[i, j, k, ..., m, n] = 1\{m=n\} * x[i, j, k, ..., n]\)

输入:
  • x (Tensor) - 输入任意维度的对角线值。支持的数据类型包括:float32、float16、int32、int8和uint8。

输出:

Tensor,shape与输入 x 相同。Shape必须为 \(x.shape + (x.shape[-1], )\)

异常:
  • TypeError - x 的数据类型不是float32、float16、int32、int8或uint8。

支持平台:

Ascend

样例:

>>> x = Tensor(np.array([1, -1]), mindspore.float32)
>>> matrix_diag = nn.MatrixDiag()
>>> output = matrix_diag(x)
>>> print(x.shape)
(2,)
>>> print(output)
[[ 1.  0.]
 [ 0. -1.]]
>>> print(output.shape)
(2, 2)
>>> x = Tensor(np.array([[1, -1], [1, -1]]), mindspore.float32)
>>> matrix_diag = nn.MatrixDiag()
>>> output = matrix_diag(x)
>>> print(x.shape)
(2, 2)
>>> print(output)
[[[ 1.  0.]
  [ 0. -1.]]
 [[ 1.  0.]
  [ 0. -1.]]]
>>> print(output.shape)
(2, 2, 2)
>>> x = Tensor(np.array([[1, -1, 1], [1, -1, 1]]), mindspore.float32)
>>> matrix_diag = nn.MatrixDiag()
>>> output = matrix_diag(x)
>>> print(x.shape)
(2, 3)
>>> print(output)
[[[ 1.  0.  0.]
  [ 0. -1.  0.]
  [ 0.  0.  1.]]
 [[ 1.  0.  0.]
  [ 0. -1.  0.]
  [ 0.  0.  1.]]]
>>> print(output.shape)
(2, 3, 3)