Skip to content

Commit

Permalink
address #22
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 11, 2024
1 parent d45ef72 commit 35cd18d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,13 @@ def forward(self, x):
x, gate = x.chunk(2, dim = -1)
return x * F.gelu(gate)

def FeedForward(dim, mult = 4):
def FeedForward(dim, mult = 4, dropout = 0.):
dim_inner = int(dim * mult * 2 / 3)
return nn.Sequential(
Linear(dim, dim_inner * 2, bias = False),
GEGLU(),
RMSNorm(dim_inner),
nn.Dropout(dropout),
Linear(dim_inner, dim, bias = False)
)

Expand Down Expand Up @@ -217,6 +218,8 @@ def __init__(
dim_head = 64,
heads = 8,
ff_mult = 4,
attn_dropout = 0.,
ff_dropout = 0.,
use_flash_attn = False,
ignore_index = -1,
abs_pos_emb = True,
Expand Down Expand Up @@ -286,10 +289,11 @@ def __init__(
causal = causal,
heads = heads,
use_flash_attn = use_flash_attn,
use_custom_causal_attn_mask = memory_not_causal
use_custom_causal_attn_mask = memory_not_causal,
dropout = attn_dropout
),
RMSNorm(dim),
FeedForward(dim = dim, mult = ff_mult),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout),
RMSNorm(dim)
]))

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'recurrent-memory-transformer-pytorch',
packages = find_packages(exclude=[]),
version = '0.5.5',
version = '0.5.6',
license='MIT',
description = 'Recurrent Memory Transformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 35cd18d

Please sign in to comment.