Skip to content

Commit

Permalink
feat: clip supports expr arguments and physical numeric dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Sep 25, 2023
1 parent 480a823 commit c946e8f
Show file tree
Hide file tree
Showing 11 changed files with 350 additions and 118 deletions.
51 changes: 51 additions & 0 deletions crates/polars-core/src/chunked_array/ops/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter, StaticArray};
use crate::prelude::{ChunkedArray, PolarsDataType};
use crate::utils::{align_chunks_binary, align_chunks_ternary};

// We need this helper because for<'a> notation can't yet be applied properly
// on the return type.
pub trait TernaryFnMut<A1, A2, A3>: FnMut(A1, A2, A3) -> Self::Ret {
type Ret;
}

impl<A1, A2, A3, R, T: FnMut(A1, A2, A3) -> R> TernaryFnMut<A1, A2, A3> for T {
type Ret = R;
}

// We need this helper because for<'a> notation can't yet be applied properly
// on the return type.
pub trait BinaryFnMut<A1, A2>: FnMut(A1, A2) -> Self::Ret {
Expand Down Expand Up @@ -334,3 +344,44 @@ where
});
ChunkedArray::try_from_chunk_iter(ca1.name(), iter)
}

#[inline]
pub fn ternary_elementwise<T, U, V, G, F>(
ca1: &ChunkedArray<T>,
ca2: &ChunkedArray<U>,
ca3: &ChunkedArray<G>,
mut op: F,
) -> ChunkedArray<V>
where
T: PolarsDataType,
U: PolarsDataType,
G: PolarsDataType,
V: PolarsDataType,
F: for<'a> TernaryFnMut<
Option<T::Physical<'a>>,
Option<U::Physical<'a>>,
Option<G::Physical<'a>>,
>,
V::Array: for<'a> ArrayFromIter<
<F as TernaryFnMut<
Option<T::Physical<'a>>,
Option<U::Physical<'a>>,
Option<G::Physical<'a>>,
>>::Ret,
>,
{
let (ca1, ca2, ca3) = align_chunks_ternary(ca1, ca2, ca3);
let iter = ca1
.downcast_iter()
.zip(ca2.downcast_iter())
.zip(ca3.downcast_iter())
.map(|((ca1_arr, ca2_arr), ca3_arr)| {
let element_iter = ca1_arr.iter().zip(ca2_arr.iter()).zip(ca3_arr.iter()).map(
|((ca1_opt_val, ca2_opt_val), ca3_opt_val)| {
op(ca1_opt_val, ca2_opt_val, ca3_opt_val)
},
);
element_iter.collect_arr()
});
ChunkedArray::from_chunk_iter(ca1.name(), iter)
}
64 changes: 0 additions & 64 deletions crates/polars-core/src/series/ops/round.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use num_traits::pow::Pow;
use num_traits::{clamp_max, clamp_min};

use crate::prelude::*;

Expand Down Expand Up @@ -60,67 +59,4 @@ impl Series {
}
polars_bail!(opq = ceil, self.dtype());
}

/// Clamp underlying values to the `min` and `max` values.
pub fn clip(mut self, min: AnyValue<'_>, max: AnyValue<'_>) -> PolarsResult<Self> {
if self.dtype().is_numeric() {
macro_rules! apply_clip {
($pl_type:ty, $ca:expr) => {{
let min = min
.extract::<<$pl_type as PolarsNumericType>::Native>()
.unwrap();
let max = max
.extract::<<$pl_type as PolarsNumericType>::Native>()
.unwrap();

$ca.apply_mut(|val| val.clamp(min, max));
}};
}
let mutable = self._get_inner_mut();
downcast_as_macro_arg_physical_mut!(mutable, apply_clip);
Ok(self)
} else {
polars_bail!(opq = clip, self.dtype());
}
}

/// Clamp underlying values to the `max` value.
pub fn clip_max(mut self, max: AnyValue<'_>) -> PolarsResult<Self> {
if self.dtype().is_numeric() {
macro_rules! apply_clip {
($pl_type:ty, $ca:expr) => {{
let max = max
.extract::<<$pl_type as PolarsNumericType>::Native>()
.unwrap();

$ca.apply_mut(|val| clamp_max(val, max));
}};
}
let mutable = self._get_inner_mut();
downcast_as_macro_arg_physical_mut!(mutable, apply_clip);
Ok(self)
} else {
polars_bail!(opq = clip_max, self.dtype());
}
}

