mindspore.ops.matrix_set_diag

查看源文件
mindspore.ops.matrix_set_diag(x, diagonal, k=0, align='RIGHT_LEFT')[源代码]

返回一个tensor,使用输入 diagonal 中的元素值替换 x 矩阵的第 k[0] 条到第 k[1] 条对角线上的元素值。

参数:
  • x (Tensor) - 输入tensor,其秩不小于2。

  • diagonal (Tensor) - 输入对角线tensor。

  • k (Union[int, Tensor], 可选) - 对角线偏移。正值表示超对角线,负值表示次对角线。当k是2个整数,表示子对角线的上界和下界。默认 0

  • align (str, 可选) - 可选字符串,指定超对角线和次对角线的对齐方式。 可选 "RIGHT_LEFT""LEFT_RIGHT""LEFT_LEFT""RIGHT_RIGHT" 。 默认 "RIGHT_LEFT"

    • "RIGHT_LEFT" 表示将超对角线与右侧对齐(左侧填充行),将次对角线与左侧对齐(右侧填充行)。

    • "LEFT_RIGHT" 表示将超对角线与左侧对齐(右侧填充行),将次对角线与右侧对齐(左侧填充行)。

    • "LEFT_LEFT" 表示将超对角线和次对角线均与左侧对齐(右侧填充行)。

    • "RIGHT_RIGHT" 表示将超对角线和次对角线均与右侧对齐(左侧填充行)。

返回:

Tensor

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> x = mindspore.tensor([[7., 7., 7., 7.],
...                       [7., 7., 7., 7.],
...                       [7., 7., 7., 7.]])
>>> diagonal = mindspore.tensor([[0., 9., 1.],
...                    [6., 5., 8.],
...                    [1., 2., 3.],
...                    [4., 5., 0.]])
>>> k = mindspore.tensor(([-1, 2]), mindspore.int32)
>>> align = 'RIGHT_LEFT'
>>> output = ops.matrix_set_diag(x, diagonal, k, align)
>>> print(output)
[[1. 6. 9. 7.]
 [4. 2. 5. 1.]
 [7. 5. 3. 8.]]
>>> print(output.shape)
(3, 4)