Skip to content

Commit

Permalink
Make synthesizers work with column_relationships (#1727)
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo authored Jan 8, 2024
1 parent 64e8df2 commit 6659835
Show file tree
Hide file tree
Showing 12 changed files with 326 additions and 359 deletions.
82 changes: 35 additions & 47 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sdv.data_processing.utils import load_module_from_path
from sdv.errors import SynthesizerInputError, log_exc_stacktrace
from sdv.metadata.single_table import SingleTableMetadata
from sdv.metadata.validation import _check_import_address_transformers

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,6 +67,10 @@ class DataProcessor:
'M': 'datetime',
}

_COLUMN_RELATIONSHIP_TO_TRANSFORMER = {
'address': 'RandomLocationGenerator',
}

def _update_numerical_transformer(self, enforce_rounding, enforce_min_max_values):
custom_float_formatter = rdt.transformers.FloatFormatter(
missing_value_replacement='mean',
Expand All @@ -75,6 +80,26 @@ def _update_numerical_transformer(self, enforce_rounding, enforce_min_max_values
)
self._transformers_by_sdtype.update({'numerical': custom_float_formatter})

def _detect_multi_column_transformers(self):
"""Detect if there are any multi column transformers in the metadata.
Returns:
dict:
A dictionary mapping column names to the multi column transformer.
"""
result = {}
if self.metadata.column_relationships:
for relationship in self.metadata._valid_column_relationships:
column_names = tuple(relationship['column_names'])
relationship_type = relationship['type']
if relationship_type in self._COLUMN_RELATIONSHIP_TO_TRANSFORMER:
transformer_name = self._COLUMN_RELATIONSHIP_TO_TRANSFORMER[relationship_type]
module = getattr(rdt.transformers, relationship_type)
transformer = getattr(module, transformer_name)
result[column_names] = transformer(locales=self._locales)

return result

def __init__(self, metadata, enforce_rounding=True, enforce_min_max_values=True,
model_kwargs=None, table_name=None, locales=None):
self.metadata = metadata
Expand All @@ -90,7 +115,7 @@ def __init__(self, metadata, enforce_rounding=True, enforce_min_max_values=True,
self._transformers_by_sdtype = deepcopy(get_default_transformers())
self._transformers_by_sdtype['id'] = rdt.transformers.RegexGenerator()
del self._transformers_by_sdtype['text']
self.grouped_columns_to_transformers = {}
self.grouped_columns_to_transformers = self._detect_multi_column_transformers()

self._update_numerical_transformer(enforce_rounding, enforce_min_max_values)
self._hyper_transformer = rdt.HyperTransformer()
Expand All @@ -115,15 +140,6 @@ def _get_grouped_columns(self):
col for col_tuple in self.grouped_columns_to_transformers for col in col_tuple
]

def _check_import_address_transformers(self):
"""Check that the address transformers can be imported."""
has_randomlocationgenerator = hasattr(rdt.transformers, 'RandomLocationGenerator')
has_regionalanonymizer = hasattr(rdt.transformers, 'RegionalAnonymizer')
if not has_randomlocationgenerator or not has_regionalanonymizer:
raise ImportError(
'You must have SDV Enterprise with the address add-on to use the address features'
)

def _get_columns_in_address_transformer(self):
"""Get the columns that are part of an address transformer.
Expand All @@ -132,14 +148,14 @@ def _get_columns_in_address_transformer(self):
A list of columns that are part of the address transformers.
"""
try:
self._check_import_address_transformers()
_check_import_address_transformers()
result = []
for col_tuple, transformer in self.grouped_columns_to_transformers.items():
is_randomlocationgenerator = isinstance(
transformer, rdt.transformers.RandomLocationGenerator
transformer, rdt.transformers.address.RandomLocationGenerator
)
is_regionalanonymizer = isinstance(
transformer, rdt.transformers.RegionalAnonymizer
transformer, rdt.transformers.address.RegionalAnonymizer
)
if is_randomlocationgenerator or is_regionalanonymizer:
result.extend(list(col_tuple))
Expand All @@ -148,40 +164,6 @@ def _get_columns_in_address_transformer(self):
except ImportError:
return []

def _get_address_transformer(self, anonymization_level):
"""Get the address transformer.
Args:
anonymization_level (str):
The anonymization level for the address transformer.
"""
locales = self._locales if self._locales else ['en_US']
self._check_import_address_transformers()
if anonymization_level == 'street_address':
return rdt.transformers.RegionalAnonymizer(locales=locales)

return rdt.transformers.RandomLocationGenerator(locales=locales)

def set_address_transformer(self, column_names, anonymization_level):
"""Set the address transformer.
Args:
column_names (tuple[str]):
The column names to set the transformer for.
anonymization_level (str):
The anonymization level for the address transformer.
"""
columns_to_sdtypes = {
column: self.metadata.columns[column]['sdtype'] for column in column_names
}
transformer = self._get_address_transformer(anonymization_level)
transformer._validate_sdtypes(columns_to_sdtypes)

if self._prepared_for_fitting:
self.update_transformers({column_names: transformer})

self.grouped_columns_to_transformers[column_names] = transformer

def get_model_kwargs(self, model_name):
"""Return the required model kwargs for the indicated model.
Expand Down Expand Up @@ -672,6 +654,12 @@ def update_transformers(self, column_name_to_transformer):
warnings.filterwarnings('ignore', module='rdt.hyper_transformer')
self._hyper_transformer.update_transformers(column_name_to_transformer)

self.grouped_columns_to_transformers = {
col_tuple: transformer
for col_tuple, transformer in self._hyper_transformer.field_transformers.items()
if isinstance(col_tuple, tuple)
}

def _fit_hyper_transformer(self, data):
"""Create and return a new ``rdt.HyperTransformer`` instance.
Expand Down
6 changes: 3 additions & 3 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,14 +585,14 @@ def _validate_column_relationships_foreign_keys(
f'Cannot use foreign keys {invalid_columns} in column relationship.'
)

def add_column_relationship(self, relationship_type, table_name, column_names):
def add_column_relationship(self, table_name, relationship_type, column_names):
"""Add a column relationship to a table in the metadata.
Args:
relationship_type (str):
The type of the relationship.
table_name (str):
The name of the table to add this relationship to.
relationship_type (str):
The type of the relationship.
column_names (list[str]):
The list of column names involved in this relationship.
"""
Expand Down
51 changes: 33 additions & 18 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,22 +615,22 @@ def _append_error(self, errors, method, *args, **kwargs):
except InvalidMetadataError as e:
errors.append(e)

def _validate_column_relationship(self, relationship_type, column_names):
def _validate_column_relationship(self, relationship):
"""Validate a column relationship.
Verify that a column relationship has a valid relationship type, has
columns that are present in the metadata, and that those columns have
valid sdtypes for the relationship type.
Args:
relationship_type (str):
Type of column relationship.
column_names (list[str]):
List of column names in this column relationship.
relationship (dict):
Column relationship to validate.
Raises:
- ``InvalidMetadataError`` if relationship is invalid
"""
relationship_type = relationship['type']
column_names = relationship['column_names']
if relationship_type not in self._COLUMN_RELATIONSHIP_TYPES:
raise InvalidMetadataError(
f"Unknown column relationship type '{relationship_type}'. "
Expand All @@ -645,8 +645,22 @@ def _validate_column_relationship(self, relationship_type, column_names):
errors.append(
f"Cannot use primary key '{column}' in column relationship."
)

columns_to_sdtypes = {
column: self.columns.get(column, {}).get('sdtype') for column in column_names
}
try:
self._COLUMN_RELATIONSHIP_TYPES[relationship_type](self.columns, column_names)
self._COLUMN_RELATIONSHIP_TYPES[relationship_type](columns_to_sdtypes)

except ImportError:
warnings.warn(
f"The metadata contains a column relationship of type '{relationship_type}'. "
f'which requires the {relationship_type} add-on.'
'This relationship will be ignored. For higher quality data in this'
' relationship, please inquire about the SDV Enterprise tier.'
)
raise ImportError

except Exception as e:
errors.append(str(e))

Expand Down Expand Up @@ -688,15 +702,16 @@ def _validate_all_column_relationships(self, column_relationships):

# Validate each individual relationship
errors = []
for relationship in column_relationships:
relationship_type = relationship['type']
columns = relationship['column_names']
self._append_error(
errors,
self._validate_column_relationship,
relationship_type,
columns
)
self._valid_column_relationships = deepcopy(column_relationships)
for idx, relationship in enumerate(column_relationships):
try:
self._append_error(
errors,
self._validate_column_relationship,
relationship,
)
except ImportError:
self._valid_column_relationships.pop(idx)

if errors:
raise InvalidMetadataError(
Expand All @@ -713,10 +728,10 @@ def add_column_relationship(self, relationship_type, column_names):
column_names (list[str]):
List of column names in the relationship.
"""
to_check = [{'type': relationship_type, 'column_names': column_names}] + \
self.column_relationships
relationship = {'type': relationship_type, 'column_names': column_names}
to_check = [relationship] + self.column_relationships
self._validate_all_column_relationships(to_check)
self.column_relationships.append({'type': relationship_type, 'column_names': column_names})
self.column_relationships.append(relationship)

def validate(self):
"""Validate the metadata.
Expand Down
45 changes: 24 additions & 21 deletions sdv/metadata/validation.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,35 @@
"""Column relationship validation functions."""
import rdt
from rdt.errors import TransformerInputError

from sdv.metadata.errors import InvalidMetadataError


def validate_address_sdtypes(column_metadata, column_names):
def _check_import_address_transformers():
"""Check that the address transformers can be imported."""
error_message = (
'You must have SDV Enterprise with the address add-on to use the address features'
)
if not hasattr(rdt.transformers, 'address'):
raise ImportError(error_message)

has_randomlocationgenerator = hasattr(rdt.transformers.address, 'RandomLocationGenerator')
has_regionalanonymizer = hasattr(rdt.transformers.address, 'RegionalAnonymizer')
if not has_randomlocationgenerator or not has_regionalanonymizer:
raise ImportError(error_message)


def validate_address_sdtypes(columns_to_sdtypes):
"""Validate sdtypes for address column relationship.
Args:
column_metadata (dict):
Column metadata for the table.
column_names (list[str]):
List of the column names involved in this relationship.
- columns_to_sdtypes (dict): Dictionary mapping column names to sdtypes.
Raises:
- ``InvalidMetadataError`` if column sdtypes are invalid for the relationship.
"""
valid_sdtypes = (
'country_code', 'administrative_unit', 'city', 'postcode', 'street_address',
'secondary_address', 'state', 'state_abbr'
)
bad_columns = []
for column_name in column_names:
if column_name not in column_metadata:
continue
if column_metadata[column_name].get('sdtype') not in valid_sdtypes:
bad_columns.append(column_name)

if bad_columns:
raise InvalidMetadataError(
f'Columns {bad_columns} have unsupported sdtypes for column relationship '
"type 'address'."
)
_check_import_address_transformers()
try:
rdt.transformers.address.RandomLocationGenerator._validate_sdtypes(columns_to_sdtypes)
except TransformerInputError as error:
raise InvalidMetadataError(str(error))
62 changes: 5 additions & 57 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,64 +92,12 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
self._fitted_date = None
self._fitted_sdv_version = None

def _check_address_columns(self, column_names):
"""Check that the column is valid for the address transformer.
Args:
column_names (tuple[str]):
The column name to check.
existing_address (set):
The set of columns that are already being used in a set of address columns.
"""
missing_columns_metadata = []
columns_in_multi_columns = []
multi_columns = self._data_processor._get_grouped_columns()

for column in column_names:
if column not in self.metadata.columns:
missing_columns_metadata.append(column)

if column in multi_columns:
columns_in_multi_columns.append(column)

if missing_columns_metadata:
to_print = "', '".join(missing_columns_metadata)
raise ValueError(
f"Unknown column names ('{to_print}'). Please choose column names listed"
' in the metadata for your table.'
)

if columns_in_multi_columns:
to_print = "', '".join(columns_in_multi_columns)
raise ValueError(
f"Columns '{to_print}' are already being used in a multi-column transformer."
)

def set_address_columns(self, column_names, anonymization_level='full'):
"""Set the address multi-column transformer.
Args:
column_names (tuple[str]):
The column names to be used for the address transformer.
anonymization_level (str):
The anonymization level to use for the address transformer.
"""
if anonymization_level not in {'full', 'street_address'}:
raise ValueError(
f"Invalid value '{anonymization_level}' for parameter 'anonymization_level'."
" Please provide 'full' or 'street_address'."
)

if not isinstance(column_names, tuple):
column_names = tuple(column_names) if len(column_names) > 1 else (column_names,)

self._check_address_columns(column_names)
self._data_processor.set_address_transformer(column_names, anonymization_level)
if self._fitted:
warnings.warn(
'Please refit your synthesizer for the address changes to appear in'
' your synthetic data.'
)
"""Set the address multi-column transformer."""
warnings.warn(
'`set_address_columns` is deprecated. Please add these columns directly to your'
' metadata using `add_column_relationship`.', DeprecationWarning
)

def _validate_metadata(self, data):
"""Validate that the data follows the metadata."""
Expand Down
Loading

0 comments on commit 6659835

Please sign in to comment.