-
Notifications
You must be signed in to change notification settings - Fork 264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BE] Lr schduler flatten #794
Conversation
@@ -183,9 +183,9 @@ def __init__( | |||
"model": ModelWrapper(model_parts), | |||
"optimizer": optimizers, | |||
"dataloader": dataloader, | |||
"lr_scheduler": lr_schedulers, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it won't be this simple. Both OptimizersContainer
and ModelWrapper
define state_dict
and load_state_dict
to handle flattening and unflattening. Since we don't have things like get_model_state_dict
and set_model_state_dict
for lr scheduler in torch.distributed.checkpoint.state_dict
, we likely will need to manually write something for the LambdaLR we are using. See #738 (comment)
Let's work with @fegin on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compared lr_schedulers before and after flattening, with/without checkpoint
lr_scheduler values are consistent with changes here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it support DCP resharding? e.g. PP degree from 2 to 4 across two jobs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this PR doesn't address the resharding issue, hence the [BE]
prefix. Supporting lr resharding deserve a separate PR.
torchtitan/optimizer.py
Outdated
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | ||
# Load the same state_dict for all schedulers | ||
for scheduler in self.schedulers: | ||
scheduler.load_state_dict(state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may need to explicitly copy state_dict
before loading. Otherwise there could be silent errors. See details of behavior here https://github.com/pytorch/pytorch/blob/v2.6.0/torch/optim/lr_scheduler.py#L359
Please add more detailed comment/NOTE here in the code.
Please add verified experiment results in the PR summary.
We should consider add unit test under the test
folder to guard the behavior, but feel free to do this in a later PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked the LambdaLR scheduler code. It seems the only thing matters in the state is the current step which is an int, so load_state_dict
will automatically make copies. See last_epoch
in https://github.com/pytorch/pytorch/blob/v2.6.0/torch/optim/lr_scheduler.py#L122
Therefore, the behavior should be correct, as long as we don't modify training.steps
and training.warmup_steps
when resuming from a checkpoint.
But for safety, let's still explicitly call .copy()
on the state_dict
, as the overhead is small.
Let's document this in the code here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments. Add more detailed comments here with our discussion results.
Update summary with multi-optimizer results, with the same lr values after checkpoint and resharding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks awesome. Thanks for the effort!
Currently, lr_scheduler is stored differently as optimizer, model and data_loader, with keys to be "lr_scheduler_0", "lr_scheduler_1", ... stored in the state
This PR aims to flatten lr_shceduler so that all the schedulers would be stored as a list of state_dict under self.state['lr_scheduler'], which is consistent with optimizer
Here we have the assumption that all the optimziers have the same lr_scheduler, thus only to save a single lr_scheduler's state_dict and load it to all the schedulers works here.
The lr_scheduler has the state_dict like:
{'base_lrs': [0.0003], 'last_epoch': 1, 'verbose': False, '_step_count': 2, '_get_lr_called_within_step': False, '_last_lr': [2.985074626865671e-06], 'lr_lambdas': [{}]}
The PR is tested by 2 parts:
[dp=2, tp=4, pp=1] -> [dp=2, tp=1, pp=4]
[dp=2, tp=1, pp=4] -> [dp=2, tp=4, pp=1]
date_loader does not support resharding right now.
logs:
[dp=2, tp=4, pp=1]
step 5 before saving to checkpoint:
[{'lr': 8.955223880597014e-06, ...}]
step 10 after loading from checkpoint and reshard to [dp=2, tp=2, pp=2]:
[{'lr': 1.6417910447761194e-05, ...}, {'lr': 1.6417910447761194e-05, ...}]
[dp=8, tp=1, pp=1]
step 5 without checkpoint:
[{'lr': 8.955223880597014e-06, ...}]
step 10 without checkpoint:
[{'lr': 1.6417910447761194e-05, ...}]
Before the flatten, rerun llama3_8b.toml from step 5 to step 10:
After the flatten, rerun llama3_8b.toml from step 5 to step 10:
![Screenshot 2025-01-16 at 2 40 21 PM](https://private-user-images.githubusercontent.com/170565853/404061421-b6ed68ae-2dbf-400a-b723-06eae6740ade.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzg3NTAzNzgsIm5iZiI6MTczODc1MDA3OCwicGF0aCI6Ii8xNzA1NjU4NTMvNDA0MDYxNDIxLWI2ZWQ2OGFlLTJkYmYtNDAwYS1iNzIzLTA2ZWFlNjc0MGFkZS5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjA1JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIwNVQxMDA3NThaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT0zNGI0ZTc1ODYwMWY4YWY5MzNmNmE1MWRkNzk1YmM3ZDA1MjkyZTMwOThlYjFlZmQ4NDZhMGM1NjYzMTE5YWYwJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.vdMxk_VzTASlPhQ7s0dPonx3S_rj5xNU7B28AQrJsnk)