From 9cc04a26cb45f0650332d995ab6291c3325af57d Mon Sep 17 00:00:00 2001 From: Marshall Date: Fri, 8 Mar 2024 07:06:16 -0500 Subject: [PATCH] fix(python): Add `drop_first` parameter to `Series.to_dummies` (#14846) --- py-polars/polars/series/series.py | 22 +++++++++++++++++++--- py-polars/tests/unit/series/test_series.py | 10 ++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 0e61efe8669f..6a1e9c4f127f 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -2108,7 +2108,9 @@ def quantile( """ return self._s.quantile(quantile, interpolation) - def to_dummies(self, separator: str = "_") -> DataFrame: + def to_dummies( + self, *, separator: str = "_", drop_first: bool = False + ) -> DataFrame: """ Get dummy/indicator variables. @@ -2116,6 +2118,8 @@ def to_dummies(self, separator: str = "_") -> DataFrame: ---------- separator Separator/delimiter used when generating column names. + drop_first + Remove the first category from the variable being encoded. Examples -------- @@ -2131,8 +2135,20 @@ def to_dummies(self, separator: str = "_") -> DataFrame: │ 0 ┆ 1 ┆ 0 │ │ 0 ┆ 0 ┆ 1 │ └─────┴─────┴─────┘ - """ - return wrap_df(self._s.to_dummies(separator)) + + >>> s.to_dummies(drop_first=True) + shape: (3, 2) + ┌─────┬─────┐ + │ a_2 ┆ a_3 │ + │ --- ┆ --- │ + │ u8 ┆ u8 │ + ╞═════╪═════╡ + │ 0 ┆ 0 │ + │ 1 ┆ 0 │ + │ 0 ┆ 1 │ + └─────┴─────┘ + """ + return wrap_df(self._s.to_dummies(separator, drop_first)) @overload def cut( diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 5f48a3347711..21faed0fcd92 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -1542,6 +1542,16 @@ def test_to_dummies() -> None: assert_frame_equal(result, expected) +def test_to_dummies_drop_first() -> None: + s = pl.Series("a", [1, 2, 3]) + result = s.to_dummies(drop_first=True) + expected = pl.DataFrame( + {"a_2": [0, 1, 0], "a_3": [0, 0, 1]}, + schema={"a_2": pl.UInt8, "a_3": pl.UInt8}, + ) + assert_frame_equal(result, expected) + + def test_chunk_lengths() -> None: s = pl.Series("a", [1, 2, 2, 3]) # this is a Series with one chunk, of length 4