Skip to content

Commit

Permalink
refactor: Add reduce ComputeNode in new streaming engine (#17389)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 17, 2024
1 parent 28d8196 commit cffa970
Show file tree
Hide file tree
Showing 23 changed files with 763 additions and 239 deletions.
38 changes: 2 additions & 36 deletions crates/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,40 +9,6 @@ use polars_utils::sync::SyncPtr;
use polars_utils::total_ord::ToTotalOrd;
use polars_utils::unwrap::UnwrapUncheckedRelease;

#[derive(Clone)]
pub struct Scalar {
dtype: DataType,
value: AnyValue<'static>,
}

impl Scalar {
pub fn new(dtype: DataType, value: AnyValue<'static>) -> Self {
Self { dtype, value }
}

pub fn value(&self) -> &AnyValue<'static> {
&self.value
}

pub fn as_any_value(&self) -> AnyValue {
self.value
.strict_cast(&self.dtype)
.unwrap_or_else(|| self.value.clone())
}

pub fn into_series(self, name: &str) -> Series {
Series::from_any_values_and_dtype(name, &[self.as_any_value()], &self.dtype, true).unwrap()
}

pub fn dtype(&self) -> &DataType {
&self.dtype
}

pub fn update(&mut self, value: AnyValue<'static>) {
self.value = value;
}
}

