From 5e7574b301bfe2793de13fd78b9d2c18816ea23d Mon Sep 17 00:00:00 2001 From: Jacob Trueb Date: Fri, 25 Aug 2023 11:27:26 -0500 Subject: [PATCH 1/2] fix(rust): Prevent panic on sample_n with replacement from empty df --- .../polars-core/src/chunked_array/random.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/crates/polars-core/src/chunked_array/random.rs b/crates/polars-core/src/chunked_array/random.rs index 046bc8fb5a08..8b2e0c0c9e67 100644 --- a/crates/polars-core/src/chunked_array/random.rs +++ b/crates/polars-core/src/chunked_array/random.rs @@ -9,6 +9,9 @@ use crate::random::get_global_random_u64; use crate::utils::{CustomIterTools, NoNull}; fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option) -> IdxCa { + if len == 0 { + return NoNull::::from_iter(std::iter::empty()).into_inner(); + } let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64)); let dist = Uniform::new(0, len as IdxSize); (0..n as IdxSize) @@ -257,6 +260,22 @@ impl BooleanChunked { mod test { use super::*; + #[test] + fn test_sample_empty_df() { + let df = df![ + "foo" => Vec::::new() + ] + .unwrap(); + + // If with replacement, then expect empty df + assert_eq!(df.sample_n(3, true, false, None).unwrap().height(), 0); + assert_eq!(df.sample_frac(0.4, true, false, None).unwrap().height(), 0); + + // If without replacement, then expect shape mismatch on sample_n not sample_frac + assert!(df.sample_n(3, false, false, None).is_err()); + assert_eq!(df.sample_frac(0.4, false, false, None).unwrap().height(), 0); + } + #[test] fn test_sample() { let df = df![ From 6296041d150dfab7e5bf0bdc3d8672c6a40770f5 Mon Sep 17 00:00:00 2001 From: Jacob Trueb Date: Sat, 26 Aug 2023 10:13:03 -0500 Subject: [PATCH 2/2] Move empty df test to py-polars and use concise empty IdxCa new --- crates/polars-core/src/chunked_array/random.rs | 18 +----------------- py-polars/tests/unit/operations/test_random.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/crates/polars-core/src/chunked_array/random.rs b/crates/polars-core/src/chunked_array/random.rs index 8b2e0c0c9e67..74ece90d2c3e 100644 --- a/crates/polars-core/src/chunked_array/random.rs +++ b/crates/polars-core/src/chunked_array/random.rs @@ -10,7 +10,7 @@ use crate::utils::{CustomIterTools, NoNull}; fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option) -> IdxCa { if len == 0 { - return NoNull::::from_iter(std::iter::empty()).into_inner(); + return IdxCa::new_vec("", vec![]); } let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64)); let dist = Uniform::new(0, len as IdxSize); @@ -260,22 +260,6 @@ impl BooleanChunked { mod test { use super::*; - #[test] - fn test_sample_empty_df() { - let df = df![ - "foo" => Vec::::new() - ] - .unwrap(); - - // If with replacement, then expect empty df - assert_eq!(df.sample_n(3, true, false, None).unwrap().height(), 0); - assert_eq!(df.sample_frac(0.4, true, false, None).unwrap().height(), 0); - - // If without replacement, then expect shape mismatch on sample_n not sample_frac - assert!(df.sample_n(3, false, false, None).is_err()); - assert_eq!(df.sample_frac(0.4, false, false, None).unwrap().height(), 0); - } - #[test] fn test_sample() { let df = df![ diff --git a/py-polars/tests/unit/operations/test_random.py b/py-polars/tests/unit/operations/test_random.py index a92dfbe69677..9b9954d8e79c 100644 --- a/py-polars/tests/unit/operations/test_random.py +++ b/py-polars/tests/unit/operations/test_random.py @@ -54,6 +54,19 @@ def test_sample_df() -> None: assert df.sample(fraction=0.4, seed=0).shape == (1, 3) +def test_sample_empty_df() -> None: + df = pl.DataFrame({"foo": []}) + + # // If with replacement, then expect empty df + assert df.sample(n=3, with_replacement=True).shape == (0, 1) + assert df.sample(fraction=0.4, with_replacement=True).shape == (0, 1) + + # // If without replacement, then expect shape mismatch on sample_n not sample_frac + with pytest.raises(pl.ShapeError): + df.sample(n=3, with_replacement=False) + assert df.sample(fraction=0.4, with_replacement=False).shape == (0, 1) + + def test_sample_series() -> None: s = pl.Series("a", [1, 2, 3, 4, 5])