mindspore.nn.MicroBatchInterleaved

查看源文件
class mindspore.nn.MicroBatchInterleaved(network, interleave_num=2)[源代码]

这个函数的作用是将输入在第零维度拆成 interleave_num 份,然后执行包裹的cell的计算。 使用场景:当在半自动模式以及网络中存在模型并行时,第1份的切片数据的前向计算同时,第2份的数据将会进行模型并行的通信,以此来达到通信计算并发的性能加速。

参数:
  • network (Cell) - 需要封装的网络。

  • interleave_num (int,可选) - batch size的拆分份数,默认值: 2

输入:

tuple[Tensor],与传入的 network 的输入一致。

输出:

被封装后的网络。传入的 network 的输出只能是单个Tensor。

支持平台:

Ascend GPU

样例:

>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> net = nn.MicroBatchInterleaved(net, 2)