diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index fe7edd24e828c..005c5778abd3e 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -7949,7 +7949,7 @@ def quantile( def to_dummies( self, - columns: str | Sequence[str] | None = None, + columns: str | Sequence[str] | SelectorType | None = None, *, separator: str = "_", drop_first: bool = False, @@ -7988,8 +7988,8 @@ def to_dummies( └───────┴───────┴───────┴───────┴───────┴───────┘ """ - if isinstance(columns, str): - columns = [columns] + if columns is not None: + columns = expand_selectors(self, columns) return self._from_pydf(self._df.to_dummies(columns, separator, drop_first)) def unique( diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 4ce0533f01214..18364b2661da8 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -805,7 +805,10 @@ def test_to_dummies() -> None: assert_frame_equal(res, expected) df = pl.DataFrame( - {"i": [1, 2, 3], "category": ["dog", "cat", "cat"]}, + { + "i": [1, 2, 3], + "category": ["dog", "cat", "cat"], + }, schema={"i": pl.Int32, "category": pl.Categorical}, ) expected = pl.DataFrame( @@ -816,8 +819,9 @@ def test_to_dummies() -> None: }, schema={"i": pl.Int32, "category|cat": pl.UInt8, "category|dog": pl.UInt8}, ) - result = df.to_dummies(columns=["category"], separator="|") - assert_frame_equal(result, expected) + for _cols in ("category", cs.string()): + result = df.to_dummies(columns=["category"], separator="|") + assert_frame_equal(result, expected) # test sorted fast path assert pl.DataFrame({"x": pl.arange(0, 3, eager=True)}).to_dummies("x").to_dict( diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index 747fd5ea9aaff..50a4f21315aa6 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -115,7 +115,7 @@ def test_struct_unnesting() -> None: } ) for cols in ("foo", cs.ends_with("oo")): - out = df.unnest(cols) + out = df.unnest(cols) # type: ignore[arg-type] assert_frame_equal(out, expected) out = (