mindspore.nn.Unflatten
- class mindspore.nn.Unflatten(axis, unflattened_size)[source]
Unflattens a Tensor dim according to axis and unflattened_size.
- Parameters
- Inputs:
input (Tensor) - The input Tensor to be unflattened.
- Outputs:
Tensor that has been unflattend.
- Raises
- Supported Platforms:
Ascend
GPU
CPU
Examples
>>> import mindspore >>> from mindspore import Tensor, nn >>> import numpy as np >>> input = Tensor(np.arange(0, 100).reshape(2, 10, 5), mindspore.float32) >>> net = nn.Unflatten(1, (2, 5)) >>> output = net(input) >>> print(f"before unflatten the input shape is {input.shape}") before unflatten the input shape is (2, 10, 5) >>> print(f"after unflatten the output shape is {output.shape}") after unflatten the output shape is (2, 2, 5, 5)