mindspore.ops.cholesky_solve

View Source On Gitee
mindspore.ops.cholesky_solve(input, input2, upper=False)[source]

Computes the solution of a set of linear equations with a positive definite matrix, according to its Cholesky decomposition factor input2 .

If upper is set to True and input2 is upper triangular, the output tensor is that:

\[output = (input2^{T} * input2)^{{-1}}input\]

If upper is set to False and input2 is lower triangular, the output is that:

\[output = (input2 * input2^{T})^{{-1}}input\]

Warning

This is an experimental API that is subject to change or deletion.

Parameters
  • input (Tensor) – Tensor of shape \((*, N, M)\), indicating 2D or 3D matrices, with float32 or float64 data type.

  • input2 (Tensor) – Tensor of shape \((*, N, N)\), indicating 2D or 3D square matrices composed of upper or lower triangular Cholesky factor, with float32 or float64 data type. input and input2 must have the same type.

  • upper (bool, optional) – A flag indicates whether to treat the Cholesky factor as an upper or a lower triangular matrix. Default: False, treating the Cholesky factor as a lower triangular matrix.

Returns

Tensor, has the same shape and data type as input.

Raises
  • TypeError – If upper is not a bool.

  • TypeError – If dtype of input and input2 is not float64 or float32.

  • TypeError – If input is not a Tensor.

  • TypeError – If input2 is not a Tensor.

  • ValueError – If input and input2 have different batch size.

  • ValueError – If input and input2 have different row numbers.

  • ValueError – If input is not 2D or 3D matrices.

  • ValueError – If input2 is not 2D or 3D square matrices.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> input1 = Tensor(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), mindspore.float32)
>>> input2 = Tensor(np.array([[2, 0, 0], [4, 1, 0], [-1, 1, 2]]), mindspore.float32)
>>> out = ops.cholesky_solve(input1, input2, upper=False)
>>> print(out)
[[ 5.8125 -2.625   0.625 ]
 [-2.625   1.25   -0.25  ]
 [ 0.625  -0.25    0.25  ]]