比较与torch.dot的差异
torch.dot
torch.dot(input, other, *, out=None)
更多内容详见torch.dot。
mindspore.ops.tensor_dot
mindspore.ops.tensor_dot(x1, x2, axes)
更多内容详见mindspore.ops.tensor_dot。
使用方式
MindSpore此API功能与PyTorch不一致。
PyTorch:计算两个相同shape的tensor的点乘(内积),仅支持1D。支持的输入数据类型包括uint8、int8/16/32/64、float32/64。
MindSpore:计算两个tensor在任意轴上的点乘,支持任意维度的tensor,但指定的轴对应的形状要相等。当输入为1D, 轴设定为1时,和PyTorch的功能一致。支持的输入数据类型为float16或float32。
分类 |
子类 |
PyTorch |
MindSpore |
差异 |
---|---|---|---|---|
参数 |
参数 1 |
input |
x1 |
参数名不同 |
参数 2 |
other |
x2 |
参数名不同 |
|
参数 3 |
out |
- |
详见通用差异参数表 |
|
参数 4 |
- |
axes |
当输入为1D,axes设定为1时,和PyTorch的功能一致 |
代码示例 1
输入的数据类型是int,输出的数据类型也是int。
import torch
input_x1 = torch.tensor([2, 3, 4], dtype=torch.int32)
input_x2 = torch.tensor([2, 1, 3], dtype=torch.int32)
output = torch.dot(input_x1, input_x2)
print(output)
# tensor(19, dtype=torch.int32)
print(output.dtype)
# torch.int32
# MindSpore目前无法支持该功能。
代码示例 2
输入的数据类型是float,输出的数据类型也是float。
import torch
input_x1 = torch.tensor([2, 3, 4], dtype=torch.float32)
input_x2 = torch.tensor([2, 1, 3], dtype=torch.float32)
output = torch.dot(input_x1, input_x2)
print(output)
# tensor(19.)
print(output.dtype)
# torch.float32
import mindspore as ms
import mindspore.ops as ops
import numpy as np
input_x1 = ms.Tensor(np.array([2, 3, 4]), ms.float32)
input_x2 = ms.Tensor(np.array([2, 1, 3]), ms.float32)
output = ops.tensor_dot(input_x1, input_x2, 1)
print(output)
# 19.0
print(output.dtype)
# Float32