From a9cf1a18e0934511c1ec38028bf3e8e374016fa7 Mon Sep 17 00:00:00 2001 From: sonaalthaker <159057578+sonaalthaker@users.noreply.github.com> Date: Fri, 16 Feb 2024 11:19:43 -0800 Subject: [PATCH] comment retrieval support (#411) --- rubicon_ml/repository/base.py | 48 +++++++++++++++++++++++++ tests/unit/repository/test_base_repo.py | 36 +++++++++++++++++-- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/rubicon_ml/repository/base.py b/rubicon_ml/repository/base.py index 93b28303..32d8c431 100644 --- a/rubicon_ml/repository/base.py +++ b/rubicon_ml/repository/base.py @@ -1252,3 +1252,51 @@ def add_comments( comment_metadata_path = f"{comment_metadata_root}/comments_{domain.utils.uuid.uuid4()}.json" self._persist_domain({"added_comments": comments}, comment_metadata_path) + + def _sort_comment_paths(self, comment_paths): + """Sorts the paths in `comment_paths` by when they were + created. + """ + return self._sort_tag_paths(comment_paths) + + def get_comments( + self, project_name, experiment_id=None, entity_identifier=None, entity_type=None + ): + """Retrieve comments from the configured filesystem. + + Parameters + ---------- + project_name : str + The name of the project the object to retrieve + comments from belongs to. + experiment_id : str, optional + The ID of the experiment to retrieve comments from. + entity_identifier : str, optional + The ID or name of the entity to apply the comments + `comments` to. + entity_type : str, optional + The name of the entity's type as returned by + `entity_cls.__class__.__name__`. + + Returns + ------- + dict + A dictionary, `added_comments` where the + value is a list of comment names that have + been added to the specified object. + """ + comment_metadata_root = self._get_comment_metadata_root( + project_name, experiment_id, entity_identifier, entity_type + ) + comment_metadata_glob = f"{comment_metadata_root}/comments_*.json" + + comment_paths = self._glob(comment_metadata_glob) + if len(comment_paths) == 0: + return [] + + sorted_comment_paths = self._sort_comment_paths(comment_paths) + + comment_data = self._cat([p for _, p in sorted_comment_paths]) + sorted_comment_data = [json.loads(comment_data[p]) for _, p in sorted_comment_paths] + + return sorted_comment_data diff --git a/tests/unit/repository/test_base_repo.py b/tests/unit/repository/test_base_repo.py index 2405a0ae..42bc5da0 100644 --- a/tests/unit/repository/test_base_repo.py +++ b/tests/unit/repository/test_base_repo.py @@ -21,12 +21,12 @@ def _create_project(repository): return project -def _create_experiment(repository, project=None, tags=[]): +def _create_experiment(repository, project=None, tags=[], comments=[]): if project is None: project = _create_project(repository) experiment = domain.Experiment( - name=f"Test Experiment {uuid.uuid4()}", project_name=project.name, tags=[] + name=f"Test Experiment {uuid.uuid4()}", project_name=project.name, tags=[], comments=[] ) repository.create_experiment(experiment) @@ -1002,3 +1002,35 @@ def test_add_comments(memory_repository): comments_json = json.load(f) assert ["this is a comment"] == comments_json["added_comments"] + + +def test_get_comments(memory_repository): + repository = memory_repository + experiment = _create_experiment(repository) + repository.add_comments( + experiment.project_name, + ["this is a comment"], + experiment_id=experiment.id, + entity_type=experiment.__class__.__name__, + ) + + comments = repository.get_comments( + experiment.project_name, + experiment_id=experiment.id, + entity_type=experiment.__class__.__name__, + ) + + assert {"added_comments": ["this is a comment"]} in comments + + +def test_get_comments_with_no_results(memory_repository): + repository = memory_repository + experiment = _create_experiment(repository) + + comments = repository.get_comments( + experiment.project_name, + experiment_id=experiment.id, + entity_type=experiment.__class__.__name__, + ) + + assert comments == []