From ce3dd72e59ef398394c5680c1be0231e0cae3f33 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 23 Oct 2023 14:56:43 +0200 Subject: [PATCH] fix: fix take return dtype in group context. (#11949) --- .../src/physical_plan/expressions/apply.rs | 66 +++++++---- .../src/physical_plan/expressions/take.rs | 21 +++- .../src/physical_plan/expressions/window.rs | 4 +- .../src/physical_plan/planner/expr.rs | 55 +++++---- .../src/physical_plan/planner/lp.rs | 2 +- .../src/physical_plan/streaming/checks.rs | 4 +- crates/polars-lazy/src/tests/aggregations.rs | 19 ++-- crates/polars-plan/src/dsl/binary.rs | 6 +- crates/polars-plan/src/dsl/dt.rs | 2 +- crates/polars-plan/src/dsl/expr.rs | 1 + .../polars-plan/src/dsl/functions/coerce.rs | 2 +- .../polars-plan/src/dsl/functions/concat.rs | 8 +- .../src/dsl/functions/correlation.rs | 12 +- .../src/dsl/functions/horizontal.rs | 38 +++---- crates/polars-plan/src/dsl/functions/index.rs | 2 +- crates/polars-plan/src/dsl/functions/range.rs | 12 +- .../polars-plan/src/dsl/functions/temporal.rs | 4 +- crates/polars-plan/src/dsl/list.rs | 10 +- crates/polars-plan/src/dsl/mod.rs | 91 +++++++++------ crates/polars-plan/src/dsl/python_udf.rs | 2 +- crates/polars-plan/src/dsl/string.rs | 15 +-- .../polars-plan/src/logical_plan/aexpr/mod.rs | 5 +- .../src/logical_plan/conversion.rs | 14 ++- crates/polars-plan/src/logical_plan/format.rs | 12 +- .../polars-plan/src/logical_plan/iterator.rs | 2 +- .../src/logical_plan/optimizer/fused.rs | 2 +- .../optimizer/predicate_pushdown/utils.rs | 10 +- .../logical_plan/optimizer/simplify_expr.rs | 12 +- .../optimizer/slice_pushdown_expr.rs | 2 +- .../polars-plan/src/logical_plan/options.rs | 14 +-- crates/polars-plan/src/utils.rs | 13 ++- .../reference/expressions/modify_select.rst | 1 + py-polars/polars/expr/expr.py | 104 +++++++++++++++--- py-polars/src/expr/general.rs | 12 +- py-polars/tests/unit/dataframe/test_df.py | 2 +- py-polars/tests/unit/namespaces/test_list.py | 16 ++- .../unit/operations/test_aggregations.py | 2 +- .../tests/unit/operations/test_group_by.py | 4 +- 38 files changed, 390 insertions(+), 213 deletions(-) diff --git a/crates/polars-lazy/src/physical_plan/expressions/apply.rs b/crates/polars-lazy/src/physical_plan/expressions/apply.rs index 21d54fc149f3..95734356a988 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/apply.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/apply.rs @@ -14,20 +14,48 @@ use crate::physical_plan::state::ExecutionState; use crate::prelude::*; pub struct ApplyExpr { - pub inputs: Vec>, - pub function: SpecialEq>, - pub expr: Expr, - pub collect_groups: ApplyOptions, - pub auto_explode: bool, - pub allow_rename: bool, - pub pass_name_to_apply: bool, - pub input_schema: Option, - pub allow_threading: bool, - pub check_lengths: bool, - pub allow_group_aware: bool, + inputs: Vec>, + function: SpecialEq>, + expr: Expr, + collect_groups: ApplyOptions, + returns_scalar: bool, + allow_rename: bool, + pass_name_to_apply: bool, + input_schema: Option, + allow_threading: bool, + check_lengths: bool, + allow_group_aware: bool, } impl ApplyExpr { + pub(crate) fn new( + inputs: Vec>, + function: SpecialEq>, + expr: Expr, + options: FunctionOptions, + allow_threading: bool, + input_schema: Option, + ) -> Self { + #[cfg(debug_assertions)] + if matches!(options.collect_groups, ApplyOptions::ElementWise) && options.returns_scalar { + panic!("expr {} is not implemented correctly. 'returns_scalar' and 'elementwise' are mutually exclusive", expr) + } + + Self { + inputs, + function, + expr, + collect_groups: options.collect_groups, + returns_scalar: options.returns_scalar, + allow_rename: options.allow_rename, + pass_name_to_apply: options.pass_name_to_apply, + input_schema, + allow_threading, + check_lengths: options.check_lengths(), + allow_group_aware: options.allow_group_aware, + } + } + pub(crate) fn new_minimal( inputs: Vec>, function: SpecialEq>, @@ -39,7 +67,7 @@ impl ApplyExpr { function, expr, collect_groups, - auto_explode: false, + returns_scalar: false, allow_rename: false, pass_name_to_apply: false, input_schema: None, @@ -70,7 +98,7 @@ impl ApplyExpr { ca: ListChunked, ) -> PolarsResult> { let all_unit_len = all_unit_length(&ca); - if all_unit_len && self.auto_explode { + if all_unit_len && self.returns_scalar { ac.with_series(ca.explode().unwrap().into_series(), true, Some(&self.expr))?; ac.update_groups = UpdateGroups::No; } else { @@ -289,8 +317,8 @@ impl PhysicalExpr for ApplyExpr { ac.with_series(s, true, Some(&self.expr))?; Ok(ac) }, - ApplyOptions::ApplyGroups => self.apply_single_group_aware(ac), - ApplyOptions::ApplyFlat => self.apply_single_elementwise(ac), + ApplyOptions::GroupWise => self.apply_single_group_aware(ac), + ApplyOptions::ElementWise => self.apply_single_elementwise(ac), } } else { let mut acs = self.prepare_multiple_inputs(df, groups, state)?; @@ -305,8 +333,8 @@ impl PhysicalExpr for ApplyExpr { ac.with_series(s, true, Some(&self.expr))?; Ok(ac) }, - ApplyOptions::ApplyGroups => self.apply_multiple_group_aware(acs, df), - ApplyOptions::ApplyFlat => { + ApplyOptions::GroupWise => self.apply_multiple_group_aware(acs, df), + ApplyOptions::ElementWise => { if acs .iter() .any(|ac| matches!(ac.agg_state(), AggState::AggregatedList(_))) @@ -328,7 +356,7 @@ impl PhysicalExpr for ApplyExpr { self.expr.to_field(input_schema, Context::Default) } fn is_valid_aggregation(&self) -> bool { - matches!(self.collect_groups, ApplyOptions::ApplyGroups) + matches!(self.collect_groups, ApplyOptions::GroupWise) } #[cfg(feature = "parquet")] fn as_stats_evaluator(&self) -> Option<&dyn polars_io::predicates::StatsEvaluator> { @@ -345,7 +373,7 @@ impl PhysicalExpr for ApplyExpr { } } fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> { - if self.inputs.len() == 1 && matches!(self.collect_groups, ApplyOptions::ApplyFlat) { + if self.inputs.len() == 1 && matches!(self.collect_groups, ApplyOptions::ElementWise) { Some(self) } else { None diff --git a/crates/polars-lazy/src/physical_plan/expressions/take.rs b/crates/polars-lazy/src/physical_plan/expressions/take.rs index 08d273978d18..b86c585cfeaf 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/take.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/take.rs @@ -12,6 +12,7 @@ pub struct TakeExpr { pub(crate) phys_expr: Arc, pub(crate) idx: Arc, pub(crate) expr: Expr, + pub(crate) returns_scalar: bool, } impl TakeExpr { @@ -101,12 +102,23 @@ impl PhysicalExpr for TakeExpr { }, }; let taken = ac.flat_naive().take(&idx)?; + + let taken = if self.returns_scalar { + taken + } else { + taken.as_list().into_series() + }; + ac.with_series(taken, true, Some(&self.expr))?; return Ok(ac); }, - AggState::AggregatedList(s) => s.list().unwrap().clone(), + AggState::AggregatedList(s) => { + polars_ensure!(!self.returns_scalar, ComputeError: "expected single index"); + s.list().unwrap().clone() + }, // Maybe a literal as well, this needs a different path. AggState::NotAggregated(_) => { + polars_ensure!(!self.returns_scalar, ComputeError: "expected single index"); let s = idx.aggregated(); s.list().unwrap().clone() }, @@ -144,6 +156,13 @@ impl PhysicalExpr for TakeExpr { }, }; let taken = ac.flat_naive().take(&idx.into_inner())?; + + let taken = if self.returns_scalar { + taken + } else { + taken.as_list().into_series() + }; + ac.with_series(taken, true, Some(&self.expr))?; ac.with_update_groups(UpdateGroups::WithGroupsLen); Ok(ac) diff --git a/crates/polars-lazy/src/physical_plan/expressions/window.rs b/crates/polars-lazy/src/physical_plan/expressions/window.rs index e74d4c800210..4810126075b5 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/window.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/window.rs @@ -302,8 +302,8 @@ impl WindowExpr { }, Expr::Function { options, .. } | Expr::AnonymousFunction { options, .. } => { - if options.auto_explode - && matches!(options.collect_groups, ApplyOptions::ApplyGroups) + if options.returns_scalar + && matches!(options.collect_groups, ApplyOptions::GroupWise) { agg_col = true; } diff --git a/crates/polars-lazy/src/physical_plan/planner/expr.rs b/crates/polars-lazy/src/physical_plan/planner/expr.rs index 72b0b80c5d3c..8e7c4a51063b 100644 --- a/crates/polars-lazy/src/physical_plan/planner/expr.rs +++ b/crates/polars-lazy/src/physical_plan/planner/expr.rs @@ -192,13 +192,18 @@ pub(crate) fn create_physical_expr( node_to_expr(expression, expr_arena), ))) }, - Take { expr, idx } => { + Take { + expr, + idx, + returns_scalar, + } => { let phys_expr = create_physical_expr(expr, ctxt, expr_arena, schema, state)?; let phys_idx = create_physical_expr(idx, ctxt, expr_arena, schema, state)?; Ok(Arc::new(TakeExpr { phys_expr, idx: phys_idx, expr: node_to_expr(expression, expr_arena), + returns_scalar, })) }, SortBy { @@ -391,7 +396,7 @@ pub(crate) fn create_physical_expr( vec![input], function, node_to_expr(expression, expr_arena), - ApplyOptions::ApplyFlat, + ApplyOptions::ElementWise, ))) }, _ => { @@ -463,7 +468,7 @@ pub(crate) fn create_physical_expr( options, } => { let is_reducing_aggregation = - options.auto_explode && matches!(options.collect_groups, ApplyOptions::ApplyGroups); + options.returns_scalar && matches!(options.collect_groups, ApplyOptions::GroupWise); // will be reset in the function so get that here let has_window = state.local.has_window; let input = create_physical_expressions_check_state( @@ -478,19 +483,14 @@ pub(crate) fn create_physical_expr( }, )?; - Ok(Arc::new(ApplyExpr { - inputs: input, + Ok(Arc::new(ApplyExpr::new( + input, function, - expr: node_to_expr(expression, expr_arena), - collect_groups: options.collect_groups, - auto_explode: options.auto_explode, - allow_rename: options.allow_rename, - pass_name_to_apply: options.pass_name_to_apply, - input_schema: schema.cloned(), - allow_threading: !state.has_cache, - check_lengths: options.check_lengths(), - allow_group_aware: options.allow_group_aware, - })) + node_to_expr(expression, expr_arena), + options, + !state.has_cache, + schema.cloned(), + ))) }, Function { input, @@ -499,7 +499,7 @@ pub(crate) fn create_physical_expr( .. } => { let is_reducing_aggregation = - options.auto_explode && matches!(options.collect_groups, ApplyOptions::ApplyGroups); + options.returns_scalar && matches!(options.collect_groups, ApplyOptions::GroupWise); // will be reset in the function so get that here let has_window = state.local.has_window; let input = create_physical_expressions_check_state( @@ -514,19 +514,14 @@ pub(crate) fn create_physical_expr( }, )?; - Ok(Arc::new(ApplyExpr { - inputs: input, - function: function.into(), - expr: node_to_expr(expression, expr_arena), - collect_groups: options.collect_groups, - auto_explode: options.auto_explode, - allow_rename: options.allow_rename, - pass_name_to_apply: options.pass_name_to_apply, - input_schema: schema.cloned(), - allow_threading: !state.has_cache, - check_lengths: options.check_lengths(), - allow_group_aware: options.allow_group_aware, - })) + Ok(Arc::new(ApplyExpr::new( + input, + function.into(), + node_to_expr(expression, expr_arena), + options, + !state.has_cache, + schema.cloned(), + ))) }, Slice { input, @@ -553,7 +548,7 @@ pub(crate) fn create_physical_expr( vec![input], function, node_to_expr(expression, expr_arena), - ApplyOptions::ApplyGroups, + ApplyOptions::GroupWise, ))) }, Wildcard => panic!("should be no wildcard at this point"), diff --git a/crates/polars-lazy/src/physical_plan/planner/lp.rs b/crates/polars-lazy/src/physical_plan/planner/lp.rs index f0128817a70d..65f44f29b61e 100644 --- a/crates/polars-lazy/src/physical_plan/planner/lp.rs +++ b/crates/polars-lazy/src/physical_plan/planner/lp.rs @@ -94,7 +94,7 @@ fn partitionable_gb( ) }, Function {input, options, ..} => { - matches!(options.collect_groups, ApplyOptions::ApplyFlat) && input.len() == 1 && + matches!(options.collect_groups, ApplyOptions::ElementWise) && input.len() == 1 && !has_aggregation(input[0]) } BinaryExpr {left, right, ..} => { diff --git a/crates/polars-lazy/src/physical_plan/streaming/checks.rs b/crates/polars-lazy/src/physical_plan/streaming/checks.rs index 02f6f636e9b4..fc8b8f2e1ad2 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/checks.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/checks.rs @@ -29,9 +29,9 @@ pub(super) fn is_streamable(node: Node, expr_arena: &Arena, context: Cont { Context::Default => matches!( options.collect_groups, - ApplyOptions::ApplyFlat | ApplyOptions::ApplyList + ApplyOptions::ElementWise | ApplyOptions::ApplyList ), - Context::Aggregation => matches!(options.collect_groups, ApplyOptions::ApplyFlat), + Context::Aggregation => matches!(options.collect_groups, ApplyOptions::ElementWise), }, AExpr::Column(_) => { seen_column = true; diff --git a/crates/polars-lazy/src/tests/aggregations.rs b/crates/polars-lazy/src/tests/aggregations.rs index 87e9e594052b..b6b6c65e86ef 100644 --- a/crates/polars-lazy/src/tests/aggregations.rs +++ b/crates/polars-lazy/src/tests/aggregations.rs @@ -421,7 +421,7 @@ fn take_aggregations() -> PolarsResult<()> { .clone() .lazy() .group_by([col("user")]) - .agg([col("book").take(col("count").arg_max()).alias("fav_book")]) + .agg([col("book").get(col("count").arg_max()).alias("fav_book")]) .sort("user", Default::default()) .collect()?; @@ -460,7 +460,7 @@ fn take_aggregations() -> PolarsResult<()> { let out = df .lazy() .group_by([col("user")]) - .agg([col("book").take(lit(0)).alias("take_lit")]) + .agg([col("book").get(lit(0)).alias("take_lit")]) .sort("user", Default::default()) .collect()?; @@ -484,7 +484,7 @@ fn test_take_consistency() -> PolarsResult<()> { multithreaded: true, maintain_order: false, }) - .take(lit(0))]) + .get(lit(0))]) .collect()?; let a = out.column("A")?; @@ -502,7 +502,7 @@ fn test_take_consistency() -> PolarsResult<()> { multithreaded: true, maintain_order: false, }) - .take(lit(0))]) + .get(lit(0))]) .collect()?; let out = out.column("A")?; @@ -521,10 +521,10 @@ fn test_take_consistency() -> PolarsResult<()> { multithreaded: true, maintain_order: false, }) - .take(lit(0)) + .get(lit(0)) .alias("1"), col("A") - .take( + .get( col("A") .arg_sort(SortOptions { descending: true, @@ -532,7 +532,7 @@ fn test_take_consistency() -> PolarsResult<()> { multithreaded: true, maintain_order: false, }) - .take(lit(0)), + .get(lit(0)), ) .alias("2"), ]) @@ -556,10 +556,7 @@ fn test_take_in_groups() -> PolarsResult<()> { let out = df .lazy() .sort("fruits", Default::default()) - .select([col("B") - .take(lit(Series::new("", &[0u32]))) - .over([col("fruits")]) - .alias("taken")]) + .select([col("B").get(lit(0u32)).over([col("fruits")]).alias("taken")]) .collect()?; assert_eq!( diff --git a/crates/polars-plan/src/dsl/binary.rs b/crates/polars-plan/src/dsl/binary.rs index 5c8aab05c579..57c40df7ed6f 100644 --- a/crates/polars-plan/src/dsl/binary.rs +++ b/crates/polars-plan/src/dsl/binary.rs @@ -9,7 +9,7 @@ impl BinaryNameSpace { self.0.map_many_private( FunctionExpr::BinaryExpr(BinaryFunction::Contains), &[pat], - true, + false, true, ) } @@ -19,7 +19,7 @@ impl BinaryNameSpace { self.0.map_many_private( FunctionExpr::BinaryExpr(BinaryFunction::EndsWith), &[sub], - true, + false, true, ) } @@ -29,7 +29,7 @@ impl BinaryNameSpace { self.0.map_many_private( FunctionExpr::BinaryExpr(BinaryFunction::StartsWith), &[sub], - true, + false, true, ) } diff --git a/crates/polars-plan/src/dsl/dt.rs b/crates/polars-plan/src/dsl/dt.rs index 84e9023f5994..fd077a498086 100644 --- a/crates/polars-plan/src/dsl/dt.rs +++ b/crates/polars-plan/src/dsl/dt.rs @@ -172,7 +172,7 @@ impl DateLikeNameSpace { self.0.map_many_private( FunctionExpr::TemporalExpr(TemporalFunction::Truncate(offset)), &[every, ambiguous], - true, + false, false, ) } diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index dd7b2eeec8fa..f319e9052d47 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -89,6 +89,7 @@ pub enum Expr { Take { expr: Box, idx: Box, + returns_scalar: bool, }, SortBy { expr: Box, diff --git a/crates/polars-plan/src/dsl/functions/coerce.rs b/crates/polars-plan/src/dsl/functions/coerce.rs index e009e9e61918..82d0c3324f5a 100644 --- a/crates/polars-plan/src/dsl/functions/coerce.rs +++ b/crates/polars-plan/src/dsl/functions/coerce.rs @@ -10,7 +10,7 @@ pub fn as_struct(exprs: Vec) -> Expr { options: FunctionOptions { input_wildcard_expansion: true, pass_name_to_apply: true, - collect_groups: ApplyOptions::ApplyFlat, + collect_groups: ApplyOptions::ElementWise, ..Default::default() }, } diff --git a/crates/polars-plan/src/dsl/functions/concat.rs b/crates/polars-plan/src/dsl/functions/concat.rs index f66a6985fa67..1f8e91cf5cb2 100644 --- a/crates/polars-plan/src/dsl/functions/concat.rs +++ b/crates/polars-plan/src/dsl/functions/concat.rs @@ -10,9 +10,9 @@ pub fn concat_str>(s: E, separator: &str) -> Expr { input, function: StringFunction::ConcatHorizontal(separator).into(), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, + collect_groups: ApplyOptions::ElementWise, input_wildcard_expansion: true, - auto_explode: true, + returns_scalar: false, ..Default::default() }, } @@ -58,7 +58,7 @@ pub fn concat_list, IE: Into + Clone>(s: E) -> PolarsResult input: s, function: FunctionExpr::ListExpr(ListFunction::Concat), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, input_wildcard_expansion: true, ..Default::default() }, @@ -76,7 +76,7 @@ pub fn concat_expr, IE: Into + Clone>( input: s, function: FunctionExpr::ConcatExpr(rechunk), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, input_wildcard_expansion: true, cast_to_supertypes: true, ..Default::default() diff --git a/crates/polars-plan/src/dsl/functions/correlation.rs b/crates/polars-plan/src/dsl/functions/correlation.rs index f5912c390a20..6b947721a6ec 100644 --- a/crates/polars-plan/src/dsl/functions/correlation.rs +++ b/crates/polars-plan/src/dsl/functions/correlation.rs @@ -11,9 +11,9 @@ pub fn cov(a: Expr, b: Expr) -> Expr { input, function, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, cast_to_supertypes: true, - auto_explode: true, + returns_scalar: true, ..Default::default() }, } @@ -34,9 +34,9 @@ pub fn pearson_corr(a: Expr, b: Expr, ddof: u8) -> Expr { input, function, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, cast_to_supertypes: true, - auto_explode: true, + returns_scalar: true, ..Default::default() }, } @@ -62,9 +62,9 @@ pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8, propagate_nans: bool) -> E input, function, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, cast_to_supertypes: true, - auto_explode: true, + returns_scalar: true, ..Default::default() }, } diff --git a/crates/polars-plan/src/dsl/functions/horizontal.rs b/crates/polars-plan/src/dsl/functions/horizontal.rs index 8517ef115387..1b10f917282a 100644 --- a/crates/polars-plan/src/dsl/functions/horizontal.rs +++ b/crates/polars-plan/src/dsl/functions/horizontal.rs @@ -44,9 +44,9 @@ where function, output_type: GetOutput::super_type(), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, input_wildcard_expansion: true, - auto_explode: true, + returns_scalar: true, fmt_str: "fold", ..Default::default() }, @@ -87,9 +87,9 @@ where function, output_type: GetOutput::super_type(), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, input_wildcard_expansion: true, - auto_explode: true, + returns_scalar: true, fmt_str: "reduce", ..Default::default() }, @@ -132,9 +132,9 @@ where function, output_type: cumfold_dtype(), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, input_wildcard_expansion: true, - auto_explode: true, + returns_scalar: true, fmt_str: "cumreduce", ..Default::default() }, @@ -181,9 +181,9 @@ where function, output_type: cumfold_dtype(), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, input_wildcard_expansion: true, - auto_explode: true, + returns_scalar: true, fmt_str: "cumfold", ..Default::default() }, @@ -199,9 +199,9 @@ pub fn all_horizontal>(exprs: E) -> Expr { input: exprs, function: FunctionExpr::Boolean(BooleanFunction::AllHorizontal), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, + collect_groups: ApplyOptions::ElementWise, input_wildcard_expansion: true, - auto_explode: true, + returns_scalar: false, cast_to_supertypes: false, allow_rename: true, ..Default::default() @@ -218,9 +218,9 @@ pub fn any_horizontal>(exprs: E) -> Expr { input: exprs, function: FunctionExpr::Boolean(BooleanFunction::AnyHorizontal), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, + collect_groups: ApplyOptions::ElementWise, input_wildcard_expansion: true, - auto_explode: true, + returns_scalar: false, cast_to_supertypes: false, allow_rename: true, ..Default::default() @@ -241,9 +241,9 @@ pub fn max_horizontal>(exprs: E) -> Expr { input: exprs, function: FunctionExpr::MaxHorizontal, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, + collect_groups: ApplyOptions::ElementWise, input_wildcard_expansion: true, - auto_explode: true, + returns_scalar: false, allow_rename: true, ..Default::default() }, @@ -263,9 +263,9 @@ pub fn min_horizontal>(exprs: E) -> Expr { input: exprs, function: FunctionExpr::MinHorizontal, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, + collect_groups: ApplyOptions::ElementWise, input_wildcard_expansion: true, - auto_explode: true, + returns_scalar: false, allow_rename: true, ..Default::default() }, @@ -282,9 +282,9 @@ pub fn sum_horizontal>(exprs: E) -> Expr { input: exprs, function: FunctionExpr::SumHorizontal, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, + collect_groups: ApplyOptions::ElementWise, input_wildcard_expansion: true, - auto_explode: true, + returns_scalar: false, cast_to_supertypes: false, allow_rename: true, ..Default::default() @@ -301,7 +301,7 @@ pub fn coalesce(exprs: &[Expr]) -> Expr { input, function: FunctionExpr::Coalesce, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, + collect_groups: ApplyOptions::ElementWise, cast_to_supertypes: true, input_wildcard_expansion: true, ..Default::default() diff --git a/crates/polars-plan/src/dsl/functions/index.rs b/crates/polars-plan/src/dsl/functions/index.rs index 2a103457aad5..da210d830a44 100644 --- a/crates/polars-plan/src/dsl/functions/index.rs +++ b/crates/polars-plan/src/dsl/functions/index.rs @@ -22,7 +22,7 @@ pub fn arg_where>(condition: E) -> Expr { input: vec![condition], function: FunctionExpr::ArgWhere, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, ..Default::default() }, } diff --git a/crates/polars-plan/src/dsl/functions/range.rs b/crates/polars-plan/src/dsl/functions/range.rs index 89c3e6bf6312..98d2909cec77 100644 --- a/crates/polars-plan/src/dsl/functions/range.rs +++ b/crates/polars-plan/src/dsl/functions/range.rs @@ -56,7 +56,7 @@ pub fn date_range( time_zone, }), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, cast_to_supertypes: true, allow_rename: true, ..Default::default() @@ -85,7 +85,7 @@ pub fn date_ranges( time_zone, }), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, cast_to_supertypes: true, allow_rename: true, ..Default::default() @@ -114,7 +114,7 @@ pub fn datetime_range( time_zone, }), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, cast_to_supertypes: true, allow_rename: true, ..Default::default() @@ -143,7 +143,7 @@ pub fn datetime_ranges( time_zone, }), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, cast_to_supertypes: true, allow_rename: true, ..Default::default() @@ -160,7 +160,7 @@ pub fn time_range(start: Expr, end: Expr, interval: Duration, closed: ClosedWind input, function: FunctionExpr::Range(RangeFunction::TimeRange { interval, closed }), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, allow_rename: true, ..Default::default() }, @@ -176,7 +176,7 @@ pub fn time_ranges(start: Expr, end: Expr, interval: Duration, closed: ClosedWin input, function: FunctionExpr::Range(RangeFunction::TimeRanges { interval, closed }), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, allow_rename: true, ..Default::default() }, diff --git a/crates/polars-plan/src/dsl/functions/temporal.rs b/crates/polars-plan/src/dsl/functions/temporal.rs index b2eb3af5a229..a3a1097c88d7 100644 --- a/crates/polars-plan/src/dsl/functions/temporal.rs +++ b/crates/polars-plan/src/dsl/functions/temporal.rs @@ -141,7 +141,7 @@ pub fn datetime(args: DatetimeArgs) -> Expr { time_zone, }), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, + collect_groups: ApplyOptions::ElementWise, allow_rename: true, input_wildcard_expansion: true, fmt_str: "datetime", @@ -359,7 +359,7 @@ pub fn duration(args: DurationArgs) -> Expr { function, output_type: GetOutput::from_type(DataType::Duration(args.time_unit)), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, + collect_groups: ApplyOptions::ElementWise, input_wildcard_expansion: true, fmt_str: "duration", ..Default::default() diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index c8741dba73ff..6232c29c0a46 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -137,7 +137,7 @@ impl ListNameSpace { self.0.map_many_private( FunctionExpr::ListExpr(ListFunction::Get), &[index], - true, + false, false, ) } @@ -152,7 +152,7 @@ impl ListNameSpace { self.0.map_many_private( FunctionExpr::ListExpr(ListFunction::Take(null_on_oob)), &[index], - true, + false, false, ) } @@ -216,7 +216,7 @@ impl ListNameSpace { self.0.map_many_private( FunctionExpr::ListExpr(ListFunction::Slice), &[offset, length], - true, + false, false, ) } @@ -296,7 +296,7 @@ impl ListNameSpace { .map_many_private( FunctionExpr::ListExpr(ListFunction::Contains), &[other], - true, + false, false, ) .with_function_options(|mut options| { @@ -313,7 +313,7 @@ impl ListNameSpace { .map_many_private( FunctionExpr::ListExpr(ListFunction::CountMatches), &[other], - true, + false, false, ) .with_function_options(|mut options| { diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 05dca3120e10..516eaa183639 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -349,8 +349,8 @@ impl Expr { /// Get the index value that has the minimum value. pub fn arg_min(self) -> Self { let options = FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, - auto_explode: true, + collect_groups: ApplyOptions::GroupWise, + returns_scalar: true, fmt_str: "arg_min", ..Default::default() }; @@ -370,8 +370,8 @@ impl Expr { /// Get the index value that has the maximum value. pub fn arg_max(self) -> Self { let options = FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, - auto_explode: true, + collect_groups: ApplyOptions::GroupWise, + returns_scalar: true, fmt_str: "arg_max", ..Default::default() }; @@ -391,7 +391,7 @@ impl Expr { /// Get the index values that would sort this expression. pub fn arg_sort(self, sort_options: SortOptions) -> Self { let options = FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, fmt_str: "arg_sort", ..Default::default() }; @@ -411,8 +411,8 @@ impl Expr { input: vec![self, element], function: FunctionExpr::SearchSorted(side), options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, - auto_explode: true, + collect_groups: ApplyOptions::GroupWise, + returns_scalar: true, fmt_str: "search_sorted", cast_to_supertypes: true, ..Default::default() @@ -444,6 +444,16 @@ impl Expr { Expr::Take { expr: Box::new(self), idx: Box::new(idx.into()), + returns_scalar: false, + } + } + + /// Take the values by a single index. + pub fn get>(self, idx: E) -> Self { + Expr::Take { + expr: Box::new(self), + idx: Box::new(idx.into()), + returns_scalar: true, } } @@ -507,7 +517,7 @@ impl Expr { function: SpecialEq::new(Arc::new(f)), output_type, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, + collect_groups: ApplyOptions::ElementWise, fmt_str: "map", ..Default::default() }, @@ -519,7 +529,7 @@ impl Expr { input: vec![self], function: function_expr, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, + collect_groups: ApplyOptions::ElementWise, ..Default::default() }, } @@ -540,7 +550,7 @@ impl Expr { function: SpecialEq::new(Arc::new(function)), output_type, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, + collect_groups: ApplyOptions::ElementWise, fmt_str: "", ..Default::default() }, @@ -612,7 +622,7 @@ impl Expr { function: SpecialEq::new(Arc::new(f)), output_type, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, fmt_str: "", ..Default::default() }, @@ -624,7 +634,7 @@ impl Expr { input: vec![self], function: function_expr, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, ..Default::default() }, } @@ -645,7 +655,7 @@ impl Expr { function: SpecialEq::new(Arc::new(function)), output_type, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, fmt_str: "", ..Default::default() }, @@ -667,8 +677,8 @@ impl Expr { input, function: function_expr, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, - auto_explode, + collect_groups: ApplyOptions::GroupWise, + returns_scalar: auto_explode, cast_to_supertypes, ..Default::default() }, @@ -679,7 +689,7 @@ impl Expr { self, function_expr: FunctionExpr, arguments: &[Expr], - auto_explode: bool, + returns_scalar: bool, cast_to_supertypes: bool, ) -> Self { let mut input = Vec::with_capacity(arguments.len() + 1); @@ -690,8 +700,8 @@ impl Expr { input, function: function_expr, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, - auto_explode, + collect_groups: ApplyOptions::ElementWise, + returns_scalar, cast_to_supertypes, ..Default::default() }, @@ -768,8 +778,8 @@ impl Expr { /// Get the product aggregation of an expression. pub fn product(self) -> Self { let options = FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, - auto_explode: true, + collect_groups: ApplyOptions::GroupWise, + returns_scalar: true, fmt_str: "product", ..Default::default() }; @@ -962,7 +972,7 @@ impl Expr { super_type: DataType::Unknown, }, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, + collect_groups: ApplyOptions::ElementWise, cast_to_supertypes: true, ..Default::default() }, @@ -1019,7 +1029,7 @@ impl Expr { pub fn approx_n_unique(self) -> Self { self.apply_private(FunctionExpr::ApproxNUnique) .with_function_options(|mut options| { - options.auto_explode = true; + options.returns_scalar = true; options }) } @@ -1060,12 +1070,25 @@ impl Expr { let other = other.into(); let has_literal = has_leaf_literal(&other); + // lit(true).is_in() returns a scalar. + let returns_scalar = all_leaf_literal(&self); + let arguments = &[other]; // we don't have to apply on groups, so this is faster if has_literal { - self.map_many_private(BooleanFunction::IsIn.into(), arguments, true, true) + self.map_many_private( + BooleanFunction::IsIn.into(), + arguments, + returns_scalar, + true, + ) } else { - self.apply_many_private(BooleanFunction::IsIn.into(), arguments, true, true) + self.apply_many_private( + BooleanFunction::IsIn.into(), + arguments, + returns_scalar, + true, + ) } } @@ -1558,7 +1581,7 @@ impl Expr { pub fn skew(self, bias: bool) -> Expr { self.apply_private(FunctionExpr::Skew(bias)) .with_function_options(|mut options| { - options.auto_explode = true; + options.returns_scalar = true; options }) } @@ -1574,7 +1597,7 @@ impl Expr { pub fn kurtosis(self, fisher: bool, bias: bool) -> Expr { self.apply_private(FunctionExpr::Kurtosis(fisher, bias)) .with_function_options(|mut options| { - options.auto_explode = true; + options.returns_scalar = true; options }) } @@ -1644,7 +1667,7 @@ impl Expr { pub fn any(self, ignore_nulls: bool) -> Self { self.apply_private(BooleanFunction::Any { ignore_nulls }.into()) .with_function_options(|mut opt| { - opt.auto_explode = true; + opt.returns_scalar = true; opt }) } @@ -1659,7 +1682,7 @@ impl Expr { pub fn all(self, ignore_nulls: bool) -> Self { self.apply_private(BooleanFunction::All { ignore_nulls }.into()) .with_function_options(|mut opt| { - opt.auto_explode = true; + opt.returns_scalar = true; opt }) } @@ -1714,7 +1737,7 @@ impl Expr { pub fn entropy(self, base: f64, normalize: bool) -> Self { self.apply_private(FunctionExpr::Entropy { base, normalize }) .with_function_options(|mut options| { - options.auto_explode = true; + options.returns_scalar = true; options }) } @@ -1722,7 +1745,7 @@ impl Expr { pub fn null_count(self) -> Expr { self.apply_private(FunctionExpr::NullCount) .with_function_options(|mut options| { - options.auto_explode = true; + options.returns_scalar = true; options }) } @@ -1810,7 +1833,7 @@ where function: SpecialEq::new(Arc::new(function)), output_type, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, + collect_groups: ApplyOptions::ElementWise, fmt_str: "", ..Default::default() }, @@ -1837,7 +1860,7 @@ where output_type, options: FunctionOptions { collect_groups: ApplyOptions::ApplyList, - auto_explode: true, + returns_scalar: true, fmt_str: "", ..Default::default() }, @@ -1870,10 +1893,10 @@ where function: SpecialEq::new(Arc::new(function)), output_type, options: FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, // don't set this to true // this is for the caller to decide - auto_explode: returns_scalar, + returns_scalar, fmt_str: "", ..Default::default() }, diff --git a/crates/polars-plan/src/dsl/python_udf.rs b/crates/polars-plan/src/dsl/python_udf.rs index 52d12b06d218..098f71b27adb 100644 --- a/crates/polars-plan/src/dsl/python_udf.rs +++ b/crates/polars-plan/src/dsl/python_udf.rs @@ -205,7 +205,7 @@ impl Expr { let (collect_groups, name) = if agg_list { (ApplyOptions::ApplyList, MAP_LIST_NAME) } else { - (ApplyOptions::ApplyFlat, "python_udf") + (ApplyOptions::ElementWise, "python_udf") }; let return_dtype = func.output_type.clone(); diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index 7e74806d7f31..4f1fcc255020 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -13,7 +13,7 @@ impl StringNameSpace { strict: false, }), &[pat], - true, + false, true, ) } @@ -28,7 +28,7 @@ impl StringNameSpace { strict, }), &[pat], - true, + false, true, ) } @@ -38,7 +38,7 @@ impl StringNameSpace { self.0.map_many_private( FunctionExpr::StringExpr(StringFunction::EndsWith), &[sub], - true, + false, true, ) } @@ -48,7 +48,7 @@ impl StringNameSpace { self.0.map_many_private( FunctionExpr::StringExpr(StringFunction::StartsWith), &[sub], - true, + false, true, ) } @@ -131,7 +131,7 @@ impl StringNameSpace { self.0.map_many_private( StringFunction::CountMatches(literal).into(), &[pat], - true, + false, false, ) } @@ -142,7 +142,7 @@ impl StringNameSpace { self.0.map_many_private( StringFunction::Strptime(dtype, options).into(), &[ambiguous], - true, + false, false, ) } @@ -206,7 +206,8 @@ impl StringNameSpace { self.0 .apply_private(StringFunction::ConcatVertical(delimiter.to_owned()).into()) .with_function_options(|mut options| { - options.auto_explode = true; + options.returns_scalar = true; + options.collect_groups = ApplyOptions::GroupWise; options }) } diff --git a/crates/polars-plan/src/logical_plan/aexpr/mod.rs b/crates/polars-plan/src/logical_plan/aexpr/mod.rs index bed9f63fe32b..9bc134f0c840 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/mod.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/mod.rs @@ -146,6 +146,7 @@ pub enum AExpr { Take { expr: Node, idx: Node, + returns_scalar: bool, }, SortBy { expr: Node, @@ -267,7 +268,7 @@ impl AExpr { }, Cast { expr, .. } => container.push(*expr), Sort { expr, .. } => container.push(*expr), - Take { expr, idx } => { + Take { expr, idx, .. } => { container.push(*idx); // latest, so that it is popped first container.push(*expr); @@ -346,7 +347,7 @@ impl AExpr { *left = inputs[1]; return self; }, - Take { expr, idx } => { + Take { expr, idx, .. } => { *idx = inputs[0]; *expr = inputs[1]; return self; diff --git a/crates/polars-plan/src/logical_plan/conversion.rs b/crates/polars-plan/src/logical_plan/conversion.rs index 2aecd7a693fc..7417c4367059 100644 --- a/crates/polars-plan/src/logical_plan/conversion.rs +++ b/crates/polars-plan/src/logical_plan/conversion.rs @@ -31,9 +31,14 @@ pub fn to_aexpr(expr: Expr, arena: &mut Arena) -> Node { data_type, strict, }, - Expr::Take { expr, idx } => AExpr::Take { + Expr::Take { + expr, + idx, + returns_scalar, + } => AExpr::Take { expr: to_aexpr(*expr, arena), idx: to_aexpr(*idx, arena), + returns_scalar, }, Expr::Sort { expr, options } => AExpr::Sort { expr: to_aexpr(*expr, arena), @@ -399,12 +404,17 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { options, } }, - AExpr::Take { expr, idx } => { + AExpr::Take { + expr, + idx, + returns_scalar, + } => { let expr = node_to_expr(expr, expr_arena); let idx = node_to_expr(idx, expr_arena); Expr::Take { expr: Box::new(expr), idx: Box::new(idx), + returns_scalar, } }, AExpr::SortBy { diff --git a/crates/polars-plan/src/logical_plan/format.rs b/crates/polars-plan/src/logical_plan/format.rs index 4ce57b088cf6..3c98e3f35043 100644 --- a/crates/polars-plan/src/logical_plan/format.rs +++ b/crates/polars-plan/src/logical_plan/format.rs @@ -278,8 +278,16 @@ impl Debug for Expr { Filter { input, by } => { write!(f, "{input:?}.filter({by:?})") }, - Take { expr, idx } => { - write!(f, "{expr:?}.take({idx:?})") + Take { + expr, + idx, + returns_scalar, + } => { + if *returns_scalar { + write!(f, "{expr:?}.get({idx:?})") + } else { + write!(f, "{expr:?}.take({idx:?})") + } }, SubPlan(lf, _) => { write!(f, ".subplan({lf:?})") diff --git a/crates/polars-plan/src/logical_plan/iterator.rs b/crates/polars-plan/src/logical_plan/iterator.rs index 48e40f04dd11..e9577a1efbc3 100644 --- a/crates/polars-plan/src/logical_plan/iterator.rs +++ b/crates/polars-plan/src/logical_plan/iterator.rs @@ -15,7 +15,7 @@ macro_rules! push_expr { }, Cast { expr, .. } => $push(expr), Sort { expr, .. } => $push(expr), - Take { expr, idx } => { + Take { expr, idx, .. } => { $push(idx); $push(expr); }, diff --git a/crates/polars-plan/src/logical_plan/optimizer/fused.rs b/crates/polars-plan/src/logical_plan/optimizer/fused.rs index 982a0b04d6d3..dbe19bef9d06 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/fused.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/fused.rs @@ -4,7 +4,7 @@ pub struct FusedArithmetic {} fn get_expr(input: Vec, op: FusedOperator) -> AExpr { let mut options = FunctionOptions { - collect_groups: ApplyOptions::ApplyFlat, + collect_groups: ApplyOptions::ElementWise, cast_to_supertypes: true, ..Default::default() }; diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs index 9d83ec2edfa4..d9df69fc82c1 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/utils.rs @@ -97,7 +97,7 @@ pub(super) fn predicate_is_sort_boundary(node: Node, expr_arena: &Arena) // group sensitive and doesn't auto-explode (e.g. is a reduction/aggregation // like sum, min, etc). // function that match this are `cumsum`, `shift`, `sort`, etc. - options.is_groups_sensitive() && !options.auto_explode + options.is_groups_sensitive() && !options.returns_scalar }, _ => false, }; @@ -120,8 +120,8 @@ pub(super) fn predicate_is_pushdown_boundary(node: Node, expr_arena: &Arena + if matches!(options.collect_groups, ApplyOptions::ElementWise) => { if let AnonymousFunction { input, diff --git a/crates/polars-plan/src/logical_plan/options.rs b/crates/polars-plan/src/logical_plan/options.rs index 6dc3edbcd16b..1b169ddfddbc 100644 --- a/crates/polars-plan/src/logical_plan/options.rs +++ b/crates/polars-plan/src/logical_plan/options.rs @@ -148,13 +148,13 @@ pub enum ApplyOptions { /// Collect groups to a list and apply the function over the groups. /// This can be important in aggregation context. // e.g. [g1, g1, g2] -> [[g1, g1], g2] - ApplyGroups, + GroupWise, // collect groups to a list and then apply // e.g. [g1, g1, g2] -> list([g1, g1, g2]) ApplyList, // do not collect before apply // e.g. [g1, g1, g2] -> [g1, g1, g2] - ApplyFlat, + ElementWise, } // a boolean that can only be set to `false` safely @@ -191,7 +191,7 @@ pub struct FunctionOptions { /// /// this also accounts for regex expansion pub input_wildcard_expansion: bool, - /// automatically explode on unit length it ran as final aggregation. + /// Automatically explode on unit length if it ran as final aggregation. /// /// this is the case for aggregations like sum, min, covariance etc. /// We need to know this because we cannot see the difference between @@ -201,7 +201,7 @@ pub struct FunctionOptions { /// /// head_1(x) -> {1} /// sum(x) -> {4} - pub auto_explode: bool, + pub returns_scalar: bool, // if the expression and its inputs should be cast to supertypes pub cast_to_supertypes: bool, // The physical expression may rename the output of this function. @@ -225,7 +225,7 @@ impl FunctionOptions { /// - Sorts /// - Counts pub fn is_groups_sensitive(&self) -> bool { - matches!(self.collect_groups, ApplyOptions::ApplyGroups) + matches!(self.collect_groups, ApplyOptions::GroupWise) } #[cfg(feature = "fused")] @@ -240,9 +240,9 @@ impl FunctionOptions { impl Default for FunctionOptions { fn default() -> Self { FunctionOptions { - collect_groups: ApplyOptions::ApplyGroups, + collect_groups: ApplyOptions::GroupWise, input_wildcard_expansion: false, - auto_explode: false, + returns_scalar: false, fmt_str: "", cast_to_supertypes: false, allow_rename: false, diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index 54644dff2bb8..70ee83d05f9a 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -99,7 +99,7 @@ pub(crate) fn aexpr_is_elementwise(current_node: Node, arena: &Arena) -> use AExpr::*; match e { AnonymousFunction { options, .. } | Function { options, .. } => { - !matches!(options.collect_groups, ApplyOptions::ApplyGroups) + !matches!(options.collect_groups, ApplyOptions::GroupWise) }, Column(_) | Alias(_, _) @@ -147,6 +147,17 @@ pub(crate) fn has_leaf_literal(e: &Expr) -> bool { }, } } +/// Check if leaf expression is a literal +#[cfg(feature = "is_in")] +pub(crate) fn all_leaf_literal(e: &Expr) -> bool { + match e { + Expr::Literal(_) => true, + _ => { + let roots = expr_to_root_column_exprs(e); + roots.iter().all(|e| matches!(e, Expr::Literal(_))) + }, + } +} pub fn has_null(current_expr: &Expr) -> bool { has_expr(current_expr, |e| { diff --git a/py-polars/docs/source/reference/expressions/modify_select.rst b/py-polars/docs/source/reference/expressions/modify_select.rst index d40fffd24a7f..3514a44ae087 100644 --- a/py-polars/docs/source/reference/expressions/modify_select.rst +++ b/py-polars/docs/source/reference/expressions/modify_select.rst @@ -27,6 +27,7 @@ Manipulation/selection Expr.flatten Expr.floor Expr.forward_fill + Expr.get Expr.head Expr.inspect Expr.interpolate diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 542a75cad7c4..988fd08de9e4 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -2331,7 +2331,60 @@ def take( ... "value": [1, 98, 2, 3, 99, 4], ... } ... ) - >>> df.group_by("group", maintain_order=True).agg(pl.col("value").take(1)) + >>> df.group_by("group", maintain_order=True).agg(pl.col("value").take([2, 1])) + shape: (2, 2) + ┌───────┬───────────┐ + │ group ┆ value │ + │ --- ┆ --- │ + │ str ┆ list[i64] │ + ╞═══════╪═══════════╡ + │ one ┆ [2, 98] │ + │ two ┆ [4, 99] │ + └───────┴───────────┘ + + See Also + -------- + Expr.get : Take a single value + + """ + if isinstance(indices, list) or ( + _check_for_numpy(indices) and isinstance(indices, np.ndarray) + ): + indices_lit = F.lit(pl.Series("", indices, dtype=UInt32))._pyexpr + else: + indices_lit = parse_as_expression(indices) # type: ignore[arg-type] + return self._from_pyexpr(self._pyexpr.take(indices_lit)) + + def get(self, index: int | Expr) -> Self: + """ + Return a single value by index. + + Parameters + ---------- + index + An expression that leads to a UInt32 index. + + Returns + ------- + Expr + Expression of the same data type. + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "group": [ + ... "one", + ... "one", + ... "one", + ... "two", + ... "two", + ... "two", + ... ], + ... "value": [1, 98, 2, 3, 99, 4], + ... } + ... ) + >>> df.group_by("group", maintain_order=True).agg(pl.col("value").get(1)) shape: (2, 2) ┌───────┬───────┐ │ group ┆ value │ @@ -2343,13 +2396,8 @@ def take( └───────┴───────┘ """ - if isinstance(indices, list) or ( - _check_for_numpy(indices) and isinstance(indices, np.ndarray) - ): - indices_lit = F.lit(pl.Series("", indices, dtype=UInt32))._pyexpr - else: - indices_lit = parse_as_expression(indices) # type: ignore[arg-type] - return self._from_pyexpr(self._pyexpr.take(indices_lit)) + index_lit = parse_as_expression(index) + return self._from_pyexpr(self._pyexpr.get(index_lit)) @deprecate_renamed_parameter("periods", "n", version="0.19.11") def shift(self, n: int = 1) -> Self: @@ -9512,7 +9560,7 @@ def is_last(self) -> Self: """ return self.is_last_distinct() - def _register_plugin( + def register_plugin( self, *, lib: str, @@ -9521,7 +9569,7 @@ def _register_plugin( kwargs: dict[Any, Any] | None = None, is_elementwise: bool = False, input_wildcard_expansion: bool = False, - auto_explode: bool = False, + returns_scalar: bool = False, cast_to_supertypes: bool = False, ) -> Self: """ @@ -9551,9 +9599,10 @@ def _register_plugin( this will trigger fast paths. input_wildcard_expansion Expand expressions as input of this function. - auto_explode - Explode the results in a group_by. - This is recommended for aggregation functions. + returns_scalar + Automatically explode on unit length if it ran as final aggregation. + this is the case for aggregations like ``sum``, ``min``, ``covariance`` etc. + cast_to_supertypes Cast the input datatypes to their supertype. @@ -9567,7 +9616,8 @@ def _register_plugin( else: import pickle - serialized_kwargs = pickle.dumps(kwargs, protocol=2) + # Choose the highest protocol supported by https://docs.rs/serde-pickle/latest/serde_pickle/ + serialized_kwargs = pickle.dumps(kwargs, protocol=5) return self._from_pyexpr( self._pyexpr.register_plugin( @@ -9577,11 +9627,35 @@ def _register_plugin( serialized_kwargs, is_elementwise, input_wildcard_expansion, - auto_explode, + returns_scalar, cast_to_supertypes, ) ) + @deprecate_renamed_function("register_plugin", version="0.19.12") + def _register_plugin( + self, + *, + lib: str, + symbol: str, + args: list[IntoExpr] | None = None, + kwargs: dict[Any, Any] | None = None, + is_elementwise: bool = False, + input_wildcard_expansion: bool = False, + auto_explode: bool = False, + cast_to_supertypes: bool = False, + ) -> Self: + return self.register_plugin( + lib=lib, + symbol=symbol, + args=args, + kwargs=kwargs, + is_elementwise=is_elementwise, + input_wildcard_expansion=input_wildcard_expansion, + returns_scalar=auto_explode, + cast_to_supertypes=cast_to_supertypes, + ) + @property def bin(self) -> ExprBinaryNameSpace: """ diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index 951ad3917eae..4e822b3a5d94 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -322,6 +322,10 @@ impl PyExpr { self.clone().inner.take(idx.inner).into() } + fn get(&self, idx: Self) -> Self { + self.clone().inner.get(idx.inner).into() + } + fn sort_by(&self, by: Vec, descending: Vec) -> Self { let by = by.into_iter().map(|e| e.inner).collect::>(); self.clone().inner.sort_by(by, descending).into() @@ -895,16 +899,16 @@ impl PyExpr { kwargs: Vec, is_elementwise: bool, input_wildcard_expansion: bool, - auto_explode: bool, + returns_scalar: bool, cast_to_supertypes: bool, ) -> PyResult { use polars_plan::prelude::*; let inner = self.inner.clone(); let collect_groups = if is_elementwise { - ApplyOptions::ApplyFlat + ApplyOptions::ElementWise } else { - ApplyOptions::ApplyGroups + ApplyOptions::GroupWise }; let mut input = Vec::with_capacity(args.len() + 1); input.push(inner); @@ -922,7 +926,7 @@ impl PyExpr { options: FunctionOptions { collect_groups, input_wildcard_expansion, - auto_explode, + returns_scalar, cast_to_supertypes, ..Default::default() }, diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 6988322dd67d..705f6bce069b 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -420,7 +420,7 @@ def test_take_misc(fruits_cars: pl.DataFrame) -> None: assert out[4, "B"].to_list() == [1, 4] out = df.sort("fruits").select( - [pl.col("B").reverse().take(pl.lit(1)).over("fruits"), "fruits"] + [pl.col("B").reverse().get(pl.lit(1)).over("fruits"), "fruits"] ) assert out[0, "B"] == 3 assert out[4, "B"] == 4 diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/test_list.py index 89ceff4e6a0a..5fb543d5119a 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/test_list.py @@ -447,15 +447,19 @@ def test_list_function_group_awareness() -> None: assert df.group_by("group").agg( [ - pl.col("a").implode().list.get(0).alias("get"), - pl.col("a").implode().list.take([0]).alias("take"), - pl.col("a").implode().list.slice(0, 3).alias("slice"), + pl.col("a").get(0).alias("get_scalar"), + pl.col("a").take([0]).alias("take_no_implode"), + pl.col("a").implode().list.get(0).alias("implode_get"), + pl.col("a").implode().list.take([0]).alias("implode_take"), + pl.col("a").implode().list.slice(0, 3).alias("implode_slice"), ] ).sort("group").to_dict(False) == { "group": [0, 1, 2], - "get": [100, 105, 100], - "take": [[100], [105], [100]], - "slice": [[100, 103], [105, 106, 105], [100, 102]], + "get_scalar": [100, 105, 100], + "take_no_implode": [[100], [105], [100]], + "implode_get": [[100], [105], [100]], + "implode_take": [[[100]], [[105]], [[100]]], + "implode_slice": [[[100, 103]], [[105, 106, 105]], [[100, 102]]], } diff --git a/py-polars/tests/unit/operations/test_aggregations.py b/py-polars/tests/unit/operations/test_aggregations.py index 0cd22d787120..3d4c5a5a38d2 100644 --- a/py-polars/tests/unit/operations/test_aggregations.py +++ b/py-polars/tests/unit/operations/test_aggregations.py @@ -255,7 +255,7 @@ def test_err_on_implode_and_agg() -> None: pl.col("type").implode().list.head().alias("foo") ).to_dict(False) == { "type": ["water", "fire", "earth"], - "foo": [["water", "water"], ["fire"], ["earth"]], + "foo": [[["water", "water"]], [["fire"]], [["earth"]]], } # but not during a window function as the groups cannot be mapped back diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 1902459bb05a..448bddac5ef3 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -269,7 +269,7 @@ def test_apply_after_take_in_group_by_3869() -> None: ) .group_by("k", maintain_order=True) .agg( - pl.col("v").take(pl.col("t").arg_max()).sqrt() + pl.col("v").get(pl.col("t").arg_max()).sqrt() ) # <- fails for sqrt, exp, log, pow, etc. ).to_dict(False) == {"k": ["a", "b"], "v": [1.4142135623730951, 2.0]} @@ -387,7 +387,7 @@ def test_group_by_dynamic_overlapping_groups_flat_apply_multiple_5038( def test_take_in_group_by() -> None: df = pl.DataFrame({"group": [1, 1, 1, 2, 2, 2], "values": [10, 200, 3, 40, 500, 6]}) assert df.group_by("group").agg( - pl.col("values").take(1) - pl.col("values").take(2) + pl.col("values").get(1) - pl.col("values").get(2) ).sort("group").to_dict(False) == {"group": [1, 2], "values": [197, 494]}