diff --git a/rdt/hyper_transformer.py b/rdt/hyper_transformer.py index ba01f7cc..8ac743e5 100644 --- a/rdt/hyper_transformer.py +++ b/rdt/hyper_transformer.py @@ -12,8 +12,8 @@ ConfigNotSetError, InvalidConfigError, InvalidDataError, NotFittedError, TransformerInputError, TransformerProcessingError) from rdt.transformers import ( - BaseTransformer, get_class_by_transformer_name, get_default_transformer, - get_transformers_by_type) + BaseMultiColumnTransformer, BaseTransformer, get_class_by_transformer_name, + get_default_transformer, get_transformers_by_type) from rdt.transformers.utils import flatten_column_list LOGGER = logging.getLogger(__name__) @@ -599,10 +599,13 @@ def _get_columns_to_sdtypes(self, field): """Generate the ``columns_to_sdtypes`` dict for the given field. Args: - field (tuple): + field (str, tuple[str]): Names of the column for the multi column trnasformer. """ columns_to_sdtypes = {} + if isinstance(field, str): + field = (field,) + for column in field: columns_to_sdtypes[column] = self.field_sdtypes[column] @@ -630,7 +633,7 @@ def _fit_field_transformer(self, data, field, transformer): self._output_columns.append(field) else: - if isinstance(field, tuple): + if isinstance(transformer, BaseMultiColumnTransformer): columns_to_sdtypes = self._get_columns_to_sdtypes(field) transformer.fit(data, columns_to_sdtypes) else: diff --git a/tests/integration/test_hyper_transformer.py b/tests/integration/test_hyper_transformer.py index ed820a00..828457c6 100644 --- a/tests/integration/test_hyper_transformer.py +++ b/tests/integration/test_hyper_transformer.py @@ -1481,6 +1481,44 @@ def test_hypertransformer_with_mutli_column_transformer_end_to_end(self): pd.testing.assert_frame_equal(transformed_data, expected_transformed_data) pd.testing.assert_frame_equal(reverse_transformed_data, data_test) + def test_hypertransformer_with_mutli_column_transformer_and_single_column(self): + """Test a mutli column transformer used with for a single column.""" + # Setup + data_test = pd.DataFrame({ + 'A': ['1.0', '2.0', '3.0'], + 'B2': ['4.0', '5.0', '6.0'], + 'C': [True, False, True] + }) + dict_config = { + 'sdtypes': { + 'A': 'categorical', + 'B2': 'categorical', + 'C': 'boolean' + }, + 'transformers': { + 'A': DummyMultiColumnTransformerNumerical(), + ('B2', ): DummyMultiColumnTransformerNumerical(), + 'C': UniformEncoder() + } + } + config = Config(dict_config) + ht = HyperTransformer() + ht.set_config(config) + + # Run + transformed_data = ht.fit_transform(data_test) + reverse_transformed_data = ht.reverse_transform(transformed_data) + + # Assert + expected_transformed_data = pd.DataFrame({ + 'A': [1.0, 2.0, 3.0], + 'B2': [4.0, 5.0, 6.0], + 'C': [0.04206197607326308, 0.8000968077312287, 0.06325519846695522] + }) + + pd.testing.assert_frame_equal(transformed_data, expected_transformed_data) + pd.testing.assert_frame_equal(reverse_transformed_data, data_test) + def test_update_transformers_single_to_multi_column(self): """Test ``update_transformers`` to go from single to mutli column transformer.""" # Setup @@ -1515,13 +1553,13 @@ def test_update_transformers_single_to_multi_column(self): }, 'transformers': { 'C': UniformEncoder(), - "('A', 'B')": DummyMultiColumnTransformerNumerical() + ('A', 'B'): DummyMultiColumnTransformerNumerical(), } }) expected_multi_columns = { 'A': ('A', 'B'), - 'B': ('A', 'B') + 'B': ('A', 'B'), } assert repr(new_config) == repr(expected_config) diff --git a/tests/unit/test_hyper_transformer.py b/tests/unit/test_hyper_transformer.py index 9d3f4f93..7793593c 100644 --- a/tests/unit/test_hyper_transformer.py +++ b/tests/unit/test_hyper_transformer.py @@ -340,7 +340,8 @@ def test__get_columns_to_sdtypes(self): column_tuple = ('col1', 'col2', 'col3') # Run - columns_to_sdtypes = ht._get_columns_to_sdtypes(column_tuple) + columns_to_sdtypes_tuple = ht._get_columns_to_sdtypes(column_tuple) + columns_to_sdtypes_str = ht._get_columns_to_sdtypes('col4') # Assert expected_columns_to_sdtypes = { @@ -348,7 +349,8 @@ def test__get_columns_to_sdtypes(self): 'col2': 'categorical', 'col3': 'boolean', } - assert columns_to_sdtypes == expected_columns_to_sdtypes + assert columns_to_sdtypes_tuple == expected_columns_to_sdtypes + assert columns_to_sdtypes_str == {'col4': 'datetime'} def test__fit_field_transformer(self): """Test the ``_fit_field_transformer`` method.