diff --git a/examples/wikitext/train.py b/examples/wikitext/train.py index 9f47cc5..113317c 100644 --- a/examples/wikitext/train.py +++ b/examples/wikitext/train.py @@ -123,6 +123,7 @@ def evaluate_model(model: nn.Module, dataset: data.Dataset, batch_size: int) -> model.eval() total_loss = 0.0 + total_num = 0 for batch in dataloader: with torch.no_grad(): lm_logits = model( @@ -133,8 +134,9 @@ def evaluate_model(model: nn.Module, dataset: data.Dataset, batch_size: int) -> shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() total_loss += F.cross_entropy( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction="sum", ).detach() + total_num += shift_labels.view(-1).shape[0] return total_loss.item() / len(dataloader.dataset)