# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hooks"""
from collections import OrderedDict
import weakref
from typing import Any, Tuple
__all__ = ["RemovableHandle"]
[文档]class RemovableHandle:
r"""
A handle which provides the capability to remove a hook.
Args:
hooks_dict (dict): A dictionary of hooks, indexed by hook `id`.
Keyword Args:
extra_dict (Union[dict, List[dict]], optional): An additional dictionary or list of
dictionaries whose keys will be deleted when the same keys are
removed from `hooks_dict`. Default ``None``.
"""
id: int
next_id: int = 0
def __init__(self, hooks_dict: Any, *, extra_dict: Any = None) -> None:
self.hooks_dict_ref = weakref.ref(hooks_dict)
self.id = RemovableHandle.next_id
RemovableHandle.next_id += 1
self.extra_dict_ref: Tuple = ()
if isinstance(extra_dict, dict):
self.extra_dict_ref = (weakref.ref(extra_dict),)
elif isinstance(extra_dict, list):
self.extra_dict_ref = tuple(weakref.ref(d) for d in extra_dict)
def remove(self) -> None:
hooks_dict = self.hooks_dict_ref()
if hooks_dict is not None and self.id in hooks_dict:
del hooks_dict[self.id]
for ref in self.extra_dict_ref:
extra_dict = ref()
if extra_dict is not None and self.id in extra_dict:
del extra_dict[self.id]
def __getstate__(self):
if self.extra_dict_ref is None:
return (self.hooks_dict_ref(), self.id)
return (self.hooks_dict_ref(), self.id, tuple(ref() for ref in self.extra_dict_ref))
def __setstate__(self, state) -> None:
if state[0] is None:
# create a dead reference
self.hooks_dict_ref = weakref.ref(OrderedDict())
else:
self.hooks_dict_ref = weakref.ref(state[0])
self.id = state[1]
RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
if len(state) < 3 or state[2] is None:
self.extra_dict_ref = ()
else:
self.extra_dict_ref = tuple(weakref.ref(d) for d in state[2])
def __enter__(self) -> "RemovableHandle":
return self
def __exit__(self, type: Any, value: Any, tb: Any) -> None:
self.remove()