use super::*;
#[cfg(feature = "dtype-struct")]
use crate::prelude::any_value::arr_to_any_value;
Expand Down Expand Up @@ -854,8 +820,8 @@ impl<'a> AnyValue<'a> {
pub fn add(&self, rhs: &AnyValue) -> AnyValue<'static> {
use AnyValue::*;
match (self, rhs) {
(Null, _) => Null,
(_, Null) => Null,
(Null, r) => r.clone().into_static().unwrap(),
(l, Null) => l.clone().into_static().unwrap(),
(Int32(l), Int32(r)) => Int32(l + r),
(Int64(l), Int64(r)) => Int64(l + r),
(UInt32(l), UInt32(r)) => UInt32(l + r),
Expand Down
1 change: 1 addition & 0 deletions crates/polars-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod named_from;
pub mod prelude;
#[cfg(feature = "random")]
pub mod random;
pub mod scalar;
pub mod schema;
#[cfg(feature = "serde")]
pub mod serde;
Expand Down
1 change: 1 addition & 0 deletions crates/polars-core/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub use crate::frame::group_by::*;
pub use crate::frame::{DataFrame, UniqueKeepStrategy};
pub use crate::hashing::VecHash;
pub use crate::named_from::{NamedFrom, NamedFromOwned};
pub use crate::scalar::Scalar;
pub use crate::schema::*;
#[cfg(feature = "checked_arithmetic")]
pub use crate::series::arithmetic::checked::NumOpsDispatchChecked;
Expand Down
38 changes: 38 additions & 0 deletions crates/polars-core/src/scalar/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
pub mod reduce;

use crate::datatypes::{AnyValue, DataType};
use crate::prelude::Series;

#[derive(Clone)]
pub struct Scalar {
dtype: DataType,
value: AnyValue<'static>,
}

impl Scalar {
pub fn new(dtype: DataType, value: AnyValue<'static>) -> Self {
Self { dtype, value }
}

pub fn value(&self) -> &AnyValue<'static> {
&self.value
}

pub fn as_any_value(&self) -> AnyValue {
self.value
.strict_cast(&self.dtype)
.unwrap_or_else(|| self.value.clone())
}

pub fn into_series(self, name: &str) -> Series {
Series::from_any_values_and_dtype(name, &[self.as_any_value()], &self.dtype, true).unwrap()
}

pub fn dtype(&self) -> &DataType {
&self.dtype
}

pub fn update(&mut self, value: AnyValue<'static>) {
self.value = value;
}
}
37 changes: 37 additions & 0 deletions crates/polars-core/src/scalar/reduce.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use crate::datatypes::{AnyValue, TimeUnit};
#[cfg(feature = "dtype-date")]
use crate::prelude::MS_IN_DAY;
use crate::prelude::{DataType, Scalar};

pub fn mean_reduce(value: Option<f64>, dtype: DataType) -> Scalar {
match dtype {
DataType::Float32 => {
let val = value.map(|m| m as f32);
Scalar::new(dtype, val.into())
},
dt if dt.is_numeric() || dt.is_decimal() || dt.is_bool() => {
Scalar::new(DataType::Float64, value.into())
},
#[cfg(feature = "dtype-date")]
DataType::Date => {
let val = value.map(|v| (v * MS_IN_DAY as f64) as i64);
Scalar::new(DataType::Datetime(TimeUnit::Milliseconds, None), val.into())
},
#[cfg(feature = "dtype-datetime")]
dt @ DataType::Datetime(_, _) => {
let val = value.map(|v| v as i64);
Scalar::new(dt, val.into())
},
#[cfg(feature = "dtype-duration")]
dt @ DataType::Duration(_) => {
let val = value.map(|v| v as i64);
Scalar::new(dt, val.into())
},
#[cfg(feature = "dtype-time")]
dt @ DataType::Time => {
let val = value.map(|v| v as i64);
Scalar::new(dt, val.into())
},
dt => Scalar::new(dt, AnyValue::Null),
}
}
36 changes: 1 addition & 35 deletions crates/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -808,41 +808,7 @@ impl Series {
}

pub fn mean_reduce(&self) -> Scalar {
match self.dtype() {
DataType::Float32 => {
let val = self.mean().map(|m| m as f32);
Scalar::new(self.dtype().clone(), val.into())
},
dt if dt.is_numeric() || dt.is_decimal() || dt.is_bool() => {
let val = self.mean();
Scalar::new(DataType::Float64, val.into())
},
#[cfg(feature = "dtype-date")]
DataType::Date => {
let val = self.mean().map(|v| (v * MS_IN_DAY as f64) as i64);
let av: AnyValue = val.into();
Scalar::new(DataType::Datetime(TimeUnit::Milliseconds, None), av)
},
#[cfg(feature = "dtype-datetime")]
dt @ DataType::Datetime(_, _) => {
let val = self.mean().map(|v| v as i64);
let av: AnyValue = val.into();
Scalar::new(dt.clone(), av)
},
#[cfg(feature = "dtype-duration")]
dt @ DataType::Duration(_) => {
let val = self.mean().map(|v| v as i64);
let av: AnyValue = val.into();
Scalar::new(dt.clone(), av)
},
#[cfg(feature = "dtype-time")]
dt @ DataType::Time => {
let val = self.mean().map(|v| v as i64);
let av: AnyValue = val.into();
Scalar::new(dt.clone(), av)
},
dt => Scalar::new(dt.clone(), AnyValue::Null),
}
crate::scalar::reduce::mean_reduce(self.mean(), self.dtype().clone())
}

/// Compute the unique elements, but maintain order. This requires more work
Expand Down
1 change: 1 addition & 0 deletions crates/polars-expr/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod expressions;
pub mod planner;
pub mod prelude;
pub mod reduce;
pub mod state;

pub use crate::planner::{create_physical_expr, ExpressionConversionState};
82 changes: 82 additions & 0 deletions crates/polars-expr/src/reduce/convert.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use polars_core::error::feature_gated;
use polars_plan::prelude::*;
use polars_utils::arena::{Arena, Node};

use super::extrema::*;
use super::sum::SumReduce;
use super::*;
use crate::reduce::mean::MeanReduce;

pub fn can_convert_into_reduction(node: Node, expr_arena: &Arena<AExpr>) -> bool {
match expr_arena.get(node) {
AExpr::Agg(agg) => matches!(
agg,
IRAggExpr::Min { .. }
| IRAggExpr::Max { .. }
| IRAggExpr::Mean { .. }
| IRAggExpr::Sum(_)
),
_ => false,
}
}

pub fn into_reduction(
node: Node,
expr_arena: &Arena<AExpr>,
schema: &Schema,
) -> PolarsResult<Option<(Box<dyn Reduction>, Node)>> {
let e = expr_arena.get(node);
let field = e.to_field(schema, Context::Default, expr_arena)?;
let out = match expr_arena.get(node) {
AExpr::Agg(agg) => match agg {
IRAggExpr::Sum(node) => (
Box::new(SumReduce::new(field.dtype.clone())) as Box<dyn Reduction>,
*node,
),
IRAggExpr::Min {
propagate_nans,
input,
} => {
if *propagate_nans && field.dtype.is_float() {
feature_gated!("propagate_nans", {
let out: Box<dyn Reduction> = match field.dtype {
DataType::Float32 => Box::new(MinNanReduce::<Float32Type>::new()),
DataType::Float64 => Box::new(MinNanReduce::<Float64Type>::new()),
_ => unreachable!(),
};
(out, *input)
})
} else {
(
Box::new(MinReduce::new(field.dtype.clone())) as Box<dyn Reduction>,
*input,
)
}
},
IRAggExpr::Max {
propagate_nans,
input,
} => {
if *propagate_nans && field.dtype.is_float() {
feature_gated!("propagate_nans", {
let out: Box<dyn Reduction> = match field.dtype {
DataType::Float32 => Box::new(MaxNanReduce::<Float32Type>::new()),
DataType::Float64 => Box::new(MaxNanReduce::<Float64Type>::new()),
_ => unreachable!(),
};
(out, *input)
})
} else {
(Box::new(MaxReduce::new(field.dtype.clone())) as _, *input)
}
},
IRAggExpr::Mean(input) => {
let out: Box<dyn Reduction> = Box::new(MeanReduce::new(field.dtype.clone()));
(out, *input)
},
_ => return Ok(None),
},
_ => return Ok(None),
};
Ok(Some(out))
}
Loading

0 comments on commit cffa970

Please sign in to comment.