比较与torch.Tensor.expand的差异
torch.Tensor.expand
torch.Tensor.expand(*sizes) -> Tensor
更多内容详见torch.Tensor.expand。
mindspore.Tensor.broadcast_to
mindspore.Tensor.broadcast_to(shape) -> Tensor
差异对比
MindSpore此API功能与PyTorch一致,参数支持的数据类型有差异。
PyTorch:sizes
为广播后的目标shape,其类型可以为 torch.Size
或者为由 int
构成的序列。
MindSpore:shape
为广播后的目标shape,其类型可以为 tuple[int]
。
分类 |
子类 |
PyTorch |
MindSpore |
差异 |
---|---|---|---|---|
参数 |
参数1 |
*sizes |
shape |
二者参数名不同,均表示广播后的目标shape。 |
代码示例
# PyTorch
import torch
x = torch.tensor([1, 2, 3])
output = x.expand(3, 3)
print(output)
# tensor([[1, 2, 3],
# [1, 2, 3],
# [1, 2, 3]])
# MindSpore
import mindspore
import numpy as np
from mindspore import Tensor
shape = (3, 3)
x = Tensor(np.array([1, 2, 3]))
output = x.broadcast_to(shape)
print(output)
# [[1 2 3]
# [1 2 3]
# [1 2 3]]