diff --git a/performer_pytorch/__init__.py b/performer_pytorch/__init__.py index f13b4b2..1ef16dc 100644 --- a/performer_pytorch/__init__.py +++ b/performer_pytorch/__init__.py @@ -1,3 +1,3 @@ -from performer_pytorch.performer_pytorch import PerformerLM, Performer, FastAttention, SelfAttention, CrossAttention +from performer_pytorch.performer_pytorch import PerformerLM, Performer, FastAttention, SelfAttention, CrossAttention, ProjectionUpdater from performer_pytorch.autoregressive_wrapper import AutoregressiveWrapper from performer_pytorch.performer_enc_dec import PerformerEncDec diff --git a/performer_pytorch/performer_pytorch.py b/performer_pytorch/performer_pytorch.py index 21848d3..32c8218 100644 --- a/performer_pytorch/performer_pytorch.py +++ b/performer_pytorch/performer_pytorch.py @@ -238,6 +238,36 @@ def forward(self, q, k, v): out = attn_fn(q, k, v) return out +# a module for keeping track of when to update the projections + +class ProjectionUpdater(nn.Module): + def __init__(self, instance, feature_redraw_interval): + super().__init__() + self.instance = instance + self.feature_redraw_interval = feature_redraw_interval + self.register_buffer('calls_since_last_redraw', torch.tensor(0)) + + def redraw_projections(self): + model = self.instance + + if not self.training: + return + + if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval: + device = get_module_device(model) + + fast_attentions = find_modules(model, FastAttention) + for fast_attention in fast_attentions: + fast_attention.redraw_projection_matrix(device) + + self.calls_since_last_redraw.zero_() + return + + self.calls_since_last_redraw += 1 + + def forward(self, x): + raise NotImplemented + # classes class ReZero(nn.Module): @@ -360,7 +390,7 @@ def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = v.masked_fill_(~global_mask, 0.) if exists(pos_emb) and not cross_attend: - q, k, = apply_rotary_pos_emb(q, k, pos_emb) + q, k = apply_rotary_pos_emb(q, k, pos_emb) out = self.fast_attention(q, k, v) attn_outs.append(out) @@ -493,31 +523,14 @@ def __init__( # keeping track of when to redraw projections for all attention layers self.auto_check_redraw = auto_check_redraw - self.feature_redraw_interval = feature_redraw_interval - self.register_buffer('calls_since_last_redraw', torch.tensor(0)) + self.proj_updater = ProjectionUpdater(self.net, feature_redraw_interval) def fix_projection_matrices_(self): - self.feature_redraw_interval = None - - def check_redraw_projections(self): - if not self.training: - return - - if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval: - device = get_module_device(self) - - fast_attentions = find_modules(self, FastAttention) - for fast_attention in fast_attentions: - fast_attention.redraw_projection_matrix(device) - - self.calls_since_last_redraw.zero_() - return - - self.calls_since_last_redraw += 1 + self.proj_updater.feature_redraw_interval = None def forward(self, x, **kwargs): if self.auto_check_redraw: - self.check_redraw_projections() + self.proj_updater.redraw_projections() return self.net(x, **kwargs) class PerformerLM(nn.Module): diff --git a/setup.py b/setup.py index 7277b5a..e1990cd 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'performer-pytorch', packages = find_packages(exclude=['examples']), - version = '1.0.8', + version = '1.0.9', license='MIT', description = 'Performer - Pytorch', author = 'Phil Wang',