mindspore.nn.utils.no_init_parameters

View Source On Gitee
mindspore.nn.utils.no_init_parameters()[source]

In scenarios where a checkpoint is loaded, parameters within the network instantiation will be instantiated and occupy physical memory. Loading a checkpoint will replace the parameter values. Decorator can be applied during network instantiation to add an attribute init_param to all parameters within the current Cell, setting it to init_param=False . When init_param=False is detected, the initialization of the parameters is skipped, and the parameters are assigned values directly from the checkpoint during loading, which can optimize performance and reduce physical memory usage.

Note

Initialization of parameters created with initializer can only be skipped. Parameters created by Tensor or numpy cannot be skipped.

Examples

>>> import mindspore as ms
>>> from mindspore import nn, ops, load_checkpoint
>>> from mindspore.common.initializer import initializer
>>> from mindspore.nn.utils import no_init_parameters
>>> # 1. Add a decorator to the network that requires delayed initialization
>>> class Net(nn.Cell):
...     def __init__(self, in_channels, out_channels):
...         super().__init__()
...         self.weight = ms.Parameter(initializer("normal", [in_channels, out_channels], ms.float32))
...         self.bias = ms.Parameter(initializer("normal", [out_channels], ms.float32))
...         self.matmul = ops.MatMul()
...         self.add = ops.Add()
...
...     def construct(self, x):
...         x = self.matmul(x, self.weight)
...         x = self.add(x, self.bias)
...         return x
>>> with no_init_parameters():
...     # After instantiation, all parameters in the net are not initialized
...     net = Net(28*28, 64)
>>> # 2. Load checkpoint parameters to the net
>>> load_checkpoint('./checkpoint/test_net.ckpt', net=net)
>>> # 3. After loading the checkpoint, manually call init_parameters_data() to initialize
>>> #    the uninitialized parameters in the net if need. If the network is executed,
>>> #    the framework will automatically call this interface.
>>> net.init_parameters_data()