From 2e77d3112558f656581677a0aa291c69c27ac2c2 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Tue, 30 Jul 2024 23:34:10 +0400 Subject: [PATCH] fix: Properly account for multi-output column expressions in frame `sort` method --- crates/polars-core/src/frame/explode.rs | 2 +- crates/polars-core/src/frame/mod.rs | 22 +++++---- crates/polars-io/src/partition.rs | 2 +- .../src/plans/conversion/dsl_to_ir.rs | 2 +- crates/polars-plan/src/plans/functions/dsl.rs | 2 +- py-polars/polars/_utils/various.py | 15 ++++++ py-polars/polars/lazyframe/frame.py | 17 +++++-- py-polars/tests/unit/dataframe/test_df.py | 46 +++++++++++++++++++ 8 files changed, 91 insertions(+), 17 deletions(-) diff --git a/crates/polars-core/src/frame/explode.rs b/crates/polars-core/src/frame/explode.rs index 3854ccf5b49f..8318fb1d10af 100644 --- a/crates/polars-core/src/frame/explode.rs +++ b/crates/polars-core/src/frame/explode.rs @@ -308,7 +308,7 @@ impl DataFrame { let mut iter = on.iter().map(|v| { schema .get(v) - .ok_or_else(|| polars_err!(ColumnNotFound: "{}", v)) + .ok_or_else(|| polars_err!(ColumnNotFound: "{:?}", v)) }); let mut st = iter.next().unwrap()?.clone(); for dt in iter { diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 3f0293832266..a1216a9e5efc 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -189,7 +189,7 @@ impl DataFrame { /// Get the index of the column. fn check_name_to_idx(&self, name: &str) -> PolarsResult { self.get_column_index(name) - .ok_or_else(|| polars_err!(ColumnNotFound: "{}", name)) + .ok_or_else(|| polars_err!(ColumnNotFound: "{:?}", name)) } fn check_already_present(&self, name: &str) -> PolarsResult<()> { @@ -1361,7 +1361,7 @@ impl DataFrame { /// Get column index of a [`Series`] by name. pub fn try_get_column_index(&self, name: &str) -> PolarsResult { self.get_column_index(name) - .ok_or_else(|| polars_err!(ColumnNotFound: "{}", name)) + .ok_or_else(|| polars_err!(ColumnNotFound: "{:?}", name)) } /// Select a single column by name. @@ -1560,7 +1560,7 @@ impl DataFrame { .map(|name| { let idx = *name_to_idx .get(name.as_str()) - .ok_or_else(|| polars_err!(ColumnNotFound: "{}", name))?; + .ok_or_else(|| polars_err!(ColumnNotFound: "{:?}", name))?; Ok(self .select_at_idx(idx) .unwrap() @@ -1588,7 +1588,7 @@ impl DataFrame { .map(|name| { let idx = *name_to_idx .get(name.as_str()) - .ok_or_else(|| polars_err!(ColumnNotFound: "{}", name))?; + .ok_or_else(|| polars_err!(ColumnNotFound: "{:?}", name))?; Ok(self.select_at_idx(idx).unwrap().clone()) }) .collect::>>()? @@ -1696,7 +1696,7 @@ impl DataFrame { /// ``` pub fn rename(&mut self, column: &str, name: &str) -> PolarsResult<&mut Self> { self.select_mut(column) - .ok_or_else(|| polars_err!(ColumnNotFound: "{}", column)) + .ok_or_else(|| polars_err!(ColumnNotFound: "{:?}", column)) .map(|s| s.rename(name))?; let unique_names: AHashSet<&str, ahash::RandomState> = AHashSet::from_iter(self.columns.iter().map(|s| s.name())); @@ -1728,11 +1728,13 @@ impl DataFrame { mut sort_options: SortMultipleOptions, slice: Option<(i64, usize)>, ) -> PolarsResult { + if by_column.is_empty() { + polars_bail!(ComputeError: "No columns selected for sorting"); + } // note that the by_column argument also contains evaluated expression from - // polars-lazy that may not even be present in this dataframe. - - // therefore when we try to set the first columns as sorted, we ignore the error - // as expressions are not present (they are renamed to _POLARS_SORT_COLUMN_i. + // polars-lazy that may not even be present in this dataframe. therefore + // when we try to set the first columns as sorted, we ignore the error as + // expressions are not present (they are renamed to _POLARS_SORT_COLUMN_i. let first_descending = sort_options.descending[0]; let first_by_column = by_column[0].name().to_string(); @@ -2966,7 +2968,7 @@ impl DataFrame { for col in cols { let _ = schema .get(&col) - .ok_or_else(|| polars_err!(ColumnNotFound: "{}", col))?; + .ok_or_else(|| polars_err!(ColumnNotFound: "{:?}", col))?; } } DataFrame::new(new_cols) diff --git a/crates/polars-io/src/partition.rs b/crates/polars-io/src/partition.rs index a2f852d5b1a4..b16224f48700 100644 --- a/crates/polars-io/src/partition.rs +++ b/crates/polars-io/src/partition.rs @@ -53,7 +53,7 @@ where .iter() .map(|x| { let Some(i) = schema.index_of(x.as_ref()) else { - polars_bail!(ColumnNotFound: "{}", x.as_ref()) + polars_bail!(ColumnNotFound: "{:?}", x.as_ref()) }; Ok(i) }) diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index 3ee71042b78d..9a989eee11f8 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -604,7 +604,7 @@ pub fn to_alp_impl( DslFunction::Drop(DropFunction { to_drop, strict }) => { if strict { for col_name in to_drop.iter() { - polars_ensure!(input_schema.contains(col_name), ColumnNotFound: "{col_name}"); + polars_ensure!(input_schema.contains(col_name), ColumnNotFound: "{:?}", col_name); } } diff --git a/crates/polars-plan/src/plans/functions/dsl.rs b/crates/polars-plan/src/plans/functions/dsl.rs index 6c53e8b676f0..aea343c1859d 100644 --- a/crates/polars-plan/src/plans/functions/dsl.rs +++ b/crates/polars-plan/src/plans/functions/dsl.rs @@ -69,7 +69,7 @@ impl DslFunction { polars_bail!(InvalidOperation: "expected column expression") }; - polars_ensure!(input_schema.contains(name), ColumnNotFound: "{name}"); + polars_ensure!(input_schema.contains(name), ColumnNotFound: "{:?}", name); Ok(name.clone()) }) diff --git a/py-polars/polars/_utils/various.py b/py-polars/polars/_utils/various.py index b14a3841103c..8d936bd3e27d 100644 --- a/py-polars/polars/_utils/various.py +++ b/py-polars/polars/_utils/various.py @@ -533,6 +533,8 @@ def extend_bool( n_match: int, value_name: str, match_name: str, + *, + condense: bool = False, ) -> Sequence[bool]: """Ensure the given bool or sequence of bools is the correct length.""" values = [value] * n_match if isinstance(value, bool) else value @@ -542,6 +544,9 @@ def extend_bool( f"does not match the length of `{match_name}` ({n_match})" ) raise ValueError(msg) + + if condense and len(set(values)) == 1: + return [values[0]] return values @@ -600,3 +605,13 @@ def re_escape(s: str) -> str: # escapes _only_ those metachars with meaning to the rust regex crate re_rust_metachars = r"\\?()|\[\]{}^$#&~.+*-" return re.sub(f"([{re_rust_metachars}])", r"\\\1", s) + + +def has_multiple_outputs(x: Any) -> bool: + """Check if the given input is an Expr/PyExpr that could return multiple outputs.""" + from polars.expr import Expr + from polars.polars import PyExpr + + return (isinstance(x, Expr) and x.meta.has_multiple_outputs()) or ( + isinstance(x, PyExpr) and x.meta_has_multiple_outputs() + ) diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 97516f877f74..d47eb2f9bdd9 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -41,6 +41,7 @@ _in_notebook, _is_generator, extend_bool, + has_multiple_outputs, is_bool_sequence, is_sequence, issue_warning, @@ -77,7 +78,7 @@ ) from polars.datatypes.group import DataTypeGroup from polars.dependencies import import_optional, subprocess -from polars.exceptions import PerformanceWarning +from polars.exceptions import InvalidOperationError, PerformanceWarning from polars.lazyframe.engine_config import GPUEngine from polars.lazyframe.group_by import LazyGroupBy from polars.lazyframe.in_process import InProcessQuery @@ -1366,8 +1367,18 @@ def sort( ) by = parse_into_list_of_expressions(by, *more_by) - descending = extend_bool(descending, len(by), "descending", "by") - nulls_last = extend_bool(nulls_last, len(by), "nulls_last", "by") + descending = extend_bool(descending, len(by), "descending", "by", condense=True) + nulls_last = extend_bool(nulls_last, len(by), "nulls_last", "by", condense=True) + + if (len(descending) > 1 or len(nulls_last) > 1) and any( + has_multiple_outputs(x) for x in by + ): + msg = ( + "Cannot set mixed per-column `descending` or `nulls_last` " + "when `by` contains multi-output expressions" + ) + raise InvalidOperationError(msg) + return self._from_pyldf( self._ldf.sort_by_exprs( by, descending, nulls_last, maintain_order, multithreaded diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 2ed5a8a39b32..777490c10a99 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -306,6 +306,52 @@ def test_sort() -> None: ) +def test_sort_multi_output() -> None: + df = pl.DataFrame( + { + "dts": [date(2077, 10, 3), date(2077, 10, 2), date(2077, 10, 2)], + "strs": ["abc", "def", "ghi"], + "vals": [10.5, 20.3, 15.7], + } + ) + + expected = pl.DataFrame( + { + "dts": [date(2077, 10, 2), date(2077, 10, 2), date(2077, 10, 3)], + "strs": ["ghi", "def", "abc"], + "vals": [15.7, 20.3, 10.5], + } + ) + assert_frame_equal(expected, df.sort(pl.col("^(d|v).*$"))) + assert_frame_equal(expected, df.sort(cs.temporal() | cs.numeric())) + assert_frame_equal(expected, df.sort(cs.temporal(), cs.numeric(), cs.binary())) + + expected = pl.DataFrame( + { + "dts": [date(2077, 10, 3), date(2077, 10, 2), date(2077, 10, 2)], + "strs": ["abc", "def", "ghi"], + "vals": [10.5, 20.3, 15.7], + } + ) + assert_frame_equal(expected, df.sort(pl.col("^(d|v).*$"), descending=[True])) + assert_frame_equal( + expected, df.sort(cs.temporal() | cs.numeric(), descending=[True]) + ) + assert_frame_equal( + expected, df.sort(cs.temporal(), cs.numeric(), descending=[True, True]) + ) + + with pytest.raises(ComputeError, match="No columns selected for sorting"): + df.sort(pl.col("^xxx$")) + + with pytest.raises( + InvalidOperationError, + match="Cannot set mixed per-column `descending` or " + "`nulls_last` when `by` contains multi-output expressions", + ): + df.sort(cs.temporal(), cs.numeric(), descending=[True, False]) + + def test_sort_maintain_order() -> None: l1 = ( pl.LazyFrame({"A": [1] * 4, "B": ["A", "B", "C", "D"]})