Dynamic Graph Parallelism

View Source On Gitee

Overview

This tutorial demonstrates how to use MindFormers dynamic graph parallel framework to train GPT models. This framework supports tensor parallel, pipeline parallel, sequence parallel and other parallel scenarios, as well as support for the use of distributed optimizer dynamic learning rate and other scenarios, to help developers quickly and easily build and train GPT pre-training models based on dynamic graph parallel framework.

Operating Practice

The following GPT model training is based on Ascend platform.

Sample Code Reference

The directory structure is as follows:

└─ gpt
    ├─ pretrain_gpt.py
    ├─ pretrain_gpt.sh
    └─ pretrain_gpt_7B.yaml
    ...

Among them, pretrain_gpt.py is the script for environment configuration, model object creation and training. pretrain_gpt.sh is the startup execution script. pretrain_gpt_7B.yaml is the configuration item.

Model Structure

GPT uses the Transformer model as its main architecture, and the network structure is mainly built around the basic building blocks of the Transformer.

In the model, five parameters are initialized, config is the model configuration item (in the model_config of the yaml file), num_tokentypes specifies the type of embedding, parallel_output is used to confirm whether to output the output of each parallel Tensor, pre_ process and post_process specify whether it is the first and last stage, respectively.

The called get_language_model is an interface based on the Transformer model, see the api documentation for get_language_model for details.

Note: The dataset return values are to correspond to the parameters required by the forward process defined by the model.

from mindformers.experimental.parallel_core.pynative.transformer.module import Module
from mindformers.experimental.parallel_core.pynative.transformer.language_model import get_language_model
from mindformers.experimental.parallel_core.pynative.transformer import ParallelLMLogits
from mindformers.experimental.parallel_core.pynative.training.loss_func import VocabParallelCrossEntropy


class AttnMaskType(enum.Enum):
    padding = 1
    causal = 2
    no_mask = 3
    padding_causal = 4


attn_mask_type_mapping = {
    "padding": AttnMaskType.padding,
    "causal": AttnMaskType.causal,
}


class GPTModel(Module):
    def __init__(self,
                 config,
                 num_tokentypes=0,
                 parallel_output=True,
                 pre_process=True,
                 post_process=True):
        super().__init__(config=config,\
                         share_embeddings_and_output_weights=not config.untie_embeddings_and_output_weights)

        self.parallel_output = parallel_output
        self.pre_process = pre_process
        self.post_process = post_process
        self.untie_embeddings_and_output_weights = config.untie_embeddings_and_output_weights
        self.fp16_lm_cross_entropy = config.fp16_lm_cross_entropy

        self.set_model_key()
        encoder_attn_mask_type = None
        if config.encoder_attn_mask_type is not None:
            encoder_attn_mask_type = attn_mask_type_mapping.get(config.encoder_attn_mask_type)
            if encoder_attn_mask_type is None:
                raise ValueError(f"encoder_attn_mask_type must be one of {attn_mask_type_mapping.keys()}, but got"
                                 f"{config.encoder_attn_mask_type}")

        self.language_model, self._language_model_key = get_language_model(
            config=config,
            num_tokentypes=num_tokentypes,
            add_pooler=False,
            encoder_attn_mask_type=encoder_attn_mask_type,
            pre_process=self.pre_process,
            post_process=self.post_process)

        if self.post_process:
            self.parallel_lm_logits = ParallelLMLogits(config=config,
                                                       bias=False,
                                                       compute_dtype=config.compute_dtype)
            self.loss = VocabParallelCrossEntropy()

        if not config.untie_embeddings_and_output_weights:
            self.initialize_word_embeddings()

    def set_input_tensor(self, input_tensor):
        """ set input_tensor to model """
        self.language_model.set_input_tensor(input_tensor)

    def set_model_key(self):
        """ set model key for differentiate PipelineCell process """
        self.model_key = "gpt3"

    def construct(self, input_ids, position_ids, attention_mask, loss_mask,
                  retriever_input_ids=None,
                  retriever_position_ids=None,
                  retriever_attn_mask=None,
                  labels=None, tokentype_ids=None, inference_params=None):
        """ gpt model forward """
        # use RoPE
        position_ids = None
        retriever_input_ids = None
        retriever_position_ids = None
        retriever_attn_mask = None
        lm_output = self.language_model(
            input_ids,
            position_ids,
            attention_mask,
            retriever_input_ids=retriever_input_ids,
            retriever_position_ids=retriever_position_ids,
            retriever_attn_mask=retriever_attn_mask,
            inference_params=inference_params)
        if self.post_process:
            return post_language_model_processing(
                self.parallel_lm_logits, self.loss,
                lm_output, labels,
                self.language_model.output_layer.weight if\
                    self.untie_embeddings_and_output_weights else self.shared_embedding_or_output_weight(),
                self.parallel_output,
                self.fp16_lm_cross_entropy,
                loss_mask)
        else:
            return lm_output

