From 451ca0af739f16fa93aa5028d8fa24a08ae85cdc Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 13 Jan 2025 05:23:44 +0000 Subject: [PATCH] Fix batch > 1 in HunyuanVideo --- src/diffusers/models/transformers/transformer_hunyuan_video.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 044f2048775f..4495623119e5 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -727,7 +727,8 @@ def forward( for i in range(batch_size): attention_mask[i, : effective_sequence_length[i]] = True - attention_mask = attention_mask.unsqueeze(1) # [B, 1, N], for broadcasting across attention heads + # [B, 1, 1, N], for broadcasting across attention heads + attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: