Model Transformation

View Source On Gitee

Overview

Background

When using MindSpore for distributed training, it is often necessary to transform the distributed Checkpoint (i.e. model transformation) obtained from training to carry out the next steps, such as inference, fine-tuning, and multi-stage training. In this tutorial, we will introduce how to transform the Checkpoint obtained from distributed training to carry out resilient training and inference with distributed strategies and cluster card changes.

This function only supports SEMI_AUTO_PARALLEL and AUTO_PARALLEL modes.

Usage Scenarios

If you encounter the following scenario, refer to this tutorial operation for resilience training and inference:

  • Scenario 1: Using M cards for training, and using N cards for fine-tuning training, where M and N can have no multiplicative relationship.

  • Scenario 2: Training is divided into multiple phases, each with a different cluster size.

  • Scenario 3: Using M cards for training, and using N cards for inference, where M and N can have no multiplicative relationship.

  • Scenario 4: Changes need to be made to the network sharding strategy.

Related interfaces:

  1. mindspore.transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, src_strategy_file, dst_strategy_file): Transform Checkpoint of the distributed network from a source sharding strategy to a target sharding strategy, where src_checkpoints_dir is the directory where the source Checkpoint file is located, and its subdirectories are required to be stored in the format rank_x/checkpoint_x.ckpt, with x being the corresponding rank id. dst_checkpoints_dir is the directory where the target checkpoint file is stored, src_strategy_file is the name of the source sharding strategy file, and dst_strategy_file is the name of the target sharding strategy file.

  2. mindspore.rank_list_for_transform(rank_id, src_strategy_file, dst_strategy_file): Get the rank list of source Checkpoint file required for the Checkpoint file of the target rank during the transformation of a distributed Checkpoint.

  3. mindspore.transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name, src_strategy_file, dst_strategy_file): Transform a Checkpoint of a distributed network from a source sharding strategy to a target sharding strategy for a specific rank, where rank_id is the rank number of the Checkpoint to be transformed. checkpoint_files_map is the source Checkpoint dictionary whose key is the rank number and the value is the path to the Checkpoint file corresponding to that rank number. save_checkpoint_file_name is the path and name of the target Checkpoint for the current rank.

  4. mindspore.load_segmented_checkpoints(ckpt_file_dir) : Load all .ckpt checkpoint files in the specified ckpt_file_dir path. Return a combined parameter dict.

Operation Practice

As an example of training on an Ascend 8-card and fine-tuning on 4-card, the overall procedure is as follows:

  1. Perform training, configure the storage location of the model parameter sharding strategy file, and automatically generate the Checkpoint file and the model parameter sharding strategy file.

  2. Compile the fine-tuned network, configure the location of the distributed strategy file storage, and automatically generate the model parameter slice and sharding strategy file.

  3. The user transforms the saved Checkpoint file based on the strategy file involved in training and inference.

  4. After compiling the fine-tuned network, load the distributed Checkpoint file obtained from the transformation.

  5. Execute the fine-tuned network.

It should be noted that loading distributed Checkpoint requires compiling the network first.

Example Code Description

Download the complete example code: model_saving_loading.

The directory structure is as follows:

└─ sample_code
    ├─ model_saving_loading
       ├── model_transformation_infer.py
       ├── model_transformation_retrain.py
       ├── pipeline_train.py
       ├── pipeline_transformation_retrain.py
       ├── run_infer_convert.sh
       ├── run_infer.sh
       ├── run_retrain_convert.sh
       ├── run_retrain.sh
       ├── run_pipeline_train.sh
       ├── run_retrain_pipeline_convert.sh
       ├── run_retrain_pipeline.sh
       ...
    ...

The functions of each file are as follows:

  • model_transformation_infer.py: Scripts for inference after model transformation.

  • model_transformation_retrain.py: Scripts for second-stage training after model transformation.

  • pipeline_transformation_retrain.py: Scripts for second-stage training after pipeline parallel model transformation.

  • pipeline_train.py: Scripts for pipeline parallel training of networks.

  • run_infer_convert.sh: Scripts that perform model transformation used in inference.

  • run_retrain_convert.sh: Scripts that perform model transformation used in second-stage training.

  • run_retrain_pipeline_convert.sh: Scripts that perform pipeline parallel model transformation.

  • run_infer.sh: Scripts that perform model inference.

  • run_retrain.sh: Scripts that perform second-stage training.

  • run_pipeline_train.sh: Scripts that perform pipeline training.

  • run_retrain_pipeline.sh: Scripts for executing second-stage training pipeline parallel model.

