Skip to content

Commit

Permalink
Merge branch 'main' into RMS
Browse files Browse the repository at this point in the history
  • Loading branch information
leisuzz authored Jan 15, 2025
2 parents 4f98ac0 + 2432f80 commit 1f7ee3f
Show file tree
Hide file tree
Showing 75 changed files with 998 additions and 68 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/push_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ jobs:
python utils/print_env.py
- name: PyTorch CUDA checkpoint tests on Ubuntu
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
Expand Down Expand Up @@ -137,7 +137,7 @@ jobs:
- name: Run PyTorch CUDA tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
Expand Down
40 changes: 40 additions & 0 deletions docs/source/en/using-diffusers/other-formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,46 @@ Benefits of using a single-file layout include:
1. Easy compatibility with diffusion interfaces such as [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which commonly use a single-file layout.
2. Easier to manage (download and share) a single file.

### DDUF

> [!WARNING]
> DDUF is an experimental file format and APIs related to it can change in the future.
DDUF (**D**DUF **D**iffusion **U**nified **F**ormat) is a file format designed to make storing, distributing, and using diffusion models much easier. Built on the ZIP file format, DDUF offers a standardized, efficient, and flexible way to package all parts of a diffusion model into a single, easy-to-manage file. It provides a balance between Diffusers multi-folder format and the widely popular single-file format.

Learn more details about DDUF on the Hugging Face Hub [documentation](https://huggingface.co/docs/hub/dduf).

Pass a checkpoint to the `dduf_file` parameter to load it in [`DiffusionPipeline`].

```py
from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained(
"DDUF/FLUX.1-dev-DDUF", dduf_file="FLUX.1-dev.dduf", torch_dtype=torch.bfloat16
).to("cuda")
image = pipe(
"photo a cat holding a sign that says Diffusers", num_inference_steps=50, guidance_scale=3.5
).images[0]
image.save("cat.png")
```

To save a pipeline as a `.dduf` checkpoint, use the [`~huggingface_hub.export_folder_as_dduf`] utility, which takes care of all the necessary file-level validations.

```py
from huggingface_hub import export_folder_as_dduf
from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)

save_folder = "flux-dev"
pipe.save_pretrained("flux-dev")
export_folder_as_dduf("flux-dev.dduf", folder_path=save_folder)

> [!TIP]
> Packaging and loading quantized checkpoints in the DDUF format is supported as long as they respect the multi-folder structure.

## Convert layout and files

Diffusers provides many scripts and methods to convert storage layouts and file formats to enable broader support across the diffusion ecosystem.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.23.2",
"huggingface-hub>=0.27.0",
"requests-mock==1.10.0",
"importlib_metadata",
"invisible-watermark>=0.2.0",
Expand Down
45 changes: 35 additions & 10 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import re
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
RepositoryNotFoundError,
Expand Down Expand Up @@ -347,6 +347,7 @@ def load_config(
_ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)

user_agent = {**user_agent, "file_type": "config"}
user_agent = http_user_agent(user_agent)
Expand All @@ -358,8 +359,15 @@ def load_config(
"`self.config_name` is not defined. Note that one should not load a config from "
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
)

if os.path.isfile(pretrained_model_name_or_path):
# Custom path for now
if dduf_entries:
if subfolder is not None:
raise ValueError(
"DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
"Please check the DDUF structure"
)
config_file = cls._get_config_file_from_dduf(pretrained_model_name_or_path, dduf_entries)
elif os.path.isfile(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if subfolder is not None and os.path.isfile(
Expand Down Expand Up @@ -426,10 +434,8 @@ def load_config(
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {cls.config_name} file"
)

try:
# Load config dict
config_dict = cls._dict_from_json_file(config_file)
config_dict = cls._dict_from_json_file(config_file, dduf_entries=dduf_entries)

commit_hash = extract_commit_hash(config_file)
except (json.JSONDecodeError, UnicodeDecodeError):
Expand Down Expand Up @@ -552,9 +558,14 @@ def extract_init_dict(cls, config_dict, **kwargs):
return init_dict, unused_kwargs, hidden_config_dict

@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
def _dict_from_json_file(
cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None
):
if dduf_entries:
text = dduf_entries[json_file].read_text()
else:
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)

def __repr__(self):
Expand Down Expand Up @@ -616,6 +627,20 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())

@classmethod
def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: Dict[str, DDUFEntry]):
# paths inside a DDUF file must always be "/"
config_file = (
cls.config_name
if pretrained_model_name_or_path == ""
else "/".join([pretrained_model_name_or_path, cls.config_name])
)
if config_file not in dduf_entries:
raise ValueError(
f"We did not manage to find the file {config_file} in the dduf file. We only have the following files {dduf_entries.keys()}"
)
return config_file


def register_to_config(init):
r"""
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.23.2",
"huggingface-hub": "huggingface-hub>=0.27.0",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",
Expand Down
39 changes: 36 additions & 3 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..utils import (
USE_PEFT_BACKEND,
deprecate,
get_submodule_by_name,
is_peft_available,
is_peft_version,
is_torch_version,
Expand Down Expand Up @@ -1981,10 +1982,17 @@ def _maybe_expand_transformer_param_shape_or_error_(
in_features = state_dict[lora_A_weight_name].shape[1]
out_features = state_dict[lora_B_weight_name].shape[0]

# Model maybe loaded with different quantization schemes which may flatten the params.
# `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models
# preserve weight shape.
module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module)

# This means there's no need for an expansion in the params, so we simply skip.
if tuple(module_weight.shape) == (out_features, in_features):
if tuple(module_weight_shape) == (out_features, in_features):
continue

# TODO (sayakpaul): We still need to consider if the module we're expanding is
# quantized and handle it accordingly if that is the case.
module_out_features, module_in_features = module_weight.shape
debug_message = ""
if in_features > module_in_features:
Expand Down Expand Up @@ -2080,13 +2088,16 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]

