-
Notifications
You must be signed in to change notification settings - Fork 264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BE] Lr schduler flatten #794
Changes from 8 commits
83d01fa
dbf1f07
9da918b
2b7e95f
c0f061f
c0a4057
4d776f0
12a0bb2
fd683e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -167,7 +167,7 @@ def linear_warmup_linear_decay( | |
return curr_adjustment | ||
|
||
|
||
class SchedulersContainer: | ||
class SchedulersContainer(Stateful): | ||
"""Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages""" | ||
|
||
def __init__(self, optimizers, lr_lambda) -> None: | ||
|
@@ -179,16 +179,17 @@ def step(self) -> None: | |
for scheduler in self.schedulers: | ||
scheduler.step() | ||
|
||
def get_lr_scheduler_state(self) -> Dict[str, Any]: | ||
state_dict = {} | ||
if len(self.schedulers) == 1: | ||
state_dict["lr_scheduler"] = self.schedulers[0] | ||
else: | ||
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. | ||
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks | ||
for idx, lr_scheduler in enumerate(self.schedulers): | ||
state_dict[f"lr_scheduler_{idx}"] = lr_scheduler | ||
return state_dict | ||
def state_dict(self) -> Dict[str, Any]: | ||
# We have lr_scheduler with the same state_dict for all optimizers, so can just save one. | ||
assert ( | ||
len(self.schedulers) > 0 | ||
), "Must have at least one scheduler to save state_dict" | ||
return self.schedulers[0].state_dict() | ||
|
||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | ||
# Load the same state_dict for all schedulers | ||
for scheduler in self.schedulers: | ||
scheduler.load_state_dict(state_dict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may need to explicitly copy Please add more detailed comment/NOTE here in the code. We should consider add unit test under the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I checked the LambdaLR scheduler code. It seems the only thing matters in the state is the current step which is an int, so Therefore, the behavior should be correct, as long as we don't modify But for safety, let's still explicitly call Let's document this in the code here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the comments. Add more detailed comments here with our discussion results. |
||
|
||
|
||
def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it won't be this simple. Both
OptimizersContainer
andModelWrapper
definestate_dict
andload_state_dict
to handle flattening and unflattening. Since we don't have things likeget_model_state_dict
andset_model_state_dict
for lr scheduler intorch.distributed.checkpoint.state_dict
, we likely will need to manually write something for the LambdaLR we are using. See #738 (comment)Let's work with @fegin on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compared lr_schedulers before and after flattening, with/without checkpoint
lr_scheduler values are consistent with changes here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it support DCP resharding? e.g. PP degree from 2 to 4 across two jobs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this PR doesn't address the resharding issue, hence the
[BE]
prefix. Supporting lr resharding deserve a separate PR.