Saving the Distributed Model

First, follow the Model Saving tutorial to perform 8-card distributed training with a parallel mode of SEMI_AUTO_ PARALLEL or AUTO_PARALLEL, while customizing the strategy_ckpt_config parameter by calling the set_auto_parallel_context interface to configure the model sharding strategy file storage path. After training for a period of time, call the train.ModelCheckpoint function of storage Checkpoint to store the distributed checkpoint.

At the end of the training, the source Checkpoint file directory as well as the source sharding strategy file will be generated at the current path:

src_checkpoints/
src_strategy.ckpt

Subdirectory within src_checkpoints are required to be stored in the rank_x/checkpoint_x.ckpt format.

That is, the directory structure of src_checkpoints is changed to the following:

src_checkpoints
 ├─ rank_0
 |   └─ checkpoint_0.ckpt
 ├─ rank_1
 |   └─ checkpoint_1.ckpt
 ├─ rank_2
 |   └─ checkpoint_2.ckpt
 ├─ rank_3
 |   └─ checkpoint_3.ckpt
 ├─ rank_4
 |   └─ checkpoint_4.ckpt
 ├─ rank_5
 |   └─ checkpoint_5.ckpt
 ├─ rank_6
 |   └─ checkpoint_6.ckpt
 └─ rank_7
     └─ checkpoint_7.ckpt
...

Generating Target Strategy Files

Then the network under the new card or sharding strategy needs to be compiled to generate the model sharding strategy file for the target network. In this example, the original strategy is trained with 8 cards, the ops.MatMul() operator parallel strategy of layer1 is ((2, 1), (1, 2)), the optimizer parallel is not turned on, and the strategy file is named as src_strategy.ckpt. The target strategy is trained with 4 cards, the ops.MatMul() operator parallel strategy of layer1 is ((2, 2), (2, 1)) and optimizer parallel is turned on, the strategy file is named as dst_strategy.ckpt.

Configuring Distributed Environment

Specify the run mode, run device, run card number via the context interface. Configure and save the distributed strategy file via strategy_ckpt_config, enable optimizer parallel and initialize HCCL or NCCL communication via init.

import mindspore as ms
from mindspore.communication import init

ms.set_context(mode=ms.GRAPH_MODE)
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
ms.set_auto_parallel_context(strategy_ckpt_config={"save_file": args_opt.dst_strategy_file})
ms.set_auto_parallel_context(enable_parallel_optimizer=True)
init()
ms.set_seed(1)

Network Definition and Loading the Dataset

The network definition modifies the ops.MatMul() operator parallel strategy for layer1 in the original network:

import os
from mindspore import nn, ops
import mindspore.dataset as ds
from mindspore.common.initializer import initializer

class Dense(nn.Cell):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.weight = ms.Parameter(initializer("normal", [in_channels, out_channels], ms.float32))
        self.bias = ms.Parameter(initializer("normal", [out_channels], ms.float32))
        self.matmul = ops.MatMul()
        self.add = ops.Add()

    def construct(self, x):
        x = self.matmul(x, self.weight)
        x = self.add(x, self.bias)
        return x

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = ops.Flatten()
        self.layer1 = Dense(28*28, 512)
        self.relu1 = ops.ReLU()
        self.layer2 = Dense(512, 512)
        self.relu2 = ops.ReLU()
        self.layer3 = Dense(512, 10)

    def construct(self, x):
        x = self.flatten(x)
        x = self.layer1(x)
        x = self.relu1(x)
        x = self.layer2(x)
        x = self.relu2(x)
        logits = self.layer3(x)
        return logits

net = Network()
net.layer1.matmul.shard(((1, 4), (4, 1)))
net.layer3.matmul.shard(((2, 2), (2, 1)))

def create_dataset(batch_size):
    dataset_path = os.getenv("DATA_PATH")
    dataset = ds.MnistDataset(dataset_path)
    image_transforms = [
        ds.vision.Rescale(1.0 / 255.0, 0),
        ds.vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        ds.vision.HWC2CHW()
    ]
    label_transform = ds.transforms.TypeCast(ms.int32)
    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset

data_set = create_dataset(32)

Performing Compilation on the Target Network

