Skip to content

Commit

Permalink
feat(rust, python): Introduce list.sample
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Oct 19, 2023
1 parent d24c508 commit 6fac63b
Show file tree
Hide file tree
Showing 19 changed files with 394 additions and 1 deletion.
46 changes: 46 additions & 0 deletions crates/polars-core/src/chunked_array/list/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,52 @@ impl ListChunked {
out
}

pub fn try_zip_and_apply_amortized<'a, T, I, F>(
&'a self,
ca: &'a ChunkedArray<T>,
mut f: F,
) -> PolarsResult<Self>
where
T: PolarsDataType,
&'a ChunkedArray<T>: IntoIterator<IntoIter = I>,
I: TrustedLen<Item = Option<T::Physical<'a>>>,
F: FnMut(
Option<UnstableSeries<'a>>,
Option<T::Physical<'a>>,
) -> PolarsResult<Option<Series>>,
{
if self.is_empty() {
return Ok(self.clone());
}
let mut fast_explode = self.null_count() == 0;
// SAFETY: unstable series never lives longer than the iterator.
let mut out: ListChunked = unsafe {
self.amortized_iter()
.zip(ca)
.map(|(opt_s, opt_v)| {
let out = f(opt_s, opt_v)?;
match out {
Some(out) if out.is_empty() => {
fast_explode = false;
Ok(Some(out))
},
None => {
fast_explode = false;
Ok(out)
},
_ => Ok(out),
}
})
.collect::<PolarsResult<_>>()?
};

out.rename(self.name());
if fast_explode {
out.set_fast_explode();
}
Ok(out)
}

/// Apply a closure `F` elementwise.
#[must_use]
pub fn apply_amortized<'a, F>(&'a self, mut f: F) -> Self
Expand Down
1 change: 1 addition & 0 deletions crates/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ fused = ["polars-plan/fused", "polars-ops/fused"]
list_sets = ["polars-plan/list_sets", "polars-ops/list_sets"]
list_any_all = ["polars-ops/list_any_all", "polars-plan/list_any_all"]
list_drop_nulls = ["polars-ops/list_drop_nulls", "polars-plan/list_drop_nulls"]
list_sample = ["polars-ops/list_sample", "polars-plan/list_sample"]
cutqcut = ["polars-plan/cutqcut", "polars-ops/cutqcut"]
rle = ["polars-plan/rle", "polars-ops/rle"]
extract_groups = ["polars-plan/extract_groups"]
Expand Down
1 change: 1 addition & 0 deletions crates/polars-lazy/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod arity;
mod cse;
#[cfg(feature = "parquet")]
mod io;
mod lazy_test;
mod logical;
mod optimization_checks;
mod predicate_queries;
Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ list_take = []
list_sets = []
list_any_all = []
list_drop_nulls = []
list_sample = []
extract_groups = ["dtype-struct", "polars-core/regex"]
is_in = ["polars-core/reinterpret"]
convert_index = []
Expand Down
81 changes: 81 additions & 0 deletions crates/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,87 @@ pub trait ListNameSpaceImpl: AsList {
list_ca.apply_amortized(|s| s.as_ref().drop_nulls())
}

