mindspore.nn.MatrixSetDiag

class mindspore.nn.MatrixSetDiag[源代码]

将输入的对角矩阵的对角线值置换为输入的对角线值。

假设 xk+1 个维度 [I,J,K,...,M,N]diagonalk 个维度 [I,J,K,...,min(M,N)] ,则输出秩为 k+1 ,维度为 [I,J,K,...,M,N] 的Tensor,其中:

output[i,j,k,...,m,n]=diagnoal[i,j,k,...,n] for m==n
output[i,j,k,...,m,n]=x[i,j,k,...,m,n] for m!=n

输入:

  • x (Tensor) - 输入的对角矩阵。秩为k+1,k大于等于1。支持如下数据类型:float32、float16、int32、int8和uint8。

  • diagonal (Tensor) - 输入的对角线值。必须与输入 x 的shape相同。秩为k,k大于等于1。

输出:

Tensor,shape和数据类型与输入 x 相同。

异常:

  • TypeError - xdiagonal 的数据类型不是float32、float16、int32、int8或uint8。

  • ValueError - x 的shape长度小于2。

  • ValueError - x_shape[2]<x_shape[1]x_shape[:1]!=diagonal_shape

  • ValueError - x_shape[2]>=x_shape[1]x_shape[:2]+x_shape[1:]!=diagonal_shape

支持平台:

Ascend

样例:

>>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32)
>>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32)
>>> matrix_set_diag = nn.MatrixSetDiag()
>>> output = matrix_set_diag(x, diagonal)
>>> print(output)
[[[-1.  0.]
  [ 0.  2.]]
 [[-1.  0.]
  [ 0.  1.]]
 [[-1.  0.]
  [ 0.  1.]]]