Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hypergeo fix #1510

Merged
merged 10 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rand_distr/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Move some of the computations in Binomial from `sample` to `new` (#1484)
- Add Kolmogorov Smirnov test for sampling of `Normal` and `Binomial` (#1494)
- Add Kolmogorov Smirnov test for more distributions (#1504)
- Fix bug in `Hypergeometric`, this is a Value-breaking change (#1510)

### Added
- Add plots for `rand_distr` distributions to documentation (#1434)
Expand Down
21 changes: 19 additions & 2 deletions rand_distr/src/hypergeometric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,17 @@ fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64,
result
}

const LOGSQRT2PI: f64 = 0.91893853320467274178; // log(sqrt(2*pi))

fn ln_of_factorial(v: f64) -> f64 {
// the paper calls for ln(v!), but also wants to pass in fractions,
// so we need to use Stirling's approximation to fill in the gaps:
v * v.ln() - v

// shift v by 3, because Stirling is bad for small values
let v_3 = v + 3.0;
let ln_fac = (v_3 + 0.5) * v_3.ln() - v_3 + LOGSQRT2PI + 1.0 / (12.0 * v_3);
// make the correction for the shift
ln_fac - ((v + 3.0) * (v + 2.0) * (v + 1.0)).ln()
}

impl Hypergeometric {
Expand Down Expand Up @@ -359,7 +366,7 @@ impl Distribution<u64> for Hypergeometric {
} else {
for i in (y as u64 + 1)..=(m as u64) {
f *= i as f64 * (n2 - k + i) as f64;
f /= (n1 - i) as f64 * (k - i) as f64;
f /= (n1 - i + 1) as f64 * (k - i + 1) as f64;
}
}

Expand Down Expand Up @@ -441,6 +448,7 @@ impl Distribution<u64> for Hypergeometric {

#[cfg(test)]
mod test {

use super::*;

#[test]
Expand Down Expand Up @@ -494,4 +502,13 @@ mod test {
fn hypergeometric_distributions_can_be_compared() {
assert_eq!(Hypergeometric::new(1, 2, 3), Hypergeometric::new(1, 2, 3));
}

#[test]
fn stirling() {
let test = [0.5, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
for &v in test.iter() {
let ln_fac = ln_of_factorial(v);
assert!((special::Gamma::ln_gamma(v + 1.0).0 - ln_fac).abs() < 1e-4);
}
}
}
2 changes: 1 addition & 1 deletion rand_distr/tests/cdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ fn hypergeometric() {
(60, 10, 7),
(70, 20, 50),
(100, 50, 10),
// (100, 50, 49), // Fail case
(100, 50, 49),
];

for (seed, (n, k, n_)) in parameters.into_iter().enumerate() {
Expand Down
2 changes: 1 addition & 1 deletion rand_distr/tests/value_stability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ fn hypergeometric_stability() {
test_samples(
7221,
Hypergeometric::new(100, 50, 50).unwrap(),
&[23, 27, 26, 27, 22, 24, 31, 22],
&[23, 27, 26, 27, 22, 25, 31, 25],
benjamin-lieser marked this conversation as resolved.
Show resolved Hide resolved
); // Algorithm H2PE
}

Expand Down