-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] Layerwise Upcasting #10347
base: main
Are you sure you want to change the base?
[core] Layerwise Upcasting #10347
Conversation
Co-Authored-By: Dhruv Nair <[email protected]>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Nice start 👍🏽 A few things to consider here
It is difficult to maintain a large global list of supported ops and can lead to us either missing modules or not applying upcasting in cases where it can be used.
So any kind of layerwise casting on these modules runs into an error because the parameters remain in a lower memory dtype unless the entire module is upcast. The initial PR got around this by adding the
This implementation seems to do something similar using the global
|
Layerwise upcasting of text encoders are possible by leveraging
Codeimport gc
import torch
from diffusers import CogVideoXPipeline, apply_layerwise_upcasting
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug
from transformers.models.t5.modeling_t5 import T5DenseGatedActDense
set_verbosity_debug()
def main(apply_layerwise_upcasting_text_encoder: bool = False):
# Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe.transformer.enable_layerwise_upcasting(
storage_dtype=torch.float8_e4m3fn,
compute_dtype=torch.bfloat16,
granularity="pytorch_layer",
skip_modules_pattern=["patch_embed", "norm"]
)
if apply_layerwise_upcasting_text_encoder:
for name, module in pipe.text_encoder.named_modules():
if isinstance(module, T5DenseGatedActDense):
module.forward = T5DenseGatedActDense_forward.__get__(module)
pipe.text_encoder = apply_layerwise_upcasting(
pipe.text_encoder,
storage_dtype=torch.float8_e4m3fn,
compute_dtype=torch.bfloat16,
granularity="pytorch_layer",
skip_modules_pattern=["norm"]
)
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
print(f"Model memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
with torch.no_grad():
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt=prompt,
negative_prompt="",
do_classifier_free_guidance=True,
num_videos_per_prompt=1,
device="cuda",
dtype=torch.bfloat16,
)
video = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator().manual_seed(42)
).frames[0]
export_to_video(video, "output.mp4", fps=8)
print(f"Inference memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")
# If we don't overwrite the forward method of T5DenseGatedActDense, the original forward would downcast hidden_states
# to torch.float8_e4m3fn, which would cause an error.
# Line in question: https://github.com/huggingface/transformers/blob/e5fd865ebae062b7cf03a81b8c6affeb39f30bec/src/transformers/models/t5/modeling_t5.py#L292
def T5DenseGatedActDense_forward(self, hidden_states):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
main()
# Without text encoder fp8 layerwise upcasting:
# Model memory: 15.228 GB
# Inference memory: 30.010 GB
# With text encoder fp8 layerwise upcasting:
# Model memory: 10.915 GB
# Inference memory: 25.705 GB When layerwise upcasting is not enable in T5, the memory required is about 30 GB. When enabled, the memory usage is about 25.7 GB.
I'm not sure how to get around this easily. We could probably use a simple context manager in the hook implementations to disable any tensor casts in the internal model forwards but it seems a bit too hacky to me :/ Open to suggestions class DisableTensorTo:
def __enter__(self):
self.original_to = torch.Tensor.to
def noop_to(self, *args, **kwargs):
return self
torch.Tensor.to = noop_to
def __exit__(self, exc_type, exc_value, traceback):
torch.Tensor.to = self.original_to |
For the reasons mentioned above, we are unable to run LoRA inference either. You can overwrite the forward methods for one (to get it to work without too much thinking), but maybe the context manager solution for disabling There are actually two problems that need to be dealt with for LoRA. Whether you load the lora weights before or after enabling layerwise upcasting, they will be loaded in the correct dtype same as the transformer (
Codeimport gc
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug
from peft.tuners.lora.layer import Linear as LoRALinear
set_verbosity_debug()
def main():
# Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe.load_lora_weights("Cseti/CogVideoX-LoRA-Wallace_and_Gromit", weight_name="walgro1-3000.safetensors", adapter_name="cogvideox-lora")
pipe.set_adapters(["cogvideox-lora"], [1.0])
pipe.transformer.enable_layerwise_upcasting(
storage_dtype=torch.float8_e4m3fn,
compute_dtype=torch.bfloat16,
granularity="pytorch_layer",
skip_modules_pattern=["patch_embed", "norm"]
)
# Post layerwise upcasting does load lora weights in torch.float8_e4m3fn but due to no hooks, it errors during inference
# pipe.load_lora_weights("Cseti/CogVideoX-LoRA-Wallace_and_Gromit", weight_name="walgro1-3000.safetensors", adapter_name="cogvideox-lora")
# pipe.set_adapters(["cogvideox-lora"], [1.0])
for name, parameter in pipe.transformer.named_parameters():
if "lora" in name:
assert(parameter.dtype == torch.float8_e4m3fn)
LoRALinear.forward = LoRALinear_forward
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
print(f"Model memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")
prompt = "walgro1. The scene begins with a close-up of Gromit's face, his expressive eyes filling the frame. His brow furrows slightly, ears perked forward in concentration. The soft lighting highlights the subtle details of his fur, every strand catching the warm sunlight filtering in from a nearby window. His dark, round nose twitches ever so slightly, sensing something in the air, and his gaze darts to the side, following an unseen movement. The camera lingers on Gromit’s face, capturing the subtleties of his expression—a quirked eyebrow and a knowing look that suggests he’s piecing together something clever. His silent, thoughtful demeanor speaks volumes as he watches the scene unfold with quiet intensity. The background remains out of focus, drawing all attention to the sharp intelligence in his eyes and the slight tilt of his head. In the claymation style of Wallace and Gromit."
with torch.no_grad():
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt=prompt,
negative_prompt="",
do_classifier_free_guidance=True,
num_videos_per_prompt=1,
device="cuda",
dtype=torch.bfloat16,
)
video = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator().manual_seed(42)
).frames[0]
export_to_video(video, "output.mp4", fps=8)
print(f"Inference memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")
from typing import Any
def LoRALinear_forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
# x = x.to(lora_A.weight.dtype)
if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
else:
if isinstance(dropout, torch.nn.Identity) or not self.training:
base_result = result
else:
x = dropout(x)
base_result = None
result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
base_result=base_result,
)
result = result.to(torch_result_dtype)
return result
main()
# Model memory: 15.351 GB
# Inference memory: 30.134 GB cogvideox-layerwise-lora.mp4 |
Just documenting for now because I don't know what approach we're going to take to deal with these kinds of problems yet. If we enable layerwise upcasting for something like HunyuanVideo, the prompt_embeds and prompt_attention_mask are moved to fp8 dtype, and the latent input passed into the model is also fp8. This is because of lines like this:
One solution would be to overwrite the Edit: Ah, so the reason why it worked in the benchmark script is because the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the ping, I left some comments.
@stevhliu Hi! Would love to hear your thoughts on how best to document this and where it would fit for good visibility about consumer-GPU-friendly optimizations. Will update the docs accordingly. Would like to make a mention of @DN6 @sayakpaul @hlky Would you be able to give this a review when you find time? 🤗 This should almost be complete apart from docs I think. If there are additional tests you'd like to see, LMK about it |
) | ||
|
||
def pre_forward(self, module: torch.nn.Module, *args, **kwargs): | ||
module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @DN6 Another things I've had in mind is allowing self.compute_dtype
to be torch.float8_e4m3fn
as well. This is because we can leverage FP8 matmul for speedup on Ada and above GPUs, similar to how it's done here: https://github.com/facebookresearch/lingua/blob/main/lingua/float8.py.
Not for this PR, but definitely something we should consider supporting in future
Super nice @a-r-r-o-w, you can add it to the Reduce memory usage doc! If you think it's going to be a pretty substantial doc or is important enough to deserve a dedicated page, then you can also make a standalone guide in "Accelerate inference and reduce memory" section. The upside of this option is you'll have better visibility :) |
@stevhliu It's nothing serious or too innovative to deserve its own doc page, so we can mention the memory savings and, eventually in follow-up PR, time savings with fp8 matmul (if we decide to add support for it) in this page. It seems like a great place to do so, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the docs! 👏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some very minor things to take care of. Think we're very close to merge. And regarding this
#10347 (comment)
Yeah I think we can do as you suggested and add a check to dtype in ModelMixin to check for the upcasting hook and return compute dtype.
from diffusers import CogVideoXPipeline | ||
from diffusers.utils import export_to_video | ||
|
||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this example, since we move the entire pipeline to cuda in bfloat16, the memory would still spike.
I think you would need to load the transformer with torch_dtype=torch.float8_e4m3fn
, enable layerwise upcasting and then pass it to the pipeline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, missed it! I've updated the example.
I think we should also try to support directly loading in float8_*
types in from_pretrained
in a future PR, which respects _keep_modules_in_fp32
and _always_upcast_modules
(I'm not too sure if this is a good variable name btw so open to suggestions)
return | ||
|
||
should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or ( | ||
skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pattern wouldn't catch class names? e.g. _always_upcast_modules = ["MaskConditionDecoder"] in the asymmetric autoencoder?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, looks like a mistake in some files. My intention with _always_upcast_modules
is for them to be regex patterns for the layers names. I think we can cover a wider number of layers, in a clean manner, that are prone to producing worse results (such as the normalization ones). For places that are mentioned incorrectly, such as the mention "MaskConditionDecoder", should really just have been the name of the layer that is initialized to an instance of that module type.
LMK if you'd prefer that we make _always_upcast_modules
to list of classes
# 1. Check if we have attached any dtype modifying hooks (eg. layerwise upcasting) | ||
if isinstance(parameter, nn.Module): | ||
for name, submodule in parameter.named_modules(): | ||
if not hasattr(submodule, "_diffusers_hook"): | ||
continue | ||
registry = submodule._diffusers_hook | ||
hook = registry.get_hook("layerwise_upcasting") | ||
if hook is not None: | ||
return hook.compute_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6 Added this here as discussed
t_emb = t_emb.to(dtype=self.dtype) | ||
# TODO(aryan): Need to have this reviewed | ||
t_emb = t_emb.to(dtype=sample.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DN6 There are a few places where I've marked as todo for review. I think it should be a safe change but a second set of eyes looking will be really helpful. But the change is probably not required now that we changed the .dtype
logic in ModelMixin. Will check to make sure
[...continuation of #9177]
Pytorch has had support for
float8_e4m3fn
andfloat8_e5m2
as storage dtypes for a while now. This allows one to store model weights in a lower precision dtype and upcast them on-the-fly when a layer is required for proceeding with computation.Code
Flux visual results
CogVideoX visual results
cogvideox-1.0---storage_dtype-bfloat16---compute_dtype-bfloat16---granularity-none---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_model---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-diffusers_model---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
Hunyuan Video visual results
hunyuan_video---storage_dtype-bfloat16---compute_dtype-bfloat16---granularity-none---compile-False.mp4
hunyuan_video---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
hunyuan_video---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
Mochi visual results
mochi---storage_dtype-bfloat16---compute_dtype-bfloat16---granularity-none---compile-False.mp4
mochi---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
mochi---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
Assumptions made so far:
compute_dtype
storage_dtype
.Why is there no memory savings in the initial load memory?
We are first moving weights to VRAM and then performing the lower dtype casting. We should maybe look into directly allowing loading of weights of lower dtype
Why a different approach from #9177?
While providing the API to use this via
ModelMixin
is okay, it puts a restriction that requires all implementations to derive from it to use it. As this method can be generally applied to any modeling component, at any level of granularity, implementing it independent ofModelMixin
allows for its use in other modeling components like text encoders, which come from transformers, and any downstream research work or library can directly use it for their demos on Spaces without having to reimplement the wheel.Not opposed to the idea of having
enable_layerwise_upcasting
inModelMixin
, but let's do it in a way that does not impose any restrictions on how it's possible to use it.Also, the original PR typecasted all leaf nodes to storage dtype, but this may not be ideal for things like normalization and modulation, so supporting parameters like
skip_modules_pattern
andskip_modules_classes
helps ignore a few layers. We can default to sensible values, while to maintain another parameter per class for layers to not upcast/downcast. This is also one of the places where it helps to follow a common naming convention across all our models.Fixes #9949
cc @vladmandic @asomoza
TODOs:
ExploreNo real impact on time unless weight casting is combined with device castingnon_blocking
and cuda streams for overlapping weight casting with computationTest tensor caching in lower precision for methods like [core] Pyramid Attention Broadcast #9562 and [core] FasterCache #10163Nice reading material for the interested: