mindspore.nn.utils.no_init_parameters
- mindspore.nn.utils.no_init_parameters()[源代码]
使用该接口可跳过parameter初始化。
加载checkpoint的场景下网络实例化中parameter会实例化并占用物理内存,加载checkpoint会替换parameter值, 使用该接口在网络实例化时用装饰器给当前Cell里所有parameter添加一个属性: init_param ,并设为 init_param=False , 检测到 init_param=False 时跳过parameter初始化,加载checkpoint时从checkpoint给parameter赋值,可优化性能和减少物理内存。
说明
只能跳过使用 initializer 创建的parameter的初始化,由 Tensor 或 numpy 创建的parameter无法跳过。
样例:
>>> import mindspore as ms >>> from mindspore import nn, ops, load_checkpoint >>> from mindspore.common.initializer import initializer >>> from mindspore.nn.utils import no_init_parameters >>> # 1. Add a decorator to the network that requires delayed initialization >>> class Net(nn.Cell): ... def __init__(self, in_channels, out_channels): ... super().__init__() ... self.weight = ms.Parameter(initializer("normal", [in_channels, out_channels], ms.float32)) ... self.bias = ms.Parameter(initializer("normal", [out_channels], ms.float32)) ... self.matmul = ops.MatMul() ... self.add = ops.Add() ... ... def construct(self, x): ... x = self.matmul(x, self.weight) ... x = self.add(x, self.bias) ... return x >>> with no_init_parameters(): ... # After instantiation, all parameters in the net are not initialized ... net = Net(28*28, 64) >>> # 2. Load checkpoint parameters to the net >>> load_checkpoint('./checkpoint/test_net.ckpt', net=net) >>> # 3. After loading the checkpoint, manually call init_parameters_data() to initialize >>> # the uninitialized parameters in the net if need. If the network is executed, >>> # the framework will automatically call this interface. >>> net.init_parameters_data()