From c1c5a190b7087bd40c88594b7ef7b0f79040a761 Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Tue, 11 Feb 2025 09:43:46 -0800 Subject: [PATCH] Add `infer_sdtypes` and `infer_keys` parameters to `detect_from_dataframes` method (#2363) --- sdv/metadata/metadata.py | 55 +++- sdv/metadata/multi_table.py | 16 +- sdv/metadata/single_table.py | 85 +++--- tests/integration/metadata/test_metadata.py | 286 ++++++++++++++++++++ tests/unit/metadata/test_metadata.py | 102 ++++++- tests/unit/metadata/test_multi_table.py | 4 +- tests/unit/metadata/test_single_table.py | 117 +++++--- 7 files changed, 581 insertions(+), 84 deletions(-) diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index 02437384f..4b0af7215 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -61,8 +61,13 @@ def load_from_dict(cls, metadata_dict, single_table_name=None): instance._set_metadata_dict(metadata_dict, single_table_name) return instance + @staticmethod + def _validate_infer_sdtypes(infer_sdtypes): + if not isinstance(infer_sdtypes, bool): + raise ValueError("'infer_sdtypes' must be a boolean value.") + @classmethod - def detect_from_dataframes(cls, data): + def detect_from_dataframes(cls, data, infer_sdtypes=True, infer_keys='primary_and_foreign'): """Detect the metadata for all tables in a dictionary of dataframes. This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrames``. @@ -71,6 +76,18 @@ def detect_from_dataframes(cls, data): Args: data (dict): Dictionary of table names to dataframes. + infer_sdtypes (bool): + A boolean describing whether to infer the sdtypes of each column. + If True it infers the sdtypes based on the data. + If False it does not infer the sdtypes and all columns are marked as unknown. + Defaults to True. + infer_keys (str): + A string describing whether to infer the primary and/or foreign keys. Options are: + - 'primary_and_foreign': Infer the primary keys in each table, + and the foreign keys in other tables that refer to them + - 'primary_only': Infer only the primary keys of each table + - None: Do not infer any keys + Defaults to 'primary_and_foreign'. Returns: Metadata: @@ -78,16 +95,31 @@ def detect_from_dataframes(cls, data): """ if not data or not all(isinstance(df, pd.DataFrame) for df in data.values()): raise ValueError('The provided dictionary must contain only pandas DataFrame objects.') + if infer_keys not in ['primary_and_foreign', 'primary_only', None]: + raise ValueError( + "'infer_keys' must be one of: 'primary_and_foreign', 'primary_only', None." + ) + cls._validate_infer_sdtypes(infer_sdtypes) metadata = Metadata() for table_name, dataframe in data.items(): - metadata.detect_table_from_dataframe(table_name, dataframe) + metadata.detect_table_from_dataframe( + table_name, dataframe, infer_sdtypes, None if infer_keys is None else 'primary_only' + ) + + if infer_keys == 'primary_and_foreign': + metadata._detect_relationships(data) - metadata._detect_relationships(data) return metadata @classmethod - def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME): + def detect_from_dataframe( + cls, + data, + table_name=DEFAULT_SINGLE_TABLE_NAME, + infer_sdtypes=True, + infer_keys='primary_only', + ): """Detect the metadata for a DataFrame. This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``. @@ -96,6 +128,16 @@ def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME): Args: data (pandas.DataFrame): Dictionary of table names to dataframes. + infer_sdtypes (bool): + A boolean describing whether to infer the sdtypes of each column. + If True it infers the sdtypes based on the data. + If False it does not infer the sdtypes and all columns are marked as unknown. + Defaults to True. + infer_keys (str): + A string describing whether to infer the primary keys. Options are: + - 'primary_only': Infer only the primary keys of each table + - None: Do not infer any keys + Defaults to 'primary_only'. Returns: Metadata: @@ -103,9 +145,12 @@ def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME): """ if not isinstance(data, pd.DataFrame): raise ValueError('The provided data must be a pandas DataFrame object.') + if infer_keys not in ['primary_only', None]: + raise ValueError("'infer_keys' must be one of: 'primary_only', None.") + cls._validate_infer_sdtypes(infer_sdtypes) metadata = Metadata() - metadata.detect_table_from_dataframe(table_name, data) + metadata.detect_table_from_dataframe(table_name, data, infer_sdtypes, infer_keys) return metadata def _set_metadata_dict(self, metadata, single_table_name=None): diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index a2b2b35c5..3e58dffe7 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -530,7 +530,9 @@ def _detect_relationships(self, data=None): ) continue - def detect_table_from_dataframe(self, table_name, data): + def detect_table_from_dataframe( + self, table_name, data, infer_sdtypes=True, infer_keys='primary_only' + ): """Detect the metadata for a table from a dataframe. This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``, @@ -541,10 +543,20 @@ def detect_table_from_dataframe(self, table_name, data): Name of the table to detect. data (pandas.DataFrame): ``pandas.DataFrame`` to detect the metadata from. + infer_sdtypes (bool): + A boolean describing whether to infer the sdtypes of each column. + If True it infers the sdtypes based on the data. + If False it does not infer the sdtypes and all columns are marked as unknown. + Defaults to True. + infer_keys (str): + A string describing whether to infer the primary keys. Options are: + - 'primary_only': Infer only the primary keys of each table + - None: Do not infer any keys + Defaults to 'primary_only'. """ self._validate_table_not_detected(table_name) table = SingleTableMetadata() - table._detect_columns(data, table_name) + table._detect_columns(data, table_name, infer_sdtypes, infer_keys) self.tables[table_name] = table self._log_detected_table(table) diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 6f91f20d3..5c14ce868 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -595,7 +595,7 @@ def _detect_primary_key(self, data): return None - def _detect_columns(self, data, table_name=None): + def _detect_columns(self, data, table_name=None, infer_sdtypes=True, infer_keys='primary_only'): """Detect the columns' sdtypes from the data. Args: @@ -603,45 +603,59 @@ def _detect_columns(self, data, table_name=None): The data to be analyzed. table_name (str): The name of the table to be analyzed. Defaults to ``None``. + infer_sdtypes (bool): + A boolean describing whether to infer the sdtypes of each column. + If True it infers the sdtypes based on the data. + If False it does not infer the sdtypes and all columns are marked as unknown. + Defaults to True. + infer_keys (str): + A string describing whether to infer the primary keys. Options are: + - 'primary_only': Infer the primary keys. + - None: Do not infer any keys. + Defaults to 'primary_only'. """ old_columns = data.columns data.columns = data.columns.astype(str) for field in data: - try: - column_data = data[field] - clean_data = column_data.dropna() - dtype = clean_data.infer_objects().dtype.kind - - sdtype = self._detect_pii_column(field) - if sdtype is None: - if dtype in self._DTYPES_TO_SDTYPES: - sdtype = self._DTYPES_TO_SDTYPES[dtype] - elif dtype in ['i', 'f', 'u']: - sdtype = self._determine_sdtype_for_numbers(column_data) - - elif dtype == 'O': - sdtype = self._determine_sdtype_for_objects(column_data) + if infer_sdtypes: + try: + column_data = data[field] + clean_data = column_data.dropna() + dtype = clean_data.infer_objects().dtype.kind + sdtype = self._detect_pii_column(field) if sdtype is None: - table_str = f"table '{table_name}' " if table_name else '' - error_message = ( - f"Unsupported data type for {table_str}column '{field}' (kind: {dtype}" - "). The valid data types are: 'object', 'int', 'float', 'datetime'," - " 'bool'." - ) - raise InvalidMetadataError(error_message) - - except Exception as e: - error_type = type(e).__name__ - if error_type == 'InvalidMetadataError': - raise e - - table_str = f"table '{table_name}' " if table_name else '' - error_message = ( - f"Unable to detect metadata for {table_str}column '{field}' due to an invalid " - f'data format.\n {error_type}: {e}' - ) - raise InvalidMetadataError(error_message) from e + if dtype in self._DTYPES_TO_SDTYPES: + sdtype = self._DTYPES_TO_SDTYPES[dtype] + elif dtype in ['i', 'f', 'u']: + sdtype = self._determine_sdtype_for_numbers(column_data) + + elif dtype == 'O': + sdtype = self._determine_sdtype_for_objects(column_data) + + if sdtype is None: + table_str = f"table '{table_name}' " if table_name else '' + error_message = ( + f"Unsupported data type for {table_str}column '{field}' " + f"(kind: {dtype}). The valid data types are: 'object', " + "'int', 'float', 'datetime', 'bool'." + ) + raise InvalidMetadataError(error_message) + + except Exception as e: + error_type = type(e).__name__ + if error_type == 'InvalidMetadataError': + raise e + + table_str = f"table '{table_name}' " if table_name else '' + error_message = ( + f"Unable to detect metadata for {table_str}column '{field}' due " + f'to an invalid data format.\n {error_type}: {e}' + ) + raise InvalidMetadataError(error_message) from e + + else: + sdtype = 'unknown' column_dict = {'sdtype': sdtype} sdtype_in_reference = sdtype in self._REFERENCE_TO_SDTYPE.values() @@ -655,7 +669,8 @@ def _detect_columns(self, data, table_name=None): self.columns[field] = deepcopy(column_dict) - self.primary_key = self._detect_primary_key(data) + if infer_keys == 'primary_only': + self.primary_key = self._detect_primary_key(data) self._updated = True data.columns = old_columns diff --git a/tests/integration/metadata/test_metadata.py b/tests/integration/metadata/test_metadata.py index 09f2f54b7..8b7e62cd1 100644 --- a/tests/integration/metadata/test_metadata.py +++ b/tests/integration/metadata/test_metadata.py @@ -121,6 +121,149 @@ def test_detect_from_dataframes_multi_table(): assert metadata.to_dict() == expected_metadata +def test_detect_from_dataframes_multi_table_without_infer_sdtypes(): + """Test it when infer_sdtypes is False.""" + # Setup + real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + + # Run + metadata = Metadata.detect_from_dataframes(real_data, infer_sdtypes=False) + + # Assert + metadata.update_column( + table_name='hotels', + column_name='classification', + sdtype='categorical', + ) + + expected_metadata = { + 'tables': { + 'hotels': { + 'columns': { + 'hotel_id': {'sdtype': 'unknown', 'pii': True}, + 'city': {'sdtype': 'unknown', 'pii': True}, + 'state': {'sdtype': 'unknown', 'pii': True}, + 'rating': {'sdtype': 'unknown', 'pii': True}, + 'classification': {'sdtype': 'categorical'}, + }, + }, + 'guests': { + 'columns': { + 'guest_email': {'sdtype': 'unknown', 'pii': True}, + 'hotel_id': {'sdtype': 'unknown', 'pii': True}, + 'has_rewards': {'sdtype': 'unknown', 'pii': True}, + 'room_type': {'sdtype': 'unknown', 'pii': True}, + 'amenities_fee': {'sdtype': 'unknown', 'pii': True}, + 'checkin_date': {'sdtype': 'unknown', 'pii': True}, + 'checkout_date': {'sdtype': 'unknown', 'pii': True}, + 'room_rate': {'sdtype': 'unknown', 'pii': True}, + 'billing_address': {'sdtype': 'unknown', 'pii': True}, + 'credit_card_number': {'sdtype': 'unknown', 'pii': True}, + }, + }, + }, + 'relationships': [], + 'METADATA_SPEC_VERSION': 'V1', + } + assert metadata.to_dict() == expected_metadata + + +def test_detect_from_dataframes_multi_table_with_infer_keys_primary_only(): + """Test it when infer_keys is 'primary_only'.""" + # Setup + real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + + # Run + metadata = Metadata.detect_from_dataframes(real_data, infer_keys='primary_only') + + # Assert + metadata.update_column( + table_name='hotels', + column_name='classification', + sdtype='categorical', + ) + + expected_metadata = { + 'tables': { + 'hotels': { + 'columns': { + 'hotel_id': {'sdtype': 'id'}, + 'city': {'sdtype': 'city', 'pii': True}, + 'state': {'sdtype': 'administrative_unit', 'pii': True}, + 'rating': {'sdtype': 'numerical'}, + 'classification': {'sdtype': 'categorical'}, + }, + 'primary_key': 'hotel_id', + }, + 'guests': { + 'columns': { + 'guest_email': {'sdtype': 'email', 'pii': True}, + 'hotel_id': {'sdtype': 'categorical'}, + 'has_rewards': {'sdtype': 'categorical'}, + 'room_type': {'sdtype': 'categorical'}, + 'amenities_fee': {'sdtype': 'numerical'}, + 'checkin_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, + 'checkout_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, + 'room_rate': {'sdtype': 'numerical'}, + 'billing_address': {'sdtype': 'unknown', 'pii': True}, + 'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True}, + }, + 'primary_key': 'guest_email', + }, + }, + 'relationships': [], + 'METADATA_SPEC_VERSION': 'V1', + } + assert metadata.to_dict() == expected_metadata + + +def test_detect_from_dataframes_multi_table_with_infer_keys_none(): + """Test it when infer_keys is None.""" + # Setup + real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + + # Run + metadata = Metadata.detect_from_dataframes(real_data, infer_keys=None) + + # Assert + metadata.update_column( + table_name='hotels', + column_name='classification', + sdtype='categorical', + ) + + expected_metadata = { + 'tables': { + 'hotels': { + 'columns': { + 'hotel_id': {'sdtype': 'id'}, + 'city': {'sdtype': 'city', 'pii': True}, + 'state': {'sdtype': 'administrative_unit', 'pii': True}, + 'rating': {'sdtype': 'numerical'}, + 'classification': {'sdtype': 'categorical'}, + }, + }, + 'guests': { + 'columns': { + 'guest_email': {'sdtype': 'email', 'pii': True}, + 'hotel_id': {'sdtype': 'categorical'}, + 'has_rewards': {'sdtype': 'categorical'}, + 'room_type': {'sdtype': 'categorical'}, + 'amenities_fee': {'sdtype': 'numerical'}, + 'checkin_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, + 'checkout_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, + 'room_rate': {'sdtype': 'numerical'}, + 'billing_address': {'sdtype': 'unknown', 'pii': True}, + 'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True}, + }, + }, + }, + 'relationships': [], + 'METADATA_SPEC_VERSION': 'V1', + } + assert metadata.to_dict() == expected_metadata + + def test_detect_from_dataframes_single_table(): """Test the ``detect_from_dataframes`` method works with a single table.""" # Setup @@ -151,6 +294,93 @@ def test_detect_from_dataframes_single_table(): assert metadata.to_dict() == expected_metadata +def test_detect_from_dataframes_single_table_infer_sdtypes_false(): + """Test it for a single table when infer_sdtypes is False.""" + # Setup + data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + metadata = Metadata.detect_from_dataframes({'table_1': data['hotels']}, infer_sdtypes=False) + + # Run + metadata.validate() + + # Assert + expected_metadata = { + 'METADATA_SPEC_VERSION': 'V1', + 'tables': { + 'table_1': { + 'columns': { + 'hotel_id': {'sdtype': 'unknown', 'pii': True}, + 'city': {'sdtype': 'unknown', 'pii': True}, + 'state': {'sdtype': 'unknown', 'pii': True}, + 'rating': {'sdtype': 'unknown', 'pii': True}, + 'classification': {'sdtype': 'unknown', 'pii': True}, + }, + } + }, + 'relationships': [], + } + assert metadata.to_dict() == expected_metadata + + +def test_detect_from_dataframes_single_table_infer_keys_primary_only(): + """Test it for a single table when infer_keys is 'primary_only'.""" + # Setup + data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + metadata = Metadata.detect_from_dataframes( + {'table_1': data['hotels']}, infer_keys='primary_only' + ) + + # Run + metadata.validate() + + # Assert + expected_metadata = { + 'METADATA_SPEC_VERSION': 'V1', + 'tables': { + 'table_1': { + 'columns': { + 'hotel_id': {'sdtype': 'id'}, + 'city': {'sdtype': 'city', 'pii': True}, + 'state': {'sdtype': 'administrative_unit', 'pii': True}, + 'rating': {'sdtype': 'numerical'}, + 'classification': {'sdtype': 'unknown', 'pii': True}, + }, + 'primary_key': 'hotel_id', + } + }, + 'relationships': [], + } + assert metadata.to_dict() == expected_metadata + + +def test_detect_from_dataframes_single_table_infer_keys_none(): + """Test it for a single table when infer_keys is None.""" + # Setup + data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + metadata = Metadata.detect_from_dataframes({'table_1': data['hotels']}, infer_keys=None) + + # Run + metadata.validate() + + # Assert + expected_metadata = { + 'METADATA_SPEC_VERSION': 'V1', + 'tables': { + 'table_1': { + 'columns': { + 'hotel_id': {'sdtype': 'id'}, + 'city': {'sdtype': 'city', 'pii': True}, + 'state': {'sdtype': 'administrative_unit', 'pii': True}, + 'rating': {'sdtype': 'numerical'}, + 'classification': {'sdtype': 'unknown', 'pii': True}, + }, + } + }, + 'relationships': [], + } + assert metadata.to_dict() == expected_metadata + + def test_detect_from_dataframe(): """Test that a single table can be detected as a DataFrame.""" # Setup @@ -181,6 +411,62 @@ def test_detect_from_dataframe(): assert metadata.to_dict() == expected_metadata +def test_detect_from_dataframe_infer_sdtypes_false(): + """Test it when infer_sdtypes is False.""" + # Setup + data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + metadata = Metadata.detect_from_dataframe(data['hotels'], infer_sdtypes=False) + + # Run + metadata.validate() + + # Assert + expected_metadata = { + 'METADATA_SPEC_VERSION': 'V1', + 'tables': { + DEFAULT_TABLE_NAME: { + 'columns': { + 'hotel_id': {'sdtype': 'unknown', 'pii': True}, + 'city': {'sdtype': 'unknown', 'pii': True}, + 'state': {'sdtype': 'unknown', 'pii': True}, + 'rating': {'sdtype': 'unknown', 'pii': True}, + 'classification': {'sdtype': 'unknown', 'pii': True}, + }, + } + }, + 'relationships': [], + } + assert metadata.to_dict() == expected_metadata + + +def test_detect_from_dataframe_infer_keys_none(): + """Test it when infer_keys is None.""" + # Setup + data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + metadata = Metadata.detect_from_dataframe(data['hotels'], infer_keys=None) + + # Run + metadata.validate() + + # Assert + expected_metadata = { + 'METADATA_SPEC_VERSION': 'V1', + 'tables': { + DEFAULT_TABLE_NAME: { + 'columns': { + 'hotel_id': {'sdtype': 'id'}, + 'city': {'sdtype': 'city', 'pii': True}, + 'state': {'sdtype': 'administrative_unit', 'pii': True}, + 'rating': {'sdtype': 'numerical'}, + 'classification': {'sdtype': 'unknown', 'pii': True}, + }, + } + }, + 'relationships': [], + } + assert metadata.to_dict() == expected_metadata + + def test_detect_from_csvs(tmp_path): """Test the ``detect_from_csvs`` method.""" # Setup diff --git a/tests/unit/metadata/test_metadata.py b/tests/unit/metadata/test_metadata.py index 8e963e017..5a4398c82 100644 --- a/tests/unit/metadata/test_metadata.py +++ b/tests/unit/metadata/test_metadata.py @@ -594,15 +594,63 @@ def test_detect_from_dataframes(self, mock_metadata): # Assert mock_metadata.return_value.detect_table_from_dataframe.assert_any_call( - 'guests', guests_table + 'guests', guests_table, True, 'primary_only' ) mock_metadata.return_value.detect_table_from_dataframe.assert_any_call( - 'hotels', hotels_table + 'hotels', hotels_table, True, 'primary_only' ) mock_metadata.return_value._detect_relationships.assert_called_once_with(data) assert metadata == mock_metadata.return_value - def test_detect_from_dataframes_bad_input(self): + @patch('sdv.metadata.metadata.Metadata') + def test_detect_from_dataframes_infer_keys_none(self, mock_metadata): + """Test ``detect_from_dataframes`` with infer_keys set to None.""" + # Setup + mock_metadata.detect_table_from_dataframe = Mock() + mock_metadata._detect_relationships = Mock() + guests_table = pd.DataFrame() + hotels_table = pd.DataFrame() + data = {'guests': guests_table, 'hotels': hotels_table} + + # Run + metadata = Metadata.detect_from_dataframes(data, infer_sdtypes=False, infer_keys=None) + + # Assert + mock_metadata.return_value.detect_table_from_dataframe.assert_any_call( + 'guests', guests_table, False, None + ) + mock_metadata.return_value.detect_table_from_dataframe.assert_any_call( + 'hotels', hotels_table, False, None + ) + mock_metadata.return_value._detect_relationships.assert_not_called() + assert metadata == mock_metadata.return_value + + @patch('sdv.metadata.metadata.Metadata') + def test_detect_from_dataframes_infer_keys_primary_only(self, mock_metadata): + """Test ``detect_from_dataframes`` with infer_keys set to 'primary_only'.""" + # Setup + mock_metadata.detect_table_from_dataframe = Mock() + mock_metadata._detect_relationships = Mock() + guests_table = pd.DataFrame() + hotels_table = pd.DataFrame() + data = {'guests': guests_table, 'hotels': hotels_table} + + # Run + metadata = Metadata.detect_from_dataframes( + data, infer_sdtypes=False, infer_keys='primary_only' + ) + + # Assert + mock_metadata.return_value.detect_table_from_dataframe.assert_any_call( + 'guests', guests_table, False, 'primary_only' + ) + mock_metadata.return_value.detect_table_from_dataframe.assert_any_call( + 'hotels', hotels_table, False, 'primary_only' + ) + mock_metadata.return_value._detect_relationships.assert_not_called() + assert metadata == mock_metadata.return_value + + def test_detect_from_dataframes_bad_input_data(self): """Test that an error is raised if the dictionary contains something other than DataFrames. If the data contains values that aren't pandas.DataFrames, it should error. @@ -615,6 +663,30 @@ def test_detect_from_dataframes_bad_input(self): with pytest.raises(ValueError, match=expected_message): Metadata.detect_from_dataframes(data) + def test_detect_from_dataframes_bad_input_infer_sdtypes(self): + """Test that an error is raised if the infer_sdtypes is not a boolean.""" + # Setup + data = {'guests': pd.DataFrame(), 'hotels': pd.DataFrame()} + infer_sdtypes = 'not_a_boolean' + + # Run and Assert + expected_message = "'infer_sdtypes' must be a boolean value." + with pytest.raises(ValueError, match=expected_message): + Metadata.detect_from_dataframes(data, infer_sdtypes=infer_sdtypes) + + def test_detect_from_dataframes_bad_input_infer_keys(self): + """Test that an error is raised if the infer_keys is not a correct string.""" + # Setup + data = {'guests': pd.DataFrame(), 'hotels': pd.DataFrame()} + infer_keys = 'incorrect_string' + + # Run and Assert + expected_message = re.escape( + "'infer_keys' must be one of: 'primary_and_foreign', 'primary_only', None." + ) + with pytest.raises(ValueError, match=expected_message): + Metadata.detect_from_dataframes(data, infer_keys=infer_keys) + @patch('sdv.metadata.metadata.Metadata') def test_detect_from_dataframe(self, mock_metadata): """Test that the method calls the detection method and returns the metadata. @@ -630,7 +702,7 @@ def test_detect_from_dataframe(self, mock_metadata): # Assert mock_metadata.return_value.detect_table_from_dataframe.assert_any_call( - Metadata.DEFAULT_SINGLE_TABLE_NAME, DataFrameMatcher(data) + Metadata.DEFAULT_SINGLE_TABLE_NAME, DataFrameMatcher(data), True, 'primary_only' ) assert metadata == mock_metadata.return_value @@ -641,6 +713,28 @@ def test_detect_from_dataframe_raises_error_if_not_dataframe(self): with pytest.raises(ValueError, match=expected_message): Metadata.detect_from_dataframe(Mock()) + def test_detect_from_dataframe_bad_input_infer_sdtypes(self): + """Test that an error is raised if the infer_sdtypes is not a boolean.""" + # Setup + data = pd.DataFrame() + infer_sdtypes = 'not_a_boolean' + + # Run and Assert + expected_message = "'infer_sdtypes' must be a boolean value." + with pytest.raises(ValueError, match=expected_message): + Metadata.detect_from_dataframe(data, infer_sdtypes=infer_sdtypes) + + def test_detect_from_dataframe_bad_input_infer_keys(self): + """Test that an error is raised if the infer_keys is not a correct string.""" + # Setup + data = pd.DataFrame() + infer_keys = 'primary_and_foreign' + + # Run and Assert + expected_message = re.escape("'infer_keys' must be one of: 'primary_only', None.") + with pytest.raises(ValueError, match=expected_message): + Metadata.detect_from_dataframe(data, infer_keys=infer_keys) + def test__handle_table_name(self): """Test the ``_handle_table_name`` method.""" # Setup diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index cf05790fd..64587901d 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -2499,7 +2499,9 @@ def test_detect_table_from_dataframe(self, single_table_mock, log_mock): metadata.detect_table_from_dataframe('table', data) # Assert - single_table_mock.return_value._detect_columns.assert_called_once_with(data, 'table') + single_table_mock.return_value._detect_columns.assert_called_once_with( + data, 'table', True, 'primary_only' + ) assert metadata.tables == {'table': single_table_mock.return_value} expected_log_calls = call( diff --git a/tests/unit/metadata/test_single_table.py b/tests/unit/metadata/test_single_table.py index ad952b346..57018a2fe 100644 --- a/tests/unit/metadata/test_single_table.py +++ b/tests/unit/metadata/test_single_table.py @@ -76,6 +76,44 @@ class TestSingleTableMetadata: ), ] # noqa: JS102 + @pytest.fixture + def data(self): + return pd.DataFrame({ + 'id': ['id1', 'id2', 'id3', 'id4', 'id5', 'id6', 'id7', 'id8', 'id9', 'id10', 'id11'], + 'numerical': [1, 2, 3, 2, 5, 6, 7, 8, 9, 10, 11], + 'datetime': [ + '2022-01-01', + '2022-02-01', + '2022-03-01', + '2022-04-01', + '2022-05-01', + '2022-06-01', + '2022-07-01', + '2022-08-01', + '2022-09-01', + '2022-10-01', + '2022-11-01', + ], + 'alternate_id': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + 'alternate_id_string': ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10'], + 'categorical': ['a', 'b', 'a', 'a', 'b', 'b', 'a', 'b', 'a', 'b', 'a'], + 'bool': [True, False, True, False, True, False, True, False, True, False, True], + 'unknown': ['a', 'b', 'c', 'c', 1, 2.2, np.nan, None, 'd', 'e', 'f'], + 'first_name': [ + 'John', + 'Jane', + 'John', + 'Jane', + 'John', + 'Jane', + 'John', + 'Jane', + 'John', + 'Jane', + 'John', + ], + }) + def test___init__(self): """Test creating an instance of ``SingleTableMetadata``.""" # Run @@ -1111,46 +1149,10 @@ def test__determine_sdtype_for_objects_with_none(self): # Assert assert sdtype == 'categorical' - def test__detect_columns(self): + def test__detect_columns(self, data): """Test the ``_detect_columns`` method.""" # Setup instance = SingleTableMetadata() - data = pd.DataFrame({ - 'id': ['id1', 'id2', 'id3', 'id4', 'id5', 'id6', 'id7', 'id8', 'id9', 'id10', 'id11'], - 'numerical': [1, 2, 3, 2, 5, 6, 7, 8, 9, 10, 11], - 'datetime': [ - '2022-01-01', - '2022-02-01', - '2022-03-01', - '2022-04-01', - '2022-05-01', - '2022-06-01', - '2022-07-01', - '2022-08-01', - '2022-09-01', - '2022-10-01', - '2022-11-01', - ], - 'alternate_id': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - 'alternate_id_string': ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10'], - 'categorical': ['a', 'b', 'a', 'a', 'b', 'b', 'a', 'b', 'a', 'b', 'a'], - 'bool': [True, False, True, False, True, False, True, False, True, False, True], - 'unknown': ['a', 'b', 'c', 'c', 1, 2.2, np.nan, None, 'd', 'e', 'f'], - 'first_name': [ - 'John', - 'Jane', - 'John', - 'Jane', - 'John', - 'Jane', - 'John', - 'Jane', - 'John', - 'Jane', - 'John', - ], - }) - expected_datetime_format = '%Y-%m-%d' # Run @@ -1320,6 +1322,47 @@ def test__detect_columns_invalid_data_format(self): with pytest.raises(InvalidMetadataError, match=expected_error_message): instance._detect_columns(data) + def test__detect_columns_without_infer_sdtypes(self, data): + """Test the _detect_columns when infer_sdtypes is False.""" + # Setup + instance = SingleTableMetadata() + + # Run + instance._detect_columns(data, infer_sdtypes=False) + + # Assert + for column in data.columns: + assert instance.columns[column]['sdtype'] == 'unknown' + assert instance.columns[column]['pii'] is True + + assert instance.primary_key is None + assert instance._updated is True + + def test__detect_columns_without_infer_keys(self, data): + """Test the _detect_columns when infer_keys is False.""" + # Setup + instance = SingleTableMetadata() + + # Run + instance._detect_columns(data, infer_keys=None) + + # Assert + assert instance.columns['id']['sdtype'] == 'id' + assert instance.columns['numerical']['sdtype'] == 'numerical' + assert instance.columns['datetime']['sdtype'] == 'datetime' + assert instance.columns['datetime']['datetime_format'] == '%Y-%m-%d' + assert instance.columns['alternate_id']['sdtype'] == 'id' + assert instance.columns['alternate_id_string']['sdtype'] == 'id' + assert instance.columns['categorical']['sdtype'] == 'categorical' + assert instance.columns['unknown']['sdtype'] == 'unknown' + assert instance.columns['unknown']['pii'] is True + assert instance.columns['bool']['sdtype'] == 'categorical' + assert instance.columns['first_name']['sdtype'] == 'first_name' + assert instance.columns['first_name']['pii'] is True + + assert instance.primary_key is None + assert instance._updated is True + def test__detect_primary_key_missing_sdtypes(self): """The method should raise an error if not all sdtypes were detected.""" # Setup