From 75976125c3315bca3b57354e9dfb311a6efa6d70 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Fri, 1 Apr 2022 18:59:42 +0200 Subject: [PATCH 1/8] move functionality to the config class --- rdt/hyper_transformer.py | 253 +++++++++++++++++++++++---------------- 1 file changed, 147 insertions(+), 106 deletions(-) diff --git a/rdt/hyper_transformer.py b/rdt/hyper_transformer.py index c29eecf5e..ecfa05cef 100644 --- a/rdt/hyper_transformer.py +++ b/rdt/hyper_transformer.py @@ -15,11 +15,129 @@ class Config(dict): """Config dict for ``HyperTransformer`` with a better representation.""" + def __init__(self): + super().__init__() + self._provided_field_sdtypes = {} + self._provided_field_transformers = {} + self['field_transformers'] = {} + self['field_sdtypes'] = {} + + @staticmethod + def _validate_config(config): + sdtypes = config.get('sdtypes', {}) + transformers = config.get('transformers', {}) + for column, transformer in transformers.items(): + input_sdtype = transformer.get_input_sdtype() + sdtype = sdtypes.get(column) + if input_sdtype != sdtype: + warnings.warn( + f'You are assigning a {input_sdtype} transformer to a {sdtype} ' + f"column ('{column}'). If the transformer doesn't match the " + 'sdtype, it may lead to errors.' + ) + + @staticmethod + def _get_supported_sdtypes(): + return get_transformers_by_type().keys() + + def reset(self): + """Reset the `field_sdtypes` and `field_transformers`.""" + self['field_sdtypes'] = self._provided_field_sdtypes.copy() + self['field_transformers'] = self._provided_field_transformers.copy() + + def update_sdtypes(self, column_name_to_sdtype): + """Update the ``sdtypes`` for each specified column name. + + Args: + column_name_to_sdtype(dict): + Dict mapping column names to ``sdtypes`` for that column. + """ + unsupported_sdtypes = [] + transformers_to_update = {} + for column, sdtype in column_name_to_sdtype.items(): + if sdtype not in self._get_supported_sdtypes(): + unsupported_sdtypes.append(sdtype) + elif self.field_sdtypes.get(column) != sdtype: + current_transformer = self.field_transformers.get(column) + if not current_transformer or current_transformer.get_input_sdtype() != sdtype: + transformers_to_update[column] = get_default_transformer(sdtype) + + if unsupported_sdtypes: + raise Error( + f'Unsupported sdtypes ({unsupported_sdtypes}). To use ``sdtypes`` with specific ' + 'semantic meanings, please contact the SDV team to update to rdt_plus. Otherwise, ' + "use 'pii' to anonymize the column." + ) + + self.field_sdtypes.update(column_name_to_sdtype) + self.field_transformers.update(transformers_to_update) + self._provided_field_sdtypes.update(column_name_to_sdtype) + + def update_transformers_by_sdtype(self, sdtype, transformer): + """Update the transformers for the specified ``sdtype``. + + Given an ``sdtype`` and a ``transformer``, change all the fields of the ``sdtype`` + to use the given transformer. + + Args: + sdtype (str): + Semantic data type for the transformer. + transformer (rdt.transformers.BaseTransformer): + Transformer class or instance to be used for the given ``sdtype``. + """ + if not self['field_sdtypes']: + raise Error( + 'Nothing to update. Use the `detect_initial_config` method to ' + 'pre-populate all the sdtypes and transformers from your dataset.' + ) + + for field, field_sdtype in self['field_sdtypes'].items(): + if field_sdtype == sdtype: + self._provided_field_transformers[field] = transformer + self['field_transformers'][field] = transformer + + def update_transformers(self, column_name_to_transformer): + """Update any of the transformers assigned to each of the column names. + + Args: + column_name_to_transformer(dict): + Dict mapping column names to transformers to be used for that column. + """ + for column_name, transformer in column_name_to_transformer.items(): + current_sdtype = self['field_sdtypes'].get(column_name) + if current_sdtype and current_sdtype != transformer.get_input_sdtype(): + warnings.warn( + f'You are assigning a {transformer.get_input_sdtype()} transformer ' + f'to a {current_sdtype} column ({column_name}). ' + "If the transformer doesn't match the sdtype, it may lead to errors." + ) + + self['field_transformers'][column_name] = transformer + self._provided_field_transformers[column_name] = transformer + + def set_config(self, config): + """Set the ``HyperTransformer`` configuration. + + This method will only update the sdtypes/transformers passed. Other previously + learned sdtypes/transformers will not be affected. + + Args: + config (dict): + A dictionary containing the following two dictionaries: + - sdtypes: A dictionary mapping column names to their ``sdtypes``. + - transformers: A dictionary mapping column names to their transformer instances. + """ + self._validate_config(config) + self._provided_field_sdtypes = config['sdtypes'] + self['field_sdtypes'].update(config['sdtypes']) + self._provided_field_transformers = config['transformers'] + self['field_transformers'].update(config['transformers']) + def __repr__(self): """Pretty print the dictionary.""" config = { - 'sdtypes': self['sdtypes'], - 'transformers': {k: repr(v) for k, v in self['transformers'].items()} + 'sdtypes': self['field_sdtypes'], + 'transformers': {k: repr(v) for k, v in self['field_transformers'].items()} } return json.dumps(config, indent=4) @@ -121,14 +239,15 @@ def _subset(input_list, other_list, not_in=False): def _create_multi_column_fields(self): multi_column_fields = {} - for field in list(self.field_sdtypes) + list(self.field_transformers): + for field in list(self.config['field_sdtypes']) + list(self.config['field_transformers']): if isinstance(field, tuple): for column in field: multi_column_fields[column] = field + return multi_column_fields def _validate_field_transformers(self): - for field in self.field_transformers: + for field in self.config['field_transformers']: if self._field_in_set(field, self._specified_fields): raise ValueError(f'Multiple transformers specified for the field {field}. ' 'Each field can have at most one transformer defined in ' @@ -138,15 +257,7 @@ def _validate_field_transformers(self): def __init__(self): self._default_sdtype_transformers = {} - - # ``_provided_field_sdtypes``` contains only the sdtypes specified by the user, - # while `field_sdtypes` contains both the sdtypes specified by the user and the - # ones learned through ``fit``/``detect_initial_config``. Same for ``field_transformers``. - self._provided_field_sdtypes = {} - self.field_sdtypes = {} - self._provided_field_transformers = {} - self.field_transformers = {} - + self.config = Config() self._specified_fields = set() self._validate_field_transformers() self._valid_output_sdtypes = self._DEFAULT_OUTPUT_SDTYPES @@ -163,22 +274,6 @@ def _field_in_data(field, data): all_columns_in_data = isinstance(field, tuple) and all(col in data for col in field) return field in data or all_columns_in_data - @staticmethod - def _validate_config(config): - sdtypes = config.get('sdtypes', {}) - transformers = config.get('transformers', {}) - for column, transformer in transformers.items(): - input_sdtype = transformer.get_input_sdtype() - sdtype = sdtypes.get(column) - if input_sdtype != sdtype: - warnings.warn(f'You are assigning a {input_sdtype} transformer to a {sdtype} ' - f"column ('{column}'). If the transformer doesn't match the " - 'sdtype, it may lead to errors.') - - @staticmethod - def _get_supported_sdtypes(): - return get_transformers_by_type().keys() - def get_config(self): """Get the current ``HyperTransformer`` configuration. @@ -188,10 +283,7 @@ def get_config(self): - sdtypes: A dictionary mapping column names to their ``sdtypes``. - transformers: A dictionary mapping column names to their transformer instances. """ - return Config({ - 'sdtypes': self.field_sdtypes, - 'transformers': self.field_transformers - }) + return self.config def set_config(self, config): """Set the ``HyperTransformer`` configuration. @@ -205,13 +297,7 @@ def set_config(self, config): - sdtypes: A dictionary mapping column names to their ``sdtypes``. - transformers: A dictionary mapping column names to their transformer instances. """ - self._validate_config(config) - self._provided_field_sdtypes = config['sdtypes'] - self.field_sdtypes.update(config['sdtypes']) - self._provided_field_transformers = config['transformers'] - self.field_transformers.update(config['transformers']) - if self._fitted: - warnings.warn(self._REFIT_MESSAGE) + self.config.set_config(config) def update_transformers_by_sdtype(self, sdtype, transformer): """Update the transformers for the specified ``sdtype``. @@ -225,17 +311,7 @@ def update_transformers_by_sdtype(self, sdtype, transformer): transformer (rdt.transformers.BaseTransformer): Transformer class or instance to be used for the given ``sdtype``. """ - if not self.field_sdtypes: - raise Error( - 'Nothing to update. Use the `detect_initial_config` method to ' - 'pre-populate all the sdtypes and transformers from your dataset.' - ) - - for field, field_sdtype in self.field_sdtypes.items(): - if field_sdtype == sdtype: - self._provided_field_transformers[field] = transformer - self.field_transformers[field] = transformer - + self.config.update_transformers_by_sdtype(sdtype, transformer) if self._fitted: warnings.warn( 'For this change to take effect, please refit your data using ' @@ -249,29 +325,10 @@ def update_sdtypes(self, column_name_to_sdtype): column_name_to_sdtype(dict): Dict mapping column names to ``sdtypes`` for that column. """ - if len(self.field_sdtypes) == 0: + if len(self.config['field_sdtypes']) == 0: raise Error(self._DETECT_CONFIG_MESSAGE) - unsupported_sdtypes = [] - transformers_to_update = {} - for column, sdtype in column_name_to_sdtype.items(): - if sdtype not in self._get_supported_sdtypes(): - unsupported_sdtypes.append(sdtype) - elif self.field_sdtypes.get(column) != sdtype: - current_transformer = self.field_transformers.get(column) - if not current_transformer or current_transformer.get_input_sdtype() != sdtype: - transformers_to_update[column] = get_default_transformer(sdtype) - - if unsupported_sdtypes: - raise Error( - f'Unsupported sdtypes ({unsupported_sdtypes}). To use ``sdtypes`` with specific ' - 'semantic meanings, please contact the SDV team to update to rdt_plus. Otherwise, ' - "use 'pii' to anonymize the column." - ) - - self.field_sdtypes.update(column_name_to_sdtype) - self.field_transformers.update(transformers_to_update) - self._provided_field_sdtypes.update(column_name_to_sdtype) + self.config.update_sdtypes(column_name_to_sdtype) self._user_message( 'The transformers for these columns may change based on the new sdtype.\n' "Use 'get_config()' to verify the transformers.", 'Info' @@ -288,21 +345,10 @@ def update_transformers(self, column_name_to_transformer): """ if self._fitted: warnings.warn(self._REFIT_MESSAGE) - - if len(self.field_transformers) == 0: + if len(self.config['field_transformers']) == 0: raise Error(self._DETECT_CONFIG_MESSAGE) - for column_name, transformer in column_name_to_transformer.items(): - current_sdtype = self.field_sdtypes.get(column_name) - if current_sdtype and current_sdtype != transformer.get_input_sdtype(): - warnings.warn( - f'You are assigning a {transformer.get_input_sdtype()} transformer ' - f'to a {current_sdtype} column ({column_name}). ' - "If the transformer doesn't match the sdtype, it may lead to errors." - ) - - self.field_transformers[column_name] = transformer - self._provided_field_transformers[column_name] = transformer + self.config.update_transformers(column_name_to_transformer) def get_transformer(self, field): """Get the transformer instance used for a field. @@ -405,11 +451,10 @@ def get_transformer_tree_yaml(self): def _set_field_sdtype(self, data, field): clean_data = data[field].dropna() kind = clean_data.infer_objects().dtype.kind - self.field_sdtypes[field] = self._DTYPES_TO_SDTYPES[kind] + self.config['field_sdtypes'][field] = self._DTYPES_TO_SDTYPES[kind] def _unfit(self): - self.field_sdtypes = self._provided_field_sdtypes.copy() - self.field_transformers = self._provided_field_transformers.copy() + self.config.reset() self._transformers_sequence = [] self._input_columns = [] self._output_columns = [] @@ -421,14 +466,15 @@ def _learn_config(self, data): """Unfit the HyperTransformer and learn the sdtypes and transformers of the data.""" self._unfit() for field in data: - if field not in self.field_sdtypes: + if field not in self.config['field_sdtypes']: self._set_field_sdtype(data, field) - if field not in self.field_transformers: - sdtype = self.field_sdtypes[field] + if field not in self.config['field_transformers']: + sdtype = self.config['field_sdtypes'][field] if sdtype in self._default_sdtype_transformers: - self.field_transformers[field] = self._default_sdtype_transformers[sdtype] + default_transformer = self._default_sdtype_transformers[sdtype] + self.config['field_transformers'][field] = default_transformer else: - self.field_transformers[field] = get_default_transformer(sdtype) + self.config['field_transformers'][field] = get_default_transformer(sdtype) def detect_initial_config(self, data): """Print the configuration of the data. @@ -444,8 +490,7 @@ def detect_initial_config(self, data): """ # Reset the state of the HyperTransformer self._default_sdtype_transformers = {} - self._provided_field_sdtypes = {} - self._provided_field_transformers = {} + self.config = Config() # Set the sdtypes and transformers of all fields to their defaults self._learn_config(data) @@ -453,18 +498,13 @@ def detect_initial_config(self, data): self._user_message('Detecting a new config from the data ... SUCCESS') self._user_message('Setting the new config ... SUCCESS') - config = Config({ - 'sdtypes': self.field_sdtypes, - 'transformers': self.field_transformers - }) - self._user_message('Config:') - self._user_message(config) + self._user_message(self.config) def _get_next_transformer(self, output_field, output_sdtype, next_transformers): next_transformer = None - if output_field in self.field_transformers: - next_transformer = self.field_transformers[output_field] + if output_field in self.config['field_transformers']: + next_transformer = self.config['field_transformers'][output_field] elif output_sdtype not in self._valid_output_sdtypes: if next_transformers is not None and output_field in next_transformers: @@ -527,14 +567,14 @@ def _sort_output_columns(self): def _validate_detect_config_called(self, data): """Assert the ``detect_initial_config`` method is correcly called before fitting.""" - if len(self.field_sdtypes) == 0 and len(self.field_transformers) == 0: + if len(self.config['field_sdtypes']) == 0 and len(self.config['field_transformers']) == 0: raise NotFittedError( "No config detected. Set the config using 'set_config' or pre-populate " "it automatically from your data using 'detect_initial_config' prior to " 'fitting your data.' ) - fields = list(self.field_sdtypes.keys()) + fields = list(self.config['field_sdtypes'].keys()) unknown_columns = self._subset(data.columns, fields, not_in=True) if unknown_columns: raise NotFittedError( @@ -555,7 +595,8 @@ def fit(self, data): self._learn_config(data) self._input_columns = list(data.columns) for field in self._input_columns: - data = self._fit_field_transformer(data, field, self.field_transformers[field]) + data = self._fit_field_transformer( + data, field, self.config['field_transformers'][field]) self._validate_all_fields_fitted() self._fitted = True From 3eae43e1d9d86b5d061272dbfb833aed3f6ec707 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Wed, 6 Apr 2022 17:09:55 +0200 Subject: [PATCH 2/8] Fix tests --- rdt/hyper_transformer.py | 53 ++-- tests/unit/test_hyper_transformer.py | 441 +++++++++++++++++---------- 2 files changed, 306 insertions(+), 188 deletions(-) diff --git a/rdt/hyper_transformer.py b/rdt/hyper_transformer.py index ecfa05cef..e6021fa9c 100644 --- a/rdt/hyper_transformer.py +++ b/rdt/hyper_transformer.py @@ -15,12 +15,8 @@ class Config(dict): """Config dict for ``HyperTransformer`` with a better representation.""" - def __init__(self): - super().__init__() - self._provided_field_sdtypes = {} - self._provided_field_transformers = {} - self['field_transformers'] = {} - self['field_sdtypes'] = {} + _provided_field_transformers = None + _provided_field_sdtypes = None @staticmethod def _validate_config(config): @@ -40,6 +36,33 @@ def _validate_config(config): def _get_supported_sdtypes(): return get_transformers_by_type().keys() + def set_config(self, config): + """Set the ``HyperTransformer`` configuration. + + This method will only update the sdtypes/transformers passed. Other previously + learned sdtypes/transformers will not be affected. + + Args: + config (dict): + A dictionary containing the following two dictionaries: + - sdtypes: A dictionary mapping column names to their ``sdtypes``. + - transformers: A dictionary mapping column names to their transformer instances. + """ + self._validate_config(config) + self._provided_field_sdtypes = config['sdtypes'] + self['field_sdtypes'].update(config['sdtypes']) + self._provided_field_transformers = config['transformers'] + self['field_transformers'].update(config['transformers']) + + def __init__(self, config=None): + super().__init__() + self._provided_field_sdtypes = {} + self._provided_field_transformers = {} + self['field_transformers'] = {} + self['field_sdtypes'] = {} + if config: + self.set_config(config) + def reset(self): """Reset the `field_sdtypes` and `field_transformers`.""" self['field_sdtypes'] = self._provided_field_sdtypes.copy() @@ -115,24 +138,6 @@ def update_transformers(self, column_name_to_transformer): self['field_transformers'][column_name] = transformer self._provided_field_transformers[column_name] = transformer - def set_config(self, config): - """Set the ``HyperTransformer`` configuration. - - This method will only update the sdtypes/transformers passed. Other previously - learned sdtypes/transformers will not be affected. - - Args: - config (dict): - A dictionary containing the following two dictionaries: - - sdtypes: A dictionary mapping column names to their ``sdtypes``. - - transformers: A dictionary mapping column names to their transformer instances. - """ - self._validate_config(config) - self._provided_field_sdtypes = config['sdtypes'] - self['field_sdtypes'].update(config['sdtypes']) - self._provided_field_transformers = config['transformers'] - self['field_transformers'].update(config['transformers']) - def __repr__(self): """Pretty print the dictionary.""" config = { diff --git a/tests/unit/test_hyper_transformer.py b/tests/unit/test_hyper_transformer.py index 7fa8ea65a..d85e415e1 100644 --- a/tests/unit/test_hyper_transformer.py +++ b/tests/unit/test_hyper_transformer.py @@ -1,5 +1,6 @@ import contextlib import io +import json import re from collections import defaultdict from unittest import TestCase @@ -10,12 +11,216 @@ import pytest from rdt import HyperTransformer +from rdt.hyper_transformer import Config from rdt.errors import Error, NotFittedError from rdt.transformers import ( BinaryEncoder, FloatFormatter, FrequencyEncoder, GaussianNormalizer, OneHotEncoder, UnixTimestampEncoder) +class TestConfig(TestCase): + + @patch('rdt.hyper_transformer.warnings') + def test__validate_config(self, warnings_mock): + """Test the ``_validate_config`` method. + + The method should throw a warnings if the ``sdtypes`` of any column name doesn't match + the ``sdtype`` of its transformer. + + Setup: + - A mock for warnings. + + Input: + - A config with a transformers dict that has a transformer that doesn't match the + sdtype for the same column in sdtypes dict. + + Expected behavior: + - There should be a warning. + """ + # Setup + transformers = { + 'column1': FloatFormatter(), + 'column2': FrequencyEncoder() + } + sdtypes = { + 'column1': 'numerical', + 'column2': 'numerical' + } + config = { + 'sdtypes': sdtypes, + 'transformers': transformers + } + + # Run + Config._validate_config(config) + + # Assert + expected_message = ( + "You are assigning a categorical transformer to a numerical column ('column2'). " + "If the transformer doesn't match the sdtype, it may lead to errors." + ) + warnings_mock.warn.assert_called_once_with(expected_message) + + @patch('rdt.hyper_transformer.warnings') + def test__validate_config_no_warning(self, warnings_mock): + """Test the ``_validate_config`` method with no warning. + + The method should not throw a warnings if the ``sdtypes`` of all columns match + the ``sdtype`` of their transformers. + + Setup: + - A mock for warnings. + + Input: + - A config with a transformers dict that matches the sdtypes for each column. + + Expected behavior: + - There should be no warning. + """ + # Setup + transformers = { + 'column1': FloatFormatter(), + 'column2': FrequencyEncoder() + } + sdtypes = { + 'column1': 'numerical', + 'column2': 'categorical' + } + config = { + 'sdtypes': sdtypes, + 'transformers': transformers + } + + # Run + Config._validate_config(config) + + # Assert + warnings_mock.warn.assert_not_called() + + def test_set_config(self): + """Test the ``set_config`` method. + + The method should set the ``instance._provided_field_sdtypes``, + ``instance['field_sdtypes']``, ``instance._provided_field_transformers ``and + ``instance['field_transformers']`` attributes based on the config. + + Setup: + - Mock the ``_validate_config`` method so no warnings get raised. + + Input: + - A dict with two keys: + - transformers: Maps to a dict that maps column names to transformers. + - sdtypes: Maps to a dict that maps column names to ``sdtypes``. + + Expected behavior: + - The attributes ``instance._provided_field_sdtypes``, ``instance['field_sdtypes']``, + ``instance._provided_field_transformers `` and ``instance['field_transformers']`` + should be set. + """ + # Setup + transformers = { + 'column1': FloatFormatter(), + 'column2': FrequencyEncoder() + } + sdtypes = { + 'column1': 'numerical', + 'column2': 'categorical' + } + config_dict = { + 'sdtypes': sdtypes, + 'transformers': transformers + } + config = Config() + config._validate_config = Mock() + + # Run + config.set_config(config_dict) + + # Assert + config._validate_config.assert_called_once_with(config_dict) + assert config._provided_field_transformers == config_dict['transformers'] + assert config['field_transformers'] == config_dict['transformers'] + assert config._provided_field_sdtypes == config_dict['sdtypes'] + assert config['field_sdtypes'] == config_dict['sdtypes'] + + def test___init__(self): + """Test the instantiation of ``Config`` with the default values.""" + # Setup + config = Config() + + # Assert + assert config._provided_field_sdtypes == {} + assert config._provided_field_transformers == {} + assert config['field_transformers'] == {} + assert config['field_sdtypes'] == {} + + def test___init__with_custom_config(self): + """Test the instantiation of ``Config`` with the default values. + + Setup: + - Dict with transformers + - Dict with sdtypes + - Custom config dict that contains transformers and sdtypes. + + Side Effect: + - config._provided_field_sdtypes has been updated with the `sdtypes`. + - config._provided_field_transformers has been updated with the `transformers`. + - config['field_sdtypes'] has been updated with the `sdtypes`. + - config['field_transformers'] has been updated with the `transformers`. + """ + # Setup + transformers = { + 'column1': FloatFormatter(), + 'column2': FrequencyEncoder() + } + sdtypes = { + 'column1': 'numerical', + 'column2': 'numerical' + } + custom_config = { + 'sdtypes': sdtypes, + 'transformers': transformers + } + + # Run + config = Config(custom_config) + + # Assert + assert config._provided_field_sdtypes == sdtypes + assert config._provided_field_transformers == transformers + assert config['field_transformers'] == transformers + assert config['field_sdtypes'] == sdtypes + + + def test_reset(self): + """Test that `reset` resets the `field_sdtypes` and `field_transformers`. + + Test that when calling the `reset` method the `field_sdtypes` and `field_transformers` + are reseted to the `config._provided_field_sdtypes` and + `config._provided_field_transformers` respectively. + + Setup: + - Instance of config. + - `instance._provided_field_sdtypes` is a dict containing information. + - `instance._provided_field_transformers` is a dict containing information. + + Side Effects: + - config['field_sdtypes'] is set to `config._provided_field_sdtypes`. + - config['field_transformers'] is set to `config._provided_field_transformers`. + """ + # Setup + instance = Config() + instance._provided_field_sdtypes = {'my_column': 'boolean'} + instance._provided_field_transformers = {'my_column': BinaryEncoder} + + # Run + instance.reset() + + # Assert + assert instance['field_transformers'] == {'my_column': BinaryEncoder} + assert instance['field_sdtypes'] == {'my_column': 'boolean'} + + class TestHyperTransformer(TestCase): @patch('rdt.hyper_transformer.print') @@ -119,7 +324,7 @@ def test__validate_field_transformers(self): ('integer',): int_transformer } ht = HyperTransformer() - ht.field_transformers = field_transformers + ht.config['field_transformers'] = field_transformers # Run / Assert error_msg = ( @@ -130,19 +335,16 @@ def test__validate_field_transformers(self): with pytest.raises(ValueError, match=error_msg): ht._validate_field_transformers() + @patch('rdt.hyper_transformer.Config') @patch('rdt.hyper_transformer.HyperTransformer._create_multi_column_fields') @patch('rdt.hyper_transformer.HyperTransformer._validate_field_transformers') - def test___init__(self, validation_mock, multi_column_mock): + def test___init__(self, validation_mock, multi_column_mock, mock_config): """Test create new instance of HyperTransformer""" # Run ht = HyperTransformer() # Asserts - assert ht._provided_field_sdtypes == {} - assert ht.field_sdtypes == {} - assert ht._default_sdtype_transformers == {} - assert ht._provided_field_transformers == {} - assert ht.field_transformers == {} + assert ht.config == mock_config.return_value multi_column_mock.assert_called_once() validation_mock.assert_called_once() @@ -156,29 +358,24 @@ def test__unfit(self): - instance._transformers_sequence is a list of transformers - instance._output_columns is a list of columns - instance._input_columns is a list of columns - - instance._provided_field_sdtypes is a dict of strings to strings - - instance._provided_field_transformers is a dict of strings to transformers + - instance.config is a mock Expected behavior: - instance._fitted is set to False - instance._transformers_sequence is set to [] - instance._output_columns is an empty list - instance._input_columns is an empty list - - instance._provided_field_sdtypes doesn't change - - instance.field_sdtypes is the same as instance._provided_field_sdtypes - - instance._provided_field_transformers doesn't change - - instance.field_transformers is the same as instance._provided_field_transformers + - instance.config.reset has been called once. """ # Setup ht = HyperTransformer() + ht.config = Mock() ht._fitted = True ht._transformers_sequence = [BinaryEncoder(), FloatFormatter()] ht._output_columns = ['col1', 'col2'] ht._input_columns = ['col3', 'col4'] sdtypes = {'col1': 'float', 'col2': 'categorical'} - ht._provided_field_sdtypes = sdtypes transformers = {'col2': FloatFormatter(), 'col3': BinaryEncoder()} - ht._provided_field_transformers = transformers # Run ht._unfit() @@ -190,10 +387,7 @@ def test__unfit(self): assert ht._output_columns == [] assert ht._input_columns == [] assert ht._transformers_tree == {} - assert ht.field_sdtypes == sdtypes - assert ht._provided_field_sdtypes == sdtypes - assert ht.field_transformers == transformers - assert ht._provided_field_transformers == transformers + ht.config.reset.assert_called_once() def test__create_multi_column_fields(self): """Test the ``_create_multi_column_fields`` method. @@ -203,8 +397,8 @@ def test__create_multi_column_fields(self): each column to its corresponding tuple. Setup: - - instance.field_transformers will be populated with multi-column fields - - instance.field_sdtypes will be populated with multi-column fields + - instance.config['field_transformers'] will be populated with multi-column fields + - instance.config['field_sdtypes'] will be populated with multi-column fields Output: - A dict mapping each column name that is part of a multi-column @@ -212,13 +406,13 @@ def test__create_multi_column_fields(self): """ # Setup ht = HyperTransformer() - ht.field_transformers = { + ht.config['field_transformers'] = { 'a': BinaryEncoder, 'b': UnixTimestampEncoder, ('c', 'd'): UnixTimestampEncoder, 'e': FloatFormatter } - ht.field_sdtypes = { + ht.config['field_sdtypes'] = { 'f': 'categorical', ('g', 'h'): 'datetime' } @@ -239,7 +433,7 @@ def test__get_next_transformer_field_transformer(self): """Test the ``_get_next_transformer method. This tests that if the transformer is defined in the - ``instance.field_transformers`` dict, then it is returned + ``instance.config['field_transformers']`` dict, then it is returned even if the output sdtype is final. Setup: @@ -259,7 +453,7 @@ def test__get_next_transformer_field_transformer(self): # Setup transformer = FloatFormatter() ht = HyperTransformer() - ht.field_transformers = {'a.out': transformer} + ht.config['field_transformers'] = {'a.out': transformer} ht._default_sdtype_transformers = {'numerical': GaussianNormalizer()} # Run @@ -272,7 +466,7 @@ def test__get_next_transformer_final_output_sdtype(self): """Test the ``_get_next_transformer method. This tests that if the transformer is not defined in the - ``instance.field_transformers`` dict and its output sdtype + ``instance.config['field_transformers']`` dict and its output sdtype is in ``instance._transform_output_sdtypes``, then ``None`` is returned. @@ -302,7 +496,7 @@ def test__get_next_transformer_next_transformers(self): """Test the ``_get_next_transformer method. This tests that if the transformer is not defined in the - ``instance.field_transformers`` dict and its output sdtype + ``instance.config['field_transformers']`` dict and its output sdtype is not in ``instance._transform_output_sdtypes`` and the ``next_transformers`` dict has a transformer for the output field, then it is used. @@ -337,7 +531,7 @@ def test__get_next_transformer_default_transformer(self, mock): """Test the ``_get_next_transformer method. This tests that if the transformer is not defined in the - ``instance.field_transformers`` dict or ``next_transformers`` + ``instance.config['field_transformers']`` dict or ``next_transformers`` and its output sdtype is not in ``instance._transform_output_sdtypes`` then the default_transformer is used. @@ -463,9 +657,9 @@ def test_detect_initial_config(self): output = f_out.getvalue() # Assert - assert ht._provided_field_sdtypes == {} - assert ht._provided_field_transformers == {} - assert ht.field_sdtypes == { + assert ht.config._provided_field_sdtypes == {} + assert ht.config._provided_field_transformers == {} + assert ht.config['field_sdtypes'] == { 'col1': 'numerical', 'col2': 'categorical', 'col3': 'boolean', @@ -473,7 +667,7 @@ def test_detect_initial_config(self): 'col5': 'numerical' } - field_transformers = {k: repr(v) for (k, v) in ht.field_transformers.items()} + field_transformers = {k: repr(v) for (k, v) in ht.config['field_transformers'].items()} assert field_transformers == { 'col1': "FloatFormatter(missing_value_replacement='mean')", 'col2': 'FrequencyEncoder()', @@ -703,83 +897,6 @@ def test__sort_output_columns(self): # Assert assert ht._output_columns == ['a.is_null', 'b.value', 'b.is_null', 'c.value'] - @patch('rdt.hyper_transformer.warnings') - def test__validate_config(self, warnings_mock): - """Test the ``_validate_config`` method. - - The method should throw a warnings if the ``sdtypes`` of any column name doesn't match - the ``sdtype`` of its transformer. - - Setup: - - A mock for warnings. - - Input: - - A config with a transformers dict that has a transformer that doesn't match the - sdtype for the same column in sdtypes dict. - - Expected behavior: - - There should be a warning. - """ - # Setup - transformers = { - 'column1': FloatFormatter(), - 'column2': FrequencyEncoder() - } - sdtypes = { - 'column1': 'numerical', - 'column2': 'numerical' - } - config = { - 'sdtypes': sdtypes, - 'transformers': transformers - } - - # Run - HyperTransformer._validate_config(config) - - # Assert - expected_message = ( - "You are assigning a categorical transformer to a numerical column ('column2'). " - "If the transformer doesn't match the sdtype, it may lead to errors." - ) - warnings_mock.warn.assert_called_once_with(expected_message) - - @patch('rdt.hyper_transformer.warnings') - def test__validate_config_no_warning(self, warnings_mock): - """Test the ``_validate_config`` method with no warning. - - The method should not throw a warnings if the ``sdtypes`` of all columns match - the ``sdtype`` of their transformers. - - Setup: - - A mock for warnings. - - Input: - - A config with a transformers dict that matches the sdtypes for each column. - - Expected behavior: - - There should be no warning. - """ - # Setup - transformers = { - 'column1': FloatFormatter(), - 'column2': FrequencyEncoder() - } - sdtypes = { - 'column1': 'numerical', - 'column2': 'categorical' - } - config = { - 'sdtypes': sdtypes, - 'transformers': transformers - } - - # Run - HyperTransformer._validate_config(config) - - # Assert - warnings_mock.warn.assert_not_called() - def test_get_config(self): """Test the ``get_config`` method. @@ -797,11 +914,11 @@ def test_get_config(self): """ # Setup ht = HyperTransformer() - ht.field_transformers = { + ht.config['field_transformers'] = { 'column1': FloatFormatter(), 'column2': FrequencyEncoder() } - ht.field_sdtypes = { + ht.config['field_sdtypes'] = { 'column1': 'numerical', 'column2': 'categorical' } @@ -810,11 +927,7 @@ def test_get_config(self): config = ht.get_config() # Assert - expected_config = { - 'sdtypes': ht.field_sdtypes, - 'transformers': ht.field_transformers - } - assert config == expected_config + assert config == ht.config def test_get_config_empty(self): """Test the ``get_config`` method when the config is empty. @@ -834,18 +947,15 @@ def test_get_config_empty(self): config = ht.get_config() # Assert - expected_config = { - 'sdtypes': {}, - 'transformers': {} - } + expected_config = Config() assert config == expected_config def test_set_config(self): """Test the ``set_config`` method. - The method should set the ``instance._provided_field_sdtypes``, - ``instance.field_sdtypes``, ``instance._provided_field_transformers ``and - ``instance.field_transformers`` attributes based on the config. + The method should set the ``instance.config._provided_field_sdtypes``, + ``instance.config['field_sdtypes']``, ``instance.config._provided_field_transformers ``and + ``instance.config['field_transformers']`` attributes based on the config. Setup: - Mock the ``_validate_config`` method so no warnings get raised. @@ -856,9 +966,10 @@ def test_set_config(self): - sdtypes: Maps to a dict that maps column names to ``sdtypes``. Expected behavior: - - The attributes ``instance._provided_field_sdtypes``, ``instance.field_sdtypes``, - ``instance._provided_field_transformers `` and ``instance.field_transformers`` - should be set. + - The attributes ``instance.config._provided_field_sdtypes``, + ``instance.config['field_sdtypes']``, + ``instance.config._provided_field_transformers`` + and ``instance.config['field_transformers']`` should be set. """ # Setup transformers = { @@ -874,17 +985,17 @@ def test_set_config(self): 'transformers': transformers } ht = HyperTransformer() - ht._validate_config = Mock() + ht.config._validate_config = Mock() # Run ht.set_config(config) # Assert - ht._validate_config.assert_called_once_with(config) - assert ht._provided_field_transformers == config['transformers'] - assert ht.field_transformers == config['transformers'] - assert ht._provided_field_sdtypes == config['sdtypes'] - assert ht.field_sdtypes == config['sdtypes'] + ht.config._validate_config.assert_called_once_with(config) + assert ht.config._provided_field_transformers == config['transformers'] + assert ht.config['field_transformers'] == config['transformers'] + assert ht.config._provided_field_sdtypes == config['sdtypes'] + assert ht.config['field_sdtypes'] == config['sdtypes'] @patch('rdt.hyper_transformer.warnings') def test_set_config_already_fitted(self, mock_warnings): @@ -976,7 +1087,7 @@ def test__validate_detect_config_called_incorrect_data(self): """ # Setup ht = HyperTransformer() - ht.field_sdtypes = {'col1': 'float', 'col2': 'categorical'} + ht.config['field_sdtypes'] = {'col1': 'float', 'col2': 'categorical'} data = pd.DataFrame({'col1': [1, 2], 'col3': ['a', 'b']}) error_msg = re.escape( 'The data you are trying to fit has different columns than the original ' @@ -1033,7 +1144,7 @@ def test_fit(self, get_default_transformer_mock): get_default_transformer_mock.return_value = datetime_transformer ht = HyperTransformer() - ht.field_transformers = field_transformers + ht.config['field_transformers'] = field_transformers ht._default_sdtype_transformers = default_sdtype_transformers ht._fit_field_transformer = Mock() ht._fit_field_transformer.return_value = data @@ -1265,7 +1376,7 @@ def test_update_transformers_by_sdtype_no_field_sdtypes(self, mock_print): ht.update_transformers_by_sdtype('categorical', object()) # Assert - assert ht.field_transformers == {} + assert ht.config['field_transformers'] == {} @patch('rdt.hyper_transformer.print') def test_update_transformers_by_sdtype_field_sdtypes_not_fitted(self, mock_print): @@ -1285,11 +1396,11 @@ def test_update_transformers_by_sdtype_field_sdtypes_not_fitted(self, mock_print """ # Setup ht = HyperTransformer() - ht.field_transformers = { + ht.config['field_transformers'] = { 'categorical_column': 'rdt.transformers.BaseTransformer', 'numerical_column': 'rdt.transformers.FloatFormatter', } - ht.field_sdtypes = { + ht.config['field_sdtypes'] = { 'categorical_column': 'categorical', 'numerical_column': 'numerical', @@ -1305,7 +1416,7 @@ def test_update_transformers_by_sdtype_field_sdtypes_not_fitted(self, mock_print 'categorical_column': transformer, 'numerical_column': 'rdt.transformers.FloatFormatter', } - assert ht.field_transformers == expected_field_transformers + assert ht.config['field_transformers'] == expected_field_transformers @patch('rdt.hyper_transformer.warnings') @patch('rdt.hyper_transformer.print') @@ -1329,8 +1440,10 @@ def test_update_transformers_by_sdtype_field_sdtypes_fitted(self, mock_print, mo # Setup ht = HyperTransformer() ht._fitted = True - ht.field_transformers = {'categorical_column': 'rdt.transformers.BaseTransformer'} - ht.field_sdtypes = {'categorical_column': 'categorical'} + ht.config['field_transformers'] = { + 'categorical_column': 'rdt.transformers.BaseTransformer' + } + ht.config['field_sdtypes'] = {'categorical_column': 'categorical'} transformer = object() # Run @@ -1344,7 +1457,7 @@ def test_update_transformers_by_sdtype_field_sdtypes_fitted(self, mock_print, mo mock_print.assert_not_called() mock_warnings.warn.assert_called_once_with(expected_warnings_msg) - assert ht.field_transformers == {'categorical_column': transformer} + assert ht.config['field_transformers'] == {'categorical_column': transformer} @patch('rdt.hyper_transformer.warnings') def test_update_transformers_fitted(self, mock_warnings): @@ -1371,7 +1484,7 @@ def test_update_transformers_fitted(self, mock_warnings): # Setup instance = HyperTransformer() instance._fitted = True - instance.field_transformers = {'a': object()} + instance.config['field_transformers'] = {'a': object()} mock_transformer = Mock() mock_transformer.get_input_sdtype.return_value = 'datetime' column_name_to_transformer = { @@ -1388,8 +1501,8 @@ def test_update_transformers_fitted(self, mock_warnings): ) mock_warnings.warn.assert_called_once_with(expected_message) - assert instance.field_transformers['my_column'] == mock_transformer - assert instance._provided_field_transformers == {'my_column': mock_transformer} + assert instance.config['field_transformers']['my_column'] == mock_transformer + assert instance.config._provided_field_transformers == {'my_column': mock_transformer} @patch('rdt.hyper_transformer.warnings') def test_update_transformers_not_fitted(self, mock_warnings): @@ -1417,7 +1530,7 @@ def test_update_transformers_not_fitted(self, mock_warnings): # Setup instance = HyperTransformer() instance._fitted = False - instance.field_transformers = {'a': object()} + instance.config['field_transformers'] = {'a': object()} mock_transformer = Mock() mock_transformer.get_input_sdtype.return_value = 'datetime' column_name_to_transformer = { @@ -1429,8 +1542,8 @@ def test_update_transformers_not_fitted(self, mock_warnings): # Assert mock_warnings.warn.assert_not_called() - assert instance.field_transformers['my_column'] == mock_transformer - assert instance._provided_field_transformers == {'my_column': mock_transformer} + assert instance.config['field_transformers']['my_column'] == mock_transformer + assert instance.config._provided_field_transformers == {'my_column': mock_transformer} def test_update_transformers_no_field_transformers(self): """Test update transformers. @@ -1498,8 +1611,8 @@ def test_update_transformers_missmatch_sdtypes(self, mock_warnings): instance = HyperTransformer() instance._fitted = False mock_numerical = Mock() - instance.field_transformers = {'my_column': mock_numerical} - instance.field_sdtypes = {'my_column': 'categorical'} + instance.config['field_transformers'] = {'my_column': mock_numerical} + instance.config['field_sdtypes'] = {'my_column': 'categorical'} mock_transformer = Mock() mock_transformer.get_input_sdtype.return_value = 'datetime' column_name_to_transformer = { @@ -1517,8 +1630,8 @@ def test_update_transformers_missmatch_sdtypes(self, mock_warnings): ) assert mock_warnings.called_once_with(expected_call) - assert instance.field_transformers['my_column'] == mock_transformer - assert instance._provided_field_transformers == {'my_column': mock_transformer} + assert instance.config['field_transformers']['my_column'] == mock_transformer + assert instance.config._provided_field_transformers == {'my_column': mock_transformer} @patch('rdt.hyper_transformer.warnings') def test_update_sdtypes_fitted(self, mock_warnings): @@ -1544,8 +1657,8 @@ def test_update_sdtypes_fitted(self, mock_warnings): """ # Setup instance = HyperTransformer() - instance.field_transformers = {'a': FrequencyEncoder, 'b': FloatFormatter} - instance.field_sdtypes = {'a': 'categorical'} + instance.config['field_transformers'] = {'a': FrequencyEncoder, 'b': FloatFormatter} + instance.config['field_sdtypes'] = {'a': 'categorical'} instance._fitted = True instance._user_message = Mock() column_name_to_sdtype = { @@ -1566,8 +1679,8 @@ def test_update_sdtypes_fitted(self, mock_warnings): ) mock_warnings.warn.assert_called_once_with(expected_message) - assert instance.field_sdtypes == {'my_column': 'numerical', 'a': 'categorical'} - assert instance._provided_field_sdtypes == {'my_column': 'numerical'} + assert instance.config['field_sdtypes'] == {'my_column': 'numerical', 'a': 'categorical'} + assert instance.config._provided_field_sdtypes == {'my_column': 'numerical'} instance._user_message.assert_called_once_with(user_message, 'Info') @patch('rdt.hyper_transformer.warnings') @@ -1597,7 +1710,7 @@ def test_update_sdtypes_not_fitted(self, mock_warnings): instance = HyperTransformer() instance._fitted = False instance._user_message = Mock() - instance.field_sdtypes = {'a': 'categorical'} + instance.config['field_sdtypes'] = {'a': 'categorical'} column_name_to_sdtype = { 'my_column': 'numerical' } @@ -1611,8 +1724,8 @@ def test_update_sdtypes_not_fitted(self, mock_warnings): "Use 'get_config()' to verify the transformers." ) mock_warnings.warn.assert_not_called() - assert instance.field_sdtypes == {'my_column': 'numerical', 'a': 'categorical'} - assert instance._provided_field_sdtypes == {'my_column': 'numerical'} + assert instance.config['field_sdtypes'] == {'my_column': 'numerical', 'a': 'categorical'} + assert instance.config._provided_field_sdtypes == {'my_column': 'numerical'} instance._user_message.assert_called_once_with(user_message, 'Info') def test_update_sdtypes_no_field_sdtypes(self): @@ -1633,7 +1746,7 @@ def test_update_sdtypes_no_field_sdtypes(self): # Setup instance = HyperTransformer() instance._fitted = False - instance.field_sdtypes = {} + instance.config['field_sdtypes'] = {} column_name_to_sdtype = { 'my_column': 'numerical' } @@ -1666,7 +1779,7 @@ def test_update_sdtypes_invalid_sdtype(self): instance._get_supported_sdtypes = Mock() instance._get_supported_sdtypes.return_value = [] instance._fitted = False - instance.field_sdtypes = { + instance.config['field_sdtypes'] = { 'my_column': 'categorical' } column_name_to_sdtype = { @@ -1712,7 +1825,7 @@ def test_update_sdtypes_different_sdtype(self, mock_warnings, default_mock): instance = HyperTransformer() instance._fitted = False instance._user_message = Mock() - instance.field_sdtypes = {'a': 'categorical'} + instance.config['field_sdtypes'] = {'a': 'categorical'} transformer_mock = Mock() default_mock.return_value = transformer_mock column_name_to_sdtype = { @@ -1728,9 +1841,9 @@ def test_update_sdtypes_different_sdtype(self, mock_warnings, default_mock): "Use 'get_config()' to verify the transformers." ) mock_warnings.warn.assert_not_called() - assert instance.field_sdtypes == {'a': 'numerical'} - assert instance.field_transformers == {'a': transformer_mock} - assert instance._provided_field_sdtypes == {'a': 'numerical'} + assert instance.config['field_sdtypes'] == {'a': 'numerical'} + assert instance.config['field_transformers'] == {'a': transformer_mock} + assert instance.config._provided_field_sdtypes == {'a': 'numerical'} instance._user_message.assert_called_once_with(user_message, 'Info') @patch('rdt.hyper_transformer.warnings') From d017a06b1327188a1c80103103365a0fa7b6f671 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Wed, 6 Apr 2022 17:11:38 +0200 Subject: [PATCH 3/8] Fix lint --- tests/unit/test_hyper_transformer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/unit/test_hyper_transformer.py b/tests/unit/test_hyper_transformer.py index d85e415e1..fc342c0ed 100644 --- a/tests/unit/test_hyper_transformer.py +++ b/tests/unit/test_hyper_transformer.py @@ -1,6 +1,5 @@ import contextlib import io -import json import re from collections import defaultdict from unittest import TestCase @@ -11,8 +10,8 @@ import pytest from rdt import HyperTransformer -from rdt.hyper_transformer import Config from rdt.errors import Error, NotFittedError +from rdt.hyper_transformer import Config from rdt.transformers import ( BinaryEncoder, FloatFormatter, FrequencyEncoder, GaussianNormalizer, OneHotEncoder, UnixTimestampEncoder) @@ -191,7 +190,6 @@ def test___init__with_custom_config(self): assert config['field_transformers'] == transformers assert config['field_sdtypes'] == sdtypes - def test_reset(self): """Test that `reset` resets the `field_sdtypes` and `field_transformers`. @@ -374,8 +372,6 @@ def test__unfit(self): ht._transformers_sequence = [BinaryEncoder(), FloatFormatter()] ht._output_columns = ['col1', 'col2'] ht._input_columns = ['col3', 'col4'] - sdtypes = {'col1': 'float', 'col2': 'categorical'} - transformers = {'col2': FloatFormatter(), 'col3': BinaryEncoder()} # Run ht._unfit() From 33200f9f04928d8c05777fcc7cd972b64cdf3420 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Fri, 8 Apr 2022 17:16:38 +0200 Subject: [PATCH 4/8] Fix api --- rdt/hyper_transformer.py | 99 +++++++++++++++++++++++----------------- 1 file changed, 57 insertions(+), 42 deletions(-) diff --git a/rdt/hyper_transformer.py b/rdt/hyper_transformer.py index e6021fa9c..6426e0797 100644 --- a/rdt/hyper_transformer.py +++ b/rdt/hyper_transformer.py @@ -12,11 +12,13 @@ get_default_transformer, get_transformer_instance, get_transformers_by_type) -class Config(dict): - """Config dict for ``HyperTransformer`` with a better representation.""" +class Config: + """Config class for ``HyperTransformer``.""" _provided_field_transformers = None _provided_field_sdtypes = None + field_sdtypes = None + field_transformers = None @staticmethod def _validate_config(config): @@ -50,25 +52,24 @@ def set_config(self, config): """ self._validate_config(config) self._provided_field_sdtypes = config['sdtypes'] - self['field_sdtypes'].update(config['sdtypes']) + self.field_sdtypes.update(config['sdtypes']) self._provided_field_transformers = config['transformers'] - self['field_transformers'].update(config['transformers']) + self.field_transformers.update(config['transformers']) def __init__(self, config=None): - super().__init__() self._provided_field_sdtypes = {} self._provided_field_transformers = {} - self['field_transformers'] = {} - self['field_sdtypes'] = {} + self.field_transformers = {} + self.field_sdtypes = {} if config: self.set_config(config) def reset(self): """Reset the `field_sdtypes` and `field_transformers`.""" - self['field_sdtypes'] = self._provided_field_sdtypes.copy() - self['field_transformers'] = self._provided_field_transformers.copy() + self.field_sdtypes = self._provided_field_sdtypes.copy() + self.field_transformers = self._provided_field_transformers.copy() - def update_sdtypes(self, column_name_to_sdtype): + def update_sdtypes(self, column_name_to_sdtype, provided_sdtype=False): """Update the ``sdtypes`` for each specified column name. Args: @@ -81,9 +82,7 @@ def update_sdtypes(self, column_name_to_sdtype): if sdtype not in self._get_supported_sdtypes(): unsupported_sdtypes.append(sdtype) elif self.field_sdtypes.get(column) != sdtype: - current_transformer = self.field_transformers.get(column) - if not current_transformer or current_transformer.get_input_sdtype() != sdtype: - transformers_to_update[column] = get_default_transformer(sdtype) + transformers_to_update[column] = get_default_transformer(sdtype) if unsupported_sdtypes: raise Error( @@ -94,7 +93,8 @@ def update_sdtypes(self, column_name_to_sdtype): self.field_sdtypes.update(column_name_to_sdtype) self.field_transformers.update(transformers_to_update) - self._provided_field_sdtypes.update(column_name_to_sdtype) + if provided_sdtype: + self._provided_field_sdtypes.update(column_name_to_sdtype) def update_transformers_by_sdtype(self, sdtype, transformer): """Update the transformers for the specified ``sdtype``. @@ -108,26 +108,32 @@ def update_transformers_by_sdtype(self, sdtype, transformer): transformer (rdt.transformers.BaseTransformer): Transformer class or instance to be used for the given ``sdtype``. """ - if not self['field_sdtypes']: + if not self.field_sdtypes: raise Error( 'Nothing to update. Use the `detect_initial_config` method to ' 'pre-populate all the sdtypes and transformers from your dataset.' ) - for field, field_sdtype in self['field_sdtypes'].items(): + for field, field_sdtype in self.field_sdtypes.items(): if field_sdtype == sdtype: self._provided_field_transformers[field] = transformer - self['field_transformers'][field] = transformer + self.field_transformers[field] = transformer - def update_transformers(self, column_name_to_transformer): - """Update any of the transformers assigned to each of the column names. + def update_transformers(self, column_name_to_transformer, provided_transformer=False): + """Update `self.field_transformers`. + + Update the `self.field_transformers` with the provided column name and transformer, if + `provided_transformer` is `True` will automatically update the + `self._provided_field_transformers` which will not reset. Args: column_name_to_transformer(dict): Dict mapping column names to transformers to be used for that column. + provided_transformer(bool): + Wether or not to add to `self._provided_field_transformers`. """ for column_name, transformer in column_name_to_transformer.items(): - current_sdtype = self['field_sdtypes'].get(column_name) + current_sdtype = self.field_sdtypes.get(column_name) if current_sdtype and current_sdtype != transformer.get_input_sdtype(): warnings.warn( f'You are assigning a {transformer.get_input_sdtype()} transformer ' @@ -135,14 +141,23 @@ def update_transformers(self, column_name_to_transformer): "If the transformer doesn't match the sdtype, it may lead to errors." ) - self['field_transformers'][column_name] = transformer - self._provided_field_transformers[column_name] = transformer + self.field_transformers[column_name] = transformer + if provided_transformer: + self._provided_field_transformers[column_name] = transformer + + def get_field_transformers(self): + """Return the fields transformers.""" + return self.field_transformers + + def get_field_sdtypes(self): + """Return the fields sdtypes.""" + return self.field_sdtypes def __repr__(self): """Pretty print the dictionary.""" config = { - 'sdtypes': self['field_sdtypes'], - 'transformers': {k: repr(v) for k, v in self['field_transformers'].items()} + 'sdtypes': self.field_sdtypes, + 'transformers': {k: repr(v) for k, v in self.field_transformers.items()} } return json.dumps(config, indent=4) @@ -244,7 +259,7 @@ def _subset(input_list, other_list, not_in=False): def _create_multi_column_fields(self): multi_column_fields = {} - for field in list(self.config['field_sdtypes']) + list(self.config['field_transformers']): + for field in list(self.config.field_sdtypes) + list(self.config.field_transformers): if isinstance(field, tuple): for column in field: multi_column_fields[column] = field @@ -252,7 +267,7 @@ def _create_multi_column_fields(self): return multi_column_fields def _validate_field_transformers(self): - for field in self.config['field_transformers']: + for field in self.config.field_transformers: if self._field_in_set(field, self._specified_fields): raise ValueError(f'Multiple transformers specified for the field {field}. ' 'Each field can have at most one transformer defined in ' @@ -288,7 +303,7 @@ def get_config(self): - sdtypes: A dictionary mapping column names to their ``sdtypes``. - transformers: A dictionary mapping column names to their transformer instances. """ - return self.config + return self.config.to_dict() def set_config(self, config): """Set the ``HyperTransformer`` configuration. @@ -330,10 +345,10 @@ def update_sdtypes(self, column_name_to_sdtype): column_name_to_sdtype(dict): Dict mapping column names to ``sdtypes`` for that column. """ - if len(self.config['field_sdtypes']) == 0: + if len(self.config.field_sdtypes) == 0: raise Error(self._DETECT_CONFIG_MESSAGE) - self.config.update_sdtypes(column_name_to_sdtype) + self.config.update_sdtypes(column_name_to_sdtype, provided_sdtype=True) self._user_message( 'The transformers for these columns may change based on the new sdtype.\n' "Use 'get_config()' to verify the transformers.", 'Info' @@ -350,10 +365,10 @@ def update_transformers(self, column_name_to_transformer): """ if self._fitted: warnings.warn(self._REFIT_MESSAGE) - if len(self.config['field_transformers']) == 0: + if len(self.config.field_transformers) == 0: raise Error(self._DETECT_CONFIG_MESSAGE) - self.config.update_transformers(column_name_to_transformer) + self.config.update_transformers(column_name_to_transformer, provided_transformer=True) def get_transformer(self, field): """Get the transformer instance used for a field. @@ -456,7 +471,7 @@ def get_transformer_tree_yaml(self): def _set_field_sdtype(self, data, field): clean_data = data[field].dropna() kind = clean_data.infer_objects().dtype.kind - self.config['field_sdtypes'][field] = self._DTYPES_TO_SDTYPES[kind] + self.config.update_sdtypes({field: self._DTYPES_TO_SDTYPES[kind]}) def _unfit(self): self.config.reset() @@ -471,15 +486,15 @@ def _learn_config(self, data): """Unfit the HyperTransformer and learn the sdtypes and transformers of the data.""" self._unfit() for field in data: - if field not in self.config['field_sdtypes']: + if field not in self.config.field_sdtypes: self._set_field_sdtype(data, field) - if field not in self.config['field_transformers']: - sdtype = self.config['field_sdtypes'][field] + if field not in self.config.field_transformers: + sdtype = self.config.field_sdtypes.get(field) if sdtype in self._default_sdtype_transformers: default_transformer = self._default_sdtype_transformers[sdtype] - self.config['field_transformers'][field] = default_transformer + self.config.update_transformers({field: default_transformer}) else: - self.config['field_transformers'][field] = get_default_transformer(sdtype) + self.config.update_transformers({field: get_default_transformer(sdtype)}) def detect_initial_config(self, data): """Print the configuration of the data. @@ -508,8 +523,8 @@ def detect_initial_config(self, data): def _get_next_transformer(self, output_field, output_sdtype, next_transformers): next_transformer = None - if output_field in self.config['field_transformers']: - next_transformer = self.config['field_transformers'][output_field] + if output_field in self.config.field_transformers: + next_transformer = self.config.field_transformers.get(output_field) elif output_sdtype not in self._valid_output_sdtypes: if next_transformers is not None and output_field in next_transformers: @@ -572,14 +587,14 @@ def _sort_output_columns(self): def _validate_detect_config_called(self, data): """Assert the ``detect_initial_config`` method is correcly called before fitting.""" - if len(self.config['field_sdtypes']) == 0 and len(self.config['field_transformers']) == 0: + if len(self.config.field_sdtypes) == 0 and len(self.config.field_transformers) == 0: raise NotFittedError( "No config detected. Set the config using 'set_config' or pre-populate " "it automatically from your data using 'detect_initial_config' prior to " 'fitting your data.' ) - fields = list(self.config['field_sdtypes'].keys()) + fields = list(self.config.field_sdtypes.keys()) unknown_columns = self._subset(data.columns, fields, not_in=True) if unknown_columns: raise NotFittedError( @@ -601,7 +616,7 @@ def fit(self, data): self._input_columns = list(data.columns) for field in self._input_columns: data = self._fit_field_transformer( - data, field, self.config['field_transformers'][field]) + data, field, self.config.field_transformers.get(field)) self._validate_all_fields_fitted() self._fitted = True From b243d90c0048d0c742747bba4a9ef29273fd4571 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Fri, 8 Apr 2022 18:19:38 +0200 Subject: [PATCH 5/8] Fix unit --- rdt/hyper_transformer.py | 21 +++- tests/unit/test_hyper_transformer.py | 147 ++++++++++++--------------- 2 files changed, 79 insertions(+), 89 deletions(-) diff --git a/rdt/hyper_transformer.py b/rdt/hyper_transformer.py index 6426e0797..79ed8c05b 100644 --- a/rdt/hyper_transformer.py +++ b/rdt/hyper_transformer.py @@ -82,7 +82,9 @@ def update_sdtypes(self, column_name_to_sdtype, provided_sdtype=False): if sdtype not in self._get_supported_sdtypes(): unsupported_sdtypes.append(sdtype) elif self.field_sdtypes.get(column) != sdtype: - transformers_to_update[column] = get_default_transformer(sdtype) + current_transformer = self.field_transformers.get(column) + if not current_transformer or current_transformer.get_input_sdtype() != sdtype: + transformers_to_update[column] = get_default_transformer(sdtype) if unsupported_sdtypes: raise Error( @@ -153,12 +155,16 @@ def get_field_sdtypes(self): """Return the fields sdtypes.""" return self.field_sdtypes - def __repr__(self): - """Pretty print the dictionary.""" - config = { + def to_dict(self): + """Return a `dict` object of `Config`.""" + return { 'sdtypes': self.field_sdtypes, 'transformers': {k: repr(v) for k, v in self.field_transformers.items()} } + + def __repr__(self): + """Pretty print the dictionary.""" + config = self.to_dict() return json.dumps(config, indent=4) @@ -303,7 +309,7 @@ def get_config(self): - sdtypes: A dictionary mapping column names to their ``sdtypes``. - transformers: A dictionary mapping column names to their transformer instances. """ - return self.config.to_dict() + return self.config def set_config(self, config): """Set the ``HyperTransformer`` configuration. @@ -318,6 +324,11 @@ def set_config(self, config): - transformers: A dictionary mapping column names to their transformer instances. """ self.config.set_config(config) + if self._fitted: + warnings.warn( + 'For this change to take effect, please refit your data using ' + "'fit' or 'fit_transform'." + ) def update_transformers_by_sdtype(self, sdtype, transformer): """Update the transformers for the specified ``sdtype``. diff --git a/tests/unit/test_hyper_transformer.py b/tests/unit/test_hyper_transformer.py index fc342c0ed..444082c1d 100644 --- a/tests/unit/test_hyper_transformer.py +++ b/tests/unit/test_hyper_transformer.py @@ -138,9 +138,9 @@ def test_set_config(self): # Assert config._validate_config.assert_called_once_with(config_dict) assert config._provided_field_transformers == config_dict['transformers'] - assert config['field_transformers'] == config_dict['transformers'] + assert config.field_transformers == config_dict['transformers'] assert config._provided_field_sdtypes == config_dict['sdtypes'] - assert config['field_sdtypes'] == config_dict['sdtypes'] + assert config.field_sdtypes == config_dict['sdtypes'] def test___init__(self): """Test the instantiation of ``Config`` with the default values.""" @@ -150,8 +150,8 @@ def test___init__(self): # Assert assert config._provided_field_sdtypes == {} assert config._provided_field_transformers == {} - assert config['field_transformers'] == {} - assert config['field_sdtypes'] == {} + assert config.field_transformers == {} + assert config.field_sdtypes == {} def test___init__with_custom_config(self): """Test the instantiation of ``Config`` with the default values. @@ -164,8 +164,8 @@ def test___init__with_custom_config(self): Side Effect: - config._provided_field_sdtypes has been updated with the `sdtypes`. - config._provided_field_transformers has been updated with the `transformers`. - - config['field_sdtypes'] has been updated with the `sdtypes`. - - config['field_transformers'] has been updated with the `transformers`. + - config.field_sdtypes has been updated with the `sdtypes`. + - config.field_transformers has been updated with the `transformers`. """ # Setup transformers = { @@ -187,8 +187,8 @@ def test___init__with_custom_config(self): # Assert assert config._provided_field_sdtypes == sdtypes assert config._provided_field_transformers == transformers - assert config['field_transformers'] == transformers - assert config['field_sdtypes'] == sdtypes + assert config.field_transformers == transformers + assert config.field_sdtypes == sdtypes def test_reset(self): """Test that `reset` resets the `field_sdtypes` and `field_transformers`. @@ -203,8 +203,8 @@ def test_reset(self): - `instance._provided_field_transformers` is a dict containing information. Side Effects: - - config['field_sdtypes'] is set to `config._provided_field_sdtypes`. - - config['field_transformers'] is set to `config._provided_field_transformers`. + - config.field_sdtypes is set to `config._provided_field_sdtypes`. + - config.field_transformers is set to `config._provided_field_transformers`. """ # Setup instance = Config() @@ -215,8 +215,8 @@ def test_reset(self): instance.reset() # Assert - assert instance['field_transformers'] == {'my_column': BinaryEncoder} - assert instance['field_sdtypes'] == {'my_column': 'boolean'} + assert instance.field_transformers == {'my_column': BinaryEncoder} + assert instance.field_sdtypes == {'my_column': 'boolean'} class TestHyperTransformer(TestCase): @@ -322,7 +322,7 @@ def test__validate_field_transformers(self): ('integer',): int_transformer } ht = HyperTransformer() - ht.config['field_transformers'] = field_transformers + ht.config.field_transformers = field_transformers # Run / Assert error_msg = ( @@ -393,8 +393,8 @@ def test__create_multi_column_fields(self): each column to its corresponding tuple. Setup: - - instance.config['field_transformers'] will be populated with multi-column fields - - instance.config['field_sdtypes'] will be populated with multi-column fields + - instance.config.field_transformers will be populated with multi-column fields + - instance.config.field_sdtypes will be populated with multi-column fields Output: - A dict mapping each column name that is part of a multi-column @@ -402,13 +402,13 @@ def test__create_multi_column_fields(self): """ # Setup ht = HyperTransformer() - ht.config['field_transformers'] = { + ht.config.field_transformers = { 'a': BinaryEncoder, 'b': UnixTimestampEncoder, ('c', 'd'): UnixTimestampEncoder, 'e': FloatFormatter } - ht.config['field_sdtypes'] = { + ht.config.field_sdtypes = { 'f': 'categorical', ('g', 'h'): 'datetime' } @@ -429,7 +429,7 @@ def test__get_next_transformer_field_transformer(self): """Test the ``_get_next_transformer method. This tests that if the transformer is defined in the - ``instance.config['field_transformers']`` dict, then it is returned + ``instance.config.field_transformers`` dict, then it is returned even if the output sdtype is final. Setup: @@ -449,7 +449,7 @@ def test__get_next_transformer_field_transformer(self): # Setup transformer = FloatFormatter() ht = HyperTransformer() - ht.config['field_transformers'] = {'a.out': transformer} + ht.config.field_transformers = {'a.out': transformer} ht._default_sdtype_transformers = {'numerical': GaussianNormalizer()} # Run @@ -462,7 +462,7 @@ def test__get_next_transformer_final_output_sdtype(self): """Test the ``_get_next_transformer method. This tests that if the transformer is not defined in the - ``instance.config['field_transformers']`` dict and its output sdtype + ``instance.config.field_transformers`` dict and its output sdtype is in ``instance._transform_output_sdtypes``, then ``None`` is returned. @@ -492,7 +492,7 @@ def test__get_next_transformer_next_transformers(self): """Test the ``_get_next_transformer method. This tests that if the transformer is not defined in the - ``instance.config['field_transformers']`` dict and its output sdtype + ``instance.config.field_transformers`` dict and its output sdtype is not in ``instance._transform_output_sdtypes`` and the ``next_transformers`` dict has a transformer for the output field, then it is used. @@ -527,7 +527,7 @@ def test__get_next_transformer_default_transformer(self, mock): """Test the ``_get_next_transformer method. This tests that if the transformer is not defined in the - ``instance.config['field_transformers']`` dict or ``next_transformers`` + ``instance.config.field_transformers`` dict or ``next_transformers`` and its output sdtype is not in ``instance._transform_output_sdtypes`` then the default_transformer is used. @@ -655,7 +655,7 @@ def test_detect_initial_config(self): # Assert assert ht.config._provided_field_sdtypes == {} assert ht.config._provided_field_transformers == {} - assert ht.config['field_sdtypes'] == { + assert ht.config.field_sdtypes == { 'col1': 'numerical', 'col2': 'categorical', 'col3': 'boolean', @@ -663,7 +663,7 @@ def test_detect_initial_config(self): 'col5': 'numerical' } - field_transformers = {k: repr(v) for (k, v) in ht.config['field_transformers'].items()} + field_transformers = {k: repr(v) for (k, v) in ht.config.field_transformers.items()} assert field_transformers == { 'col1': "FloatFormatter(missing_value_replacement='mean')", 'col2': 'FrequencyEncoder()', @@ -910,11 +910,11 @@ def test_get_config(self): """ # Setup ht = HyperTransformer() - ht.config['field_transformers'] = { + ht.config.field_transformers = { 'column1': FloatFormatter(), 'column2': FrequencyEncoder() } - ht.config['field_sdtypes'] = { + ht.config.field_sdtypes = { 'column1': 'numerical', 'column2': 'categorical' } @@ -925,33 +925,12 @@ def test_get_config(self): # Assert assert config == ht.config - def test_get_config_empty(self): - """Test the ``get_config`` method when the config is empty. - - The method should return a dictionary containing the following keys: - - sdtypes: Maps to a dictionary that maps column names to ``sdtypes``. - - transformers: Maps to a dictionary that maps column names to transformers. - - Output: - - A dictionary with the key sdtypes mapping to an empty dict and the key - transformers mapping to an empty dict. - """ - # Setup - ht = HyperTransformer() - - # Run - config = ht.get_config() - - # Assert - expected_config = Config() - assert config == expected_config - def test_set_config(self): """Test the ``set_config`` method. The method should set the ``instance.config._provided_field_sdtypes``, - ``instance.config['field_sdtypes']``, ``instance.config._provided_field_transformers ``and - ``instance.config['field_transformers']`` attributes based on the config. + ``instance.config.field_sdtypes``, ``instance.config._provided_field_transformers ``and + ``instance.config.field_transformers`` attributes based on the config. Setup: - Mock the ``_validate_config`` method so no warnings get raised. @@ -963,9 +942,9 @@ def test_set_config(self): Expected behavior: - The attributes ``instance.config._provided_field_sdtypes``, - ``instance.config['field_sdtypes']``, + ``instance.config.field_sdtypes``, ``instance.config._provided_field_transformers`` - and ``instance.config['field_transformers']`` should be set. + and ``instance.config.field_transformers`` should be set. """ # Setup transformers = { @@ -989,9 +968,9 @@ def test_set_config(self): # Assert ht.config._validate_config.assert_called_once_with(config) assert ht.config._provided_field_transformers == config['transformers'] - assert ht.config['field_transformers'] == config['transformers'] + assert ht.config.field_transformers == config['transformers'] assert ht.config._provided_field_sdtypes == config['sdtypes'] - assert ht.config['field_sdtypes'] == config['sdtypes'] + assert ht.config.field_sdtypes == config['sdtypes'] @patch('rdt.hyper_transformer.warnings') def test_set_config_already_fitted(self, mock_warnings): @@ -1083,7 +1062,7 @@ def test__validate_detect_config_called_incorrect_data(self): """ # Setup ht = HyperTransformer() - ht.config['field_sdtypes'] = {'col1': 'float', 'col2': 'categorical'} + ht.config.field_sdtypes = {'col1': 'float', 'col2': 'categorical'} data = pd.DataFrame({'col1': [1, 2], 'col3': ['a', 'b']}) error_msg = re.escape( 'The data you are trying to fit has different columns than the original ' @@ -1140,7 +1119,7 @@ def test_fit(self, get_default_transformer_mock): get_default_transformer_mock.return_value = datetime_transformer ht = HyperTransformer() - ht.config['field_transformers'] = field_transformers + ht.config.field_transformers = field_transformers ht._default_sdtype_transformers = default_sdtype_transformers ht._fit_field_transformer = Mock() ht._fit_field_transformer.return_value = data @@ -1372,7 +1351,7 @@ def test_update_transformers_by_sdtype_no_field_sdtypes(self, mock_print): ht.update_transformers_by_sdtype('categorical', object()) # Assert - assert ht.config['field_transformers'] == {} + assert ht.config.field_transformers == {} @patch('rdt.hyper_transformer.print') def test_update_transformers_by_sdtype_field_sdtypes_not_fitted(self, mock_print): @@ -1392,11 +1371,11 @@ def test_update_transformers_by_sdtype_field_sdtypes_not_fitted(self, mock_print """ # Setup ht = HyperTransformer() - ht.config['field_transformers'] = { + ht.config.field_transformers = { 'categorical_column': 'rdt.transformers.BaseTransformer', 'numerical_column': 'rdt.transformers.FloatFormatter', } - ht.config['field_sdtypes'] = { + ht.config.field_sdtypes = { 'categorical_column': 'categorical', 'numerical_column': 'numerical', @@ -1412,7 +1391,7 @@ def test_update_transformers_by_sdtype_field_sdtypes_not_fitted(self, mock_print 'categorical_column': transformer, 'numerical_column': 'rdt.transformers.FloatFormatter', } - assert ht.config['field_transformers'] == expected_field_transformers + assert ht.config.field_transformers == expected_field_transformers @patch('rdt.hyper_transformer.warnings') @patch('rdt.hyper_transformer.print') @@ -1436,10 +1415,10 @@ def test_update_transformers_by_sdtype_field_sdtypes_fitted(self, mock_print, mo # Setup ht = HyperTransformer() ht._fitted = True - ht.config['field_transformers'] = { + ht.config.field_transformers = { 'categorical_column': 'rdt.transformers.BaseTransformer' } - ht.config['field_sdtypes'] = {'categorical_column': 'categorical'} + ht.config.field_sdtypes = {'categorical_column': 'categorical'} transformer = object() # Run @@ -1453,7 +1432,7 @@ def test_update_transformers_by_sdtype_field_sdtypes_fitted(self, mock_print, mo mock_print.assert_not_called() mock_warnings.warn.assert_called_once_with(expected_warnings_msg) - assert ht.config['field_transformers'] == {'categorical_column': transformer} + assert ht.config.field_transformers == {'categorical_column': transformer} @patch('rdt.hyper_transformer.warnings') def test_update_transformers_fitted(self, mock_warnings): @@ -1480,7 +1459,7 @@ def test_update_transformers_fitted(self, mock_warnings): # Setup instance = HyperTransformer() instance._fitted = True - instance.config['field_transformers'] = {'a': object()} + instance.config.field_transformers = {'a': object()} mock_transformer = Mock() mock_transformer.get_input_sdtype.return_value = 'datetime' column_name_to_transformer = { @@ -1497,7 +1476,7 @@ def test_update_transformers_fitted(self, mock_warnings): ) mock_warnings.warn.assert_called_once_with(expected_message) - assert instance.config['field_transformers']['my_column'] == mock_transformer + assert instance.config.field_transformers['my_column'] == mock_transformer assert instance.config._provided_field_transformers == {'my_column': mock_transformer} @patch('rdt.hyper_transformer.warnings') @@ -1526,7 +1505,7 @@ def test_update_transformers_not_fitted(self, mock_warnings): # Setup instance = HyperTransformer() instance._fitted = False - instance.config['field_transformers'] = {'a': object()} + instance.config.field_transformers = {'a': object()} mock_transformer = Mock() mock_transformer.get_input_sdtype.return_value = 'datetime' column_name_to_transformer = { @@ -1538,7 +1517,7 @@ def test_update_transformers_not_fitted(self, mock_warnings): # Assert mock_warnings.warn.assert_not_called() - assert instance.config['field_transformers']['my_column'] == mock_transformer + assert instance.config.field_transformers['my_column'] == mock_transformer assert instance.config._provided_field_transformers == {'my_column': mock_transformer} def test_update_transformers_no_field_transformers(self): @@ -1607,8 +1586,8 @@ def test_update_transformers_missmatch_sdtypes(self, mock_warnings): instance = HyperTransformer() instance._fitted = False mock_numerical = Mock() - instance.config['field_transformers'] = {'my_column': mock_numerical} - instance.config['field_sdtypes'] = {'my_column': 'categorical'} + instance.config.field_transformers = {'my_column': mock_numerical} + instance.config.field_sdtypes = {'my_column': 'categorical'} mock_transformer = Mock() mock_transformer.get_input_sdtype.return_value = 'datetime' column_name_to_transformer = { @@ -1626,7 +1605,7 @@ def test_update_transformers_missmatch_sdtypes(self, mock_warnings): ) assert mock_warnings.called_once_with(expected_call) - assert instance.config['field_transformers']['my_column'] == mock_transformer + assert instance.config.field_transformers['my_column'] == mock_transformer assert instance.config._provided_field_transformers == {'my_column': mock_transformer} @patch('rdt.hyper_transformer.warnings') @@ -1653,8 +1632,8 @@ def test_update_sdtypes_fitted(self, mock_warnings): """ # Setup instance = HyperTransformer() - instance.config['field_transformers'] = {'a': FrequencyEncoder, 'b': FloatFormatter} - instance.config['field_sdtypes'] = {'a': 'categorical'} + instance.config.field_transformers = {'a': FrequencyEncoder, 'b': FloatFormatter} + instance.config.field_sdtypes = {'a': 'categorical'} instance._fitted = True instance._user_message = Mock() column_name_to_sdtype = { @@ -1675,7 +1654,7 @@ def test_update_sdtypes_fitted(self, mock_warnings): ) mock_warnings.warn.assert_called_once_with(expected_message) - assert instance.config['field_sdtypes'] == {'my_column': 'numerical', 'a': 'categorical'} + assert instance.config.field_sdtypes == {'my_column': 'numerical', 'a': 'categorical'} assert instance.config._provided_field_sdtypes == {'my_column': 'numerical'} instance._user_message.assert_called_once_with(user_message, 'Info') @@ -1706,7 +1685,7 @@ def test_update_sdtypes_not_fitted(self, mock_warnings): instance = HyperTransformer() instance._fitted = False instance._user_message = Mock() - instance.config['field_sdtypes'] = {'a': 'categorical'} + instance.config.field_sdtypes = {'a': 'categorical'} column_name_to_sdtype = { 'my_column': 'numerical' } @@ -1720,7 +1699,7 @@ def test_update_sdtypes_not_fitted(self, mock_warnings): "Use 'get_config()' to verify the transformers." ) mock_warnings.warn.assert_not_called() - assert instance.config['field_sdtypes'] == {'my_column': 'numerical', 'a': 'categorical'} + assert instance.config.field_sdtypes == {'my_column': 'numerical', 'a': 'categorical'} assert instance.config._provided_field_sdtypes == {'my_column': 'numerical'} instance._user_message.assert_called_once_with(user_message, 'Info') @@ -1742,7 +1721,7 @@ def test_update_sdtypes_no_field_sdtypes(self): # Setup instance = HyperTransformer() instance._fitted = False - instance.config['field_sdtypes'] = {} + instance.config.field_sdtypes = {} column_name_to_sdtype = { 'my_column': 'numerical' } @@ -1775,7 +1754,7 @@ def test_update_sdtypes_invalid_sdtype(self): instance._get_supported_sdtypes = Mock() instance._get_supported_sdtypes.return_value = [] instance._fitted = False - instance.config['field_sdtypes'] = { + instance.config.field_sdtypes = { 'my_column': 'categorical' } column_name_to_sdtype = { @@ -1821,7 +1800,7 @@ def test_update_sdtypes_different_sdtype(self, mock_warnings, default_mock): instance = HyperTransformer() instance._fitted = False instance._user_message = Mock() - instance.config['field_sdtypes'] = {'a': 'categorical'} + instance.config.field_sdtypes = {'a': 'categorical'} transformer_mock = Mock() default_mock.return_value = transformer_mock column_name_to_sdtype = { @@ -1837,8 +1816,8 @@ def test_update_sdtypes_different_sdtype(self, mock_warnings, default_mock): "Use 'get_config()' to verify the transformers." ) mock_warnings.warn.assert_not_called() - assert instance.config['field_sdtypes'] == {'a': 'numerical'} - assert instance.config['field_transformers'] == {'a': transformer_mock} + assert instance.config.field_sdtypes == {'a': 'numerical'} + assert instance.config.field_transformers == {'a': transformer_mock} assert instance.config._provided_field_sdtypes == {'a': 'numerical'} instance._user_message.assert_called_once_with(user_message, 'Info') @@ -1872,9 +1851,9 @@ def test_update_sdtypes_different_sdtype_transformer_provided(self, mock_warning instance = HyperTransformer() instance._fitted = False instance._user_message = Mock() - instance.field_sdtypes = {'a': 'categorical'} + instance.config.field_sdtypes = {'a': 'categorical'} transformer = FloatFormatter() - instance.field_transformers = {'a': transformer} + instance.config.field_transformers = {'a': transformer} column_name_to_sdtype = { 'a': 'numerical' } @@ -1888,9 +1867,9 @@ def test_update_sdtypes_different_sdtype_transformer_provided(self, mock_warning "Use 'get_config()' to verify the transformers." ) mock_warnings.warn.assert_not_called() - assert instance.field_sdtypes == {'a': 'numerical'} - assert instance.field_transformers == {'a': transformer} - assert instance._provided_field_sdtypes == {'a': 'numerical'} + assert instance.config.field_sdtypes == {'a': 'numerical'} + assert instance.config.field_transformers == {'a': transformer} + assert instance.config._provided_field_sdtypes == {'a': 'numerical'} instance._user_message.assert_called_once_with(user_message, 'Info') def test_get_transformer(self): From 2b30d0af20d7999148626947177a8bdeb89c58b4 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Fri, 8 Apr 2022 18:39:54 +0200 Subject: [PATCH 6/8] Fix unit tests and coverage --- rdt/hyper_transformer.py | 8 ----- tests/unit/test_hyper_transformer.py | 46 ++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/rdt/hyper_transformer.py b/rdt/hyper_transformer.py index 79ed8c05b..0e634bcea 100644 --- a/rdt/hyper_transformer.py +++ b/rdt/hyper_transformer.py @@ -147,14 +147,6 @@ def update_transformers(self, column_name_to_transformer, provided_transformer=F if provided_transformer: self._provided_field_transformers[column_name] = transformer - def get_field_transformers(self): - """Return the fields transformers.""" - return self.field_transformers - - def get_field_sdtypes(self): - """Return the fields sdtypes.""" - return self.field_sdtypes - def to_dict(self): """Return a `dict` object of `Config`.""" return { diff --git a/tests/unit/test_hyper_transformer.py b/tests/unit/test_hyper_transformer.py index 444082c1d..a6deb5b89 100644 --- a/tests/unit/test_hyper_transformer.py +++ b/tests/unit/test_hyper_transformer.py @@ -385,6 +385,52 @@ def test__unfit(self): assert ht._transformers_tree == {} ht.config.reset.assert_called_once() + @patch('rdt.hyper_transformer.get_default_transformer') + def test__learn_config(self, mock_get_default_transformer): + """Test the ``_learn_config`` method. + + Test that the method properly learns the config. + + - Setup: + - Create instance of HyperTransformer. + - Set some field_sdtypes to the config. + - Set a default `numerical` transformer. + + - Input: + - Fields, an array of fields. + + - Mock: + - Mock the `config` class. + - Patch `get_default_transformers` to return a transformer and assert calling it. + + - Side Effects: + - `instance.config.update_transformers` has been called three times with + the expected transformers. + """ + # Setup + ht = HyperTransformer() + ht._default_sdtype_transformers = {'numerical': FloatFormatter} + ht.config = Mock() + ht.config.field_transformers = {} + ht.config.field_sdtypes = { + 'a': 'numerical', + 'b': 'numerical', + 'c': 'other' + } + fields = ['a', 'b', 'c'] + mock_get_default_transformer.return_value = BinaryEncoder + + # Run + ht._learn_config(fields) + + # Assert + ht.config.update_transformers.call_args_list == [ + call('a', FloatFormatter), + call('b', FloatFormatter), + call('c', BinaryEncoder) + ] + mock_get_default_transformer.assert_called_once_with('other') + def test__create_multi_column_fields(self): """Test the ``_create_multi_column_fields`` method. From 2802267576b5d0083e4fa5f80e50e16ad363f4b5 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Fri, 8 Apr 2022 20:52:09 +0200 Subject: [PATCH 7/8] Fix printable version of config --- rdt/hyper_transformer.py | 18 +++++++++++------- rdt/transformers/base.py | 2 +- tests/unit/test_hyper_transformer.py | 20 ++++++++++---------- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/rdt/hyper_transformer.py b/rdt/hyper_transformer.py index 0e634bcea..0710e1376 100644 --- a/rdt/hyper_transformer.py +++ b/rdt/hyper_transformer.py @@ -1,6 +1,7 @@ """Hyper transformer module.""" import json +import re import warnings from collections import defaultdict from copy import deepcopy @@ -147,17 +148,20 @@ def update_transformers(self, column_name_to_transformer, provided_transformer=F if provided_transformer: self._provided_field_transformers[column_name] = transformer - def to_dict(self): - """Return a `dict` object of `Config`.""" - return { + def __repr__(self): + """Pretty print the dictionary.""" + config = { 'sdtypes': self.field_sdtypes, 'transformers': {k: repr(v) for k, v in self.field_transformers.items()} } - def __repr__(self): - """Pretty print the dictionary.""" - config = self.to_dict() - return json.dumps(config, indent=4) + printed = json.dumps(config, indent=4) + for transformer in self.field_transformers.values(): + quoted_transformer = f'"{transformer}"' + if quoted_transformer in printed: + printed = printed.replace(quoted_transformer, repr(transformer)) + + return printed class HyperTransformer: diff --git a/rdt/transformers/base.py b/rdt/transformers/base.py index 0480fa1c5..215214abb 100644 --- a/rdt/transformers/base.py +++ b/rdt/transformers/base.py @@ -186,7 +186,7 @@ def __repr__(self): for arg, value in instanced.items(): if defaults[arg] != value: - value = f"'{value}'" if isinstance(value, str) else value + value = repr(value) custom_args.append(f'{arg}={value}') args_string = ', '.join(custom_args) diff --git a/tests/unit/test_hyper_transformer.py b/tests/unit/test_hyper_transformer.py index a6deb5b89..2f0685beb 100644 --- a/tests/unit/test_hyper_transformer.py +++ b/tests/unit/test_hyper_transformer.py @@ -711,11 +711,11 @@ def test_detect_initial_config(self): field_transformers = {k: repr(v) for (k, v) in ht.config.field_transformers.items()} assert field_transformers == { - 'col1': "FloatFormatter(missing_value_replacement='mean')", - 'col2': 'FrequencyEncoder()', - 'col3': "BinaryEncoder(missing_value_replacement='mode')", - 'col4': "UnixTimestampEncoder(missing_value_replacement='mean')", - 'col5': "FloatFormatter(missing_value_replacement='mean')" + 'col1': repr(FloatFormatter(missing_value_replacement='mean')), + 'col2': repr(FrequencyEncoder()), + 'col3': repr(BinaryEncoder(missing_value_replacement='mode')), + 'col4': repr(UnixTimestampEncoder(missing_value_replacement='mean')), + 'col5': repr(FloatFormatter(missing_value_replacement='mean')) } expected_output = '\n'.join(( @@ -731,11 +731,11 @@ def test_detect_initial_config(self): ' "col5": "numerical"', ' },', ' "transformers": {', - ' "col1": "FloatFormatter(missing_value_replacement=\'mean\')",', - ' "col2": "FrequencyEncoder()",', - ' "col3": "BinaryEncoder(missing_value_replacement=\'mode\')",', - ' "col4": "UnixTimestampEncoder(missing_value_replacement=\'mean\')",', - ' "col5": "FloatFormatter(missing_value_replacement=\'mean\')"', + ' "col1": FloatFormatter(missing_value_replacement=\'mean\'),', + ' "col2": FrequencyEncoder(),', + ' "col3": BinaryEncoder(missing_value_replacement=\'mode\'),', + ' "col4": UnixTimestampEncoder(missing_value_replacement=\'mean\'),', + ' "col5": FloatFormatter(missing_value_replacement=\'mean\')', ' }', '}', '' From 155359e6d822427e3404b35ee94a90d2c63466f0 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Fri, 8 Apr 2022 20:55:24 +0200 Subject: [PATCH 8/8] Fix lint --- rdt/hyper_transformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rdt/hyper_transformer.py b/rdt/hyper_transformer.py index 0710e1376..b69ca3ac1 100644 --- a/rdt/hyper_transformer.py +++ b/rdt/hyper_transformer.py @@ -1,7 +1,6 @@ """Hyper transformer module.""" import json -import re import warnings from collections import defaultdict from copy import deepcopy