mindspore.transform_checkpoint_by_rank

mindspore.transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name, src_strategy_file=None, dst_strategy_file=None)[source]

Transform distributed checkpoint from source sharding strategy to destination sharding strategy by rank for a network.

Parameters
  • rank_id (int) – The rank of which distributed checkpoint needs to be obtained after conversion.

  • checkpoint_files_map (dict) – The checkpoint files map whose key is the rank id and the value is the checkpoint file name.

  • save_checkpoint_file_name (str) – The file name to save the converted checkpoint.

  • src_strategy_file (str) – Name of source sharding strategy file which saved by ‘mindspore.set_auto_parallel_context(strategy_ckpt_save_file)’. when the ‘src_strategy_file’ is None, it means that the source sharding strategy is without any sharing for each parameter. Default:None.

  • dst_strategy_file (str) – Name of destination sharding strategy file which saved by ‘mindspore.set_auto_parallel_context(strategy_ckpt_save_file)’. when the ‘dst_strategy_file’ is None, it means that the destination sharding strategy is without any sharing for each parameter. Default:None.

Raises
  • ValueErrorsrc_strategy_file or dst_strategy_file is incorrect.

  • ValueError – item in checkpoint_files_map is incorrect.

  • ValueErrorsave_checkpoint_file_name is not end with “.ckpt”.

  • TypeErrorcheckpoint_files_map is not a dict.

  • TypeErrorsrc_strategy_file or dst_strategy_file is not a string.

  • TypeErrorrank_id is not a int.

  • TypeErrorsave_checkpoint_file_name is not a string.

Examples

>>> dst_device_num = 8
>>> for rank_id in range(dst_device_num)
>>>     rank_list = rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt")
>>>     checkpoint_files_map = {}
>>>     for rank in rank_list:
>>>         checkpoint_files_map[rank] = "./origin_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank)
>>>     save_checkpoint_file_name = "./new_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank_id)
>>>     transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name,
...                                  "./src_strategy.ckpt", "./dst_strategy.ckpt")