Skip to content

Commit

Permalink
make update_transformer work for multi column
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Jan 4, 2024
1 parent b5bf5d8 commit f381e9e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
6 changes: 6 additions & 0 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,12 @@ def update_transformers(self, column_name_to_transformer):
warnings.filterwarnings('ignore', module='rdt.hyper_transformer')
self._hyper_transformer.update_transformers(column_name_to_transformer)

self.grouped_columns_to_transformers = {
col_tuple: transformer
for col_tuple, transformer in self._hyper_transformer.field_transformers.items()
if isinstance(col_tuple, tuple)
}

def _fit_hyper_transformer(self, data):
"""Create and return a new ``rdt.HyperTransformer`` instance.
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _validate_sdtypes(cls, columns_to_sdtypes):
instance.update_column('hotels', 'state', sdtype='state')

# Run
instance.add_column_relationship('address', 'hotels', ['city', 'state'])
instance.add_column_relationship('hotels', 'address', ['city', 'state'])

# Assert
instance.validate()
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,37 @@ def test_update_transformers_not_fitted(self):
with pytest.raises(NotFittedError, match=error_msg):
dp.update_transformers({'column': None})

def test_update_transformer_with_multi_column(self):
"""Test when a multi-column transformer is updated."""
# Setup
dp = DataProcessor(SingleTableMetadata())
dp.grouped_columns_to_transformers = {
('col_3', 'col_4'): 'transformer_3',
}
dp._hyper_transformer = Mock()
dp._hyper_transformer.field_transformers = {
'col_1': 'transformer_1',
'col_2': 'transformer_2',
('col_3', 'col_4'): 'transformer_3'
}

def update_transformers_effect(update_dict):
dp._hyper_transformer.field_transformers = {
'col_2': 'transformer_2',
('col_1', 'col_3'): 'transformer_4',
'col_4': 'transformer_3'
}

dp._hyper_transformer.update_transformers.side_effect = update_transformers_effect

# Run
dp.update_transformers({('col_1', 'col_3'): 'transformer_4'})

# Assert
assert dp.grouped_columns_to_transformers == {
('col_1', 'col_3'): 'transformer_4',
}

def test_update_transformers_ignores_rdt_refit_warning(self):
"""Test silencing hypertransformer refit warning (replaced by SDV warning elsewhere)"""
metadata = SingleTableMetadata()
Expand Down

0 comments on commit f381e9e

Please sign in to comment.