Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update _get_is_datetime constraint logic to use the metadata #1732

Merged
merged 4 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 72 additions & 32 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ def __init__(self, low_column_name, high_column_name, strict_boundaries=False):
self.constraint_columns = (low_column_name, high_column_name)
self._dtype = None
self._is_datetime = None
self._low_datetime_format = None
self._high_datetime_format = None
self._nan_column_name = None

def _get_data(self, table_data):
Expand Down Expand Up @@ -435,8 +437,13 @@ def _fit(self, table_data):
The Table data.
"""
self._validate_columns_exist(table_data)
self._is_datetime = self._get_is_datetime()
self._dtype = table_data[self._high_column_name].dtypes
self._is_datetime = self._get_is_datetime()
if self._is_datetime:
self._low_datetime_format = self.metadata.columns[self._low_column_name].get(
'datetime_format')
self._high_datetime_format = self.metadata.columns[self._high_column_name].get(
'datetime_format')

def is_valid(self, table_data):
"""Check whether ``high`` is greater than ``low`` in each row.
Expand All @@ -451,8 +458,8 @@ def is_valid(self, table_data):
"""
low, high = self._get_data(table_data)
if self._is_datetime and self._dtype == 'O':
low = cast_to_datetime64(low)
high = cast_to_datetime64(high)
low = cast_to_datetime64(low, self._low_datetime_format)
high = cast_to_datetime64(high, self._high_datetime_format)

valid = pd.isna(low) | pd.isna(high) | self._operator(high, low)
return valid
Expand All @@ -476,7 +483,12 @@ def _transform(self, table_data):
"""
low, high = self._get_data(table_data)
if self._is_datetime:
diff_column = get_datetime_diff(high, low)
diff_column = get_datetime_diff(
high=high,
low=low,
high_datetime_format=self._high_datetime_format,
low_datetime_format=self._low_datetime_format
)
else:
diff_column = high - low

Expand Down Expand Up @@ -615,12 +627,12 @@ def __init__(self, column_name, relation, value):
self._diff_column_name = f'{self._column_name}#diff'
self.constraint_columns = (column_name,)
self._is_datetime = None
self._datetime_format = None
self._dtype = None
self._operator = INEQUALITY_TO_OPERATION[relation]

def _get_is_datetime(self, table_data):
column = table_data[self._column_name].to_numpy()
is_column_datetime = is_datetime_type(column)
def _get_is_datetime(self):
is_column_datetime = self.metadata.columns[self._column_name]['sdtype'] == 'datetime'
is_value_datetime = is_datetime_type(self._value)
is_datetime = is_column_datetime and is_value_datetime

