Skip to content

Commit

Permalink
copy model hook implementation from pab
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Jan 16, 2025
1 parent deda9a3 commit 80ac5a7
Showing 1 changed file with 38 additions and 19 deletions.
57 changes: 38 additions & 19 deletions src/diffusers/hooks/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import functools
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple

import torch

Expand All @@ -33,7 +33,6 @@ class ModelHook:
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is initialized.
Args:
module (`torch.nn.Module`):
The module attached to this hook.
Expand All @@ -43,7 +42,6 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is deinitalized.
Args:
module (`torch.nn.Module`):
The module attached to this hook.
Expand All @@ -55,15 +53,13 @@ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
r"""
Hook that is executed just before the forward method of the model.
Args:
module (`torch.nn.Module`):
The module whose forward pass will be executed just after this event.
args (`Tuple[Any]`):
The positional arguments passed to the module.
kwargs (`Dict[Str, Any]`):
The keyword arguments passed to the module.
Returns:
`Tuple[Tuple[Any], Dict[Str, Any]]`:
A tuple with the treated `args` and `kwargs`.
Expand All @@ -73,13 +69,11 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[A
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
r"""
Hook that is executed just after the forward method of the model.
Args:
module (`torch.nn.Module`):
The module whose forward pass been executed just before this event.
output (`Any`):
The output of the module.
Returns:
`Any`: The processed `output`.
"""
Expand All @@ -88,7 +82,6 @@ def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when the hook is detached from a module.
Args:
module (`torch.nn.Module`):
The module detached from this hook.
Expand Down Expand Up @@ -123,31 +116,57 @@ def register_hook(self, hook: ModelHook, name: str) -> None:
self._module_ref = hook.initialize_hook(self._module_ref)

if hasattr(hook, "new_forward"):
new_forward = hook.new_forward
rewritten_forward = hook.new_forward

def new_forward(module, *args, **kwargs):
args, kwargs = hook.pre_forward(module, *args, **kwargs)
output = rewritten_forward(module, *args, **kwargs)
return hook.post_forward(module, output)
else:

def new_forward(module, *args, **kwargs):
args, kwargs = hook.pre_forward(module, *args, **kwargs)
output = old_forward(*args, **kwargs)
return hook.post_forward(module, output)

new_forward = functools.update_wrapper(new_forward, old_forward)
self._module_ref.forward = new_forward.__get__(self._module_ref)
self._module_ref.forward = functools.update_wrapper(
functools.partial(new_forward, self._module_ref), old_forward
)

self.hooks[name] = hook
self._hook_order.append(name)

def get_hook(self, name: str) -> ModelHook:
def get_hook(self, name: str) -> Optional[ModelHook]:
if name not in self.hooks.keys():
raise ValueError(f"Hook with name {name} not found.")
return None
return self.hooks[name]

def remove_hook(self, name: str) -> None:
if name not in self.hooks.keys():
raise ValueError(f"Hook with name {name} not found.")
self.hooks[name].deinitalize_hook(self._module_ref)
del self.hooks[name]
self._hook_order.remove(name)
def remove_hook(self, name: str, recurse: bool = True) -> None:
if name in self.hooks.keys():
hook = self.hooks[name]
self._module_ref = hook.deinitalize_hook(self._module_ref)
del self.hooks[name]
self._hook_order.remove(name)

if recurse:
for module_name, module in self._module_ref.named_modules():
if module_name == "":
continue
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.remove_hook(name, recurse=False)

def reset_stateful_hooks(self, recurse: bool = True) -> None:
for hook_name in self._hook_order:
hook = self.hooks[hook_name]
if hook._is_stateful:
hook.reset_state(self._module_ref)

if recurse:
for module_name, module in self._module_ref.named_modules():
if module_name == "":
continue
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False)

@classmethod
def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
Expand Down

0 comments on commit 80ac5a7

Please sign in to comment.