Skip to content

Commit

Permalink
feat: List set_operations supports float (pola-rs#13920)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored Jan 23, 2024
1 parent 83f7b16 commit 1bdac4e
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 21 deletions.
9 changes: 9 additions & 0 deletions crates/polars-arrow/src/array/primitive/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::iter::FromIterator;
use std::sync::Arc;

use polars_error::PolarsResult;
use polars_utils::total_ord::TotalOrdWrap;

use super::{check, PrimitiveArray};
use crate::array::physical_binary::extend_validity;
Expand Down Expand Up @@ -363,6 +364,14 @@ impl<T: NativeType> Extend<Option<T>> for MutablePrimitiveArray<T> {
}
}

impl<T: NativeType> Extend<Option<TotalOrdWrap<T>>> for MutablePrimitiveArray<T> {
fn extend<I: IntoIterator<Item = Option<TotalOrdWrap<T>>>>(&mut self, iter: I) {
let iter = iter.into_iter();
self.reserve(iter.size_hint().0);
iter.for_each(|x| self.push(x.map(|x| x.0)))
}
}

impl<T: NativeType> TryExtend<Option<T>> for MutablePrimitiveArray<T> {
/// This is infallible and is implemented for consistency with all other types
fn try_extend<I: IntoIterator<Item = Option<T>>>(&mut self, iter: I) -> PolarsResult<()> {
Expand Down
16 changes: 8 additions & 8 deletions crates/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ macro_rules! match_arrow_data_type_apply_macro_ca {
DataType::Int64 => $macro!($self.i64().unwrap() $(, $opt_args)*),
DataType::Float32 => $macro!($self.f32().unwrap() $(, $opt_args)*),
DataType::Float64 => $macro!($self.f64().unwrap() $(, $opt_args)*),
_ => unimplemented!(),
dt => panic!("not implemented for dtype {:?}", dt),
}
}};
}
Expand All @@ -302,7 +302,7 @@ macro_rules! with_match_physical_numeric_type {(
UInt64 => __with_ty__! { u64 },
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
_ => unimplemented!()
dt => panic!("not implemented for dtype {:?}", dt),
}
})}

Expand All @@ -321,7 +321,7 @@ macro_rules! with_match_physical_integer_type {(
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
_ => unimplemented!()
dt => panic!("not implemented for dtype {:?}", dt),
}
})}

Expand All @@ -334,7 +334,7 @@ macro_rules! with_match_physical_float_polars_type {(
match $key_type {
Float32 => __with_ty__! { Float32Type },
Float64 => __with_ty__! { Float64Type },
_ => unimplemented!()
dt => panic!("not implemented for dtype {:?}", dt),
}
})}

Expand All @@ -359,7 +359,7 @@ macro_rules! with_match_physical_numeric_polars_type {(
UInt64 => __with_ty__! { UInt64Type },
Float32 => __with_ty__! { Float32Type },
Float64 => __with_ty__! { Float64Type },
dt => panic!("not implemented for dtype: {}", dt)
dt => panic!("not implemented for dtype {:?}", dt),
}
})}

Expand All @@ -383,7 +383,7 @@ macro_rules! with_match_physical_integer_polars_type {(
UInt16 => __with_ty__! { UInt16Type },
UInt32 => __with_ty__! { UInt32Type },
UInt64 => __with_ty__! { UInt64Type },
_ => unimplemented!()
dt => panic!("not implemented for dtype {:?}", dt),
}
})}

Expand Down Expand Up @@ -505,7 +505,7 @@ macro_rules! apply_amortized_generic_list_or_array {
#[cfg(feature = "dtype-array")]
DataType::Array(_, _) => $self.array().unwrap().apply_amortized_generic($($args),*),
DataType::List(_) => $self.list().unwrap().apply_amortized_generic($($args),*),
_ => unimplemented!(),
dt => panic!("not implemented for dtype {:?}", dt),
}
}
}
Expand All @@ -526,7 +526,7 @@ macro_rules! apply_method_physical_integer {
DataType::Int16 => $self.i16().unwrap().$method($($args),*),
DataType::Int32 => $self.i32().unwrap().$method($($args),*),
DataType::Int64 => $self.i64().unwrap().$method($($args),*),
_ => unimplemented!(),
dt => panic!("not implemented for dtype {:?}", dt),
}
}
}
Expand Down
37 changes: 24 additions & 13 deletions crates/polars-ops/src/chunked_array/list/sets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use arrow::compute::utils::combine_validities_and;
use arrow::offset::OffsetsBuffer;
use arrow::types::NativeType;
use polars_core::prelude::*;
use polars_core::with_match_physical_integer_type;
use polars_core::with_match_physical_numeric_type;
use polars_utils::total_ord::{TotalEq, TotalHash, TotalOrdWrap};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand All @@ -29,6 +30,16 @@ where
}
}

