Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BE] Lr schduler flatten #794

Merged
merged 9 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ def __init__(
"model": ModelWrapper(model_parts),
"optimizer": optimizers,
"dataloader": dataloader,
"lr_scheduler": lr_schedulers,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it won't be this simple. Both OptimizersContainer and ModelWrapper define state_dict and load_state_dict to handle flattening and unflattening. Since we don't have things like get_model_state_dict and set_model_state_dict for lr scheduler in torch.distributed.checkpoint.state_dict, we likely will need to manually write something for the LambdaLR we are using. See #738 (comment)

Let's work with @fegin on this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compared lr_schedulers before and after flattening, with/without checkpoint
lr_scheduler values are consistent with changes here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it support DCP resharding? e.g. PP degree from 2 to 4 across two jobs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this PR doesn't address the resharding issue, hence the [BE] prefix. Supporting lr resharding deserve a separate PR.

}
)
self.states.update(lr_schedulers.get_lr_scheduler_state())

self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval_type = (
Expand Down
23 changes: 12 additions & 11 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def linear_warmup_linear_decay(
return curr_adjustment


class SchedulersContainer:
class SchedulersContainer(Stateful):
"""Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages"""

def __init__(self, optimizers, lr_lambda) -> None:
Expand All @@ -179,16 +179,17 @@ def step(self) -> None:
for scheduler in self.schedulers:
scheduler.step()

def get_lr_scheduler_state(self) -> Dict[str, Any]:
state_dict = {}
if len(self.schedulers) == 1:
state_dict["lr_scheduler"] = self.schedulers[0]
else:
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks
for idx, lr_scheduler in enumerate(self.schedulers):
state_dict[f"lr_scheduler_{idx}"] = lr_scheduler
return state_dict
def state_dict(self) -> Dict[str, Any]:
# We have lr_scheduler with the same state_dict for all optimizers, so can just save one.
assert (
len(self.schedulers) > 0
), "Must have at least one scheduler to save state_dict"
return self.schedulers[0].state_dict()

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# Load the same state_dict for all schedulers
for scheduler in self.schedulers:
scheduler.load_state_dict(state_dict)
Copy link
Contributor

@tianyu-l tianyu-l Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need to explicitly copy state_dict before loading. Otherwise there could be silent errors. See details of behavior here https://github.com/pytorch/pytorch/blob/v2.6.0/torch/optim/lr_scheduler.py#L359

Please add more detailed comment/NOTE here in the code.
Please add verified experiment results in the PR summary.

We should consider add unit test under the test folder to guard the behavior, but feel free to do this in a later PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the LambdaLR scheduler code. It seems the only thing matters in the state is the current step which is an int, so load_state_dict will automatically make copies. See last_epoch in https://github.com/pytorch/pytorch/blob/v2.6.0/torch/optim/lr_scheduler.py#L122

Therefore, the behavior should be correct, as long as we don't modify training.steps and training.warmup_steps when resuming from a checkpoint.

But for safety, let's still explicitly call .copy() on the state_dict, as the overhead is small.

Let's document this in the code here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments. Add more detailed comments here with our discussion results.
Update summary with multi-optimizer results, with the same lr values after checkpoint and resharding



def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer:
Expand Down