Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: adapter checkpoint loading on resume #3769

Merged
merged 5 commits into from
Nov 8, 2023
Merged

fix: adapter checkpoint loading on resume #3769

merged 5 commits into from
Nov 8, 2023

Conversation

geoffreyangus
Copy link
Contributor

This PR ensures that DeepSpeed does not throw an error on resume when loading adapter checkpoints.

Adapter checkpoints only contain a subset of the model's weights. This means that strict checkpoint loading throws a RuntimeError that may look like the following:

> /home/ray/anaconda3/lib/python3.8/site-packages/ludwig/trainers/trainer.py(839)train()
-> self.resume_weights_and_optimizer(training_checkpoints_path, checkpoint)
  /home/ray/anaconda3/lib/python3.8/site-packages/ludwig/trainers/trainer.py(1466)resume_weights_and_optimizer()
-> CheckpointManager.load_latest_checkpoint(checkpoint, model_weights_progress_path, self.device)
  /home/ray/anaconda3/lib/python3.8/site-packages/ludwig/utils/checkpoint_utils.py(332)load_latest_checkpoint()
-> checkpoint.load(last_ckpt, device)
  /home/ray/anaconda3/lib/python3.8/site-packages/ludwig/distributed/deepspeed.py(224)load()
-> _, client_state = self.model.load_checkpoint(save_path, load_lr_scheduler_states=False)
  /home/ray/anaconda3/lib/python3.8/site-packages/deepspeed/runtime/engine.py(2705)load_checkpoint()
-> load_path, client_states = self._load_checkpoint(load_dir,
  /home/ray/anaconda3/lib/python3.8/site-packages/deepspeed/runtime/engine.py(2773)_load_checkpoint()
-> self.load_module_state_dict(checkpoint=checkpoint,
  /home/ray/anaconda3/lib/python3.8/site-packages/deepspeed/runtime/engine.py(2568)load_module_state_dict()
-> self.module.load_state_dict(
  /home/ray/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py(2041)load_state_dict()
-> raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for LLM:
Missing key(s) in state_dict: "model.base_model.model.model.embed_tokens.weight", "model.base_model.model.model.layers.0.self_attn.q_proj.weight", "model.base_model.model.model.layers.0.self_attn.k_proj.weight", "model.base_model.model.model.layers.0.self_attn.v_proj.weight", "model.base_model.model.model.layers.0.self_attn.o_proj.weight",
...

since the base model weights are not stored in the checkpoint. This PR ensures that this strict constraint is relaxed.

Copy link
Contributor

@arnavgarg1 arnavgarg1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! Thanks for this PR @geoffreyangus

Copy link

github-actions bot commented Nov 7, 2023

Unit Test Results

  6 files  ±0    6 suites  ±0   19m 47s ⏱️ +6s
12 tests ±0    9 ✔️ ±0    3 💤 ±0  0 ±0 
60 runs  ±0  42 ✔️ ±0  18 💤 ±0  0 ±0 

Results for commit 8993a33. ± Comparison against base commit 9992f32.

♻️ This comment has been updated with latest results.

@geoffreyangus geoffreyangus merged commit 7458885 into master Nov 8, 2023
18 checks passed
@geoffreyangus geoffreyangus deleted the MLX-1509 branch November 8, 2023 01:06
Infernaught pushed a commit to Infernaught/nightlyfix that referenced this pull request Nov 9, 2023
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants