Skip to content

Commit

Permalink
Merge branch 'main' into flux_ptxla_trillium
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul authored Jan 10, 2025
2 parents 2816222 + 83ba01a commit 8c42542
Show file tree
Hide file tree
Showing 29 changed files with 367 additions and 776 deletions.
11 changes: 11 additions & 0 deletions examples/advanced_diffusion_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ write_basic_config()
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.

Lastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub:
```bash
huggingface-cli login
```
This command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter.

> [!NOTE]
> In the examples below we use `wandb` to document the training runs. To do the same, make sure to install `wandb`:
> `pip install wandb`
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
### Pivotal Tuning
**Training with text encoder(s)**

Expand Down
11 changes: 11 additions & 0 deletions examples/advanced_diffusion_training/README_flux.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ write_basic_config()
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.

Lastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub:
```bash
huggingface-cli login
```
This command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter.

> [!NOTE]
> In the examples below we use `wandb` to document the training runs. To do the same, make sure to install `wandb`:
> `pip install wandb`
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
### Target Modules
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
Expand Down
8 changes: 4 additions & 4 deletions examples/community/pipeline_flux_differential_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,10 +875,10 @@ def __call__(
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
Expand Down
16 changes: 8 additions & 8 deletions examples/community/pipeline_flux_rf_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,10 +820,10 @@ def __call__(
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
Expand Down Expand Up @@ -990,10 +990,10 @@ def invert(
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16),
)
timesteps, num_inversion_steps = retrieve_timesteps(
self.scheduler,
Expand Down
9 changes: 5 additions & 4 deletions examples/community/pipeline_flux_with_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"""


# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
Expand Down Expand Up @@ -755,10 +756,10 @@ def __call__(
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
Expand Down
6 changes: 3 additions & 3 deletions examples/community/rerender_a_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def __call__(
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
frames (`List[np.ndarray]` or `torch.Tensor`): The input images to be used as the starting point for the image generation process.
control_frames (`List[np.ndarray]` or `torch.Tensor`): The ControlNet input images condition to provide guidance to the `unet` for generation.
control_frames (`List[np.ndarray]` or `torch.Tensor` or `Callable`): The ControlNet input images condition to provide guidance to the `unet` for generation or any callable object to convert frame to control_frame.
strength ('float'): SDEdit strength.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
Expand Down Expand Up @@ -789,7 +789,7 @@ def __call__(
# Currently we only support single control
if isinstance(controlnet, ControlNetModel):
control_image = self.prepare_control_image(
image=control_frames[0],
image=control_frames(frames[0]) if callable(control_frames) else control_frames[0],
width=width,
height=height,
batch_size=batch_size,
Expand Down Expand Up @@ -924,7 +924,7 @@ def __call__(
for idx in range(1, len(frames)):
image = frames[idx]
prev_image = frames[idx - 1]
control_image = control_frames[idx]
control_image = control_frames(image) if callable(control_frames) else control_frames[idx]
# 5.1 prepare frames
image = self.image_processor.preprocess(image).to(dtype=self.dtype)
prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype)
Expand Down
177 changes: 156 additions & 21 deletions src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,20 @@
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
delete_adapter_layers,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_peft_available,
is_peft_version,
is_transformers_available,
is_transformers_version,
logging,
recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
)
Expand All @@ -43,6 +50,8 @@
if is_transformers_available():
from transformers import PreTrainedModel

from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules

if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer

Expand Down Expand Up @@ -297,6 +306,152 @@ def _best_guess_weight_name(
return weight_name


def _load_lora_into_text_encoder(
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
text_encoder_name="text_encoder",
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")

peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
if not is_transformers_version(">", "4.45.2"):
# Note from sayakpaul: It's not in `transformers` stable yet.
# https://github.com/huggingface/transformers/pull/33725/
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage

from peft import LoraConfig

# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = text_encoder_name if prefix is None else prefix

# Safe prefix to check with.
if any(text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}

if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)

# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)

for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]

for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]

if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}

lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)

if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")

if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")

lora_config = LoraConfig(**lora_config_kwargs)

# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)

is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)

# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)

# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)

text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)

# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />


def _func_optionally_disable_offloading(_pipeline):
is_model_cpu_offload = False
is_sequential_cpu_offload = False

if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)

logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)

return (is_model_cpu_offload, is_sequential_cpu_offload)


class LoraBaseMixin:
"""Utility class for handling LoRAs."""

Expand Down Expand Up @@ -327,27 +482,7 @@ def _optionally_disable_offloading(cls, _pipeline):
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False

if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)

logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)

return (is_model_cpu_offload, is_sequential_cpu_offload)
return _func_optionally_disable_offloading(_pipeline=_pipeline)

@classmethod
def _fetch_state_dict(cls, *args, **kwargs):
Expand Down
Loading

0 comments on commit 8c42542

Please sign in to comment.