Skip to content

Commit

Permalink
Use datetime_format in constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Jan 10, 2024
1 parent 159195b commit 3b9ab77
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 22 deletions.
77 changes: 62 additions & 15 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ def __init__(self, low_column_name, high_column_name, strict_boundaries=False):
self.constraint_columns = (low_column_name, high_column_name)
self._dtype = None
self._is_datetime = None
self._low_datetime_format = None
self._high_datetime_format = None
self._nan_column_name = None

def _get_data(self, table_data):
Expand Down Expand Up @@ -435,8 +437,13 @@ def _fit(self, table_data):
The Table data.
"""
self._validate_columns_exist(table_data)
self._is_datetime = self._get_is_datetime()
self._dtype = table_data[self._high_column_name].dtypes
self._is_datetime = self._get_is_datetime()
if self._is_datetime:
self._low_datetime_format = self.metadata.columns[self._low_column_name].get(
'datetime_format')
self._high_datetime_format = self.metadata.columns[self._high_column_name].get(
'datetime_format')

def is_valid(self, table_data):
"""Check whether ``high`` is greater than ``low`` in each row.
Expand All @@ -451,8 +458,8 @@ def is_valid(self, table_data):
"""
low, high = self._get_data(table_data)
if self._is_datetime and self._dtype == 'O':
low = cast_to_datetime64(low)
high = cast_to_datetime64(high)
low = cast_to_datetime64(low, self._low_datetime_format)
high = cast_to_datetime64(high, self._high_datetime_format)

valid = pd.isna(low) | pd.isna(high) | self._operator(high, low)
return valid
Expand All @@ -476,7 +483,12 @@ def _transform(self, table_data):
"""
low, high = self._get_data(table_data)
if self._is_datetime:
diff_column = get_datetime_diff(high, low)
diff_column = get_datetime_diff(
high=high,
low=low,
high_datetime_format=self._high_datetime_format,
low_datetime_format=self._low_datetime_format
)
else:
diff_column = high - low

Expand Down Expand Up @@ -615,6 +627,7 @@ def __init__(self, column_name, relation, value):
self._diff_column_name = f'{self._column_name}#diff'
self.constraint_columns = (column_name,)
self._is_datetime = None
self._datetime_format = None
self._dtype = None
self._operator = INEQUALITY_TO_OPERATION[relation]

