Skip to content

Commit

Permalink
Merge pull request #26 from microsoft/ae-foster/pallas-attention
Browse files Browse the repository at this point in the history
Add `pallas` fused attention and a corresponding forward Laplacian
  • Loading branch information
n-gao authored Dec 7, 2024
2 parents d6da803 + 32a5c07 commit a983cdc
Show file tree
Hide file tree
Showing 6 changed files with 1,261 additions and 0 deletions.
74 changes: 74 additions & 0 deletions folx/experimental/pallas/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from functools import partial
from typing import Literal

import jax
from folx import register_function

from .custom_gradients import mhsa_backward, mhsa_forward
from .forward_laplacian import mhsa_forward_laplacian
from .mhsa import mhsa

custom_vjp_mhsa = jax.custom_vjp(mhsa, nondiff_argnums=(5, 6, 7, 8, 9))
custom_vjp_mhsa.defvjp(mhsa_forward, mhsa_backward)


@partial(jax.jit, static_argnums=(5, 6, 7, 8, 9))
def multi_head_self_attention(
q: jax.Array,
k: jax.Array,
v: jax.Array,
# TODO: support multiple masks for cross-attention
mask: jax.Array,
input_mask: jax.Array,
kernel: Literal["pallas", "reference"] = "pallas",
interpret: bool = False,
q_block_len: int | None = None,
num_warps: int = 2,
num_stages: int = 2,
):
r"""Compute multi-head attention (support VJP not JVP).
Having this wrapper jit block is necessary for folx to recognize the attention block.
Args:
q: Queries of shape ``(batch_size, sequence_length, num_heads, head_dim)``
k: Keys of shape ``(batch_size, sequence_length, num_heads, head_dim)``
v: Values of shape ``(batch_size, sequence_length, num_heads, head_dim)``
mask: Mask of the q, k, v values, shape ``(batch_size, sequence_length)``
input_mask: Used only during mode forward Laplacian: mask of the original
input to the model, with respect to which the forward Laplacian is computed.
For us, normally of shape ``(3 * sequence_length, batch_size)``, but
if ``q``, ``k``, ``v`` are padded (e.g. in FLASH attention below with
``n_elec < 16``), this should still retain the original ``3 * n_elec``
length.
kernel (str): Default ``pallas``. The kernel to use.
- folx: the vanilla folx kernel is used.
- reference: the reference jax kernel is used.
- pallas: the pallas kernel is used.
interpret: If ``True``, the pallas kernels are executed in interpret mode,
which allows them to be executed e.g. on a CPU (slow). Default is ``False``.
q_block_len (int | None): If ``None``, there is no blocking of the query
array, otherwise it's blocked into blocks of length ``q_block_len``.
Default is ``None``.
num_warps (int): The number of threads to execute a single instance of the
kernel with. Default is 2.
num_stages (int): The number of stages. Default is 2.
"""
return custom_vjp_mhsa(
q,
k,
v,
mask,
input_mask,
kernel,
interpret,
q_block_len,
num_warps,
num_stages,
)


register_function("multi_head_self_attention", mhsa_forward_laplacian)


__all__ = ["multi_head_self_attention"]
208 changes: 208 additions & 0 deletions folx/experimental/pallas/attention/custom_gradients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import logging
from functools import partial
from typing import Literal, Tuple

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

from .mhsa import mhsa_kernel, reference_mhsa_kernel
from .utils import (
big_number,
compute_q_and_kv_block_len,
create_grid,
get_mask_block_spec,
get_value_or_laplacian_block_spec,
sum_columns,
)

#######################################################################################################
# Multi-head attention VJP
#######################################################################################################


def mhsa_forward(
q: jax.Array,
k: jax.Array,
v: jax.Array,
mask: jax.Array,
input_mask: jax.Array,
kernel: Literal['pallas', 'reference'],
interpret: bool,
q_block_len: int | None,
num_warps: int,
num_stages: int,
) -> Tuple[jax.Array, Tuple[jax.Array, jax.Array, jax.Array, jax.Array]]:
del input_mask # Only used in the forward Laplacian
batch_len, seq_len, num_heads, head_len = q.shape
q_block_len, kv_block_len = compute_q_and_kv_block_len(seq_len, q_block_len)

if kernel == 'pallas':
kernel_fn = pl.pallas_call(
partial(mhsa_kernel, q_block_len=q_block_len),
grid=create_grid(batch_len, seq_len, num_heads, q_block_len),
in_specs=[
get_value_or_laplacian_block_spec(seq_len, head_len, q_block_len),
get_value_or_laplacian_block_spec(seq_len, head_len, kv_block_len),
get_value_or_laplacian_block_spec(seq_len, head_len, kv_block_len),
get_mask_block_spec(seq_len, q_block_len),
],
out_specs=get_value_or_laplacian_block_spec(seq_len, head_len, q_block_len),
out_shape=jax.ShapeDtypeStruct(
shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype
),
compiler_params=dict(
triton=dict(num_warps=num_warps, num_stages=num_stages)
),
debug=False,
interpret=interpret,
name='mhsa_forward',
)
elif kernel == 'reference':
logging.warning(
'Passing kernel="reference" to function mhsa is not recommended in production, '
'as it is very slow. Use kernel="pallas" instead.'
)
kernel_fn = reference_mhsa_kernel
else:
raise ValueError(f'Unknown multi-head attention kernel: {kernel}')
o = kernel_fn(q, k, v, mask)
return o, (q, k, v, mask)


