diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index 7c28a57e21a4..afecf3c5a88e 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -96,7 +96,6 @@ moment = [] diagonal_concat = [] horizontal_concat = [] abs = [] -ewma = [] dataframe_arithmetic = [] product = [] unique_counts = [] diff --git a/crates/polars-core/src/series/ops/ewm.rs b/crates/polars-core/src/series/ops/ewm.rs deleted file mode 100644 index 388a44eb4a2d..000000000000 --- a/crates/polars-core/src/series/ops/ewm.rs +++ /dev/null @@ -1,104 +0,0 @@ -use std::convert::TryFrom; - -pub use arrow::legacy::kernels::ewm::EWMOptions; -use arrow::legacy::kernels::ewm::{ewm_mean, ewm_std, ewm_var}; - -use crate::prelude::*; - -fn check_alpha(alpha: f64) -> PolarsResult<()> { - polars_ensure!((0.0..=1.0).contains(&alpha), ComputeError: "alpha must be in [0; 1]"); - Ok(()) -} - -impl Series { - pub fn ewm_mean(&self, options: EWMOptions) -> PolarsResult { - check_alpha(options.alpha)?; - match self.dtype() { - DataType::Float32 => { - let xs = self.f32().unwrap(); - let result = ewm_mean( - xs, - options.alpha as f32, - options.adjust, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((self.name(), Box::new(result) as ArrayRef)) - }, - DataType::Float64 => { - let xs = self.f64().unwrap(); - let result = ewm_mean( - xs, - options.alpha, - options.adjust, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((self.name(), Box::new(result) as ArrayRef)) - }, - _ => self.cast(&DataType::Float64)?.ewm_mean(options), - } - } - - pub fn ewm_std(&self, options: EWMOptions) -> PolarsResult { - check_alpha(options.alpha)?; - match self.dtype() { - DataType::Float32 => { - let xs = self.f32().unwrap(); - let result = ewm_std( - xs, - options.alpha as f32, - options.adjust, - options.bias, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((self.name(), Box::new(result) as ArrayRef)) - }, - DataType::Float64 => { - let xs = self.f64().unwrap(); - let result = ewm_std( - xs, - options.alpha, - options.adjust, - options.bias, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((self.name(), Box::new(result) as ArrayRef)) - }, - _ => self.cast(&DataType::Float64)?.ewm_std(options), - } - } - - pub fn ewm_var(&self, options: EWMOptions) -> PolarsResult { - check_alpha(options.alpha)?; - match self.dtype() { - DataType::Float32 => { - let xs = self.f32().unwrap(); - let result = ewm_var( - xs, - options.alpha as f32, - options.adjust, - options.bias, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((self.name(), Box::new(result) as ArrayRef)) - }, - DataType::Float64 => { - let xs = self.f64().unwrap(); - let result = ewm_var( - xs, - options.alpha, - options.adjust, - options.bias, - options.min_periods, - options.ignore_nulls, - ); - Series::try_from((self.name(), Box::new(result) as ArrayRef)) - }, - _ => self.cast(&DataType::Float64)?.ewm_var(options), - } - } -} diff --git a/crates/polars-core/src/series/ops/mod.rs b/crates/polars-core/src/series/ops/mod.rs index 1cf51c1743dd..bc57ad3ee480 100644 --- a/crates/polars-core/src/series/ops/mod.rs +++ b/crates/polars-core/src/series/ops/mod.rs @@ -1,8 +1,6 @@ #[cfg(feature = "diff")] pub mod diff; mod downcast; -#[cfg(feature = "ewma")] -mod ewm; mod extend; #[cfg(feature = "moment")] pub mod moment; diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 3e1f83fa8978..8e5c2aaa25c3 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -113,3 +113,4 @@ convert_index = [] repeat_by = [] peaks = [] cum_agg = [] +ewma = [] diff --git a/crates/polars-ops/src/series/ops/ewm.rs b/crates/polars-ops/src/series/ops/ewm.rs new file mode 100644 index 000000000000..6f4458777306 --- /dev/null +++ b/crates/polars-ops/src/series/ops/ewm.rs @@ -0,0 +1,103 @@ +use std::convert::TryFrom; + +pub use arrow::legacy::kernels::ewm::EWMOptions; +use arrow::legacy::kernels::ewm::{ + ewm_mean as kernel_ewm_mean, ewm_std as kernel_ewm_std, ewm_var as kernel_ewm_var, +}; +use polars_core::prelude::*; + +fn check_alpha(alpha: f64) -> PolarsResult<()> { + polars_ensure!((0.0..=1.0).contains(&alpha), ComputeError: "alpha must be in [0; 1]"); + Ok(()) +} + +pub fn ewm_mean(s: &Series, options: EWMOptions) -> PolarsResult { + check_alpha(options.alpha)?; + match s.dtype() { + DataType::Float32 => { + let xs = s.f32().unwrap(); + let result = kernel_ewm_mean( + xs, + options.alpha as f32, + options.adjust, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name(), Box::new(result) as ArrayRef)) + }, + DataType::Float64 => { + let xs = s.f64().unwrap(); + let result = kernel_ewm_mean( + xs, + options.alpha, + options.adjust, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name(), Box::new(result) as ArrayRef)) + }, + _ => ewm_mean(&s.cast(&DataType::Float64)?, options), + } +} + +pub fn ewm_std(s: &Series, options: EWMOptions) -> PolarsResult { + check_alpha(options.alpha)?; + match s.dtype() { + DataType::Float32 => { + let xs = s.f32().unwrap(); + let result = kernel_ewm_std( + xs, + options.alpha as f32, + options.adjust, + options.bias, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name(), Box::new(result) as ArrayRef)) + }, + DataType::Float64 => { + let xs = s.f64().unwrap(); + let result = kernel_ewm_std( + xs, + options.alpha, + options.adjust, + options.bias, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name(), Box::new(result) as ArrayRef)) + }, + _ => ewm_std(&s.cast(&DataType::Float64)?, options), + } +} + +pub fn ewm_var(s: &Series, options: EWMOptions) -> PolarsResult { + check_alpha(options.alpha)?; + match s.dtype() { + DataType::Float32 => { + let xs = s.f32().unwrap(); + let result = kernel_ewm_var( + xs, + options.alpha as f32, + options.adjust, + options.bias, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name(), Box::new(result) as ArrayRef)) + }, + DataType::Float64 => { + let xs = s.f64().unwrap(); + let result = kernel_ewm_var( + xs, + options.alpha, + options.adjust, + options.bias, + options.min_periods, + options.ignore_nulls, + ); + Series::try_from((s.name(), Box::new(result) as ArrayRef)) + }, + _ => ewm_var(&s.cast(&DataType::Float64)?, options), + } +} diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index c524a880d967..6437c7a0ffac 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -7,6 +7,8 @@ mod clip; mod cum_agg; #[cfg(feature = "cutqcut")] mod cut; +#[cfg(feature = "ewma")] +mod ewm; #[cfg(feature = "round_series")] mod floor_divide; #[cfg(feature = "fused")] @@ -47,6 +49,8 @@ pub use clip::*; pub use cum_agg::*; #[cfg(feature = "cutqcut")] pub use cut::*; +#[cfg(feature = "ewma")] +pub use ewm::*; #[cfg(feature = "round_series")] pub use floor_divide::*; #[cfg(feature = "fused")] diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 8ea14f5fc7b4..ee412f48dbec 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -113,7 +113,7 @@ moment = ["polars-core/moment", "polars-ops/moment"] abs = ["polars-core/abs"] random = ["polars-core/random"] dynamic_group_by = ["polars-core/dynamic_group_by"] -ewma = ["polars-core/ewma"] +ewma = ["polars-ops/ewma"] dot_diagram = [] unique_counts = ["polars-core/unique_counts"] log = ["polars-ops/log"] diff --git a/crates/polars-plan/src/dsl/function_expr/ewm.rs b/crates/polars-plan/src/dsl/function_expr/ewm.rs index a26285eef33a..b824ca3013e9 100644 --- a/crates/polars-plan/src/dsl/function_expr/ewm.rs +++ b/crates/polars-plan/src/dsl/function_expr/ewm.rs @@ -1,13 +1,13 @@ use super::*; pub(super) fn ewm_mean(s: &Series, options: EWMOptions) -> PolarsResult { - s.ewm_mean(options) + polars_ops::prelude::ewm_mean(s, options) } pub(super) fn ewm_std(s: &Series, options: EWMOptions) -> PolarsResult { - s.ewm_std(options) + polars_ops::prelude::ewm_std(s, options) } pub(super) fn ewm_var(s: &Series, options: EWMOptions) -> PolarsResult { - s.ewm_var(options) + polars_ops::prelude::ewm_var(s, options) } diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 876193e089fb..58f6a8334bcd 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -152,7 +152,7 @@ diagonal_concat = ["polars-core/diagonal_concat", "polars-lazy?/diagonal_concat" horizontal_concat = ["polars-core/horizontal_concat"] abs = ["polars-core/abs", "polars-lazy?/abs"] dynamic_group_by = ["polars-core/dynamic_group_by", "polars-lazy?/dynamic_group_by"] -ewma = ["polars-core/ewma", "polars-lazy?/ewma"] +ewma = ["polars-ops/ewma", "polars-lazy?/ewma"] dot_diagram = ["polars-lazy?/dot_diagram"] dataframe_arithmetic = ["polars-core/dataframe_arithmetic"] product = ["polars-core/product"]