Function Differences with torch.nn.Flatten
torch.nn.Flatten
class torch.nn.Flatten(
start_dim=1,
end_dim=-1
)
For more information, see torch.nn.Flatten.
mindspore.nn.Flatten
class mindspore.nn.Flatten()(input)
For more information, see mindspore.nn.Flatten.
Differences
PyTorch: Supports the flatten of elements by specified dimensions. This should be used together with torch.nn.Sequential
MindSpore:Only the 0th dimension element is reserved and the elements of the remaining dimensions are flattened.
Code Example
import mindspore as ms
from mindspore import nn
import torch
import numpy as np
# In MindSpore, only the 0th dimension will be reserved and the rest will be flattened.
input_tensor = ms.Tensor(np.ones(shape=[1, 2, 3, 4]), ms.float32)
flatten = nn.Flatten()
output = flatten(input_tensor)
print(output.shape)
# Out:
# (1, 24)
# In torch, the dimension to reserve can be specified and the rest will be flattened.
# Different from torch.flatten, you should pass it as parameter into torch.nn.Sequential.
input_tensor = torch.Tensor(np.ones(shape=[1, 2, 3, 4]))
flatten1 = torch.nn.Sequential(torch.nn.Flatten(start_dim=1))
output1 = flatten1(input_tensor)
print(output1.shape)
# Out:
# torch.Size([1, 24])
input_tensor = torch.Tensor(np.ones(shape=[1, 2, 3, 4]))
flatten2 = torch.nn.Sequential(torch.nn.Flatten(start_dim=2))
output2 = flatten2(input_tensor)
print(output2.shape)
# Out:
# torch.Size([1, 2, 12])