mindspore.nn.ChannelShuffle
- class mindspore.nn.ChannelShuffle(groups)[源代码]
将shape为
的Tensor的通道划分成 组,并按如下方式重新排列 ,同时在最终输出中保持原始Tensor的shape。- 参数:
groups (int) - 划分通道的组数,必须大于0。在上述公式中表示为
。
- 输入:
x (Tensor) - Tensor的shape
。
- 输出:
Tensor,数据类型和shape与 x 相同。
- 异常:
TypeError - groups 非正整数。
ValueError - groups 小于1。
ValueError - x 的维度小于3。
ValueError - x 的通道数不能被 groups 整除。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import mindspore as ms >>> import numpy as np >>> channel_shuffle = ms.nn.ChannelShuffle(2) >>> x = ms.Tensor(np.arange(16).astype(np.int32).reshape(1, 4, 2, 2)) >>> print(x) [[[[ 0 1] [ 2 3]] [[ 4 5] [ 6 7]] [[ 8 9] [10 11]] [[12 13] [14 15]]]] >>> output = channel_shuffle(x) >>> print(output) [[[[ 0 1] [ 2 3]] [[ 8 9] [10 11]] [[ 4 5] [ 6 7]] [[12 13] [14 15]]]]