diff --git a/rubicon_ml/client/artifact.py b/rubicon_ml/client/artifact.py index 5a24bb81..f07bbbcd 100644 --- a/rubicon_ml/client/artifact.py +++ b/rubicon_ml/client/artifact.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import os import pickle import warnings @@ -86,6 +87,11 @@ def get_data(self, unpickle: bool = False): return data raise RubiconException("all configured storage backends failed") from return_err + @failsafe + def get_json(self): + data = self.get_data() + return json.loads(data) + @failsafe def download(self, location: Optional[str] = None, name: Optional[str] = None): """Download this artifact's data. diff --git a/tests/integration/test_rubicon.py b/tests/integration/test_rubicon.py index fbe51052..2e80fb79 100644 --- a/tests/integration/test_rubicon.py +++ b/tests/integration/test_rubicon.py @@ -61,6 +61,15 @@ def test_rubicon(rubicon, request): df=pd.DataFrame([[0, 1], [1, 0]], columns=["a", "b"]) ) + json_dict = {"hello": "world", "numbers": [1, 2, 3]} + + written_project_json = written_project.log_json( + name=f"Test JSON {uuid.uuid4()}.json", json_object=json_dict + ) + written_experiment_json = written_experiment.log_json( + name=f"Test JSON {uuid.uuid4()}.json", json_object=json_dict + ) + written_project_dataframe.add_tags(["x", "y"]) written_project_dataframe.remove_tags(["x"]) @@ -89,17 +98,22 @@ def test_rubicon(rubicon, request): assert written_metric.value == read_metrics[0].value read_project_artifacts = read_project.artifacts() - assert len(read_project_artifacts) == 1 + assert len(read_project_artifacts) == 2 assert written_project_artifact.id == read_project_artifacts[0].id assert written_project_artifact.data == read_project_artifacts[0].data + assert written_project_json.id == read_project_artifacts[1].id + assert written_project_json.data == read_project_artifacts[1].data - read_project.delete_artifacts([read_project_artifacts[0].id]) + read_project.delete_artifacts([artifact.id for artifact in read_project_artifacts]) assert len(read_project.artifacts()) == 0 read_experiment_artifacts = read_experiment.artifacts() - assert len(read_experiment_artifacts) == 1 + assert len(read_experiment_artifacts) == 2 assert written_experiment_artifact.id == read_experiment_artifacts[0].id assert written_experiment_artifact.data == read_experiment_artifacts[0].data + assert written_experiment_json.id == read_experiment_artifacts[1].id + assert written_experiment_json.data == read_experiment_artifacts[1].data + assert json_dict == read_experiment_artifacts[1].get_json() read_project_dataframes = read_project.dataframes() assert len(read_project_dataframes) == 1 diff --git a/tests/unit/client/test_artifact_client.py b/tests/unit/client/test_artifact_client.py index 0f8fc132..82c84cf9 100644 --- a/tests/unit/client/test_artifact_client.py +++ b/tests/unit/client/test_artifact_client.py @@ -29,6 +29,14 @@ def test_get_data(project_client): assert artifact.data == data +def test_get_json(project_client): + project = project_client + data = {"hello": "world", "numbers": [1, 2, 3]} + artifact = project.log_json(json_object=data, name="test.json") + + assert artifact.get_json() == data + + def test_internal_get_data_multiple_backend_error(): rb = Rubicon( composite_config=[