Skip to content

Commit

Permalink
expose projection update logic, for easy use in alphafold2 project
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 21, 2021
1 parent 9ae49af commit 7a68c9c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 23 deletions.
2 changes: 1 addition & 1 deletion performer_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from performer_pytorch.performer_pytorch import PerformerLM, Performer, FastAttention, SelfAttention, CrossAttention
from performer_pytorch.performer_pytorch import PerformerLM, Performer, FastAttention, SelfAttention, CrossAttention, ProjectionUpdater
from performer_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from performer_pytorch.performer_enc_dec import PerformerEncDec
55 changes: 34 additions & 21 deletions performer_pytorch/performer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,36 @@ def forward(self, q, k, v):
out = attn_fn(q, k, v)
return out

# a module for keeping track of when to update the projections

class ProjectionUpdater(nn.Module):
def __init__(self, instance, feature_redraw_interval):
super().__init__()
self.instance = instance
self.feature_redraw_interval = feature_redraw_interval
self.register_buffer('calls_since_last_redraw', torch.tensor(0))

def redraw_projections(self):
model = self.instance

if not self.training:
return

if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval:
device = get_module_device(model)

fast_attentions = find_modules(model, FastAttention)
for fast_attention in fast_attentions:
fast_attention.redraw_projection_matrix(device)

self.calls_since_last_redraw.zero_()
return

self.calls_since_last_redraw += 1

def forward(self, x):
raise NotImplemented

# classes

class ReZero(nn.Module):
Expand Down Expand Up @@ -360,7 +390,7 @@ def forward(self, x, pos_emb = None, context = None, mask = None, context_mask =
v.masked_fill_(~global_mask, 0.)

if exists(pos_emb) and not cross_attend:
q, k, = apply_rotary_pos_emb(q, k, pos_emb)
q, k = apply_rotary_pos_emb(q, k, pos_emb)

out = self.fast_attention(q, k, v)
attn_outs.append(out)
Expand Down Expand Up @@ -493,31 +523,14 @@ def __init__(

# keeping track of when to redraw projections for all attention layers
self.auto_check_redraw = auto_check_redraw
self.feature_redraw_interval = feature_redraw_interval
self.register_buffer('calls_since_last_redraw', torch.tensor(0))
self.proj_updater = ProjectionUpdater(self.net, feature_redraw_interval)

def fix_projection_matrices_(self):
self.feature_redraw_interval = None

def check_redraw_projections(self):
if not self.training:
return

if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval:
device = get_module_device(self)

fast_attentions = find_modules(self, FastAttention)
for fast_attention in fast_attentions:
fast_attention.redraw_projection_matrix(device)

self.calls_since_last_redraw.zero_()
return

self.calls_since_last_redraw += 1
self.proj_updater.feature_redraw_interval = None

def forward(self, x, **kwargs):
if self.auto_check_redraw:
self.check_redraw_projections()
self.proj_updater.redraw_projections()
return self.net(x, **kwargs)

class PerformerLM(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'performer-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.0.8',
version = '1.0.9',
license='MIT',
description = 'Performer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 7a68c9c

Please sign in to comment.