From fa93fa4c0fe9627702473adfbb517fa3f67f6870 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 23 Jul 2022 14:44:01 -0700 Subject: [PATCH] test out flash attention in GPT --- memory_efficient_attention_pytorch/transformer.py | 7 +++++-- setup.py | 2 +- tests/test.py | 6 ++++-- train.py | 6 +++--- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/memory_efficient_attention_pytorch/transformer.py b/memory_efficient_attention_pytorch/transformer.py index f2248d8..849ec2b 100644 --- a/memory_efficient_attention_pytorch/transformer.py +++ b/memory_efficient_attention_pytorch/transformer.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from einops import rearrange -from memory_efficient_attention_pytorch import Attention +from memory_efficient_attention_pytorch import FlashAttention, Attention from memory_efficient_attention_pytorch.reversible import ReversibleSequence def exists(val): @@ -51,6 +51,7 @@ def __init__( heads = 8, ff_mult = 4, ff_chunks = 1, + use_flash_attn = True, **kwargs ): super().__init__() @@ -59,10 +60,12 @@ def __init__( self.token_emb = nn.Embedding(num_tokens, dim) self.pos_emb = nn.Embedding(max_seq_len, dim) + attn_klass = FlashAttention if use_flash_attn else partial(Attention, memory_efficient = True) + self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim = dim, dim_head = dim_head, heads = heads, causal = causal, **kwargs)), + PreNorm(dim, FlashAttention(dim = dim, dim_head = dim_head, heads = heads, causal = causal, **kwargs)), PreNorm(dim, FeedForward(dim = dim, mult = ff_mult, chunks = ff_chunks)), ])) diff --git a/setup.py b/setup.py index 78f54d5..b8cceff 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'memory-efficient-attention-pytorch', packages = find_packages(exclude=[]), - version = '0.0.23', + version = '0.0.24', license='MIT', description = 'Memory Efficient Attention - Pytorch', long_description_content_type = 'text/markdown', diff --git a/tests/test.py b/tests/test.py index fbdc470..1a152fb 100644 --- a/tests/test.py +++ b/tests/test.py @@ -91,7 +91,9 @@ def test_flash_attn_gradients_equal(): k = torch.randn(1, 8, 1024, 512).requires_grad_() v = torch.randn(1, 8, 1024, 512).requires_grad_() - o = attention(q, k, v, causal = False) + mask = torch.ones(1, 1024).bool() + + o = attention(q, k, v, mask = mask, causal = True) o.sum().backward() dq_grad = q.grad.clone() @@ -102,7 +104,7 @@ def test_flash_attn_gradients_equal(): k.grad.zero_() v.grad.zero_() - flash_o = FlashAttentionFunction.apply(q, k, v, None, False, 64, 64) + flash_o = FlashAttentionFunction.apply(q, k, v, mask, True, 64, 64) flash_o.sum().backward() flash_dq_grad = q.grad.clone() diff --git a/train.py b/train.py index 4844a41..fa005b7 100644 --- a/train.py +++ b/train.py @@ -18,7 +18,7 @@ LEARNING_RATE = 2e-4 VALIDATE_EVERY = 100 GENERATE_EVERY = 500 -GENERATE_LENGTH = 4096 +GENERATE_LENGTH = 1024 SEQ_LEN = 4096 # helpers @@ -43,10 +43,10 @@ def decode_tokens(tokens): depth = 6, heads = 8, causal = True, - memory_efficient = True, q_bucket_size = 256, k_bucket_size = 256, - ff_chunks = 5 + ff_chunks = 5, + use_flash_attn = True ) model = AutoregressiveWrapper(model)