mindspore.ops.MatrixSolve

查看源文件
class mindspore.ops.MatrixSolve(adjoint=False)[源代码]

求解线性方程组。

参数:
  • adjoint (bool,可选) - 指明是否使用矩阵的伴随矩阵进行求解。默认值: False ,使用转置矩阵进行求解。

输入:
  • matrix (Tensor) - Tensor,线性方程组系数组成的矩阵,其shape为 \((..., M, M)\)

  • rhs (Tensor) - Tensor,线性方程组结果值组成的矩阵,其shape为 \((..., M, K)\)rhsmatrix 的类型必须一致。

输出:

Tensor,线性方程组解组成的矩阵,与 rhs 的shape及类型均相同。

异常:
  • TypeError - 如果 adjoint 不是bool型。

  • TypeError - 如果 matrix 的类型不是以下之一: mstype.float16、mstype.float32、mstype.float64、mstype.complex64、mstype.complex128。

  • TypeError - 如果 rhsmatrix 的类型不一致。

  • ValueError - 如果 matrix 的秩小于2。

  • ValueError - 如果 matrixrhs 的维度不同。

  • ValueError - 如果 matrix 的最内两维不同。

  • ValueError - 如果 rhs 的最内两维与 matrix 不匹配。

支持平台:

Ascend CPU

样例:

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> matrix = Tensor(np.array([[1.0  , 4.0],
...                       [2.0 , 7.0]]), mindspore.float32)
>>> rhs = Tensor(np.array([[1.0]  , [3.0]]), mindspore.float32)
>>> matrix_solve = ops.MatrixSolve(adjoint = False)
>>> output = matrix_solve(matrix, rhs)
>>> print(output)
[[5.0]
 [-1.0]]