Skip to content

Commit

Permalink
Move empty df test to py-polars and use concise empty IdxCa new
Browse files Browse the repository at this point in the history
  • Loading branch information
trueb2 committed Aug 26, 2023
1 parent 5e7574b commit bcc9619
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 17 deletions.
18 changes: 1 addition & 17 deletions crates/polars-core/src/chunked_array/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::utils::{CustomIterTools, NoNull};

fn create_rand_index_with_replacement(n: usize, len: usize, seed: Option<u64>) -> IdxCa {
if len == 0 {
return NoNull::<IdxCa>::from_iter(std::iter::empty()).into_inner();
return IdxCa::new_vec("", vec![]);
}
let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_global_random_u64));
let dist = Uniform::new(0, len as IdxSize);
Expand Down Expand Up @@ -260,22 +260,6 @@ impl BooleanChunked {
mod test {
use super::*;

#[test]
fn test_sample_empty_df() {
let df = df![
"foo" => Vec::<i32>::new()
]
.unwrap();

// If with replacement, then expect empty df
assert_eq!(df.sample_n(3, true, false, None).unwrap().height(), 0);
assert_eq!(df.sample_frac(0.4, true, false, None).unwrap().height(), 0);

// If without replacement, then expect shape mismatch on sample_n not sample_frac
assert!(df.sample_n(3, false, false, None).is_err());
assert_eq!(df.sample_frac(0.4, false, false, None).unwrap().height(), 0);
}

#[test]
fn test_sample() {
let df = df![
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/operations/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def test_sample_df() -> None:
assert df.sample(n=2, seed=0).shape == (2, 3)
assert df.sample(fraction=0.4, seed=0).shape == (1, 3)

def test_sample_empty_df() -> None:
df = pl.DataFrame({"foo": []})

# // If with replacement, then expect empty df
assert df.sample(n=3, with_replacement=True).shape == (0, 1)
assert df.sample(fraction=0.4, with_replacement=True).shape == (0, 1)

# // If without replacement, then expect shape mismatch on sample_n not sample_frac
with pytest.raises(pl.ShapeError):
df.sample(n=3, with_replacement=False)
assert df.sample(fraction=0.4, with_replacement=False).shape == (0, 1)


def test_sample_series() -> None:
s = pl.Series("a", [1, 2, 3, 4, 5])
Expand Down

0 comments on commit bcc9619

Please sign in to comment.