diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 94cb0ef997e11..d176a6c67f63c 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -15,7 +15,6 @@ BinaryIO, Callable, ClassVar, - Collection, Generator, Iterable, Iterator, @@ -124,6 +123,8 @@ AsofJoinStrategy, AvroCompression, ClosedInterval, + ColumnNameOrSelectorCollection, + ColumnNameOrSelectorSequence, ColumnTotalsDefinition, ComparisonOperator, ConditionalFormatDict, @@ -4512,9 +4513,7 @@ def limit(self, n: int = 5) -> Self: def drop_nulls( self, - subset: ( - str | SelectorType | Collection[str] | Collection[SelectorType] | None - ) = None, + subset: ColumnNameOrSelectorCollection | None = None, ) -> DataFrame: """ Drop all rows that contain null values. @@ -6028,7 +6027,7 @@ def extend(self, other: DataFrame) -> Self: def drop( self, - columns: str | Collection[str] | SelectorType, + columns: ColumnNameOrSelectorCollection, *more_columns: str | SelectorType, ) -> DataFrame: """ @@ -6473,15 +6472,8 @@ def fill_nan(self, value: Expr | int | float | None) -> DataFrame: def explode( self, - columns: ( - str - | Expr - | SelectorType - | Sequence[str] - | Sequence[Expr] - | Sequence[SelectorType] - ), - *more_columns: str | Expr | SelectorType, + columns: str | Expr | Sequence[str | Expr], + *more_columns: str | Expr, ) -> DataFrame: """ Explode the dataframe to long format by exploding the given columns. @@ -6489,8 +6481,8 @@ def explode( Parameters ---------- columns - Name of the column(s) to explode. Columns must be of datatype List or Utf8. - Accepts ``col`` expressions as input as well. + Column names, expressions, or a selector defining them. The underlying + columns being exploded must be of List or Utf8 datatype. *more_columns Additional names of columns to explode, specified as positional arguments. @@ -6540,9 +6532,9 @@ def explode( def pivot( self, - values: str | Sequence[str] | SelectorType, - index: str | Sequence[str] | SelectorType, - columns: str | Sequence[str] | SelectorType, + values: ColumnNameOrSelectorSequence | None, + index: ColumnNameOrSelectorSequence | None, + columns: ColumnNameOrSelectorSequence | None, aggregate_function: PivotAgg | Expr | None | NoDefault = no_default, *, maintain_order: bool = True, @@ -6772,9 +6764,7 @@ def unstack( self, step: int, how: UnstackDirection = "vertical", - columns: ( - str | SelectorType | Sequence[str] | Sequence[SelectorType] | None - ) = None, + columns: ColumnNameOrSelectorSequence | None = None, fill_values: list[Any] | None = None, ) -> DataFrame: """ @@ -6916,7 +6906,7 @@ def unstack( @overload def partition_by( self, - by: str | SelectorType | Iterable[str] | Iterable[SelectorType], + by: ColumnNameOrSelectorSequence, *more_by: str, maintain_order: bool = ..., include_key: bool = ..., @@ -6927,7 +6917,7 @@ def partition_by( @overload def partition_by( self, - by: str | SelectorType | Iterable[str] | Iterable[SelectorType], + by: ColumnNameOrSelectorSequence, *more_by: str, maintain_order: bool = ..., include_key: bool = ..., @@ -6937,7 +6927,7 @@ def partition_by( def partition_by( self, - by: str | SelectorType | Iterable[str] | Iterable[SelectorType], + by: ColumnNameOrSelectorSequence, *more_by: str | SelectorType, maintain_order: bool = True, include_key: bool = True, @@ -8005,9 +7995,7 @@ def quantile( def to_dummies( self, - columns: ( - str | SelectorType | Sequence[str] | Sequence[SelectorType] | None - ) = None, + columns: ColumnNameOrSelectorSequence | None = None, *, separator: str = "_", drop_first: bool = False, @@ -8086,9 +8074,7 @@ def to_dummies( def unique( self, - subset: ( - str | SelectorType | Collection[str] | Collection[SelectorType] | None - ) = None, + subset: ColumnNameOrSelectorCollection | None = None, *, keep: UniqueKeepStrategy = "any", maintain_order: bool = False, @@ -9105,7 +9091,7 @@ def to_struct(self, name: str) -> Series: def unnest( self, - columns: str | SelectorType | Sequence[str] | Sequence[SelectorType], + columns: ColumnNameOrSelectorCollection, *more_columns: str | SelectorType, ) -> Self: """ diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 2bc5a3b3ab6ff..727e0ec23a697 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -10,7 +10,6 @@ Any, Callable, ClassVar, - Collection, Iterable, NoReturn, Sequence, @@ -81,6 +80,8 @@ from polars.type_aliases import ( AsofJoinStrategy, ClosedInterval, + ColumnNameOrSelectorCollection, + ColumnNameOrSelectorSequence, CsvEncoding, FillNullStrategy, FrameInitTypes, @@ -3304,7 +3305,7 @@ def with_context(self, other: Self | list[Self]) -> Self: def drop( self, - columns: str | Collection[str] | SelectorType, + columns: ColumnNameOrSelectorCollection, *more_columns: str | SelectorType, ) -> Self: """ @@ -4303,15 +4304,8 @@ def quantile( def explode( self, - columns: ( - str - | Expr - | SelectorType - | Sequence[str] - | Sequence[Expr] - | Sequence[SelectorType] - ), - *more_columns: str | Expr | SelectorType, + columns: str | Expr | Sequence[str | Expr], + *more_columns: str | Expr, ) -> Self: """ Explode the dataframe to long format by exploding the given columns. @@ -4319,8 +4313,8 @@ def explode( Parameters ---------- columns - Name of the column(s) to explode. Columns must be of datatype List or Utf8. - Accepts ``col`` expressions as input as well. + Column names, expressions, or a selector defining them. The underlying + columns being exploded must be of List or Utf8 datatype. *more_columns Additional names of columns to explode, specified as positional arguments. @@ -4357,9 +4351,7 @@ def explode( def unique( self, - subset: ( - str | SelectorType | Collection[str] | Collection[SelectorType] | None - ) = None, + subset: ColumnNameOrSelectorCollection | None = None, *, keep: UniqueKeepStrategy = "any", maintain_order: bool = False, @@ -4444,9 +4436,7 @@ def unique( def drop_nulls( self, - subset: ( - str | SelectorType | Collection[str] | Collection[SelectorType] | None - ) = None, + subset: ColumnNameOrSelectorCollection | None = None, ) -> Self: """ Drop all rows that contain null values. @@ -4544,8 +4534,8 @@ def drop_nulls( def melt( self, - id_vars: str | Sequence[str] | SelectorType | None = None, - value_vars: str | Sequence[str] | SelectorType | None = None, + id_vars: ColumnNameOrSelectorSequence | None = None, + value_vars: ColumnNameOrSelectorSequence | None = None, variable_name: str | None = None, value_name: str | None = None, *, diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index c716fe6bbc4e3..66c3cfe5c7237 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -68,8 +68,14 @@ IntoExpr: TypeAlias = Union["Expr", PythonLiteral, "Series", None] ComparisonOperator: TypeAlias = Literal["eq", "neq", "gt", "lt", "gt_eq", "lt_eq"] -# selector type and column +# selector type, and related collection/sequence SelectorType: TypeAlias = "_selector_proxy_" +ColumnNameOrSelectorCollection: TypeAlias = Union[ + str, SelectorType, Collection[Union[str, SelectorType]] +] +ColumnNameOrSelectorSequence: TypeAlias = Union[ + str, SelectorType, Sequence[Union[str, SelectorType]] +] # User-facing string literal types # The following all have an equivalent Rust enum with the same name