Gradient Accumulation Algorithm

View Source On Gitee

Overview

This tutorial introduces the training algorithm of gradient accumulation, the purpose of which is to solve the OOM (Out Of Memory) problem that the Batch size is too large to train the neural network or the network model is too large to load due to insufficient memory.

Gradient Accumulation Principle

Gradient accumulation is a way of training a neural network in which data samples are split into several small Batches by Batch size and then calculated sequentially.

Before we discuss the gradient accumulation further, check the calculation process of the neural network.

Deep learning models are made up of many interconnected neural network units, and in all neural network layers, sample data propagates continuously forward. After passing through all the layers, the network model outputs the predicted values of the samples, and then calculates the loss values (errors) for each sample through the loss function. The neural network calculates the gradient of the loss value relative to the model parameters by backpropagation. Finally, the gradient information is used to update the parameters in the network model.

The optimizer is a mathematical formula used to update the weight parameters of the network model. Take a simple stochastic gradient descent (SGD) algorithm as an example.

Assuming the Loss Function function formula is:

\[Loss(\theta)=\frac{1}{2}\left(h(x^{k})-y^{k}\right)^{2}\]

When building a model, the optimizer is used to calculate the algorithm that minimizes losses. Here the SGD algorithm uses the Loss function to update the weight parameter formula as follows:

\[\theta{i}=\theta_{i-1}-lr * grad_{i}\]

where \(\theta\) is the trainable parameter (weight or error) in the network model. \(lr\) is the learning rate, and \(grad_{i}\) is the loss relative to network model parameter.

Gradient accumulation only calculates the neural network model, does not update the parameters of the network model in time, and accumulates the obtained gradient information when calculation, and finally uses the accumulated gradient to update the parameters.

\[accumulated=\sum_{i=0}^{N} grad_{i}\]

When the model variables are not updated, the original data Batch size is actually divided into several Mini-Batches, and the samples used in each step are actually smaller datasets.

The variables are not updated within N steps, so that all Mini-Batches use the same model variables to calculate the gradient, to ensure that the same gradient and weight information is calculated, which is equivalent to using the original Batch size without splitting.

\[\theta_{i}=\theta_{i-1}-lr * \sum_{i=0}^{N} grad_{i}\]

Eventually accumulating the gradient in the previous step yields the sum of the gradients of the same size as using the global Batche size.

In the actual project, there are two points to pay attention to the tuning parameters and algorithms:

  1. learning rate: Under certain conditions, the larger the Batch size, the better the training effect. The gradient accumulation simulates the effect of the increase of the Batch size. If the accumulation steps is 4, the Batch size is increased by 4 times. According to experience, the learning rate needs to be appropriately amplified when using gradient accumulation.

  2. Batch Norm: Batch size simulation amplification effect is performed when the accumulation steps are 4. Compared with the real Batch size, the distribution of the data is not exactly the same, and the mean and variance calculated by Batch Norm of 4 times Batch size is not the same as the actual data mean and variance, so some implementations will use Group Norm instead of Batch Norm.

Gradient Accumulation Implementation

Based on MindSpore functional auto-differentiation mechanism, the function will return the gradient corresponding to the training parameters after the forward and reverse execution is completed. Therefore, we need to design a gradient accumulation class Accumulator to store the gradient values generated by each Step. Here is a sample implementation of Accumulator, where we need to maintain two copies of the same internal properties of the Shape with trainable parameters of the model, namely inner_grads and zeros. The inner_grads are used to store the accumulated gradient values, while the zeros are used to clear the parameters after optimization updates. At the same time, Accumulator maintains a counter variable internally, and after each forward and reverse execution is completed, the counter is self-incrementing, and the cumulative number of steps is determined by taking a mode on the counter to determine whether the cumulative number of steps is reached.

import mindspore as ms
from mindspore import Tensor, Parameter, ops

@ms.jit_class
class Accumulator():
    def __init__(self, optimizer, accumulate_step, clip_norm=1.0):
        self.optimizer = optimizer
        self.clip_norm = clip_norm
        self.inner_grads = optimizer.parameters.clone(prefix="accumulate_", init='zeros')
        self.zeros = optimizer.parameters.clone(prefix="zeros_", init='zeros')
        self.counter = Parameter(Tensor(1, ms.int32), 'counter_')
        assert accumulate_step > 0
        self.accumulate_step = accumulate_step
        self.map = ops.HyperMap()

    def __call__(self, grads):
        # Accumulate the gradients obtained in a single step to the inner_grads of the Accumulator
        self.map(ops.partial(ops.assign_add), self.inner_grads, grads)
        if self.counter % self.accumulate_step == 0:
            # If the accumulated number of steps is reached, parameter optimization update is performed
            self.optimizer(self.inner_grads)
            # Clear inner_grads after completing the parameter optimization update
            self.map(ops.partial(ops.assign), self.inner_grads, self.zeros)
        # The number of steps plus one
        ops.assign_add(self.counter, Tensor(1, ms.int32))

        return True

