diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d90a63ad..348a6b5c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: exclude: (versioneer.py|_version.py) - repo: https://github.com/timothycrosley/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort diff --git a/rubicon_ml/domain/experiment.py b/rubicon_ml/domain/experiment.py index d5118b37..e0dd4b38 100644 --- a/rubicon_ml/domain/experiment.py +++ b/rubicon_ml/domain/experiment.py @@ -4,12 +4,12 @@ from datetime import datetime from typing import List, Optional -from rubicon_ml.domain.mixin import TagMixin +from rubicon_ml.domain.mixin import CommentMixin, TagMixin from rubicon_ml.domain.utils import TrainingMetadata, uuid @dataclass -class Experiment(TagMixin): +class Experiment(TagMixin, CommentMixin): project_name: str id: str = field(default_factory=uuid.uuid4) @@ -20,4 +20,5 @@ class Experiment(TagMixin): commit_hash: Optional[str] = None training_metadata: Optional[TrainingMetadata] = None tags: List[str] = field(default_factory=list) + comments: List[str] = field(default_factory=list) created_at: datetime = field(default_factory=datetime.utcnow) diff --git a/rubicon_ml/domain/mixin.py b/rubicon_ml/domain/mixin.py index 974f3bdb..6485f50f 100644 --- a/rubicon_ml/domain/mixin.py +++ b/rubicon_ml/domain/mixin.py @@ -25,3 +25,18 @@ def remove_tags(self, tags: List[str]): A list of string tags to remove from this domain model. """ self.tags = list(set(self.tags).difference(set(tags))) + + +class CommentMixin: + """Adds comment support to a domain model.""" + + def add_comments(self, comments: List[str]): + """ + Add new comments to this model. + + Parameters + ---------- + comments : List[str] + A list of string comments to add to the domain model. + """ + self.comments.extend(comments) diff --git a/rubicon_ml/repository/base.py b/rubicon_ml/repository/base.py index 950cefd7..93b28303 100644 --- a/rubicon_ml/repository/base.py +++ b/rubicon_ml/repository/base.py @@ -1212,3 +1212,43 @@ def get_tags(self, project_name, experiment_id=None, entity_identifier=None, ent sorted_tag_data = [json.loads(tag_data[p]) for _, p in sorted_tag_paths] return sorted_tag_data + + # ---------- Comments ---------- + + def _get_comment_metadata_root( + self, project_name, experiment_id=None, entity_identifier=None, entity_type=None + ): + """Returns the directory to write comments to.""" + # comments and tags are currently written to the same root with a different filename + return self._get_tag_metadata_root( + project_name, experiment_id, entity_identifier, entity_type + ) + + def add_comments( + self, project_name, comments, experiment_id=None, entity_identifier=None, entity_type=None + ): + """Persist comments to the configured filesystem. + + Parameters + ---------- + project_name : str + The name of the project the object to comment + belongs to. + comments : list of str + The comment values to persist. + experiment_id : str, optional + The ID of the experiment to apply the comments + `comments` to. + entity_identifier : str, optional + The ID or name of the entity to apply the comments + `comments` to. + entity_type : str, optional + The name of the entity's type as returned by + `entity_cls.__class__.__name__`. + """ + comment_metadata_root = self._get_comment_metadata_root( + project_name, experiment_id, entity_identifier, entity_type + ) + comment_metadata_path = f"{comment_metadata_root}/comments_{domain.utils.uuid.uuid4()}.json" + + self._persist_domain({"added_comments": comments}, comment_metadata_path) diff --git a/tests/unit/domain/test_mixin.py b/tests/unit/domain/test_mixin.py index 3d84792b..3b2ea6f1 100644 --- a/tests/unit/domain/test_mixin.py +++ b/tests/unit/domain/test_mixin.py @@ -1,9 +1,10 @@ -from rubicon_ml.domain.mixin import TagMixin +from rubicon_ml.domain.mixin import CommentMixin, TagMixin -class Taggable(TagMixin): - def __init__(self, tags=[]): +class Taggable(TagMixin, CommentMixin): + def __init__(self, tags=[], comments=[]): self.tags = tags + self.comments = comments def test_add_tags(): @@ -25,3 +26,10 @@ def test_remove_tags(): taggable.remove_tags(["x"]) assert taggable.tags == ["y"] + + +def test_add_comments(): + taggable = Taggable() + taggable.add_comments(["x"]) + + assert taggable.comments == ["x"] diff --git a/tests/unit/repository/test_base_repo.py b/tests/unit/repository/test_base_repo.py index 4ceba57b..2405a0ae 100644 --- a/tests/unit/repository/test_base_repo.py +++ b/tests/unit/repository/test_base_repo.py @@ -980,3 +980,25 @@ def test_get_tags_with_no_results(memory_repository): ) assert tags == [] + + +def test_add_comments(memory_repository): + repository = memory_repository + experiment = _create_experiment(repository) + repository.add_comments( + experiment.project_name, + ["this is a comment"], + experiment_id=experiment.id, + entity_type=experiment.__class__.__name__, + ) + + comments_glob = f"{repository.root_dir}/{slugify(experiment.project_name)}/experiments/{experiment.id}/comments_*.json" + comments_files = repository.filesystem.glob(comments_glob) + + assert len(comments_files) == 1 + + open_file = repository.filesystem.open(comments_files[0]) + with open_file as f: + comments_json = json.load(f) + + assert ["this is a comment"] == comments_json["added_comments"]