Skip to content

Commit

Permalink
Provide user-friendly error messages when there are missing values in…
Browse files Browse the repository at this point in the history
… conditional sampling (#1791)
  • Loading branch information
fealho authored Feb 16, 2024
1 parent c0dd7c7 commit 7a52ed2
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 16 deletions.
44 changes: 34 additions & 10 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,13 +794,6 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file
show_progress_bar=show_progress_bar
)

def _validate_conditions(self, conditions):
"""Validate the user-passed conditions."""
for column in conditions.columns:
if column not in self._data_processor.get_sdtypes():
raise ValueError(f"Unexpected column name '{column}'. "
f'Use a column name that was present in the original data.')

def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size,
progress_bar=None, output_file_path=None):
"""Sample rows with conditions.
Expand Down Expand Up @@ -904,6 +897,27 @@ def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size,

return all_sampled_rows

def _validate_conditions_unseen_columns(self, conditions):
"""Validate the user-passed conditions."""
for column in conditions.columns:
if column not in self._data_processor.get_sdtypes():
raise ValueError(f"Unexpected column name '{column}'. "
f'Use a column name that was present in the original data.')

@staticmethod
def _raise_condition_with_nans():
raise SynthesizerInputError(
'Missing values are not yet supported for conditional sampling. '
'Please include only non-null values in your Condition objects.'
)

def _validate_conditions(self, conditions):
"""Validate the user-passed conditions."""
for condition_dataframe in conditions:
self._validate_conditions_unseen_columns(condition_dataframe)
if condition_dataframe.isna().any().any():
self._raise_condition_with_nans()

def sample_from_conditions(self, conditions, max_tries_per_batch=100,
batch_size=None, output_file_path=None):
"""Sample rows from this table with the given conditions.
Expand Down Expand Up @@ -939,8 +953,7 @@ def sample_from_conditions(self, conditions, max_tries_per_batch=100,
lambda num_rows, condition: condition.get_num_rows() + num_rows, conditions, 0)

conditions = self._make_condition_dfs(conditions)
for condition_dataframe in conditions:
self._validate_conditions(condition_dataframe)
self._validate_conditions(conditions)

