Skip to content

Commit

Permalink
Add infer_sdtypes and infer_keys parameters to `detect_from_dataf…
Browse files Browse the repository at this point in the history
…rames` method (#2363)
  • Loading branch information
fealho authored Feb 11, 2025
1 parent 363d8bd commit c1c5a19
Show file tree
Hide file tree
Showing 7 changed files with 581 additions and 84 deletions.
55 changes: 50 additions & 5 deletions sdv/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,13 @@ def load_from_dict(cls, metadata_dict, single_table_name=None):
instance._set_metadata_dict(metadata_dict, single_table_name)
return instance

@staticmethod
def _validate_infer_sdtypes(infer_sdtypes):
if not isinstance(infer_sdtypes, bool):
raise ValueError("'infer_sdtypes' must be a boolean value.")

@classmethod
def detect_from_dataframes(cls, data):
def detect_from_dataframes(cls, data, infer_sdtypes=True, infer_keys='primary_and_foreign'):
"""Detect the metadata for all tables in a dictionary of dataframes.
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrames``.
Expand All @@ -71,23 +76,50 @@ def detect_from_dataframes(cls, data):
Args:
data (dict):
Dictionary of table names to dataframes.
infer_sdtypes (bool):
A boolean describing whether to infer the sdtypes of each column.
If True it infers the sdtypes based on the data.
If False it does not infer the sdtypes and all columns are marked as unknown.
Defaults to True.
infer_keys (str):
A string describing whether to infer the primary and/or foreign keys. Options are:
- 'primary_and_foreign': Infer the primary keys in each table,
and the foreign keys in other tables that refer to them
- 'primary_only': Infer only the primary keys of each table
- None: Do not infer any keys
Defaults to 'primary_and_foreign'.
Returns:
Metadata:
A new metadata object with the sdtypes detected from the data.
"""
if not data or not all(isinstance(df, pd.DataFrame) for df in data.values()):
raise ValueError('The provided dictionary must contain only pandas DataFrame objects.')
if infer_keys not in ['primary_and_foreign', 'primary_only', None]:
raise ValueError(
"'infer_keys' must be one of: 'primary_and_foreign', 'primary_only', None."
)
cls._validate_infer_sdtypes(infer_sdtypes)

metadata = Metadata()
for table_name, dataframe in data.items():
metadata.detect_table_from_dataframe(table_name, dataframe)
metadata.detect_table_from_dataframe(
table_name, dataframe, infer_sdtypes, None if infer_keys is None else 'primary_only'
)

if infer_keys == 'primary_and_foreign':
metadata._detect_relationships(data)

metadata._detect_relationships(data)
return metadata

@classmethod
def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME):
def detect_from_dataframe(
cls,
data,
table_name=DEFAULT_SINGLE_TABLE_NAME,
infer_sdtypes=True,
infer_keys='primary_only',
):
"""Detect the metadata for a DataFrame.
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``.
Expand All @@ -96,16 +128,29 @@ def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME):
Args:
data (pandas.DataFrame):
Dictionary of table names to dataframes.
infer_sdtypes (bool):
A boolean describing whether to infer the sdtypes of each column.
If True it infers the sdtypes based on the data.
If False it does not infer the sdtypes and all columns are marked as unknown.
Defaults to True.
infer_keys (str):
A string describing whether to infer the primary keys. Options are:
- 'primary_only': Infer only the primary keys of each table
- None: Do not infer any keys
Defaults to 'primary_only'.
Returns:
Metadata:
A new metadata object with the sdtypes detected from the data.
"""
if not isinstance(data, pd.DataFrame):
raise ValueError('The provided data must be a pandas DataFrame object.')
if infer_keys not in ['primary_only', None]:
raise ValueError("'infer_keys' must be one of: 'primary_only', None.")
cls._validate_infer_sdtypes(infer_sdtypes)

metadata = Metadata()
metadata.detect_table_from_dataframe(table_name, data)
metadata.detect_table_from_dataframe(table_name, data, infer_sdtypes, infer_keys)
return metadata

def _set_metadata_dict(self, metadata, single_table_name=None):
Expand Down
16 changes: 14 additions & 2 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,9 @@ def _detect_relationships(self, data=None):
)
continue

