Skip to content

Commit

Permalink
Merge branch 'main' into json_helper
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenpardy authored Sep 26, 2023
2 parents 9eee769 + 5ecbf82 commit 5cb71ab
Show file tree
Hide file tree
Showing 12 changed files with 172 additions and 125 deletions.
10 changes: 7 additions & 3 deletions rubicon_ml/client/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from rubicon_ml.client.base import Base
from rubicon_ml.client.mixin import TagMixin
from rubicon_ml.client.utils.exception_handling import failsafe
from rubicon_ml.exceptions import RubiconException

if TYPE_CHECKING:
from rubicon_ml.client import Project
Expand Down Expand Up @@ -50,7 +49,9 @@ def _get_data(self):
"""Loads the data associated with this artifact."""
project_name, experiment_id = self.parent._get_identifiers()
return_err = None

self._data = None

for repo in self.repositories or []:
try:
self._data = repo.get_artifact_data(
Expand All @@ -60,8 +61,9 @@ def _get_data(self):
return_err = err
else:
return

if self._data is None:
raise RubiconException("all configured storage backends failed") from return_err
self._raise_rubicon_exception(return_err)

@failsafe
def get_data(self, unpickle: bool = False):
Expand All @@ -76,6 +78,7 @@ def get_data(self, unpickle: bool = False):
"""
project_name, experiment_id = self.parent._get_identifiers()
return_err = None

for repo in self.repositories or []:
try:
data = repo.get_artifact_data(project_name, self.id, experiment_id=experiment_id)
Expand All @@ -85,7 +88,8 @@ def get_data(self, unpickle: bool = False):
if unpickle:
data = pickle.loads(data)
return data
raise RubiconException("all configured storage backends failed") from return_err

self._raise_rubicon_exception(return_err)

@failsafe
def get_json(self):
Expand Down
8 changes: 8 additions & 0 deletions rubicon_ml/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import TYPE_CHECKING, List, Optional

from rubicon_ml.exceptions import RubiconException

if TYPE_CHECKING:
from rubicon_ml.client import Config
from rubicon_ml.domain import DOMAIN_TYPES
Expand All @@ -26,6 +28,12 @@ def __init__(self, domain: DOMAIN_TYPES, config: Optional[Config] = None):
def __str__(self) -> str:
return self._domain.__str__()

def _raise_rubicon_exception(self, exception: Exception):
if len(self.repositories) > 1:
raise RubiconException("all configured storage backends failed") from exception
else:
raise exception

@property
def repository(self) -> Optional[BaseRepository]:
return self._config.repository if self._config is not None else None
Expand Down
46 changes: 27 additions & 19 deletions rubicon_ml/client/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
)
from rubicon_ml.client.utils.exception_handling import failsafe
from rubicon_ml.client.utils.tags import filter_children
from rubicon_ml.exceptions import RubiconException

if TYPE_CHECKING:
from rubicon_ml.client import Project
Expand Down Expand Up @@ -112,16 +111,18 @@ def metrics(self, name=None, tags=[], qtype="or"):
The metrics previously logged to this experiment.
"""
return_err = None

for repo in self.repositories:
try:
metrics = [Metric(m, self) for m in repo.get_metrics(self.project.name, self.id)]
except Exception as err:
return_err = err
else:
self._metrics = filter_children(metrics, tags, qtype, name)

return self._metrics

raise RubiconException("all configured storage backends failed") from return_err
self._raise_rubicon_exception(return_err)

@failsafe
def metric(self, name=None, id=None):
Expand All @@ -144,18 +145,18 @@ def metric(self, name=None, id=None):

if name is not None:
return_err = None

for repo in self.repositories:
try:
metric = repo.get_metric(self.project.name, self.id, name)
except Exception as err:
return_err = err
else:
metric = Metric(metric, self)
return metric
raise RubiconException("all configured storage backends failed") from return_err
return Metric(metric, self)

self._raise_rubicon_exception(return_err)
else:
metric = [m for m in self.metrics() if m.id == id][0]
return metric
return [m for m in self.metrics() if m.id == id][0]

@failsafe
def log_feature(self, name, description=None, importance=None, tags=[]):
Expand Down Expand Up @@ -183,6 +184,7 @@ def log_feature(self, name, description=None, importance=None, tags=[]):
raise ValueError("`tags` must be `list` of type `str`")

feature = domain.Feature(name, description=description, importance=importance, tags=tags)

for repo in self.repositories:
repo.create_feature(feature, self.project.name, self.id)

Expand All @@ -208,16 +210,18 @@ def features(self, name=None, tags=[], qtype="or"):
The features previously logged to this experiment.
"""
return_err = None

for repo in self.repositories:
try:
features = [Feature(f, self) for f in repo.get_features(self.project.name, self.id)]
except Exception as err:
return_err = err
else:
self._features = filter_children(features, tags, qtype, name)

return self._features

raise RubiconException("all configured storage backends failed") from return_err
self._raise_rubicon_exception(return_err)

@failsafe
def feature(self, name=None, id=None):
Expand All @@ -237,20 +241,21 @@ def feature(self, name=None, id=None):
"""
if (name is None and id is None) or (name is not None and id is not None):
raise ValueError("`name` OR `id` required.")

if name is not None:
return_err = None

for repo in self.repositories:
try:
feature = repo.get_feature(self.project.name, self.id, name)
except Exception as err:
return_err = err
else:
feature = Feature(feature, self)
return feature
raise RubiconException("all configured storage backends failed") from return_err
return Feature(feature, self)

self._raise_rubicon_exception(return_err)
else:
feature = [f for f in self.features() if f.id == id][0]
return feature
return [f for f in self.features() if f.id == id][0]

@failsafe
def log_parameter(self, name, value=None, description=None, tags=[]):
Expand Down Expand Up @@ -280,6 +285,7 @@ def log_parameter(self, name, value=None, description=None, tags=[]):
raise ValueError("`tags` must be `list` of type `str`")

parameter = domain.Parameter(name, value=value, description=description, tags=tags)

for repo in self.repositories:
repo.create_parameter(parameter, self.project.name, self.id)

Expand All @@ -305,6 +311,7 @@ def parameters(self, name=None, tags=[], qtype="or"):
The parameters previously logged to this experiment.
"""
return_err = None

for repo in self.repositories:
try:
parameters = [
Expand All @@ -314,9 +321,10 @@ def parameters(self, name=None, tags=[], qtype="or"):
return_err = err
else:
self._parameters = filter_children(parameters, tags, qtype, name)

return self._parameters

raise RubiconException("all configured storage backends failed") from return_err
self._raise_rubicon_exception(return_err)

@failsafe
def parameter(self, name=None, id=None):
Expand All @@ -339,18 +347,18 @@ def parameter(self, name=None, id=None):

if name is not None:
return_err = None

for repo in self.repositories:
try:
parameter = repo.get_parameter(self.project.name, self.id, name)
except Exception as err:
return_err = err
else:
parameter = Parameter(parameter, self)
return parameter
raise RubiconException("all configured storage backends failed") from return_err
return Parameter(parameter, self)

self._raise_rubicon_exception(return_err)
else:
parameter = [p for p in self.parameters() if p.id == id][0]
return parameter
return [p for p in self.parameters() if p.id == id][0]

@property
def id(self):
Expand Down
10 changes: 5 additions & 5 deletions rubicon_ml/client/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def artifacts(
self._artifacts = filter_children(artifacts, tags, qtype, name)
return self._artifacts

raise RubiconException("all configured storage backends failed") from return_err
self._raise_rubicon_exception(return_err)

@failsafe
def artifact(self, name: Optional[str] = None, id: Optional[str] = None) -> Artifact:
Expand Down Expand Up @@ -295,7 +295,7 @@ def artifact(self, name: Optional[str] = None, id: Optional[str] = None) -> Arti
else:
return artifact

raise RubiconException("all configured storage backends failed") from return_err
self._raise_rubicon_exception(return_err)

@failsafe
def delete_artifacts(self, ids: List[str]):
Expand Down Expand Up @@ -440,7 +440,7 @@ def dataframes(
self._dataframes = filter_children(dataframes, tags, qtype, name)
return self._dataframes

raise RubiconException("all configured storage backends failed") from return_err
self._raise_rubicon_exception(return_err)

@failsafe
def dataframe(self, name: Optional[str] = None, id: Optional[str] = None) -> Dataframe:
Expand Down Expand Up @@ -490,7 +490,7 @@ def dataframe(self, name: Optional[str] = None, id: Optional[str] = None) -> Dat
else:
return dataframe

raise RubiconException("all configured storage backends failed") from return_err
self._raise_rubicon_exception(return_err)

@failsafe
def delete_dataframes(self, ids: List[str]):
Expand Down Expand Up @@ -603,4 +603,4 @@ def tags(self) -> List[str]:

return self._domain.tags

raise RubiconException("all configured storage backends failed") from return_err
self._raise_rubicon_exception(return_err)
13 changes: 9 additions & 4 deletions rubicon_ml/client/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def log_experiment(
"""
if tags is None:
tags = []

if not isinstance(tags, list) or not all([isinstance(tag, str) for tag in tags]):
raise ValueError("`tags` must be `list` of type `str`")

Expand Down Expand Up @@ -302,18 +303,19 @@ def experiment(self, id: Optional[str] = None, name: Optional[str] = None) -> Ex
" Returning most recently logged."
)