Expand All @@ -640,8 +653,10 @@ def _fit(self, table_data):
The Table data.
"""
self._validate_columns_exist(table_data)
self._is_datetime = self._get_is_datetime()
self._dtype = table_data[self._column_name].dtypes
self._is_datetime = self._get_is_datetime()
if self._is_datetime:
self._datetime_format = self.metadata.columns[self._column_name].get('datetime_format')

def is_valid(self, table_data):
"""Say whether ``high`` is greater than ``low`` in each row.
Expand All @@ -656,7 +671,7 @@ def is_valid(self, table_data):
"""
column = table_data[self._column_name].to_numpy()
if self._is_datetime and self._dtype == 'O':
column = cast_to_datetime64(column)
column = cast_to_datetime64(column, datetime_format=self._datetime_format)

valid = pd.isna(column) | self._operator(column, self._value)
return valid
Expand All @@ -680,7 +695,7 @@ def _transform(self, table_data):
"""
column = table_data[self._column_name].to_numpy()
if self._is_datetime:
column = cast_to_datetime64(column)
column = cast_to_datetime64(column, datetime_format=self._datetime_format)
diff_column = abs(column - self._value)
diff_column = diff_column.astype(np.float64)
else:
Expand Down Expand Up @@ -837,6 +852,9 @@ def __init__(self, low_column_name, middle_column_name, high_column_name,
self.high_column_name = high_column_name
self.nan_column_name = None
self._is_datetime = None
self._low_datetime_format = None
self._middle_datetime_format = None
self._high_datetime_format = None
self._dtype = None
self.strict_boundaries = strict_boundaries
self._operator = operator.lt if strict_boundaries else operator.le
Expand All @@ -861,6 +879,13 @@ def _fit(self, table_data):
"""
self._dtype = table_data[self.middle_column_name].dtypes
self._is_datetime = self._get_is_datetime()
if self._is_datetime:
self._low_datetime_format = self.metadata.columns[self.low_column_name].get(
'datetime_format')
self._middle_datetime_format = self.metadata.columns[self.middle_column_name].get(
'datetime_format')
self._high_datetime_format = self.metadata.columns[self.high_column_name].get(
'datetime_format')

self.low_diff_column_name = f'{self.low_column_name}#{self.middle_column_name}'
self.high_diff_column_name = f'{self.middle_column_name}#{self.high_column_name}'
Expand Down Expand Up @@ -917,8 +942,21 @@ def _transform(self, table_data):
high = table_data[self.high_column_name].to_numpy()

if self._is_datetime:
low_diff_column = get_datetime_diff(middle, low, self._dtype)
high_diff_column = get_datetime_diff(high, middle, self._dtype)
low_diff_column = get_datetime_diff(
middle,
low,
high_datetime_format=self._middle_datetime_format,
low_datetime_format=self._low_datetime_format,
dtype=self._dtype,
)
high_diff_column = get_datetime_diff(
high,
middle,
high_datetime_format=self._high_datetime_format,
low_datetime_format=self._middle_datetime_format,
dtype=self._dtype,
)

else:
low_diff_column = middle - low
high_diff_column = high - middle
Expand Down Expand Up @@ -971,7 +1009,7 @@ def _reverse_transform(self, table_data):

low = table_data[self.low_column_name].to_numpy()
if self._is_datetime and self._dtype == 'O':
low = cast_to_datetime64(low)
low = cast_to_datetime64(low, self._low_datetime_format)

middle = pd.Series(low_diff_column + low).astype(self._dtype)
table_data[self.middle_column_name] = middle
Expand Down Expand Up @@ -1095,8 +1133,12 @@ def _fit(self, table_data):
self._is_datetime = self._get_is_datetime()
self._transformed_column = self._get_diff_column_name(table_data)
if self._is_datetime:
self._low_value = cast_to_datetime64(self._low_value)
self._high_value = cast_to_datetime64(self._high_value)
self._datetime_format = self.metadata.columns[self._column_name].get(
'datetime_format')
self._low_value = cast_to_datetime64(
self._low_value, datetime_format=self._datetime_format)
self._high_value = cast_to_datetime64(
self._high_value, datetime_format=self._datetime_format)

def is_valid(self, table_data):
"""Say whether the ``column_name`` is between the ``low`` and ``high`` values.
Expand All @@ -1112,7 +1154,7 @@ def is_valid(self, table_data):
data = table_data[self._column_name]

if self._is_datetime:
data = cast_to_datetime64(data)
data = cast_to_datetime64(data, datetime_format=self._datetime_format)

satisfy_low_bound = np.logical_or(
self._operator(self._low_value, data),
Expand Down Expand Up @@ -1145,7 +1187,8 @@ def _transform(self, table_data):
"""
data = table_data[self._column_name]
if self._is_datetime:
data = cast_to_datetime64(table_data[self._column_name])
data = cast_to_datetime64(
table_data[self._column_name], datetime_format=self._datetime_format)

data = logit(data, self._low_value, self._high_value)
table_data[self._transformed_column] = data
Expand Down Expand Up @@ -1174,7 +1217,11 @@ def _reverse_transform(self, table_data):
data = data.clip(self._low_value, self._high_value)

if self._is_datetime:
table_data[self._column_name] = pd.to_datetime(data)
pandas_datetime_format = None
if self._datetime_format:
pandas_datetime_format = self._datetime_format.replace('%-', '%')
table_data[self._column_name] = pd.to_datetime(data, format=pandas_datetime_format)

else:
table_data[self._column_name] = data.astype(self._dtype)

Expand Down
21 changes: 15 additions & 6 deletions sdv/constraints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,28 @@
import pandas as pd


def cast_to_datetime64(value):
def cast_to_datetime64(value, datetime_format=None):
"""Cast a given value to a ``numpy.datetime64`` format.
Args:
value (pandas.Series, np.ndarray, list, or str):
Input data to convert to ``numpy.datetime64``.
datetime_format (str):
Datetime format of the `value`.
Return:
``numpy.datetime64`` value or values.
"""
if datetime_format:
datetime_format = datetime_format.replace('%-', '%')

if isinstance(value, str):
value = pd.to_datetime(value).to_datetime64()
value = pd.to_datetime(value, format=datetime_format).to_datetime64()
elif isinstance(value, pd.Series):
value = value.astype('datetime64[ns]')
elif isinstance(value, (np.ndarray, list)):
value = np.array([
pd.to_datetime(item).to_datetime64()
pd.to_datetime(item, format=datetime_format).to_datetime64()
if not pd.isna(item)
else pd.NaT.to_datetime64()
for item in value
Expand Down Expand Up @@ -169,7 +174,7 @@ def revert_nans_columns(table_data, nan_column_name):
return table_data.drop(columns=nan_column_name)


def get_datetime_diff(high, low, dtype='O'):
def get_datetime_diff(high, low, high_datetime_format=None, low_datetime_format=None, dtype='O'):
"""Calculate the difference between two datetime columns.
When casting datetimes to float using ``astype``, NaT values are not automatically
Expand All @@ -181,14 +186,18 @@ def get_datetime_diff(high, low, dtype='O'):
The high column values.
low (numpy.ndarray):
The low column values.
high_datetime_format (str):
Datetime format of the `high` column.
low_datetime_format (str):
Datetime format of the `low` column.
Returns:
numpy.ndarray:
The difference between the high and low column values.
"""
if dtype == 'O':
low = cast_to_datetime64(low)
high = cast_to_datetime64(high)
low = cast_to_datetime64(low, low_datetime_format)
high = cast_to_datetime64(high, high_datetime_format)

diff_column = high - low
nan_mask = pd.isna(diff_column)
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/constraints/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,11 @@ def test__fit(self):
instance = Inequality(low_column_name='a', high_column_name='b')
instance._validate_columns_exist = Mock()
instance._get_is_datetime = Mock(return_value='abc')
instance.metadata = Mock()
instance.metadata.columns = {
'a': {'sdtype': 'datetime', 'datetime_format': '%y %m, %d'},
'b': {'sdtype': 'datetime', 'datetime_format': '%y %m, %d'},
}

# Run
instance._fit(table_data)
Expand Down Expand Up @@ -3981,7 +3986,7 @@ def test_reverse_transform_is_datetime(self, mock_sigmoid, mock_pd):
instance = ScalarRange('current_age', 20, 28)
instance._transformed_column = 'current_age#20#28'
instance._is_datetime = True
mock_pd.to_datetime.side_effect = lambda x: pd.to_datetime('2021-02-02 10:10:59')
mock_pd.to_datetime.side_effect = lambda x, format: pd.to_datetime('2021-02-02 10:10:59')

# Run
output = instance.reverse_transform(transformed_data)
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/constraints/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,35 @@ def test_cast_to_datetime64():
assert expected_string_output == string_out


def test_cast_to_datetime64_datetime_format():
"""Test it when `datetime_format` is passed."""
# Setup
string_value = '2021-02-02'
list_value = [None, np.nan, '2021-02-02']
series_value = pd.Series(['2021-02-02', None, pd.NaT])

# Run
string_out = cast_to_datetime64(string_value, datetime_format='%Y-%m-%d')
list_out = cast_to_datetime64(list_value, datetime_format='%Y-%m-%d')
series_out = cast_to_datetime64(series_value, datetime_format='%Y-%m-%d')

# Assert
expected_string_output = np.datetime64('2021-02-02')
expected_series_output = pd.Series([
np.datetime64('2021-02-02'),
np.datetime64('NaT'),
np.datetime64('NaT')
])
expected_list_output = np.array([
np.datetime64('NaT'),
np.datetime64('NaT'),
'2021-02-02'
], dtype='datetime64[ns]')
np.testing.assert_array_equal(expected_list_output, list_out)
pd.testing.assert_series_equal(expected_series_output, series_out)
assert expected_string_output == string_out


def test_matches_datetime_format():
"""Test the ``matches_datetime_format`` method.
Expand Down

0 comments on commit 3b9ab77

Please sign in to comment.