def detect_table_from_dataframe(self, table_name, data):
def detect_table_from_dataframe(
self, table_name, data, infer_sdtypes=True, infer_keys='primary_only'
):
"""Detect the metadata for a table from a dataframe.
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``,
Expand All @@ -541,10 +543,20 @@ def detect_table_from_dataframe(self, table_name, data):
Name of the table to detect.
data (pandas.DataFrame):
``pandas.DataFrame`` to detect the metadata from.
infer_sdtypes (bool):
A boolean describing whether to infer the sdtypes of each column.
If True it infers the sdtypes based on the data.
If False it does not infer the sdtypes and all columns are marked as unknown.
Defaults to True.
infer_keys (str):
A string describing whether to infer the primary keys. Options are:
- 'primary_only': Infer only the primary keys of each table
- None: Do not infer any keys
Defaults to 'primary_only'.
"""
self._validate_table_not_detected(table_name)
table = SingleTableMetadata()
table._detect_columns(data, table_name)
table._detect_columns(data, table_name, infer_sdtypes, infer_keys)
self.tables[table_name] = table
self._log_detected_table(table)

Expand Down
85 changes: 50 additions & 35 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,53 +595,67 @@ def _detect_primary_key(self, data):

return None

def _detect_columns(self, data, table_name=None):
def _detect_columns(self, data, table_name=None, infer_sdtypes=True, infer_keys='primary_only'):
"""Detect the columns' sdtypes from the data.
Args:
data (pandas.DataFrame):
The data to be analyzed.
table_name (str):
The name of the table to be analyzed. Defaults to ``None``.
infer_sdtypes (bool):
A boolean describing whether to infer the sdtypes of each column.
If True it infers the sdtypes based on the data.
If False it does not infer the sdtypes and all columns are marked as unknown.
Defaults to True.
infer_keys (str):
A string describing whether to infer the primary keys. Options are:
- 'primary_only': Infer the primary keys.
- None: Do not infer any keys.
Defaults to 'primary_only'.
"""
old_columns = data.columns
data.columns = data.columns.astype(str)
for field in data:
try:
column_data = data[field]
clean_data = column_data.dropna()
dtype = clean_data.infer_objects().dtype.kind

sdtype = self._detect_pii_column(field)
if sdtype is None:
if dtype in self._DTYPES_TO_SDTYPES:
sdtype = self._DTYPES_TO_SDTYPES[dtype]
elif dtype in ['i', 'f', 'u']:
sdtype = self._determine_sdtype_for_numbers(column_data)

elif dtype == 'O':
sdtype = self._determine_sdtype_for_objects(column_data)
if infer_sdtypes:
try:
column_data = data[field]
clean_data = column_data.dropna()
dtype = clean_data.infer_objects().dtype.kind

sdtype = self._detect_pii_column(field)
if sdtype is None:
table_str = f"table '{table_name}' " if table_name else ''
error_message = (
f"Unsupported data type for {table_str}column '{field}' (kind: {dtype}"
"). The valid data types are: 'object', 'int', 'float', 'datetime',"
" 'bool'."
)
raise InvalidMetadataError(error_message)

except Exception as e:
error_type = type(e).__name__
if error_type == 'InvalidMetadataError':
raise e

table_str = f"table '{table_name}' " if table_name else ''
error_message = (
f"Unable to detect metadata for {table_str}column '{field}' due to an invalid "
f'data format.\n {error_type}: {e}'
)
raise InvalidMetadataError(error_message) from e
if dtype in self._DTYPES_TO_SDTYPES:
sdtype = self._DTYPES_TO_SDTYPES[dtype]
elif dtype in ['i', 'f', 'u']:
sdtype = self._determine_sdtype_for_numbers(column_data)

elif dtype == 'O':
sdtype = self._determine_sdtype_for_objects(column_data)

if sdtype is None:
table_str = f"table '{table_name}' " if table_name else ''
error_message = (
f"Unsupported data type for {table_str}column '{field}' "
f"(kind: {dtype}). The valid data types are: 'object', "
"'int', 'float', 'datetime', 'bool'."
)
raise InvalidMetadataError(error_message)

except Exception as e:
error_type = type(e).__name__
if error_type == 'InvalidMetadataError':
raise e

table_str = f"table '{table_name}' " if table_name else ''
error_message = (
f"Unable to detect metadata for {table_str}column '{field}' due "
f'to an invalid data format.\n {error_type}: {e}'
)
raise InvalidMetadataError(error_message) from e

else:
sdtype = 'unknown'

column_dict = {'sdtype': sdtype}
sdtype_in_reference = sdtype in self._REFERENCE_TO_SDTYPE.values()
Expand All @@ -655,7 +669,8 @@ def _detect_columns(self, data, table_name=None):

self.columns[field] = deepcopy(column_dict)

self.primary_key = self._detect_primary_key(data)
if infer_keys == 'primary_only':
self.primary_key = self._detect_primary_key(data)
self._updated = True
data.columns = old_columns

Expand Down
Loading

0 comments on commit c1c5a19

Please sign in to comment.