Skip to content

Commit

Permalink
Update get_is_datetime methods
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Jan 5, 2024
1 parent f25037b commit 159195b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 52 deletions.
29 changes: 11 additions & 18 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,9 +618,8 @@ def __init__(self, column_name, relation, value):
self._dtype = None
self._operator = INEQUALITY_TO_OPERATION[relation]

def _get_is_datetime(self, table_data):
column = table_data[self._column_name].to_numpy()
is_column_datetime = is_datetime_type(column)
def _get_is_datetime(self):
is_column_datetime = self.metadata.columns[self._column_name]['sdtype'] == 'datetime'
is_value_datetime = is_datetime_type(self._value)
is_datetime = is_column_datetime and is_value_datetime

Expand All @@ -641,7 +640,7 @@ def _fit(self, table_data):
The Table data.
"""
self._validate_columns_exist(table_data)
self._is_datetime = self._get_is_datetime(table_data)
self._is_datetime = self._get_is_datetime()
self._dtype = table_data[self._column_name].dtypes

def is_valid(self, table_data):
Expand Down Expand Up @@ -842,14 +841,10 @@ def __init__(self, low_column_name, middle_column_name, high_column_name,
self.strict_boundaries = strict_boundaries
self._operator = operator.lt if strict_boundaries else operator.le

def _get_is_datetime(self, table_data):
low = table_data[self.low_column_name]
middle = table_data[self.middle_column_name]
high = table_data[self.high_column_name]

is_low_datetime = is_datetime_type(low)
is_middle_datetime = is_datetime_type(middle)
is_high_datetime = is_datetime_type(high)
def _get_is_datetime(self):
is_low_datetime = self.metadata.columns[self.low_column_name]['sdtype'] == 'datetime'
is_middle_datetime = self.metadata.columns[self.middle_column_name]['sdtype'] == 'datetime'
is_high_datetime = self.metadata.columns[self.high_column_name]['sdtype'] == 'datetime'
is_datetime = is_low_datetime and is_high_datetime and is_middle_datetime

if not is_datetime and any([is_low_datetime, is_middle_datetime, is_high_datetime]):
Expand All @@ -865,7 +860,7 @@ def _fit(self, table_data):
The Table data.
"""
self._dtype = table_data[self.middle_column_name].dtypes
self._is_datetime = self._get_is_datetime(table_data)
self._is_datetime = self._get_is_datetime()

self.low_diff_column_name = f'{self.low_column_name}#{self.middle_column_name}'
self.high_diff_column_name = f'{self.middle_column_name}#{self.high_column_name}'
Expand Down Expand Up @@ -1078,10 +1073,8 @@ def _get_diff_column_name(self, table_data):

return token.join(components)

def _get_is_datetime(self, table_data):
data = table_data[self._column_name]

is_column_datetime = is_datetime_type(data)
def _get_is_datetime(self):
is_column_datetime = self.metadata.columns[self._column_name]['sdtype'] == 'datetime'
is_low_datetime = is_datetime_type(self._low_value)
is_high_datetime = is_datetime_type(self._high_value)
is_datetime = is_low_datetime and is_high_datetime and is_column_datetime
Expand All @@ -1099,7 +1092,7 @@ def _fit(self, table_data):
Table data.
"""
self._dtype = table_data[self._column_name].dtypes
self._is_datetime = self._get_is_datetime(table_data)
self._is_datetime = self._get_is_datetime()
self._transformed_column = self._get_diff_column_name(table_data)
if self._is_datetime:
self._low_value = cast_to_datetime64(self._low_value)
Expand Down
74 changes: 40 additions & 34 deletions tests/unit/constraints/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import operator
import re
from datetime import datetime
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -2066,15 +2066,13 @@ def test__get_is_datetime_incorrect_data(self):
- Raises ``ValueError`` if only one of column/value is datetime.
"""
# Setupy
table_data = pd.DataFrame({
'a': pd.to_datetime(['2020-01-01']),
})
instance = ScalarInequality(column_name='a', value=1, relation='<')
instance.metadata = Mock(columns={'a': {'sdtype': 'datetime'}})

