From 2d9923eb84c1fe61cd9f172a4c197e00f1b3ca79 Mon Sep 17 00:00:00 2001 From: Philip Sampaio Date: Tue, 22 Oct 2024 12:43:07 -0300 Subject: [PATCH 1/2] Add ":allow_duplicates" and ":left_close" opts to qcut This may fix https://github.com/elixir-explorer/explorer/issues/1006 --- lib/explorer/backend/lazy_series.ex | 2 +- lib/explorer/backend/series.ex | 11 ++++++++++- lib/explorer/polars_backend/native.ex | 13 +++++++++++-- lib/explorer/polars_backend/series.ex | 16 ++++++++++++++-- lib/explorer/series.ex | 22 +++++++++++++++++++--- native/explorer/src/series.rs | 8 +++++--- test/explorer/series_test.exs | 16 +++++++++++++++- 7 files changed, 75 insertions(+), 13 deletions(-) diff --git a/lib/explorer/backend/lazy_series.ex b/lib/explorer/backend/lazy_series.ex index ce26fe40e..86b8afb36 100644 --- a/lib/explorer/backend/lazy_series.ex +++ b/lib/explorer/backend/lazy_series.ex @@ -1246,7 +1246,7 @@ defmodule Explorer.Backend.LazySeries do categorise: 2, cut: 5, frequencies: 1, - qcut: 5, + qcut: 8, mask: 2, owner_import: 1, owner_export: 1, diff --git a/lib/explorer/backend/series.ex b/lib/explorer/backend/series.ex index 517a2b5ef..96445ec77 100644 --- a/lib/explorer/backend/series.ex +++ b/lib/explorer/backend/series.ex @@ -192,7 +192,16 @@ defmodule Explorer.Backend.Series do @callback cut(s, [float()], [String.t()] | nil, String.t() | nil, String.t() | nil) :: df - @callback qcut(s, [float()], [String.t()] | nil, String.t() | nil, String.t() | nil) :: + @callback qcut( + s, + quantiles :: [float()], + labels :: option([String.t()]), + break_point_label :: option(String.t()), + category_label :: option(String.t()), + allow_duplicates :: boolean(), + left_close :: boolean(), + include_breaks :: boolean() + ) :: df # Rolling diff --git a/lib/explorer/polars_backend/native.ex b/lib/explorer/polars_backend/native.ex index 4d438b97e..53546e309 100644 --- a/lib/explorer/polars_backend/native.ex +++ b/lib/explorer/polars_backend/native.ex @@ -423,8 +423,17 @@ defmodule Explorer.PolarsBackend.Native do def s_split(_s, _by), do: err() def s_split_into(_s, _by, _num_fields), do: err() - def s_qcut(_s, _quantiles, _labels, _break_point_label, _category_label), - do: err() + def s_qcut( + _s, + _quantiles, + _labels, + _break_point_label, + _category_label, + _allow_duplicates, + _left_close, + _include_breaks + ), + do: err() def s_variance(_s, _ddof), do: err() def s_window_max(_s, _window_size, _weight, _ignore_null, _min_periods), do: err() diff --git a/lib/explorer/polars_backend/series.ex b/lib/explorer/polars_backend/series.ex index e21804f4b..27cc776c3 100644 --- a/lib/explorer/polars_backend/series.ex +++ b/lib/explorer/polars_backend/series.ex @@ -561,13 +561,25 @@ defmodule Explorer.PolarsBackend.Series do end @impl true - def qcut(series, quantiles, labels, break_point_label, category_label) do + def qcut( + series, + quantiles, + labels, + break_point_label, + category_label, + allow_duplicates, + left_close, + include_breaks + ) do Shared.apply(:s_qcut, [ series.data, quantiles, labels, break_point_label, - category_label + category_label, + allow_duplicates, + left_close, + include_breaks ]) |> Shared.create_dataframe!() end diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index 26a503959..a8d51ee0f 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -4868,6 +4868,9 @@ defmodule Explorer.Series do * `:category_label` - The name given to the category column. Defaults to `category`. + * `:allow_duplicates` - If quantiles can have duplicated values. + Defaults to `false`. + ## Examples iex> s = Explorer.Series.from_list([1.0, 2.0, 3.0, 4.0, 5.0]) @@ -4881,11 +4884,24 @@ defmodule Explorer.Series do """ @doc type: :aggregation def qcut(series, quantiles, opts \\ []) do + opts = + Keyword.validate!(opts, + labels: nil, + break_point_label: nil, + category_label: nil, + allow_duplicates: false, + left_close: false, + include_breaks: true + ) + apply_series(series, :qcut, [ Enum.map(quantiles, &(&1 / 1.0)), - Keyword.get(opts, :labels), - Keyword.get(opts, :break_point_label), - Keyword.get(opts, :category_label) + opts[:labels], + opts[:break_point_label], + opts[:category_label], + opts[:allow_duplicates], + opts[:left_close], + opts[:include_breaks] ]) end diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index 2af37ab7a..35f664c5b 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -231,6 +231,7 @@ pub fn s_cut( Ok(ExDataFrame::new(cut_df.clone())) } +#[allow(clippy::too_many_arguments)] #[rustler::nif(schedule = "DirtyCpu")] pub fn s_qcut( series: ExSeries, @@ -238,10 +239,11 @@ pub fn s_qcut( labels: Option>, break_point_label: Option<&str>, category_label: Option<&str>, + allow_duplicates: bool, + left_close: bool, + include_breaks: bool, ) -> Result { let series = series.clone_inner(); - let left_close = false; - let allow_duplicates = false; let qcut_series: Series = qcut( &series, @@ -249,7 +251,7 @@ pub fn s_qcut( labels.map(|vec| vec.iter().map(|label| label.into()).collect()), left_close, allow_duplicates, - true, + include_breaks, )?; let mut qcut_df = DataFrame::new(qcut_series.struct_()?.fields_as_series())?; diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index 3bfff11ca..307da6d75 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -5981,7 +5981,7 @@ defmodule Explorer.SeriesTest do assert Explorer.DataFrame.names(df) == ["values", "bp", "cat"] end - test "qcut/6" do + test "qcut/3" do series = Enum.to_list(-5..3) |> Series.from_list() df = Series.qcut(series, [0.0, 0.25, 0.75]) freqs = Series.frequencies(df[:category]) @@ -5995,6 +5995,20 @@ defmodule Explorer.SeriesTest do assert Series.to_list(freqs[:counts]) == [4, 2, 2, 1] end + + test "qcut/3 with duplicates" do + series = Explorer.Series.from_list([0.0, 0.0, 0.0, 3.0, 4.0, 5.0]) + df = Explorer.Series.qcut(series, [0.1, 0.25, 0.75], allow_duplicates: true) + freqs = Series.frequencies(df[:category]) + + assert Series.to_list(freqs[:values]) == [ + "(-inf, 0]", + "(3.75, inf]", + "(0, 3.75]" + ] + + assert Series.to_list(freqs[:counts]) == [3, 2, 1] + end end describe "join/2" do From 5c7133dad6ed40db6d4dd13a77a090eb4f1784c9 Mon Sep 17 00:00:00 2001 From: Philip Sampaio Date: Tue, 22 Oct 2024 20:57:21 -0300 Subject: [PATCH 2/2] Add more options to `Series.cut/3`, but without docs --- lib/explorer/backend/lazy_series.ex | 2 +- lib/explorer/backend/series.ex | 10 +++++++++- lib/explorer/polars_backend/native.ex | 13 ++++++++++++- lib/explorer/polars_backend/series.ex | 6 ++++-- lib/explorer/series.ex | 17 ++++++++++++++--- native/explorer/src/series.rs | 5 +++-- 6 files changed, 43 insertions(+), 10 deletions(-) diff --git a/lib/explorer/backend/lazy_series.ex b/lib/explorer/backend/lazy_series.ex index 86b8afb36..764d0251b 100644 --- a/lib/explorer/backend/lazy_series.ex +++ b/lib/explorer/backend/lazy_series.ex @@ -1244,7 +1244,7 @@ defmodule Explorer.Backend.LazySeries do at_every: 2, categories: 1, categorise: 2, - cut: 5, + cut: 7, frequencies: 1, qcut: 8, mask: 2, diff --git a/lib/explorer/backend/series.ex b/lib/explorer/backend/series.ex index 96445ec77..ded019664 100644 --- a/lib/explorer/backend/series.ex +++ b/lib/explorer/backend/series.ex @@ -190,7 +190,15 @@ defmodule Explorer.Backend.Series do # Categorisation - @callback cut(s, [float()], [String.t()] | nil, String.t() | nil, String.t() | nil) :: + @callback cut( + s, + bins :: [float()], + labels :: option([String.t()]), + break_point_label :: option(String.t()), + category_label :: option(String.t()), + left_close :: boolean(), + include_breaks :: boolean() + ) :: df @callback qcut( s, diff --git a/lib/explorer/polars_backend/native.ex b/lib/explorer/polars_backend/native.ex index 53546e309..688ed0827 100644 --- a/lib/explorer/polars_backend/native.ex +++ b/lib/explorer/polars_backend/native.ex @@ -418,7 +418,18 @@ defmodule Explorer.PolarsBackend.Native do def s_upcase(_s), do: err() def s_unordered_distinct(_s), do: err() def s_frequencies(_s), do: err() - def s_cut(_s, _bins, _labels, _break_point_label, _category_label), do: err() + + def s_cut( + _s, + _bins, + _labels, + _break_point_label, + _category_label, + _left_close, + _include_breaks + ), + do: err() + def s_substring(_s, _offset, _length), do: err() def s_split(_s, _by), do: err() def s_split_into(_s, _by, _num_fields), do: err() diff --git a/lib/explorer/polars_backend/series.ex b/lib/explorer/polars_backend/series.ex index 27cc776c3..56af73590 100644 --- a/lib/explorer/polars_backend/series.ex +++ b/lib/explorer/polars_backend/series.ex @@ -541,13 +541,15 @@ defmodule Explorer.PolarsBackend.Series do # Categorisation @impl true - def cut(series, bins, labels, break_point_label, category_label) do + def cut(series, bins, labels, break_point_label, category_label, left_close, include_breaks) do case Explorer.PolarsBackend.Native.s_cut( series.data, bins, labels, break_point_label, - category_label + category_label, + left_close, + include_breaks ) do {:ok, polars_df} -> Shared.create_dataframe!(polars_df) diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index a8d51ee0f..afec16156 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -4842,11 +4842,22 @@ defmodule Explorer.Series do """ @doc type: :aggregation def cut(series, bins, opts \\ []) do + opts = + Keyword.validate!(opts, + labels: nil, + break_point_label: nil, + category_label: nil, + left_close: false, + include_breaks: true + ) + apply_series(series, :cut, [ Enum.map(bins, &(&1 / 1.0)), - Keyword.get(opts, :labels), - Keyword.get(opts, :break_point_label), - Keyword.get(opts, :category_label) + opts[:labels], + opts[:break_point_label], + opts[:category_label], + opts[:left_close], + opts[:include_breaks] ]) end diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index 35f664c5b..d5584a0a1 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -206,9 +206,10 @@ pub fn s_cut( labels: Option>, break_point_label: Option<&str>, category_label: Option<&str>, + left_close: bool, + include_breaks: bool, ) -> Result { let series = series.clone_inner(); - let left_close = false; // Cut is going to return a Series of a Struct. We need to convert it to a DF. let cut_series = cut( @@ -216,7 +217,7 @@ pub fn s_cut( bins, labels.map(|vec| vec.iter().map(|label| label.into()).collect()), left_close, - true, + include_breaks, )?; let mut cut_df = DataFrame::new(cut_series.struct_()?.fields_as_series())?;