The distributed Checkpoint transformation depends on the original distributed strategy file and the target distributed strategy file. When performing the training of the network under the original strategy, the distributed strategy file is stored, so it is necessary to obtain the distributed strategy file under the target strategy separately. The distributed strategy file for the target strategy network can be obtained by performing compilation of the network with the target strategy. Compilation of the network can be performed separately through the model.infer_train_layout interface.

import mindspore as ms
from mindspore import nn, ops

optimizer = nn.SGD(net.trainable_params(), 1e-2)
loss_fn = nn.CrossEntropyLoss()
model = ms.Model(net, loss_fn=loss_fn, optimizer=optimizer)
model.infer_train_layout(data_set)

When the target network is to perform inference, model.infer_train_layout is replaced with model.infer_predict_layout to perform compilation:

import numpy as np
import mindspore as ms

predict_data = ms.Tensor(np.random.randn(1, 28, 28).astype(np.float32))
model = ms.Model(net)
model.infer_predict_layout(predict_data)

After compilation, you can get the target sharding strategy file dst_strategy.ckpt.

Executing Distributed Checkpoint Transformation

In this step, you need to call the distributed Checkpoint transformation interface for distributed Checkpoint transformation. Distributed Checkpoint provides two interfaces for Checkpoint transformation.

The first interface, transform_checkpoints, requires the user to place all checkpoints in a single directory, and the subdirectories must be named in the format "rank_0, rank_1, rank_2, …". The user calls this interface to transform the entire directory directly, which is easier to use, but the transformation requires a slightly higher memory overhead.

The second interface, transform_checkpoint_by_rank, is used to get the checkpoints for a particular rank, which has more flexibility and lower memory overhead, and needs to be used in conjunction with the rank_list_for_transform interface to determine original Checkpoints are needed to get the target checkpoints for this rank.

  1. Use the interface transform_checkpoints.

    import mindspore as ms
    
    ms.transform_checkpoints(args_opt.src_checkpoints_dir, args_opt.dst_checkpoints_dir, "checkpoint_", args_opt.src_strategy_file, args_opt.dst_strategy_file)
    

    The method is used in the example code model_transformation_retrain.py.

  2. Call the transform_checkpoint_by_rank interface to perform a parameter merge on the original Checkpoint corresponding to the current rank.

    Ensure that the subdirectory "rank_x" exists in dst_checkpoints_dir.

    import os
    import mindspore as ms
    from mindspore.communication import get_rank
    
    rank_list = ms.rank_list_for_transform(get_rank(), args_opt.src_strategy_file, args_opt.dst_strategy_file)
    checkpoint_file_map = {}
    for rank_id in rank_list:
        checkpoint_file_map[rank_id] = os.path.join(args_opt.src_checkpoints_dir, "rank_{}".format(rank_id), "checkpoint_{}.ckpt".format(rank_id))
    save_checkpoint_path = os.path.join(args_opt.dst_checkpoints_dir, "rank_{}".format(get_rank()), "checkpoint_{}.ckpt".format(get_rank()))
    ms.transform_checkpoint_by_rank(get_rank(), checkpoint_file_map, save_checkpoint_path, args_opt.src_strategy_file, args_opt.dst_strategy_file)
    

    The method is used in the example code model_transformation_infer.py.

After execution, a directory of transformed target Checkpoint files will be generated:

dst_checkpoints/

CheckPoint Conversion in MultiProcessing

During the process of writing code for converting the CheckPoint files, users often use the following serial logic, resulting in a conversion time of several hours. For example, the following writing method results in a parallelism of 1, which is not recommended.

import mindspore as ms
dst_device_num = 8
for tgt_rank in range(dst_device_num):
    rank_list = ms.rank_list_for_transform(tgt_rank, "./src_strategy.ckpt", "./dst_strategy.ckpt")
    checkpoint_files_map = {}
    for rank_id in rank_list:
        checkpoint_file_map[rank_id] = os.path.join(args_opt.src_checkpoints_dir, "rank_{}".format(rank_id), "checkpoint_{}.ckpt".format(rank_id))
    save_checkpoint_path = os.path.join(args_opt.dst_checkpoints_dir, "rank_{}".format(get_rank()), "checkpoint_{}.ckpt".format(get_rank()))
    ms.transform_checkpoint_by_rank(tgt_rank, checkpoint_file_map, save_checkpoint_path, args_opt.src_strategy_file, args_opt.dst_strategy_file)

We can accelerate the procession of ckpt conversion by introducing multi-process method. The specific modification is as follows: each process determines which target rank weight file to convert based on its current rank_id. That is, each process converts only one weight file.