# Run / Assert
err_msg = 'Both column and value must be datetime.'
with pytest.raises(ValueError, match=err_msg):
instance._get_is_datetime(table_data)
instance._get_is_datetime()

def test__validate_columns_exist_incorrect_columns(self):
"""Test the ``ScalarInequality._validate_columns_exist`` method.
Expand Down Expand Up @@ -2122,7 +2120,7 @@ def test__fit(self):

# Assert
instance._validate_columns_exist.assert_called_once_with(table_data)
instance._get_is_datetime.assert_called_once_with(table_data)
instance._get_is_datetime.assert_called_once()
assert not instance._is_datetime
assert instance._dtype == pd.Series([1]).dtype # exact dtype (32 or 64) depends on OS

Expand All @@ -2142,6 +2140,7 @@ def test__fit_floats(self):
'b': [4., 5., 6.]
})
instance = ScalarInequality(column_name='b', value=10, relation='>')
instance.metadata = MagicMock()

# Run
instance._fit(table_data)
Expand All @@ -2165,6 +2164,10 @@ def test__fit_datetime(self):
'b': pd.to_datetime(['2020-01-02'])
})
instance = ScalarInequality(column_name='b', value='2020-01-01', relation='>')
instance.metadata = Mock(columns={
'a': {'sdtype': 'datetime'},
'b': {'sdtype': 'datetime'},
})

# Run
instance._fit(table_data)
Expand Down Expand Up @@ -2990,15 +2993,15 @@ def test__get_is_datetime(self):
- The output should be True.
"""
# Setup
table_data = pd.DataFrame({
'join_date': pd.to_datetime(['2021-02-10', '2021-05-10', '2021-08-11']),
'promotion_date': pd.to_datetime(['2022-05-10', '2022-06-10', '2022-11-17']),
'retirement_date': pd.to_datetime(['2050-10-11', '2058-10-04', '2075-11-14'])
})
instance = Range('join_date', 'promotion_date', 'retirement_date')
instance.metadata = Mock(columns={
'join_date': {'sdtype': 'datetime'},
'promotion_date': {'sdtype': 'datetime'},
'retirement_date': {'sdtype': 'datetime'},
})

# Run
is_datetime = instance._get_is_datetime(table_data)
is_datetime = instance._get_is_datetime()

# Assert
assert is_datetime
Expand All @@ -3020,15 +3023,11 @@ def test__get_is_datetime_no_datetimes(self):
- The output should be false since all the data is ``int``.
"""
# Setup
table_data = pd.DataFrame({
'age_when_joined': [18, 19, 20],
'current_age': [21, 22, 25],
'retirement_age': [65, 68, 75]
})
instance = Range('age_when_joined', 'current_age', 'retirement_age')
instance.metadata = MagicMock()

# Run
is_datetime = instance._get_is_datetime(table_data)
is_datetime = instance._get_is_datetime()

# Assert
assert not is_datetime
Expand All @@ -3051,17 +3050,17 @@ def test__get_is_datetime_raises_an_error(self):
- Value error with the expected message should be raised.
"""
# Setup
table_data = pd.DataFrame({
'join_date': pd.to_datetime(['2021-02-10', '2021-05-10', '2021-08-11']),
'promotion_date': pd.to_datetime(['2022-05-10', '2022-06-10', '2022-11-17']),
'current_age': [21, 22, 25],
})
instance = Range('join_date', 'promotion_date', 'current_age')
instance.metadata = Mock(columns={
'join_date': {'sdtype': 'datetime'},
'promotion_date': {'sdtype': 'datetime'},
'current_age': {'sdtype': 'numerical'},
})
expected_text = 'The constraint column and bounds must all be datetime.'

# Run
with pytest.raises(ValueError, match=expected_text):
instance._get_is_datetime(table_data)
instance._get_is_datetime()

