diff --git a/crates/polars-core/src/series/arithmetic/list_borrowed.rs b/crates/polars-core/src/series/arithmetic/list_borrowed.rs new file mode 100644 index 000000000000..1628780d7b0e --- /dev/null +++ b/crates/polars-core/src/series/arithmetic/list_borrowed.rs @@ -0,0 +1,177 @@ +//! Allow arithmetic operations for ListChunked. + +use super::*; +use crate::chunked_array::builder::AnonymousListBuilder; + +/// Given an ArrayRef with some primitive values, wrap it in list(s) until it +/// matches the requested shape. +fn reshape_list_based_on(data: &ArrayRef, shape: &ArrayRef) -> ArrayRef { + if let Some(list_chunk) = shape.as_any().downcast_ref::() { + let result = LargeListArray::new( + list_chunk.dtype().clone(), + list_chunk.offsets().clone(), + reshape_list_based_on(data, list_chunk.values()), + list_chunk.validity().cloned(), + ); + Box::new(result) + } else { + data.clone() + } +} + +/// Given an ArrayRef, return true if it's a LargeListArrays and it has one or +/// more nulls. +fn does_list_have_nulls(data: &ArrayRef) -> bool { + if let Some(list_chunk) = data.as_any().downcast_ref::() { + if list_chunk + .validity() + .map(|bitmap| bitmap.unset_bits() > 0) + .unwrap_or(false) + { + true + } else { + does_list_have_nulls(list_chunk.values()) + } + } else { + false + } +} + +/// Return whether the left and right have the same shape. We assume neither has +/// any nulls, recursively. +fn lists_same_shapes(left: &ArrayRef, right: &ArrayRef) -> bool { + debug_assert!(!does_list_have_nulls(left)); + debug_assert!(!does_list_have_nulls(right)); + let left_as_list = left.as_any().downcast_ref::(); + let right_as_list = right.as_any().downcast_ref::(); + match (left_as_list, right_as_list) { + (Some(left), Some(right)) => { + left.offsets() == right.offsets() && lists_same_shapes(left.values(), right.values()) + }, + (None, None) => left.len() == right.len(), + _ => false, + } +} + +impl ListChunked { + /// Helper function for NumOpsDispatchInner implementation for ListChunked. + /// + /// Run the given `op` on `self` and `rhs`. + fn arithm_helper( + &self, + rhs: &Series, + op: &dyn Fn(&Series, &Series) -> PolarsResult, + has_nulls: Option, + ) -> PolarsResult { + polars_ensure!( + self.len() == rhs.len(), + InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}", + self.len(), + rhs.len() + ); + + let mut has_nulls = has_nulls.unwrap_or(false); + if !has_nulls { + for chunk in self.chunks().iter() { + if does_list_have_nulls(chunk) { + has_nulls = true; + break; + } + } + } + if !has_nulls { + for chunk in rhs.chunks().iter() { + if does_list_have_nulls(chunk) { + has_nulls = true; + break; + } + } + } + if has_nulls { + // A slower implementation since we can't just add the underlying + // values Arrow arrays. Given nulls, the two values arrays might not + // line up the way we expect. + let mut result = AnonymousListBuilder::new( + self.name().clone(), + self.len(), + Some(self.inner_dtype().clone()), + ); + let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| { + let (Some(a_owner), Some(b_owner)) = (a, b) else { + // Operations with nulls always result in nulls: + return Ok(None); + }; + let a = a_owner.as_ref(); + let b = b_owner.as_ref(); + polars_ensure!( + a.len() == b.len(), + InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}", + a.len(), + b.len() + ); + let chunk_result = if let Ok(a_listchunked) = a.list() { + // If `a` contains more lists, we're going to reach this + // function recursively, and again have to decide whether to + // use the fast path (no nulls) or slow path (there were + // nulls). Since we know there were nulls, that means we + // have to stick to the slow path, so pass that information + // along. + a_listchunked.arithm_helper(b, op, Some(true)) + } else { + op(a, b) + }; + chunk_result.map(Some) + }).collect::>>>()?; + for s in combined.iter() { + if let Some(s) = s { + result.append_series(s)?; + } else { + result.append_null(); + } + } + return Ok(result.finish().into()); + } + let l_rechunked = self.clone().rechunk().into_series(); + let l_leaf_array = l_rechunked.get_leaf_array(); + let r_leaf_array = rhs.rechunk().get_leaf_array(); + polars_ensure!( + lists_same_shapes(&l_leaf_array.chunks()[0], &r_leaf_array.chunks()[0]), + InvalidOperation: "can only do arithmetic operations on lists of the same size" + ); + + let result = op(&l_leaf_array, &r_leaf_array)?; + + // We now need to wrap the Arrow arrays with the metadata that turns + // them into lists: + // TODO is there a way to do this without cloning the underlying data? + let result_chunks = result.chunks(); + assert_eq!(result_chunks.len(), 1); + let left_chunk = &l_rechunked.chunks()[0]; + let result_chunk = reshape_list_based_on(&result_chunks[0], left_chunk); + + unsafe { + let mut result = + ListChunked::new_with_dims(self.field.clone(), vec![result_chunk], 0, 0); + result.compute_len(); + Ok(result.into()) + } + } +} + +impl NumOpsDispatchInner for ListType { + fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + lhs.arithm_helper(rhs, &|l, r| l.add_to(r), None) + } + fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + lhs.arithm_helper(rhs, &|l, r| l.subtract(r), None) + } + fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + lhs.arithm_helper(rhs, &|l, r| l.multiply(r), None) + } + fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + lhs.arithm_helper(rhs, &|l, r| l.divide(r), None) + } + fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult { + lhs.arithm_helper(rhs, &|l, r| l.remainder(r), None) + } +} diff --git a/crates/polars-core/src/series/arithmetic/mod.rs b/crates/polars-core/src/series/arithmetic/mod.rs index 7aa703221b7c..d7d7dbdb8a0e 100644 --- a/crates/polars-core/src/series/arithmetic/mod.rs +++ b/crates/polars-core/src/series/arithmetic/mod.rs @@ -1,4 +1,5 @@ mod borrowed; +mod list_borrowed; mod owned; use std::borrow::Cow; diff --git a/crates/polars-core/src/series/implementations/list.rs b/crates/polars-core/src/series/implementations/list.rs index 865fadcfcb93..5e5b4a95d5e2 100644 --- a/crates/polars-core/src/series/implementations/list.rs +++ b/crates/polars-core/src/series/implementations/list.rs @@ -43,6 +43,24 @@ impl private::PrivateSeries for SeriesWrap { fn into_total_eq_inner<'a>(&'a self) -> Box { (&self.0).into_total_eq_inner() } + + fn add_to(&self, rhs: &Series) -> PolarsResult { + self.0.add_to(rhs) + } + + fn subtract(&self, rhs: &Series) -> PolarsResult { + self.0.subtract(rhs) + } + + fn multiply(&self, rhs: &Series) -> PolarsResult { + self.0.multiply(rhs) + } + fn divide(&self, rhs: &Series) -> PolarsResult { + self.0.divide(rhs) + } + fn remainder(&self, rhs: &Series) -> PolarsResult { + self.0.remainder(rhs) + } } impl SeriesTrait for SeriesWrap { diff --git a/crates/polars-expr/src/expressions/binary.rs b/crates/polars-expr/src/expressions/binary.rs index d76db8db59f8..179f4524aa7e 100644 --- a/crates/polars-expr/src/expressions/binary.rs +++ b/crates/polars-expr/src/expressions/binary.rs @@ -75,6 +75,11 @@ pub fn apply_operator(left: &Series, right: &Series, op: Operator) -> PolarsResu let right_dt = right.dtype().cast_leaf(Float64); left.cast(&left_dt)? / right.cast(&right_dt)? }, + dt @ List(_) => { + let left_dt = dt.cast_leaf(Float64); + let right_dt = right.dtype().cast_leaf(Float64); + left.cast(&left_dt)? / right.cast(&right_dt)? + }, _ => { if right.dtype().is_temporal() { return left / right; diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs index 7ee2282b0da9..37d58e004ab1 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/binary.rs @@ -56,11 +56,12 @@ fn process_list_arithmetic( expr_arena: &mut Arena, ) -> PolarsResult> { match (&type_left, &type_right) { - (DataType::List(inner), _) => { - if type_right != **inner { + (DataType::List(_), _) => { + let leaf = type_left.leaf_dtype(); + if type_right != *leaf { let new_node_right = expr_arena.add(AExpr::Cast { expr: node_right, - dtype: *inner.clone(), + dtype: type_left.cast_leaf(leaf.clone()), options: CastOptions::NonStrict, }); @@ -73,11 +74,12 @@ fn process_list_arithmetic( Ok(None) } }, - (_, DataType::List(inner)) => { - if type_left != **inner { + (_, DataType::List(_)) => { + let leaf = type_right.leaf_dtype(); + if type_left != *leaf { let new_node_left = expr_arena.add(AExpr::Cast { expr: node_left, - dtype: *inner.clone(), + dtype: type_right.cast_leaf(leaf.clone()), options: CastOptions::NonStrict, }); diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 532cd4c74bc1..506c52ed9f7a 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -1058,6 +1058,22 @@ def __sub__(self, other: Any) -> Self | Expr: return F.lit(self) - other return self._arithmetic(other, "sub", "sub_<>") + def _recursive_cast_to_dtype(self, leaf_dtype: PolarsDataType) -> Series: + """ + Convert leaf dtype the to given primitive datatype. + + This is equivalent to logic in DataType::cast_leaf() in Rust. + """ + + def convert_to_primitive(dtype: PolarsDataType) -> PolarsDataType: + if isinstance(dtype, Array): + return Array(convert_to_primitive(dtype.inner), shape=dtype.shape) + if isinstance(dtype, List): + return List(convert_to_primitive(dtype.inner)) + return leaf_dtype + + return self.cast(convert_to_primitive(self.dtype)) + @overload def __truediv__(self, other: Expr) -> Expr: ... @@ -1073,9 +1089,11 @@ def __truediv__(self, other: Any) -> Series | Expr: # this branch is exactly the floordiv function without rounding the floats if self.dtype.is_float() or self.dtype == Decimal: - return self._arithmetic(other, "div", "div_<>") + as_float = self + else: + as_float = self._recursive_cast_to_dtype(Float64()) - return self.cast(Float64) / other + return as_float._arithmetic(other, "div", "div_<>") @overload def __floordiv__(self, other: Expr) -> Expr: ... diff --git a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py index 374d2965a029..360def065ca1 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py +++ b/py-polars/tests/unit/operations/arithmetic/test_arithmetic.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import operator from collections import OrderedDict from datetime import date, datetime, timedelta -from typing import Any +from typing import Any, Callable import numpy as np import pytest @@ -19,7 +21,7 @@ UInt32, UInt64, ) -from polars.exceptions import ColumnNotFoundError, InvalidOperationError +from polars.exceptions import ColumnNotFoundError, InvalidOperationError, SchemaError from polars.testing import assert_frame_equal, assert_series_equal from tests.unit.conftest import INTEGER_DTYPES, NUMERIC_DTYPES @@ -558,35 +560,161 @@ def test_power_series() -> None: @pytest.mark.parametrize( - ("expected", "expr"), + ("expected", "expr", "column_names"), [ + (np.array([[2, 4], [6, 8]], dtype=np.int64), lambda a, b: a + b, ("a", "a")), + (np.array([[0, 0], [0, 0]], dtype=np.int64), lambda a, b: a - b, ("a", "a")), + (np.array([[1, 4], [9, 16]], dtype=np.int64), lambda a, b: a * b, ("a", "a")), ( - np.array([[2, 4], [6, 8]]), - pl.col("a") + pl.col("a"), + np.array([[1.0, 1.0], [1.0, 1.0]], dtype=np.float64), + lambda a, b: a / b, + ("a", "a"), ), + (np.array([[0, 0], [0, 0]], dtype=np.int64), lambda a, b: a % b, ("a", "a")), ( - np.array([[0, 0], [0, 0]]), - pl.col("a") - pl.col("a"), + np.array([[3, 4], [7, 8]], dtype=np.int64), + lambda a, b: a + b, + ("a", "uint8"), + ), + # This fails because the code is buggy, see + # https://github.com/pola-rs/polars/issues/17820 + # + # ( + # np.array([[[2, 4]], [[6, 8]]], dtype=np.int64), + # lambda a, b: a + b, + # ("nested", "nested"), + # ), + ], +) +def test_array_arithmetic_same_size( + expected: Any, + expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series], + column_names: tuple[str, str], +) -> None: + df = pl.DataFrame( + [ + pl.Series("a", np.array([[1, 2], [3, 4]], dtype=np.int64)), + pl.Series("uint8", np.array([[2, 2], [4, 4]], dtype=np.uint8)), + pl.Series("nested", np.array([[[1, 2]], [[3, 4]]], dtype=np.int64)), + ] + ) + print(df.select(expr(pl.col(column_names[0]), pl.col(column_names[1])))) + # Expr-based arithmetic: + assert_frame_equal( + df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), + pl.Series(column_names[0], expected).to_frame(), + ) + # Direct arithmetic on the Series: + assert_series_equal( + expr(df[column_names[0]], df[column_names[1]]), + pl.Series(column_names[0], expected), + ) + + +@pytest.mark.parametrize( + ("expected", "expr", "column_names"), + [ + ([[2, 4], [6]], lambda a, b: a + b, ("a", "a")), + ([[0, 0], [0]], lambda a, b: a - b, ("a", "a")), + ([[1, 4], [9]], lambda a, b: a * b, ("a", "a")), + ([[1.0, 1.0], [1.0]], lambda a, b: a / b, ("a", "a")), + ([[0, 0], [0]], lambda a, b: a % b, ("a", "a")), + ( + [[3, 4], [7]], + lambda a, b: a + b, + ("a", "uint8"), ), ( - np.array([[1, 4], [9, 16]]), - pl.col("a") * pl.col("a"), + [[[2, 4]], [[6]]], + lambda a, b: a + b, + ("nested", "nested"), ), ( - np.array([[1.0, 1.0], [1.0, 1.0]]), - pl.col("a") / pl.col("a"), + [[[2, 4]], [[6]]], + lambda a, b: a + b, + ("nested", "nested_uint8"), ), ], ) -def test_array_arithmetic_same_size(expected: Any, expr: pl.Expr) -> None: - df = pl.Series("a", np.array([[1, 2], [3, 4]])).to_frame() - +def test_list_arithmetic_same_size( + expected: Any, + expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series], + column_names: tuple[str, str], +) -> None: + df = pl.DataFrame( + [ + pl.Series("a", [[1, 2], [3]]), + pl.Series("uint8", [[2, 2], [4]], dtype=pl.List(pl.UInt8())), + pl.Series("nested", [[[1, 2]], [[3]]]), + pl.Series( + "nested_uint8", [[[1, 2]], [[3]]], dtype=pl.List(pl.List(pl.UInt8())) + ), + ] + ) + # Expr-based arithmetic: assert_frame_equal( - df.select(expr), - pl.Series("a", expected).to_frame(), + df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))), + pl.Series(column_names[0], expected).to_frame(), + ) + # Direct arithmetic on the Series: + assert_series_equal( + expr(df[column_names[0]], df[column_names[1]]), + pl.Series(column_names[0], expected), ) +@pytest.mark.parametrize( + ("a", "b", "expected"), + [ + ([[1, 2, 3]], [[1, None, 5]], [[2, None, 8]]), + ([[2], None, [5]], [None, [3], [2]], [None, None, [7]]), + ([[[2]], [None], [[4]]], [[[3]], [[6]], [[8]]], [[[5]], [None], [[12]]]), + ], +) +def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None: + series_a = pl.Series(a) + series_b = pl.Series(b) + series_expected = pl.Series(expected) + + # Same dtype: + assert_series_equal(series_a + series_b, series_expected) + + # Different dtype: + assert_series_equal( + series_a._recursive_cast_to_dtype(pl.Int32()) + + series_b._recursive_cast_to_dtype(pl.Int64()), + series_expected._recursive_cast_to_dtype(pl.Int64()), + ) + + +def test_list_arithmetic_error_cases() -> None: + # Different series length: + with pytest.raises( + InvalidOperationError, match="Series of the same size; got 1 and 2" + ): + _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1, 2], [3, 4]]) + with pytest.raises( + InvalidOperationError, match="Series of the same size; got 1 and 2" + ): + _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1, 2], None]) + + # Different list length: + with pytest.raises(InvalidOperationError, match="lists of the same size"): + _ = pl.Series("a", [[1, 2]]) / pl.Series("b", [[1]]) + with pytest.raises( + InvalidOperationError, match="lists of the same size; got 2 and 1" + ): + _ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None]) + + # Wrong types: + with pytest.raises(InvalidOperationError, match="cannot cast List type"): + _ = pl.Series("a", [[1, 2]]) + pl.Series("b", ["hello"]) + + # Different nesting: + with pytest.raises(SchemaError, match="failed to determine supertype"): + _ = pl.Series("a", [[1]]) + pl.Series("b", [[[1]]]) + + def test_schema_owned_arithmetic_5669() -> None: df = ( pl.LazyFrame({"A": [1, 2, 3]})