Skip to content

Commit

Permalink
Refactor train.py to include model logging in MLflow
Browse files Browse the repository at this point in the history
  • Loading branch information
valhassan committed Oct 2, 2024
1 parent 6e273e5 commit 7f96659
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 18 deletions.
1 change: 1 addition & 0 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ trainer:
class_path: lightning.pytorch.loggers.mlflow.MLFlowLogger
init_args:
save_dir: /home/valhassa/Projects/geo-deep-learning/logs
log_model: "all"
experiment_name: "gdl_experiment"
run_name: "gdl_run"
callbacks:
Expand Down
23 changes: 5 additions & 18 deletions geo_deep_learning/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from tools.mlflow_logger import LoggerSaveConfigCallback
from lightning.pytorch import Trainer
from lightning.pytorch.cli import ArgsType, LightningCLI
from lightning.pytorch.loggers import Logger

class GeoDeepLearningCLI(LightningCLI):
def before_fit(self):
self.datamodule.prepare_data()
Expand All @@ -17,7 +17,10 @@ def after_fit(self):
logger=False)
best_model = self.model.__class__.load_from_checkpoint(best_model_path)
test_results = test_trainer.test(model=best_model, dataloaders=self.datamodule.test_dataloader())
self.log_test_metrics(test_results)
for metric_name, metric_value in test_results[0].items():
self.trainer.logger.log_metrics({f"test_{metric_name}": metric_value})
self.trainer.logger.log_hyperparams({"best_model_path": best_model_path})
print("Test metrics logged successfully to all loggers.")
self.trainer.strategy.barrier()

def log_dataset_sizes(self):
Expand All @@ -39,22 +42,6 @@ def log_dataset_sizes(self):
print(f"Number of validation samples: {val_size}")
print(f"Number of test samples: {test_size}")

def log_test_metrics(self, test_results):
if not self.trainer.logger:
print("No logger found. Test metrics will not be logged.")
return
if isinstance(self.trainer.logger, Logger):
for metric_name, metric_value in test_results[0].items():
self.trainer.logger.log_metrics({f"test_{metric_name}": metric_value})
print("Test metrics logged successfully.")
elif isinstance(self.trainer.logger, list):
for logger in self.trainer.logger:
if isinstance(logger, Logger):
for metric_name, metric_value in test_results[0].items():
logger.log_metrics({f"test_{metric_name}": metric_value})
print("Test metrics logged successfully to all loggers.")
else:
print("Unsupported logger type. Test metrics will not be logged.")


def main(args: ArgsType = None) -> None:
Expand Down

0 comments on commit 7f96659

Please sign in to comment.