Skip to content

Commit

Permalink
Fix cut/3 and qcut/3 when :include_breaks is false [ci skip] (#…
Browse files Browse the repository at this point in the history
…1009)

* Fix `cut/3` and `qcut/3` when `:include_breaks` is false

Also add docs regarding the `:left_close` and `:include_breaks` options.

Thanks @billylanchantin :D
Reference: #1007 (comment)

* Change default of `cut/qcut` to not include breaks
  • Loading branch information
philss authored Oct 23, 2024
1 parent 9845751 commit d865544
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 24 deletions.
33 changes: 28 additions & 5 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4824,21 +4824,37 @@ defmodule Explorer.Series do
bounds (e.g. `(-inf -1.0]`, `(-1.0, 1.0]`, `(1.0, inf]`)
* `:break_point_label` - The name given to the breakpoint column.
This is only relevant if `:include_breaks` is `true`.
Defaults to `break_point`.
* `:category_label` - The name given to the category column.
Defaults to `category`.
* `:left_closed` - Set the intervals to be left-closed instead
of right-closed. Defaults to `false`.
* `:include_breaks` - Include a column with the right endpoint
of the bin that each observation falls in.
Defaults to `false`.
## Examples
iex> s = Explorer.Series.from_list([1.0, 2.0, 3.0])
iex> Explorer.Series.cut(s, [1.5, 2.5])
iex> Explorer.Series.cut(s, [1.5, 2.5], include_breaks: true)
#Explorer.DataFrame<
Polars[3 x 3]
values f64 [1.0, 2.0, 3.0]
break_point f64 [1.5, 2.5, Inf]
category category ["(-inf, 1.5]", "(1.5, 2.5]", "(2.5, inf]"]
>
iex> s = Explorer.Series.from_list([1.0, 2.0, 3.0])
iex> Explorer.Series.cut(s, [1.5, 2.5])
#Explorer.DataFrame<
Polars[3 x 2]
values f64 [1.0, 2.0, 3.0]
category category ["(-inf, 1.5]", "(1.5, 2.5]", "(2.5, inf]"]
>
"""
@doc type: :aggregation
def cut(series, bins, opts \\ []) do
Expand All @@ -4848,7 +4864,7 @@ defmodule Explorer.Series do
break_point_label: nil,
category_label: nil,
left_close: false,
include_breaks: true
include_breaks: false
)

apply_series(series, :cut, [
Expand All @@ -4874,6 +4890,7 @@ defmodule Explorer.Series do
bounds (e.g. `(-inf -1.0]`, `(-1.0, 1.0]`, `(1.0, inf]`)
* `:break_point_label` - The name given to the breakpoint column.
This is only relevant if `:include_breaks` is `true`.
Defaults to `break_point`.
* `:category_label` - The name given to the category column.
Expand All @@ -4882,14 +4899,20 @@ defmodule Explorer.Series do
* `:allow_duplicates` - If quantiles can have duplicated values.
Defaults to `false`.
* `:left_closed` - Set the intervals to be left-closed instead
of right-closed. Defaults to `false`.
* `:include_breaks` - Include a column with the right endpoint
of the bin that each observation falls in.
Defaults to `false`.
## Examples
iex> s = Explorer.Series.from_list([1.0, 2.0, 3.0, 4.0, 5.0])
iex> Explorer.Series.qcut(s, [0.25, 0.75])
#Explorer.DataFrame<
Polars[5 x 3]
Polars[5 x 2]
values f64 [1.0, 2.0, 3.0, 4.0, 5.0]
break_point f64 [2.0, 2.0, 4.0, 4.0, Inf]
category category ["(-inf, 2]", "(-inf, 2]", "(2, 4]", "(2, 4]", "(4, inf]"]
>
"""
Expand All @@ -4902,7 +4925,7 @@ defmodule Explorer.Series do
category_label: nil,
allow_duplicates: false,
left_close: false,
include_breaks: true
include_breaks: false
)

