mindspore.parallel.nn.GradAccumulation

View Source On Gitee
class mindspore.parallel.nn.GradAccumulation(network, micro_size)[source]

Implementation of parallel gradient accumulation for static graphs.

Parameters
  • network (Cell) – The target network to wrap.

  • micro_size (int) – MicroBatch size.

Raises
  • TypeError – The type of network is not cell.

  • TypeError – If the type of micro_size is not int.

  • ValueError – When micro_size is 0 or negative value.

Supported Platforms:

Ascend

Examples

>>> from mindspore.parallel.nn import GradAccumulation
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> net = GradAccumulation(net, 4)