diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 745ef64eb..9ef5e6533 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -766,7 +766,9 @@ def train(self) -> None: if self._optimizer_in_bwd: torch.distributed.all_reduce(num_tokens) torch.distributed.all_reduce(running_loss) - current_loss = current_loss / num_tokens + + # We multiply by world_size to undo FSDP2 gradient normalization. + current_loss = current_loss * (world_size / num_tokens) current_loss.backward() @@ -778,7 +780,8 @@ def train(self) -> None: # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), diff --git a/recipes/knowledge_distillation_distributed.py b/recipes/knowledge_distillation_distributed.py index a11adabb9..77fc50927 100644 --- a/recipes/knowledge_distillation_distributed.py +++ b/recipes/knowledge_distillation_distributed.py @@ -769,7 +769,6 @@ def save_checkpoint(self, epoch: int) -> None: def _loss_step( self, batch: Dict[str, torch.Tensor] ) -> (torch.Tensor, torch.Tensor): - # Both are shape [b, s] tokens, labels = batch["tokens"], batch["labels"] @@ -875,7 +874,8 @@ def train(self) -> None: torch.distributed.all_reduce(running_class_loss) torch.distributed.all_reduce(running_kd_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) class_loss_to_log = running_class_loss.item() / num_tokens kd_loss_to_log = running_kd_loss.item() / num_tokens self._optimizer.step() diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index ac05e2060..2cdfcd801 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -822,7 +822,8 @@ def train(self) -> None: # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), diff --git a/recipes/lora_finetune_distributed_multi_dataset.py b/recipes/lora_finetune_distributed_multi_dataset.py index 30ece7034..ce482bfa2 100644 --- a/recipes/lora_finetune_distributed_multi_dataset.py +++ b/recipes/lora_finetune_distributed_multi_dataset.py @@ -851,7 +851,8 @@ def train(self) -> None: # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index eaa297457..f1b1302b7 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -837,7 +837,9 @@ def train(self) -> None: if self._optimizer_in_bwd: torch.distributed.all_reduce(num_tokens) torch.distributed.all_reduce(running_loss) - current_loss = current_loss / num_tokens + + # We multiply by world_size to undo FSDP2 gradient normalization. + current_loss = current_loss * (world_size / num_tokens) current_loss.backward() @@ -849,7 +851,8 @@ def train(self) -> None: # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), diff --git a/recipes/qat_lora_finetune_distributed.py b/recipes/qat_lora_finetune_distributed.py index b9080de77..133c39c94 100644 --- a/recipes/qat_lora_finetune_distributed.py +++ b/recipes/qat_lora_finetune_distributed.py @@ -866,7 +866,8 @@ def train(self) -> None: # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + # We multiply by world_size to undo FSDP2 gradient normalization. + training.scale_grads(self._model, world_size / num_tokens) if self._clip_grad_norm is not None: grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(),