mindspore.scipy.linalg.solve_triangular
- mindspore.scipy.linalg.solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False, overwrite_b=False, debug=None, check_finite=True)[source]
Solve the linear system \(a x = b\) for x, Assuming a is a triangular matrix.
Note
solve_triangular is currently only used in mindscience scientific computing scenarios and dose not support other usage scenarios.
solve_triangular is not supported on Windows platform yet.
- Parameters
a (Tensor) – A triangular matrix of shape \((*, M, M)\) where \(*\) is zero or more batch dimensions.
b (Tensor) – A Tensor of shape \((*, M)\) or \((*, M, N)\). Right-hand side matrix in \(a x = b\).
trans (Union[int, str], optional) –
Type of system to solve. Default:
0
.trans
system
0 or 'N'
a x = b
1 or 'T'
a^T x = b
2 or 'C'
a^H x = b
lower (bool, optional) – Use only data contained in the lower triangle of a. Default:
False
.unit_diagonal (bool, optional) – If
True
, diagonal elements of \(a\) are assumed to be 1 and will not be referenced. Default:False
.overwrite_b (bool, optional) – Not implemented now. Default:
False
.debug (Any, optional) – Not implemented now. Default:
None
.check_finite (bool, optional) – Not implemented now. Default:
True
.
- Returns
Tensor of shape \((*, M)\) or \((*, M, N)\), which is the solution to the system \(a x = b\). Shape of \(x\) matches \(b\).
- Raises
ValueError – If a is less than 2 dimension.
ValueError – if a is not square matrix.
TypeError – If dtype of a and b are not the same.
ValueError – If the shape of a and b are not matched.
ValueError – If trans is not in set {0, 1, 2, 'N', 'T', 'C'}.
- Supported Platforms:
Ascend
CPU
Examples
>>> import numpy as onp >>> import mindspore >>> from mindspore import Tensor >>> from mindspore.scipy.linalg import solve_triangular >>> a = Tensor(onp.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]], onp.float32)) >>> b = Tensor(onp.array([3, 1, 3, 4], onp.float32)) >>> x = solve_triangular(a, b, lower=True, unit_diagonal=False, trans='N') >>> print(x) [ 1. -1. 2. 2.] >>> print(a @ x) # Check the result [3. 1. 3. 4.]