From 0e14cacffc24b7926c94d5aa7a56ccc8baf1a800 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 14 Jan 2025 04:55:06 +0000 Subject: [PATCH] Fix batch > 1 in HunyuanVideo (#10548) --- src/diffusers/models/transformers/transformer_hunyuan_video.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 846104718b9a..b5ff734eee25 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -721,7 +721,8 @@ def forward( for i in range(batch_size): attention_mask[i, : effective_sequence_length[i]] = True - attention_mask = attention_mask.unsqueeze(1) # [B, 1, N], for broadcasting across attention heads + # [B, 1, 1, N], for broadcasting across attention heads + attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: