From c1291eb0240c11c9fd727064d35d1a4caaf3d493 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Wed, 20 Mar 2024 03:56:30 -0400 Subject: [PATCH] Fix GLUE not training --- DOCUMENTATION.md | 3 +-- examples/glue/train.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 08c8d92..91aea41 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -355,8 +355,7 @@ score_args = ScoreArguments( all modules, this will keep track of intermediate module-wise scores. - `query_gradient_rank`: The rank for the query batching. If `None`, no query batching will be used. -- `query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can use `torch.float32`, -but `torch.float64` is recommended. +- `query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can also use `torch.float32`. - `cached_activation_cpu_offload`: Whether to offload cached activations to CPU. - `score_dtype`: `dtype` for computing influence scores. You can use `torch.bfloat16` or `torch.float16`. diff --git a/examples/glue/train.py b/examples/glue/train.py index cfff139..d277671 100644 --- a/examples/glue/train.py +++ b/examples/glue/train.py @@ -104,16 +104,17 @@ def train( for epoch in range(num_train_epochs): total_loss = 0.0 for batch in train_dataloader: - model.zero_grad() loss = model( input_ids=batch["input_ids"].to(device=DEVICE), attention_mask=batch["attention_mask"].to(device=DEVICE), token_type_ids=batch["token_type_ids"].to(device=DEVICE), labels=batch["labels"].to(device=DEVICE) ) + optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.detach().float() + logging.info(f"Epoch {epoch + 1} - Averaged Loss: {total_loss / len(dataset)}") return model