mindspore.nn.ChannelShuffle
- class mindspore.nn.ChannelShuffle(groups)[源代码]
将shape的为 \((*, C, H, W)\) 的Tensor的通道划分成 \(g\) 组,并将其以 \((*, C \frac g, g, H, W)\) 的shape重新排列, 同时保持Tensor原有的shape。
- 参数:
groups (int) - 划分通道的组数。取值范围是 \((0, \inf)\) 。在上述公式中表示为 \(g\) 。
- 输入:
x (Tensor) - Tensor的shape \((*, C_{in}, H_{in}, W_{in})\) 。
- 输出:
Tensor,数据类型和shape与 x 相同。
- 异常:
TypeError - groups 非整数。
ValueError - groups 小于1。
ValueError - x 的维度小于3。
ValueError - Tensor的通道数不能被 groups 整除。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> channel_shuffle = nn.ChannelShuffle(2) >>> x = Tensor(np.arange(16).astype(np.int32).reshape(1, 4, 2, 2)) >>> print(x) [[[[0 1], [2 3]], [[4 5], [6 7]], [[8 9], [10 11]], [[12 13], [14 15]], ]] >>> output = channel_shuffle(x) >>> print(output) [[[[0 1], [2 3]], [[8 9], [10 11]], [[4 5], [6 7]], [[12 13], [14 15]], ]]