Skip to content

Commit

Permalink
More hints
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenpardy committed Sep 13, 2023
1 parent 10d3a12 commit 734f76d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
12 changes: 9 additions & 3 deletions rubicon_ml/client/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import subprocess
import warnings
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Optional, Union
from typing import TYPE_CHECKING, Any, List, Optional, TextIO, Union

import fsspec

Expand Down Expand Up @@ -61,7 +61,7 @@ def _validate_data(self, data_bytes, data_file, data_object, data_path, name):
def log_artifact(
self,
data_bytes: Optional[bytes] = None,
data_file=None,
data_file: Optional[TextIO] = None,
data_object: Optional[Any] = None,
data_path: Optional[str] = None,
name: Optional[str] = None,
Expand Down Expand Up @@ -320,7 +320,11 @@ class DataframeMixin:

@failsafe
def log_dataframe(
self, df: Union[pd.DataFrame, dd.DataFrame], description=None, name=None, tags=[]
self,
df: Union[pd.DataFrame, dd.DataFrame],
description: Optional[str] = None,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> Dataframe:
"""Log a dataframe to this client object.
Expand All @@ -339,6 +343,8 @@ def log_dataframe(
rubicon.client.Dataframe
The new dataframe.
"""
if tags is None:
tags = []
if not isinstance(tags, list) or not all([isinstance(tag, str) for tag in tags]):
raise ValueError("`tags` must be `list` of type `str`")

Expand Down
4 changes: 3 additions & 1 deletion rubicon_ml/client/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def __init__(self, domain: ParameterDomain, parent: Experiment):
super().__init__(domain, parent._config)
self._parent = parent

self._domain: ParameterDomain

@property
def id(self) -> str:
"""Get the parameter's id."""
Expand All @@ -48,7 +50,7 @@ def name(self) -> Optional[str]:
@property
def value(self) -> Optional[Union[object, float]]:
"""Get the parameter's value."""
return getattr(self._domain, "value", None)
return self._domain.value

@property
def description(self) -> Optional[str]:
Expand Down
21 changes: 15 additions & 6 deletions rubicon_ml/client/rubicon.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import subprocess
import warnings
from typing import Optional
from typing import List, Optional, Tuple, Union

from rubicon_ml import domain
from rubicon_ml.client import Config, Project
from rubicon_ml.client.utils.exception_handling import failsafe
from rubicon_ml.domain.utils import TrainingMetadata
from rubicon_ml.exceptions import RubiconException
from rubicon_ml.repository.utils import slugify

Expand Down Expand Up @@ -35,8 +36,8 @@ class Rubicon:
def __init__(
self,
persistence: Optional[str] = "filesystem",
root_dir=None,
auto_git_enabled=False,
root_dir: Optional[str] = None,
auto_git_enabled: bool = False,
**storage_options,
):
self.config = Config(persistence, root_dir, auto_git_enabled, **storage_options)
Expand Down Expand Up @@ -69,19 +70,27 @@ def _get_github_url(self):

return github_url

def _create_project_domain(self, name, description, github_url, training_metadata):
def _create_project_domain(
self,
name: str,
description: str,
github_url: str,
training_metadata: Union[List[Tuple], Tuple],
):
"""Instantiates and returns a project domain object."""
if self.config.is_auto_git_enabled and github_url is None:
github_url = self._get_github_url()

if training_metadata is not None:
training_metadata = domain.utils.TrainingMetadata(training_metadata)
training_metadata_class = TrainingMetadata(training_metadata)
else:
training_metadata_class = None

return domain.Project(
name,
description=description,
github_url=github_url,
training_metadata=training_metadata,
training_metadata=training_metadata_class,
)

@failsafe
Expand Down

0 comments on commit 734f76d

Please sign in to comment.