Skip to content

Commit

Permalink
feat(python): upcast int->float and date->datetime for certain Series…
Browse files Browse the repository at this point in the history
… comparisons (#11779)
  • Loading branch information
mcrumiller authored Oct 21, 2023
1 parent 96b465e commit 6155e7f
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 11 deletions.
35 changes: 25 additions & 10 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,14 +481,30 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series:
return self.clone()
elif (other is False and op == "eq") or (other is True and op == "neq"):
return ~self

if isinstance(other, datetime) and self.dtype == Datetime:
time_zone = self.dtype.time_zone # type: ignore[union-attr]
if str(other.tzinfo) != str(time_zone):
raise TypeError(
f"Datetime time zone {other.tzinfo!r} does not match Series timezone {time_zone!r}"
elif isinstance(other, float) and self.dtype in INTEGER_DTYPES:
# require upcast when comparing int series to float value
self = self.cast(Float64)
f = get_ffi_func(op + "_<>", Float64, self._s)
assert f is not None
return self._from_pyseries(f(other))
elif isinstance(other, datetime):
if self.dtype == Date:
# require upcast when comparing date series to datetime
self = self.cast(Datetime("us"))
time_unit = "us"
elif self.dtype == Datetime:
# Use local time zone info
time_zone = self.dtype.time_zone # type: ignore[union-attr]
if str(other.tzinfo) != str(time_zone):
raise TypeError(
f"Datetime time zone {other.tzinfo!r} does not match Series timezone {time_zone!r}"
)
time_unit = self.dtype.time_unit # type: ignore[union-attr]
else:
raise ValueError(
f"cannot compare datetime.datetime to series of type {self.dtype}"
)
ts = _datetime_to_pl_timestamp(other, self.dtype.time_unit) # type: ignore[union-attr]
ts = _datetime_to_pl_timestamp(other, time_unit) # type: ignore[arg-type]
f = get_ffi_func(op + "_<>", Int64, self._s)
assert f is not None
return self._from_pyseries(f(ts))
Expand All @@ -497,14 +513,13 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series:
f = get_ffi_func(op + "_<>", Int64, self._s)
assert f is not None
return self._from_pyseries(f(d))
elif self.dtype == Categorical and not isinstance(other, Series):
other = Series([other])
elif isinstance(other, date) and self.dtype == Date:
d = _date_to_pl_date(other)
f = get_ffi_func(op + "_<>", Int32, self._s)
assert f is not None
return self._from_pyseries(f(d))
elif self.dtype == Categorical and not isinstance(other, Series):
other = Series([other])

if isinstance(other, Sequence) and not isinstance(other, str):
other = Series("", other, dtype_if_empty=self.dtype)
if isinstance(other, Series):
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _negate_duration(duration: str) -> str:
return f"-{duration}"


def _datetime_to_pl_timestamp(dt: datetime, time_unit: TimeUnit | None) -> int:
def _datetime_to_pl_timestamp(dt: datetime, time_unit: TimeUnit) -> int:
"""Convert a python datetime to a timestamp in given time unit."""
if dt.tzinfo is None:
# Make sure to use UTC rather than system time zone.
Expand Down
15 changes: 15 additions & 0 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,6 +1568,21 @@ def test_comparisons_int_series_to_float() -> None:
assert_series_equal(srs_int - True, pl.Series([0, 1, 2, 3]))


def test_comparisons_int_series_to_float_scalar() -> None:
srs_int = pl.Series([1, 2, 3, 4])

assert_series_equal(srs_int < 1.5, pl.Series([True, False, False, False]))
assert_series_equal(srs_int > 1.5, pl.Series([False, True, True, True]))


def test_comparisons_datetime_series_to_date_scalar() -> None:
srs_date = pl.Series([date(2023, 1, 1), date(2023, 1, 2), date(2023, 1, 3)])
dt = datetime(2023, 1, 1, 12, 0, 0)

assert_series_equal(srs_date < dt, pl.Series([True, False, False]))
assert_series_equal(srs_date > dt, pl.Series([False, True, True]))


def test_comparisons_float_series_to_int() -> None:
srs_float = pl.Series([1.0, 2.0, 3.0, 4.0])

Expand Down

0 comments on commit 6155e7f

Please sign in to comment.