mindspore.ops.LuUnpack
- class mindspore.ops.LuUnpack(unpack_data=True, unpack_pivots=True)[source]
Converts LU_data and LU_pivots back into P, L and U matrices, where P is a permutation matrix, L is a lower triangular matrix, and U is an upper triangular matrix. Typically, LU_data and LU_pivots are generated from the LU decomposition of a matrix.
Warning
This is an experimental API that is subject to change or deletion.
Refer to
mindspore.ops.lu_unpack()
for more details.- Parameters
unpack_data (bool, optional) – A flag indicating if the LU_data should be unpacked. If False, then the returned L and U are None. Default: True.
unpack_pivots (bool, optional) – A flag indicating if the LU_pivots should be unpacked into a permutation matrix P. If False, then the returned P is None. Default: True.
- Inputs:
LU_data (Tensor) - The packed LU factorization data. The shape of a tensor is \((*, M, N)\), where \(*\) is batch dimensions, with data type int8, uint8, int16, int32, int64, float16, float32, float64. The dims of LU_data must be equal to or greater than 2.
LU_pivots (Tensor) - The packed LU factorization pivots. The shape of a tensor is \((*, min(M, N))\), where \(*\) is batch dimensions, with data type int8, uint8, int16, int32, int64.
- Outputs:
pivots (Tensor) - The permutation matrix of LU factorization. The shape is \((*, M, M)\), the dtype is same as LU_data.
L (Tensor) - The L matrix of LU factorization. The dtype is the same as LU_data.
U (Tensor) - The U matrix of LU factorization. The dtype is the same as LU_data.
- Supported Platforms:
GPU
CPU
Examples
>>> LU_data = Tensor(np.array([[[-0.3806, -0.4872, 0.5536], ... [-0.1287, 0.6508, -0.2396], ... [ 0.2583, 0.5239, 0.6902]], ... [[ 0.6706, -1.1782, 0.4574], ... [-0.6401, -0.4779, 0.6701], ... [ 0.1015, -0.5363, 0.6165]]]), mstype.float32) >>> LU_pivots = Tensor(np.array([[1, 3, 3], ... [2, 3, 3]]), mstype.int32) >>> lu_unpack = ops.LuUnpack() >>> pivots, L, U = lu_unpack(LU_data, LU_pivots) >>> print(pivots) [[[1. 0. 0.] [0. 0. 1.] [0. 1. 0.]] [[0. 0. 1.] [1. 0. 0.] [0. 1. 0.]]] >>> print(L) [[[ 1. 0. 0. ] [-0.1287 1. 0. ] [ 0.2583 0.5239 1. ]] [[ 1. 0. 0. ] [-0.6401 1. 0. ] [ 0.1015 -0.5363 1. ]]] >>> print(U) [[[-0.3806 -0.4872 0.5536] [ 0. 0.6508 -0.2396] [ 0. 0. 0.6902]] [[ 0.6706 -1.1782 0.4574] [ 0. -0.4779 0.6701] [ 0. 0. 0.6165]]]