Skip to content

Commit

Permalink
Update _get_is_datetime constraint logic to use the metadata (#1732)
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho authored Jan 11, 2024
1 parent f25037b commit 69a9638
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 78 deletions.
104 changes: 72 additions & 32 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ def __init__(self, low_column_name, high_column_name, strict_boundaries=False):
self.constraint_columns = (low_column_name, high_column_name)
self._dtype = None
self._is_datetime = None
self._low_datetime_format = None
self._high_datetime_format = None
self._nan_column_name = None

def _get_data(self, table_data):
Expand Down Expand Up @@ -435,8 +437,13 @@ def _fit(self, table_data):
The Table data.
"""
self._validate_columns_exist(table_data)
self._is_datetime = self._get_is_datetime()
self._dtype = table_data[self._high_column_name].dtypes
self._is_datetime = self._get_is_datetime()
if self._is_datetime:
self._low_datetime_format = self.metadata.columns[self._low_column_name].get(
'datetime_format')
self._high_datetime_format = self.metadata.columns[self._high_column_name].get(
'datetime_format')

def is_valid(self, table_data):
"""Check whether ``high`` is greater than ``low`` in each row.
Expand All @@ -451,8 +458,8 @@ def is_valid(self, table_data):
"""
low, high = self._get_data(table_data)
if self._is_datetime and self._dtype == 'O':
low = cast_to_datetime64(low)
high = cast_to_datetime64(high)
low = cast_to_datetime64(low, self._low_datetime_format)
high = cast_to_datetime64(high, self._high_datetime_format)

valid = pd.isna(low) | pd.isna(high) | self._operator(high, low)
return valid
Expand All @@ -476,7 +483,12 @@ def _transform(self, table_data):
"""
low, high = self._get_data(table_data)
if self._is_datetime:
diff_column = get_datetime_diff(high, low)
diff_column = get_datetime_diff(
high=high,
low=low,
high_datetime_format=self._high_datetime_format,
low_datetime_format=self._low_datetime_format
)
else:
diff_column = high - low

Expand Down Expand Up @@ -615,12 +627,12 @@ def __init__(self, column_name, relation, value):
self._diff_column_name = f'{self._column_name}#diff'
self.constraint_columns = (column_name,)
self._is_datetime = None
self._datetime_format = None
self._dtype = None
self._operator = INEQUALITY_TO_OPERATION[relation]

def _get_is_datetime(self, table_data):
column = table_data[self._column_name].to_numpy()
is_column_datetime = is_datetime_type(column)
def _get_is_datetime(self):
is_column_datetime = self.metadata.columns[self._column_name]['sdtype'] == 'datetime'
is_value_datetime = is_datetime_type(self._value)
is_datetime = is_column_datetime and is_value_datetime

Expand All @@ -641,8 +653,10 @@ def _fit(self, table_data):
The Table data.
"""
self._validate_columns_exist(table_data)
self._is_datetime = self._get_is_datetime(table_data)
self._dtype = table_data[self._column_name].dtypes
self._is_datetime = self._get_is_datetime()
if self._is_datetime:
self._datetime_format = self.metadata.columns[self._column_name].get('datetime_format')

def is_valid(self, table_data):
"""Say whether ``high`` is greater than ``low`` in each row.
Expand All @@ -657,7 +671,7 @@ def is_valid(self, table_data):
"""
column = table_data[self._column_name].to_numpy()
if self._is_datetime and self._dtype == 'O':
column = cast_to_datetime64(column)
column = cast_to_datetime64(column, datetime_format=self._datetime_format)

valid = pd.isna(column) | self._operator(column, self._value)
return valid
Expand All @@ -681,7 +695,7 @@ def _transform(self, table_data):
"""
column = table_data[self._column_name].to_numpy()
if self._is_datetime:
column = cast_to_datetime64(column)
column = cast_to_datetime64(column, datetime_format=self._datetime_format)
diff_column = abs(column - self._value)
diff_column = diff_column.astype(np.float64)
else:
Expand Down Expand Up @@ -838,18 +852,17 @@ def __init__(self, low_column_name, middle_column_name, high_column_name,
self.high_column_name = high_column_name
self.nan_column_name = None
self._is_datetime = None
self._low_datetime_format = None
self._middle_datetime_format = None
self._high_datetime_format = None
self._dtype = None
self.strict_boundaries = strict_boundaries
self._operator = operator.lt if strict_boundaries else operator.le

def _get_is_datetime(self, table_data):
low = table_data[self.low_column_name]
middle = table_data[self.middle_column_name]
high = table_data[self.high_column_name]

is_low_datetime = is_datetime_type(low)
is_middle_datetime = is_datetime_type(middle)
is_high_datetime = is_datetime_type(high)
def _get_is_datetime(self):
is_low_datetime = self.metadata.columns[self.low_column_name]['sdtype'] == 'datetime'
is_middle_datetime = self.metadata.columns[self.middle_column_name]['sdtype'] == 'datetime'
is_high_datetime = self.metadata.columns[self.high_column_name]['sdtype'] == 'datetime'
is_datetime = is_low_datetime and is_high_datetime and is_middle_datetime

if not is_datetime and any([is_low_datetime, is_middle_datetime, is_high_datetime]):
Expand All @@ -865,7 +878,14 @@ def _fit(self, table_data):
The Table data.
"""
self._dtype = table_data[self.middle_column_name].dtypes
self._is_datetime = self._get_is_datetime(table_data)
self._is_datetime = self._get_is_datetime()
if self._is_datetime:
self._low_datetime_format = self.metadata.columns[self.low_column_name].get(
'datetime_format')
self._middle_datetime_format = self.metadata.columns[self.middle_column_name].get(
'datetime_format')
self._high_datetime_format = self.metadata.columns[self.high_column_name].get(
'datetime_format')

self.low_diff_column_name = f'{self.low_column_name}#{self.middle_column_name}'
self.high_diff_column_name = f'{self.middle_column_name}#{self.high_column_name}'
Expand Down Expand Up @@ -922,8 +942,21 @@ def _transform(self, table_data):
high = table_data[self.high_column_name].to_numpy()

if self._is_datetime:
low_diff_column = get_datetime_diff(middle, low, self._dtype)
high_diff_column = get_datetime_diff(high, middle, self._dtype)
low_diff_column = get_datetime_diff(
middle,
low,
high_datetime_format=self._middle_datetime_format,
low_datetime_format=self._low_datetime_format,
dtype=self._dtype,
)
high_diff_column = get_datetime_diff(
high,
middle,
high_datetime_format=self._high_datetime_format,
low_datetime_format=self._middle_datetime_format,
dtype=self._dtype,
)

else:
low_diff_column = middle - low
high_diff_column = high - middle
Expand Down Expand Up @@ -976,7 +1009,7 @@ def _reverse_transform(self, table_data):

low = table_data[self.low_column_name].to_numpy()
if self._is_datetime and self._dtype == 'O':
low = cast_to_datetime64(low)
low = cast_to_datetime64(low, self._low_datetime_format)

middle = pd.Series(low_diff_column + low).astype(self._dtype)
table_data[self.middle_column_name] = middle
Expand Down Expand Up @@ -1078,10 +1111,8 @@ def _get_diff_column_name(self, table_data):

return token.join(components)

def _get_is_datetime(self, table_data):
data = table_data[self._column_name]

is_column_datetime = is_datetime_type(data)
def _get_is_datetime(self):
is_column_datetime = self.metadata.columns[self._column_name]['sdtype'] == 'datetime'
is_low_datetime = is_datetime_type(self._low_value)
is_high_datetime = is_datetime_type(self._high_value)
is_datetime = is_low_datetime and is_high_datetime and is_column_datetime
Expand All @@ -1099,11 +1130,15 @@ def _fit(self, table_data):
Table data.
"""
self._dtype = table_data[self._column_name].dtypes
self._is_datetime = self._get_is_datetime(table_data)
self._is_datetime = self._get_is_datetime()
self._transformed_column = self._get_diff_column_name(table_data)
if self._is_datetime:
self._low_value = cast_to_datetime64(self._low_value)
self._high_value = cast_to_datetime64(self._high_value)
self._datetime_format = self.metadata.columns[self._column_name].get(
'datetime_format')
self._low_value = cast_to_datetime64(
self._low_value, datetime_format=self._datetime_format)
self._high_value = cast_to_datetime64(
self._high_value, datetime_format=self._datetime_format)

def is_valid(self, table_data):
"""Say whether the ``column_name`` is between the ``low`` and ``high`` values.
Expand All @@ -1119,7 +1154,7 @@ def is_valid(self, table_data):
data = table_data[self._column_name]

if self._is_datetime:
data = cast_to_datetime64(data)
data = cast_to_datetime64(data, datetime_format=self._datetime_format)

satisfy_low_bound = np.logical_or(
self._operator(self._low_value, data),
Expand Down Expand Up @@ -1152,7 +1187,8 @@ def _transform(self, table_data):
"""
data = table_data[self._column_name]
if self._is_datetime:
data = cast_to_datetime64(table_data[self._column_name])
data = cast_to_datetime64(
table_data[self._column_name], datetime_format=self._datetime_format)

data = logit(data, self._low_value, self._high_value)
table_data[self._transformed_column] = data
Expand Down Expand Up @@ -1181,7 +1217,11 @@ def _reverse_transform(self, table_data):
data = data.clip(self._low_value, self._high_value)

if self._is_datetime:
table_data[self._column_name] = pd.to_datetime(data)
pandas_datetime_format = None
if self._datetime_format:
pandas_datetime_format = self._datetime_format.replace('%-', '%')
table_data[self._column_name] = pd.to_datetime(data, format=pandas_datetime_format)

else:
table_data[self._column_name] = data.astype(self._dtype)

Expand Down
21 changes: 15 additions & 6 deletions sdv/constraints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,28 @@
import pandas as pd


def cast_to_datetime64(value):
def cast_to_datetime64(value, datetime_format=None):
"""Cast a given value to a ``numpy.datetime64`` format.
Args:
value (pandas.Series, np.ndarray, list, or str):
Input data to convert to ``numpy.datetime64``.
datetime_format (str):
Datetime format of the `value`.
Return:
``numpy.datetime64`` value or values.
"""
if datetime_format:
datetime_format = datetime_format.replace('%-', '%')

if isinstance(value, str):
value = pd.to_datetime(value).to_datetime64()
value = pd.to_datetime(value, format=datetime_format).to_datetime64()
elif isinstance(value, pd.Series):
value = value.astype('datetime64[ns]')
elif isinstance(value, (np.ndarray, list)):
value = np.array([
pd.to_datetime(item).to_datetime64()
pd.to_datetime(item, format=datetime_format).to_datetime64()
if not pd.isna(item)
else pd.NaT.to_datetime64()
for item in value
Expand Down Expand Up @@ -169,7 +174,7 @@ def revert_nans_columns(table_data, nan_column_name):
return table_data.drop(columns=nan_column_name)


def get_datetime_diff(high, low, dtype='O'):
def get_datetime_diff(high, low, high_datetime_format=None, low_datetime_format=None, dtype='O'):
"""Calculate the difference between two datetime columns.
When casting datetimes to float using ``astype``, NaT values are not automatically
Expand All @@ -181,14 +186,18 @@ def get_datetime_diff(high, low, dtype='O'):
The high column values.
low (numpy.ndarray):
The low column values.
high_datetime_format (str):
Datetime format of the `high` column.
low_datetime_format (str):
Datetime format of the `low` column.
Returns:
numpy.ndarray:
The difference between the high and low column values.
"""
if dtype == 'O':
low = cast_to_datetime64(low)
high = cast_to_datetime64(high)
low = cast_to_datetime64(low, low_datetime_format)
high = cast_to_datetime64(high, high_datetime_format)

diff_column = high - low
nan_mask = pd.isna(diff_column)
Expand Down
Loading

0 comments on commit 69a9638

Please sign in to comment.