diff --git a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py b/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py index f8ec0fd..4cde598 100644 --- a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py +++ b/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py @@ -210,6 +210,7 @@ def __init__( enhanced_xl_recurrence = False, # add simple method for enhancing receptive field of xl memories, from ernie-doc paper emb_gradient_frac = 0.1, # trick from cogview paper that leads to a bit more stability memory_not_causal = True, # flash attention behaves a bit more optimally if causal mask is not explicitly passed in - but if the memories perform better without a causal mask, it is necessary to have this turned on + resi_dual_scale = 1., # in the case of overflows in fp16 on the prenorm branch, set this to a value less than 1. ): super().__init__() self.causal = causal @@ -217,6 +218,9 @@ def __init__( self.emb_gradient_frac = emb_gradient_frac + assert 0 < resi_dual_scale <= 1., 'resiDual scale must be between 0 and 1' + self.resi_dual_scale = resi_dual_scale + assert num_memory_tokens > 0 self.token_emb = nn.Embedding(num_tokens, dim) @@ -384,7 +388,7 @@ def forward( # attention and feedforward - residual = x + residual = x * self.resi_dual_scale for attn, attn_post_norm, ff, ff_post_norm in self.layers: attn_out, xl_memories = attn(shift_fn(x), mask = mask, xl_memories = next(xl_memories_iter, None), rotary_emb = rotary_emb) @@ -392,13 +396,13 @@ def forward( x = attn_post_norm(x + attn_out) - residual = residual + attn_out + residual = residual + attn_out * self.resi_dual_scale ff_out = ff(shift_fn(x)) x = ff_post_norm(x + ff_out) - residual = residual + ff_out + residual = residual + ff_out * self.resi_dual_scale # whether to return xl memories diff --git a/setup.py b/setup.py index 854dfb9..46703a0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'recurrent-memory-transformer-pytorch', packages = find_packages(exclude=[]), - version = '0.4.0', + version = '0.4.1', license='MIT', description = 'Recurrent Memory Transformer - Pytorch', author = 'Phil Wang',