Skip to content

Commit

Permalink
Fix bug in evaluation code
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 20, 2024
1 parent 6696431 commit 8c42bc8
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions examples/glue/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from accelerate.utils import set_seed
from torch import nn
from torch.utils import data
from tqdm import tqdm
from transformers import default_data_collator

from examples.glue.pipeline import construct_bert, get_glue_dataset
Expand Down Expand Up @@ -129,9 +128,9 @@ def evaluate_model(model: nn.Module, dataset: data.Dataset, batch_size: int) ->
for batch in dataloader:
with torch.no_grad():
logits = model(
batch["input_ids"].to(device=DEVICE),
batch["token_type_ids"].to(device=DEVICE),
batch["attention_mask"].to(device=DEVICE),
input_ids=batch["input_ids"].to(device=DEVICE),
attention_mask=batch["attention_mask"].to(device=DEVICE),
token_type_ids=batch["token_type_ids"].to(device=DEVICE),
).logits
labels = batch["labels"].to(device=DEVICE)
total_loss += F.cross_entropy(logits, labels, reduction="sum").detach()
Expand Down

0 comments on commit 8c42bc8

Please sign in to comment.