Please set the number of processes according to the memory consumed by converting a single weight and the total memory in the current host, otherwise it will cause a memory shortage error in the host.

import sys
import mindspore as ms

rank_list = ms.rank_list_for_transform(get_rank(), "./src_strategy.ckpt", "./dst_strategy.ckpt")
checkpoint_files_map = {}
for rank_id in rank_list:
    checkpoint_file_map[rank_id] = os.path.join(args_opt.src_checkpoints_dir, "rank_{}".format(rank_id), "checkpoint_{}.ckpt".format(rank_id))
save_checkpoint_path = os.path.join(args_opt.dst_checkpoints_dir, "rank_{}".format(get_rank()), "checkpoint_{}.ckpt".format(get_rank()))
ms.transform_checkpoint_by_rank(get_rank(), checkpoint_file_map, save_checkpoint_path, args_opt.src_strategy_file, args_opt.dst_strategy_file)

In addition, you need to configure the startup method. Change the single-process startup to multi-process startup. The following shows an example of a 4-process conversion.

mpirun -n 4 --output-filename log_output --merge-stderr-to-stdout python model_transformation_infer.py --only_compile=1

The code used in this tutorial already adopt the multi-process style. The purpose of this section is to emphasize how to solve the problem of slow conversion in daily use scenarios for users.

Loading the Transformed Checkpoint Files

The network for the target strategy is compiled and the load_checkpoint interface is called to load the model parameter data from the transformed Checkpoint file.

Compile the network using the model.infer_train_layout (for training) or model.infer_predict_layout (for inference) interfaces, at which point the weight Shape is sliced in the compilation process. Call the load_checkpoint interface to load the model parameter data for each card from the Checkpoint file.

The target network is the training scenario:

import os
import mindspore as ms
from mindspore import nn, train

save_checkpoint_path = os.path.join(args_opt.dst_checkpoints_dir, "rank_{}".format(get_rank()), "checkpoint_{}.ckpt".format(get_rank()))
loss_cb = train.LossMonitor(20)
model.infer_train_layout(data_set)
param_dict = ms.load_checkpoint(save_checkpoint_path)
ms.load_param_into_net(net, param_dict)
model.train(2, data_set, callbacks=[loss_cb])
  • save_checkpoint_path: The name of the Checkpoint model parameter file corresponding to the current rank that needs to be loaded.

The target network is the inference scenario:

import os
import mindspore as ms

save_checkpoint_path = os.path.join(args_opt.dst_checkpoints_dir, "rank_{}".format(get_rank()), "checkpoint_{}.ckpt".format(get_rank()))
param_dict = ms.load_checkpoint(save_checkpoint_path)
model.infer_predict_layout(predict_data)
ms.load_param_into_net(net, param_dict)
predict_result = model.predict(predict_data)
print(predict_result)
  • predict_data: Tensor data for inference.

Running Stand-alone 4-card Script

Next, the corresponding scripts are called by commands to perform second-stage fine-tuning training after model transformation, using the mpirun startup method with a 4-card distributed script as an example:

bash run_retrain_convert.sh
bash run_retrain.sh

Or infer after the model transformation:

bash run_infer_convert.sh
bash run_infer.sh

After the execution is completed, the log file is saved to the log_output directory, the target Checkpoint file is saved in the dst_checkpoints folder, and the target strategy file is saved in dst_strategy.ckpt, with the following directory structure of the files:

├─ src_strategy.ckpt
├─ dst_strategy.ckpt
├─ log_output
|   └─ 1
|       ├─ rank.0
|       |   └─ stdout
|       ├─ rank.1
|       |   └─ stdout
|       ...
├─ dst_checkpoints
|   ├─ rank_0
|   |   └─ checkpoint_0.ckpt
|   ├─ rank_1
|   |   └─ checkpoint_1.ckpt
|   |   ...
|   ...
...

The part of results of the Loss after the second-stage fine-tuned training are saved in log_output/1/rank.*/stdout, as exemplified below:

epoch: 1, step: 20, loss is 0.10617774
epoch: 1, step: 40, loss is 0.06953259
epoch: 1, step: 60, loss is 0.08409108
epoch: 1, step: 80, loss is 0.08699021
epoch: 1, step: 100, loss is 0.07113413
...

In the case of an inference task, the results are saved in log_output/1/rank.*/stdout, as exemplified below:

[[ 0.05044775 -0.94413316  0.84689134 -0.2881832   0.66444755  1.0564336
  -0.04191193  0.25590348 -0.690101   -0.6532427 ]]

Pipeline Parallel Model Transformation

Pipelining Parallel is to slice a linear network to get multiple sub-networks, which are pipelined among multiple cards. Therefore, the sharding strategy file stored for each subgraph is inconsistent, and all the sharding strategies are aggregated together to get the complete slicing information of the network. Therefore, for the dimension of pipeline parallel, compared to the transformation of other dimensions, it is necessary to perform an operation of aggregating the sharding strategy file before getting the aggregated sharding strategy file, and use this file as the strategy file on which the distributed Checkpoint transformation depends. In addition, there is no difference with the previous Executing Distributed Checkpoint Transformation.

Related interfaces:

mindspore.merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file): the sharding strategy file that aggregates the subgraphs of all pipeline parallels in pipeline parallel mode. src_strategy_dirs is the directory containing the sharding strategy files for all pipeline-parallel subgraphs, and the sharding strategy files are obtained by storing them by the mindspore.set_auto_parallel_context(strategy_ckpt_config) interface. dst_strategy_file is the path to the file where the converged sharding strategy is stored.

First, 8-card pipeline parallel training is executed, where pipeline parallel dimension is 2 and optimizer parallelism is turned on.

The training code is in pipeline_train.py. The network structure adds a pipeline parallel configuration based on the chapter Model Saving with parallel dimension 2.

The core code is:

import mindspore as ms
from mindspore import nn, train
from mindspore.communication import init, get_rank
...
ms.set_context(mode=ms.GRAPH_MODE)
ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
ms.set_auto_parallel_context(pipeline_stages=2, enable_parallel_optimizer=True)
init()
ms.set_auto_parallel_context(strategy_ckpt_config={"save_file": "./src_pipeline_strategys/src_strategy_{}.ckpt".format(get_rank())})
ms.set_seed(1)
...
ckpt_config = train.CheckpointConfig(save_checkpoint_steps=1000, keep_checkpoint_max=1, integrated_save=False)
ckpoint_cb = train.ModelCheckpoint(prefix="checkpoint",
                                   directory="./src_checkpoints_pipeline/rank_{}".format(get_rank()),
                                   config=ckpt_config)
net_with_grads = nn.PipelineCell(nn.WithLossCell(net, loss_fn), 4)
model = ms.Model(net_with_grads, optimizer=optimizer)
model.train(3, data_set, callbacks=[loss_cb, ckpoint_cb])

The sharding strategy file is inconsistent for each card, so it needs to be saved separately and stored in "src_pipeline_strategys/src_strategy_x.ckpt" format.

Execute the 8-card training script execution command as:

bash run_pipeline_train.sh

After execution, the source Checkpoint file directory and the source sharding strategy file will be generated with the file directory structure:

├─ src_checkpoints_pipeline
|   ├─ rank_0
|   |   ├─ checkpoint-3_1875.ckpt
|   |   └─ checkpoint-graph.meta
|   ├─ rank_1
|   |   ├─ checkpoint-3_1875.ckpt
|   |   ...
|   ...
├─ src_pipeline_strategys
|   ├─ src_strategy_0.ckpt
|   ├─ src_strategy_1.ckpt
|   ├─ src_strategy_2.ckpt
|   ├─ src_strategy_3.ckpt
|   ├─ src_strategy_4.ckpt
|   ├─ src_strategy_5.ckpt
|   ├─ src_strategy_6.ckpt
|   └─ src_strategy_7.ckpt
...

Refer to Performing a compilation of the target network section, and similarly compile the target network in order to obtain the sharding strategy file for the target network.

The next step unfolds the distributed Checkpoint dimension transformation containing pipeline parallel dimensions, first merging the sharding strategy files obtained from pipline training using interface merge_pipeline_strategys, and then performing the distributed Checkpoint transformation using interface transform_checkpoints or transform_checkpoint_by_rank.

The example introduces an interface that uses transform_checkpoints, and the interface that uses transform_checkpoint_by_rank. Refer to Executing Distributed Checkpoint Transformation.

import mindspore as ms

ms.merge_pipeline_strategys(args_opt.src_strategy_dir, args_opt.src_strategy_file)
ms.transform_checkpoints(args_opt.src_checkpoints_dir, args_opt.dst_checkpoints_dir, "checkpoint_", args_opt.src_strategy_file, args_opt.dst_strategy_file)

