mindspore.ops.batch_dot

查看源文件
mindspore.ops.batch_dot(x1, x2, axes=None)[源代码]

当输入的两个Tensor是批量数据时,对其进行批量点积操作。其中 x1x2 的第一维度为batch维。

\[output = x1[batch, :] * x2[batch, :]\]
参数:
  • x1 (Tensor) - 第一个输入Tensor,数据类型为float32且 x1 的秩必须大于或等于2。

  • x2 (Tensor) - 第二个输入Tensor,数据类型为float32。 x2 的数据类型应与 x1 相同,x2 的秩必须大于或等于2。

  • axes (Union[int, tuple(int), list(int)]) - 指定为单值或长度为2的tuple和list,分别指定 ab 的维度。如果传递了单个值 N,则自动从输入 a 的shape中获取最后N个维度,从输入 b 的shape中获取最后N个维度,分别作为每个维度的轴。默认值: None

返回:

Tensor, x1x2 的批量点积。例如:输入 x1 的shape为 \((batch, d1, axes, d2)\)x2 shape为 \((batch, d3, axes, d4)\),则输出shape为 \((batch, d1, d2, d3, d4)\),其中d1和d2表示任意数字。

异常:
  • TypeError - x1x2 的类型不相同。

  • TypeError - x1x2 的数据类型不是float32。

  • ValueError - x1x2 的秩小于2。

  • ValueError - 在 axes 中使用了代表批量的维度。

  • ValueError - axes 的长度小于2。

  • ValueError - axes 不是其一:None,int,或(int, int)。

  • ValueError - 如果 axes 为负值,低于输入数组的维度。

  • ValueError - 如果 axes 的值高于输入数组的维度。

  • ValueError - x1x2 的第一维度batch维的大小不相同。

支持平台:

Ascend GPU CPU

样例:

>>> import mindspore
>>> from mindspore import Tensor, ops
>>> import numpy as np
>>> x1 = Tensor(np.ones(shape=[2, 2, 3]), mindspore.float32)
>>> x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
>>> axes = (-1, -2)
>>> output = ops.batch_dot(x1, x2, axes)
>>> print(output)
[[[3. 3.]
  [3. 3.]]
 [[3. 3.]
  [3. 3.]]]
>>> x1 = Tensor(np.ones(shape=[2, 2]), mindspore.float32)
>>> x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
>>> axes = (1, 2)
>>> output = ops.batch_dot(x1, x2, axes)
>>> print(output)
[[2. 2. 2.]
 [2. 2. 2.]]
>>> print(output.shape)
(2, 3)
>>> x1 = Tensor(np.ones(shape=[6, 2, 3, 4]), mindspore.float32)
>>> x2 = Tensor(np.ones(shape=[6, 5, 4, 8]), mindspore.float32)
>>> output = ops.batch_dot(x1, x2)
>>> print(output.shape)
(6, 2, 3, 5, 8)
>>> x1 = Tensor(np.ones(shape=[2, 2, 4]), mindspore.float32)
>>> x2 = Tensor(np.ones(shape=[2, 5, 4, 5]), mindspore.float32)
>>> output = ops.batch_dot(x1, x2)
>>> print(output.shape)
(2, 2, 5, 5)