Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support arithmetic between Series with dtype list #17823

Merged
merged 34 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
31c0e44
A more functional sketch.
pythonspeed Jul 19, 2024
559d9d2
Make division work for `list[int64]` (and arrays too)
pythonspeed Jul 23, 2024
f205d4f
More thorough testing of array math expressions
pythonspeed Jul 23, 2024
53dd233
Another test, commented out
pythonspeed Jul 23, 2024
b58f7f4
Success case test suite for list arithmetic
pythonspeed Jul 23, 2024
6e888ee
Include reference to Rust code
pythonspeed Jul 23, 2024
0ef2846
Support division in lists
pythonspeed Jul 23, 2024
bffb77f
Test error edge cases for List arithmetic
pythonspeed Jul 23, 2024
856a884
Run ruff to fix formatting
pythonspeed Jul 23, 2024
c830d33
Fix lints
pythonspeed Jul 23, 2024
5692b62
Fix lint
pythonspeed Jul 23, 2024
e194f97
Clean up
pythonspeed Jul 23, 2024
0601f5e
Specify dtype explicitly
pythonspeed Jul 23, 2024
2a86fbc
Rewrite to operate directly on underlying data in one chunk.
pythonspeed Aug 7, 2024
f0eea11
Handle nulls correctly
pythonspeed Aug 8, 2024
7ba7fd6
WIP improvements to null handling.
pythonspeed Aug 8, 2024
ef8b39d
Merge remote-tracking branch 'origin/main' into 9188-list-arithmetic
pythonspeed Aug 8, 2024
00ba975
Null handling now appears to work with latest tests.
pythonspeed Aug 8, 2024
254b37e
All tests pass.
pythonspeed Aug 8, 2024
d1d3950
Merge branch 'main' into 9188-list-arithmetic
itamarst Sep 9, 2024
cfd08f9
Merge remote-tracking branch 'origin/main' into 9188-list-arithmetic
pythonspeed Sep 11, 2024
cf4fa30
Update to compile with latest code.
pythonspeed Sep 11, 2024
03cdddd
Get rid of thread local, expand testing slightly.
pythonspeed Sep 12, 2024
19650ab
Drop scopeguard as explicit dependency.
pythonspeed Sep 12, 2024
4972d06
Simplify by getting rid of intermediate Series.
pythonspeed Sep 12, 2024
677f8d8
Merge remote-tracking branch 'origin/main' into 9188-list-arithmetic
pythonspeed Sep 16, 2024
ee74063
Simpler signature, better name.
pythonspeed Sep 16, 2024
b27b7ff
Use an AnonymousListBuilder.
pythonspeed Sep 16, 2024
b356683
Split list handling into its own module.
pythonspeed Sep 16, 2024
85cc6dd
Improve testing, and fix bug caught by the better test.
pythonspeed Sep 19, 2024
920fed2
There's an API for that.
pythonspeed Sep 20, 2024
0002d9a
Additional testing.
pythonspeed Sep 20, 2024
9e2e346
Remove a broken workaround I added, and replace it with actual fix fo…
pythonspeed Sep 20, 2024
ead35ac
Fix formatting
pythonspeed Sep 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions crates/polars-core/src/series/arithmetic/list_borrowed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
//! Allow arithmetic operations for ListChunked.

use super::*;
use crate::chunked_array::builder::AnonymousListBuilder;

/// Given an ArrayRef with some primitive values, wrap it in list(s) until it
/// matches the requested shape.
fn reshape_list_based_on(data: &ArrayRef, shape: &ArrayRef) -> ArrayRef {
if let Some(list_chunk) = shape.as_any().downcast_ref::<LargeListArray>() {
let result = LargeListArray::new(
list_chunk.dtype().clone(),
list_chunk.offsets().clone(),
reshape_list_based_on(data, list_chunk.values()),
list_chunk.validity().cloned(),
);
Box::new(result)
} else {
data.clone()
}
}

/// Given an ArrayRef, return true if it's a LargeListArrays and it has one or
/// more nulls.
fn does_list_have_nulls(data: &ArrayRef) -> bool {
if let Some(list_chunk) = data.as_any().downcast_ref::<LargeListArray>() {
if list_chunk
.validity()
.map(|bitmap| bitmap.unset_bits() > 0)
.unwrap_or(false)
{
true
} else {
does_list_have_nulls(list_chunk.values())
}
} else {
false
}
}

