Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(rust): Use defunctionalization in polars-core scalar.rs in order to reduce code duplication #20377

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 50 additions & 36 deletions crates/polars-core/src/chunked_array/comparison/scalar.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
use super::*;

#[derive(Clone, Copy)]
enum CmpOp {
Lt,
Le,
Gt,
Ge,
}

// Given two monotonic functions f_a and f_d where f_a is ascending
// (f_a(x[0]) <= f_a(x[1]) <= .. <= f_a(x[n-1])) and f_d is descending
// (f_d(x[0]) >= f_d(x[1]) >= .. >= f_d(x[n-1])),
// outputs a mask where both are true.
//
// If a function is not given it is always assumed to be true. If invert is
// true the output mask is inverted.
fn bitonic_mask<T: PolarsNumericType, FA, FD>(
fn bitonic_mask<T: PolarsNumericType>(
ca: &ChunkedArray<T>,
f_a: Option<FA>,
f_d: Option<FD>,
f_a: Option<CmpOp>,
f_d: Option<CmpOp>,
rhs: &T::Native,
invert: bool,
) -> BooleanChunked
where
FA: Fn(T::Native) -> bool,
FD: Fn(T::Native) -> bool,
{
) -> BooleanChunked {
fn apply<T: PolarsNumericType>(op: CmpOp, x: T::Native, rhs: &T::Native) -> bool {
match op {
CmpOp::Lt => x.tot_lt(rhs),
CmpOp::Le => x.tot_le(rhs),
CmpOp::Gt => x.tot_gt(rhs),
CmpOp::Ge => x.tot_ge(rhs),
}
}
let mut output_order: Option<IsSorted> = None;
let mut last_value: Option<bool> = None;
let mut logical_extend = |len: usize, val: bool| {
Expand All @@ -36,13 +49,14 @@ where

let chunks = ca.downcast_iter().map(|arr| {
let values = arr.values();
let true_range_start = if let Some(f_a) = f_a.as_ref() {
values.partition_point(|x| !f_a(*x))
let true_range_start = if let Some(f_a) = f_a {
values.partition_point(|x| !apply::<T>(f_a, *x, rhs))
} else {
0
};
let true_range_end = if let Some(f_d) = f_d.as_ref() {
true_range_start + values[true_range_start..].partition_point(|x| f_d(*x))
let true_range_end = if let Some(f_d) = f_d {
true_range_start
+ values[true_range_start..].partition_point(|x| apply::<T>(f_d, *x, rhs))
} else {
values.len()
};
Expand Down Expand Up @@ -71,11 +85,11 @@ where

fn equal(&self, rhs: Rhs) -> BooleanChunked {
let rhs: T::Native = NumCast::from(rhs).unwrap();
let fa = Some(|x: T::Native| x.tot_ge(&rhs));
let fd = Some(|x: T::Native| x.tot_le(&rhs));
let fa = Some(CmpOp::Ge);
let fd = Some(CmpOp::Le);
match (self.is_sorted_flag(), self.null_count()) {
(IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, false),
(IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, false),
(IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
(IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
_ => arity::unary_mut_values(self, |arr| arr.tot_eq_kernel_broadcast(&rhs).into()),
}
}
Expand All @@ -93,11 +107,11 @@ where

fn not_equal(&self, rhs: Rhs) -> BooleanChunked {
let rhs: T::Native = NumCast::from(rhs).unwrap();
let fa = Some(|x: T::Native| x.tot_ge(&rhs));
let fd = Some(|x: T::Native| x.tot_le(&rhs));
let fa = Some(CmpOp::Ge);
let fd = Some(CmpOp::Le);
match (self.is_sorted_flag(), self.null_count()) {
(IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, true),
(IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, true),
(IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, true),
(IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, true),
_ => arity::unary_mut_values(self, |arr| arr.tot_ne_kernel_broadcast(&rhs).into()),
}
}
Expand All @@ -124,44 +138,44 @@ where

fn gt(&self, rhs: Rhs) -> BooleanChunked {
let rhs: T::Native = NumCast::from(rhs).unwrap();
let fa = Some(|x: T::Native| x.tot_gt(&rhs));
let fd: Option<fn(_) -> _> = None;
let fa = Some(CmpOp::Gt);
let fd = None;
match (self.is_sorted_flag(), self.null_count()) {
(IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, false),
(IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, false),
(IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
(IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
_ => arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(&rhs).into()),
}
}

fn gt_eq(&self, rhs: Rhs) -> BooleanChunked {
let rhs: T::Native = NumCast::from(rhs).unwrap();
let fa = Some(|x: T::Native| x.tot_ge(&rhs));
let fd: Option<fn(_) -> _> = None;
let fa = Some(CmpOp::Ge);
let fd = None;
match (self.is_sorted_flag(), self.null_count()) {
(IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, false),
(IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, false),
(IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
(IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
_ => arity::unary_mut_values(self, |arr| arr.tot_ge_kernel_broadcast(&rhs).into()),
}
}

fn lt(&self, rhs: Rhs) -> BooleanChunked {
let rhs: T::Native = NumCast::from(rhs).unwrap();
let fa: Option<fn(_) -> _> = None;
let fd = Some(|x: T::Native| x.tot_lt(&rhs));
let fa = None;
let fd = Some(CmpOp::Lt);
match (self.is_sorted_flag(), self.null_count()) {
(IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, false),
(IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, false),
(IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
(IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
_ => arity::unary_mut_values(self, |arr| arr.tot_lt_kernel_broadcast(&rhs).into()),
}
}

fn lt_eq(&self, rhs: Rhs) -> BooleanChunked {
let rhs: T::Native = NumCast::from(rhs).unwrap();
let fa: Option<fn(_) -> _> = None;
let fd = Some(|x: T::Native| x.tot_le(&rhs));
let fa = None;
let fd = Some(CmpOp::Le);
match (self.is_sorted_flag(), self.null_count()) {
(IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, false),
(IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, false),
(IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false),
(IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false),
_ => arity::unary_mut_values(self, |arr| arr.tot_le_kernel_broadcast(&rhs).into()),
}
}
Expand Down
Loading