比较与torch.Tensor.flatten的功能差异

查看源文件

torch.Tensor.flatten

torch.Tensor.flatten(input, start_dim=0, end_dim=-1)

更多内容详见torch.Tensor.flatten

mindspore.Tensor.flatten

mindspore.Tensor.flatten(order="C")

更多内容详见mindspore.Tensor.flatten

使用方式

torch.flatten通过入参start_dimend_dim限制需要扩展的维度范围。

mindspore.Tensor.flatten通过order为”C”或”F”确定优先按行还是列展平。

代码示例

import mindspore as ms

a = ms.Tensor([[1,2], [3,4]], ms.float32)
print(a.flatten())
# [1. 2. 3. 4.]
print(a.flatten('F'))
# [1. 3. 2. 4.]

import torch

b = torch.tensor([[[1, 2],[3, 4]],[[5, 6],[7, 8]]])
print(torch.Tensor.flatten(b))
# tensor([1, 2, 3, 4, 5, 6, 7, 8])
print(torch.Tensor.flatten(b, start_dim=1))
# tensor([[1, 2, 3, 4],
#         [5, 6, 7, 8]])