Skip to content

Commit

Permalink
Add new reader and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenpardy committed Sep 26, 2023
1 parent f318608 commit 9eee769
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 3 deletions.
6 changes: 6 additions & 0 deletions rubicon_ml/client/artifact.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import os
import pickle
import warnings
Expand Down Expand Up @@ -86,6 +87,11 @@ def get_data(self, unpickle: bool = False):
return data
raise RubiconException("all configured storage backends failed") from return_err

@failsafe
def get_json(self):
data = self.get_data()
return json.loads(data)

@failsafe
def download(self, location: Optional[str] = None, name: Optional[str] = None):
"""Download this artifact's data.
Expand Down
20 changes: 17 additions & 3 deletions tests/integration/test_rubicon.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def test_rubicon(rubicon, request):
df=pd.DataFrame([[0, 1], [1, 0]], columns=["a", "b"])
)

json_dict = {"hello": "world", "numbers": [1, 2, 3]}

written_project_json = written_project.log_json(
name=f"Test JSON {uuid.uuid4()}.json", json_object=json_dict
)
written_experiment_json = written_experiment.log_json(
name=f"Test JSON {uuid.uuid4()}.json", json_object=json_dict
)

written_project_dataframe.add_tags(["x", "y"])
written_project_dataframe.remove_tags(["x"])

Expand Down Expand Up @@ -89,17 +98,22 @@ def test_rubicon(rubicon, request):
assert written_metric.value == read_metrics[0].value

read_project_artifacts = read_project.artifacts()
assert len(read_project_artifacts) == 1
assert len(read_project_artifacts) == 2
assert written_project_artifact.id == read_project_artifacts[0].id
assert written_project_artifact.data == read_project_artifacts[0].data
assert written_project_json.id == read_project_artifacts[1].id
assert written_project_json.data == read_project_artifacts[1].data

read_project.delete_artifacts([read_project_artifacts[0].id])
read_project.delete_artifacts([artifact.id for artifact in read_project_artifacts])
assert len(read_project.artifacts()) == 0

read_experiment_artifacts = read_experiment.artifacts()
assert len(read_experiment_artifacts) == 1
assert len(read_experiment_artifacts) == 2
assert written_experiment_artifact.id == read_experiment_artifacts[0].id
assert written_experiment_artifact.data == read_experiment_artifacts[0].data
assert written_experiment_json.id == read_experiment_artifacts[1].id
assert written_experiment_json.data == read_experiment_artifacts[1].data
assert json_dict == read_experiment_artifacts[1].get_json()

read_project_dataframes = read_project.dataframes()
assert len(read_project_dataframes) == 1
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/client/test_artifact_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def test_get_data(project_client):
assert artifact.data == data


def test_get_json(project_client):
project = project_client
data = {"hello": "world", "numbers": [1, 2, 3]}
artifact = project.log_json(json_object=data, name="test.json")

assert artifact.get_json() == data


def test_internal_get_data_multiple_backend_error():
rb = Rubicon(
composite_config=[
Expand Down

0 comments on commit 9eee769

Please sign in to comment.