Skip to content

Commit

Permalink
Add tag support to MLFlowLogger (#2716)
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl authored Nov 17, 2023
1 parent 15d4c34 commit 1b28293
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class MLFlowLogger(LoggerDestination):
use the MLflow environment variable or a default value
run_name: (str, optional): MLflow run name. If not set it will be the same as the
Trainer run name
tags: (dict, optional): MLflow tags to log with the run
tracking_uri (str | pathlib.Path, optional): MLflow tracking uri, the URI to the
remote or local endpoint where logs are stored (If none it is set to MLflow default)
rank_zero_only (bool, optional): Whether to log only on the rank-zero process
Expand All @@ -55,6 +56,7 @@ def __init__(
self,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
tags: Optional[Dict[str, Any]] = None,
tracking_uri: Optional[Union[str, pathlib.Path]] = None,
rank_zero_only: bool = True,
flush_interval: int = 10,
Expand All @@ -71,8 +73,9 @@ def __init__(
conda_channel='conda-forge') from e
self._enabled = (not rank_zero_only) or dist.get_global_rank() == 0

self.run_name = run_name
self.experiment_name = experiment_name
self.run_name = run_name
self.tags = tags
self.model_registry_prefix = model_registry_prefix
self.model_registry_uri = model_registry_uri
if self.model_registry_uri == 'databricks-uc':
Expand Down Expand Up @@ -133,7 +136,7 @@ def init(self, state: State, logger: Logger) -> None:
run_name=self.run_name,
)
self._run_id = new_run.info.run_id
mlflow.start_run(run_id=self._run_id)
mlflow.start_run(run_id=self._run_id, tags=self.tags)

def log_table(self, columns: List[str], rows: List[List[Any]], name: str = 'Table') -> None:
if self._enabled:
Expand Down

0 comments on commit 1b28293

Please sign in to comment.