From c322c68badbf6f8433b0abeba0618703dbab82d5 Mon Sep 17 00:00:00 2001 From: stephenpardy Date: Wed, 21 Jun 2023 17:08:23 -0500 Subject: [PATCH 01/11] Add tests for ValueError --- tests/unit/client/test_experiment_client.py | 68 +++++++++++++++++++++ tests/unit/client/test_mixin_client.py | 17 ++++++ tests/unit/client/test_project_client.py | 18 ++++++ tests/unit/client/test_rubicon_client.py | 16 +++++ 4 files changed, 119 insertions(+) diff --git a/tests/unit/client/test_experiment_client.py b/tests/unit/client/test_experiment_client.py index 1b533341..6b370347 100644 --- a/tests/unit/client/test_experiment_client.py +++ b/tests/unit/client/test_experiment_client.py @@ -1,3 +1,5 @@ +import pytest + from rubicon_ml import domain from rubicon_ml.client import Experiment @@ -74,6 +76,28 @@ def test_get_metric_by_name(project_client): assert metric == "accuracy" +def test_get_metric_fails_neither_set(project_client): + project = project_client + experiment = project.log_experiment(name="exp1") + experiment.log_metric("accuracy", 100) + + with pytest.raises(ValueError) as e: + experiment.metric(name=None, id=None) + + assert "`name` OR `id` required." in str(e) + + +def test_get_metric_fails_both_set(project_client): + project = project_client + experiment = project.log_experiment(name="exp1") + experiment.log_metric("accuracy", 100) + + with pytest.raises(ValueError) as e: + experiment.metric(name="foo", id=123) + + assert "`name` OR `id` required." in str(e) + + def test_metrics_tagged_and(project_client): project = project_client experiment = project.log_experiment(name="exp1") @@ -154,6 +178,28 @@ def test_get_feature_by_id(project_client): assert feature == "year" +def test_get_feature_fails_neither_set(project_client): + project = project_client + experiment = project.log_experiment(name="exp1") + experiment.log_feature("year") + + with pytest.raises(ValueError) as e: + experiment.feature(name=None, id=None) + + assert "`name` OR `id` required." in str(e) + + +def test_get_feature_fails_both_set(project_client): + project = project_client + experiment = project.log_experiment(name="exp1") + experiment.log_feature("year") + + with pytest.raises(ValueError) as e: + experiment.feature(name="foo", id=123) + + assert "`name` OR `id` required." in str(e) + + def test_features_tagged_and(project_client): project = project_client experiment = project.log_experiment(name="exp1") @@ -226,6 +272,28 @@ def test_get_parameter_by_id(project_client): assert parameter == "n_estimators" +def test_get_parameter_fails_neither_set(project_client): + project = project_client + experiment = project.log_experiment(name="exp1") + experiment.log_parameter("n_estimators", "estimator") + + with pytest.raises(ValueError) as e: + experiment.parameter(name=None, id=None) + + assert "`name` OR `id` required." in str(e) + + +def test_get_parameter_fails_both_set(project_client): + project = project_client + experiment = project.log_experiment(name="exp1") + experiment.log_parameter("n_estimators", "estimator") + + with pytest.raises(ValueError) as e: + experiment.parameter(name="foo", id=123) + + assert "`name` OR `id` required." in str(e) + + def test_parameters_tagged_and(project_client): project = project_client experiment = project.log_experiment(name="exp1") diff --git a/tests/unit/client/test_mixin_client.py b/tests/unit/client/test_mixin_client.py index 66816869..17f4e460 100644 --- a/tests/unit/client/test_mixin_client.py +++ b/tests/unit/client/test_mixin_client.py @@ -332,6 +332,23 @@ def test_dataframes_by_name_not_found(project_client, test_dataframe): assert dataframes == [] +def test_get_dataframe_fails_both_set(project_client, test_dataframe): + project = project_client + with pytest.raises(ValueError) as e: + DataframeMixin.dataframe(project, name="foo", id=123) + + assert "`name` OR `id` required." in str(e.value) + + +def test_get_dataframe_fails_neither_set(project_client, test_dataframe): + project = project_client + + with pytest.raises(ValueError) as e: + DataframeMixin.dataframe(project, name=None, id=None) + + assert "`name` OR `id` required." in str(e.value) + + def test_dataframes_tagged_and(project_client, test_dataframe): project = project_client df = test_dataframe diff --git a/tests/unit/client/test_project_client.py b/tests/unit/client/test_project_client.py index 7ed23051..55c0be3a 100644 --- a/tests/unit/client/test_project_client.py +++ b/tests/unit/client/test_project_client.py @@ -121,6 +121,24 @@ def test_experiment_by_name(project_client): assert experiment.name == "exp1" +def test_get_experiment_fails_both_set(project_client): + project = project_client + project.log_experiment(name="exp1") + with pytest.raises(ValueError) as e: + project.experiment(name="foo", id=123) + + assert "`name` OR `id` required." in str(e.value) + + +def test_get_experiment_fails_neither_set(project_client): + project = project_client + project.log_experiment(name="exp1") + with pytest.raises(ValueError) as e: + project.experiment(name=None, id=None) + + assert "`name` OR `id` required." in str(e.value) + + def test_experiment_warning(project_client, test_dataframe): project = project_client experiment_a = project.log_experiment(name="exp1") diff --git a/tests/unit/client/test_rubicon_client.py b/tests/unit/client/test_rubicon_client.py index 97b4bf13..e0b50129 100644 --- a/tests/unit/client/test_rubicon_client.py +++ b/tests/unit/client/test_rubicon_client.py @@ -107,6 +107,22 @@ def test_get_project_by_id(rubicon_and_project_client): assert project_id == rubicon.get_project(id=project_id).id +def test_get_project_fails_both_set(rubicon_and_project_client): + rubicon, project = rubicon_and_project_client + with pytest.raises(ValueError) as e: + rubicon.get_project(name="foo", id=123) + + assert "`name` OR `id` required." in str(e.value) + + +def test_get_project_fails_neither_set(rubicon_and_project_client): + rubicon, project = rubicon_and_project_client + with pytest.raises(ValueError) as e: + rubicon.get_project(name=None, id=None) + + assert "`name` OR `id` required." in str(e.value) + + def test_get_projects(rubicon_client): rubicon = rubicon_client rubicon.create_project("Project A") From 1c3812788128867f3ec3e6fe1c665f2a99dc1e30 Mon Sep 17 00:00:00 2001 From: stephenpardy Date: Tue, 25 Jul 2023 10:13:46 -0500 Subject: [PATCH 02/11] Start the process of adding type hints --- rubicon_ml/client/artifact.py | 24 +++++++++++++++--------- rubicon_ml/client/base.py | 21 ++++++++++++++++----- rubicon_ml/client/dataframe.py | 18 +++++++++++++++--- rubicon_ml/client/experiment.py | 9 ++++++++- rubicon_ml/client/feature.py | 8 +++++++- rubicon_ml/client/project.py | 12 ++++++++---- rubicon_ml/domain/__init__.py | 3 +++ tests/unit/client/test_rubicon_client.py | 3 --- 8 files changed, 72 insertions(+), 26 deletions(-) diff --git a/rubicon_ml/client/artifact.py b/rubicon_ml/client/artifact.py index 6b2cb44b..56650d54 100644 --- a/rubicon_ml/client/artifact.py +++ b/rubicon_ml/client/artifact.py @@ -1,5 +1,6 @@ import os import pickle +from typing import Optional, TYPE_CHECKING import warnings import fsspec @@ -10,6 +11,11 @@ from rubicon_ml.exceptions import RubiconException +if TYPE_CHECKING: + from rubicon_ml.domain import Artifact as ArtifactDomain + from rubicon_ml.client import Project + + class Artifact(Base, TagMixin): """A client artifact. @@ -32,7 +38,7 @@ class Artifact(Base, TagMixin): logged to. """ - def __init__(self, domain, parent): + def __init__(self, domain: ArtifactDomain, parent: Project): super().__init__(domain, parent._config) self._data = None @@ -42,8 +48,8 @@ def _get_data(self): """Loads the data associated with this artifact.""" project_name, experiment_id = self.parent._get_identifiers() return_err = None - for repo in self.repositories: - self._data = None + self._data = None + for repo in self.repositories or []: try: self._data = repo.get_artifact_data( project_name, self.id, experiment_id=experiment_id @@ -56,7 +62,7 @@ def _get_data(self): raise RubiconException("all configured storage backends failed") from return_err @failsafe - def get_data(self, unpickle=False): + def get_data(self, unpickle: bool = False): """Loads the data associated with this artifact and unpickles if needed. @@ -68,7 +74,7 @@ def get_data(self, unpickle=False): """ project_name, experiment_id = self.parent._get_identifiers() return_err = None - for repo in self.repositories: + for repo in self.repositories or []: try: data = repo.get_artifact_data(project_name, self.id, experiment_id=experiment_id) except Exception as err: @@ -80,7 +86,7 @@ def get_data(self, unpickle=False): raise RubiconException("all configured storage backends failed") from return_err @failsafe - def download(self, location=None, name=None): + def download(self, location: Optional[str] = None, name: Optional[str] = None): """Download this artifact's data. Parameters @@ -104,17 +110,17 @@ def download(self, location=None, name=None): f.write(self.data) @property - def id(self): + def id(self) -> str: """Get the artifact's id.""" return self._domain.id @property - def name(self): + def name(self) -> str: """Get the artifact's name.""" return self._domain.name @property - def description(self): + def description(self) -> str: """Get the artifact's description.""" return self._domain.description diff --git a/rubicon_ml/client/base.py b/rubicon_ml/client/base.py index 736f0851..e6aa8284 100644 --- a/rubicon_ml/client/base.py +++ b/rubicon_ml/client/base.py @@ -1,3 +1,11 @@ +from typing import List, Optional, TYPE_CHECKING + + +if TYPE_CHECKING: + from rubicon_ml.client import Config + from rubicon_ml.domain import DOMAIN_TYPES + + class Base: """The base object for all top-level client objects. @@ -9,19 +17,22 @@ class Base: The config, which injects the repository to use. """ - def __init__(self, domain, config=None): + def __init__(self, domain: DOMAIN_TYPES, config: Optional[Config] = None): self._config = config self._domain = domain - def __str__(self): + def __str__(self) -> str: return self._domain.__str__() @property - def repository(self): - return self._config.repository + def repository(self) -> Optional[str]: + return self._config.repository if self._config is not None else None @property - def repositories(self): + def repositories(self) -> Optional[List[str]]: + if self._config is None: + return None + if hasattr(self._config, "repositories"): return self._config.repositories else: diff --git a/rubicon_ml/client/dataframe.py b/rubicon_ml/client/dataframe.py index d1c8394f..a76da08c 100644 --- a/rubicon_ml/client/dataframe.py +++ b/rubicon_ml/client/dataframe.py @@ -1,8 +1,15 @@ +from typing import Callable, Literal, Optional, TYPE_CHECKING, Union + from rubicon_ml.client import Base, TagMixin from rubicon_ml.client.utils.exception_handling import failsafe from rubicon_ml.exceptions import RubiconException +if TYPE_CHECKING: + from rubicon_ml.domain import Dataframe as DataframeDomain + from rubicon_ml.client import Experiment, Project + + class Dataframe(Base, TagMixin): """A client dataframe. @@ -24,14 +31,14 @@ class Dataframe(Base, TagMixin): logged to. """ - def __init__(self, domain, parent): + def __init__(self, domain: DataframeDomain, parent: Union[Experiment, Project]): super().__init__(domain, parent._config) self._data = None self._parent = parent @failsafe - def get_data(self, df_type="pandas"): + def get_data(self, df_type: Literal["pandas", "dask"] = "pandas"): """Loads the data associated with this Dataframe into a `pandas` or `dask` dataframe. @@ -59,7 +66,12 @@ def get_data(self, df_type="pandas"): raise RubiconException(return_err) @failsafe - def plot(self, df_type="pandas", plotting_func=None, **kwargs): + def plot( + self, + df_type: Literal["pandas", "dask"] = "pandas", + plotting_func: Optional[Callable] = None, + **kwargs + ): """Render the dataframe using `plotly.express`. Parameters diff --git a/rubicon_ml/client/experiment.py b/rubicon_ml/client/experiment.py index b71307d4..b7b23398 100644 --- a/rubicon_ml/client/experiment.py +++ b/rubicon_ml/client/experiment.py @@ -1,3 +1,5 @@ +from typing import Optional, TYPE_CHECKING + from rubicon_ml import domain from rubicon_ml.client import ( ArtifactMixin, @@ -13,6 +15,11 @@ from rubicon_ml.exceptions import RubiconException +if TYPE_CHECKING: + from rubicon_ml.domain import Experiment as ExperimentDomain + from rubicon_ml.client import Project + + class Experiment(Base, ArtifactMixin, DataframeMixin, TagMixin): """A client experiment. @@ -30,7 +37,7 @@ class Experiment(Base, ArtifactMixin, DataframeMixin, TagMixin): The project that the experiment is logged to. """ - def __init__(self, domain, parent): + def __init__(self, domain: ExperimentDomain, parent: Project): super().__init__(domain, parent._config) self._parent = parent diff --git a/rubicon_ml/client/feature.py b/rubicon_ml/client/feature.py index be1991fa..6cd36817 100644 --- a/rubicon_ml/client/feature.py +++ b/rubicon_ml/client/feature.py @@ -1,5 +1,11 @@ +from typing import TYPE_CHECKING + from rubicon_ml.client import Base, TagMixin +if TYPE_CHECKING: + from rubicon_ml.domain import Feature as FeatureDomain + from rubicon_ml.client import Experiment + class Feature(Base, TagMixin): """A client feature. @@ -25,7 +31,7 @@ class Feature(Base, TagMixin): logged to. """ - def __init__(self, domain, parent): + def __init__(self, domain: FeatureDomain, parent: Experiment): super().__init__(domain, parent._config) self._data = None diff --git a/rubicon_ml/client/project.py b/rubicon_ml/client/project.py index 8ed592f9..ce7d8541 100644 --- a/rubicon_ml/client/project.py +++ b/rubicon_ml/client/project.py @@ -1,6 +1,6 @@ import subprocess import warnings -from typing import List, Optional +from typing import List, Optional, TYPE_CHECKING import dask.dataframe as dd import pandas as pd @@ -11,6 +11,10 @@ from rubicon_ml.client.utils.tags import filter_children from rubicon_ml.exceptions import RubiconException +if TYPE_CHECKING: + from rubicon_ml.domain import Project as ProjectDomain + from rubicon_ml.client import Config + class Project(Base, ArtifactMixin, DataframeMixin): """A client project. @@ -26,14 +30,14 @@ class Project(Base, ArtifactMixin, DataframeMixin): The config, which specifies the underlying repository. """ - def __init__(self, domain, config=None): + def __init__(self, domain: ProjectDomain, config: Optional[Config] = None): super().__init__(domain, config) self._artifacts = [] self._dataframes = [] self._experiments = [] - def _get_branch_name(self): + def _get_branch_name(self) -> str: """Returns the name of the active branch of the `git` repo it is called from. """ @@ -42,7 +46,7 @@ def _get_branch_name(self): return completed_process.stdout.decode("utf8").replace("\n", "") - def _get_commit_hash(self): + def _get_commit_hash(self) -> str: """Returns the hash of the last commit to the active branch of the `git` repo it is called from. """ diff --git a/rubicon_ml/domain/__init__.py b/rubicon_ml/domain/__init__.py index ad4002c5..064d09ea 100644 --- a/rubicon_ml/domain/__init__.py +++ b/rubicon_ml/domain/__init__.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import Union from rubicon_ml.domain.artifact import Artifact from rubicon_ml.domain.dataframe import Dataframe @@ -8,4 +9,6 @@ from rubicon_ml.domain.parameter import Parameter from rubicon_ml.domain.project import Project +DOMAIN_TYPES = Union[Artifact, Dataframe, Experiment, Feature, Metric, Parameter, Project] + __all__ = ["Artifact", "Dataframe", "Experiment", "Feature", "Metric", "Parameter", "Project"] diff --git a/tests/unit/client/test_rubicon_client.py b/tests/unit/client/test_rubicon_client.py index 858c2558..17e53c0d 100644 --- a/tests/unit/client/test_rubicon_client.py +++ b/tests/unit/client/test_rubicon_client.py @@ -123,8 +123,6 @@ def test_get_project_fails_neither_set(rubicon_and_project_client): assert "`name` OR `id` required." in str(e.value) -<<<<<<< HEAD -======= @mock.patch("rubicon_ml.repository.BaseRepository.get_project") def test_get_project_multiple_backend_error(mock_get_project, rubicon_client): rubicon = rubicon_client @@ -138,7 +136,6 @@ def raise_error(): assert "all configured storage backends failed" in str(e) ->>>>>>> 34d3bcbccd2fc9079f3f5c9dd0171c7cf04a51c3 def test_get_projects(rubicon_client): rubicon = rubicon_client rubicon.create_project("Project A") From 2a9ad46533fbd07eb9d757c41242ba137603fcd7 Mon Sep 17 00:00:00 2001 From: stephenpardy Date: Tue, 25 Jul 2023 12:12:25 -0500 Subject: [PATCH 03/11] More types --- rubicon_ml/client/parameter.py | 13 ++++++++++--- tests/unit/client/test_project_client.py | 3 --- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/rubicon_ml/client/parameter.py b/rubicon_ml/client/parameter.py index ada71e34..e76fae6a 100644 --- a/rubicon_ml/client/parameter.py +++ b/rubicon_ml/client/parameter.py @@ -1,6 +1,13 @@ +from typing import TYPE_CHECKING + from rubicon_ml.client import Base, TagMixin +if TYPE_CHECKING: + from rubicon_ml.domain import Parameter as ParameterDomain + from rubicon_ml.client import Experiment + + class Parameter(Base, TagMixin): """A client parameter. @@ -22,17 +29,17 @@ class Parameter(Base, TagMixin): The experiment that the parameter is logged to. """ - def __init__(self, domain, parent): + def __init__(self, domain: ParameterDomain, parent: Experiment): super().__init__(domain, parent._config) self._parent = parent @property - def id(self): + def id(self) -> str: """Get the parameter's id.""" return self._domain.id @property - def name(self): + def name(self) -> str: """Get the parameter's name.""" return self._domain.name diff --git a/tests/unit/client/test_project_client.py b/tests/unit/client/test_project_client.py index 9dddd4d9..7d6724c9 100644 --- a/tests/unit/client/test_project_client.py +++ b/tests/unit/client/test_project_client.py @@ -153,8 +153,6 @@ def test_get_experiment_fails_neither_set(project_client): assert "`name` OR `id` required." in str(e.value) -<<<<<<< HEAD -======= @mock.patch("rubicon_ml.repository.BaseRepository.get_experiment") def test_get_experiment_multiple_backend_error(mock_get_experiment, project_client): project = project_client @@ -168,7 +166,6 @@ def raise_error(): assert "all configured storage backends failed" in str(e) ->>>>>>> 34d3bcbccd2fc9079f3f5c9dd0171c7cf04a51c3 def test_experiment_warning(project_client, test_dataframe): project = project_client experiment_a = project.log_experiment(name="exp1") From 3cb3b9c4fa5569e6d717b595bbfb285881a2226f Mon Sep 17 00:00:00 2001 From: stephenpardy Date: Sun, 10 Sep 2023 17:52:52 -0500 Subject: [PATCH 04/11] More complete type hinting --- rubicon_ml/client/base.py | 5 +- rubicon_ml/client/config.py | 19 ++++-- rubicon_ml/client/feature.py | 13 ++-- rubicon_ml/client/metric.py | 21 ++++--- rubicon_ml/client/mixin.py | 58 ++++++++++++------ rubicon_ml/client/parameter.py | 15 ++--- rubicon_ml/client/project.py | 61 +++++++++++++------ rubicon_ml/client/rubicon.py | 7 ++- rubicon_ml/client/utils/exception_handling.py | 7 ++- rubicon_ml/client/utils/tags.py | 5 +- rubicon_ml/domain/artifact.py | 6 +- rubicon_ml/domain/dataframe.py | 8 +-- rubicon_ml/domain/experiment.py | 14 ++--- rubicon_ml/domain/feature.py | 6 +- rubicon_ml/domain/metric.py | 4 +- rubicon_ml/domain/mixin.py | 23 ++++++- rubicon_ml/domain/parameter.py | 4 +- rubicon_ml/domain/project.py | 7 ++- rubicon_ml/domain/utils/training_metadata.py | 6 +- rubicon_ml/domain/utils/uuid.py | 2 +- rubicon_ml/repository/base.py | 2 +- 21 files changed, 191 insertions(+), 102 deletions(-) diff --git a/rubicon_ml/client/base.py b/rubicon_ml/client/base.py index e6aa8284..bb735dc9 100644 --- a/rubicon_ml/client/base.py +++ b/rubicon_ml/client/base.py @@ -4,6 +4,7 @@ if TYPE_CHECKING: from rubicon_ml.client import Config from rubicon_ml.domain import DOMAIN_TYPES + from rubicon_ml.repository import BaseRepository class Base: @@ -25,11 +26,11 @@ def __str__(self) -> str: return self._domain.__str__() @property - def repository(self) -> Optional[str]: + def repository(self) -> Optional[BaseRepository]: return self._config.repository if self._config is not None else None @property - def repositories(self) -> Optional[List[str]]: + def repositories(self) -> Optional[List[BaseRepository]]: if self._config is None: return None diff --git a/rubicon_ml/client/config.py b/rubicon_ml/client/config.py index a7b41bcf..8f4913da 100644 --- a/rubicon_ml/client/config.py +++ b/rubicon_ml/client/config.py @@ -1,8 +1,9 @@ import os +from typing import Dict, Optional, Tuple import subprocess from rubicon_ml.exceptions import RubiconException -from rubicon_ml.repository import LocalRepository, MemoryRepository, S3Repository +from rubicon_ml.repository import BaseRepository, LocalRepository, MemoryRepository, S3Repository class Config: @@ -29,14 +30,18 @@ class Config: """ PERSISTENCE_TYPES = ["filesystem", "memory"] - REPOSITORIES = { + REPOSITORIES: Dict[str, type[BaseRepository]] = { "memory-memory": MemoryRepository, "filesystem-local": LocalRepository, "filesystem-s3": S3Repository, } def __init__( - self, persistence=None, root_dir=None, is_auto_git_enabled=False, **storage_options + self, + persistence: Optional[str] = None, + root_dir: Optional[str] = None, + is_auto_git_enabled: bool = False, + **storage_options, ): self.storage_options = storage_options if storage_options is not None and "composite_config" in storage_options: @@ -62,7 +67,9 @@ def _check_is_in_git_repo(self): "Not a `git` repo: Falied to locate the '.git' directory in this or any parent directories." ) - def _load_config(self, persistence, root_dir, is_auto_git_enabled): + def _load_config( + self, persistence: Optional[str], root_dir: Optional[str], is_auto_git_enabled: bool + ) -> Tuple[str, Optional[str], bool]: """Get the configuration values.""" persistence = os.environ.get("PERSISTENCE", persistence) if persistence not in self.PERSISTENCE_TYPES: @@ -77,7 +84,7 @@ def _load_config(self, persistence, root_dir, is_auto_git_enabled): return (persistence, root_dir, is_auto_git_enabled) - def _get_protocol(self): + def _get_protocol(self) -> str: """Get the file protocol of the configured root directory.""" if self.persistence == "memory": return "memory" @@ -89,7 +96,7 @@ def _get_protocol(self): return "custom" # catch-all for external backends - def _get_repository(self): + def _get_repository(self) -> BaseRepository: """Get the repository for the configured persistence type.""" protocol = self._get_protocol() diff --git a/rubicon_ml/client/feature.py b/rubicon_ml/client/feature.py index 6cd36817..49c9a73b 100644 --- a/rubicon_ml/client/feature.py +++ b/rubicon_ml/client/feature.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING +from datetime import datetime +from typing import TYPE_CHECKING, Optional from rubicon_ml.client import Base, TagMixin @@ -38,17 +39,17 @@ def __init__(self, domain: FeatureDomain, parent: Experiment): self._parent = parent @property - def id(self): + def id(self) -> str: """Get the feature's id.""" return self._domain.id @property - def name(self): + def name(self) -> Optional[str]: """Get the feature's name.""" return self._domain.name @property - def description(self): + def description(self) -> Optional[str]: """Get the feature's description.""" return self._domain.description @@ -58,11 +59,11 @@ def importance(self): return self._domain.importance @property - def created_at(self): + def created_at(self) -> datetime: """Get the feature's created_at.""" return self._domain.created_at @property - def parent(self): + def parent(self) -> Experiment: """Get the feature's parent client object.""" return self._parent diff --git a/rubicon_ml/client/metric.py b/rubicon_ml/client/metric.py index 4e61588a..88677ffc 100644 --- a/rubicon_ml/client/metric.py +++ b/rubicon_ml/client/metric.py @@ -1,5 +1,12 @@ +from datetime import datetime + +from typing import Optional, TYPE_CHECKING from rubicon_ml.client import Base, TagMixin +if TYPE_CHECKING: + from rubicon_ml.domain import Metric as MetricDomain + from rubicon_ml.client import Experiment + class Metric(Base, TagMixin): """A client metric. @@ -21,19 +28,19 @@ class Metric(Base, TagMixin): logged to. """ - def __init__(self, domain, parent): + def __init__(self, domain: MetricDomain, parent: Experiment): super().__init__(domain, parent._config) self._data = None self._parent = parent @property - def id(self): + def id(self) -> str: """Get the metric's id.""" return self._domain.id @property - def name(self): + def name(self) -> Optional[str]: """Get the metric's name.""" return self._domain.name @@ -43,21 +50,21 @@ def value(self): return self._domain.value @property - def directionality(self): + def directionality(self) -> str: """Get the metric's directionality.""" return self._domain.directionality @property - def description(self): + def description(self) -> Optional[str]: """Get the metric's description.""" return self._domain.description @property - def created_at(self): + def created_at(self) -> datetime: """Get the metric's created_at.""" return self._domain.created_at @property - def parent(self): + def parent(self) -> Experiment: """Get the metric's parent client object.""" return self._parent diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index f1c792ac..308bfa25 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -1,6 +1,7 @@ import os import pickle import subprocess +from typing import Any, Optional, List, Union import warnings from datetime import datetime @@ -11,6 +12,12 @@ from rubicon_ml.client.utils.tags import filter_children from rubicon_ml.exceptions import RubiconException +if TYPE_CHECKING: + import pandas as pd + import dask.dataframe as dd + + from rubicon_ml.client import Artifact, Dataframe + class ArtifactMixin: """Adds artifact support to a client object.""" @@ -47,14 +54,14 @@ def _validate_data(self, data_bytes, data_file, data_object, data_path, name): @failsafe def log_artifact( self, - data_bytes=None, + data_bytes: Optional[bytes] = None, data_file=None, - data_object=None, - data_path=None, - name=None, - description=None, - tags=[], - ): + data_object: Optional[Any] = None, + data_path: Optional[str] = None, + name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[List[str]] = None, + ) -> Artifact: """Log an artifact to this client object. Parameters @@ -108,7 +115,8 @@ def log_artifact( ... data_path="./path/to/artifact.pkl", description="log artifact from file path" ... ) """ - + if tags is None: + tags = [] if not isinstance(tags, list) or not all([isinstance(tag, str) for tag in tags]): raise ValueError("`tags` must be `list` of type `str`") @@ -127,7 +135,7 @@ def log_artifact( return client.Artifact(artifact, self) - def _get_environment_bytes(self, export_cmd): + def _get_environment_bytes(self, export_cmd: List[str]) -> bytes: """Get the working environment as a sequence of bytes. Parameters @@ -148,7 +156,7 @@ def _get_environment_bytes(self, export_cmd): return completed_process.stdout @failsafe - def log_conda_environment(self, artifact_name=None): + def log_conda_environment(self, artifact_name: Optional[str] = None) -> Artifact: """Log the conda environment as an artifact to this client object. Useful for recreating your exact environment at a later date. @@ -175,7 +183,7 @@ def log_conda_environment(self, artifact_name=None): return artifact @failsafe - def log_pip_requirements(self, artifact_name=None): + def log_pip_requirements(self, artifact_name: Optional[str] = None) -> Artifact: """Log the pip requirements as an artifact to this client object. Useful for recreating your exact environment at a later date. @@ -198,7 +206,9 @@ def log_pip_requirements(self, artifact_name=None): return artifact @failsafe - def artifacts(self, name=None, tags=[], qtype="or"): + def artifacts( + self, name: Optional[str] = None, tags: Optional[List[str]] = None, qtype: str = "or" + ) -> List[Artifact]: """Get the artifacts logged to this client object. Parameters @@ -216,6 +226,8 @@ def artifacts(self, name=None, tags=[], qtype="or"): list of rubicon.client.Artifact The artifacts previously logged to this client object. """ + if tags is None: + tags = [] project_name, experiment_id = self._get_identifiers() return_err = None for repo in self.repositories: @@ -233,7 +245,7 @@ def artifacts(self, name=None, tags=[], qtype="or"): raise RubiconException("all configured storage backends failed") from return_err @failsafe - def artifact(self, name=None, id=None): + def artifact(self, name: Optional[str] = None, id: Optional[str] = None) -> Artifact: """Get an artifact logged to this project by id or name. Parameters @@ -279,7 +291,7 @@ def artifact(self, name=None, id=None): raise RubiconException("all configured storage backends failed") from return_err @failsafe - def delete_artifacts(self, ids): + def delete_artifacts(self, ids: List[str]): """Delete the artifacts logged to with client object with ids `ids`. @@ -299,7 +311,9 @@ class DataframeMixin: """Adds dataframe support to a client object.""" @failsafe - def log_dataframe(self, df, description=None, name=None, tags=[]): + def log_dataframe( + self, df: Union[pd.DataFrame, dd.DataFrame], description=None, name=None, tags=[] + ) -> Dataframe: """Log a dataframe to this client object. Parameters @@ -334,7 +348,9 @@ def log_dataframe(self, df, description=None, name=None, tags=[]): return client.Dataframe(dataframe, self) @failsafe - def dataframes(self, name=None, tags=[], qtype="or"): + def dataframes( + self, name: Optional[str] = None, tags: Optional[List[str]] = None, qtype: str = "or" + ) -> List[Dataframe]: """Get the dataframes logged to this client object. Parameters @@ -352,6 +368,8 @@ def dataframes(self, name=None, tags=[], qtype="or"): list of rubicon.client.Dataframe The dataframes previously logged to this client object. """ + if tags is None: + tags = [] project_name, experiment_id = self._get_identifiers() return_err = None for repo in self.repositories: @@ -369,7 +387,7 @@ def dataframes(self, name=None, tags=[], qtype="or"): raise RubiconException("all configured storage backends failed") from return_err @failsafe - def dataframe(self, name=None, id=None): + def dataframe(self, name: Optional[str] = None, id: Optional[str] = None) -> Dataframe: """ Get the dataframe logged to this client object. @@ -419,7 +437,7 @@ def dataframe(self, name=None, id=None): raise RubiconException("all configured storage backends failed") from return_err @failsafe - def delete_dataframes(self, ids): + def delete_dataframes(self, ids: List[str]): """Delete the dataframes with ids `ids` logged to this client object. @@ -455,7 +473,7 @@ def _get_taggable_identifiers(self): return project_name, experiment_id, entity_identifier @failsafe - def add_tags(self, tags): + def add_tags(self, tags: List[str]): """Add tags to this client object. Parameters @@ -479,7 +497,7 @@ def add_tags(self, tags): ) @failsafe - def remove_tags(self, tags): + def remove_tags(self, tags: List[str]): """Remove tags from this client object. Parameters diff --git a/rubicon_ml/client/parameter.py b/rubicon_ml/client/parameter.py index e76fae6a..69d0cbf7 100644 --- a/rubicon_ml/client/parameter.py +++ b/rubicon_ml/client/parameter.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING +from datetime import datetime +from typing import Optional, TYPE_CHECKING, Union from rubicon_ml.client import Base, TagMixin @@ -39,26 +40,26 @@ def id(self) -> str: return self._domain.id @property - def name(self) -> str: + def name(self) -> Optional[str]: """Get the parameter's name.""" return self._domain.name @property - def value(self): + def value(self) -> Optional[Union[object, float]]: """Get the parameter's value.""" - return self._domain.value + return getattr(self._domain, "value", None) @property - def description(self): + def description(self) -> Optional[str]: """Get the parameter's description.""" return self._domain.description @property - def created_at(self): + def created_at(self) -> datetime: """Get the time the parameter was created.""" return self._domain.created_at @property - def parent(self): + def parent(self) -> Experiment: """Get the parameter's parent client object.""" return self._parent diff --git a/rubicon_ml/client/project.py b/rubicon_ml/client/project.py index ce7d8541..a6cd7cdc 100644 --- a/rubicon_ml/client/project.py +++ b/rubicon_ml/client/project.py @@ -1,6 +1,6 @@ import subprocess import warnings -from typing import List, Optional, TYPE_CHECKING +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union import dask.dataframe as dd import pandas as pd @@ -13,7 +13,8 @@ if TYPE_CHECKING: from rubicon_ml.domain import Project as ProjectDomain - from rubicon_ml.client import Config + from rubicon_ml.client import Config, Dataframe + from rubicon_ml import Rubicon class Project(Base, ArtifactMixin, DataframeMixin): @@ -55,7 +56,7 @@ def _get_commit_hash(self) -> str: return completed_process.stdout.decode("utf8").replace("\n", "") - def _get_identifiers(self): + def _get_identifiers(self) -> Tuple[Optional[str], None]: """Get the project's name.""" return self.name, None @@ -90,7 +91,7 @@ def _create_experiment_domain( tags=tags, ) - def _group_experiments(self, experiments, group_by=None): + def _group_experiments(self, experiments: List[Experiment], group_by: Optional[str] = None): """Groups experiments by `group_by`. Valid options include ["commit_hash"]. Returns @@ -115,7 +116,7 @@ def _group_experiments(self, experiments, group_by=None): return grouped_experiments - def to_dask_df(self, group_by=None): + def to_dask_df(self, group_by: Optional[str] = None): """DEPRECATED: Available for backwards compatibility.""" warnings.warn( "`to_dask_df` is deprecated and will be removed in a future release. " @@ -126,7 +127,9 @@ def to_dask_df(self, group_by=None): return self.to_df(df_type="dask", group_by=group_by) @failsafe - def to_df(self, df_type="pandas", group_by=None): + def to_df( + self, df_type: str = "pandas", group_by: Optional[str] = None + ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame], dd.DataFrame, Dict[str, dd.DataFrame]]: """Loads the project's data into dask or pandas dataframe(s) sorted by `created_at`. This includes the experiment details along with parameters and metrics. @@ -142,9 +145,9 @@ def to_df(self, df_type="pandas", group_by=None): Returns ------- - pandas.DataFrame or list of pandas.DataFrame or dask.DataFrame or list of dask.DataFrame + pandas.DataFrame or dict of pandas.DataFrame or dask.DataFrame or dict of dask.DataFrame If `group_by` is `None`, a dask or pandas dataframe holding the project's - data. Otherwise a list of dask or pandas dataframes holding the project's + data. Otherwise a dict of dask or pandas dataframes holding the project's data grouped by `group_by`. """ DEFAULT_COLUMNS = [ @@ -203,14 +206,14 @@ def to_df(self, df_type="pandas", group_by=None): @failsafe def log_experiment( self, - name=None, - description=None, - model_name=None, - branch_name=None, - commit_hash=None, - training_metadata=None, - tags=[], - ): + name: Optional[str] = None, + description: Optional[str] = None, + model_name: Optional[str] = None, + branch_name: Optional[str] = None, + commit_hash: Optional[str] = None, + training_metadata: Optional[Union[Tuple, List[Tuple]]] = None, + tags: Optional[List[str]] = None, + ) -> Experiment: """Log a new experiment to this project. Parameters @@ -246,6 +249,8 @@ def log_experiment( rubicon.client.Experiment The created experiment. """ + if tags is None: + tags = [] if not isinstance(tags, list) or not all([isinstance(tag, str) for tag in tags]): raise ValueError("`tags` must be `list` of type `str`") @@ -264,7 +269,7 @@ def log_experiment( return Experiment(experiment, self) @failsafe - def experiment(self, id=None, name=None): + def experiment(self, id: Optional[str] = None, name: Optional[str] = None) -> Experiment: """Get an experiment logged to this project by id or name. Parameters @@ -307,7 +312,9 @@ def experiment(self, id=None, name=None): raise RubiconException("all configured storage backends failed") from return_err @failsafe - def experiments(self, tags=[], qtype="or", name=None): + def experiments( + self, tags: Optional[List[str]] = None, qtype: str = "or", name: Optional[str] = None + ) -> List[Experiment]: """Get the experiments logged to this project. Parameters @@ -325,6 +332,8 @@ def experiments(self, tags=[], qtype="or", name=None): list of rubicon.client.Experiment The experiments previously logged to this project. """ + if tags is None: + tags = [] return_err = None for repo in self.repositories: try: @@ -338,7 +347,13 @@ def experiments(self, tags=[], qtype="or", name=None): raise RubiconException("all configured storage backends failed") from return_err @failsafe - def dataframes(self, tags=[], qtype="or", recursive=False, name=None): + def dataframes( + self, + tags: Optional[List[str]] = None, + qtype: str = "or", + recursive: bool = False, + name: Optional[str] = None, + ) -> List[Dataframe]: """Get the dataframes logged to this project. Parameters @@ -359,6 +374,8 @@ def dataframes(self, tags=[], qtype="or", recursive=False, name=None): list of rubicon.client.Dataframe The dataframes previously logged to this client object. """ + if tags is None: + tags = [] super().dataframes(tags=tags, qtype=qtype, name=name) if recursive is True: @@ -368,7 +385,11 @@ def dataframes(self, tags=[], qtype="or", recursive=False, name=None): return self._dataframes @failsafe - def archive(self, experiments: Optional[List[Experiment]] = None, remote_rubicon=None): + def archive( + self, + experiments: Optional[List[Experiment]] = None, + remote_rubicon: Optional[Rubicon] = None, + ): """Archive the experiments logged to this project. Parameters diff --git a/rubicon_ml/client/rubicon.py b/rubicon_ml/client/rubicon.py index ea9f038b..a12f0a3a 100644 --- a/rubicon_ml/client/rubicon.py +++ b/rubicon_ml/client/rubicon.py @@ -1,4 +1,5 @@ import subprocess +from typing import Optional import warnings from rubicon_ml import domain @@ -32,7 +33,11 @@ class Rubicon: """ def __init__( - self, persistence="filesystem", root_dir=None, auto_git_enabled=False, **storage_options + self, + persistence: Optional[str] = "filesystem", + root_dir=None, + auto_git_enabled=False, + **storage_options, ): self.config = Config(persistence, root_dir, auto_git_enabled, **storage_options) diff --git a/rubicon_ml/client/utils/exception_handling.py b/rubicon_ml/client/utils/exception_handling.py index e8fc3835..1880836a 100644 --- a/rubicon_ml/client/utils/exception_handling.py +++ b/rubicon_ml/client/utils/exception_handling.py @@ -1,6 +1,7 @@ import functools import logging import traceback +from typing import Callable, Optional import warnings FAILURE_MODE = "raise" @@ -9,7 +10,9 @@ TRACEBACK_LIMIT = None -def set_failure_mode(failure_mode, traceback_chain=False, traceback_limit=None): +def set_failure_mode( + failure_mode: str, traceback_chain: bool = False, traceback_limit: Optional[int] = None +) -> None: """Set the failure mode. Parameters @@ -38,7 +41,7 @@ def set_failure_mode(failure_mode, traceback_chain=False, traceback_limit=None): TRACEBACK_LIMIT = traceback_limit -def failsafe(func): +def failsafe(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs): try: diff --git a/rubicon_ml/client/utils/tags.py b/rubicon_ml/client/utils/tags.py index a6c62250..bb444bce 100644 --- a/rubicon_ml/client/utils/tags.py +++ b/rubicon_ml/client/utils/tags.py @@ -1,4 +1,7 @@ -def has_tag_requirements(tags, required_tags, qtype): +from typing import List + + +def has_tag_requirements(tags: List[str], required_tags: List[str], qtype: str) -> bool: """Returns True if `tags` meets the requirements based on the values of `required_tags` and `qtype`. False otherwise. """ diff --git a/rubicon_ml/domain/artifact.py b/rubicon_ml/domain/artifact.py index 4a47f59b..14dcf01b 100644 --- a/rubicon_ml/domain/artifact.py +++ b/rubicon_ml/domain/artifact.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import List +from typing import List, Optional from rubicon_ml.domain.mixin import TagMixin from rubicon_ml.domain.utils import uuid @@ -13,8 +13,8 @@ class Artifact(TagMixin): name: str id: str = field(default_factory=uuid.uuid4) - description: str = None + description: Optional[str] = None created_at: datetime = field(default_factory=datetime.utcnow) tags: List[str] = field(default_factory=list) - parent_id: str = None + parent_id: Optional[str] = None diff --git a/rubicon_ml/domain/dataframe.py b/rubicon_ml/domain/dataframe.py index f8b12b51..e2f085b8 100644 --- a/rubicon_ml/domain/dataframe.py +++ b/rubicon_ml/domain/dataframe.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import List +from typing import List, Optional from rubicon_ml.domain.mixin import TagMixin from rubicon_ml.domain.utils import uuid @@ -11,9 +11,9 @@ @dataclass class Dataframe(TagMixin): id: str = field(default_factory=uuid.uuid4) - name: str = None - description: str = None + name: Optional[str] = None + description: Optional[str] = None tags: List[str] = field(default_factory=list) created_at: datetime = field(default_factory=datetime.utcnow) - parent_id: str = None + parent_id: Optional[str] = None diff --git a/rubicon_ml/domain/experiment.py b/rubicon_ml/domain/experiment.py index 7c98da2f..d5118b37 100644 --- a/rubicon_ml/domain/experiment.py +++ b/rubicon_ml/domain/experiment.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import List +from typing import List, Optional from rubicon_ml.domain.mixin import TagMixin from rubicon_ml.domain.utils import TrainingMetadata, uuid @@ -13,11 +13,11 @@ class Experiment(TagMixin): project_name: str id: str = field(default_factory=uuid.uuid4) - name: str = None - description: str = None - model_name: str = None - branch_name: str = None - commit_hash: str = None - training_metadata: TrainingMetadata = None + name: Optional[str] = None + description: Optional[str] = None + model_name: Optional[str] = None + branch_name: Optional[str] = None + commit_hash: Optional[str] = None + training_metadata: Optional[TrainingMetadata] = None tags: List[str] = field(default_factory=list) created_at: datetime = field(default_factory=datetime.utcnow) diff --git a/rubicon_ml/domain/feature.py b/rubicon_ml/domain/feature.py index 9da684dd..fd4215d4 100644 --- a/rubicon_ml/domain/feature.py +++ b/rubicon_ml/domain/feature.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import List +from typing import List, Optional from rubicon_ml.domain.mixin import TagMixin from rubicon_ml.domain.utils import uuid @@ -11,7 +11,7 @@ class Feature(TagMixin): name: str id: str = field(default_factory=uuid.uuid4) - description: str = None - importance: float = None + description: Optional[str] = None + importance: Optional[float] = None tags: List[str] = field(default_factory=list) created_at: datetime = field(default_factory=datetime.utcnow) diff --git a/rubicon_ml/domain/metric.py b/rubicon_ml/domain/metric.py index b58278fa..9efd338d 100644 --- a/rubicon_ml/domain/metric.py +++ b/rubicon_ml/domain/metric.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import List +from typing import List, Optional from rubicon_ml.domain.mixin import TagMixin from rubicon_ml.domain.utils import uuid @@ -14,7 +14,7 @@ class Metric(TagMixin): value: float id: str = field(default_factory=uuid.uuid4) - description: str = None + description: Optional[str] = None directionality: str = "score" created_at: datetime = field(default_factory=datetime.utcnow) tags: List[str] = field(default_factory=list) diff --git a/rubicon_ml/domain/mixin.py b/rubicon_ml/domain/mixin.py index 4dd3f4ab..974f3bdb 100644 --- a/rubicon_ml/domain/mixin.py +++ b/rubicon_ml/domain/mixin.py @@ -1,8 +1,27 @@ +from typing import List + + class TagMixin: """Adds tagging support to a domain model.""" - def add_tags(self, tags): + def add_tags(self, tags: List[str]): + """ + Add new tags to this model. + + Parameters + ---------- + tags : List[str] + A list of string tags to add to the domain model. + """ self.tags = list(set(self.tags).union(set(tags))) - def remove_tags(self, tags): + def remove_tags(self, tags: List[str]): + """ + Remove tags from this model. + + Parameters + ---------- + tags : List[str] + A list of string tags to remove from this domain model. + """ self.tags = list(set(self.tags).difference(set(tags))) diff --git a/rubicon_ml/domain/parameter.py b/rubicon_ml/domain/parameter.py index f8497416..1238c5cc 100644 --- a/rubicon_ml/domain/parameter.py +++ b/rubicon_ml/domain/parameter.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import List +from typing import List, Optional from rubicon_ml.domain.mixin import TagMixin from rubicon_ml.domain.utils import uuid @@ -12,6 +12,6 @@ class Parameter(TagMixin): id: str = field(default_factory=uuid.uuid4) value: object = None - description: str = None + description: Optional[str] = None tags: List[str] = field(default_factory=list) created_at: datetime = field(default_factory=datetime.utcnow) diff --git a/rubicon_ml/domain/project.py b/rubicon_ml/domain/project.py index fe7190b4..ff159590 100644 --- a/rubicon_ml/domain/project.py +++ b/rubicon_ml/domain/project.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field from datetime import datetime +from typing import Optional from rubicon_ml.domain.utils import TrainingMetadata, uuid @@ -11,7 +12,7 @@ class Project: name: str id: str = field(default_factory=uuid.uuid4) - description: str = None - github_url: str = None - training_metadata: TrainingMetadata = None + description: Optional[str] = None + github_url: Optional[str] = None + training_metadata: Optional[TrainingMetadata] = None created_at: datetime = field(default_factory=datetime.utcnow) diff --git a/rubicon_ml/domain/utils/training_metadata.py b/rubicon_ml/domain/utils/training_metadata.py index b766b474..a636756b 100644 --- a/rubicon_ml/domain/utils/training_metadata.py +++ b/rubicon_ml/domain/utils/training_metadata.py @@ -1,3 +1,5 @@ +from typing import List, Tuple, Union + from rubicon_ml.exceptions import RubiconException @@ -25,7 +27,7 @@ class TrainingMetadata: [('s3', ['bucket/a.csv', 'bucket/b.csv'], 'SELECT * FROM x')] """ - def __init__(self, training_metadata): + def __init__(self, training_metadata: Union[List[Tuple], Tuple]): if not isinstance(training_metadata, list): training_metadata = [training_metadata] @@ -34,5 +36,5 @@ def __init__(self, training_metadata): self.training_metadata = training_metadata - def __repr__(self): + def __repr__(self) -> str: return str(self.training_metadata) diff --git a/rubicon_ml/domain/utils/uuid.py b/rubicon_ml/domain/utils/uuid.py index c57a60bb..bbf1f567 100644 --- a/rubicon_ml/domain/utils/uuid.py +++ b/rubicon_ml/domain/utils/uuid.py @@ -1,7 +1,7 @@ import uuid -def uuid4(): +def uuid4() -> str: """Generate a UUID as a string in a single function. To be used as a default factory within `dataclasses.field`. diff --git a/rubicon_ml/repository/base.py b/rubicon_ml/repository/base.py index 9b33a2c1..950cefd7 100644 --- a/rubicon_ml/repository/base.py +++ b/rubicon_ml/repository/base.py @@ -36,7 +36,7 @@ class BaseRepository: the underlying filesystem class. """ - def __init__(self, root_dir, **storage_options): + def __init__(self, root_dir: str, **storage_options): self.filesystem = fsspec.filesystem(self.PROTOCOL, **storage_options) self.root_dir = root_dir.rstrip("/") From 0a74eef63388d943e4608078ea5776bb255ad60f Mon Sep 17 00:00:00 2001 From: stephenpardy Date: Mon, 11 Sep 2023 09:31:50 -0500 Subject: [PATCH 05/11] Fixing some linting issues --- rubicon_ml/client/artifact.py | 5 ++--- rubicon_ml/client/base.py | 3 +-- rubicon_ml/client/config.py | 9 +++++++-- rubicon_ml/client/dataframe.py | 5 ++--- rubicon_ml/client/experiment.py | 5 ++--- rubicon_ml/client/feature.py | 2 +- rubicon_ml/client/metric.py | 4 ++-- rubicon_ml/client/mixin.py | 6 +++--- rubicon_ml/client/parameter.py | 5 ++--- rubicon_ml/client/project.py | 6 +++--- rubicon_ml/client/rubicon.py | 2 +- rubicon_ml/client/utils/exception_handling.py | 2 +- rubicon_ml/domain/__init__.py | 1 + 13 files changed, 28 insertions(+), 27 deletions(-) diff --git a/rubicon_ml/client/artifact.py b/rubicon_ml/client/artifact.py index 56650d54..0c2abbec 100644 --- a/rubicon_ml/client/artifact.py +++ b/rubicon_ml/client/artifact.py @@ -1,7 +1,7 @@ import os import pickle -from typing import Optional, TYPE_CHECKING import warnings +from typing import TYPE_CHECKING, Optional import fsspec @@ -10,10 +10,9 @@ from rubicon_ml.client.utils.exception_handling import failsafe from rubicon_ml.exceptions import RubiconException - if TYPE_CHECKING: - from rubicon_ml.domain import Artifact as ArtifactDomain from rubicon_ml.client import Project + from rubicon_ml.domain import Artifact as ArtifactDomain class Artifact(Base, TagMixin): diff --git a/rubicon_ml/client/base.py b/rubicon_ml/client/base.py index bb735dc9..d8e9e41a 100644 --- a/rubicon_ml/client/base.py +++ b/rubicon_ml/client/base.py @@ -1,5 +1,4 @@ -from typing import List, Optional, TYPE_CHECKING - +from typing import TYPE_CHECKING, List, Optional if TYPE_CHECKING: from rubicon_ml.client import Config diff --git a/rubicon_ml/client/config.py b/rubicon_ml/client/config.py index 8f4913da..1dd74739 100644 --- a/rubicon_ml/client/config.py +++ b/rubicon_ml/client/config.py @@ -1,9 +1,14 @@ import os -from typing import Dict, Optional, Tuple import subprocess +from typing import Dict, Optional, Tuple from rubicon_ml.exceptions import RubiconException -from rubicon_ml.repository import BaseRepository, LocalRepository, MemoryRepository, S3Repository +from rubicon_ml.repository import ( + BaseRepository, + LocalRepository, + MemoryRepository, + S3Repository, +) class Config: diff --git a/rubicon_ml/client/dataframe.py b/rubicon_ml/client/dataframe.py index a76da08c..29571491 100644 --- a/rubicon_ml/client/dataframe.py +++ b/rubicon_ml/client/dataframe.py @@ -1,13 +1,12 @@ -from typing import Callable, Literal, Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Callable, Literal, Optional, Union from rubicon_ml.client import Base, TagMixin from rubicon_ml.client.utils.exception_handling import failsafe from rubicon_ml.exceptions import RubiconException - if TYPE_CHECKING: - from rubicon_ml.domain import Dataframe as DataframeDomain from rubicon_ml.client import Experiment, Project + from rubicon_ml.domain import Dataframe as DataframeDomain class Dataframe(Base, TagMixin): diff --git a/rubicon_ml/client/experiment.py b/rubicon_ml/client/experiment.py index b7b23398..aafc278f 100644 --- a/rubicon_ml/client/experiment.py +++ b/rubicon_ml/client/experiment.py @@ -1,4 +1,4 @@ -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from rubicon_ml import domain from rubicon_ml.client import ( @@ -14,10 +14,9 @@ from rubicon_ml.client.utils.tags import filter_children from rubicon_ml.exceptions import RubiconException - if TYPE_CHECKING: - from rubicon_ml.domain import Experiment as ExperimentDomain from rubicon_ml.client import Project + from rubicon_ml.domain import Experiment as ExperimentDomain class Experiment(Base, ArtifactMixin, DataframeMixin, TagMixin): diff --git a/rubicon_ml/client/feature.py b/rubicon_ml/client/feature.py index 49c9a73b..f7fd112a 100644 --- a/rubicon_ml/client/feature.py +++ b/rubicon_ml/client/feature.py @@ -4,8 +4,8 @@ from rubicon_ml.client import Base, TagMixin if TYPE_CHECKING: - from rubicon_ml.domain import Feature as FeatureDomain from rubicon_ml.client import Experiment + from rubicon_ml.domain import Feature as FeatureDomain class Feature(Base, TagMixin): diff --git a/rubicon_ml/client/metric.py b/rubicon_ml/client/metric.py index 88677ffc..a2ba6e84 100644 --- a/rubicon_ml/client/metric.py +++ b/rubicon_ml/client/metric.py @@ -1,11 +1,11 @@ from datetime import datetime +from typing import TYPE_CHECKING, Optional -from typing import Optional, TYPE_CHECKING from rubicon_ml.client import Base, TagMixin if TYPE_CHECKING: - from rubicon_ml.domain import Metric as MetricDomain from rubicon_ml.client import Experiment + from rubicon_ml.domain import Metric as MetricDomain class Metric(Base, TagMixin): diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index 308bfa25..0058dcc7 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -1,9 +1,9 @@ import os import pickle import subprocess -from typing import Any, Optional, List, Union import warnings from datetime import datetime +from typing import Any, List, Optional, TYPE_CHECKING, Union import fsspec @@ -13,8 +13,8 @@ from rubicon_ml.exceptions import RubiconException if TYPE_CHECKING: - import pandas as pd import dask.dataframe as dd + import pandas as pd from rubicon_ml.client import Artifact, Dataframe @@ -526,7 +526,7 @@ def _update_tags(self, tag_data): self._domain.remove_tags(tag.get("removed_tags", [])) @property - def tags(self): + def tags(self) -> List[str]: """Get this client object's tags.""" project_name, experiment_id, entity_identifier = self._get_taggable_identifiers() return_err = None diff --git a/rubicon_ml/client/parameter.py b/rubicon_ml/client/parameter.py index 69d0cbf7..7a2f25e8 100644 --- a/rubicon_ml/client/parameter.py +++ b/rubicon_ml/client/parameter.py @@ -1,12 +1,11 @@ from datetime import datetime -from typing import Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Optional, Union from rubicon_ml.client import Base, TagMixin - if TYPE_CHECKING: - from rubicon_ml.domain import Parameter as ParameterDomain from rubicon_ml.client import Experiment + from rubicon_ml.domain import Parameter as ParameterDomain class Parameter(Base, TagMixin): diff --git a/rubicon_ml/client/project.py b/rubicon_ml/client/project.py index a6cd7cdc..4ff9dcc9 100644 --- a/rubicon_ml/client/project.py +++ b/rubicon_ml/client/project.py @@ -1,6 +1,6 @@ import subprocess import warnings -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import dask.dataframe as dd import pandas as pd @@ -12,9 +12,9 @@ from rubicon_ml.exceptions import RubiconException if TYPE_CHECKING: - from rubicon_ml.domain import Project as ProjectDomain - from rubicon_ml.client import Config, Dataframe from rubicon_ml import Rubicon + from rubicon_ml.client import Config, Dataframe + from rubicon_ml.domain import Project as ProjectDomain class Project(Base, ArtifactMixin, DataframeMixin): diff --git a/rubicon_ml/client/rubicon.py b/rubicon_ml/client/rubicon.py index a12f0a3a..93c4417b 100644 --- a/rubicon_ml/client/rubicon.py +++ b/rubicon_ml/client/rubicon.py @@ -1,6 +1,6 @@ import subprocess -from typing import Optional import warnings +from typing import Optional from rubicon_ml import domain from rubicon_ml.client import Config, Project diff --git a/rubicon_ml/client/utils/exception_handling.py b/rubicon_ml/client/utils/exception_handling.py index 1880836a..36407e79 100644 --- a/rubicon_ml/client/utils/exception_handling.py +++ b/rubicon_ml/client/utils/exception_handling.py @@ -1,8 +1,8 @@ import functools import logging import traceback -from typing import Callable, Optional import warnings +from typing import Callable, Optional FAILURE_MODE = "raise" FAILURE_MODES = ["log", "raise", "warn"] diff --git a/rubicon_ml/domain/__init__.py b/rubicon_ml/domain/__init__.py index 064d09ea..7db2ea93 100644 --- a/rubicon_ml/domain/__init__.py +++ b/rubicon_ml/domain/__init__.py @@ -1,4 +1,5 @@ from __future__ import annotations + from typing import Union from rubicon_ml.domain.artifact import Artifact From 6251bd0d35c640954fc0a8f691ba259109a92b58 Mon Sep 17 00:00:00 2001 From: stephenpardy Date: Mon, 11 Sep 2023 10:27:57 -0500 Subject: [PATCH 06/11] Fix import order and use github --- .pre-commit-config.yaml | 2 +- rubicon_ml/client/mixin.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f6ec9ea2..06905f7c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: hooks: - id: isort - - repo: https://gitlab.com/pycqa/flake8 + - repo: https://github.com/pycqa/flake8 rev: 3.9.2 hooks: - id: flake8 diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index 0058dcc7..730dd88a 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -3,7 +3,7 @@ import subprocess import warnings from datetime import datetime -from typing import Any, List, Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Any, List, Optional, Union import fsspec From 1a45b9252eb9eeaf58d5325252ba8f8d29a2d6fa Mon Sep 17 00:00:00 2001 From: stephenpardy Date: Mon, 11 Sep 2023 16:45:56 -0500 Subject: [PATCH 07/11] More hints --- rubicon_ml/client/config.py | 2 +- rubicon_ml/client/dataframe.py | 2 ++ rubicon_ml/client/experiment.py | 2 ++ rubicon_ml/client/feature.py | 2 ++ rubicon_ml/client/metric.py | 2 ++ rubicon_ml/client/mixin.py | 7 +++++++ rubicon_ml/client/project.py | 2 ++ 7 files changed, 18 insertions(+), 1 deletion(-) diff --git a/rubicon_ml/client/config.py b/rubicon_ml/client/config.py index 1dd74739..881376ad 100644 --- a/rubicon_ml/client/config.py +++ b/rubicon_ml/client/config.py @@ -50,7 +50,7 @@ def __init__( ): self.storage_options = storage_options if storage_options is not None and "composite_config" in storage_options: - composite_config = storage_options.get("composite_config") + composite_config = storage_options.get("composite_config", []) repositories = [] for config in composite_config: self.persistence, self.root_dir, self.is_auto_git_enabled = self._load_config( diff --git a/rubicon_ml/client/dataframe.py b/rubicon_ml/client/dataframe.py index 29571491..34c10a9f 100644 --- a/rubicon_ml/client/dataframe.py +++ b/rubicon_ml/client/dataframe.py @@ -33,6 +33,8 @@ class Dataframe(Base, TagMixin): def __init__(self, domain: DataframeDomain, parent: Union[Experiment, Project]): super().__init__(domain, parent._config) + self._domain: DataframeDomain + self._data = None self._parent = parent diff --git a/rubicon_ml/client/experiment.py b/rubicon_ml/client/experiment.py index aafc278f..02ed29b9 100644 --- a/rubicon_ml/client/experiment.py +++ b/rubicon_ml/client/experiment.py @@ -39,6 +39,8 @@ class Experiment(Base, ArtifactMixin, DataframeMixin, TagMixin): def __init__(self, domain: ExperimentDomain, parent: Project): super().__init__(domain, parent._config) + self._domain: ExperimentDomain + self._parent = parent self._artifacts = [] self._dataframes = [] diff --git a/rubicon_ml/client/feature.py b/rubicon_ml/client/feature.py index f7fd112a..94100e53 100644 --- a/rubicon_ml/client/feature.py +++ b/rubicon_ml/client/feature.py @@ -35,6 +35,8 @@ class Feature(Base, TagMixin): def __init__(self, domain: FeatureDomain, parent: Experiment): super().__init__(domain, parent._config) + self._domain: FeatureDomain + self._data = None self._parent = parent diff --git a/rubicon_ml/client/metric.py b/rubicon_ml/client/metric.py index a2ba6e84..1a28c4fd 100644 --- a/rubicon_ml/client/metric.py +++ b/rubicon_ml/client/metric.py @@ -31,6 +31,8 @@ class Metric(Base, TagMixin): def __init__(self, domain: MetricDomain, parent: Experiment): super().__init__(domain, parent._config) + self._domain: MetricDomain + self._data = None self._parent = parent diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index 730dd88a..df386c3c 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -17,11 +17,14 @@ import pandas as pd from rubicon_ml.client import Artifact, Dataframe + from rubicon_ml.domain import DOMAIN_TYPES, Artifact as ArtifactDomain class ArtifactMixin: """Adds artifact support to a client object.""" + _domain: ArtifactDomain + def _validate_data(self, data_bytes, data_file, data_object, data_path, name): """Raises a `RubiconException` if the data to log as an artifact is improperly provided. @@ -310,6 +313,8 @@ def delete_artifacts(self, ids: List[str]): class DataframeMixin: """Adds dataframe support to a client object.""" + _domain: DOMAIN_TYPES + @failsafe def log_dataframe( self, df: Union[pd.DataFrame, dd.DataFrame], description=None, name=None, tags=[] @@ -456,6 +461,8 @@ def delete_dataframes(self, ids: List[str]): class TagMixin: """Adds tag support to a client object.""" + _domain: DOMAIN_TYPES + def _get_taggable_identifiers(self): project_name, experiment_id = self._parent._get_identifiers() entity_identifier = None diff --git a/rubicon_ml/client/project.py b/rubicon_ml/client/project.py index 4ff9dcc9..3a3ec60b 100644 --- a/rubicon_ml/client/project.py +++ b/rubicon_ml/client/project.py @@ -34,6 +34,8 @@ class Project(Base, ArtifactMixin, DataframeMixin): def __init__(self, domain: ProjectDomain, config: Optional[Config] = None): super().__init__(domain, config) + self._domain: ProjectDomain + self._artifacts = [] self._dataframes = [] self._experiments = [] From bd3eb2f27b00cee14042fa70b784e7449614d129 Mon Sep 17 00:00:00 2001 From: stephenpardy Date: Tue, 12 Sep 2023 09:24:13 -0500 Subject: [PATCH 08/11] Add annotations import --- rubicon_ml/client/artifact.py | 2 ++ rubicon_ml/client/base.py | 2 ++ rubicon_ml/client/dataframe.py | 4 +++- rubicon_ml/client/experiment.py | 2 ++ rubicon_ml/client/feature.py | 2 ++ rubicon_ml/client/metric.py | 2 ++ rubicon_ml/client/parameter.py | 2 ++ rubicon_ml/client/project.py | 2 ++ 8 files changed, 17 insertions(+), 1 deletion(-) diff --git a/rubicon_ml/client/artifact.py b/rubicon_ml/client/artifact.py index 0c2abbec..5a24bb81 100644 --- a/rubicon_ml/client/artifact.py +++ b/rubicon_ml/client/artifact.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import pickle import warnings diff --git a/rubicon_ml/client/base.py b/rubicon_ml/client/base.py index d8e9e41a..30be1b76 100644 --- a/rubicon_ml/client/base.py +++ b/rubicon_ml/client/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING, List, Optional if TYPE_CHECKING: diff --git a/rubicon_ml/client/dataframe.py b/rubicon_ml/client/dataframe.py index 34c10a9f..facef63f 100644 --- a/rubicon_ml/client/dataframe.py +++ b/rubicon_ml/client/dataframe.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING, Callable, Literal, Optional, Union from rubicon_ml.client import Base, TagMixin @@ -71,7 +73,7 @@ def plot( self, df_type: Literal["pandas", "dask"] = "pandas", plotting_func: Optional[Callable] = None, - **kwargs + **kwargs, ): """Render the dataframe using `plotly.express`. diff --git a/rubicon_ml/client/experiment.py b/rubicon_ml/client/experiment.py index 02ed29b9..50803f07 100644 --- a/rubicon_ml/client/experiment.py +++ b/rubicon_ml/client/experiment.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING from rubicon_ml import domain diff --git a/rubicon_ml/client/feature.py b/rubicon_ml/client/feature.py index 94100e53..505d2803 100644 --- a/rubicon_ml/client/feature.py +++ b/rubicon_ml/client/feature.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from typing import TYPE_CHECKING, Optional diff --git a/rubicon_ml/client/metric.py b/rubicon_ml/client/metric.py index 1a28c4fd..d5390c26 100644 --- a/rubicon_ml/client/metric.py +++ b/rubicon_ml/client/metric.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from typing import TYPE_CHECKING, Optional diff --git a/rubicon_ml/client/parameter.py b/rubicon_ml/client/parameter.py index 7a2f25e8..046c6bbf 100644 --- a/rubicon_ml/client/parameter.py +++ b/rubicon_ml/client/parameter.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from typing import TYPE_CHECKING, Optional, Union diff --git a/rubicon_ml/client/project.py b/rubicon_ml/client/project.py index 3a3ec60b..dffc5cb2 100644 --- a/rubicon_ml/client/project.py +++ b/rubicon_ml/client/project.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import subprocess import warnings from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union From 2af829b7dc494aace1f5c676c4af2ce1ff969375 Mon Sep 17 00:00:00 2001 From: stephenpardy Date: Tue, 12 Sep 2023 09:30:22 -0500 Subject: [PATCH 09/11] Import fix --- rubicon_ml/client/mixin.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index df386c3c..65807329 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import pickle import subprocess @@ -10,6 +12,7 @@ from rubicon_ml import client, domain from rubicon_ml.client.utils.exception_handling import failsafe from rubicon_ml.client.utils.tags import filter_children +from rubicon_ml.domain import Artifact as ArtifactDomain from rubicon_ml.exceptions import RubiconException if TYPE_CHECKING: @@ -17,7 +20,7 @@ import pandas as pd from rubicon_ml.client import Artifact, Dataframe - from rubicon_ml.domain import DOMAIN_TYPES, Artifact as ArtifactDomain + from rubicon_ml.domain import DOMAIN_TYPES class ArtifactMixin: From 10d3a12530e0f521315cfe74280282362ffb1d75 Mon Sep 17 00:00:00 2001 From: stephenpardy Date: Tue, 12 Sep 2023 11:23:50 -0500 Subject: [PATCH 10/11] Fix for python3.8 --- rubicon_ml/client/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rubicon_ml/client/config.py b/rubicon_ml/client/config.py index 881376ad..623e89df 100644 --- a/rubicon_ml/client/config.py +++ b/rubicon_ml/client/config.py @@ -1,6 +1,6 @@ import os import subprocess -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Type from rubicon_ml.exceptions import RubiconException from rubicon_ml.repository import ( @@ -35,7 +35,7 @@ class Config: """ PERSISTENCE_TYPES = ["filesystem", "memory"] - REPOSITORIES: Dict[str, type[BaseRepository]] = { + REPOSITORIES: Dict[str, Type[BaseRepository]] = { "memory-memory": MemoryRepository, "filesystem-local": LocalRepository, "filesystem-s3": S3Repository, From 734f76d395fd1bc74d7c57f492fbb11ad37b3e41 Mon Sep 17 00:00:00 2001 From: stephenpardy Date: Wed, 13 Sep 2023 14:18:06 -0500 Subject: [PATCH 11/11] More hints --- rubicon_ml/client/mixin.py | 12 +++++++++--- rubicon_ml/client/parameter.py | 4 +++- rubicon_ml/client/rubicon.py | 21 +++++++++++++++------ 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index 65807329..37cc5c0e 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -5,7 +5,7 @@ import subprocess import warnings from datetime import datetime -from typing import TYPE_CHECKING, Any, List, Optional, Union +from typing import TYPE_CHECKING, Any, List, Optional, TextIO, Union import fsspec @@ -61,7 +61,7 @@ def _validate_data(self, data_bytes, data_file, data_object, data_path, name): def log_artifact( self, data_bytes: Optional[bytes] = None, - data_file=None, + data_file: Optional[TextIO] = None, data_object: Optional[Any] = None, data_path: Optional[str] = None, name: Optional[str] = None, @@ -320,7 +320,11 @@ class DataframeMixin: @failsafe def log_dataframe( - self, df: Union[pd.DataFrame, dd.DataFrame], description=None, name=None, tags=[] + self, + df: Union[pd.DataFrame, dd.DataFrame], + description: Optional[str] = None, + name: Optional[str] = None, + tags: Optional[List[str]] = None, ) -> Dataframe: """Log a dataframe to this client object. @@ -339,6 +343,8 @@ def log_dataframe( rubicon.client.Dataframe The new dataframe. """ + if tags is None: + tags = [] if not isinstance(tags, list) or not all([isinstance(tag, str) for tag in tags]): raise ValueError("`tags` must be `list` of type `str`") diff --git a/rubicon_ml/client/parameter.py b/rubicon_ml/client/parameter.py index 046c6bbf..e17a20ad 100644 --- a/rubicon_ml/client/parameter.py +++ b/rubicon_ml/client/parameter.py @@ -35,6 +35,8 @@ def __init__(self, domain: ParameterDomain, parent: Experiment): super().__init__(domain, parent._config) self._parent = parent + self._domain: ParameterDomain + @property def id(self) -> str: """Get the parameter's id.""" @@ -48,7 +50,7 @@ def name(self) -> Optional[str]: @property def value(self) -> Optional[Union[object, float]]: """Get the parameter's value.""" - return getattr(self._domain, "value", None) + return self._domain.value @property def description(self) -> Optional[str]: diff --git a/rubicon_ml/client/rubicon.py b/rubicon_ml/client/rubicon.py index 93c4417b..a7c02a88 100644 --- a/rubicon_ml/client/rubicon.py +++ b/rubicon_ml/client/rubicon.py @@ -1,10 +1,11 @@ import subprocess import warnings -from typing import Optional +from typing import List, Optional, Tuple, Union from rubicon_ml import domain from rubicon_ml.client import Config, Project from rubicon_ml.client.utils.exception_handling import failsafe +from rubicon_ml.domain.utils import TrainingMetadata from rubicon_ml.exceptions import RubiconException from rubicon_ml.repository.utils import slugify @@ -35,8 +36,8 @@ class Rubicon: def __init__( self, persistence: Optional[str] = "filesystem", - root_dir=None, - auto_git_enabled=False, + root_dir: Optional[str] = None, + auto_git_enabled: bool = False, **storage_options, ): self.config = Config(persistence, root_dir, auto_git_enabled, **storage_options) @@ -69,19 +70,27 @@ def _get_github_url(self): return github_url - def _create_project_domain(self, name, description, github_url, training_metadata): + def _create_project_domain( + self, + name: str, + description: str, + github_url: str, + training_metadata: Union[List[Tuple], Tuple], + ): """Instantiates and returns a project domain object.""" if self.config.is_auto_git_enabled and github_url is None: github_url = self._get_github_url() if training_metadata is not None: - training_metadata = domain.utils.TrainingMetadata(training_metadata) + training_metadata_class = TrainingMetadata(training_metadata) + else: + training_metadata_class = None return domain.Project( name, description=description, github_url=github_url, - training_metadata=training_metadata, + training_metadata=training_metadata_class, ) @failsafe