Skip to content

Commit

Permalink
Fix api
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer committed Apr 8, 2022
1 parent d017a06 commit 33200f9
Showing 1 changed file with 57 additions and 42 deletions.
99 changes: 57 additions & 42 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
get_default_transformer, get_transformer_instance, get_transformers_by_type)


class Config(dict):
"""Config dict for ``HyperTransformer`` with a better representation."""
class Config:
"""Config class for ``HyperTransformer``."""

_provided_field_transformers = None
_provided_field_sdtypes = None
field_sdtypes = None
field_transformers = None

@staticmethod
def _validate_config(config):
Expand Down Expand Up @@ -50,25 +52,24 @@ def set_config(self, config):
"""
self._validate_config(config)
self._provided_field_sdtypes = config['sdtypes']
self['field_sdtypes'].update(config['sdtypes'])
self.field_sdtypes.update(config['sdtypes'])
self._provided_field_transformers = config['transformers']
self['field_transformers'].update(config['transformers'])
self.field_transformers.update(config['transformers'])

def __init__(self, config=None):
super().__init__()
self._provided_field_sdtypes = {}
self._provided_field_transformers = {}
self['field_transformers'] = {}
self['field_sdtypes'] = {}
self.field_transformers = {}
self.field_sdtypes = {}
if config:
self.set_config(config)

def reset(self):
"""Reset the `field_sdtypes` and `field_transformers`."""
self['field_sdtypes'] = self._provided_field_sdtypes.copy()
self['field_transformers'] = self._provided_field_transformers.copy()
self.field_sdtypes = self._provided_field_sdtypes.copy()
self.field_transformers = self._provided_field_transformers.copy()

def update_sdtypes(self, column_name_to_sdtype):
def update_sdtypes(self, column_name_to_sdtype, provided_sdtype=False):
"""Update the ``sdtypes`` for each specified column name.
Args:
Expand All @@ -81,9 +82,7 @@ def update_sdtypes(self, column_name_to_sdtype):
if sdtype not in self._get_supported_sdtypes():
unsupported_sdtypes.append(sdtype)
elif self.field_sdtypes.get(column) != sdtype:
current_transformer = self.field_transformers.get(column)
if not current_transformer or current_transformer.get_input_sdtype() != sdtype:
transformers_to_update[column] = get_default_transformer(sdtype)
transformers_to_update[column] = get_default_transformer(sdtype)

