mindspore.nn.Unflatten

View Source On Gitee
class mindspore.nn.Unflatten(axis, unflattened_size)[source]
Summary:

Unflattens a Tensor dim according to axis and unflattened_size.

Parameters
  • axis (int) – specifies the dimension of the input Tensor to be unflattened.

  • unflattened_size (Union(tuple[int], list[int])) – the new shape of the unflattened dimension of the Tensor and it can be a tuple of ints or a list of ints. The product of unflattened_size must equal to input_shape[axis].

Inputs:
  • input (Tensor) - The input Tensor to be unflattened.

Outputs:

Tensor that has been unflattend.

Raises
  • TypeError – If axis is not int.

  • TypeError – If unflattened_size is neither tuple of ints nor list of ints.

  • TypeError – The product of unflattened_size does not equal to input_shape[axis].

Supported Platforms:

Ascend GPU CPU

Examples

>>> import mindspore
>>> from mindspore import Tensor, nn
>>> import numpy as np
>>> input = Tensor(np.arange(0, 100).reshape(2, 10, 5), mindspore.float32)
>>> net = nn.Unflatten(1, (2, 5))
>>> output = net(input)
>>> print(f"before unflatten the input shape is {input.shape}")
before unflatten the input shape is  (2, 10, 5)
>>> print(f"after unflatten the output shape is {output.shape}")
after unflatten the output shape is (2, 2, 5, 5)