mindspore.nn.utils.hooks 源代码

# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hooks"""
from collections import OrderedDict
import weakref
from typing import Any, Tuple

__all__ = ["RemovableHandle"]


[文档]class RemovableHandle: r""" A handle which provides the capability to remove a hook. Args: hooks_dict (dict): A dictionary of hooks, indexed by hook `id`. Keyword Args: extra_dict (Union[dict, List[dict]], optional): An additional dictionary or list of dictionaries whose keys will be deleted when the same keys are removed from `hooks_dict`. Default ``None``. """ id: int next_id: int = 0 def __init__(self, hooks_dict: Any, *, extra_dict: Any = None) -> None: self.hooks_dict_ref = weakref.ref(hooks_dict) self.id = RemovableHandle.next_id RemovableHandle.next_id += 1 self.extra_dict_ref: Tuple = () if isinstance(extra_dict, dict): self.extra_dict_ref = (weakref.ref(extra_dict),) elif isinstance(extra_dict, list): self.extra_dict_ref = tuple(weakref.ref(d) for d in extra_dict) def remove(self) -> None: hooks_dict = self.hooks_dict_ref() if hooks_dict is not None and self.id in hooks_dict: del hooks_dict[self.id] for ref in self.extra_dict_ref: extra_dict = ref() if extra_dict is not None and self.id in extra_dict: del extra_dict[self.id] def __getstate__(self): if self.extra_dict_ref is None: return (self.hooks_dict_ref(), self.id) return (self.hooks_dict_ref(), self.id, tuple(ref() for ref in self.extra_dict_ref)) def __setstate__(self, state) -> None: if state[0] is None: # create a dead reference self.hooks_dict_ref = weakref.ref(OrderedDict()) else: self.hooks_dict_ref = weakref.ref(state[0]) self.id = state[1] RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1) if len(state) < 3 or state[2] is None: self.extra_dict_ref = () else: self.extra_dict_ref = tuple(weakref.ref(d) for d in state[2]) def __enter__(self) -> "RemovableHandle": return self def __exit__(self, type: Any, value: Any, tb: Any) -> None: self.remove()