mindspore.nn.MicroBatchInterleaved

class mindspore.nn.MicroBatchInterleaved(network, interleave_num=2)[source]

This function splits the input at the 0th into interleave_num pieces and then performs the computation of the wrapped cell. Application scenario: When there is model parallelism in semi-automatic mode and network, if the first slice data is calculating forward, the second slice data will execute the communication operators at the same time, to achieve the performance acceleration of communication and computing concurrency.

Note

The output of the input network must be a single tensor.

Parameters
  • network (Cell) – The target network to wrap.

  • interleave_num (int, optional) – split num of batch size. Default: 2.

Inputs:

tuple[Tensor]. It’s the same with the input of the network .

Outputs:

Tensor. The output of the input network .

Supported Platforms:

Ascend GPU

Examples

>>> net = Net()
>>> net = MicroBatchInterleaved(net, 2)