When post_process is set to True, the output lm_output of the language model needs to be post-processed to output losses and predictions.

import mindspore.common.dtype as mstype

def post_language_model_processing(parallel_lm_logits, loss_fn, lm_output, labels, logit_weights,
                                   parallel_output, fp16_lm_cross_entropy, loss_mask):
    """ gpt model post process forward """
    output = parallel_lm_logits(lm_output, logit_weights, parallel_output)

    if labels is None:
        return output

    labels = labels
    loss_mask = loss_mask.reshape(-1)

    if fp16_lm_cross_entropy:
        if output.dtype != mstype.float16:
            raise ValueError(f"When fp16_lm_cross_entropy=True, output should be float16, but got {output.dtype}")
        loss = loss_fn(output, labels, loss_mask)
    else:
        loss = loss_fn(output.astype(mstype.float32), labels)
    token_nums = loss_mask.sum()
    loss_mask = loss_mask.astype(mstype.float32)
    loss = ops.sum(loss * loss_mask.float()) / loss_mask.sum()
    return loss, output, token_nums

Dynamic Graph Parallel Training Configuration

Configuration items for dynamic graph parallel are read through a yaml file and are categorized into different types, including training configuration, parallel configuration, and model configuration. The next section briefly describes the basic configurations needed for large model training.

training_config

training_config:
  seed: 42                                        # Seeds for fixed randomness
  output_dir: './output'                          # Output directory for storing checkpoints, logs, etc.
  training_iters: 10                              # The number of training iterations
  log_interval: 1                                 # Frequency of log prints
  save_interval: null                             # Frequency of storing checkpoints
  loss_scale: 4096                                # Initial value of loss scale
  grad_clip_kwargs:
    grad_clip_type: "ClipGlobalNorm"              # Gradient cropping methods, optional: "ClipGlobalNorm" or  "GradClipByValue"
    clip_value: 1.0
  loss_reduction: "mean"                          # loss reduction methods, optional: "mean" or "sum"
  loss_func_kwargs:
    loss_func_type: "VocabParallelCrossEntropy"   # Loss function, optional: "VocabParallelCrossEntropy" or "CrossEntropyLoss"
  use_distributed_optimizer: True                 # Whether to use a distributed optimizer

parallel_config

parallel_config:
  tensor_model_parallel_size: 1                    # Tensor parallel
  pipeline_model_parallel_size: 1                  # Pipeline parallel
  expert_model_parallel_size: 1                    # Expert parallel
  virtual_pipeline_model_parallel_size: null       # Virtual pipeline parallel
  sequence_parallel: False                         # Sequence parallel

gpt_config

model_config:
  params_dtype: "float32"                          # Parameter initialization type
  compute_dtype: "bfloat16"                        # Types used in calculations
  position_embedding_type: 'rope'                  # Type of location code, optional: "rope" or "absolute"
  untie_embeddings_and_output_weights: True        # Whether the embedding layer and the head layer do not share weights
  # Configure the GPT 7B model
  num_layers: 6                                    # The number of Transformer layers
  hidden_size: 4096                                # Size of the hidden layer
  ffn_hidden_size: 11008                           # Size of feedforward neural network hidden layer
  num_attention_heads: 32                          # Number of attention heads

The GPT model is currently available in three different sizes of configurations: 7B, 13B and 70B.

7B:
  num_layers: 32
  hidden_size: 4096
  ffn_hidden_size: 11008
  num_attention_heads: 32
13B:
  num_layers: 40
  hidden_size: 5120
  ffn_hidden_size: 13824
  num_attention_heads: 40
70B:
  num_layers: 80
  hidden_size: 8192
  ffn_hidden_size: 28672
  num_attention_heads: 64
  group_query_attention: True
  num_query_groups: 8

dataset_config

dataset_config:
  batch_size: 1                                    # Size of data removed from the dataset in one iteration
  micro_batch_num: 2                               # Number of micro batches
  dataset_dir: './dataset'                         # Catalog where the dataset is located
  shuffle: False                                   # Whether to break the order

optimizer_config

