mindspore.scipy.linalg.cho_factor

mindspore.scipy.linalg.cho_factor(a, lower=False, overwrite_a=False, check_finite=True)[源代码]

Compute the cholesky decomposition of a matrix, to use in cho_solve.

Returns a matrix containing the cholesky decomposition, \(a = l l*\) or \(a = u* u\) of a Hermitian positive-definite matrix a. The return value can be directly used as the first parameter to cho_solve.

Note

  • cho_factor is not supported on Windows platform yet.

  • Only float32, float64, int32, int64 are supported Tensor dtypes. If Tensor with dtype int32 or int64 is passed, it will be cast to mstype.float64.

Warning

The returned matrix also contains random data in the entries not used by the cholesky decomposition. If you need to zero these entries, use the function cholesky instead.

Parameters
  • a (Tensor) – square Matrix of (M, M) to be decomposed.

  • lower (bool, optional) – Whether to compute the upper or lower triangular cholesky factorization. Default: False.

  • overwrite_a (bool, optional) – Whether to overwrite data in a (may improve performance). Default: False. in mindspore, this arg does not work right now.

  • check_finite (bool, optional) – Whether to check that the input matrix contains only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs. Default: True. in mindspore, this arg does not work right now.

Returns

  • Tensor, matrix whose upper or lower triangle contains the cholesky factor of a. Other parts of the matrix contain random data.

  • bool, flag indicating whether the factor is in the lower or upper triangle

Raises

ValueError – If input a tensor is not a square matrix or it’s dims not equal to 2D.

Supported Platforms:

CPU GPU

Examples

>>> import numpy as onp
>>> from mindspore.common import Tensor
>>> from mindspore.scipy.linalg import cho_factor
>>> a = Tensor(onp.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]]).astype(onp.float32))
>>> c, low = cho_factor(a)
>>> print(c)
[[ 3.          1.          0.33333334  1.6666666 ]
 [ 3.          2.4494898   1.9051585  -0.2721655 ]
 [ 1.          5.          2.2933078   0.8559526 ]
 [ 5.          1.          2.          1.5541857 ]]