From 35cd18deeb7965491873fcba4a15d581106eae39 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 11 Feb 2024 10:28:43 -0800 Subject: [PATCH] address https://github.com/lucidrains/recurrent-memory-transformer-pytorch/issues/22 --- .../recurrent_memory_transformer.py | 10 +++++++--- setup.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py b/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py index 313b962..1af7741 100644 --- a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py +++ b/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py @@ -113,12 +113,13 @@ def forward(self, x): x, gate = x.chunk(2, dim = -1) return x * F.gelu(gate) -def FeedForward(dim, mult = 4): +def FeedForward(dim, mult = 4, dropout = 0.): dim_inner = int(dim * mult * 2 / 3) return nn.Sequential( Linear(dim, dim_inner * 2, bias = False), GEGLU(), RMSNorm(dim_inner), + nn.Dropout(dropout), Linear(dim_inner, dim, bias = False) ) @@ -217,6 +218,8 @@ def __init__( dim_head = 64, heads = 8, ff_mult = 4, + attn_dropout = 0., + ff_dropout = 0., use_flash_attn = False, ignore_index = -1, abs_pos_emb = True, @@ -286,10 +289,11 @@ def __init__( causal = causal, heads = heads, use_flash_attn = use_flash_attn, - use_custom_causal_attn_mask = memory_not_causal + use_custom_causal_attn_mask = memory_not_causal, + dropout = attn_dropout ), RMSNorm(dim), - FeedForward(dim = dim, mult = ff_mult), + FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout), RMSNorm(dim) ])) diff --git a/setup.py b/setup.py index 5189119..75360d6 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'recurrent-memory-transformer-pytorch', packages = find_packages(exclude=[]), - version = '0.5.5', + version = '0.5.6', license='MIT', description = 'Recurrent Memory Transformer - Pytorch', author = 'Phil Wang',