Skip to content

Commit

Permalink
fix: adapter checkpoint loading on resume (#3769)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
geoffreyangus and pre-commit-ci[bot] authored Nov 8, 2023
1 parent 9992f32 commit 7458885
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion ludwig/distributed/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,14 @@ def load(self, save_path: str, device: Optional[torch.device] = None) -> bool:
https://deepspeed.readthedocs.io/en/latest/model-checkpointing.html#loading-training-checkpoints
"""
_, client_state = self.model.load_checkpoint(save_path, load_lr_scheduler_states=False)
# NOTE(geoffrey): `load_module_strict=False` because this code path is frequently used to load models trained
# using adapter-based fine-tuning, where the checkpoints only contain the adapter weights, and not the full
# model weights. This may lead to silent, unexpected behavior for resuming full model fine-tuning,
# where all the model weights *must* be loaded in.
# TODO(geoffrey): Add a boolean arg to function to control load_module_strict behavior.
_, client_state = self.model.load_checkpoint(
save_path, load_lr_scheduler_states=False, load_module_strict=False
)
self.global_step = self._get_global_step(client_state, save_path)
if self.scheduler is not None and "scheduler_state" in client_state:
self.scheduler.load_state_dict(client_state["scheduler_state"])
Expand Down

0 comments on commit 7458885

Please sign in to comment.