ms.jit_class is a MindSpore just-in-time compilation modifier that allows ordinary Python classes to be used as compilable computational graphs.

Next, we verify the effect of gradient accumulation by using the handwritten digit recognition model in Quick Start.

from mindspore import nn
from mindspore import value_and_grad
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)


def datapipe(path, batch_size):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(ms.int32)

    dataset = MnistDataset(path)
    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip (10.3 MB)

file_sizes: 100%|██████████████████████████| 10.8M/10.8M [00:06<00:00, 1.67MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./

Suppose we are using configured batch_size=64 in Quick Start which will result in insufficient video memory, at this point we set the number of accumulation steps to 2 and perform gradient accumulation by executing batch_size=32 twice.

First, we use Accumulator, pass in the instantiated optimizer, and configure the number of accumulation steps. Then the forward calculation function forward_fn is defined, and at this point, due to the need for gradient accumulation, the loss value needs to be scaled accordingly.

accumulate_step = 2

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)
accumulator = Accumulator(optimizer, accumulate_step)

def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    # loss divided by cumulate_step
    return loss / accumulate_step

Next, we continue to use the value_and_grad function for function transformation and construct the single-step training function train_step. At this point, we use the instantiated accumulator to perform gradient accumulation. As an internal property of the accumulator, the optimizer does not need to be executed separately.

grad_fn = value_and_grad(forward_fn, None, model.trainable_params())

@ms.jit
def train_step(data, label):
    loss, grads = grad_fn(data, label)
    accumulator(grads)
    return loss

Next, we define the training and evaluation logic and perform training validation.

def train_loop(model, dataset, loss_fn, optimizer):
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")
def test_loop(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

Next, the same 3-epoch training is performed, noting that according to our assumptions, the dataset needs to be set batch_size=32 and accumulated every two steps.

train_dataset = datapipe('MNIST_Data/train', 32)
test_dataset = datapipe('MNIST_Data/test', 32)

Start training validation, and the number of steps to be trained is increased to 2 times due to the small batch_size. The final Accuracy validation results are consistent with the results of Quick Start, both around 92.0%.

epochs = 3
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(model, train_dataset, loss_fn, optimizer)
    test_loop(model, test_dataset, loss_fn)
print("Done!")
Epoch 1
-------------------------------
loss: 1.150851  [  0/1875]
loss: 1.149633  [100/1875]
loss: 1.145340  [200/1875]
loss: 1.140591  [300/1875]
loss: 1.134244  [400/1875]
loss: 1.125991  [500/1875]
loss: 1.100611  [600/1875]
loss: 1.051961  [700/1875]
loss: 0.925877  [800/1875]
loss: 0.879966  [900/1875]
loss: 0.750192  [1000/1875]
loss: 0.617844  [1100/1875]
loss: 0.470084  [1200/1875]
loss: 0.560856  [1300/1875]
loss: 0.359766  [1400/1875]
loss: 0.502521  [1500/1875]
loss: 0.299145  [1600/1875]
loss: 0.383266  [1700/1875]
loss: 0.239381  [1800/1875]
Test:
 Accuracy: 84.8%, Avg loss: 0.528309

Epoch 2
-------------------------------
loss: 0.390662  [  0/1875]
loss: 0.250778  [100/1875]
loss: 0.570571  [200/1875]
loss: 0.196102  [300/1875]
loss: 0.297634  [400/1875]
loss: 0.192528  [500/1875]
loss: 0.231240  [600/1875]
loss: 0.144425  [700/1875]
loss: 0.113696  [800/1875]
loss: 0.233481  [900/1875]
loss: 0.212078  [1000/1875]
loss: 0.144562  [1100/1875]
loss: 0.220822  [1200/1875]
loss: 0.197890  [1300/1875]
loss: 0.283782  [1400/1875]
loss: 0.219684  [1500/1875]
loss: 0.155505  [1600/1875]
loss: 0.255665  [1700/1875]
loss: 0.155548  [1800/1875]
Test:
 Accuracy: 90.1%, Avg loss: 0.340294

Epoch 3
-------------------------------
loss: 0.176077  [  0/1875]
loss: 0.204260  [100/1875]
loss: 0.339903  [200/1875]
loss: 0.221457  [300/1875]
loss: 0.244668  [400/1875]
loss: 0.089163  [500/1875]
loss: 0.159595  [600/1875]
loss: 0.211632  [700/1875]
loss: 0.096592  [800/1875]
loss: 0.081018  [900/1875]
loss: 0.190852  [1000/1875]
loss: 0.139729  [1100/1875]
loss: 0.049344  [1200/1875]
loss: 0.122041  [1300/1875]
loss: 0.198622  [1400/1875]
loss: 0.133956  [1500/1875]
loss: 0.144801  [1600/1875]
loss: 0.076985  [1700/1875]
loss: 0.103241  [1800/1875]
Test:
 Accuracy: 92.0%, Avg loss: 0.281193

Done!