From e8bd0fa5adabfd96347b57272ada9927f431c8a2 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Wed, 26 Feb 2025 17:52:18 -0800 Subject: [PATCH] Update hf_auto_model_for_causal_lm.py Signed-off-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> --- nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py index b1ee557c8246..1a716e976d33 100644 --- a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -290,7 +290,7 @@ def on_before_optimizer_step(self, optimizer) -> None: mean_loss = torch.tensor([mean_loss], device=self.device).detach() dist.all_reduce(mean_loss, group=group, op=dist.ReduceOp.AVG) mean_loss = mean_loss.item() - tps = torch.tensor([tps], device=self.device, dtype=torch.int64) + tps = torch.tensor([tps], device=self.device, dtype=torch.int64).detach() dist.all_reduce(tps, group=group, op=dist.ReduceOp.SUM) tps = tps.item()