Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] Add LoRA support to AuraFlow #10216

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e50a108
Add AuraFlowLoraLoaderMixin
Warlord-K Jul 30, 2024
658d058
Add comments, remove qkv fusion
Warlord-K Jul 30, 2024
4208d09
Add Tests
Warlord-K Jul 30, 2024
98b19f6
Add AuraFlowLoraLoaderMixin to documentation
Warlord-K Jul 30, 2024
71f8bac
Add Suggested changes
Warlord-K Aug 11, 2024
0eee03e
Change attention_kwargs->joint_attention_kwargs
Warlord-K Aug 12, 2024
4e4f780
Rebasing derp.
hameerabbasi Dec 13, 2024
c07d1f5
fix
hlky Dec 13, 2024
1b7f99f
fix
hlky Dec 13, 2024
875a3e0
Quality fixes.
hameerabbasi Dec 13, 2024
a242d7a
make style
hlky Dec 13, 2024
a73df6b
`make fix-copies`
hameerabbasi Dec 13, 2024
894eac0
`ruff check --fix`
hameerabbasi Dec 13, 2024
2b36416
Attept 1 to fix tests.
hameerabbasi Dec 15, 2024
6b762b8
Attept 2 to fix tests.
hameerabbasi Dec 15, 2024
bc2a466
Attept 3 to fix tests.
hameerabbasi Dec 15, 2024
1c79095
Address review comments.
hameerabbasi Dec 19, 2024
9454e84
Rebasing derp.
hameerabbasi Dec 19, 2024
5700e52
Merge branch 'main' into auraflow-lora
hameerabbasi Jan 3, 2025
6da81f8
Merge branch 'main' into auraflow-lora
sayakpaul Jan 6, 2025
28a4918
Get more tests passing by copying from Flux. Address review comments.
hameerabbasi Jan 7, 2025
d6028cd
`joint_attention_kwargs`->`attention_kwargs`
hameerabbasi Jan 7, 2025
6e899a3
Merge branch 'main' into auraflow-lora
hameerabbasi Jan 7, 2025
2d02c2c
Add `lora_scale` property for te LoRAs.
hameerabbasi Jan 7, 2025
2b934b4
Make test better.
hameerabbasi Jan 7, 2025
532013f
Remove useless property.
hameerabbasi Jan 7, 2025
0ea9ecd
Merge branch 'main' into auraflow-lora
hlky Jan 8, 2025
e06d8eb
Skip TE-only tests for AuraFlow.
hameerabbasi Jan 8, 2025
2b35909
Support LoRA for non-CLIP TEs.
hameerabbasi Jan 10, 2025
1ec07a1
Merge remote-tracking branch 'upstream/main' into auraflow-lora
hameerabbasi Jan 10, 2025
077a452
Merge branch 'main' into auraflow-lora
hlky Jan 10, 2025
3095644
Merge branch 'main' into auraflow-lora
hameerabbasi Jan 13, 2025
df28362
Merge remote-tracking branch 'upstream/main' into auraflow-lora
hameerabbasi Jan 19, 2025
7e63330
Restore LoRA tests.
hameerabbasi Jan 19, 2025
5620384
Undo adding LoRA support for non-CLIP TEs.
hameerabbasi Jan 19, 2025
cd691d3
Undo support for TE in AuraFlow LoRA.
hameerabbasi Jan 19, 2025
0fa5cd5
`make fix-copies`
hameerabbasi Jan 19, 2025
83e0825
Sync with upstream changes.
hameerabbasi Jan 19, 2025
12fbd11
Remove unneeded stuff.
hameerabbasi Jan 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/api/loaders/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux).
- [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox).
- [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi).
- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.

Expand Down Expand Up @@ -52,6 +53,9 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
## Mochi1LoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
## AuraFlowLoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin

## AmusedLoraLoaderMixin

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def text_encoder_attn_modules(text_encoder):
"AmusedLoraLoaderMixin",
"StableDiffusionLoraLoaderMixin",
"SD3LoraLoaderMixin",
"AuraFlowLoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LTXVideoLoraLoaderMixin",
"LoraLoaderMixin",
Expand Down Expand Up @@ -100,6 +101,7 @@ def text_encoder_attn_modules(text_encoder):
)
from .lora_pipeline import (
AmusedLoraLoaderMixin,
AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
FluxLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
Expand Down
328 changes: 328 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
}


Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def text_encoder_attn_modules(text_encoder):
def text_encoder_attn_modules(text_encoder: nn.Module):
attn_modules = []

if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
Expand All @@ -52,7 +52,7 @@ def text_encoder_attn_modules(text_encoder):
return attn_modules


def text_encoder_mlp_modules(text_encoder):
def text_encoder_mlp_modules(text_encoder: nn.Module):
mlp_modules = []

if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
Expand Down
27 changes: 22 additions & 5 deletions src/diffusers/models/transformers/auraflow_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
# limitations under the License.


from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import is_torch_version, logging
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import (
Attention,
Expand Down Expand Up @@ -254,7 +254,7 @@ def forward(
return encoder_hidden_states, hidden_states


class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
r"""
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).

Expand Down Expand Up @@ -452,6 +452,7 @@ def forward(
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
height, width = hidden_states.shape[-2:]
Expand All @@ -464,7 +465,19 @@ def forward(
encoder_hidden_states = torch.cat(
[self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1
)

if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# MMDiT blocks.
for index_block, block in enumerate(self.joint_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
Expand Down Expand Up @@ -539,6 +552,10 @@ def custom_forward(*inputs):
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
)

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)

if not return_dict:
return (output,)

Expand Down
47 changes: 44 additions & 3 deletions src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from transformers import T5Tokenizer, UMT5EncoderModel

from ...image_processor import VaeImageProcessor
from ...loaders import AuraFlowLoraLoaderMixin
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput

Expand Down Expand Up @@ -111,7 +119,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class AuraFlowPipeline(DiffusionPipeline):
class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
r"""
Args:
tokenizer (`T5TokenizerFast`):
Expand Down Expand Up @@ -219,6 +227,7 @@ def encode_prompt(
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 256,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Expand All @@ -245,10 +254,21 @@ def encode_prompt(
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
if device is None:
device = self._execution_device

# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin):
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)

if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
Expand Down Expand Up @@ -406,6 +426,7 @@ def __call__(
max_sequence_length: int = 256,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -461,6 +482,10 @@ def __call__(
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).

Examples:

Expand All @@ -483,6 +508,8 @@ def __call__(
negative_prompt_attention_mask,
)

self._attention_kwargs = attention_kwargs

# 2. Determine batch size.
if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand All @@ -492,6 +519,9 @@ def __call__(
batch_size = prompt_embeds.shape[0]

device = self._execution_device
lora_scale = (
self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
)

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
Expand All @@ -515,6 +545,7 @@ def __call__(
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
Expand Down Expand Up @@ -555,6 +586,7 @@ def __call__(
encoder_hidden_states=prompt_embeds,
timestep=timestep,
return_dict=False,
attention_kwargs=self.attention_kwargs,
)[0]

# perform guidance
Expand Down Expand Up @@ -586,7 +618,16 @@ def __call__(
# Offload all models
self.maybe_free_model_hooks()

if self.text_encoder is not None:
if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)

if not return_dict:
return (image,)

return ImagePipelineOutput(images=image)

@property
def attention_kwargs(self):
return self._attention_kwargs
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
)
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
from .state_dict_utils import (
StateDictType,
convert_all_state_dict_to_peft,
convert_state_dict_to_diffusers,
convert_state_dict_to_kohya,
Expand Down
Loading