From 67f2fdb7b71567ba74503553e577117cc9334756 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 25 Apr 2023 19:42:50 -0700 Subject: [PATCH] final tweak, so network can differentiate better between read and write memory --- .../recurrent_memory_transformer.py | 11 +++++++++-- setup.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py b/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py index e9b4ed9..8a8dcdf 100644 --- a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py +++ b/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py @@ -208,6 +208,9 @@ def __init__( self.num_memory_tokens = num_memory_tokens + self.read_memory_emb = nn.Parameter(torch.zeros(dim)) + nn.init.normal_(self.read_memory_emb, std = 0.02) + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) nn.init.normal_(self.memory_tokens, std = 0.02) @@ -265,8 +268,12 @@ def forward( write_memories = self.init_memory(b) - read_memories = default(read_memories, x[:, 0:0]) - read_mem_length = read_memories.shape[-2] + if exists(read_memories): + read_mem_length = mem_length + read_memories = read_memories + self.read_memory_emb + else: + read_mem_length = 0 + read_memories = x[:, 0:0] # concat to main sequence using einop's pack diff --git a/setup.py b/setup.py index d315503..da95df3 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'recurrent-memory-transformer-pytorch', packages = find_packages(exclude=[]), - version = '0.1.6', + version = '0.1.7', license='MIT', description = 'Recurrent Memory Transformer - Pytorch', author = 'Phil Wang',