From cca07028e440de6a13189d251c28337bd34256ef Mon Sep 17 00:00:00 2001 From: yifanmao Date: Fri, 31 Jan 2025 14:10:43 -0800 Subject: [PATCH] [BE] Lr schduler flatten (#794) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently, lr_scheduler is stored differently as optimizer, model and data_loader, with keys to be "lr_scheduler_0", "lr_scheduler_1", ... stored in the state This PR aims to flatten lr_shceduler so that all the schedulers would be stored as a list of state_dict under self.state['lr_scheduler'], which is consistent with optimizer Here we have the assumption that all the optimziers have the same lr_scheduler, thus only to save a single lr_scheduler's state_dict and load it to all the schedulers works here. The lr_scheduler has the state_dict like: `{'base_lrs': [0.0003], 'last_epoch': 1, 'verbose': False, '_step_count': 2, '_get_lr_called_within_step': False, '_last_lr': [2.985074626865671e-06], 'lr_lambdas': [{}]}` The PR is tested by 2 parts: 1. test lr_scheduler value before and after checkpoint, resharding with degree changes on tp and pp. [dp=2, tp=4, pp=1] -> [dp=2, tp=1, pp=4] [dp=2, tp=1, pp=4] -> [dp=2, tp=4, pp=1] date_loader does not support resharding right now. logs: [dp=2, tp=4, pp=1] step 5 before saving to checkpoint: [{'lr': 8.955223880597014e-06, ...}] step 10 after loading from checkpoint and reshard to [dp=2, tp=2, pp=2]: [{'lr': 1.6417910447761194e-05, ...}, {'lr': 1.6417910447761194e-05, ...}] [dp=8, tp=1, pp=1] step 5 without checkpoint: [{'lr': 8.955223880597014e-06, ...}] step 10 without checkpoint: [{'lr': 1.6417910447761194e-05, ...}] 2. Memory trace: Before the flatten, rerun llama3_8b.toml from step 5 to step 10: Screenshot 2025-01-16 at 2 40 03 PM After the flatten, rerun llama3_8b.toml from step 5 to step 10: Screenshot 2025-01-16 at 2 40 21 PM --- torchtitan/checkpoint.py | 2 +- torchtitan/optimizer.py | 27 ++++++++++++++++----------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 68479aad..7d143383 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -183,9 +183,9 @@ def __init__( "model": ModelWrapper(model_parts), "optimizer": optimizers, "dataloader": dataloader, + "lr_scheduler": lr_schedulers, } ) - self.states.update(lr_schedulers.get_lr_scheduler_state()) self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder) self.interval_type = ( diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index 8927125f..1b724b7a 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -178,7 +178,7 @@ def linear_warmup_linear_decay( return curr_adjustment -class SchedulersContainer: +class SchedulersContainer(Stateful): """Util for calling step on multiple learning rate schedulers needed for virtual pipeline stages""" def __init__(self, optimizers, lr_lambda) -> None: @@ -190,16 +190,21 @@ def step(self) -> None: for scheduler in self.schedulers: scheduler.step() - def get_lr_scheduler_state(self) -> Dict[str, Any]: - state_dict = {} - if len(self.schedulers) == 1: - state_dict["lr_scheduler"] = self.schedulers[0] - else: - # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. - # It should only support saving and loading a distributed checkpoint with the same number of pp ranks - for idx, lr_scheduler in enumerate(self.schedulers): - state_dict[f"lr_scheduler_{idx}"] = lr_scheduler - return state_dict + def state_dict(self) -> Dict[str, Any]: + # Currently, we have one scheduler per optimizer. However, when using MultiSchedule PP or optimizer-in-backward, + # there are multiple optimizers and schedulers, but the scheduler state_dict remains the same for all. + # Therefore, we only save the first one and later load it for all. + assert ( + len(self.schedulers) > 0 + ), "Must have at least one scheduler to save state_dict" + return self.schedulers[0].state_dict() + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + # Load the same state_dict for all schedulers. The key value we're concerned with in scheduler.state_dict() is `last_epoch`, + # which is an integer that will be automatically copied. As long as `training.steps` and `training.warmup_steps` remain + # unchanged when resuming from a checkpoint, this approach is safe. We call `.copy()` here to ensure extra safety. + for scheduler in self.schedulers: + scheduler.load_state_dict(state_dict.copy()) def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer: