diff --git a/crates/polars-core/src/chunked_array/comparison/scalar.rs b/crates/polars-core/src/chunked_array/comparison/scalar.rs index 4649696b41dd..458194e1dc3c 100644 --- a/crates/polars-core/src/chunked_array/comparison/scalar.rs +++ b/crates/polars-core/src/chunked_array/comparison/scalar.rs @@ -1,5 +1,13 @@ use super::*; +#[derive(Clone, Copy)] +enum CmpOp { + Lt, + Le, + Gt, + Ge, +} + // Given two monotonic functions f_a and f_d where f_a is ascending // (f_a(x[0]) <= f_a(x[1]) <= .. <= f_a(x[n-1])) and f_d is descending // (f_d(x[0]) >= f_d(x[1]) >= .. >= f_d(x[n-1])), @@ -7,16 +15,21 @@ use super::*; // // If a function is not given it is always assumed to be true. If invert is // true the output mask is inverted. -fn bitonic_mask( +fn bitonic_mask( ca: &ChunkedArray, - f_a: Option, - f_d: Option, + f_a: Option, + f_d: Option, + rhs: &T::Native, invert: bool, -) -> BooleanChunked -where - FA: Fn(T::Native) -> bool, - FD: Fn(T::Native) -> bool, -{ +) -> BooleanChunked { + fn apply(op: CmpOp, x: T::Native, rhs: &T::Native) -> bool { + match op { + CmpOp::Lt => x.tot_lt(rhs), + CmpOp::Le => x.tot_le(rhs), + CmpOp::Gt => x.tot_gt(rhs), + CmpOp::Ge => x.tot_ge(rhs), + } + } let mut output_order: Option = None; let mut last_value: Option = None; let mut logical_extend = |len: usize, val: bool| { @@ -36,13 +49,14 @@ where let chunks = ca.downcast_iter().map(|arr| { let values = arr.values(); - let true_range_start = if let Some(f_a) = f_a.as_ref() { - values.partition_point(|x| !f_a(*x)) + let true_range_start = if let Some(f_a) = f_a { + values.partition_point(|x| !apply::(f_a, *x, rhs)) } else { 0 }; - let true_range_end = if let Some(f_d) = f_d.as_ref() { - true_range_start + values[true_range_start..].partition_point(|x| f_d(*x)) + let true_range_end = if let Some(f_d) = f_d { + true_range_start + + values[true_range_start..].partition_point(|x| apply::(f_d, *x, rhs)) } else { values.len() }; @@ -71,11 +85,11 @@ where fn equal(&self, rhs: Rhs) -> BooleanChunked { let rhs: T::Native = NumCast::from(rhs).unwrap(); - let fa = Some(|x: T::Native| x.tot_ge(&rhs)); - let fd = Some(|x: T::Native| x.tot_le(&rhs)); + let fa = Some(CmpOp::Ge); + let fd = Some(CmpOp::Le); match (self.is_sorted_flag(), self.null_count()) { - (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, false), - (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, false), + (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false), + (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false), _ => arity::unary_mut_values(self, |arr| arr.tot_eq_kernel_broadcast(&rhs).into()), } } @@ -93,11 +107,11 @@ where fn not_equal(&self, rhs: Rhs) -> BooleanChunked { let rhs: T::Native = NumCast::from(rhs).unwrap(); - let fa = Some(|x: T::Native| x.tot_ge(&rhs)); - let fd = Some(|x: T::Native| x.tot_le(&rhs)); + let fa = Some(CmpOp::Ge); + let fd = Some(CmpOp::Le); match (self.is_sorted_flag(), self.null_count()) { - (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, true), - (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, true), + (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, true), + (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, true), _ => arity::unary_mut_values(self, |arr| arr.tot_ne_kernel_broadcast(&rhs).into()), } } @@ -124,44 +138,44 @@ where fn gt(&self, rhs: Rhs) -> BooleanChunked { let rhs: T::Native = NumCast::from(rhs).unwrap(); - let fa = Some(|x: T::Native| x.tot_gt(&rhs)); - let fd: Option _> = None; + let fa = Some(CmpOp::Gt); + let fd = None; match (self.is_sorted_flag(), self.null_count()) { - (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, false), - (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, false), + (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false), + (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false), _ => arity::unary_mut_values(self, |arr| arr.tot_gt_kernel_broadcast(&rhs).into()), } } fn gt_eq(&self, rhs: Rhs) -> BooleanChunked { let rhs: T::Native = NumCast::from(rhs).unwrap(); - let fa = Some(|x: T::Native| x.tot_ge(&rhs)); - let fd: Option _> = None; + let fa = Some(CmpOp::Ge); + let fd = None; match (self.is_sorted_flag(), self.null_count()) { - (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, false), - (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, false), + (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false), + (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false), _ => arity::unary_mut_values(self, |arr| arr.tot_ge_kernel_broadcast(&rhs).into()), } } fn lt(&self, rhs: Rhs) -> BooleanChunked { let rhs: T::Native = NumCast::from(rhs).unwrap(); - let fa: Option _> = None; - let fd = Some(|x: T::Native| x.tot_lt(&rhs)); + let fa = None; + let fd = Some(CmpOp::Lt); match (self.is_sorted_flag(), self.null_count()) { - (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, false), - (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, false), + (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false), + (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false), _ => arity::unary_mut_values(self, |arr| arr.tot_lt_kernel_broadcast(&rhs).into()), } } fn lt_eq(&self, rhs: Rhs) -> BooleanChunked { let rhs: T::Native = NumCast::from(rhs).unwrap(); - let fa: Option _> = None; - let fd = Some(|x: T::Native| x.tot_le(&rhs)); + let fa = None; + let fd = Some(CmpOp::Le); match (self.is_sorted_flag(), self.null_count()) { - (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, false), - (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, false), + (IsSorted::Ascending, 0) => bitonic_mask(self, fa, fd, &rhs, false), + (IsSorted::Descending, 0) => bitonic_mask(self, fd, fa, &rhs, false), _ => arity::unary_mut_values(self, |arr| arr.tot_le_kernel_broadcast(&rhs).into()), } }