mindspore.load_distributed_checkpoint
- mindspore.load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None, train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM', format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None, output_format='safetensors', name_map=None, max_process_num=64, return_param_dict=False)[源代码]
给分布式预测加载checkpoint文件到网络。用于分布式推理。关于分布式推理的细节,请参考: 分布式模型加载 。
说明
只有 format 设置为 safetensors 并且 network 为 None 时,output_format才会生效。
- 参数:
network (Cell) - 分布式预测网络,format为 safetensors 时,network入参可以不传递或传递为None,此时接口执行保存模式。
checkpoint_filenames (list[str]) - checkpoint文件的名称,按rank id顺序排列。默认值:
None
。predict_strategy (Union[dict, str]) - 预测时参数的切分策略或者策略文件。默认值:
None
。train_strategy_filename (str) - 训练策略proto文件名。默认值:
None
。strict_load (bool) - 表示是否严格加载参数到网络。如果值为
False
,则当checkpoint文件中参数名称的后缀与网络中的参数相同时,加载参数到网络。当类型不一致时,对相同类型的参数进行类型转换,如从float32到float16。默认值:False
。dec_key (Union[None, bytes]) - 用于解密的字节类型key。如果value为
None
,则不需要解密。默认值:None
。dec_mode (str) - 仅当dec_key不设为
None
时,该参数有效。指定了解密模式,目前支持'AES-GCM'
,'AES-CBC'
和'SM4-CBC'
。默认值:'AES-GCM'
。format (str) - 待加载进网络的输入权重格式。可以设置为 "ckpt" 或 "safetensors"。默认值:"ckpt"。
unified_safetensors_dir (str) - 待加载进网络的输入权重文件目录。默认值:
None
。dst_safetensors_dir (str) - 保存模式场景下,权重的保存目录。
rank_id (int) - 卡的逻辑序号。非保存模式下,通过初始化网络全局自动获取;保存模式下,按传入序号保存文件,若未传入,则全量保存。
output_format (str, 可选) - 控制转换后输出的 checkpoint 格式。可以设置为 "ckpt" 或 "safetensors"。默认值:"safetensors"。
name_map (dict) - 权重映射字典,切分完的权重加载到网络或保存之前,会按照映射字典修改权重名字。默认值:None。
max_process_num (int) - 最大进程数。默认值:64。
return_param_dict (bool) - 是否返回 param_dict。默认值:
False
。
- 异常:
TypeError - 输入类型不符合要求。
ValueError - 无法加载checkpoint文件到网络。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import os >>> import numpy as np >>> import mindspore as ms >>> import mindspore.dataset as ds >>> from mindspore import nn, ops, train >>> from mindspore.communication import init >>> >>> step_per_epoch = 4 >>> device_num = 8 >>> >>> # Define the network structure. >>> class Net(nn.Cell): ... def __init__(self, matmul_size, strategy=None): ... super().__init__() ... matmul_np = np.full(matmul_size, 0.5, dtype=np.float32) ... self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np)) ... self.matmul = ops.MatMul() ... self.neg = ops.Neg() ... if strategy is not None: ... self.matmul.shard(strategy) ... ... def construct(self, inputs): ... x = self.matmul(inputs, self.matmul_weight) ... x = self.neg(x) ... return x >>> >>> # Create dataset. >>> def get_dataset(*inputs): ... def generate(): ... for _ in range(step_per_epoch): ... yield inputs ... return generate >>> >>> # Train network and save distributed checkpoint. >>> def train_net(): ... ms.set_context(mode=ms.GRAPH_MODE) ... init() ... np.random.seed(1) ... input_data = np.random.rand(16, 96).astype(np.float32) ... label_data = np.random.rand(16, 16).astype(np.float32) ... fake_dataset = get_dataset(input_data, label_data) ... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"]) ... ... # Set parallel strategy. ... strategy = ((1, 4), (4, 1)) ... ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num, ... strategy_ckpt_save_file="./train_strategy.ckpt") ... network = Net(matmul_size=(96, 16), strategy=strategy) ... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) ... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean") ... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt) ... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False) ... global_rank_id = int(os.getenv("RANK_ID")) ... ckpt_path = "./rank_{}_ckpt".format(global_rank_id) ... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config) ... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False) ... ms.reset_auto_parallel_context() >>> >>> # Load distributed checkpoint and test. >>> def load_model(): ... ms.set_context(mode=ms.GRAPH_MODE) ... init() ... ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel", ... strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num) ... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32)) ... network = Net(matmul_size=(96, 16)) ... model = ms.Model(network) ... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data)) ... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)] ... ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout) ... predict_result = model.predict(predict_data) ... print(predict_result) >>> >>> train_net() >>> load_model() [[-7.3259363 -7.497216 -7.398196 ... -7.374962 -7.204874 -7.234935 ] [ 3.362938 3.3535435 3.3832688 ... 3.4263954 3.279045 3.3202887] ... [ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]