Source code for mindspore.parallel.checkpoint_transform

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transform distributed checkpoint"""
from __future__ import absolute_import

import os
import glob
import copy
from collections import defaultdict
import numpy as np
import mindspore as ms
from mindspore.common import dtype as mstype
from mindspore.parallel._utils import _is_in_auto_parallel_mode
from mindspore.parallel._parallel_serialization import _rank_list_for_transform_parallel_checkpoint, \
    _transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, \
    _extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
    _merge_protobuf_strategy, _merge_json_strategy


__all__ = ["merge_pipeline_strategys", "rank_list_for_transform", "transform_checkpoint_by_rank",
           "transform_checkpoints", "sync_pipeline_shared_parameters"]


[docs]def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file): """ Merge parallel strategy between all pipeline stages in pipeline parallel mode. For more details about converting distributed Checkpoint, please refer to `Model Transformation <https://www.mindspore.cn/tutorials/experts/en/r2.3.0rc1/parallel/model_transformation.html>`_. Note: Strategy file of each pipeline stage should be included in src_strategy_dirs. Args: src_strategy_dirs (str): The directory of strategy files including all pipeline stage which is saved by 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. dst_strategy_file (str): The file merged strategy to save. Raises: NotADirectoryError: `src_strategy_dirs` is not a directory. Examples: >>> import mindspore as ms >>> # src_strategy_dir/stra0.ckpt, src_strategy_dir/stra1.ckpt ... src_strategy_dir/stra127.ckpt >>> ms.merge_pipeline_strategys("./src_strategy_dir", "./dst_strategy.ckpt") """ dst_strategy_dir, _ = os.path.split(dst_strategy_file) if not os.path.exists(dst_strategy_dir): _make_dir(dst_strategy_dir, "path") if not os.path.isdir(src_strategy_dirs): raise NotADirectoryError("src_strategy_dirs {} is not a directory.".format(src_strategy_dirs)) src_strategy_files_protobuf = glob.glob(os.path.join(src_strategy_dirs, "*.ckpt")) src_strategy_files_json = glob.glob(os.path.join(src_strategy_dirs, "*.json")) if src_strategy_files_protobuf and src_strategy_files_json: raise ValueError("The strategys format should be all '.ckpt' or all '.json'") is_protobuf = len(src_strategy_files_protobuf) > 0 if is_protobuf: _merge_protobuf_strategy(src_strategy_files_protobuf, dst_strategy_file) else: _merge_json_strategy(src_strategy_files_json, dst_strategy_file)
[docs]def rank_list_for_transform(rank_id, src_strategy_file=None, dst_strategy_file=None): """ List of original distributed checkpoint rank index for obtaining the target checkpoint of a rank_id during the distributed checkpoint conversion. For more details about converting distributed Checkpoint, please refer to `Model Transformation <https://www.mindspore.cn/tutorials/experts/en/r2.3.0rc1/parallel/model_transformation.html>`_. Args: rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion. src_strategy_file (str): Name of source sharding strategy file which saved by `mindspore.set_auto_parallel_context(strategy_ckpt_save_file)`. when the `src_strategy_file` is ``None``, it means that the source sharding strategy is without any sharing for each parameter. Default: ``None``. dst_strategy_file (str): Name of destination sharding strategy file which saved by `mindspore.set_auto_parallel_context(strategy_ckpt_save_file)`. when the `dst_strategy_file` is ``None``, it means that the destination sharding strategy is without any sharing for each parameter. Default: ``None``. Returns: List, the rank list required for converting the distributed checkpoint of rank_id. Raises: ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect. TypeError: `src_strategy_file` or `dst_strategy_file` is not a string. TypeError: `rank_id` is not an int. Examples: >>> import mindspore as ms >>> rank_id = 0 >>> rank_list = ms.rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt") >>> checkpoint_files_map = {} >>> for rank in rank_list: ... checkpoint_files_map[rank] = "./pangu{}-100_2.ckpt".format(rank) """ if not isinstance(rank_id, int): raise TypeError("The rank_id should be a int.") if src_strategy_file is None: return [0] src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(rank_id, src_strategy_file, dst_strategy_file) src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \ is not None else 1 dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \ is not None else 1 if not src_strategy_list: raise ValueError("The src_strategy_file is empty.") local_rank_id = rank_id % dst_stage_device_num if dst_stage_device_num > 1 else rank_id needed_rank_list_in_local_stage = _rank_list_for_transform_parallel_checkpoint(local_rank_id, src_strategy_list, dst_strategy_list) result_set = set() handled_pipeline_stage = [] for _, layout in src_strategy_list.items(): for src_pipeline_stage_id in layout[6]: if src_pipeline_stage_id in handled_pipeline_stage: continue src_rank_id_start = src_pipeline_stage_id * src_stage_device_num result_set.update([src_rank_id_start + rank for rank in needed_rank_list_in_local_stage]) handled_pipeline_stage.append(src_pipeline_stage_id) return list(result_set)
[docs]def transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name, src_strategy_file=None, dst_strategy_file=None): """ Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank for a network. For more details about converting distributed Checkpoint, please refer to `Model Transformation <https://www.mindspore.cn/tutorials/experts/en/r2.3.0rc1/parallel/model_transformation.html>`_. Args: rank_id (int): The rank of which distributed checkpoint needs to be obtained after conversion. checkpoint_files_map (dict): The checkpoint files map whose key is the rank id and the value is the checkpoint file name. save_checkpoint_file_name (str): The file name to save the converted checkpoint. src_strategy_file (str): Name of source sharding strategy file which saved by 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. when the `src_strategy_file` is None, it means that the source sharding strategy is without any sharing for each parameter. Default: ``None``. dst_strategy_file (str): Name of destination sharding strategy file which saved by 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. when the `dst_strategy_file` is ``None``, it means that the destination sharding strategy is without any sharing for each parameter. Default: ``None``. Raises: ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect. ValueError: item in `checkpoint_files_map` is incorrect. ValueError: `save_checkpoint_file_name` is not end with ".ckpt". TypeError: `checkpoint_files_map` is not a dict. TypeError: `src_strategy_file` or `dst_strategy_file` is not a string. TypeError: `rank_id` is not an int. TypeError: `save_checkpoint_file_name` is not a string. Examples: >>> import mindspore as ms >>> dst_device_num = 8 >>> for rank_id in range(dst_device_num): ... rank_list = ms.rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt") ... checkpoint_files_map = {} ... for rank in rank_list: ... checkpoint_files_map[rank] = "./origin_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank) ... save_checkpoint_file_name = "./new_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank_id) ... ms.transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name, ... "./src_strategy.ckpt", "./dst_strategy.ckpt") """ if not isinstance(checkpoint_files_map, dict): raise TypeError("The checkpoint_files_map should be a dict.") if not isinstance(rank_id, int): raise TypeError("The rank_id should be a int.") if not isinstance(save_checkpoint_file_name, str): raise TypeError("The save_checkpoint_file_name should be a str.") if save_checkpoint_file_name[-5:] != ".ckpt": raise ValueError("The save_checkpoint_file_name {} should end with .ckpt".format(save_checkpoint_file_name)) if dst_strategy_file and os.path.dirname(dst_strategy_file) and not os.path.exists( os.path.dirname(dst_strategy_file)): raise ValueError("The director of dst_strategy_file: {} is not exists.". format(os.path.dirname(dst_strategy_file))) for rank, local_file in checkpoint_files_map.items(): if not os.path.exists(local_file): raise ValueError("Checkpoint file {} in rank {} not exits: ".format(local_file, rank)) param_total_dict = defaultdict(dict) param_attr_dict = defaultdict(dict) param_type_dict = defaultdict(dict) src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(rank_id, src_strategy_file, dst_strategy_file) # src rank => local rank inside pipeline stage src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \ is not None else 1 dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \ is not None else 1 origin_dst_strategy_list = _extract_layout_map(dst_strategy_file) origin_src_strategy_list = _extract_layout_map(src_strategy_file) for rank, file_name in checkpoint_files_map.items(): ckpt_dict = ms.load_checkpoint(file_name) for param_name, param in ckpt_dict.items(): # cut the parameter not in the pipeline stage. if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \ and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list): continue src_rank = rank % src_stage_device_num param_type_dict[param_name][src_rank] = str(param.data.dtype) if param.data.dtype == mstype.bfloat16: param.set_dtype(mstype.float32) param_total_dict[param_name][src_rank] = param.data.asnumpy() param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel) local_rank_id = rank_id % dst_stage_device_num transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict, param_attr_dict, src_strategy_list, dst_strategy_list, param_type_dict) ms.save_checkpoint(transform_param_list, save_checkpoint_file_name)
[docs]def transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, src_strategy_file=None, dst_strategy_file=None): """ Transform distributed checkpoint from source sharding strategy to destination sharding strategy for a rank. For more details about converting distributed Checkpoint, please refer to `Model Transformation <https://www.mindspore.cn/tutorials/experts/en/r2.3.0rc1/parallel/model_transformation.html>`_. Note: The `src_checkpoints_dir` directory structure should be organized like "src_checkpoints_dir/rank_0/a.ckpt", the rank number should be set to a subdirectory and the checkpoint file is stored in this subdirectory. If multiple files exist in a rank directory, the last file in the lexicgraphic order would be selected. Args: src_checkpoints_dir (str): The source checkpoints directory. dst_checkpoints_dir (str): The destination checkpoints directory to save the converted checkpoints. ckpt_prefix (str): The destination checkpoint name prefix. src_strategy_file (str): Name of source sharding strategy file which saved by 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. when the `src_strategy_file` is ``None``, it means that the source sharding strategy is without any sharing for each parameter. Default: ``None``. dst_strategy_file (str): Name of destination sharding strategy file which saved by 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. when the `dst_strategy_file` is ``None``, it means that the destination sharding strategy is without any sharing for each parameter. Default: ``None``. Raises: ValueError: `src_strategy_file` or `dst_strategy_file` is incorrect. NotADirectoryError: `src_checkpoints_dir` or `dst_checkpoints_dir` is not a directory. ValueError: The checkpoint file is missing in `src_checkpoints_dir`. TypeError: `src_strategy_file` or `dst_strategy_file` is not a string. Examples: >>> import mindspore as ms >>> ms.transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, "dst_checkpoint", ... "./src_strategy.ckpt", "./dst_strategy.ckpt") """ if not os.path.isdir(src_checkpoints_dir): raise NotADirectoryError("src_checkpoints_dir {} is not a directory.".format(src_checkpoints_dir)) _make_dir(dst_checkpoints_dir, "path") if not isinstance(ckpt_prefix, str): raise TypeError("The ckpt_prefix should be a str.") checkpoints_rank_dir_list = os.path.join(src_checkpoints_dir, "rank_[0-9]*") all_checkpoint_files_map = {} for checkpoint_dir in glob.glob(checkpoints_rank_dir_list): if not os.path.isdir(checkpoint_dir): ms.log.warning("{} is not a directory.".format(checkpoint_dir)) continue rank_id_str = checkpoint_dir.split('rank_')[-1] if not rank_id_str.isdigit(): ms.log.warning("{} is not a expected directory, the directory should end with rank_0/rank_1.....". format(checkpoint_dir)) continue rank_id = int(rank_id_str) checkpoint_file_name = os.path.join(checkpoint_dir, "*.ckpt") rank_ckpts = glob.glob(checkpoint_file_name) rank_ckpts.sort() for checkpoint_file in rank_ckpts: if not os.path.isfile(checkpoint_file): ms.log.warning("{} is not a checkpoint file.".format(checkpoint_file)) continue all_checkpoint_files_map[rank_id] = checkpoint_file needed_rank_list_map = defaultdict(list) dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_file) src_stage_device_num = _get_device_num_from_strategy(src_strategy_file) dst_stage_num = _extract_pipeline_stage_num(dst_strategy_file) dst_device_num = dst_stage_device_num * dst_stage_num origin_src_strategy_list = _extract_layout_map(src_strategy_file) origin_dst_strategy_list = _extract_layout_map(dst_strategy_file) for rank in range(dst_device_num): needed_rank_list = rank_list_for_transform(rank, src_strategy_file, dst_strategy_file) for needed_rank in needed_rank_list: if needed_rank not in all_checkpoint_files_map: raise ValueError("The checkpoint file of rank{} is needed for converting rank{}'s checkpoint, " "but it is missing.".format(needed_rank, rank)) needed_rank_list_key = "-".join([str(r) for r in needed_rank_list]) needed_rank_list_map[needed_rank_list_key].append(rank) for needed_rank_list_key, transform_rank_list in needed_rank_list_map.items(): param_total_dict = defaultdict(dict) param_attr_dict = defaultdict(dict) param_type_dict = defaultdict(dict) needed_rank_list = needed_rank_list_key.split("-") for needed_rank in needed_rank_list: ckpt_dict = ms.load_checkpoint(all_checkpoint_files_map.get(int(needed_rank))) for param_name, param in ckpt_dict.items(): src_rank = int(needed_rank) % src_stage_device_num param_type_dict[param_name][src_rank] = str(param.data.dtype) if param.data.dtype == mstype.bfloat16: param.set_dtype(mstype.float32) param_total_dict[param_name][src_rank] = param.data.asnumpy() param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel) for transform_rank in transform_rank_list: param_total_dict_copy = copy.deepcopy(param_total_dict) src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(transform_rank, src_strategy_file, dst_strategy_file) # cut the parameter not in the pipeline stage. for param in list(param_total_dict_copy.keys()): if _parameter_not_in_local_stage(param, origin_src_strategy_list, src_strategy_list) \ and _parameter_not_in_local_stage(param, origin_dst_strategy_list, dst_strategy_list): param_total_dict_copy.pop(param) local_rank_id = transform_rank % dst_stage_device_num transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict_copy, param_attr_dict, src_strategy_list, dst_strategy_list, param_type_dict) save_checkpoint_file = "{}{}.ckpt".format(ckpt_prefix, transform_rank) save_checkpoint_file_dir = os.path.join(dst_checkpoints_dir, "rank_{}".format(transform_rank)) if not os.path.exists(save_checkpoint_file_dir): _make_dir(save_checkpoint_file_dir, "path") save_checkpoint_file_name = os.path.join(save_checkpoint_file_dir, save_checkpoint_file) ms.save_checkpoint(transform_param_list, save_checkpoint_file_name) del param_total_dict_copy del param_total_dict
def _sync_params(name, param, layout): """synchronize single parameter""" if len(layout) < 10: ms.log.warning("The layout dict does not contain the pipeline_shared_param info %s", name) return pipeline_shared = layout[8] if not pipeline_shared: return is_send = layout[9] peer_rank = layout[10] sr_tag = layout[11] class SharedParameterSyncCell(ms.nn.Cell): """synchronize cell""" def __init__(self, param, is_send, peer_rank, sr_tag): super().__init__() self.param = param self.is_send = is_send self.ret = ms.Tensor([0]) from mindspore.ops.operations._inner_ops import Send, Receive if self.is_send: self.send = Send(sr_tag=sr_tag, dest_rank=peer_rank) else: self.receive = Receive(sr_tag=sr_tag, src_rank=peer_rank, shape=param.shape, dtype=param.dtype) def construct(self): if self.is_send: out = self.send(self.param) return ms.ops.functional.depend(self.ret, out) self.param = self.receive(self.ret) return ms.ops.functional.depend(self.ret, self.param) sync_net = SharedParameterSyncCell(param, is_send, peer_rank, sr_tag) sync_net()
[docs]def sync_pipeline_shared_parameters(net): """synchronize pipeline parallel stage shared parameters. Parameters may be shared between different stages. For example, `embedding table` is shared by `WordEmbedding` layer and `LMHead` layer, which are usually split into different stages. It is necessary to perform synchronization after `embedding table` changes. Note: The network should be compiled before synchronize pipeline parallel stage shared parameters. Args: net (nn.Cell): the inference network. Examples: >>> import numpy as np >>> import mindspore as ms >>> from mindspore import nn, ops, Parameter, Tensor >>> class VocabEmbedding(nn.Cell): ... def __init__(self, vocab_size, embedding_size): ... super().__init__() ... self.embedding_table = Parameter(Tensor(np.ones([vocab_size, embedding_size]), ms.float32), ... name='embedding') ... self.gather = ops.Gather() ... ... def construct(self, x): ... output = self.gather(self.embedding_table, x, 0) ... output = output.squeeze(1) ... return output, self.embedding_table.value() ... >>> class LMHead(nn.Cell): ... def __init__(self): ... super().__init__() ... self.matmul = ops.MatMul(transpose_b=True) ... ... def construct(self, state, embed): ... return self.matmul(state, embed) ... >>> class Network(nn.Cell): ... @lazy_inline ... def __init__(self): ... super().__init__() ... self.word_embedding = VocabEmbedding(vocab_size=4, embedding_size=4) ... self.head = LMHead() ... ... def construct(self, x): ... x, embed = self.word_embedding(x) ... x = self.head(x, embed) ... return x >>> >>> net = Network() >>> net.word_embedding.pipeline_stage = 0 >>> net.head.pipeline_stage = 1 >>> x = Tensor(np.ones((8, 4)) >>> net.compile() >>> ms.sync_pipeline_shared_parameters(net) >>> print(net.word_embedding.embedding_table.asnumpy()) >>> [[1. 1. 1. 1.] [1. 1. 1. 1.] [1. 1. 1. 1.] [1. 1. 1. 1.]] """ if not isinstance(net, ms.nn.Cell): ms.log.critical("Failed to synchronize pipeline shared parameters.") msg = ("For 'sync_pipeline_shared_parameters', the argument 'net' should be a Cell, " "but got {}.".format(type(net))) raise TypeError(msg) layout_dict = net.parameter_layout_dict if _is_in_auto_parallel_mode() and not layout_dict: from mindspore.common.api import _get_parameter_layout layout_dict = _get_parameter_layout() # switch to standalone mode parallel_mode = ms.context.get_auto_parallel_context("parallel_mode") full_batch = ms.context.get_auto_parallel_context("full_batch") ms.context.set_auto_parallel_context(parallel_mode="stand_alone", full_batch=False) # synchronize shared parameter for name, param in net.parameters_and_names(): if name in layout_dict: _sync_params(name, param, layout_dict[name]) # restore parallel context ms.context.set_auto_parallel_context(parallel_mode=parallel_mode, full_batch=full_batch)