Skip to content

Commit

Permalink
turns out resiDual is prone to overflowing in fp16. add the scaling s…
Browse files Browse the repository at this point in the history
…olution proposed in the paper
  • Loading branch information
lucidrains committed May 27, 2023
1 parent a8d4345 commit eee6c34
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,17 @@ def __init__(
enhanced_xl_recurrence = False, # add simple method for enhancing receptive field of xl memories, from ernie-doc paper
emb_gradient_frac = 0.1, # trick from cogview paper that leads to a bit more stability
memory_not_causal = True, # flash attention behaves a bit more optimally if causal mask is not explicitly passed in - but if the memories perform better without a causal mask, it is necessary to have this turned on
resi_dual_scale = 1., # in the case of overflows in fp16 on the prenorm branch, set this to a value less than 1.
):
super().__init__()
self.causal = causal
self.seq_len = seq_len

self.emb_gradient_frac = emb_gradient_frac

assert 0 < resi_dual_scale <= 1., 'resiDual scale must be between 0 and 1'
self.resi_dual_scale = resi_dual_scale

assert num_memory_tokens > 0

self.token_emb = nn.Embedding(num_tokens, dim)
Expand Down Expand Up @@ -384,21 +388,21 @@ def forward(

# attention and feedforward

residual = x
residual = x * self.resi_dual_scale

for attn, attn_post_norm, ff, ff_post_norm in self.layers:
attn_out, xl_memories = attn(shift_fn(x), mask = mask, xl_memories = next(xl_memories_iter, None), rotary_emb = rotary_emb)
new_xl_memories.append(xl_memories)

x = attn_post_norm(x + attn_out)

residual = residual + attn_out
residual = residual + attn_out * self.resi_dual_scale

ff_out = ff(shift_fn(x))

x = ff_post_norm(x + ff_out)

residual = residual + ff_out
residual = residual + ff_out * self.resi_dual_scale

# whether to return xl memories

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'recurrent-memory-transformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.4.0',
version = '0.4.1',
license='MIT',
description = 'Recurrent Memory Transformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit eee6c34

Please sign in to comment.