optimizer_config:
  optimizer_type: "AdamW"                          # Optimizer types, optional: "AdamW", "Adam", "SGD", "Came", "mint.AdamW" and "SpeedAdamW"
  betas:                                           # Optimizer input parameters
    - 0.9
    - 0.95
  eps: 1.e-8
  learning_rate: 1.25e-6                           # Initial learning rate
  weight_decay: 1.e-1                              # Weight decay factor
  learning_rate_scheduler_kwargs:                  # Learning rate adjustment strategy
    warmup_steps: 200
    decay_steps: 2000
    use_cosine: True
    end_learning_rate: 1.25e-7

Model Training Configuration Parsing

The passing yaml configuration file is parsed in pretrain_gpt.py to get the training configuration, model configuration, optimizer configuration, parallel strategy configuration, and dataset configuration.

import argparse
from mindformers.experimental.parallel_core.pynative.config import (
    init_configs_from_yaml
)

def get_arg_parser():
    """get argument parser"""
    parser = argparse.ArgumentParser(description="Train gpt model")
    parser.add_argument("--config_path", type=str, default="pretrain_gpt.yaml", help="The path to the config file.")
    parser.add_argument("--run_cmd", type=str, default="", help="running cmd.")
    parser.add_argument("--model_type", type=str, default="gpt_config", help="Input model config.")
    return parser
parser = get_arg_parser()
args = parser.parse_args()

all_config = init_configs_from_yaml(args.config_path)

training_config = all_config.training_config
model_config = all_config.model_config
optimizer_config = all_config.optimizer_config
parallel_config = all_config.parallel_config
dataset_config = all_config.dataset_config

Communication Configuration

The set_context interface allows you to specify the run mode, run device, and run card number. The parallel script also needs to specify the parallel mode parallel_mode as the data parallel mode and initialize the HCCL, NCCL or MCCL communication through init depending on the different device requirements. Specify platform: set device_target to Ascend. You can use set_context(pynative_synchronize=True) in debugging phase to enable synchronization mode and locate the error report location more accurately.

import mindspore as ms


def set_parallel_context(parallel_config):
    init()
    initialize_model_parallel(
        tensor_model_parallel_size=parallel_config.tensor_model_parallel_size,
        pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size,
        virtual_pipeline_model_parallel_size=parallel_config.virtual_pipeline_model_parallel_size,
    )
    logger.info(
        f"dp {get_data_parallel_world_size()} | "
        f"pp {parallel_config.pipeline_model_parallel_size} | "
        f"tp {parallel_config.tensor_model_parallel_size} | "
        f"sp {parallel_config.sequence_parallel} | "
        f"vpp {parallel_config.virtual_pipeline_model_parallel_size}"
    )


def set_seed(seed):
    # set global seed, np seed, and dataset seed
    ms.set_seed(seed)
    # set rng seed
    ms.manual_seed(seed)


ms.set_context(device_target="Ascend", mode=ms.PYNATIVE_MODE)
set_parallel_context(parallel_config)
set_seed(training_config.seed)

Creating Network Objects

Get the GPT model from the model library and create a network model object based on the configuration file. Set different weight decay coefficients for different parameters via set_weight_decay, a function that divides the parameters into two groups, one with a specific weight decay value applied and the other with a weight decay of 0, and returns a list containing information about the grouping of the parameters assigned to the group_params variable. The get_optimizer function is called, passing in optimizer_config (optimizer configuration), training_config (training configuration), group_params (information about the grouping of parameters obtained earlier), network_with_loss (an object containing the model and loss ), and a gradient reduction operation (obtained from training_config.loss_reduction) that returns an optimizer object and assigns it to the optimizer variable. Create a TrainOneStepCell object, which is typically used to perform one-step optimization during training. Pass network_with_loss, optimizer and configuration as parameters and assign them to the train_one_step_cell variable.

Complete code for creating network objects:

from mindformers.experimental.parallel_core.pynative.optimizer import get_optimizer
from mindformers.experimental.parallel_core.pynative.training import get_model
from mindformers.experimental.parallel_core.pynative.training import TrainOneStepCell
from mindformers.experimental.parallel_core.models import GPTModel


def decay_filter(x):
    return "norm" not in x.name.lower() and "bias" not in x.name.lower()


def set_weight_decay(params, weight_decay=1e-1):
    decay_params = list(filter(decay_filter, params))
    other_params = list(filter(lambda x: not decay_filter(x), params))
    group_params = []
    if decay_params:
        group_params.append({"params": decay_params, "weight_decay": weight_decay})
    if other_params:
        group_params.append({"params": other_params, "weight_decay": 0.0})
    return group_params


