Skip to content

Commit

Permalink
Merge branch 'main' into edgetest-patch
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanSoley authored Feb 27, 2024
2 parents 9213b5b + c921595 commit a4eb8ea
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 3 deletions.
7 changes: 6 additions & 1 deletion rubicon_ml/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from rubicon_ml.client.base import Base # noqa F401
from rubicon_ml.client.config import Config

from rubicon_ml.client.mixin import ArtifactMixin, DataframeMixin, TagMixin # noqa F401
from rubicon_ml.client.mixin import ( # noqa F401
ArtifactMixin,
DataframeMixin,
TagMixin,
CommentMixin,
)

from rubicon_ml.client.artifact import Artifact
from rubicon_ml.client.dataframe import Dataframe
Expand Down
3 changes: 2 additions & 1 deletion rubicon_ml/client/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from rubicon_ml.client import (
ArtifactMixin,
Base,
CommentMixin,
DataframeMixin,
Feature,
Metric,
Expand All @@ -21,7 +22,7 @@
from rubicon_ml.domain import Experiment as ExperimentDomain


class Experiment(Base, ArtifactMixin, DataframeMixin, TagMixin):
class Experiment(Base, ArtifactMixin, DataframeMixin, TagMixin, CommentMixin):
"""A client experiment.
An `experiment` represents a model run and is identified by
Expand Down
99 changes: 99 additions & 0 deletions rubicon_ml/client/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,3 +604,102 @@ def tags(self) -> List[str]:
return self._domain.tags

self._raise_rubicon_exception(return_err)


class CommentMixin:
"""Adds comment support to a client object."""

_domain: DOMAIN_TYPES

def _get_commentable_identifiers(self):
project_name, experiment_id = self._parent._get_identifiers()
entity_identifier = None

# experiments do not return an entity identifier - they are the entity
if isinstance(self, client.Experiment):
experiment_id = self.id
# dataframes and artifacts are identified by their `id`s
elif isinstance(self, client.Dataframe) or isinstance(self, client.Artifact):
entity_identifier = self.id
# everything else is identified by its `name`
else:
entity_identifier = self.name

return project_name, experiment_id, entity_identifier

@failsafe
def add_comments(self, comments: List[str]):
"""Add comments to this client object.
Parameters
----------
comments : list of str
The comment values to add.
"""
if not isinstance(comments, list) or not all(
[isinstance(comment, str) for comment in comments]
):
raise ValueError("`comments` must be `list` of type `str`")

project_name, experiment_id, entity_identifier = self._get_commentable_identifiers()

self._domain.add_comments(comments)
for repo in self.repositories:
repo.add_comments(
project_name,
comments,
experiment_id=experiment_id,
entity_identifier=entity_identifier,
entity_type=self.__class__.__name__,
)

@failsafe
def remove_comments(self, comments: List[str]):
"""Remove comments from this client object.
Parameters
----------
comments : list of str
The comment values to remove.
"""
project_name, experiment_id, entity_identifier = self._get_commentable_identifiers()

self._domain.remove_comments(comments)
for repo in self.repositories:
repo.remove_comments(
project_name,
comments,
experiment_id=experiment_id,
entity_identifier=entity_identifier,
entity_type=self.__class__.__name__,
)

def _update_comments(self, comment_data):
"""Add or remove the comments in `comment_data` based on
their key.
"""
for comment in comment_data:
self._domain.add_comments(comment.get("added_comments", []))
self._domain.remove_comments(comment.get("removed_comments", []))

@property
def comments(self) -> List[str]:
"""Get this client object's comments."""
project_name, experiment_id, entity_identifier = self._get_commentable_identifiers()
return_err = None
for repo in self.repositories:
try:
comment_data = repo.get_comments(
project_name,
experiment_id=experiment_id,
entity_identifier=entity_identifier,
entity_type=self.__class__.__name__,
)
except Exception as err:
return_err = err
else:
self._update_comments(comment_data)

return self._domain.comments

self._raise_rubicon_exception(return_err)
14 changes: 14 additions & 0 deletions rubicon_ml/client/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def _create_experiment_domain(
commit_hash,
training_metadata,
tags,
comments,
):
"""Instantiates and returns an experiment domain object."""
if self.is_auto_git_enabled:
Expand All @@ -94,6 +95,7 @@ def _create_experiment_domain(
commit_hash=commit_hash,
training_metadata=training_metadata,
tags=tags,
comments=comments,
)

def _group_experiments(self, experiments: List[Experiment], group_by: Optional[str] = None):
Expand Down Expand Up @@ -218,6 +220,7 @@ def log_experiment(
commit_hash: Optional[str] = None,
training_metadata: Optional[Union[Tuple, List[Tuple]]] = None,
tags: Optional[List[str]] = None,
comments: Optional[List[str]] = None,
) -> Experiment:
"""Log a new experiment to this project.
Expand Down Expand Up @@ -248,6 +251,8 @@ def log_experiment(
to differentiate between the type of model or classifier
used during the experiment (i.e. `linear regression`
or `random forest`).
comments : list of str, optional
Values to comment the experiment with.
Returns
-------
Expand All @@ -260,6 +265,14 @@ def log_experiment(
if not isinstance(tags, list) or not all([isinstance(tag, str) for tag in tags]):
raise ValueError("`tags` must be `list` of type `str`")

if comments is None:
comments = []

if not isinstance(comments, list) or not all(
[isinstance(comment, str) for comment in comments]
):
raise ValueError("`comments` must be `list` of type `str`")

experiment = self._create_experiment_domain(
name,
description,
Expand All @@ -268,6 +281,7 @@ def log_experiment(
commit_hash,
training_metadata,
tags,
comments,
)
for repo in self.repositories:
repo.create_experiment(experiment)
Expand Down
11 changes: 11 additions & 0 deletions rubicon_ml/domain/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,14 @@ def add_comments(self, comments: List[str]):
A list of string comments to add to the domain model.
"""
self.comments.extend(comments)

def remove_comments(self, comments: List[str]):
"""
Remove comments from this model.
Parameters
----------
comments : List[str]
A list of string comments to remove from this domain model.
"""
self.comments = list(set(self.comments).difference(set(comments)))
29 changes: 29 additions & 0 deletions rubicon_ml/repository/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,35 @@ def add_comments(

self._persist_domain({"added_comments": comments}, comment_metadata_path)

def remove_comments(
self, project_name, comments, experiment_id=None, entity_identifier=None, entity_type=None
):
"""Delete comments from the configured filesystem.
Parameters
----------
project_name : str
The name of the project the object to delete
comments from belongs to.
comments : list of str
The comment values to delete.
experiment_id : str, optional
The ID of the experiment to delete the comments
`comments` from.
entity_identifier : str, optional
The ID or name of the entity to apply the comments
`comments` to.
entity_type : str, optional
The name of the entity's type as returned by
`entity_cls.__class__.__name__`.
"""
comment_metadata_root = self._get_comment_metadata_root(
project_name, experiment_id, entity_identifier, entity_type
)
comment_metadata_path = f"{comment_metadata_root}/comments_{domain.utils.uuid.uuid4()}.json"

self._persist_domain({"removed_comments": comments}, comment_metadata_path)

def _sort_comment_paths(self, comment_paths):
"""Sorts the paths in `comment_paths` by when they were
created.
Expand Down
39 changes: 38 additions & 1 deletion tests/unit/client/test_mixin_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

import pytest

from rubicon_ml.client.mixin import ArtifactMixin, DataframeMixin, TagMixin
from rubicon_ml.client.mixin import (
ArtifactMixin,
CommentMixin,
DataframeMixin,
TagMixin,
)
from rubicon_ml.exceptions import RubiconException


Expand Down Expand Up @@ -537,3 +542,35 @@ def raise_error():
with pytest.raises(RubiconException) as e:
experiment.tags()
assert "all configured storage backends failed" in str(e)


def test_add_comments(project_client):
project = project_client
experiment = project.log_experiment()

CommentMixin.add_comments(experiment, ["this is a comment"])

assert experiment.comments == ["this is a comment"]


def test_remove_comments(project_client):
project = project_client
experiment = project.log_experiment(comments=["comment 1", "comment 2"])

CommentMixin.remove_comments(experiment, ["comment 1", "comment 2"])

assert experiment.comments == []


@mock.patch("rubicon_ml.repository.BaseRepository.get_comments")
def test_comments_multiple_backend_error(mock_get_comments, project_composite_client):
project = project_composite_client
experiment = project.log_experiment(comments=["comment 1", "comment 2"])

def raise_error():
raise RubiconException()

mock_get_comments.side_effect = _raise_error
with pytest.raises(RubiconException) as e:
experiment.comments()
assert "all configured storage backends failed" in str(e)
7 changes: 7 additions & 0 deletions tests/unit/domain/test_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,10 @@ def test_add_comments():
taggable.add_comments(["x"])

assert taggable.comments == ["x"]


def test_remove_comments():
taggable = Taggable(comments=["comment 1", "comment 2"])
taggable.remove_comments(["comment 1"])

assert taggable.comments == ["comment 2"]
22 changes: 22 additions & 0 deletions tests/unit/repository/test_base_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,28 @@ def test_add_comments(memory_repository):
assert ["this is a comment"] == comments_json["added_comments"]


def test_remove_comments(memory_repository):
repository = memory_repository
experiment = _create_experiment(repository, comments=["this is a comment"])
repository.remove_comments(
experiment.project_name,
["this is a comment"],
experiment_id=experiment.id,
entity_type=experiment.__class__.__name__,
)

comments_glob = f"{repository.root_dir}/{slugify(experiment.project_name)}/experiments/{experiment.id}/comments_*.json"
comments_files = repository.filesystem.glob(comments_glob)

assert len(comments_files) == 1

open_file = repository.filesystem.open(comments_files[0])
with open_file as f:
comments_json = json.load(f)

assert ["this is a comment"] == comments_json["removed_comments"]


def test_get_comments(memory_repository):
repository = memory_repository
experiment = _create_experiment(repository)
Expand Down

0 comments on commit a4eb8ea

Please sign in to comment.