mindspore.nn.MatrixSetDiag
- class mindspore.nn.MatrixSetDiag[source]
Modifies the batched diagonal part of a batched tensor.
Assume x has
dimensions and diagonal has dimensions . Then the output is a tensor of rank with dimensions where:- Inputs:
x (Tensor) - The batched tensor. Rank k+1, where k >= 1. It can be one of the following data types: float32, float16, int32, int8, and uint8.
diagonal (Tensor) - The diagonal values. Must have the same type as input x. Rank k, where k >= 1.
- Outputs:
Tensor, has the same type and shape as input x.
- Raises
TypeError – If dtype of x or diagonal is not one of float32, float16, int32, int8 or uint8.
ValueError – If length of shape of x is less than 2.
ValueError – If x_shape[-2] < x_shape[-1] and x_shape[:-1] != diagonal_shape.
ValueError – If x_shape[-2] >= x_shape[-1] and x_shape[:-2] + x_shape[-1:] != diagonal_shape.
- Supported Platforms:
Ascend
Examples
>>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32) >>> matrix_set_diag = nn.MatrixSetDiag() >>> output = matrix_set_diag(x, diagonal) >>> print(output) [[[-1. 0.] [ 0. 2.]] [[-1. 0.] [ 0. 1.]] [[-1. 0.] [ 0. 1.]]]