From 7a52ed200c80fbd78a3ade3c4f3256d4d190026c Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Fri, 16 Feb 2024 08:30:10 -0800 Subject: [PATCH] Provide user-friendly error messages when there are missing values in conditional sampling (#1791) --- sdv/single_table/base.py | 44 +++++++++--- tests/integration/single_table/test_base.py | 78 +++++++++++++++++++++ tests/unit/single_table/test_base.py | 58 +++++++++++++-- 3 files changed, 164 insertions(+), 16 deletions(-) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index b89f29ef1..d2ca490e8 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -794,13 +794,6 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file show_progress_bar=show_progress_bar ) - def _validate_conditions(self, conditions): - """Validate the user-passed conditions.""" - for column in conditions.columns: - if column not in self._data_processor.get_sdtypes(): - raise ValueError(f"Unexpected column name '{column}'. " - f'Use a column name that was present in the original data.') - def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size, progress_bar=None, output_file_path=None): """Sample rows with conditions. @@ -904,6 +897,27 @@ def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size, return all_sampled_rows + def _validate_conditions_unseen_columns(self, conditions): + """Validate the user-passed conditions.""" + for column in conditions.columns: + if column not in self._data_processor.get_sdtypes(): + raise ValueError(f"Unexpected column name '{column}'. " + f'Use a column name that was present in the original data.') + + @staticmethod + def _raise_condition_with_nans(): + raise SynthesizerInputError( + 'Missing values are not yet supported for conditional sampling. ' + 'Please include only non-null values in your Condition objects.' + ) + + def _validate_conditions(self, conditions): + """Validate the user-passed conditions.""" + for condition_dataframe in conditions: + self._validate_conditions_unseen_columns(condition_dataframe) + if condition_dataframe.isna().any().any(): + self._raise_condition_with_nans() + def sample_from_conditions(self, conditions, max_tries_per_batch=100, batch_size=None, output_file_path=None): """Sample rows from this table with the given conditions. @@ -939,8 +953,7 @@ def sample_from_conditions(self, conditions, max_tries_per_batch=100, lambda num_rows, condition: condition.get_num_rows() + num_rows, conditions, 0) conditions = self._make_condition_dfs(conditions) - for condition_dataframe in conditions: - self._validate_conditions(condition_dataframe) + self._validate_conditions(conditions) sampled = pd.DataFrame() try: @@ -974,6 +987,17 @@ def sample_from_conditions(self, conditions, max_tries_per_batch=100, return sampled + def _validate_known_columns(self, conditions): + """Validate the user-passed conditions.""" + self._validate_conditions_unseen_columns(conditions) + if conditions.dropna().empty: + self._raise_condition_with_nans() + elif conditions.isna().any().any(): + warnings.warn( + 'Missing values are not yet supported. ' + 'Rows with any missing values will not be created.' + ) + def sample_remaining_columns(self, known_columns, max_tries_per_batch=100, batch_size=None, output_file_path=None): """Sample remaining rows from already known columns. @@ -1006,7 +1030,7 @@ def sample_remaining_columns(self, known_columns, max_tries_per_batch=100, output_file_path = validate_file_path(output_file_path) known_columns = known_columns.copy() - self._validate_conditions(known_columns) + self._validate_known_columns(known_columns) sampled = pd.DataFrame() try: with tqdm.tqdm(total=len(known_columns)) as progress_bar: diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index ebb726d96..ef049be67 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -3,11 +3,14 @@ import warnings from unittest.mock import patch +import numpy as np import pandas as pd import pkg_resources import pytest from rdt.transformers import AnonymizedFaker, FloatFormatter, RegexGenerator, UniformEncoder +from sdv.datasets.demo import download_demo +from sdv.errors import SynthesizerInputError from sdv.metadata import SingleTableMetadata from sdv.sampling import Condition from sdv.single_table import ( @@ -147,6 +150,81 @@ def test_sample_from_conditions_negative_float(): pd.testing.assert_series_equal(sampled_data['column1'], expected) +def test_sample_from_conditions_with_nans(): + """Test it crashes when condition has nans (GH#1758).""" + # Setup + data, metadata = download_demo( + modality='single_table', + dataset_name='fake_hotel_guests' + ) + synthesizer = GaussianCopulaSynthesizer(metadata) + my_condition = Condition( + num_rows=250, + column_values={'room_type': None, 'has_rewards': False} + ) + + # Run + synthesizer.fit(data) + + # Assert + error_msg = ( + 'Missing values are not yet supported for conditional sampling. ' + 'Please include only non-null values in your Condition objects.' + ) + with pytest.raises(SynthesizerInputError, match=error_msg): + synthesizer.sample_from_conditions(conditions=[my_condition]) + + +def test_sample_remaining_columns_with_all_nans(): + """Test it crashes when every condition row has a nan (GH#1758).""" + # Setup + data, metadata = download_demo( + modality='single_table', + dataset_name='fake_hotel_guests' + ) + synthesizer = GaussianCopulaSynthesizer(metadata) + known_columns = pd.DataFrame(data={ + 'has_rewards': [np.nan, False, True], + 'amenities_fee': [5.00, np.nan, None] + }) + + # Run + synthesizer.fit(data) + + # Assert + error_msg = ( + 'Missing values are not yet supported for conditional sampling. ' + 'Please include only non-null values in your Condition objects.' + ) + with pytest.raises(SynthesizerInputError, match=error_msg): + synthesizer.sample_remaining_columns(known_columns=known_columns) + + +def test_sample_remaining_columns_with_some_nans(): + """Test it warns when some of the condition rows contain nans (GH#1758).""" + # Setup + data, metadata = download_demo( + modality='single_table', + dataset_name='fake_hotel_guests' + ) + synthesizer = GaussianCopulaSynthesizer(metadata) + known_columns = pd.DataFrame(data={ + 'has_rewards': [True, False, np.nan], + 'amenities_fee': [5.00, np.nan, None] + }) + + # Run + synthesizer.fit(data) + + # Assert + warn_msg = ( + 'Missing values are not yet supported. ' + 'Rows with any missing values will not be created.' + ) + with pytest.warns(UserWarning, match=warn_msg): + synthesizer.sample_remaining_columns(known_columns=known_columns) + + def test_multiple_fits(): """Test the synthesizer refits correctly on new data. diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 45cae5f41..5fe85393b 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -2,6 +2,7 @@ from datetime import date, datetime from unittest.mock import ANY, MagicMock, Mock, call, mock_open, patch +import numpy as np import pandas as pd import pytest from copulas.multivariate import GaussianMultivariate @@ -1279,7 +1280,7 @@ def test_sample(self): ) assert result == instance._sample_with_progress_bar.return_value - def test__validate_conditions(self): + def test__validate_conditions_unseen_columns(self): """Test that conditions are within the ``data_processor`` fields.""" # Setup instance = Mock() @@ -1290,12 +1291,12 @@ def test__validate_conditions(self): conditions = pd.DataFrame({'name': ['Johanna'], 'surname': ['Doe']}) # Run - BaseSingleTableSynthesizer._validate_conditions(instance, conditions) + BaseSingleTableSynthesizer._validate_conditions_unseen_columns(instance, conditions) # Assert instance._data_processor.get_sdtypes.assert_called() - def test__validate_conditions_raises_error(self): + def test__validate_conditions_unseen_columns_raises_error(self): """Test that conditions are not in the ``data_processor`` fields.""" # Setup instance = Mock() @@ -1311,7 +1312,22 @@ def test__validate_conditions_raises_error(self): 'original data.' ) with pytest.raises(ValueError, match=error_msg): - BaseSingleTableSynthesizer._validate_conditions(instance, conditions) + BaseSingleTableSynthesizer._validate_conditions_unseen_columns(instance, conditions) + + def test__validate_conditions_nans(self): + """Test that it raises an error when nans are in the data.""" + # Setup + conditions = [pd.DataFrame({'names': [np.nan], 'surname': ['Doe']})] + synthesizer = BaseSingleTableSynthesizer(MagicMock()) + synthesizer._validate_conditions_unseen_columns = Mock() + + # Run and Assert + error_msg = ( + 'Missing values are not yet supported for conditional sampling. ' + 'Please include only non-null values in your Condition objects.' + ) + with pytest.raises(SynthesizerInputError, match=error_msg): + synthesizer._validate_conditions(conditions) def test__sample_with_conditions_constraints_not_met(self): """Test when conditions are not met.""" @@ -1523,7 +1539,7 @@ def test_sample_remaining_columns(self, mock_validate_file_path, mock_tqdm, instance = BaseSingleTableSynthesizer(metadata) known_columns = pd.DataFrame({'name': ['Johanna Doe']}) - instance._validate_conditions = Mock() + instance._validate_known_columns = Mock() instance._sample_with_conditions = Mock() instance._model = GaussianMultivariate() instance._sample_with_conditions.return_value = pd.DataFrame({'name': ['John Doe']}) @@ -1563,7 +1579,7 @@ def test_sample_remaining_columns_handles_sampling_error( instance = BaseSingleTableSynthesizer(metadata) known_columns = pd.DataFrame({'name': ['Johanna Doe']}) - instance._validate_conditions = Mock() + instance._validate_known_columns = Mock() instance._sample_with_conditions = Mock() instance._model = GaussianMultivariate() keyboard_error = KeyboardInterrupt() @@ -1585,6 +1601,36 @@ def test_sample_remaining_columns_handles_sampling_error( pd.testing.assert_frame_equal(result, pd.DataFrame()) mock_handle_sampling_error.assert_called_once_with(False, 'temp_file', keyboard_error) + def test__validate_known_columns_nans(self): + """Test that it crashes when condition has nans.""" + # Setup + conditions = pd.DataFrame({'names': [np.nan], 'surname': ['Doe']}) + synthesizer = BaseSingleTableSynthesizer(MagicMock()) + synthesizer._validate_conditions_unseen_columns = Mock() + + # Run and Assert + error_msg = ( + 'Missing values are not yet supported for conditional sampling. ' + 'Please include only non-null values in your Condition objects.' + ) + with pytest.raises(SynthesizerInputError, match=error_msg): + synthesizer._validate_known_columns(conditions) + + def test__validate_known_columns_a_few_nans(self): + """Test that it warns when condition has a few nans, but at least a valid row.""" + # Setup + conditions = pd.DataFrame({'names': [np.nan, 'Dae'], 'surname': ['Doe', 'Due']}) + synthesizer = BaseSingleTableSynthesizer(MagicMock()) + synthesizer._validate_conditions_unseen_columns = Mock() + + # Run and Assert + warn_msg = ( + 'Missing values are not yet supported. ' + 'Rows with any missing values will not be created.' + ) + with pytest.warns(UserWarning, match=warn_msg): + synthesizer._validate_known_columns(conditions) + @patch('sdv.single_table.base.cloudpickle') def test_save(self, cloudpickle_mock, tmp_path): """Test that the synthesizer is saved correctly."""