/// Clamp underlying values to the `min` value.
pub fn clip_min(mut self, min: AnyValue<'_>) -> PolarsResult<Self> {
if self.dtype().is_numeric() {
macro_rules! apply_clip {
($pl_type:ty, $ca:expr) => {{
let min = min
.extract::<<$pl_type as PolarsNumericType>::Native>()
.unwrap();

$ca.apply_mut(|val| clamp_min(val, min));
}};
}
let mutable = self._get_inner_mut();
downcast_as_macro_arg_physical_mut!(mutable, apply_clip);
Ok(self)
} else {
polars_bail!(opq = clip_min, self.dtype());
}
}
}
151 changes: 151 additions & 0 deletions crates/polars-ops/src/series/ops/clip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
use num_traits::{clamp, clamp_max, clamp_min};
use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise};
use polars_core::prelude::*;
use polars_core::with_match_physical_numeric_polars_type;

fn clip_helper<T>(
ca: &ChunkedArray<T>,
min: &ChunkedArray<T>,
max: &ChunkedArray<T>,
) -> ChunkedArray<T>
where
T: PolarsNumericType,
T::Native: PartialOrd,
{
match (min.len(), max.len()) {
(1, 1) => match (min.get(0), max.get(0)) {
(Some(min), Some(max)) => {
ca.apply_generic(|s| s.map(|s| num_traits::clamp(s, min, max)))
},
_ => ChunkedArray::<T>::full_null(ca.name(), ca.len()),
},
(1, _) => match min.get(0) {
Some(min) => binary_elementwise(ca, max, |opt_s, opt_max| match (opt_s, opt_max) {
(Some(s), Some(max)) => Some(clamp(s, min, max)),
_ => None,
}),
_ => ChunkedArray::<T>::full_null(ca.name(), ca.len()),
},
(_, 1) => match max.get(0) {
Some(max) => binary_elementwise(ca, min, |opt_s, opt_min| match (opt_s, opt_min) {
(Some(s), Some(min)) => Some(clamp(s, min, max)),
_ => None,
}),
_ => ChunkedArray::<T>::full_null(ca.name(), ca.len()),
},
_ => ternary_elementwise(ca, min, max, |opt_s, opt_min, opt_max| {
match (opt_s, opt_min, opt_max) {
(Some(s), Some(min), Some(max)) => Some(clamp(s, min, max)),
_ => None,
}
}),
}
}

fn clip_min_max_helper<T, F>(
ca: &ChunkedArray<T>,
bound: &ChunkedArray<T>,
op: F,
) -> ChunkedArray<T>
where
T: PolarsNumericType,
T::Native: PartialOrd,
F: Fn(T::Native, T::Native) -> T::Native,
{
match bound.len() {
1 => match bound.get(0) {
Some(bound) => ca.apply_generic(|s| s.map(|s| op(s, bound))),
_ => ChunkedArray::<T>::full_null(ca.name(), ca.len()),
},
_ => binary_elementwise(ca, bound, |opt_s, opt_bound| match (opt_s, opt_bound) {
(Some(s), Some(bound)) => Some(op(s, bound)),
_ => None,
}),
}
}

/// Clamp underlying values to the `min` and `max` values.
pub fn clip(s: &Series, min: &Series, max: &Series) -> PolarsResult<Series> {
polars_ensure!(s.dtype().to_physical().is_numeric(), InvalidOperation: "Only physical numeric types are supported.");

let original_type = s.dtype();
// cast min & max to the dtype of s first.
let (min, max) = (min.cast(s.dtype())?, max.cast(s.dtype())?);

let (s, min, max) = (
s.to_physical_repr(),
min.to_physical_repr(),
max.to_physical_repr(),
);

match s.dtype() {
dt if dt.is_numeric() => {
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
let min: &ChunkedArray<$T> = min.as_ref().as_ref().as_ref();
let max: &ChunkedArray<$T> = max.as_ref().as_ref().as_ref();
let out = clip_helper(ca, min, max).into_series();
if original_type.is_logical(){
out.cast(original_type)
}else{
Ok(out)
}
})
},
dt => polars_bail!(opq = clippy, dt),
}
}

