Skip to content

Commit

Permalink
[DC-AE, SANA] fix SanaMultiscaleLinearAttention apply_quadratic_atten…
Browse files Browse the repository at this point in the history
…tion bf16 (#10595)

* autoencoder_dc tiling

* add tiling and slicing support in SANA pipelines

* create variables for padding length because the line becomes too long

* add tiling and slicing support in pag SANA pipelines

* revert changes to tile size

* make style

* add vae tiling test

* fix SanaMultiscaleLinearAttention apply_quadratic_attention bf16

---------

Co-authored-by: Aryan <[email protected]>
  • Loading branch information
chenjy2003 and a-r-r-o-w authored Jan 16, 2025
1 parent e8114bd commit b785ddb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, valu
scores = torch.matmul(key.transpose(-1, -2), query)
scores = scores.to(dtype=torch.float32)
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
hidden_states = torch.matmul(value, scores)
hidden_states = torch.matmul(value, scores.to(value.dtype))
return hidden_states

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit b785ddb

Please sign in to comment.