#[cfg(feature = "list_sample")]
fn lst_sample_n(
&self,
n: &Series,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PolarsResult<ListChunked> {
let ca = self.as_list();

let n_s = n.cast(&IDX_DTYPE)?;
let n = n_s.idx()?;

let out = match n.len() {
1 => {
if let Some(n) = n.get(0) {
ca.try_apply_amortized(|s| {
s.as_ref()
.sample_n(n as usize, with_replacement, shuffle, seed)
})
} else {
Ok(ListChunked::full_null_with_dtype(
ca.name(),
ca.len(),
&ca.inner_dtype(),
))
}
},
_ => ca.try_zip_and_apply_amortized(n, |opt_s, opt_n| match (opt_s, opt_n) {
(Some(s), Some(n)) => s
.as_ref()
.sample_n(n as usize, with_replacement, shuffle, seed)
.map(Some),
_ => Ok(None),
}),
};
out.map(|ok| self.same_type(ok))
}

#[cfg(feature = "list_sample")]
fn lst_sample_fraction(
&self,
fraction: &Series,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PolarsResult<ListChunked> {
let ca = self.as_list();

let fraction_s = fraction.cast(&DataType::Float64)?;
let fraction = fraction_s.f64()?;

let out = match fraction.len() {
1 => {
if let Some(fraction) = fraction.get(0) {
ca.try_apply_amortized(|s| {
let n = (s.as_ref().len() as f64 * fraction) as usize;
s.as_ref()
.sample_n(n as usize, with_replacement, shuffle, seed)
})
} else {
Ok(ListChunked::full_null_with_dtype(
ca.name(),
ca.len(),
&ca.inner_dtype(),
))
}
},
_ => ca.try_zip_and_apply_amortized(fraction, |opt_s, opt_n| match (opt_s, opt_n) {
(Some(s), Some(fraction)) => {
let n = (s.as_ref().len() as f64 * fraction) as usize;
s.as_ref()
.sample_n(n, with_replacement, shuffle, seed)
.map(Some)
},
_ => Ok(None),
}),
};
out.map(|ok| self.same_type(ok))
}

fn lst_concat(&self, other: &[Series]) -> PolarsResult<ListChunked> {
let ca = self.as_list();
let other_len = other.len();
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ fused = ["polars-ops/fused"]
list_sets = ["polars-ops/list_sets"]
list_any_all = ["polars-ops/list_any_all"]
list_drop_nulls = ["polars-ops/list_drop_nulls"]
list_sample = ["polars-ops/list_sample"]
cutqcut = ["polars-ops/cutqcut"]
rle = ["polars-ops/rle"]
extract_groups = ["regex", "dtype-struct", "polars-ops/extract_groups"]
Expand Down
41 changes: 41 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ pub enum ListFunction {
Contains,
#[cfg(feature = "list_drop_nulls")]
DropNulls,
#[cfg(feature = "list_sample")]
Sample {
is_fraction: bool,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
},
Slice,
Shift,
Get,
Expand Down Expand Up @@ -52,6 +59,14 @@ impl Display for ListFunction {
Contains => "contains",
#[cfg(feature = "list_drop_nulls")]
DropNulls => "drop_nulls",
#[cfg(feature = "list_sample")]
Sample { is_fraction, .. } => {
if *is_fraction {
"sample_fraction"
} else {
"sample_n"
}
},
Slice => "slice",
Shift => "shift",
Get => "get",
Expand Down Expand Up @@ -107,6 +122,32 @@ pub(super) fn drop_nulls(s: &Series) -> PolarsResult<Series> {
Ok(list.lst_drop_nulls().into_series())
}

#[cfg(feature = "list_sample")]
pub(super) fn sample_n(
s: &[Series],
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PolarsResult<Series> {
let list = s[0].list()?;
let n = &s[1];
list.lst_sample_n(n, with_replacement, shuffle, seed)
.map(|ok| ok.into_series())
}

#[cfg(feature = "list_sample")]
pub(super) fn sample_fraction(
s: &[Series],
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> PolarsResult<Series> {
let list = s[0].list()?;
let fraction = &s[1];
list.lst_sample_fraction(fraction, with_replacement, shuffle, seed)
.map(|ok| ok.into_series())
}

fn check_slice_arg_shape(slice_len: usize, ca_len: usize, name: &str) -> PolarsResult<()> {
polars_ensure!(
slice_len == ca_len,
Expand Down
13 changes: 13 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,19 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
Contains => wrap!(list::contains),
#[cfg(feature = "list_drop_nulls")]
DropNulls => map!(list::drop_nulls),
#[cfg(feature = "list_sample")]
Sample {
is_fraction,
with_replacement,
shuffle,
seed,
} => {
if is_fraction {
map_as_slice!(list::sample_fraction, with_replacement, shuffle, seed)
} else {
map_as_slice!(list::sample_n, with_replacement, shuffle, seed)
}
},
Slice => wrap!(list::slice),
Shift => map_as_slice!(list::shift),
Get => wrap!(list::get),
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ impl FunctionExpr {
Contains => mapper.with_dtype(DataType::Boolean),
#[cfg(feature = "list_drop_nulls")]
DropNulls => mapper.with_same_dtype(),
#[cfg(feature = "list_sample")]
Sample { .. } => mapper.with_same_dtype(),
Slice => mapper.with_same_dtype(),
Shift => mapper.with_same_dtype(),
Get => mapper.map_to_list_inner_dtype(),
Expand Down
42 changes: 42 additions & 0 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,48 @@ impl ListNameSpace {
.map_private(FunctionExpr::ListExpr(ListFunction::DropNulls))
}

#[cfg(feature = "list_sample")]
pub fn sample_n(
self,
n: Expr,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Expr {
self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::Sample {
is_fraction: false,
with_replacement,
shuffle,
seed,
}),
&[n],
false,
false,
)
}

#[cfg(feature = "list_sample")]
pub fn sample_fraction(
self,
fraction: Expr,
with_replacement: bool,
shuffle: bool,
seed: Option<u64>,
) -> Expr {
self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::Sample {
is_fraction: true,
with_replacement,
shuffle,
seed,
}),
&[fraction],
false,
false,
)
}

/// Return the number of elements in each list.
///
/// Null values are treated like regular elements in this context.
Expand Down
1 change: 1 addition & 0 deletions crates/polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ fused = ["polars-ops/fused", "polars-lazy?/fused"]
list_sets = ["polars-lazy?/list_sets"]
list_any_all = ["polars-lazy?/list_any_all"]
list_drop_nulls = ["polars-lazy?/list_drop_nulls"]
list_sample = ["polars-lazy?/list_sample"]
cutqcut = ["polars-lazy?/cutqcut"]
rle = ["polars-lazy?/rle"]
extract_groups = ["polars-lazy?/extract_groups"]
Expand Down
2 changes: 2 additions & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ binary_encoding = ["polars/binary_encoding"]
list_sets = ["polars-lazy/list_sets"]
list_any_all = ["polars/list_any_all"]
list_drop_nulls = ["polars/list_drop_nulls"]
list_sample = ["polars/list_sample"]
cutqcut = ["polars/cutqcut"]
rle = ["polars/rle"]
extract_groups = ["polars/extract_groups"]
Expand Down Expand Up @@ -165,6 +166,7 @@ operations = [
"list_sets",
"list_any_all",
"list_drop_nulls",
"list_sample",
"cutqcut",
"rle",
"extract_groups",
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expressions/list.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ The following methods are available under the `expr.list` attribute.
Expr.list.mean
Expr.list.min
Expr.list.reverse
Expr.list.sample
Expr.list.set_difference
Expr.list.set_intersection
Expr.list.set_symmetric_difference
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series/list.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ The following methods are available under the `Series.list` attribute.
Series.list.mean
Series.list.min
Series.list.reverse
Series.list.sample
Series.list.set_difference
Series.list.set_intersection
Series.list.set_symmetric_difference
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8244,7 +8244,7 @@ def shuffle(self, seed: int | None = None) -> Self:

def sample(
self,
n: int | Expr | None = None,
n: int | IntoExprColumn | None = None,
*,
fraction: float | None = None,
with_replacement: bool = False,
Expand Down
Loading

0 comments on commit 6fac63b

Please sign in to comment.