比较与torch.dot的差异

查看源文件

torch.dot

torch.dot(input, other, *, out=None)

更多内容详见torch.dot

mindspore.ops.tensor_dot

mindspore.ops.tensor_dot(x1, x2, axes)

更多内容详见mindspore.ops.tensor_dot

使用方式

MindSpore此API功能与PyTorch不一致。

PyTorch:计算两个相同shape的tensor的点乘(内积),仅支持1D。支持的输入数据类型包括uint8、int8/16/32/64、float32/64。

MindSpore:计算两个tensor在任意轴上的点乘,支持任意维度的tensor,但指定的轴对应的形状要相等。当输入为1D, 轴设定为1时,和PyTorch的功能一致。支持的输入数据类型为float16或float32。

分类

子类

PyTorch

MindSpore

差异

参数

参数 1

input

x1

参数名不同

参数 2

other

x2

参数名不同

参数 3

out

-

详见通用差异参数表

参数 4

-

axes

当输入为1D,axes设定为1时,和PyTorch的功能一致

代码示例 1

输入的数据类型是int,输出的数据类型也是int。

import torch
input_x1 = torch.tensor([2, 3, 4], dtype=torch.int32)
input_x2 = torch.tensor([2, 1, 3], dtype=torch.int32)
output = torch.dot(input_x1, input_x2)
print(output)
# tensor(19, dtype=torch.int32)
print(output.dtype)
# torch.int32

# MindSpore目前无法支持该功能。

代码示例 2

输入的数据类型是float,输出的数据类型也是float。

import torch
input_x1 = torch.tensor([2, 3, 4], dtype=torch.float32)
input_x2 = torch.tensor([2, 1, 3], dtype=torch.float32)
output = torch.dot(input_x1, input_x2)
print(output)
# tensor(19.)
print(output.dtype)
# torch.float32

import mindspore as ms
import mindspore.ops as ops
import numpy as np
input_x1 = ms.Tensor(np.array([2, 3, 4]), ms.float32)
input_x2 = ms.Tensor(np.array([2, 1, 3]), ms.float32)
output = ops.tensor_dot(input_x1, input_x2, 1)
print(output)
# 19.0
print(output.dtype)
# Float32