mindspore.nn.MatrixDiag
- class mindspore.nn.MatrixDiag[源代码]
根据对角线值返回一批对角矩阵。
假设 x 有 \(k\) 个维度 \([I, J, K, ..., N]\) ,则输出秩为 \(k+1\) 且维度为 \([I, J, K, ..., N, N]\) 的Tensor,其中: \(output[i, j, k, ..., m, n] = 1\{m=n\} * x[i, j, k, ..., n]\) 。
- 输入:
x (Tensor) - 输入任意维度的对角线值。支持的数据类型包括:float32、float16、int32、int8和uint8。
- 输出:
Tensor,shape与输入 x 相同。Shape必须为 \(x.shape + (x.shape[-1], )\) 。
- 异常:
TypeError - x 的数据类型不是float32、float16、int32、int8或uint8。
- 支持平台:
Ascend
样例:
>>> x = Tensor(np.array([1, -1]), mindspore.float32) >>> matrix_diag = nn.MatrixDiag() >>> output = matrix_diag(x) >>> print(x.shape) (2,) >>> print(output) [[ 1. 0.] [ 0. -1.]] >>> print(output.shape) (2, 2) >>> x = Tensor(np.array([[1, -1], [1, -1]]), mindspore.float32) >>> matrix_diag = nn.MatrixDiag() >>> output = matrix_diag(x) >>> print(x.shape) (2, 2) >>> print(output) [[[ 1. 0.] [ 0. -1.]] [[ 1. 0.] [ 0. -1.]]] >>> print(output.shape) (2, 2, 2) >>> x = Tensor(np.array([[1, -1, 1], [1, -1, 1]]), mindspore.float32) >>> matrix_diag = nn.MatrixDiag() >>> output = matrix_diag(x) >>> print(x.shape) (2, 3) >>> print(output) [[[ 1. 0. 0.] [ 0. -1. 0.] [ 0. 0. 1.]] [[ 1. 0. 0.] [ 0. -1. 0.] [ 0. 0. 1.]]] >>> print(output.shape) (2, 3, 3)