mindspore.load_param_into_net

View Source On Gitee
mindspore.load_param_into_net(net, parameter_dict, strict_load=False)[source]

Load parameters into network, return parameter list that are not loaded in the network.

Parameters
  • net (Cell) – The network where the parameters will be loaded.

  • parameter_dict (dict) – The dictionary generated by load checkpoint file, it is a dictionary consisting of key: parameters's name, value: parameter.

  • strict_load (bool) – Whether to strict load the parameter into net. If False , it will load parameter into net when parameter name's suffix in checkpoint file is the same as the parameter in the network. When the types are inconsistent perform type conversion on the parameters of the same type, such as float32 to float16. Default: False .

Returns

  • param_not_load (List), the parameter name in model which are not loaded into the network.

  • ckpt_not_load (List), the parameter name in checkpoint file which are not loaded into the network.

Raises

TypeError – Argument is not a Cell, or parameter_dict is not a Parameter dictionary.

Examples

>>> import mindspore as ms
>>>
>>> # Define the network structure of LeNet5. Refer to
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
>>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1")
>>> param_not_load, _ = ms.load_param_into_net(net, param_dict)
>>> print(param_not_load)
['conv1.weight']
Tutorial Examples: