Skip to content

Commit

Permalink
[BE] Lr schduler flatten (#794)
Browse files Browse the repository at this point in the history
Currently, lr_scheduler is stored differently as optimizer, model and
data_loader, with keys to be "lr_scheduler_0", "lr_scheduler_1", ...
stored in the state
This PR aims to flatten lr_shceduler so that all the schedulers would be
stored as a list of state_dict under self.state['lr_scheduler'], which
is consistent with optimizer

Here we have the assumption that all the optimziers have the same
lr_scheduler, thus only to save a single lr_scheduler's state_dict and
load it to all the schedulers works here.

The lr_scheduler has the state_dict like:
`{'base_lrs': [0.0003], 'last_epoch': 1, 'verbose': False,
'_step_count': 2, '_get_lr_called_within_step': False, '_last_lr':
[2.985074626865671e-06], 'lr_lambdas': [{}]}`

The PR is tested by 2 parts:
1. test lr_scheduler value before and after checkpoint, resharding with
degree changes on tp and pp.
[dp=2, tp=4, pp=1] -> [dp=2, tp=1, pp=4]
[dp=2, tp=1, pp=4] -> [dp=2, tp=4, pp=1]
date_loader does not support resharding right now.

logs:
[dp=2, tp=4, pp=1] 
step 5 before saving to checkpoint:
[{'lr': 8.955223880597014e-06, ...}]

step 10 after loading from checkpoint and reshard to [dp=2, tp=2, pp=2]:
[{'lr': 1.6417910447761194e-05, ...}, {'lr': 1.6417910447761194e-05,
...}]

[dp=8, tp=1, pp=1]
step 5 without checkpoint:
[{'lr': 8.955223880597014e-06, ...}]
step 10 without checkpoint:
[{'lr': 1.6417910447761194e-05, ...}]

2. Memory trace:
Before the flatten, rerun llama3_8b.toml from step 5 to step 10:
<img width="1166" alt="Screenshot 2025-01-16 at 2 40 03 PM"
src="https://github.com/user-attachments/assets/d3e84d63-30be-4604-823b-68bd217498a0"
/>

After the flatten, rerun llama3_8b.toml from step 5 to step 10:
<img width="1166" alt="Screenshot 2025-01-16 at 2 40 21 PM"
src="https://github.com/user-attachments/assets/b6ed68ae-2dbf-400a-b723-06eae6740ade"
/>
  • Loading branch information
mori360 authored Jan 31, 2025
1 parent 2271b63 commit cca0702
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
2 changes: 1 addition & 1 deletion torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ def __init__(
"model": ModelWrapper(model_parts),
"optimizer": optimizers,
"dataloader": dataloader,
"lr_scheduler": lr_schedulers,
}
)
self.states.update(lr_schedulers.get_lr_scheduler_state())

self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
self.interval_type = (
Expand Down
27 changes: 16 additions & 11 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def linear_warmup_linear_decay(
return curr_adjustment


class SchedulersContainer:
class SchedulersContainer(Stateful):
"""Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages"""

def __init__(self, optimizers, lr_lambda) -> None:
Expand All @@ -190,16 +190,21 @@ def step(self) -> None:
for scheduler in self.schedulers:
scheduler.step()

def get_lr_scheduler_state(self) -> Dict[str, Any]:
state_dict = {}
if len(self.schedulers) == 1:
state_dict["lr_scheduler"] = self.schedulers[0]
else:
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks
for idx, lr_scheduler in enumerate(self.schedulers):
state_dict[f"lr_scheduler_{idx}"] = lr_scheduler
return state_dict
def state_dict(self) -> Dict[str, Any]:
# Currently, we have one scheduler per optimizer. However, when using MultiSchedule PP or optimizer-in-backward,
# there are multiple optimizers and schedulers, but the scheduler state_dict remains the same for all.
# Therefore, we only save the first one and later load it for all.
assert (
len(self.schedulers) > 0
), "Must have at least one scheduler to save state_dict"
return self.schedulers[0].state_dict()

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# Load the same state_dict for all schedulers. The key value we're concerned with in scheduler.state_dict() is `last_epoch`,
# which is an integer that will be automatically copied. As long as `training.steps` and `training.warmup_steps` remain
# unchanged when resuming from a checkpoint, this approach is safe. We call `.copy()` here to ensure extra safety.
for scheduler in self.schedulers:
scheduler.load_state_dict(state_dict.copy())


def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer:
Expand Down

0 comments on commit cca0702

Please sign in to comment.