比较与torch.Tensor.flatten的功能差异
torch.Tensor.flatten
torch.Tensor.flatten(input, start_dim=0, end_dim=-1)
更多内容详见torch.Tensor.flatten。
mindspore.Tensor.flatten
mindspore.Tensor.flatten(order="C")
更多内容详见mindspore.Tensor.flatten。
使用方式
torch.flatten
通过入参start_dim
,end_dim
限制需要扩展的维度范围。
mindspore.Tensor.flatten
通过order
为”C”或”F”确定优先按行还是列展平。
代码示例
import mindspore as ms
a = ms.Tensor([[1,2], [3,4]], ms.float32)
print(a.flatten())
# [1. 2. 3. 4.]
print(a.flatten('F'))
# [1. 3. 2. 4.]
import torch
b = torch.tensor([[[1, 2],[3, 4]],[[5, 6],[7, 8]]])
print(torch.Tensor.flatten(b))
# tensor([1, 2, 3, 4, 5, 6, 7, 8])
print(torch.Tensor.flatten(b, start_dim=1))
# tensor([[1, 2, 3, 4],
# [5, 6, 7, 8]])