Skip to content

Commit

Permalink
Update CometMLTracker to allow re-using experiment
Browse files Browse the repository at this point in the history
Update CometMLTracker to use new `comet_ml.start` function to create
Experiments, this way end-users can create online, offline experiments, append
data to an existing experiment and it also automatically re-use a running
experiment if one is present rather than creating a new one.
  • Loading branch information
Lothiraldan committed Jan 7, 2025
1 parent d6d3e03 commit 6fc980b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 20 deletions.
16 changes: 8 additions & 8 deletions src/accelerate/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ class CometMLTracker(GeneralTracker):
run_name (`str`):
The name of the experiment run.
**kwargs (additional keyword arguments, *optional*):
Additional key word arguments passed along to the `Experiment.__init__` method.
Additional key word arguments passed along to the `comet_ml.start` method.
"""

name = "comet_ml"
Expand All @@ -417,9 +417,9 @@ def __init__(self, run_name: str, **kwargs):
super().__init__()
self.run_name = run_name

from comet_ml import Experiment
from comet_ml import start

self.writer = Experiment(project_name=run_name, **kwargs)
self.writer = start(project_name=run_name, **kwargs)
logger.debug(f"Initialized CometML project {self.run_name}")
logger.debug(
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
Expand All @@ -440,7 +440,7 @@ def store_init_configuration(self, values: dict):
`str`, `float`, `int`, or `None`.
"""
self.writer.log_parameters(values)
logger.debug("Stored initial configuration hyperparameters to CometML")
logger.debug("Stored initial configuration hyperparameters to Comet")

@on_main_process
def log(self, values: dict, step: Optional[int] = None, **kwargs):
Expand All @@ -466,15 +466,15 @@ def log(self, values: dict, step: Optional[int] = None, **kwargs):
self.writer.log_other(k, v, **kwargs)
elif isinstance(v, dict):
self.writer.log_metrics(v, step=step, **kwargs)
logger.debug("Successfully logged to CometML")
logger.debug("Successfully logged to Comet")

@on_main_process
def finish(self):
"""
Closes `comet-ml` writer
Flush `comet-ml` writer
"""
self.writer.end()
logger.debug("CometML run closed")
self.writer.flush()
logger.debug("Comet run flushed")


class AimTracker(GeneralTracker):
Expand Down
19 changes: 7 additions & 12 deletions tests/test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@


if is_comet_ml_available():
from comet_ml import OfflineExperiment
from comet_ml import ExperimentConfig

if is_tensorboard_available():
import struct
Expand Down Expand Up @@ -201,16 +201,7 @@ def test_wandb(self):
assert logged_items["_step"] == "0"


# Comet has a special `OfflineExperiment` we need to use for testing
def offline_init(self, run_name: str, tmpdir: str):
self.run_name = run_name
self.writer = OfflineExperiment(project_name=run_name, offline_directory=tmpdir)
logger.info(f"Initialized offline CometML project {self.run_name}")
logger.info("Make sure to log any initial configurations with `self.store_init_configuration` before training!")


@require_comet_ml
@mock.patch.object(CometMLTracker, "__init__", offline_init)
class CometMLTest(unittest.TestCase):
@staticmethod
def get_value_from_key(log_list, key: str, is_param: bool = False):
Expand All @@ -231,7 +222,9 @@ def get_value_from_key(log_list, key: str, is_param: bool = False):

def test_init_trackers(self):
with tempfile.TemporaryDirectory() as d:
tracker = CometMLTracker("test_project_with_config", d)
tracker = CometMLTracker(
"test_project_with_config", online=False, experiment_config=ExperimentConfig(offline_directory=d)
)
accelerator = Accelerator(log_with=tracker)
config = {"num_iterations": 12, "learning_rate": 1e-2, "some_boolean": False, "some_string": "some_value"}
accelerator.init_trackers(None, config)
Expand All @@ -249,7 +242,9 @@ def test_init_trackers(self):

def test_log(self):
with tempfile.TemporaryDirectory() as d:
tracker = CometMLTracker("test_project_with_config", d)
tracker = CometMLTracker(
"test_project_with_config", online=False, experiment_config=ExperimentConfig(offline_directory=d)
)
accelerator = Accelerator(log_with=tracker)
accelerator.init_trackers(None)
values = {"total_loss": 0.1, "iteration": 1, "my_text": "some_value"}
Expand Down

0 comments on commit 6fc980b

Please sign in to comment.