Skip to content

Commit

Permalink
Refactor train.py to log test metrics to MLflow
Browse files Browse the repository at this point in the history
  • Loading branch information
valhassan committed Sep 27, 2024
1 parent 2926b66 commit deca95a
Showing 1 changed file with 39 additions and 1 deletion.
40 changes: 39 additions & 1 deletion geo_deep_learning/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.cli import ArgsType, LightningCLI
from lightning.pytorch.loggers import MLFlowLogger
from lightning.pytorch.loggers import Logger


class GeoDeepLearningCLI(LightningCLI):
Expand All @@ -16,7 +18,9 @@ def after_fit(self):
accelerator="auto",
strategy="auto")
best_model = self.model.__class__.load_from_checkpoint(best_model_path)
test_trainer.test(model=best_model, dataloaders=self.datamodule.test_dataloader())
test_results = test_trainer.test(model=best_model, dataloaders=self.datamodule.test_dataloader())
self.log_test_metrics_to_mlflow(test_results)
# self.log_test_metrics(test_results)
self.trainer.strategy.barrier()

def print_dataset_sizes(self):
Expand All @@ -28,6 +32,40 @@ def print_dataset_sizes(self):
print(f"Number of training samples: {train_size}")
print(f"Number of validation samples: {val_size}")
print(f"Number of test samples: {test_size}")

def log_test_metrics_to_mlflow(self, test_results):
# Get the MLflow logger from the trainer
mlf_logger = next((logger for logger in self.trainer.loggers if isinstance(logger, MLFlowLogger)), None)

if mlf_logger is not None:
# Log each metric from the test results
for metric_name, metric_value in test_results[0].items():
mlf_logger.experiment.log_metric(
run_id=mlf_logger.run_id,
key=f"test_{metric_name}",
value=metric_value
)
print("Test metrics logged to MLflow successfully.")
else:
print("MLflow logger not found in the trainer's loggers.")

# def log_test_metrics(self, test_results):
# if not self.trainer.logger:
# print("No logger found. Test metrics will not be logged.")
# return

# if isinstance(self.trainer.logger, Logger):
# for metric_name, metric_value in test_results[0].items():
# self.trainer.logger.log_metrics({f"test_{metric_name}": metric_value})
# print("Test metrics logged successfully.")
# elif isinstance(self.trainer.logger, list):
# for logger in self.trainer.logger:
# if isinstance(logger, Logger):
# for metric_name, metric_value in test_results[0].items():
# logger.log_metrics({f"test_{metric_name}": metric_value})
# print("Test metrics logged successfully to all loggers.")
# else:
# print("Unsupported logger type. Test metrics will not be logged.")


def main(args: ArgsType = None) -> None:
Expand Down

0 comments on commit deca95a

Please sign in to comment.