mindspore.transform_checkpoint_by_rank
- mindspore.transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name, src_strategy_file=None, dst_strategy_file=None)[source]
Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank for a network. For more details about converting distributed Checkpoint, please refer to Model Transformation.
- Parameters
rank_id (int) – The rank of which distributed checkpoint needs to be obtained after conversion.
checkpoint_files_map (dict) – The checkpoint files map whose key is the rank id and the value is the checkpoint file name.
save_checkpoint_file_name (str) – The file name to save the converted checkpoint.
src_strategy_file (str) – Name of source sharding strategy file which saved by 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. when the src_strategy_file is None, it means that the source sharding strategy is without any sharing for each parameter. Default:
None
.dst_strategy_file (str) – Name of destination sharding strategy file which saved by 'mindspore.set_auto_parallel_context(strategy_ckpt_save_file)'. when the dst_strategy_file is
None
, it means that the destination sharding strategy is without any sharing for each parameter. Default:None
.
- Raises
ValueError – src_strategy_file or dst_strategy_file is incorrect.
ValueError – item in checkpoint_files_map is incorrect.
ValueError – save_checkpoint_file_name is not end with ".ckpt".
TypeError – checkpoint_files_map is not a dict.
TypeError – src_strategy_file or dst_strategy_file is not a string.
TypeError – rank_id is not an int.
TypeError – save_checkpoint_file_name is not a string.
Examples
>>> import mindspore as ms >>> dst_device_num = 8 >>> for rank_id in range(dst_device_num): ... rank_list = ms.rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt") ... checkpoint_files_map = {} ... for rank in rank_list: ... checkpoint_files_map[rank] = "./origin_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank) ... save_checkpoint_file_name = "./new_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank_id) ... ms.transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name, ... "./src_strategy.ckpt", "./dst_strategy.ckpt")