Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Layerwise Upcasting #10347

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open

[core] Layerwise Upcasting #10347

wants to merge 33 commits into from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Dec 23, 2024

[...continuation of #9177]

Pytorch has had support for float8_e4m3fn and float8_e5m2 as storage dtypes for a while now. This allows one to store model weights in a lower precision dtype and upcast them on-the-fly when a layer is required for proceeding with computation.

Code
import argparse
import gc
import pathlib
import traceback

import git
import pandas as pd
import torch
from diffusers import AllegroPipeline, CogVideoXPipeline, LattePipeline, FluxPipeline, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, MochiPipeline, LTXPipeline
from diffusers.utils import export_to_video, load_image
from diffusers.utils.logging import set_verbosity_info, set_verbosity_debug
from tabulate import tabulate


repo = git.Repo(path="/home/aryan/work/diffusers")
branch = repo.active_branch


def pretty_print_results(results, precision: int = 3):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output


def prepare_flux(dtype: torch.dtype, compile: bool = False, **kwargs) -> None:
    model_id = "black-forest-labs/Flux.1-Dev"
    cache_dir = "/raid/.cache/huggingface"

    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A cat holding a sign that says hello world",
        "height": 768,
        "width": 768,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_cogvideox_1_0(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "THUDM/CogVideoX-5b"
    cache_dir = None

    pipe = CogVideoXPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": (
            "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
            "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
            "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
            "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
            "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
            "atmosphere of this unique musical performance."
        ),
        "height": 480,
        "width": 720,
        "num_frames": 49,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_latte(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "maxin-cn/Latte-1"
    cache_dir = None

    pipe = LattePipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "a cat wearing sunglasses and working as a lifeguard at pool.",
        "height": 512,
        "width": 512,
        "video_length": 16,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_ltx_video(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "a-r-r-o-w/LTX-Video-diffusers"
    cache_dir = None

    pipe = LTXPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)
    
    generation_kwargs = {
        "prompt": "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage",
        "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
        "width": 768,
        "height": 512,
        "num_frames": 161,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_allegro(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "rhymes-ai/Allegro"
    cache_dir = None

    pipe = AllegroPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")
    pipe.vae.enable_tiling()

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this location might be a popular spot for docking fishing boats.",
        "height": 720,
        "width": 1280,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_hunyuan_video(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "hunyuanvideo-community/HunyuanVideo"
    cache_dir = None

    transformer = HunyuanVideoTransformer3DModel.from_pretrained(
        model_id, subfolder="transformer", torch_dtype=torch.bfloat16
    )
    pipe = HunyuanVideoPipeline.from_pretrained(
        model_id, transformer=transformer, torch_dtype=torch.float16, cache_dir=cache_dir
    )
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A cat wearing sunglasses and working as a lifeguard at pool.",
        "height": 320,
        "width": 512,
        "num_frames": 61,
        "num_inference_steps": 30,
    }

    return pipe, generation_kwargs


def prepare_mochi(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "genmo/mochi-1-preview"
    cache_dir = None

    pipe = MochiPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")
    pipe.vae.enable_tiling()

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
        "height": 480,
        "width": 848,
        "num_frames": 85,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def decode_flux(pipe: FluxPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    height = kwargs["height"]
    width = kwargs["width"]
    filename = f"{filename.as_posix()}.png"
    latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
    latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    image = pipe.vae.decode(latents, return_dict=False)[0]
    image = pipe.image_processor.postprocess(image, output_type="pil")[0]
    image.save(filename)
    return filename


def decode_cogvideox_1_0(pipe: CogVideoXPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_latte(pipe: LattePipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents, video_length=kwargs["video_length"])
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_allegro(pipe: AllegroPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_hunyuan_video(pipe: HunyuanVideoPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_mochi(pipe: MochiPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
    latents_std = torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
    latents = latents * latents_std / pipe.vae.config.scaling_factor + latents_mean
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_ltx_video(pipe: LTXPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latent_num_frames = (kwargs["num_frames"] - 1) // pipe.vae_temporal_compression_ratio + 1
    latent_height = kwargs["height"] // pipe.vae_spatial_compression_ratio
    latent_width = kwargs["width"] // pipe.vae_spatial_compression_ratio

    latents = pipe._unpack_latents(
        latents,
        latent_num_frames,
        latent_height,
        latent_width,
        pipe.transformer_spatial_patch_size,
        pipe.transformer_temporal_patch_size,
    )
    latents = pipe._denormalize_latents(
        latents, pipe.vae.latents_mean, pipe.vae.latents_std, pipe.vae.config.scaling_factor
    )
    latents = latents.to(pipe.vae.dtype)

    timestep = None
    video = pipe.vae.decode(latents, timestep, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
    export_to_video(video, filename, fps=24)
    return filename


def clean_memory():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.reset_accumulated_memory_stats()


MODEL_MAPPING = {
    "flux": {
        "prepare": prepare_flux,
        "decode": decode_flux,
    },
    "cogvideox-1.0": {
        "prepare": prepare_cogvideox_1_0,
        "decode": decode_cogvideox_1_0,
    },
    "latte": {
        "prepare": prepare_latte,
        "decode": decode_latte,
    },
    "allegro": {
        "prepare": prepare_allegro,
        "decode": decode_allegro,
    },
    "hunyuan_video": {
        "prepare": prepare_hunyuan_video,
        "decode": decode_hunyuan_video,
    },
    "mochi": {
        "prepare": prepare_mochi,
        "decode": decode_mochi,
    },
    "ltx_video": {
        "prepare": prepare_ltx_video,
        "decode": decode_ltx_video,
    },
}

STR_TO_DTYPE = {
    "float8_e4m3fn": torch.float8_e4m3fn,
    "float8_e5m2": torch.float8_e5m2,
    "bfloat16": torch.bfloat16,
    "float16": torch.float16,
    "float32": torch.float32,
}


def run_inference(pipe, generation_kwargs):
    generator = torch.Generator("cuda").manual_seed(181201)
    output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
    torch.cuda.synchronize()
    return output


@torch.no_grad()
def main(
    model_id: str, apply_layerwise_upcasting: str, output_dir: str, storage_dtype: str, compute_dtype: str, compile: bool = False
):
    if model_id not in MODEL_MAPPING.keys():
        raise ValueError("Unsupported `model_id` specified.")

    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    csv_filename = output_dir / f"{model_id}.csv"

    pytorch_storage_dtype = STR_TO_DTYPE[storage_dtype]
    pytorch_compute_dtype = STR_TO_DTYPE[compute_dtype]
    model = MODEL_MAPPING[model_id]

    try:
        clean_memory()

        # 1. Prepare inputs and generation kwargs
        pipe, generation_kwargs = model["prepare"](dtype=pytorch_compute_dtype, compile=compile)

        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)

        # 2. Apply layerwise upcasting technique
        if apply_layerwise_upcasting:
            pipe.transformer.enable_layerwise_upcasting(
                storage_dtype=pytorch_storage_dtype,
                compute_dtype=pytorch_compute_dtype,
                skip_modules_pattern=["pos_embed", "patch_embed", "norm"],
            )

        downcast_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)

        # 3. Warmup
        num_warmups = 1
        original_num_inference_steps = generation_kwargs["num_inference_steps"]
        generation_kwargs["num_inference_steps"] = 2
        for _ in range(num_warmups):
            run_inference(pipe, generation_kwargs)
        generation_kwargs["num_inference_steps"] = original_num_inference_steps

        # 4. Benchmark
        clean_memory()
        time, latents = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_memory = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 5. Decode latents
        filename = (
            output_dir
            / f"{model_id}---storage_dtype-{storage_dtype}---compute_dtype-{compute_dtype}---compile-{compile}"
        )
        filename = model["decode"](
            pipe,
            latents,
            filename,
            height=generation_kwargs["height"],
            width=generation_kwargs["width"],
            num_frames=generation_kwargs.get("num_frames", None),
            video_length=generation_kwargs.get("video_length", None),
        )

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "upcasting": apply_layerwise_upcasting,
            "time": time,
            "initial_memory": model_memory,
            "model_memory": downcast_memory,
            "inference_memory": inference_memory,
            "storage_dtype": storage_dtype,
            "compute_dtype": compute_dtype,
            "compile": compile,
            "branch": branch,
            "filename": filename,
            "exception": None,
        }

    except Exception as e:
        print(f"An error occurred: {e}")
        traceback.print_exc()

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "upcasting": apply_layerwise_upcasting,
            "time": None,
            "initial_memory": None,
            "model_memory": None,
            "inference_memory": None,
            "storage_dtype": storage_dtype,
            "compute_dtype": compute_dtype,
            "compile": compile,
            "branch": branch,
            "filename": None,
            "exception": str(e),
        }

    pretty_print_results(info, precision=3)

    df = pd.DataFrame([info])
    df.to_csv(csv_filename.as_posix(), mode="a", index=False, header=not csv_filename.is_file())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default="flux",
        choices=["flux", "cogvideox-1.0", "latte", "allegro", "hunyuan_video", "mochi", "ltx_video"],
        help="Model to run benchmark for.",
    )
    parser.add_argument(
        "--apply_layerwise_upcasting",
        action="store_true",
        help="Whether to apply layerwise upcasting to the transformer.",
    )
    parser.add_argument(
        "--output_dir", type=str, help="Path where the benchmark artifacts and outputs are the be saved."
    )
    parser.add_argument(
        "--storage_dtype",
        type=str,
        choices=["float8_e4m3fn", "float8_e5m2", "bfloat16", "float16", "float32"],
        help="Storage torch.dtype to use for transformer",
    )
    parser.add_argument(
        "--compute_dtype",
        type=str,
        choices=["bfloat16", "float16", "float32"],
        help="Compute torch.dtype to use for transformer",
    )
    parser.add_argument(
        "--compile",
        action="store_true",
        default=False,
        help="Whether to torch.compile the denoiser.",
    )
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging.")
    args = parser.parse_args()

    if args.verbose:
        set_verbosity_debug()
    else:
        set_verbosity_info()

    main(args.model_id, args.apply_layerwise_upcasting, args.output_dir, args.storage_dtype, args.compute_dtype, args.compile)
model_id granularity storage_dtype time initial_memory model_max_memory inference_max_memory
flux none bfloat16 16.939 31.438 31.447 32.02
flux diffusers_model float8_e4m3fn 23.866 31.438 21.178 33.963
flux diffusers_layer float8_e4m3fn 18.125 31.438 28.779 29.291
flux pytorch_layer float8_e4m3fn 20.339 31.438 24.449 24.945
flux diffusers_model float8_e5m2 22.097 31.438 21.18 33.949
flux diffusers_layer float8_e5m2 18.013 31.438 28.797 29.309
flux pytorch_layer float8_e5m2 20.084 31.44 24.451 24.947
cogvideox-1.0 none bfloat16 244.255 19.661 19.678 24.426
cogvideox-1.0 diffusers_model float8_e4m3fn 243.65 19.661 14.531 25.217
cogvideox-1.0 diffusers_layer float8_e4m3fn 243.541 19.66 16.76 21.469
cogvideox-1.0 pytorch_layer float8_e4m3fn 243.346 19.661 15.281 19.992
cogvideox-1.0 diffusers_model float8_e5m2 243.899 19.661 14.531 25.217
cogvideox-1.0 diffusers_layer float8_e5m2 243.182 19.661 16.76 21.469
cogvideox-1.0 pytorch_layer float8_e5m2 243.136 19.661 15.281 19.992
hunyuan_video none bfloat16 71.748 38.584 38.613 41.141
hunyuan_video diffusers_layer float8_e4m3fn 71.933 38.574 35.904 38.314
hunyuan_video pytorch_layer float8_e4m3fn 72.869 38.573 31.33 33.719
latte none bfloat16 27.986 11.005 11.314 12.471
latte diffusers_layer float8_e4m3fn 27.921 11.005 10.75 11.889
latte pytorch_layer float8_e4m3fn 28.079 11.005 10.879 12.018
mochi none bfloat16 431.799 28.411 28.648 36.059
mochi diffusers_layer float8_e4m3fn 432.142 28.411 24.424 31.934
mochi pytorch_layer float8_e4m3fn 431.947 28.411 21.988 29.441
Flux visual results
Baseline
diffusers_model-float8_e4m3 diffusers_model-float8_e5m2
diffusers_layer-float8_e4m3 diffusers_layer-float8_e5m2
pytorch_layer-float8_e4m3 pytorch_layer-float8_e5m2
CogVideoX visual results
Baseline
cogvideox-1.0---storage_dtype-bfloat16---compute_dtype-bfloat16---granularity-none---compile-False.mp4
diffusers_model-float8_e4m3 diffusers_model-float8_e5m2
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_model---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-diffusers_model---compile-False.mp4
diffusers_layer-float8_e4m3 diffusers_layer-float8_e5m2
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
pytorch_layer-float8_e4m3 pytorch_layer-float8_e5m2
cogvideox-1.0---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
cogvideox-1.0---storage_dtype-float8_e5m2---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
Hunyuan Video visual results
Baseline
hunyuan_video---storage_dtype-bfloat16---compute_dtype-bfloat16---granularity-none---compile-False.mp4
diffusers_layer-float8_e4m3
hunyuan_video---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
pytorch_layer-float8_e4m3
hunyuan_video---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4
Mochi visual results
Baseline
mochi---storage_dtype-bfloat16---compute_dtype-bfloat16---granularity-none---compile-False.mp4
diffusers_layer-float8_e4m3
mochi---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-diffusers_layer---compile-False.mp4
pytorch_layer-float8_e4m3
mochi---storage_dtype-float8_e4m3fn---compute_dtype-bfloat16---granularity-pytorch_layer---compile-False.mp4

Assumptions made so far:

  • The input to the models with a hook are not casted, and are expected to already be in compute_dtype
  • Weight casting learned parameters of normalization layers can lead to poor quality as we've seen in the past few integrations. By default, layers for normalization and modulation are not downcasted to storage_dtype.
  • Sensible default names to avoid embedding, normalization and modulation layers. This is still configurable so users can choose to typecast them if they want.

Why is there no memory savings in the initial load memory?

We are first moving weights to VRAM and then performing the lower dtype casting. We should maybe look into directly allowing loading of weights of lower dtype


Why a different approach from #9177?

While providing the API to use this via ModelMixin is okay, it puts a restriction that requires all implementations to derive from it to use it. As this method can be generally applied to any modeling component, at any level of granularity, implementing it independent of ModelMixin allows for its use in other modeling components like text encoders, which come from transformers, and any downstream research work or library can directly use it for their demos on Spaces without having to reimplement the wheel.

Not opposed to the idea of having enable_layerwise_upcasting in ModelMixin, but let's do it in a way that does not impose any restrictions on how it's possible to use it.

Also, the original PR typecasted all leaf nodes to storage dtype, but this may not be ideal for things like normalization and modulation, so supporting parameters like skip_modules_pattern and skip_modules_classes helps ignore a few layers. We can default to sensible values, while to maintain another parameter per class for layers to not upcast/downcast. This is also one of the places where it helps to follow a common naming convention across all our models.


Fixes #9949

cc @vladmandic @asomoza

TODOs:

  • Explore non_blocking and cuda streams for overlapping weight casting with computation No real impact on time unless weight casting is combined with device casting
  • Try to make torch compile work (edit: works if we increase the cache_size_limit but still recompiles multiple times)
  • Test with LoRAs
  • Test with training
  • Test tensor caching in lower precision for methods like [core] Pyramid Attention Broadcast #9562 and [core] FasterCache #10163
  • Tests
  • Docs

Nice reading material for the interested:

@a-r-r-o-w a-r-r-o-w requested review from DN6, sayakpaul and hlky December 23, 2024 00:14
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6
Copy link
Collaborator

DN6 commented Dec 27, 2024

Nice start 👍🏽 A few things to consider here

  1. Rather than relying on standard names for keys to ignore for upcasting, it would be better to have them be class attributes in the models themselves. e.g something like this
    _always_upcast_modules = ["Decoder"]

It is difficult to maintain a large global list of supported ops and can lead to us either missing modules or not applying upcasting in cases where it can be used.

  1. Upcasting should also account for _keep_in_fp32_modules the way we do with quantization.

  2. There are model components that have casting operations internally such as:

    upscale_dtype = next(iter(self.up_blocks.parameters())).dtype

So any kind of layerwise casting on these modules runs into an error because the parameters remain in a lower memory dtype unless the entire module is upcast. The initial PR got around this by adding the _always_upcast_modules attribute that would apply the hook to the top level module instead of the individual layers.

if hasattr(self, "_always_upcast_modules") and module.__class__.__name__ in self._always_upcast_modules:

This implementation seems to do something similar using the global _SUPPORTED_DIFFUSERS_LAYERS list, but this should also be a class attribute IMO.

  1. It's fine to add the hooks via the functions defined here, but enabling and disabling upcasting should be done through the ModelMixin IMO. If users want to apply them to other models we can include a section in the docs about importing the relevant functions from the hooks module.

src/diffusers/models/hooks.py Outdated Show resolved Hide resolved
@a-r-r-o-w a-r-r-o-w added the roadmap Add to current release roadmap label Dec 30, 2024
@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jan 4, 2025

Layerwise upcasting of text encoders are possible by leveraging apply_layerwise_upcasting. Tested with many different existing models, as well as ComfyUI nodes and it seems to work really well unless:

  • There is dtype casting of weights within the original forward (seems to be uncommon). Since we overwrite the forward method ourselves, any casting in the original forward is going to be very problematic to deal with. This is the case with T5Encoder from transformers, so I'm not quite sure how to deal with it without workarounds like in the example below.
  • There are model_weight-based casting of input tensors. This is used a lot in PEFT and requires workarounds too.
Code
import gc
import torch
from diffusers import CogVideoXPipeline, apply_layerwise_upcasting
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug
from transformers.models.t5.modeling_t5 import T5DenseGatedActDense

set_verbosity_debug()


def main(apply_layerwise_upcasting_text_encoder: bool = False):
    # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
    pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    pipe.transformer.enable_layerwise_upcasting(
        storage_dtype=torch.float8_e4m3fn,
        compute_dtype=torch.bfloat16,
        granularity="pytorch_layer",
        skip_modules_pattern=["patch_embed", "norm"]
    )

    if apply_layerwise_upcasting_text_encoder:
        for name, module in pipe.text_encoder.named_modules():
            if isinstance(module, T5DenseGatedActDense):
                module.forward = T5DenseGatedActDense_forward.__get__(module)

        pipe.text_encoder = apply_layerwise_upcasting(
            pipe.text_encoder,
            storage_dtype=torch.float8_e4m3fn,
            compute_dtype=torch.bfloat16,
            granularity="pytorch_layer",
            skip_modules_pattern=["norm"]
        )

    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    print(f"Model memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")

    prompt = (
        "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
        "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
        "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
        "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
        "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
        "atmosphere of this unique musical performance."
    )

    with torch.no_grad():
        prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
            prompt=prompt,
            negative_prompt="",
            do_classifier_free_guidance=True,
            num_videos_per_prompt=1,
            device="cuda",
            dtype=torch.bfloat16,
        )

    video = pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        guidance_scale=6,
        num_inference_steps=50,
        generator=torch.Generator().manual_seed(42)
    ).frames[0]
    export_to_video(video, "output.mp4", fps=8)

    print(f"Inference memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")


# If we don't overwrite the forward method of T5DenseGatedActDense, the original forward would downcast hidden_states
# to torch.float8_e4m3fn, which would cause an error.
# Line in question: https://github.com/huggingface/transformers/blob/e5fd865ebae062b7cf03a81b8c6affeb39f30bec/src/transformers/models/t5/modeling_t5.py#L292
def T5DenseGatedActDense_forward(self, hidden_states):
    hidden_gelu = self.act(self.wi_0(hidden_states))
    hidden_linear = self.wi_1(hidden_states)
    hidden_states = hidden_gelu * hidden_linear
    hidden_states = self.dropout(hidden_states)
    hidden_states = self.wo(hidden_states)
    return hidden_states


main()

# Without text encoder fp8 layerwise upcasting:
# Model memory: 15.228 GB
# Inference memory: 30.010 GB

# With text encoder fp8 layerwise upcasting:
# Model memory: 10.915 GB
# Inference memory: 25.705 GB

When layerwise upcasting is not enable in T5, the memory required is about 30 GB. When enabled, the memory usage is about 25.7 GB.

Without text encoder fp8 With text encoder fp8
cogvideox-layerwise-without-text-encoder.mp4
cogvideox-layerwise-with-text-encoder.mp4

I'm not sure how to get around this easily. We could probably use a simple context manager in the hook implementations to disable any tensor casts in the internal model forwards but it seems a bit too hacky to me :/ Open to suggestions

class DisableTensorTo:
    def __enter__(self):
        self.original_to = torch.Tensor.to
        
        def noop_to(self, *args, **kwargs):
            return self
    
        torch.Tensor.to = noop_to
    
    def __exit__(self, exc_type, exc_value, traceback):
        torch.Tensor.to = self.original_to

@a-r-r-o-w
Copy link
Member Author

For the reasons mentioned above, we are unable to run LoRA inference either. peft.tuners.lora.layer::Linear casts the model input based on the weight dtype, which is problematic too. Our hooks assumes the input to already be in compute_dtype, but peft will cast it to lower dtype here, which will be lossy cast (if we were to align input dtypes ourselves in the hook forward)

You can overwrite the forward methods for one (to get it to work without too much thinking), but maybe the context manager solution for disabling torch.Tensor::to works here as well :/

There are actually two problems that need to be dealt with for LoRA. Whether you load the lora weights before or after enabling layerwise upcasting, they will be loaded in the correct dtype same as the transformer (torch.float8_e4m3fn for example). This is good because we don't have to add any additional code to handle this. But:

  • if we load lora before enabling layerwise upcasting, then all lora linears get the upcasting hook attached, which is what we want. All is well in this case, except that input tensors are casted to float8 in peft.
  • if we load lora after enabling layerwise upcasting, then the lora linears don't get a hook attached (with the current implementation). This leaves the weights in float8 always. Open to suggestions on how we want to go about this - do I add some logic in load_lora_weights to check if the model already has upcasting hooks, and if so attach to lora layers as well?
Code
import gc
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug
from peft.tuners.lora.layer import Linear as LoRALinear

set_verbosity_debug()


def main():
    # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
    pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
    pipe.to("cuda")

    pipe.load_lora_weights("Cseti/CogVideoX-LoRA-Wallace_and_Gromit", weight_name="walgro1-3000.safetensors", adapter_name="cogvideox-lora")
    pipe.set_adapters(["cogvideox-lora"], [1.0])
    
    pipe.transformer.enable_layerwise_upcasting(
        storage_dtype=torch.float8_e4m3fn,
        compute_dtype=torch.bfloat16,
        granularity="pytorch_layer",
        skip_modules_pattern=["patch_embed", "norm"]
    )

    # Post layerwise upcasting does load lora weights in torch.float8_e4m3fn but due to no hooks, it errors during inference
    # pipe.load_lora_weights("Cseti/CogVideoX-LoRA-Wallace_and_Gromit", weight_name="walgro1-3000.safetensors", adapter_name="cogvideox-lora")
    # pipe.set_adapters(["cogvideox-lora"], [1.0])
    
    for name, parameter in pipe.transformer.named_parameters():
        if "lora" in name:
            assert(parameter.dtype == torch.float8_e4m3fn)
    
    LoRALinear.forward = LoRALinear_forward

    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    print(f"Model memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")

    prompt = "walgro1. The scene begins with a close-up of Gromit's face, his expressive eyes filling the frame. His brow furrows slightly, ears perked forward in concentration. The soft lighting highlights the subtle details of his fur, every strand catching the warm sunlight filtering in from a nearby window. His dark, round nose twitches ever so slightly, sensing something in the air, and his gaze darts to the side, following an unseen movement. The camera lingers on Gromit’s face, capturing the subtleties of his expression—a quirked eyebrow and a knowing look that suggests he’s piecing together something clever. His silent, thoughtful demeanor speaks volumes as he watches the scene unfold with quiet intensity. The background remains out of focus, drawing all attention to the sharp intelligence in his eyes and the slight tilt of his head. In the claymation style of Wallace and Gromit."

    with torch.no_grad():
        prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
            prompt=prompt,
            negative_prompt="",
            do_classifier_free_guidance=True,
            num_videos_per_prompt=1,
            device="cuda",
            dtype=torch.bfloat16,
        )

    video = pipe(
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
        guidance_scale=6,
        num_inference_steps=50,
        generator=torch.Generator().manual_seed(42)
    ).frames[0]
    export_to_video(video, "output.mp4", fps=8)

    print(f"Inference memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")


from typing import Any

def LoRALinear_forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
    self._check_forward_args(x, *args, **kwargs)
    adapter_names = kwargs.pop("adapter_names", None)

    if self.disable_adapters:
        if self.merged:
            self.unmerge()
        result = self.base_layer(x, *args, **kwargs)
    elif adapter_names is not None:
        result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
    elif self.merged:
        result = self.base_layer(x, *args, **kwargs)
    else:
        result = self.base_layer(x, *args, **kwargs)
        torch_result_dtype = result.dtype
        for active_adapter in self.active_adapters:
            if active_adapter not in self.lora_A.keys():
                continue
            lora_A = self.lora_A[active_adapter]
            lora_B = self.lora_B[active_adapter]
            dropout = self.lora_dropout[active_adapter]
            scaling = self.scaling[active_adapter]
            # x = x.to(lora_A.weight.dtype)

            if not self.use_dora[active_adapter]:
                result = result + lora_B(lora_A(dropout(x))) * scaling
            else:
                if isinstance(dropout, torch.nn.Identity) or not self.training:
                    base_result = result
                else:
                    x = dropout(x)
                    base_result = None

                result = result + self.lora_magnitude_vector[active_adapter](
                    x,
                    lora_A=lora_A,
                    lora_B=lora_B,
                    scaling=scaling,
                    base_layer=self.get_base_layer(),
                    base_result=base_result,
                )

        result = result.to(torch_result_dtype)

    return result


main()

# Model memory: 15.351 GB
# Inference memory: 30.134 GB
cogvideox-layerwise-lora.mp4

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Jan 6, 2025

Just documenting for now because I don't know what approach we're going to take to deal with these kinds of problems yet.

If we enable layerwise upcasting for something like HunyuanVideo, the prompt_embeds and prompt_attention_mask are moved to fp8 dtype, and the latent input passed into the model is also fp8. This is because of lines like this:

One solution would be to overwrite the dtype method for ModelMixin to detect if a LayerwiseUpcastingHook is attached to the submodules, and if so, just read out the compute_dtype and return it. Open to suggestions and will be on the lookout for more such things. Surprisingly, this does not error out on A100s (Ampere), but does on H100 (Hopper). This was discovered during fp8 lora training run

Edit: Ah, so the reason why it worked in the benchmark script is because the x_embedder layer is skipped from layerwise upcasting. Our dtype() method therefore simply reads bfloat16 as the first parameter dtype and works with it. If we were to enable upcasting in x_embedder as well, it errors on A100 too

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping, I left some comments.

src/diffusers/hooks/layerwise_upcasting.py Outdated Show resolved Hide resolved
src/diffusers/models/modeling_utils.py Outdated Show resolved Hide resolved
@a-r-r-o-w
Copy link
Member Author

@stevhliu Hi! Would love to hear your thoughts on how best to document this and where it would fit for good visibility about consumer-GPU-friendly optimizations. Will update the docs accordingly. Would like to make a mention of ModelMixin.enable_layerwise_upcasting, diffusers.hooks.layerwise_upcasting.apply_layerwise_upcasting, and some information about what kinds of layers it makes sense to apply the lossy fp8 downcasting to without hurting generation quality too much

@DN6 @sayakpaul @hlky Would you be able to give this a review when you find time? 🤗 This should almost be complete apart from docs I think. If there are additional tests you'd like to see, LMK about it

)

def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @DN6 Another things I've had in mind is allowing self.compute_dtype to be torch.float8_e4m3fn as well. This is because we can leverage FP8 matmul for speedup on Ada and above GPUs, similar to how it's done here: https://github.com/facebookresearch/lingua/blob/main/lingua/float8.py.

Not for this PR, but definitely something we should consider supporting in future

@stevhliu
Copy link
Member

Super nice @a-r-r-o-w, you can add it to the Reduce memory usage doc!

If you think it's going to be a pretty substantial doc or is important enough to deserve a dedicated page, then you can also make a standalone guide in "Accelerate inference and reduce memory" section. The upside of this option is you'll have better visibility :)

@a-r-r-o-w
Copy link
Member Author

@stevhliu It's nothing serious or too innovative to deserve its own doc page, so we can mention the memory savings and, eventually in follow-up PR, time savings with fp8 matmul (if we decide to add support for it) in this page. It seems like a great place to do so, thanks!

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the docs! 👏

docs/source/en/optimization/memory.md Outdated Show resolved Hide resolved
docs/source/en/optimization/memory.md Outdated Show resolved Hide resolved
docs/source/en/optimization/memory.md Outdated Show resolved Hide resolved
docs/source/en/optimization/memory.md Outdated Show resolved Hide resolved
docs/source/en/optimization/memory.md Outdated Show resolved Hide resolved
Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some very minor things to take care of. Think we're very close to merge. And regarding this
#10347 (comment)

Yeah I think we can do as you suggested and add a check to dtype in ModelMixin to check for the upcasting hook and return compute dtype.

src/diffusers/models/layerwise_upcasting_utils.py Outdated Show resolved Hide resolved
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this example, since we move the entire pipeline to cuda in bfloat16, the memory would still spike.

I think you would need to load the transformer with torch_dtype=torch.float8_e4m3fn, enable layerwise upcasting and then pass it to the pipeline.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, missed it! I've updated the example.

I think we should also try to support directly loading in float8_* types in from_pretrained in a future PR, which respects _keep_modules_in_fp32 and _always_upcast_modules (I'm not too sure if this is a good variable name btw so open to suggestions)

src/diffusers/hooks/layerwise_upcasting.py Outdated Show resolved Hide resolved
return

should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern wouldn't catch class names? e.g. _always_upcast_modules = ["MaskConditionDecoder"] in the asymmetric autoencoder?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, looks like a mistake in some files. My intention with _always_upcast_modules is for them to be regex patterns for the layers names. I think we can cover a wider number of layers, in a clean manner, that are prone to producing worse results (such as the normalization ones). For places that are mentioned incorrectly, such as the mention "MaskConditionDecoder", should really just have been the name of the layer that is initialized to an instance of that module type.

LMK if you'd prefer that we make _always_upcast_modules to list of classes

Comment on lines +107 to +115
# 1. Check if we have attached any dtype modifying hooks (eg. layerwise upcasting)
if isinstance(parameter, nn.Module):
for name, submodule in parameter.named_modules():
if not hasattr(submodule, "_diffusers_hook"):
continue
registry = submodule._diffusers_hook
hook = registry.get_hook("layerwise_upcasting")
if hook is not None:
return hook.compute_dtype
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 Added this here as discussed

Comment on lines -2134 to +2136
t_emb = t_emb.to(dtype=self.dtype)
# TODO(aryan): Need to have this reviewed
t_emb = t_emb.to(dtype=sample.dtype)
Copy link
Member Author

@a-r-r-o-w a-r-r-o-w Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DN6 There are a few places where I've marked as todo for review. I think it should be a safe change but a second set of eyes looking will be really helpful. But the change is probably not required now that we changed the .dtype logic in ModelMixin. Will check to make sure

@sayakpaul sayakpaul mentioned this pull request Jan 20, 2025
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

[Experimental] expose dynamic upcasting of layers as experimental APIs
5 participants