Skip to content

Commit

Permalink
Expose options to Series' cut and qcut (#1007)
Browse files Browse the repository at this point in the history
* Add ":allow_duplicates" and ":left_close" opts to qcut

This may fix #1006

* Add more options to `Series.cut/3`, but without docs
  • Loading branch information
philss authored Oct 23, 2024
1 parent 29a93ea commit 9845751
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 23 deletions.
4 changes: 2 additions & 2 deletions lib/explorer/backend/lazy_series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1244,9 +1244,9 @@ defmodule Explorer.Backend.LazySeries do
at_every: 2,
categories: 1,
categorise: 2,
cut: 5,
cut: 7,
frequencies: 1,
qcut: 5,
qcut: 8,
mask: 2,
owner_import: 1,
owner_export: 1,
Expand Down
21 changes: 19 additions & 2 deletions lib/explorer/backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,26 @@ defmodule Explorer.Backend.Series do

# Categorisation

@callback cut(s, [float()], [String.t()] | nil, String.t() | nil, String.t() | nil) ::
@callback cut(
s,
bins :: [float()],
labels :: option([String.t()]),
break_point_label :: option(String.t()),
category_label :: option(String.t()),
left_close :: boolean(),
include_breaks :: boolean()
) ::
df
@callback qcut(s, [float()], [String.t()] | nil, String.t() | nil, String.t() | nil) ::
@callback qcut(
s,
quantiles :: [float()],
labels :: option([String.t()]),
break_point_label :: option(String.t()),
category_label :: option(String.t()),
allow_duplicates :: boolean(),
left_close :: boolean(),
include_breaks :: boolean()
) ::
df

# Rolling
Expand Down
26 changes: 23 additions & 3 deletions lib/explorer/polars_backend/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -418,13 +418,33 @@ defmodule Explorer.PolarsBackend.Native do
def s_upcase(_s), do: err()
def s_unordered_distinct(_s), do: err()
def s_frequencies(_s), do: err()
def s_cut(_s, _bins, _labels, _break_point_label, _category_label), do: err()

def s_cut(
_s,
_bins,
_labels,
_break_point_label,
_category_label,
_left_close,
_include_breaks
),
do: err()

def s_substring(_s, _offset, _length), do: err()
def s_split(_s, _by), do: err()
def s_split_into(_s, _by, _num_fields), do: err()

def s_qcut(_s, _quantiles, _labels, _break_point_label, _category_label),
do: err()
def s_qcut(
_s,
_quantiles,
_labels,
_break_point_label,
_category_label,
_allow_duplicates,
_left_close,
_include_breaks
),
do: err()

def s_variance(_s, _ddof), do: err()
def s_window_max(_s, _window_size, _weight, _ignore_null, _min_periods), do: err()
Expand Down
22 changes: 18 additions & 4 deletions lib/explorer/polars_backend/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -541,13 +541,15 @@ defmodule Explorer.PolarsBackend.Series do
# Categorisation

@impl true
def cut(series, bins, labels, break_point_label, category_label) do
def cut(series, bins, labels, break_point_label, category_label, left_close, include_breaks) do
case Explorer.PolarsBackend.Native.s_cut(
series.data,
bins,
labels,
break_point_label,
category_label
category_label,
left_close,
include_breaks
) do
{:ok, polars_df} ->
Shared.create_dataframe!(polars_df)
Expand All @@ -561,13 +563,25 @@ defmodule Explorer.PolarsBackend.Series do
end

