mindspore.parallel.nn.MicroBatchInterleaved
- class mindspore.parallel.nn.MicroBatchInterleaved(network, interleave_num=2)[source]
Implement the static graph parallel multi-copy splitting function to enable concurrent computation and communication. Application scenario: When there is model parallelism in semi-automatic mode and network, if the first slice data is calculating forward, the second slice data will execute the communication operators at the same time, to achieve the performance acceleration of communication and computing concurrency.
- Parameters
- Inputs:
tuple[Tensor]. It's the same with the input of the network .
- Outputs:
The wrapped input. The output of the input network should be a Tensor.
- Supported Platforms:
Ascend
Examples
>>> import mindspore.nn as nn >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> net = nn.MicroBatchInterleaved(net, 2)