Skip to content

Commit

Permalink
don't introduce new bug
Browse files Browse the repository at this point in the history
ae-foster committed Oct 30, 2024
1 parent 896e926 commit c6fda6a
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions folx/experimental/pallas/attention/forward_laplacian.py
Original file line number Diff line number Diff line change
@@ -351,15 +351,15 @@ def mhsa_forward_laplacian_kernel(
q_mask = pl.load(mask_ref, (q_slice,))
square_mask = q_mask[:, None] * kv_mask[None, :]
# Forward pass
q = jnp.where(q_mask[:, None], q_x_ref[q_slice, :], 0.0)
q = jnp.where(q_mask[:, None], q_x_ref[:, :], 0.0)
s = jnp.where(square_mask, pl.dot(q, k, trans_b=True), -big_number(q.dtype))
p = jax.nn.softmax(s, axis=1)
o = pl.dot(p, v)
o_x_ref[:, :] = o

# Laplacian L(h) J(F) terms
# We don't need to mask q_lap, no cross-electron contributions
q_lap = jnp.where(q_mask[:, None], q_lap_ref[q_slice, :], 0.0)
q_lap = jnp.where(q_mask[:, None], q_lap_ref[:, :], 0.0)
qr2_k = pl.dot(q_lap, k, trans_b=True)
qr2_k_p = qr2_k * p
q_kr2 = pl.dot(q, k_lap, trans_b=True)
@@ -371,7 +371,7 @@ def mhsa_forward_laplacian_kernel(
def body_of_loop_over_elec_coords(p_idx, o_lap):
# Jacobian
# We don't need to mask the electron coordinate axis of the Jacobian, no cross-electron-coordinate contributions
q_jac = jnp.where(q_mask[:, None], q_jac_ref[p_idx, q_slice, :], 0.0)
q_jac = jnp.where(q_mask[:, None], q_jac_ref[p_idx, :, :], 0.0)
k_jac = jnp.where(kv_mask[:, None], k_jac_ref[p_idx, :, :], 0.0)
v_jac = jnp.where(kv_mask[:, None], v_jac_ref[p_idx, :, :], 0.0)
input_mask = input_mask_ref[p_idx]
2 changes: 1 addition & 1 deletion folx/experimental/pallas/attention/mhsa.py
Original file line number Diff line number Diff line change
@@ -117,7 +117,7 @@ def mhsa_kernel(
q_mask = pl.load(mask_ref, (q_slice,))
square_mask = q_mask[:, None] * kv_mask[None, :]
# Forward pass
q = jnp.where(q_mask[:, None], q_ref[q_slice, :], 0.0)
q = jnp.where(q_mask[:, None], q_ref[:, :], 0.0)
k = jnp.where(kv_mask[:, None], k_ref[:, :], 0.0)
v = jnp.where(kv_mask[:, None], v_ref[:, :], 0.0)
s = jnp.where(square_mask, pl.dot(q, k, trans_b=True), -big_number(q.dtype))

0 comments on commit c6fda6a

Please sign in to comment.