@impl true
def qcut(series, quantiles, labels, break_point_label, category_label) do
def qcut(
series,
quantiles,
labels,
break_point_label,
category_label,
allow_duplicates,
left_close,
include_breaks
) do
Shared.apply(:s_qcut, [
series.data,
quantiles,
labels,
break_point_label,
category_label
category_label,
allow_duplicates,
left_close,
include_breaks
])
|> Shared.create_dataframe!()
end
Expand Down
39 changes: 33 additions & 6 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -4842,11 +4842,22 @@ defmodule Explorer.Series do
"""
@doc type: :aggregation
def cut(series, bins, opts \\ []) do
opts =
Keyword.validate!(opts,
labels: nil,
break_point_label: nil,
category_label: nil,
left_close: false,
include_breaks: true
)

apply_series(series, :cut, [
Enum.map(bins, &(&1 / 1.0)),
Keyword.get(opts, :labels),
Keyword.get(opts, :break_point_label),
Keyword.get(opts, :category_label)
opts[:labels],
opts[:break_point_label],
opts[:category_label],
opts[:left_close],
opts[:include_breaks]
])
end

Expand All @@ -4868,6 +4879,9 @@ defmodule Explorer.Series do
* `:category_label` - The name given to the category column.
Defaults to `category`.
* `:allow_duplicates` - If quantiles can have duplicated values.
Defaults to `false`.
## Examples
iex> s = Explorer.Series.from_list([1.0, 2.0, 3.0, 4.0, 5.0])
Expand All @@ -4881,11 +4895,24 @@ defmodule Explorer.Series do
"""
@doc type: :aggregation
def qcut(series, quantiles, opts \\ []) do
opts =
Keyword.validate!(opts,
labels: nil,
break_point_label: nil,
category_label: nil,
allow_duplicates: false,
left_close: false,
include_breaks: true
)

apply_series(series, :qcut, [
Enum.map(quantiles, &(&1 / 1.0)),
Keyword.get(opts, :labels),
Keyword.get(opts, :break_point_label),
Keyword.get(opts, :category_label)
opts[:labels],
opts[:break_point_label],
opts[:category_label],
opts[:allow_duplicates],
opts[:left_close],
opts[:include_breaks]
])
end

Expand Down
13 changes: 8 additions & 5 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,18 @@ pub fn s_cut(
labels: Option<Vec<String>>,
break_point_label: Option<&str>,
category_label: Option<&str>,
left_close: bool,
include_breaks: bool,
) -> Result<ExDataFrame, ExplorerError> {
let series = series.clone_inner();
let left_close = false;

// Cut is going to return a Series of a Struct. We need to convert it to a DF.
let cut_series = cut(
&series,
bins,
labels.map(|vec| vec.iter().map(|label| label.into()).collect()),
left_close,
true,
include_breaks,
)?;
let mut cut_df = DataFrame::new(cut_series.struct_()?.fields_as_series())?;

Expand All @@ -231,25 +232,27 @@ pub fn s_cut(
Ok(ExDataFrame::new(cut_df.clone()))
}

#[allow(clippy::too_many_arguments)]
#[rustler::nif(schedule = "DirtyCpu")]
pub fn s_qcut(
series: ExSeries,
quantiles: Vec<f64>,
labels: Option<Vec<String>>,
break_point_label: Option<&str>,
category_label: Option<&str>,
allow_duplicates: bool,
left_close: bool,
include_breaks: bool,
) -> Result<ExDataFrame, ExplorerError> {
let series = series.clone_inner();
let left_close = false;
let allow_duplicates = false;

let qcut_series: Series = qcut(
&series,
quantiles,
labels.map(|vec| vec.iter().map(|label| label.into()).collect()),
left_close,
allow_duplicates,
true,
include_breaks,
)?;

let mut qcut_df = DataFrame::new(qcut_series.struct_()?.fields_as_series())?;
Expand Down
16 changes: 15 additions & 1 deletion test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5981,7 +5981,7 @@ defmodule Explorer.SeriesTest do
assert Explorer.DataFrame.names(df) == ["values", "bp", "cat"]
end

test "qcut/6" do
test "qcut/3" do
series = Enum.to_list(-5..3) |> Series.from_list()
df = Series.qcut(series, [0.0, 0.25, 0.75])
freqs = Series.frequencies(df[:category])
Expand All @@ -5995,6 +5995,20 @@ defmodule Explorer.SeriesTest do

assert Series.to_list(freqs[:counts]) == [4, 2, 2, 1]
end

test "qcut/3 with duplicates" do
series = Explorer.Series.from_list([0.0, 0.0, 0.0, 3.0, 4.0, 5.0])
df = Explorer.Series.qcut(series, [0.1, 0.25, 0.75], allow_duplicates: true)
freqs = Series.frequencies(df[:category])

assert Series.to_list(freqs[:values]) == [
"(-inf, 0]",
"(3.75, inf]",
"(0, 3.75]"
]

assert Series.to_list(freqs[:counts]) == [3, 2, 1]
end
end

describe "join/2" do
Expand Down

0 comments on commit 9845751

Please sign in to comment.