diff --git a/examples/glue/train.py b/examples/glue/train.py index b61428a..9a57510 100644 --- a/examples/glue/train.py +++ b/examples/glue/train.py @@ -9,7 +9,6 @@ from accelerate.utils import set_seed from torch import nn from torch.utils import data -from tqdm import tqdm from transformers import default_data_collator from examples.glue.pipeline import construct_bert, get_glue_dataset @@ -129,9 +128,9 @@ def evaluate_model(model: nn.Module, dataset: data.Dataset, batch_size: int) -> for batch in dataloader: with torch.no_grad(): logits = model( - batch["input_ids"].to(device=DEVICE), - batch["token_type_ids"].to(device=DEVICE), - batch["attention_mask"].to(device=DEVICE), + input_ids=batch["input_ids"].to(device=DEVICE), + attention_mask=batch["attention_mask"].to(device=DEVICE), + token_type_ids=batch["token_type_ids"].to(device=DEVICE), ).logits labels = batch["labels"].to(device=DEVICE) total_loss += F.cross_entropy(logits, labels, reduction="sum").detach()