diff --git a/crates/polars-core/src/chunked_array/random.rs b/crates/polars-core/src/chunked_array/random.rs index 99641dc0b2a9..195c955d879f 100644 --- a/crates/polars-core/src/chunked_array/random.rs +++ b/crates/polars-core/src/chunked_array/random.rs @@ -168,6 +168,34 @@ where impl DataFrame { /// Sample n datapoints from this [`DataFrame`]. pub fn sample_n( + &self, + n: &Series, + with_replacement: bool, + shuffle: bool, + seed: Option, + ) -> PolarsResult { + polars_ensure!( + n.len() == 1, + ComputeError: "Sample size must be a single value." + ); + + let n = n.cast(&IDX_DTYPE)?; + let n = n.idx()?; + + match n.get(0) { + Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed), + None => { + let new_cols = self + .columns + .iter() + .map(|c| Series::new_empty(c.name(), c.dtype())) + .collect_trusted(); + Ok(DataFrame::new_no_checks(new_cols)) + }, + } + } + + pub fn sample_n_literal( &self, n: usize, with_replacement: bool, @@ -194,7 +222,7 @@ impl DataFrame { seed: Option, ) -> PolarsResult { let n = (self.height() as f64 * frac) as usize; - self.sample_n(n, with_replacement, shuffle, seed) + self.sample_n_literal(n, with_replacement, shuffle, seed) } } @@ -268,14 +296,20 @@ mod test { .unwrap(); // default samples are random and don't require seeds - assert!(df.sample_n(3, false, false, None).is_ok()); + assert!(df + .sample_n(&Series::new("s", &[3]), false, false, None) + .is_ok()); assert!(df.sample_frac(0.4, false, false, None).is_ok()); // with seeding - assert!(df.sample_n(3, false, false, Some(0)).is_ok()); + assert!(df + .sample_n(&Series::new("s", &[3]), false, false, Some(0)) + .is_ok()); assert!(df.sample_frac(0.4, false, false, Some(0)).is_ok()); // without replacement can not sample more than 100% assert!(df.sample_frac(2.0, false, false, Some(0)).is_err()); - assert!(df.sample_n(3, true, false, Some(0)).is_ok()); + assert!(df + .sample_n(&Series::new("s", &[3]), true, false, Some(0)) + .is_ok()); assert!(df.sample_frac(0.4, true, false, Some(0)).is_ok()); // with replacement can sample more than 100% assert!(df.sample_frac(2.0, true, false, Some(0)).is_ok()); diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index f6b668442b16..42f5e3ceee21 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -679,7 +679,21 @@ impl From for SpecialEq> { RLEID => map!(rle_id), ToPhysical => map!(dispatch::to_physical), #[cfg(feature = "random")] - Random { method, seed } => map!(random::random, method, seed), + Random { method, seed } => { + use RandomMethod::*; + match method { + Shuffle => map!(random::shuffle, seed), + SampleFrac { + frac, + with_replacement, + shuffle, + } => map!(random::sample_frac, frac, with_replacement, shuffle, seed), + SampleN { + with_replacement, + shuffle, + } => map_as_slice!(random::sample_n, with_replacement, shuffle, seed), + } + }, SetSortedFlag(sorted) => map!(dispatch::set_sorted_flag, sorted), #[cfg(feature = "ffi_plugin")] FfiPlugin { lib, symbol, .. } => unsafe { diff --git a/crates/polars-plan/src/dsl/function_expr/random.rs b/crates/polars-plan/src/dsl/function_expr/random.rs index cb8b7c586915..9555671abaf0 100644 --- a/crates/polars-plan/src/dsl/function_expr/random.rs +++ b/crates/polars-plan/src/dsl/function_expr/random.rs @@ -10,7 +10,6 @@ use super::*; pub enum RandomMethod { Shuffle, SampleN { - n: usize, with_replacement: bool, shuffle: bool, }, @@ -27,18 +26,39 @@ impl Hash for RandomMethod { } } -pub(super) fn random(s: &Series, method: RandomMethod, seed: Option) -> PolarsResult { - match method { - RandomMethod::Shuffle => Ok(s.shuffle(seed)), - RandomMethod::SampleFrac { - frac, - with_replacement, - shuffle, - } => s.sample_frac(frac, with_replacement, shuffle, seed), - RandomMethod::SampleN { - n, - with_replacement, - shuffle, - } => s.sample_n(n, with_replacement, shuffle, seed), +pub(super) fn shuffle(s: &Series, seed: Option) -> PolarsResult { + Ok(s.shuffle(seed)) +} + +pub(super) fn sample_frac( + s: &Series, + frac: f64, + with_replacement: bool, + shuffle: bool, + seed: Option, +) -> PolarsResult { + s.sample_frac(frac, with_replacement, shuffle, seed) +} + +pub(super) fn sample_n( + s: &[Series], + with_replacement: bool, + shuffle: bool, + seed: Option, +) -> PolarsResult { + let src = &s[0]; + let n_s = &s[1]; + + polars_ensure!( + n_s.len() == 1, + ComputeError: "Sample size must be a single value." + ); + + let n_s = n_s.cast(&IDX_DTYPE)?; + let n = n_s.idx()?; + + match n.get(0) { + Some(n) => src.sample_n(n as usize, with_replacement, shuffle, seed), + None => Ok(Series::new_empty(src.name(), src.dtype())), } } diff --git a/crates/polars-plan/src/dsl/random.rs b/crates/polars-plan/src/dsl/random.rs index 8c1f9e1c683b..efd36a15a86c 100644 --- a/crates/polars-plan/src/dsl/random.rs +++ b/crates/polars-plan/src/dsl/random.rs @@ -10,19 +10,23 @@ impl Expr { pub fn sample_n( self, - n: usize, + n: Expr, with_replacement: bool, shuffle: bool, seed: Option, ) -> Self { - self.apply_private(FunctionExpr::Random { - method: RandomMethod::SampleN { - n, - with_replacement, - shuffle, + self.apply_many_private( + FunctionExpr::Random { + method: RandomMethod::SampleN { + with_replacement, + shuffle, + }, + seed, }, - seed, - }) + &[n], + false, + false, + ) } pub fn sample_frac( diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index d13f4b62ccd3..c0ef8427cf15 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -8678,7 +8678,7 @@ def null_count(self) -> Self: def sample( self, - n: int | None = None, + n: int | Series | None = None, *, fraction: float | None = None, with_replacement: bool = False, @@ -8739,7 +8739,11 @@ def sample( if n is None: n = 1 - return self._from_pydf(self._df.sample_n(n, with_replacement, shuffle, seed)) + + if not isinstance(n, pl.Series): + n = pl.Series("", [n]) + + return self._from_pydf(self._df.sample_n(n._s, with_replacement, shuffle, seed)) def fold(self, operation: Callable[[Series, Series], Series]) -> Series: """ diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 5d74ab9bfe15..c9d35301df39 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -8053,7 +8053,7 @@ def shuffle(self, seed: int | None = None) -> Self: def sample( self, - n: int | None = None, + n: int | Expr | None = None, *, fraction: float | None = None, with_replacement: bool = False, @@ -8104,6 +8104,7 @@ def sample( if n is None: n = 1 + n = parse_as_expression(n) return self._from_pyexpr( self._pyexpr.sample_n(n, with_replacement, shuffle, seed) ) diff --git a/py-polars/src/dataframe.rs b/py-polars/src/dataframe.rs index d6736fce0200..1b39b274cf2f 100644 --- a/py-polars/src/dataframe.rs +++ b/py-polars/src/dataframe.rs @@ -886,14 +886,14 @@ impl PyDataFrame { pub fn sample_n( &self, - n: usize, + n: &PySeries, with_replacement: bool, shuffle: bool, seed: Option, ) -> PyResult { let df = self .df - .sample_n(n, with_replacement, shuffle, seed) + .sample_n(&n.series, with_replacement, shuffle, seed) .map_err(PyPolarsErr::from)?; Ok(df.into()) } diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index 6f4efebd0179..8fb35d267eef 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -735,10 +735,10 @@ impl PyExpr { } #[pyo3(signature = (n, with_replacement, shuffle, seed))] - fn sample_n(&self, n: usize, with_replacement: bool, shuffle: bool, seed: Option) -> Self { + fn sample_n(&self, n: Self, with_replacement: bool, shuffle: bool, seed: Option) -> Self { self.inner .clone() - .sample_n(n, with_replacement, shuffle, seed) + .sample_n(n.inner, with_replacement, shuffle, seed) .into() } diff --git a/py-polars/tests/unit/operations/test_random.py b/py-polars/tests/unit/operations/test_random.py index 9c7307dd3820..d9099e5f4f34 100644 --- a/py-polars/tests/unit/operations/test_random.py +++ b/py-polars/tests/unit/operations/test_random.py @@ -54,6 +54,29 @@ def test_sample_df() -> None: assert df.sample(fraction=0.4, seed=0).shape == (1, 3) +def test_sample_n_expr() -> None: + df = pl.DataFrame( + { + "group": [1, 1, 1, 2, 2, 2], + "val": [1, 2, 3, 2, 1, 1], + } + ) + + out_df = df.sample(pl.Series([3]), seed=0) + expected_df = pl.DataFrame({"group": [1, 1, 2], "val": [1, 2, 1]}) + assert_frame_equal(out_df, expected_df) + + agg_df = df.group_by("group", maintain_order=True).agg( + pl.col("val").sample(pl.col("val").max(), seed=0) + ) + expected_df = pl.DataFrame({"group": [1, 2], "val": [[1, 2, 3], [2, 1]]}) + assert_frame_equal(agg_df, expected_df) + + select_df = df.select(pl.col("val").sample(pl.col("val").max(), seed=0)) + expected_df = pl.DataFrame({"val": [1, 2, 1]}) + assert_frame_equal(select_df, expected_df) + + def test_sample_empty_df() -> None: df = pl.DataFrame({"foo": []})