mindspore.load_distributed_checkpoint
- mindspore.load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, train_strategy_filename=None, strict_load=False, dec_key=None, dec_mode='AES-GCM')[source]
Load checkpoint into net for distributed predication. Used in the case of distributed inference.
- Parameters
network (Cell) – Network for distributed predication.
checkpoint_filenames (list[str]) – The name of Checkpoint files in order of rank id.
predict_strategy (dict) – Strategy of predication process. It means that using one device to predict when setting predict_strategy as None. Default:
None
.train_strategy_filename (str) – The filename of training strategy protocol buffer file. When train_strategy_filename is None, the training strategy file will be obtained from context.get_auto_parallel_context("strategy_ckpt_load_file"). Therefore, the training strategy file needs to be specified in at least one of them. Default:
None
.strict_load (bool) – Whether to strict load the parameter into net. If
False
, it will load parameter into net when parameter name's suffix in checkpoint file is the same as the parameter in the network. When the types are inconsistent, perform type conversion on the parameters of the same type, such as float32 to float16. Default:False
.dec_key (Union[None, bytes]) – Byte type key used for decryption. If the value is
None
, the decryption is not required. Default:None
.dec_mode (str) – This parameter is valid only when dec_key is not set to
None
. Specifies the decryption mode, currently supports'AES-GCM'
,'AES-CBC'
and'SM4-CBC'
. Default:'AES-GCM'
.
- Raises
TypeError – The type of inputs do not match the requirements.
ValueError – Failed to load checkpoint into net.
- Supported Platforms:
Ascend
GPU
Examples
Note
Before running the following examples, you need to configure the communication environment variables.
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id. Please see the rank table startup for more details.
For the GPU devices, users need to prepare the host file and mpi, please see the mpirun startup .
For the CPU device, users need to write a dynamic cluster startup script, please see the Dynamic Cluster Startup .
>>> import os >>> import numpy as np >>> import mindspore as ms >>> import mindspore.dataset as ds >>> from mindspore import nn, ops, train >>> from mindspore.communication import init >>> >>> step_per_epoch = 4 >>> device_num = 8 >>> >>> # Define the network structure. >>> class Net(nn.Cell): ... def __init__(self, matmul_size, strategy=None): ... super().__init__() ... matmul_np = np.full(matmul_size, 0.5, dtype=np.float32) ... self.matmul_weight = ms.Parameter(ms.Tensor(matmul_np)) ... self.matmul = ops.MatMul() ... self.neg = ops.Neg() ... if strategy is not None: ... self.matmul.shard(strategy) ... ... def construct(self, inputs): ... x = self.matmul(inputs, self.matmul_weight) ... x = self.neg(x) ... return x >>> >>> # Create dataset. >>> def get_dataset(*inputs): ... def generate(): ... for _ in range(step_per_epoch): ... yield inputs ... return generate >>> >>> # Train network and save distributed checkpoint. >>> def train_net(): ... ms.set_context(mode=ms.GRAPH_MODE) ... init() ... np.random.seed(1) ... input_data = np.random.rand(16, 96).astype(np.float32) ... label_data = np.random.rand(16, 16).astype(np.float32) ... fake_dataset = get_dataset(input_data, label_data) ... dataset = ds.GeneratorDataset(fake_dataset, ["input", "label"]) ... ... # Set parallel strategy. ... strategy = ((1, 4), (4, 1)) ... ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num, ... strategy_ckpt_save_file="./train_strategy.ckpt") ... network = Net(matmul_size=(96, 16), strategy=strategy) ... net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) ... net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean") ... model = ms.Model(network=network, loss_fn=net_loss, optimizer=net_opt) ... ckpt_config = train.CheckpointConfig(keep_checkpoint_max=1, integrated_save=False) ... global_rank_id = int(os.getenv("RANK_ID")) ... ckpt_path = "./rank_{}_ckpt".format(global_rank_id) ... ckpt_callback = train.ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config) ... model.train(epoch=2, train_dataset=dataset, callbacks=[ckpt_callback], dataset_sink_mode=False) ... ms.reset_auto_parallel_context() >>> >>> # Load distributed checkpoint and test. >>> def load_model(): ... ms.set_context(mode=ms.GRAPH_MODE) ... init() ... ms.set_auto_parallel_context(full_batch=True, parallel_mode="semi_auto_parallel", ... strategy_ckpt_load_file="./train_strategy.ckpt", device_num=device_num) ... predict_data = ms.Tensor(np.random.randn(128, 96).astype(np.float32)) ... network = Net(matmul_size=(96, 16)) ... model = ms.Model(network) ... predict_layout = model.infer_predict_layout(ms.Tensor(predict_data)) ... ckpt_file_list = ["./rank_{}_ckpt/parallel-2_4.ckpt".format(i) for i in range(0, device_num)] ... ms.load_distributed_checkpoint(network, ckpt_file_list, predict_layout) ... predict_result = model.predict(predict_data) ... print(predict_result) >>> >>> train_net() >>> load_model() [[-7.3259363 -7.497216 -7.398196 ... -7.374962 -7.204874 -7.234935 ] [ 3.362938 3.3535435 3.3832688 ... 3.4263954 3.279045 3.3202887] ... [ 1.6067538 1.6244187 1.5384722 ... 1.5449994 1.6195512 1.6176052]]