Skip to content

Commit

Permalink
complete cross attention dropout logic
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 22, 2022
1 parent be37653 commit 285e61b
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 6 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ model = PerceiverAR(
heads = 8, # attention heads
max_seq_len = 4096, # total max sequence length
cross_attn_seq_len = 3072, # the sequence length in which to attend to, but does not undergo self attention (must be less than max_seq_len)
cross_attn_dropout = 0.5, # what percentage of the prefix to dropout during training, in paper they had extensive experimentation to show up to 50% dropout helped prevent overfitting
)

x = torch.randint(0, 20000, (1, 4096))
Expand Down
47 changes: 42 additions & 5 deletions perceiver_ar_pytorch/perceiver_ar_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange
from einops import rearrange, repeat

# helper functions

Expand Down Expand Up @@ -104,7 +104,8 @@ def __init__(
dim_head = 64,
heads = 8,
max_heads_process = 2,
dropout = 0.
dropout = 0.,
cross_attn_dropout = 0.
):
super().__init__()
self.scale = dim_head ** -0.5
Expand All @@ -117,14 +118,47 @@ def __init__(
self.context_norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)

self.cross_attn_dropout = cross_attn_dropout # they drop out a percentage of the prefix during training, shown to help prevent overfitting

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)

def forward(self, x, context, context_mask = None, rotary_pos_emb = None):
batch, context_len, device = x.shape[0], context.shape[-2], x.device

q_rotary_pos_emb = rotary_pos_emb
k_rotary_pos_emb = rotary_pos_emb

# take care of cross attention dropout

if self.training and self.cross_attn_dropout > 0.:
rand = torch.zeros((batch, context_len), device = device).uniform_()
keep_context_len = context_len - int(context_len * self.cross_attn_dropout)
keep_indices = rand.topk(keep_context_len, dim = -1).indices
keep_mask = torch.zeros_like(rand).scatter_(1, keep_indices, 1).bool()

context = rearrange(context[keep_mask], '(b n) d -> b n d', b = batch)

if exists(context_mask):
context_mask = rearrange(context_mask[keep_mask], '(b n) -> b n', b = batch)

# operate on rotary position embeddings for keys

k_rotary_pos_emb = repeat(k_rotary_pos_emb, '... -> b ...', b = batch)
k_rotary_pos_emb_context, k_rotary_pos_emb_seq = k_rotary_pos_emb[:, :context_len], k_rotary_pos_emb[:, context_len:]
k_rotary_pos_emb_context = rearrange(k_rotary_pos_emb_context[keep_mask], '(b n) d -> b n d', b = batch)

k_rotary_pos_emb = torch.cat((k_rotary_pos_emb_context, k_rotary_pos_emb_seq), dim = 1)
k_rotary_pos_emb = rearrange(k_rotary_pos_emb, 'b n d -> b 1 n d')

# normalization

x = self.norm(x)
context = self.context_norm(context)

# derive queries, keys, values

q = self.to_q(x)

k_input, v_input = self.to_kv(x).chunk(2, dim = -1)
Expand All @@ -137,9 +171,11 @@ def forward(self, x, context, context_mask = None, rotary_pos_emb = None):

q = q * self.scale

# rotate queries and keys with rotary embeddings

if exists(rotary_pos_emb):
q = apply_rotary_pos_emb(rotary_pos_emb, q)
k = apply_rotary_pos_emb(rotary_pos_emb, k)
q = apply_rotary_pos_emb(q_rotary_pos_emb, q)
k = apply_rotary_pos_emb(k_rotary_pos_emb, k)

# take care of masking

Expand Down Expand Up @@ -195,6 +231,7 @@ def __init__(
dim_head = 64,
heads = 8,
dropout = 0.,
cross_attn_dropout = 0.,
ff_mult = 4,
perceive_depth = 1,
perceive_max_heads_process = 2 # processes the heads in the perceiver layer in chunks to lower peak memory, in the case the prefix is really long
Expand All @@ -213,7 +250,7 @@ def __init__(

for _ in range(perceive_depth):
self.perceive_layers.append(nn.ModuleList([
CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, max_heads_process = perceive_max_heads_process, dropout = dropout),
CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, max_heads_process = perceive_max_heads_process, dropout = dropout, cross_attn_dropout = cross_attn_dropout),
FeedForward(dim, mult = ff_mult, dropout = dropout)
]))

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'perceiver-ar-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'Perceiver AR',
author = 'Phil Wang',
Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def decode_tokens(tokens):
depth = 8,
heads = 8,
dim_head = 64,
cross_attn_dropout = 0.5,
max_seq_len = SEQ_LEN,
cross_attn_seq_len = PREFIX_SEQ_LEN
)
Expand Down

0 comments on commit 285e61b

Please sign in to comment.