mindspore.nn.PipelineCell

查看源文件
class mindspore.nn.PipelineCell(network, micro_size)[源代码]

将MiniBatch切分成更细粒度的MicroBatch,用于流水线并行的训练中。

说明

micro_size必须大于或等于流水线stage的个数。

参数:
  • network (Cell) - 要修饰的目标网络。

  • micro_size (int) - MicroBatch大小。

支持平台:

Ascend GPU

样例:

>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.3.q1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> net = nn.PipelineCell(net, 4)