-
Notifications
You must be signed in to change notification settings - Fork 483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix gradient scaling to account for world_size normalization #2172
Fix gradient scaling to account for world_size normalization #2172
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2172
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a6dc03a with merge base 27fd3a1 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Thanks @mirceamironenco for finding this bug and for making the fix! Apologies for the delay in getting back to it, I wanted to put together a minimal repro to validate this myself (I trust the code pointers, but I like seeing numerical parity). So I put together the following script(s) to convince myself. Can confirm that on identical toy models with identical data we see |
@ebsmothers any updates on this? We've also seen this in our mulit-node runs -- our grad norms are significantly smaller than what we see from other frameworks (e.g. NeMo). |
Hey @EugenHotaj thanks for the bump -- yes, we plan to land this soon. Actually the main reason for being slow on this PR (besides the holidays and PSC) is that we wanna be careful about breaking people who have e.g. their LR tuned to this setting. Ultimately I think we need to just rip the bandaid off and make the fix, then put comms here and in our Discord. Let me try to review and land later today |
34906b2
to
a6dc03a
Compare
Thanks for your patience @mirceamironenco. Just ran some quick experiments on my end on a single node with 8 GPUs. Attaching some plots below, WandB project is here. There are three runs: one on main, one on this PR, and one on this PR with learning rate scaled by 1/8. Unsurprisingly, it's similar to what @EugenHotaj mentioned -- the grad norm is off, almost exactly by a factor of 8. At least for my case the loss curves are pretty much identical too, not sure if there's a noticeable difference on multinode. |
@ebsmothers I've noticed this as well on my runs and found it a bit surprising. Is it expected that the losses would be identical? The gradients point in the same direction but I would have thought we'd see some divergence after taking a few hundred gradient steps. I guess gradient clipping / LR accounts for a lot of this? |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2172 +/- ##
===========================================
- Coverage 65.41% 23.95% -41.47%
===========================================
Files 344 352 +8
Lines 20658 20847 +189
===========================================
- Hits 13514 4993 -8521
- Misses 7144 15854 +8710 ☔ View full report in Codecov by Sentry. |
@EugenHotaj yeah this gave me a bit of a scare, especially considering that we don't enable gradient clipping by default. Actually I believe that the behavior has to do with the optimizer: I used SGD instead of AdamW and manually hacked in a really high value for the grad scaler just to make sure nothing was broken. In that case the difference is very noticeable (see below). I didn't think that momentum would result in consistent loss curves when scaling grads up and down, but maybe I just need to refresh my memory on Adam a bit. |
@ebsmothers any chance we also need to do the same to adam momentum params when using FSDP? Pretty surprising to me as well that Adam would lead to identical learning curves |
Just to make sure I understand, if you only hack the grad scaler to be very large but keep AdamW, the loss curves are still basically identical? (IIUC you did both in this comparison?) Maybe the loss curves being very similar is not so strange since the denominator will have a very large number of tokens compared to world_size, but some other ideas to battle test this more (I can implement these in a separate branch just for a comparison if you want):
# Must be done for each sharded module.
module = fully_shard(
module,
mesh=mesh,
reshard_after_forward=reshard_after_forward,
shard_placement_fn=shard_placement_fn,
mp_policy=mp_policy,
offload_policy=offload_policy,
)
# Change the reduce op manually
fsdp_param_group = fully_shard.state(module)._fsdp_param_group
fsdp_param_group.reduce_scatter_reduce_op = ReduceOp.SUM this happens before the optimizer is initialized, in case anything is happening there.
Potentially getting some feedback from the FSDP2 authors just as a sanity check could be useful. |
@mirceamironenco Yeah this is correct. Re your suggestions, (3) was the first one that came to my mind (also conveniently the easiest 😃) so I gave that a try on our distributed LoRA recipe. The below plot is the result of running AdamW with a much higher LR of 0.01, you can see that the two loss curves diverge (also unsurprisingly the loss blows up): But the point is that scaling the gradients can result in different loss curves with AdamW, it just doesn't really show up under our baseline configs (which I suppose is a good thing with respect to the impact of this whole world-size-scaling bug). Can also tag in our resident optimizer expert @janeyx99 in case she has any thoughts. TLDR for Jane is that we manually scale grads using this utility just before optimizer step, but surprisingly even scaling by a pretty large amount doesn't really mess with our loss curves when using AdamW (while for SGD there is a noticeable impact). |
Thanks @janeyx99! This is very helpful. Also I clearly should've just dug up the Adam paper. Direct quote: Assuming (Please excuse my sloppy LaTeX, I swear I was good at this once..) Also thanks @mirceamironenco for mentioning this over chat and forcing me to dig it up. So actually I think we are good here -- in fact I'm no longer even worried about breaking BC with this change after actually having done my homework. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for finding and fixing the bug @mirceamironenco! And thanks for your patience while we sorted out the whole Adam grad scaling thing in review. Based on our discussion, I think this is good to go.
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
scale_grads
:torchtune/recipes/full_finetune_distributed.py
Line 780 in 3518492
If A, B are processed on separate data parallel workers the current gradients would be produced by loss(A) / 2 + loss(B) / 2, and with the normalization done as before our loss becomes (loss(A) + loss(B)) / (2 * (|A| + |B|)). This PR accounts for world_size cancelling out the scaling factor.
I haven't seen very large differences wrt loss curves in my preliminary experiments after this change:
Where
world_size
means the gradient scaling factor isworld_size / num_tokens
and otherwise1 / num_tokens
. The commands to replicate these plots being:tune run --nproc_per_node 2 full_finetune_distributed --config llama3_2/3B_full metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=llama3.23b_fix metric_logger.name=world_size dataset.packed=True tokenizer.max_seq_len=512 compile=True
tune run --nproc_per_node 2 full_finetune_distributed.py --config configs/llama3_2/3B_full metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=llama3.23b_fix_noprompt metric_logger.name=world_size dataset.packed=True dataset.train_on_input=False tokenizer.max_seq_len=512 compile=True
Someone with more compute budget can probably get a better idea of the effect for larger models.
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example