mindspore.nn.MatrixDiag

class mindspore.nn.MatrixDiag[source]

Returns a batched diagonal tensor with a given batched diagonal values.

Assume x has k dimensions [I,J,K,...,N], then the output is a tensor of rank k+1 with dimensions [I,J,K,...,N,N] where: output[i,j,k,...,m,n]=1{m=n}x[i,j,k,...,n]

Inputs:
  • x (Tensor) - The diagonal values. It can be one of the following data types: float32, float16, int32, int8, and uint8. The shape is (N,) where means, any number of additional dimensions.

Outputs:

Tensor, has the same type as input x. The shape must be x.shape + (x.shape[-1], ).

Raises

TypeError – If dtype of x is not one of float32, float16, int32, int8 or uint8.

Supported Platforms:

Ascend

Examples

>>> x = Tensor(np.array([1, -1]), mindspore.float32)
>>> matrix_diag = nn.MatrixDiag()
>>> output = matrix_diag(x)
>>> print(x.shape)
(2,)
>>> print(output)
[[ 1.  0.]
 [ 0. -1.]]
>>> print(output.shape)
(2, 2)
>>> x = Tensor(np.array([[1, -1], [1, -1]]), mindspore.float32)
>>> matrix_diag = nn.MatrixDiag()
>>> output = matrix_diag(x)
>>> print(x.shape)
(2, 2)
>>> print(output)
[[[ 1.  0.]
  [ 0. -1.]]
 [[ 1.  0.]
  [ 0. -1.]]]
>>> print(output.shape)
(2, 2, 2)
>>> x = Tensor(np.array([[1, -1, 1], [1, -1, 1]]), mindspore.float32)
>>> matrix_diag = nn.MatrixDiag()
>>> output = matrix_diag(x)
>>> print(x.shape)
(2, 3)
>>> print(output)
[[[ 1.  0.  0.]
  [ 0. -1.  0.]
  [ 0.  0.  1.]
 [[ 1.  0.  0.]
  [ 0. -1.  0.]
  [ 0.  0.  1.]]]
>>> print(output.shape)
(2, 3, 3)