From edf0a2634ab52d0fc9fdefe40445a2bedd5ee372 Mon Sep 17 00:00:00 2001 From: rizar Date: Mon, 3 Mar 2025 15:17:18 -0500 Subject: [PATCH] we should not autocast --- tapeagents/finetune/finetune.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tapeagents/finetune/finetune.py b/tapeagents/finetune/finetune.py index 18e0bf44..8ce37831 100644 --- a/tapeagents/finetune/finetune.py +++ b/tapeagents/finetune/finetune.py @@ -203,16 +203,15 @@ def toggle_sync(sync: bool): torch.cuda.empty_cache() do_optimizer_step = training_metrics.passes % args.gradient_accumulation_passes == 0 - with torch.autocast("cuda"): - with toggle_sync(do_optimizer_step): - loss, this_step_rl_metrics = forward(model, batch) - for k, v in this_step_rl_metrics.items(): - rl_metrics[k].append(v) - training_metrics.train_loss = loss.item() - training_metrics.lr = optimizer.param_groups[0]["lr"] - training_metrics.max_batch_len = max(batch["input_ids"].shape[1], training_metrics.max_batch_len) - training_metrics.min_batch_len = min(batch["input_ids"].shape[1], training_metrics.min_batch_len) - accelerator.backward(loss / args.gradient_accumulation_passes) + with toggle_sync(do_optimizer_step): + loss, this_step_rl_metrics = forward(model, batch) + for k, v in this_step_rl_metrics.items(): + rl_metrics[k].append(v) + training_metrics.train_loss = loss.item() + training_metrics.lr = optimizer.param_groups[0]["lr"] + training_metrics.max_batch_len = max(batch["input_ids"].shape[1], training_metrics.max_batch_len) + training_metrics.min_batch_len = min(batch["input_ids"].shape[1], training_metrics.min_batch_len) + accelerator.backward(loss / args.gradient_accumulation_passes) if not do_optimizer_step: continue