experiment = experiments[-1]
return experiment
return experiments[-1]
else:
return_err = None

for repo in self.repositories:
try:
experiment = Experiment(repo.get_experiment(self.name, id), self)
except Exception as err:
return_err = err
else:
return experiment
raise RubiconException("all configured storage backends failed") from return_err

self._raise_rubicon_exception(return_err)

@failsafe
def experiments(
Expand All @@ -338,17 +340,20 @@ def experiments(
"""
if tags is None:
tags = []

return_err = None

for repo in self.repositories:
try:
experiments = [Experiment(e, self) for e in repo.get_experiments(self.name)]
except Exception as err:
return_err = err
else:
self._experiments = filter_children(experiments, tags, qtype, name)

return self._experiments

raise RubiconException("all configured storage backends failed") from return_err
self._raise_rubicon_exception(return_err)

@failsafe
def dataframes(
Expand Down
20 changes: 14 additions & 6 deletions rubicon_ml/client/rubicon.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,21 @@ def get_project(self, name=None, id=None):

if name is not None:
return_err = None

for repo in self.repositories:
try:
project = repo.get_project(name)
except Exception as err:
return_err = err
else:
project = Project(project, self.config)
return project
raise RubiconException("all configured storage backends failed") from return_err
return Project(project, self.config)

if len(self.repositories) > 1:
raise RubiconException("all configured storage backends failed") from return_err
else:
raise return_err
else:
project = [p for p in self.projects() if p.id == id][0]
return project
return [p for p in self.projects() if p.id == id][0]

def get_project_as_dask_df(self, name, group_by=None):
"""DEPRECATED: Available for backwards compatibility."""
Expand Down Expand Up @@ -229,14 +232,19 @@ def projects(self):
The list of available projects.
"""
return_err = None

for repo in self.repositories:
try:
projects = [Project(project, self.config) for project in repo.get_projects()]
except Exception as err:
return_err = err
else:
return projects
raise RubiconException("all configured storage backends failed") from return_err

if len(self.repositories) > 1:
raise RubiconException("all configured storage backends failed") from return_err
else:
raise return_err

@failsafe
def sync(self, project_name, s3_root_dir):
Expand Down
Loading

0 comments on commit 5cb71ab

Please sign in to comment.