From 56f728169dfd55c16ee35293fa42e2093cf9a787 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Fri, 27 Oct 2023 07:37:05 -0600 Subject: [PATCH] tests --- tests/integration/test_hyper_transformer.py | 42 ++++++++++++++++++++- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_hyper_transformer.py b/tests/integration/test_hyper_transformer.py index ed820a001..fdd85a236 100644 --- a/tests/integration/test_hyper_transformer.py +++ b/tests/integration/test_hyper_transformer.py @@ -1481,6 +1481,44 @@ def test_hypertransformer_with_mutli_column_transformer_end_to_end(self): pd.testing.assert_frame_equal(transformed_data, expected_transformed_data) pd.testing.assert_frame_equal(reverse_transformed_data, data_test) + def test_hypertransformer_with_mutli_column_transformer_and_single_column(self): + """Test a mutli column transformer used with for a single column.""" + # Setup + data_test = pd.DataFrame({ + 'A': ['1.0', '2.0', '3.0'], + 'B2': ['4.0', '5.0', '6.0'], + 'C': [True, False, True] + }) + dict_config = { + 'sdtypes': { + 'A': 'categorical', + 'B2': 'categorical', + 'C': 'boolean' + }, + 'transformers': { + 'A': DummyMultiColumnTransformerNumerical(), + ('B2'): DummyMultiColumnTransformerNumerical(), + 'C': UniformEncoder() + } + } + config = Config(dict_config) + ht = HyperTransformer() + ht.set_config(config) + + # Run + transformed_data = ht.fit_transform(data_test) + reverse_transformed_data = ht.reverse_transform(transformed_data) + + # Assert + expected_transformed_data = pd.DataFrame({ + 'A': [1.0, 2.0, 3.0], + 'B2': [4.0, 5.0, 6.0], + 'C': [0.04206197607326308, 0.8000968077312287, 0.06325519846695522] + }) + + pd.testing.assert_frame_equal(transformed_data, expected_transformed_data) + pd.testing.assert_frame_equal(reverse_transformed_data, data_test) + def test_update_transformers_single_to_multi_column(self): """Test ``update_transformers`` to go from single to mutli column transformer.""" # Setup @@ -1515,13 +1553,13 @@ def test_update_transformers_single_to_multi_column(self): }, 'transformers': { 'C': UniformEncoder(), - "('A', 'B')": DummyMultiColumnTransformerNumerical() + ('A', 'B'): DummyMultiColumnTransformerNumerical(), } }) expected_multi_columns = { 'A': ('A', 'B'), - 'B': ('A', 'B') + 'B': ('A', 'B'), } assert repr(new_config) == repr(expected_config)