mindspore.nn.ChannelShuffle

class mindspore.nn.ChannelShuffle(groups)[源代码]

将shape的为 \((*, C, H, W)\) 的Tensor的通道划分成 \(g\) 组,并将其以 \((*, C \frac g, g, H, W)\) 的shape重新排列, 同时保持Tensor原有的shape。

参数:
  • groups (int) - 划分通道的组数。取值范围是 \((0, \inf)\) 。在上述公式中表示为 \(g\)

输入:
  • x (Tensor) - Tensor的shape \((*, C_{in}, H_{in}, W_{in})\)

输出:

Tensor,数据类型和shape与 x 相同。

异常:
  • TypeError - groups 非整数。

  • ValueError - groups 小于1。

  • ValueError - x 的维度小于3。

  • ValueError - Tensor的通道数不能被 groups 整除。

支持平台:

Ascend GPU CPU

样例:

>>> channel_shuffle = nn.ChannelShuffle(2)
>>> x = Tensor(np.arange(16).astype(np.int32).reshape(1, 4, 2, 2))
>>> print(x)
[[[[0 1],
   [2 3]],
  [[4 5],
   [6 7]],
  [[8 9],
   [10 11]],
  [[12 13],
   [14 15]],
 ]]
>>> output = channel_shuffle(x)
>>> print(output)
[[[[0 1],
   [2 3]],
  [[8 9],
   [10 11]],
  [[4 5],
   [6 7]],
  [[12 13],
   [14 15]],
 ]]