mindspore.ops.batch_dot

View Source On Gitee
mindspore.ops.batch_dot(x1, x2, axes=None)[source]

Computation of batch dot product between samples in two tensors containing batch dims, i.e. x1 or x2 ‘s first dimension is batch size.

\[output = x1[batch, :] * x2[batch, :]\]
Parameters
  • x1 (Tensor) – First tensor in Batch Dot op with datatype float32 and the rank of x1 must be greater than or equal to 2.

  • x2 (Tensor) – Second tensor in Batch Dot op with datatype float32. The datatype of x2 should be same as x1 and the rank of x2 must be greater than or equal to 2.

  • axes (Union[int, tuple(int), list(int)]) – Single value or tuple/list of length 2 with dimensions specified for a and b each. If single value N passed, automatically picks up last N dims from a input shape and last N dimensions from b input shape in order as axes for each respectively. Default: None .

Returns

Tensor, batch dot product of x1 and x2. For example, the Shape of output for input x1 shapes \((batch, d1, axes, d2)\) and x2 shapes \((batch, d3, axes, d4)\) is \((batch, d1, d2, d3, d4)\), where d1 and d2 means any number.

Raises
  • TypeError – If type of x1 and x2 are not the same.

  • TypeError – If dtype of x1 or x2 is not float32.

  • ValueError – If rank of x1 or x2 less than 2.

  • ValueError – If batch dim used in axes.

  • ValueError – If len(axes) less than 2.

  • ValueError – If axes is not one of those: None, int, (int, int).

  • ValueError – If axes reversed from negative int is too low for dimensions of input arrays.

  • ValueError – If axes value is too high for dimensions of input arrays.

  • ValueError – If batch size of x1 and x2 are not the same.

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> from mindspore import Tensor, ops
>>> import numpy as np
>>> x1 = Tensor(np.ones(shape=[2, 2, 3]), mindspore.float32)
>>> x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
>>> axes = (-1, -2)
>>> output = ops.batch_dot(x1, x2, axes)
>>> print(output)
[[[3. 3.]
  [3. 3.]]
 [[3. 3.]
  [3. 3.]]]
>>> x1 = Tensor(np.ones(shape=[2, 2]), mindspore.float32)
>>> x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
>>> axes = (1, 2)
>>> output = ops.batch_dot(x1, x2, axes)
>>> print(output)
[[2. 2. 2.]
 [2. 2. 2.]]
>>> print(output.shape)
(2, 3)
>>> x1 = Tensor(np.ones(shape=[6, 2, 3, 4]), mindspore.float32)
>>> x2 = Tensor(np.ones(shape=[6, 5, 4, 8]), mindspore.float32)
>>> output = ops.batch_dot(x1, x2)
>>> print(output.shape)
(6, 2, 3, 5, 8)
>>> x1 = Tensor(np.ones(shape=[2, 2, 4]), mindspore.float32)
>>> x2 = Tensor(np.ones(shape=[2, 5, 4, 5]), mindspore.float32)
>>> output = ops.batch_dot(x1, x2)
>>> print(output.shape)
(2, 2, 5, 5)