Heterogeneous Storage

View Source On Gitee

Overview

In recent years Transformer-based large models have made rapid progress in various downstream tasks in Natural Language Processing and Computer Vision, and often the larger the model, the higher the accuracy achieved in downstream tasks. The model size develops from hundreds of millions to hundreds of billions, however, large model training consumes a large amount of computational storage resources and the training overhead is huge.

Large model training is limited by the size of the video memory, and the number of model parameters that can be stored on a single card is limited. With model parallel, we can split large models into different machines, and after introducing the necessary inter-process communication, we can conduct collaborative training in clusters, where the model size is proportional to the machine size. At the same time, when the model size exceeds the memory capacity of a single machine, the overhead of inter-machine communication in model parallel will become larger, and the resource utilization will decrease significantly. How to train larger models on a single machine and avoid inter-machine communication in model parallel has become the key to improve the performance of large model training.

Heterogeneous storage management enables 10x to 100x storage expansion of model parameters, thus breaking the memory limitation of large model training and realizing low-cost large model training. This tutorial will explain the basic principles of heterogeneous storage management and introduce the related configuration parameters and their use. With this feature, developers can use the same hardware to train larger models.

The related configuration and switch cide;

import mindspore

offload_config = {"offload_param": "cpu",
                  "auto_offload": False,
                  "offload_cpu_size": "512GB",
                  "offload_disk_size": "1024GB",
                  "offload_path": "./offload/",
                  "host_mem_block_size":"1GB",
                  "enable_aio": True,
                  "enable_pinned_mem": True}
mindspore.set_context(mode=mindspore.GRAPH_MODE, memory_offload='ON', max_device_memory='30GB')
mindspore.set_offload_context(offload_config=offload_config)
  • memory_offload: : Whether to enable heterogeneous storage to temporarily copy free data to Host-side memory in out-of-memory scenarios.

  • max_device_memory: Sets the maximum memory available to the device.

  • offload_config is a configuration option for heterogeneous storage where:

    • "offload_param": "cpu": The parameters of the setup model are stored on the cpu memory and loaded to the device side only when the data needs to be used during the training process, and then unloaded to the cpu memory once the use is complete.

    • "auto_offload": False: set off the auto-offload strategy, parameter data will strictly follow the previous configuration option.

    • "offload_cpu_size": "512GB", "offload_disk_size": "1024GB": The cpu memory and disk size available for offload are set respectively.

    • "offload_path": "./offload/": sets the path to the disk file to be used for offload.

    • "enable_pinned_mem": True: set to turn on page locking, which when turned on speeds up copying between HBM-CPU memory.

    • "host_mem_block_size":"1GB": set the cpu lock page memory pool block size.

    • "enable_aio": True: set to turn on file asynchronous IO, which when turned on speeds up DDR-to-disk copying. (Requires compilation with the -o option, and only supports Linux environments with aio installed)

Basic Principle

During training, the main stored data consists of parameters and intermediate results:

  • Parameters: data such as the weights of the model and the amount of state of the optimizer, which need to be stored all the time during the training process.

  • Intermediate results: data generated by calculations in the forward/backward and optimization processes can be released and deleted after the corresponding calculations are completed.

Through heterogeneous storage management, parameters or intermediate results that do not need to participate in computation temporarily can be copied to the memory of Host side or even hard disk storage during the training process, and then copied and restored to the device side when the data is needed to participate in computation. By the above means, the model size that can be trained by the same hardware device can be increased.

image.png

Operation Practice

The following is an illustration of heterogeneous storage operation using Ascend as an example:

Example Code Description

Download the complete example code: memory_offload.

└─ sample_code
    ├─ memory_offload
       ├── train.py
       └── run.sh
    ...

train.py is the script that defines the network structure and the training process. run.sh is the execution script.

Configuring a Distributed Environment

Specify the run mode, run device, run card number via the context interface. The parallel mode in this sample uses data parallel and initializes HCCL or NCCL communication with init.

import mindspore as ms
from mindspore.communication import init

ms.set_context(mode=ms.GRAPH_MODE)
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True)
ms.set_context(max_device_memory="1GB")
if args_opt.memory_offload == "ON":
    ms.set_context(memory_offload="ON")
    offload_config = {"offload_path": args_opt.offload_path, "auto_offload": args_opt.auto_offload,
                      "offload_param": args_opt.offload_param, "offload_cpu_size": args_opt.offload_cpu_size,
                      "offload_disk_size": args_opt.offload_disk_size,
                      "host_mem_block_size": args_opt.host_mem_block_size,
                      "enable_aio": args_opt.enable_aio, "enable_pinned_mem": args_opt.enable_pinned_mem}
    print("=====offload_config====\n", offload_config, flush=True)
    ms.set_offload_context(offload_config=offload_config)
init()
ms.set_seed(1)

offload_config is the configuration dictionary for heterogeneous storage, and see the relevant configuration notes in the overview of this chapter for details of the configuration. Here max_device_memory is configured to 1GB to trigger heterogeneous storage by preventing the video memory from loading the full network. The “1GB” here only represents the borderline video memory we tested on the Ascend 910, which may vary from device to device.

