diff --git a/rdt/hyper_transformer.py b/rdt/hyper_transformer.py index 9a0ee88c..32b1bc22 100644 --- a/rdt/hyper_transformer.py +++ b/rdt/hyper_transformer.py @@ -154,7 +154,20 @@ def get_config(self): }) def _update_multi_column_config(self, column_name_to_transformer): + """Update the multi column fields in the config. + Review all the existing multi column fields and update them if necessary. + + Args: + column_name_to_transformer (dict): + Dict mapping column names to transformers to be used for that column. + + Returns: + dict: + Dict mapping all the multi column fields to their transformers. + Include unchanged multi column fields, new multi column fields and + updated multi column fields. + """ column_names = [ item for key in column_name_to_transformer.keys() for item in (key if isinstance(key, tuple) else (key,)) @@ -203,6 +216,17 @@ def _update_multi_column_config(self, column_name_to_transformer): return multi_columns_to_transformer def _validate_multi_column_transformers(self, column_name_to_transformer): + """Validate the given multi column transformers are valid. + + Update the ``column_name_to_transformer`` dict to include changes required + by the multi column transformers. If a multi column transformer is no longer + valid according to its columns names, then it will be replaced by the default + transformer for the sdtype of the columns. + + Args: + column_name_to_transformer (dict): + Dict mapping column names to transformers to be used for that column. + """ multi_columns_to_transformer = self._update_multi_column_config(column_name_to_transformer) for columns in list(multi_columns_to_transformer.keys()): transformer = multi_columns_to_transformer[columns] @@ -223,6 +247,17 @@ def _validate_multi_column_transformers(self, column_name_to_transformer): return column_name_to_transformer def _update_multi_column_transformers(self, column_name_to_transformer): + """Update the transformer field for multi column fields. + + Args: + column_name_to_transformer (dict): + Dict mapping column names to transformers to be used for that column. + + Returns: + dict: + Updated ``column_name_to_transformer`` dict with everything to update + multi and single column fields. + """ column_name_to_transformer = self._validate_multi_column_transformers( column_name_to_transformer ) @@ -242,6 +277,12 @@ def _update_multi_column_transformers(self, column_name_to_transformer): def _update_single_column_transformers(self, column_name_to_transformer): + """Update the transformer field for single column fields. + + Args: + column_name_to_transformer (dict): + Dict mapping column names to transformers to be used for that column. + """ for column, transformer in column_name_to_transformer.items(): if isinstance(column, tuple): continue