Skip to content

Commit

Permalink
Merge branch 'main' into Add-AnyText
Browse files Browse the repository at this point in the history
  • Loading branch information
tolgacangoz authored Jan 6, 2025
2 parents b04d015 + b572635 commit 67f8839
Show file tree
Hide file tree
Showing 17 changed files with 390 additions and 79 deletions.
12 changes: 6 additions & 6 deletions docs/source/en/tutorials/using_peft_for_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ image

With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images and call it `"pixel"`.

The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~PeftAdapterMixin.set_adapters`] method:
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~loaders.peft.PeftAdapterMixin.set_adapters`] method:

```python
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
Expand Down Expand Up @@ -85,7 +85,7 @@ By default, if the most up-to-date versions of PEFT and Transformers are detecte

You can also merge different adapter checkpoints for inference to blend their styles together.

Once again, use the [`~PeftAdapterMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
Once again, use the [`~loaders.peft.PeftAdapterMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.

```python
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
Expand Down Expand Up @@ -114,7 +114,7 @@ Impressive! As you can see, the model generated an image that mixed the characte
> [!TIP]
> Through its PEFT integration, Diffusers also offers more efficient merging methods which you can learn about in the [Merge LoRAs](../using-diffusers/merge_loras) guide!
To return to only using one adapter, use the [`~PeftAdapterMixin.set_adapters`] method to activate the `"toy"` adapter:
To return to only using one adapter, use the [`~loaders.peft.PeftAdapterMixin.set_adapters`] method to activate the `"toy"` adapter:

```python
pipe.set_adapters("toy")
Expand All @@ -127,7 +127,7 @@ image = pipe(
image
```

Or to disable all adapters entirely, use the [`~PeftAdapterMixin.disable_lora`] method to return the base model.
Or to disable all adapters entirely, use the [`~loaders.peft.PeftAdapterMixin.disable_lora`] method to return the base model.

```python
pipe.disable_lora()
Expand All @@ -141,7 +141,7 @@ image

### Customize adapters strength

For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~PeftAdapterMixin.set_adapters`].
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~loaders.peft.PeftAdapterMixin.set_adapters`].

For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts:
```python
Expand Down Expand Up @@ -214,7 +214,7 @@ list_adapters_component_wise
{"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]}
```

The [`~PeftAdapterMixin.delete_adapters`] function completely removes an adapter and their LoRA layers from a model.
The [`~loaders.peft.PeftAdapterMixin.delete_adapters`] function completely removes an adapter and their LoRA layers from a model.

```py
pipe.delete_adapters("toy")
Expand Down
18 changes: 14 additions & 4 deletions examples/community/rerender_a_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,17 @@
from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import BaseOutput, deprecate, logging
from diffusers.utils import BaseOutput, deprecate, is_torch_xla_available, logging
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor


if is_torch_xla_available():
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False

logger = logging.get_logger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -775,7 +782,7 @@ def __call__(
self.attn_state.reset()

# 4.1 prepare frames
image = self.image_processor.preprocess(frames[0]).to(dtype=torch.float32)
image = self.image_processor.preprocess(frames[0]).to(dtype=self.dtype)
first_image = image[0] # C, H, W

# 4.2 Prepare controlnet_conditioning_image
Expand Down Expand Up @@ -919,8 +926,8 @@ def __call__(
prev_image = frames[idx - 1]
control_image = control_frames[idx]
# 5.1 prepare frames
image = self.image_processor.preprocess(image).to(dtype=torch.float32)
prev_image = self.image_processor.preprocess(prev_image).to(dtype=torch.float32)
image = self.image_processor.preprocess(image).to(dtype=self.dtype)
prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype)

warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
self.flow_model, first_image, image[0], first_result, False, self.device
Expand Down Expand Up @@ -1100,6 +1107,9 @@ def denoising_loop(latents, mask=None, xtrg=None, noise_rescale=None):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)

if XLA_AVAILABLE:
xm.mark_step()

return latents

if mask_start_t <= mask_end_t:
Expand Down
19 changes: 18 additions & 1 deletion examples/flux-control/train_control_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,11 +923,28 @@ def load_model_hook(models, input_dir):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")

else:
transformer_ = FluxTransformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer"
).to(accelerator.device, weight_dtype)

# Handle input dimension doubling before adding adapter
with torch.no_grad():
initial_input_channels = transformer_.config.in_channels
new_linear = torch.nn.Linear(
transformer_.x_embedder.in_features * 2,
transformer_.x_embedder.out_features,
bias=transformer_.x_embedder.bias is not None,
dtype=transformer_.dtype,
device=transformer_.device,
)
new_linear.weight.zero_()
new_linear.weight[:, :initial_input_channels].copy_(transformer_.x_embedder.weight)
if transformer_.x_embedder.bias is not None:
new_linear.bias.copy_(transformer_.x_embedder.bias)
transformer_.x_embedder = new_linear
transformer_.register_to_config(in_channels=initial_input_channels * 2)

