diff --git a/rubicon_ml/client/artifact.py b/rubicon_ml/client/artifact.py index f44ab428..ebd44f4b 100644 --- a/rubicon_ml/client/artifact.py +++ b/rubicon_ml/client/artifact.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import os import pickle import warnings @@ -90,6 +91,11 @@ def get_data(self, unpickle: bool = False): self._raise_rubicon_exception(return_err) + @failsafe + def get_json(self): + data = self.get_data() + return json.loads(data) + @failsafe def download(self, location: Optional[str] = None, name: Optional[str] = None): """Download this artifact's data. diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index ec35bbae..26c80abd 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -1,11 +1,12 @@ from __future__ import annotations +import json import os import pickle import subprocess import warnings from datetime import datetime -from typing import TYPE_CHECKING, Any, List, Optional, TextIO, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, TextIO, Union import fsspec @@ -312,6 +313,47 @@ def delete_artifacts(self, ids: List[str]): for repo in self.repositories: repo.delete_artifact(project_name, artifact_id, experiment_id=experiment_id) + @failsafe + def log_json( + self, + json_object: Dict[str, Any], + name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[List[str]] = None, + ) -> Artifact: + """Log a python dictionary to a JSON file. + + Parameters + ---------- + json_object : Dict[str, Any] + A python dictionary capable of being converted to JSON. + name : Optional[str], optional + A name for this JSON file, by default None + description : Optional[str], optional + A description for this file, by default None + tags : Optional[List[str]], optional + Any Rubicon tags, by default None + + Returns + ------- + Artifact + The new artifact. + + """ + if name is None: + json_name = f"dictionary-{datetime.now().strftime('%Y_%m_%d-%I_%M_%S_%p')}.json" + else: + json_name = name + + artifact = self.log_artifact( + data_bytes=bytes(json.dumps(json_object), "utf-8"), + name=json_name, + description=description, + tags=tags, + ) + + return artifact + class DataframeMixin: """Adds dataframe support to a client object.""" diff --git a/tests/integration/test_rubicon.py b/tests/integration/test_rubicon.py index fbe51052..2e80fb79 100644 --- a/tests/integration/test_rubicon.py +++ b/tests/integration/test_rubicon.py @@ -61,6 +61,15 @@ def test_rubicon(rubicon, request): df=pd.DataFrame([[0, 1], [1, 0]], columns=["a", "b"]) ) + json_dict = {"hello": "world", "numbers": [1, 2, 3]} + + written_project_json = written_project.log_json( + name=f"Test JSON {uuid.uuid4()}.json", json_object=json_dict + ) + written_experiment_json = written_experiment.log_json( + name=f"Test JSON {uuid.uuid4()}.json", json_object=json_dict + ) + written_project_dataframe.add_tags(["x", "y"]) written_project_dataframe.remove_tags(["x"]) @@ -89,17 +98,22 @@ def test_rubicon(rubicon, request): assert written_metric.value == read_metrics[0].value read_project_artifacts = read_project.artifacts() - assert len(read_project_artifacts) == 1 + assert len(read_project_artifacts) == 2 assert written_project_artifact.id == read_project_artifacts[0].id assert written_project_artifact.data == read_project_artifacts[0].data + assert written_project_json.id == read_project_artifacts[1].id + assert written_project_json.data == read_project_artifacts[1].data - read_project.delete_artifacts([read_project_artifacts[0].id]) + read_project.delete_artifacts([artifact.id for artifact in read_project_artifacts]) assert len(read_project.artifacts()) == 0 read_experiment_artifacts = read_experiment.artifacts() - assert len(read_experiment_artifacts) == 1 + assert len(read_experiment_artifacts) == 2 assert written_experiment_artifact.id == read_experiment_artifacts[0].id assert written_experiment_artifact.data == read_experiment_artifacts[0].data + assert written_experiment_json.id == read_experiment_artifacts[1].id + assert written_experiment_json.data == read_experiment_artifacts[1].data + assert json_dict == read_experiment_artifacts[1].get_json() read_project_dataframes = read_project.dataframes() assert len(read_project_dataframes) == 1 diff --git a/tests/unit/client/test_artifact_client.py b/tests/unit/client/test_artifact_client.py index 0f8fc132..82c84cf9 100644 --- a/tests/unit/client/test_artifact_client.py +++ b/tests/unit/client/test_artifact_client.py @@ -29,6 +29,14 @@ def test_get_data(project_client): assert artifact.data == data +def test_get_json(project_client): + project = project_client + data = {"hello": "world", "numbers": [1, 2, 3]} + artifact = project.log_json(json_object=data, name="test.json") + + assert artifact.get_json() == data + + def test_internal_get_data_multiple_backend_error(): rb = Rubicon( composite_config=[ diff --git a/tests/unit/client/test_mixin_client.py b/tests/unit/client/test_mixin_client.py index 3aa89750..0ad0e893 100644 --- a/tests/unit/client/test_mixin_client.py +++ b/tests/unit/client/test_mixin_client.py @@ -135,6 +135,20 @@ def test_log_pip_requirements(project_client, mock_completed_process_empty): assert artifact.data == b"\n" +def test_log_json(project_client): + project = project_client + + data = {"hello": "world", "foo": [1, 2, 3]} + artifact_a = ArtifactMixin.log_json(project, data, name="test.json") + artifact_b = ArtifactMixin.log_json(project, data, name="test.txt") + + artifacts = ArtifactMixin.artifacts(project) + + assert len(artifacts) == 2 + assert artifact_a.id in [a.id for a in artifacts] + assert artifact_b.id in [a.id for a in artifacts] + + def test_artifacts(project_client): project = project_client data = b"content"