Skip to content

Commit

Permalink
qk is fp32 when float32_logits
Browse files Browse the repository at this point in the history
Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng committed Feb 13, 2025
1 parent 4a2137a commit f0e4993
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def __call__(
del self.scale_factor

if self.float32_logits:
query = query.astype(self.dtype)
key = key.astype(self.dtype)
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
h_q, h_kv = query.shape[-2], key.shape[-2]
# The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
# Therefore, we have to maintain two code paths.
Expand Down

0 comments on commit f0e4993

Please sign in to comment.