diff --git a/rubicon_ml/client/config.py b/rubicon_ml/client/config.py index f1bcc814..4df91a82 100644 --- a/rubicon_ml/client/config.py +++ b/rubicon_ml/client/config.py @@ -25,7 +25,7 @@ class Config: root_dir : str, optional Absolute or relative filepath. Defaults to using the local filesystem. Prefix with s3:// to use s3 instead. - auto_git_enabled : bool, optional + is_auto_git_enabled : bool, optional True to use the `git` command to automatically log relevant repository information to projects and experiments logged with this client instance, False otherwise. Defaults to False. diff --git a/rubicon_ml/client/project.py b/rubicon_ml/client/project.py index 74aa11ea..eb1ea7ab 100644 --- a/rubicon_ml/client/project.py +++ b/rubicon_ml/client/project.py @@ -78,7 +78,7 @@ def _create_experiment_domain( comments, ): """Instantiates and returns an experiment domain object.""" - if self.is_auto_git_enabled: + if self.is_auto_git_enabled(): if branch_name is None: branch_name = self._get_branch_name() if commit_hash is None: diff --git a/rubicon_ml/client/rubicon.py b/rubicon_ml/client/rubicon.py index d8ea41c7..2883f063 100644 --- a/rubicon_ml/client/rubicon.py +++ b/rubicon_ml/client/rubicon.py @@ -46,13 +46,20 @@ def __init__( Config( persistence=config["persistence"], root_dir=config["root_dir"], - auto_git_enabled=auto_git_enabled, + is_auto_git_enabled=auto_git_enabled, **storage_options, ) for config in composite_config ] else: - self.configs = [Config(persistence, root_dir, auto_git_enabled, **storage_options)] + self.configs = [ + Config( + persistence=persistence, + root_dir=root_dir, + is_auto_git_enabled=auto_git_enabled, + **storage_options, + ), + ] @property def config(self): @@ -116,7 +123,7 @@ def _create_project_domain( training_metadata: Optional[Union[List[Tuple], Tuple]], ): """Instantiates and returns a project domain object.""" - if self.is_auto_git_enabled and github_url is None: + if self.is_auto_git_enabled() and github_url is None: github_url = self._get_github_url() if training_metadata is not None: diff --git a/rubicon_ml/viz/experiments_table.py b/rubicon_ml/viz/experiments_table.py index fd4785f3..bcf4c11d 100644 --- a/rubicon_ml/viz/experiments_table.py +++ b/rubicon_ml/viz/experiments_table.py @@ -241,7 +241,7 @@ def load_experiment_data(self): "tags": ", ".join(str(tag) for tag in experiment.tags), } - if experiment.commit_hash is not None: + if experiment.commit_hash: experiment_record["commit_hash"] = experiment.commit_hash[:7] commit_hashes.add(experiment.commit_hash) diff --git a/tests/unit/client/test_rubicon_client.py b/tests/unit/client/test_rubicon_client.py index 06214337..c24807de 100644 --- a/tests/unit/client/test_rubicon_client.py +++ b/tests/unit/client/test_rubicon_client.py @@ -113,6 +113,20 @@ def test_create_project_with_auto_git(mock_completed_process_git): rubicon.repository.filesystem.store = {} +def test_create_project_withouy_auto_git(mock_completed_process_git): + with mock.patch("subprocess.run") as mock_run: + mock_run.return_value = mock_completed_process_git + + rubicon = Rubicon("memory", "test-root", auto_git_enabled=False) + rubicon.create_project("test_create_project_withouy_auto_git") + + expected = [] + + assert mock_run.mock_calls == expected + + rubicon.repository.filesystem.store = {} + + def test_get_project_by_name(rubicon_and_project_client): rubicon, project = rubicon_and_project_client diff --git a/tests/unit/viz/test_experiments_table.py b/tests/unit/viz/test_experiments_table.py index 5236b7e2..406549a5 100644 --- a/tests/unit/viz/test_experiments_table.py +++ b/tests/unit/viz/test_experiments_table.py @@ -17,6 +17,15 @@ def test_experiments_table(viz_experiments): assert experiments_table.is_selectable is True +def test_experiments_table_no_git_commit(viz_experiments): + for experiment in viz_experiments: + experiment._domain.commit_hash = "" + + experiments_table = ExperimentsTable(experiments=viz_experiments, is_selectable=True) + + assert len(viz_experiments) == len(experiments_table.experiments) + + def test_experiments_table_load_data(viz_experiments): experiments_table = ExperimentsTable(experiments=viz_experiments) experiments_table.load_experiment_data()