impl<T> MaterializeValues<Option<TotalOrdWrap<T>>> for MutablePrimitiveArray<T>
where
T: NativeType,
{
fn extend_buf<I: Iterator<Item = Option<TotalOrdWrap<T>>>>(&mut self, values: I) -> usize {
self.extend(values);
self.len()
}
}

impl<'a> MaterializeValues<Option<&'a [u8]>> for MutablePlBinary {
fn extend_buf<I: Iterator<Item = Option<&'a [u8]>>>(&mut self, values: I) -> usize {
self.extend(values);
Expand Down Expand Up @@ -91,8 +102,8 @@ where
}
}

fn copied_opt<T: Copy>(v: Option<&T>) -> Option<T> {
v.copied()
fn copied_wrapper_opt<T: Copy>(v: Option<&T>) -> Option<TotalOrdWrap<T>> {
v.copied().map(TotalOrdWrap)
}

#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
Expand Down Expand Up @@ -125,13 +136,13 @@ fn primitive<T>(
validity: Option<Bitmap>,
) -> PolarsResult<ListArray<i64>>
where
T: NativeType + Hash + Copy + Eq,
T: NativeType + TotalHash + Copy + TotalEq,
{
let broadcast_lhs = offsets_a.len() == 2;
let broadcast_rhs = offsets_b.len() == 2;

let mut set = Default::default();
let mut set2: PlIndexSet<Option<T>> = Default::default();
let mut set2: PlIndexSet<Option<TotalOrdWrap<T>>> = Default::default();

let mut values_out = MutablePrimitiveArray::with_capacity(std::cmp::max(
*offsets_a.last().unwrap(),
Expand All @@ -141,7 +152,7 @@ where
offsets.push(0i64);

if broadcast_rhs {
set2.extend(b.into_iter().map(copied_opt));
set2.extend(b.into_iter().map(copied_wrapper_opt));
}
let offsets_slice = if offsets_a.len() > offsets_b.len() {
offsets_a
Expand All @@ -168,8 +179,8 @@ where
.into_iter()
.skip(start_a)
.take(end_a - start_a)
.map(copied_opt);
let b_iter = b.into_iter().map(copied_opt);
.map(copied_wrapper_opt);
let b_iter = b.into_iter().map(copied_wrapper_opt);
set_operation(
&mut set,
&mut set2,
Expand All @@ -180,13 +191,13 @@ where
true,
)
} else if broadcast_lhs {
let a_iter = a.into_iter().map(copied_opt);
let a_iter = a.into_iter().map(copied_wrapper_opt);

let b_iter = b
.into_iter()
.skip(start_b)
.take(end_b - start_b)
.map(copied_opt);
.map(copied_wrapper_opt);

set_operation(
&mut set,
Expand All @@ -203,13 +214,13 @@ where
.into_iter()
.skip(start_a)
.take(end_a - start_a)
.map(copied_opt);
.map(copied_wrapper_opt);

let b_iter = b
.into_iter()
.skip(start_b)
.take(end_b - start_b)
.map(copied_opt);
.map(copied_wrapper_opt);
set_operation(
&mut set,
&mut set2,
Expand Down Expand Up @@ -366,7 +377,7 @@ fn array_set_operation(
polars_bail!(InvalidOperation: "boolean type not yet supported in list 'set' operations")
},
_ => {
with_match_physical_integer_type!(dtype.into(), |$T| {
with_match_physical_numeric_type!(dtype.into(), |$T| {
let a = values_a.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
let b = values_b.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();

Expand Down
8 changes: 8 additions & 0 deletions crates/polars-utils/src/total_ord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ impl<T: TotalHash> Hash for TotalOrdWrap<T> {
}
}

impl<T: Clone> Clone for TotalOrdWrap<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

impl<T: Copy> Copy for TotalOrdWrap<T> {}

macro_rules! impl_trivial_total {
($T: ty) => {
impl TotalEq for $T {
Expand Down
28 changes: 28 additions & 0 deletions py-polars/tests/unit/namespaces/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,34 @@ def test_list_set_oob() -> None:
) == {"a": [[], []]}


def test_list_set_operations_float() -> None:
df = pl.DataFrame(
{"a": [[1, 2, 3], [1, 1, 1], [4]], "b": [[4, 2, 1], [2, 1, 12], [4]]},
schema={"a": pl.List(pl.Float32), "b": pl.List(pl.Float32)},
)

assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [
[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 12.0],
[4.0],
]
assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [
[1.0, 2.0],
[1.0],
[4.0],
]
assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [
[3.0],
[],
[],
]
assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [
[4.0],
[2.0, 12.0],
[],
]


def test_list_set_operations() -> None:
df = pl.DataFrame(
{"a": [[1, 2, 3], [1, 1, 1], [4]], "b": [[4, 2, 1], [2, 1, 12], [4]]}
Expand Down

0 comments on commit 1bdac4e

Please sign in to comment.