Created
May 1, 2022 18:23
-
-
Save ProGamerGov/e4060b55c702835ac933d95f063a2f6e to your computer and use it in GitHub Desktop.
Remove hooks in PyTorch without using the hook handle
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import OrderedDict | |
from typing import Callable, Dict, Optional | |
from warnings import warn | |
import torch | |
def _remove_all_forward_hooks( | |
module: torch.nn.Module, hook_fn_name: Optional[str] = None | |
) -> None: | |
""" | |
This function removes all forward hooks in the specified module, without requiring | |
any hook handles. This lets us clean up & remove any hooks that weren't property | |
deleted. | |
Warning: Various PyTorch modules and systems make use of hooks, and thus extreme | |
caution should be exercised when removing all hooks. Users are recommended to give | |
their hook function a unique name that can be used to safely identify and remove | |
the target forward hooks. | |
Args: | |
module (nn.Module): The module instance to remove forward hooks from. | |
hook_fn_name (str, optional): Optionally only remove specific forward hooks | |
based on their function's __name__ attribute. | |
Default: None | |
""" | |
if hook_fn_name is None: | |
warn("Removing all active hooks will break some PyTorch modules & systems.") | |
def _remove_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None: | |
if hasattr(module, "_forward_hooks"): | |
if m._forward_hooks != OrderedDict(): | |
if name is not None: | |
dict_items = list(m._forward_hooks.items()) | |
m._forward_hooks = OrderedDict( | |
[(i, fn) for i, fn in dict_items if fn.__name__ != name] | |
) | |
else: | |
m._forward_hooks: Dict[int, Callable] = OrderedDict() | |
def _remove_child_hooks( | |
target_module: torch.nn.Module, hook_name: Optional[str] = None | |
) -> None: | |
for name, child in target_module._modules.items(): | |
if child is not None: | |
_remove_hooks(child, hook_name) | |
_remove_child_hooks(child, hook_name) | |
# Remove hooks from target submodules | |
_remove_child_hooks(module, hook_fn_name) | |
# Remove hooks from the target module | |
_remove_hooks(module, hook_fn_name) | |
from collections import OrderedDict | |
from typing import List, Optional | |
import torch | |
def _count_forward_hooks( | |
module: torch.nn.Module, hook_fn_name: Optional[str] = None | |
) -> int: | |
""" | |
Count the number of active forward hooks on the specified module instance. | |
Args: | |
module (nn.Module): The model module instance to count the number of | |
forward hooks on. | |
name (str, optional): Optionally only count specific forward hooks based on | |
their function's __name__ attribute. | |
Default: None | |
Returns: | |
num_hooks (int): The number of active hooks in the specified module. | |
""" | |
num_hooks: List[int] = [0] | |
def _count_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None: | |
if hasattr(m, "_forward_hooks"): | |
if m._forward_hooks != OrderedDict(): | |
dict_items = list(m._forward_hooks.items()) | |
for i, fn in dict_items: | |
if hook_fn_name is None or fn.__name__ == name: | |
num_hooks[0] += 1 | |
def _count_child_hooks( | |
target_module: torch.nn.Module, | |
hook_name: Optional[str] = None, | |
) -> None: | |
for name, child in target_module._modules.items(): | |
if child is not None: | |
_count_hooks(child, hook_name) | |
_count_child_hooks(child, hook_name) | |
_count_child_hooks(module, hook_fn_name) | |
_count_hooks(module, hook_fn_name) | |
return num_hooks[0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
These functions are based on issues with PyTorch's hook management system that I raised here: pytorch/pytorch#70455