/// Clamp underlying values to the `max` value.
pub fn clip_max(s: &Series, max: &Series) -> PolarsResult<Series> {
polars_ensure!(s.dtype().to_physical().is_numeric(), InvalidOperation: "Only physical numeric types are supported.");

let original_type = s.dtype();
// cast max to the dtype of s first.
let max = max.cast(s.dtype())?;

let (s, max) = (s.to_physical_repr(), max.to_physical_repr());

match s.dtype() {
dt if dt.is_numeric() => {
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
let max: &ChunkedArray<$T> = max.as_ref().as_ref().as_ref();
let out = clip_min_max_helper(ca, max, clamp_max).into_series();
if original_type.is_logical(){
out.cast(original_type)
}else{
Ok(out)
}
})
},
dt => polars_bail!(opq = clippy_max, dt),
}
}

/// Clamp underlying values to the `min` value.
pub fn clip_min(s: &Series, min: &Series) -> PolarsResult<Series> {
polars_ensure!(s.dtype().to_physical().is_numeric(), InvalidOperation: "Only physical numeric types are supported.");

let original_type = s.dtype();
// cast min to the dtype of s first.
let min = min.cast(s.dtype())?;

let (s, min) = (s.to_physical_repr(), min.to_physical_repr());

match s.dtype() {
dt if dt.is_numeric() => {
with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
let min: &ChunkedArray<$T> = min.as_ref().as_ref().as_ref();
let out = clip_min_max_helper(ca, min, clamp_min).into_series();
if original_type.is_logical(){
out.cast(original_type)
}else{
Ok(out)
}
})
},
dt => polars_bail!(opq = clippy_min, dt),
}
}
2 changes: 2 additions & 0 deletions crates/polars-ops/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod approx_algo;
#[cfg(feature = "approx_unique")]
mod approx_unique;
mod arg_min_max;
mod clip;
#[cfg(feature = "cutqcut")]
mod cut;
#[cfg(feature = "round_series")]
Expand Down Expand Up @@ -34,6 +35,7 @@ pub use approx_algo::*;
#[cfg(feature = "approx_unique")]
pub use approx_unique::*;
pub use arg_min_max::ArgAgg;
pub use clip::*;
#[cfg(feature = "cutqcut")]
pub use cut::*;
#[cfg(feature = "round_series")]
Expand Down
14 changes: 5 additions & 9 deletions crates/polars-plan/src/dsl/function_expr/clip.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
use super::*;

pub(super) fn clip(
s: Series,
min: Option<AnyValue<'_>>,
max: Option<AnyValue<'_>>,
) -> PolarsResult<Series> {
match (min, max) {
(Some(min), Some(max)) => s.clip(min, max),
(Some(min), None) => s.clip_min(min),
(None, Some(max)) => s.clip_max(max),
pub(super) fn clip(s: &[Series], has_min: bool, has_max: bool) -> PolarsResult<Series> {
match (has_min, has_max) {
(true, true) => polars_ops::prelude::clip(&s[0], &s[1], &s[2]),
(true, false) => polars_ops::prelude::clip_min(&s[0], &s[1]),
(false, true) => polars_ops::prelude::clip_max(&s[0], &s[1]),
_ => unreachable!(),
}
}
16 changes: 8 additions & 8 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ pub enum FunctionExpr {
DropNans,
#[cfg(feature = "round_series")]
Clip {
min: Option<AnyValue<'static>>,
max: Option<AnyValue<'static>>,
has_min: bool,
has_max: bool,
},
ListExpr(ListFunction),
#[cfg(feature = "dtype-array")]
Expand Down Expand Up @@ -321,10 +321,10 @@ impl Display for FunctionExpr {
ShiftAndFill { .. } => "shift_and_fill",
DropNans => "drop_nans",
#[cfg(feature = "round_series")]
Clip { min, max } => match (min, max) {
(Some(_), Some(_)) => "clip",
(None, Some(_)) => "clip_max",
(Some(_), None) => "clip_min",
Clip { has_min, has_max } => match (has_min, has_max) {
(true, true) => "clip",
(false, true) => "clip_max",
(true, false) => "clip_min",
_ => unreachable!(),
},
ListExpr(func) => return write!(f, "{func}"),
Expand Down Expand Up @@ -543,8 +543,8 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
},
DropNans => map_owned!(nan::drop_nans),
#[cfg(feature = "round_series")]
Clip { min, max } => {
map_owned!(clip::clip, min.clone(), max.clone())
Clip { has_min, has_max } => {
map_as_slice!(clip::clip, has_min, has_max)
},
ListExpr(lf) => {
use ListFunction::*;
Expand Down
Loading

0 comments on commit c946e8f

Please sign in to comment.