From 5ecbf82ab2d3abcf3f5a246fa227d1520d001f88 Mon Sep 17 00:00:00 2001 From: Ryan Soley Date: Fri, 22 Sep 2023 16:21:33 -0400 Subject: [PATCH] multi-backend errors (#384) * move multi-backend error to base * fix tests --- rubicon_ml/client/artifact.py | 10 +++- rubicon_ml/client/base.py | 8 +++ rubicon_ml/client/experiment.py | 46 +++++++++------ rubicon_ml/client/mixin.py | 10 ++-- rubicon_ml/client/project.py | 13 +++-- rubicon_ml/client/rubicon.py | 20 +++++-- tests/fixtures.py | 40 +++++++++++++ tests/unit/client/test_experiment_client.py | 58 +++++++------------ tests/unit/client/test_mixin_client.py | 46 ++++++--------- tests/unit/client/test_project_client.py | 22 ++++--- tests/unit/client/test_rubicon_client.py | 22 ++++--- .../client/utils/test_exception_handling.py | 2 +- 12 files changed, 172 insertions(+), 125 deletions(-) diff --git a/rubicon_ml/client/artifact.py b/rubicon_ml/client/artifact.py index 5a24bb81..f44ab428 100644 --- a/rubicon_ml/client/artifact.py +++ b/rubicon_ml/client/artifact.py @@ -10,7 +10,6 @@ from rubicon_ml.client.base import Base from rubicon_ml.client.mixin import TagMixin from rubicon_ml.client.utils.exception_handling import failsafe -from rubicon_ml.exceptions import RubiconException if TYPE_CHECKING: from rubicon_ml.client import Project @@ -49,7 +48,9 @@ def _get_data(self): """Loads the data associated with this artifact.""" project_name, experiment_id = self.parent._get_identifiers() return_err = None + self._data = None + for repo in self.repositories or []: try: self._data = repo.get_artifact_data( @@ -59,8 +60,9 @@ def _get_data(self): return_err = err else: return + if self._data is None: - raise RubiconException("all configured storage backends failed") from return_err + self._raise_rubicon_exception(return_err) @failsafe def get_data(self, unpickle: bool = False): @@ -75,6 +77,7 @@ def get_data(self, unpickle: bool = False): """ project_name, experiment_id = self.parent._get_identifiers() return_err = None + for repo in self.repositories or []: try: data = repo.get_artifact_data(project_name, self.id, experiment_id=experiment_id) @@ -84,7 +87,8 @@ def get_data(self, unpickle: bool = False): if unpickle: data = pickle.loads(data) return data - raise RubiconException("all configured storage backends failed") from return_err + + self._raise_rubicon_exception(return_err) @failsafe def download(self, location: Optional[str] = None, name: Optional[str] = None): diff --git a/rubicon_ml/client/base.py b/rubicon_ml/client/base.py index 30be1b76..92cb8000 100644 --- a/rubicon_ml/client/base.py +++ b/rubicon_ml/client/base.py @@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, List, Optional +from rubicon_ml.exceptions import RubiconException + if TYPE_CHECKING: from rubicon_ml.client import Config from rubicon_ml.domain import DOMAIN_TYPES @@ -26,6 +28,12 @@ def __init__(self, domain: DOMAIN_TYPES, config: Optional[Config] = None): def __str__(self) -> str: return self._domain.__str__() + def _raise_rubicon_exception(self, exception: Exception): + if len(self.repositories) > 1: + raise RubiconException("all configured storage backends failed") from exception + else: + raise exception + @property def repository(self) -> Optional[BaseRepository]: return self._config.repository if self._config is not None else None diff --git a/rubicon_ml/client/experiment.py b/rubicon_ml/client/experiment.py index 50803f07..c8e1abb4 100644 --- a/rubicon_ml/client/experiment.py +++ b/rubicon_ml/client/experiment.py @@ -14,7 +14,6 @@ ) from rubicon_ml.client.utils.exception_handling import failsafe from rubicon_ml.client.utils.tags import filter_children -from rubicon_ml.exceptions import RubiconException if TYPE_CHECKING: from rubicon_ml.client import Project @@ -112,6 +111,7 @@ def metrics(self, name=None, tags=[], qtype="or"): The metrics previously logged to this experiment. """ return_err = None + for repo in self.repositories: try: metrics = [Metric(m, self) for m in repo.get_metrics(self.project.name, self.id)] @@ -119,9 +119,10 @@ def metrics(self, name=None, tags=[], qtype="or"): return_err = err else: self._metrics = filter_children(metrics, tags, qtype, name) + return self._metrics - raise RubiconException("all configured storage backends failed") from return_err + self._raise_rubicon_exception(return_err) @failsafe def metric(self, name=None, id=None): @@ -144,18 +145,18 @@ def metric(self, name=None, id=None): if name is not None: return_err = None + for repo in self.repositories: try: metric = repo.get_metric(self.project.name, self.id, name) except Exception as err: return_err = err else: - metric = Metric(metric, self) - return metric - raise RubiconException("all configured storage backends failed") from return_err + return Metric(metric, self) + + self._raise_rubicon_exception(return_err) else: - metric = [m for m in self.metrics() if m.id == id][0] - return metric + return [m for m in self.metrics() if m.id == id][0] @failsafe def log_feature(self, name, description=None, importance=None, tags=[]): @@ -183,6 +184,7 @@ def log_feature(self, name, description=None, importance=None, tags=[]): raise ValueError("`tags` must be `list` of type `str`") feature = domain.Feature(name, description=description, importance=importance, tags=tags) + for repo in self.repositories: repo.create_feature(feature, self.project.name, self.id) @@ -208,6 +210,7 @@ def features(self, name=None, tags=[], qtype="or"): The features previously logged to this experiment. """ return_err = None + for repo in self.repositories: try: features = [Feature(f, self) for f in repo.get_features(self.project.name, self.id)] @@ -215,9 +218,10 @@ def features(self, name=None, tags=[], qtype="or"): return_err = err else: self._features = filter_children(features, tags, qtype, name) + return self._features - raise RubiconException("all configured storage backends failed") from return_err + self._raise_rubicon_exception(return_err) @failsafe def feature(self, name=None, id=None): @@ -237,20 +241,21 @@ def feature(self, name=None, id=None): """ if (name is None and id is None) or (name is not None and id is not None): raise ValueError("`name` OR `id` required.") + if name is not None: return_err = None + for repo in self.repositories: try: feature = repo.get_feature(self.project.name, self.id, name) except Exception as err: return_err = err else: - feature = Feature(feature, self) - return feature - raise RubiconException("all configured storage backends failed") from return_err + return Feature(feature, self) + + self._raise_rubicon_exception(return_err) else: - feature = [f for f in self.features() if f.id == id][0] - return feature + return [f for f in self.features() if f.id == id][0] @failsafe def log_parameter(self, name, value=None, description=None, tags=[]): @@ -280,6 +285,7 @@ def log_parameter(self, name, value=None, description=None, tags=[]): raise ValueError("`tags` must be `list` of type `str`") parameter = domain.Parameter(name, value=value, description=description, tags=tags) + for repo in self.repositories: repo.create_parameter(parameter, self.project.name, self.id) @@ -305,6 +311,7 @@ def parameters(self, name=None, tags=[], qtype="or"): The parameters previously logged to this experiment. """ return_err = None + for repo in self.repositories: try: parameters = [ @@ -314,9 +321,10 @@ def parameters(self, name=None, tags=[], qtype="or"): return_err = err else: self._parameters = filter_children(parameters, tags, qtype, name) + return self._parameters - raise RubiconException("all configured storage backends failed") from return_err + self._raise_rubicon_exception(return_err) @failsafe def parameter(self, name=None, id=None): @@ -339,18 +347,18 @@ def parameter(self, name=None, id=None): if name is not None: return_err = None + for repo in self.repositories: try: parameter = repo.get_parameter(self.project.name, self.id, name) except Exception as err: return_err = err else: - parameter = Parameter(parameter, self) - return parameter - raise RubiconException("all configured storage backends failed") from return_err + return Parameter(parameter, self) + + self._raise_rubicon_exception(return_err) else: - parameter = [p for p in self.parameters() if p.id == id][0] - return parameter + return [p for p in self.parameters() if p.id == id][0] @property def id(self): diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index 37cc5c0e..ec35bbae 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -248,7 +248,7 @@ def artifacts( self._artifacts = filter_children(artifacts, tags, qtype, name) return self._artifacts - raise RubiconException("all configured storage backends failed") from return_err + self._raise_rubicon_exception(return_err) @failsafe def artifact(self, name: Optional[str] = None, id: Optional[str] = None) -> Artifact: @@ -294,7 +294,7 @@ def artifact(self, name: Optional[str] = None, id: Optional[str] = None) -> Arti else: return artifact - raise RubiconException("all configured storage backends failed") from return_err + self._raise_rubicon_exception(return_err) @failsafe def delete_artifacts(self, ids: List[str]): @@ -398,7 +398,7 @@ def dataframes( self._dataframes = filter_children(dataframes, tags, qtype, name) return self._dataframes - raise RubiconException("all configured storage backends failed") from return_err + self._raise_rubicon_exception(return_err) @failsafe def dataframe(self, name: Optional[str] = None, id: Optional[str] = None) -> Dataframe: @@ -448,7 +448,7 @@ def dataframe(self, name: Optional[str] = None, id: Optional[str] = None) -> Dat else: return dataframe - raise RubiconException("all configured storage backends failed") from return_err + self._raise_rubicon_exception(return_err) @failsafe def delete_dataframes(self, ids: List[str]): @@ -561,4 +561,4 @@ def tags(self) -> List[str]: return self._domain.tags - raise RubiconException("all configured storage backends failed") from return_err + self._raise_rubicon_exception(return_err) diff --git a/rubicon_ml/client/project.py b/rubicon_ml/client/project.py index dffc5cb2..62543b76 100644 --- a/rubicon_ml/client/project.py +++ b/rubicon_ml/client/project.py @@ -255,6 +255,7 @@ def log_experiment( """ if tags is None: tags = [] + if not isinstance(tags, list) or not all([isinstance(tag, str) for tag in tags]): raise ValueError("`tags` must be `list` of type `str`") @@ -302,10 +303,10 @@ def experiment(self, id: Optional[str] = None, name: Optional[str] = None) -> Ex " Returning most recently logged." ) - experiment = experiments[-1] - return experiment + return experiments[-1] else: return_err = None + for repo in self.repositories: try: experiment = Experiment(repo.get_experiment(self.name, id), self) @@ -313,7 +314,8 @@ def experiment(self, id: Optional[str] = None, name: Optional[str] = None) -> Ex return_err = err else: return experiment - raise RubiconException("all configured storage backends failed") from return_err + + self._raise_rubicon_exception(return_err) @failsafe def experiments( @@ -338,7 +340,9 @@ def experiments( """ if tags is None: tags = [] + return_err = None + for repo in self.repositories: try: experiments = [Experiment(e, self) for e in repo.get_experiments(self.name)] @@ -346,9 +350,10 @@ def experiments( return_err = err else: self._experiments = filter_children(experiments, tags, qtype, name) + return self._experiments - raise RubiconException("all configured storage backends failed") from return_err + self._raise_rubicon_exception(return_err) @failsafe def dataframes( diff --git a/rubicon_ml/client/rubicon.py b/rubicon_ml/client/rubicon.py index a7c02a88..089a57ac 100644 --- a/rubicon_ml/client/rubicon.py +++ b/rubicon_ml/client/rubicon.py @@ -143,18 +143,21 @@ def get_project(self, name=None, id=None): if name is not None: return_err = None + for repo in self.repositories: try: project = repo.get_project(name) except Exception as err: return_err = err else: - project = Project(project, self.config) - return project - raise RubiconException("all configured storage backends failed") from return_err + return Project(project, self.config) + + if len(self.repositories) > 1: + raise RubiconException("all configured storage backends failed") from return_err + else: + raise return_err else: - project = [p for p in self.projects() if p.id == id][0] - return project + return [p for p in self.projects() if p.id == id][0] def get_project_as_dask_df(self, name, group_by=None): """DEPRECATED: Available for backwards compatibility.""" @@ -229,6 +232,7 @@ def projects(self): The list of available projects. """ return_err = None + for repo in self.repositories: try: projects = [Project(project, self.config) for project in repo.get_projects()] @@ -236,7 +240,11 @@ def projects(self): return_err = err else: return projects - raise RubiconException("all configured storage backends failed") from return_err + + if len(self.repositories) > 1: + raise RubiconException("all configured storage backends failed") from return_err + else: + raise return_err @failsafe def sync(self, project_name, s3_root_dir): diff --git a/tests/fixtures.py b/tests/fixtures.py index 2a76b3e2..71579d1b 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -39,9 +39,34 @@ def rubicon_client(): # teardown after yield yield rubicon + rubicon.repository.filesystem.rm(rubicon.config.root_dir, recursive=True) +@pytest.fixture +def rubicon_composite_client(): + """Setup an instance of rubicon configured to log to two memory + backends and clean it up afterwards. + """ + from rubicon_ml import Rubicon + + rubicon = Rubicon( + composite_config=[ + {"persistence": "memory", "root_dir": "a"}, + {"persistence": "memory", "root_dir": "b"}, + ], + ) + + # teardown after yield + yield rubicon + + for i, repository in enumerate(rubicon.repositories): + repository.filesystem.rm( + rubicon.config.storage_options["composite_config"][i]["root_dir"], + recursive=True, + ) + + @pytest.fixture def rubicon_local_filesystem_client(): """Setup an instance of rubicon configured to log to the @@ -84,6 +109,21 @@ def project_client(rubicon_client): return project +@pytest.fixture +def project_composite_client(rubicon_composite_client): + """Setup an instance of rubicon configured to log to two memory + backends with a default project and clean it up afterwards. + """ + rubicon = rubicon_composite_client + + project_name = "Test Project" + project = rubicon.get_or_create_project( + project_name, description="In memory project for testing." + ) + + return project + + @pytest.fixture def rubicon_and_project_client(rubicon_client): """Setup an instance of rubicon configured to log to memory diff --git a/tests/unit/client/test_experiment_client.py b/tests/unit/client/test_experiment_client.py index 97bb7c80..d8102424 100644 --- a/tests/unit/client/test_experiment_client.py +++ b/tests/unit/client/test_experiment_client.py @@ -7,6 +7,10 @@ from rubicon_ml.exceptions import RubiconException +def _raise_error(): + raise RubiconException() + + def test_properties(project_client): project = project_client @@ -71,14 +75,11 @@ def test_get_metrics(project_client): @mock.patch("rubicon_ml.repository.BaseRepository.get_metrics") -def test_get_metrics_multiple_backend_error(mock_get_metrics, project_client): - project = project_client +def test_get_metrics_multiple_backend_error(mock_get_metrics, project_composite_client): + project = project_composite_client experiment = project.log_experiment(name="exp1") - def raise_error(): - raise RubiconException() - - mock_get_metrics.side_effect = raise_error + mock_get_metrics.side_effect = _raise_error with pytest.raises(RubiconException) as e: experiment.metrics() assert "all configured storage backends failed" in str(e) @@ -155,14 +156,11 @@ def test_get_metric_by_id(project_client): @mock.patch("rubicon_ml.repository.BaseRepository.get_metric") -def test_get_metric_multiple_backend_error(mock_get_metric, project_client): - project = project_client +def test_get_metric_multiple_backend_error(mock_get_metric, project_composite_client): + project = project_composite_client experiment = project.log_experiment(name="exp1") - def raise_error(): - raise RubiconException() - - mock_get_metric.side_effect = raise_error + mock_get_metric.side_effect = _raise_error with pytest.raises(RubiconException) as e: experiment.metric("accuracy") assert "all configured storage backends failed" in str(e) @@ -191,14 +189,11 @@ def test_get_features(project_client): @mock.patch("rubicon_ml.repository.BaseRepository.get_features") -def test_get_features_multiple_backend_error(mock_get_features, project_client): - project = project_client +def test_get_features_multiple_backend_error(mock_get_features, project_composite_client): + project = project_composite_client experiment = project.log_experiment(name="exp1") - def raise_error(): - raise RubiconException() - - mock_get_features.side_effect = raise_error + mock_get_features.side_effect = _raise_error with pytest.raises(RubiconException) as e: experiment.features() assert "all configured storage backends failed" in str(e) @@ -246,14 +241,11 @@ def test_get_feature_fails_both_set(project_client): @mock.patch("rubicon_ml.repository.BaseRepository.get_feature") -def test_get_feature_multiple_backend_error(mock_get_feature, project_client): - project = project_client +def test_get_feature_multiple_backend_error(mock_get_feature, project_composite_client): + project = project_composite_client experiment = project.log_experiment(name="exp1") - def raise_error(): - raise RubiconException() - - mock_get_feature.side_effect = raise_error + mock_get_feature.side_effect = _raise_error with pytest.raises(RubiconException) as e: experiment.feature("year") assert "all configured storage backends failed" in str(e) @@ -313,14 +305,11 @@ def test_parameters(project_client): @mock.patch("rubicon_ml.repository.BaseRepository.get_parameters") -def test_parameters_multiple_backend_error(mock_get_parameters, project_client): - project = project_client +def test_parameters_multiple_backend_error(mock_get_parameters, project_composite_client): + project = project_composite_client experiment = project.log_experiment(name="exp1") - def raise_error(): - raise RubiconException() - - mock_get_parameters.side_effect = raise_error + mock_get_parameters.side_effect = _raise_error with pytest.raises(RubiconException) as e: experiment.parameters() assert "all configured storage backends failed" in str(e) @@ -368,14 +357,11 @@ def test_get_parameter_fails_both_set(project_client): @mock.patch("rubicon_ml.repository.BaseRepository.get_parameter") -def test_get_parameter_multiple_backend_error(mock_get_parameter, project_client): - project = project_client +def test_get_parameter_multiple_backend_error(mock_get_parameter, project_composite_client): + project = project_composite_client experiment = project.log_experiment(name="exp1") - def raise_error(): - raise RubiconException() - - mock_get_parameter.side_effect = raise_error + mock_get_parameter.side_effect = _raise_error with pytest.raises(RubiconException) as e: experiment.parameter("n_estimators") assert "all configured storage backends failed" in str(e) diff --git a/tests/unit/client/test_mixin_client.py b/tests/unit/client/test_mixin_client.py index 051b226f..3aa89750 100644 --- a/tests/unit/client/test_mixin_client.py +++ b/tests/unit/client/test_mixin_client.py @@ -9,6 +9,10 @@ from rubicon_ml.exceptions import RubiconException +def _raise_error(): + raise RubiconException() + + # ArtifactMixin def test_log_artifact_from_bytes(project_client): project = project_client @@ -145,13 +149,10 @@ def test_artifacts(project_client): @mock.patch("rubicon_ml.repository.BaseRepository.get_artifacts_metadata") -def test_artifacts_multiple_backend_error(mock_get_artifacts_metadata, project_client): - project = project_client - - def raise_error(): - raise RubiconException() +def test_artifacts_multiple_backend_error(mock_get_artifacts_metadata, project_composite_client): + project = project_composite_client - mock_get_artifacts_metadata.side_effect = raise_error + mock_get_artifacts_metadata.side_effect = _raise_error with pytest.raises(RubiconException) as e: ArtifactMixin.artifacts(project) assert "all configured storage backends failed" in str(e) @@ -248,15 +249,12 @@ def test_artifact_by_id(project_client): @mock.patch("rubicon_ml.repository.BaseRepository.get_artifact_metadata") -def test_artifact_multiple_backend_error(mock_get_artifact_metadata, project_client): - project = project_client +def test_artifact_multiple_backend_error(mock_get_artifact_metadata, project_composite_client): + project = project_composite_client data = b"content" artifact = ArtifactMixin.log_artifact(project, data_bytes=data, name="test.txt") - def raise_error(): - raise RubiconException() - - mock_get_artifact_metadata.side_effect = raise_error + mock_get_artifact_metadata.side_effect = _raise_error with pytest.raises(RubiconException) as e: ArtifactMixin.artifact(project, id=artifact.id) assert "all configured storage backends failed" in str(e) @@ -299,13 +297,10 @@ def test_dataframes(project_client, test_dataframe): @mock.patch("rubicon_ml.repository.BaseRepository.get_dataframes_metadata") -def test_dataframes_multiple_backend_error(mock_get_dataframes_metadata, project_client): - project = project_client +def test_dataframes_multiple_backend_error(mock_get_dataframes_metadata, project_composite_client): + project = project_composite_client - def raise_error(): - raise RubiconException() - - mock_get_dataframes_metadata.side_effect = raise_error + mock_get_dataframes_metadata.side_effect = _raise_error with pytest.raises(RubiconException) as e: DataframeMixin.dataframes(project) assert "all configured storage backends failed" in str(e) @@ -345,15 +340,12 @@ def test_dataframe_by_id(project_client, test_dataframe): @mock.patch("rubicon_ml.repository.BaseRepository.get_dataframe_metadata") def test_dataframe_multiple_backend_error( - mock_get_dataframe_metadata, project_client, test_dataframe + mock_get_dataframe_metadata, project_composite_client, test_dataframe ): - project = project_client + project = project_composite_client dataframe = DataframeMixin.log_dataframe(project, test_dataframe) - def raise_error(): - raise RubiconException() - - mock_get_dataframe_metadata.side_effect = raise_error + mock_get_dataframe_metadata.side_effect = _raise_error with pytest.raises(RubiconException) as e: DataframeMixin.dataframe(project, id=dataframe.id) assert "all configured storage backends failed" in str(e) @@ -520,14 +512,14 @@ def test_remove_tags(project_client): @mock.patch("rubicon_ml.repository.BaseRepository.get_tags") -def test_tags_multiple_backend_error(mock_get_tags, project_client): - project = project_client +def test_tags_multiple_backend_error(mock_get_tags, project_composite_client): + project = project_composite_client experiment = project.log_experiment(tags=["x", "y"]) def raise_error(): raise RubiconException() - mock_get_tags.side_effect = raise_error + mock_get_tags.side_effect = _raise_error with pytest.raises(RubiconException) as e: experiment.tags() assert "all configured storage backends failed" in str(e) diff --git a/tests/unit/client/test_project_client.py b/tests/unit/client/test_project_client.py index 7d6724c9..2b33cdc8 100644 --- a/tests/unit/client/test_project_client.py +++ b/tests/unit/client/test_project_client.py @@ -12,6 +12,10 @@ from rubicon_ml.repository.utils import slugify +def _raise_error(): + raise RubiconException() + + class MockCompletedProcess: def __init__(self, stdout="", returncode=0): self.stdout = stdout @@ -105,13 +109,10 @@ def test_experiments_log_and_retrieval(project_client): @mock.patch("rubicon_ml.repository.BaseRepository.get_experiments") -def test_get_experiments_multiple_backend_error(mock_get_experiments, project_client): - project = project_client +def test_get_experiments_multiple_backend_error(mock_get_experiments, project_composite_client): + project = project_composite_client - def raise_error(): - raise RubiconException() - - mock_get_experiments.side_effect = raise_error + mock_get_experiments.side_effect = _raise_error with pytest.raises(RubiconException) as e: project.experiments() assert "all configured storage backends failed" in str(e) @@ -154,13 +155,10 @@ def test_get_experiment_fails_neither_set(project_client): @mock.patch("rubicon_ml.repository.BaseRepository.get_experiment") -def test_get_experiment_multiple_backend_error(mock_get_experiment, project_client): - project = project_client - - def raise_error(): - raise RubiconException() +def test_get_experiment_multiple_backend_error(mock_get_experiment, project_composite_client): + project = project_composite_client - mock_get_experiment.side_effect = raise_error + mock_get_experiment.side_effect = _raise_error with pytest.raises(RubiconException) as e: project.experiment("exp1") assert "all configured storage backends failed" in str(e) diff --git a/tests/unit/client/test_rubicon_client.py b/tests/unit/client/test_rubicon_client.py index 17e53c0d..a5917593 100644 --- a/tests/unit/client/test_rubicon_client.py +++ b/tests/unit/client/test_rubicon_client.py @@ -22,6 +22,10 @@ def rm(self, path, recursive): return TestFilesystem() +def _raise_error(): + raise RubiconException() + + def test_get_repository(rubicon_client): rubicon = rubicon_client assert rubicon.repository == rubicon.config.repository @@ -124,13 +128,10 @@ def test_get_project_fails_neither_set(rubicon_and_project_client): @mock.patch("rubicon_ml.repository.BaseRepository.get_project") -def test_get_project_multiple_backend_error(mock_get_project, rubicon_client): - rubicon = rubicon_client +def test_get_project_multiple_backend_error(mock_get_project, rubicon_composite_client): + rubicon = rubicon_composite_client - def raise_error(): - raise RubiconException() - - mock_get_project.side_effect = raise_error + mock_get_project.side_effect = _raise_error with pytest.raises(RubiconException) as e: rubicon.get_project(name="Test Project") assert "all configured storage backends failed" in str(e) @@ -149,13 +150,10 @@ def test_get_projects(rubicon_client): @mock.patch("rubicon_ml.repository.BaseRepository.get_projects") -def test_get_projects_multiple_backend_error(mock_get_projects, rubicon_client): - rubicon = rubicon_client - - def raise_error(): - raise RubiconException() +def test_get_projects_multiple_backend_error(mock_get_projects, rubicon_composite_client): + rubicon = rubicon_composite_client - mock_get_projects.side_effect = raise_error + mock_get_projects.side_effect = _raise_error with pytest.raises(RubiconException) as e: rubicon.projects() assert "all configured storage backends failed" in str(e) diff --git a/tests/unit/client/utils/test_exception_handling.py b/tests/unit/client/utils/test_exception_handling.py index 12582cd4..8af95961 100644 --- a/tests/unit/client/utils/test_exception_handling.py +++ b/tests/unit/client/utils/test_exception_handling.py @@ -46,7 +46,7 @@ def test_failure_mode_raise(rubicon_client): with pytest.raises(RubiconException) as e: rubicon_client.get_project(name="does not exist") - assert "all configured storage backends failed" in repr(e) + assert "No project with name 'does not exist' found." in repr(e) @patch("warnings.warn")