Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Currently, lr_scheduler is stored differently as optimizer, model and data_loader, with keys to be "lr_scheduler_0", "lr_scheduler_1", ... stored in the state This PR aims to flatten lr_shceduler so that all the schedulers would be stored as a list of state_dict under self.state['lr_scheduler'], which is consistent with optimizer Here we have the assumption that all the optimziers have the same lr_scheduler, thus only to save a single lr_scheduler's state_dict and load it to all the schedulers works here. The lr_scheduler has the state_dict like: `{'base_lrs': [0.0003], 'last_epoch': 1, 'verbose': False, '_step_count': 2, '_get_lr_called_within_step': False, '_last_lr': [2.985074626865671e-06], 'lr_lambdas': [{}]}` The PR is tested by 2 parts: 1. test lr_scheduler value before and after checkpoint, resharding with degree changes on tp and pp. [dp=2, tp=4, pp=1] -> [dp=2, tp=1, pp=4] [dp=2, tp=1, pp=4] -> [dp=2, tp=4, pp=1] date_loader does not support resharding right now. logs: [dp=2, tp=4, pp=1] step 5 before saving to checkpoint: [{'lr': 8.955223880597014e-06, ...}] step 10 after loading from checkpoint and reshard to [dp=2, tp=2, pp=2]: [{'lr': 1.6417910447761194e-05, ...}, {'lr': 1.6417910447761194e-05, ...}] [dp=8, tp=1, pp=1] step 5 without checkpoint: [{'lr': 8.955223880597014e-06, ...}] step 10 without checkpoint: [{'lr': 1.6417910447761194e-05, ...}] 2. Memory trace: Before the flatten, rerun llama3_8b.toml from step 5 to step 10: <img width="1166" alt="Screenshot 2025-01-16 at 2 40 03 PM" src="https://github.com/user-attachments/assets/d3e84d63-30be-4604-823b-68bd217498a0" /> After the flatten, rerun llama3_8b.toml from step 5 to step 10: <img width="1166" alt="Screenshot 2025-01-16 at 2 40 21 PM" src="https://github.com/user-attachments/assets/b6ed68ae-2dbf-400a-b723-06eae6740ade" />
- Loading branch information