Source code for mindspore.parallel.checkpoint_transform
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transform distributed checkpoint"""
from __future__ import absolute_import
import os
import glob
import copy
from collections import defaultdict
import numpy as np
import mindspore as ms
from mindspore.parallel._parallel_serialization import _rank_list_for_transform_parallel_checkpoint, \
_transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, \
_extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
_merge_protobuf_strategy, _merge_json_strategy
__all__ = ["merge_pipeline_strategys", "rank_list_for_transform", "transform_checkpoint_by_rank",
"transform_checkpoints"]
[docs]def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
"""
Merge parallel strategy between all pipeline stages in pipeline parallel mode.
For more details about converting distributed Checkpoint, please refer to
`Distributed Resilience Training and
Inference <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/resilience_train_and_predict.html>`_.
Note:
Strategy file of each pipeline stage should be included in src_strategy_dirs.
Args:
src_strategy_dirs (str): The directory of strategy files including all pipeline stage which is saved by
'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'
dst_strategy_file (str): The file merged strategy to save.
Raises:
NotADirectoryError: `src_strategy_dirs` is not a directory.
Examples:
>>> # src_strategy_dir/stra0.ckpt, src_strategy_dir/stra1.ckpt ... src_strategy_dir/stra127.ckpt
>>> merge_pipeline_strategys("./src_strategy_dir", "./dst_strategy.ckpt")
"""
dst_strategy_dir, _ = os.path.split(dst_strategy_file)
if not os.path.exists(dst_strategy_dir):
_make_dir(dst_strategy_dir, "path")
if not os.path.isdir(src_strategy_dirs):
raise NotADirectoryError("src_strategy_dirs {} is not a directory.".format(src_strategy_dirs))
src_strategy_files_protobuf = glob.glob(os.path.join(src_strategy_dirs, "*.ckpt"))
src_strategy_files_json = glob.glob(os.path.join(src_strategy_dirs, "*.json"))
if src_strategy_files_protobuf and src_strategy_files_json:
raise ValueError("The strategys format should be all '.ckpt' or all '.json'")
is_protobuf = len(src_strategy_files_protobuf) > 0
if is_protobuf:
_merge_protobuf_strategy(src_strategy_files_protobuf, dst_strategy_file)
else:
_merge_json_strategy(src_strategy_files_json, dst_strategy_file)
[docs]def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=None):
"""
List of original distributed checkpoint rank index for obtaining the target checkpoint of a rank_id
during the distributed checkpoint conversion. For more details about converting distributed Checkpoint,
please refer to `Distributed Resilience Training and
Inference <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/resilience_train_and_predict.html>`_.
Args:
rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion.
src_strategy_file (str): Name of source sharding strategy file which saved by
'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
when the 'src_strategy_file' is None, it means that the source sharding strategy is
without any sharing for each parameter. Default:None.
dst_strategy_file (str): Name of destination sharding strategy file which saved by
'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
when the 'dst_strategy_file' is None, it means that the destination sharding strategy
is without any sharing for each parameter. Default:None.
Returns:
List, the rank list required for converting the distributed checkpoint of rank_id.
Raises:
ValueError: `src_strategy_file` or dst_strategy_file is incorrect.
TypeError: `src_strategy_file` or dst_strategy_file is not a string.
TypeError: `rank_id` is not a int.
Examples:
>>> rank_id = 0
>>> rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
>>> checkpoint_files_map = {}
>>> for rank in rank_list:
... checkpoint_files_map[rank] = "./pangu{}-100_2.ckpt".format(rank)
"""
if not isinstance(rank_id, int):
raise TypeError("The rank_id should be a int.")
if src_strategy_file is None:
return [0]
src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(rank_id, src_strategy_file, dst_strategy_file)
src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \
is not None else 1
dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
is not None else 1
if not src_strategy_list:
raise ValueError("The src_strategy_file is empty.")
local_rank_id = rank_id % dst_stage_device_num if dst_stage_device_num > 1 else rank_id
needed_rank_list_in_local_stage = _rank_list_for_transform_parallel_checkpoint(local_rank_id,
src_strategy_list, dst_strategy_list)
result_set = set()
handled_pipeline_stage = []
for _, layout in src_strategy_list.items():
for src_pipeline_stage_id in layout[6]:
if src_pipeline_stage_id in handled_pipeline_stage:
continue
src_rank_id_start = src_pipeline_stage_id * src_stage_device_num
result_set.update([src_rank_id_start + rank for rank in needed_rank_list_in_local_stage])
handled_pipeline_stage.append(src_pipeline_stage_id)
result_list = list(result_set)
result_list.sort(reverse=True)
return result_list
[docs]def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name,
src_strategy_file=None, dst_strategy_file=None):
"""
Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank
for a network. For more details about converting distributed Checkpoint, please refer to
`Distributed Resilience Training and
Inference <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/resilience_train_and_predict.html>`_.
Args:
rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion.
checkpoint_files_map (dict): The checkpoint files map whose key is the rank id and the value is
the checkpoint file name.
save_checkpoint_file_name (str): The file name to save the converted checkpoint.
src_strategy_file (str): Name of source sharding strategy file which saved by
'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
when the 'src_strategy_file' is None, it means that the source sharding strategy is
without any sharing for each parameter. Default:None.
dst_strategy_file (str): Name of destination sharding strategy file which saved by
'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
when the 'dst_strategy_file' is None, it means that the destination sharding strategy
is without any sharing for each parameter. Default:None.
Raises:
ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect.
ValueError: item in `checkpoint_files_map` is incorrect.
ValueError: `save_checkpoint_file_name` is not end with ".ckpt".
TypeError: `checkpoint_files_map` is not a dict.
TypeError: `src_strategy_file` or `dst_strategy_file` is not a string.
TypeError: `rank_id` is not a int.
TypeError: `save_checkpoint_file_name` is not a string.
Examples:
>>> dst_device_num = 8
>>> for rank_id in range(dst_device_num):
... rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
... checkpoint_files_map = {}
... for rank in rank_list:
... checkpoint_files_map[rank] = "./origin_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank)
... save_checkpoint_file_name = "./new_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank_id)
... transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name,
... "./src_strategy.ckpt", "./dst_strategy.ckpt")
"""
if not isinstance(checkpoint_files_map, dict):
raise TypeError("The checkpoint_files_map should be a dict.")
if not isinstance(rank_id, int):
raise TypeError("The rank_id should be a int.")
if not isinstance(save_checkpoint_file_name, str):
raise TypeError("The save_checkpoint_file_name should be a str.")
if save_checkpoint_file_name[-5:] != ".ckpt":
raise ValueError("The save_checkpoint_file_name {} should end with .ckpt".format(save_checkpoint_file_name))
if dst_strategy_file and os.path.dirname(dst_strategy_file) and not os.path.exists(
os.path.dirname(dst_strategy_file)):
raise ValueError("The director of dst_strategy_file: {} is not exists.".
format(os.path.dirname(dst_strategy_file)))
for rank, local_file in checkpoint_files_map.items():
if not os.path.exists(local_file):
raise ValueError("Checkpoint file {} in rank {} not exits: ".format(local_file, rank))
param_total_dict = defaultdict(dict)
param_attr_dict = defaultdict(dict)
src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(rank_id, src_strategy_file, dst_strategy_file)
# src rank => local rank inside pipeline stage
src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \
is not None else 1
dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
is not None else 1
origin_dst_strategy_list = _extract_layout_map(dst_strategy_file)
origin_src_strategy_list = _extract_layout_map(src_strategy_file)
for rank, file_name in checkpoint_files_map.items():
ckpt_dict = ms.load_checkpoint(file_name)
for param_name, param in ckpt_dict.items():
# cut the parameter not in the pipeline stage.
if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \
and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list):
continue
src_rank = rank % src_stage_device_num
param_total_dict[param_name][src_rank] = param.data.asnumpy()
param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel)
local_rank_id = rank_id % dst_stage_device_num
transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict,
param_attr_dict, src_strategy_list, dst_strategy_list)
ms.save_checkpoint(transform_param_list, save_checkpoint_file_name)
[docs]def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, src_strategy_file=None,
dst_strategy_file=None):
"""
Transform distributed checkpoint from source sharding strategy to destination sharding strategy for a rank.
For more details about converting distributed Checkpoint, please refer to
`Distributed Resilience Training and
Inference <https://www.mindspore.cn/tutorials/experts/en/r2.0/parallel/resilience_train_and_predict.html>`_.
Note:
The `src_checkpoints_dir` directory structure should be organized like "src_checkpoints_dir/rank_0/a.ckpt", the
rank number should be set to a subdirectory and the checkpoint file is stored in this subdirectory. If multiple
files exist in a rank directory, the last file in the lexicgraphic order would be selected.
Args:
src_checkpoints_dir (str): The source checkpoints directory.
dst_checkpoints_dir (str): The destination checkpoints directory to save the converted checkpoints.
ckpt_prefix (str): The destination checkpoint name prefix.
src_strategy_file (str): Name of source sharding strategy file which saved by
'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
when the 'src_strategy_file' is None, it means that the source sharding strategy is
without any sharing for each parameter. Default:None.
dst_strategy_file (str): Name of destination sharding strategy file which saved by
'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
when the 'dst_strategy_file' is None, it means that the destination sharding strategy
is without any sharing for each parameter. Default:None.
Raises:
ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect.
NotADirectoryError: `src_checkpoints_dir` or `dst_checkpoints_dir` is not a directory.
ValueError: The checkpoint file is missing in `src_checkpoints_dir`.
TypeError: `src_strategy_file` or `dst_strategy_file` is not a string.
Examples:
>>> transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, "dst_checkpoint",
... "./src_strategy.ckpt", "./dst_strategy.ckpt")
"""
if not os.path.isdir(src_checkpoints_dir):
raise NotADirectoryError("src_checkpoints_dir {} is not a directory.".format(src_checkpoints_dir))
_make_dir(dst_checkpoints_dir, "path")
if not isinstance(ckpt_prefix, str):
raise TypeError("The ckpt_prefix should be a str.")
checkpoints_rank_dir_list = os.path.join(src_checkpoints_dir, "rank_[0-9]*")
all_checkpoint_files_map = {}
for checkpoint_dir in glob.glob(checkpoints_rank_dir_list):
if not os.path.isdir(checkpoint_dir):
ms.log.warning("{} is not a directory.".format(checkpoint_dir))
continue
rank_id_str = checkpoint_dir.split('rank_')[-1]
if not rank_id_str.isdigit():
ms.log.warning("{} is not a expected directory, the directory should end with rank_0/rank_1.....".
format(checkpoint_dir))
continue
rank_id = int(rank_id_str)
checkpoint_file_name = os.path.join(checkpoint_dir, "*.ckpt")
rank_ckpts = glob.glob(checkpoint_file_name)
rank_ckpts.sort()
for checkpoint_file in rank_ckpts:
if not os.path.isfile(checkpoint_file):
ms.log.warning("{} is not a checkpoint file.".format(checkpoint_file))
continue
all_checkpoint_files_map[rank_id] = checkpoint_file
needed_rank_list_map = defaultdict(list)
dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_file)
src_stage_device_num = _get_device_num_from_strategy(src_strategy_file)
dst_stage_num = _extract_pipeline_stage_num(dst_strategy_file)
dst_device_num = dst_stage_device_num * dst_stage_num
origin_src_strategy_list = _extract_layout_map(src_strategy_file)
origin_dst_strategy_list = _extract_layout_map(dst_strategy_file)
for rank in range(dst_device_num):
needed_rank_list = rank_list_for_transform(rank, src_strategy_file, dst_strategy_file)
for needed_rank in needed_rank_list:
if needed_rank not in all_checkpoint_files_map:
raise ValueError("The checkpoint file of rank{} is needed for converting rank{}'s checkpoint, "
"but it is missing.".format(needed_rank, rank))
needed_rank_list_key = "-".join([str(r) for r in needed_rank_list])
needed_rank_list_map[needed_rank_list_key].append(rank)
for needed_rank_list_key, transform_rank_list in needed_rank_list_map.items():
param_total_dict = defaultdict(dict)
param_attr_dict = defaultdict(dict)
needed_rank_list = needed_rank_list_key.split("-")
for needed_rank in needed_rank_list:
ckpt_dict = ms.load_checkpoint(all_checkpoint_files_map.get(int(needed_rank)))
for param_name, param in ckpt_dict.items():
src_rank = int(needed_rank) % src_stage_device_num
param_total_dict[param_name][src_rank] = param.data.asnumpy()
param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel)
for transform_rank in transform_rank_list:
param_total_dict_copy = copy.deepcopy(param_total_dict)
src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(transform_rank, src_strategy_file,
dst_strategy_file)
# cut the parameter not in the pipeline stage.
for param in list(param_total_dict_copy.keys()):
if _parameter_not_in_local_stage(param, origin_src_strategy_list, src_strategy_list) \
and _parameter_not_in_local_stage(param, origin_dst_strategy_list, dst_strategy_list):
param_total_dict_copy.pop(param)
local_rank_id = transform_rank % dst_stage_device_num
transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict_copy,
param_attr_dict, src_strategy_list, dst_strategy_list)
save_checkpoint_file = "{}{}.ckpt".format(ckpt_prefix, transform_rank)
save_checkpoint_file_dir = os.path.join(dst_checkpoints_dir, "rank_{}".format(transform_rank))
if not os.path.exists(save_checkpoint_file_dir):
_make_dir(save_checkpoint_file_dir, "path")
save_checkpoint_file_name = os.path.join(save_checkpoint_file_dir, save_checkpoint_file)
ms.save_checkpoint(transform_param_list, save_checkpoint_file_name)
del param_total_dict_copy
del param_total_dict