apply_series(series, :qcut, [
Expand Down
47 changes: 31 additions & 16 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,25 @@ pub fn s_cut(
left_close,
include_breaks,
)?;
let mut cut_df = DataFrame::new(cut_series.struct_()?.fields_as_series())?;

let cut_df = cut_df.insert_column(0, series)?;
if include_breaks {
let mut cut_df = DataFrame::new(cut_series.struct_()?.fields_as_series())?;

cut_df.set_column_names([
"values",
break_point_label.unwrap_or("break_point"),
category_label.unwrap_or("category"),
])?;
let cut_df = cut_df.insert_column(0, series)?;

Ok(ExDataFrame::new(cut_df.clone()))
cut_df.set_column_names([
"values",
break_point_label.unwrap_or("break_point"),
category_label.unwrap_or("category"),
])?;

Ok(ExDataFrame::new(cut_df.clone()))
} else {
let mut cut_df = DataFrame::new(vec![series, cut_series])?;
cut_df.set_column_names(["values", category_label.unwrap_or("category")])?;

Ok(ExDataFrame::new(cut_df.clone()))
}
}

#[allow(clippy::too_many_arguments)]
Expand All @@ -255,16 +263,23 @@ pub fn s_qcut(
include_breaks,
)?;

let mut qcut_df = DataFrame::new(qcut_series.struct_()?.fields_as_series())?;
let qcut_df = qcut_df.insert_column(0, series)?;
if include_breaks {
let mut qcut_df = DataFrame::new(qcut_series.struct_()?.fields_as_series())?;
let qcut_df = qcut_df.insert_column(0, series)?;

qcut_df.set_column_names([
"values",
break_point_label.unwrap_or("break_point"),
category_label.unwrap_or("category"),
])?;
qcut_df.set_column_names([
"values",
break_point_label.unwrap_or("break_point"),
category_label.unwrap_or("category"),
])?;

Ok(ExDataFrame::new(qcut_df.clone()))
Ok(ExDataFrame::new(qcut_df.clone()))
} else {
let mut qcut_df = DataFrame::new(vec![series, qcut_series])?;
qcut_df.set_column_names(["values", category_label.unwrap_or("category")])?;

Ok(ExDataFrame::new(qcut_df.clone()))
}
}

#[rustler::nif(schedule = "DirtyCpu")]
Expand Down
38 changes: 35 additions & 3 deletions test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5950,21 +5950,21 @@ defmodule Explorer.SeriesTest do
end

describe "categorisation functions" do
test "cut/6 with no nils" do
test "cut/3 with no nils" do
series = -30..30//5 |> Enum.map(&(&1 / 10)) |> Enum.to_list() |> Series.from_list()
df = Series.cut(series, [-1, 1])
freqs = Series.frequencies(df[:category])
assert Series.to_list(freqs[:values]) == ["(-inf, -1]", "(-1, 1]", "(1, inf]"]
assert Series.to_list(freqs[:counts]) == [5, 4, 4]
end

test "cut/6 with nils" do
test "cut/3 with nils" do
series = Series.from_list([1, 2, 3, nil, nil])
df = Series.cut(series, [2])
assert [_, _, _, nil, nil] = Series.to_list(df[:category])
end

test "cut/6 options" do
test "cut/3 options" do
series = Series.from_list([1, 2, 3])

assert_raise ArgumentError,
Expand All @@ -5973,6 +5973,7 @@ defmodule Explorer.SeriesTest do

df =
Series.cut(series, [2],
include_breaks: true,
labels: ["x", "y"],
break_point_label: "bp",
category_label: "cat"
Expand All @@ -5981,6 +5982,17 @@ defmodule Explorer.SeriesTest do
assert Explorer.DataFrame.names(df) == ["values", "bp", "cat"]
end

test "cut/3 with include breaks" do
series = Series.from_list([1.0, 2.0, 3.0])
df = Series.cut(series, [1.5, 2.5], include_breaks: true)

assert Explorer.DataFrame.to_columns(df, atom_keys: true) == %{
category: ["(-inf, 1.5]", "(1.5, 2.5]", "(2.5, inf]"],
break_point: [1.5, 2.5, :infinity],
values: [1.0, 2.0, 3.0]
}
end

test "qcut/3" do
series = Enum.to_list(-5..3) |> Series.from_list()
df = Series.qcut(series, [0.0, 0.25, 0.75])
Expand Down Expand Up @@ -6009,6 +6021,26 @@ defmodule Explorer.SeriesTest do

assert Series.to_list(freqs[:counts]) == [3, 2, 1]
end

test "qcut/3 without include breaks" do
series = Enum.to_list(-5..3) |> Series.from_list()
df = Series.qcut(series, [0.0, 0.25, 0.75], include_breaks: false)

assert Explorer.DataFrame.to_columns(df, atom_keys: true) == %{
category: [
"(-inf, -5]",
"(-5, -3]",
"(-5, -3]",
"(-3, 1]",
"(-3, 1]",
"(-3, 1]",
"(-3, 1]",
"(1, inf]",
"(1, inf]"
],
values: [-5, -4, -3, -2, -1, 0, 1, 2, 3]
}
end
end

describe "join/2" do
Expand Down

0 comments on commit d865544

Please sign in to comment.