mindspore.transform_checkpoint_by_rank

View Source On Gitee
mindspore.transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name, src_strategy_file=None, dst_strategy_file=None)[source]

Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank for a network. For more details about converting distributed Checkpoint, please refer to Model Transformation.

Parameters
  • rank_id (int) – The rank of which distributed checkpoint needs to be obtained after conversion.

  • checkpoint_files_map (dict) – The checkpoint files map whose key is the rank id and the value is the checkpoint file name.

  • save_checkpoint_file_name (str) – The file name to save the converted checkpoint.

  • src_strategy_file (str) – Name of source sharding strategy file which saved by ‘mindspore.set_auto_parallel_context(strategy_ckpt_save_file)’. when the src_strategy_file is None, it means that the source sharding strategy is without any sharing for each parameter. Default: None.

  • dst_strategy_file (str) – Name of destination sharding strategy file which saved by ‘mindspore.set_auto_parallel_context(strategy_ckpt_save_file)’. when the dst_strategy_file is None, it means that the destination sharding strategy is without any sharing for each parameter. Default: None.

Raises
  • ValueErrorsrc_strategy_file or dst_strategy_file is incorrect.

  • ValueError – item in checkpoint_files_map is incorrect.

  • ValueErrorsave_checkpoint_file_name is not end with “.ckpt”.

  • TypeErrorcheckpoint_files_map is not a dict.

  • TypeErrorsrc_strategy_file or dst_strategy_file is not a string.

  • TypeErrorrank_id is not an int.

  • TypeErrorsave_checkpoint_file_name is not a string.

Examples

>>> import mindspore as ms
>>> dst_device_num = 8
>>> for rank_id in range(dst_device_num):
...     rank_list = ms.rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
...     checkpoint_files_map = {}
...     for rank in rank_list:
...         checkpoint_files_map[rank] = "./origin_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank)
...     save_checkpoint_file_name = "./new_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank_id)
...     ms.transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name,
...                                  "./src_strategy.ckpt", "./dst_strategy.ckpt")