Skip to content

Commit

Permalink
fix: Properly account for multi-output column expressions in frame `s…
Browse files Browse the repository at this point in the history
…ort` method
  • Loading branch information
alexander-beedie committed Jul 30, 2024
1 parent dea0679 commit 2e77d31
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 17 deletions.
2 changes: 1 addition & 1 deletion crates/polars-core/src/frame/explode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ impl DataFrame {
let mut iter = on.iter().map(|v| {
schema
.get(v)
.ok_or_else(|| polars_err!(ColumnNotFound: "{}", v))
.ok_or_else(|| polars_err!(ColumnNotFound: "{:?}", v))
});
let mut st = iter.next().unwrap()?.clone();
for dt in iter {
Expand Down
22 changes: 12 additions & 10 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ impl DataFrame {
/// Get the index of the column.
fn check_name_to_idx(&self, name: &str) -> PolarsResult<usize> {
self.get_column_index(name)
.ok_or_else(|| polars_err!(ColumnNotFound: "{}", name))
.ok_or_else(|| polars_err!(ColumnNotFound: "{:?}", name))
}

fn check_already_present(&self, name: &str) -> PolarsResult<()> {
Expand Down Expand Up @@ -1361,7 +1361,7 @@ impl DataFrame {
/// Get column index of a [`Series`] by name.
pub fn try_get_column_index(&self, name: &str) -> PolarsResult<usize> {
self.get_column_index(name)
.ok_or_else(|| polars_err!(ColumnNotFound: "{}", name))
.ok_or_else(|| polars_err!(ColumnNotFound: "{:?}", name))
}

/// Select a single column by name.
Expand Down Expand Up @@ -1560,7 +1560,7 @@ impl DataFrame {
.map(|name| {
let idx = *name_to_idx
.get(name.as_str())
.ok_or_else(|| polars_err!(ColumnNotFound: "{}", name))?;
.ok_or_else(|| polars_err!(ColumnNotFound: "{:?}", name))?;
Ok(self
.select_at_idx(idx)
.unwrap()
Expand Down Expand Up @@ -1588,7 +1588,7 @@ impl DataFrame {
.map(|name| {
let idx = *name_to_idx
.get(name.as_str())
.ok_or_else(|| polars_err!(ColumnNotFound: "{}", name))?;
.ok_or_else(|| polars_err!(ColumnNotFound: "{:?}", name))?;
Ok(self.select_at_idx(idx).unwrap().clone())
})
.collect::<PolarsResult<Vec<_>>>()?
Expand Down Expand Up @@ -1696,7 +1696,7 @@ impl DataFrame {
/// ```
pub fn rename(&mut self, column: &str, name: &str) -> PolarsResult<&mut Self> {
self.select_mut(column)
.ok_or_else(|| polars_err!(ColumnNotFound: "{}", column))
.ok_or_else(|| polars_err!(ColumnNotFound: "{:?}", column))
.map(|s| s.rename(name))?;
let unique_names: AHashSet<&str, ahash::RandomState> =
AHashSet::from_iter(self.columns.iter().map(|s| s.name()));
Expand Down Expand Up @@ -1728,11 +1728,13 @@ impl DataFrame {
mut sort_options: SortMultipleOptions,
slice: Option<(i64, usize)>,
) -> PolarsResult<Self> {
if by_column.is_empty() {
polars_bail!(ComputeError: "No columns selected for sorting");
}
// note that the by_column argument also contains evaluated expression from
// polars-lazy that may not even be present in this dataframe.

// therefore when we try to set the first columns as sorted, we ignore the error
// as expressions are not present (they are renamed to _POLARS_SORT_COLUMN_i.
// polars-lazy that may not even be present in this dataframe. therefore
// when we try to set the first columns as sorted, we ignore the error as
// expressions are not present (they are renamed to _POLARS_SORT_COLUMN_i.
let first_descending = sort_options.descending[0];
let first_by_column = by_column[0].name().to_string();

Expand Down Expand Up @@ -2966,7 +2968,7 @@ impl DataFrame {
for col in cols {
let _ = schema
.get(&col)
.ok_or_else(|| polars_err!(ColumnNotFound: "{}", col))?;
.ok_or_else(|| polars_err!(ColumnNotFound: "{:?}", col))?;
}
}
DataFrame::new(new_cols)
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ where
.iter()
.map(|x| {
let Some(i) = schema.index_of(x.as_ref()) else {
polars_bail!(ColumnNotFound: "{}", x.as_ref())
polars_bail!(ColumnNotFound: "{:?}", x.as_ref())
};
Ok(i)
})
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ pub fn to_alp_impl(
DslFunction::Drop(DropFunction { to_drop, strict }) => {
if strict {
for col_name in to_drop.iter() {
polars_ensure!(input_schema.contains(col_name), ColumnNotFound: "{col_name}");
polars_ensure!(input_schema.contains(col_name), ColumnNotFound: "{:?}", col_name);
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/functions/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl DslFunction {
polars_bail!(InvalidOperation: "expected column expression")
};

polars_ensure!(input_schema.contains(name), ColumnNotFound: "{name}");
polars_ensure!(input_schema.contains(name), ColumnNotFound: "{:?}", name);

Ok(name.clone())
})
Expand Down
15 changes: 15 additions & 0 deletions py-polars/polars/_utils/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,8 @@ def extend_bool(
n_match: int,
value_name: str,
match_name: str,
*,
condense: bool = False,
) -> Sequence[bool]:
"""Ensure the given bool or sequence of bools is the correct length."""
values = [value] * n_match if isinstance(value, bool) else value
Expand All @@ -542,6 +544,9 @@ def extend_bool(
f"does not match the length of `{match_name}` ({n_match})"
)
raise ValueError(msg)

if condense and len(set(values)) == 1:
return [values[0]]
return values


Expand Down Expand Up @@ -600,3 +605,13 @@ def re_escape(s: str) -> str:
# escapes _only_ those metachars with meaning to the rust regex crate
re_rust_metachars = r"\\?()|\[\]{}^$#&~.+*-"
return re.sub(f"([{re_rust_metachars}])", r"\\\1", s)


def has_multiple_outputs(x: Any) -> bool:
"""Check if the given input is an Expr/PyExpr that could return multiple outputs."""
from polars.expr import Expr
from polars.polars import PyExpr

return (isinstance(x, Expr) and x.meta.has_multiple_outputs()) or (
isinstance(x, PyExpr) and x.meta_has_multiple_outputs()
)
17 changes: 14 additions & 3 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_in_notebook,
_is_generator,
extend_bool,
has_multiple_outputs,
is_bool_sequence,
is_sequence,
issue_warning,
Expand Down Expand Up @@ -77,7 +78,7 @@
)
from polars.datatypes.group import DataTypeGroup
from polars.dependencies import import_optional, subprocess
from polars.exceptions import PerformanceWarning
from polars.exceptions import InvalidOperationError, PerformanceWarning
from polars.lazyframe.engine_config import GPUEngine
from polars.lazyframe.group_by import LazyGroupBy
from polars.lazyframe.in_process import InProcessQuery
Expand Down Expand Up @@ -1366,8 +1367,18 @@ def sort(
)

by = parse_into_list_of_expressions(by, *more_by)
descending = extend_bool(descending, len(by), "descending", "by")
nulls_last = extend_bool(nulls_last, len(by), "nulls_last", "by")
descending = extend_bool(descending, len(by), "descending", "by", condense=True)
nulls_last = extend_bool(nulls_last, len(by), "nulls_last", "by", condense=True)

if (len(descending) > 1 or len(nulls_last) > 1) and any(
has_multiple_outputs(x) for x in by
):
msg = (
"Cannot set mixed per-column `descending` or `nulls_last` "
"when `by` contains multi-output expressions"
)
raise InvalidOperationError(msg)

return self._from_pyldf(
self._ldf.sort_by_exprs(
by, descending, nulls_last, maintain_order, multithreaded
Expand Down
46 changes: 46 additions & 0 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,52 @@ def test_sort() -> None:
)


def test_sort_multi_output() -> None:
df = pl.DataFrame(
{
"dts": [date(2077, 10, 3), date(2077, 10, 2), date(2077, 10, 2)],
"strs": ["abc", "def", "ghi"],
"vals": [10.5, 20.3, 15.7],
}
)

expected = pl.DataFrame(
{
"dts": [date(2077, 10, 2), date(2077, 10, 2), date(2077, 10, 3)],
"strs": ["ghi", "def", "abc"],
"vals": [15.7, 20.3, 10.5],
}
)
assert_frame_equal(expected, df.sort(pl.col("^(d|v).*$")))
assert_frame_equal(expected, df.sort(cs.temporal() | cs.numeric()))
assert_frame_equal(expected, df.sort(cs.temporal(), cs.numeric(), cs.binary()))

expected = pl.DataFrame(
{
"dts": [date(2077, 10, 3), date(2077, 10, 2), date(2077, 10, 2)],
"strs": ["abc", "def", "ghi"],
"vals": [10.5, 20.3, 15.7],
}
)
assert_frame_equal(expected, df.sort(pl.col("^(d|v).*$"), descending=[True]))
assert_frame_equal(
expected, df.sort(cs.temporal() | cs.numeric(), descending=[True])
)
assert_frame_equal(
expected, df.sort(cs.temporal(), cs.numeric(), descending=[True, True])
)

with pytest.raises(ComputeError, match="No columns selected for sorting"):
df.sort(pl.col("^xxx$"))

with pytest.raises(
InvalidOperationError,
match="Cannot set mixed per-column `descending` or "
"`nulls_last` when `by` contains multi-output expressions",
):
df.sort(cs.temporal(), cs.numeric(), descending=[True, False])


def test_sort_maintain_order() -> None:
l1 = (
pl.LazyFrame({"A": [1] * 4, "B": ["A", "B", "C", "D"]})
Expand Down

0 comments on commit 2e77d31

Please sign in to comment.