def model_provider_func(pre_process=True, post_process=True):
    network_with_loss = GPTModel(
        model_config, pre_process=pre_process, post_process=post_process
    )
    return network_with_loss

network_with_loss = get_model(model_provider_func, training_config)

group_params = set_weight_decay(network_with_loss.trainable_params(), optimizer_config.weight_decay)
optimizer = get_optimizer(
    optimizer_config,
    training_config,
    group_params,
    network_with_loss,
    grad_allreduce_op=training_config.loss_reduction
)

train_one_step_cell = TrainOneStepCell(network_with_loss, optimizer, None, training_config, model_config)

Loading the Dataset and Performing Training

from dataset import get_dataset
from mindformers.experimental.parallel_core.pynative.training import train

train_dataset_iterator, val_dataset_iterator = get_dataset(dataset_config)
train(
    train_one_step_cell,
    train_dataset_iterator,
    training_config,
    val_dataset_iterator,
    metrics,
    evaluation,
)

Running the Training Script

bash pretrain_gpt.sh xx.yaml

If xx.yaml is not specified, it defaults to pretrain_gpt_7B.yaml.

The training script pretrain_gpt.sh is parsed in detail below:

Setting Environment Variables

HCCL_BUFFSIZE=200 sets the size of the buffer for sharing data between the two NPUs to 200M; HCCL_EXEC_TIMEOUT=600 sets the wait time for synchronization during execution between the devices to 10 minutes. ASCEND_RT_VISIBLE_DEVICES specifies the visible device number, here set to device 0 card.

export HCCL_BUFFSIZE=200
export HCCL_EXEC_TIMEOUT=600
export ASCEND_RT_VISIBLE_DEVICES='0'

Setting Port Number

port=8828

If the previous configuration exits abnormally, you can use the following code to clean it up.

PIDS=$(sudo lsof -i :$port | awk 'NR>1 {print $2}')
if [ -n "$PIDS" ]; then
    for pid in $PIDS; do
        kill -9 $pid
        echo "Killed process $pid"
    done
else
    echo "No processes found listening on port $port."
fi

Setting Log Storage Path

Get the path to the directory where the current script is located and store it in the project_dir variable, and set the log path variable log_path=“msrun_log”. Delete the directory named msrun_log (if it exists) and recreate it.

project_dir=$(cd "$(dirname "$0")" || exit; pwd)
log_path="msrun_log"

rm -rf "${log_path}"
mkdir "${log_path}"

Setting the Number of Available Devices

# Calculate the number of devices
IFS=',' read -r -a devices <<< "$ASCEND_RT_VISIBLE_DEVICES"
work_num=${#devices[@]}

Getting the Configuration File

Try to get the configuration file path from the command line arguments, if no command line arguments are provided, the default configuration file “pretrain_gpt_7B.yaml” is used.

config_path=$1
if [ -z "$config_path" ]; then
    config_path="pretrain_gpt_7B.yaml"
fi

Executing Training Scripts in msrun Mode

msrun --worker_num "$work_num" --local_worker_num="$work_num" --master_port=$port --log_dir="$log_path" --join=True --cluster_time_out=300 pretrain_gpt.py --config_path="${config_path}"

Running Results

Next, the corresponding script is invoked by command.

bash pretrain_gpt.sh

After execution, the log files are saved to the output directory, where some of the files have the following directory structure:

└─ output
    └─ log
        ├─ rank_0
        |   ├─ info.log
        |   └─ error.log
        ├─ rank_1
        |   ├─ info.log
        |   └─ error.log
    ...

The results on the Loss section are saved in output/log/rank_*/info.log, example below:

train: Epoch:0, Step:5, Loss: 10.341485, Finite_grads: True, Loss_scale: 4096.0, Learning_rate: (1.250000e-06,1.250000e-06,), Time: 1403.24 ms
train: Epoch:0, Step:6, Loss: 10.38118, Finite_grads: True, Loss_scale: 4096.0, Learning_rate: (1.250000e-06,1.250000e-06,), Time: 1378.19 ms
train: Epoch:0, Step:7, Loss: 10.165115, Finite_grads: True, Loss_scale: 4096.0, Learning_rate: (1.250000e-06,1.250000e-06,), Time: 1370.32 ms
train: Epoch:0, Step:8, Loss: 10.039211, Finite_grads: True, Loss_scale: 4096.0, Learning_rate: (1.250000e-06,1.250000e-06,), Time: 1386.89 ms
train: Epoch:0, Step:9, Loss: 10.040031, Finite_grads: True, Loss_scale: 4096.0, Learning_rate: (1.250000e-06,1.250000e-06,), Time: 1475.95 ms
...