比较与torch.flatten的功能差异
torch.flatten
torch.flatten(
input,
start_dim=0,
end_dim=-1
)
更多内容详见torch.flatten。
mindspore.ops.flatten
mindspore.ops.flatten(input, order='C', *, start_dim=1, end_dim=-1)
更多内容详见mindspore.ops.flatten。
使用方式
PyTorch:支持指定维度对元素进行展开,start_dim
默认为0,end_dim
默认为-1。
MindSpore:支持指定维度对元素进行展开,start_dim
默认为1,end_dim
默认为-1。通过 order
为”C”或”F”确定优先按行还是列展平。
分类 |
子类 |
PyTorch |
MindSpore |
差异 |
---|---|---|---|---|
参数 |
参数1 |
input |
input |
功能一致 |
参数2 |
- |
order |
展平优先顺序选项,PyTorch无此参数 |
|
参数3 |
start_dim |
start_dim |
功能一致 |
|
参数4 |
end_dim |
end_dim |
功能一致 |
代码示例
import mindspore as ms
import mindspore.ops as ops
import torch
import numpy as np
# MindSpore
input_tensor = ms.Tensor(np.ones(shape=[1, 2, 3, 4]), ms.float32)
output = ops.flatten(input_tensor)
print(output.shape)
# Out:
# (1, 24)
input_tensor = ms.Tensor(np.ones(shape=[1, 2, 3, 4]), ms.float32)
output = ops.flatten(input_tensor, start_dim=2)
print(output.shape)
# Out:
# (1, 2, 12)
# PyTorch
input_tensor = torch.Tensor(np.ones(shape=[1, 2, 3, 4]))
output1 = torch.flatten(input=input_tensor, start_dim=1)
print(output1.shape)
# Out:
# torch.Size([1, 24])
input_tensor = torch.Tensor(np.ones(shape=[1, 2, 3, 4]))
output2 = torch.flatten(input=input_tensor, start_dim=2)
print(output2.shape)
# Out:
# torch.Size([1, 2, 12])