diff --git a/geo_deep_learning/train.py b/geo_deep_learning/train.py index d2f63218..9f899662 100644 --- a/geo_deep_learning/train.py +++ b/geo_deep_learning/train.py @@ -1,6 +1,8 @@ import torch from lightning.pytorch import Trainer from lightning.pytorch.cli import ArgsType, LightningCLI +from lightning.pytorch.loggers import MLFlowLogger +from lightning.pytorch.loggers import Logger class GeoDeepLearningCLI(LightningCLI): @@ -16,7 +18,9 @@ def after_fit(self): accelerator="auto", strategy="auto") best_model = self.model.__class__.load_from_checkpoint(best_model_path) - test_trainer.test(model=best_model, dataloaders=self.datamodule.test_dataloader()) + test_results = test_trainer.test(model=best_model, dataloaders=self.datamodule.test_dataloader()) + self.log_test_metrics_to_mlflow(test_results) + # self.log_test_metrics(test_results) self.trainer.strategy.barrier() def print_dataset_sizes(self): @@ -28,6 +32,40 @@ def print_dataset_sizes(self): print(f"Number of training samples: {train_size}") print(f"Number of validation samples: {val_size}") print(f"Number of test samples: {test_size}") + + def log_test_metrics_to_mlflow(self, test_results): + # Get the MLflow logger from the trainer + mlf_logger = next((logger for logger in self.trainer.loggers if isinstance(logger, MLFlowLogger)), None) + + if mlf_logger is not None: + # Log each metric from the test results + for metric_name, metric_value in test_results[0].items(): + mlf_logger.experiment.log_metric( + run_id=mlf_logger.run_id, + key=f"test_{metric_name}", + value=metric_value + ) + print("Test metrics logged to MLflow successfully.") + else: + print("MLflow logger not found in the trainer's loggers.") + + # def log_test_metrics(self, test_results): + # if not self.trainer.logger: + # print("No logger found. Test metrics will not be logged.") + # return + + # if isinstance(self.trainer.logger, Logger): + # for metric_name, metric_value in test_results[0].items(): + # self.trainer.logger.log_metrics({f"test_{metric_name}": metric_value}) + # print("Test metrics logged successfully.") + # elif isinstance(self.trainer.logger, list): + # for logger in self.trainer.logger: + # if isinstance(logger, Logger): + # for metric_name, metric_value in test_results[0].items(): + # logger.log_metrics({f"test_{metric_name}": metric_value}) + # print("Test metrics logged successfully to all loggers.") + # else: + # print("Unsupported logger type. Test metrics will not be logged.") def main(args: ArgsType = None) -> None: