Skip to content

Commit

Permalink
Add comments, account for optimizer_in_bwd case.
Browse files Browse the repository at this point in the history
  • Loading branch information
mirceamironenco committed Dec 18, 2024
1 parent f07241b commit 34906b2
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 2 deletions.
5 changes: 4 additions & 1 deletion recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,9 @@ def train(self) -> None:
if self._optimizer_in_bwd:
torch.distributed.all_reduce(num_tokens)
torch.distributed.all_reduce(running_loss)
current_loss = current_loss / num_tokens

# We multiply by world_size to undo FSDP2 gradient normalization.
current_loss = current_loss * (world_size / num_tokens)

current_loss.backward()

Expand All @@ -777,6 +779,7 @@ def train(self) -> None:
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
Expand Down
1 change: 1 addition & 0 deletions recipes/knowledge_distillation_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,7 @@ def train(self) -> None:
torch.distributed.all_reduce(running_class_loss)
torch.distributed.all_reduce(running_kd_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
class_loss_to_log = running_class_loss.item() / num_tokens
kd_loss_to_log = running_kd_loss.item() / num_tokens
Expand Down
1 change: 1 addition & 0 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,7 @@ def train(self) -> None:
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
Expand Down
1 change: 1 addition & 0 deletions recipes/lora_finetune_distributed_multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,7 @@ def train(self) -> None:
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
Expand Down
5 changes: 4 additions & 1 deletion recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,9 @@ def train(self) -> None:
if self._optimizer_in_bwd:
torch.distributed.all_reduce(num_tokens)
torch.distributed.all_reduce(running_loss)
current_loss = current_loss / num_tokens

# We multiply by world_size to undo FSDP2 gradient normalization.
current_loss = current_loss * (world_size / num_tokens)

current_loss.backward()

Expand All @@ -841,6 +843,7 @@ def train(self) -> None:
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
Expand Down
1 change: 1 addition & 0 deletions recipes/qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,7 @@ def train(self) -> None:
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)
# Manually scale the gradients from unnormalized loss by total # of tokens
# We multiply by world_size to undo FSDP2 gradient normalization.
training.scale_grads(self._model, world_size / num_tokens)
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
Expand Down

0 comments on commit 34906b2

Please sign in to comment.