比较与torch.Tensor.expand的差异

查看源文件

torch.Tensor.expand

torch.Tensor.expand(*sizes) -> Tensor

更多内容详见torch.Tensor.expand

mindspore.Tensor.broadcast_to

mindspore.Tensor.broadcast_to(shape) -> Tensor

更多内容详见mindspore.Tensor.broadcast_to

差异对比

MindSpore此API功能与PyTorch一致,参数支持的数据类型有差异。

PyTorch:sizes 为广播后的目标shape,其类型可以为 torch.Size 或者为由 int 构成的序列。

MindSpore:shape 为广播后的目标shape,其类型可以为 tuple[int]

分类

子类

PyTorch

MindSpore

差异

参数

参数1

*sizes

shape

二者参数名不同,均表示广播后的目标shape。 sizes 的类型可以为 torch.Size 或者为由 int 构成的序列,shape 的类型可以为 tuple[int]

代码示例

# PyTorch
import torch

x = torch.tensor([1, 2, 3])
output = x.expand(3, 3)
print(output)
# tensor([[1, 2, 3],
#         [1, 2, 3],
#         [1, 2, 3]])

# MindSpore
import mindspore
import numpy as np
from mindspore import Tensor

shape = (3, 3)
x = Tensor(np.array([1, 2, 3]))
output = x.broadcast_to(shape)
print(output)
# [[1 2 3]
#  [1 2 3]
#  [1 2 3]]