# Copyright 2024 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""init for nn.Cell."""
from __future__ import absolute_import
from contextlib import contextmanager
from mindspore.common.parameter import Parameter
[文档]@contextmanager
def no_init_parameters():
r"""
This interface is used to skip parameter initialization.
In scenarios where a checkpoint is loaded, parameters within the network instantiation will be
instantiated and occupy physical memory. Loading a checkpoint will replace the parameter values.
Decorator can be applied during network instantiation to add an attribute `init_param` to all
parameters within the current Cell, setting it to `init_param=False` .
When `init_param=False` is detected, the initialization of the parameters is skipped,
and the parameters are assigned values directly from the checkpoint during loading,
which can optimize performance and reduce physical memory usage.
Note:
Initialization of parameters created with `initializer` can only be skipped.
Parameters created by `Tensor` or `numpy` cannot be skipped.
Examples:
>>> from mindspore.nn.utils import no_init_parameters
>>> # 1. Add a decorator to the network that requires delayed initialization
>>> with no_init_parameters():
>>> # After instantiation, all parameters in the net are not initialized
>>> net = Net()
>>> # 2. Load checkpoint parameters to the net
>>> load_checkpoint(ckpt_file, net=net)
>>> # 3. After loading the checkpoint, manually call init_parameters_data() to initialize
>>> # the uninitialized parameters in the net if need. If the network is executed,
>>> # the framework will automatically call this interface.
>>> net.init_parameters_data()
"""
init_class = Parameter
setattr(init_class, "init_param", False)
try:
yield
finally:
setattr(init_class, "init_param", True)