From d25e3d9e63fcefce95de252072cb622890130c59 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Thu, 3 Aug 2023 15:47:49 +0400 Subject: [PATCH] add selector expansion utility function, streamline "exclude" expression --- py-polars/polars/expr/expr.py | 56 +++++++++++--------------- py-polars/polars/functions/lazy.py | 4 +- py-polars/polars/lazyframe/frame.py | 19 +-------- py-polars/polars/selectors.py | 24 ++++++++++- py-polars/tests/unit/test_exprs.py | 2 +- py-polars/tests/unit/test_selectors.py | 7 ++-- 6 files changed, 56 insertions(+), 56 deletions(-) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 28eccc96f6917..fbf6a931b6e26 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -781,13 +781,14 @@ def keep_name(self) -> Self: def exclude( self, - columns: str | PolarsDataType | Iterable[str] | Iterable[PolarsDataType], + columns: str | PolarsDataType | Collection[str] | Collection[PolarsDataType], *more_columns: str | PolarsDataType, ) -> Self: """ Exclude columns from a multi-column expression. - Only works after a wildcard or regex column selection. + Only works after a wildcard or regex column selection, and you cannot provide + both string column names *and* dtypes (you may prefer to use selectors instead). Parameters ---------- @@ -862,43 +863,34 @@ def exclude( └──────┘ """ - if more_columns: - if isinstance(columns, str): - columns_str = [columns] - columns_str.extend(more_columns) # type: ignore[arg-type] - return self._from_pyexpr(self._pyexpr.exclude(columns_str)) - elif is_polars_dtype(columns): - dtypes = [columns] - dtypes.extend(more_columns) - return self._from_pyexpr(self._pyexpr.exclude_dtype(dtypes)) - else: - raise TypeError( - f"Invalid input for `exclude`. Expected `str` or `DataType`, got {type(columns)!r}" - ) - - if isinstance(columns, str): - return self._from_pyexpr(self._pyexpr.exclude([columns])) - elif is_polars_dtype(columns): - return self._from_pyexpr(self._pyexpr.exclude_dtype([columns])) - elif isinstance(columns, Iterable): - columns_list = list(columns) - if not columns_list: - return self - - item = columns_list[0] + exclude_cols: list[str] = [] + exclude_dtypes: list[PolarsDataType] = [] + for item in ( + *( + columns + if isinstance(columns, Collection) and not isinstance(columns, str) + else [columns] + ), + *more_columns, + ): if isinstance(item, str): - return self._from_pyexpr(self._pyexpr.exclude(columns_list)) + exclude_cols.append(item) elif is_polars_dtype(item): - return self._from_pyexpr(self._pyexpr.exclude_dtype(columns_list)) + exclude_dtypes.append(item) else: raise TypeError( - "Invalid input for `exclude`. Expected iterable of type `str` or `DataType`," - f" got iterable of type {type(item)!r}" + "Invalid input for `exclude`. Expected one or more `str`, " + f"`DataType`, or selector; found {type(item)!r} instead" ) - else: + + if exclude_cols and exclude_dtypes: raise TypeError( - f"Invalid input for `exclude`. Expected `str` or `DataType`, got {type(columns)!r}" + "Cannot exclude by both column name and dtype; use a selector instead" ) + elif exclude_dtypes: + return self._from_pyexpr(self._pyexpr.exclude_dtype(exclude_dtypes)) + else: + return self._from_pyexpr(self._pyexpr.exclude(exclude_cols)) def pipe( self, diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index 1a9c0f21cacfe..5f22b8841f1ed 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -37,7 +37,7 @@ if TYPE_CHECKING: - from typing import Literal + from typing import Collection, Literal from polars import DataFrame, Expr, LazyFrame, Series from polars.type_aliases import ( @@ -1648,7 +1648,7 @@ def arctan2d(y: str | Expr, x: str | Expr) -> Expr: def exclude( - columns: str | PolarsDataType | Iterable[str] | Iterable[PolarsDataType], + columns: str | PolarsDataType | Collection[str] | Collection[PolarsDataType], *more_columns: str | PolarsDataType, ) -> Expr: """ diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index df6422d12c6fb..ba2212aab5d0b 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -47,7 +47,7 @@ from polars.io.ipc.anonymous_scan import _scan_ipc_fsspec from polars.io.parquet.anonymous_scan import _scan_parquet_fsspec from polars.lazyframe.groupby import LazyGroupBy -from polars.selectors import selector_column_names +from polars.selectors import expand_selectors from polars.slice import LazyPolarsSlice from polars.utils._parse_expr_input import ( parse_as_expression, @@ -3370,22 +3370,7 @@ def drop( └─────┘ """ - input_cols: list[str | SelectorType] = [ - *( - columns - if isinstance(columns, Collection) and not isinstance(columns, str) - else [columns] # type: ignore[list-item] - ), - *more_columns, - ] - drop_cols: list[str] = [] - for col in input_cols: - if isinstance(col, str): - drop_cols.append(col) - else: - selector_cols = selector_column_names(self, col) - drop_cols.extend(selector_cols) - + drop_cols = expand_selectors(self, columns, *more_columns) return self._from_pyldf(self._ldf.drop(drop_cols)) def rename(self, mapping: dict[str, str]) -> Self: diff --git a/py-polars/polars/selectors.py b/py-polars/polars/selectors.py index 7029bd59c3110..e593fa0cf15e6 100644 --- a/py-polars/polars/selectors.py +++ b/py-polars/polars/selectors.py @@ -31,6 +31,27 @@ from typing_extensions import Self +def expand_selectors( + frame: DataFrame | LazyFrame, items: Any, *more_items: Any +) -> list[Any]: + """Expand any selectors in the given input.""" + expanded: list[Any] = [] + for item in ( + *( + items + if isinstance(items, Collection) and not isinstance(items, str) + else [items] + ), + *more_items, + ): + if is_selector(item): + selector_cols = selector_column_names(frame, item) + expanded.extend(selector_cols) + else: + expanded.append(item) + return expanded + + def is_selector(obj: Any) -> bool: """ Indicate whether the given object/expression is a selector. @@ -121,7 +142,7 @@ def __repr__(self) -> str: set_ops = {"and": "&", "or": "|", "sub": "-"} if selector_name in set_ops: op = set_ops[selector_name] - return f" {op} ".join(repr(p) for p in params.values()) + return "(%s)" % f" {op} ".join(repr(p) for p in params.values()) else: str_params = ",".join( (repr(v)[1:-1] if k.startswith("*") else f"{k}={v!r}") @@ -1365,6 +1386,7 @@ def temporal() -> SelectorType: "temporal", "string", "is_selector", + "expand_selectors", "selector_column_names", "SelectorType", ] diff --git a/py-polars/tests/unit/test_exprs.py b/py-polars/tests/unit/test_exprs.py index d88b1f9faf417..3948ff331a1b7 100644 --- a/py-polars/tests/unit/test_exprs.py +++ b/py-polars/tests/unit/test_exprs.py @@ -870,7 +870,7 @@ def test_exclude(input: tuple[Any, ...], expected: list[str]) -> None: assert df.select(pl.all().exclude(*input)).columns == expected -@pytest.mark.parametrize("input", [(5,), (["a"], "b"), (pl.Int64, "a")]) +@pytest.mark.parametrize("input", [(5,), (["a"], date.today()), (pl.Int64, "a")]) def test_exclude_invalid_input(input: tuple[Any, ...]) -> None: df = pl.DataFrame(schema=["a", "b", "c"]) with pytest.raises(TypeError): diff --git a/py-polars/tests/unit/test_selectors.py b/py-polars/tests/unit/test_selectors.py index a6280d51357fa..cc6ee5e00bee7 100644 --- a/py-polars/tests/unit/test_selectors.py +++ b/py-polars/tests/unit/test_selectors.py @@ -331,11 +331,12 @@ def test_selector_expansion() -> None: def test_selector_repr() -> None: - assert repr(cs.all() - cs.first()) == "cs.all() - cs.first()" + assert repr(cs.all() - cs.first()) == "(cs.all() - cs.first())" assert repr(~cs.starts_with("a", "b")) == "~cs.starts_with('a', 'b')" - assert repr(cs.float() | cs.by_name("x")) == "cs.float() | cs.by_name('x')" + assert repr(cs.float() | cs.by_name("x")) == "(cs.float() | cs.by_name('x'))" assert ( - repr(cs.integer() & cs.matches("z")) == "cs.integer() & cs.matches(pattern='z')" + repr(cs.integer() & cs.matches("z")) + == "(cs.integer() & cs.matches(pattern='z'))" )