# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Transform distributed checkpoint"""
from __future__ import absolute_import
import os
import glob
import copy
from collections import defaultdict
import numpy as np
import mindspore as ms
from mindspore.common import dtype as mstype
from mindspore.parallel._utils import _is_in_auto_parallel_mode
from mindspore.parallel._parallel_serialization import _rank_list_for_transform_parallel_checkpoint, \
_transform_parallel_checkpoint, _get_device_num_from_strategy, _make_dir, \
_extract_layout_map, _extract_src_dst_layout_map, _parameter_not_in_local_stage, _extract_pipeline_stage_num, \
_merge_protobuf_strategy, _merge_json_strategy, _extract_src_dst_layout_map_by_src
__all__ = ["merge_pipeline_strategys", "rank_list_for_transform", "transform_checkpoint_by_rank",
"transform_checkpoints", "sync_pipeline_shared_parameters", "load_segmented_checkpoints"]
[docs]def merge_pipeline_strategys(src_strategy_dirs, dst_strategy_file):
"""
Merge parallel strategy between all pipeline stages in pipeline parallel mode.
For more details about converting distributed Checkpoint, please refer to
`Model Transformation <https://www.mindspore.cn/tutorials/experts/en/r2.3.1/parallel/model_transformation.html>`_.
Note:
Strategy file of each pipeline stage should be included in src_strategy_dirs.
Args:
src_strategy_dirs (str): The directory of strategy files including all pipeline stage which is saved by
'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'.
dst_strategy_file (str): The file merged strategy to save.
Raises:
NotADirectoryError: `src_strategy_dirs` is not a directory.
Examples:
>>> import mindspore as ms
>>> # src_strategy_dir/stra0.ckpt, src_strategy_dir/stra1.ckpt ... src_strategy_dir/stra127.ckpt
>>> ms.merge_pipeline_strategys("./src_strategy_dir", "./dst_strategy.ckpt")
"""
dst_strategy_dir, _ = os.path.split(dst_strategy_file)
if not os.path.exists(dst_strategy_dir):
_make_dir(dst_strategy_dir, "path")
if not os.path.isdir(src_strategy_dirs):
raise NotADirectoryError("src_strategy_dirs {} is not a directory.".format(src_strategy_dirs))
src_strategy_files_protobuf = glob.glob(os.path.join(src_strategy_dirs, "*.ckpt"))
src_strategy_files_json = glob.glob(os.path.join(src_strategy_dirs, "*.json"))
if src_strategy_files_protobuf and src_strategy_files_json:
raise ValueError("The strategys format should be all '.ckpt' or all '.json'")
is_protobuf = len(src_strategy_files_protobuf) > 0
if is_protobuf:
_merge_protobuf_strategy(src_strategy_files_protobuf, dst_strategy_file)
else:
_merge_json_strategy(src_strategy_files_json, dst_strategy_file)
def _transform_checkpoint_by_stage(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, src_strategy_file,
dst_strategy_file=None):
"""Transform checkpoint for stage in src_strategy_file"""
param_total_dict = defaultdict(dict)
param_attr_dict = defaultdict(dict)
param_type_dict = defaultdict(dict)
src_strategy_list, dst_strategy_list, stage_id = _extract_src_dst_layout_map_by_src(src_strategy_file, \
dst_strategy_file)
src_stage_device_num = np.prod(src_strategy_list.get(list(src_strategy_list.keys())[0])[0]) if src_strategy_list \
is not None else 1
dst_stage_device_num = np.prod(dst_strategy_list.get(list(dst_strategy_list.keys())[0])[0]) if dst_strategy_list \
is not None else 1
origin_dst_strategy_list = _extract_layout_map(dst_strategy_file)
origin_src_strategy_list = _extract_layout_map(src_strategy_file)
checkpoint_files_map = {}
src_rank_id_start = stage_id * src_stage_device_num
for local_rank in range(src_stage_device_num):
rank_id = src_rank_id_start + local_rank
checkpoint_file_name = os.path.join(src_checkpoints_dir, "rank_{}".format(rank_id), "*.ckpt")
rank_ckpts = glob.glob(checkpoint_file_name)
rank_ckpts.sort()
for checkpoint_file in rank_ckpts:
if not os.path.isfile(checkpoint_file):
ms.log.warning("{} is not a checkpoint file.".format(checkpoint_file))
continue
checkpoint_files_map[rank_id] = checkpoint_file
for rank, local_file in checkpoint_files_map.items():
if not os.path.exists(local_file):
raise ValueError("Checkpoint file {} in rank {} not exits: ".format(local_file, rank))
for rank, file_name in checkpoint_files_map.items():
ckpt_dict = ms.load_checkpoint(file_name)
for param_name, param in ckpt_dict.items():
# cut the parameter not in the pipeline stage.
if _parameter_not_in_local_stage(param_name, origin_src_strategy_list, src_strategy_list) \
and _parameter_not_in_local_stage(param_name, origin_dst_strategy_list, dst_strategy_list):
continue
src_rank = rank % src_stage_device_num
param_type_dict[param_name][src_rank] = str(param.data.dtype)
if param.data.dtype == mstype.bfloat16:
param.set_dtype(mstype.float32)
param_total_dict[param_name][src_rank] = param.data.asnumpy()
param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel)
for local_rank_id in range(dst_stage_device_num):
transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict,
param_attr_dict, src_strategy_list, dst_strategy_list,
param_type_dict)
save_checkpoint_file = "{}{}_part{}.ckpt".format(ckpt_prefix, local_rank_id, stage_id)
save_checkpoint_file_dir = os.path.join(dst_checkpoints_dir, "rank_{}".format(local_rank_id))
if not os.path.exists(save_checkpoint_file_dir):
_make_dir(save_checkpoint_file_dir, "path")
save_checkpoint_file_name = os.path.join(save_checkpoint_file_dir, save_checkpoint_file)
ms.save_checkpoint(transform_param_list, save_checkpoint_file_name)
def _transform_checkpoints(src_checkpoints_dir, dst_checkpoints_dir, ckpt_prefix, src_strategy_file=None,
dst_strategy_file=None):
"""Transform checkpoints for all stages in src_strategy_file"""
checkpoints_rank_dir_list = os.path.join(src_checkpoints_dir, "rank_[0-9]*")
all_checkpoint_files_map = {}
for checkpoint_dir in glob.glob(checkpoints_rank_dir_list):
if not os.path.isdir(checkpoint_dir):
ms.log.warning("{} is not a directory.".format(checkpoint_dir))
continue
rank_id_str = checkpoint_dir.split('rank_')[-1]
if not rank_id_str.isdigit():
ms.log.warning("{} is not a expected directory, the directory should end with rank_0/rank_1.....".
format(checkpoint_dir))
continue
rank_id = int(rank_id_str)
checkpoint_file_name = os.path.join(checkpoint_dir, "*.ckpt")
rank_ckpts = glob.glob(checkpoint_file_name)
rank_ckpts.sort()
for checkpoint_file in rank_ckpts:
if not os.path.isfile(checkpoint_file):
ms.log.warning("{} is not a checkpoint file.".format(checkpoint_file))
continue
all_checkpoint_files_map[rank_id] = checkpoint_file
needed_rank_list_map = defaultdict(list)
dst_stage_device_num = _get_device_num_from_strategy(dst_strategy_file)
src_stage_device_num = _get_device_num_from_strategy(src_strategy_file)
dst_stage_num = _extract_pipeline_stage_num(dst_strategy_file)
dst_device_num = dst_stage_device_num * dst_stage_num
origin_src_strategy_list = _extract_layout_map(src_strategy_file)
origin_dst_strategy_list = _extract_layout_map(dst_strategy_file)
for rank in range(dst_device_num):
needed_rank_list = rank_list_for_transform(rank, src_strategy_file, dst_strategy_file)
for needed_rank in needed_rank_list:
if needed_rank not in all_checkpoint_files_map:
raise ValueError("The checkpoint file of rank{} is needed for converting rank{}'s checkpoint, "
"but it is missing.".format(needed_rank, rank))
needed_rank_list_key = "-".join([str(r) for r in needed_rank_list])
needed_rank_list_map[needed_rank_list_key].append(rank)
for needed_rank_list_key, transform_rank_list in needed_rank_list_map.items():
param_total_dict = defaultdict(dict)
param_attr_dict = defaultdict(dict)
param_type_dict = defaultdict(dict)
needed_rank_list = needed_rank_list_key.split("-")
for needed_rank in needed_rank_list:
ckpt_dict = ms.load_checkpoint(all_checkpoint_files_map.get(int(needed_rank)))
for param_name, param in ckpt_dict.items():
src_rank = int(needed_rank) % src_stage_device_num
param_type_dict[param_name][src_rank] = str(param.data.dtype)
if param.data.dtype == mstype.bfloat16:
param.set_dtype(mstype.float32)
param_total_dict[param_name][src_rank] = param.data.asnumpy()
param_attr_dict[param_name][src_rank] = (param.requires_grad, param.layerwise_parallel)
for transform_rank in transform_rank_list:
param_total_dict_copy = copy.deepcopy(param_total_dict)
src_strategy_list, dst_strategy_list = _extract_src_dst_layout_map(transform_rank, src_strategy_file,
dst_strategy_file)
# cut the parameter not in the pipeline stage.
for param in list(param_total_dict_copy.keys()):
if _parameter_not_in_local_stage(param, origin_src_strategy_list, src_strategy_list) \
and _parameter_not_in_local_stage(param, origin_dst_strategy_list, dst_strategy_list):
param_total_dict_copy.pop(param)
local_rank_id = transform_rank % dst_stage_device_num
transform_param_list = _transform_parallel_checkpoint(local_rank_id, param_total_dict_copy,
param_attr_dict, src_strategy_list, dst_strategy_list,
param_type_dict)
save_checkpoint_file = "{}{}.ckpt".format(ckpt_prefix, transform_rank)
save_checkpoint_file_dir = os.path.join(dst_checkpoints_dir, "rank_{}".format(transform_rank))
if not os.path.exists(save_checkpoint_file_dir):
_make_dir(save_checkpoint_file_dir, "path")
save_checkpoint_file_name = os.path.join(save_checkpoint_file_dir, save_checkpoint_file)
ms.save_checkpoint(transform_param_list, save_checkpoint_file_name)
del param_total_dict_copy
del param_total_dict
def _sync_params(name, param, layout):
"""synchronize single parameter"""
if len(layout) < 10:
ms.log.warning("The layout dict does not contain the pipeline_shared_param info %s", name)
return
pipeline_shared = layout[8]
if not pipeline_shared:
return
is_send = layout[9]
peer_rank = layout[10]
sr_tag = layout[11]
class SharedParameterSyncCell(ms.nn.Cell):
"""synchronize cell"""
def __init__(self, param, is_send, peer_rank, sr_tag):
super().__init__()
self.param = param
self.is_send = is_send
self.ret = ms.Tensor([0])
from mindspore.ops import Send, Receive
if self.is_send:
self.send = Send(sr_tag=sr_tag, dest_rank=peer_rank)
else:
self.receive = Receive(sr_tag=sr_tag, src_rank=peer_rank, shape=param.shape, dtype=param.dtype)
def construct(self):
if self.is_send:
out = self.send(self.param)
return ms.ops.functional.depend(self.ret, out)
self.param = self.receive(self.ret)
return ms.ops.functional.depend(self.ret, self.param)
sync_net = SharedParameterSyncCell(param, is_send, peer_rank, sr_tag)
sync_net()
[docs]def sync_pipeline_shared_parameters(net):
"""synchronize pipeline parallel stage shared parameters.
Parameters may be shared between different stages. For example, `embedding table` is
shared by `WordEmbedding` layer and `LMHead` layer, which are usually split into different stages. It is necessary
to perform synchronization after `embedding table` changes.
Note:
The network should be compiled before synchronize pipeline parallel stage shared parameters.
Args:
net (nn.Cell): the inference network.
Supported Platforms:
``Ascend``
Examples:
.. note::
Before running the following examples, you need to configure the communication environment variables.
For the Ascend device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
Startup <https://www.mindspore.cn/tutorials/experts/en/r2.3.1/parallel/dynamic_cluster.html>`_ .
>>> import numpy as np
>>> import mindspore as ms
>>> import mindspore.communication.management as D
>>> from mindspore import lazy_inline, context, nn, ops, Parameter, Tensor
>>> context.set_context(mode=context.GRAPH_MODE)
>>> class Embedding(nn.Cell):
... def __init__(self, shape):
... super().__init__()
... self.w = Parameter(Tensor(np.ones(shape), ms.float32), name='w')
... self.matmul = ops.MatMul().shard(((1, 1), (1, 1)))
... def construct(self, x):
... return self.matmul(x, self.w), self.w
...
>>> class LMHead(nn.Cell):
... def __init__(self):
... super().__init__()
... self.matmul = ops.MatMul(transpose_b=True).shard(((1, 1), (1, 1)))
... def construct(self, x, w):
... return self.matmul(x, w)
...
>>> class Network(nn.Cell):
... @lazy_inline
... def __init__(self):
... super().__init__()
... shape = (4, 4)
... self.word_embedding = Embedding(shape)
... self.lm_head = LMHead()
... self.word_embedding.pipeline_stage = 0
... self.lm_head.pipeline_stage = 1
... def construct(self, x):
... x, embed = self.word_embedding(x)
... return self.lm_head(x, embed)
...
>>> class PipelineCellInference(nn.Cell):
... def __init__(self, network, micro_batch_num):
... super().__init__()
... self.network = network
... self.micro_batch_num = micro_batch_num
... self.concat = ops.Concat()
... def construct(self, x):
... ret = ()
... for i in range(self.micro_batch_num):
... micro_batch_size = x.shape[0] // self.micro_batch_num
... start = micro_batch_size * i
... end = micro_batch_size * (i + 1)
... micro_input = x[start:end]
... y = self.network(micro_input)
... ret = ret + (y,)
... ret = self.concat(ret)
... return ret
>>> D.init()
>>> context.set_auto_parallel_context(parallel_mode='semi_auto_parallel', full_batch=True, pipeline_stages=2)
>>> net = Network()
>>> net = PipelineCellInference(net, 2)
>>> net.set_train(False)
>>> x = Tensor(np.ones((2, 4)), ms.float32)
>>> net.compile(x)
>>> ms.sync_pipeline_shared_parameters(net)
>>> print(net.network.word_embedding.w.asnumpy())
[[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]]
"""
if not isinstance(net, ms.nn.Cell):
ms.log.critical("Failed to synchronize pipeline shared parameters.")
msg = ("For 'sync_pipeline_shared_parameters', the argument 'net' should be a Cell, "
"but got {}.".format(type(net)))
raise TypeError(msg)
layout_dict = net.parameter_layout_dict
if _is_in_auto_parallel_mode() and not layout_dict:
from mindspore.common.api import _get_parameter_layout
layout_dict = _get_parameter_layout()
# switch to standalone mode
parallel_mode = ms.context.get_auto_parallel_context("parallel_mode")
full_batch = ms.context.get_auto_parallel_context("full_batch")
ms.context.set_auto_parallel_context(parallel_mode="stand_alone", full_batch=False)
# synchronize shared parameter
for name, param in net.parameters_and_names():
if name in layout_dict:
_sync_params(name, param, layout_dict[name])
# restore parallel context
ms.context.set_auto_parallel_context(parallel_mode=parallel_mode, full_batch=full_batch)
[docs]def load_segmented_checkpoints(ckpt_file_dir, net=None, strict_load=False, filter_prefix=None,
dec_key=None, dec_mode="AES-GCM", specify_prefix=None, choice_func=None):
"""
Load checkpoint info from a specified file. If the specified ckpt_file_dir path contains multiple
checkpoint files, all checkpoint files will be loaded one by one and the combined dictionary will be return.
Note:
- `specify_prefix` and `filter_prefix` do not affect each other.
- If none of the parameters are loaded from checkpoint file, it will throw ValueError.
- `specify_prefix` and `filter_prefix` are in the process of being deprecated,
`choice_func` is recommended instead.
And using either of those two args will override `choice_func` at the same time.
Args:
ckpt_file_dir (str): Checkpoint file directory.
net (Cell): The network where the parameters will be loaded. Default: ``None`` .
strict_load (bool): Whether to strict load the parameter into net. If ``False`` , it will load parameter
into net when parameter name's suffix in checkpoint file is the same as the
parameter in the network. When the types are inconsistent perform type conversion
on the parameters of the same type, such as float32 to float16. Default: ``False`` .
filter_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
filter_prefix will not be loaded. Default: ``None`` .
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is ``None`` , the decryption
is not required. Default: ``None`` .
dec_mode (str): This parameter is valid only when dec_key is not set to ``None`` . Specifies the decryption
mode, currently supports ``"AES-GCM"`` and ``"AES-CBC"`` and ``"SM4-CBC"`` .
Default: ``"AES-GCM"`` .
specify_prefix (Union[str, list[str], tuple[str]]): Deprecated(see `choice_func`). Parameters starting with the
specify_prefix will be loaded. Default: ``None`` .
choice_func (Union[None, function]) : Input value of the function is a Parameter name of type string,
and the return value is a bool. If returns ``True`` , the Parameter
that matches the custom condition will be loaded. If returns ``False`` , the Parameter that
matches the custom condition will be removed. Default: ``None`` .
Returns:
Dict, key is parameter name, value is a Parameter or string. When the `append_dict` parameter of
:func:`mindspore.save_checkpoint` and the `append_info` parameter of :class:`mindspore.train.CheckpointConfig`
are used to save the checkpoint, `append_dict` and `append_info` are dict types, and their value are string,
then the return value obtained by loading checkpoint is string, and in other cases the return value is
Parameter.
Raises:
TypeError: Input ckpt_file_dir is not a string.
ValueError: Checkpoint file directory doesn't exist. Or it's not a directory
ValueError: Checkpoint file's format is incorrect.
ValueError: Parameter's dict is None after load checkpoint file.
TypeError: The type of `specify_prefix` or `filter_prefix` is incorrect.
"""
if not isinstance(ckpt_file_dir, str):
raise TypeError("The ckpt_file_dir should be a str.")
if not os.path.isdir(ckpt_file_dir):
raise ValueError("The dst_strategy_file: {} doesn't exist. Or it's not a directory".
format(ckpt_file_dir))
checkpoint_file_name = os.path.join(ckpt_file_dir, "*.ckpt")
rank_ckpts = glob.glob(checkpoint_file_name)
parameter_dict = {}
for checkpoint_file in rank_ckpts:
parameter_dict.update(ms.load_checkpoint(checkpoint_file, net, strict_load, filter_prefix, dec_key,
dec_mode, specify_prefix, choice_func))
return parameter_dict