Skip to content

Commit

Permalink
add selector support to "to_dummies" frame method
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Aug 3, 2023
1 parent a7864ae commit 3532ae6
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
6 changes: 3 additions & 3 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7949,7 +7949,7 @@ def quantile(

def to_dummies(
self,
columns: str | Sequence[str] | None = None,
columns: str | Sequence[str] | SelectorType | None = None,
*,
separator: str = "_",
drop_first: bool = False,
Expand Down Expand Up @@ -7988,8 +7988,8 @@ def to_dummies(
└───────┴───────┴───────┴───────┴───────┴───────┘
"""
if isinstance(columns, str):
columns = [columns]
if columns is not None:
columns = expand_selectors(self, columns)
return self._from_pydf(self._df.to_dummies(columns, separator, drop_first))

def unique(
Expand Down
10 changes: 7 additions & 3 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,10 @@ def test_to_dummies() -> None:
assert_frame_equal(res, expected)

df = pl.DataFrame(
{"i": [1, 2, 3], "category": ["dog", "cat", "cat"]},
{
"i": [1, 2, 3],
"category": ["dog", "cat", "cat"],
},
schema={"i": pl.Int32, "category": pl.Categorical},
)
expected = pl.DataFrame(
Expand All @@ -816,8 +819,9 @@ def test_to_dummies() -> None:
},
schema={"i": pl.Int32, "category|cat": pl.UInt8, "category|dog": pl.UInt8},
)
result = df.to_dummies(columns=["category"], separator="|")
assert_frame_equal(result, expected)
for _cols in ("category", cs.string()):
result = df.to_dummies(columns=["category"], separator="|")
assert_frame_equal(result, expected)

# test sorted fast path
assert pl.DataFrame({"x": pl.arange(0, 3, eager=True)}).to_dummies("x").to_dict(
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/datatypes/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_struct_unnesting() -> None:
}
)
for cols in ("foo", cs.ends_with("oo")):
out = df.unnest(cols)
out = df.unnest(cols) # type: ignore[arg-type]
assert_frame_equal(out, expected)

out = (
Expand Down

0 comments on commit 3532ae6

Please sign in to comment.