mindspore.parallel.nn.GradAccumulation
- class mindspore.parallel.nn.GradAccumulation(net, micro_size)[源代码]
使能GradAccumulation实现梯度累加。
- 参数:
net (Cell) - 将进行梯度累加的网络。
micro_size (int) - MicroBatchSize。
- 异常:
TypeError - net 不是cell类型输入。
TypeError - micro_size 不是整数类型。
ValueError - micro_size 值异常,为0或者负数。
- 支持平台:
Ascend
样例:
>>> from mindspore.parallel.nn import GradAccumulation >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> net = GradAccumulation(net, 4)