
class mindspore.nn.Flatten[source]

Flatten the dimensions other than the 0th dimension of the input Tensor.

  • x (Tensor) - The input Tensor to be flattened. The data type is number . The shape is \((N, *)\) , where \(*\) means any number of additional dimensions and the shape can’t be ().


Tensor, the shape of the output tensor is \((N, X)\), where \(X\) is the product of the remaining dimensions.


TypeError – If x is not a subclass of Tensor.

Supported Platforms:

Ascend GPU CPU


>>> x = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32)
>>> net = nn.Flatten()
>>> output = net(x)
>>> print(output)
[[1.2 1.2 2.1 2.1]
 [2.2 2.2 3.2 3.2]]
>>> print(f"before flatten the x shape is {x.shape}")
before flatten the x shape is  (2, 2, 2)
>>> print(f"after flatten the output shape is {output.shape}")
after flatten the output shape is (2, 4)