diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6db48054..d90a63ad 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: rev: 5.12.0 hooks: - id: isort - + - repo: https://github.com/pycqa/flake8 rev: 6.1.0 hooks: diff --git a/rubicon_ml/client/artifact.py b/rubicon_ml/client/artifact.py index 6b2cb44b..5a24bb81 100644 --- a/rubicon_ml/client/artifact.py +++ b/rubicon_ml/client/artifact.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import os import pickle import warnings +from typing import TYPE_CHECKING, Optional import fsspec @@ -9,6 +12,10 @@ from rubicon_ml.client.utils.exception_handling import failsafe from rubicon_ml.exceptions import RubiconException +if TYPE_CHECKING: + from rubicon_ml.client import Project + from rubicon_ml.domain import Artifact as ArtifactDomain + class Artifact(Base, TagMixin): """A client artifact. @@ -32,7 +39,7 @@ class Artifact(Base, TagMixin): logged to. """ - def __init__(self, domain, parent): + def __init__(self, domain: ArtifactDomain, parent: Project): super().__init__(domain, parent._config) self._data = None @@ -42,8 +49,8 @@ def _get_data(self): """Loads the data associated with this artifact.""" project_name, experiment_id = self.parent._get_identifiers() return_err = None - for repo in self.repositories: - self._data = None + self._data = None + for repo in self.repositories or []: try: self._data = repo.get_artifact_data( project_name, self.id, experiment_id=experiment_id @@ -56,7 +63,7 @@ def _get_data(self): raise RubiconException("all configured storage backends failed") from return_err @failsafe - def get_data(self, unpickle=False): + def get_data(self, unpickle: bool = False): """Loads the data associated with this artifact and unpickles if needed. @@ -68,7 +75,7 @@ def get_data(self, unpickle=False): """ project_name, experiment_id = self.parent._get_identifiers() return_err = None - for repo in self.repositories: + for repo in self.repositories or []: try: data = repo.get_artifact_data(project_name, self.id, experiment_id=experiment_id) except Exception as err: @@ -80,7 +87,7 @@ def get_data(self, unpickle=False): raise RubiconException("all configured storage backends failed") from return_err @failsafe - def download(self, location=None, name=None): + def download(self, location: Optional[str] = None, name: Optional[str] = None): """Download this artifact's data. Parameters @@ -104,17 +111,17 @@ def download(self, location=None, name=None): f.write(self.data) @property - def id(self): + def id(self) -> str: """Get the artifact's id.""" return self._domain.id @property - def name(self): + def name(self) -> str: """Get the artifact's name.""" return self._domain.name @property - def description(self): + def description(self) -> str: """Get the artifact's description.""" return self._domain.description diff --git a/rubicon_ml/client/base.py b/rubicon_ml/client/base.py index 736f0851..30be1b76 100644 --- a/rubicon_ml/client/base.py +++ b/rubicon_ml/client/base.py @@ -1,3 +1,13 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional + +if TYPE_CHECKING: + from rubicon_ml.client import Config + from rubicon_ml.domain import DOMAIN_TYPES + from rubicon_ml.repository import BaseRepository + + class Base: """The base object for all top-level client objects. @@ -9,19 +19,22 @@ class Base: The config, which injects the repository to use. """ - def __init__(self, domain, config=None): + def __init__(self, domain: DOMAIN_TYPES, config: Optional[Config] = None): self._config = config self._domain = domain - def __str__(self): + def __str__(self) -> str: return self._domain.__str__() @property - def repository(self): - return self._config.repository + def repository(self) -> Optional[BaseRepository]: + return self._config.repository if self._config is not None else None @property - def repositories(self): + def repositories(self) -> Optional[List[BaseRepository]]: + if self._config is None: + return None + if hasattr(self._config, "repositories"): return self._config.repositories else: diff --git a/rubicon_ml/client/config.py b/rubicon_ml/client/config.py index a7b41bcf..623e89df 100644 --- a/rubicon_ml/client/config.py +++ b/rubicon_ml/client/config.py @@ -1,8 +1,14 @@ import os import subprocess +from typing import Dict, Optional, Tuple, Type from rubicon_ml.exceptions import RubiconException -from rubicon_ml.repository import LocalRepository, MemoryRepository, S3Repository +from rubicon_ml.repository import ( + BaseRepository, + LocalRepository, + MemoryRepository, + S3Repository, +) class Config: @@ -29,18 +35,22 @@ class Config: """ PERSISTENCE_TYPES = ["filesystem", "memory"] - REPOSITORIES = { + REPOSITORIES: Dict[str, Type[BaseRepository]] = { "memory-memory": MemoryRepository, "filesystem-local": LocalRepository, "filesystem-s3": S3Repository, } def __init__( - self, persistence=None, root_dir=None, is_auto_git_enabled=False, **storage_options + self, + persistence: Optional[str] = None, + root_dir: Optional[str] = None, + is_auto_git_enabled: bool = False, + **storage_options, ): self.storage_options = storage_options if storage_options is not None and "composite_config" in storage_options: - composite_config = storage_options.get("composite_config") + composite_config = storage_options.get("composite_config", []) repositories = [] for config in composite_config: self.persistence, self.root_dir, self.is_auto_git_enabled = self._load_config( @@ -62,7 +72,9 @@ def _check_is_in_git_repo(self): "Not a `git` repo: Falied to locate the '.git' directory in this or any parent directories." ) - def _load_config(self, persistence, root_dir, is_auto_git_enabled): + def _load_config( + self, persistence: Optional[str], root_dir: Optional[str], is_auto_git_enabled: bool + ) -> Tuple[str, Optional[str], bool]: """Get the configuration values.""" persistence = os.environ.get("PERSISTENCE", persistence) if persistence not in self.PERSISTENCE_TYPES: @@ -77,7 +89,7 @@ def _load_config(self, persistence, root_dir, is_auto_git_enabled): return (persistence, root_dir, is_auto_git_enabled) - def _get_protocol(self): + def _get_protocol(self) -> str: """Get the file protocol of the configured root directory.""" if self.persistence == "memory": return "memory" @@ -89,7 +101,7 @@ def _get_protocol(self): return "custom" # catch-all for external backends - def _get_repository(self): + def _get_repository(self) -> BaseRepository: """Get the repository for the configured persistence type.""" protocol = self._get_protocol() diff --git a/rubicon_ml/client/dataframe.py b/rubicon_ml/client/dataframe.py index d1c8394f..facef63f 100644 --- a/rubicon_ml/client/dataframe.py +++ b/rubicon_ml/client/dataframe.py @@ -1,7 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Literal, Optional, Union + from rubicon_ml.client import Base, TagMixin from rubicon_ml.client.utils.exception_handling import failsafe from rubicon_ml.exceptions import RubiconException +if TYPE_CHECKING: + from rubicon_ml.client import Experiment, Project + from rubicon_ml.domain import Dataframe as DataframeDomain + class Dataframe(Base, TagMixin): """A client dataframe. @@ -24,14 +32,16 @@ class Dataframe(Base, TagMixin): logged to. """ - def __init__(self, domain, parent): + def __init__(self, domain: DataframeDomain, parent: Union[Experiment, Project]): super().__init__(domain, parent._config) + self._domain: DataframeDomain + self._data = None self._parent = parent @failsafe - def get_data(self, df_type="pandas"): + def get_data(self, df_type: Literal["pandas", "dask"] = "pandas"): """Loads the data associated with this Dataframe into a `pandas` or `dask` dataframe. @@ -59,7 +69,12 @@ def get_data(self, df_type="pandas"): raise RubiconException(return_err) @failsafe - def plot(self, df_type="pandas", plotting_func=None, **kwargs): + def plot( + self, + df_type: Literal["pandas", "dask"] = "pandas", + plotting_func: Optional[Callable] = None, + **kwargs, + ): """Render the dataframe using `plotly.express`. Parameters diff --git a/rubicon_ml/client/experiment.py b/rubicon_ml/client/experiment.py index b71307d4..50803f07 100644 --- a/rubicon_ml/client/experiment.py +++ b/rubicon_ml/client/experiment.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from rubicon_ml import domain from rubicon_ml.client import ( ArtifactMixin, @@ -12,6 +16,10 @@ from rubicon_ml.client.utils.tags import filter_children from rubicon_ml.exceptions import RubiconException +if TYPE_CHECKING: + from rubicon_ml.client import Project + from rubicon_ml.domain import Experiment as ExperimentDomain + class Experiment(Base, ArtifactMixin, DataframeMixin, TagMixin): """A client experiment. @@ -30,9 +38,11 @@ class Experiment(Base, ArtifactMixin, DataframeMixin, TagMixin): The project that the experiment is logged to. """ - def __init__(self, domain, parent): + def __init__(self, domain: ExperimentDomain, parent: Project): super().__init__(domain, parent._config) + self._domain: ExperimentDomain + self._parent = parent self._artifacts = [] self._dataframes = [] diff --git a/rubicon_ml/client/feature.py b/rubicon_ml/client/feature.py index be1991fa..505d2803 100644 --- a/rubicon_ml/client/feature.py +++ b/rubicon_ml/client/feature.py @@ -1,5 +1,14 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING, Optional + from rubicon_ml.client import Base, TagMixin +if TYPE_CHECKING: + from rubicon_ml.client import Experiment + from rubicon_ml.domain import Feature as FeatureDomain + class Feature(Base, TagMixin): """A client feature. @@ -25,24 +34,26 @@ class Feature(Base, TagMixin): logged to. """ - def __init__(self, domain, parent): + def __init__(self, domain: FeatureDomain, parent: Experiment): super().__init__(domain, parent._config) + self._domain: FeatureDomain + self._data = None self._parent = parent @property - def id(self): + def id(self) -> str: """Get the feature's id.""" return self._domain.id @property - def name(self): + def name(self) -> Optional[str]: """Get the feature's name.""" return self._domain.name @property - def description(self): + def description(self) -> Optional[str]: """Get the feature's description.""" return self._domain.description @@ -52,11 +63,11 @@ def importance(self): return self._domain.importance @property - def created_at(self): + def created_at(self) -> datetime: """Get the feature's created_at.""" return self._domain.created_at @property - def parent(self): + def parent(self) -> Experiment: """Get the feature's parent client object.""" return self._parent diff --git a/rubicon_ml/client/metric.py b/rubicon_ml/client/metric.py index 4e61588a..d5390c26 100644 --- a/rubicon_ml/client/metric.py +++ b/rubicon_ml/client/metric.py @@ -1,5 +1,14 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING, Optional + from rubicon_ml.client import Base, TagMixin +if TYPE_CHECKING: + from rubicon_ml.client import Experiment + from rubicon_ml.domain import Metric as MetricDomain + class Metric(Base, TagMixin): """A client metric. @@ -21,19 +30,21 @@ class Metric(Base, TagMixin): logged to. """ - def __init__(self, domain, parent): + def __init__(self, domain: MetricDomain, parent: Experiment): super().__init__(domain, parent._config) + self._domain: MetricDomain + self._data = None self._parent = parent @property - def id(self): + def id(self) -> str: """Get the metric's id.""" return self._domain.id @property - def name(self): + def name(self) -> Optional[str]: """Get the metric's name.""" return self._domain.name @@ -43,21 +54,21 @@ def value(self): return self._domain.value @property - def directionality(self): + def directionality(self) -> str: """Get the metric's directionality.""" return self._domain.directionality @property - def description(self): + def description(self) -> Optional[str]: """Get the metric's description.""" return self._domain.description @property - def created_at(self): + def created_at(self) -> datetime: """Get the metric's created_at.""" return self._domain.created_at @property - def parent(self): + def parent(self) -> Experiment: """Get the metric's parent client object.""" return self._parent diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index f1c792ac..37cc5c0e 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -1,20 +1,33 @@ +from __future__ import annotations + import os import pickle import subprocess import warnings from datetime import datetime +from typing import TYPE_CHECKING, Any, List, Optional, TextIO, Union import fsspec from rubicon_ml import client, domain from rubicon_ml.client.utils.exception_handling import failsafe from rubicon_ml.client.utils.tags import filter_children +from rubicon_ml.domain import Artifact as ArtifactDomain from rubicon_ml.exceptions import RubiconException +if TYPE_CHECKING: + import dask.dataframe as dd + import pandas as pd + + from rubicon_ml.client import Artifact, Dataframe + from rubicon_ml.domain import DOMAIN_TYPES + class ArtifactMixin: """Adds artifact support to a client object.""" + _domain: ArtifactDomain + def _validate_data(self, data_bytes, data_file, data_object, data_path, name): """Raises a `RubiconException` if the data to log as an artifact is improperly provided. @@ -47,14 +60,14 @@ def _validate_data(self, data_bytes, data_file, data_object, data_path, name): @failsafe def log_artifact( self, - data_bytes=None, - data_file=None, - data_object=None, - data_path=None, - name=None, - description=None, - tags=[], - ): + data_bytes: Optional[bytes] = None, + data_file: Optional[TextIO] = None, + data_object: Optional[Any] = None, + data_path: Optional[str] = None, + name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[List[str]] = None, + ) -> Artifact: """Log an artifact to this client object. Parameters @@ -108,7 +121,8 @@ def log_artifact( ... data_path="./path/to/artifact.pkl", description="log artifact from file path" ... ) """ - + if tags is None: + tags = [] if not isinstance(tags, list) or not all([isinstance(tag, str) for tag in tags]): raise ValueError("`tags` must be `list` of type `str`") @@ -127,7 +141,7 @@ def log_artifact( return client.Artifact(artifact, self) - def _get_environment_bytes(self, export_cmd): + def _get_environment_bytes(self, export_cmd: List[str]) -> bytes: """Get the working environment as a sequence of bytes. Parameters @@ -148,7 +162,7 @@ def _get_environment_bytes(self, export_cmd): return completed_process.stdout @failsafe - def log_conda_environment(self, artifact_name=None): + def log_conda_environment(self, artifact_name: Optional[str] = None) -> Artifact: """Log the conda environment as an artifact to this client object. Useful for recreating your exact environment at a later date. @@ -175,7 +189,7 @@ def log_conda_environment(self, artifact_name=None): return artifact @failsafe - def log_pip_requirements(self, artifact_name=None): + def log_pip_requirements(self, artifact_name: Optional[str] = None) -> Artifact: """Log the pip requirements as an artifact to this client object. Useful for recreating your exact environment at a later date. @@ -198,7 +212,9 @@ def log_pip_requirements(self, artifact_name=None): return artifact @failsafe - def artifacts(self, name=None, tags=[], qtype="or"): + def artifacts( + self, name: Optional[str] = None, tags: Optional[List[str]] = None, qtype: str = "or" + ) -> List[Artifact]: """Get the artifacts logged to this client object. Parameters @@ -216,6 +232,8 @@ def artifacts(self, name=None, tags=[], qtype="or"): list of rubicon.client.Artifact The artifacts previously logged to this client object. """ + if tags is None: + tags = [] project_name, experiment_id = self._get_identifiers() return_err = None for repo in self.repositories: @@ -233,7 +251,7 @@ def artifacts(self, name=None, tags=[], qtype="or"): raise RubiconException("all configured storage backends failed") from return_err @failsafe - def artifact(self, name=None, id=None): + def artifact(self, name: Optional[str] = None, id: Optional[str] = None) -> Artifact: """Get an artifact logged to this project by id or name. Parameters @@ -279,7 +297,7 @@ def artifact(self, name=None, id=None): raise RubiconException("all configured storage backends failed") from return_err @failsafe - def delete_artifacts(self, ids): + def delete_artifacts(self, ids: List[str]): """Delete the artifacts logged to with client object with ids `ids`. @@ -298,8 +316,16 @@ def delete_artifacts(self, ids): class DataframeMixin: """Adds dataframe support to a client object.""" + _domain: DOMAIN_TYPES + @failsafe - def log_dataframe(self, df, description=None, name=None, tags=[]): + def log_dataframe( + self, + df: Union[pd.DataFrame, dd.DataFrame], + description: Optional[str] = None, + name: Optional[str] = None, + tags: Optional[List[str]] = None, + ) -> Dataframe: """Log a dataframe to this client object. Parameters @@ -317,6 +343,8 @@ def log_dataframe(self, df, description=None, name=None, tags=[]): rubicon.client.Dataframe The new dataframe. """ + if tags is None: + tags = [] if not isinstance(tags, list) or not all([isinstance(tag, str) for tag in tags]): raise ValueError("`tags` must be `list` of type `str`") @@ -334,7 +362,9 @@ def log_dataframe(self, df, description=None, name=None, tags=[]): return client.Dataframe(dataframe, self) @failsafe - def dataframes(self, name=None, tags=[], qtype="or"): + def dataframes( + self, name: Optional[str] = None, tags: Optional[List[str]] = None, qtype: str = "or" + ) -> List[Dataframe]: """Get the dataframes logged to this client object. Parameters @@ -352,6 +382,8 @@ def dataframes(self, name=None, tags=[], qtype="or"): list of rubicon.client.Dataframe The dataframes previously logged to this client object. """ + if tags is None: + tags = [] project_name, experiment_id = self._get_identifiers() return_err = None for repo in self.repositories: @@ -369,7 +401,7 @@ def dataframes(self, name=None, tags=[], qtype="or"): raise RubiconException("all configured storage backends failed") from return_err @failsafe - def dataframe(self, name=None, id=None): + def dataframe(self, name: Optional[str] = None, id: Optional[str] = None) -> Dataframe: """ Get the dataframe logged to this client object. @@ -419,7 +451,7 @@ def dataframe(self, name=None, id=None): raise RubiconException("all configured storage backends failed") from return_err @failsafe - def delete_dataframes(self, ids): + def delete_dataframes(self, ids: List[str]): """Delete the dataframes with ids `ids` logged to this client object. @@ -438,6 +470,8 @@ def delete_dataframes(self, ids): class TagMixin: """Adds tag support to a client object.""" + _domain: DOMAIN_TYPES + def _get_taggable_identifiers(self): project_name, experiment_id = self._parent._get_identifiers() entity_identifier = None @@ -455,7 +489,7 @@ def _get_taggable_identifiers(self): return project_name, experiment_id, entity_identifier @failsafe - def add_tags(self, tags): + def add_tags(self, tags: List[str]): """Add tags to this client object. Parameters @@ -479,7 +513,7 @@ def add_tags(self, tags): ) @failsafe - def remove_tags(self, tags): + def remove_tags(self, tags: List[str]): """Remove tags from this client object. Parameters @@ -508,7 +542,7 @@ def _update_tags(self, tag_data): self._domain.remove_tags(tag.get("removed_tags", [])) @property - def tags(self): + def tags(self) -> List[str]: """Get this client object's tags.""" project_name, experiment_id, entity_identifier = self._get_taggable_identifiers() return_err = None diff --git a/rubicon_ml/client/parameter.py b/rubicon_ml/client/parameter.py index ada71e34..e17a20ad 100644 --- a/rubicon_ml/client/parameter.py +++ b/rubicon_ml/client/parameter.py @@ -1,5 +1,14 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING, Optional, Union + from rubicon_ml.client import Base, TagMixin +if TYPE_CHECKING: + from rubicon_ml.client import Experiment + from rubicon_ml.domain import Parameter as ParameterDomain + class Parameter(Base, TagMixin): """A client parameter. @@ -22,36 +31,38 @@ class Parameter(Base, TagMixin): The experiment that the parameter is logged to. """ - def __init__(self, domain, parent): + def __init__(self, domain: ParameterDomain, parent: Experiment): super().__init__(domain, parent._config) self._parent = parent + self._domain: ParameterDomain + @property - def id(self): + def id(self) -> str: """Get the parameter's id.""" return self._domain.id @property - def name(self): + def name(self) -> Optional[str]: """Get the parameter's name.""" return self._domain.name @property - def value(self): + def value(self) -> Optional[Union[object, float]]: """Get the parameter's value.""" return self._domain.value @property - def description(self): + def description(self) -> Optional[str]: """Get the parameter's description.""" return self._domain.description @property - def created_at(self): + def created_at(self) -> datetime: """Get the time the parameter was created.""" return self._domain.created_at @property - def parent(self): + def parent(self) -> Experiment: """Get the parameter's parent client object.""" return self._parent diff --git a/rubicon_ml/client/project.py b/rubicon_ml/client/project.py index 8ed592f9..dffc5cb2 100644 --- a/rubicon_ml/client/project.py +++ b/rubicon_ml/client/project.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import subprocess import warnings -from typing import List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import dask.dataframe as dd import pandas as pd @@ -11,6 +13,11 @@ from rubicon_ml.client.utils.tags import filter_children from rubicon_ml.exceptions import RubiconException +if TYPE_CHECKING: + from rubicon_ml import Rubicon + from rubicon_ml.client import Config, Dataframe + from rubicon_ml.domain import Project as ProjectDomain + class Project(Base, ArtifactMixin, DataframeMixin): """A client project. @@ -26,14 +33,16 @@ class Project(Base, ArtifactMixin, DataframeMixin): The config, which specifies the underlying repository. """ - def __init__(self, domain, config=None): + def __init__(self, domain: ProjectDomain, config: Optional[Config] = None): super().__init__(domain, config) + self._domain: ProjectDomain + self._artifacts = [] self._dataframes = [] self._experiments = [] - def _get_branch_name(self): + def _get_branch_name(self) -> str: """Returns the name of the active branch of the `git` repo it is called from. """ @@ -42,7 +51,7 @@ def _get_branch_name(self): return completed_process.stdout.decode("utf8").replace("\n", "") - def _get_commit_hash(self): + def _get_commit_hash(self) -> str: """Returns the hash of the last commit to the active branch of the `git` repo it is called from. """ @@ -51,7 +60,7 @@ def _get_commit_hash(self): return completed_process.stdout.decode("utf8").replace("\n", "") - def _get_identifiers(self): + def _get_identifiers(self) -> Tuple[Optional[str], None]: """Get the project's name.""" return self.name, None @@ -86,7 +95,7 @@ def _create_experiment_domain( tags=tags, ) - def _group_experiments(self, experiments, group_by=None): + def _group_experiments(self, experiments: List[Experiment], group_by: Optional[str] = None): """Groups experiments by `group_by`. Valid options include ["commit_hash"]. Returns @@ -111,7 +120,7 @@ def _group_experiments(self, experiments, group_by=None): return grouped_experiments - def to_dask_df(self, group_by=None): + def to_dask_df(self, group_by: Optional[str] = None): """DEPRECATED: Available for backwards compatibility.""" warnings.warn( "`to_dask_df` is deprecated and will be removed in a future release. " @@ -122,7 +131,9 @@ def to_dask_df(self, group_by=None): return self.to_df(df_type="dask", group_by=group_by) @failsafe - def to_df(self, df_type="pandas", group_by=None): + def to_df( + self, df_type: str = "pandas", group_by: Optional[str] = None + ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame], dd.DataFrame, Dict[str, dd.DataFrame]]: """Loads the project's data into dask or pandas dataframe(s) sorted by `created_at`. This includes the experiment details along with parameters and metrics. @@ -138,9 +149,9 @@ def to_df(self, df_type="pandas", group_by=None): Returns ------- - pandas.DataFrame or list of pandas.DataFrame or dask.DataFrame or list of dask.DataFrame + pandas.DataFrame or dict of pandas.DataFrame or dask.DataFrame or dict of dask.DataFrame If `group_by` is `None`, a dask or pandas dataframe holding the project's - data. Otherwise a list of dask or pandas dataframes holding the project's + data. Otherwise a dict of dask or pandas dataframes holding the project's data grouped by `group_by`. """ DEFAULT_COLUMNS = [ @@ -199,14 +210,14 @@ def to_df(self, df_type="pandas", group_by=None): @failsafe def log_experiment( self, - name=None, - description=None, - model_name=None, - branch_name=None, - commit_hash=None, - training_metadata=None, - tags=[], - ): + name: Optional[str] = None, + description: Optional[str] = None, + model_name: Optional[str] = None, + branch_name: Optional[str] = None, + commit_hash: Optional[str] = None, + training_metadata: Optional[Union[Tuple, List[Tuple]]] = None, + tags: Optional[List[str]] = None, + ) -> Experiment: """Log a new experiment to this project. Parameters @@ -242,6 +253,8 @@ def log_experiment( rubicon.client.Experiment The created experiment. """ + if tags is None: + tags = [] if not isinstance(tags, list) or not all([isinstance(tag, str) for tag in tags]): raise ValueError("`tags` must be `list` of type `str`") @@ -260,7 +273,7 @@ def log_experiment( return Experiment(experiment, self) @failsafe - def experiment(self, id=None, name=None): + def experiment(self, id: Optional[str] = None, name: Optional[str] = None) -> Experiment: """Get an experiment logged to this project by id or name. Parameters @@ -303,7 +316,9 @@ def experiment(self, id=None, name=None): raise RubiconException("all configured storage backends failed") from return_err @failsafe - def experiments(self, tags=[], qtype="or", name=None): + def experiments( + self, tags: Optional[List[str]] = None, qtype: str = "or", name: Optional[str] = None + ) -> List[Experiment]: """Get the experiments logged to this project. Parameters @@ -321,6 +336,8 @@ def experiments(self, tags=[], qtype="or", name=None): list of rubicon.client.Experiment The experiments previously logged to this project. """ + if tags is None: + tags = [] return_err = None for repo in self.repositories: try: @@ -334,7 +351,13 @@ def experiments(self, tags=[], qtype="or", name=None): raise RubiconException("all configured storage backends failed") from return_err @failsafe - def dataframes(self, tags=[], qtype="or", recursive=False, name=None): + def dataframes( + self, + tags: Optional[List[str]] = None, + qtype: str = "or", + recursive: bool = False, + name: Optional[str] = None, + ) -> List[Dataframe]: """Get the dataframes logged to this project. Parameters @@ -355,6 +378,8 @@ def dataframes(self, tags=[], qtype="or", recursive=False, name=None): list of rubicon.client.Dataframe The dataframes previously logged to this client object. """ + if tags is None: + tags = [] super().dataframes(tags=tags, qtype=qtype, name=name) if recursive is True: @@ -364,7 +389,11 @@ def dataframes(self, tags=[], qtype="or", recursive=False, name=None): return self._dataframes @failsafe - def archive(self, experiments: Optional[List[Experiment]] = None, remote_rubicon=None): + def archive( + self, + experiments: Optional[List[Experiment]] = None, + remote_rubicon: Optional[Rubicon] = None, + ): """Archive the experiments logged to this project. Parameters diff --git a/rubicon_ml/client/rubicon.py b/rubicon_ml/client/rubicon.py index ea9f038b..a7c02a88 100644 --- a/rubicon_ml/client/rubicon.py +++ b/rubicon_ml/client/rubicon.py @@ -1,9 +1,11 @@ import subprocess import warnings +from typing import List, Optional, Tuple, Union from rubicon_ml import domain from rubicon_ml.client import Config, Project from rubicon_ml.client.utils.exception_handling import failsafe +from rubicon_ml.domain.utils import TrainingMetadata from rubicon_ml.exceptions import RubiconException from rubicon_ml.repository.utils import slugify @@ -32,7 +34,11 @@ class Rubicon: """ def __init__( - self, persistence="filesystem", root_dir=None, auto_git_enabled=False, **storage_options + self, + persistence: Optional[str] = "filesystem", + root_dir: Optional[str] = None, + auto_git_enabled: bool = False, + **storage_options, ): self.config = Config(persistence, root_dir, auto_git_enabled, **storage_options) @@ -64,19 +70,27 @@ def _get_github_url(self): return github_url - def _create_project_domain(self, name, description, github_url, training_metadata): + def _create_project_domain( + self, + name: str, + description: str, + github_url: str, + training_metadata: Union[List[Tuple], Tuple], + ): """Instantiates and returns a project domain object.""" if self.config.is_auto_git_enabled and github_url is None: github_url = self._get_github_url() if training_metadata is not None: - training_metadata = domain.utils.TrainingMetadata(training_metadata) + training_metadata_class = TrainingMetadata(training_metadata) + else: + training_metadata_class = None return domain.Project( name, description=description, github_url=github_url, - training_metadata=training_metadata, + training_metadata=training_metadata_class, ) @failsafe diff --git a/rubicon_ml/client/utils/exception_handling.py b/rubicon_ml/client/utils/exception_handling.py index e8fc3835..36407e79 100644 --- a/rubicon_ml/client/utils/exception_handling.py +++ b/rubicon_ml/client/utils/exception_handling.py @@ -2,6 +2,7 @@ import logging import traceback import warnings +from typing import Callable, Optional FAILURE_MODE = "raise" FAILURE_MODES = ["log", "raise", "warn"] @@ -9,7 +10,9 @@ TRACEBACK_LIMIT = None -def set_failure_mode(failure_mode, traceback_chain=False, traceback_limit=None): +def set_failure_mode( + failure_mode: str, traceback_chain: bool = False, traceback_limit: Optional[int] = None +) -> None: """Set the failure mode. Parameters @@ -38,7 +41,7 @@ def set_failure_mode(failure_mode, traceback_chain=False, traceback_limit=None): TRACEBACK_LIMIT = traceback_limit -def failsafe(func): +def failsafe(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs): try: diff --git a/rubicon_ml/client/utils/tags.py b/rubicon_ml/client/utils/tags.py index a6c62250..bb444bce 100644 --- a/rubicon_ml/client/utils/tags.py +++ b/rubicon_ml/client/utils/tags.py @@ -1,4 +1,7 @@ -def has_tag_requirements(tags, required_tags, qtype): +from typing import List + + +def has_tag_requirements(tags: List[str], required_tags: List[str], qtype: str) -> bool: """Returns True if `tags` meets the requirements based on the values of `required_tags` and `qtype`. False otherwise. """ diff --git a/rubicon_ml/domain/__init__.py b/rubicon_ml/domain/__init__.py index ad4002c5..7db2ea93 100644 --- a/rubicon_ml/domain/__init__.py +++ b/rubicon_ml/domain/__init__.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Union + from rubicon_ml.domain.artifact import Artifact from rubicon_ml.domain.dataframe import Dataframe from rubicon_ml.domain.experiment import Experiment @@ -8,4 +10,6 @@ from rubicon_ml.domain.parameter import Parameter from rubicon_ml.domain.project import Project +DOMAIN_TYPES = Union[Artifact, Dataframe, Experiment, Feature, Metric, Parameter, Project] + __all__ = ["Artifact", "Dataframe", "Experiment", "Feature", "Metric", "Parameter", "Project"] diff --git a/rubicon_ml/domain/artifact.py b/rubicon_ml/domain/artifact.py index 4a47f59b..14dcf01b 100644 --- a/rubicon_ml/domain/artifact.py +++ b/rubicon_ml/domain/artifact.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import List +from typing import List, Optional from rubicon_ml.domain.mixin import TagMixin from rubicon_ml.domain.utils import uuid @@ -13,8 +13,8 @@ class Artifact(TagMixin): name: str id: str = field(default_factory=uuid.uuid4) - description: str = None + description: Optional[str] = None created_at: datetime = field(default_factory=datetime.utcnow) tags: List[str] = field(default_factory=list) - parent_id: str = None + parent_id: Optional[str] = None diff --git a/rubicon_ml/domain/dataframe.py b/rubicon_ml/domain/dataframe.py index f8b12b51..e2f085b8 100644 --- a/rubicon_ml/domain/dataframe.py +++ b/rubicon_ml/domain/dataframe.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import List +from typing import List, Optional from rubicon_ml.domain.mixin import TagMixin from rubicon_ml.domain.utils import uuid @@ -11,9 +11,9 @@ @dataclass class Dataframe(TagMixin): id: str = field(default_factory=uuid.uuid4) - name: str = None - description: str = None + name: Optional[str] = None + description: Optional[str] = None tags: List[str] = field(default_factory=list) created_at: datetime = field(default_factory=datetime.utcnow) - parent_id: str = None + parent_id: Optional[str] = None diff --git a/rubicon_ml/domain/experiment.py b/rubicon_ml/domain/experiment.py index 7c98da2f..d5118b37 100644 --- a/rubicon_ml/domain/experiment.py +++ b/rubicon_ml/domain/experiment.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import List +from typing import List, Optional from rubicon_ml.domain.mixin import TagMixin from rubicon_ml.domain.utils import TrainingMetadata, uuid @@ -13,11 +13,11 @@ class Experiment(TagMixin): project_name: str id: str = field(default_factory=uuid.uuid4) - name: str = None - description: str = None - model_name: str = None - branch_name: str = None - commit_hash: str = None - training_metadata: TrainingMetadata = None + name: Optional[str] = None + description: Optional[str] = None + model_name: Optional[str] = None + branch_name: Optional[str] = None + commit_hash: Optional[str] = None + training_metadata: Optional[TrainingMetadata] = None tags: List[str] = field(default_factory=list) created_at: datetime = field(default_factory=datetime.utcnow) diff --git a/rubicon_ml/domain/feature.py b/rubicon_ml/domain/feature.py index 9da684dd..fd4215d4 100644 --- a/rubicon_ml/domain/feature.py +++ b/rubicon_ml/domain/feature.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import List +from typing import List, Optional from rubicon_ml.domain.mixin import TagMixin from rubicon_ml.domain.utils import uuid @@ -11,7 +11,7 @@ class Feature(TagMixin): name: str id: str = field(default_factory=uuid.uuid4) - description: str = None - importance: float = None + description: Optional[str] = None + importance: Optional[float] = None tags: List[str] = field(default_factory=list) created_at: datetime = field(default_factory=datetime.utcnow) diff --git a/rubicon_ml/domain/metric.py b/rubicon_ml/domain/metric.py index b58278fa..9efd338d 100644 --- a/rubicon_ml/domain/metric.py +++ b/rubicon_ml/domain/metric.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import List +from typing import List, Optional from rubicon_ml.domain.mixin import TagMixin from rubicon_ml.domain.utils import uuid @@ -14,7 +14,7 @@ class Metric(TagMixin): value: float id: str = field(default_factory=uuid.uuid4) - description: str = None + description: Optional[str] = None directionality: str = "score" created_at: datetime = field(default_factory=datetime.utcnow) tags: List[str] = field(default_factory=list) diff --git a/rubicon_ml/domain/mixin.py b/rubicon_ml/domain/mixin.py index 4dd3f4ab..974f3bdb 100644 --- a/rubicon_ml/domain/mixin.py +++ b/rubicon_ml/domain/mixin.py @@ -1,8 +1,27 @@ +from typing import List + + class TagMixin: """Adds tagging support to a domain model.""" - def add_tags(self, tags): + def add_tags(self, tags: List[str]): + """ + Add new tags to this model. + + Parameters + ---------- + tags : List[str] + A list of string tags to add to the domain model. + """ self.tags = list(set(self.tags).union(set(tags))) - def remove_tags(self, tags): + def remove_tags(self, tags: List[str]): + """ + Remove tags from this model. + + Parameters + ---------- + tags : List[str] + A list of string tags to remove from this domain model. + """ self.tags = list(set(self.tags).difference(set(tags))) diff --git a/rubicon_ml/domain/parameter.py b/rubicon_ml/domain/parameter.py index f8497416..1238c5cc 100644 --- a/rubicon_ml/domain/parameter.py +++ b/rubicon_ml/domain/parameter.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import List +from typing import List, Optional from rubicon_ml.domain.mixin import TagMixin from rubicon_ml.domain.utils import uuid @@ -12,6 +12,6 @@ class Parameter(TagMixin): id: str = field(default_factory=uuid.uuid4) value: object = None - description: str = None + description: Optional[str] = None tags: List[str] = field(default_factory=list) created_at: datetime = field(default_factory=datetime.utcnow) diff --git a/rubicon_ml/domain/project.py b/rubicon_ml/domain/project.py index fe7190b4..ff159590 100644 --- a/rubicon_ml/domain/project.py +++ b/rubicon_ml/domain/project.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field from datetime import datetime +from typing import Optional from rubicon_ml.domain.utils import TrainingMetadata, uuid @@ -11,7 +12,7 @@ class Project: name: str id: str = field(default_factory=uuid.uuid4) - description: str = None - github_url: str = None - training_metadata: TrainingMetadata = None + description: Optional[str] = None + github_url: Optional[str] = None + training_metadata: Optional[TrainingMetadata] = None created_at: datetime = field(default_factory=datetime.utcnow) diff --git a/rubicon_ml/domain/utils/training_metadata.py b/rubicon_ml/domain/utils/training_metadata.py index b766b474..a636756b 100644 --- a/rubicon_ml/domain/utils/training_metadata.py +++ b/rubicon_ml/domain/utils/training_metadata.py @@ -1,3 +1,5 @@ +from typing import List, Tuple, Union + from rubicon_ml.exceptions import RubiconException @@ -25,7 +27,7 @@ class TrainingMetadata: [('s3', ['bucket/a.csv', 'bucket/b.csv'], 'SELECT * FROM x')] """ - def __init__(self, training_metadata): + def __init__(self, training_metadata: Union[List[Tuple], Tuple]): if not isinstance(training_metadata, list): training_metadata = [training_metadata] @@ -34,5 +36,5 @@ def __init__(self, training_metadata): self.training_metadata = training_metadata - def __repr__(self): + def __repr__(self) -> str: return str(self.training_metadata) diff --git a/rubicon_ml/domain/utils/uuid.py b/rubicon_ml/domain/utils/uuid.py index c57a60bb..bbf1f567 100644 --- a/rubicon_ml/domain/utils/uuid.py +++ b/rubicon_ml/domain/utils/uuid.py @@ -1,7 +1,7 @@ import uuid -def uuid4(): +def uuid4() -> str: """Generate a UUID as a string in a single function. To be used as a default factory within `dataclasses.field`. diff --git a/rubicon_ml/repository/base.py b/rubicon_ml/repository/base.py index 9b33a2c1..950cefd7 100644 --- a/rubicon_ml/repository/base.py +++ b/rubicon_ml/repository/base.py @@ -36,7 +36,7 @@ class BaseRepository: the underlying filesystem class. """ - def __init__(self, root_dir, **storage_options): + def __init__(self, root_dir: str, **storage_options): self.filesystem = fsspec.filesystem(self.PROTOCOL, **storage_options) self.root_dir = root_dir.rstrip("/")