mindspore.nn.MatrixDiagPart
- class mindspore.nn.MatrixDiagPart[source]
Returns the batched diagonal part of a batched tensor.
Assume x has \(k\) dimensions \([I, J, K, ..., M, N]\), then the output is a tensor of rank \(k-1\) with dimensions \([I, J, K, ..., min(M, N)]\) where: \(output[i, j, k, ..., n] = x[i, j, k, ..., n, n]\)
- Inputs:
x (Tensor) - The batched tensor. It can be one of the following data types: float32, float16, int32, int8, and uint8.
- Outputs:
Tensor, has the same type as input x. The shape must be x.shape[:-2] + [min(x.shape[-2:])].
- Raises
TypeError – If dtype of x is not one of float32, float16, int32, int8 or uint8.
- Supported Platforms:
Ascend
Examples
>>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) >>> matrix_diag_part = nn.MatrixDiagPart() >>> output = matrix_diag_part(x) >>> print(output) [[-1. 1.] [-1. 1.] [-1. 1.]]