Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Oct 27, 2023
1 parent 99d31e2 commit 56f7281
Showing 1 changed file with 40 additions and 2 deletions.
42 changes: 40 additions & 2 deletions tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,44 @@ def test_hypertransformer_with_mutli_column_transformer_end_to_end(self):
pd.testing.assert_frame_equal(transformed_data, expected_transformed_data)
pd.testing.assert_frame_equal(reverse_transformed_data, data_test)

def test_hypertransformer_with_mutli_column_transformer_and_single_column(self):
"""Test a mutli column transformer used with for a single column."""
# Setup
data_test = pd.DataFrame({
'A': ['1.0', '2.0', '3.0'],
'B2': ['4.0', '5.0', '6.0'],
'C': [True, False, True]
})
dict_config = {
'sdtypes': {
'A': 'categorical',
'B2': 'categorical',
'C': 'boolean'
},
'transformers': {
'A': DummyMultiColumnTransformerNumerical(),
('B2'): DummyMultiColumnTransformerNumerical(),
'C': UniformEncoder()
}
}
config = Config(dict_config)
ht = HyperTransformer()
ht.set_config(config)

# Run
transformed_data = ht.fit_transform(data_test)
reverse_transformed_data = ht.reverse_transform(transformed_data)

# Assert
expected_transformed_data = pd.DataFrame({
'A': [1.0, 2.0, 3.0],
'B2': [4.0, 5.0, 6.0],
'C': [0.04206197607326308, 0.8000968077312287, 0.06325519846695522]
})

pd.testing.assert_frame_equal(transformed_data, expected_transformed_data)
pd.testing.assert_frame_equal(reverse_transformed_data, data_test)

def test_update_transformers_single_to_multi_column(self):
"""Test ``update_transformers`` to go from single to mutli column transformer."""
# Setup
Expand Down Expand Up @@ -1515,13 +1553,13 @@ def test_update_transformers_single_to_multi_column(self):
},
'transformers': {
'C': UniformEncoder(),
"('A', 'B')": DummyMultiColumnTransformerNumerical()
('A', 'B'): DummyMultiColumnTransformerNumerical(),
}
})

expected_multi_columns = {
'A': ('A', 'B'),
'B': ('A', 'B')
'B': ('A', 'B'),
}

assert repr(new_config) == repr(expected_config)
Expand Down

0 comments on commit 56f7281

Please sign in to comment.