比较与torch.nn.Flatten的功能差异
torch.nn.Flatten
class torch.nn.Flatten(
start_dim=1,
end_dim=-1
)
更多内容详见torch.nn.Flatten。
mindspore.nn.Flatten
class mindspore.nn.Flatten()(input)
更多内容详见mindspore.nn.Flatten。
使用方式
PyTorch:支持指定维度对元素进行展开,默认保留第零维,对其余维度的元素进行展开;需要同torch.nn.Sequential
一起使用。
MindSpore:仅支持保留第零维元素,对其余维度的元素进行展开。
代码示例
import mindspore as ms
from mindspore import nn
import torch
import numpy as np
# In MindSpore, only the 0th dimension will be reserved and the rest will be flattened.
input_tensor = ms.Tensor(np.ones(shape=[1, 2, 3, 4]), ms.float32)
flatten = nn.Flatten()
output = flatten(input_tensor)
print(output.shape)
# Out:
# (1, 24)
# In torch, the dimension to reserve can be specified and the rest will be flattened.
# Different from torch.flatten, you should pass it as parameter into torch.nn.Sequential.
input_tensor = torch.Tensor(np.ones(shape=[1, 2, 3, 4]))
flatten1 = torch.nn.Sequential(torch.nn.Flatten(start_dim=1))
output1 = flatten1(input_tensor)
print(output1.shape)
# Out:
# torch.Size([1, 24])
input_tensor = torch.Tensor(np.ones(shape=[1, 2, 3, 4]))
flatten2 = torch.nn.Sequential(torch.nn.Flatten(start_dim=2))
output2 = flatten2(input_tensor)
print(output2.shape)
# Out:
# torch.Size([1, 2, 12])