mindspore.ops.flatten

mindspore.ops.flatten(input_x)[source]

Flattens a tensor without changing its batch size on the 0-th axis.

Parameters

input_x (Tensor) – Tensor of shape \((N, \ldots)\) to be flattened, where \(N\) is batch size.

Returns

Tensor, the shape of the output tensor is \((N, X)\), where \(X\) is the product of the remaining dimension.

Raises
  • TypeError – If input_x is not a Tensor.

  • ValueError – If length of shape of input_x is less than 1.

Supported Platforms:

Ascend GPU CPU

Examples

>>> input_x = Tensor(np.ones(shape=[1, 2, 3, 4]), mindspore.float32)
>>> output = ops.flatten(input_x)
>>> print(output.shape)
(1, 24)