Skip to content

Commit

Permalink
update doc page
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Jan 15, 2025
1 parent 64e6c9c commit 7037133
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions docs/source/en/optimization/memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,39 @@ In order to properly offload models after they're called, it is required to run

</Tip>

## FP8 layerwise weight-casting

PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes. This precision cannot be used for performing computation directly for many different tensor operations due to unimplemented kernel support. However, one can still use these dtypes for storing model weights in lower FP8 precision. For computation, the weights can be upcasted on-the-fly as and when layers are invoked in the forward pass.

Typically, inference on most models is done with `torch.float16` or `torch.bfloat16` weight/computation precision. Applying layerwise weight-casting, by storing the weights in FP8 precision, cuts down the memory footprint of the model weights by half approximately.

```python
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe.transformer.enable_layerwise_upcasting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)

prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
```

In the above example, we enable layerwise upcasting on the transformer component of the pipeline. By default, certain layers are skipped from the FP8 weight casting because it can lead to significant degradation of generation quality. For most cases, skipping the normalization and modulation related weight parameters is a good choice (which is also the default choice). However, more control and flexibility can be obtained by directly utilizing the [`~hooks.layerwise_upcasting.apply_layerwise_upcasting`] function instead of using [`~ModelMixin.enable_layerwise_upcasting`].

[[autodoc]] ModelMixin.enable_layerwise_upcasting

[[autodoc]] hooks.layerwise_upcasting.apply_layerwise_upcasting

## Channels-last memory format

The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.
Expand Down

0 comments on commit 7037133

Please sign in to comment.