-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat] Support for numeric date feature inputs (#3517)
- Loading branch information
1 parent
6552adc
commit 3a8d65e
Showing
5 changed files
with
292 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import datetime | ||
import time | ||
|
||
import pandas as pd | ||
import pytest | ||
from dateutil.parser import parse | ||
|
||
from ludwig.api import LudwigModel | ||
from ludwig.constants import ( | ||
BACKEND, | ||
BINARY, | ||
DATE, | ||
EPOCHS, | ||
FILL_WITH_CONST, | ||
INPUT_FEATURES, | ||
MISSING_VALUE_STRATEGY, | ||
NAME, | ||
OUTPUT_FEATURES, | ||
PREPROCESSING, | ||
RAY, | ||
TRAINER, | ||
TYPE, | ||
) | ||
from ludwig.utils.date_utils import create_vector_from_datetime_obj | ||
|
||
ray = pytest.importorskip("ray") | ||
|
||
pytestmark = [ | ||
pytest.mark.distributed, | ||
] | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def string_date_df() -> "pd.DataFrame": | ||
df = pd.DataFrame.from_dict( | ||
{ | ||
"date_feature": [str(datetime.datetime.now()) for i in range(100)], | ||
"binary_feature": [i % 2 for i in range(100)], | ||
} | ||
) | ||
return df | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def int_date_df() -> "pd.DataFrame": | ||
df = pd.DataFrame.from_dict( | ||
{ | ||
"date_feature": [time.time_ns() for i in range(100)], | ||
"binary_feature": [i % 2 for i in range(100)], | ||
} | ||
) | ||
return df | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def float_date_df() -> "pd.DataFrame": | ||
df = pd.DataFrame.from_dict( | ||
{ | ||
"date_feature": [time.time() for i in range(100)], | ||
"binary_feature": [i % 2 for i in range(100)], | ||
} | ||
) | ||
return df | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"date_df", | ||
[ | ||
pytest.param("string_date_df", id="string_date"), | ||
pytest.param("int_date_df", id="int_date"), | ||
pytest.param("float_date_df", id="float_date"), | ||
], | ||
) | ||
def test_date_feature_formats(date_df, request, ray_cluster_2cpu): | ||
df = request.getfixturevalue(date_df) | ||
|
||
config = { | ||
INPUT_FEATURES: [ | ||
{ | ||
NAME: "date_feature", | ||
TYPE: DATE, | ||
PREPROCESSING: {MISSING_VALUE_STRATEGY: FILL_WITH_CONST, "fill_value": "1970-01-01 00:00:00"}, | ||
} | ||
], | ||
OUTPUT_FEATURES: [{NAME: "binary_feature", TYPE: BINARY}], | ||
TRAINER: {EPOCHS: 2}, | ||
BACKEND: {TYPE: RAY, "processor": {TYPE: "dask"}}, | ||
} | ||
|
||
fill_value = create_vector_from_datetime_obj(parse("1970-01-01 00:00:00")) | ||
|
||
model = LudwigModel(config) | ||
preprocessed = model.preprocess(df) | ||
|
||
# Because parsing errors are suppressed, we want to ensure that the data was preprocessed correctly. Sample data is | ||
# drawn from the current time, so the recorded years should not match the fill value's year. | ||
for date in preprocessed.training_set.to_df().compute().iloc[:, 0].values: | ||
assert date[0] != fill_value[0] | ||
|
||
for date in preprocessed.validation_set.to_df().compute().iloc[:, 0].values: | ||
assert date[0] != fill_value[0] | ||
|
||
for date in preprocessed.test_set.to_df().compute().iloc[:, 0].values: | ||
assert date[0] != fill_value[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import datetime | ||
from contextlib import nullcontext as does_not_raise | ||
from typing import Any, ContextManager | ||
|
||
import pytest | ||
|
||
from ludwig.utils.date_utils import convert_number_to_datetime | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def reference_datetime() -> datetime.datetime: | ||
return datetime.datetime.utcfromtimestamp(1691600953.443032) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"timestamp,raises", | ||
[ | ||
pytest.param(1691600953.443032, does_not_raise(), id="float-s"), | ||
pytest.param(1691600953443.032, does_not_raise(), id="float-ms"), | ||
pytest.param(1691600953, does_not_raise(), id="int-s"), | ||
pytest.param(1691600953443, does_not_raise(), id="int-ms"), | ||
pytest.param("1691600953.443032", does_not_raise(), id="string[float]-s"), | ||
pytest.param("1691600953443.0032", does_not_raise(), id="string[float]-ms"), | ||
pytest.param("1691600953", does_not_raise(), id="string[int]-s"), | ||
pytest.param("1691600953443", does_not_raise(), id="string[int]-ms"), | ||
pytest.param("foo", pytest.raises(ValueError), id="string error"), | ||
pytest.param([1691600953.443032], pytest.raises(ValueError), id="list error"), | ||
pytest.param(datetime.datetime(2023, 8, 9, 13, 9, 13), pytest.raises(ValueError), id="datetime error"), | ||
pytest.param(None, pytest.raises(ValueError), id="NoneType error"), | ||
], | ||
) | ||
def test_convert_number_to_datetime(reference_datetime: datetime.datetime, timestamp: Any, raises: ContextManager): | ||
"""Ensure that numeric timestamps are correctly converted to datetime objects. | ||
Args: | ||
reference_datetime: A datetime object with the expected date/time | ||
timestamp: The timestamp to convert in s or ms | ||
raises: context manager to check for expected exceptions | ||
""" | ||
with raises: | ||
dt = convert_number_to_datetime(timestamp) | ||
|
||
# Check that the returned datetime is accurate to the scale of seconds. | ||
assert dt.year == reference_datetime.year | ||
assert dt.month == reference_datetime.month | ||
assert dt.day == reference_datetime.day | ||
assert dt.hour == reference_datetime.hour | ||
assert dt.minute == reference_datetime.minute | ||
assert dt.second == reference_datetime.second |