Skip to content

Commit

Permalink
allow for full attention mask in naive flash attention impl
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 16, 2023
1 parent edcef83 commit 55ae343
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 15 deletions.
32 changes: 20 additions & 12 deletions memory_efficient_attention_pytorch/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,18 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):

scale = (q.shape[-1] ** -0.5)

num_row_tiles = math.ceil(q.shape[-2] / q_bucket_size)
num_col_tiles = math.ceil(k.shape[-2] / k_bucket_size)

if exists(mask) and mask.ndim == 2:
mask = rearrange(mask, 'b n -> b 1 1 n')

if not exists(mask):
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
col_masks = (None,) * num_col_tiles
mask = (col_masks,) * num_row_tiles
else:
mask = rearrange(mask, 'b n -> b 1 1 n')
mask = mask.split(q_bucket_size, dim = -1)
mask = ((mask,) * num_row_tiles) if mask.shape[-2] == 1 else mask.split(q_bucket_size, dim = -2)
mask = tuple(((row_mask,) * num_col_tiles) if row_mask.shape[-1] == 1 else row_mask.split(k_bucket_size, dim = -1) for row_mask in mask)

row_splits = zip(
q.split(q_bucket_size, dim = -2),
Expand All @@ -58,15 +65,16 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
col_splits = zip(
k.split(k_bucket_size, dim = -2),
v.split(k_bucket_size, dim = -2),
row_mask
)

for k_ind, (kc, vc) in enumerate(col_splits):
for k_ind, (kc, vc, col_mask) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size

attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

if exists(row_mask):
attn_weights.masked_fill_(~row_mask, max_neg_value)
if exists(col_mask):
attn_weights.masked_fill_(~col_mask, max_neg_value)

if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
Expand All @@ -76,8 +84,8 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
attn_weights -= block_row_maxes
exp_weights = torch.exp(attn_weights)

if exists(row_mask):
exp_weights.masked_fill_(~row_mask, 0.)
if exists(col_mask):
exp_weights.masked_fill_(~col_mask, 0.)

block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)

Expand Down Expand Up @@ -136,9 +144,10 @@ def backward(ctx, do):
v.split(k_bucket_size, dim = -2),
dk.split(k_bucket_size, dim = -2),
dv.split(k_bucket_size, dim = -2),
row_mask
)

for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
for k_ind, (kc, vc, dkc, dvc, col_mask) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size

attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
Expand All @@ -149,8 +158,8 @@ def backward(ctx, do):

p = torch.exp(attn_weights - lsec)

if exists(row_mask):
p.masked_fill_(~row_mask, 0.)
if exists(col_mask):
p.masked_fill_(~col_mask, 0.)

dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
Expand Down Expand Up @@ -186,7 +195,6 @@ def __init__(
):
super().__init__()
self.heads = heads

self.causal = causal

inner_dim = heads * dim_head
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def attention(
mask_value = -torch.finfo(sim.dtype).max

if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
if mask.ndim == 2:
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, mask_value)

if causal:
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
'pytest-runner',
],
tests_require=[
'pytest',
'torch==1.12.1'
'pytest'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down
58 changes: 58 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,61 @@ def test_flash_attn_gradients_equal():
assert isclose(flash_dq_grad, dq_grad, atol = 1e-5)
assert isclose(flash_dk_grad, dk_grad, atol = 1e-5)
assert isclose(flash_dv_grad, dv_grad, atol = 1e-5)

# test flash attention - full attention mask

def test_flash_attn_full_attn_mask_output_equal():
attn_kwargs = dict(
dim = 512,
dim_head = 64,
heads = 8,
q_bucket_size = 64,
k_bucket_size = 64,
causal = True
)

attn = Attention(**attn_kwargs)
flash_attn = FlashAttention(**attn_kwargs)

flash_attn.to_q = attn.to_q
flash_attn.to_kv = attn.to_kv
flash_attn.to_out = attn.to_out

x = torch.randn(2, 2048, 512)
mask = torch.ones(2, 1, 2048, 2048).bool()

out = attn(x, mask = mask)
mem_efficient_out = flash_attn(x, mask = mask)

assert isclose(mem_efficient_out, out, atol = 1e-6)

# test gradients equal - full attention mask

def test_flash_attn_full_attn_mask_gradients_equal():
q = torch.randn(1, 8, 1024, 512).requires_grad_()
k = torch.randn(1, 8, 1024, 512).requires_grad_()
v = torch.randn(1, 8, 1024, 512).requires_grad_()

mask = torch.ones(1, 1, 1024, 1024).bool()

o = attention(q, k, v, mask = mask, causal = True)
o.sum().backward()

dq_grad = q.grad.clone()
dk_grad = k.grad.clone()
dv_grad = v.grad.clone()

q.grad.zero_()
k.grad.zero_()
v.grad.zero_()

flash_o = FlashAttentionFunction.apply(q, k, v, mask, True, 64, 64)
flash_o.sum().backward()

flash_dq_grad = q.grad.clone()
flash_dk_grad = k.grad.clone()
flash_dv_grad = v.grad.clone()

assert isclose(flash_dq_grad, dq_grad, atol = 1e-5)
assert isclose(flash_dk_grad, dk_grad, atol = 1e-5)
assert isclose(flash_dv_grad, dv_grad, atol = 1e-5)

0 comments on commit 55ae343

Please sign in to comment.