mindspore.ops.matrix_diag_part

View Source On Gitee
mindspore.ops.matrix_diag_part(x, k, padding_value, align='RIGHT_LEFT')[source]

Return a tensor that retains the values of the specified diagonal while setting all other elements to zero.

Input k and padding_value must be const tensor when taking graph mode.

Parameters
  • x (Tensor) – The input tensor with rank r, where r >= 2.

  • k (Union[int, Tensor], optional) – Diagonal offsets. Positive value means superdiagonal, and negative value means subdiagonals. When k is a pair of integers specifying the low and high ends of a matrix band.

  • padding_value (Union[int, float, Tensor], optional) – The number to fill the area outside the specified diagonal band.

  • align (str, optional) –

    specifies how superdiagonals and subdiagonals should be aligned. Supported values "RIGHT_LEFT" , "LEFT_RIGHT" , "LEFT_LEFT" , "RIGHT_RIGHT" . Default "RIGHT_LEFT" .

    • When set to "RIGHT_LEFT", the alignment of superdiagonals will be towards the right side (padding the row on the left), while subdiagonals will be towards the left side (padding the row on the right)

    • When set to "LEFT_RIGHT", the alignment of superdiagonals will be towards the left side (padding the row on the right), while subdiagonals will be towards the right side (padding the row on the left)

    • When set to "LEFT_LEFT", the alignment of both superdiagonals and subdiagonals will be towards the left side(padding the row on the right).

    • When set to "RIGHT_RIGHT", the alignment of both superdiagonals and subdiagonals will be towards the right side(padding the row on the left).

Returns

Tensor

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> x = mindspore.tensor([[1., 2., 3., 4.],
...                       [5., 6., 7., 8.],
...                       [9., 8., 7., 6.]])
>>> k = mindspore.tensor([1, 3], mindspore.int32)
>>> output = mindspore.ops.matrix_diag_part(x, k, mindspore.tensor(9.), align='RIGHT_LEFT')
>>> print(output)
[[9. 9. 4.]
 [9. 3. 8.]
 [2. 7. 6.]]
>>> print(output.shape)
(3, 3)