Skip to content

Commit

Permalink
comment retrieval support (#411)
Browse files Browse the repository at this point in the history
  • Loading branch information
sonaalthaker authored Feb 16, 2024
1 parent f907c59 commit a9cf1a1
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 2 deletions.
48 changes: 48 additions & 0 deletions rubicon_ml/repository/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,3 +1252,51 @@ def add_comments(
comment_metadata_path = f"{comment_metadata_root}/comments_{domain.utils.uuid.uuid4()}.json"

self._persist_domain({"added_comments": comments}, comment_metadata_path)

def _sort_comment_paths(self, comment_paths):
"""Sorts the paths in `comment_paths` by when they were
created.
"""
return self._sort_tag_paths(comment_paths)

def get_comments(
self, project_name, experiment_id=None, entity_identifier=None, entity_type=None
):
"""Retrieve comments from the configured filesystem.
Parameters
----------
project_name : str
The name of the project the object to retrieve
comments from belongs to.
experiment_id : str, optional
The ID of the experiment to retrieve comments from.
entity_identifier : str, optional
The ID or name of the entity to apply the comments
`comments` to.
entity_type : str, optional
The name of the entity's type as returned by
`entity_cls.__class__.__name__`.
Returns
-------
dict
A dictionary, `added_comments` where the
value is a list of comment names that have
been added to the specified object.
"""
comment_metadata_root = self._get_comment_metadata_root(
project_name, experiment_id, entity_identifier, entity_type
)
comment_metadata_glob = f"{comment_metadata_root}/comments_*.json"

comment_paths = self._glob(comment_metadata_glob)
if len(comment_paths) == 0:
return []

sorted_comment_paths = self._sort_comment_paths(comment_paths)

comment_data = self._cat([p for _, p in sorted_comment_paths])
sorted_comment_data = [json.loads(comment_data[p]) for _, p in sorted_comment_paths]

return sorted_comment_data
36 changes: 34 additions & 2 deletions tests/unit/repository/test_base_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ def _create_project(repository):
return project


def _create_experiment(repository, project=None, tags=[]):
def _create_experiment(repository, project=None, tags=[], comments=[]):
if project is None:
project = _create_project(repository)

experiment = domain.Experiment(
name=f"Test Experiment {uuid.uuid4()}", project_name=project.name, tags=[]
name=f"Test Experiment {uuid.uuid4()}", project_name=project.name, tags=[], comments=[]
)
repository.create_experiment(experiment)

Expand Down Expand Up @@ -1002,3 +1002,35 @@ def test_add_comments(memory_repository):
comments_json = json.load(f)

assert ["this is a comment"] == comments_json["added_comments"]


def test_get_comments(memory_repository):
repository = memory_repository
experiment = _create_experiment(repository)
repository.add_comments(
experiment.project_name,
["this is a comment"],
experiment_id=experiment.id,
entity_type=experiment.__class__.__name__,
)

comments = repository.get_comments(
experiment.project_name,
experiment_id=experiment.id,
entity_type=experiment.__class__.__name__,
)

assert {"added_comments": ["this is a comment"]} in comments


def test_get_comments_with_no_results(memory_repository):
repository = memory_repository
experiment = _create_experiment(repository)

comments = repository.get_comments(
experiment.project_name,
experiment_id=experiment.id,
entity_type=experiment.__class__.__name__,
)

assert comments == []

0 comments on commit a9cf1a1

Please sign in to comment.