mindspore.nn.PipelineCell
- class mindspore.nn.PipelineCell(network, micro_size)[源代码]
将MiniBatch切分成更细粒度的MicroBatch,用于流水线并行的训练中。
说明
micro_size必须大于或等于流水线stage的个数。
- 参数:
network (Cell) - 要修饰的目标网络。
micro_size (int) - MicroBatch大小。
- 支持平台:
Ascend
GPU
样例:
>>> import mindspore.nn as nn >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.4.1/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> net = nn.PipelineCell(net, 4)