Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Feb 7, 2025
1 parent 4be1c4c commit 52ef546
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 33 deletions.
35 changes: 21 additions & 14 deletions sdv/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,10 @@ def load_from_dict(cls, metadata_dict, single_table_name=None):
return instance

@staticmethod
def _validate_infer_sdtypes_and_keys(infer_sdtypes, infer_keys):
def _validate_infer_sdtypes(infer_sdtypes):
if not isinstance(infer_sdtypes, bool):
raise ValueError("'infer_sdtypes' must be a boolean value.")

if infer_keys not in ['primary_and_foreign', 'primary_only', None]:
raise ValueError(
"'infer_keys' must be one of: 'primary_and_foreign', 'primary_only', None."
)

@classmethod
def detect_from_dataframes(cls, data, infer_sdtypes=True, infer_keys='primary_and_foreign'):
"""Detect the metadata for all tables in a dictionary of dataframes.
Expand Down Expand Up @@ -100,19 +95,29 @@ def detect_from_dataframes(cls, data, infer_sdtypes=True, infer_keys='primary_an
"""
if not data or not all(isinstance(df, pd.DataFrame) for df in data.values()):
raise ValueError('The provided dictionary must contain only pandas DataFrame objects.')
cls._validate_detect_from_dataframes(infer_sdtypes, infer_keys)
if infer_keys not in ['primary_and_foreign', 'primary_only', None]:
raise ValueError(
"'infer_keys' must be one of: 'primary_and_foreign', 'primary_only', None."
)
cls._validate_infer_sdtypes(infer_sdtypes)

metadata = Metadata()
for table_name, dataframe in data.items():
metadata.detect_table_from_dataframe(table_name, dataframe, infer_sdtypes, infer_keys)
metadata.detect_table_from_dataframe(
table_name,
dataframe,
infer_sdtypes,
None if infer_keys is None else 'primary_only'
)

if infer_keys == 'primary_and_foreign':
metadata._detect_relationships(data)

return metadata

@classmethod
def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME, infer_sdtypes=True, infer_keys='primary_and_foreign'):
def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME,
infer_sdtypes=True, infer_keys='primary_only'):
"""Detect the metadata for a DataFrame.
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``.
Expand All @@ -128,22 +133,24 @@ def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME, infer
Defaults to True.
infer_keys (str):
A string describing whether to infer the primary and/or foreign keys. Options are:
- 'primary_and_foreign': Infer the primary keys in each table,
and the foreign keys in other tables that refer to them
- 'primary_only': Infer only the primary keys of each table
- None: Do not infer any keys
Defaults to 'primary_and_foreign'.
Defaults to 'primary_only'.
Returns:
Metadata:
A new metadata object with the sdtypes detected from the data.
"""
if not isinstance(data, pd.DataFrame):
raise ValueError('The provided data must be a pandas DataFrame object.')
cls._validate_infer_sdtypes_and_keys(infer_sdtypes, infer_keys)
if infer_keys not in ['primary_only', None]:
raise ValueError(
"'infer_keys' must be one of: 'primary_only', None."
)
cls._validate_infer_sdtypes(infer_sdtypes)

metadata = Metadata()
metadata.detect_table_from_dataframe(table_name, data)
metadata.detect_table_from_dataframe(table_name, data, infer_sdtypes, infer_keys)
return metadata

def _set_metadata_dict(self, metadata, single_table_name=None):
Expand Down
26 changes: 24 additions & 2 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,19 @@ def _detect_relationships(self, data=None):
)
continue

def detect_table_from_dataframe(self, table_name, data):
@staticmethod
def _validate_infer_sdtypes_and_keys(infer_sdtypes, infer_keys):
if not isinstance(infer_sdtypes, bool):
raise ValueError("'infer_sdtypes' must be a boolean value.")

if infer_keys not in ['primary_only', None]:
raise ValueError(
"'infer_keys' must be one of: 'primary_only', None."
)