if unsupported_sdtypes:
raise Error(
Expand All @@ -94,7 +93,8 @@ def update_sdtypes(self, column_name_to_sdtype):

self.field_sdtypes.update(column_name_to_sdtype)
self.field_transformers.update(transformers_to_update)
self._provided_field_sdtypes.update(column_name_to_sdtype)
if provided_sdtype:
self._provided_field_sdtypes.update(column_name_to_sdtype)

def update_transformers_by_sdtype(self, sdtype, transformer):
"""Update the transformers for the specified ``sdtype``.
Expand All @@ -108,41 +108,56 @@ def update_transformers_by_sdtype(self, sdtype, transformer):
transformer (rdt.transformers.BaseTransformer):
Transformer class or instance to be used for the given ``sdtype``.
"""
if not self['field_sdtypes']:
if not self.field_sdtypes:
raise Error(
'Nothing to update. Use the `detect_initial_config` method to '
'pre-populate all the sdtypes and transformers from your dataset.'
)

for field, field_sdtype in self['field_sdtypes'].items():
for field, field_sdtype in self.field_sdtypes.items():
if field_sdtype == sdtype:
self._provided_field_transformers[field] = transformer
self['field_transformers'][field] = transformer
self.field_transformers[field] = transformer

def update_transformers(self, column_name_to_transformer):
"""Update any of the transformers assigned to each of the column names.
def update_transformers(self, column_name_to_transformer, provided_transformer=False):
"""Update `self.field_transformers`.
Update the `self.field_transformers` with the provided column name and transformer, if
`provided_transformer` is `True` will automatically update the
`self._provided_field_transformers` which will not reset.
Args:
column_name_to_transformer(dict):
Dict mapping column names to transformers to be used for that column.
provided_transformer(bool):
Wether or not to add to `self._provided_field_transformers`.
"""
for column_name, transformer in column_name_to_transformer.items():
current_sdtype = self['field_sdtypes'].get(column_name)
current_sdtype = self.field_sdtypes.get(column_name)
if current_sdtype and current_sdtype != transformer.get_input_sdtype():
warnings.warn(
f'You are assigning a {transformer.get_input_sdtype()} transformer '
f'to a {current_sdtype} column ({column_name}). '
"If the transformer doesn't match the sdtype, it may lead to errors."
)

self['field_transformers'][column_name] = transformer
self._provided_field_transformers[column_name] = transformer
self.field_transformers[column_name] = transformer
if provided_transformer:
self._provided_field_transformers[column_name] = transformer

def get_field_transformers(self):
"""Return the fields transformers."""
return self.field_transformers

def get_field_sdtypes(self):
"""Return the fields sdtypes."""
return self.field_sdtypes

def __repr__(self):
"""Pretty print the dictionary."""
config = {
'sdtypes': self['field_sdtypes'],
'transformers': {k: repr(v) for k, v in self['field_transformers'].items()}
'sdtypes': self.field_sdtypes,
'transformers': {k: repr(v) for k, v in self.field_transformers.items()}
}
return json.dumps(config, indent=4)

Expand Down Expand Up @@ -244,15 +259,15 @@ def _subset(input_list, other_list, not_in=False):

def _create_multi_column_fields(self):
multi_column_fields = {}
for field in list(self.config['field_sdtypes']) + list(self.config['field_transformers']):
for field in list(self.config.field_sdtypes) + list(self.config.field_transformers):
if isinstance(field, tuple):
for column in field:
multi_column_fields[column] = field

return multi_column_fields

def _validate_field_transformers(self):
for field in self.config['field_transformers']:
for field in self.config.field_transformers:
if self._field_in_set(field, self._specified_fields):
raise ValueError(f'Multiple transformers specified for the field {field}. '
'Each field can have at most one transformer defined in '
Expand Down Expand Up @@ -288,7 +303,7 @@ def get_config(self):
- sdtypes: A dictionary mapping column names to their ``sdtypes``.
- transformers: A dictionary mapping column names to their transformer instances.
"""
return self.config
return self.config.to_dict()

def set_config(self, config):
"""Set the ``HyperTransformer`` configuration.
Expand Down Expand Up @@ -330,10 +345,10 @@ def update_sdtypes(self, column_name_to_sdtype):
column_name_to_sdtype(dict):
Dict mapping column names to ``sdtypes`` for that column.
"""
if len(self.config['field_sdtypes']) == 0:
if len(self.config.field_sdtypes) == 0:
raise Error(self._DETECT_CONFIG_MESSAGE)

self.config.update_sdtypes(column_name_to_sdtype)
self.config.update_sdtypes(column_name_to_sdtype, provided_sdtype=True)
self._user_message(
'The transformers for these columns may change based on the new sdtype.\n'
"Use 'get_config()' to verify the transformers.", 'Info'
Expand All @@ -350,10 +365,10 @@ def update_transformers(self, column_name_to_transformer):
"""
if self._fitted:
warnings.warn(self._REFIT_MESSAGE)
if len(self.config['field_transformers']) == 0:
if len(self.config.field_transformers) == 0:
raise Error(self._DETECT_CONFIG_MESSAGE)

self.config.update_transformers(column_name_to_transformer)
self.config.update_transformers(column_name_to_transformer, provided_transformer=True)

def get_transformer(self, field):
"""Get the transformer instance used for a field.
Expand Down Expand Up @@ -456,7 +471,7 @@ def get_transformer_tree_yaml(self):
def _set_field_sdtype(self, data, field):
clean_data = data[field].dropna()
kind = clean_data.infer_objects().dtype.kind
self.config['field_sdtypes'][field] = self._DTYPES_TO_SDTYPES[kind]
self.config.update_sdtypes({field: self._DTYPES_TO_SDTYPES[kind]})

def _unfit(self):
self.config.reset()
Expand All @@ -471,15 +486,15 @@ def _learn_config(self, data):
"""Unfit the HyperTransformer and learn the sdtypes and transformers of the data."""
self._unfit()
for field in data:
if field not in self.config['field_sdtypes']:
if field not in self.config.field_sdtypes:
self._set_field_sdtype(data, field)
if field not in self.config['field_transformers']:
sdtype = self.config['field_sdtypes'][field]
if field not in self.config.field_transformers:
sdtype = self.config.field_sdtypes.get(field)
if sdtype in self._default_sdtype_transformers:
default_transformer = self._default_sdtype_transformers[sdtype]
self.config['field_transformers'][field] = default_transformer
self.config.update_transformers({field: default_transformer})
else:
self.config['field_transformers'][field] = get_default_transformer(sdtype)
self.config.update_transformers({field: get_default_transformer(sdtype)})

def detect_initial_config(self, data):
"""Print the configuration of the data.
Expand Down Expand Up @@ -508,8 +523,8 @@ def detect_initial_config(self, data):

def _get_next_transformer(self, output_field, output_sdtype, next_transformers):
next_transformer = None
if output_field in self.config['field_transformers']:
next_transformer = self.config['field_transformers'][output_field]
if output_field in self.config.field_transformers:
next_transformer = self.config.field_transformers.get(output_field)

elif output_sdtype not in self._valid_output_sdtypes:
if next_transformers is not None and output_field in next_transformers:
Expand Down Expand Up @@ -572,14 +587,14 @@ def _sort_output_columns(self):

def _validate_detect_config_called(self, data):
"""Assert the ``detect_initial_config`` method is correcly called before fitting."""
if len(self.config['field_sdtypes']) == 0 and len(self.config['field_transformers']) == 0:
if len(self.config.field_sdtypes) == 0 and len(self.config.field_transformers) == 0:
raise NotFittedError(
"No config detected. Set the config using 'set_config' or pre-populate "
"it automatically from your data using 'detect_initial_config' prior to "
'fitting your data.'
)

fields = list(self.config['field_sdtypes'].keys())
fields = list(self.config.field_sdtypes.keys())
unknown_columns = self._subset(data.columns, fields, not_in=True)
if unknown_columns:
raise NotFittedError(
Expand All @@ -601,7 +616,7 @@ def fit(self, data):
self._input_columns = list(data.columns)
for field in self._input_columns:
data = self._fit_field_transformer(
data, field, self.config['field_transformers'][field])
data, field, self.config.field_transformers.get(field))

self._validate_all_fields_fitted()
self._fitted = True
Expand Down

0 comments on commit 33200f9

Please sign in to comment.