From 13ea83f0faecf6ef475d58c4137e563c1014fcc5 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 7 Jan 2025 23:13:55 +0000 Subject: [PATCH] Fix HunyuanVideo produces NaN on PyTorch<2.5 (#10482) Co-authored-by: Sayak Paul --- .../models/transformers/transformer_hunyuan_video.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 6cb97af93652..846104718b9a 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -713,15 +713,15 @@ def forward( condition_sequence_length = encoder_hidden_states.shape[1] sequence_length = latent_sequence_length + condition_sequence_length attention_mask = torch.zeros( - batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool - ) # [B, N, N] + batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool + ) # [B, N] effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,] effective_sequence_length = latent_sequence_length + effective_condition_sequence_length for i in range(batch_size): - attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True - attention_mask = attention_mask.unsqueeze(1) # [B, 1, N, N], for broadcasting across attention heads + attention_mask[i, : effective_sequence_length[i]] = True + attention_mask = attention_mask.unsqueeze(1) # [B, 1, N], for broadcasting across attention heads # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: