mindspore.nn.PipelineCell

class mindspore.nn.PipelineCell(network, micro_size)[源代码]

将MiniBatch切分成更细粒度的MicroBatch,用于流水线并行的训练中。

Note

micro_size必须大于或等于流水线stage的个数。

参数:

  • network (Cell) - 要修饰的目标网络。

  • micro_size (int) - MicroBatch大小。

支持平台:

Ascend GPU

样例:

>>> net = Net()
>>> net = PipelineCell(net, 4)