From 890a418cb349878076ac8ef58ccee2c8e796f544 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Portela=20Afonso?= Date: Thu, 7 Sep 2023 17:38:49 +0100 Subject: [PATCH] fix(synthesizer): conform with recent changes to the metadata --- .editorconfig | 2 +- src/ydata/sdk/synthesizers/synthesizer.py | 62 ++++++++++------------- 2 files changed, 27 insertions(+), 37 deletions(-) diff --git a/.editorconfig b/.editorconfig index 886cdb20..78e11480 100644 --- a/.editorconfig +++ b/.editorconfig @@ -17,5 +17,5 @@ indent_size = 2 indent_style = tab [*.py] -indent_size = 2 +indent_size = 4 indent_style = space diff --git a/src/ydata/sdk/synthesizers/synthesizer.py b/src/ydata/sdk/synthesizers/synthesizer.py index 2a620873..7435bbb1 100644 --- a/src/ydata/sdk/synthesizers/synthesizer.py +++ b/src/ydata/sdk/synthesizers/synthesizer.py @@ -179,7 +179,10 @@ def _validate_datasource_attributes(X: Union[DataSource, pdDataFrame], dataset_a "The dataset attributes are invalid:\n {}".format('\n'.join(error_msgs))) @staticmethod - def _metadata_to_payload(datatype: DataSourceType, ds_metadata: Metadata, dataset_attrs: Optional[DataSourceAttrs] = None) -> list: + def _metadata_to_payload( + datatype: DataSourceType, ds_metadata: Metadata, + dataset_attrs: Optional[DataSourceAttrs] = None, target: str | None = None + ) -> dict: """Transform a the metadata and dataset attributes into a valid payload. @@ -187,39 +190,32 @@ def _metadata_to_payload(datatype: DataSourceType, ds_metadata: Metadata, datase datatype (DataSourceType): datasource type ds_metadata (Metadata): datasource metadata object dataset_attrs ( Optional[DataSourceAttrs] ): (optional) Dataset attributes + target (Optional[str]): (optional) target column name Returns: - payload dictionary + metadata payload dictionary """ - columns = {} - for c in ds_metadata.columns: - columns[c.name] = { + + columns = [ + { 'name': c.name, - 'generation': True, - 'dataType': c.datatype if c.datatype != DataType.STR.value else DataType.CATEGORICAL.value, + 'generation': c.name in dataset_attrs.sortbykey or (c.name in dataset_attrs.generate_cols and c.name not in dataset_attrs.exclude_cols), + 'dataType': DataType(dataset_attrs.dtypes[c.name]).value if c.name in dataset_attrs.dtypes else c.datatype, 'varType': c.vartype, - 'entity': False, } - if dataset_attrs is not None: - if datatype == DataSourceType.TIMESERIES: - for c in ds_metadata.columns: - columns[c.name]['sortBy'] = c.name in dataset_attrs.sortbykey - - for c in dataset_attrs.entities: - columns[c]['entity'] = True - - for c in dataset_attrs.generate_cols: - columns[c]['generation'] = True + for c in ds_metadata.columns ] - for c in dataset_attrs.exclude_cols: - columns[c]['generation'] = False + metadata = { + 'columns': columns, + 'target': target + } - # Update metadata based on the datatypes and vartypes provided by the user - for k, v in dataset_attrs.dtypes.items(): - if k in columns and columns[k]['generation']: - columns[k]['dataType'] = v.value + if dataset_attrs is not None: + if datatype == DataSourceType.TIMESERIES: + metadata['sortBy'] = [c for c in dataset_attrs.sortbykey] + metadata['entity'] = [c for c in dataset_attrs.entities] - return list(columns.values()) + return metadata def _fit_from_datasource( self, @@ -232,25 +228,19 @@ def _fit_from_datasource( condition_on: Optional[List[str]] = None ) -> None: _name = name if name is not None else str(uuid4()) - columns = self._metadata_to_payload( - DataSourceType(X.datatype), X.metadata, dataset_attrs) + metadata = self._metadata_to_payload( + DataSourceType(X.datatype), X.metadata, dataset_attrs, target) payload = { 'name': _name, 'dataSourceUID': X.uid, - 'metadata': { - 'dataType': X.datatype, - "columns": columns, - }, - 'extraData': { - 'privacy_level': privacy_level.value - } + 'metadata': metadata, + 'extraData': {}, + 'privacyLevel': privacy_level.value } if anonymize is not None: payload["extraData"]["anonymize"] = anonymize if condition_on is not None: payload["extraData"]["condition_on"] = condition_on - if target is not None: - payload['metadata']['target'] = target response = self._client.post('/synthesizer/', json=payload) data: list = response.json()