From bcc96198430c20a986d0e02281773451b02b4883 Mon Sep 17 00:00:00 2001 From: Jacob Trueb Date: Sat, 26 Aug 2023 10:13:03 -0500 Subject: [PATCH] Move empty df test to py-polars and use concise empty IdxCa new --- crates/polars-core/src/chunked_array/random.rs | 18 +----------------- py-polars/tests/unit/operations/test_random.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/crates/polars-core/src/chunked_array/random.rs b/crates/polars-core/src/chunked_array/random.rs index 8b2e0c0c9e672..74ece90d2c3e4 100644 --- a/crates/polars-core/src/chunked_array/random.rs +++ b/crates/polars-core/src/chunked_array/random.rs @@ -10,7 +10,7 @@ use crate::utils::{CustomIterTools, NoNull}; fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option) -> IdxCa { if len == 0 { - return NoNull::::from_iter(std::iter::empty()).into_inner(); + return IdxCa::new_vec("", vec![]); } let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64)); let dist = Uniform::new(0, len as IdxSize); @@ -260,22 +260,6 @@ impl BooleanChunked { mod test { use super::*; - #[test] - fn test_sample_empty_df() { - let df = df![ - "foo" => Vec::::new() - ] - .unwrap(); - - // If with replacement, then expect empty df - assert_eq!(df.sample_n(3, true, false, None).unwrap().height(), 0); - assert_eq!(df.sample_frac(0.4, true, false, None).unwrap().height(), 0); - - // If without replacement, then expect shape mismatch on sample_n not sample_frac - assert!(df.sample_n(3, false, false, None).is_err()); - assert_eq!(df.sample_frac(0.4, false, false, None).unwrap().height(), 0); - } - #[test] fn test_sample() { let df = df![ diff --git a/py-polars/tests/unit/operations/test_random.py b/py-polars/tests/unit/operations/test_random.py index a92dfbe696772..479ec51b07f01 100644 --- a/py-polars/tests/unit/operations/test_random.py +++ b/py-polars/tests/unit/operations/test_random.py @@ -53,6 +53,18 @@ def test_sample_df() -> None: assert df.sample(n=2, seed=0).shape == (2, 3) assert df.sample(fraction=0.4, seed=0).shape == (1, 3) +def test_sample_empty_df() -> None: + df = pl.DataFrame({"foo": []}) + + # // If with replacement, then expect empty df + assert df.sample(n=3, with_replacement=True).shape == (0, 1) + assert df.sample(fraction=0.4, with_replacement=True).shape == (0, 1) + + # // If without replacement, then expect shape mismatch on sample_n not sample_frac + with pytest.raises(pl.ShapeError): + df.sample(n=3, with_replacement=False) + assert df.sample(fraction=0.4, with_replacement=False).shape == (0, 1) + def test_sample_series() -> None: s = pl.Series("a", [1, 2, 3, 4, 5])