diff --git a/crates/polars-core/src/chunked_array/random.rs b/crates/polars-core/src/chunked_array/random.rs index 22081708bd1a6..195c955d879fe 100644 --- a/crates/polars-core/src/chunked_array/random.rs +++ b/crates/polars-core/src/chunked_array/random.rs @@ -174,14 +174,14 @@ impl DataFrame { shuffle: bool, seed: Option, ) -> PolarsResult { - let n = n.cast(&IDX_DTYPE)?; - let n = n.idx()?; - polars_ensure!( n.len() == 1, ComputeError: "Sample size must be a single value." ); + let n = n.cast(&IDX_DTYPE)?; + let n = n.idx()?; + match n.get(0) { Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed), None => { diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 9cfb82d6e3002..5172ae503ba97 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -8739,7 +8739,9 @@ def sample( if n is None: n = 1 - n = _prepare_other_arg(n, 1) + + if not isinstance(n, pl.Series): + n = pl.Series("", [n]) return self._from_pydf(self._df.sample_n(n._s, with_replacement, shuffle, seed)) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index ced8ef629f81e..bfe08022039a8 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -8038,7 +8038,7 @@ def shuffle(self, seed: int | None = None) -> Self: def sample( self, - n: IntoExpr = None, + n: int | Expr | None = None, *, fraction: float | None = None, with_replacement: bool = False,