Function Differences with torch.flatten
torch.flatten
torch.flatten(
input,
start_dim=0,
end_dim=-1
)
For more information, see torch.flatten.
mindspore.ops.Flatten
class mindspore.ops.Flatten(*args, **kwargs)(input_x)
For more information, see mindspore.ops.Flatten.
Differences
PyTorch: Supports the flatten of elements by specified dimensions.
MindSpore:Only the 0th dimension element is reserved and the elements of the remaining dimensions are flattened.
Code Example
import mindspore as ms
import mindspore.ops as ops
import torch
import numpy as np
# In MindSpore, only the 0th dimension will be reserved and the rest will be flattened.
input_tensor = ms.Tensor(np.ones(shape=[1, 2, 3, 4]), ms.float32)
flatten = ops.Flatten()
output = flatten(input_tensor)
print(output.shape)
# Out:
# (1, 24)
# In torch, the dimension to reserve will be specified and the rest will be flattened.
input_tensor = torch.Tensor(np.ones(shape=[1, 2, 3, 4]))
output1 = torch.flatten(input=input_tensor, start_dim=1)
print(output1.shape)
# Out:
# torch.Size([1, 24])
input_tensor = torch.Tensor(np.ones(shape=[1, 2, 3, 4]))
output2 = torch.flatten(input=input_tensor, start_dim=2)
print(output2.shape)
# Out:
# torch.Size([1, 2, 12])