mindspore.nn.MatrixSetDiag
- class mindspore.nn.MatrixSetDiag[源代码]
将输入的对角矩阵的对角线值置换为输入的对角线值。
假设 x 有 \(k+1\) 个维度 \([I,J,K,...,M,N]\) , diagonal 有 \(k\) 个维度 \([I, J, K, ..., min(M, N)]\) ,则输出秩为 \(k+1\) ,维度为 \([I, J, K, ..., M, N]\) 的Tensor,其中:
\[output[i, j, k, ..., m, n] = diagnoal[i, j, k, ..., n]\ for\ m == n\]\[output[i, j, k, ..., m, n] = x[i, j, k, ..., m, n]\ for\ m != n\]- 输入:
x (Tensor) - 输入的对角矩阵。秩为k+1,k大于等于1。支持如下数据类型:float32、float16、int32、int8和uint8。
diagonal (Tensor) - 输入的对角线值。必须与输入 x 的shape相同。秩为k,k大于等于1。
- 输出:
Tensor,shape和数据类型与输入 x 相同。
- 异常:
TypeError - x 或 diagonal 的数据类型不是float32、float16、int32、int8或uint8。
ValueError - x 的shape长度小于2。
ValueError - \(x\_shape[-2] < x\_shape[-1]\) 且 \(x\_shape[:-1] != diagonal\_shape\) 。
ValueError - \(x\_shape[-2] >= x\_shape[-1]\) 且 \(x\_shape[:-2] + x\_shape[-1:] != diagonal\_shape\) 。
- 支持平台:
Ascend
样例:
>>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32) >>> matrix_set_diag = nn.MatrixSetDiag() >>> output = matrix_set_diag(x, diagonal) >>> print(output) [[[-1. 0.] [ 0. 2.]] [[-1. 0.] [ 0. 1.]] [[-1. 0.] [ 0. 1.]]]