Skip to content

Commit

Permalink
add tests, fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 23, 2022
1 parent c901ae5 commit 33fb78a
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 17 deletions.
1 change: 1 addition & 0 deletions memory_efficient_attention_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from memory_efficient_attention_pytorch.memory_efficient_attention import Attention, memory_efficient_attention
from memory_efficient_attention_pytorch.memory_efficient_cosine_sim_attention import CosineSimAttention, numerically_unstable_memory_efficient_attention
from memory_efficient_attention_pytorch.flash_attention import FlashAttention
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def forward(ctx, q, k, v, mask, scale, causal, q_bucket_size, k_bucket_size):
if not exists(mask):
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
else:
mask = mask.split(q_bucket_size, dim = -2)
mask = rearrange(mask, 'b n -> b 1 1 n')
mask = mask.split(q_bucket_size, dim = -1)

row_splits = zip(
q.split(q_bucket_size, dim = -2),
Expand Down Expand Up @@ -184,7 +185,7 @@ def __init__(

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

# memory efficient attention related parameters
# can be overriden on forward
Expand Down
21 changes: 9 additions & 12 deletions memory_efficient_attention_pytorch/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, device = device)

scale = (q.shape[-1] ** -0.5)
q = q * scale

if not exists(mask):
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
else:
mask = mask.split(q_bucket_size, dim = -2)
mask = rearrange(mask, 'b n -> b 1 1 n')
mask = mask.split(q_bucket_size, dim = -1)

row_splits = zip(
q.split(q_bucket_size, dim = -2),
Expand All @@ -63,7 +63,7 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
for k_ind, (kc, vc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size

attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc)
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

if exists(row_mask):
attn_weights.masked_fill_(~row_mask, max_neg_value)
Expand All @@ -73,7 +73,6 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
attn_weights.masked_fill_(causal_mask, max_neg_value)

block_row_maxes = attn_weights.amax(dim = -1, keepdims = True)

attn_weights -= block_row_maxes
exp_weights = torch.exp(attn_weights)

Expand All @@ -82,7 +81,7 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):

block_row_sums = exp_weights.sum(dim = -1, keepdims = True).clamp(min = EPSILON)

new_row_maxes = torch.maximum(block_row_maxes, row_sums)
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)

exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)

Expand All @@ -92,10 +91,11 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums

oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)

row_maxes.copy_(new_row_maxes)
row_sums.copy_(new_row_sums)

ctx.args = (causal, mask, q_bucket_size, k_bucket_size)
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)

return o
Expand All @@ -105,7 +105,7 @@ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
def backward(ctx, do):
""" Algorithm 4 in the paper """

causal, mask, q_bucket_size, k_bucket_size = ctx.args
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
q, k, v, o, l, m = ctx.saved_tensors

device = q.device
Expand All @@ -117,8 +117,6 @@ def backward(ctx, do):
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)

scale = q.shape[-1] ** -0.5

row_splits = zip(
q.split(q_bucket_size, dim = -2),
o.split(q_bucket_size, dim = -2),
Expand All @@ -142,8 +140,7 @@ def backward(ctx, do):
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size

qc_scaled = qc * scale
attn_weights = einsum('... i d, ... j d -> ... i j', qc_scaled, kc)
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale

if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype = torch.bool, device = device).triu(q_start_index - k_start_index + 1)
Expand Down Expand Up @@ -197,7 +194,7 @@ def __init__(

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

# memory efficient attention related parameters
# can be overriden on forward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __init__(

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

# memory efficient attention related parameters
# can be overriden on forward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __init__(

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

# memory efficient attention related parameters
# can be overriden on forward
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'memory-efficient-attention-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.22',
version = '0.0.23',
license='MIT',
description = 'Memory Efficient Attention - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
59 changes: 59 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch
from memory_efficient_attention_pytorch import Attention

from memory_efficient_attention_pytorch.memory_efficient_attention import attention
from memory_efficient_attention_pytorch.flash_attention import FlashAttention, FlashAttentionFunction

# constants

def isclose(a, b, atol = 1e-6):
Expand Down Expand Up @@ -53,3 +56,59 @@ def loss_fn(inp, **kwargs):
mem_efficient_out_grad = x.grad.clone()

assert isclose(out_grad, mem_efficient_out_grad, atol = 1e-5)

# test flash attention

def test_flash_attn_output_equal():
attn_kwargs = dict(
dim = 512,
dim_head = 64,
heads = 8,
q_bucket_size = 64,
k_bucket_size = 64,
causal = True
)

attn = Attention(**attn_kwargs)
flash_attn = FlashAttention(**attn_kwargs)

flash_attn.to_q = attn.to_q
flash_attn.to_kv = attn.to_kv
flash_attn.to_out = attn.to_out

x = torch.randn(2, 2048, 512)
mask = torch.ones(2, 2048).bool()

out = attn(x, mask = mask)
mem_efficient_out = flash_attn(x, mask = mask)

assert isclose(mem_efficient_out, out, atol = 1e-6)

# test gradients equal

def test_flash_attn_gradients_equal():
q = torch.randn(1, 8, 1024, 512).requires_grad_()
k = torch.randn(1, 8, 1024, 512).requires_grad_()
v = torch.randn(1, 8, 1024, 512).requires_grad_()

o = attention(q, k, v, causal = False)
o.sum().backward()

dq_grad = q.grad.clone()
dk_grad = k.grad.clone()
dv_grad = v.grad.clone()

q.grad.zero_()
k.grad.zero_()
v.grad.zero_()

flash_o = FlashAttentionFunction.apply(q, k, v, None, False, 64, 64)
flash_o.sum().backward()

flash_dq_grad = q.grad.clone()
flash_dk_grad = k.grad.clone()
flash_dv_grad = v.grad.clone()

assert isclose(flash_dq_grad, dq_grad, atol = 1e-5)
assert isclose(flash_dk_grad, dk_grad, atol = 1e-5)
assert isclose(flash_dv_grad, dv_grad, atol = 1e-5)

0 comments on commit 33fb78a

Please sign in to comment.