From 59d786a1337f467096ec00f39c143b978850c387 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Wed, 20 Mar 2024 13:54:41 -0400 Subject: [PATCH] minor --- examples/wikitext/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/wikitext/train.py b/examples/wikitext/train.py index 9f47cc5..113317c 100644 --- a/examples/wikitext/train.py +++ b/examples/wikitext/train.py @@ -123,6 +123,7 @@ def evaluate_model(model: nn.Module, dataset: data.Dataset, batch_size: int) -> model.eval() total_loss = 0.0 + total_num = 0 for batch in dataloader: with torch.no_grad(): lm_logits = model( @@ -133,8 +134,9 @@ def evaluate_model(model: nn.Module, dataset: data.Dataset, batch_size: int) -> shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() total_loss += F.cross_entropy( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction="sum", ).detach() + total_num += shift_labels.view(-1).shape[0] return total_loss.item() / len(dataloader.dataset)