mindspore.nn.ChannelShuffle
- class mindspore.nn.ChannelShuffle(groups)[源代码]
将shape为 \((*, C, H, W)\) 的Tensor的通道划分成 \(g\) 组,得到shape为 \((*, C \frac g, g, H, W)\) 的Tensor,并沿着 \(C\) 和 \(\frac{g}{}\), \(g\) 对应轴进行转置,将Tensor还原成原有的shape。
- 参数:
groups (int) - 划分通道的组数,必须大于0。在上述公式中表示为 \(g\) 。
- 输入:
x (Tensor) - Tensor的shape \((*, C_{in}, H_{in}, W_{in})\) 。
- 输出:
Tensor,数据类型和shape与 x 相同。
- 异常:
TypeError - groups 非正整数。
ValueError - groups 小于1。
ValueError - x 的维度小于3。
ValueError - Tensor的通道数不能被 groups 整除。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> channel_shuffle = nn.ChannelShuffle(2) >>> x = Tensor(np.arange(16).astype(np.int32).reshape(1, 4, 2, 2)) >>> print(x) [[[[ 0 1] [ 2 3]] [[ 4 5] [ 6 7]] [[ 8 9] [10 11]] [[12 13] [14 15]]]] >>> output = channel_shuffle(x) >>> print(output) [[[[ 0 1] [ 2 3]] [[ 8 9] [10 11]] [[ 4 5] [ 6 7]] [[12 13] [14 15]]]]