From d86231b014f60fdb11149bbde5b3597440bfe5bf Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 10 Nov 2020 23:21:09 -0800 Subject: [PATCH] make performer work as encoder / decoder with cross attention --- README.md | 33 ++++++++++++++++++++ performer_pytorch/performer_pytorch.py | 43 +++++++++++++++++++------- setup.py | 2 +- 3 files changed, 66 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 557c15c..03f89ea 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,39 @@ x = torch.randn(1, 2048, 512) model(x) # (1, 2048, 512) ``` +Full encoder / decoder + +```python +import torch +from performer_pytorch import PerformerLM + +enc = PerformerLM( + num_tokens = 20000, + max_seq_len = 2048, + dim = 512, + depth = 6, + heads = 8 +).cuda() + +dec = PerformerLM( + num_tokens = 20000, + max_seq_len = 2048, + dim = 512, + depth = 6, + heads = 8, + causal = True, + cross_attend = True +).cuda() + +src = torch.randint(0, 20000, (1, 2048)).cuda() +tgt = torch.randint(0, 20000, (1, 2048)).cuda() +src_mask = torch.ones_like(src).bool() +tgt_mask = torch.ones_like(src).bool() + +encodings = enc(src, mask = src_mask, return_encodings = True) +logits = dec(tgt, context = encodings, mask = tgt_mask, context_mask = src_mask) # (1, 2048, 20000) +``` + Standalone self-attention layer with linear complexity in respect to sequence length, for replacing trained full-attention transformer self-attention layers. ```python diff --git a/performer_pytorch/performer_pytorch.py b/performer_pytorch/performer_pytorch.py index 70430ca..ff942a2 100644 --- a/performer_pytorch/performer_pytorch.py +++ b/performer_pytorch/performer_pytorch.py @@ -269,24 +269,30 @@ def __init__(self, dim, causal = False, heads = 8, local_heads = 0, local_window self.to_out = nn.Linear(dim, dim) self.dropout = nn.Dropout(dropout) - def forward(self, x, mask = None): + def forward(self, x, context = None, mask = None, context_mask = None): b, n, _, h, gh = *x.shape, self.heads, self.global_heads - qkv = map(lambda fn: fn(x), (self.to_q, self.to_k, self.to_v)) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) + cross_attend = exists(context) + context = default(context, x) + context_mask = default(context_mask, mask) + + q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v)) attn_outs = [] if not empty(q): - if exists(mask): - global_mask = mask[:, None, :, None] + if exists(context_mask): + global_mask = context_mask[:, None, :, None] k.masked_fill_(~global_mask, 0) out = self.fast_attention(q, k, v) attn_outs.append(out) if not empty(lq): + assert 'local attention is not compatible with cross attention' out = self.local_attn(lq, lk, lv, input_mask = mask) attn_outs.append(out) @@ -296,7 +302,7 @@ def forward(self, x, mask = None): return self.dropout(out) class Performer(nn.Module): - def __init__(self, dim, depth, heads, local_attn_heads = 0, local_window_size = 256, causal = False, ff_mult = 4, nb_features = None, reversible = False, ff_chunks = 1, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, use_scalenorm = False, use_rezero = False, ff_glu = False, ff_dropout = 0., attn_dropout = 0.): + def __init__(self, dim, depth, heads, local_attn_heads = 0, local_window_size = 256, causal = False, ff_mult = 4, nb_features = None, reversible = False, ff_chunks = 1, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, use_scalenorm = False, use_rezero = False, ff_glu = False, ff_dropout = 0., attn_dropout = 0., cross_attend = False): super().__init__() layers = nn.ModuleList([]) local_attn_heads = cast_tuple(local_attn_heads) @@ -317,16 +323,27 @@ def __init__(self, dim, depth, heads, local_attn_heads = 0, local_window_size = wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1)) ])) + if not cross_attend: + continue + + layers.append(nn.ModuleList([ + wrapper_fn(SelfAttention(dim, heads = heads, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, dropout = attn_dropout)), + wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1)) + ])) + execute_type = ReversibleSequence if reversible else SequentialSequence - route_attn = ((True, False),) * depth + + route_attn = ((True, False),) * depth * (2 if cross_attend else 1) + route_context = ((False, False), (True, False)) * depth attn_route_map = {'mask': route_attn} - self.net = execute_type(layers, args_route = {**attn_route_map}) + context_route_map = {'context': route_context, 'context_mask': route_context} if cross_attend else {} + self.net = execute_type(layers, args_route = {**attn_route_map, **context_route_map}) def forward(self, x, **kwargs): return self.net(x, **kwargs) class PerformerLM(nn.Module): - def __init__(self, *, num_tokens, max_seq_len, dim, depth, heads, local_attn_heads = 0, local_window_size = 256, causal = False, ff_mult = 4, nb_features = None, reversible = False, ff_chunks = 1, ff_glu = False, emb_dropout = 0., ff_dropout = 0., attn_dropout = 0., generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, use_scalenorm = False, use_rezero = False): + def __init__(self, *, num_tokens, max_seq_len, dim, depth, heads, local_attn_heads = 0, local_window_size = 256, causal = False, ff_mult = 4, nb_features = None, reversible = False, ff_chunks = 1, ff_glu = False, emb_dropout = 0., ff_dropout = 0., attn_dropout = 0., generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, use_scalenorm = False, use_rezero = False, cross_attend = False): super().__init__() local_attn_heads = cast_tuple(local_attn_heads) @@ -338,7 +355,7 @@ def __init__(self, *, num_tokens, max_seq_len, dim, depth, heads, local_attn_hea nn.init.normal_(self.token_emb.weight, std = 0.02) nn.init.normal_(self.pos_emb.weight, std = 0.02) - self.performer = Performer(dim, depth, heads, local_attn_heads, local_window_size, causal, ff_mult, nb_features, reversible, ff_chunks, generalized_attention, kernel_fn, qr_uniform_q, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout) + self.performer = Performer(dim, depth, heads, local_attn_heads, local_window_size, causal, ff_mult, nb_features, reversible, ff_chunks, generalized_attention, kernel_fn, qr_uniform_q, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend) self.norm = nn.LayerNorm(dim) def fix_projection_matrices_(self): @@ -347,7 +364,7 @@ def fix_projection_matrices_(self): for fast_attention in fast_attentions: fast_attention.set_projection_matrix(device) - def forward(self, x, **kwargs): + def forward(self, x, return_encodings = False, **kwargs): b, n, device = *x.shape, x.device # token and positional embeddings x = self.token_emb(x) @@ -359,4 +376,8 @@ def forward(self, x, **kwargs): # norm and to logits x = self.norm(x) + + if return_encodings: + return x + return x @ self.token_emb.weight.t() diff --git a/setup.py b/setup.py index 6759c27..726c724 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'performer-pytorch', packages = find_packages(exclude=['examples']), - version = '0.8.1', + version = '0.9.0', license='MIT', description = 'Performer - Pytorch', author = 'Phil Wang',