Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the bug that joint_attention_kwargs is not passed to the FLUX's transformer attention processors #9517

Merged
merged 9 commits into from
Oct 8, 2024
8 changes: 7 additions & 1 deletion src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,12 @@ def forward(
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
joint_attention_kwargs=None,
):
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))

joint_attention_kwargs = joint_attention_kwargs if joint_attention_kwargs is not None else {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we pass this to attn too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! I think it will be useful for other trial!

attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
Expand Down Expand Up @@ -161,6 +162,7 @@ def forward(
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
joint_attention_kwargs={},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain what additional argument you need to pass down to flux attention processor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your recognition!

In our work, I am trying to integrate box and mask into the FLUX model and implement layout control (similar to what has been done in many works on SD1.4). This requires modifying the attention processor. I believe that the architecture of FLUX and other transformers can also be used to develop better layout control algorithms, so I believe these modifications will contribute to future training-free experiments on FLUX.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
joint_attention_kwargs={},
joint_attention_kwargs=None,

):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)

Expand All @@ -173,6 +175,8 @@ def forward(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,

)

# Process attention outputs for the `hidden_states`.
Expand Down Expand Up @@ -497,6 +501,7 @@ def custom_forward(*inputs):
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

# controlnet residual
Expand Down Expand Up @@ -533,6 +538,7 @@ def custom_forward(*inputs):
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

# controlnet residual
Expand Down