/// Return whether the left and right have the same shape. We assume neither has
/// any nulls, recursively.
fn lists_same_shapes(left: &ArrayRef, right: &ArrayRef) -> bool {
debug_assert!(!does_list_have_nulls(left));
debug_assert!(!does_list_have_nulls(right));
let left_as_list = left.as_any().downcast_ref::<LargeListArray>();
let right_as_list = right.as_any().downcast_ref::<LargeListArray>();
match (left_as_list, right_as_list) {
(Some(left), Some(right)) => {
left.offsets() == right.offsets() && lists_same_shapes(left.values(), right.values())
},
(None, None) => left.len() == right.len(),
_ => false,
}
}

impl ListChunked {
/// Helper function for NumOpsDispatchInner implementation for ListChunked.
///
/// Run the given `op` on `self` and `rhs`.
fn arithm_helper(
&self,
rhs: &Series,
op: &dyn Fn(&Series, &Series) -> PolarsResult<Series>,
has_nulls: Option<bool>,
) -> PolarsResult<Series> {
polars_ensure!(
self.len() == rhs.len(),
InvalidOperation: "can only do arithmetic operations on Series of the same size; got {} and {}",
self.len(),
rhs.len()
);

let mut has_nulls = has_nulls.unwrap_or(false);
if !has_nulls {
for chunk in self.chunks().iter() {
if does_list_have_nulls(chunk) {
has_nulls = true;
break;
}
}
}
if !has_nulls {
for chunk in rhs.chunks().iter() {
if does_list_have_nulls(chunk) {
has_nulls = true;
break;
}
}
}
if has_nulls {
// A slower implementation since we can't just add the underlying
// values Arrow arrays. Given nulls, the two values arrays might not
// line up the way we expect.
let mut result = AnonymousListBuilder::new(
self.name().clone(),
self.len(),
Some(self.inner_dtype().clone()),
);
let combined = self.amortized_iter().zip(rhs.list()?.amortized_iter()).map(|(a, b)| {
let (Some(a_owner), Some(b_owner)) = (a, b) else {
// Operations with nulls always result in nulls:
return Ok(None);
};
let a = a_owner.as_ref();
let b = b_owner.as_ref();
polars_ensure!(
a.len() == b.len(),
InvalidOperation: "can only do arithmetic operations on lists of the same size; got {} and {}",
a.len(),
b.len()
);
let chunk_result = if let Ok(a_listchunked) = a.list() {
// If `a` contains more lists, we're going to reach this
// function recursively, and again have to decide whether to
// use the fast path (no nulls) or slow path (there were
// nulls). Since we know there were nulls, that means we
// have to stick to the slow path, so pass that information
// along.
a_listchunked.arithm_helper(b, op, Some(true))
} else {
op(a, b)
};
chunk_result.map(Some)
}).collect::<PolarsResult<Vec<Option<Series>>>>()?;
for s in combined.iter() {
if let Some(s) = s {
result.append_series(s)?;
} else {
result.append_null();
}
}
return Ok(result.finish().into());
}
let l_rechunked = self.clone().rechunk().into_series();
let l_leaf_array = l_rechunked.get_leaf_array();
let r_leaf_array = rhs.rechunk().get_leaf_array();
polars_ensure!(
lists_same_shapes(&l_leaf_array.chunks()[0], &r_leaf_array.chunks()[0]),
InvalidOperation: "can only do arithmetic operations on lists of the same size"
);

let result = op(&l_leaf_array, &r_leaf_array)?;

// We now need to wrap the Arrow arrays with the metadata that turns
// them into lists:
// TODO is there a way to do this without cloning the underlying data?
let result_chunks = result.chunks();
assert_eq!(result_chunks.len(), 1);
let left_chunk = &l_rechunked.chunks()[0];
let result_chunk = reshape_list_based_on(&result_chunks[0], left_chunk);

unsafe {
let mut result =
ListChunked::new_with_dims(self.field.clone(), vec![result_chunk], 0, 0);
result.compute_len();
Ok(result.into())
}
}
}

impl NumOpsDispatchInner for ListType {
fn add_to(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.add_to(r), None)
}
fn subtract(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.subtract(r), None)
}
fn multiply(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.multiply(r), None)
}
fn divide(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.divide(r), None)
}
fn remainder(lhs: &ListChunked, rhs: &Series) -> PolarsResult<Series> {
lhs.arithm_helper(rhs, &|l, r| l.remainder(r), None)
}
}
1 change: 1 addition & 0 deletions crates/polars-core/src/series/arithmetic/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod borrowed;
mod list_borrowed;
mod owned;

use std::borrow::Cow;
Expand Down
18 changes: 18 additions & 0 deletions crates/polars-core/src/series/implementations/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,24 @@ impl private::PrivateSeries for SeriesWrap<ListChunked> {
fn into_total_eq_inner<'a>(&'a self) -> Box<dyn TotalEqInner + 'a> {
(&self.0).into_total_eq_inner()
}

fn add_to(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.add_to(rhs)
}

fn subtract(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.subtract(rhs)
}

fn multiply(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.multiply(rhs)
}
fn divide(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.divide(rhs)
}
fn remainder(&self, rhs: &Series) -> PolarsResult<Series> {
self.0.remainder(rhs)
}
}

impl SeriesTrait for SeriesWrap<ListChunked> {
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ pub fn apply_operator(left: &Series, right: &Series, op: Operator) -> PolarsResu
let right_dt = right.dtype().cast_leaf(Float64);
left.cast(&left_dt)? / right.cast(&right_dt)?
},
dt @ List(_) => {
let left_dt = dt.cast_leaf(Float64);
let right_dt = right.dtype().cast_leaf(Float64);
left.cast(&left_dt)? / right.cast(&right_dt)?
},
_ => {
if right.dtype().is_temporal() {
return left / right;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ fn process_list_arithmetic(
expr_arena: &mut Arena<AExpr>,
) -> PolarsResult<Option<AExpr>> {
match (&type_left, &type_right) {
(DataType::List(inner), _) => {
if type_right != **inner {
(DataType::List(_), _) => {
let leaf = type_left.leaf_dtype();
if type_right != *leaf {
let new_node_right = expr_arena.add(AExpr::Cast {
expr: node_right,
dtype: *inner.clone(),
dtype: type_left.cast_leaf(leaf.clone()),
options: CastOptions::NonStrict,
});

Expand All @@ -73,11 +74,12 @@ fn process_list_arithmetic(
Ok(None)
}
},
(_, DataType::List(inner)) => {
if type_left != **inner {
(_, DataType::List(_)) => {
let leaf = type_right.leaf_dtype();
if type_left != *leaf {
let new_node_left = expr_arena.add(AExpr::Cast {
expr: node_left,
dtype: *inner.clone(),
dtype: type_right.cast_leaf(leaf.clone()),
options: CastOptions::NonStrict,
});

Expand Down
22 changes: 20 additions & 2 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,22 @@ def __sub__(self, other: Any) -> Self | Expr:
return F.lit(self) - other
return self._arithmetic(other, "sub", "sub_<>")

def _recursive_cast_to_dtype(self, leaf_dtype: PolarsDataType) -> Series:
"""
Convert leaf dtype the to given primitive datatype.
This is equivalent to logic in DataType::cast_leaf() in Rust.
"""

def convert_to_primitive(dtype: PolarsDataType) -> PolarsDataType:
if isinstance(dtype, Array):
return Array(convert_to_primitive(dtype.inner), shape=dtype.shape)
if isinstance(dtype, List):
return List(convert_to_primitive(dtype.inner))
return leaf_dtype

return self.cast(convert_to_primitive(self.dtype))

@overload
def __truediv__(self, other: Expr) -> Expr: ...

Expand All @@ -1073,9 +1089,11 @@ def __truediv__(self, other: Any) -> Series | Expr:

# this branch is exactly the floordiv function without rounding the floats
if self.dtype.is_float() or self.dtype == Decimal:
return self._arithmetic(other, "div", "div_<>")
as_float = self
else:
as_float = self._recursive_cast_to_dtype(Float64())

return self.cast(Float64) / other
return as_float._arithmetic(other, "div", "div_<>")

@overload
def __floordiv__(self, other: Expr) -> Expr: ...
Expand Down
Loading