Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Layerwise Upcasting #10347

Merged
merged 55 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
36b0c37
update
a-r-r-o-w Dec 22, 2024
42046c0
update
a-r-r-o-w Dec 22, 2024
7dc739b
make style
a-r-r-o-w Dec 22, 2024
7ed7141
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w Dec 23, 2024
1fa4ee5
remove dynamo disable
a-r-r-o-w Dec 23, 2024
da4907e
add coauthor
a-r-r-o-w Dec 23, 2024
bc2ada4
update
a-r-r-o-w Dec 23, 2024
91bfc3d
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w Jan 2, 2025
7c31bb0
update
a-r-r-o-w Jan 2, 2025
8975bbf
update
a-r-r-o-w Jan 2, 2025
341fbfc
update mixin
a-r-r-o-w Jan 2, 2025
5f898a1
add some basic tests
a-r-r-o-w Jan 2, 2025
558c64e
update
a-r-r-o-w Jan 4, 2025
7858f2c
update
a-r-r-o-w Jan 12, 2025
2663026
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w Jan 12, 2025
3d84b9e
non_blocking
a-r-r-o-w Jan 12, 2025
9372647
improvements
a-r-r-o-w Jan 12, 2025
a0f1de7
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w Jan 14, 2025
e586ef3
update
a-r-r-o-w Jan 13, 2025
cfe6318
norm.* -> norm
a-r-r-o-w Jan 14, 2025
9235f77
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w Jan 15, 2025
7627415
apply suggestions from review
a-r-r-o-w Jan 15, 2025
b9e1217
add example
a-r-r-o-w Jan 15, 2025
bde103c
update hook implementation to the latest changes from pyramid attenti…
a-r-r-o-w Jan 15, 2025
64e6c9c
deinitialize should raise an error
a-r-r-o-w Jan 15, 2025
7037133
update doc page
a-r-r-o-w Jan 15, 2025
f1b46d6
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w Jan 15, 2025
390742b
Apply suggestions from code review
a-r-r-o-w Jan 16, 2025
19901e7
update docs
a-r-r-o-w Jan 17, 2025
3ae32b4
update
a-r-r-o-w Jan 17, 2025
bf797e7
refactor
a-r-r-o-w Jan 17, 2025
d22465a
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w Jan 17, 2025
5956a9e
fix _always_upcast_modules for asym ae and vq_model
a-r-r-o-w Jan 17, 2025
93bd8ee
fix lumina embedding forward to not depend on weight dtype
a-r-r-o-w Jan 21, 2025
77a32a7
refactor tests
a-r-r-o-w Jan 21, 2025
1335d7e
add simple lora inference tests
a-r-r-o-w Jan 21, 2025
a263e1a
_always_upcast_modules -> _precision_sensitive_module_patterns
a-r-r-o-w Jan 21, 2025
93e36ba
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w Jan 21, 2025
245137f
remove todo comments about review; revert changes to self.dtype in un…
a-r-r-o-w Jan 21, 2025
b713511
check layer dtypes in lora test
a-r-r-o-w Jan 21, 2025
4450b1c
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w Jan 21, 2025
ed14d26
fix UNet1DModelTests::test_layerwise_upcasting_inference
a-r-r-o-w Jan 21, 2025
2c9c33f
_precision_sensitive_module_patterns -> _skip_layerwise_casting_patte…
a-r-r-o-w Jan 21, 2025
08211f7
skip test in NCSNppModelTests
a-r-r-o-w Jan 21, 2025
59e04c3
skip tests for AutoencoderTinyTests
a-r-r-o-w Jan 21, 2025
0a16826
skip tests for AutoencoderOobleckTests
a-r-r-o-w Jan 21, 2025
1d306b8
skip tests for UNet1DModelTests - unsupported pytorch operations
a-r-r-o-w Jan 21, 2025
a9364bd
layerwise_upcasting -> layerwise_casting
a-r-r-o-w Jan 21, 2025
c4d5a2b
skip tests for UNetRLModelTests; needs next pytorch release for curre…
a-r-r-o-w Jan 21, 2025
d175d93
add layerwise fp8 pipeline test
a-r-r-o-w Jan 21, 2025
bf11691
use xfail
a-r-r-o-w Jan 21, 2025
1c523b2
Apply suggestions from code review
a-r-r-o-w Jan 22, 2025
7803364
Merge branch 'main' into layerwise-upcasting-hook
a-r-r-o-w Jan 22, 2025
376adf9
add assertion with fp32 comparison; add tolerance to fp8-fp32 vs fp32…
a-r-r-o-w Jan 22, 2025
719e8d3
add note about memory consumption on tesla CI runner for failing test
a-r-r-o-w Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/api/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@ Utility and helper functions for working with 🤗 Diffusers.
## randn_tensor

[[autodoc]] utils.torch_utils.randn_tensor

## apply_layerwise_upcasting

