Skip to content

Commit

Permalink
fix(rust, python): fix is_in on empty series (#10195)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 31, 2023
1 parent f8c4c4e commit e3437ec
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 29 deletions.
6 changes: 5 additions & 1 deletion crates/polars-core/src/chunked_array/ops/is_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ where
match other.dtype() {
DataType::List(dt) => {
let st = try_get_supertype(self.dtype(), dt)?;
if &st != self.dtype() {
if &st != self.dtype() || **dt != st {
let left = self.cast(&st)?;
let right = other.cast(&DataType::List(Box::new(st)))?;
return left.is_in(&right);
Expand All @@ -65,6 +65,7 @@ where
})
.collect_trusted()
} else {
polars_ensure!(self.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", self.len(), other.len());
self.into_iter()
.zip(other.list()?.amortized_iter())
.map(|(value, series)| match (value, series) {
Expand Down Expand Up @@ -192,6 +193,7 @@ impl IsIn for BinaryChunked {
})
.collect_trusted()
} else {
polars_ensure!(self.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", self.len(), other.len());
self.into_iter()
.zip(other.list()?.amortized_iter())
.map(|(value, series)| match (value, series) {
Expand Down Expand Up @@ -252,6 +254,7 @@ impl IsIn for BooleanChunked {
.collect_trusted()
}
} else {
polars_ensure!(self.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", self.len(), other.len());
self.into_iter()
.zip(other.list()?.amortized_iter())
.map(|(value, series)| match (value, series) {
Expand Down Expand Up @@ -310,6 +313,7 @@ impl IsIn for StructChunked {
})
.collect()
} else {
polars_ensure!(self.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", self.len(), other.len());
self.into_iter()
.zip(other.list()?.amortized_iter())
.map(|(value, series)| match (value, series) {
Expand Down
13 changes: 2 additions & 11 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub use crate::logical_plan::lit;
use crate::prelude::*;
use crate::utils::has_expr;
#[cfg(feature = "is_in")]
use crate::utils::has_root_literal_expr;
use crate::utils::has_leaf_literal;

impl Expr {
/// Modify the Options passed to the `Function` node.
Expand Down Expand Up @@ -1045,16 +1045,7 @@ impl Expr {
#[cfg(feature = "is_in")]
pub fn is_in<E: Into<Expr>>(self, other: E) -> Self {
let other = other.into();
let has_literal = has_root_literal_expr(&other);
if has_literal
&& match &other {
Expr::Literal(LiteralValue::Series(s)) if s.is_empty() => true,
Expr::Literal(LiteralValue::Null) => true,
_ => false,
}
{
return Expr::Literal(LiteralValue::Boolean(false));
}
let has_literal = has_leaf_literal(&other);

let arguments = &[other];
// we don't have to apply on groups, so this is faster
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ where
current_expr.into_iter().any(matches)
}

/// Check if root expression is a literal
/// Check if leaf expression is a literal
#[cfg(feature = "is_in")]
pub(crate) fn has_root_literal_expr(e: &Expr) -> bool {
pub(crate) fn has_leaf_literal(e: &Expr) -> bool {
match e {
Expr::Literal(_) => true,
_ => {
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4796,7 +4796,7 @@ def is_in(self, other: Expr | Collection[Any] | Series) -> Self:
if isinstance(other, Collection) and not isinstance(other, str):
if isinstance(other, (Set, FrozenSet)):
other = list(other)
other = F.lit(None) if len(other) == 0 else F.lit(pl.Series(other))
other = F.lit(pl.Series(other))
other = other._pyexpr
else:
other = parse_as_expression(other)
Expand Down
16 changes: 9 additions & 7 deletions py-polars/tests/unit/io/test_pyarrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def helper_dataset_test(


@pytest.mark.write_disk()
def test_dataset(df: pl.DataFrame, tmp_path: Path) -> None:
def test_dataset_foo(df: pl.DataFrame, tmp_path: Path) -> None:
file_path = tmp_path / "small.ipc"
df.write_ipc(file_path)

Expand Down Expand Up @@ -113,12 +113,14 @@ def test_dataset(df: pl.DataFrame, tmp_path: Path) -> None:
.select(["bools", "floats", "date"])
.collect(),
)
helper_dataset_test(
file_path,
lambda lf: lf.filter(pl.col("cat").is_in([]))
.select(["bools", "floats", "date"])
.collect(),
)
# todo! remove string cache
with pl.StringCache():
helper_dataset_test(
file_path,
lambda lf: lf.filter(pl.col("cat").is_in([]))
.select(["bools", "floats", "date"])
.collect(),
)
# direct filter
helper_dataset_test(
file_path,
Expand Down
18 changes: 12 additions & 6 deletions py-polars/tests/unit/operations/test_is_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@ def test_is_in_empty_list_4639() -> None:
df = pl.DataFrame({"a": [1, None]})
empty_list: list[int] = []

print(df.with_columns(pl.col("a").is_in(empty_list)))
assert df.with_columns([pl.col("a").is_in(empty_list).alias("a_in_list")]).to_dict(
False
) == {"a": [1, None], "a_in_list": [False, False]}
df = pl.DataFrame()
assert df.with_columns(
[pl.lit(None).cast(pl.Int64).is_in(empty_list).alias("in_empty_list")]
).to_dict(False) == {"in_empty_list": [False]}
# df = pl.DataFrame()
# assert df.with_columns(
# [pl.lit(None).cast(pl.Int64).is_in(empty_list).alias("in_empty_list")]
# ).to_dict(False) == {"in_empty_list": [False]}


def test_is_in_struct() -> None:
Expand Down Expand Up @@ -87,7 +88,7 @@ def test_is_in_series() -> None:

# Check if empty list is converted to pl.Utf8.
out = s.is_in([])
assert out.to_list() == [False] # one element?
assert out.to_list() == [False] * out.len()

for x_y_z in (["x", "y", "z"], {"x", "y", "z"}):
out = s.is_in(x_y_z)
Expand All @@ -98,7 +99,7 @@ def test_is_in_series() -> None:
True,
False,
]
assert df.select(pl.col("b").is_in([])).to_series().to_list() == [False]
assert df.select(pl.col("b").is_in([])).to_series().to_list() == [False] * df.height

with pytest.raises(pl.ComputeError, match=r"cannot compare"):
df.select(pl.col("b").is_in(["x", "x"]))
Expand All @@ -109,3 +110,8 @@ def test_is_in_series() -> None:

assert a.name == "a"
assert_series_equal(b, pl.Series("b", [True, False]))


def test_is_in_invalid_shape() -> None:
with pytest.raises(pl.ComputeError):
pl.Series("a", [1, 2, 3]).is_in([[]])
8 changes: 7 additions & 1 deletion py-polars/tests/unit/test_empty.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

import polars as pl
from polars.testing import assert_frame_equal
from polars.testing import assert_frame_equal, assert_series_equal


def test_empty_str_concat_lit() -> None:
Expand Down Expand Up @@ -87,3 +87,9 @@ def test_empty_list_namespace_output_9585() -> None:
assert df.select(
[eval(f"pl.col('A').list.{name}().suffix(f'_{name}')") for name in names]
).dtypes == [dtype] * len(names)


def test_empty_is_in() -> None:
assert_series_equal(
pl.Series("a", [1, 2, 3]).is_in([]), pl.Series("a", [False] * 3)
)

0 comments on commit e3437ec

Please sign in to comment.