diff --git a/crates/polars-core/src/chunked_array/list/iterator.rs b/crates/polars-core/src/chunked_array/list/iterator.rs index 2dc2c7eb85590..fe7f89abd8ec0 100644 --- a/crates/polars-core/src/chunked_array/list/iterator.rs +++ b/crates/polars-core/src/chunked_array/list/iterator.rs @@ -229,6 +229,52 @@ impl ListChunked { out } + pub fn try_zip_and_apply_amortized<'a, T, I, F>( + &'a self, + ca: &'a ChunkedArray, + mut f: F, + ) -> PolarsResult + where + T: PolarsDataType, + &'a ChunkedArray: IntoIterator, + I: TrustedLen>>, + F: FnMut( + Option>, + Option>, + ) -> PolarsResult>, + { + if self.is_empty() { + return Ok(self.clone()); + } + let mut fast_explode = self.null_count() == 0; + // SAFETY: unstable series never lives longer than the iterator. + let mut out: ListChunked = unsafe { + self.amortized_iter() + .zip(ca) + .map(|(opt_s, opt_v)| { + let out = f(opt_s, opt_v)?; + match out { + Some(out) if out.is_empty() => { + fast_explode = false; + Ok(Some(out)) + }, + None => { + fast_explode = false; + Ok(out) + }, + _ => Ok(out), + } + }) + .collect::>()? + }; + + out.rename(self.name()); + if fast_explode { + out.set_fast_explode(); + } + Ok(out) + } + /// Apply a closure `F` elementwise. #[must_use] pub fn apply_amortized<'a, F>(&'a self, mut f: F) -> Self diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 7c6a7770f8d04..785c2aeea51c6 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -138,6 +138,7 @@ fused = ["polars-plan/fused", "polars-ops/fused"] list_sets = ["polars-plan/list_sets", "polars-ops/list_sets"] list_any_all = ["polars-ops/list_any_all", "polars-plan/list_any_all"] list_drop_nulls = ["polars-ops/list_drop_nulls", "polars-plan/list_drop_nulls"] +list_sample = ["polars-ops/list_sample", "polars-plan/list_sample"] cutqcut = ["polars-plan/cutqcut", "polars-ops/cutqcut"] rle = ["polars-plan/rle", "polars-ops/rle"] extract_groups = ["polars-plan/extract_groups"] diff --git a/crates/polars-lazy/src/tests/mod.rs b/crates/polars-lazy/src/tests/mod.rs index 641e250f6dd7a..3bb99a46ae4cb 100644 --- a/crates/polars-lazy/src/tests/mod.rs +++ b/crates/polars-lazy/src/tests/mod.rs @@ -4,6 +4,7 @@ mod arity; mod cse; #[cfg(feature = "parquet")] mod io; +mod lazy_test; mod logical; mod optimization_checks; mod predicate_queries; diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 8e5c2aaa25c32..62955af37e127 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -107,6 +107,7 @@ list_take = [] list_sets = [] list_any_all = [] list_drop_nulls = [] +list_sample = [] extract_groups = ["dtype-struct", "polars-core/regex"] is_in = ["polars-core/reinterpret"] convert_index = [] diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index f9f92a75ad85b..7762b9c42703a 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -402,6 +402,87 @@ pub trait ListNameSpaceImpl: AsList { list_ca.apply_amortized(|s| s.as_ref().drop_nulls()) } + #[cfg(feature = "list_sample")] + fn lst_sample_n( + &self, + n: &Series, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + let ca = self.as_list(); + + let n_s = n.cast(&IDX_DTYPE)?; + let n = n_s.idx()?; + + let out = match n.len() { + 1 => { + if let Some(n) = n.get(0) { + ca.try_apply_amortized(|s| { + s.as_ref() + .sample_n(n as usize, with_replacement, shuffle, seed) + }) + } else { + Ok(ListChunked::full_null_with_dtype( + ca.name(), + ca.len(), + &ca.inner_dtype(), + )) + } + }, + _ => ca.try_zip_and_apply_amortized(n, |opt_s, opt_n| match (opt_s, opt_n) { + (Some(s), Some(n)) => s + .as_ref() + .sample_n(n as usize, with_replacement, shuffle, seed) + .map(Some), + _ => Ok(None), + }), + }; + out.map(|ok| self.same_type(ok)) + } + + #[cfg(feature = "list_sample")] + fn lst_sample_fraction( + &self, + fraction: &Series, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + let ca = self.as_list(); + + let fraction_s = fraction.cast(&DataType::Float64)?; + let fraction = fraction_s.f64()?; + + let out = match fraction.len() { + 1 => { + if let Some(fraction) = fraction.get(0) { + ca.try_apply_amortized(|s| { + let n = (s.as_ref().len() as f64 * fraction) as usize; + s.as_ref() + .sample_n(n as usize, with_replacement, shuffle, seed) + }) + } else { + Ok(ListChunked::full_null_with_dtype( + ca.name(), + ca.len(), + &ca.inner_dtype(), + )) + } + }, + _ => ca.try_zip_and_apply_amortized(fraction, |opt_s, opt_n| match (opt_s, opt_n) { + (Some(s), Some(fraction)) => { + let n = (s.as_ref().len() as f64 * fraction) as usize; + s.as_ref() + .sample_n(n, with_replacement, shuffle, seed) + .map(Some) + }, + _ => Ok(None), + }), + }; + out.map(|ok| self.same_type(ok)) + } + fn lst_concat(&self, other: &[Series]) -> PolarsResult { let ca = self.as_list(); let other_len = other.len(); diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index ee412f48dbec0..041dbbb82fd98 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -136,6 +136,7 @@ fused = ["polars-ops/fused"] list_sets = ["polars-ops/list_sets"] list_any_all = ["polars-ops/list_any_all"] list_drop_nulls = ["polars-ops/list_drop_nulls"] +list_sample = ["polars-ops/list_sample"] cutqcut = ["polars-ops/cutqcut"] rle = ["polars-ops/rle"] extract_groups = ["regex", "dtype-struct", "polars-ops/extract_groups"] diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 35155902b788c..bb238d09bd755 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -11,6 +11,13 @@ pub enum ListFunction { Contains, #[cfg(feature = "list_drop_nulls")] DropNulls, + #[cfg(feature = "list_sample")] + Sample { + is_fraction: bool, + with_replacement: bool, + shuffle: bool, + seed: Option, + }, Slice, Shift, Get, @@ -52,6 +59,14 @@ impl Display for ListFunction { Contains => "contains", #[cfg(feature = "list_drop_nulls")] DropNulls => "drop_nulls", + #[cfg(feature = "list_sample")] + Sample { is_fraction, .. } => { + if *is_fraction { + "sample_fraction" + } else { + "sample_n" + } + }, Slice => "slice", Shift => "shift", Get => "get", @@ -107,6 +122,32 @@ pub(super) fn drop_nulls(s: &Series) -> PolarsResult { Ok(list.lst_drop_nulls().into_series()) } +#[cfg(feature = "list_sample")] +pub(super) fn sample_n( + s: &[Series], + with_replacement: bool, + shuffle: bool, + seed: Option, +) -> PolarsResult { + let list = s[0].list()?; + let n = &s[1]; + list.lst_sample_n(n, with_replacement, shuffle, seed) + .map(|ok| ok.into_series()) +} + +#[cfg(feature = "list_sample")] +pub(super) fn sample_fraction( + s: &[Series], + with_replacement: bool, + shuffle: bool, + seed: Option, +) -> PolarsResult { + let list = s[0].list()?; + let fraction = &s[1]; + list.lst_sample_fraction(fraction, with_replacement, shuffle, seed) + .map(|ok| ok.into_series()) +} + fn check_slice_arg_shape(slice_len: usize, ca_len: usize, name: &str) -> PolarsResult<()> { polars_ensure!( slice_len == ca_len, diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 40a5eef2f3268..3ca6b5f2c7c9b 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -671,6 +671,19 @@ impl From for SpecialEq> { Contains => wrap!(list::contains), #[cfg(feature = "list_drop_nulls")] DropNulls => map!(list::drop_nulls), + #[cfg(feature = "list_sample")] + Sample { + is_fraction, + with_replacement, + shuffle, + seed, + } => { + if is_fraction { + map_as_slice!(list::sample_fraction, with_replacement, shuffle, seed) + } else { + map_as_slice!(list::sample_n, with_replacement, shuffle, seed) + } + }, Slice => wrap!(list::slice), Shift => map_as_slice!(list::shift), Get => wrap!(list::get), diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index cf2178fcd42cd..af99f7f81b52f 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -70,6 +70,8 @@ impl FunctionExpr { Contains => mapper.with_dtype(DataType::Boolean), #[cfg(feature = "list_drop_nulls")] DropNulls => mapper.with_same_dtype(), + #[cfg(feature = "list_sample")] + Sample { .. } => mapper.with_same_dtype(), Slice => mapper.with_same_dtype(), Shift => mapper.with_same_dtype(), Get => mapper.map_to_list_inner_dtype(), diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 6e9bde5b68eb0..c8741dba73ff8 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -34,6 +34,48 @@ impl ListNameSpace { .map_private(FunctionExpr::ListExpr(ListFunction::DropNulls)) } + #[cfg(feature = "list_sample")] + pub fn sample_n( + self, + n: Expr, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> Expr { + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::Sample { + is_fraction: false, + with_replacement, + shuffle, + seed, + }), + &[n], + false, + false, + ) + } + + #[cfg(feature = "list_sample")] + pub fn sample_fraction( + self, + fraction: Expr, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> Expr { + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::Sample { + is_fraction: true, + with_replacement, + shuffle, + seed, + }), + &[fraction], + false, + false, + ) + } + /// Return the number of elements in each list. /// /// Null values are treated like regular elements in this context. diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 58f6a8334bcde..e2908aec78e66 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -190,6 +190,7 @@ fused = ["polars-ops/fused", "polars-lazy?/fused"] list_sets = ["polars-lazy?/list_sets"] list_any_all = ["polars-lazy?/list_any_all"] list_drop_nulls = ["polars-lazy?/list_drop_nulls"] +list_sample = ["polars-lazy?/list_sample"] cutqcut = ["polars-lazy?/cutqcut"] rle = ["polars-lazy?/rle"] extract_groups = ["polars-lazy?/extract_groups"] diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index 445646088c1d9..f782192c869e7 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -138,6 +138,7 @@ binary_encoding = ["polars/binary_encoding"] list_sets = ["polars-lazy/list_sets"] list_any_all = ["polars/list_any_all"] list_drop_nulls = ["polars/list_drop_nulls"] +list_sample = ["polars/list_sample"] cutqcut = ["polars/cutqcut"] rle = ["polars/rle"] extract_groups = ["polars/extract_groups"] @@ -165,6 +166,7 @@ operations = [ "list_sets", "list_any_all", "list_drop_nulls", + "list_sample", "cutqcut", "rle", "extract_groups", diff --git a/py-polars/docs/source/reference/expressions/list.rst b/py-polars/docs/source/reference/expressions/list.rst index d56b44abcc301..f43401e205612 100644 --- a/py-polars/docs/source/reference/expressions/list.rst +++ b/py-polars/docs/source/reference/expressions/list.rst @@ -34,6 +34,7 @@ The following methods are available under the `expr.list` attribute. Expr.list.mean Expr.list.min Expr.list.reverse + Expr.list.sample Expr.list.set_difference Expr.list.set_intersection Expr.list.set_symmetric_difference diff --git a/py-polars/docs/source/reference/series/list.rst b/py-polars/docs/source/reference/series/list.rst index 7f3b709e80dbd..ad766dd92eb9e 100644 --- a/py-polars/docs/source/reference/series/list.rst +++ b/py-polars/docs/source/reference/series/list.rst @@ -34,6 +34,7 @@ The following methods are available under the `Series.list` attribute. Series.list.mean Series.list.min Series.list.reverse + Series.list.sample Series.list.set_difference Series.list.set_intersection Series.list.set_symmetric_difference diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 608cb0db48407..8ad68bb58117e 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -8244,7 +8244,7 @@ def shuffle(self, seed: int | None = None) -> Self: def sample( self, - n: int | Expr | None = None, + n: int | IntoExprColumn | None = None, *, fraction: float | None = None, with_replacement: bool = False, diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 5a41f4e2f2131..249c17ab5cb4a 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -138,6 +138,64 @@ def drop_nulls(self) -> Expr: """ return wrap_expr(self._pyexpr.list_drop_nulls()) + def sample( + self, + n: int | IntoExprColumn | None = None, + *, + fraction: float | IntoExprColumn | None = None, + with_replacement: bool = False, + shuffle: bool = False, + seed: int | None = None, + ) -> Expr: + """ + Sample from this list. + + Parameters + ---------- + n + Number of items to return. Cannot be used with `fraction`. Defaults to 1 if + `fraction` is None. + fraction + Fraction of items to return. Cannot be used with `n`. + with_replacement + Allow values to be sampled more than once. + shuffle + Shuffle the order of sampled data points. + seed + Seed for the random number generator. If set to None (default), a + random seed is generated for each sample operation. + + Examples + -------- + >>> df = pl.DataFrame({"values": [[1, 2, 3], [4, 5]], "n": [2, 1]}) + >>> df.select(pl.col("values").list.sample(n=pl.col("n"), seed=1)) + shape: (2, 1) + ┌───────────┐ + │ values │ + │ --- │ + │ list[i64] │ + ╞═══════════╡ + │ [2, 1] │ + │ [5] │ + └───────────┘ + + """ + if n is not None and fraction is not None: + raise ValueError("cannot specify both `n` and `fraction`") + + if fraction is not None: + fraction = parse_as_expression(fraction) + return wrap_expr( + self._pyexpr.list_sample_fraction( + fraction, with_replacement, shuffle, seed + ) + ) + + if n is None: + n = 1 + n = parse_as_expression(n) + return wrap_expr(self._pyexpr.list_sample_n(n, with_replacement, shuffle, seed)) + def sum(self) -> Expr: """ Sum all the lists in the array. diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 3b883df4dea3c..3f4b11d257b52 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -125,6 +125,46 @@ def drop_nulls(self) -> Series: """ + def sample( + self, + n: int | IntoExprColumn | None = None, + *, + fraction: float | IntoExprColumn | None = None, + with_replacement: bool = False, + shuffle: bool = False, + seed: int | None = None, + ) -> Series: + """ + Sample from this list. + + Parameters + ---------- + n + Number of items to return. Cannot be used with `fraction`. Defaults to 1 if + `fraction` is None. + fraction + Fraction of items to return. Cannot be used with `n`. + with_replacement + Allow values to be sampled more than once. + shuffle + Shuffle the order of sampled data points. + seed + Seed for the random number generator. If set to None (default), a + random seed is generated for each sample operation. + + Examples + -------- + >>> s = pl.Series("values", [[1, 2, 3], [4, 5]]) + >>> s.list.sample(n=pl.Series("n", [2, 1]), seed=1) + shape: (2,) + Series: 'values' [list[i64]] + [ + [2, 1] + [5] + ] + + """ + def sum(self) -> Series: """Sum all the arrays in the list.""" diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index dbac07c08a3bc..a8a6db6613b92 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -115,6 +115,36 @@ impl PyExpr { self.inner.clone().list().drop_nulls().into() } + #[cfg(feature = "list_sample")] + fn list_sample_n( + &self, + n: PyExpr, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> Self { + self.inner + .clone() + .list() + .sample_n(n.inner, with_replacement, shuffle, seed) + .into() + } + + #[cfg(feature = "list_sample")] + fn list_sample_fraction( + &self, + fraction: PyExpr, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> Self { + self.inner + .clone() + .list() + .sample_fraction(fraction.inner, with_replacement, shuffle, seed) + .into() + } + #[cfg(feature = "list_take")] fn list_take(&self, index: PyExpr, null_on_oob: bool) -> Self { self.inner diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/test_list.py index 9c0eeb51e0630..89ceff4e6a0a1 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/test_list.py @@ -179,6 +179,37 @@ def test_list_drop_nulls() -> None: assert_frame_equal(df, expected_df) +def test_list_sample() -> None: + s = pl.Series("values", [[1, 2, 3, None], [None, None], [1, 2], None]) + + expected_sample_n = pl.Series("values", [[3, 1], [None], [2], None]) + assert_series_equal( + s.list.sample(n=pl.Series([2, 1, 1, 1]), seed=1), expected_sample_n + ) + + expected_sample_frac = pl.Series("values", [[3, 1], [None], [1, 2], None]) + assert_series_equal( + s.list.sample(fraction=pl.Series([0.5, 0.5, 1.0, 0.3]), seed=1), + expected_sample_frac, + ) + + df = pl.DataFrame( + { + "values": [[1, 2, 3, None], [None, None], [3, 4]], + "n": [2, 1, 2], + "frac": [0.5, 0.5, 1.0], + } + ) + df = df.select( + sample_n=pl.col("values").list.sample(n=pl.col("n"), seed=1), + sample_frac=pl.col("values").list.sample(fraction=pl.col("frac"), seed=1), + ) + expected_df = pl.DataFrame( + {"sample_n": [[3, 1], [None], [3, 4]], "sample_frac": [[3, 1], [None], [3, 4]]} + ) + assert_frame_equal(df, expected_df) + + def test_list_diff() -> None: s = pl.Series("a", [[1, 2], [10, 2, 1]]) expected = pl.Series("a", [[None, 1], [None, -8, -1]])