diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 7478496666..bfc7862a10 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -204,12 +204,12 @@ def load_state_dict(self, state: dict[str, Any]): )) # Ensure that the datamix has not changed on the current datamix - current_loader = self._schedule[self._schedule_index]['train_loader'] - saved_loader = schedule[self._schedule_index]['train_loader'] - if current_loader != saved_loader: + current_dataset = self._schedule[self._schedule_index]['dataset'] + saved_dataset = schedule[self._schedule_index]['dataset'] + if current_dataset != saved_dataset: raise ValueError(( f'The current datamix must stay the same across resumptions. ', - f'Expected {saved_loader} but got {current_loader}', + f'Expected {saved_dataset} but got {current_dataset}', )) # Ensure that the current datamix duration is in the correct units