mindspore.ops.matmul

View Source On Gitee
mindspore.ops.matmul(input, other)[source]

Returns the matrix product of two tensors.

Note

Numpy arguments out, casting, order, subok, signature, and extobj are not supported. On GPU, the supported dtypes are np.float16 and np.float32. On CPU, the supported dtypes are np.float16 and np.float32. The dtype of input and other must be same.

Parameters
  • input (Tensor) – Input tensor, scalar not allowed. The last dimension of input must be the same size as the second last dimension of other. And the shape of input and other could be broadcast.

  • other (Tensor) – Input tensor, scalar not allowed. The last dimension of input must be the same size as the second last dimension of other. And the shape of input and other could be broadcast.

Returns

Tensor or scalar, the matrix product of the inputs. This is a scalar only when both input, other are 1-d vectors.

Raises
  • TypeError – If the dtype of input and the dtype of other are not the same.

  • ValueError – If the last dimension of input is not the same size as the second-to-last dimension of other, or if a scalar value is passed in.

  • ValueError – If the shape of input and other could not broadcast together.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor, ops
>>> # case 1 : Reasonable application of broadcast mechanism
>>> input = Tensor(np.arange(2*3*4).reshape(2, 3, 4), mindspore.float32)
>>> other = Tensor(np.arange(4*5).reshape(4, 5), mindspore.float32)
>>> output = ops.matmul(input, other)
>>> print(output)
[[[  70.   76.   82.   88.   94.]
[ 190.  212.  234.  256.  278.]
[ 310.  348.  386.  424.  462.]]
[[ 430.  484.  538.  592.  646.]
[ 550.  620.  690.  760.  830.]
[ 670.  756.  842.  928. 1014.]]]
>>> print(output.shape)
(2, 3, 5)
>>> # case 2 : the rank of `input` is 1
>>> input = Tensor(np.ones([1, 2]), mindspore.float32)
>>> other = Tensor(np.ones([2,]), mindspore.float32)
>>> output = ops.matmul(input, other)
>>> print(output)
[2.]
>>> print(output.shape)
(1,)