From 394cfc57da06bc51df8fb7114557ccc7f6651071 Mon Sep 17 00:00:00 2001 From: ritchie Date: Mon, 23 Oct 2023 20:32:34 +0200 Subject: [PATCH] fix(rust): implement proper hash for identifier in cse --- .../src/legacy/kernels/ewm/mod.rs | 12 ++ crates/polars-core/src/series/series_trait.rs | 2 +- crates/polars-ops/src/series/ops/rank.rs | 4 +- crates/polars-plan/Cargo.toml | 1 + .../polars-plan/src/dsl/function_expr/mod.rs | 145 +++++++++++++++++- py-polars/Cargo.lock | 1 + py-polars/tests/unit/test_cse.py | 18 +++ 7 files changed, 178 insertions(+), 5 deletions(-) diff --git a/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs b/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs index 5984106f1521..2eafab5bbed4 100644 --- a/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs @@ -1,6 +1,8 @@ mod average; mod variance; +use std::hash::{Hash, Hasher}; + pub use average::*; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -29,6 +31,16 @@ impl Default for EWMOptions { } } +impl Hash for EWMOptions { + fn hash(&self, state: &mut H) { + self.alpha.to_bits().hash(state); + self.adjust.hash(state); + self.bias.hash(state); + self.min_periods.hash(state); + self.ignore_nulls.hash(state); + } +} + impl EWMOptions { pub fn and_min_periods(mut self, min_periods: usize) -> Self { self.min_periods = min_periods; diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index f7ab23947447..27ac2d735e38 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -12,7 +12,7 @@ use crate::chunked_array::object::PolarsObjectSafe; pub use crate::prelude::ChunkCompare; use crate::prelude::*; -#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum IsSorted { Ascending, diff --git a/crates/polars-ops/src/series/ops/rank.rs b/crates/polars-ops/src/series/ops/rank.rs index c251aaa6922e..8bcc3347fc66 100644 --- a/crates/polars-ops/src/series/ops/rank.rs +++ b/crates/polars-ops/src/series/ops/rank.rs @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize}; use crate::prelude::SeriesSealed; -#[derive(Copy, Clone, Debug, PartialEq)] +#[derive(Copy, Clone, Debug, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum RankMethod { Average, @@ -24,7 +24,7 @@ pub enum RankMethod { } // We might want to add a `nulls_last` or `null_behavior` field. -#[derive(Copy, Clone, Debug, PartialEq)] +#[derive(Copy, Clone, Debug, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct RankOptions { pub method: RankMethod, diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index b8d55668befe..f51c41bcde40 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -22,6 +22,7 @@ polars-utils = { workspace = true } ahash = { workspace = true } arrow = { workspace = true } +bytemuck = { workspace = true } chrono = { workspace = true, optional = true } chrono-tz = { workspace = true, optional = true } ciborium = { workspace = true, optional = true } diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index f8d22b9f8968..38eb03906604 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -137,6 +137,7 @@ pub enum FunctionExpr { ShiftAndFill { periods: i64, }, + Shift(i64), DropNans, DropNulls, #[cfg(feature = "mode")] @@ -165,7 +166,6 @@ pub enum FunctionExpr { AsStruct, #[cfg(feature = "top_k")] TopK(bool), - Shift(i64), #[cfg(feature = "cum_agg")] Cumcount { reverse: bool, @@ -345,7 +345,148 @@ impl Hash for FunctionExpr { lib.hash(state); symbol.hash(state); }, - _ => {}, + FunctionExpr::Shift(periods) | FunctionExpr::ShiftAndFill { periods } => { + periods.hash(state) + }, + FunctionExpr::SumHorizontal + | FunctionExpr::MaxHorizontal + | FunctionExpr::MinHorizontal + | FunctionExpr::DropNans + | FunctionExpr::DropNulls + | FunctionExpr::Reverse + | FunctionExpr::ArgUnique => {}, + #[cfg(feature = "mode")] + FunctionExpr::Mode => {}, + #[cfg(feature = "abs")] + FunctionExpr::Abs => {}, + FunctionExpr::NullCount => {}, + #[cfg(feature = "date_offset")] + FunctionExpr::DateOffset => {}, + #[cfg(feature = "arg_where")] + FunctionExpr::ArgWhere => {}, + #[cfg(feature = "trigonometry")] + FunctionExpr::Atan2 => {}, + #[cfg(feature = "dtype-struct")] + FunctionExpr::AsStruct => {}, + #[cfg(feature = "sign")] + FunctionExpr::Sign => {}, + FunctionExpr::Hash(a, b, c, d) => (a, b, c, d).hash(state), + FunctionExpr::FillNull { super_type } => super_type.hash(state), + #[cfg(all(feature = "rolling_window", feature = "moment"))] + FunctionExpr::RollingSkew { window_size, bias } => { + window_size.hash(state); + bias.hash(state); + }, + #[cfg(feature = "moment")] + FunctionExpr::Skew(a) => a.hash(state), + #[cfg(feature = "moment")] + FunctionExpr::Kurtosis(a, b) => { + a.hash(state); + b.hash(state); + }, + #[cfg(feature = "rank")] + FunctionExpr::Rank { options, seed } => { + options.hash(state); + seed.hash(state); + }, + #[cfg(feature = "round_series")] + FunctionExpr::Clip { has_min, has_max } => { + has_min.hash(state); + has_max.hash(state); + }, + #[cfg(feature = "top_k")] + FunctionExpr::TopK(a) => a.hash(state), + #[cfg(feature = "cum_agg")] + FunctionExpr::Cumcount { reverse } => reverse.hash(state), + #[cfg(feature = "cum_agg")] + FunctionExpr::Cumsum { reverse } => reverse.hash(state), + #[cfg(feature = "cum_agg")] + FunctionExpr::Cumprod { reverse } => reverse.hash(state), + #[cfg(feature = "cum_agg")] + FunctionExpr::Cummin { reverse } => reverse.hash(state), + #[cfg(feature = "cum_agg")] + FunctionExpr::Cummax { reverse } => reverse.hash(state), + #[cfg(feature = "dtype-struct")] + FunctionExpr::ValueCounts { sort, parallel } => { + sort.hash(state); + parallel.hash(state); + }, + #[cfg(feature = "unique_counts")] + FunctionExpr::UniqueCounts => {}, + #[cfg(feature = "approx_unique")] + FunctionExpr::ApproxNUnique => {}, + FunctionExpr::Coalesce => {}, + FunctionExpr::ShrinkType => {}, + #[cfg(feature = "pct_change")] + FunctionExpr::PctChange => {}, + #[cfg(feature = "log")] + FunctionExpr::Entropy { base, normalize } => { + base.to_bits().hash(state); + normalize.hash(state); + }, + #[cfg(feature = "log")] + FunctionExpr::Log { base } => base.to_bits().hash(state), + #[cfg(feature = "log")] + FunctionExpr::Log1p => {}, + #[cfg(feature = "log")] + FunctionExpr::Exp => {}, + FunctionExpr::Unique(a) => a.hash(state), + #[cfg(feature = "round_series")] + FunctionExpr::Round { decimals } => decimals.hash(state), + #[cfg(feature = "round_series")] + FunctionExpr::Floor => {}, + #[cfg(feature = "round_series")] + FunctionExpr::Ceil => {}, + FunctionExpr::UpperBound => {}, + FunctionExpr::LowerBound => {}, + FunctionExpr::ConcatExpr(a) => a.hash(state), + #[cfg(feature = "peaks")] + FunctionExpr::PeakMin => {}, + #[cfg(feature = "peaks")] + FunctionExpr::PeakMax => {}, + #[cfg(feature = "cutqcut")] + FunctionExpr::Cut { + breaks, + labels, + left_closed, + include_breaks, + } => { + let slice = bytemuck::cast_slice::<_, u64>(breaks); + slice.hash(state); + labels.hash(state); + left_closed.hash(state); + include_breaks.hash(state); + }, + #[cfg(feature = "cutqcut")] + FunctionExpr::QCut { + probs, + labels, + left_closed, + allow_duplicates, + include_breaks, + } => { + let slice = bytemuck::cast_slice::<_, u64>(probs); + slice.hash(state); + labels.hash(state); + left_closed.hash(state); + allow_duplicates.hash(state); + include_breaks.hash(state); + }, + #[cfg(feature = "rle")] + FunctionExpr::RLE => {}, + #[cfg(feature = "rle")] + FunctionExpr::RLEID => {}, + FunctionExpr::ToPhysical => {}, + FunctionExpr::SetSortedFlag(is_sorted) => is_sorted.hash(state), + FunctionExpr::BackwardFill { limit } | FunctionExpr::ForwardFill { limit } => { + limit.hash(state) + }, + #[cfg(feature = "ewma")] + FunctionExpr::EwmMean { options } => options.hash(state), + #[cfg(feature = "ewma")] + FunctionExpr::EwmStd { options } => options.hash(state), + #[cfg(feature = "ewma")] + FunctionExpr::EwmVar { options } => options.hash(state), } } } diff --git a/py-polars/Cargo.lock b/py-polars/Cargo.lock index 0c4084a5d0a5..6f62fe0f6532 100644 --- a/py-polars/Cargo.lock +++ b/py-polars/Cargo.lock @@ -1828,6 +1828,7 @@ name = "polars-plan" version = "0.33.2" dependencies = [ "ahash", + "bytemuck", "chrono", "chrono-tz", "ciborium", diff --git a/py-polars/tests/unit/test_cse.py b/py-polars/tests/unit/test_cse.py index a6bcb8f541d1..319f2172949c 100644 --- a/py-polars/tests/unit/test_cse.py +++ b/py-polars/tests/unit/test_cse.py @@ -536,3 +536,21 @@ def test_cse_is_in_11489() -> None: "any_cond": [False, True, True, True, False], "val": [0.0, 1.0, 1.0, 1.0, 0.0], } + + +def test_cse_11958() -> None: + df = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + vector_losses = [] + for lag in range(1, 5): + difference = pl.col("a") - pl.col("a").shift(lag) + component_loss = pl.when(difference >= 0).then(difference * 10) + vector_losses.append(component_loss.alias(f"diff{lag}")) + + q = df.select(vector_losses) + assert "__POLARS_CSE" in q.explain(comm_subexpr_elim=True) + assert q.collect(comm_subexpr_elim=True).to_dict(False) == { + "diff1": [None, 10, 10, 10, 10], + "diff2": [None, None, 20, 20, 20], + "diff3": [None, None, None, 30, 30], + "diff4": [None, None, None, None, 40], + }