Skip to content

Commit

Permalink
Remove duplicate metrics (#3670)
Browse files Browse the repository at this point in the history
  • Loading branch information
Infernaught authored Sep 27, 2023
1 parent 1286123 commit 0c655e0
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions ludwig/contribs/mlflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def get_or_create_experiment_id(experiment_name, artifact_uri: str = None):
@PublicAPI
class MlflowCallback(Callback):
def __init__(self, tracking_uri=None, log_artifacts: bool = True):
self.logged_steps = set()

if tracking_uri:
mlflow.set_tracking_uri(tracking_uri)
self.tracking_uri = mlflow.get_tracking_uri()
Expand Down Expand Up @@ -157,13 +159,17 @@ def on_trainer_train_setup(self, trainer, save_path, is_coordinator):
self.save_fn = lambda args: _log_mlflow(*args, self.log_artifacts)

def on_eval_end(self, trainer, progress_tracker, save_path):
self.save_fn((progress_tracker.log_metrics(), progress_tracker.steps, save_path, True))
if progress_tracker.steps not in self.logged_steps:
self.logged_steps.add(progress_tracker.steps)
self.save_fn((progress_tracker.log_metrics(), progress_tracker.steps, save_path, True)) # Why True?

def on_trainer_train_teardown(self, trainer, progress_tracker, save_path, is_coordinator):
if is_coordinator:
self.save_fn((progress_tracker.log_metrics(), progress_tracker.steps, save_path, False))
if self.save_thread is not None:
self.save_thread.join()
if progress_tracker.steps not in self.logged_steps:
self.logged_steps.add(progress_tracker.steps)
self.save_fn((progress_tracker.log_metrics(), progress_tracker.steps, save_path, False)) # Why False?
if self.save_thread is not None:
self.save_thread.join()

def on_visualize_figure(self, fig):
# TODO: need to also include a filename for this figure
Expand Down

0 comments on commit 0c655e0

Please sign in to comment.