Skip to content

Commit

Permalink
def
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Oct 27, 2023
1 parent 5ebecbb commit 99d31e2
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
ConfigNotSetError, InvalidConfigError, InvalidDataError, NotFittedError, TransformerInputError,
TransformerProcessingError)
from rdt.transformers import (
BaseTransformer, get_class_by_transformer_name, get_default_transformer,
get_transformers_by_type)
BaseMultiColumnTransformer, BaseTransformer, get_class_by_transformer_name,
get_default_transformer, get_transformers_by_type)
from rdt.transformers.utils import flatten_column_list

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -599,10 +599,13 @@ def _get_columns_to_sdtypes(self, field):
"""Generate the ``columns_to_sdtypes`` dict for the given field.
Args:
field (tuple):
field (str, tuple[str]):
Names of the column for the multi column trnasformer.
"""
columns_to_sdtypes = {}
if isinstance(field, str):
field = (field,)

for column in field:
columns_to_sdtypes[column] = self.field_sdtypes[column]

Expand Down Expand Up @@ -630,7 +633,7 @@ def _fit_field_transformer(self, data, field, transformer):
self._output_columns.append(field)

else:
if isinstance(field, tuple):
if isinstance(transformer, BaseMultiColumnTransformer):
columns_to_sdtypes = self._get_columns_to_sdtypes(field)
transformer.fit(data, columns_to_sdtypes)
else:
Expand Down

0 comments on commit 99d31e2

Please sign in to comment.