Skip to content

Commit

Permalink
Keep the same column order when fitting the hypertransformer (#714)
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Oct 31, 2023
1 parent 0687a54 commit 003c7ba
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
18 changes: 16 additions & 2 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def update_transformers_by_sdtype(
if field in self._multi_column_fields:
self._remove_column_in_multi_column_fields(field)

self._multi_column_fields = self._create_multi_column_fields()
self._modified_config = True

def update_sdtypes(self, column_name_to_sdtype):
Expand Down Expand Up @@ -443,6 +444,7 @@ def update_sdtypes(self, column_name_to_sdtype):
"Use 'get_config()' to verify the transformers."
)

self._multi_column_fields = self._create_multi_column_fields()
self._modified_config = True
if self._fitted:
warnings.warn(self._REFIT_MESSAGE)
Expand Down Expand Up @@ -482,6 +484,7 @@ def update_transformers(self, column_name_to_transformer):

self.field_transformers[column_name] = transformer

self._multi_column_fields = self._create_multi_column_fields()
self._modified_config = True

def remove_transformers(self, column_names):
Expand Down Expand Up @@ -697,8 +700,19 @@ def fit(self, data):
self._validate_detect_config_called(data)
self._unfit()
self._input_columns = list(data.columns)
for field_column, field_transformer in self.field_transformers.items():
data = self._fit_field_transformer(data, field_column, field_transformer)
skipped_columns = [] # skip columns in multi column transformer already fitted
for column in self._input_columns:
if column in skipped_columns:
continue

if column in self._multi_column_fields:
field = self._multi_column_fields[column]
field_to_skip = [col for col in field if col != column]
skipped_columns.extend(field_to_skip)
else:
field = column

data = self._fit_field_transformer(data, field, self.field_transformers[field])

self._validate_all_fields_fitted()
self._fitted = True
Expand Down
28 changes: 27 additions & 1 deletion tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,7 +1475,7 @@ def test_hypertransformer_with_mutli_column_transformer_end_to_end(self):
expected_transformed_data = pd.DataFrame({
'A': [1.0, 2.0, 3.0],
'B': [4.0, 5.0, 6.0],
'C': [0.5892351646057272, 0.8615278122985615, 0.36493646501970534]
'C': [0.10333535312718026, 0.6697388922326716, 0.18775548909503287]
})

pd.testing.assert_frame_equal(transformed_data, expected_transformed_data)
Expand Down Expand Up @@ -1519,7 +1519,13 @@ def test_update_transformers_single_to_multi_column(self):
}
})

expected_multi_columns = {
'A': ('A', 'B'),
'B': ('A', 'B')
}

assert repr(new_config) == repr(expected_config)
assert ht._multi_column_fields == expected_multi_columns

def test_update_transformers_multi_to_single_column(self):
"""Test ``update_transformers`` to go from multi to single column transformer."""
Expand Down Expand Up @@ -1568,7 +1574,12 @@ def test_update_transformers_multi_to_single_column(self):
}
})

expected_multi_columns = {
'A': ('A', 'B'),
'B': ('A', 'B'),
}
assert repr(new_config) == repr(expected_config)
assert ht._multi_column_fields == expected_multi_columns

def test_update_transformers_by_sdtype_mutli_column(self):
"""Test ``update_transformers_by_sdtype`` with mutli column transformers."""
Expand Down Expand Up @@ -1612,8 +1623,13 @@ def test_update_transformers_by_sdtype_mutli_column(self):
"('B', 'D')": DummyMultiColumnTransformerNumerical()
}
})
expected_multi_columns = {
'B': ('B', 'D'),
'D': ('B', 'D')
}

assert repr(new_config) == repr(expected_config)
assert ht._multi_column_fields == expected_multi_columns

def test_remove_transformer(self):
"""Test ``remove_transformer`` with multi column transformer."""
Expand Down Expand Up @@ -1656,8 +1672,13 @@ def test_remove_transformer(self):
'B': None
}
})
exepected_multi_columns = {
'C': ('C', 'D'),
'D': ('C', 'D')
}

assert repr(new_config) == repr(expected_config)
assert ht._multi_column_fields == exepected_multi_columns

def test_remove_transformer_by_sdtype(self):
"""Test ``remove_transformer_by_sdtype`` with multi column transformer."""
Expand Down Expand Up @@ -1749,5 +1770,10 @@ def test_update_sdtype(self):
'C': FloatFormatter()
}
})
expected_multi_columns = {
'B': ('B', 'D'),
'D': ('B', 'D')
}

assert repr(new_config) == repr(expected_config)
assert ht._multi_column_fields == expected_multi_columns
12 changes: 12 additions & 0 deletions tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,10 @@ def _reverse_transform(self, data):
'col2': 'categorical',
}
ht = HyperTransformer()
ht._multi_column_fields = {
'col1': ('col1', 'col2'),
'col2': ('col1', 'col2'),
}
ht.field_transformers = field_transformers
ht.field_sdtypes = field_sdtypes
ht._get_columns_to_sdtypes = Mock(return_value=columns_to_sdtype)
Expand Down Expand Up @@ -2326,6 +2330,7 @@ def test_update_transformers_by_sdtype_with_multi_column_transformer(self):
}
mock__remove_column_in_multi_column_fields = Mock()
ht._remove_column_in_multi_column_fields = mock__remove_column_in_multi_column_fields
ht._create_multi_column_fields = Mock()

# Run
ht.update_transformers_by_sdtype(
Expand All @@ -2336,6 +2341,7 @@ def test_update_transformers_by_sdtype_with_multi_column_transformer(self):
# Assert
assert len(ht.field_transformers) == 4
assert mock__remove_column_in_multi_column_fields.call_count == 1
ht._create_multi_column_fields.assert_called_once()

@patch('rdt.hyper_transformer.warnings')
def test_update_transformers_fitted(self, mock_warnings):
Expand Down Expand Up @@ -2402,6 +2408,8 @@ def test_update_transformers_multi_column(self):
('A', 'B'): None,
'C': None,
}
ht._create_multi_column_fields = Mock()

# Run
ht.update_transformers(column_name_to_transformer)

Expand All @@ -2410,7 +2418,9 @@ def test_update_transformers_multi_column(self):
('A', 'B'): None,
'C': None,
}

assert ht.field_transformers == expected_field_transformers
ht._create_multi_column_fields.assert_called_once()

def test_update_transformers_changing_multi_column_transformer(self):
"""Test ``update_transformers`` when changing a multi column transformer."""
Expand Down Expand Up @@ -2976,6 +2986,7 @@ class DummyMultiColumnTransformer(BaseMultiColumnTransformer):
'column2': ('column2', 'column3'),
'column3': ('column2', 'column3')
}
ht._create_multi_column_fields = Mock()

# Run
ht.update_sdtypes(column_name_to_sdtype={
Expand All @@ -2998,6 +3009,7 @@ class DummyMultiColumnTransformer(BaseMultiColumnTransformer):
}
assert ht.field_sdtypes == expected_field_sdtypes
assert str(ht.field_transformers) == str(expected_field_transformers)
ht._create_multi_column_fields.assert_called_once()

def test_update_sdtypes_multi_column_with_unsupported_sdtypes(self):
"""Test the ``update_sdtypes`` method.
Expand Down

0 comments on commit 003c7ba

Please sign in to comment.