mindspore.parallel.nn.MicroBatchInterleaved

View Source On Gitee
class mindspore.parallel.nn.MicroBatchInterleaved(network, interleave_num=2)[source]

Implement the static graph parallel multi-copy splitting function to enable concurrent computation and communication. Application scenario: When there is model parallelism in semi-automatic mode and network, if the first slice data is calculating forward, the second slice data will execute the communication operators at the same time, to achieve the performance acceleration of communication and computing concurrency.

Parameters
  • network (Cell) – The target network to wrap.

  • interleave_num (int, optional) – split num of batch size. Default: 2 .

Inputs:

tuple[Tensor]. It's the same with the input of the network .

Outputs:

The wrapped input. The output of the input network should be a Tensor.

Supported Platforms:

Ascend

Examples

>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> net = nn.MicroBatchInterleaved(net, 2)