Expand All @@ -641,8 +653,10 @@ def _fit(self, table_data):
The Table data.
"""
self._validate_columns_exist(table_data)
self._is_datetime = self._get_is_datetime(table_data)
self._dtype = table_data[self._column_name].dtypes
self._is_datetime = self._get_is_datetime()
if self._is_datetime:
self._datetime_format = self.metadata.columns[self._column_name].get('datetime_format')

def is_valid(self, table_data):
"""Say whether ``high`` is greater than ``low`` in each row.
Expand All @@ -657,7 +671,7 @@ def is_valid(self, table_data):
"""
column = table_data[self._column_name].to_numpy()
if self._is_datetime and self._dtype == 'O':
column = cast_to_datetime64(column)
column = cast_to_datetime64(column, datetime_format=self._datetime_format)

valid = pd.isna(column) | self._operator(column, self._value)
return valid
Expand All @@ -681,7 +695,7 @@ def _transform(self, table_data):
"""
column = table_data[self._column_name].to_numpy()
if self._is_datetime:
column = cast_to_datetime64(column)
column = cast_to_datetime64(column, datetime_format=self._datetime_format)
diff_column = abs(column - self._value)
diff_column = diff_column.astype(np.float64)
else:
Expand Down Expand Up @@ -838,18 +852,17 @@ def __init__(self, low_column_name, middle_column_name, high_column_name,
self.high_column_name = high_column_name
self.nan_column_name = None
self._is_datetime = None
self._low_datetime_format = None
self._middle_datetime_format = None
self._high_datetime_format = None
self._dtype = None
self.strict_boundaries = strict_boundaries
self._operator = operator.lt if strict_boundaries else operator.le

def _get_is_datetime(self, table_data):
low = table_data[self.low_column_name]
middle = table_data[self.middle_column_name]
high = table_data[self.high_column_name]

is_low_datetime = is_datetime_type(low)
is_middle_datetime = is_datetime_type(middle)
is_high_datetime = is_datetime_type(high)
def _get_is_datetime(self):
is_low_datetime = self.metadata.columns[self.low_column_name]['sdtype'] == 'datetime'
is_middle_datetime = self.metadata.columns[self.middle_column_name]['sdtype'] == 'datetime'
is_high_datetime = self.metadata.columns[self.high_column_name]['sdtype'] == 'datetime'
is_datetime = is_low_datetime and is_high_datetime and is_middle_datetime

if not is_datetime and any([is_low_datetime, is_middle_datetime, is_high_datetime]):
Expand All @@ -865,7 +878,14 @@ def _fit(self, table_data):
The Table data.
"""
self._dtype = table_data[self.middle_column_name].dtypes
self._is_datetime = self._get_is_datetime(table_data)
self._is_datetime = self._get_is_datetime()
if self._is_datetime:
self._low_datetime_format = self.metadata.columns[self.low_column_name].get(
'datetime_format')
self._middle_datetime_format = self.metadata.columns[self.middle_column_name].get(
'datetime_format')
self._high_datetime_format = self.metadata.columns[self.high_column_name].get(
'datetime_format')

self.low_diff_column_name = f'{self.low_column_name}#{self.middle_column_name}'
self.high_diff_column_name = f'{self.middle_column_name}#{self.high_column_name}'
Expand Down Expand Up @@ -922,8 +942,21 @@ def _transform(self, table_data):
high = table_data[self.high_column_name].to_numpy()

if self._is_datetime:
low_diff_column = get_datetime_diff(middle, low, self._dtype)
high_diff_column = get_datetime_diff(high, middle, self._dtype)
low_diff_column = get_datetime_diff(
middle,
low,
high_datetime_format=self._middle_datetime_format,
low_datetime_format=self._low_datetime_format,
dtype=self._dtype,
)
high_diff_column = get_datetime_diff(
high,
middle,
high_datetime_format=self._high_datetime_format,
low_datetime_format=self._middle_datetime_format,
dtype=self._dtype,
)

else:
low_diff_column = middle - low
high_diff_column = high - middle
Expand Down Expand Up @@ -976,7 +1009,7 @@ def _reverse_transform(self, table_data):

low = table_data[self.low_column_name].to_numpy()
if self._is_datetime and self._dtype == 'O':
low = cast_to_datetime64(low)
low = cast_to_datetime64(low, self._low_datetime_format)

middle = pd.Series(low_diff_column + low).astype(self._dtype)
table_data[self.middle_column_name] = middle
Expand Down Expand Up @@ -1078,10 +1111,8 @@ def _get_diff_column_name(self, table_data):

return token.join(components)

def _get_is_datetime(self, table_data):
data = table_data[self._column_name]

is_column_datetime = is_datetime_type(data)
def _get_is_datetime(self):
is_column_datetime = self.metadata.columns[self._column_name]['sdtype'] == 'datetime'
is_low_datetime = is_datetime_type(self._low_value)
is_high_datetime = is_datetime_type(self._high_value)
is_datetime = is_low_datetime and is_high_datetime and is_column_datetime
Expand All @@ -1099,11 +1130,15 @@ def _fit(self, table_data):
Table data.
"""
self._dtype = table_data[self._column_name].dtypes
self._is_datetime = self._get_is_datetime(table_data)
self._is_datetime = self._get_is_datetime()
self._transformed_column = self._get_diff_column_name(table_data)
if self._is_datetime:
self._low_value = cast_to_datetime64(self._low_value)
self._high_value = cast_to_datetime64(self._high_value)
self._datetime_format = self.metadata.columns[self._column_name].get(
'datetime_format')
self._low_value = cast_to_datetime64(
self._low_value, datetime_format=self._datetime_format)
self._high_value = cast_to_datetime64(
self._high_value, datetime_format=self._datetime_format)

def is_valid(self, table_data):
"""Say whether the ``column_name`` is between the ``low`` and ``high`` values.
Expand All @@ -1119,7 +1154,7 @@ def is_valid(self, table_data):
data = table_data[self._column_name]

if self._is_datetime:
data = cast_to_datetime64(data)
data = cast_to_datetime64(data, datetime_format=self._datetime_format)

satisfy_low_bound = np.logical_or(
self._operator(self._low_value, data),
Expand Down Expand Up @@ -1152,7 +1187,8 @@ def _transform(self, table_data):
"""
data = table_data[self._column_name]
if self._is_datetime:
data = cast_to_datetime64(table_data[self._column_name])
data = cast_to_datetime64(
table_data[self._column_name], datetime_format=self._datetime_format)

data = logit(data, self._low_value, self._high_value)
table_data[self._transformed_column] = data
Expand Down Expand Up @@ -1181,7 +1217,11 @@ def _reverse_transform(self, table_data):
data = data.clip(self._low_value, self._high_value)

if self._is_datetime:
table_data[self._column_name] = pd.to_datetime(data)
pandas_datetime_format = None
if self._datetime_format:
pandas_datetime_format = self._datetime_format.replace('%-', '%')
table_data[self._column_name] = pd.to_datetime(data, format=pandas_datetime_format)

else:
table_data[self._column_name] = data.astype(self._dtype)

Expand Down
21 changes: 15 additions & 6 deletions sdv/constraints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,28 @@
import pandas as pd


def cast_to_datetime64(value):
def cast_to_datetime64(value, datetime_format=None):
"""Cast a given value to a ``numpy.datetime64`` format.

Args:
value (pandas.Series, np.ndarray, list, or str):
Input data to convert to ``numpy.datetime64``.
datetime_format (str):
Datetime format of the `value`.

Return:
``numpy.datetime64`` value or values.
"""
if datetime_format:
datetime_format = datetime_format.replace('%-', '%')

if isinstance(value, str):
value = pd.to_datetime(value).to_datetime64()
value = pd.to_datetime(value, format=datetime_format).to_datetime64()
elif isinstance(value, pd.Series):
value = value.astype('datetime64[ns]')
elif isinstance(value, (np.ndarray, list)):
value = np.array([
pd.to_datetime(item).to_datetime64()
pd.to_datetime(item, format=datetime_format).to_datetime64()
if not pd.isna(item)
else pd.NaT.to_datetime64()
for item in value
Expand Down Expand Up @@ -169,7 +174,7 @@ def revert_nans_columns(table_data, nan_column_name):
return table_data.drop(columns=nan_column_name)


def get_datetime_diff(high, low, dtype='O'):
def get_datetime_diff(high, low, high_datetime_format=None, low_datetime_format=None, dtype='O'):
"""Calculate the difference between two datetime columns.

When casting datetimes to float using ``astype``, NaT values are not automatically
Expand All @@ -181,14 +186,18 @@ def get_datetime_diff(high, low, dtype='O'):
The high column values.
low (numpy.ndarray):
The low column values.
high_datetime_format (str):
Datetime format of the `high` column.
low_datetime_format (str):
Datetime format of the `low` column.

Returns:
numpy.ndarray:
The difference between the high and low column values.
"""
if dtype == 'O':
low = cast_to_datetime64(low)
high = cast_to_datetime64(high)
low = cast_to_datetime64(low, low_datetime_format)
high = cast_to_datetime64(high, high_datetime_format)

diff_column = high - low
nan_mask = pd.isna(diff_column)
Expand Down
Loading
Loading