From 72ee8da87393ca00f4215362612b101324c4286d Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 25 Apr 2023 13:26:29 -0700 Subject: [PATCH] labels can be directly passed in, if training encoder --- .../recurrent_memory_transformer.py | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py b/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py index 2e1fc2f..9a955b1 100644 --- a/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py +++ b/recurrent_memory_transformer_pytorch/recurrent_memory_transformer.py @@ -344,13 +344,14 @@ def forward( *, mask = None, return_loss = False, + labels = None, memory_replay_backprop = False, # whether to have the class do the backwards pass memory efficiently mrbp_loss_weight = 1. # if using memory replay backprop with gradient accumulation, scale loss by this factor ex. (1. / ) ): seq_len = self.seq_len labels = None - if return_loss or memory_replay_backprop: + if (return_loss or memory_replay_backprop) and not exists(labels): x, labels = x[:, :-1], x[:, 1:] # segment input diff --git a/setup.py b/setup.py index 8ed9a68..cabbf70 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'recurrent-memory-transformer-pytorch', packages = find_packages(exclude=[]), - version = '0.1.0', + version = '0.1.1', license='MIT', description = 'Recurrent Memory Transformer - Pytorch', author = 'Phil Wang',