mindspore.nn.MatrixSetDiag

class mindspore.nn.MatrixSetDiag[源代码]

将输入的对角矩阵的对角线值置换为输入的对角线值。

假设 x\(k+1\) 个维度 \([I,J,K,...,M,N]\)diagonal\(k\) 个维度 \([I, J, K, ..., min(M, N)]\) ,则输出秩为 \(k+1\) ,维度为 \([I, J, K, ..., M, N]\) 的Tensor,其中:

\[output[i, j, k, ..., m, n] = diagnoal[i, j, k, ..., n]\ for\ m == n\]
\[output[i, j, k, ..., m, n] = x[i, j, k, ..., m, n]\ for\ m != n\]

输入:

  • x (Tensor) - 输入的对角矩阵。秩为k+1,k大于等于1。支持如下数据类型:float32、float16、int32、int8和uint8。

  • diagonal (Tensor) - 输入的对角线值。必须与输入 x 的shape相同。秩为k,k大于等于1。

输出:

Tensor,shape和数据类型与输入 x 相同。

异常:

  • TypeError - xdiagonal 的数据类型不是float32、float16、int32、int8或uint8。

  • ValueError - x 的shape长度小于2。

  • ValueError - \(x\_shape[-2] < x\_shape[-1]\)\(x\_shape[:-1] != diagonal\_shape\)

  • ValueError - \(x\_shape[-2] >= x\_shape[-1]\)\(x\_shape[:-2] + x\_shape[-1:] != diagonal\_shape\)

支持平台:

Ascend

样例:

>>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
>>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
>>> matrix_set_diag = nn.MatrixSetDiag()
>>> output = matrix_set_diag(x, diagonal)
>>> print(output)
[[[-1.  0.]
  [ 0.  2.]]
 [[-1.  0.]
  [ 0.  1.]]
 [[-1.  0.]
  [ 0.  1.]]]