From 31c0e44c085aa4f6bcd1f7f0eeb53171c96f56dd Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Fri, 19 Jul 2024 11:54:37 -0400 Subject: [PATCH 01/30] A more functional sketch. --- .../src/series/arithmetic/borrowed.rs | 46 +++++++++++++++++++ .../src/series/implementations/list.rs | 18 ++++++++ 2 files changed, 64 insertions(+) diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 6cecab742ffd..7612859617c3 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -171,6 +171,52 @@ impl NumOpsDispatchInner for FixedSizeListType { } } +impl ListChunked { + fn arithm_helper( + &self, + rhs: &Series, + op: &dyn Fn(&Series, &Series) -> PolarsResult, + ) -> PolarsResult { + polars_ensure!(self.len() == rhs.len(), InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", self.len(), rhs.len()); + + // TODO ensure same dtype? + let mut result = self.clear(); + let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { + // We ensured the original Series are the same length, so we can + // assume no None: + let a_owner = a.unwrap(); + let b_owner = b.unwrap(); + let a = a_owner.as_ref(); + let b = b_owner.as_ref(); + polars_ensure!(a.len() == b.len(), InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", a.len(), b.len()); + let result = op(a, b).and_then(|s| s.implode()).map(|ca|Series::from(ca)); + result + }); + for c in combined.into_iter() { + result.append(c?.list()?)?; + } + Ok(result.into()) + } +} + +impl NumOpsDispatchInner for ListType { + fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + lhs.arithm_helper(rhs, &|l, r| l.add_to(r)) + } + fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + lhs.arithm_helper(rhs, &|l, r| l.subtract(r)) + } + fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + lhs.arithm_helper(rhs, &|l, r| l.multiply(r)) + } + fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + lhs.arithm_helper(rhs, &|l, r| l.divide(r)) + } + fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + lhs.arithm_helper(rhs, &|l, r| l.remainder(r)) + } +} + #[cfg(feature = "checked_arithmetic")] pub mod checked { use num_traits::{CheckedDiv, One, ToPrimitive, Zero}; diff --git a/crates/polars-core/src/series/implementations/list.rs b/crates/polars-core/src/series/implementations/list.rs index a67dc8e8f487..c66fd2ce8492 100644 --- a/crates/polars-core/src/series/implementations/list.rs +++ b/crates/polars-core/src/series/implementations/list.rs @@ -47,6 +47,24 @@ impl private::PrivateSeries for SeriesWrap { fn into_total_eq_inner<'a>(&'a self) -> Box { (&self.0).into_total_eq_inner() } + + fn add_to(&self, rhs: &Series) -> PolarsResult { + self.0.add_to(rhs) + } + + fn subtract(&self, rhs: &Series) -> PolarsResult { + self.0.subtract(rhs) + } + + fn multiply(&self, rhs: &Series) -> PolarsResult { + self.0.multiply(rhs) + } + fn divide(&self, rhs: &Series) -> PolarsResult { + self.0.divide(rhs) + } + fn remainder(&self, rhs: &Series) -> PolarsResult { + self.0.remainder(rhs) + } } impl SeriesTrait for SeriesWrap { From 559d9d27eb9e73eb7237f25ceaefb87752c0e2fc Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 23 Jul 2024 10:48:06 -0400 Subject: [PATCH 02/30] Make division work for `list[int64]` (and arrays too) --- py-polars/polars/series/series.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 8b49a3a4bd8f..d87f2b268433 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -1061,6 +1061,22 @@ def __sub__(self, other: Any) -> Self | Expr: return F.lit(self) - other return self._arithmetic(other, "sub", "sub_<>") + def _recursive_cast_to_float64(self) -> Series: + """ + Traverse dtype recursively, eventually converting leaf integer dtypes + to Float64 dtypes. + """ + + def convert_to_float64(dtype: DataType) -> DataType: + if isinstance(dtype, Array): + return Array(convert_to_float64(dtype.inner), shape=dtype.shape) + if isinstance(dtype, List): + return List(convert_to_float64(dtype.inner)) + # TODO are there other types to handle? Struct? + return Float64 + + return self.cast(convert_to_float64(self.dtype)) + @overload def __truediv__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @@ -1077,9 +1093,11 @@ def __truediv__(self, other: Any) -> Series | Expr: # this branch is exactly the floordiv function without rounding the floats if self.dtype.is_float() or self.dtype == Decimal: - return self._arithmetic(other, "div", "div_<>") + as_float = self + else: + as_float = self._recursive_cast_to_float64() - return self.cast(Float64) / other + return as_float._arithmetic(other, "div", "div_<>") @overload def __floordiv__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] From f205d4f90dca63212c6f4232d13884153ded114a Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 23 Jul 2024 13:06:31 -0400 Subject: [PATCH 03/30] More thorough testing of array math expressions --- .../operations/arithmetic/test_arithmetic.py | 48 +++++++++++-------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index 374d2965a029..24dfa97557db 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -1,7 +1,7 @@ import operator from collections import OrderedDict from datetime import date, datetime, timedelta -from typing import Any +from typing import Any, Callable import numpy as np import pytest @@ -558,33 +558,41 @@ def test_power_series() -> None: @pytest.mark.parametrize( - ("expected", "expr"), + ("expected", "expr", "column_names"), [ + (np.array([[2, 4], [6, 8]]), lambda a, b: a + b, ("a", "a")), + (np.array([[0, 0], [0, 0]]), lambda a, b: a - b, ("a", "a")), + (np.array([[1, 4], [9, 16]]), lambda a, b: a * b, ("a", "a")), + (np.array([[1.0, 1.0], [1.0, 1.0]]), lambda a, b: a / b, ("a", "a")), + (np.array([[0, 0], [0, 0]]), lambda a, b: a % b, ("a", "a")), ( - np.array([[2, 4], [6, 8]]), - pl.col("a") + pl.col("a"), - ), - ( - np.array([[0, 0], [0, 0]]), - pl.col("a") - pl.col("a"), - ), - ( - np.array([[1, 4], [9, 16]]), - pl.col("a") * pl.col("a"), - ), - ( - np.array([[1.0, 1.0], [1.0, 1.0]]), - pl.col("a") / pl.col("a"), + np.array([[3, 4], [7, 8]], dtype=np.int64), + lambda a, b: a + b, + ("a", "uint8"), ), ], ) -def test_array_arithmetic_same_size(expected: Any, expr: pl.Expr) -> None: - df = pl.Series("a", np.array([[1, 2], [3, 4]])).to_frame() - +def test_array_arithmetic_same_size( + expected: Any, + expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series], + column_names: tuple[str, str], +) -> None: + df = pl.DataFrame( + [ + pl.Series("a", np.array([[1, 2], [3, 4]], dtype=np.int64)), + pl.Series("uint8", np.array([[2, 2], [4, 4]], dtype=np.uint8)), + ] + ) + # Expr-based arithmetic: assert_frame_equal( - df.select(expr), + df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), pl.Series("a", expected).to_frame(), ) + # Direct arithmetic on the Series: + assert_series_equal( + expr(df[column_names[0]], df[column_names[1]]), + pl.Series("a", expected), + ) def test_schema_owned_arithmetic_5669() -> None: From 53dd23391e816c7be7f866af4651e055931bb29c Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 23 Jul 2024 13:27:46 -0400 Subject: [PATCH 04/30] Another test, commented out --- .../unit/operations/arithmetic/test_arithmetic.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index 24dfa97557db..f7cf436ba019 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -570,6 +570,14 @@ def test_power_series() -> None: lambda a, b: a + b, ("a", "uint8"), ), + # This fails because the code is buggy, see + # https://github.com/pola-rs/polars/issues/17820 + # + # ( + # np.array([[[2, 4]], [[6, 8]]], dtype=np.int64), + # lambda a, b: a + b, + # ("nested", "nested"), + # ), ], ) def test_array_arithmetic_same_size( @@ -581,17 +589,19 @@ def test_array_arithmetic_same_size( [ pl.Series("a", np.array([[1, 2], [3, 4]], dtype=np.int64)), pl.Series("uint8", np.array([[2, 2], [4, 4]], dtype=np.uint8)), + pl.Series("nested", np.array([[[1, 2]], [[3, 4]]], dtype=np.int64)), ] ) + print(df.select(expr(pl.col(column_names[0]), pl.col(column_names[1])))) # Expr-based arithmetic: assert_frame_equal( df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), - pl.Series("a", expected).to_frame(), + pl.Series(column_names[0], expected).to_frame(), ) # Direct arithmetic on the Series: assert_series_equal( expr(df[column_names[0]], df[column_names[1]]), - pl.Series("a", expected), + pl.Series(column_names[0], expected), ) From b58f7f48af02e2e6d54791960802a358d9c39184 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 23 Jul 2024 13:32:58 -0400 Subject: [PATCH 05/30] Success case test suite for list arithmetic --- .../operations/arithmetic/test_arithmetic.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index f7cf436ba019..aaf49694b75f 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -605,6 +605,51 @@ def test_array_arithmetic_same_size( ) +@pytest.mark.parametrize( + ("expected", "expr", "column_names"), + [ + ([[2, 4], [6]], lambda a, b: a + b, ("a", "a")), + ([[0, 0], [0]], lambda a, b: a - b, ("a", "a")), + ([[1, 4], [9]], lambda a, b: a * b, ("a", "a")), + ([[1.0, 1.0], [1.0]], lambda a, b: a / b, ("a", "a")), + ([[0, 0], [0]], lambda a, b: a % b, ("a", "a")), + ( + [[3, 4], [7]], + lambda a, b: a + b, + ("a", "uint8"), + ), + ( + [[[2, 4]], [[6]]], + lambda a, b: a + b, + ("nested", "nested"), + ), + ], +) +def test_list_arithmetic_same_size( + expected: Any, + expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series], + column_names: tuple[str, str], +) -> None: + print(expected) + df = pl.DataFrame( + [ + pl.Series("a", [[1, 2], [3]]), + pl.Series("uint8", [[2, 2], [4]]), + pl.Series("nested", [[[1, 2]], [[3]]]), + ] + ) + # Expr-based arithmetic: + assert_frame_equal( + df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), + pl.Series(column_names[0], expected).to_frame(), + ) + # Direct arithmetic on the Series: + assert_series_equal( + expr(df[column_names[0]], df[column_names[1]]), + pl.Series(column_names[0], expected), + ) + + def test_schema_owned_arithmetic_5669() -> None: df = ( pl.LazyFrame({"A": [1, 2, 3]}) From 6e888ee6c7e9a4533d17e95b291f66e0fd030243 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 23 Jul 2024 13:43:40 -0400 Subject: [PATCH 06/30] Include reference to Rust code --- py-polars/polars/series/series.py | 90 ++++++++++++++++++++----------- 1 file changed, 60 insertions(+), 30 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index d87f2b268433..592ea898f1be 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -748,10 +748,12 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series: return self._from_pyseries(f(other)) @overload # type: ignore[override] - def __eq__(self, other: Expr) -> Expr: ... # type: ignore[overload-overlap] + def __eq__(self, other: Expr) -> Expr: + ... # type: ignore[overload-overlap] @overload - def __eq__(self, other: Any) -> Series: ... + def __eq__(self, other: Any) -> Series: + ... def __eq__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -764,7 +766,8 @@ def __ne__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __ne__(self, other: Any) -> Series: ... + def __ne__(self, other: Any) -> Series: + ... def __ne__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -777,7 +780,8 @@ def __gt__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __gt__(self, other: Any) -> Series: ... + def __gt__(self, other: Any) -> Series: + ... def __gt__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -790,7 +794,8 @@ def __lt__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __lt__(self, other: Any) -> Series: ... + def __lt__(self, other: Any) -> Series: + ... def __lt__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -803,7 +808,8 @@ def __ge__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __ge__(self, other: Any) -> Series: ... + def __ge__(self, other: Any) -> Series: + ... def __ge__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -816,7 +822,8 @@ def __le__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __le__(self, other: Any) -> Series: ... + def __le__(self, other: Any) -> Series: + ... def __le__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -829,7 +836,8 @@ def le(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def le(self, other: Any) -> Series: ... + def le(self, other: Any) -> Series: + ... def le(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series <= other`.""" @@ -840,7 +848,8 @@ def lt(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def lt(self, other: Any) -> Series: ... + def lt(self, other: Any) -> Series: + ... def lt(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series < other`.""" @@ -851,7 +860,8 @@ def eq(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def eq(self, other: Any) -> Series: ... + def eq(self, other: Any) -> Series: + ... def eq(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series == other`.""" @@ -862,7 +872,8 @@ def eq_missing(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def eq_missing(self, other: Any) -> Series: ... + def eq_missing(self, other: Any) -> Series: + ... def eq_missing(self, other: Any) -> Series | Expr: """ @@ -910,7 +921,8 @@ def ne(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def ne(self, other: Any) -> Series: ... + def ne(self, other: Any) -> Series: + ... def ne(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series != other`.""" @@ -921,7 +933,8 @@ def ne_missing(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def ne_missing(self, other: Any) -> Series: ... + def ne_missing(self, other: Any) -> Series: + ... def ne_missing(self, other: Any) -> Series | Expr: """ @@ -969,7 +982,8 @@ def ge(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def ge(self, other: Any) -> Series: ... + def ge(self, other: Any) -> Series: + ... def ge(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series >= other`.""" @@ -980,7 +994,8 @@ def gt(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def gt(self, other: Any) -> Series: ... + def gt(self, other: Any) -> Series: + ... def gt(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series > other`.""" @@ -1038,7 +1053,8 @@ def __add__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __add__(self, other: Any) -> Self: ... + def __add__(self, other: Any) -> Self: + ... def __add__(self, other: Any) -> Self | DataFrame | Expr: if isinstance(other, str): @@ -1054,7 +1070,8 @@ def __sub__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __sub__(self, other: Any) -> Self: ... + def __sub__(self, other: Any) -> Self: + ... def __sub__(self, other: Any) -> Self | Expr: if isinstance(other, pl.Expr): @@ -1065,6 +1082,8 @@ def _recursive_cast_to_float64(self) -> Series: """ Traverse dtype recursively, eventually converting leaf integer dtypes to Float64 dtypes. + + This is equivalent to logic in DataType::cast_leaf() in Rust. """ def convert_to_float64(dtype: DataType) -> DataType: @@ -1072,7 +1091,6 @@ def convert_to_float64(dtype: DataType) -> DataType: return Array(convert_to_float64(dtype.inner), shape=dtype.shape) if isinstance(dtype, List): return List(convert_to_float64(dtype.inner)) - # TODO are there other types to handle? Struct? return Float64 return self.cast(convert_to_float64(self.dtype)) @@ -1082,7 +1100,8 @@ def __truediv__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __truediv__(self, other: Any) -> Series: ... + def __truediv__(self, other: Any) -> Series: + ... def __truediv__(self, other: Any) -> Series | Expr: if isinstance(other, pl.Expr): @@ -1104,7 +1123,8 @@ def __floordiv__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __floordiv__(self, other: Any) -> Series: ... + def __floordiv__(self, other: Any) -> Series: + ... def __floordiv__(self, other: Any) -> Series | Expr: if isinstance(other, pl.Expr): @@ -1129,7 +1149,8 @@ def __mul__(self, other: DataFrame) -> DataFrame: # type: ignore[overload-overl ... @overload - def __mul__(self, other: Any) -> Series: ... + def __mul__(self, other: Any) -> Series: + ... def __mul__(self, other: Any) -> Series | DataFrame | Expr: if isinstance(other, pl.Expr): @@ -1147,7 +1168,8 @@ def __mod__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __mod__(self, other: Any) -> Series: ... + def __mod__(self, other: Any) -> Series: + ... def __mod__(self, other: Any) -> Series | Expr: if isinstance(other, pl.Expr): @@ -1250,10 +1272,12 @@ def __iter__(self) -> Generator[Any, None, None]: yield from self.slice(offset, buffer_size).to_list() @overload - def __getitem__(self, key: SingleIndexSelector) -> Any: ... + def __getitem__(self, key: SingleIndexSelector) -> Any: + ... @overload - def __getitem__(self, key: MultiIndexSelector) -> Series: ... + def __getitem__(self, key: MultiIndexSelector) -> Series: + ... def __getitem__( self, key: SingleIndexSelector | MultiIndexSelector @@ -1581,10 +1605,12 @@ def cbrt(self) -> Series: """ @overload - def any(self, *, ignore_nulls: Literal[True] = ...) -> bool: ... + def any(self, *, ignore_nulls: Literal[True] = ...) -> bool: + ... @overload - def any(self, *, ignore_nulls: bool) -> bool | None: ... + def any(self, *, ignore_nulls: bool) -> bool | None: + ... def any(self, *, ignore_nulls: bool = True) -> bool | None: """ @@ -1623,10 +1649,12 @@ def any(self, *, ignore_nulls: bool = True) -> bool | None: return self._s.any(ignore_nulls=ignore_nulls) @overload - def all(self, *, ignore_nulls: Literal[True] = ...) -> bool: ... + def all(self, *, ignore_nulls: Literal[True] = ...) -> bool: + ... @overload - def all(self, *, ignore_nulls: bool) -> bool | None: ... + def all(self, *, ignore_nulls: bool) -> bool | None: + ... def all(self, *, ignore_nulls: bool = True) -> bool | None: """ @@ -3368,14 +3396,16 @@ def arg_max(self) -> int | None: @overload def search_sorted( self, element: NonNestedLiteral | None, side: SearchSortedSide = ... - ) -> int: ... + ) -> int: + ... @overload def search_sorted( self, element: list[NonNestedLiteral | None] | np.ndarray[Any, Any] | Expr | Series, side: SearchSortedSide = ..., - ) -> Series: ... + ) -> Series: + ... def search_sorted( self, From 0ef28469462df46e58b96b9a93d4428261d37b23 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 23 Jul 2024 14:09:30 -0400 Subject: [PATCH 07/30] Support division in lists --- crates/polars-expr/src/expressions/binary.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/crates/polars-expr/src/expressions/binary.rs b/crates/polars-expr/src/expressions/binary.rs index 0d4634d6eeaf..c3df716d7e3d 100644 --- a/crates/polars-expr/src/expressions/binary.rs +++ b/crates/polars-expr/src/expressions/binary.rs @@ -72,6 +72,11 @@ pub fn apply_operator(left: &Series, right: &Series, op: Operator) -> PolarsResu let right_dt = right.dtype().cast_leaf(Float64); left.cast(&left_dt)? / right.cast(&right_dt)? }, + dt @ List(_) => { + let left_dt = dt.cast_leaf(Float64); + let right_dt = right.dtype().cast_leaf(Float64); + left.cast(&left_dt)? / right.cast(&right_dt)? + }, _ => { if right.dtype().is_temporal() { return left / right; From bffb77f621690e1daa42dab6e9cb6166aae31d3c Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 23 Jul 2024 14:14:59 -0400 Subject: [PATCH 08/30] Test error edge cases for List arithmetic --- .../operations/arithmetic/test_arithmetic.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index aaf49694b75f..e524015210c6 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -650,6 +650,26 @@ def test_list_arithmetic_same_size( ) +def test_list_arithmetic_error_cases(): + # Different series length: + with pytest.raises( + InvalidOperationError, match="Series of the same size; got 1 and 2" + ): + _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1, 2], [3, 4]]) + + # Different list length: + # Different series length: + with pytest.raises( + InvalidOperationError, match="lists of the same size; got 2 and 1" + ): + _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1]]) + + # Wrong types: + # Different series length: + with pytest.raises(InvalidOperationError, match="cannot cast List type"): + _ = pl.Series("a", [[1, 2]]) + pl.Series("b", ["hello"]) + + def test_schema_owned_arithmetic_5669() -> None: df = ( pl.LazyFrame({"A": [1, 2, 3]}) From 856a884150887fc8b1fa12a28bf4babbea0e20e5 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 23 Jul 2024 14:16:17 -0400 Subject: [PATCH 09/30] Run ruff to fix formatting --- py-polars/polars/series/series.py | 87 +++++++++++-------------------- 1 file changed, 29 insertions(+), 58 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 592ea898f1be..6532e056a562 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -748,12 +748,10 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series: return self._from_pyseries(f(other)) @overload # type: ignore[override] - def __eq__(self, other: Expr) -> Expr: - ... # type: ignore[overload-overlap] + def __eq__(self, other: Expr) -> Expr: ... # type: ignore[overload-overlap] @overload - def __eq__(self, other: Any) -> Series: - ... + def __eq__(self, other: Any) -> Series: ... def __eq__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -766,8 +764,7 @@ def __ne__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __ne__(self, other: Any) -> Series: - ... + def __ne__(self, other: Any) -> Series: ... def __ne__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -780,8 +777,7 @@ def __gt__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __gt__(self, other: Any) -> Series: - ... + def __gt__(self, other: Any) -> Series: ... def __gt__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -794,8 +790,7 @@ def __lt__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __lt__(self, other: Any) -> Series: - ... + def __lt__(self, other: Any) -> Series: ... def __lt__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -808,8 +803,7 @@ def __ge__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __ge__(self, other: Any) -> Series: - ... + def __ge__(self, other: Any) -> Series: ... def __ge__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -822,8 +816,7 @@ def __le__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __le__(self, other: Any) -> Series: - ... + def __le__(self, other: Any) -> Series: ... def __le__(self, other: Any) -> Series | Expr: warn_null_comparison(other) @@ -836,8 +829,7 @@ def le(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def le(self, other: Any) -> Series: - ... + def le(self, other: Any) -> Series: ... def le(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series <= other`.""" @@ -848,8 +840,7 @@ def lt(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def lt(self, other: Any) -> Series: - ... + def lt(self, other: Any) -> Series: ... def lt(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series < other`.""" @@ -860,8 +851,7 @@ def eq(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def eq(self, other: Any) -> Series: - ... + def eq(self, other: Any) -> Series: ... def eq(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series == other`.""" @@ -872,8 +862,7 @@ def eq_missing(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def eq_missing(self, other: Any) -> Series: - ... + def eq_missing(self, other: Any) -> Series: ... def eq_missing(self, other: Any) -> Series | Expr: """ @@ -921,8 +910,7 @@ def ne(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def ne(self, other: Any) -> Series: - ... + def ne(self, other: Any) -> Series: ... def ne(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series != other`.""" @@ -933,8 +921,7 @@ def ne_missing(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def ne_missing(self, other: Any) -> Series: - ... + def ne_missing(self, other: Any) -> Series: ... def ne_missing(self, other: Any) -> Series | Expr: """ @@ -982,8 +969,7 @@ def ge(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def ge(self, other: Any) -> Series: - ... + def ge(self, other: Any) -> Series: ... def ge(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series >= other`.""" @@ -994,8 +980,7 @@ def gt(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def gt(self, other: Any) -> Series: - ... + def gt(self, other: Any) -> Series: ... def gt(self, other: Any) -> Series | Expr: """Method equivalent of operator expression `series > other`.""" @@ -1053,8 +1038,7 @@ def __add__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __add__(self, other: Any) -> Self: - ... + def __add__(self, other: Any) -> Self: ... def __add__(self, other: Any) -> Self | DataFrame | Expr: if isinstance(other, str): @@ -1070,8 +1054,7 @@ def __sub__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __sub__(self, other: Any) -> Self: - ... + def __sub__(self, other: Any) -> Self: ... def __sub__(self, other: Any) -> Self | Expr: if isinstance(other, pl.Expr): @@ -1100,8 +1083,7 @@ def __truediv__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __truediv__(self, other: Any) -> Series: - ... + def __truediv__(self, other: Any) -> Series: ... def __truediv__(self, other: Any) -> Series | Expr: if isinstance(other, pl.Expr): @@ -1123,8 +1105,7 @@ def __floordiv__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __floordiv__(self, other: Any) -> Series: - ... + def __floordiv__(self, other: Any) -> Series: ... def __floordiv__(self, other: Any) -> Series | Expr: if isinstance(other, pl.Expr): @@ -1149,8 +1130,7 @@ def __mul__(self, other: DataFrame) -> DataFrame: # type: ignore[overload-overl ... @overload - def __mul__(self, other: Any) -> Series: - ... + def __mul__(self, other: Any) -> Series: ... def __mul__(self, other: Any) -> Series | DataFrame | Expr: if isinstance(other, pl.Expr): @@ -1168,8 +1148,7 @@ def __mod__(self, other: Expr) -> Expr: # type: ignore[overload-overlap] ... @overload - def __mod__(self, other: Any) -> Series: - ... + def __mod__(self, other: Any) -> Series: ... def __mod__(self, other: Any) -> Series | Expr: if isinstance(other, pl.Expr): @@ -1272,12 +1251,10 @@ def __iter__(self) -> Generator[Any, None, None]: yield from self.slice(offset, buffer_size).to_list() @overload - def __getitem__(self, key: SingleIndexSelector) -> Any: - ... + def __getitem__(self, key: SingleIndexSelector) -> Any: ... @overload - def __getitem__(self, key: MultiIndexSelector) -> Series: - ... + def __getitem__(self, key: MultiIndexSelector) -> Series: ... def __getitem__( self, key: SingleIndexSelector | MultiIndexSelector @@ -1605,12 +1582,10 @@ def cbrt(self) -> Series: """ @overload - def any(self, *, ignore_nulls: Literal[True] = ...) -> bool: - ... + def any(self, *, ignore_nulls: Literal[True] = ...) -> bool: ... @overload - def any(self, *, ignore_nulls: bool) -> bool | None: - ... + def any(self, *, ignore_nulls: bool) -> bool | None: ... def any(self, *, ignore_nulls: bool = True) -> bool | None: """ @@ -1649,12 +1624,10 @@ def any(self, *, ignore_nulls: bool = True) -> bool | None: return self._s.any(ignore_nulls=ignore_nulls) @overload - def all(self, *, ignore_nulls: Literal[True] = ...) -> bool: - ... + def all(self, *, ignore_nulls: Literal[True] = ...) -> bool: ... @overload - def all(self, *, ignore_nulls: bool) -> bool | None: - ... + def all(self, *, ignore_nulls: bool) -> bool | None: ... def all(self, *, ignore_nulls: bool = True) -> bool | None: """ @@ -3396,16 +3369,14 @@ def arg_max(self) -> int | None: @overload def search_sorted( self, element: NonNestedLiteral | None, side: SearchSortedSide = ... - ) -> int: - ... + ) -> int: ... @overload def search_sorted( self, element: list[NonNestedLiteral | None] | np.ndarray[Any, Any] | Expr | Series, side: SearchSortedSide = ..., - ) -> Series: - ... + ) -> Series: ... def search_sorted( self, From c830d33c302a5e47a7b584ccbdb4a1b1755cae0e Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 23 Jul 2024 14:22:52 -0400 Subject: [PATCH 10/30] Fix lints --- py-polars/polars/series/series.py | 7 +++---- .../tests/unit/operations/arithmetic/test_arithmetic.py | 4 +++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 6532e056a562..8c84a095820d 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -1063,18 +1063,17 @@ def __sub__(self, other: Any) -> Self | Expr: def _recursive_cast_to_float64(self) -> Series: """ - Traverse dtype recursively, eventually converting leaf integer dtypes - to Float64 dtypes. + Convert leaf dtypes to Float64 dtypes. This is equivalent to logic in DataType::cast_leaf() in Rust. """ - def convert_to_float64(dtype: DataType) -> DataType: + def convert_to_float64(dtype: PolarsDataType) -> PolarsDataType: if isinstance(dtype, Array): return Array(convert_to_float64(dtype.inner), shape=dtype.shape) if isinstance(dtype, List): return List(convert_to_float64(dtype.inner)) - return Float64 + return Float64() return self.cast(convert_to_float64(self.dtype)) diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index e524015210c6..33d85c2ab96b 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import operator from collections import OrderedDict from datetime import date, datetime, timedelta @@ -650,7 +652,7 @@ def test_list_arithmetic_same_size( ) -def test_list_arithmetic_error_cases(): +def test_list_arithmetic_error_cases() -> None: # Different series length: with pytest.raises( InvalidOperationError, match="Series of the same size; got 1 and 2" From 5692b62074985b4170986ad329954b14a7ff745c Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 23 Jul 2024 15:26:10 -0400 Subject: [PATCH 11/30] Fix lint --- crates/polars-core/src/series/arithmetic/borrowed.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 7612859617c3..8d1b7b0dcb0f 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -189,8 +189,7 @@ impl ListChunked { let a = a_owner.as_ref(); let b = b_owner.as_ref(); polars_ensure!(a.len() == b.len(), InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", a.len(), b.len()); - let result = op(a, b).and_then(|s| s.implode()).map(|ca|Series::from(ca)); - result + op(a, b).and_then(|s| s.implode()).map(Series::from) }); for c in combined.into_iter() { result.append(c?.list()?)?; From e194f9767faac9762fe73443fa5c37e6b2bc96fb Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 23 Jul 2024 16:40:51 -0400 Subject: [PATCH 12/30] Clean up --- crates/polars-core/src/series/arithmetic/borrowed.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 8d1b7b0dcb0f..e056f6a5cd3e 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -179,7 +179,6 @@ impl ListChunked { ) -> PolarsResult { polars_ensure!(self.len() == rhs.len(), InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", self.len(), rhs.len()); - // TODO ensure same dtype? let mut result = self.clear(); let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { // We ensured the original Series are the same length, so we can From 0601f5ecb406c048989a3098d137ef94d3eb4d43 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 23 Jul 2024 16:41:53 -0400 Subject: [PATCH 13/30] Specify dtype explicitly --- .../unit/operations/arithmetic/test_arithmetic.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index 33d85c2ab96b..ba6b00ffcdc4 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -562,11 +562,15 @@ def test_power_series() -> None: @pytest.mark.parametrize( ("expected", "expr", "column_names"), [ - (np.array([[2, 4], [6, 8]]), lambda a, b: a + b, ("a", "a")), - (np.array([[0, 0], [0, 0]]), lambda a, b: a - b, ("a", "a")), - (np.array([[1, 4], [9, 16]]), lambda a, b: a * b, ("a", "a")), - (np.array([[1.0, 1.0], [1.0, 1.0]]), lambda a, b: a / b, ("a", "a")), - (np.array([[0, 0], [0, 0]]), lambda a, b: a % b, ("a", "a")), + (np.array([[2, 4], [6, 8]], dtype=np.int64), lambda a, b: a + b, ("a", "a")), + (np.array([[0, 0], [0, 0]], dtype=np.int64), lambda a, b: a - b, ("a", "a")), + (np.array([[1, 4], [9, 16]], dtype=np.int64), lambda a, b: a * b, ("a", "a")), + ( + np.array([[1.0, 1.0], [1.0, 1.0]], dtype=np.float64), + lambda a, b: a / b, + ("a", "a"), + ), + (np.array([[0, 0], [0, 0]], dtype=np.int64), lambda a, b: a % b, ("a", "a")), ( np.array([[3, 4], [7, 8]], dtype=np.int64), lambda a, b: a + b, From 2a86fbc1879170fa2d9950671c41c322f825e5a6 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Wed, 7 Aug 2024 11:23:13 -0400 Subject: [PATCH 14/30] Rewrite to operate directly on underlying data in one chunk. --- .../src/series/arithmetic/borrowed.rs | 52 +++++++++++++------ 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index e056f6a5cd3e..0e4a9a2d22f5 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -1,3 +1,5 @@ +use arrow::array::Array; + use super::*; use crate::utils::align_chunks_binary; @@ -171,6 +173,22 @@ impl NumOpsDispatchInner for FixedSizeListType { } } +/// Given an ArrayRef with some primitive values, wrap it in list(s) until it +/// matches the requested shape. +fn reshape_based_on(data: &ArrayRef, shape: &ArrayRef) -> PolarsResult { + if let Some(list_chunk) = shape.as_any().downcast_ref::() { + let result = LargeListArray::try_new( + list_chunk.data_type().clone(), + list_chunk.offsets().clone(), + reshape_based_on(data, list_chunk.values())?, + list_chunk.validity().cloned(), + )?; + Ok(Box::new(result)) + } else { + Ok(data.clone()) + } +} + impl ListChunked { fn arithm_helper( &self, @@ -178,22 +196,26 @@ impl ListChunked { op: &dyn Fn(&Series, &Series) -> PolarsResult, ) -> PolarsResult { polars_ensure!(self.len() == rhs.len(), InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", self.len(), rhs.len()); - - let mut result = self.clear(); - let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { - // We ensured the original Series are the same length, so we can - // assume no None: - let a_owner = a.unwrap(); - let b_owner = b.unwrap(); - let a = a_owner.as_ref(); - let b = b_owner.as_ref(); - polars_ensure!(a.len() == b.len(), InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", a.len(), b.len()); - op(a, b).and_then(|s| s.implode()).map(Series::from) - }); - for c in combined.into_iter() { - result.append(c?.list()?)?; + // TODO make sure the list shapes the same + let l_rechunked = self.rechunk().into_series(); + let l_leaf_array = l_rechunked.get_leaf_array(); + let r_leaf_array = rhs.rechunk().get_leaf_array(); + let result = op(&l_leaf_array, &r_leaf_array)?; + + // We now need to wrap the Arrow arrays with the metadata that turns + // them into lists: + // TODO is there a way to do this without cloning the underlying data? + let result_chunks = result.chunks(); + assert_eq!(result_chunks.len(), 1); + let left_chunk = &l_rechunked.chunks()[0]; + let result_chunk = reshape_based_on(&result_chunks[0], left_chunk)?; + + unsafe { + let mut result = + ListChunked::new_with_dims(self.field.clone(), vec![result_chunk], 0, 0); + result.compute_len(); + Ok(result.into()) } - Ok(result.into()) } } From f0eea119f9abb1e1b8918fb89c2c4a26bff0a439 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Thu, 8 Aug 2024 11:09:54 -0400 Subject: [PATCH 15/30] Handle nulls correctly --- .../src/series/arithmetic/borrowed.rs | 23 +++++++++++++++++++ .../operations/arithmetic/test_arithmetic.py | 16 ++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 0e4a9a2d22f5..387bf6390eaa 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -197,9 +197,32 @@ impl ListChunked { ) -> PolarsResult { polars_ensure!(self.len() == rhs.len(), InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", self.len(), rhs.len()); // TODO make sure the list shapes the same + + if self.null_count() > 0 || rhs.null_count() > 0 { + // A slower implementation since we can't just add the underlying + // values Arrow arrays. Given nulls, the two values arrays might not + // line up the way we expect. + let mut result = self.clear(); + let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { + let (Some(a_owner), Some(b_owner)) = (a, b) else { + // Operations with nulls always result in nulls: + return Ok(Series::full_null(self.name(), 1, self.dtype())); + }; + let a = a_owner.as_ref(); + let b = b_owner.as_ref(); + polars_ensure!(a.len() == b.len(), InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", a.len(), b.len()); + op(a, b).and_then(|s| s.implode()).map(Series::from) + }); + for c in combined.into_iter() { + result.append(c?.list()?)?; + } + return Ok(result.into()); + } + let l_rechunked = self.rechunk().into_series(); let l_leaf_array = l_rechunked.get_leaf_array(); let r_leaf_array = rhs.rechunk().get_leaf_array(); + let result = op(&l_leaf_array, &r_leaf_array)?; // We now need to wrap the Arrow arrays with the metadata that turns diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index ba6b00ffcdc4..0d5bd9da7be8 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -636,7 +636,6 @@ def test_list_arithmetic_same_size( expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series], column_names: tuple[str, str], ) -> None: - print(expected) df = pl.DataFrame( [ pl.Series("a", [[1, 2], [3]]), @@ -656,6 +655,21 @@ def test_list_arithmetic_same_size( ) +@pytest.mark.parametrize( + ("a", "b", "expected"), + [ + ([[2, 3]], [[None, 5]], [[None, 8]]), + ([[2], None, [5]], [None, [3], [2]], [None, None, [7]]), + ([[[2]], [None]], [[[3]], [[6]]], [[[5]], [None]]), + ], +) +def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None: + assert_series_equal( + pl.Series(a) + pl.Series(b), + pl.Series(expected), + ) + + def test_list_arithmetic_error_cases() -> None: # Different series length: with pytest.raises( From 7ba7fd6765daac25b2e093c72a4487b6effd3495 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Thu, 8 Aug 2024 13:44:34 -0400 Subject: [PATCH 16/30] WIP improvements to null handling. --- .../src/series/arithmetic/borrowed.rs | 61 ++++++++++++++++--- .../operations/arithmetic/test_arithmetic.py | 18 ++++-- 2 files changed, 66 insertions(+), 13 deletions(-) diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 387bf6390eaa..86474ee43782 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -175,12 +175,12 @@ impl NumOpsDispatchInner for FixedSizeListType { /// Given an ArrayRef with some primitive values, wrap it in list(s) until it /// matches the requested shape. -fn reshape_based_on(data: &ArrayRef, shape: &ArrayRef) -> PolarsResult { +fn reshape_list_based_on(data: &ArrayRef, shape: &ArrayRef) -> PolarsResult { if let Some(list_chunk) = shape.as_any().downcast_ref::() { let result = LargeListArray::try_new( list_chunk.data_type().clone(), list_chunk.offsets().clone(), - reshape_based_on(data, list_chunk.values())?, + reshape_list_based_on(data, list_chunk.values())?, list_chunk.validity().cloned(), )?; Ok(Box::new(result)) @@ -189,6 +189,32 @@ fn reshape_based_on(data: &ArrayRef, shape: &ArrayRef) -> PolarsResult } } +/// Given an ArrayRef, return true if it's a LargeListArrays and it has one or +/// more nulls. +fn maybe_list_has_nulls(data: &ArrayRef) -> bool { + if let Some(list_chunk) = data.as_any().downcast_ref::() { + println!("BITMAP UNSET BITS {:?}", list_chunk + .validity() + .map(|bitmap| bitmap.unset_bits())); + if list_chunk + .validity() + .map(|bitmap| bitmap.unset_bits() > 0) + .unwrap_or(false) + { + true + } else { + maybe_list_has_nulls(list_chunk.values()) + } + } else { + false + } +} + +// fn same_shapes(left: &ArrayRef, right: &ArrayRef) -> bool { +// let left_as_list = left.as_any().downcast_ref::(); +// let right_as_list = right.as_any().downcast_ref::(); +// match +// } impl ListChunked { fn arithm_helper( &self, @@ -196,9 +222,23 @@ impl ListChunked { op: &dyn Fn(&Series, &Series) -> PolarsResult, ) -> PolarsResult { polars_ensure!(self.len() == rhs.len(), InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", self.len(), rhs.len()); - // TODO make sure the list shapes the same - if self.null_count() > 0 || rhs.null_count() > 0 { + let mut has_nulls = false; + for chunk in self.chunks().iter() { + if maybe_list_has_nulls(chunk) { + has_nulls = true; + break; + } + } + if !has_nulls { + for chunk in rhs.chunks().iter() { + if maybe_list_has_nulls(chunk) { + has_nulls = true; + break; + } + } + } + if has_nulls { // A slower implementation since we can't just add the underlying // values Arrow arrays. Given nulls, the two values arrays might not // line up the way we expect. @@ -211,6 +251,7 @@ impl ListChunked { let a = a_owner.as_ref(); let b = b_owner.as_ref(); polars_ensure!(a.len() == b.len(), InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", a.len(), b.len()); + println!("SLOW {a:?} {b:?}"); op(a, b).and_then(|s| s.implode()).map(Series::from) }); for c in combined.into_iter() { @@ -218,10 +259,12 @@ impl ListChunked { } return Ok(result.into()); } - - let l_rechunked = self.rechunk().into_series(); - let l_leaf_array = l_rechunked.get_leaf_array(); - let r_leaf_array = rhs.rechunk().get_leaf_array(); + println!("FAST INPUTS {self:?} {rhs:?}"); + let l_rechunked = self.clone().rechunk().into_series(); + println!("FAST L_RECHUNKED {l_rechunked:?}"); + let l_leaf_array = l_rechunked.explode()?; + let r_leaf_array = rhs.rechunk().explode()?; + println!("FAST LEAF ARRAY {l_leaf_array:?} {r_leaf_array:?}"); let result = op(&l_leaf_array, &r_leaf_array)?; @@ -231,7 +274,7 @@ impl ListChunked { let result_chunks = result.chunks(); assert_eq!(result_chunks.len(), 1); let left_chunk = &l_rechunked.chunks()[0]; - let result_chunk = reshape_based_on(&result_chunks[0], left_chunk)?; + let result_chunk = reshape_list_based_on(&result_chunks[0], left_chunk)?; unsafe { let mut result = diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index 0d5bd9da7be8..bb04b64fd773 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -658,9 +658,9 @@ def test_list_arithmetic_same_size( @pytest.mark.parametrize( ("a", "b", "expected"), [ - ([[2, 3]], [[None, 5]], [[None, 8]]), + ([[1, 2, 3]], [[1, None, 5]], [[2, None, 8]]), ([[2], None, [5]], [None, [3], [2]], [None, None, [7]]), - ([[[2]], [None]], [[[3]], [[6]]], [[[5]], [None]]), + ([[[2]], [None], [[4]]], [[[3]], [[6]], [[8]]], [[[5]], [None], [[12]]]), ], ) def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None: @@ -676,19 +676,29 @@ def test_list_arithmetic_error_cases() -> None: InvalidOperationError, match="Series of the same size; got 1 and 2" ): _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1, 2], [3, 4]]) + with pytest.raises( + InvalidOperationError, match="Series of the same size; got 1 and 2" + ): + _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1, 2], None]) # Different list length: - # Different series length: with pytest.raises( InvalidOperationError, match="lists of the same size; got 2 and 1" ): _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1]]) + with pytest.raises( + InvalidOperationError, match="lists of the same size; got 2 and 1" + ): + _ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None]) # Wrong types: - # Different series length: with pytest.raises(InvalidOperationError, match="cannot cast List type"): _ = pl.Series("a", [[1, 2]]) + pl.Series("b", ["hello"]) + # Different nesting: + with pytest.raises(InvalidOperationError, match="TODO"): + _ = pl.Series("a", [[1]]) / pl.Series("b", [[[1]]]) + def test_schema_owned_arithmetic_5669() -> None: df = ( From 00ba975ee5a1fd2df2ac39c8befbf0d73e858aa4 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Thu, 8 Aug 2024 14:11:50 -0400 Subject: [PATCH 17/30] Null handling now appears to work with latest tests. --- Cargo.lock | 1 + Cargo.toml | 1 + crates/polars-core/Cargo.toml | 1 + .../src/series/arithmetic/borrowed.rs | 40 ++++++++++++------- 4 files changed, 29 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bcc3fe133d87..4d85a199a8f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3034,6 +3034,7 @@ dependencies = [ "rand_distr", "rayon", "regex", + "scopeguard", "serde", "serde_json", "smartstring", diff --git a/Cargo.toml b/Cargo.toml index 2372919640cc..4a6716efe131 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,6 +74,7 @@ recursive = "0.1" regex = "1.9" reqwest = { version = "0.12", default-features = false } ryu = "1.0.13" +scopeguard = "1.2.0" serde = { version = "1.0.188", features = ["derive"] } serde_json = "1" simd-json = { version = "0.13", features = ["known-key"] } diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index 882392f080cf..3fa4152dd877 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -32,6 +32,7 @@ rand = { workspace = true, optional = true, features = ["small_rng", "std"] } rand_distr = { workspace = true, optional = true } rayon = { workspace = true } regex = { workspace = true, optional = true } +scopeguard = { workspace = true } # activate if you want serde support for Series and DataFrames serde = { workspace = true, optional = true } serde_json = { workspace = true, optional = true } diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 86474ee43782..d696e1949702 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -1,3 +1,5 @@ +use std::cell::Cell; + use arrow::array::Array; use super::*; @@ -193,9 +195,6 @@ fn reshape_list_based_on(data: &ArrayRef, shape: &ArrayRef) -> PolarsResult bool { if let Some(list_chunk) = data.as_any().downcast_ref::() { - println!("BITMAP UNSET BITS {:?}", list_chunk - .validity() - .map(|bitmap| bitmap.unset_bits())); if list_chunk .validity() .map(|bitmap| bitmap.unset_bits() > 0) @@ -215,6 +214,7 @@ fn maybe_list_has_nulls(data: &ArrayRef) -> bool { // let right_as_list = right.as_any().downcast_ref::(); // match // } + impl ListChunked { fn arithm_helper( &self, @@ -223,11 +223,19 @@ impl ListChunked { ) -> PolarsResult { polars_ensure!(self.len() == rhs.len(), InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", self.len(), rhs.len()); - let mut has_nulls = false; - for chunk in self.chunks().iter() { - if maybe_list_has_nulls(chunk) { - has_nulls = true; - break; + thread_local!{ + static HAS_NULLS: Cell = const { Cell::new(false) }; + }; + + let orig_has_nulls = HAS_NULLS.get(); + + let mut has_nulls = orig_has_nulls; + if !has_nulls { + for chunk in self.chunks().iter() { + if maybe_list_has_nulls(chunk) { + has_nulls = true; + break; + } } } if !has_nulls { @@ -242,6 +250,14 @@ impl ListChunked { // A slower implementation since we can't just add the underlying // values Arrow arrays. Given nulls, the two values arrays might not // line up the way we expect. + + // This can be recursive, so preserve the knowledge that there + // were nulls. Unfortunately get_leaf_array() and explode() + // don't work on the Series that come out of amortized_iter(), + // so we need to stick to this code path. + HAS_NULLS.set(true); + scopeguard::defer!{ HAS_NULLS.set(orig_has_nulls); }; + let mut result = self.clear(); let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { let (Some(a_owner), Some(b_owner)) = (a, b) else { @@ -251,7 +267,6 @@ impl ListChunked { let a = a_owner.as_ref(); let b = b_owner.as_ref(); polars_ensure!(a.len() == b.len(), InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", a.len(), b.len()); - println!("SLOW {a:?} {b:?}"); op(a, b).and_then(|s| s.implode()).map(Series::from) }); for c in combined.into_iter() { @@ -259,12 +274,9 @@ impl ListChunked { } return Ok(result.into()); } - println!("FAST INPUTS {self:?} {rhs:?}"); let l_rechunked = self.clone().rechunk().into_series(); - println!("FAST L_RECHUNKED {l_rechunked:?}"); - let l_leaf_array = l_rechunked.explode()?; - let r_leaf_array = rhs.rechunk().explode()?; - println!("FAST LEAF ARRAY {l_leaf_array:?} {r_leaf_array:?}"); + let l_leaf_array = l_rechunked.get_leaf_array(); + let r_leaf_array = rhs.rechunk().get_leaf_array(); let result = op(&l_leaf_array, &r_leaf_array)?; From 254b37ee1e3be3297a328eee94f0bbf33dbeebd8 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Thu, 8 Aug 2024 14:26:06 -0400 Subject: [PATCH 18/30] All tests pass. --- .../src/series/arithmetic/borrowed.rs | 40 ++++++++++++++----- .../operations/arithmetic/test_arithmetic.py | 10 ++--- 2 files changed, 35 insertions(+), 15 deletions(-) diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index d696e1949702..f6f80107b081 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -209,11 +209,19 @@ fn maybe_list_has_nulls(data: &ArrayRef) -> bool { } } -// fn same_shapes(left: &ArrayRef, right: &ArrayRef) -> bool { -// let left_as_list = left.as_any().downcast_ref::(); -// let right_as_list = right.as_any().downcast_ref::(); -// match -// } +/// Return whether the left and right have the same shape. We assume neither has +/// any nulls, recursively. +fn lists_same_shapes(left: &ArrayRef, right: &ArrayRef) -> bool { + let left_as_list = left.as_any().downcast_ref::(); + let right_as_list = right.as_any().downcast_ref::(); + match (left_as_list, right_as_list) { + (Some(left), Some(right)) => { + left.offsets() == right.offsets() && lists_same_shapes(left.values(), right.values()) + }, + (None, None) => left.len() == right.len(), + _ => false, + } +} impl ListChunked { fn arithm_helper( @@ -221,9 +229,14 @@ impl ListChunked { rhs: &Series, op: &dyn Fn(&Series, &Series) -> PolarsResult, ) -> PolarsResult { - polars_ensure!(self.len() == rhs.len(), InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", self.len(), rhs.len()); + polars_ensure!( + self.len() == rhs.len(), + InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", + self.len(), + rhs.len() + ); - thread_local!{ + thread_local! { static HAS_NULLS: Cell = const { Cell::new(false) }; }; @@ -256,7 +269,7 @@ impl ListChunked { // don't work on the Series that come out of amortized_iter(), // so we need to stick to this code path. HAS_NULLS.set(true); - scopeguard::defer!{ HAS_NULLS.set(orig_has_nulls); }; + scopeguard::defer! { HAS_NULLS.set(orig_has_nulls); }; let mut result = self.clear(); let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { @@ -266,7 +279,12 @@ impl ListChunked { }; let a = a_owner.as_ref(); let b = b_owner.as_ref(); - polars_ensure!(a.len() == b.len(), InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", a.len(), b.len()); + polars_ensure!( + a.len() == b.len(), + InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", + a.len(), + b.len() + ); op(a, b).and_then(|s| s.implode()).map(Series::from) }); for c in combined.into_iter() { @@ -277,6 +295,10 @@ impl ListChunked { let l_rechunked = self.clone().rechunk().into_series(); let l_leaf_array = l_rechunked.get_leaf_array(); let r_leaf_array = rhs.rechunk().get_leaf_array(); + polars_ensure!( + lists_same_shapes(&l_leaf_array.chunks()[0], &r_leaf_array.chunks()[0]), + InvalidOperation: "can only do arithmetic operations on lists of the same size" + ); let result = op(&l_leaf_array, &r_leaf_array)?; diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index bb04b64fd773..5c3fdc7e4469 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -21,7 +21,7 @@ UInt32, UInt64, ) -from polars.exceptions import ColumnNotFoundError, InvalidOperationError +from polars.exceptions import ColumnNotFoundError, InvalidOperationError, SchemaError from polars.testing import assert_frame_equal, assert_series_equal from tests.unit.conftest import INTEGER_DTYPES, NUMERIC_DTYPES @@ -682,9 +682,7 @@ def test_list_arithmetic_error_cases() -> None: _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1, 2], None]) # Different list length: - with pytest.raises( - InvalidOperationError, match="lists of the same size; got 2 and 1" - ): + with pytest.raises(InvalidOperationError, match="lists of the same size"): _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1]]) with pytest.raises( InvalidOperationError, match="lists of the same size; got 2 and 1" @@ -696,8 +694,8 @@ def test_list_arithmetic_error_cases() -> None: _ = pl.Series("a", [[1, 2]]) + pl.Series("b", ["hello"]) # Different nesting: - with pytest.raises(InvalidOperationError, match="TODO"): - _ = pl.Series("a", [[1]]) / pl.Series("b", [[[1]]]) + with pytest.raises(SchemaError, match="failed to determine supertype"): + _ = pl.Series("a", [[1]]) + pl.Series("b", [[[1]]]) def test_schema_owned_arithmetic_5669() -> None: From cf4fa30388ca71e6464e65bc6f90ccf846a55288 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Wed, 11 Sep 2024 14:29:05 -0400 Subject: [PATCH 19/30] Update to compile with latest code. --- crates/polars-core/src/series/arithmetic/borrowed.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 65cce0e9450d..118ef6eaa934 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -180,7 +180,7 @@ impl NumOpsDispatchInner for FixedSizeListType { fn reshape_list_based_on(data: &ArrayRef, shape: &ArrayRef) -> PolarsResult { if let Some(list_chunk) = shape.as_any().downcast_ref::() { let result = LargeListArray::try_new( - list_chunk.data_type().clone(), + list_chunk.dtype().clone(), list_chunk.offsets().clone(), reshape_list_based_on(data, list_chunk.values())?, list_chunk.validity().cloned(), @@ -275,7 +275,7 @@ impl ListChunked { let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { let (Some(a_owner), Some(b_owner)) = (a, b) else { // Operations with nulls always result in nulls: - return Ok(Series::full_null(self.name(), 1, self.dtype())); + return Ok(Series::full_null(self.name().clone(), 1, self.dtype())); }; let a = a_owner.as_ref(); let b = b_owner.as_ref(); From 03cddddfffb640afe1eac75868d4a4da554d8f32 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Thu, 12 Sep 2024 11:10:38 -0400 Subject: [PATCH 20/30] Get rid of thread local, expand testing slightly. --- .../src/series/arithmetic/borrowed.rs | 42 +++++++++---------- py-polars/polars/series/series.py | 16 +++---- .../operations/arithmetic/test_arithmetic.py | 13 +++++- 3 files changed, 38 insertions(+), 33 deletions(-) diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 118ef6eaa934..9904a440fa18 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -1,5 +1,3 @@ -use std::cell::Cell; - use arrow::array::Array; use super::*; @@ -228,6 +226,7 @@ impl ListChunked { &self, rhs: &Series, op: &dyn Fn(&Series, &Series) -> PolarsResult, + has_nulls: Option, ) -> PolarsResult { polars_ensure!( self.len() == rhs.len(), @@ -236,13 +235,7 @@ impl ListChunked { rhs.len() ); - thread_local! { - static HAS_NULLS: Cell = const { Cell::new(false) }; - }; - - let orig_has_nulls = HAS_NULLS.get(); - - let mut has_nulls = orig_has_nulls; + let mut has_nulls = has_nulls.unwrap_or(false); if !has_nulls { for chunk in self.chunks().iter() { if maybe_list_has_nulls(chunk) { @@ -263,14 +256,6 @@ impl ListChunked { // A slower implementation since we can't just add the underlying // values Arrow arrays. Given nulls, the two values arrays might not // line up the way we expect. - - // This can be recursive, so preserve the knowledge that there - // were nulls. Unfortunately get_leaf_array() and explode() - // don't work on the Series that come out of amortized_iter(), - // so we need to stick to this code path. - HAS_NULLS.set(true); - scopeguard::defer! { HAS_NULLS.set(orig_has_nulls); }; - let mut result = self.clear(); let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { let (Some(a_owner), Some(b_owner)) = (a, b) else { @@ -285,7 +270,18 @@ impl ListChunked { a.len(), b.len() ); - op(a, b).and_then(|s| s.implode()).map(Series::from) + let chunk_result = if let Ok(a_listchunked) = a.list() { + // If `a` contains more lists, we're going to reach this + // function recursively, and again have to decide whether to + // use the fast path (no nulls) or slow path (there were + // nulls). Since we know there were nulls, that means we + // have to stick to the slow path, so pass that information + // along. + a_listchunked.arithm_helper(b, op, Some(true)) + } else { + op(a, b) + }; + chunk_result.and_then(|s| s.implode()).map(Series::from) }); for c in combined.into_iter() { result.append(c?.list()?)?; @@ -321,19 +317,19 @@ impl ListChunked { impl NumOpsDispatchInner for ListType { fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.add_to(r)) + lhs.arithm_helper(rhs, &|l, r| l.add_to(r), None) } fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.subtract(r)) + lhs.arithm_helper(rhs, &|l, r| l.subtract(r), None) } fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.multiply(r)) + lhs.arithm_helper(rhs, &|l, r| l.multiply(r), None) } fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.divide(r)) + lhs.arithm_helper(rhs, &|l, r| l.divide(r), None) } fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.remainder(r)) + lhs.arithm_helper(rhs, &|l, r| l.remainder(r), None) } } diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index df0f4ec469bf..1966b1b78f1f 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -1058,21 +1058,21 @@ def __sub__(self, other: Any) -> Self | Expr: return F.lit(self) - other return self._arithmetic(other, "sub", "sub_<>") - def _recursive_cast_to_float64(self) -> Series: + def _recursive_cast_to_dtype(self, leaf_dtype: PolarsDataType) -> Series: """ - Convert leaf dtypes to Float64 dtypes. + Convert leaf dtype the to given primitive datatype. This is equivalent to logic in DataType::cast_leaf() in Rust. """ - def convert_to_float64(dtype: PolarsDataType) -> PolarsDataType: + def convert_to_primitive(dtype: PolarsDataType) -> PolarsDataType: if isinstance(dtype, Array): - return Array(convert_to_float64(dtype.inner), shape=dtype.shape) + return Array(convert_to_primitive(dtype.inner), shape=dtype.shape) if isinstance(dtype, List): - return List(convert_to_float64(dtype.inner)) - return Float64() + return List(convert_to_primitive(dtype.inner)) + return leaf_dtype - return self.cast(convert_to_float64(self.dtype)) + return self.cast(convert_to_primitive(self.dtype)) @overload def __truediv__(self, other: Expr) -> Expr: ... @@ -1091,7 +1091,7 @@ def __truediv__(self, other: Any) -> Series | Expr: if self.dtype.is_float() or self.dtype == Decimal: as_float = self else: - as_float = self._recursive_cast_to_float64() + as_float = self._recursive_cast_to_dtype(Float64()) return as_float._arithmetic(other, "div", "div_<>") diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index 5c3fdc7e4469..2f4b4673885f 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -664,9 +664,18 @@ def test_list_arithmetic_same_size( ], ) def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None: + series_a = pl.Series(a) + series_b = pl.Series(b) + series_expected = pl.Series(expected) + + # Same dtype: + assert_series_equal(series_a + series_b, series_expected) + + # Different dtype: assert_series_equal( - pl.Series(a) + pl.Series(b), - pl.Series(expected), + series_a._recursive_cast_to_dtype(pl.Int32()) + + series_b._recursive_cast_to_dtype(pl.Int64()), + series_expected._recursive_cast_to_dtype(pl.Int64()), ) From 19650abe0cf04804525578d7d3ffa72569af9c67 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Thu, 12 Sep 2024 11:12:11 -0400 Subject: [PATCH 21/30] Drop scopeguard as explicit dependency. --- Cargo.lock | 1 - Cargo.toml | 1 - crates/polars-core/Cargo.toml | 1 - 3 files changed, 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5c15fbf6d1fd..bd7cad4040f1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3093,7 +3093,6 @@ dependencies = [ "rand_distr", "rayon", "regex", - "scopeguard", "serde", "serde_json", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index 01943f51c029..11952df77330 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,7 +76,6 @@ recursive = "0.1" regex = "1.9" reqwest = { version = "0.12", default-features = false } ryu = "1.0.13" -scopeguard = "1.2.0" serde = { version = "1.0.188", features = ["derive", "rc"] } serde_json = "1" simd-json = { version = "0.13", features = ["known-key"] } diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index f6309d07d8c3..8ce74ced160c 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -33,7 +33,6 @@ rand = { workspace = true, optional = true, features = ["small_rng", "std"] } rand_distr = { workspace = true, optional = true } rayon = { workspace = true } regex = { workspace = true, optional = true } -scopeguard = { workspace = true } # activate if you want serde support for Series and DataFrames serde = { workspace = true, optional = true } serde_json = { workspace = true, optional = true } From 4972d0687e06d10f07a4420f909139da492aed9a Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Thu, 12 Sep 2024 11:27:12 -0400 Subject: [PATCH 22/30] Simplify by getting rid of intermediate Series. --- crates/polars-core/src/series/arithmetic/borrowed.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 9904a440fa18..39b560311d84 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -260,7 +260,8 @@ impl ListChunked { let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { let (Some(a_owner), Some(b_owner)) = (a, b) else { // Operations with nulls always result in nulls: - return Ok(Series::full_null(self.name().clone(), 1, self.dtype())); + let inner_dtype = self.dtype().inner_dtype().unwrap(); + return Ok(ListChunked::full_null_with_dtype(self.name().clone(), 1, inner_dtype)); }; let a = a_owner.as_ref(); let b = b_owner.as_ref(); @@ -281,10 +282,10 @@ impl ListChunked { } else { op(a, b) }; - chunk_result.and_then(|s| s.implode()).map(Series::from) + chunk_result.and_then(|s| s.implode()) }); for c in combined.into_iter() { - result.append(c?.list()?)?; + result.append(&c?)?; } return Ok(result.into()); } From ee74063b5c59d8f19251015d70a012db95def224 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Mon, 16 Sep 2024 13:03:58 -0400 Subject: [PATCH 23/30] Simpler signature, better name. --- .../src/series/arithmetic/borrowed.rs | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 39b560311d84..aef7002f1af9 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -175,23 +175,23 @@ impl NumOpsDispatchInner for FixedSizeListType { /// Given an ArrayRef with some primitive values, wrap it in list(s) until it /// matches the requested shape. -fn reshape_list_based_on(data: &ArrayRef, shape: &ArrayRef) -> PolarsResult { +fn reshape_list_based_on(data: &ArrayRef, shape: &ArrayRef) -> ArrayRef { if let Some(list_chunk) = shape.as_any().downcast_ref::() { - let result = LargeListArray::try_new( + let result = LargeListArray::new( list_chunk.dtype().clone(), list_chunk.offsets().clone(), - reshape_list_based_on(data, list_chunk.values())?, + reshape_list_based_on(data, list_chunk.values()), list_chunk.validity().cloned(), - )?; - Ok(Box::new(result)) + ); + Box::new(result) } else { - Ok(data.clone()) + data.clone() } } /// Given an ArrayRef, return true if it's a LargeListArrays and it has one or /// more nulls. -fn maybe_list_has_nulls(data: &ArrayRef) -> bool { +fn does_list_have_nulls(data: &ArrayRef) -> bool { if let Some(list_chunk) = data.as_any().downcast_ref::() { if list_chunk .validity() @@ -200,7 +200,7 @@ fn maybe_list_has_nulls(data: &ArrayRef) -> bool { { true } else { - maybe_list_has_nulls(list_chunk.values()) + does_list_have_nulls(list_chunk.values()) } } else { false @@ -238,7 +238,7 @@ impl ListChunked { let mut has_nulls = has_nulls.unwrap_or(false); if !has_nulls { for chunk in self.chunks().iter() { - if maybe_list_has_nulls(chunk) { + if does_list_have_nulls(chunk) { has_nulls = true; break; } @@ -246,7 +246,7 @@ impl ListChunked { } if !has_nulls { for chunk in rhs.chunks().iter() { - if maybe_list_has_nulls(chunk) { + if does_list_have_nulls(chunk) { has_nulls = true; break; } @@ -305,7 +305,7 @@ impl ListChunked { let result_chunks = result.chunks(); assert_eq!(result_chunks.len(), 1); let left_chunk = &l_rechunked.chunks()[0]; - let result_chunk = reshape_list_based_on(&result_chunks[0], left_chunk)?; + let result_chunk = reshape_list_based_on(&result_chunks[0], left_chunk); unsafe { let mut result = From b27b7ff96064bb0b0cc5dacc6cebddaf1f477b14 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Mon, 16 Sep 2024 13:44:29 -0400 Subject: [PATCH 24/30] Use an AnonymousListBuilder. --- .../src/series/arithmetic/borrowed.rs | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index aef7002f1af9..327ed6197be3 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -1,6 +1,7 @@ use arrow::array::Array; use super::*; +use crate::chunked_array::builder::AnonymousListBuilder; use crate::utils::align_chunks_binary; pub trait NumOpsDispatchInner: PolarsDataType + Sized { @@ -210,6 +211,8 @@ fn does_list_have_nulls(data: &ArrayRef) -> bool { /// Return whether the left and right have the same shape. We assume neither has /// any nulls, recursively. fn lists_same_shapes(left: &ArrayRef, right: &ArrayRef) -> bool { + debug_assert!(!does_list_have_nulls(left)); + debug_assert!(!does_list_have_nulls(right)); let left_as_list = left.as_any().downcast_ref::(); let right_as_list = right.as_any().downcast_ref::(); match (left_as_list, right_as_list) { @@ -256,12 +259,15 @@ impl ListChunked { // A slower implementation since we can't just add the underlying // values Arrow arrays. Given nulls, the two values arrays might not // line up the way we expect. - let mut result = self.clear(); + let mut result = AnonymousListBuilder::new( + self.name().clone(), + self.len(), + Some(self.inner_dtype().clone()), + ); let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { let (Some(a_owner), Some(b_owner)) = (a, b) else { // Operations with nulls always result in nulls: - let inner_dtype = self.dtype().inner_dtype().unwrap(); - return Ok(ListChunked::full_null_with_dtype(self.name().clone(), 1, inner_dtype)); + return Ok(None); }; let a = a_owner.as_ref(); let b = b_owner.as_ref(); @@ -282,12 +288,16 @@ impl ListChunked { } else { op(a, b) }; - chunk_result.and_then(|s| s.implode()) - }); - for c in combined.into_iter() { - result.append(&c?)?; + chunk_result.map(Some) + }).collect::>>>()?; + for s in combined.iter() { + if let Some(s) = s { + result.append_series(s)?; + } else { + result.append_null(); + } } - return Ok(result.into()); + return Ok(result.finish().into()); } let l_rechunked = self.clone().rechunk().into_series(); let l_leaf_array = l_rechunked.get_leaf_array(); From b3566833bec62d87b86380df004a3e21c921d492 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Mon, 16 Sep 2024 13:54:21 -0400 Subject: [PATCH 25/30] Split list handling into its own module. --- .../src/series/arithmetic/borrowed.rs | 173 ----------------- .../src/series/arithmetic/list_borrowed.rs | 177 ++++++++++++++++++ .../polars-core/src/series/arithmetic/mod.rs | 1 + 3 files changed, 178 insertions(+), 173 deletions(-) create mode 100644 crates/polars-core/src/series/arithmetic/list_borrowed.rs diff --git a/crates/polars-core/src/series/arithmetic/borrowed.rs b/crates/polars-core/src/series/arithmetic/borrowed.rs index 327ed6197be3..6003d0b05792 100644 --- a/crates/polars-core/src/series/arithmetic/borrowed.rs +++ b/crates/polars-core/src/series/arithmetic/borrowed.rs @@ -1,7 +1,4 @@ -use arrow::array::Array; - use super::*; -use crate::chunked_array::builder::AnonymousListBuilder; use crate::utils::align_chunks_binary; pub trait NumOpsDispatchInner: PolarsDataType + Sized { @@ -174,176 +171,6 @@ impl NumOpsDispatchInner for FixedSizeListType { } } -/// Given an ArrayRef with some primitive values, wrap it in list(s) until it -/// matches the requested shape. -fn reshape_list_based_on(data: &ArrayRef, shape: &ArrayRef) -> ArrayRef { - if let Some(list_chunk) = shape.as_any().downcast_ref::() { - let result = LargeListArray::new( - list_chunk.dtype().clone(), - list_chunk.offsets().clone(), - reshape_list_based_on(data, list_chunk.values()), - list_chunk.validity().cloned(), - ); - Box::new(result) - } else { - data.clone() - } -} - -/// Given an ArrayRef, return true if it's a LargeListArrays and it has one or -/// more nulls. -fn does_list_have_nulls(data: &ArrayRef) -> bool { - if let Some(list_chunk) = data.as_any().downcast_ref::() { - if list_chunk - .validity() - .map(|bitmap| bitmap.unset_bits() > 0) - .unwrap_or(false) - { - true - } else { - does_list_have_nulls(list_chunk.values()) - } - } else { - false - } -} - -/// Return whether the left and right have the same shape. We assume neither has -/// any nulls, recursively. -fn lists_same_shapes(left: &ArrayRef, right: &ArrayRef) -> bool { - debug_assert!(!does_list_have_nulls(left)); - debug_assert!(!does_list_have_nulls(right)); - let left_as_list = left.as_any().downcast_ref::(); - let right_as_list = right.as_any().downcast_ref::(); - match (left_as_list, right_as_list) { - (Some(left), Some(right)) => { - left.offsets() == right.offsets() && lists_same_shapes(left.values(), right.values()) - }, - (None, None) => left.len() == right.len(), - _ => false, - } -} - -impl ListChunked { - fn arithm_helper( - &self, - rhs: &Series, - op: &dyn Fn(&Series, &Series) -> PolarsResult, - has_nulls: Option, - ) -> PolarsResult { - polars_ensure!( - self.len() == rhs.len(), - InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", - self.len(), - rhs.len() - ); - - let mut has_nulls = has_nulls.unwrap_or(false); - if !has_nulls { - for chunk in self.chunks().iter() { - if does_list_have_nulls(chunk) { - has_nulls = true; - break; - } - } - } - if !has_nulls { - for chunk in rhs.chunks().iter() { - if does_list_have_nulls(chunk) { - has_nulls = true; - break; - } - } - } - if has_nulls { - // A slower implementation since we can't just add the underlying - // values Arrow arrays. Given nulls, the two values arrays might not - // line up the way we expect. - let mut result = AnonymousListBuilder::new( - self.name().clone(), - self.len(), - Some(self.inner_dtype().clone()), - ); - let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { - let (Some(a_owner), Some(b_owner)) = (a, b) else { - // Operations with nulls always result in nulls: - return Ok(None); - }; - let a = a_owner.as_ref(); - let b = b_owner.as_ref(); - polars_ensure!( - a.len() == b.len(), - InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", - a.len(), - b.len() - ); - let chunk_result = if let Ok(a_listchunked) = a.list() { - // If `a` contains more lists, we're going to reach this - // function recursively, and again have to decide whether to - // use the fast path (no nulls) or slow path (there were - // nulls). Since we know there were nulls, that means we - // have to stick to the slow path, so pass that information - // along. - a_listchunked.arithm_helper(b, op, Some(true)) - } else { - op(a, b) - }; - chunk_result.map(Some) - }).collect::>>>()?; - for s in combined.iter() { - if let Some(s) = s { - result.append_series(s)?; - } else { - result.append_null(); - } - } - return Ok(result.finish().into()); - } - let l_rechunked = self.clone().rechunk().into_series(); - let l_leaf_array = l_rechunked.get_leaf_array(); - let r_leaf_array = rhs.rechunk().get_leaf_array(); - polars_ensure!( - lists_same_shapes(&l_leaf_array.chunks()[0], &r_leaf_array.chunks()[0]), - InvalidOperation: "can only do arithmetic operations on lists of the same size" - ); - - let result = op(&l_leaf_array, &r_leaf_array)?; - - // We now need to wrap the Arrow arrays with the metadata that turns - // them into lists: - // TODO is there a way to do this without cloning the underlying data? - let result_chunks = result.chunks(); - assert_eq!(result_chunks.len(), 1); - let left_chunk = &l_rechunked.chunks()[0]; - let result_chunk = reshape_list_based_on(&result_chunks[0], left_chunk); - - unsafe { - let mut result = - ListChunked::new_with_dims(self.field.clone(), vec![result_chunk], 0, 0); - result.compute_len(); - Ok(result.into()) - } - } -} - -impl NumOpsDispatchInner for ListType { - fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.add_to(r), None) - } - fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.subtract(r), None) - } - fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.multiply(r), None) - } - fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.divide(r), None) - } - fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult { - lhs.arithm_helper(rhs, &|l, r| l.remainder(r), None) - } -} - #[cfg(feature = "checked_arithmetic")] pub mod checked { use num_traits::{CheckedDiv, One, ToPrimitive, Zero}; diff --git a/crates/polars-core/src/series/arithmetic/list_borrowed.rs b/crates/polars-core/src/series/arithmetic/list_borrowed.rs new file mode 100644 index 000000000000..1628780d7b0e --- /dev/null +++ b/crates/polars-core/src/series/arithmetic/list_borrowed.rs @@ -0,0 +1,177 @@ +//! Allow arithmetic operations for ListChunked. + +use super::*; +use crate::chunked_array::builder::AnonymousListBuilder; + +/// Given an ArrayRef with some primitive values, wrap it in list(s) until it +/// matches the requested shape. +fn reshape_list_based_on(data: &ArrayRef, shape: &ArrayRef) -> ArrayRef { + if let Some(list_chunk) = shape.as_any().downcast_ref::() { + let result = LargeListArray::new( + list_chunk.dtype().clone(), + list_chunk.offsets().clone(), + reshape_list_based_on(data, list_chunk.values()), + list_chunk.validity().cloned(), + ); + Box::new(result) + } else { + data.clone() + } +} + +/// Given an ArrayRef, return true if it's a LargeListArrays and it has one or +/// more nulls. +fn does_list_have_nulls(data: &ArrayRef) -> bool { + if let Some(list_chunk) = data.as_any().downcast_ref::() { + if list_chunk + .validity() + .map(|bitmap| bitmap.unset_bits() > 0) + .unwrap_or(false) + { + true + } else { + does_list_have_nulls(list_chunk.values()) + } + } else { + false + } +} + +/// Return whether the left and right have the same shape. We assume neither has +/// any nulls, recursively. +fn lists_same_shapes(left: &ArrayRef, right: &ArrayRef) -> bool { + debug_assert!(!does_list_have_nulls(left)); + debug_assert!(!does_list_have_nulls(right)); + let left_as_list = left.as_any().downcast_ref::(); + let right_as_list = right.as_any().downcast_ref::(); + match (left_as_list, right_as_list) { + (Some(left), Some(right)) => { + left.offsets() == right.offsets() && lists_same_shapes(left.values(), right.values()) + }, + (None, None) => left.len() == right.len(), + _ => false, + } +} + +impl ListChunked { + /// Helper function for NumOpsDispatchInner implementation for ListChunked. + /// + /// Run the given `op` on `self` and `rhs`. + fn arithm_helper( + &self, + rhs: &Series, + op: &dyn Fn(&Series, &Series) -> PolarsResult, + has_nulls: Option, + ) -> PolarsResult { + polars_ensure!( + self.len() == rhs.len(), + InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", + self.len(), + rhs.len() + ); + + let mut has_nulls = has_nulls.unwrap_or(false); + if !has_nulls { + for chunk in self.chunks().iter() { + if does_list_have_nulls(chunk) { + has_nulls = true; + break; + } + } + } + if !has_nulls { + for chunk in rhs.chunks().iter() { + if does_list_have_nulls(chunk) { + has_nulls = true; + break; + } + } + } + if has_nulls { + // A slower implementation since we can't just add the underlying + // values Arrow arrays. Given nulls, the two values arrays might not + // line up the way we expect. + let mut result = AnonymousListBuilder::new( + self.name().clone(), + self.len(), + Some(self.inner_dtype().clone()), + ); + let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { + let (Some(a_owner), Some(b_owner)) = (a, b) else { + // Operations with nulls always result in nulls: + return Ok(None); + }; + let a = a_owner.as_ref(); + let b = b_owner.as_ref(); + polars_ensure!( + a.len() == b.len(), + InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", + a.len(), + b.len() + ); + let chunk_result = if let Ok(a_listchunked) = a.list() { + // If `a` contains more lists, we're going to reach this + // function recursively, and again have to decide whether to + // use the fast path (no nulls) or slow path (there were + // nulls). Since we know there were nulls, that means we + // have to stick to the slow path, so pass that information + // along. + a_listchunked.arithm_helper(b, op, Some(true)) + } else { + op(a, b) + }; + chunk_result.map(Some) + }).collect::>>>()?; + for s in combined.iter() { + if let Some(s) = s { + result.append_series(s)?; + } else { + result.append_null(); + } + } + return Ok(result.finish().into()); + } + let l_rechunked = self.clone().rechunk().into_series(); + let l_leaf_array = l_rechunked.get_leaf_array(); + let r_leaf_array = rhs.rechunk().get_leaf_array(); + polars_ensure!( + lists_same_shapes(&l_leaf_array.chunks()[0], &r_leaf_array.chunks()[0]), + InvalidOperation: "can only do arithmetic operations on lists of the same size" + ); + + let result = op(&l_leaf_array, &r_leaf_array)?; + + // We now need to wrap the Arrow arrays with the metadata that turns + // them into lists: + // TODO is there a way to do this without cloning the underlying data? + let result_chunks = result.chunks(); + assert_eq!(result_chunks.len(), 1); + let left_chunk = &l_rechunked.chunks()[0]; + let result_chunk = reshape_list_based_on(&result_chunks[0], left_chunk); + + unsafe { + let mut result = + ListChunked::new_with_dims(self.field.clone(), vec![result_chunk], 0, 0); + result.compute_len(); + Ok(result.into()) + } + } +} + +impl NumOpsDispatchInner for ListType { + fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + lhs.arithm_helper(rhs, &|l, r| l.add_to(r), None) + } + fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + lhs.arithm_helper(rhs, &|l, r| l.subtract(r), None) + } + fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + lhs.arithm_helper(rhs, &|l, r| l.multiply(r), None) + } + fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + lhs.arithm_helper(rhs, &|l, r| l.divide(r), None) + } + fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + lhs.arithm_helper(rhs, &|l, r| l.remainder(r), None) + } +} diff --git a/crates/polars-core/src/series/arithmetic/mod.rs b/crates/polars-core/src/series/arithmetic/mod.rs index 7aa703221b7c..d7d7dbdb8a0e 100644 --- a/crates/polars-core/src/series/arithmetic/mod.rs +++ b/crates/polars-core/src/series/arithmetic/mod.rs @@ -1,4 +1,5 @@ mod borrowed; +mod list_borrowed; mod owned; use std::borrow::Cow; From 85cc6dd054a299945e0951a040ddb2a1184565c6 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Thu, 19 Sep 2024 13:29:55 -0400 Subject: [PATCH 26/30] Improve testing, and fix bug caught by the better test. --- crates/polars-expr/src/expressions/cast.rs | 9 ++++++++- .../tests/unit/operations/arithmetic/test_arithmetic.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/crates/polars-expr/src/expressions/cast.rs b/crates/polars-expr/src/expressions/cast.rs index ebfd50311918..1c486179ecd5 100644 --- a/crates/polars-expr/src/expressions/cast.rs +++ b/crates/polars-expr/src/expressions/cast.rs @@ -13,7 +13,14 @@ pub struct CastExpr { impl CastExpr { fn finish(&self, input: &Series) -> PolarsResult { - input.cast_with_options(&self.dtype, self.options) + let dtype = if matches!(input.dtype(), DataType::List(_)) && !self.dtype.is_nested() { + // Necessary for expressions that e.g. add UInt8 to List[Int64] to + // work. + &input.dtype().cast_leaf(self.dtype.clone()) + } else { + &self.dtype + }; + input.cast_with_options(dtype, self.options) } } diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index 2f4b4673885f..e8daf4d87926 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -639,7 +639,7 @@ def test_list_arithmetic_same_size( df = pl.DataFrame( [ pl.Series("a", [[1, 2], [3]]), - pl.Series("uint8", [[2, 2], [4]]), + pl.Series("uint8", [[2, 2], [4]], dtype=pl.List(pl.UInt8())), pl.Series("nested", [[[1, 2]], [[3]]]), ] ) From 920fed2d89f12b9d0b685a1e63104e02c942ad68 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Fri, 20 Sep 2024 11:24:51 -0400 Subject: [PATCH 27/30] There's an API for that. --- crates/polars-expr/src/expressions/cast.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/polars-expr/src/expressions/cast.rs b/crates/polars-expr/src/expressions/cast.rs index 1c486179ecd5..43301aff7189 100644 --- a/crates/polars-expr/src/expressions/cast.rs +++ b/crates/polars-expr/src/expressions/cast.rs @@ -13,7 +13,7 @@ pub struct CastExpr { impl CastExpr { fn finish(&self, input: &Series) -> PolarsResult { - let dtype = if matches!(input.dtype(), DataType::List(_)) && !self.dtype.is_nested() { + let dtype = if input.dtype().is_list() && !self.dtype.is_nested() { // Necessary for expressions that e.g. add UInt8 to List[Int64] to // work. &input.dtype().cast_leaf(self.dtype.clone()) From 0002d9a51df99650ad6c60b907fc8118588a7254 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Fri, 20 Sep 2024 14:20:48 -0400 Subject: [PATCH 28/30] Additional testing. --- .../tests/unit/operations/arithmetic/test_arithmetic.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index e8daf4d87926..99bb5e89ff71 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -629,6 +629,11 @@ def test_array_arithmetic_same_size( lambda a, b: a + b, ("nested", "nested"), ), + ( + [[[2, 4]], [[6]]], + lambda a, b: a + b, + ("nested", "nested_uint8"), + ), ], ) def test_list_arithmetic_same_size( @@ -641,6 +646,7 @@ def test_list_arithmetic_same_size( pl.Series("a", [[1, 2], [3]]), pl.Series("uint8", [[2, 2], [4]], dtype=pl.List(pl.UInt8())), pl.Series("nested", [[[1, 2]], [[3]]]), + pl.Series("nested_uint8", [[[1, 2]], [[3]]], dtype=pl.List(pl.List(pl.UInt8()))), ] ) # Expr-based arithmetic: From 9e2e346d05350cf9139237035cb57c8962769f66 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Fri, 20 Sep 2024 14:38:49 -0400 Subject: [PATCH 29/30] Remove a broken workaround I added, and replace it with actual fix for the problem. --- crates/polars-expr/src/expressions/cast.rs | 9 +-------- .../src/plans/conversion/type_coercion/binary.rs | 14 ++++++++------ 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/crates/polars-expr/src/expressions/cast.rs b/crates/polars-expr/src/expressions/cast.rs index 43301aff7189..ebfd50311918 100644 --- a/crates/polars-expr/src/expressions/cast.rs +++ b/crates/polars-expr/src/expressions/cast.rs @@ -13,14 +13,7 @@ pub struct CastExpr { impl CastExpr { fn finish(&self, input: &Series) -> PolarsResult { - let dtype = if input.dtype().is_list() && !self.dtype.is_nested() { - // Necessary for expressions that e.g. add UInt8 to List[Int64] to - // work. - &input.dtype().cast_leaf(self.dtype.clone()) - } else { - &self.dtype - }; - input.cast_with_options(dtype, self.options) + input.cast_with_options(&self.dtype, self.options) } } diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs index 7ee2282b0da9..37d58e004ab1 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs @@ -56,11 +56,12 @@ fn process_list_arithmetic( expr_arena: &mut Arena, ) -> PolarsResult> { match (&type_left, &type_right) { - (DataType::List(inner), _) => { - if type_right != **inner { + (DataType::List(_), _) => { + let leaf = type_left.leaf_dtype(); + if type_right != *leaf { let new_node_right = expr_arena.add(AExpr::Cast { expr: node_right, - dtype: *inner.clone(), + dtype: type_left.cast_leaf(leaf.clone()), options: CastOptions::NonStrict, }); @@ -73,11 +74,12 @@ fn process_list_arithmetic( Ok(None) } }, - (_, DataType::List(inner)) => { - if type_left != **inner { + (_, DataType::List(_)) => { + let leaf = type_right.leaf_dtype(); + if type_left != *leaf { let new_node_left = expr_arena.add(AExpr::Cast { expr: node_left, - dtype: *inner.clone(), + dtype: type_right.cast_leaf(leaf.clone()), options: CastOptions::NonStrict, }); From ead35ac6655cc3cb92efd8cefd342610890e7b68 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Fri, 20 Sep 2024 14:43:11 -0400 Subject: [PATCH 30/30] Fix formatting --- py-polars/tests/unit/operations/arithmetic/test_arithmetic.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index 99bb5e89ff71..360def065ca1 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -646,7 +646,9 @@ def test_list_arithmetic_same_size( pl.Series("a", [[1, 2], [3]]), pl.Series("uint8", [[2, 2], [4]], dtype=pl.List(pl.UInt8())), pl.Series("nested", [[[1, 2]], [[3]]]), - pl.Series("nested_uint8", [[[1, 2]], [[3]]], dtype=pl.List(pl.List(pl.UInt8()))), + pl.Series( + "nested_uint8", [[[1, 2]], [[3]]], dtype=pl.List(pl.List(pl.UInt8())) + ), ] ) # Expr-based arithmetic: