From d8655447590f45fa12f5883889edd2fbdad16cd0 Mon Sep 17 00:00:00 2001 From: Philip Sampaio Date: Wed, 23 Oct 2024 18:56:54 -0300 Subject: [PATCH] Fix `cut/3` and `qcut/3` when `:include_breaks` is false [ci skip] (#1009) * Fix `cut/3` and `qcut/3` when `:include_breaks` is false Also add docs regarding the `:left_close` and `:include_breaks` options. Thanks @billylanchantin :D Reference: https://github.com/elixir-explorer/explorer/pull/1007#discussion_r1812946449 * Change default of `cut/qcut` to not include breaks --- lib/explorer/series.ex | 33 ++++++++++++++++++++---- native/explorer/src/series.rs | 47 +++++++++++++++++++++++------------ test/explorer/series_test.exs | 38 +++++++++++++++++++++++++--- 3 files changed, 94 insertions(+), 24 deletions(-) diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index afec16156..73be883b1 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -4824,21 +4824,37 @@ defmodule Explorer.Series do bounds (e.g. `(-inf -1.0]`, `(-1.0, 1.0]`, `(1.0, inf]`) * `:break_point_label` - The name given to the breakpoint column. + This is only relevant if `:include_breaks` is `true`. Defaults to `break_point`. * `:category_label` - The name given to the category column. Defaults to `category`. + * `:left_closed` - Set the intervals to be left-closed instead + of right-closed. Defaults to `false`. + + * `:include_breaks` - Include a column with the right endpoint + of the bin that each observation falls in. + Defaults to `false`. + ## Examples iex> s = Explorer.Series.from_list([1.0, 2.0, 3.0]) - iex> Explorer.Series.cut(s, [1.5, 2.5]) + iex> Explorer.Series.cut(s, [1.5, 2.5], include_breaks: true) #Explorer.DataFrame< Polars[3 x 3] values f64 [1.0, 2.0, 3.0] break_point f64 [1.5, 2.5, Inf] category category ["(-inf, 1.5]", "(1.5, 2.5]", "(2.5, inf]"] > + + iex> s = Explorer.Series.from_list([1.0, 2.0, 3.0]) + iex> Explorer.Series.cut(s, [1.5, 2.5]) + #Explorer.DataFrame< + Polars[3 x 2] + values f64 [1.0, 2.0, 3.0] + category category ["(-inf, 1.5]", "(1.5, 2.5]", "(2.5, inf]"] + > """ @doc type: :aggregation def cut(series, bins, opts \\ []) do @@ -4848,7 +4864,7 @@ defmodule Explorer.Series do break_point_label: nil, category_label: nil, left_close: false, - include_breaks: true + include_breaks: false ) apply_series(series, :cut, [ @@ -4874,6 +4890,7 @@ defmodule Explorer.Series do bounds (e.g. `(-inf -1.0]`, `(-1.0, 1.0]`, `(1.0, inf]`) * `:break_point_label` - The name given to the breakpoint column. + This is only relevant if `:include_breaks` is `true`. Defaults to `break_point`. * `:category_label` - The name given to the category column. @@ -4882,14 +4899,20 @@ defmodule Explorer.Series do * `:allow_duplicates` - If quantiles can have duplicated values. Defaults to `false`. + * `:left_closed` - Set the intervals to be left-closed instead + of right-closed. Defaults to `false`. + + * `:include_breaks` - Include a column with the right endpoint + of the bin that each observation falls in. + Defaults to `false`. + ## Examples iex> s = Explorer.Series.from_list([1.0, 2.0, 3.0, 4.0, 5.0]) iex> Explorer.Series.qcut(s, [0.25, 0.75]) #Explorer.DataFrame< - Polars[5 x 3] + Polars[5 x 2] values f64 [1.0, 2.0, 3.0, 4.0, 5.0] - break_point f64 [2.0, 2.0, 4.0, 4.0, Inf] category category ["(-inf, 2]", "(-inf, 2]", "(2, 4]", "(2, 4]", "(4, inf]"] > """ @@ -4902,7 +4925,7 @@ defmodule Explorer.Series do category_label: nil, allow_duplicates: false, left_close: false, - include_breaks: true + include_breaks: false ) apply_series(series, :qcut, [ diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index d5584a0a1..8fbde7f8b 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -219,17 +219,25 @@ pub fn s_cut( left_close, include_breaks, )?; - let mut cut_df = DataFrame::new(cut_series.struct_()?.fields_as_series())?; - let cut_df = cut_df.insert_column(0, series)?; + if include_breaks { + let mut cut_df = DataFrame::new(cut_series.struct_()?.fields_as_series())?; - cut_df.set_column_names([ - "values", - break_point_label.unwrap_or("break_point"), - category_label.unwrap_or("category"), - ])?; + let cut_df = cut_df.insert_column(0, series)?; - Ok(ExDataFrame::new(cut_df.clone())) + cut_df.set_column_names([ + "values", + break_point_label.unwrap_or("break_point"), + category_label.unwrap_or("category"), + ])?; + + Ok(ExDataFrame::new(cut_df.clone())) + } else { + let mut cut_df = DataFrame::new(vec![series, cut_series])?; + cut_df.set_column_names(["values", category_label.unwrap_or("category")])?; + + Ok(ExDataFrame::new(cut_df.clone())) + } } #[allow(clippy::too_many_arguments)] @@ -255,16 +263,23 @@ pub fn s_qcut( include_breaks, )?; - let mut qcut_df = DataFrame::new(qcut_series.struct_()?.fields_as_series())?; - let qcut_df = qcut_df.insert_column(0, series)?; + if include_breaks { + let mut qcut_df = DataFrame::new(qcut_series.struct_()?.fields_as_series())?; + let qcut_df = qcut_df.insert_column(0, series)?; - qcut_df.set_column_names([ - "values", - break_point_label.unwrap_or("break_point"), - category_label.unwrap_or("category"), - ])?; + qcut_df.set_column_names([ + "values", + break_point_label.unwrap_or("break_point"), + category_label.unwrap_or("category"), + ])?; - Ok(ExDataFrame::new(qcut_df.clone())) + Ok(ExDataFrame::new(qcut_df.clone())) + } else { + let mut qcut_df = DataFrame::new(vec![series, qcut_series])?; + qcut_df.set_column_names(["values", category_label.unwrap_or("category")])?; + + Ok(ExDataFrame::new(qcut_df.clone())) + } } #[rustler::nif(schedule = "DirtyCpu")] diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index 307da6d75..8ce2d6c1f 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -5950,7 +5950,7 @@ defmodule Explorer.SeriesTest do end describe "categorisation functions" do - test "cut/6 with no nils" do + test "cut/3 with no nils" do series = -30..30//5 |> Enum.map(&(&1 / 10)) |> Enum.to_list() |> Series.from_list() df = Series.cut(series, [-1, 1]) freqs = Series.frequencies(df[:category]) @@ -5958,13 +5958,13 @@ defmodule Explorer.SeriesTest do assert Series.to_list(freqs[:counts]) == [5, 4, 4] end - test "cut/6 with nils" do + test "cut/3 with nils" do series = Series.from_list([1, 2, 3, nil, nil]) df = Series.cut(series, [2]) assert [_, _, _, nil, nil] = Series.to_list(df[:category]) end - test "cut/6 options" do + test "cut/3 options" do series = Series.from_list([1, 2, 3]) assert_raise ArgumentError, @@ -5973,6 +5973,7 @@ defmodule Explorer.SeriesTest do df = Series.cut(series, [2], + include_breaks: true, labels: ["x", "y"], break_point_label: "bp", category_label: "cat" @@ -5981,6 +5982,17 @@ defmodule Explorer.SeriesTest do assert Explorer.DataFrame.names(df) == ["values", "bp", "cat"] end + test "cut/3 with include breaks" do + series = Series.from_list([1.0, 2.0, 3.0]) + df = Series.cut(series, [1.5, 2.5], include_breaks: true) + + assert Explorer.DataFrame.to_columns(df, atom_keys: true) == %{ + category: ["(-inf, 1.5]", "(1.5, 2.5]", "(2.5, inf]"], + break_point: [1.5, 2.5, :infinity], + values: [1.0, 2.0, 3.0] + } + end + test "qcut/3" do series = Enum.to_list(-5..3) |> Series.from_list() df = Series.qcut(series, [0.0, 0.25, 0.75]) @@ -6009,6 +6021,26 @@ defmodule Explorer.SeriesTest do assert Series.to_list(freqs[:counts]) == [3, 2, 1] end + + test "qcut/3 without include breaks" do + series = Enum.to_list(-5..3) |> Series.from_list() + df = Series.qcut(series, [0.0, 0.25, 0.75], include_breaks: false) + + assert Explorer.DataFrame.to_columns(df, atom_keys: true) == %{ + category: [ + "(-inf, -5]", + "(-5, -3]", + "(-5, -3]", + "(-3, 1]", + "(-3, 1]", + "(-3, 1]", + "(-3, 1]", + "(1, inf]", + "(1, inf]" + ], + values: [-5, -4, -3, -2, -1, 0, 1, 2, 3] + } + end end describe "join/2" do