def detect_table_from_dataframe(
self, table_name, data, infer_sdtypes=True, infer_keys='primary_only'
):
"""Detect the metadata for a table from a dataframe.
This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``,
Expand All @@ -541,10 +553,20 @@ def detect_table_from_dataframe(self, table_name, data):
Name of the table to detect.
data (pandas.DataFrame):
``pandas.DataFrame`` to detect the metadata from.
infer_sdtypes (bool):
A boolean describing whether to infer the sdtypes of each column.
If True it infers the sdtypes based on the data.
If False it does not infer the sdtypes and all columns are marked as unknown.
Defaults to True.
infer_keys (str):
A string describing whether to infer the primary and/or foreign keys. Options are:
- 'primary_only': Infer only the primary keys of each table
- None: Do not infer any keys
Defaults to 'primary_only'.
"""
self._validate_table_not_detected(table_name)
table = SingleTableMetadata()
table._detect_columns(data, table_name)
table._detect_columns(data, table_name, infer_sdtypes, infer_keys)
self.tables[table_name] = table
self._log_detected_table(table)

Expand Down
23 changes: 11 additions & 12 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def _detect_primary_key(self, data):

return None

def _detect_columns(self, data, table_name=None, infer_sdtypes=True, infer_keys=True):
def _detect_columns(self, data, table_name=None, infer_sdtypes=True, infer_keys='primary_only'):
"""Detect the columns' sdtypes from the data.
Args:
Expand All @@ -609,11 +609,10 @@ def _detect_columns(self, data, table_name=None, infer_sdtypes=True, infer_keys=
If False it does not infer the sdtypes and all columns are marked as unknown.
Defaults to True.
infer_keys (str):
A string describing whether to infer the primary and/or foreign keys. Options are:
- 'primary_and_foreign': Infer the primary keys in each table
- 'primary_only': Same as 'primary_and_foreign', infer only the primary keys
- None: Do not infer any keys
Defaults to 'primary_and_foreign'.
A string describing whether to infer the primary keys. Options are:
- 'primary_only': Infer the primary keys.
- None: Do not infer any keys.
Defaults to 'primary_only'.
"""
old_columns = data.columns
data.columns = data.columns.astype(str)
Expand All @@ -637,9 +636,9 @@ def _detect_columns(self, data, table_name=None, infer_sdtypes=True, infer_keys=
if sdtype is None:
table_str = f"table '{table_name}' " if table_name else ''
error_message = (
f"Unsupported data type for {table_str}column '{field}' (kind: {dtype}"
"). The valid data types are: 'object', 'int', 'float', 'datetime',"
" 'bool'."
f"Unsupported data type for {table_str}column '{field}' "
f"(kind: {dtype}). The valid data types are: 'object', "
"'int', 'float', 'datetime', 'bool'."
)
raise InvalidMetadataError(error_message)

Expand All @@ -650,8 +649,8 @@ def _detect_columns(self, data, table_name=None, infer_sdtypes=True, infer_keys=

table_str = f"table '{table_name}' " if table_name else ''
error_message = (
f"Unable to detect metadata for {table_str}column '{field}' due to an invalid "
f'data format.\n {error_type}: {e}'
f"Unable to detect metadata for {table_str}column '{field}' due "
f'to an invalid data format.\n {error_type}: {e}'
)
raise InvalidMetadataError(error_message) from e

Expand All @@ -670,7 +669,7 @@ def _detect_columns(self, data, table_name=None, infer_sdtypes=True, infer_keys=

self.columns[field] = deepcopy(column_dict)

if infer_keys:
if infer_keys == 'primary_only':
self.primary_key = self._detect_primary_key(data)
self._updated = True
data.columns = old_columns
Expand Down
103 changes: 99 additions & 4 deletions tests/unit/metadata/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,15 +594,62 @@ def test_detect_from_dataframes(self, mock_metadata):

# Assert
mock_metadata.return_value.detect_table_from_dataframe.assert_any_call(
'guests', guests_table
'guests', guests_table, True, 'primary_only'
)
mock_metadata.return_value.detect_table_from_dataframe.assert_any_call(
'hotels', hotels_table
'hotels', hotels_table, True, 'primary_only'
)
mock_metadata.return_value._detect_relationships.assert_called_once_with(data)
assert metadata == mock_metadata.return_value

def test_detect_from_dataframes_bad_input(self):
@patch('sdv.metadata.metadata.Metadata')
def test_detect_from_dataframes_infer_keys_none(self, mock_metadata):
"""Test ``detect_from_dataframes`` with infer_keys set to None."""
# Setup
mock_metadata.detect_table_from_dataframe = Mock()
mock_metadata._detect_relationships = Mock()
guests_table = pd.DataFrame()
hotels_table = pd.DataFrame()
data = {'guests': guests_table, 'hotels': hotels_table}

# Run
metadata = Metadata.detect_from_dataframes(data, infer_sdtypes=False, infer_keys=None)

# Assert
mock_metadata.return_value.detect_table_from_dataframe.assert_any_call(
'guests', guests_table, False, None
)
mock_metadata.return_value.detect_table_from_dataframe.assert_any_call(
'hotels', hotels_table, False, None
)
mock_metadata.return_value._detect_relationships.assert_not_called()
assert metadata == mock_metadata.return_value

@patch('sdv.metadata.metadata.Metadata')
def test_detect_from_dataframes_infer_keys_primary_only(self, mock_metadata):
"""Test ``detect_from_dataframes`` with infer_keys set to 'primary_only'."""
# Setup
mock_metadata.detect_table_from_dataframe = Mock()
mock_metadata._detect_relationships = Mock()
guests_table = pd.DataFrame()
hotels_table = pd.DataFrame()
data = {'guests': guests_table, 'hotels': hotels_table}

# Run
metadata = Metadata.detect_from_dataframes(
data, infer_sdtypes=False, infer_keys='primary_only')

# Assert
mock_metadata.return_value.detect_table_from_dataframe.assert_any_call(
'guests', guests_table, False, 'primary_only'
)
mock_metadata.return_value.detect_table_from_dataframe.assert_any_call(
'hotels', hotels_table, False, 'primary_only'
)
mock_metadata.return_value._detect_relationships.assert_not_called()
assert metadata == mock_metadata.return_value

def test_detect_from_dataframes_bad_input_data(self):
"""Test that an error is raised if the dictionary contains something other than DataFrames.
If the data contains values that aren't pandas.DataFrames, it should error.
Expand All @@ -615,6 +662,30 @@ def test_detect_from_dataframes_bad_input(self):
with pytest.raises(ValueError, match=expected_message):
Metadata.detect_from_dataframes(data)

def test_detect_from_dataframes_bad_input_infer_sdtypes(self):
"""Test that an error is raised if the infer_sdtypes is not a boolean."""
# Setup
data = {'guests': pd.DataFrame(), 'hotels': pd.DataFrame()}
infer_sdtypes = 'not_a_boolean'

# Run and Assert
expected_message = "'infer_sdtypes' must be a boolean value."
with pytest.raises(ValueError, match=expected_message):
Metadata.detect_from_dataframes(data, infer_sdtypes=infer_sdtypes)

def test_detect_from_dataframes_bad_input_infer_keys(self):
"""Test that an error is raised if the infer_keys is not a correct string."""
# Setup
data = {'guests': pd.DataFrame(), 'hotels': pd.DataFrame()}
infer_keys = 'incorrect_string'

# Run and Assert
expected_message = re.escape(
"'infer_keys' must be one of: 'primary_and_foreign', 'primary_only', None."
)
with pytest.raises(ValueError, match=expected_message):
Metadata.detect_from_dataframes(data, infer_keys=infer_keys)

@patch('sdv.metadata.metadata.Metadata')
def test_detect_from_dataframe(self, mock_metadata):
"""Test that the method calls the detection method and returns the metadata.
Expand All @@ -630,7 +701,7 @@ def test_detect_from_dataframe(self, mock_metadata):

# Assert
mock_metadata.return_value.detect_table_from_dataframe.assert_any_call(
Metadata.DEFAULT_SINGLE_TABLE_NAME, DataFrameMatcher(data)
Metadata.DEFAULT_SINGLE_TABLE_NAME, DataFrameMatcher(data), True, 'primary_only'
)
assert metadata == mock_metadata.return_value

Expand All @@ -641,6 +712,30 @@ def test_detect_from_dataframe_raises_error_if_not_dataframe(self):
with pytest.raises(ValueError, match=expected_message):
Metadata.detect_from_dataframe(Mock())

def test_detect_from_dataframe_bad_input_infer_sdtypes(self):
"""Test that an error is raised if the infer_sdtypes is not a boolean."""
# Setup
data = pd.DataFrame()
infer_sdtypes = 'not_a_boolean'

# Run and Assert
expected_message = "'infer_sdtypes' must be a boolean value."
with pytest.raises(ValueError, match=expected_message):
Metadata.detect_from_dataframe(data, infer_sdtypes=infer_sdtypes)

def test_detect_from_dataframe_bad_input_infer_keys(self):
"""Test that an error is raised if the infer_keys is not a correct string."""
# Setup
data = pd.DataFrame()
infer_keys = 'primary_and_foreign'

# Run and Assert
expected_message = re.escape(
"'infer_keys' must be one of: 'primary_only', None."
)
with pytest.raises(ValueError, match=expected_message):
Metadata.detect_from_dataframe(data, infer_keys=infer_keys)

def test__handle_table_name(self):
"""Test the ``_handle_table_name`` method."""
# Setup
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2499,7 +2499,8 @@ def test_detect_table_from_dataframe(self, single_table_mock, log_mock):
metadata.detect_table_from_dataframe('table', data)

# Assert
single_table_mock.return_value._detect_columns.assert_called_once_with(data, 'table')
single_table_mock.return_value._detect_columns.assert_called_once_with(
data, 'table', True, 'primary_only')
assert metadata.tables == {'table': single_table_mock.return_value}

expected_log_calls = call(
Expand Down
43 changes: 43 additions & 0 deletions tests/unit/metadata/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,49 @@ def test__detect_columns_invalid_data_format(self):
with pytest.raises(InvalidMetadataError, match=expected_error_message):
instance._detect_columns(data)

def test__detect_columns_without_infer_sdtypes(self):
"""Test the _detect_columns when infer_sdtypes is False."""
# Setup
instance = SingleTableMetadata()
data = pd.DataFrame({
'id': ['id1', 'id2', 'id3', 'id4', 'id5', 'id6', 'id7', 'id8', 'id9', 'id10', 'id11'],
'numerical': [1, 2, 3, 2, 5, 6, 7, 8, 9, 10, 11],
'datetime': [
'2022-01-01', '2022-02-01', '2022-03-01', '2022-04-01', '2022-05-01', '2022-06-01',
'2022-07-01', '2022-08-01', '2022-09-01', '2022-10-01', '2022-11-01'
],
'alternate_id': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'alternate_id_string': ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10'],
'categorical': ['a', 'b', 'a', 'a', 'b', 'b', 'a', 'b', 'a', 'b', 'a'],
'bool': [True, False, True, False, True, False, True, False, True, False, True],
'unknown': ['a', 'b', 'c', 'c', 1, 2.2, np.nan, None, 'd', 'e', 'f'],
'first_name': [
'John', 'Jane', 'John', 'Jane', 'John', 'Jane',
'John', 'Jane', 'John', 'Jane', 'John'
],
})

# Run
instance._detect_columns(data, infer_sdtypes=False)

# Assert
assert instance.columns['id']['sdtype'] == 'unknown'
assert instance.columns['numerical']['sdtype'] == 'unknown'
assert instance.columns['datetime']['sdtype'] == 'unknown'
assert instance.columns['alternate_id']['sdtype'] == 'unknown'
assert instance.columns['alternate_id']['pii'] is True
assert instance.columns['alternate_id_string']['sdtype'] == 'unknown'
assert instance.columns['alternate_id_string']['pii'] is True
assert instance.columns['categorical']['sdtype'] == 'unknown'
assert instance.columns['unknown']['sdtype'] == 'unknown'
assert instance.columns['unknown']['pii'] is True
assert instance.columns['bool']['sdtype'] == 'unknown'
assert instance.columns['first_name']['sdtype'] == 'unknown'
assert instance.columns['first_name']['pii'] is True

assert instance.primary_key is None
assert instance._updated is True

def test__detect_primary_key_missing_sdtypes(self):
"""The method should raise an error if not all sdtypes were detected."""
# Setup
Expand Down

0 comments on commit 52ef546

Please sign in to comment.