mindspore.nn.MatrixSetDiag
- class mindspore.nn.MatrixSetDiag[源代码]
将输入的对角矩阵的对角线值置换为输入的对角线值。
假设 x 有
个维度 , diagonal 有 个维度 ,则输出秩为 ,维度为 的Tensor,其中:输入:
x (Tensor) - 输入的对角矩阵。秩为k+1,k大于等于1。支持如下数据类型:float32、float16、int32、int8和uint8。
diagonal (Tensor) - 输入的对角线值。必须与输入 x 的shape相同。秩为k,k大于等于1。
输出:
Tensor,shape和数据类型与输入 x 相同。
异常:
TypeError - x 或 diagonal 的数据类型不是float32、float16、int32、int8或uint8。
ValueError - x 的shape长度小于2。
ValueError -
且 。ValueError -
且 。
- 支持平台:
Ascend
样例:
>>> x = Tensor([[[-1, 0], [0, 1]], [[-1, 0], [0, 1]], [[-1, 0], [0, 1]]], mindspore.float32) >>> diagonal = Tensor([[-1., 2.], [-1., 1.], [-1., 1.]], mindspore.float32) >>> matrix_set_diag = nn.MatrixSetDiag() >>> output = matrix_set_diag(x, diagonal) >>> print(output) [[[-1. 0.] [ 0. 2.]] [[-1. 0.] [ 0. 1.]] [[-1. 0.] [ 0. 1.]]]