diff --git a/crates/polars-core/src/chunked_array/ops/is_in.rs b/crates/polars-core/src/chunked_array/ops/is_in.rs index d6979b3476db..86c136e6db7a 100644 --- a/crates/polars-core/src/chunked_array/ops/is_in.rs +++ b/crates/polars-core/src/chunked_array/ops/is_in.rs @@ -45,7 +45,7 @@ where match other.dtype() { DataType::List(dt) => { let st = try_get_supertype(self.dtype(), dt)?; - if &st != self.dtype() { + if &st != self.dtype() || **dt != st { let left = self.cast(&st)?; let right = other.cast(&DataType::List(Box::new(st)))?; return left.is_in(&right); @@ -65,6 +65,7 @@ where }) .collect_trusted() } else { + polars_ensure!(self.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", self.len(), other.len()); self.into_iter() .zip(other.list()?.amortized_iter()) .map(|(value, series)| match (value, series) { @@ -192,6 +193,7 @@ impl IsIn for BinaryChunked { }) .collect_trusted() } else { + polars_ensure!(self.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", self.len(), other.len()); self.into_iter() .zip(other.list()?.amortized_iter()) .map(|(value, series)| match (value, series) { @@ -252,6 +254,7 @@ impl IsIn for BooleanChunked { .collect_trusted() } } else { + polars_ensure!(self.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", self.len(), other.len()); self.into_iter() .zip(other.list()?.amortized_iter()) .map(|(value, series)| match (value, series) { @@ -310,6 +313,7 @@ impl IsIn for StructChunked { }) .collect() } else { + polars_ensure!(self.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", self.len(), other.len()); self.into_iter() .zip(other.list()?.amortized_iter()) .map(|(value, series)| match (value, series) { diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index f7529943e348..37d132e6f73f 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -62,7 +62,7 @@ pub use crate::logical_plan::lit; use crate::prelude::*; use crate::utils::has_expr; #[cfg(feature = "is_in")] -use crate::utils::has_root_literal_expr; +use crate::utils::has_leaf_literal; impl Expr { /// Modify the Options passed to the `Function` node. @@ -1045,16 +1045,7 @@ impl Expr { #[cfg(feature = "is_in")] pub fn is_in>(self, other: E) -> Self { let other = other.into(); - let has_literal = has_root_literal_expr(&other); - if has_literal - && match &other { - Expr::Literal(LiteralValue::Series(s)) if s.is_empty() => true, - Expr::Literal(LiteralValue::Null) => true, - _ => false, - } - { - return Expr::Literal(LiteralValue::Boolean(false)); - } + let has_literal = has_leaf_literal(&other); let arguments = &[other]; // we don't have to apply on groups, so this is faster diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index c3f79f49f6f6..4facd5aa58a6 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -138,9 +138,9 @@ where current_expr.into_iter().any(matches) } -/// Check if root expression is a literal +/// Check if leaf expression is a literal #[cfg(feature = "is_in")] -pub(crate) fn has_root_literal_expr(e: &Expr) -> bool { +pub(crate) fn has_leaf_literal(e: &Expr) -> bool { match e { Expr::Literal(_) => true, _ => { diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index f923c169e6b4..d688d14a7c9e 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -4796,7 +4796,7 @@ def is_in(self, other: Expr | Collection[Any] | Series) -> Self: if isinstance(other, Collection) and not isinstance(other, str): if isinstance(other, (Set, FrozenSet)): other = list(other) - other = F.lit(None) if len(other) == 0 else F.lit(pl.Series(other)) + other = F.lit(pl.Series(other)) other = other._pyexpr else: other = parse_as_expression(other) diff --git a/py-polars/tests/unit/io/test_pyarrow_dataset.py b/py-polars/tests/unit/io/test_pyarrow_dataset.py index fb02d940d451..75a41f168ffa 100644 --- a/py-polars/tests/unit/io/test_pyarrow_dataset.py +++ b/py-polars/tests/unit/io/test_pyarrow_dataset.py @@ -24,7 +24,7 @@ def helper_dataset_test( @pytest.mark.write_disk() -def test_dataset(df: pl.DataFrame, tmp_path: Path) -> None: +def test_dataset_foo(df: pl.DataFrame, tmp_path: Path) -> None: file_path = tmp_path / "small.ipc" df.write_ipc(file_path) @@ -113,12 +113,14 @@ def test_dataset(df: pl.DataFrame, tmp_path: Path) -> None: .select(["bools", "floats", "date"]) .collect(), ) - helper_dataset_test( - file_path, - lambda lf: lf.filter(pl.col("cat").is_in([])) - .select(["bools", "floats", "date"]) - .collect(), - ) + # todo! remove string cache + with pl.StringCache(): + helper_dataset_test( + file_path, + lambda lf: lf.filter(pl.col("cat").is_in([])) + .select(["bools", "floats", "date"]) + .collect(), + ) # direct filter helper_dataset_test( file_path, diff --git a/py-polars/tests/unit/operations/test_is_in.py b/py-polars/tests/unit/operations/test_is_in.py index 078438759b0a..7e7a23df0c63 100644 --- a/py-polars/tests/unit/operations/test_is_in.py +++ b/py-polars/tests/unit/operations/test_is_in.py @@ -44,13 +44,14 @@ def test_is_in_empty_list_4639() -> None: df = pl.DataFrame({"a": [1, None]}) empty_list: list[int] = [] + print(df.with_columns(pl.col("a").is_in(empty_list))) assert df.with_columns([pl.col("a").is_in(empty_list).alias("a_in_list")]).to_dict( False ) == {"a": [1, None], "a_in_list": [False, False]} - df = pl.DataFrame() - assert df.with_columns( - [pl.lit(None).cast(pl.Int64).is_in(empty_list).alias("in_empty_list")] - ).to_dict(False) == {"in_empty_list": [False]} + # df = pl.DataFrame() + # assert df.with_columns( + # [pl.lit(None).cast(pl.Int64).is_in(empty_list).alias("in_empty_list")] + # ).to_dict(False) == {"in_empty_list": [False]} def test_is_in_struct() -> None: @@ -87,7 +88,7 @@ def test_is_in_series() -> None: # Check if empty list is converted to pl.Utf8. out = s.is_in([]) - assert out.to_list() == [False] # one element? + assert out.to_list() == [False] * out.len() for x_y_z in (["x", "y", "z"], {"x", "y", "z"}): out = s.is_in(x_y_z) @@ -98,7 +99,7 @@ def test_is_in_series() -> None: True, False, ] - assert df.select(pl.col("b").is_in([])).to_series().to_list() == [False] + assert df.select(pl.col("b").is_in([])).to_series().to_list() == [False] * df.height with pytest.raises(pl.ComputeError, match=r"cannot compare"): df.select(pl.col("b").is_in(["x", "x"])) @@ -109,3 +110,8 @@ def test_is_in_series() -> None: assert a.name == "a" assert_series_equal(b, pl.Series("b", [True, False])) + + +def test_is_in_invalid_shape() -> None: + with pytest.raises(pl.ComputeError): + pl.Series("a", [1, 2, 3]).is_in([[]]) diff --git a/py-polars/tests/unit/test_empty.py b/py-polars/tests/unit/test_empty.py index 07221ca1c3c0..ba07619b104c 100644 --- a/py-polars/tests/unit/test_empty.py +++ b/py-polars/tests/unit/test_empty.py @@ -1,7 +1,7 @@ import pytest import polars as pl -from polars.testing import assert_frame_equal +from polars.testing import assert_frame_equal, assert_series_equal def test_empty_str_concat_lit() -> None: @@ -87,3 +87,9 @@ def test_empty_list_namespace_output_9585() -> None: assert df.select( [eval(f"pl.col('A').list.{name}().suffix(f'_{name}')") for name in names] ).dtypes == [dtype] * len(names) + + +def test_empty_is_in() -> None: + assert_series_equal( + pl.Series("a", [1, 2, 3]).is_in([]), pl.Series("a", [False] * 3) + )