From 15fbf706d977fcf351167475bf6d794a43b7f4eb Mon Sep 17 00:00:00 2001 From: huangzehuan Date: Sun, 5 Jan 2025 16:21:31 +0800 Subject: [PATCH 1/3] Support pass kwargs to cogvideox custom attention processor --- src/diffusers/models/attention_processor.py | 4 ++++ .../models/transformers/cogvideox_transformer_3d.py | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4d7ae6bef26e..60fefd97c33d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2813,6 +2813,8 @@ def __call__( encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + *args, + **kwargs, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) @@ -2884,6 +2886,8 @@ def __call__( encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, + *args, + **kwargs, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index b47d439774cc..5f87ce5bc1d8 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -120,8 +120,10 @@ def forward( encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) + attention_kwargs = attention_kwargs or {} # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( @@ -133,6 +135,7 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, + **attention_kwargs, ) hidden_states = hidden_states + gate_msa * attn_hidden_states @@ -497,6 +500,7 @@ def custom_forward(*inputs): encoder_hidden_states, emb, image_rotary_emb, + attention_kwargs, **ckpt_kwargs, ) else: @@ -505,6 +509,7 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, temb=emb, image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, ) if not self.config.use_rotary_positional_embeddings: From 15b5fef2fb73717294da47dc789b92e82d5a1428 Mon Sep 17 00:00:00 2001 From: huangzehuan Date: Tue, 7 Jan 2025 11:26:20 +0800 Subject: [PATCH 2/3] remove args in cogvideox attn processor --- src/diffusers/models/attention_processor.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 60fefd97c33d..fef1593b673b 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2813,7 +2813,6 @@ def __call__( encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - *args, **kwargs, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) @@ -2886,7 +2885,6 @@ def __call__( encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - *args, **kwargs, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) From a44f962b8e5ff81866f4cad688691b952a67b656 Mon Sep 17 00:00:00 2001 From: huangzehuan Date: Thu, 9 Jan 2025 12:13:41 +0800 Subject: [PATCH 3/3] remove unused kwargs --- src/diffusers/models/attention_processor.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index fef1593b673b..4d7ae6bef26e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2813,7 +2813,6 @@ def __call__( encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - **kwargs, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) @@ -2885,7 +2884,6 @@ def __call__( encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, - **kwargs, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1)