transformer_.add_adapter(transformer_lora_config)

lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
Expand Down
2 changes: 1 addition & 1 deletion examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ def preprocess_train(examples):
# fingerprint used by the cache for the other processes to load the result
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
new_fingerprint = Hasher.hash(args)
new_fingerprint_for_vae = Hasher.hash(vae_path)
new_fingerprint_for_vae = Hasher.hash((vae_path, args))
train_dataset_with_embeddings = train_dataset.map(
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
)
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2466,7 +2466,9 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
continue

base_param_name = (
f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
f"{k.replace(prefix, '')}.base_layer.weight"
if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
else f"{k.replace(prefix, '')}.weight"
)
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,8 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/transformers/sana_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,14 @@ def __init__(
inner_dim = num_attention_heads * attention_head_dim

# 1. Patch Embedding
interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
self.patch_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
pos_embed_type="sincos" if interpolation_scale is not None else None,
)

# 2. Additional condition embeddings
Expand Down
105 changes: 63 additions & 42 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):

def forward(
self,
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
joint_attention_kwargs=None,
):
hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
Expand Down Expand Up @@ -117,15 +117,22 @@ class FluxTransformerBlock(nn.Module):
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
Args:
dim (`int`):
The embedding dimension of the block.
num_attention_heads (`int`):
The number of attention heads to use.
attention_head_dim (`int`):
The number of dimensions to use for each attention head.
qk_norm (`str`, defaults to `"rms_norm"`):
The normalization to use for the query and key tensors.
eps (`float`, defaults to `1e-6`):
The epsilon value to use for the normalization.
"""

def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
def __init__(
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
):
super().__init__()

self.norm1 = AdaLayerNormZero(dim)
Expand Down Expand Up @@ -164,12 +171,12 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_no

def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
joint_attention_kwargs=None,
):
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)

norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
Expand Down Expand Up @@ -227,16 +234,30 @@ class FluxTransformer2DModel(
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
Args:
patch_size (`int`, defaults to `1`):
Patch size to turn the input data into small patches.
in_channels (`int`, defaults to `64`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `None`):
The number of channels in the output. If not specified, it defaults to `in_channels`.
num_layers (`int`, defaults to `19`):
The number of layers of dual stream DiT blocks to use.
num_single_layers (`int`, defaults to `38`):
The number of layers of single stream DiT blocks to use.
attention_head_dim (`int`, defaults to `128`):
The number of dimensions to use for each attention head.
num_attention_heads (`int`, defaults to `24`):
The number of attention heads to use.
joint_attention_dim (`int`, defaults to `4096`):
The number of dimensions to use for the joint attention (embedding/channel dimension of
`encoder_hidden_states`).
pooled_projection_dim (`int`, defaults to `768`):
The number of dimensions to use for the pooled projection.
guidance_embeds (`bool`, defaults to `False`):
Whether to use guidance embeddings for guidance-distilled variant of the model.
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions to use for the rotary positional embeddings.
"""

_supports_gradient_checkpointing = True
Expand All @@ -259,39 +280,39 @@ def __init__(
):
super().__init__()
self.out_channels = out_channels or in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.inner_dim = num_attention_heads * attention_head_dim

self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)

text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)

self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = nn.Linear(in_channels, self.inner_dim)

self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(self.config.num_layers)
for _ in range(num_layers)
]
)

self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(self.config.num_single_layers)
for _ in range(num_single_layers)
]
)

Expand Down Expand Up @@ -418,16 +439,16 @@ def forward(
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
Expand Down
12 changes: 9 additions & 3 deletions src/diffusers/pipelines/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,9 @@ def from_pipe(cls, pipeline, **kwargs):
if k not in text_2_image_kwargs
}

missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(text_2_image_kwargs.keys())
missing_modules = (
set(expected_modules) - set(text_2_image_cls._optional_components) - set(text_2_image_kwargs.keys())
)

if len(missing_modules) > 0:
raise ValueError(
Expand Down Expand Up @@ -838,7 +840,9 @@ def from_pipe(cls, pipeline, **kwargs):
if k not in image_2_image_kwargs
}

missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(image_2_image_kwargs.keys())
missing_modules = (
set(expected_modules) - set(image_2_image_cls._optional_components) - set(image_2_image_kwargs.keys())
)

if len(missing_modules) > 0:
raise ValueError(
Expand Down Expand Up @@ -1141,7 +1145,9 @@ def from_pipe(cls, pipeline, **kwargs):
if k not in inpainting_kwargs
}

missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(inpainting_kwargs.keys())
missing_modules = (
set(expected_modules) - set(inpainting_cls._optional_components) - set(inpainting_kwargs.keys())
)

if len(missing_modules) > 0:
raise ValueError(
Expand Down
Loading

0 comments on commit 67f8839

Please sign in to comment.