Skip to content

Commit

Permalink
[bug] Support preprocessing datetime.date date features (#3534)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffkinnison authored Aug 16, 2023
1 parent 090918d commit 1b0774f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
4 changes: 3 additions & 1 deletion ludwig/features/date_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================
import logging
from datetime import datetime
from datetime import date, datetime
from typing import Dict, List

import numpy as np
Expand Down Expand Up @@ -66,6 +66,8 @@ def date_to_list(date_value, datetime_format, preprocessing_parameters):
try:
if isinstance(date_value, datetime):
datetime_obj = date_value
elif isinstance(date_value, date):
datetime_obj = datetime.combine(date=date_value, time=datetime.min.time())
elif isinstance(date_value, str) and datetime_format is not None:
try:
datetime_obj = datetime.strptime(date_value, datetime_format)
Expand Down
25 changes: 24 additions & 1 deletion tests/ludwig/features/test_date_feature.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from datetime import datetime
from datetime import date, datetime
from typing import Any, List

import pytest
Expand Down Expand Up @@ -157,3 +157,26 @@ def test_date_to_list__UsesFillValueOnInvalidDate():
0,
0,
]


@pytest.fixture(scope="module")
def date_obj():
return date.fromisoformat("2022-06-25")


@pytest.fixture(scope="module")
def date_obj_vec():
return create_vector_from_datetime_obj(datetime.fromisoformat("2022-06-25"))


def test_date_object_to_list(date_obj, date_obj_vec, fill_value):
"""Test support for datetime.date object conversion.
Args:
date_obj: Date object to convert into a vector
date_obj_vector: Expected vector version of `date_obj`
"""
computed_date_vec = date_feature.DateInputFeature.date_to_list(
date_obj, None, preprocessing_parameters={MISSING_VALUE_STRATEGY: FILL_WITH_CONST, "fill_value": fill_value}
)
assert computed_date_vec == date_obj_vec

0 comments on commit 1b0774f

Please sign in to comment.