Subdirectories within src_checkpoints_dir are required to be stored in the format "rank_x/checkpoint_x.ckpt".

The example script execution command to transform the entire Checkpoint catalog is:

bash run_retrain_pipeline_convert.sh

After the transformation is completed, refer to Loading the Transformed Checkpoint Files section to execute the distributed network without pipeline dimension.

In the example, the script execution command for loading the transformed Checkpoint for second-stage fine-tuning training is:

bash run_retrain_pipeline.sh

After the execution is complete, you can see that the loss is decreasing from 0.15:

epoch: 1, step: 20, loss is 0.15090162
epoch: 1, step: 40, loss is 0.13296325
epoch: 1, step: 60, loss is 0.14676111
epoch: 1, step: 80, loss is 0.11930083
epoch: 1, step: 100, loss is 0.0784434
epoch: 1, step: 120, loss is 0.10741685

Pipeline Parallel Subnetwork Model Transformation

The transform_checkpoints interface also supports checkpoint transformation of a single subnetwork when source strategy using pipeline parallel. Different from the pipeline model transformation mentioned above, the strategy file of a single pipeline subnetwork is used as the source strategy file for distributed checkpoint transformation, without performing a prior operation of aggregating and slicing policy files. Using subnetwork model transformation, each process will transform parameters in this subnetwork for all target ranks. The suffix _part* identifies which subnetwork under the pipeline used to convert the checkpoint file.

Taking the training network used in the first part of pipeline parallel model transformation as an example, the training script and its use will not be described again. After execution, the source checkpoint file directory and source strategy file are generated. The file directory structure is:

├─ src_checkpoints_pipeline
|   ├─ rank_0
|   |   ├─ checkpoint-3_1875.ckpt
|   |   └─ checkpoint-graph.meta
|   ├─ rank_1
|   |   ├─ checkpoint-3_1875.ckpt
|   |   ...
|   ...
├─ src_pipeline_strategys
|   ├─ src_strategy_0.ckpt
|   ├─ src_strategy_1.ckpt
|   ├─ src_strategy_2.ckpt
|   ├─ src_strategy_3.ckpt
|   ├─ src_strategy_4.ckpt
|   ├─ src_strategy_5.ckpt
|   ├─ src_strategy_6.ckpt
|   └─ src_strategy_7.ckpt
...

Refer to performing compilation on the target network section. Compile the target network to obtain the strategy file of the target network.

In the training parallel stratey, the pipeline parallel dimensino is 2. The network will be divided into two subnetworks. The strategy files src_strategy_0.ckpt and src_strategy_4.ckpt are used for checkpoint transformation. Using transform_checkpoints interface for checkpoint transformation of a single subnetwork.

import mindspore as ms

stage_strategy_list = ['src_strategy_0.ckpt', 'src_strategy_4.ckpt']
for src_strategy_file in stage_strategy_list:
  ms.transform_checkpoints(args_opt.src_checkpoints_dir, args_opt.dst_checkpoints_dir, "checkpoint_", src_strategy_file, args_opt.dst_strategy_file)

Currently, only multi-card to multi-card checkpoint transformation is supported, that is, src_strategy_file=None or dst_strategy_file=None is not supported, and the transform whose target strategy is pipeline parallel is not supported.

The directory structure of the checkpoint file obtained after transformation is:

├─ dst_checkpoints_dir
|   ├─ rank_0
|   |   ├─ checkpoint_0_part0.ckpt
|   |   └─ checkpoint_0_part1.ckpt
|   ├─ rank_1
|   |   ├─ checkpoint_1_part0.ckpt
|   |   └─ checkpoint_1_part1.ckpt
|   ...

Since the checkpoint transformation is carried out in subnetworks, each subnetwork will transferom one checkpoint file corresponding to the target rank. The directory rank_* stores the checkpoint files fransformed by all subnetworks under the corresponding target rank. checkpoint_0_part0.ckpt is the checkpoint of target rank 0 transformed from subnetwork 0.

The weight files are obtained by using a single subnetwork model transformation approach, and all the weight files under the rank_* path make up the complete weight of the target strategy under that rank, and all the files under the path need to be read when loading. Using the load_segmented_checkpoints interface, we can load and aggregate checkpoint files given a specified path. The example code is as follows:

import mindspore as ms

net = Network()

param_dict = ms.load_segmented_checkpoints(checkpoint_file_dir)
param_not_load, _ = ms.load_param_into_net(net, param_dict)