From 25e4ac12f49af9c79c63810a663d38d4c89800d4 Mon Sep 17 00:00:00 2001 From: Geoffrey Angus Date: Tue, 19 Mar 2024 17:15:20 -0700 Subject: [PATCH] enh: enable loading model weights from training checkpoint (#3969) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ludwig/api.py | 31 ++++++-- ludwig/trainers/trainer.py | 9 ++- ludwig/utils/checkpoint_utils.py | 2 +- .../test_model_save_and_load.py | 73 +++++++++++++++++++ 4 files changed, 106 insertions(+), 9 deletions(-) diff --git a/ludwig/api.py b/ludwig/api.py index 4df24eda93a..540c5797d6d 100644 --- a/ludwig/api.py +++ b/ludwig/api.py @@ -67,6 +67,7 @@ MODEL_HYPERPARAMETERS_FILE_NAME, set_disable_progressbar, TRAIN_SET_METADATA_FILE_NAME, + TRAINING_CHECKPOINTS_DIR_PATH, ) from ludwig.models.base import BaseModel from ludwig.models.calibrator import Calibrator @@ -1282,9 +1283,12 @@ def evaluate( self.model.output_features, predictions, dataset, training_set_metadata ) eval_stats = { - of_name: {**eval_stats[of_name], **overall_stats[of_name]} - # account for presence of 'combined' key - if of_name in overall_stats else {**eval_stats[of_name]} + of_name: ( + {**eval_stats[of_name], **overall_stats[of_name]} + # account for presence of 'combined' key + if of_name in overall_stats + else {**eval_stats[of_name]} + ) for of_name in eval_stats } @@ -1765,6 +1769,7 @@ def load( gpu_memory_limit: Optional[float] = None, allow_parallel_threads: bool = True, callbacks: List[Callback] = None, + from_checkpoint: bool = False, ) -> "LudwigModel": # return is an instance of ludwig.api.LudwigModel class """This function allows for loading pretrained models. @@ -1788,6 +1793,9 @@ def load( :param callbacks: (list, default: `None`) a list of `ludwig.callbacks.Callback` objects that provide hooks into the Ludwig pipeline. + :param from_checkpoint: (bool, default: `False`) if `True`, the model + will be loaded from the latest checkpoint (training_checkpoints/) + instead of the final model weights. # Return @@ -1834,7 +1842,7 @@ def load( ludwig_model.model = LudwigModel.create_model(config_obj) # load model weights - ludwig_model.load_weights(model_dir) + ludwig_model.load_weights(model_dir, from_checkpoint) # The LoRA layers appear to be loaded again (perhaps due to a potential bug); hence, we merge and unload again. if ludwig_model.is_merge_and_unload_set(): @@ -1851,12 +1859,16 @@ def load( def load_weights( self, model_dir: str, + from_checkpoint: bool = False, ) -> None: """Loads weights from a pre-trained model. # Inputs :param model_dir: (str) filepath string to location of a pre-trained model + :param from_checkpoint: (bool, default: `False`) if `True`, the model + will be loaded from the latest checkpoint (training_checkpoints/) + instead of the final model weights. # Return :return: `None` @@ -1868,7 +1880,16 @@ def load_weights( ``` """ if self.backend.is_coordinator(): - self.model.load(model_dir) + if from_checkpoint: + with self.backend.create_trainer( + model=self.model, + config=self.config_obj.trainer, + ) as trainer: + checkpoint = trainer.create_checkpoint_handle() + training_checkpoints_path = os.path.join(model_dir, TRAINING_CHECKPOINTS_DIR_PATH) + trainer.resume_weights_and_optimizer(training_checkpoints_path, checkpoint) + else: + self.model.load(model_dir) self.backend.sync_model(self.model) diff --git a/ludwig/trainers/trainer.py b/ludwig/trainers/trainer.py index 7a6ff400ca6..b4a66af14cf 100644 --- a/ludwig/trainers/trainer.py +++ b/ludwig/trainers/trainer.py @@ -821,6 +821,11 @@ def save_checkpoint(self, progress_tracker: ProgressTracker, save_path: str, che # Callback that the checkpoint was reached, regardless of whether the model was evaluated. self.callback(lambda c: c.on_checkpoint(self, progress_tracker)) + def create_checkpoint_handle(self): + return self.distributed.create_checkpoint_handle( + dist_model=self.dist_model, model=self.model, optimizer=self.optimizer, scheduler=self.scheduler + ) + def train( self, training_set, @@ -873,9 +878,7 @@ def train( ) # ====== Setup session ======= - checkpoint = self.distributed.create_checkpoint_handle( - dist_model=self.dist_model, model=self.model, optimizer=self.optimizer, scheduler=self.scheduler - ) + checkpoint = self.create_checkpoint_handle() checkpoint_manager = CheckpointManager(checkpoint, training_checkpoints_path, device=self.device) # ====== Setup Tensorboard writers ======= diff --git a/ludwig/utils/checkpoint_utils.py b/ludwig/utils/checkpoint_utils.py index 9d37ddcbfc3..6988347a19c 100644 --- a/ludwig/utils/checkpoint_utils.py +++ b/ludwig/utils/checkpoint_utils.py @@ -337,4 +337,4 @@ def load_latest_checkpoint(checkpoint: Checkpoint, directory: str, device: torch if last_ckpt: checkpoint.load(last_ckpt, device) else: - logger.error(f"No checkpoints found in {directory}.") + raise FileNotFoundError(f"No checkpoints found in {directory}.") diff --git a/tests/integration_tests/test_model_save_and_load.py b/tests/integration_tests/test_model_save_and_load.py index 5dc2d794fe5..8f2eaab372e 100644 --- a/tests/integration_tests/test_model_save_and_load.py +++ b/tests/integration_tests/test_model_save_and_load.py @@ -32,6 +32,79 @@ ) +def test_model_load_from_checkpoint(tmpdir, csv_filename, tmp_path): + torch.manual_seed(1) + random.seed(1) + np.random.seed(1) + + input_features = [ + binary_feature(), + number_feature(), + ] + + output_features = [ + binary_feature(), + ] + + data_csv_path = generate_data(input_features, output_features, csv_filename, num_examples=50) + + config = { + "input_features": input_features, + "output_features": output_features, + TRAINER: {"epochs": 1, BATCH_SIZE: 2}, + } + backend = LocalTestBackend() + + # create sub-directory to store results + results_dir = tmp_path / "results" + results_dir.mkdir() + + data_df = read_csv(data_csv_path) + splitter = get_splitter("random") + training_set, validation_set, test_set = splitter.split(data_df, backend) + ludwig_model1 = LudwigModel(config, backend=backend) + _, _, output_dir = ludwig_model1.train( + training_set=training_set, + validation_set=validation_set, + test_set=test_set, + output_directory="results", # results_dir + ) + + model_dir = os.path.join(output_dir, "model") + ludwig_model_loaded = LudwigModel.load(model_dir, backend=backend, from_checkpoint=True) + preds_1, _ = ludwig_model1.predict(dataset=validation_set) + + def check_model_equal(ludwig_model2): + # Compare model predictions + preds_2, _ = ludwig_model2.predict(dataset=validation_set) + assert set(preds_1.keys()) == set(preds_2.keys()) + for key in preds_1: + assert preds_1[key].dtype == preds_2[key].dtype, key + assert np.all(a == b for a, b in zip(preds_1[key], preds_2[key])), key + # assert preds_2[key].dtype == preds_3[key].dtype, key + # assert list(preds_2[key]) == list(preds_3[key]), key + + # Compare model weights + for if_name in ludwig_model1.model.input_features: + if1 = ludwig_model1.model.input_features.get(if_name) + if2 = ludwig_model2.model.input_features.get(if_name) + for if1_w, if2_w in zip(if1.encoder_obj.parameters(), if2.encoder_obj.parameters()): + assert torch.allclose(if1_w, if2_w) + + c1 = ludwig_model1.model.combiner + c2 = ludwig_model2.model.combiner + for c1_w, c2_w in zip(c1.parameters(), c2.parameters()): + assert torch.allclose(c1_w, c2_w) + + for of_name in ludwig_model1.model.output_features: + of1 = ludwig_model1.model.output_features.get(of_name) + of2 = ludwig_model2.model.output_features.get(of_name) + for of1_w, of2_w in zip(of1.decoder_obj.parameters(), of2.decoder_obj.parameters()): + assert torch.allclose(of1_w, of2_w) + + check_model_equal(ludwig_model_loaded) + + def test_model_save_reload_api(tmpdir, csv_filename, tmp_path): torch.manual_seed(1) random.seed(1)