if base_weight_param.shape[1] > lora_A_param.shape[1]:
# TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization.
base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name)

if base_module_shape[1] > lora_A_param.shape[1]:
shape = (lora_A_param.shape[0], base_weight_param.shape[1])
expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device)
expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param)
lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight
expanded_module_names.add(k)
elif base_weight_param.shape[1] < lora_A_param.shape[1]:
elif base_module_shape[1] < lora_A_param.shape[1]:
raise NotImplementedError(
f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new."
)
Expand All @@ -2098,6 +2109,28 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):

return lora_state_dict

@staticmethod
def _calculate_module_shape(
model: "torch.nn.Module",
base_module: "torch.nn.Linear" = None,
base_weight_param_name: str = None,
) -> "torch.Size":
def _get_weight_shape(weight: torch.Tensor):
return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape

if base_module is not None:
return _get_weight_shape(base_module.weight)
elif base_weight_param_name is not None:
if not base_weight_param_name.endswith(".weight"):
raise ValueError(
f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}."
)
module_path = base_weight_param_name.rsplit(".weight", 1)[0]
submodule = get_submodule_by_name(model, module_path)
return _get_weight_shape(submodule.weight)

raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")


# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/loaders/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
hf_token = kwargs.pop("hf_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
Expand Down Expand Up @@ -73,7 +73,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
token=hf_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
Expand All @@ -93,7 +93,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
token=hf_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
Expand Down Expand Up @@ -312,7 +312,7 @@ def load_textual_inversion(
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
hf_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
Expand Down
44 changes: 34 additions & 10 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from array import array
from collections import OrderedDict
from pathlib import Path
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

import safetensors
import torch
from huggingface_hub import DDUFEntry
from huggingface_hub.utils import EntryNotFoundError

from ..utils import (
Expand Down Expand Up @@ -132,7 +133,10 @@ def _fetch_remapped_cls_from_config(config, old_class):


def load_state_dict(
checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False
checkpoint_file: Union[str, os.PathLike],
variant: Optional[str] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
disable_mmap: bool = False,
):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
Expand All @@ -144,6 +148,10 @@ def load_state_dict(
try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
if dduf_entries:
# tensors are loaded on cpu
with dduf_entries[checkpoint_file].as_mmap() as mm:
return safetensors.torch.load(mm)
if disable_mmap:
return safetensors.torch.load(open(checkpoint_file, "rb").read())
else:
Expand Down Expand Up @@ -284,6 +292,7 @@ def _fetch_index_file(
revision,
user_agent,
commit_hash,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
):
if is_local:
index_file = Path(
Expand All @@ -309,8 +318,10 @@ def _fetch_index_file(
subfolder=None,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
index_file = Path(index_file)
if not dduf_entries:
index_file = Path(index_file)
except (EntryNotFoundError, EnvironmentError):
index_file = None

Expand All @@ -319,7 +330,9 @@ def _fetch_index_file(

# Adapted from
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
def _merge_sharded_checkpoints(
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries: Optional[Dict[str, DDUFEntry]] = None
):
weight_map = sharded_metadata.get("weight_map", None)
if weight_map is None:
raise KeyError("'weight_map' key not found in the shard index file.")
Expand All @@ -332,14 +345,23 @@ def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
# Load tensors from each unique file
for file_name in files_to_load:
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
if not os.path.exists(part_file_path):
raise FileNotFoundError(f"Part file {file_name} not found.")
if dduf_entries:
if part_file_path not in dduf_entries:
raise FileNotFoundError(f"Part file {file_name} not found.")
else:
if not os.path.exists(part_file_path):
raise FileNotFoundError(f"Part file {file_name} not found.")

if is_safetensors:
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
for tensor_key in f.keys():
if tensor_key in weight_map:
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
if dduf_entries:
with dduf_entries[part_file_path].as_mmap() as mm:
tensors = safetensors.torch.load(mm)
merged_state_dict.update(tensors)
else:
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
for tensor_key in f.keys():
if tensor_key in weight_map:
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
else:
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))

Expand All @@ -360,6 +382,7 @@ def _fetch_index_file_legacy(
revision,
user_agent,
commit_hash,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
):
if is_local:
index_file = Path(
Expand Down Expand Up @@ -400,6 +423,7 @@ def _fetch_index_file_legacy(
subfolder=None,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
index_file = Path(index_file)
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
Expand Down
Loading

0 comments on commit 1f7ee3f

Please sign in to comment.