Skip to content

Commit

Permalink
Add lazyframe comparison operators
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Aug 3, 2023
1 parent b81d9ec commit 19bcf60
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 0 deletions.
81 changes: 81 additions & 0 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from polars.type_aliases import (
AsofJoinStrategy,
ClosedInterval,
ComparisonOperator,
CsvEncoding,
FillNullStrategy,
FrameInitTypes,
Expand Down Expand Up @@ -635,6 +636,68 @@ def __dataframe_consortium_standard__(
self, api_version=api_version
)

def _comp(self, other: Any, op: ComparisonOperator) -> LazyFrame:
"""Compare a DataFrame with another object."""
if isinstance(other, LazyFrame):
return self._compare_to_other_df(other, op)
else:
return self._compare_to_non_df(other, op)

def _compare_to_other_df(
self,
other: LazyFrame,
op: ComparisonOperator,
) -> LazyFrame:
"""Compare a DataFrame with another DataFrame."""
if self.columns != other.columns:
raise ValueError("DataFrame columns do not match")
# if self.shape != other.shape:
# raise ValueError("DataFrame dimensions do not match")

suffix = "__POLARS_CMP_OTHER"
other_renamed = other.select(F.all().suffix(suffix))

# we must join on row count, since we cannot concatenate two lazy frames
combined = self.with_context(other_renamed)

if op == "eq":
expr = [F.col(n) == F.col(f"{n}{suffix}") for n in self.columns]
elif op == "neq":
expr = [F.col(n) != F.col(f"{n}{suffix}") for n in self.columns]
elif op == "gt":
expr = [F.col(n) > F.col(f"{n}{suffix}") for n in self.columns]
elif op == "lt":
expr = [F.col(n) < F.col(f"{n}{suffix}") for n in self.columns]
elif op == "gt_eq":
expr = [F.col(n) >= F.col(f"{n}{suffix}") for n in self.columns]
elif op == "lt_eq":
expr = [F.col(n) <= F.col(f"{n}{suffix}") for n in self.columns]
else:
raise ValueError(f"got unexpected comparison operator: {op}")

return combined.select(expr)

def _compare_to_non_df(
self,
other: Any,
op: ComparisonOperator,
) -> LazyFrame:
"""Compare a DataFrame with a non-DataFrame object."""
if op == "eq":
return self.select(F.all() == other)
elif op == "neq":
return self.select(F.all() != other)
elif op == "gt":
return self.select(F.all() > other)
elif op == "lt":
return self.select(F.all() < other)
elif op == "gt_eq":
return self.select(F.all() >= other)
elif op == "lt_eq":
return self.select(F.all() <= other)
else:
raise ValueError(f"got unexpected comparison operator: {op}")

@property
def width(self) -> int:
"""
Expand All @@ -660,6 +723,24 @@ def __bool__(self) -> NoReturn:
"cannot be used in boolean context with and/or/not operators. "
)

def __eq__(self, other: Any) -> LazyFrame: # type: ignore[override]
return self._comp(other, "eq")

def __ne__(self, other: Any) -> LazyFrame: # type: ignore[override]
return self._comp(other, "neq")

def __gt__(self, other: Any) -> LazyFrame:
return self._comp(other, "gt")

def __lt__(self, other: Any) -> LazyFrame:
return self._comp(other, "lt")

def __ge__(self, other: Any) -> LazyFrame:
return self._comp(other, "gt_eq")

def __le__(self, other: Any) -> LazyFrame:
return self._comp(other, "lt_eq")

def __contains__(self, key: str) -> bool:
return key in self.columns

Expand Down
56 changes: 56 additions & 0 deletions py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,3 +1459,59 @@ def test_compare_aggregation_between_lazy_and_eager_6904(
dtype_eager = result_eager["x"].dtype
result_lazy = df.lazy().select(func.over("y")).select(pl.col(dtype_eager)).collect()
assert result_eager.frame_equal(result_lazy)


def test_lazy_comparison_operators() -> None:
df1 = pl.DataFrame(
{
"a": [1, 2, 3],
"b": ["a", "b", "c"],
}
).lazy()
df2 = pl.DataFrame(
{
"a": [1, 2, 3],
"b": ["a", "b", "d"],
}
).lazy()

assert (df1 == df2).collect().rows() == [(True, True), (True, True), (True, False)]
assert (df1 < df2).collect().rows() == [
(False, False),
(False, False),
(False, True),
]
assert (df1 <= df2).collect().rows() == [(True, True), (True, True), (True, True)]
assert (df1 > df2).collect().rows() == [
(False, False),
(False, False),
(False, False),
]
assert (df1 >= df2).collect().rows() == [(True, True), (True, True), (True, False)]
assert (df1 != df2).collect().rows() == [
(False, False),
(False, False),
(False, True),
]

# test with different columns
df1 = df1.with_columns(pl.lit(0).alias("c"))
with pytest.raises(ValueError, match="DataFrame columns do not match"):
(df1 == df2).collect()

# test with different # of rows
df1 = pl.DataFrame(
{
"a": [1, 2, 3],
}
).lazy()
df2 = pl.DataFrame(
{
"a": [1, 2, 3, 4],
}
).lazy()
with pytest.raises(
pl.exceptions.ComputeError,
match="cannot evaluate two series of different lengths",
):
(df1 == df2).collect()

0 comments on commit 19bcf60

Please sign in to comment.