mindspore.transform_checkpoint_by_rank
- mindspore.transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name, src_strategy_file=None, dst_strategy_file=None)[源代码]
将一个分布式网络的Checkpoint由源切分策略转换到目标切分策略,对特定一个rank进行转换。关于更多分布式Checkpoint转换的细节,请参考:模型转换。
- 参数:
rank_id (int) - 待转换得到的Checkpoint的rank号。
checkpoint_files_map (dict) - 源Checkpoint字典,其key为rank号,值为该rank号对应的Checkpoint文件路径。
save_checkpoint_file_name (str) - 目标Checkpoint路径以及名字。
src_strategy_file (str) - 源切分策略proto文件名,由mindspore.set_auto_parallel_context(strategy_ckpt_save_file)接口存储下来的文件。当其为
None
时,表示切分策略为不切分。默认值:None
。dst_strategy_file (str) - 目标切分策略proto文件名,由mindspore.set_auto_parallel_context(strategy_ckpt_save_file)接口存储下来的文件。当其为
None
时,表示切分策略为不切分。默认值:None
。
- 异常:
ValueError - src_strategy_file 或者 dst_strategy_file 不是正确的切分策略proto文件。
ValueError - checkpoint_files_map 内的元素不是正确的Checkpoint文件。
ValueError - save_checkpoint_file_name 不以“.ckpt”结尾。
TypeError - checkpoint_files_map 不是一个字典。
TypeError - src_strategy_file 或者 dst_strategy_file 不是字符串。
TypeError - rank_id 不是一个整数。
TypeError - save_checkpoint_file_name 不是字符串。
样例:
>>> import mindspore as ms >>> dst_device_num = 8 >>> for rank_id in range(dst_device_num): ... rank_list = ms.rank_list_for_transform(rank_id, "./src_strategy.ckpt", "./dst_strategy.ckpt") ... checkpoint_files_map = {} ... for rank in rank_list: ... checkpoint_files_map[rank] = "./origin_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank) ... save_checkpoint_file_name = "./new_checkpoint_rank{}/pangu{}-100_2.ckpt".format(rank_id) ... ms.transform_checkpoint_by_rank(rank_id, checkpoint_files_map, save_checkpoint_file_name, ... "./src_strategy.ckpt", "./dst_strategy.ckpt")