mindspore.ops.Ormqr
- class mindspore.ops.Ormqr(left=True, transpose=False)[source]
Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix. Multiplies a(m, n) matrix C (given by other) with a matrix Q, where Q is represented using Householder reflectors (x, tau), which is the output of
mindspore.ops.geqrf()
.Refer to
mindspore.ops.ormqr()
for more details.Warning
This is an experimental API that is subject to change or deletion.
- Parameters
- Inputs:
x (Tensor) - Tensor of shape \((*, mn, k)\) where the value of mn depending on left, When left is
True
, the value of mn is equal to m; otherwise, the value of mn is equal to n. and * is zero or more batch dimensions.tau (Tensor) - Tensor of shape \((*, min(mn, k))\) where * is zero or more batch dimensions, and its type is the same as x.
other (Tensor) - Tensor of shape \((*, m, n)\) where * is zero or more batch dimensions, and its type is the same as x.
- Outputs:
y (Tensor) - the output Tensor, has the same shape and data type as other.
- Raises
TypeError – If x or tau or other is not Tensor.
TypeError – If dtype of x or tau or other is not one of: float64, float32, complex64, complex128.
ValueError – If x or other is less than 2D.
ValueError – If rank(x) - rank(tau) != 1.
ValueError – If tau.shape[:-1] != x.shape[:-2]
ValueError – If other.shape[:-2] != x.shape[:-2]
ValueError – If left == True, other.shape[-2] < tau.shape[-1].
ValueError – If left == True, other.shape[-2] != x.shape[-2].
ValueError – If left == False, other.shape[-1] < tau.shape[-1].
ValueError – If left == False, other.shape[-1] != x.shape[-2].
- Supported Platforms:
GPU
Examples
>>> import mindspore >>> import numpy as np >>> from mindspore import Tensor, ops >>> x = Tensor(np.array([[-114.6, 10.9, 1.1], [-0.304, 38.07, 69.38], [-0.45, -0.17, 62]]), mindspore.float32) >>> tau = Tensor(np.array([1.55, 1.94, 3.0]), mindspore.float32) >>> other = Tensor(np.array([[-114.6, 10.9, 1.1], ... [-0.304, 38.07, 69.38], ... [-0.45, -0.17, 62]]), mindspore.float32) >>> net = ops.Ormqr() >>> y = net(x, tau, other) >>> print(y) [[ 63.82713 -13.823125 -116.28614 ] [ -53.659264 -28.157839 -70.42702 ] [ -79.54292 24.00183 -41.34253 ]]