Skip to content

Commit

Permalink
refactor(python): deprecate bins argument and rename to breaks in…
Browse files Browse the repository at this point in the history
… `Series.cut` (#9913)
  • Loading branch information
mcrumiller authored Jul 16, 2023
1 parent 2021249 commit f93e796
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 13 deletions.
2 changes: 1 addition & 1 deletion py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3298,7 +3298,7 @@ def cut(
breaks
A list of unique cut points.
labels
Labels to assign to bins. If given, the length must be len(probs) + 1.
Labels to assign to bins. If given, the length must be len(breaks) + 1.
left_closed
Whether intervals should be [) instead of the default of (]
include_breaks
Expand Down
15 changes: 8 additions & 7 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1604,9 +1604,10 @@ def to_dummies(self, separator: str = "_") -> DataFrame:
"""
return wrap_df(self._s.to_dummies(separator))

@deprecated_alias(bins="breaks")
def cut(
self,
bins: list[float],
breaks: list[float],
labels: list[str] | None = None,
break_point_label: str = "break_point",
category_label: str = "category",
Expand All @@ -1620,11 +1621,11 @@ def cut(
Parameters
----------
bins
Bins to create.
breaks
A list of unique cut points.
labels
Labels to assign to the bins. If given the length of labels must be
len(bins) + 1.
len(breaks) + 1.
break_point_label
Name given to the breakpoint column/field. Only used if series == False or
include_breaks == True
Expand Down Expand Up @@ -1707,14 +1708,14 @@ def cut(
return (
self.to_frame()
.with_columns(
F.col(n).cut(bins, labels, left_closed, True).alias(n + "_bin")
F.col(n).cut(breaks, labels, left_closed, True).alias(n + "_bin")
)
.unnest(n + "_bin")
.rename({"brk": break_point_label, n + "_bin": category_label})
)
res = (
self.to_frame()
.select(F.col(n).cut(bins, labels, left_closed, include_breaks))
.select(F.col(n).cut(breaks, labels, left_closed, include_breaks))
.to_series()
)
if include_breaks:
Expand Down Expand Up @@ -1743,7 +1744,7 @@ def qcut(
We expect quantiles ``0.0 <= quantile <= 1``
labels
Labels to assign to the quantiles. If given the length of labels must be
len(bins) + 1.
len(breaks) + 1.
break_point_label
Name given to the breakpoint column/field. Only used if series == False or
include_breaks == True
Expand Down
8 changes: 4 additions & 4 deletions py-polars/tests/unit/operations/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_corr() -> None:

def test_cut() -> None:
a = pl.Series("a", [v / 10 for v in range(-30, 30, 5)])
out = cast(pl.DataFrame, a.cut(bins=[-1, 1], series=False))
out = cast(pl.DataFrame, a.cut(breaks=[-1, 1], series=False))

assert out.shape == (12, 3)
assert out.filter(pl.col("break_point") < 1e9).to_dict(False) == {
Expand All @@ -50,7 +50,7 @@ def test_cut() -> None:
inf = float("inf")
df = pl.DataFrame({"a": list(range(5))})
ser = df.select("a").to_series()
assert cast(pl.DataFrame, ser.cut(bins=[-1, 1], series=False)).rows() == [
assert cast(pl.DataFrame, ser.cut(breaks=[-1, 1], series=False)).rows() == [
(0.0, 1.0, "(-1, 1]"),
(1.0, 1.0, "(-1, 1]"),
(2.0, inf, "(1, inf]"),
Expand Down Expand Up @@ -78,8 +78,8 @@ def test_cut() -> None:
)
np.random.seed(1)
a = pl.Series("a", np.random.randint(0, 10, 10))
out = cast(pl.DataFrame, a.cut(bins=[-1, 1], series=False))
out_s = cast(pl.Series, a.cut(bins=[-1, 1], series=True))
out = cast(pl.DataFrame, a.cut(breaks=[-1, 1], series=False))
out_s = cast(pl.Series, a.cut(breaks=[-1, 1], series=True))
assert out["a"].cast(int).series_equal(a)
# Compare strings and categoricals without a hassle
assert_frame_equal(expected_df, out, check_dtype=False)
Expand Down
4 changes: 3 additions & 1 deletion py-polars/tests/unit/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def test_error_on_reducing_map() -> None:
),
):
df.select(
pl.col("x").map(lambda x: x.cut(bins=[1, 2, 3], series=False)).over("group")
pl.col("x")
.map(lambda x: x.cut(breaks=[1, 2, 3], series=False))
.over("group")
)


Expand Down

0 comments on commit f93e796

Please sign in to comment.