mindspore.load_checkpoint
- mindspore.load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode='AES-GCM', specify_prefix=None)[source]
Load checkpoint info from a specified file.
Note
specify_prefix and filter_prefix do not affect each other.
If none of the parameters are loaded from checkpoint file, it will throw ValueError.
- Parameters
ckpt_file_name (str) – Checkpoint file name.
net (Cell) – The network where the parameters will be loaded. Default: None
strict_load (bool) – Whether to strict load the parameter into net. If False, it will load parameter into net when parameter name’s suffix in checkpoint file is the same as the parameter in the network. When the types are inconsistent perform type conversion on the parameters of the same type, such as float32 to float16. Default: False.
filter_prefix (Union[str, list[str], tuple[str]]) – Parameters starting with the filter_prefix will not be loaded. Default: None.
dec_key (Union[None, bytes]) – Byte type key used for decryption. If the value is None, the decryption is not required. Default: None.
dec_mode (str) – This parameter is valid only when dec_key is not set to None. Specifies the decryption mode, currently supports ‘AES-GCM’ and ‘AES-CBC’. Default: ‘AES-GCM’.
specify_prefix (Union[str, list[str], tuple[str]]) – Parameters starting with the specify_prefix will be loaded. Default: None.
- Returns
Dict, key is parameter name, value is a Parameter or string. When the append_dict parameter of
mindspore.save_checkpoint()
and the append_info parameter ofCheckpointConfig
are used to save the checkpoint, append_dict and append_info are dict types, and their value are string, then the return value obtained by loading checkpoint is string, and in other cases the return value is Parameter.- Raises
ValueError – Checkpoint file’s format is incorrect.
ValueError – Parameter’s dict is None after load checkpoint file.
TypeError – The type of specify_prefix or filter_prefix is incorrect.
Examples
>>> import mindspore as ms >>> >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" >>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1", specify_prefix="conv", ) >>> print(param_dict["conv2.weight"]) Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)