Skip to content

Commit

Permalink
Fix pytorch checkpointing for CL callback (#1583)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu authored Oct 10, 2024
1 parent 1654827 commit 85b251f
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,12 @@ def load_state_dict(self, state: dict[str, Any]):
))

# Ensure that the datamix has not changed on the current datamix
current_loader = self._schedule[self._schedule_index]['train_loader']
saved_loader = schedule[self._schedule_index]['train_loader']
if current_loader != saved_loader:
current_dataset = self._schedule[self._schedule_index]['dataset']
saved_dataset = schedule[self._schedule_index]['dataset']
if current_dataset != saved_dataset:
raise ValueError((
f'The current datamix must stay the same across resumptions. ',
f'Expected {saved_loader} but got {current_loader}',
f'Expected {saved_dataset} but got {current_dataset}',
))

# Ensure that the current datamix duration is in the correct units
Expand Down

0 comments on commit 85b251f

Please sign in to comment.