Source code for mindspore.nn.utils.init

# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""init for nn.Cell."""
from __future__ import absolute_import

from contextlib import contextmanager
from mindspore.common.parameter import Parameter


[docs]@contextmanager def no_init_parameters(): r""" In scenarios where a checkpoint is loaded, parameters within the network instantiation will be instantiated and occupy physical memory. Loading a checkpoint will replace the parameter values. Decorator can be applied during network instantiation to add an attribute `init_param` to all parameters within the current Cell, setting it to `init_param=False` . When `init_param=False` is detected, the initialization of the parameters is skipped, and the parameters are assigned values directly from the checkpoint during loading, which can optimize performance and reduce physical memory usage. Note: Initialization of parameters created with `initializer` can only be skipped. Parameters created by `Tensor` or `numpy` cannot be skipped. Examples: >>> import mindspore as ms >>> from mindspore import nn, ops, load_checkpoint >>> from mindspore.common.initializer import initializer >>> from mindspore.nn.utils import no_init_parameters >>> # 1. Add a decorator to the network that requires delayed initialization >>> class Net(nn.Cell): ... def __init__(self, in_channels, out_channels): ... super().__init__() ... self.weight = ms.Parameter(initializer("normal", [in_channels, out_channels], ms.float32)) ... self.bias = ms.Parameter(initializer("normal", [out_channels], ms.float32)) ... self.matmul = ops.MatMul() ... self.add = ops.Add() ... ... def construct(self, x): ... x = self.matmul(x, self.weight) ... x = self.add(x, self.bias) ... return x >>> with no_init_parameters(): ... # After instantiation, all parameters in the net are not initialized ... net = Net(28*28, 64) >>> # 2. Load checkpoint parameters to the net >>> load_checkpoint('./checkpoint/test_net.ckpt', net=net) >>> # 3. After loading the checkpoint, manually call init_parameters_data() to initialize >>> # the uninitialized parameters in the net if need. If the network is executed, >>> # the framework will automatically call this interface. >>> net.init_parameters_data() """ init_class = Parameter setattr(init_class, "init_param", False) try: yield finally: setattr(init_class, "init_param", True)