Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a _update_multi_column_transformer method #758

Merged
merged 6 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,32 @@ def _remove_column_in_multi_column_fields(self, column):

self.field_transformers[new_tuple] = self.field_transformers.pop(old_tuple)

def _update_multi_column_transformer(self):
"""Check that multi-columns mappings are valid and update them otherwise."""
all_fields_multi_column = set()
for columns, transformer in self.field_transformers.items():
if isinstance(transformer, BaseMultiColumnTransformer):
all_fields_multi_column.add(columns)

for field in all_fields_multi_column:
transformer = self.field_transformers[field]

columns_to_sdtypes = self._get_columns_to_sdtypes(field)
try:
transformer._validate_sdtypes( # pylint: disable=protected-access
columns_to_sdtypes
)
except TransformerInputError:
warnings.warn(
f"Transformer '{transformer.get_name()}' is incompatible with the "
f"multi-column field '{field}'. Assigning default transformer to the columns."
)
del self.field_transformers[field]
for column, sdtype in columns_to_sdtypes.items():
self.field_transformers[column] = deepcopy(get_default_transformer(sdtype))

self._multi_column_fields = self._create_multi_column_fields()

def update_transformers_by_sdtype(
self, sdtype, transformer=None, transformer_name=None, transformer_parameters=None):
"""Update the transformers for the specified ``sdtype``.
Expand Down Expand Up @@ -397,6 +423,7 @@ def update_transformers_by_sdtype(
self._remove_column_in_multi_column_fields(field)

self._multi_column_fields = self._create_multi_column_fields()
self._update_multi_column_transformer()
self._modified_config = True

def update_sdtypes(self, column_name_to_sdtype):
Expand Down Expand Up @@ -445,6 +472,7 @@ def update_sdtypes(self, column_name_to_sdtype):
)

self._multi_column_fields = self._create_multi_column_fields()
self._update_multi_column_transformer()
self._modified_config = True
if self._fitted:
warnings.warn(self._REFIT_MESSAGE)
Expand Down Expand Up @@ -485,6 +513,7 @@ def update_transformers(self, column_name_to_transformer):
self.field_transformers[column_name] = transformer

self._multi_column_fields = self._create_multi_column_fields()
self._update_multi_column_transformer()
self._modified_config = True

def remove_transformers(self, column_names):
Expand Down Expand Up @@ -514,6 +543,7 @@ def remove_transformers(self, column_names):

self.field_transformers[column_name] = None

self._update_multi_column_transformer()
if self._fitted:
warnings.warn(self._REFIT_MESSAGE)

Expand All @@ -540,6 +570,7 @@ def remove_transformers_by_sdtype(self, sdtype):

self.field_transformers[column_name] = None

self._update_multi_column_transformer()
if self._fitted:
warnings.warn(self._REFIT_MESSAGE)

Expand Down
4 changes: 4 additions & 0 deletions rdt/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,10 @@ def _validate_columns_to_sdtypes(self, data, columns_to_sdtypes):
missing_to_print = ', '.join(missing)
raise ValueError(f'Columns ({missing_to_print}) are not present in the data.')

@classmethod
def _validate_sdtypes(cls, columns_to_sdtypes):
raise NotImplementedError()

def _fit(self, data):
"""Fit the transformer to the data.

Expand Down
104 changes: 103 additions & 1 deletion tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pytest

from rdt import get_demo
from rdt.errors import ConfigNotSetError, InvalidConfigError, InvalidDataError, NotFittedError
from rdt.errors import (
ConfigNotSetError, InvalidConfigError, InvalidDataError, NotFittedError, TransformerInputError)
from rdt.hyper_transformer import Config, HyperTransformer
from rdt.transformers import (
AnonymizedFaker, BaseMultiColumnTransformer, BaseTransformer, BinaryEncoder,
Expand Down Expand Up @@ -67,6 +68,10 @@ def _fit(self, data):
} for column in self.columns
}

@classmethod
def _validate_sdtypes(cls, columns_to_sdtype):
return None

def _get_prefix(self):
return None

Expand Down Expand Up @@ -1846,3 +1851,100 @@ def test_with_tuple_returned_by_faker(self):
]
})
pd.testing.assert_frame_equal(result, expected_results)

expected_sdtype = {
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'D': 'categorical',
'E': 'categorical',
'C': 'boolean'
}
}
expected_transformer_update = {
'transformers': {
'A': UniformEncoder(),
'E': UniformEncoder(),
'C': UniformEncoder(),
'B': UniformEncoder(),
'D': UniformEncoder()
}
}
expected_transformer_remove = {
'transformers': {
'A': UniformEncoder(),
'E': UniformEncoder(),
'C': None,
'B': UniformEncoder(),
'D': UniformEncoder()
}
}
expected_update = {
**expected_sdtype,
**expected_transformer_update
}
expected_remove = {
**expected_sdtype,
**expected_transformer_remove
}

parametrization = [
(
'update_transformers', {'column_name_to_transformer': {'C': UniformEncoder()}},
expected_update
),
(
'update_transformers_by_sdtype',
{'sdtype': 'boolean', 'transformer': UniformEncoder()}, expected_update
),
('remove_transformers', {'column_names': 'C'}, expected_remove),
('remove_transformers_by_sdtype', {'sdtype': 'boolean'}, expected_remove),
]

@pytest.mark.parametrize(('method_name', 'method_input', 'expected_result'), parametrization)
def test_invalid_multi_column(self, method_name, method_input, expected_result):
"""Test the ``update`` and ``remove`` methods with invalid multi column transformer.

When a multi column is no longer valid, all these methods should raise a warning
and assign the default transformer to the columns.
"""
# Setup
class BadDummyMultiColumnTransformer(DummyMultiColumnTransformerNumerical):

@classmethod
def _validate_sdtypes(cls, columns_to_sdtype):
raise TransformerInputError('Invalid sdtype')

dict_config = {
'sdtypes': {
'A': 'categorical',
'B': 'categorical',
'D': 'categorical',
'E': 'categorical',
'C': 'boolean',
},
'transformers': {
'A': UniformEncoder(),
('B', 'D', 'C'): BadDummyMultiColumnTransformer(),
'E': UniformEncoder()
}
}

config = Config(dict_config)
ht = HyperTransformer()
ht.set_config(config)

# Run
expected_warning = re.escape(
"Transformer 'BadDummyMultiColumnTransformer' is incompatible with the "
"multi-column field '('B', 'D')'. Assigning default transformer to the columns."
)
with pytest.warns(UserWarning, match=expected_warning):
ht.__getattribute__(method_name)(**method_input)

# Assert
new_config = ht.get_config()
expected_config = Config(expected_result)
expected_multi_columns = {}
assert ht._multi_column_fields == expected_multi_columns
assert repr(new_config) == repr(expected_config)
Loading
Loading