Skip to content

Commit

Permalink
Multi column transformers crash when assigned to single column (#733)
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Oct 31, 2023
1 parent 003c7ba commit edf0a56
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 8 deletions.
11 changes: 7 additions & 4 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
ConfigNotSetError, InvalidConfigError, InvalidDataError, NotFittedError, TransformerInputError,
TransformerProcessingError)
from rdt.transformers import (
BaseTransformer, get_class_by_transformer_name, get_default_transformer,
get_transformers_by_type)
BaseMultiColumnTransformer, BaseTransformer, get_class_by_transformer_name,
get_default_transformer, get_transformers_by_type)
from rdt.transformers.utils import flatten_column_list

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -599,10 +599,13 @@ def _get_columns_to_sdtypes(self, field):
"""Generate the ``columns_to_sdtypes`` dict for the given field.
Args:
field (tuple):
field (str, tuple[str]):
Names of the column for the multi column trnasformer.
"""
columns_to_sdtypes = {}
if isinstance(field, str):
field = (field,)

for column in field:
columns_to_sdtypes[column] = self.field_sdtypes[column]

Expand Down Expand Up @@ -630,7 +633,7 @@ def _fit_field_transformer(self, data, field, transformer):
self._output_columns.append(field)

else:
if isinstance(field, tuple):
if isinstance(transformer, BaseMultiColumnTransformer):
columns_to_sdtypes = self._get_columns_to_sdtypes(field)
transformer.fit(data, columns_to_sdtypes)
else:
Expand Down
42 changes: 40 additions & 2 deletions tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,44 @@ def test_hypertransformer_with_mutli_column_transformer_end_to_end(self):
pd.testing.assert_frame_equal(transformed_data, expected_transformed_data)
pd.testing.assert_frame_equal(reverse_transformed_data, data_test)

def test_hypertransformer_with_mutli_column_transformer_and_single_column(self):
"""Test a mutli column transformer used with for a single column."""
# Setup
data_test = pd.DataFrame({
'A': ['1.0', '2.0', '3.0'],
'B2': ['4.0', '5.0', '6.0'],
'C': [True, False, True]
})
dict_config = {
'sdtypes': {
'A': 'categorical',
'B2': 'categorical',
'C': 'boolean'
},
'transformers': {
'A': DummyMultiColumnTransformerNumerical(),
('B2', ): DummyMultiColumnTransformerNumerical(),
'C': UniformEncoder()
}
}
config = Config(dict_config)
ht = HyperTransformer()
ht.set_config(config)

# Run
transformed_data = ht.fit_transform(data_test)
reverse_transformed_data = ht.reverse_transform(transformed_data)

# Assert
expected_transformed_data = pd.DataFrame({
'A': [1.0, 2.0, 3.0],
'B2': [4.0, 5.0, 6.0],
'C': [0.04206197607326308, 0.8000968077312287, 0.06325519846695522]
})

pd.testing.assert_frame_equal(transformed_data, expected_transformed_data)
pd.testing.assert_frame_equal(reverse_transformed_data, data_test)

def test_update_transformers_single_to_multi_column(self):
"""Test ``update_transformers`` to go from single to mutli column transformer."""
# Setup
Expand Down Expand Up @@ -1515,13 +1553,13 @@ def test_update_transformers_single_to_multi_column(self):
},
'transformers': {
'C': UniformEncoder(),
"('A', 'B')": DummyMultiColumnTransformerNumerical()
('A', 'B'): DummyMultiColumnTransformerNumerical(),
}
})

expected_multi_columns = {
'A': ('A', 'B'),
'B': ('A', 'B')
'B': ('A', 'B'),
}

assert repr(new_config) == repr(expected_config)
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,15 +340,17 @@ def test__get_columns_to_sdtypes(self):
column_tuple = ('col1', 'col2', 'col3')

# Run
columns_to_sdtypes = ht._get_columns_to_sdtypes(column_tuple)
columns_to_sdtypes_tuple = ht._get_columns_to_sdtypes(column_tuple)
columns_to_sdtypes_str = ht._get_columns_to_sdtypes('col4')

# Assert
expected_columns_to_sdtypes = {
'col1': 'numerical',
'col2': 'categorical',
'col3': 'boolean',
}
assert columns_to_sdtypes == expected_columns_to_sdtypes
assert columns_to_sdtypes_tuple == expected_columns_to_sdtypes
assert columns_to_sdtypes_str == {'col4': 'datetime'}

def test__fit_field_transformer(self):
"""Test the ``_fit_field_transformer`` method.
Expand Down

0 comments on commit edf0a56

Please sign in to comment.