sampled = pd.DataFrame()
try:
Expand Down Expand Up @@ -974,6 +987,17 @@ def sample_from_conditions(self, conditions, max_tries_per_batch=100,

return sampled

def _validate_known_columns(self, conditions):
"""Validate the user-passed conditions."""
self._validate_conditions_unseen_columns(conditions)
if conditions.dropna().empty:
self._raise_condition_with_nans()
elif conditions.isna().any().any():
warnings.warn(
'Missing values are not yet supported. '
'Rows with any missing values will not be created.'
)

def sample_remaining_columns(self, known_columns, max_tries_per_batch=100,
batch_size=None, output_file_path=None):
"""Sample remaining rows from already known columns.
Expand Down Expand Up @@ -1006,7 +1030,7 @@ def sample_remaining_columns(self, known_columns, max_tries_per_batch=100,
output_file_path = validate_file_path(output_file_path)

known_columns = known_columns.copy()
self._validate_conditions(known_columns)
self._validate_known_columns(known_columns)
sampled = pd.DataFrame()
try:
with tqdm.tqdm(total=len(known_columns)) as progress_bar:
Expand Down
78 changes: 78 additions & 0 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import warnings
from unittest.mock import patch

import numpy as np
import pandas as pd
import pkg_resources
import pytest
from rdt.transformers import AnonymizedFaker, FloatFormatter, RegexGenerator, UniformEncoder

from sdv.datasets.demo import download_demo
from sdv.errors import SynthesizerInputError
from sdv.metadata import SingleTableMetadata
from sdv.sampling import Condition
from sdv.single_table import (
Expand Down Expand Up @@ -147,6 +150,81 @@ def test_sample_from_conditions_negative_float():
pd.testing.assert_series_equal(sampled_data['column1'], expected)


def test_sample_from_conditions_with_nans():
"""Test it crashes when condition has nans (GH#1758)."""
# Setup
data, metadata = download_demo(
modality='single_table',
dataset_name='fake_hotel_guests'
)
synthesizer = GaussianCopulaSynthesizer(metadata)
my_condition = Condition(
num_rows=250,
column_values={'room_type': None, 'has_rewards': False}
)

# Run
synthesizer.fit(data)

# Assert
error_msg = (
'Missing values are not yet supported for conditional sampling. '
'Please include only non-null values in your Condition objects.'
)
with pytest.raises(SynthesizerInputError, match=error_msg):
synthesizer.sample_from_conditions(conditions=[my_condition])


def test_sample_remaining_columns_with_all_nans():
"""Test it crashes when every condition row has a nan (GH#1758)."""
# Setup
data, metadata = download_demo(
modality='single_table',
dataset_name='fake_hotel_guests'
)
synthesizer = GaussianCopulaSynthesizer(metadata)
known_columns = pd.DataFrame(data={
'has_rewards': [np.nan, False, True],
'amenities_fee': [5.00, np.nan, None]
})

# Run
synthesizer.fit(data)

# Assert
error_msg = (
'Missing values are not yet supported for conditional sampling. '
'Please include only non-null values in your Condition objects.'
)
with pytest.raises(SynthesizerInputError, match=error_msg):
synthesizer.sample_remaining_columns(known_columns=known_columns)


def test_sample_remaining_columns_with_some_nans():
"""Test it warns when some of the condition rows contain nans (GH#1758)."""
# Setup
data, metadata = download_demo(
modality='single_table',
dataset_name='fake_hotel_guests'
)
synthesizer = GaussianCopulaSynthesizer(metadata)
known_columns = pd.DataFrame(data={
'has_rewards': [True, False, np.nan],
'amenities_fee': [5.00, np.nan, None]
})

# Run
synthesizer.fit(data)

# Assert
warn_msg = (
'Missing values are not yet supported. '
'Rows with any missing values will not be created.'
)
with pytest.warns(UserWarning, match=warn_msg):
synthesizer.sample_remaining_columns(known_columns=known_columns)


def test_multiple_fits():
"""Test the synthesizer refits correctly on new data.
Expand Down
58 changes: 52 additions & 6 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import date, datetime
from unittest.mock import ANY, MagicMock, Mock, call, mock_open, patch

import numpy as np
import pandas as pd
import pytest
from copulas.multivariate import GaussianMultivariate
Expand Down Expand Up @@ -1279,7 +1280,7 @@ def test_sample(self):
)
assert result == instance._sample_with_progress_bar.return_value

def test__validate_conditions(self):
def test__validate_conditions_unseen_columns(self):
"""Test that conditions are within the ``data_processor`` fields."""
# Setup
instance = Mock()
Expand All @@ -1290,12 +1291,12 @@ def test__validate_conditions(self):
conditions = pd.DataFrame({'name': ['Johanna'], 'surname': ['Doe']})

# Run
BaseSingleTableSynthesizer._validate_conditions(instance, conditions)
BaseSingleTableSynthesizer._validate_conditions_unseen_columns(instance, conditions)

# Assert
instance._data_processor.get_sdtypes.assert_called()

def test__validate_conditions_raises_error(self):
def test__validate_conditions_unseen_columns_raises_error(self):
"""Test that conditions are not in the ``data_processor`` fields."""
# Setup
instance = Mock()
Expand All @@ -1311,7 +1312,22 @@ def test__validate_conditions_raises_error(self):
'original data.'
)
with pytest.raises(ValueError, match=error_msg):
BaseSingleTableSynthesizer._validate_conditions(instance, conditions)
BaseSingleTableSynthesizer._validate_conditions_unseen_columns(instance, conditions)

def test__validate_conditions_nans(self):
"""Test that it raises an error when nans are in the data."""
# Setup
conditions = [pd.DataFrame({'names': [np.nan], 'surname': ['Doe']})]
synthesizer = BaseSingleTableSynthesizer(MagicMock())
synthesizer._validate_conditions_unseen_columns = Mock()

# Run and Assert
error_msg = (
'Missing values are not yet supported for conditional sampling. '
'Please include only non-null values in your Condition objects.'
)
with pytest.raises(SynthesizerInputError, match=error_msg):
synthesizer._validate_conditions(conditions)

def test__sample_with_conditions_constraints_not_met(self):
"""Test when conditions are not met."""
Expand Down Expand Up @@ -1523,7 +1539,7 @@ def test_sample_remaining_columns(self, mock_validate_file_path, mock_tqdm,
instance = BaseSingleTableSynthesizer(metadata)
known_columns = pd.DataFrame({'name': ['Johanna Doe']})

instance._validate_conditions = Mock()
instance._validate_known_columns = Mock()
instance._sample_with_conditions = Mock()
instance._model = GaussianMultivariate()
instance._sample_with_conditions.return_value = pd.DataFrame({'name': ['John Doe']})
Expand Down Expand Up @@ -1563,7 +1579,7 @@ def test_sample_remaining_columns_handles_sampling_error(
instance = BaseSingleTableSynthesizer(metadata)
known_columns = pd.DataFrame({'name': ['Johanna Doe']})

instance._validate_conditions = Mock()
instance._validate_known_columns = Mock()
instance._sample_with_conditions = Mock()
instance._model = GaussianMultivariate()
keyboard_error = KeyboardInterrupt()
Expand All @@ -1585,6 +1601,36 @@ def test_sample_remaining_columns_handles_sampling_error(
pd.testing.assert_frame_equal(result, pd.DataFrame())
mock_handle_sampling_error.assert_called_once_with(False, 'temp_file', keyboard_error)

def test__validate_known_columns_nans(self):
"""Test that it crashes when condition has nans."""
# Setup
conditions = pd.DataFrame({'names': [np.nan], 'surname': ['Doe']})
synthesizer = BaseSingleTableSynthesizer(MagicMock())
synthesizer._validate_conditions_unseen_columns = Mock()

# Run and Assert
error_msg = (
'Missing values are not yet supported for conditional sampling. '
'Please include only non-null values in your Condition objects.'
)
with pytest.raises(SynthesizerInputError, match=error_msg):
synthesizer._validate_known_columns(conditions)

def test__validate_known_columns_a_few_nans(self):
"""Test that it warns when condition has a few nans, but at least a valid row."""
# Setup
conditions = pd.DataFrame({'names': [np.nan, 'Dae'], 'surname': ['Doe', 'Due']})
synthesizer = BaseSingleTableSynthesizer(MagicMock())
synthesizer._validate_conditions_unseen_columns = Mock()

# Run and Assert
warn_msg = (
'Missing values are not yet supported. '
'Rows with any missing values will not be created.'
)
with pytest.warns(UserWarning, match=warn_msg):
synthesizer._validate_known_columns(conditions)

@patch('sdv.single_table.base.cloudpickle')
def test_save(self, cloudpickle_mock, tmp_path):
"""Test that the synthesizer is saved correctly."""
Expand Down

0 comments on commit 7a52ed2

Please sign in to comment.