Skip to content

Commit

Permalink
fix(rust): implement proper hash for identifier in cse
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 23, 2023
1 parent ce3dd72 commit 394cfc5
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 5 deletions.
12 changes: 12 additions & 0 deletions crates/polars-arrow/src/legacy/kernels/ewm/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
mod average;
mod variance;

use std::hash::{Hash, Hasher};

pub use average::*;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -29,6 +31,16 @@ impl Default for EWMOptions {
}
}

impl Hash for EWMOptions {
fn hash<H: Hasher>(&self, state: &mut H) {
self.alpha.to_bits().hash(state);
self.adjust.hash(state);
self.bias.hash(state);
self.min_periods.hash(state);
self.ignore_nulls.hash(state);
}
}

impl EWMOptions {
pub fn and_min_periods(mut self, min_periods: usize) -> Self {
self.min_periods = min_periods;
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/series/series_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::chunked_array::object::PolarsObjectSafe;
pub use crate::prelude::ChunkCompare;
use crate::prelude::*;

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum IsSorted {
Ascending,
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-ops/src/series/ops/rank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize};

use crate::prelude::SeriesSealed;

#[derive(Copy, Clone, Debug, PartialEq)]
#[derive(Copy, Clone, Debug, PartialEq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum RankMethod {
Average,
Expand All @@ -24,7 +24,7 @@ pub enum RankMethod {
}

// We might want to add a `nulls_last` or `null_behavior` field.
#[derive(Copy, Clone, Debug, PartialEq)]
#[derive(Copy, Clone, Debug, PartialEq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RankOptions {
pub method: RankMethod,
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ polars-utils = { workspace = true }

ahash = { workspace = true }
arrow = { workspace = true }
bytemuck = { workspace = true }
chrono = { workspace = true, optional = true }
chrono-tz = { workspace = true, optional = true }
ciborium = { workspace = true, optional = true }
Expand Down
145 changes: 143 additions & 2 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ pub enum FunctionExpr {
ShiftAndFill {
periods: i64,
},
Shift(i64),
DropNans,
DropNulls,
#[cfg(feature = "mode")]
Expand Down Expand Up @@ -165,7 +166,6 @@ pub enum FunctionExpr {
AsStruct,
#[cfg(feature = "top_k")]
TopK(bool),
Shift(i64),
#[cfg(feature = "cum_agg")]
Cumcount {
reverse: bool,
Expand Down Expand Up @@ -345,7 +345,148 @@ impl Hash for FunctionExpr {
lib.hash(state);
symbol.hash(state);
},
_ => {},
FunctionExpr::Shift(periods) | FunctionExpr::ShiftAndFill { periods } => {
periods.hash(state)
},
FunctionExpr::SumHorizontal
| FunctionExpr::MaxHorizontal
| FunctionExpr::MinHorizontal
| FunctionExpr::DropNans
| FunctionExpr::DropNulls
| FunctionExpr::Reverse
| FunctionExpr::ArgUnique => {},
#[cfg(feature = "mode")]
FunctionExpr::Mode => {},
#[cfg(feature = "abs")]
FunctionExpr::Abs => {},
FunctionExpr::NullCount => {},
#[cfg(feature = "date_offset")]
FunctionExpr::DateOffset => {},
#[cfg(feature = "arg_where")]
FunctionExpr::ArgWhere => {},
#[cfg(feature = "trigonometry")]
FunctionExpr::Atan2 => {},
#[cfg(feature = "dtype-struct")]
FunctionExpr::AsStruct => {},
#[cfg(feature = "sign")]
FunctionExpr::Sign => {},
FunctionExpr::Hash(a, b, c, d) => (a, b, c, d).hash(state),
FunctionExpr::FillNull { super_type } => super_type.hash(state),
#[cfg(all(feature = "rolling_window", feature = "moment"))]
FunctionExpr::RollingSkew { window_size, bias } => {
window_size.hash(state);
bias.hash(state);
},
#[cfg(feature = "moment")]
FunctionExpr::Skew(a) => a.hash(state),
#[cfg(feature = "moment")]
FunctionExpr::Kurtosis(a, b) => {
a.hash(state);
b.hash(state);
},
#[cfg(feature = "rank")]
FunctionExpr::Rank { options, seed } => {
options.hash(state);
seed.hash(state);
},
#[cfg(feature = "round_series")]
FunctionExpr::Clip { has_min, has_max } => {
has_min.hash(state);
has_max.hash(state);
},
#[cfg(feature = "top_k")]
FunctionExpr::TopK(a) => a.hash(state),
#[cfg(feature = "cum_agg")]
FunctionExpr::Cumcount { reverse } => reverse.hash(state),
#[cfg(feature = "cum_agg")]
FunctionExpr::Cumsum { reverse } => reverse.hash(state),
#[cfg(feature = "cum_agg")]
FunctionExpr::Cumprod { reverse } => reverse.hash(state),
#[cfg(feature = "cum_agg")]
FunctionExpr::Cummin { reverse } => reverse.hash(state),
#[cfg(feature = "cum_agg")]
FunctionExpr::Cummax { reverse } => reverse.hash(state),
#[cfg(feature = "dtype-struct")]
FunctionExpr::ValueCounts { sort, parallel } => {
sort.hash(state);
parallel.hash(state);
},
#[cfg(feature = "unique_counts")]
FunctionExpr::UniqueCounts => {},
#[cfg(feature = "approx_unique")]
FunctionExpr::ApproxNUnique => {},
FunctionExpr::Coalesce => {},
FunctionExpr::ShrinkType => {},
#[cfg(feature = "pct_change")]
FunctionExpr::PctChange => {},
#[cfg(feature = "log")]
FunctionExpr::Entropy { base, normalize } => {
base.to_bits().hash(state);
normalize.hash(state);
},
#[cfg(feature = "log")]
FunctionExpr::Log { base } => base.to_bits().hash(state),
#[cfg(feature = "log")]
FunctionExpr::Log1p => {},
#[cfg(feature = "log")]
FunctionExpr::Exp => {},
FunctionExpr::Unique(a) => a.hash(state),
#[cfg(feature = "round_series")]
FunctionExpr::Round { decimals } => decimals.hash(state),
#[cfg(feature = "round_series")]
FunctionExpr::Floor => {},
#[cfg(feature = "round_series")]
FunctionExpr::Ceil => {},
FunctionExpr::UpperBound => {},
FunctionExpr::LowerBound => {},
FunctionExpr::ConcatExpr(a) => a.hash(state),
#[cfg(feature = "peaks")]
FunctionExpr::PeakMin => {},
#[cfg(feature = "peaks")]
FunctionExpr::PeakMax => {},
#[cfg(feature = "cutqcut")]
FunctionExpr::Cut {
breaks,
labels,
left_closed,
include_breaks,
} => {
let slice = bytemuck::cast_slice::<_, u64>(breaks);
slice.hash(state);
labels.hash(state);
left_closed.hash(state);
include_breaks.hash(state);
},
#[cfg(feature = "cutqcut")]
FunctionExpr::QCut {
probs,
labels,
left_closed,
allow_duplicates,
include_breaks,
} => {
let slice = bytemuck::cast_slice::<_, u64>(probs);
slice.hash(state);
labels.hash(state);
left_closed.hash(state);
allow_duplicates.hash(state);
include_breaks.hash(state);
},
#[cfg(feature = "rle")]
FunctionExpr::RLE => {},
#[cfg(feature = "rle")]
FunctionExpr::RLEID => {},
FunctionExpr::ToPhysical => {},
FunctionExpr::SetSortedFlag(is_sorted) => is_sorted.hash(state),
FunctionExpr::BackwardFill { limit } | FunctionExpr::ForwardFill { limit } => {
limit.hash(state)
},
#[cfg(feature = "ewma")]
FunctionExpr::EwmMean { options } => options.hash(state),
#[cfg(feature = "ewma")]
FunctionExpr::EwmStd { options } => options.hash(state),
#[cfg(feature = "ewma")]
FunctionExpr::EwmVar { options } => options.hash(state),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions py-polars/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions py-polars/tests/unit/test_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,3 +536,21 @@ def test_cse_is_in_11489() -> None:
"any_cond": [False, True, True, True, False],
"val": [0.0, 1.0, 1.0, 1.0, 0.0],
}


def test_cse_11958() -> None:
df = pl.LazyFrame({"a": [1, 2, 3, 4, 5]})
vector_losses = []
for lag in range(1, 5):
difference = pl.col("a") - pl.col("a").shift(lag)
component_loss = pl.when(difference >= 0).then(difference * 10)
vector_losses.append(component_loss.alias(f"diff{lag}"))

q = df.select(vector_losses)
assert "__POLARS_CSE" in q.explain(comm_subexpr_elim=True)
assert q.collect(comm_subexpr_elim=True).to_dict(False) == {
"diff1": [None, 10, 10, 10, 10],
"diff2": [None, None, 20, 20, 20],
"diff3": [None, None, None, 30, 30],
"diff4": [None, None, None, None, 40],
}

0 comments on commit 394cfc5

Please sign in to comment.