diff --git a/sdv/constraints/tabular.py b/sdv/constraints/tabular.py index 4de025c71..3895ce8e5 100644 --- a/sdv/constraints/tabular.py +++ b/sdv/constraints/tabular.py @@ -618,9 +618,8 @@ def __init__(self, column_name, relation, value): self._dtype = None self._operator = INEQUALITY_TO_OPERATION[relation] - def _get_is_datetime(self, table_data): - column = table_data[self._column_name].to_numpy() - is_column_datetime = is_datetime_type(column) + def _get_is_datetime(self): + is_column_datetime = self.metadata.columns[self._column_name]['sdtype'] == 'datetime' is_value_datetime = is_datetime_type(self._value) is_datetime = is_column_datetime and is_value_datetime @@ -641,7 +640,7 @@ def _fit(self, table_data): The Table data. """ self._validate_columns_exist(table_data) - self._is_datetime = self._get_is_datetime(table_data) + self._is_datetime = self._get_is_datetime() self._dtype = table_data[self._column_name].dtypes def is_valid(self, table_data): @@ -842,14 +841,10 @@ def __init__(self, low_column_name, middle_column_name, high_column_name, self.strict_boundaries = strict_boundaries self._operator = operator.lt if strict_boundaries else operator.le - def _get_is_datetime(self, table_data): - low = table_data[self.low_column_name] - middle = table_data[self.middle_column_name] - high = table_data[self.high_column_name] - - is_low_datetime = is_datetime_type(low) - is_middle_datetime = is_datetime_type(middle) - is_high_datetime = is_datetime_type(high) + def _get_is_datetime(self): + is_low_datetime = self.metadata.columns[self.low_column_name]['sdtype'] == 'datetime' + is_middle_datetime = self.metadata.columns[self.middle_column_name]['sdtype'] == 'datetime' + is_high_datetime = self.metadata.columns[self.high_column_name]['sdtype'] == 'datetime' is_datetime = is_low_datetime and is_high_datetime and is_middle_datetime if not is_datetime and any([is_low_datetime, is_middle_datetime, is_high_datetime]): @@ -865,7 +860,7 @@ def _fit(self, table_data): The Table data. """ self._dtype = table_data[self.middle_column_name].dtypes - self._is_datetime = self._get_is_datetime(table_data) + self._is_datetime = self._get_is_datetime() self.low_diff_column_name = f'{self.low_column_name}#{self.middle_column_name}' self.high_diff_column_name = f'{self.middle_column_name}#{self.high_column_name}' @@ -1078,10 +1073,8 @@ def _get_diff_column_name(self, table_data): return token.join(components) - def _get_is_datetime(self, table_data): - data = table_data[self._column_name] - - is_column_datetime = is_datetime_type(data) + def _get_is_datetime(self): + is_column_datetime = self.metadata.columns[self._column_name]['sdtype'] == 'datetime' is_low_datetime = is_datetime_type(self._low_value) is_high_datetime = is_datetime_type(self._high_value) is_datetime = is_low_datetime and is_high_datetime and is_column_datetime @@ -1099,7 +1092,7 @@ def _fit(self, table_data): Table data. """ self._dtype = table_data[self._column_name].dtypes - self._is_datetime = self._get_is_datetime(table_data) + self._is_datetime = self._get_is_datetime() self._transformed_column = self._get_diff_column_name(table_data) if self._is_datetime: self._low_value = cast_to_datetime64(self._low_value) diff --git a/tests/unit/constraints/test_tabular.py b/tests/unit/constraints/test_tabular.py index 177fd8c8b..be4631e3a 100644 --- a/tests/unit/constraints/test_tabular.py +++ b/tests/unit/constraints/test_tabular.py @@ -3,7 +3,7 @@ import operator import re from datetime import datetime -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import numpy as np import pandas as pd @@ -2066,15 +2066,13 @@ def test__get_is_datetime_incorrect_data(self): - Raises ``ValueError`` if only one of column/value is datetime. """ # Setupy - table_data = pd.DataFrame({ - 'a': pd.to_datetime(['2020-01-01']), - }) instance = ScalarInequality(column_name='a', value=1, relation='<') + instance.metadata = Mock(columns={'a': {'sdtype': 'datetime'}}) # Run / Assert err_msg = 'Both column and value must be datetime.' with pytest.raises(ValueError, match=err_msg): - instance._get_is_datetime(table_data) + instance._get_is_datetime() def test__validate_columns_exist_incorrect_columns(self): """Test the ``ScalarInequality._validate_columns_exist`` method. @@ -2122,7 +2120,7 @@ def test__fit(self): # Assert instance._validate_columns_exist.assert_called_once_with(table_data) - instance._get_is_datetime.assert_called_once_with(table_data) + instance._get_is_datetime.assert_called_once() assert not instance._is_datetime assert instance._dtype == pd.Series([1]).dtype # exact dtype (32 or 64) depends on OS @@ -2142,6 +2140,7 @@ def test__fit_floats(self): 'b': [4., 5., 6.] }) instance = ScalarInequality(column_name='b', value=10, relation='>') + instance.metadata = MagicMock() # Run instance._fit(table_data) @@ -2165,6 +2164,10 @@ def test__fit_datetime(self): 'b': pd.to_datetime(['2020-01-02']) }) instance = ScalarInequality(column_name='b', value='2020-01-01', relation='>') + instance.metadata = Mock(columns={ + 'a': {'sdtype': 'datetime'}, + 'b': {'sdtype': 'datetime'}, + }) # Run instance._fit(table_data) @@ -2990,15 +2993,15 @@ def test__get_is_datetime(self): - The output should be True. """ # Setup - table_data = pd.DataFrame({ - 'join_date': pd.to_datetime(['2021-02-10', '2021-05-10', '2021-08-11']), - 'promotion_date': pd.to_datetime(['2022-05-10', '2022-06-10', '2022-11-17']), - 'retirement_date': pd.to_datetime(['2050-10-11', '2058-10-04', '2075-11-14']) - }) instance = Range('join_date', 'promotion_date', 'retirement_date') + instance.metadata = Mock(columns={ + 'join_date': {'sdtype': 'datetime'}, + 'promotion_date': {'sdtype': 'datetime'}, + 'retirement_date': {'sdtype': 'datetime'}, + }) # Run - is_datetime = instance._get_is_datetime(table_data) + is_datetime = instance._get_is_datetime() # Assert assert is_datetime @@ -3020,15 +3023,11 @@ def test__get_is_datetime_no_datetimes(self): - The output should be false since all the data is ``int``. """ # Setup - table_data = pd.DataFrame({ - 'age_when_joined': [18, 19, 20], - 'current_age': [21, 22, 25], - 'retirement_age': [65, 68, 75] - }) instance = Range('age_when_joined', 'current_age', 'retirement_age') + instance.metadata = MagicMock() # Run - is_datetime = instance._get_is_datetime(table_data) + is_datetime = instance._get_is_datetime() # Assert assert not is_datetime @@ -3051,17 +3050,17 @@ def test__get_is_datetime_raises_an_error(self): - Value error with the expected message should be raised. """ # Setup - table_data = pd.DataFrame({ - 'join_date': pd.to_datetime(['2021-02-10', '2021-05-10', '2021-08-11']), - 'promotion_date': pd.to_datetime(['2022-05-10', '2022-06-10', '2022-11-17']), - 'current_age': [21, 22, 25], - }) instance = Range('join_date', 'promotion_date', 'current_age') + instance.metadata = Mock(columns={ + 'join_date': {'sdtype': 'datetime'}, + 'promotion_date': {'sdtype': 'datetime'}, + 'current_age': {'sdtype': 'numerical'}, + }) expected_text = 'The constraint column and bounds must all be datetime.' # Run with pytest.raises(ValueError, match=expected_text): - instance._get_is_datetime(table_data) + instance._get_is_datetime() def test__fit(self): """Test the ``_fit`` method of ``Range``. @@ -3087,6 +3086,7 @@ def test__fit(self): 'current_age#age_when_joined#retirement_age': [1, 2, 3] }, dtype=np.int64) instance = Range('age_when_joined', 'current_age', 'retirement_age') + instance.metadata = MagicMock() # Run instance._fit(table_data) @@ -3254,6 +3254,7 @@ def test_reverse_transform(self): 'b#c': [np.log(5), np.log(4), np.log(6)], }) instance = Range('a', 'b', 'c') + instance.metadata = MagicMock() # Run instance.fit(table_data) @@ -3278,6 +3279,11 @@ def test_reverse_transform_is_datetime(self): }) instance = Range('a', 'b', 'c') + instance.metadata = Mock(columns={ + 'a': {'sdtype': 'datetime'}, + 'b': {'sdtype': 'datetime'}, + 'c': {'sdtype': 'datetime'}, + }) # Run instance.fit(table_data) @@ -3661,13 +3667,11 @@ def test__get_is_datetime(self): - The output should be True. """ # Setup - table_data = pd.DataFrame({ - 'promotion_date': pd.to_datetime(['2022-05-10', '2022-06-10', '2022-11-17']), - }) instance = ScalarRange('promotion_date', '2021-02-10', '2050-10-11') + instance.metadata = Mock(columns={'promotion_date': {'sdtype': 'datetime'}}) # Run - is_datetime = instance._get_is_datetime(table_data) + is_datetime = instance._get_is_datetime() # Assert assert is_datetime @@ -3689,11 +3693,11 @@ def test__get_is_datetime_no_datetimes(self): - The output should be false since all the data is ``int``. """ # Setup - table_data = pd.DataFrame({'current_age': [21, 22, 25]}) instance = ScalarRange('current_age', 21, 30) + instance.metadata = MagicMock() # Run - is_datetime = instance._get_is_datetime(table_data) + is_datetime = instance._get_is_datetime() # Assert assert not is_datetime @@ -3715,15 +3719,13 @@ def test__get_is_datetime_raises_an_error(self): - The output should be false since all the data is ``int``. """ # Setup - table_data = pd.DataFrame({ - 'promotion_date': pd.to_datetime(['2022-05-10', '2022-06-10', '2022-11-17']), - }) instance = ScalarRange('promotion_date', 18, 25) + instance.metadata = Mock(columns={'promotion_date': {'sdtype': 'datetime'}}) expected_text = 'The constraint column and bounds must all be datetime.' # Run with pytest.raises(ValueError, match=expected_text): - instance._get_is_datetime(table_data) + instance._get_is_datetime() def test__get_diff_column_name(self): """Test the ``ScalarRange._get_diff_column_name`` method. @@ -3772,6 +3774,7 @@ def test__fit(self): # Setup table_data = pd.DataFrame({'current_age': [21, 22, 25]}) instance = ScalarRange('current_age', 18, 20) + instance.metadata = MagicMock() # Run instance._fit(table_data) @@ -3800,6 +3803,7 @@ def test__fit_datetime(self): ] }) instance = ScalarRange('checkin', '2022-05-05', '2022-06-01') + instance.metadata = Mock(columns={'checkin': {'sdtype': 'datetime'}}) # Run instance._fit(table_data) @@ -3903,6 +3907,7 @@ def test__transform(self, mock_logit): table_data = pd.DataFrame({'current_age': [21, 22, 25]}) instance = ScalarRange('current_age', 20, 28) mock_logit.return_value = [1, 2, 3] + instance.metadata = MagicMock() # Run instance.fit(table_data) @@ -3937,6 +3942,7 @@ def test_reverse_transform(self, mock_sigmoid): transformed_data = pd.DataFrame({'current_age#20#28': [1, 2, 3]}) mock_sigmoid.return_value = pd.Series([21, 22, 25]) instance = ScalarRange('current_age', 20, 28) + instance.metadata = MagicMock() # Run instance.fit(table_data)