mindspore.mint.bmm

查看源文件
mindspore.mint.bmm(input, mat2)[源代码]

基于batch维度的两个Tensor的矩阵乘法,仅支持三维。

\[\text{output} = \text{input} @ \text{mat2}\]
参数:
  • input (Tensor) - 输入相乘的第一个Tensor。必须是三维Tensor,shape为 \((b, n, m)\)

  • mat2 (Tensor) - 输入相乘的第二个Tensor。必须是三维Tensor,shape为 \((b, m, p)\)

返回:

Tensor,输出Tensor的shape为 \((b, n, p)\) 。其中每个矩阵是输入批次中相应矩阵的乘积。

异常:
  • ValueError - inputmat2 的维度不为3。

  • ValueError - input 第三维的长度不等于 mat2 第二维的长度。

  • ValueError - input 的 batch 维长度不等于 mat2 的 batch 维长度。

支持平台:

Ascend

样例:

>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore import mint
>>> a = Tensor(np.ones(shape=[2, 3, 4]), mindspore.float32)
>>> b = Tensor(np.ones(shape=[2, 4, 5]), mindspore.float32)
>>> output = mint.bmm(a, b)
>>> print(output)
[[[4. 4. 4. 4. 4.]
  [4. 4. 4. 4. 4.]
  [4. 4. 4. 4. 4.]]
 [[4. 4. 4. 4. 4.]
  [4. 4. 4. 4. 4.]
  [4. 4. 4. 4. 4.]]]