Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 20, 2024
1 parent 0b6811b commit 59d786a
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion examples/wikitext/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def evaluate_model(model: nn.Module, dataset: data.Dataset, batch_size: int) ->

model.eval()
total_loss = 0.0
total_num = 0
for batch in dataloader:
with torch.no_grad():
lm_logits = model(
Expand All @@ -133,8 +134,9 @@ def evaluate_model(model: nn.Module, dataset: data.Dataset, batch_size: int) ->
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
total_loss += F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction="sum",
).detach()
total_num += shift_labels.view(-1).shape[0]
return total_loss.item() / len(dataloader.dataset)


Expand Down

0 comments on commit 59d786a

Please sign in to comment.