mindspore.nn.MicroBatchInterleaved
- class mindspore.nn.MicroBatchInterleaved(network, interleave_num=2)[source]
This function splits the input at the 0th into interleave_num pieces and then performs the computation of the wrapped cell. Application scenario: When there is model parallelism in semi-automatic mode and network, if the first slice data is calculating forward, the second slice data will execute the communication operators at the same time, to achieve the performance acceleration of communication and computing concurrency.
- Parameters
- Inputs:
tuple[Tensor]. It's the same with the input of the network .
- Outputs:
The wrapped input. The output of the input network should be a Tensor.
- Supported Platforms:
Ascend
GPU
Examples
>>> import mindspore.nn as nn >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> net = nn.MicroBatchInterleaved(net, 2)