mindspore.ops.matrix_set_diag

View Source On Gitee
mindspore.ops.matrix_set_diag(x, diagonal, k=0, align='RIGHT_LEFT')[source]

Return a tensor by replacing the elements on the k[0]-th to k[1]-th diagonals of the matrix x with the values from the input diagonal .

Parameters
  • x (Tensor) – The input tensor with rank r, where r >= 2.

  • diagonal (Tensor) – A diagonal tensor.

  • k (Union[int, Tensor], optional) – Diagonal offsets. Positive value means superdiagonal, and negative value means subdiagonals. When k is a pair of integers specifying the low and high ends of a matrix band. Default 0 .

  • align (str, optional) –

    specifies how superdiagonals and subdiagonals should be aligned. Supported values "RIGHT_LEFT" , "LEFT_RIGHT" , "LEFT_LEFT" , "RIGHT_RIGHT" . Default "RIGHT_LEFT" .

    • When set to "RIGHT_LEFT", the alignment of superdiagonals will be towards the right side (padding the row on the left), while subdiagonals will be towards the left side (padding the row on the right)

    • When set to "LEFT_RIGHT", the alignment of superdiagonals will be towards the left side (padding the row on the right), while subdiagonals will be towards the right side (padding the row on the left)

    • When set to "LEFT_LEFT", the alignment of both superdiagonals and subdiagonals will be towards the left side(padding the row on the right).

    • When set to "RIGHT_RIGHT", the alignment of both superdiagonals and subdiagonals will be towards the right side(padding the row on the left).

Returns

Tensor

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> x = mindspore.tensor([[7., 7., 7., 7.],
...                       [7., 7., 7., 7.],
...                       [7., 7., 7., 7.]])
>>> diagonal = mindspore.tensor([[0., 9., 1.],
...                    [6., 5., 8.],
...                    [1., 2., 3.],
...                    [4., 5., 0.]])
>>> k = mindspore.tensor(([-1, 2]), mindspore.int32)
>>> align = 'RIGHT_LEFT'
>>> output = ops.matrix_set_diag(x, diagonal, k, align)
>>> print(output)
[[1. 6. 9. 7.]
 [4. 2. 5. 1.]
 [7. 5. 3. 8.]]
>>> print(output.shape)
(3, 4)