mindspore.load_obf_params_into_net

mindspore.load_obf_params_into_net(network, target_modules=None, obf_ratios=None, obf_config=None, data_parallel_num=1, **kwargs)[源代码]

根据用户配置的混淆策略,对模型结构进行修改,并将混淆态Checkpoint加载到模型中。

参数:
  • network (nn.Cell) - 待混淆的原始网络。

  • target_modules (list[str]) - 需要混淆的目标算子。第一个字符串表示目标算子在原网络中的路径,应该是 "A/B/C" 的形式。第二个字符串表示同一个路径下的多个目标算子名,它应该是 "D|E|F" 的形式。例如,GPT2的 target_modules 可以是 ['backbone/blocks/attention', 'dense1|dense2|dense3'] 。如果 target_modules 有第三个值,它的格式应该是 "obfuscate_layers:all""obfuscate_layers:int" ,这表示需要混淆重复层(如transformer层或resnet块)的层数。默认值:None

  • obf_ratios (Tensor) - 混淆系数,由 mindspore.obfuscate_ckpt 接口生成。默认值:None

  • obf_config (dict) - 模型混淆策略的配置。默认值:None

  • data_parallel_num (int) - 模型并行训练的数据并行度。默认值:1

  • kwargs (dict) - 配置选项字典。

    • ignored_func_decorators (list[str]) - Python代码中函数装饰器的名字列表。

    • ignored_class_decorators (list[str]) - Python代码中类装饰器的名字列表。

返回:

混淆后模型(nn.Cell)。

异常:
  • TypeError - network 不是nn.Cell类型。

  • TypeError - obf_ratios 不是Tensor类型。

  • TypeError - target_modules 不是list类型。

  • TypeError - obf_config 不是dict类型。

  • TypeError - target_modules 中的元素不是str类型。

  • ValueError - obf_ratios 为空。

  • ValueError - target_modules 中的元素个数小于2。

  • ValueError - target_modules 的第一个字符串包含大小写字母、数字、 '_''/' 以外的字符。

  • ValueError - target_modules 的第二个字符串为空或包含大小写字母,数字, '_''/'' 以外的字符。

  • ValueError - target_modules 的第三个字符串不是 "obfuscate_layers:all""obfuscate_layers:int" 的格式。

  • TypeError - ignored_func_decorators 不是字符串列表,或 ignored_class_decorators 不是字符串列表。

样例:

>>> from mindspore import obfuscate_ckpt, save_checkpoint, load_checkpoint, Tensor
>>> import mindspore.common.dtype as mstype
>>> import numpy as np
>>> # Refer to https://gitee.com/mindspore/docs/blob/r2.4.1/docs/mindspore/code/lenet.py
>>> net = LeNet5()
>>> save_checkpoint(net, './test_net.ckpt')
>>> target_modules = ['', 'fc1|fc2']
>>> # obfuscate ckpt files
>>> obfuscate_ckpt(net, './', target_modules=target_modules, saved_path='./')
>>> # load obf ckpt into network
>>> new_net = LeNet5()
>>> load_checkpoint('./test_net_obf.ckpt', new_net)
>>> obf_net = load_obf_params_into_net(new_net, target_modules)