mindspore.mint.matmul
- mindspore.mint.matmul(input, other)[源代码]
计算两个数组的矩阵乘积。
说明
不支持NumPy参数 out 、 casting 、 order 、 subok 、 signature 、 extobj 。 input 和 other 的数据类型必须一致。在Ascend平台上,input 和 other 的秩必须在 1 到 6 之间。
- 参数:
input (Tensor) - 输入Tensor,不支持Scalar, input 的最后一维度和 other 的倒数第二维度相等,且 input 和 other 彼此支持广播。
other (Tensor) - 输入Tensor,不支持Scalar, input 的最后一维度和 other 的倒数第二维度相等,且 input 和 other 彼此支持广播。
- 返回:
Tensor或Scalar,输入的矩阵乘积。仅当 input 和 other 为一维向量时,输出为Scalar。
- 异常:
TypeError - input 的dtype和 other 的dtype不一致。
ValueError - input 的最后一维度和 other 的倒数第二维度不相等,或者输入的是Scalar。
ValueError - input 和 other 彼此不能广播。
RuntimeError - 在Ascend平台上, input 或 other 的秩小于 1 或大于 6。
- 支持平台:
Ascend
样例:
>>> import mindspore >>> import numpy as np >>> from mindspore import Tensor, mint >>> # case 1 : Reasonable application of broadcast mechanism >>> input = Tensor(np.arange(2*3*4).reshape(2, 3, 4), mindspore.float32) >>> other = Tensor(np.arange(4*5).reshape(4, 5), mindspore.float32) >>> output = mint.matmul(input, other) >>> print(output) [[[ 70. 76. 82. 88. 94.] [ 190. 212. 234. 256. 278.] [ 310. 348. 386. 424. 462.]] [[ 430. 484. 538. 592. 646.] [ 550. 620. 690. 760. 830.] [ 670. 756. 842. 928. 1014.]]] >>> print(output.shape) (2, 3, 5) >>> # case 2 : the rank of `input` is 1 >>> input = Tensor(np.ones([1, 2]), mindspore.float32) >>> other = Tensor(np.ones([2,]), mindspore.float32) >>> output = mint.matmul(input, other) >>> print(output) [2.] >>> print(output.shape) (1,)