diff --git a/folx/experimental/pallas/attention/forward_laplacian.py b/folx/experimental/pallas/attention/forward_laplacian.py index 1f5c9d1..5f9fd88 100644 --- a/folx/experimental/pallas/attention/forward_laplacian.py +++ b/folx/experimental/pallas/attention/forward_laplacian.py @@ -351,7 +351,7 @@ def mhsa_forward_laplacian_kernel( q_mask = pl.load(mask_ref, (q_slice,)) square_mask = q_mask[:, None] * kv_mask[None, :] # Forward pass - q = jnp.where(q_mask[:, None], q_x_ref[q_slice, :], 0.0) + q = jnp.where(q_mask[:, None], q_x_ref[:, :], 0.0) s = jnp.where(square_mask, pl.dot(q, k, trans_b=True), -big_number(q.dtype)) p = jax.nn.softmax(s, axis=1) o = pl.dot(p, v) @@ -359,7 +359,7 @@ def mhsa_forward_laplacian_kernel( # Laplacian L(h) J(F) terms # We don't need to mask q_lap, no cross-electron contributions - q_lap = jnp.where(q_mask[:, None], q_lap_ref[q_slice, :], 0.0) + q_lap = jnp.where(q_mask[:, None], q_lap_ref[:, :], 0.0) qr2_k = pl.dot(q_lap, k, trans_b=True) qr2_k_p = qr2_k * p q_kr2 = pl.dot(q, k_lap, trans_b=True) @@ -371,7 +371,7 @@ def mhsa_forward_laplacian_kernel( def body_of_loop_over_elec_coords(p_idx, o_lap): # Jacobian # We don't need to mask the electron coordinate axis of the Jacobian, no cross-electron-coordinate contributions - q_jac = jnp.where(q_mask[:, None], q_jac_ref[p_idx, q_slice, :], 0.0) + q_jac = jnp.where(q_mask[:, None], q_jac_ref[p_idx, :, :], 0.0) k_jac = jnp.where(kv_mask[:, None], k_jac_ref[p_idx, :, :], 0.0) v_jac = jnp.where(kv_mask[:, None], v_jac_ref[p_idx, :, :], 0.0) input_mask = input_mask_ref[p_idx] diff --git a/folx/experimental/pallas/attention/mhsa.py b/folx/experimental/pallas/attention/mhsa.py index 3f84f70..db4b00c 100644 --- a/folx/experimental/pallas/attention/mhsa.py +++ b/folx/experimental/pallas/attention/mhsa.py @@ -117,7 +117,7 @@ def mhsa_kernel( q_mask = pl.load(mask_ref, (q_slice,)) square_mask = q_mask[:, None] * kv_mask[None, :] # Forward pass - q = jnp.where(q_mask[:, None], q_ref[q_slice, :], 0.0) + q = jnp.where(q_mask[:, None], q_ref[:, :], 0.0) k = jnp.where(kv_mask[:, None], k_ref[:, :], 0.0) v = jnp.where(kv_mask[:, None], v_ref[:, :], 0.0) s = jnp.where(square_mask, pl.dot(q, k, trans_b=True), -big_number(q.dtype))