def test__fit(self):
"""Test the ``_fit`` method of ``Range``.
Expand All @@ -3087,6 +3086,7 @@ def test__fit(self):
'current_age#age_when_joined#retirement_age': [1, 2, 3]
}, dtype=np.int64)
instance = Range('age_when_joined', 'current_age', 'retirement_age')
instance.metadata = MagicMock()

# Run
instance._fit(table_data)
Expand Down Expand Up @@ -3254,6 +3254,7 @@ def test_reverse_transform(self):
'b#c': [np.log(5), np.log(4), np.log(6)],
})
instance = Range('a', 'b', 'c')
instance.metadata = MagicMock()

# Run
instance.fit(table_data)
Expand All @@ -3278,6 +3279,11 @@ def test_reverse_transform_is_datetime(self):
})

instance = Range('a', 'b', 'c')
instance.metadata = Mock(columns={
'a': {'sdtype': 'datetime'},
'b': {'sdtype': 'datetime'},
'c': {'sdtype': 'datetime'},
})

# Run
instance.fit(table_data)
Expand Down Expand Up @@ -3661,13 +3667,11 @@ def test__get_is_datetime(self):
- The output should be True.
"""
# Setup
table_data = pd.DataFrame({
'promotion_date': pd.to_datetime(['2022-05-10', '2022-06-10', '2022-11-17']),
})
instance = ScalarRange('promotion_date', '2021-02-10', '2050-10-11')
instance.metadata = Mock(columns={'promotion_date': {'sdtype': 'datetime'}})

# Run
is_datetime = instance._get_is_datetime(table_data)
is_datetime = instance._get_is_datetime()

# Assert
assert is_datetime
Expand All @@ -3689,11 +3693,11 @@ def test__get_is_datetime_no_datetimes(self):
- The output should be false since all the data is ``int``.
"""
# Setup
table_data = pd.DataFrame({'current_age': [21, 22, 25]})
instance = ScalarRange('current_age', 21, 30)
instance.metadata = MagicMock()

# Run
is_datetime = instance._get_is_datetime(table_data)
is_datetime = instance._get_is_datetime()

# Assert
assert not is_datetime
Expand All @@ -3715,15 +3719,13 @@ def test__get_is_datetime_raises_an_error(self):
- The output should be false since all the data is ``int``.
"""
# Setup
table_data = pd.DataFrame({
'promotion_date': pd.to_datetime(['2022-05-10', '2022-06-10', '2022-11-17']),
})
instance = ScalarRange('promotion_date', 18, 25)
instance.metadata = Mock(columns={'promotion_date': {'sdtype': 'datetime'}})
expected_text = 'The constraint column and bounds must all be datetime.'

# Run
with pytest.raises(ValueError, match=expected_text):
instance._get_is_datetime(table_data)
instance._get_is_datetime()

def test__get_diff_column_name(self):
"""Test the ``ScalarRange._get_diff_column_name`` method.
Expand Down Expand Up @@ -3772,6 +3774,7 @@ def test__fit(self):
# Setup
table_data = pd.DataFrame({'current_age': [21, 22, 25]})
instance = ScalarRange('current_age', 18, 20)
instance.metadata = MagicMock()

# Run
instance._fit(table_data)
Expand Down Expand Up @@ -3800,6 +3803,7 @@ def test__fit_datetime(self):
]
})
instance = ScalarRange('checkin', '2022-05-05', '2022-06-01')
instance.metadata = Mock(columns={'checkin': {'sdtype': 'datetime'}})

# Run
instance._fit(table_data)
Expand Down Expand Up @@ -3903,6 +3907,7 @@ def test__transform(self, mock_logit):
table_data = pd.DataFrame({'current_age': [21, 22, 25]})
instance = ScalarRange('current_age', 20, 28)
mock_logit.return_value = [1, 2, 3]
instance.metadata = MagicMock()

# Run
instance.fit(table_data)
Expand Down Expand Up @@ -3937,6 +3942,7 @@ def test_reverse_transform(self, mock_sigmoid):
transformed_data = pd.DataFrame({'current_age#20#28': [1, 2, 3]})
mock_sigmoid.return_value = pd.Series([21, 22, 25])
instance = ScalarRange('current_age', 20, 28)
instance.metadata = MagicMock()

# Run
instance.fit(table_data)
Expand Down

0 comments on commit 159195b

Please sign in to comment.