Function Differences with torch.nn.Flatten

View Source On Gitee

torch.nn.Flatten

class torch.nn.Flatten(
    start_dim=1,
    end_dim=-1
)

For more information, see torch.nn.Flatten.

mindspore.nn.Flatten

class mindspore.nn.Flatten()(input)

For more information, see mindspore.nn.Flatten.

Differences

PyTorch: Supports the flatten of elements by specified dimensions. This should be used together with torch.nn.Sequential

MindSpore:Only the 0th dimension element is reserved and the elements of the remaining dimensions are flattened.

Code Example

import mindspore as ms
from mindspore import nn
import torch
import numpy as np

# In MindSpore, only the 0th dimension will be reserved and the rest will be flattened.
input_tensor = ms.Tensor(np.ones(shape=[1, 2, 3, 4]), ms.float32)
flatten = nn.Flatten()
output = flatten(input_tensor)
print(output.shape)
# Out:
# (1, 24)

# In torch, the dimension to reserve can be specified and the rest will be flattened.
# Different from torch.flatten, you should pass it as parameter into torch.nn.Sequential.
input_tensor = torch.Tensor(np.ones(shape=[1, 2, 3, 4]))
flatten1 = torch.nn.Sequential(torch.nn.Flatten(start_dim=1))
output1 = flatten1(input_tensor)
print(output1.shape)
# Out:
# torch.Size([1, 24])

input_tensor = torch.Tensor(np.ones(shape=[1, 2, 3, 4]))
flatten2 = torch.nn.Sequential(torch.nn.Flatten(start_dim=2))
output2 = flatten2(input_tensor)
print(output2.shape)
# Out:
# torch.Size([1, 2, 12])