mindspore.load_distributed_checkpoint
- mindspore.load_distributed_checkpoint(network, checkpoint_filenames=None, predict_strategy=None, train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM', format='ckpt', unified_safetensors_dir=None, dst_safetensors_dir=None, rank_id=None)[源代码]
给分布式预测加载checkpoint文件到网络。用于分布式推理。关于分布式推理的细节,请参考: 分布式模型加载 。
- 参数:
network (Cell) - 分布式预测网络。
checkpoint_filenames (list[str]) - checkpoint文件的名称,按rank id顺序排列。默认值:
None
。predict_strategy (dict) - 预测时参数的切分策略。默认值:
None
。train_strategy_filename (str) - 训练策略proto文件名。默认值:
None
。strict_load (bool) - 表示是否严格加载参数到网络。如果值为
False
,则当checkpoint文件中参数名称的后缀与网络中的参数相同时,加载参数到网络。当类型不一致时,对相同类型的参数进行类型转换,如从float32到float16。默认值:False
。dec_key (Union[None, bytes]) - 用于解密的字节类型key。如果value为
None
,则不需要解密。默认值:None
。dec_mode (str) - 仅当dec_key不设为
None
时,该参数有效。指定了解密模式,目前支持'AES-GCM'
,'AES-CBC'
和'SM4-CBC'
。默认值:'AES-GCM'
。format (str) - 待加载进网络的输入权重格式。可以设置为 "ckpt" 或 "safetensors"。默认值:"ckpt"。
unified_safetensors_dir (str) - 待加载进网络的输入权重文件目录。默认值:
None
。dst_safetensors_dir (str) - 保存模式场景下,safetensors的保存目录。
rank_id (int) - 卡的逻辑序号。非保存模式下,通过初始化网络全局自动获取;保存模式下,按传入序号保存文件,若未传入,则全量保存。
- 异常:
TypeError - 输入类型不符合要求。
ValueError - 无法加载checkpoint文件到网络。
- 支持平台:
Ascend
GPU
样例:
>>> import os >>> import numpy as np >>> import mindspore as ms >>> import mindspore.dataset as ds >>> from mindspore import nn, ops, train >>> from mindspore.communication import init >>> >>> step_per_epoch = 4 >>> device_num = 8 >>> >>> # Define the network structure. >>> class Net(nn.Cell): ... def __init__(self, matmul_size, strategy=None): ... super().__init__() ... matmul_np = np.full(matmul_size, 0.5, dtype=np.float32) ... self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np)) ... self.matmul = ops.MatMul() ... self.neg = ops.Neg() ... if strategy is not None: ... self.matmul.shard(strategy) ... ... def construct(self, inputs): ... x = self.matmul(inputs, self.matmul_weight) ... x = self.neg(x) ... return x >>> >>> # Create dataset. >>> def get_dataset(*inputs): ... def generate(): ... for _ in range(step_per_epoch): ... yield inputs ... return generate >>> >>> # Train network and save distributed checkpoint. >>> def train_net(): ... ms.set_context(mode=ms.GRAPH_MODE) ... init() ... np.random.seed(1) ... input_data = np.random.rand(16, 96).astype(np.float32) ... label_data = np.random.rand(16, 16).astype(np.float32) ... fake_dataset = get_dataset(input_data, label_data) ... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"]) ... ... # Set parallel strategy. ... strategy = ((1, 4), (4, 1)) ... ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num, ... strategy_ckpt_save_file="./train_strategy.ckpt") ... network = Net(matmul_size=(96, 16), strategy=strategy) ... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) ... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean") ... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt) ... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False) ... global_rank_id = int(os.getenv("RANK_ID")) ... ckpt_path = "./rank_{}_ckpt".format(global_rank_id) ... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config) ... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False) ... ms.reset_auto_parallel_context() >>> >>> # Load distributed checkpoint and test. >>> def load_model(): ... ms.set_context(mode=ms.GRAPH_MODE) ... init() ... ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel", ... strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num) ... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32)) ... network = Net(matmul_size=(96, 16)) ... model = ms.Model(network) ... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data)) ... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)] ... ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout) ... predict_result = model.predict(predict_data) ... print(predict_result) >>> >>> train_net() >>> load_model() [[-7.3259363 -7.497216 -7.398196 ... -7.374962 -7.204874 -7.234935 ] [ 3.362938 3.3535435 3.3832688 ... 3.4263954 3.279045 3.3202887] ... [ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]