Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a simple JSON wrapper #385

Merged
merged 4 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions rubicon_ml/client/artifact.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import os
import pickle
import warnings
Expand Down Expand Up @@ -90,6 +91,11 @@ def get_data(self, unpickle: bool = False):

self._raise_rubicon_exception(return_err)

@failsafe
def get_json(self):
data = self.get_data()
return json.loads(data)

@failsafe
def download(self, location: Optional[str] = None, name: Optional[str] = None):
"""Download this artifact's data.
Expand Down
44 changes: 43 additions & 1 deletion rubicon_ml/client/mixin.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import json
import os
import pickle
import subprocess
import warnings
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Optional, TextIO, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TextIO, Union

import fsspec

Expand Down Expand Up @@ -312,6 +313,47 @@ def delete_artifacts(self, ids: List[str]):
for repo in self.repositories:
repo.delete_artifact(project_name, artifact_id, experiment_id=experiment_id)

@failsafe
def log_json(
self,
json_object: Dict[str, Any],
name: Optional[str] = None,
description: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> Artifact:
"""Log a python dictionary to a JSON file.

Parameters
----------
json_object : Dict[str, Any]
A python dictionary capable of being converted to JSON.
name : Optional[str], optional
A name for this JSON file, by default None
description : Optional[str], optional
A description for this file, by default None
tags : Optional[List[str]], optional
Any Rubicon tags, by default None

Returns
-------
Artifact
The new artifact.

"""
if name is None:
json_name = f"dictionary-{datetime.now().strftime('%Y_%m_%d-%I_%M_%S_%p')}.json"
else:
json_name = name

artifact = self.log_artifact(
data_bytes=bytes(json.dumps(json_object), "utf-8"),
name=json_name,
description=description,
tags=tags,
)

return artifact


class DataframeMixin:
"""Adds dataframe support to a client object."""
Expand Down
20 changes: 17 additions & 3 deletions tests/integration/test_rubicon.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def test_rubicon(rubicon, request):
df=pd.DataFrame([[0, 1], [1, 0]], columns=["a", "b"])
)

json_dict = {"hello": "world", "numbers": [1, 2, 3]}

written_project_json = written_project.log_json(
name=f"Test JSON {uuid.uuid4()}.json", json_object=json_dict
)
written_experiment_json = written_experiment.log_json(
name=f"Test JSON {uuid.uuid4()}.json", json_object=json_dict
)

written_project_dataframe.add_tags(["x", "y"])
written_project_dataframe.remove_tags(["x"])

Expand Down Expand Up @@ -89,17 +98,22 @@ def test_rubicon(rubicon, request):
assert written_metric.value == read_metrics[0].value

read_project_artifacts = read_project.artifacts()
assert len(read_project_artifacts) == 1
assert len(read_project_artifacts) == 2
assert written_project_artifact.id == read_project_artifacts[0].id
assert written_project_artifact.data == read_project_artifacts[0].data
assert written_project_json.id == read_project_artifacts[1].id
assert written_project_json.data == read_project_artifacts[1].data

read_project.delete_artifacts([read_project_artifacts[0].id])
read_project.delete_artifacts([artifact.id for artifact in read_project_artifacts])
assert len(read_project.artifacts()) == 0

read_experiment_artifacts = read_experiment.artifacts()
assert len(read_experiment_artifacts) == 1
assert len(read_experiment_artifacts) == 2
assert written_experiment_artifact.id == read_experiment_artifacts[0].id
assert written_experiment_artifact.data == read_experiment_artifacts[0].data
assert written_experiment_json.id == read_experiment_artifacts[1].id
assert written_experiment_json.data == read_experiment_artifacts[1].data
assert json_dict == read_experiment_artifacts[1].get_json()

read_project_dataframes = read_project.dataframes()
assert len(read_project_dataframes) == 1
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/client/test_artifact_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def test_get_data(project_client):
assert artifact.data == data


def test_get_json(project_client):
project = project_client
data = {"hello": "world", "numbers": [1, 2, 3]}
artifact = project.log_json(json_object=data, name="test.json")

assert artifact.get_json() == data


def test_internal_get_data_multiple_backend_error():
rb = Rubicon(
composite_config=[
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/client/test_mixin_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,20 @@ def test_log_pip_requirements(project_client, mock_completed_process_empty):
assert artifact.data == b"\n"


def test_log_json(project_client):
project = project_client

data = {"hello": "world", "foo": [1, 2, 3]}
artifact_a = ArtifactMixin.log_json(project, data, name="test.json")
artifact_b = ArtifactMixin.log_json(project, data, name="test.txt")

artifacts = ArtifactMixin.artifacts(project)

assert len(artifacts) == 2
assert artifact_a.id in [a.id for a in artifacts]
assert artifact_b.id in [a.id for a in artifacts]


def test_artifacts(project_client):
project = project_client
data = b"content"
Expand Down
Loading