mindspore.mint.addbmm

查看源文件
mindspore.mint.addbmm(input, batch1, batch2, *, beta=1, alpha=1)[源代码]

batch1batch2 应用批量矩阵乘法后进行规约加, input 和最终的结果相加。 alphabeta 分别是 batch1batch2 矩阵乘法和 input 的乘数。如果 beta 是0,那么 input 将会被忽略。

\[output = \beta input + \alpha (\sum_{i=0}^{b-1} {batch1_i @ batch2_i})\]

警告

这是一个实验性API,后续可能修改或删除。

参数:
  • input (Tensor) - 被添加的Tensor。

  • batch1 (Tensor) - 矩阵乘法中的第一个Tensor。

  • batch2 (Tensor) - 矩阵乘法中的第二个Tensor。

关键字参数:
  • beta (Union[int, float],可选) - input 的乘数。默认值: 1

  • alpha (Union[int, float],可选) - batch1 @ batch2 的乘数。默认值: 1

返回:

Tensor,和 input 具有相同的dtype。

异常:
  • TypeError - 如果 alphabeta 不是int或者float。

  • ValueError - 如果 batch1batch2 不能进行批量矩阵乘法。

  • ValueError - 如果 batch1batch2 的不是三维Tensor。

支持平台:

Ascend

样例:

>>> import numpy as np
>>> from mindspore import Tensor, mint
>>> m = np.ones((3, 3)).astype(np.float32)
>>> arr1 = np.arange(24).astype(np.float32).reshape((2, 3, 4))
>>> arr2 = np.arange(24).astype(np.float32).reshape((2, 4, 3))
>>> a = Tensor(arr1)
>>> b = Tensor(arr2)
>>> c = Tensor(m)
>>> output = mint.addbmm(c, a, b)
>>> print(output)
[[ 949. 1009. 1069.]
 [1285. 1377. 1469.]
 [1621. 1745. 1869.]]