mindspore.nn.GradAccumulationCell

class mindspore.nn.GradAccumulationCell(network, micro_size)[源代码]

将MiniBatch切分成更细粒度的MicroBatch,用于半自动/全自动并行模式下的梯度累加训练中。

参数:
  • network (Cell) - 要修饰的目标网络。

  • micro_size (int) - MicroBatch大小。

支持平台:

Ascend GPU

样例:

>>> import mindspore.nn as nn
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/r2.4.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> net = nn.GradAccumulationCell(net, 4)