From 285e61be862000c38bfeb7ebf61a469be5131a32 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 21 Jun 2022 19:08:14 -0700 Subject: [PATCH] complete cross attention dropout logic --- README.md | 1 + perceiver_ar_pytorch/perceiver_ar_pytorch.py | 47 +++++++++++++++++--- setup.py | 2 +- train.py | 1 + 4 files changed, 45 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index f44edf9..7250477 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ model = PerceiverAR( heads = 8, # attention heads max_seq_len = 4096, # total max sequence length cross_attn_seq_len = 3072, # the sequence length in which to attend to, but does not undergo self attention (must be less than max_seq_len) + cross_attn_dropout = 0.5, # what percentage of the prefix to dropout during training, in paper they had extensive experimentation to show up to 50% dropout helped prevent overfitting ) x = torch.randint(0, 20000, (1, 4096)) diff --git a/perceiver_ar_pytorch/perceiver_ar_pytorch.py b/perceiver_ar_pytorch/perceiver_ar_pytorch.py index c82aede..00780cb 100644 --- a/perceiver_ar_pytorch/perceiver_ar_pytorch.py +++ b/perceiver_ar_pytorch/perceiver_ar_pytorch.py @@ -2,7 +2,7 @@ import torch.nn.functional as F from torch import nn, einsum -from einops import rearrange +from einops import rearrange, repeat # helper functions @@ -104,7 +104,8 @@ def __init__( dim_head = 64, heads = 8, max_heads_process = 2, - dropout = 0. + dropout = 0., + cross_attn_dropout = 0. ): super().__init__() self.scale = dim_head ** -0.5 @@ -117,14 +118,47 @@ def __init__( self.context_norm = nn.LayerNorm(dim) self.dropout = nn.Dropout(dropout) + self.cross_attn_dropout = cross_attn_dropout # they drop out a percentage of the prefix during training, shown to help prevent overfitting + self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.to_out = nn.Linear(inner_dim, dim) def forward(self, x, context, context_mask = None, rotary_pos_emb = None): + batch, context_len, device = x.shape[0], context.shape[-2], x.device + + q_rotary_pos_emb = rotary_pos_emb + k_rotary_pos_emb = rotary_pos_emb + + # take care of cross attention dropout + + if self.training and self.cross_attn_dropout > 0.: + rand = torch.zeros((batch, context_len), device = device).uniform_() + keep_context_len = context_len - int(context_len * self.cross_attn_dropout) + keep_indices = rand.topk(keep_context_len, dim = -1).indices + keep_mask = torch.zeros_like(rand).scatter_(1, keep_indices, 1).bool() + + context = rearrange(context[keep_mask], '(b n) d -> b n d', b = batch) + + if exists(context_mask): + context_mask = rearrange(context_mask[keep_mask], '(b n) -> b n', b = batch) + + # operate on rotary position embeddings for keys + + k_rotary_pos_emb = repeat(k_rotary_pos_emb, '... -> b ...', b = batch) + k_rotary_pos_emb_context, k_rotary_pos_emb_seq = k_rotary_pos_emb[:, :context_len], k_rotary_pos_emb[:, context_len:] + k_rotary_pos_emb_context = rearrange(k_rotary_pos_emb_context[keep_mask], '(b n) d -> b n d', b = batch) + + k_rotary_pos_emb = torch.cat((k_rotary_pos_emb_context, k_rotary_pos_emb_seq), dim = 1) + k_rotary_pos_emb = rearrange(k_rotary_pos_emb, 'b n d -> b 1 n d') + + # normalization + x = self.norm(x) context = self.context_norm(context) + # derive queries, keys, values + q = self.to_q(x) k_input, v_input = self.to_kv(x).chunk(2, dim = -1) @@ -137,9 +171,11 @@ def forward(self, x, context, context_mask = None, rotary_pos_emb = None): q = q * self.scale + # rotate queries and keys with rotary embeddings + if exists(rotary_pos_emb): - q = apply_rotary_pos_emb(rotary_pos_emb, q) - k = apply_rotary_pos_emb(rotary_pos_emb, k) + q = apply_rotary_pos_emb(q_rotary_pos_emb, q) + k = apply_rotary_pos_emb(k_rotary_pos_emb, k) # take care of masking @@ -195,6 +231,7 @@ def __init__( dim_head = 64, heads = 8, dropout = 0., + cross_attn_dropout = 0., ff_mult = 4, perceive_depth = 1, perceive_max_heads_process = 2 # processes the heads in the perceiver layer in chunks to lower peak memory, in the case the prefix is really long @@ -213,7 +250,7 @@ def __init__( for _ in range(perceive_depth): self.perceive_layers.append(nn.ModuleList([ - CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, max_heads_process = perceive_max_heads_process, dropout = dropout), + CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, max_heads_process = perceive_max_heads_process, dropout = dropout, cross_attn_dropout = cross_attn_dropout), FeedForward(dim, mult = ff_mult, dropout = dropout) ])) diff --git a/setup.py b/setup.py index 5803de6..1433117 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'perceiver-ar-pytorch', packages = find_packages(exclude=[]), - version = '0.0.6', + version = '0.0.7', license='MIT', description = 'Perceiver AR', author = 'Phil Wang', diff --git a/train.py b/train.py index 49b91d6..517fbc9 100644 --- a/train.py +++ b/train.py @@ -46,6 +46,7 @@ def decode_tokens(tokens): depth = 8, heads = 8, dim_head = 64, + cross_attn_dropout = 0.5, max_seq_len = SEQ_LEN, cross_attn_seq_len = PREFIX_SEQ_LEN )