From 6b824ecd3a9a54eb5b0d1e87486c242aa9f3bbff Mon Sep 17 00:00:00 2001 From: ritchie Date: Mon, 23 Oct 2023 13:39:29 +0200 Subject: [PATCH] fix take/get relation --- .../src/physical_plan/expressions/apply.rs | 1 + .../src/physical_plan/expressions/take.rs | 21 +++++- .../src/physical_plan/planner/expr.rs | 7 +- crates/polars-plan/src/dsl/binary.rs | 6 +- crates/polars-plan/src/dsl/dt.rs | 2 +- crates/polars-plan/src/dsl/expr.rs | 1 + .../polars-plan/src/dsl/functions/concat.rs | 2 +- .../src/dsl/functions/horizontal.rs | 10 +-- crates/polars-plan/src/dsl/list.rs | 10 +-- crates/polars-plan/src/dsl/mod.rs | 31 +++++++-- crates/polars-plan/src/dsl/string.rs | 13 ++-- .../polars-plan/src/logical_plan/aexpr/mod.rs | 5 +- .../src/logical_plan/conversion.rs | 14 +++- crates/polars-plan/src/logical_plan/format.rs | 12 +++- .../polars-plan/src/logical_plan/iterator.rs | 2 +- .../logical_plan/optimizer/simplify_expr.rs | 2 +- crates/polars-plan/src/utils.rs | 11 ++++ .../reference/expressions/modify_select.rst | 1 + py-polars/polars/expr/expr.py | 64 ++++++++++++++++--- py-polars/src/expr/general.rs | 4 ++ py-polars/tests/unit/dataframe/test_df.py | 2 +- py-polars/tests/unit/namespaces/test_list.py | 16 +++-- .../unit/operations/test_aggregations.py | 2 +- .../tests/unit/operations/test_group_by.py | 4 +- 24 files changed, 190 insertions(+), 53 deletions(-) diff --git a/crates/polars-lazy/src/physical_plan/expressions/apply.rs b/crates/polars-lazy/src/physical_plan/expressions/apply.rs index e0c3f329011b..95734356a988 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/apply.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/apply.rs @@ -36,6 +36,7 @@ impl ApplyExpr { allow_threading: bool, input_schema: Option, ) -> Self { + #[cfg(debug_assertions)] if matches!(options.collect_groups, ApplyOptions::ElementWise) && options.returns_scalar { panic!("expr {} is not implemented correctly. 'returns_scalar' and 'elementwise' are mutually exclusive", expr) } diff --git a/crates/polars-lazy/src/physical_plan/expressions/take.rs b/crates/polars-lazy/src/physical_plan/expressions/take.rs index 08d273978d18..b86c585cfeaf 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/take.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/take.rs @@ -12,6 +12,7 @@ pub struct TakeExpr { pub(crate) phys_expr: Arc, pub(crate) idx: Arc, pub(crate) expr: Expr, + pub(crate) returns_scalar: bool, } impl TakeExpr { @@ -101,12 +102,23 @@ impl PhysicalExpr for TakeExpr { }, }; let taken = ac.flat_naive().take(&idx)?; + + let taken = if self.returns_scalar { + taken + } else { + taken.as_list().into_series() + }; + ac.with_series(taken, true, Some(&self.expr))?; return Ok(ac); }, - AggState::AggregatedList(s) => s.list().unwrap().clone(), + AggState::AggregatedList(s) => { + polars_ensure!(!self.returns_scalar, ComputeError: "expected single index"); + s.list().unwrap().clone() + }, // Maybe a literal as well, this needs a different path. AggState::NotAggregated(_) => { + polars_ensure!(!self.returns_scalar, ComputeError: "expected single index"); let s = idx.aggregated(); s.list().unwrap().clone() }, @@ -144,6 +156,13 @@ impl PhysicalExpr for TakeExpr { }, }; let taken = ac.flat_naive().take(&idx.into_inner())?; + + let taken = if self.returns_scalar { + taken + } else { + taken.as_list().into_series() + }; + ac.with_series(taken, true, Some(&self.expr))?; ac.with_update_groups(UpdateGroups::WithGroupsLen); Ok(ac) diff --git a/crates/polars-lazy/src/physical_plan/planner/expr.rs b/crates/polars-lazy/src/physical_plan/planner/expr.rs index 4cbfebab732c..8e7c4a51063b 100644 --- a/crates/polars-lazy/src/physical_plan/planner/expr.rs +++ b/crates/polars-lazy/src/physical_plan/planner/expr.rs @@ -192,13 +192,18 @@ pub(crate) fn create_physical_expr( node_to_expr(expression, expr_arena), ))) }, - Take { expr, idx } => { + Take { + expr, + idx, + returns_scalar, + } => { let phys_expr = create_physical_expr(expr, ctxt, expr_arena, schema, state)?; let phys_idx = create_physical_expr(idx, ctxt, expr_arena, schema, state)?; Ok(Arc::new(TakeExpr { phys_expr, idx: phys_idx, expr: node_to_expr(expression, expr_arena), + returns_scalar, })) }, SortBy { diff --git a/crates/polars-plan/src/dsl/binary.rs b/crates/polars-plan/src/dsl/binary.rs index 5c8aab05c579..57c40df7ed6f 100644 --- a/crates/polars-plan/src/dsl/binary.rs +++ b/crates/polars-plan/src/dsl/binary.rs @@ -9,7 +9,7 @@ impl BinaryNameSpace { self.0.map_many_private( FunctionExpr::BinaryExpr(BinaryFunction::Contains), &[pat], - true, + false, true, ) } @@ -19,7 +19,7 @@ impl BinaryNameSpace { self.0.map_many_private( FunctionExpr::BinaryExpr(BinaryFunction::EndsWith), &[sub], - true, + false, true, ) } @@ -29,7 +29,7 @@ impl BinaryNameSpace { self.0.map_many_private( FunctionExpr::BinaryExpr(BinaryFunction::StartsWith), &[sub], - true, + false, true, ) } diff --git a/crates/polars-plan/src/dsl/dt.rs b/crates/polars-plan/src/dsl/dt.rs index 84e9023f5994..fd077a498086 100644 --- a/crates/polars-plan/src/dsl/dt.rs +++ b/crates/polars-plan/src/dsl/dt.rs @@ -172,7 +172,7 @@ impl DateLikeNameSpace { self.0.map_many_private( FunctionExpr::TemporalExpr(TemporalFunction::Truncate(offset)), &[every, ambiguous], - true, + false, false, ) } diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index dd7b2eeec8fa..f319e9052d47 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -89,6 +89,7 @@ pub enum Expr { Take { expr: Box, idx: Box, + returns_scalar: bool, }, SortBy { expr: Box, diff --git a/crates/polars-plan/src/dsl/functions/concat.rs b/crates/polars-plan/src/dsl/functions/concat.rs index 68271772540c..1f8e91cf5cb2 100644 --- a/crates/polars-plan/src/dsl/functions/concat.rs +++ b/crates/polars-plan/src/dsl/functions/concat.rs @@ -12,7 +12,7 @@ pub fn concat_str>(s: E, separator: &str) -> Expr { options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, input_wildcard_expansion: true, - returns_scalar: true, + returns_scalar: false, ..Default::default() }, } diff --git a/crates/polars-plan/src/dsl/functions/horizontal.rs b/crates/polars-plan/src/dsl/functions/horizontal.rs index 4bc134dc2e6f..1b10f917282a 100644 --- a/crates/polars-plan/src/dsl/functions/horizontal.rs +++ b/crates/polars-plan/src/dsl/functions/horizontal.rs @@ -201,7 +201,7 @@ pub fn all_horizontal>(exprs: E) -> Expr { options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, input_wildcard_expansion: true, - returns_scalar: true, + returns_scalar: false, cast_to_supertypes: false, allow_rename: true, ..Default::default() @@ -220,7 +220,7 @@ pub fn any_horizontal>(exprs: E) -> Expr { options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, input_wildcard_expansion: true, - returns_scalar: true, + returns_scalar: false, cast_to_supertypes: false, allow_rename: true, ..Default::default() @@ -243,7 +243,7 @@ pub fn max_horizontal>(exprs: E) -> Expr { options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, input_wildcard_expansion: true, - returns_scalar: true, + returns_scalar: false, allow_rename: true, ..Default::default() }, @@ -265,7 +265,7 @@ pub fn min_horizontal>(exprs: E) -> Expr { options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, input_wildcard_expansion: true, - returns_scalar: true, + returns_scalar: false, allow_rename: true, ..Default::default() }, @@ -284,7 +284,7 @@ pub fn sum_horizontal>(exprs: E) -> Expr { options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, input_wildcard_expansion: true, - returns_scalar: true, + returns_scalar: false, cast_to_supertypes: false, allow_rename: true, ..Default::default() diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index c8741dba73ff..6232c29c0a46 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -137,7 +137,7 @@ impl ListNameSpace { self.0.map_many_private( FunctionExpr::ListExpr(ListFunction::Get), &[index], - true, + false, false, ) } @@ -152,7 +152,7 @@ impl ListNameSpace { self.0.map_many_private( FunctionExpr::ListExpr(ListFunction::Take(null_on_oob)), &[index], - true, + false, false, ) } @@ -216,7 +216,7 @@ impl ListNameSpace { self.0.map_many_private( FunctionExpr::ListExpr(ListFunction::Slice), &[offset, length], - true, + false, false, ) } @@ -296,7 +296,7 @@ impl ListNameSpace { .map_many_private( FunctionExpr::ListExpr(ListFunction::Contains), &[other], - true, + false, false, ) .with_function_options(|mut options| { @@ -313,7 +313,7 @@ impl ListNameSpace { .map_many_private( FunctionExpr::ListExpr(ListFunction::CountMatches), &[other], - true, + false, false, ) .with_function_options(|mut options| { diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index d3e2e62599bc..516eaa183639 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -444,6 +444,16 @@ impl Expr { Expr::Take { expr: Box::new(self), idx: Box::new(idx.into()), + returns_scalar: false, + } + } + + /// Take the values by a single index. + pub fn get>(self, idx: E) -> Self { + Expr::Take { + expr: Box::new(self), + idx: Box::new(idx.into()), + returns_scalar: true, } } @@ -679,7 +689,7 @@ impl Expr { self, function_expr: FunctionExpr, arguments: &[Expr], - auto_explode: bool, + returns_scalar: bool, cast_to_supertypes: bool, ) -> Self { let mut input = Vec::with_capacity(arguments.len() + 1); @@ -691,7 +701,7 @@ impl Expr { function: function_expr, options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, - returns_scalar: auto_explode, + returns_scalar, cast_to_supertypes, ..Default::default() }, @@ -1060,12 +1070,25 @@ impl Expr { let other = other.into(); let has_literal = has_leaf_literal(&other); + // lit(true).is_in() returns a scalar. + let returns_scalar = all_leaf_literal(&self); + let arguments = &[other]; // we don't have to apply on groups, so this is faster if has_literal { - self.map_many_private(BooleanFunction::IsIn.into(), arguments, true, true) + self.map_many_private( + BooleanFunction::IsIn.into(), + arguments, + returns_scalar, + true, + ) } else { - self.apply_many_private(BooleanFunction::IsIn.into(), arguments, true, true) + self.apply_many_private( + BooleanFunction::IsIn.into(), + arguments, + returns_scalar, + true, + ) } } diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index 5cd361046ac8..4f1fcc255020 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -13,7 +13,7 @@ impl StringNameSpace { strict: false, }), &[pat], - true, + false, true, ) } @@ -28,7 +28,7 @@ impl StringNameSpace { strict, }), &[pat], - true, + false, true, ) } @@ -38,7 +38,7 @@ impl StringNameSpace { self.0.map_many_private( FunctionExpr::StringExpr(StringFunction::EndsWith), &[sub], - true, + false, true, ) } @@ -48,7 +48,7 @@ impl StringNameSpace { self.0.map_many_private( FunctionExpr::StringExpr(StringFunction::StartsWith), &[sub], - true, + false, true, ) } @@ -131,7 +131,7 @@ impl StringNameSpace { self.0.map_many_private( StringFunction::CountMatches(literal).into(), &[pat], - true, + false, false, ) } @@ -142,7 +142,7 @@ impl StringNameSpace { self.0.map_many_private( StringFunction::Strptime(dtype, options).into(), &[ambiguous], - true, + false, false, ) } @@ -207,6 +207,7 @@ impl StringNameSpace { .apply_private(StringFunction::ConcatVertical(delimiter.to_owned()).into()) .with_function_options(|mut options| { options.returns_scalar = true; + options.collect_groups = ApplyOptions::GroupWise; options }) } diff --git a/crates/polars-plan/src/logical_plan/aexpr/mod.rs b/crates/polars-plan/src/logical_plan/aexpr/mod.rs index bed9f63fe32b..9bc134f0c840 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/mod.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/mod.rs @@ -146,6 +146,7 @@ pub enum AExpr { Take { expr: Node, idx: Node, + returns_scalar: bool, }, SortBy { expr: Node, @@ -267,7 +268,7 @@ impl AExpr { }, Cast { expr, .. } => container.push(*expr), Sort { expr, .. } => container.push(*expr), - Take { expr, idx } => { + Take { expr, idx, .. } => { container.push(*idx); // latest, so that it is popped first container.push(*expr); @@ -346,7 +347,7 @@ impl AExpr { *left = inputs[1]; return self; }, - Take { expr, idx } => { + Take { expr, idx, .. } => { *idx = inputs[0]; *expr = inputs[1]; return self; diff --git a/crates/polars-plan/src/logical_plan/conversion.rs b/crates/polars-plan/src/logical_plan/conversion.rs index 2aecd7a693fc..7417c4367059 100644 --- a/crates/polars-plan/src/logical_plan/conversion.rs +++ b/crates/polars-plan/src/logical_plan/conversion.rs @@ -31,9 +31,14 @@ pub fn to_aexpr(expr: Expr, arena: &mut Arena) -> Node { data_type, strict, }, - Expr::Take { expr, idx } => AExpr::Take { + Expr::Take { + expr, + idx, + returns_scalar, + } => AExpr::Take { expr: to_aexpr(*expr, arena), idx: to_aexpr(*idx, arena), + returns_scalar, }, Expr::Sort { expr, options } => AExpr::Sort { expr: to_aexpr(*expr, arena), @@ -399,12 +404,17 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { options, } }, - AExpr::Take { expr, idx } => { + AExpr::Take { + expr, + idx, + returns_scalar, + } => { let expr = node_to_expr(expr, expr_arena); let idx = node_to_expr(idx, expr_arena); Expr::Take { expr: Box::new(expr), idx: Box::new(idx), + returns_scalar, } }, AExpr::SortBy { diff --git a/crates/polars-plan/src/logical_plan/format.rs b/crates/polars-plan/src/logical_plan/format.rs index 3f6631163716..c48c1108a833 100644 --- a/crates/polars-plan/src/logical_plan/format.rs +++ b/crates/polars-plan/src/logical_plan/format.rs @@ -278,8 +278,16 @@ impl Debug for Expr { Filter { input, by } => { write!(f, "{input:?}.filter({by:?})") }, - Take { expr, idx } => { - write!(f, "{expr:?}.take({idx:?})") + Take { + expr, + idx, + returns_scalar, + } => { + if *returns_scalar { + write!(f, "{expr:?}.get({idx:?})") + } else { + write!(f, "{expr:?}.take({idx:?})") + } }, SubPlan(lf, _) => { write!(f, ".subplan({lf:?})") diff --git a/crates/polars-plan/src/logical_plan/iterator.rs b/crates/polars-plan/src/logical_plan/iterator.rs index 48e40f04dd11..e9577a1efbc3 100644 --- a/crates/polars-plan/src/logical_plan/iterator.rs +++ b/crates/polars-plan/src/logical_plan/iterator.rs @@ -15,7 +15,7 @@ macro_rules! push_expr { }, Cast { expr, .. } => $push(expr), Sort { expr, .. } => $push(expr), - Take { expr, idx } => { + Take { expr, idx, .. } => { $push(idx); $push(expr); }, diff --git a/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs b/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs index 1cd4baf87281..316485450187 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs @@ -402,7 +402,7 @@ fn string_addition_to_linear_concat( options: FunctionOptions { collect_groups: ApplyOptions::ElementWise, input_wildcard_expansion: true, - returns_scalar: true, + returns_scalar: false, ..Default::default() }, }), diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index 54c9b445b5ed..70ee83d05f9a 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -147,6 +147,17 @@ pub(crate) fn has_leaf_literal(e: &Expr) -> bool { }, } } +/// Check if leaf expression is a literal +#[cfg(feature = "is_in")] +pub(crate) fn all_leaf_literal(e: &Expr) -> bool { + match e { + Expr::Literal(_) => true, + _ => { + let roots = expr_to_root_column_exprs(e); + roots.iter().all(|e| matches!(e, Expr::Literal(_))) + }, + } +} pub fn has_null(current_expr: &Expr) -> bool { has_expr(current_expr, |e| { diff --git a/py-polars/docs/source/reference/expressions/modify_select.rst b/py-polars/docs/source/reference/expressions/modify_select.rst index d40fffd24a7f..3514a44ae087 100644 --- a/py-polars/docs/source/reference/expressions/modify_select.rst +++ b/py-polars/docs/source/reference/expressions/modify_select.rst @@ -27,6 +27,7 @@ Manipulation/selection Expr.flatten Expr.floor Expr.forward_fill + Expr.get Expr.head Expr.inspect Expr.interpolate diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 9bd679ede957..4f6a01f91fba 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -2331,7 +2331,60 @@ def take( ... "value": [1, 98, 2, 3, 99, 4], ... } ... ) - >>> df.group_by("group", maintain_order=True).agg(pl.col("value").take(1)) + >>> df.group_by("group", maintain_order=True).agg(pl.col("value").take([2, 1])) + shape: (2, 2) + ┌───────┬───────────┐ + │ group ┆ value │ + │ --- ┆ --- │ + │ str ┆ list[i64] │ + ╞═══════╪═══════════╡ + │ one ┆ [2, 98] │ + │ two ┆ [4, 99] │ + └───────┴───────────┘ + + See Also + -------- + Expr.get : Take a single value + + """ + if isinstance(indices, list) or ( + _check_for_numpy(indices) and isinstance(indices, np.ndarray) + ): + indices_lit = F.lit(pl.Series("", indices, dtype=UInt32))._pyexpr + else: + indices_lit = parse_as_expression(indices) # type: ignore[arg-type] + return self._from_pyexpr(self._pyexpr.take(indices_lit)) + + def get(self, index: int | Expr) -> Self: + """ + Return a single value by index. + + Parameters + ---------- + index + An expression that leads to a UInt32 index. + + Returns + ------- + Expr + Expression of the same data type. + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "group": [ + ... "one", + ... "one", + ... "one", + ... "two", + ... "two", + ... "two", + ... ], + ... "value": [1, 98, 2, 3, 99, 4], + ... } + ... ) + >>> df.group_by("group", maintain_order=True).agg(pl.col("value").get(1)) shape: (2, 2) ┌───────┬───────┐ │ group ┆ value │ @@ -2343,13 +2396,8 @@ def take( └───────┴───────┘ """ - if isinstance(indices, list) or ( - _check_for_numpy(indices) and isinstance(indices, np.ndarray) - ): - indices_lit = F.lit(pl.Series("", indices, dtype=UInt32))._pyexpr - else: - indices_lit = parse_as_expression(indices) # type: ignore[arg-type] - return self._from_pyexpr(self._pyexpr.take(indices_lit)) + index_lit = parse_as_expression(index) + return self._from_pyexpr(self._pyexpr.get(index_lit)) @deprecate_renamed_parameter("periods", "n", version="0.19.11") def shift(self, n: int = 1) -> Self: diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index 2e9f5aea94b6..686b2d0c277e 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -322,6 +322,10 @@ impl PyExpr { self.clone().inner.take(idx.inner).into() } + fn get(&self, idx: Self) -> Self { + self.clone().inner.get(idx.inner).into() + } + fn sort_by(&self, by: Vec, descending: Vec) -> Self { let by = by.into_iter().map(|e| e.inner).collect::>(); self.clone().inner.sort_by(by, descending).into() diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 6988322dd67d..705f6bce069b 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -420,7 +420,7 @@ def test_take_misc(fruits_cars: pl.DataFrame) -> None: assert out[4, "B"].to_list() == [1, 4] out = df.sort("fruits").select( - [pl.col("B").reverse().take(pl.lit(1)).over("fruits"), "fruits"] + [pl.col("B").reverse().get(pl.lit(1)).over("fruits"), "fruits"] ) assert out[0, "B"] == 3 assert out[4, "B"] == 4 diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/test_list.py index 89ceff4e6a0a..5fb543d5119a 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/test_list.py @@ -447,15 +447,19 @@ def test_list_function_group_awareness() -> None: assert df.group_by("group").agg( [ - pl.col("a").implode().list.get(0).alias("get"), - pl.col("a").implode().list.take([0]).alias("take"), - pl.col("a").implode().list.slice(0, 3).alias("slice"), + pl.col("a").get(0).alias("get_scalar"), + pl.col("a").take([0]).alias("take_no_implode"), + pl.col("a").implode().list.get(0).alias("implode_get"), + pl.col("a").implode().list.take([0]).alias("implode_take"), + pl.col("a").implode().list.slice(0, 3).alias("implode_slice"), ] ).sort("group").to_dict(False) == { "group": [0, 1, 2], - "get": [100, 105, 100], - "take": [[100], [105], [100]], - "slice": [[100, 103], [105, 106, 105], [100, 102]], + "get_scalar": [100, 105, 100], + "take_no_implode": [[100], [105], [100]], + "implode_get": [[100], [105], [100]], + "implode_take": [[[100]], [[105]], [[100]]], + "implode_slice": [[[100, 103]], [[105, 106, 105]], [[100, 102]]], } diff --git a/py-polars/tests/unit/operations/test_aggregations.py b/py-polars/tests/unit/operations/test_aggregations.py index 0cd22d787120..3d4c5a5a38d2 100644 --- a/py-polars/tests/unit/operations/test_aggregations.py +++ b/py-polars/tests/unit/operations/test_aggregations.py @@ -255,7 +255,7 @@ def test_err_on_implode_and_agg() -> None: pl.col("type").implode().list.head().alias("foo") ).to_dict(False) == { "type": ["water", "fire", "earth"], - "foo": [["water", "water"], ["fire"], ["earth"]], + "foo": [[["water", "water"]], [["fire"]], [["earth"]]], } # but not during a window function as the groups cannot be mapped back diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 1902459bb05a..448bddac5ef3 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -269,7 +269,7 @@ def test_apply_after_take_in_group_by_3869() -> None: ) .group_by("k", maintain_order=True) .agg( - pl.col("v").take(pl.col("t").arg_max()).sqrt() + pl.col("v").get(pl.col("t").arg_max()).sqrt() ) # <- fails for sqrt, exp, log, pow, etc. ).to_dict(False) == {"k": ["a", "b"], "v": [1.4142135623730951, 2.0]} @@ -387,7 +387,7 @@ def test_group_by_dynamic_overlapping_groups_flat_apply_multiple_5038( def test_take_in_group_by() -> None: df = pl.DataFrame({"group": [1, 1, 1, 2, 2, 2], "values": [10, 200, 3, 40, 500, 6]}) assert df.group_by("group").agg( - pl.col("values").take(1) - pl.col("values").take(2) + pl.col("values").get(1) - pl.col("values").get(2) ).sort("group").to_dict(False) == {"group": [1, 2], "values": [197, 494]}