Skip to content

Commit

Permalink
Fix HunyuanVideo produces NaN on PyTorch<2.5 (#10482)
Browse files Browse the repository at this point in the history
Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
2 people authored and DN6 committed Jan 15, 2025
1 parent 2b432ac commit 13ea83f
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -713,15 +713,15 @@ def forward(
condition_sequence_length = encoder_hidden_states.shape[1]
sequence_length = latent_sequence_length + condition_sequence_length
attention_mask = torch.zeros(
batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N, N]
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N]

effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length

for i in range(batch_size):
attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True
attention_mask = attention_mask.unsqueeze(1) # [B, 1, N, N], for broadcasting across attention heads
attention_mask[i, : effective_sequence_length[i]] = True
attention_mask = attention_mask.unsqueeze(1) # [B, 1, N], for broadcasting across attention heads

# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
Expand Down

0 comments on commit 13ea83f

Please sign in to comment.