Function Differences with tf.keras.backend.batch_dot

View Source On Gitee

tf.keras.backend.batch_dot

tf.keras.backend.batch_dot(x, y, axes=None)

For more information, see tf.keras.backend.batch_dot.

mindspore.ops.batch_dot

mindspore.ops.batch_dot(x1, x2, axes=None)

For more information, see mindspore.ops.batch_dot.

Differences

TensorFlow: When the input x and y are batch data, batch_dot returns the dot product of x and y.

MindSpore: MindSpore API implements the same function as Keras, and only the parameter names are different.

Categories

Subcategories

TensorFlow

MindSpore

Differences

Parameters

Parameter 1

x

x1

Same function, different parameter names

Parameter 2

y

x2

Same function, different parameter names

Parameter 3

axes

axes

-

Code Example 1

The two APIs without axes parameter achieve the same function and the same usage.

# TensorFlow
import keras.backend as K
import tensorflow as tf
import numpy as np

x = K.variable(np.random.randint(10,size=(10,12,4,5)), dtype=tf.float32)
y = K.variable(np.random.randint(10,size=(10,12,5,8)), dtype=tf.float32)
output = K.batch_dot(x, y)
print(output.shape)
# (10, 12, 4, 12, 8)

# MindSpore
import numpy as np
import mindspore
import mindspore.ops as ops
from mindspore import Tensor

x1 = Tensor(np.random.randint(10,size=(10,12,4,5)), mindspore.float32)
x2 = Tensor(np.random.randint(10,size=(10,12,5,8)), mindspore.float32)
output = ops.batch_dot(x1, x2)
print(output.shape)
# (10, 12, 4, 12, 8)

Code Example 2

The two APIs with axes parameter achieve the same function and the same usage.

# TensorFlow
import keras.backend as K
import tensorflow as tf
import numpy as np

x = K.variable(np.ones(shape=[2, 2]), dtype=tf.float32)
y = K.variable(np.ones(shape=[2, 3, 2]), dtype=tf.float32)
axes = (1, 2)
output = K.batch_dot(x, y, axes)
print(output.shape)
# (2, 3)

# MindSpore
import numpy as np
import mindspore
import mindspore.ops as ops
from mindspore import Tensor

x1 = Tensor(np.ones(shape=[2, 2]), mindspore.float32)
x2 = Tensor(np.ones(shape=[2, 3, 2]), mindspore.float32)
axes = (1, 2)
output = ops.batch_dot(x1, x2, axes)
print(output.shape)
# (2, 3)