Skip to content

Commit

Permalink
fix take/get relation
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 23, 2023
1 parent 8eba908 commit 6b824ec
Show file tree
Hide file tree
Showing 24 changed files with 190 additions and 53 deletions.
1 change: 1 addition & 0 deletions crates/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ impl ApplyExpr {
allow_threading: bool,
input_schema: Option<SchemaRef>,
) -> Self {
#[cfg(debug_assertions)]
if matches!(options.collect_groups, ApplyOptions::ElementWise) && options.returns_scalar {
panic!("expr {} is not implemented correctly. 'returns_scalar' and 'elementwise' are mutually exclusive", expr)
}
Expand Down
21 changes: 20 additions & 1 deletion crates/polars-lazy/src/physical_plan/expressions/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub struct TakeExpr {
pub(crate) phys_expr: Arc<dyn PhysicalExpr>,
pub(crate) idx: Arc<dyn PhysicalExpr>,
pub(crate) expr: Expr,
pub(crate) returns_scalar: bool,
}

impl TakeExpr {
Expand Down Expand Up @@ -101,12 +102,23 @@ impl PhysicalExpr for TakeExpr {
},
};
let taken = ac.flat_naive().take(&idx)?;

let taken = if self.returns_scalar {
taken
} else {
taken.as_list().into_series()
};

ac.with_series(taken, true, Some(&self.expr))?;
return Ok(ac);
},
AggState::AggregatedList(s) => s.list().unwrap().clone(),
AggState::AggregatedList(s) => {
polars_ensure!(!self.returns_scalar, ComputeError: "expected single index");
s.list().unwrap().clone()
},
// Maybe a literal as well, this needs a different path.
AggState::NotAggregated(_) => {
polars_ensure!(!self.returns_scalar, ComputeError: "expected single index");
let s = idx.aggregated();
s.list().unwrap().clone()
},
Expand Down Expand Up @@ -144,6 +156,13 @@ impl PhysicalExpr for TakeExpr {
},
};
let taken = ac.flat_naive().take(&idx.into_inner())?;

let taken = if self.returns_scalar {
taken
} else {
taken.as_list().into_series()
};

ac.with_series(taken, true, Some(&self.expr))?;
ac.with_update_groups(UpdateGroups::WithGroupsLen);
Ok(ac)
Expand Down
7 changes: 6 additions & 1 deletion crates/polars-lazy/src/physical_plan/planner/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,18 @@ pub(crate) fn create_physical_expr(
node_to_expr(expression, expr_arena),
)))
},
Take { expr, idx } => {
Take {
expr,
idx,
returns_scalar,
} => {
let phys_expr = create_physical_expr(expr, ctxt, expr_arena, schema, state)?;
let phys_idx = create_physical_expr(idx, ctxt, expr_arena, schema, state)?;
Ok(Arc::new(TakeExpr {
phys_expr,
idx: phys_idx,
expr: node_to_expr(expression, expr_arena),
returns_scalar,
}))
},
SortBy {
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-plan/src/dsl/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ impl BinaryNameSpace {
self.0.map_many_private(
FunctionExpr::BinaryExpr(BinaryFunction::Contains),
&[pat],
true,
false,
true,
)
}
Expand All @@ -19,7 +19,7 @@ impl BinaryNameSpace {
self.0.map_many_private(
FunctionExpr::BinaryExpr(BinaryFunction::EndsWith),
&[sub],
true,
false,
true,
)
}
Expand All @@ -29,7 +29,7 @@ impl BinaryNameSpace {
self.0.map_many_private(
FunctionExpr::BinaryExpr(BinaryFunction::StartsWith),
&[sub],
true,
false,
true,
)
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/dt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl DateLikeNameSpace {
self.0.map_many_private(
FunctionExpr::TemporalExpr(TemporalFunction::Truncate(offset)),
&[every, ambiguous],
true,
false,
false,
)
}
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ pub enum Expr {
Take {
expr: Box<Expr>,
idx: Box<Expr>,
returns_scalar: bool,
},
SortBy {
expr: Box<Expr>,
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/functions/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub fn concat_str<E: AsRef<[Expr]>>(s: E, separator: &str) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: true,
returns_scalar: false,
..Default::default()
},
}
Expand Down
10 changes: 5 additions & 5 deletions crates/polars-plan/src/dsl/functions/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ pub fn all_horizontal<E: AsRef<[Expr]>>(exprs: E) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: true,
returns_scalar: false,
cast_to_supertypes: false,
allow_rename: true,
..Default::default()
Expand All @@ -220,7 +220,7 @@ pub fn any_horizontal<E: AsRef<[Expr]>>(exprs: E) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: true,
returns_scalar: false,
cast_to_supertypes: false,
allow_rename: true,
..Default::default()
Expand All @@ -243,7 +243,7 @@ pub fn max_horizontal<E: AsRef<[Expr]>>(exprs: E) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: true,
returns_scalar: false,
allow_rename: true,
..Default::default()
},
Expand All @@ -265,7 +265,7 @@ pub fn min_horizontal<E: AsRef<[Expr]>>(exprs: E) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: true,
returns_scalar: false,
allow_rename: true,
..Default::default()
},
Expand All @@ -284,7 +284,7 @@ pub fn sum_horizontal<E: AsRef<[Expr]>>(exprs: E) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: true,
returns_scalar: false,
cast_to_supertypes: false,
allow_rename: true,
..Default::default()
Expand Down
10 changes: 5 additions & 5 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ impl ListNameSpace {
self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::Get),
&[index],
true,
false,
false,
)
}
Expand All @@ -152,7 +152,7 @@ impl ListNameSpace {
self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::Take(null_on_oob)),
&[index],
true,
false,
false,
)
}
Expand Down Expand Up @@ -216,7 +216,7 @@ impl ListNameSpace {
self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::Slice),
&[offset, length],
true,
false,
false,
)
}
Expand Down Expand Up @@ -296,7 +296,7 @@ impl ListNameSpace {
.map_many_private(
FunctionExpr::ListExpr(ListFunction::Contains),
&[other],
true,
false,
false,
)
.with_function_options(|mut options| {
Expand All @@ -313,7 +313,7 @@ impl ListNameSpace {
.map_many_private(
FunctionExpr::ListExpr(ListFunction::CountMatches),
&[other],
true,
false,
false,
)
.with_function_options(|mut options| {
Expand Down
31 changes: 27 additions & 4 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,16 @@ impl Expr {
Expr::Take {
expr: Box::new(self),
idx: Box::new(idx.into()),
returns_scalar: false,
}
}

/// Take the values by a single index.
pub fn get<E: Into<Expr>>(self, idx: E) -> Self {
Expr::Take {
expr: Box::new(self),
idx: Box::new(idx.into()),
returns_scalar: true,
}
}

Expand Down Expand Up @@ -679,7 +689,7 @@ impl Expr {
self,
function_expr: FunctionExpr,
arguments: &[Expr],
auto_explode: bool,
returns_scalar: bool,
cast_to_supertypes: bool,
) -> Self {
let mut input = Vec::with_capacity(arguments.len() + 1);
Expand All @@ -691,7 +701,7 @@ impl Expr {
function: function_expr,
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
returns_scalar: auto_explode,
returns_scalar,
cast_to_supertypes,
..Default::default()
},
Expand Down Expand Up @@ -1060,12 +1070,25 @@ impl Expr {
let other = other.into();
let has_literal = has_leaf_literal(&other);

// lit(true).is_in() returns a scalar.
let returns_scalar = all_leaf_literal(&self);

let arguments = &[other];
// we don't have to apply on groups, so this is faster
if has_literal {
self.map_many_private(BooleanFunction::IsIn.into(), arguments, true, true)
self.map_many_private(
BooleanFunction::IsIn.into(),
arguments,
returns_scalar,
true,
)
} else {
self.apply_many_private(BooleanFunction::IsIn.into(), arguments, true, true)
self.apply_many_private(
BooleanFunction::IsIn.into(),
arguments,
returns_scalar,
true,
)
}
}

Expand Down
13 changes: 7 additions & 6 deletions crates/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ impl StringNameSpace {
strict: false,
}),
&[pat],
true,
false,
true,
)
}
Expand All @@ -28,7 +28,7 @@ impl StringNameSpace {
strict,
}),
&[pat],
true,
false,
true,
)
}
Expand All @@ -38,7 +38,7 @@ impl StringNameSpace {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::EndsWith),
&[sub],
true,
false,
true,
)
}
Expand All @@ -48,7 +48,7 @@ impl StringNameSpace {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::StartsWith),
&[sub],
true,
false,
true,
)
}
Expand Down Expand Up @@ -131,7 +131,7 @@ impl StringNameSpace {
self.0.map_many_private(
StringFunction::CountMatches(literal).into(),
&[pat],
true,
false,
false,
)
}
Expand All @@ -142,7 +142,7 @@ impl StringNameSpace {
self.0.map_many_private(
StringFunction::Strptime(dtype, options).into(),
&[ambiguous],
true,
false,
false,
)
}
Expand Down Expand Up @@ -207,6 +207,7 @@ impl StringNameSpace {
.apply_private(StringFunction::ConcatVertical(delimiter.to_owned()).into())
.with_function_options(|mut options| {
options.returns_scalar = true;
options.collect_groups = ApplyOptions::GroupWise;
options
})
}
Expand Down
5 changes: 3 additions & 2 deletions crates/polars-plan/src/logical_plan/aexpr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ pub enum AExpr {
Take {
expr: Node,
idx: Node,
returns_scalar: bool,
},
SortBy {
expr: Node,
Expand Down Expand Up @@ -267,7 +268,7 @@ impl AExpr {
},
Cast { expr, .. } => container.push(*expr),
Sort { expr, .. } => container.push(*expr),
Take { expr, idx } => {
Take { expr, idx, .. } => {
container.push(*idx);
// latest, so that it is popped first
container.push(*expr);
Expand Down Expand Up @@ -346,7 +347,7 @@ impl AExpr {
*left = inputs[1];
return self;
},
Take { expr, idx } => {
Take { expr, idx, .. } => {
*idx = inputs[0];
*expr = inputs[1];
return self;
Expand Down
Loading

0 comments on commit 6b824ec

Please sign in to comment.