Weight Saving and Resumable Training
Weight Saving
Overview
To train a deep learning model, saving the weights of the model is a critical step. The weight saving function enables you to store model parameters at any training stage so that you can resume training, evaluation, or deployment after the training is interrupted or completed. By saving the weights, you can also reproduce the experiment results in different environments.
Directory Structure
During training, MindFormers generates two weight saving folders in the output directory: checkpoint
and checkpoint_network
.
Folder |
Description |
---|---|
checkpoint |
Stores the weights, optimizer status, steps, and epoches to the ckpt file for resuming training. |
checkpoint_network |
Stores only weight parameters in the ckpt file. This folder applies to pre-trained weight loading or inference and evaluation but not for resuming training. |
checkpoint
Directory Structure
The weight file in the checkpoint
folder is saved in the following format:
checkpoint
├── rank_0
├── meta.json
└── {prefix}-{epoch}_{step}.ckpt
...
└── rank_x
├── meta.json
└── {prefix}-{epoch}_{step}.ckpt
File |
Description |
---|---|
meta.json |
Record the |
{prefix}-{epoch}_{step}.ckpt |
Saved weight file. |
Directory Structure of checkpoint_network
checkpoint
├── rank_0
└── {prefix}-{epoch}_{step}.ckpt
...
└── rank_x
└── {prefix}-{epoch}_{step}.ckpt
File |
Description |
---|---|
{prefix}-{epoch}_{step}.ckpt |
Saved weight file. |
Configuration and Usage
YAML Parameters
You can modify the configuration file to control weight saving. The main parameters are as follows.
Parameter |
Description |
---|---|
save_checkpoint_steps |
Number of steps taken each time a weight is saved. If this parameter is not set, no weight is saved. |
keep_checkpoint_max |
Maximum number of weight files that can be saved at the same time. If the number of weight files reaches the upper limit, the earliest weight file will be deleted when the latest weight file is saved. |
You can modify the fields under CheckpointMonitor
in the yaml
configuration file to control the weight saving behavior. For example:
callbacks:
...
- type: CheckpointMonitor
prefix: "llama2_7b"
save_checkpoint_steps: 500
keep_checkpoint_max: 3
...
In the preceding example, the weights are saved every 500 steps. A maximum of three weights can be saved at the same time.
Resumable Training
Overview
MindFormers supports step-level resumable training, which allows the checkpoints of a model to be saved during training. If the training is interrupted, you can load a saved checkpoint to resume the training. This feature is crucial for processing large-scale training tasks, and can effectively reduce time and resource waste caused by unexpected interruptions. In addition, to resume a training where the dataset remains unchanged but the global batch size
is changed, for example, when the cluster is changed or the configuration is modified, this tool supports automatic scaling of the number of resumable training steps and skipped data steps in the same proportion.
Configuration and Usage
YAML Parameters
You can modify the configuration file to control resumable training. The main parameters are as follows. For details about other parameters, see the description of CheckpointMonitor.
Parameter |
Description |
---|---|
load_checkpoint |
Weight path loaded during resumable training. The path can be a folder path (used to load distributed weights) or a specific weight file path. The default value is an empty string, indicating that no weight is loaded. |
resume_training |
Specifies whether to enable resumable training. You can set it to |
Based on the input parameters, there are four cases.
load_checkpoint |
resume_training |
Description |
Recommended or Not |
---|---|---|---|
Weight file path |
True |
Resumes a training based on the weights specified by load_checkpoint. |
√ |
Weight file path |
Weight file name |
The file name specified by resume_training is invalid. A training is resumed based on the weights specified by load_checkpoint. |
× |
Weight folder path |
True |
Scenario 1: Single-node system, multi-node system+shared directory, or ModelArts |
√ |
Weight folder path |
Weight file name |
Resumes the training based on the weights specified by resume_training. |
√ |
In addition, you can modify the following parameters under the trainer
field in the configuration file to use related functions.
Parameter |
Description |
---|---|
ignore_data_skip |
Specifies whether to ignore the mechanism of skipping data during resumable training and read the dataset from the beginning instead. This parameter is used when the dataset is changed during resumable training. If this parameter is set to |
data_skip_steps |
Number of steps skipped for the dataset. This parameter is used when the training is interrupted again after being resumed because the dataset or |
Fault Recovery Mechanism
If resume_training
is set to True
, the system automatically resumes training based on the weights recorded in meta.json
. If the weight file of a rank is missing or damaged, the system rolls back to the latest available weight for recovery.
In a distributed environment, resumable training requires that the weights of all nodes be in the same shared directory. You can use the
SHARED_PATHS
environment variable to set the shared path.
Example of Distributed Training
The following example shows how to enable resumable training in single-device and multi-device environments. The example is based on the llama2_7b
model.
For related configuration files, see configs/llama2/pretrain_llama2_7b.yaml.
Complete Training
Modify
configs/llama2/pretrain_llama2_7b.yaml
.Configure the parallelism as required.
parallel_config: data_parallel: 1 model_parallel: 2 pipeline_stage: 2 micro_batch_num: 2
Configure the model weight saving as required.
callbacks: ... - type: CheckpointMonitor prefix: "llama2_7b" save_checkpoint_steps: 10 keep_checkpoint_max: 3 integrated_save: False async_save: False ...
Prepare a dataset. The following uses wikitext2 as an example to describe how to start four-device distributed training.
bash scripts/msrun_launcher.sh "run_mindformer.py \ --config configs/llama2/pretrain_llama2_7b.yaml \ --train_dataset /path/to/wikitext2-llama2.mindrecord \ --run_mode train \ --use_parallel True" 4
After the fourth saving is complete, end the process. The structure of the
rank_0
folder undercheckpoint
is as follows:checkpoint/rank_0 ├── llama2_7b_rank_0-10_2.ckpt ├── llama2_7b_rank_0-15_2.ckpt ├── llama2_7b_rank_0-20_2.ckpt └── meta.json
Resumable Training
Modify the configuration and specify the resumable training weight file.
load_checkpoint: './output/checkpoint' resume_training: True
Resume training.
bash scripts/msrun_launcher.sh "run_mindformer.py \ --config configs/llama2/pretrain_llama2_7b.yaml \ --train_dataset /path/to/wikitext2-llama2.mindrecord \ --run_mode train \ --use_parallel True" 4
If the initial number of steps is
42
, the training is resumed successfully. The saved weight file contains the information about step40
. The default value ofsink_size
is2
, indicating that the information is printed every two steps. Therefore, the initial number of steps is42
.
Resumable Training with the Dataset Changed
There are three main scenarios where the dataset is changed in resumable training. You need to modify the configuration file in each scenario. The following describes each case one by one, and describes in detail which step of the basic resumable training process needs to be modified, and how to modify a specific configuration to achieve an expected effect.
Scenario 1: Training resumed with a new dataset (but not skipping trained steps)
In this scenario, when the new dataset is used, the model training starts from scratch without skipping any data or steps. In this case, you need to set the configuration file to ignore the previous data progress so that the model can be trained from scratch based on the new dataset.
Configuration modification: You need to set
ignore_data_skip
based on the first step of the basic resumable training process. Setignore_data_skip
toTrue
, indicating that no data is skipped.load_checkpoint: './output/checkpoint' resume_training: True trainer: ignore_data_skip: True
Expected result: The model is trained from scratch based on the new dataset without skipping any steps.
Scenario 2: Training resumed with a new dataset, skipping trained steps
In this case, the model has been partially trained based on the new dataset (for example, 2
steps have been performed before the training is interrupted), and the training is expected to continue from the last interruption. In this case, you must manually specify the number of steps to be skipped.
Configuration modification: You need to set
ignore_data_skip
anddata_skip_steps
based on the first step of the basic resumable training process. Setignore_data_skip
toFalse
and usedata_skip_steps
to specify the number of trained steps to skip (for example,2
).load_checkpoint: './output/checkpoint' resume_training: True trainer: ignore_data_skip: False data_skip_steps: 2
Expected result: The model skips the first
2
steps and continues the training from step3
based on the new dataset.
Scenario 3: Training resumed with a new dataset and global batch size
changed
If global batch size
is changed (for example, doubled) when a training is resumed based on a new dataset, you need to scale the number of steps that have been performed when manually specifying the number of steps to be skipped. Specifically, the number of skipped steps needs to be divided and rounded down based on the scaling coefficient. For example, if the value of global batch size
is changed to 2
times of the original value, the number of steps that need to be skipped is halved.
Configuration modification: Adjust
data_skip_steps
based on Scenario 2. Setdata_skip_steps
to the number of steps after scaling. For example, ifglobal batch size
is changed to2
times of the original value, the number of steps to be skipped is changed to1
(rounded down).load_checkpoint: './output/checkpoint' resume_training: True trainer: ignore_data_skip: False data_skip_steps: 1
Expected result: The model adjusts the number of skipped steps based on the new setting of
global batch size
and continues the training from the specified position.
Fault Recovery Example
If some weight files are missing, the system automatically restores the files based on the latest available weight.
Delete the
llama2_7b_rank_0-20_2.ckpt
file from therank_3
directory. The folder structure after the deletion is as follows:checkpoint/rank_3 ├── llama2_7b_rank_0-10_2.ckpt ├── llama2_7b_rank_0-15_2.ckpt └── meta.json
Modify the configuration to enable fault recovery.
load_checkpoint: './output/checkpoint' resume_training: True
Start distributed training.
bash scripts/msrun_launcher.sh "run_mindformer.py \ --config configs/llama2/pretrain_llama2_7b.yaml \ --train_dataset /path/to/wikitext2-llama2.mindrecord \ --run_mode train \ --use_parallel True" 4
If the initial number of steps is
32
, the training is resumed successfully. Because the weight of the information in step40
underrank_3
is deleted, the weight saved last time, that is, the weight of the information in step30
, is automatically used. The default value ofsink_size
is2
, indicating that information is printed every two steps. Therefore, the initial number of steps is32
.
Precautions
Data offloading: You must enable data offloading and configure
sink_mode=True
for distributed resumable training.Weight file check: Ensure that the weights loaded for resumable training are the ones saved when the training is interrupted instead of in the entire training process. Otherwise, an error is reported.