Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typehints to much of the core code. #379

Merged
merged 14 commits into from
Sep 13, 2023
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:
rev: 5.12.0
hooks:
- id: isort

- repo: https://github.com/pycqa/flake8
rev: 6.1.0
hooks:
Expand Down
25 changes: 16 additions & 9 deletions rubicon_ml/client/artifact.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import os
import pickle
import warnings
from typing import TYPE_CHECKING, Optional

import fsspec

Expand All @@ -9,6 +12,10 @@
from rubicon_ml.client.utils.exception_handling import failsafe
from rubicon_ml.exceptions import RubiconException

if TYPE_CHECKING:
from rubicon_ml.client import Project
from rubicon_ml.domain import Artifact as ArtifactDomain


class Artifact(Base, TagMixin):
"""A client artifact.
Expand All @@ -32,7 +39,7 @@ class Artifact(Base, TagMixin):
logged to.
"""

def __init__(self, domain, parent):
def __init__(self, domain: ArtifactDomain, parent: Project):
super().__init__(domain, parent._config)

self._data = None
Expand All @@ -42,8 +49,8 @@ def _get_data(self):
"""Loads the data associated with this artifact."""
project_name, experiment_id = self.parent._get_identifiers()
return_err = None
for repo in self.repositories:
self._data = None
self._data = None
for repo in self.repositories or []:
try:
self._data = repo.get_artifact_data(
project_name, self.id, experiment_id=experiment_id
Expand All @@ -56,7 +63,7 @@ def _get_data(self):
raise RubiconException("all configured storage backends failed") from return_err

@failsafe
def get_data(self, unpickle=False):
def get_data(self, unpickle: bool = False):
"""Loads the data associated with this artifact and
unpickles if needed.

Expand All @@ -68,7 +75,7 @@ def get_data(self, unpickle=False):
"""
project_name, experiment_id = self.parent._get_identifiers()
return_err = None
for repo in self.repositories:
for repo in self.repositories or []:
try:
data = repo.get_artifact_data(project_name, self.id, experiment_id=experiment_id)
except Exception as err:
Expand All @@ -80,7 +87,7 @@ def get_data(self, unpickle=False):
raise RubiconException("all configured storage backends failed") from return_err

@failsafe
def download(self, location=None, name=None):
def download(self, location: Optional[str] = None, name: Optional[str] = None):
"""Download this artifact's data.

Parameters
Expand All @@ -104,17 +111,17 @@ def download(self, location=None, name=None):
f.write(self.data)

@property
def id(self):
def id(self) -> str:
"""Get the artifact's id."""
return self._domain.id

@property
def name(self):
def name(self) -> str:
"""Get the artifact's name."""
return self._domain.name

@property
def description(self):
def description(self) -> str:
"""Get the artifact's description."""
return self._domain.description

Expand Down
23 changes: 18 additions & 5 deletions rubicon_ml/client/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING, List, Optional

if TYPE_CHECKING:
from rubicon_ml.client import Config
from rubicon_ml.domain import DOMAIN_TYPES
from rubicon_ml.repository import BaseRepository


class Base:
"""The base object for all top-level client objects.

Expand All @@ -9,19 +19,22 @@ class Base:
The config, which injects the repository to use.
"""

def __init__(self, domain, config=None):
def __init__(self, domain: DOMAIN_TYPES, config: Optional[Config] = None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to look into why the config is optional - that doesn't seem right

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the default value here is None, do we need to make it required?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not in this PR, maybe in the future. I think its a remnant of the fact that Rubicon used to extend the Base and didn't take in a config since it generated it. I'll verify that's the case and if so make an issue to fix it

self._config = config
self._domain = domain

def __str__(self):
def __str__(self) -> str:
return self._domain.__str__()

@property
def repository(self):
return self._config.repository
def repository(self) -> Optional[BaseRepository]:
return self._config.repository if self._config is not None else None

@property
def repositories(self):
def repositories(self) -> Optional[List[BaseRepository]]:
if self._config is None:
return None

if hasattr(self._config, "repositories"):
return self._config.repositories
else:
Expand Down
26 changes: 19 additions & 7 deletions rubicon_ml/client/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import os
import subprocess
from typing import Dict, Optional, Tuple, Type

from rubicon_ml.exceptions import RubiconException
from rubicon_ml.repository import LocalRepository, MemoryRepository, S3Repository
from rubicon_ml.repository import (
BaseRepository,
LocalRepository,
MemoryRepository,
S3Repository,
)


class Config:
Expand All @@ -29,18 +35,22 @@ class Config:
"""

PERSISTENCE_TYPES = ["filesystem", "memory"]
REPOSITORIES = {
REPOSITORIES: Dict[str, Type[BaseRepository]] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need the Type here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type here means that the dictionary contains uninstantiated classes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh gotcha, makes sense

"memory-memory": MemoryRepository,
"filesystem-local": LocalRepository,
"filesystem-s3": S3Repository,
}

def __init__(
self, persistence=None, root_dir=None, is_auto_git_enabled=False, **storage_options
self,
persistence: Optional[str] = None,
root_dir: Optional[str] = None,
is_auto_git_enabled: bool = False,
**storage_options,
):
self.storage_options = storage_options
if storage_options is not None and "composite_config" in storage_options:
composite_config = storage_options.get("composite_config")
composite_config = storage_options.get("composite_config", [])
repositories = []
for config in composite_config:
self.persistence, self.root_dir, self.is_auto_git_enabled = self._load_config(
Expand All @@ -62,7 +72,9 @@ def _check_is_in_git_repo(self):
"Not a `git` repo: Falied to locate the '.git' directory in this or any parent directories."
)

def _load_config(self, persistence, root_dir, is_auto_git_enabled):
def _load_config(
self, persistence: Optional[str], root_dir: Optional[str], is_auto_git_enabled: bool
) -> Tuple[str, Optional[str], bool]:
"""Get the configuration values."""
persistence = os.environ.get("PERSISTENCE", persistence)
if persistence not in self.PERSISTENCE_TYPES:
Expand All @@ -77,7 +89,7 @@ def _load_config(self, persistence, root_dir, is_auto_git_enabled):

return (persistence, root_dir, is_auto_git_enabled)

def _get_protocol(self):
def _get_protocol(self) -> str:
"""Get the file protocol of the configured root directory."""
if self.persistence == "memory":
return "memory"
Expand All @@ -89,7 +101,7 @@ def _get_protocol(self):

return "custom" # catch-all for external backends

def _get_repository(self):
def _get_repository(self) -> BaseRepository:
"""Get the repository for the configured persistence type."""
protocol = self._get_protocol()

Expand Down
21 changes: 18 additions & 3 deletions rubicon_ml/client/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Literal, Optional, Union

from rubicon_ml.client import Base, TagMixin
from rubicon_ml.client.utils.exception_handling import failsafe
from rubicon_ml.exceptions import RubiconException

if TYPE_CHECKING:
from rubicon_ml.client import Experiment, Project
from rubicon_ml.domain import Dataframe as DataframeDomain


class Dataframe(Base, TagMixin):
"""A client dataframe.
Expand All @@ -24,14 +32,16 @@ class Dataframe(Base, TagMixin):
logged to.
"""

def __init__(self, domain, parent):
def __init__(self, domain: DataframeDomain, parent: Union[Experiment, Project]):
super().__init__(domain, parent._config)

self._domain: DataframeDomain

self._data = None
self._parent = parent

@failsafe
def get_data(self, df_type="pandas"):
def get_data(self, df_type: Literal["pandas", "dask"] = "pandas"):
"""Loads the data associated with this Dataframe
into a `pandas` or `dask` dataframe.

Expand Down Expand Up @@ -59,7 +69,12 @@ def get_data(self, df_type="pandas"):
raise RubiconException(return_err)

@failsafe
def plot(self, df_type="pandas", plotting_func=None, **kwargs):
def plot(
self,
df_type: Literal["pandas", "dask"] = "pandas",
plotting_func: Optional[Callable] = None,
**kwargs,
):
"""Render the dataframe using `plotly.express`.

Parameters
Expand Down
12 changes: 11 additions & 1 deletion rubicon_ml/client/experiment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from rubicon_ml import domain
from rubicon_ml.client import (
ArtifactMixin,
Expand All @@ -12,6 +16,10 @@
from rubicon_ml.client.utils.tags import filter_children
from rubicon_ml.exceptions import RubiconException

if TYPE_CHECKING:
from rubicon_ml.client import Project
from rubicon_ml.domain import Experiment as ExperimentDomain


class Experiment(Base, ArtifactMixin, DataframeMixin, TagMixin):
"""A client experiment.
Expand All @@ -30,9 +38,11 @@ class Experiment(Base, ArtifactMixin, DataframeMixin, TagMixin):
The project that the experiment is logged to.
"""

def __init__(self, domain, parent):
def __init__(self, domain: ExperimentDomain, parent: Project):
super().__init__(domain, parent._config)

self._domain: ExperimentDomain

self._parent = parent
self._artifacts = []
self._dataframes = []
Expand Down
23 changes: 17 additions & 6 deletions rubicon_ml/client/feature.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING, Optional

from rubicon_ml.client import Base, TagMixin

if TYPE_CHECKING:
from rubicon_ml.client import Experiment
from rubicon_ml.domain import Feature as FeatureDomain


class Feature(Base, TagMixin):
"""A client feature.
Expand All @@ -25,24 +34,26 @@ class Feature(Base, TagMixin):
logged to.
"""

def __init__(self, domain, parent):
def __init__(self, domain: FeatureDomain, parent: Experiment):
super().__init__(domain, parent._config)

self._domain: FeatureDomain

self._data = None
self._parent = parent

@property
def id(self):
def id(self) -> str:
"""Get the feature's id."""
return self._domain.id

@property
def name(self):
def name(self) -> Optional[str]:
"""Get the feature's name."""
return self._domain.name

@property
def description(self):
def description(self) -> Optional[str]:
"""Get the feature's description."""
return self._domain.description

Expand All @@ -52,11 +63,11 @@ def importance(self):
return self._domain.importance

@property
def created_at(self):
def created_at(self) -> datetime:
"""Get the feature's created_at."""
return self._domain.created_at

@property
def parent(self):
def parent(self) -> Experiment:
"""Get the feature's parent client object."""
return self._parent
Loading
Loading