Skip to content

Commit

Permalink
add selector expansion utility function, streamline "exclude" expression
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Aug 3, 2023
1 parent fa82d3b commit d25e3d9
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 56 deletions.
56 changes: 24 additions & 32 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,13 +781,14 @@ def keep_name(self) -> Self:

def exclude(
self,
columns: str | PolarsDataType | Iterable[str] | Iterable[PolarsDataType],
columns: str | PolarsDataType | Collection[str] | Collection[PolarsDataType],
*more_columns: str | PolarsDataType,
) -> Self:
"""
Exclude columns from a multi-column expression.
Only works after a wildcard or regex column selection.
Only works after a wildcard or regex column selection, and you cannot provide
both string column names *and* dtypes (you may prefer to use selectors instead).
Parameters
----------
Expand Down Expand Up @@ -862,43 +863,34 @@ def exclude(
└──────┘
"""
if more_columns:
if isinstance(columns, str):
columns_str = [columns]
columns_str.extend(more_columns) # type: ignore[arg-type]
return self._from_pyexpr(self._pyexpr.exclude(columns_str))
elif is_polars_dtype(columns):
dtypes = [columns]
dtypes.extend(more_columns)
return self._from_pyexpr(self._pyexpr.exclude_dtype(dtypes))
else:
raise TypeError(
f"Invalid input for `exclude`. Expected `str` or `DataType`, got {type(columns)!r}"
)

if isinstance(columns, str):
return self._from_pyexpr(self._pyexpr.exclude([columns]))
elif is_polars_dtype(columns):
return self._from_pyexpr(self._pyexpr.exclude_dtype([columns]))
elif isinstance(columns, Iterable):
columns_list = list(columns)
if not columns_list:
return self

item = columns_list[0]
exclude_cols: list[str] = []
exclude_dtypes: list[PolarsDataType] = []
for item in (
*(
columns
if isinstance(columns, Collection) and not isinstance(columns, str)
else [columns]
),
*more_columns,
):
if isinstance(item, str):
return self._from_pyexpr(self._pyexpr.exclude(columns_list))
exclude_cols.append(item)
elif is_polars_dtype(item):
return self._from_pyexpr(self._pyexpr.exclude_dtype(columns_list))
exclude_dtypes.append(item)
else:
raise TypeError(
"Invalid input for `exclude`. Expected iterable of type `str` or `DataType`,"
f" got iterable of type {type(item)!r}"
"Invalid input for `exclude`. Expected one or more `str`, "
f"`DataType`, or selector; found {type(item)!r} instead"
)
else:

if exclude_cols and exclude_dtypes:
raise TypeError(
f"Invalid input for `exclude`. Expected `str` or `DataType`, got {type(columns)!r}"
"Cannot exclude by both column name and dtype; use a selector instead"
)
elif exclude_dtypes:
return self._from_pyexpr(self._pyexpr.exclude_dtype(exclude_dtypes))
else:
return self._from_pyexpr(self._pyexpr.exclude(exclude_cols))

def pipe(
self,
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/functions/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


if TYPE_CHECKING:
from typing import Literal
from typing import Collection, Literal

from polars import DataFrame, Expr, LazyFrame, Series
from polars.type_aliases import (
Expand Down Expand Up @@ -1648,7 +1648,7 @@ def arctan2d(y: str | Expr, x: str | Expr) -> Expr:


def exclude(
columns: str | PolarsDataType | Iterable[str] | Iterable[PolarsDataType],
columns: str | PolarsDataType | Collection[str] | Collection[PolarsDataType],
*more_columns: str | PolarsDataType,
) -> Expr:
"""
Expand Down
19 changes: 2 additions & 17 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from polars.io.ipc.anonymous_scan import _scan_ipc_fsspec
from polars.io.parquet.anonymous_scan import _scan_parquet_fsspec
from polars.lazyframe.groupby import LazyGroupBy
from polars.selectors import selector_column_names
from polars.selectors import expand_selectors
from polars.slice import LazyPolarsSlice
from polars.utils._parse_expr_input import (
parse_as_expression,
Expand Down Expand Up @@ -3370,22 +3370,7 @@ def drop(
└─────┘
"""
input_cols: list[str | SelectorType] = [
*(
columns
if isinstance(columns, Collection) and not isinstance(columns, str)
else [columns] # type: ignore[list-item]
),
*more_columns,
]
drop_cols: list[str] = []
for col in input_cols:
if isinstance(col, str):
drop_cols.append(col)
else:
selector_cols = selector_column_names(self, col)
drop_cols.extend(selector_cols)

drop_cols = expand_selectors(self, columns, *more_columns)
return self._from_pyldf(self._ldf.drop(drop_cols))

def rename(self, mapping: dict[str, str]) -> Self:
Expand Down
24 changes: 23 additions & 1 deletion py-polars/polars/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,27 @@
from typing_extensions import Self


def expand_selectors(
frame: DataFrame | LazyFrame, items: Any, *more_items: Any
) -> list[Any]:
"""Expand any selectors in the given input."""
expanded: list[Any] = []
for item in (
*(
items
if isinstance(items, Collection) and not isinstance(items, str)
else [items]
),
*more_items,
):
if is_selector(item):
selector_cols = selector_column_names(frame, item)
expanded.extend(selector_cols)
else:
expanded.append(item)
return expanded


def is_selector(obj: Any) -> bool:
"""
Indicate whether the given object/expression is a selector.
Expand Down Expand Up @@ -121,7 +142,7 @@ def __repr__(self) -> str:
set_ops = {"and": "&", "or": "|", "sub": "-"}
if selector_name in set_ops:
op = set_ops[selector_name]
return f" {op} ".join(repr(p) for p in params.values())
return "(%s)" % f" {op} ".join(repr(p) for p in params.values())
else:
str_params = ",".join(
(repr(v)[1:-1] if k.startswith("*") else f"{k}={v!r}")
Expand Down Expand Up @@ -1365,6 +1386,7 @@ def temporal() -> SelectorType:
"temporal",
"string",
"is_selector",
"expand_selectors",
"selector_column_names",
"SelectorType",
]
2 changes: 1 addition & 1 deletion py-polars/tests/unit/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ def test_exclude(input: tuple[Any, ...], expected: list[str]) -> None:
assert df.select(pl.all().exclude(*input)).columns == expected


@pytest.mark.parametrize("input", [(5,), (["a"], "b"), (pl.Int64, "a")])
@pytest.mark.parametrize("input", [(5,), (["a"], date.today()), (pl.Int64, "a")])
def test_exclude_invalid_input(input: tuple[Any, ...]) -> None:
df = pl.DataFrame(schema=["a", "b", "c"])
with pytest.raises(TypeError):
Expand Down
7 changes: 4 additions & 3 deletions py-polars/tests/unit/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,12 @@ def test_selector_expansion() -> None:


def test_selector_repr() -> None:
assert repr(cs.all() - cs.first()) == "cs.all() - cs.first()"
assert repr(cs.all() - cs.first()) == "(cs.all() - cs.first())"
assert repr(~cs.starts_with("a", "b")) == "~cs.starts_with('a', 'b')"
assert repr(cs.float() | cs.by_name("x")) == "cs.float() | cs.by_name('x')"
assert repr(cs.float() | cs.by_name("x")) == "(cs.float() | cs.by_name('x'))"
assert (
repr(cs.integer() & cs.matches("z")) == "cs.integer() & cs.matches(pattern='z')"
repr(cs.integer() & cs.matches("z"))
== "(cs.integer() & cs.matches(pattern='z'))"
)


Expand Down

0 comments on commit d25e3d9

Please sign in to comment.