Loading the Dataset

This example uses the CIFAR-10 dataset for training, so the data processing method corresponding to CIFAR-10 is used, and since the parallel method is data parallel, num_shards and shard_id also need to be configured, and the code is as follows:

import os
import mindspore as ms
import mindspore.dataset as ds
from mindspore import nn

def create_dataset(batch_size):
    dataset_path = os.getenv("DATA_PATH")
    rank_id = get_rank()
    rank_size = get_group_size()
    cifar_ds = ds.Cifar10Dataset(dataset_path, num_shards=rank_size, shard_id=rank_id)

    resize_height = 224
    resize_width = 224
    rescale = 1.0 / 255.0
    shift = 0.0

    random_crop_op = ds.vision.RandomCrop((32, 32), (4, 4, 4, 4))
    random_horizontal_op = ds.vision.RandomHorizontalFlip()
    resize_op = ds.vision.Resize((resize_height, resize_width))
    rescale_op = ds.vision.Rescale(rescale, shift)
    normalize_op = ds.vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    changeswap_op = ds.vision.HWC2CHW()
    type_cast_op = ds.transforms.TypeCast(ms.int32)

    c_trans = [random_crop_op, random_horizontal_op]
    c_trans += [resize_op, rescale_op, normalize_op, changeswap_op]
    cifar_ds = cifar_ds.map(operations=type_cast_op, input_columns="label")
    cifar_ds = cifar_ds.map(operations=c_trans, input_columns="image")
    cifar_ds = cifar_ds.shuffle(buffer_size=10)
    cifar_ds = cifar_ds.batch(batch_size=batch_size, drop_remainder=True)
    return cifar_ds

data_set = create_dataset(args_opt.batch_size)

Defining the Network

The definition of a network is consistent with a single card network:

from mindspore import nn

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 16, 3)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.dense = nn.Dense(16, 10)

    def construct(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.avgpool(x).squeeze()
        logits = self.dense(x)
        return logits

net = Network()

Training the Network

In this step, we need to define the loss function, the optimizer, and the training process, which is written in the same way as the data parallel, also calling the nn.DistributedGradReducer interface to aggregate the gradients, with the following code:

import mindspore as ms
from mindspore import nn

optimizer = nn.SGD(net.trainable_params(), 1e-2)
loss_fn = nn.CrossEntropyLoss()

def forward_fn(data, target):
    logits = net(data)
    loss = loss_fn(logits, target)
    return loss, logits

grad_fn = ms.value_and_grad(forward_fn, None, net.trainable_params(), has_aux=True)
grad_reducer = nn.DistributedGradReducer(optimizer.parameters)

i = 0
step = 10
for image, label in data_set:
    (loss_value, _), grads = grad_fn(image, label)
    grads = grad_reducer(grads)
    optimizer(grads)
    print("step: %s, loss is %s" % (i, loss_value))
    if i >= step:
        break
    i += 1

Running the Script

Next, the corresponding script is called by command:

bash run.sh 96 OFF

When training with batch_size=96 without turning on heterogeneous storage, an error ‘Memory not enough’ is reported due to insufficient memory space:

----------------------------------------------------
- Framework Error Message:
----------------------------------------------------
Out of Memory!!! Request memory size: 1088627200B, Memory Statistic:
Device HBM memory size: 32768M
MindSpore Used memory size: 1024M
MindSpore memory base address: 0x124140000000
Total Static Memory size: 56M
Total Dynamic memory size: 0M
Dynamic memory size of this graph: 0M

Please try to reduce 'batch_size' or check whether exists extra large shape. For more details, please refer to 'Out of Memory' at https://www.mindspore.cn .

After turning on heterogeneous storage, it is able to train normally with batch_size=96:

bash run.sh 96 ON
step: 0, loss is 2.3294048
step: 1, loss is 2.3190398
step: 2, loss is 2.314652
step: 3, loss is 2.3037016
...

Automatically Generating offload Strategies

In addition to copying data strictly according to the user "offload_param" configuration, MindSpore also supports automatic generation of heterogeneous storage strategies. MindSpore can analyze the network video memory usage information and combine the user-configured "max_device_memory", "offload_cpu_size", "offload_disk_size", "hbm_ratio", and "cpu_ratio" to generate eterogeneous storage strategies, and then follow the established strategy to move data across multiple storage media.

import mindspore

offload_config = {"offload_path": "./offload/",
                  "auto_offload": True,
                  "offload_param": "cpu",
                  "offload_cpu_size": "512GB",
                  "offload_disk_size": "1024GB",
                  "host_mem_block_size":"1GB",
                  "enable_aio": True,
                  "enable_pinned_mem": True}
mindspore.set_context(mode=mindspore.GRAPH_MODE, memory_offload='ON', max_device_memory='30GB')
mindspore.set_offload_context(offload_config=offload_config)

In this example, "auto_offload": True is set, "offload_param" only affects the initial storage location of the parameter, and the framework adjusts the weights and intermediate results storage location during the computation process according to the generated strategy.