Skip to content

Commit

Permalink
feat(synthesizer): Move name from fit to init method
Browse files Browse the repository at this point in the history
  • Loading branch information
portellaa committed Dec 13, 2023
1 parent df50229 commit c28f3ed
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 10 deletions.
20 changes: 20 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: ydata-sdk
channels:
- defaults
dependencies:
- bzip2=1.0.8=h1de35cc_0
- ca-certificates=2023.08.22=hecd8cb5_0
- libffi=3.4.4=hecd8cb5_0
- ncurses=6.4=hcec6c5f_0
- openssl=3.0.12=hca72f7f_0
- pip=23.3.1=py310hecd8cb5_0
- python=3.10.13=h5ee71fb_0
- readline=8.2=hca72f7f_0
- setuptools=68.0.0=py310hecd8cb5_0
- sqlite=3.41.2=h6c40b1e_0
- tk=8.6.12=h5d9f67b_0
- tzdata=2023c=h04d1e81_0
- wheel=0.41.2=py310hecd8cb5_0
- xz=5.4.5=h6c40b1e_0
- zlib=1.2.13=h4dc903c_0
prefix: /usr/local/Caskroom/miniconda/base/envs/ydata-sdk
3 changes: 1 addition & 2 deletions src/ydata/sdk/synthesizers/regular.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def fit(self, X: Union[DataSource, pdDataFrame],
exclude_cols: Optional[List[str]] = None,
dtypes: Optional[Dict[str, Union[str, DataType]]] = None,
target: Optional[str] = None,
name: Optional[str] = None,
anonymize: Optional[dict] = None,
condition_on: Optional[List[str]] = None) -> None:
"""Fit the synthesizer.
Expand All @@ -61,7 +60,7 @@ def fit(self, X: Union[DataSource, pdDataFrame],
"""
BaseSynthesizer.fit(self, X=X, datatype=DataSourceType.TABULAR, entities=entities,
generate_cols=generate_cols, exclude_cols=exclude_cols, dtypes=dtypes,
target=target, name=name, anonymize=anonymize, privacy_level=privacy_level,
target=target, anonymize=anonymize, privacy_level=privacy_level,
condition_on=condition_on)

def __repr__(self):
Expand Down
9 changes: 3 additions & 6 deletions src/ydata/sdk/synthesizers/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ class BaseSynthesizer(ABC, ModelFactoryMixin):
client (Client): (optional) Client to connect to the backend
"""

def __init__(self, client: Optional[Client] = None):
def __init__(self, name: str | None = None, client: Client | None = None):
self._init_common(client=client)
self._model: Optional[mSynthesizer] = None
self._model = mSynthesizer(name=name or str(uuid4()))

@init_client
def _init_common(self, client: Optional[Client] = None):
Expand All @@ -69,7 +69,6 @@ def fit(self, X: Union[DataSource, pdDataFrame],
exclude_cols: Optional[List[str]] = None,
dtypes: Optional[Dict[str, Union[str, DataType]]] = None,
target: Optional[str] = None,
name: Optional[str] = None,
anonymize: Optional[dict] = None,
condition_on: Optional[List[str]] = None) -> None:
"""Fit the synthesizer.
Expand Down Expand Up @@ -223,15 +222,13 @@ def _fit_from_datasource(
privacy_level: PrivacyLevel = PrivacyLevel.HIGH_FIDELITY,
dataset_attrs: Optional[DataSourceAttrs] = None,
target: Optional[str] = None,
name: Optional[str] = None,
anonymize: Optional[dict] = None,
condition_on: Optional[List[str]] = None
) -> None:
_name = name if name is not None else str(uuid4())
metadata = self._metadata_to_payload(
DataSourceType(X.datatype), X.metadata, dataset_attrs, target)
payload = {
'name': _name,
'name': self._model.name,
'dataSourceUID': X.uid,
'metadata': metadata,
'extraData': {},
Expand Down
3 changes: 1 addition & 2 deletions src/ydata/sdk/synthesizers/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def fit(self, X: Union[DataSource, pdDataFrame],
exclude_cols: Optional[List[str]] = None,
dtypes: Optional[Dict[str, Union[str, DataType]]] = None,
target: Optional[str] = None,
name: Optional[str] = None,
anonymize: Optional[dict] = None,
condition_on: Optional[List[str]] = None) -> None:
"""Fit the synthesizer.
Expand All @@ -65,7 +64,7 @@ def fit(self, X: Union[DataSource, pdDataFrame],
"""
BaseSynthesizer.fit(self, X=X, datatype=DataSourceType.TIMESERIES, sortbykey=sortbykey,
entities=entities, generate_cols=generate_cols, exclude_cols=exclude_cols,
dtypes=dtypes, target=target, name=name, anonymize=anonymize, privacy_level=privacy_level,
dtypes=dtypes, target=target, anonymize=anonymize, privacy_level=privacy_level,
condition_on=condition_on)

def __repr__(self):
Expand Down

0 comments on commit c28f3ed

Please sign in to comment.