Skip to content

Commit

Permalink
fix: fix take return dtype in group context. (#11949)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Oct 23, 2023
1 parent cd612c9 commit ce3dd72
Show file tree
Hide file tree
Showing 38 changed files with 390 additions and 213 deletions.
66 changes: 47 additions & 19 deletions crates/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,48 @@ use crate::physical_plan::state::ExecutionState;
use crate::prelude::*;

pub struct ApplyExpr {
pub inputs: Vec<Arc<dyn PhysicalExpr>>,
pub function: SpecialEq<Arc<dyn SeriesUdf>>,
pub expr: Expr,
pub collect_groups: ApplyOptions,
pub auto_explode: bool,
pub allow_rename: bool,
pub pass_name_to_apply: bool,
pub input_schema: Option<SchemaRef>,
pub allow_threading: bool,
pub check_lengths: bool,
pub allow_group_aware: bool,
inputs: Vec<Arc<dyn PhysicalExpr>>,
function: SpecialEq<Arc<dyn SeriesUdf>>,
expr: Expr,
collect_groups: ApplyOptions,
returns_scalar: bool,
allow_rename: bool,
pass_name_to_apply: bool,
input_schema: Option<SchemaRef>,
allow_threading: bool,
check_lengths: bool,
allow_group_aware: bool,
}

impl ApplyExpr {
pub(crate) fn new(
inputs: Vec<Arc<dyn PhysicalExpr>>,
function: SpecialEq<Arc<dyn SeriesUdf>>,
expr: Expr,
options: FunctionOptions,
allow_threading: bool,
input_schema: Option<SchemaRef>,
) -> Self {
#[cfg(debug_assertions)]
if matches!(options.collect_groups, ApplyOptions::ElementWise) && options.returns_scalar {
panic!("expr {} is not implemented correctly. 'returns_scalar' and 'elementwise' are mutually exclusive", expr)
}

Self {
inputs,
function,
expr,
collect_groups: options.collect_groups,
returns_scalar: options.returns_scalar,
allow_rename: options.allow_rename,
pass_name_to_apply: options.pass_name_to_apply,
input_schema,
allow_threading,
check_lengths: options.check_lengths(),
allow_group_aware: options.allow_group_aware,
}
}

pub(crate) fn new_minimal(
inputs: Vec<Arc<dyn PhysicalExpr>>,
function: SpecialEq<Arc<dyn SeriesUdf>>,
Expand All @@ -39,7 +67,7 @@ impl ApplyExpr {
function,
expr,
collect_groups,
auto_explode: false,
returns_scalar: false,
allow_rename: false,
pass_name_to_apply: false,
input_schema: None,
Expand Down Expand Up @@ -70,7 +98,7 @@ impl ApplyExpr {
ca: ListChunked,
) -> PolarsResult<AggregationContext<'a>> {
let all_unit_len = all_unit_length(&ca);
if all_unit_len && self.auto_explode {
if all_unit_len && self.returns_scalar {
ac.with_series(ca.explode().unwrap().into_series(), true, Some(&self.expr))?;
ac.update_groups = UpdateGroups::No;
} else {
Expand Down Expand Up @@ -289,8 +317,8 @@ impl PhysicalExpr for ApplyExpr {
ac.with_series(s, true, Some(&self.expr))?;
Ok(ac)
},
ApplyOptions::ApplyGroups => self.apply_single_group_aware(ac),
ApplyOptions::ApplyFlat => self.apply_single_elementwise(ac),
ApplyOptions::GroupWise => self.apply_single_group_aware(ac),
ApplyOptions::ElementWise => self.apply_single_elementwise(ac),
}
} else {
let mut acs = self.prepare_multiple_inputs(df, groups, state)?;
Expand All @@ -305,8 +333,8 @@ impl PhysicalExpr for ApplyExpr {
ac.with_series(s, true, Some(&self.expr))?;
Ok(ac)
},
ApplyOptions::ApplyGroups => self.apply_multiple_group_aware(acs, df),
ApplyOptions::ApplyFlat => {
ApplyOptions::GroupWise => self.apply_multiple_group_aware(acs, df),
ApplyOptions::ElementWise => {
if acs
.iter()
.any(|ac| matches!(ac.agg_state(), AggState::AggregatedList(_)))
Expand All @@ -328,7 +356,7 @@ impl PhysicalExpr for ApplyExpr {
self.expr.to_field(input_schema, Context::Default)
}
fn is_valid_aggregation(&self) -> bool {
matches!(self.collect_groups, ApplyOptions::ApplyGroups)
matches!(self.collect_groups, ApplyOptions::GroupWise)
}
#[cfg(feature = "parquet")]
fn as_stats_evaluator(&self) -> Option<&dyn polars_io::predicates::StatsEvaluator> {
Expand All @@ -345,7 +373,7 @@ impl PhysicalExpr for ApplyExpr {
}
}
fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
if self.inputs.len() == 1 && matches!(self.collect_groups, ApplyOptions::ApplyFlat) {
if self.inputs.len() == 1 && matches!(self.collect_groups, ApplyOptions::ElementWise) {
Some(self)
} else {
None
Expand Down
21 changes: 20 additions & 1 deletion crates/polars-lazy/src/physical_plan/expressions/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub struct TakeExpr {
pub(crate) phys_expr: Arc<dyn PhysicalExpr>,
pub(crate) idx: Arc<dyn PhysicalExpr>,
pub(crate) expr: Expr,
pub(crate) returns_scalar: bool,
}

impl TakeExpr {
Expand Down Expand Up @@ -101,12 +102,23 @@ impl PhysicalExpr for TakeExpr {
},
};
let taken = ac.flat_naive().take(&idx)?;

let taken = if self.returns_scalar {
taken
} else {
taken.as_list().into_series()
};

ac.with_series(taken, true, Some(&self.expr))?;
return Ok(ac);
},
AggState::AggregatedList(s) => s.list().unwrap().clone(),
AggState::AggregatedList(s) => {
polars_ensure!(!self.returns_scalar, ComputeError: "expected single index");
s.list().unwrap().clone()
},
// Maybe a literal as well, this needs a different path.
AggState::NotAggregated(_) => {
polars_ensure!(!self.returns_scalar, ComputeError: "expected single index");
let s = idx.aggregated();
s.list().unwrap().clone()
},
Expand Down Expand Up @@ -144,6 +156,13 @@ impl PhysicalExpr for TakeExpr {
},
};
let taken = ac.flat_naive().take(&idx.into_inner())?;

let taken = if self.returns_scalar {
taken
} else {
taken.as_list().into_series()
};

ac.with_series(taken, true, Some(&self.expr))?;
ac.with_update_groups(UpdateGroups::WithGroupsLen);
Ok(ac)
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/physical_plan/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ impl WindowExpr {
},
Expr::Function { options, .. }
| Expr::AnonymousFunction { options, .. } => {
if options.auto_explode
&& matches!(options.collect_groups, ApplyOptions::ApplyGroups)
if options.returns_scalar
&& matches!(options.collect_groups, ApplyOptions::GroupWise)
{
agg_col = true;
}
Expand Down
55 changes: 25 additions & 30 deletions crates/polars-lazy/src/physical_plan/planner/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,18 @@ pub(crate) fn create_physical_expr(
node_to_expr(expression, expr_arena),
)))
},
Take { expr, idx } => {
Take {
expr,
idx,
returns_scalar,
} => {
let phys_expr = create_physical_expr(expr, ctxt, expr_arena, schema, state)?;
let phys_idx = create_physical_expr(idx, ctxt, expr_arena, schema, state)?;
Ok(Arc::new(TakeExpr {
phys_expr,
idx: phys_idx,
expr: node_to_expr(expression, expr_arena),
returns_scalar,
}))
},
SortBy {
Expand Down Expand Up @@ -391,7 +396,7 @@ pub(crate) fn create_physical_expr(
vec![input],
function,
node_to_expr(expression, expr_arena),
ApplyOptions::ApplyFlat,
ApplyOptions::ElementWise,
)))
},
_ => {
Expand Down Expand Up @@ -463,7 +468,7 @@ pub(crate) fn create_physical_expr(
options,
} => {
let is_reducing_aggregation =
options.auto_explode && matches!(options.collect_groups, ApplyOptions::ApplyGroups);
options.returns_scalar && matches!(options.collect_groups, ApplyOptions::GroupWise);
// will be reset in the function so get that here
let has_window = state.local.has_window;
let input = create_physical_expressions_check_state(
Expand All @@ -478,19 +483,14 @@ pub(crate) fn create_physical_expr(
},
)?;

Ok(Arc::new(ApplyExpr {
inputs: input,
Ok(Arc::new(ApplyExpr::new(
input,
function,
expr: node_to_expr(expression, expr_arena),
collect_groups: options.collect_groups,
auto_explode: options.auto_explode,
allow_rename: options.allow_rename,
pass_name_to_apply: options.pass_name_to_apply,
input_schema: schema.cloned(),
allow_threading: !state.has_cache,
check_lengths: options.check_lengths(),
allow_group_aware: options.allow_group_aware,
}))
node_to_expr(expression, expr_arena),
options,
!state.has_cache,
schema.cloned(),
)))
},
Function {
input,
Expand All @@ -499,7 +499,7 @@ pub(crate) fn create_physical_expr(
..
} => {
let is_reducing_aggregation =
options.auto_explode && matches!(options.collect_groups, ApplyOptions::ApplyGroups);
options.returns_scalar && matches!(options.collect_groups, ApplyOptions::GroupWise);
// will be reset in the function so get that here
let has_window = state.local.has_window;
let input = create_physical_expressions_check_state(
Expand All @@ -514,19 +514,14 @@ pub(crate) fn create_physical_expr(
},
)?;

Ok(Arc::new(ApplyExpr {
inputs: input,
function: function.into(),
expr: node_to_expr(expression, expr_arena),
collect_groups: options.collect_groups,
auto_explode: options.auto_explode,
allow_rename: options.allow_rename,
pass_name_to_apply: options.pass_name_to_apply,
input_schema: schema.cloned(),
allow_threading: !state.has_cache,
check_lengths: options.check_lengths(),
allow_group_aware: options.allow_group_aware,
}))
Ok(Arc::new(ApplyExpr::new(
input,
function.into(),
node_to_expr(expression, expr_arena),
options,
!state.has_cache,
schema.cloned(),
)))
},
Slice {
input,
Expand All @@ -553,7 +548,7 @@ pub(crate) fn create_physical_expr(
vec![input],
function,
node_to_expr(expression, expr_arena),
ApplyOptions::ApplyGroups,
ApplyOptions::GroupWise,
)))
},
Wildcard => panic!("should be no wildcard at this point"),
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/physical_plan/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ fn partitionable_gb(
)
},
Function {input, options, ..} => {
matches!(options.collect_groups, ApplyOptions::ApplyFlat) && input.len() == 1 &&
matches!(options.collect_groups, ApplyOptions::ElementWise) && input.len() == 1 &&
!has_aggregation(input[0])
}
BinaryExpr {left, right, ..} => {
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/physical_plan/streaming/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ pub(super) fn is_streamable(node: Node, expr_arena: &Arena<AExpr>, context: Cont
{
Context::Default => matches!(
options.collect_groups,
ApplyOptions::ApplyFlat | ApplyOptions::ApplyList
ApplyOptions::ElementWise | ApplyOptions::ApplyList
),
Context::Aggregation => matches!(options.collect_groups, ApplyOptions::ApplyFlat),
Context::Aggregation => matches!(options.collect_groups, ApplyOptions::ElementWise),
},
AExpr::Column(_) => {
seen_column = true;
Expand Down
19 changes: 8 additions & 11 deletions crates/polars-lazy/src/tests/aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ fn take_aggregations() -> PolarsResult<()> {
.clone()
.lazy()
.group_by([col("user")])
.agg([col("book").take(col("count").arg_max()).alias("fav_book")])
.agg([col("book").get(col("count").arg_max()).alias("fav_book")])
.sort("user", Default::default())
.collect()?;

Expand Down Expand Up @@ -460,7 +460,7 @@ fn take_aggregations() -> PolarsResult<()> {
let out = df
.lazy()
.group_by([col("user")])
.agg([col("book").take(lit(0)).alias("take_lit")])
.agg([col("book").get(lit(0)).alias("take_lit")])
.sort("user", Default::default())
.collect()?;

Expand All @@ -484,7 +484,7 @@ fn test_take_consistency() -> PolarsResult<()> {
multithreaded: true,
maintain_order: false,
})
.take(lit(0))])
.get(lit(0))])
.collect()?;

let a = out.column("A")?;
Expand All @@ -502,7 +502,7 @@ fn test_take_consistency() -> PolarsResult<()> {
multithreaded: true,
maintain_order: false,
})
.take(lit(0))])
.get(lit(0))])
.collect()?;

let out = out.column("A")?;
Expand All @@ -521,18 +521,18 @@ fn test_take_consistency() -> PolarsResult<()> {
multithreaded: true,
maintain_order: false,
})
.take(lit(0))
.get(lit(0))
.alias("1"),
col("A")
.take(
.get(
col("A")
.arg_sort(SortOptions {
descending: true,
nulls_last: false,
multithreaded: true,
maintain_order: false,
})
.take(lit(0)),
.get(lit(0)),
)
.alias("2"),
])
Expand All @@ -556,10 +556,7 @@ fn test_take_in_groups() -> PolarsResult<()> {
let out = df
.lazy()
.sort("fruits", Default::default())
.select([col("B")
.take(lit(Series::new("", &[0u32])))
.over([col("fruits")])
.alias("taken")])
.select([col("B").get(lit(0u32)).over([col("fruits")]).alias("taken")])
.collect()?;

assert_eq!(
Expand Down
Loading

0 comments on commit ce3dd72

Please sign in to comment.