[[autodoc]] hooks.layerwise_upcasting.apply_layerwise_upcasting
37 changes: 37 additions & 0 deletions docs/source/en/optimization/memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,43 @@ In order to properly offload models after they're called, it is required to run

</Tip>

## FP8 layerwise weight-casting

PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.

Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Layerwise weight-casting cuts down the memory footprint of the model weights by approximately half.

```python
import torch
from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel
from diffusers.utils import export_to_video

model_id = "THUDM/CogVideoX-5b"

# Load the model in bfloat16 and enable layerwise upcasting
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)

# Load the pipeline
pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
```

In the above example, layerwise upcasting is enabled on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. The normalization and modulation related weight parameters are also skipped by default.

However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] function instead of [`~ModelMixin.enable_layerwise_upcasting`].

## Channels-last memory format

The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ..utils import is_torch_available


if is_torch_available():
from .layerwise_upcasting import apply_layerwise_upcasting, apply_layerwise_upcasting_hook
188 changes: 188 additions & 0 deletions src/diffusers/hooks/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Any, Dict, Optional, Tuple

import torch

from ..utils.logging import get_logger


logger = get_logger(__name__) # pylint: disable=invalid-name


class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model.
"""

_is_stateful = False

def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is initialized.

Args:
module (`torch.nn.Module`):
The module attached to this hook.
"""
return module

def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is deinitalized.

Args:
module (`torch.nn.Module`):
The module attached to this hook.
"""
module.forward = module._old_forward
del module._old_forward
return module

def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
r"""
Hook that is executed just before the forward method of the model.

Args:
module (`torch.nn.Module`):
The module whose forward pass will be executed just after this event.
args (`Tuple[Any]`):
The positional arguments passed to the module.
kwargs (`Dict[Str, Any]`):
The keyword arguments passed to the module.
Returns:
`Tuple[Tuple[Any], Dict[Str, Any]]`:
A tuple with the treated `args` and `kwargs`.
"""
return args, kwargs

def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
r"""
Hook that is executed just after the forward method of the model.

Args:
module (`torch.nn.Module`):
The module whose forward pass been executed just before this event.
output (`Any`):
The output of the module.
Returns:
`Any`: The processed `output`.
"""
return output

def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when the hook is detached from a module.

Args:
module (`torch.nn.Module`):
The module detached from this hook.
"""
return module

def reset_state(self, module: torch.nn.Module):
if self._is_stateful:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module


class HookRegistry:
def __init__(self, module_ref: torch.nn.Module) -> None:
super().__init__()

self.hooks: Dict[str, ModelHook] = {}

self._module_ref = module_ref
self._hook_order = []

def register_hook(self, hook: ModelHook, name: str) -> None:
if name in self.hooks.keys():
logger.warning(f"Hook with name {name} already exists, replacing it.")

if hasattr(self._module_ref, "_old_forward"):
old_forward = self._module_ref._old_forward
else:
old_forward = self._module_ref.forward
self._module_ref._old_forward = self._module_ref.forward

self._module_ref = hook.initialize_hook(self._module_ref)

if hasattr(hook, "new_forward"):
rewritten_forward = hook.new_forward

def new_forward(module, *args, **kwargs):
args, kwargs = hook.pre_forward(module, *args, **kwargs)
output = rewritten_forward(module, *args, **kwargs)
return hook.post_forward(module, output)
else:

def new_forward(module, *args, **kwargs):
args, kwargs = hook.pre_forward(module, *args, **kwargs)
output = old_forward(*args, **kwargs)
return hook.post_forward(module, output)

self._module_ref.forward = functools.update_wrapper(
functools.partial(new_forward, self._module_ref), old_forward
)

self.hooks[name] = hook
self._hook_order.append(name)

def get_hook(self, name: str) -> Optional[ModelHook]:
if name not in self.hooks.keys():
return None
return self.hooks[name]

def remove_hook(self, name: str, recurse: bool = True) -> None:
if name in self.hooks.keys():
hook = self.hooks[name]
self._module_ref = hook.deinitalize_hook(self._module_ref)
del self.hooks[name]
self._hook_order.remove(name)

if recurse:
for module_name, module in self._module_ref.named_modules():
if module_name == "":
continue
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.remove_hook(name, recurse=False)

def reset_stateful_hooks(self, recurse: bool = True) -> None:
for hook_name in self._hook_order:
hook = self.hooks[hook_name]
if hook._is_stateful:
hook.reset_state(self._module_ref)

if recurse:
for module_name, module in self._module_ref.named_modules():
if module_name == "":
continue
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False)

@classmethod
def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
if not hasattr(module, "_diffusers_hook"):
module._diffusers_hook = cls(module)
return module._diffusers_hook

def __repr__(self) -> str:
hook_repr = ""
for i, hook_name in enumerate(self._hook_order):
hook_repr += f" ({i}) {hook_name} - ({self.hooks[hook_name].__class__.__name__})"
if i < len(self._hook_order) - 1:
hook_repr += "\n"
return f"HookRegistry(\n{hook_repr}\n)"
Loading
Loading