Skip to content

Commit

Permalink
Fix GLUE not training
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 20, 2024
1 parent 11aff5d commit c1291eb
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
3 changes: 1 addition & 2 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,7 @@ score_args = ScoreArguments(
all modules, this will keep track of intermediate module-wise scores.

- `query_gradient_rank`: The rank for the query batching. If `None`, no query batching will be used.
- `query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can use `torch.float32`,
but `torch.float64` is recommended.
- `query_gradient_svd_dtype`: `dtype` for performing singular value decomposition (SVD) for query batch. You can also use `torch.float32`.

- `cached_activation_cpu_offload`: Whether to offload cached activations to CPU.
- `score_dtype`: `dtype` for computing influence scores. You can use `torch.bfloat16` or `torch.float16`.
Expand Down
3 changes: 2 additions & 1 deletion examples/glue/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,17 @@ def train(
for epoch in range(num_train_epochs):
total_loss = 0.0
for batch in train_dataloader:
model.zero_grad()
loss = model(
input_ids=batch["input_ids"].to(device=DEVICE),
attention_mask=batch["attention_mask"].to(device=DEVICE),
token_type_ids=batch["token_type_ids"].to(device=DEVICE),
labels=batch["labels"].to(device=DEVICE)
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.detach().float()
logging.info(f"Epoch {epoch + 1} - Averaged Loss: {total_loss / len(dataset)}")
return model


Expand Down

0 comments on commit c1291eb

Please sign in to comment.