def mhsa_backward(
kernel: Literal['pallas', 'reference'],
interpret: bool,
q_block_len: int | None,
num_warps: int,
num_stages: int,
fwd_cache: Tuple[jax.Array, jax.Array, jax.Array, jax.Array],
o_vjp: jax.Array,
) -> Tuple[jax.Array, jax.Array, jax.Array, None, None]:
assert q_block_len is None, 'Q blocking is not implemented in backward'
q, k, v, mask = fwd_cache
batch_len, seq_len, num_heads, head_len = q.shape
q_block_len, kv_block_len = compute_q_and_kv_block_len(seq_len, q_block_len)

if kernel == 'pallas':
kernel_fn = pl.pallas_call(
mhsa_backward_kernel,
grid=create_grid(batch_len, seq_len, num_heads, q_block_len),
in_specs=[
get_value_or_laplacian_block_spec(seq_len, head_len, q_block_len),
get_value_or_laplacian_block_spec(seq_len, head_len, kv_block_len),
get_value_or_laplacian_block_spec(seq_len, head_len, kv_block_len),
get_mask_block_spec(seq_len, q_block_len),
get_value_or_laplacian_block_spec(seq_len, head_len, q_block_len),
],
out_specs=[
get_value_or_laplacian_block_spec(seq_len, head_len, q_block_len),
get_value_or_laplacian_block_spec(seq_len, head_len, kv_block_len),
get_value_or_laplacian_block_spec(seq_len, head_len, kv_block_len),
],
out_shape=[
jax.ShapeDtypeStruct(
shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype
),
jax.ShapeDtypeStruct(
shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype
),
jax.ShapeDtypeStruct(
shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype
),
],
compiler_params=dict(
triton=dict(num_warps=num_warps, num_stages=num_stages)
),
debug=False,
interpret=interpret,
name='mhsa_backward',
)
elif kernel == 'reference':
kernel_fn = reference_mhsa_backward_kernel
else:
raise ValueError(f'Unknown multi-head attention kernel: {kernel}')
dq, dk, dv = kernel_fn(q, k, v, mask, o_vjp)
return dq, dk, dv, None, None


def reference_mhsa_backward_kernel(
q: jax.Array, k: jax.Array, v: jax.Array, mask: jax.Array, o_vjp: jax.Array
) -> Tuple[jax.Array, jax.Array, jax.Array]:
r"""Reference jax implementation of the multi-head attention backward kernel."""
# [batch_size, seq_len, num_heads, seq_len]
q = jnp.where(mask[:, :, None, None], q, 0.0)
square_mask = mask[:, None, None, :] * mask[:, :, None, None]
s = jnp.einsum('Biha,Bjha->Bihj', q, k)
s = jnp.where(square_mask, s, -big_number(q.dtype))
p = jax.nn.softmax(s, axis=-1)

# Compute the VJPs
p_vjp = jnp.einsum('Biha,Bjha->Bihj', o_vjp, v)
q_vjp = jnp.einsum('Bkha,Bihk,Bihk->Biha', k, p, p_vjp) - jnp.einsum(
'Bmha,Bihk,Bihm,Bihk->Biha', k, p, p, p_vjp
)
k_vjp = jnp.einsum('Bjha,Bjhi,Bjhi->Biha', q, p, p_vjp) - jnp.einsum(
'Bjha,Bjhk,Bjhi,Bjhk->Biha', q, p, p, p_vjp
)
v_vjp = jnp.einsum('Bjhi,Bjha->Biha', p, o_vjp)

return q_vjp, k_vjp, v_vjp


def mhsa_backward_kernel(
q_ref, # Inputs
k_ref,
v_ref,
mask_ref,
o_vjp_ref,
q_vjp_ref, # Outputs
k_vjp_ref,
v_vjp_ref,
):
r"""The pallas implementation of the backward of the multi-head attention kernel.
Here pallas grid has already removed the batch and head dimensions.
Args:
q_ref: Queries, shape ``(sequence_length, head_dim)``
k_ref: Keys, shape ``(sequence_length, head_dim)``
v_ref: Values, shape ``(sequence_length, head_dim)``
mask_ref: Mask of the q, k, v values, shape ``(sequence_length,)``
o_vjp_ref: VJP of the output of MHA, shape ``(sequence_length, head_dim)``
q_vjp_ref: output, VJP of the queries, shape ``(sequence_length, head_dim)``
k_vjp_ref: output, VJP of the keys, shape ``(sequence_length, head_dim)``
v_vjp_ref: output, VJP of the values, shape ``(sequence_length, head_dim)``
"""
mask = mask_ref[:]
square_mask = mask[:, None] * mask[None, :]
# Recompute the output to save memory
q = jnp.where(mask[:, None], q_ref[:, :], 0.0)
k = jnp.where(mask[:, None], k_ref[:, :], 0.0)
v = jnp.where(mask[:, None], v_ref[:, :], 0.0)
s = jnp.where(square_mask, pl.dot(q, k, trans_b=True), -big_number(q.dtype))
p = jax.nn.softmax(s)

# Compute the VJPs
o_vjp = o_vjp_ref[:, :]

# v_vjp
v_vjp = pl.dot(p, o_vjp, trans_a=True)
v_vjp_ref[:, :] = v_vjp

# q_vjp
lo_v_p = pl.dot(o_vjp, v, trans_b=True) * p
## First term
q_vjp = pl.dot(lo_v_p, k)
## Second term
pk = pl.dot(p, k)
q_vjp -= pk * sum_columns(lo_v_p)
q_vjp_ref[:, :] = q_vjp

# k_vjp
## First term
k_vjp = pl.dot(lo_v_p.T, q)
## Second term
p_vjp = pl.dot(o_vjp, v, trans_b=True)
k_vjp -= pl.dot((p * sum_columns(p_vjp * p)), q, trans_a=True)
k_vjp_ref[:, :] = k_vjp
Loading

0 comments on commit a983cdc

Please sign in to comment.