Skip to content

Commit

Permalink
[FEAT] DDUF format (#10037)
Browse files Browse the repository at this point in the history
* load and save dduf archive

* style

* switch to zip uncompressed

* updates

* Update src/diffusers/pipelines/pipeline_utils.py

Co-authored-by: Sayak Paul <[email protected]>

* Update src/diffusers/pipelines/pipeline_utils.py

Co-authored-by: Sayak Paul <[email protected]>

* first draft

* remove print

* switch to dduf_file for consistency

* switch to huggingface hub api

* fix log

* add a basic test

* Update src/diffusers/configuration_utils.py

Co-authored-by: Sayak Paul <[email protected]>

* Update src/diffusers/pipelines/pipeline_utils.py

Co-authored-by: Sayak Paul <[email protected]>

* Update src/diffusers/pipelines/pipeline_utils.py

Co-authored-by: Sayak Paul <[email protected]>

* fix

* fix variant

* change saving logic

* DDUF - Load transformers components manually (#10171)

* update hfh version

* Load transformers components manually

* load encoder from_pretrained with state_dict

* working version with transformers and tokenizer !

* add generation_config case

* fix tests

* remove saving for now

* typing

* need next version from transformers

* Update src/diffusers/configuration_utils.py

Co-authored-by: Lucain <[email protected]>

* check path corectly

* Apply suggestions from code review

Co-authored-by: Lucain <[email protected]>

* udapte

* typing

* remove check for subfolder

* quality

* revert setup changes

* oups

* more readable condition

* add loading from the hub test

* add basic docs.

* Apply suggestions from code review

Co-authored-by: Lucain <[email protected]>

* add example

* add

* make functions private

* Apply suggestions from code review

Co-authored-by: Steven Liu <[email protected]>

* minor.

* fixes

* fix

* change the precdence of parameterized.

* error out when custom pipeline is passed with dduf_file.

* updates

* fix

* updates

* fixes

* updates

* fix xfail condition.

* fix xfail

* fixes

* sharded checkpoint compat

* add test for sharded checkpoint

* add suggestions

* Update src/diffusers/models/model_loading_utils.py

Co-authored-by: YiYi Xu <[email protected]>

* from suggestions

* add class attributes to flag dduf tests

* last one

* fix logic

* remove comment

* revert changes

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Lucain <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
  • Loading branch information
5 people authored Jan 14, 2025
1 parent 3279751 commit fbff43a
Show file tree
Hide file tree
Showing 62 changed files with 750 additions and 45 deletions.
40 changes: 40 additions & 0 deletions docs/source/en/using-diffusers/other-formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,46 @@ Benefits of using a single-file layout include:
1. Easy compatibility with diffusion interfaces such as [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which commonly use a single-file layout.
2. Easier to manage (download and share) a single file.

### DDUF

> [!WARNING]
> DDUF is an experimental file format and APIs related to it can change in the future.
DDUF (**D**DUF **D**iffusion **U**nified **F**ormat) is a file format designed to make storing, distributing, and using diffusion models much easier. Built on the ZIP file format, DDUF offers a standardized, efficient, and flexible way to package all parts of a diffusion model into a single, easy-to-manage file. It provides a balance between Diffusers multi-folder format and the widely popular single-file format.

Learn more details about DDUF on the Hugging Face Hub [documentation](https://huggingface.co/docs/hub/dduf).

Pass a checkpoint to the `dduf_file` parameter to load it in [`DiffusionPipeline`].

```py
from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained(
"DDUF/FLUX.1-dev-DDUF", dduf_file="FLUX.1-dev.dduf", torch_dtype=torch.bfloat16
).to("cuda")
image = pipe(
"photo a cat holding a sign that says Diffusers", num_inference_steps=50, guidance_scale=3.5
).images[0]
image.save("cat.png")
```

To save a pipeline as a `.dduf` checkpoint, use the [`~huggingface_hub.export_folder_as_dduf`] utility, which takes care of all the necessary file-level validations.

```py
from huggingface_hub import export_folder_as_dduf
from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)

save_folder = "flux-dev"
pipe.save_pretrained("flux-dev")
export_folder_as_dduf("flux-dev.dduf", folder_path=save_folder)

> [!TIP]
> Packaging and loading quantized checkpoints in the DDUF format is supported as long as they respect the multi-folder structure.

## Convert layout and files

Diffusers provides many scripts and methods to convert storage layouts and file formats to enable broader support across the diffusion ecosystem.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.23.2",
"huggingface-hub>=0.27.0",
"requests-mock==1.10.0",
"importlib_metadata",
"invisible-watermark>=0.2.0",
Expand Down
45 changes: 35 additions & 10 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import re
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
RepositoryNotFoundError,
Expand Down Expand Up @@ -347,6 +347,7 @@ def load_config(
_ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)

user_agent = {**user_agent, "file_type": "config"}
user_agent = http_user_agent(user_agent)
Expand All @@ -358,8 +359,15 @@ def load_config(
"`self.config_name` is not defined. Note that one should not load a config from "
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
)

if os.path.isfile(pretrained_model_name_or_path):
# Custom path for now
if dduf_entries:
if subfolder is not None:
raise ValueError(
"DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
"Please check the DDUF structure"
)
config_file = cls._get_config_file_from_dduf(pretrained_model_name_or_path, dduf_entries)
elif os.path.isfile(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if subfolder is not None and os.path.isfile(
Expand Down Expand Up @@ -426,10 +434,8 @@ def load_config(
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {cls.config_name} file"
)

try:
# Load config dict
config_dict = cls._dict_from_json_file(config_file)
config_dict = cls._dict_from_json_file(config_file, dduf_entries=dduf_entries)

commit_hash = extract_commit_hash(config_file)
except (json.JSONDecodeError, UnicodeDecodeError):
Expand Down Expand Up @@ -552,9 +558,14 @@ def extract_init_dict(cls, config_dict, **kwargs):
return init_dict, unused_kwargs, hidden_config_dict

@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
def _dict_from_json_file(
cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None
):
if dduf_entries:
text = dduf_entries[json_file].read_text()
else:
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)

def __repr__(self):
Expand Down Expand Up @@ -616,6 +627,20 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())

@classmethod
def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: Dict[str, DDUFEntry]):
# paths inside a DDUF file must always be "/"
config_file = (
cls.config_name
if pretrained_model_name_or_path == ""
else "/".join([pretrained_model_name_or_path, cls.config_name])
)
if config_file not in dduf_entries:
raise ValueError(
f"We did not manage to find the file {config_file} in the dduf file. We only have the following files {dduf_entries.keys()}"
)
return config_file


def register_to_config(init):
r"""
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.23.2",
"huggingface-hub": "huggingface-hub>=0.27.0",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",
Expand Down
44 changes: 34 additions & 10 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
from array import array
from collections import OrderedDict
from pathlib import Path
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

import safetensors
import torch
from huggingface_hub import DDUFEntry
from huggingface_hub.utils import EntryNotFoundError

from ..utils import (
Expand Down Expand Up @@ -132,7 +133,10 @@ def _fetch_remapped_cls_from_config(config, old_class):


def load_state_dict(
checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False
checkpoint_file: Union[str, os.PathLike],
variant: Optional[str] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
disable_mmap: bool = False,
):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
Expand All @@ -144,6 +148,10 @@ def load_state_dict(
try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
if dduf_entries:
# tensors are loaded on cpu
with dduf_entries[checkpoint_file].as_mmap() as mm:
return safetensors.torch.load(mm)
if disable_mmap:
return safetensors.torch.load(open(checkpoint_file, "rb").read())
else:
Expand Down Expand Up @@ -284,6 +292,7 @@ def _fetch_index_file(
revision,
user_agent,
commit_hash,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
):
if is_local:
index_file = Path(
Expand All @@ -309,8 +318,10 @@ def _fetch_index_file(
subfolder=None,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
index_file = Path(index_file)
if not dduf_entries:
index_file = Path(index_file)
except (EntryNotFoundError, EnvironmentError):
index_file = None

Expand All @@ -319,7 +330,9 @@ def _fetch_index_file(

# Adapted from
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
def _merge_sharded_checkpoints(
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries: Optional[Dict[str, DDUFEntry]] = None
):
weight_map = sharded_metadata.get("weight_map", None)
if weight_map is None:
raise KeyError("'weight_map' key not found in the shard index file.")
Expand All @@ -332,14 +345,23 @@ def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
# Load tensors from each unique file
for file_name in files_to_load:
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
if not os.path.exists(part_file_path):
raise FileNotFoundError(f"Part file {file_name} not found.")
if dduf_entries:
if part_file_path not in dduf_entries:
raise FileNotFoundError(f"Part file {file_name} not found.")
else:
if not os.path.exists(part_file_path):
raise FileNotFoundError(f"Part file {file_name} not found.")

if is_safetensors:
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
for tensor_key in f.keys():
if tensor_key in weight_map:
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
if dduf_entries:
with dduf_entries[part_file_path].as_mmap() as mm:
tensors = safetensors.torch.load(mm)
merged_state_dict.update(tensors)
else:
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
for tensor_key in f.keys():
if tensor_key in weight_map:
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
else:
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))

Expand All @@ -360,6 +382,7 @@ def _fetch_index_file_legacy(
revision,
user_agent,
commit_hash,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
):
if is_local:
index_file = Path(
Expand Down Expand Up @@ -400,6 +423,7 @@ def _fetch_index_file_legacy(
subfolder=None,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
index_file = Path(index_file)
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
Expand Down
27 changes: 20 additions & 7 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from collections import OrderedDict
from functools import partial, wraps
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import safetensors
import torch
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards
from huggingface_hub.utils import validate_hf_hub_args
from torch import Tensor, nn

Expand Down Expand Up @@ -607,6 +607,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
quantization_config = kwargs.pop("quantization_config", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)

allow_pickle = False
Expand Down Expand Up @@ -700,6 +701,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
dduf_entries=dduf_entries,
**kwargs,
)
# no in-place modification of the original config.
Expand Down Expand Up @@ -776,13 +778,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
"revision": revision,
"user_agent": user_agent,
"commit_hash": commit_hash,
"dduf_entries": dduf_entries,
}
index_file = _fetch_index_file(**index_file_kwargs)
# In case the index file was not found we still have to consider the legacy format.
# this becomes applicable when the variant is not None.
if variant is not None and (index_file is None or not os.path.exists(index_file)):
index_file = _fetch_index_file_legacy(**index_file_kwargs)
if index_file is not None and index_file.is_file():
if index_file is not None and (dduf_entries or index_file.is_file()):
is_sharded = True

if is_sharded and from_flax:
Expand Down Expand Up @@ -811,6 +814,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

model = load_flax_checkpoint_in_pytorch_model(model, model_file)
else:
# in the case it is sharded, we have already the index
if is_sharded:
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
pretrained_model_name_or_path,
Expand All @@ -822,10 +826,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
user_agent=user_agent,
revision=revision,
subfolder=subfolder or "",
dduf_entries=dduf_entries,
)
# TODO: https://github.com/huggingface/diffusers/issues/10013
if hf_quantizer is not None:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
if hf_quantizer is not None or dduf_entries:
model_file = _merge_sharded_checkpoints(
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries=dduf_entries
)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False

Expand All @@ -843,6 +850,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)

except IOError as e:
Expand All @@ -866,6 +874,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)

if low_cpu_mem_usage:
Expand All @@ -887,7 +896,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
else:
param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
state_dict = load_state_dict(
model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap
)
model._convert_deprecated_attention_blocks(state_dict)

# move the params from meta device to cpu
Expand Down Expand Up @@ -983,7 +994,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
model = cls.from_config(config, **unused_kwargs)

state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
state_dict = load_state_dict(
model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap
)
model._convert_deprecated_attention_blocks(state_dict)

model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
Expand Down
Loading

0 comments on commit fbff43a

Please sign in to comment.