Skip to content

Commit

Permalink
Merge branch 'main' into groupwise-offloading
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Jan 16, 2025
2 parents a8eabd0 + b0c8973 commit deda9a3
Show file tree
Hide file tree
Showing 102 changed files with 1,464 additions and 164 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/push_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ jobs:
python utils/print_env.py
- name: PyTorch CUDA checkpoint tests on Ubuntu
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
Expand Down Expand Up @@ -137,7 +137,7 @@ jobs:
- name: Run PyTorch CUDA tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/api/pipelines/flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ transformer_8bit = FluxTransformer2DModel.from_pretrained(

pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
text_encoder=text_encoder_8bit,
text_encoder_2=text_encoder_8bit,
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
Expand Down
8 changes: 4 additions & 4 deletions docs/source/en/api/pipelines/hunyuan_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent.

*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/Tencent/HunyuanVideo).*
*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/tencent/HunyuanVideo).*

<Tip>

Expand Down Expand Up @@ -45,14 +45,14 @@ from diffusers.utils import export_to_video

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = HunyuanVideoTransformer3DModel.from_pretrained(
"tencent/HunyuanVideo",
"hunyuanvideo-community/HunyuanVideo",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.float16,
torch_dtype=torch.bfloat16,
)

pipeline = HunyuanVideoPipeline.from_pretrained(
"tencent/HunyuanVideo",
"hunyuanvideo-community/HunyuanVideo",
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
Expand Down
40 changes: 40 additions & 0 deletions docs/source/en/using-diffusers/other-formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,46 @@ Benefits of using a single-file layout include:
1. Easy compatibility with diffusion interfaces such as [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which commonly use a single-file layout.
2. Easier to manage (download and share) a single file.

### DDUF

> [!WARNING]
> DDUF is an experimental file format and APIs related to it can change in the future.
DDUF (**D**DUF **D**iffusion **U**nified **F**ormat) is a file format designed to make storing, distributing, and using diffusion models much easier. Built on the ZIP file format, DDUF offers a standardized, efficient, and flexible way to package all parts of a diffusion model into a single, easy-to-manage file. It provides a balance between Diffusers multi-folder format and the widely popular single-file format.

Learn more details about DDUF on the Hugging Face Hub [documentation](https://huggingface.co/docs/hub/dduf).

Pass a checkpoint to the `dduf_file` parameter to load it in [`DiffusionPipeline`].

```py
from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained(
"DDUF/FLUX.1-dev-DDUF", dduf_file="FLUX.1-dev.dduf", torch_dtype=torch.bfloat16
).to("cuda")
image = pipe(
"photo a cat holding a sign that says Diffusers", num_inference_steps=50, guidance_scale=3.5
).images[0]
image.save("cat.png")
```

To save a pipeline as a `.dduf` checkpoint, use the [`~huggingface_hub.export_folder_as_dduf`] utility, which takes care of all the necessary file-level validations.

```py
from huggingface_hub import export_folder_as_dduf
from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)

save_folder = "flux-dev"
pipe.save_pretrained("flux-dev")
export_folder_as_dduf("flux-dev.dduf", folder_path=save_folder)

> [!TIP]
> Packaging and loading quantized checkpoints in the DDUF format is supported as long as they respect the multi-folder structure.

## Convert layout and files

Diffusers provides many scripts and methods to convert storage layouts and file formats to enable broader support across the diffusion ecosystem.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/en/using-diffusers/text-img2vid.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

transformer = HunyuanVideoTransformer3DModel.from_pretrained(
"tencent/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
"hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(
"tencent/HunyuanVideo", transformer=transformer, torch_dtype=torch.float16
"hunyuanvideo-community/HunyuanVideo", transformer=transformer, torch_dtype=torch.float16
)

# reduce memory requirements
Expand Down
4 changes: 4 additions & 0 deletions examples/dreambooth/train_dreambooth_lora_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
if args.enable_vae_tiling:
pipeline.vae.enable_tiling(tile_sample_min_height=1024, tile_sample_stride_width=1024)

pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
Expand Down Expand Up @@ -597,6 +600,7 @@ def parse_args(input_args=None):
help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation")

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)

transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.23.2",
"huggingface-hub>=0.27.0",
"requests-mock==1.10.0",
"importlib_metadata",
"invisible-watermark>=0.2.0",
Expand Down Expand Up @@ -135,6 +135,7 @@
"transformers>=4.41.2",
"urllib3<=2.0.0",
"black",
"phonemizer",
]

# this is a lookup table with items like:
Expand Down Expand Up @@ -227,6 +228,7 @@ def run(self):
"scipy",
"torchvision",
"transformers",
"phonemizer",
)
extras["torch"] = deps_list("torch", "accelerate")

Expand Down
45 changes: 35 additions & 10 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import re
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
RepositoryNotFoundError,
Expand Down Expand Up @@ -347,6 +347,7 @@ def load_config(
_ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)

user_agent = {**user_agent, "file_type": "config"}
user_agent = http_user_agent(user_agent)
Expand All @@ -358,8 +359,15 @@ def load_config(
"`self.config_name` is not defined. Note that one should not load a config from "
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
)

if os.path.isfile(pretrained_model_name_or_path):
# Custom path for now
if dduf_entries:
if subfolder is not None:
raise ValueError(
"DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
"Please check the DDUF structure"
)
config_file = cls._get_config_file_from_dduf(pretrained_model_name_or_path, dduf_entries)
elif os.path.isfile(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if subfolder is not None and os.path.isfile(
Expand Down Expand Up @@ -426,10 +434,8 @@ def load_config(
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {cls.config_name} file"
)

try:
# Load config dict
config_dict = cls._dict_from_json_file(config_file)
config_dict = cls._dict_from_json_file(config_file, dduf_entries=dduf_entries)

commit_hash = extract_commit_hash(config_file)
except (json.JSONDecodeError, UnicodeDecodeError):
Expand Down Expand Up @@ -552,9 +558,14 @@ def extract_init_dict(cls, config_dict, **kwargs):
return init_dict, unused_kwargs, hidden_config_dict

@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
def _dict_from_json_file(
cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None
):
if dduf_entries:
text = dduf_entries[json_file].read_text()
else:
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)

def __repr__(self):
Expand Down Expand Up @@ -616,6 +627,20 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())

@classmethod
def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: Dict[str, DDUFEntry]):
# paths inside a DDUF file must always be "/"
config_file = (
cls.config_name
if pretrained_model_name_or_path == ""
else "/".join([pretrained_model_name_or_path, cls.config_name])
)
if config_file not in dduf_entries:
raise ValueError(
f"We did not manage to find the file {config_file} in the dduf file. We only have the following files {dduf_entries.keys()}"
)
return config_file


def register_to_config(init):
r"""
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.23.2",
"huggingface-hub": "huggingface-hub>=0.27.0",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",
Expand Down Expand Up @@ -43,4 +43,5 @@
"transformers": "transformers>=4.41.2",
"urllib3": "urllib3<=2.0.0",
"black": "black",
"phonemizer": "phonemizer",
}
39 changes: 36 additions & 3 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..utils import (
USE_PEFT_BACKEND,
deprecate,
get_submodule_by_name,
is_peft_available,
is_peft_version,
is_torch_version,
Expand Down Expand Up @@ -1981,10 +1982,17 @@ def _maybe_expand_transformer_param_shape_or_error_(
in_features = state_dict[lora_A_weight_name].shape[1]
out_features = state_dict[lora_B_weight_name].shape[0]

# Model maybe loaded with different quantization schemes which may flatten the params.
# `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
# preserve weight shape.
module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)

# This means there's no need for an expansion in the params, so we simply skip.
if tuple(module_weight.shape) == (out_features, in_features):
if tuple(module_weight_shape) == (out_features, in_features):
continue

# TODO (sayakpaul): We still need to consider if the module we're expanding is
# quantized and handle it accordingly if that is the case.
module_out_features, module_in_features = module_weight.shape
debug_message = ""
if in_features > module_in_features:
Expand Down Expand Up @@ -2080,13 +2088,16 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]

if base_weight_param.shape[1] > lora_A_param.shape[1]:
# TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)

if base_module_shape[1] > lora_A_param.shape[1]:
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
expanded_module_names.add(k)
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
elif base_module_shape[1] < lora_A_param.shape[1]:
raise NotImplementedError(
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
)
Expand All @@ -2098,6 +2109,28 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):

return lora_state_dict

@staticmethod
def _calculate_module_shape(
model: "torch.nn.Module",
base_module: "torch.nn.Linear" = None,
base_weight_param_name: str = None,
) -> "torch.Size":
def _get_weight_shape(weight: torch.Tensor):
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape

if base_module is not None:
return _get_weight_shape(base_module.weight)
elif base_weight_param_name is not None:
if not base_weight_param_name.endswith(".weight"):
raise ValueError(
f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
)
module_path = base_weight_param_name.rsplit(".weight", 1)[0]
submodule = get_submodule_by_name(model, module_path)
return _get_weight_shape(submodule.weight)

raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")


# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
Expand Down
Loading

0 comments on commit deda9a3

Please sign in to comment.