mindspore.nn.utils.no_init_parameters

查看源文件
mindspore.nn.utils.no_init_parameters()[源代码]

使用该接口可跳过parameter初始化。

加载checkpoint的场景下网络实例化中parameter会实例化并占用物理内存,加载checkpoint会替换parameter值, 使用该接口在网络实例化时用装饰器给当前Cell里所有parameter添加一个属性: init_param ,并设为 init_param=False , 检测到 init_param=False 时跳过parameter初始化,加载checkpoint时从checkpoint给parameter赋值,可优化性能和减少物理内存。

说明

只能跳过使用 initializer 创建的parameter的初始化,由 Tensornumpy 创建的parameter无法跳过。

样例:

>>> import mindspore as ms
>>> from mindspore import nn, ops, load_checkpoint
>>> from mindspore.common.initializer import initializer
>>> from mindspore.nn.utils import no_init_parameters
>>> # 1. Add a decorator to the network that requires delayed initialization
>>> class Net(nn.Cell):
...     def __init__(self, in_channels, out_channels):
...         super().__init__()
...         self.weight = ms.Parameter(initializer("normal", [in_channels, out_channels], ms.float32))
...         self.bias = ms.Parameter(initializer("normal", [out_channels], ms.float32))
...         self.matmul = ops.MatMul()
...         self.add = ops.Add()
...
...     def construct(self, x):
...         x = self.matmul(x, self.weight)
...         x = self.add(x, self.bias)
...         return x
>>> with no_init_parameters():
...     # After instantiation, all parameters in the net are not initialized
...     net = Net(28*28, 64)
>>> # 2. Load checkpoint parameters to the net
>>> load_checkpoint('./checkpoint/test_net.ckpt', net=net)
>>> # 3. After loading the checkpoint, manually call init_parameters_data() to initialize
>>> #    the uninitialized parameters in the net if need. If the network is executed,
>>> #    the framework will automatically call this interface.
>>> net.init_parameters_data()