From 0c655e086b7d6d74685f7379b2c9a3bdab01390e Mon Sep 17 00:00:00 2001 From: Infernaught <72055086+Infernaught@users.noreply.github.com> Date: Wed, 27 Sep 2023 15:04:05 -0400 Subject: [PATCH] Remove duplicate metrics (#3670) --- ludwig/contribs/mlflow/__init__.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/ludwig/contribs/mlflow/__init__.py b/ludwig/contribs/mlflow/__init__.py index 452b5a9e36d..fa1f99c384d 100644 --- a/ludwig/contribs/mlflow/__init__.py +++ b/ludwig/contribs/mlflow/__init__.py @@ -37,6 +37,8 @@ def get_or_create_experiment_id(experiment_name, artifact_uri: str = None): @PublicAPI class MlflowCallback(Callback): def __init__(self, tracking_uri=None, log_artifacts: bool = True): + self.logged_steps = set() + if tracking_uri: mlflow.set_tracking_uri(tracking_uri) self.tracking_uri = mlflow.get_tracking_uri() @@ -157,13 +159,17 @@ def on_trainer_train_setup(self, trainer, save_path, is_coordinator): self.save_fn = lambda args: _log_mlflow(*args, self.log_artifacts) def on_eval_end(self, trainer, progress_tracker, save_path): - self.save_fn((progress_tracker.log_metrics(), progress_tracker.steps, save_path, True)) + if progress_tracker.steps not in self.logged_steps: + self.logged_steps.add(progress_tracker.steps) + self.save_fn((progress_tracker.log_metrics(), progress_tracker.steps, save_path, True)) # Why True? def on_trainer_train_teardown(self, trainer, progress_tracker, save_path, is_coordinator): if is_coordinator: - self.save_fn((progress_tracker.log_metrics(), progress_tracker.steps, save_path, False)) - if self.save_thread is not None: - self.save_thread.join() + if progress_tracker.steps not in self.logged_steps: + self.logged_steps.add(progress_tracker.steps) + self.save_fn((progress_tracker.log_metrics(), progress_tracker.steps, save_path, False)) # Why False? + if self.save_thread is not None: + self.save_thread.join() def on_visualize_figure(self, fig): # TODO: need to also include a filename for this figure