Skip to content

Commit

Permalink
feat: sample n can take an expr (#11257)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored Sep 26, 2023
1 parent ddc6a1b commit 6cd7633
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 34 deletions.
42 changes: 38 additions & 4 deletions crates/polars-core/src/chunked_array/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,34 @@ where
impl DataFrame {
/// Sample n datapoints from this [`DataFrame`].
pub fn sample_n(
&self,
n: &Series,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PolarsResult<Self> {
polars_ensure!(
n.len() == 1,
ComputeError: "Sample size must be a single value."
);

let n = n.cast(&IDX_DTYPE)?;
let n = n.idx()?;

match n.get(0) {
Some(n) => self.sample_n_literal(n as usize, with_replacement, shuffle, seed),
None => {
let new_cols = self
.columns
.iter()
.map(|c| Series::new_empty(c.name(), c.dtype()))
.collect_trusted();
Ok(DataFrame::new_no_checks(new_cols))
},
}
}

pub fn sample_n_literal(
&self,
n: usize,
with_replacement: bool,
Expand All @@ -194,7 +222,7 @@ impl DataFrame {
seed: Option<u64>,
) -> PolarsResult<Self> {
let n = (self.height() as f64 * frac) as usize;
self.sample_n(n, with_replacement, shuffle, seed)
self.sample_n_literal(n, with_replacement, shuffle, seed)
}
}

Expand Down Expand Up @@ -268,14 +296,20 @@ mod test {
.unwrap();

// default samples are random and don't require seeds
assert!(df.sample_n(3, false, false, None).is_ok());
assert!(df
.sample_n(&Series::new("s", &[3]), false, false, None)
.is_ok());
assert!(df.sample_frac(0.4, false, false, None).is_ok());
// with seeding
assert!(df.sample_n(3, false, false, Some(0)).is_ok());
assert!(df
.sample_n(&Series::new("s", &[3]), false, false, Some(0))
.is_ok());
assert!(df.sample_frac(0.4, false, false, Some(0)).is_ok());
// without replacement can not sample more than 100%
assert!(df.sample_frac(2.0, false, false, Some(0)).is_err());
assert!(df.sample_n(3, true, false, Some(0)).is_ok());
assert!(df
.sample_n(&Series::new("s", &[3]), true, false, Some(0))
.is_ok());
assert!(df.sample_frac(0.4, true, false, Some(0)).is_ok());
// with replacement can sample more than 100%
assert!(df.sample_frac(2.0, true, false, Some(0)).is_ok());
Expand Down
16 changes: 15 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,21 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
RLEID => map!(rle_id),
ToPhysical => map!(dispatch::to_physical),
#[cfg(feature = "random")]
Random { method, seed } => map!(random::random, method, seed),
Random { method, seed } => {
use RandomMethod::*;
match method {
Shuffle => map!(random::shuffle, seed),
SampleFrac {
frac,
with_replacement,
shuffle,
} => map!(random::sample_frac, frac, with_replacement, shuffle, seed),
SampleN {
with_replacement,
shuffle,
} => map_as_slice!(random::sample_n, with_replacement, shuffle, seed),
}
},
SetSortedFlag(sorted) => map!(dispatch::set_sorted_flag, sorted),
#[cfg(feature = "ffi_plugin")]
FfiPlugin { lib, symbol, .. } => unsafe {
Expand Down
48 changes: 34 additions & 14 deletions crates/polars-plan/src/dsl/function_expr/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use super::*;
pub enum RandomMethod {
Shuffle,
SampleN {
n: usize,
with_replacement: bool,
shuffle: bool,
},
Expand All @@ -27,18 +26,39 @@ impl Hash for RandomMethod {
}
}

pub(super) fn random(s: &Series, method: RandomMethod, seed: Option<u64>) -> PolarsResult<Series> {
match method {
RandomMethod::Shuffle => Ok(s.shuffle(seed)),
RandomMethod::SampleFrac {
frac,
with_replacement,
shuffle,
} => s.sample_frac(frac, with_replacement, shuffle, seed),
RandomMethod::SampleN {
n,
with_replacement,
shuffle,
} => s.sample_n(n, with_replacement, shuffle, seed),
pub(super) fn shuffle(s: &Series, seed: Option<u64>) -> PolarsResult<Series> {
Ok(s.shuffle(seed))
}

pub(super) fn sample_frac(
s: &Series,
frac: f64,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PolarsResult<Series> {
s.sample_frac(frac, with_replacement, shuffle, seed)
}

pub(super) fn sample_n(
s: &[Series],
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PolarsResult<Series> {
let src = &s[0];
let n_s = &s[1];

polars_ensure!(
n_s.len() == 1,
ComputeError: "Sample size must be a single value."
);

let n_s = n_s.cast(&IDX_DTYPE)?;
let n = n_s.idx()?;

match n.get(0) {
Some(n) => src.sample_n(n as usize, with_replacement, shuffle, seed),
None => Ok(Series::new_empty(src.name(), src.dtype())),
}
}
20 changes: 12 additions & 8 deletions crates/polars-plan/src/dsl/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,23 @@ impl Expr {

pub fn sample_n(
self,
n: usize,
n: Expr,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Self {
self.apply_private(FunctionExpr::Random {
method: RandomMethod::SampleN {
n,
with_replacement,
shuffle,
self.apply_many_private(
FunctionExpr::Random {
method: RandomMethod::SampleN {
with_replacement,
shuffle,
},
seed,
},
seed,
})
&[n],
false,
false,
)
}

pub fn sample_frac(
Expand Down
8 changes: 6 additions & 2 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -8678,7 +8678,7 @@ def null_count(self) -> Self:

def sample(
self,
n: int | None = None,
n: int | Series | None = None,
*,
fraction: float | None = None,
with_replacement: bool = False,
Expand Down Expand Up @@ -8739,7 +8739,11 @@ def sample(

if n is None:
n = 1
return self._from_pydf(self._df.sample_n(n, with_replacement, shuffle, seed))

if not isinstance(n, pl.Series):
n = pl.Series("", [n])

return self._from_pydf(self._df.sample_n(n._s, with_replacement, shuffle, seed))

def fold(self, operation: Callable[[Series, Series], Series]) -> Series:
"""
Expand Down
3 changes: 2 additions & 1 deletion py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8053,7 +8053,7 @@ def shuffle(self, seed: int | None = None) -> Self:

def sample(
self,
n: int | None = None,
n: int | Expr | None = None,
*,
fraction: float | None = None,
with_replacement: bool = False,
Expand Down Expand Up @@ -8104,6 +8104,7 @@ def sample(

if n is None:
n = 1
n = parse_as_expression(n)
return self._from_pyexpr(
self._pyexpr.sample_n(n, with_replacement, shuffle, seed)
)
Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -886,14 +886,14 @@ impl PyDataFrame {

pub fn sample_n(
&self,
n: usize,
n: &PySeries,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PyResult<Self> {
let df = self
.df
.sample_n(n, with_replacement, shuffle, seed)
.sample_n(&n.series, with_replacement, shuffle, seed)
.map_err(PyPolarsErr::from)?;
Ok(df.into())
}
Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/expr/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -735,10 +735,10 @@ impl PyExpr {
}

#[pyo3(signature = (n, with_replacement, shuffle, seed))]
fn sample_n(&self, n: usize, with_replacement: bool, shuffle: bool, seed: Option<u64>) -> Self {
fn sample_n(&self, n: Self, with_replacement: bool, shuffle: bool, seed: Option<u64>) -> Self {
self.inner
.clone()
.sample_n(n, with_replacement, shuffle, seed)
.sample_n(n.inner, with_replacement, shuffle, seed)
.into()
}

Expand Down
23 changes: 23 additions & 0 deletions py-polars/tests/unit/operations/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,29 @@ def test_sample_df() -> None:
assert df.sample(fraction=0.4, seed=0).shape == (1, 3)


def test_sample_n_expr() -> None:
df = pl.DataFrame(
{
"group": [1, 1, 1, 2, 2, 2],
"val": [1, 2, 3, 2, 1, 1],
}
)

out_df = df.sample(pl.Series([3]), seed=0)
expected_df = pl.DataFrame({"group": [1, 1, 2], "val": [1, 2, 1]})
assert_frame_equal(out_df, expected_df)

agg_df = df.group_by("group", maintain_order=True).agg(
pl.col("val").sample(pl.col("val").max(), seed=0)
)
expected_df = pl.DataFrame({"group": [1, 2], "val": [[1, 2, 3], [2, 1]]})
assert_frame_equal(agg_df, expected_df)

select_df = df.select(pl.col("val").sample(pl.col("val").max(), seed=0))
expected_df = pl.DataFrame({"val": [1, 2, 1]})
assert_frame_equal(select_df, expected_df)


def test_sample_empty_df() -> None:
df = pl.DataFrame({"foo": []})

Expand Down

0 comments on commit 6cd7633

Please sign in to comment.