mindspore.load_param_into_net
- mindspore.load_param_into_net(net, parameter_dict, strict_load=False, remove_redundancy=False)[源代码]
将参数加载到网络中,返回网络中没有被加载的参数列表。
说明
当加载去冗余的参数字典时,网络应该是编译过的。
- 参数:
net (Cell) - 将要加载参数的网络。
parameter_dict (dict) - 加载checkpoint文件得到的字典。
strict_load (bool) - 是否将参数严格加载到网络中。如果是
False
, 它将以相同的后缀名将参数字典中的参数加载到网络中,并会在精度不匹配时,进行精度转换,比如将 float32 转换为 float16 。默认值:False
。remove_redundancy (bool) - 是否开启加载去冗余保存的checkpoint。去冗余是指去除数据并行模式下的冗余数据。默认值:
False
,不开启去冗余加载。
- 返回:
param_not_load (List),网络中没有被加载的参数。
ckpt_not_load (List),checkpoint文件中没有被加载的参数。
- 异常:
TypeError - 如果参数不是Cell,或者 parameter_dict 不是Parameter类型的字典。
样例:
>>> import mindspore as ms >>> >>> # Define the network structure of LeNet5. Refer to >>> # https://gitee.com/mindspore/docs/blob/r2.4.1/docs/mindspore/code/lenet.py >>> net = LeNet5() >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" >>> param_dict = ms.load_checkpoint(ckpt_file_name, filter_prefix="conv1") >>> param_not_load, _ = ms.load_param_into_net(net, param_dict) >>> print(param_not_load) ['conv1.weight']
- 教程样例: