Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rand_distr::TruncatedNormal #1523

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions benches/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,7 @@ harness = false
[[bench]]
name = "weighted"
harness = false

[[bench]]
name = "truncnorm"
harness = false
83 changes: 83 additions & 0 deletions benches/benches/truncnorm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion_cycles_per_byte::CyclesPerByte;

use rand::prelude::*;
use rand_distr::*;

// At this time, distributions are optimised for 64-bit platforms.
use rand_pcg::Pcg64Mcg;

struct TruncatedNormalByRejection {
normal: Normal<f64>,
a: f64,
b: f64,
}

impl TruncatedNormalByRejection {
fn new(mean: f64, std_dev: f64, a: f64, b: f64) -> Self {
Self {
normal: Normal::new(mean, std_dev).unwrap(),
a,
b,
}
}
}

impl Distribution<f64> for TruncatedNormalByRejection {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> f64 {
let mut value;
loop {
value = rng.sample(self.normal);
if value >= self.a && value <= self.b {
return value;
}
}
}
}

fn bench(c: &mut Criterion<CyclesPerByte>) {
let distr = TruncatedNormal::new(0., 1., f64::NEG_INFINITY, f64::INFINITY).unwrap();

let ranges = [
(1, f64::NEG_INFINITY, distr.ppf(0.01)),
(3, f64::NEG_INFINITY, distr.ppf(0.03)),
(5, f64::NEG_INFINITY, distr.ppf(0.05)),
(7, f64::NEG_INFINITY, distr.ppf(0.07)),
(10, f64::NEG_INFINITY, distr.ppf(0.1)),
(30, f64::NEG_INFINITY, distr.ppf(0.3)),
(50, f64::NEG_INFINITY, distr.ppf(0.5)),
(70, f64::NEG_INFINITY, distr.ppf(0.7)),
(100, f64::NEG_INFINITY, f64::INFINITY),
];

let mut g = c.benchmark_group("truncnorm by rejection");
for range in &ranges {
let mut rng = Pcg64Mcg::from_os_rng();
g.throughput(Throughput::Elements(range.0));
g.bench_with_input(BenchmarkId::from_parameter(range.0), range, |b, &range| {
let distr = TruncatedNormalByRejection::new(0.0, 1.0, range.1, range.2);
b.iter(|| std::hint::black_box(Distribution::<f64>::sample(&distr, &mut rng)));
});
}
g.finish();

let mut g = c.benchmark_group("truncnorm by ppf");
for range in &ranges {
let mut rng = Pcg64Mcg::from_os_rng();
g.throughput(Throughput::Elements(range.0));
g.bench_with_input(BenchmarkId::from_parameter(range.0), range, |b, &range| {
let distr = TruncatedNormal::new(0.0, 1.0, range.1, range.2).unwrap();
b.iter(|| std::hint::black_box(Distribution::<f64>::sample(&distr, &mut rng)));
});
}
g.finish();
}

criterion_group!(
name = benches;
config = Criterion::default().with_measurement(CyclesPerByte)
.warm_up_time(core::time::Duration::from_secs(1))
.measurement_time(core::time::Duration::from_secs(2));
targets = bench
);
criterion_main!(benches);
1 change: 1 addition & 0 deletions rand_distr/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Add plots for `rand_distr` distributions to documentation (#1434)
- Add `PertBuilder`, fix case where mode ≅ mean (#1452)
- Add `rand_distr::TruncatedNormal` (#1523)

## [0.5.0-alpha.1] - 2024-03-18
- Target `rand` version `0.9.0-alpha.1`
Expand Down
1 change: 1 addition & 0 deletions rand_distr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ rand = { path = "..", version = "=0.9.0-alpha.1", default-features = false }
num-traits = { version = "0.2", default-features = false, features = ["libm"] }
serde = { version = "1.0.103", features = ["derive"], optional = true }
serde_with = { version = ">= 3.0, <= 3.11", optional = true }
spec_math = "0.1.6"

[dev-dependencies]
rand_pcg = { version = "=0.9.0-alpha.1", path = "../rand_pcg" }
Expand Down
2 changes: 1 addition & 1 deletion rand_distr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ pub use self::geometric::{Error as GeoError, Geometric, StandardGeometric};
pub use self::gumbel::{Error as GumbelError, Gumbel};
pub use self::hypergeometric::{Error as HyperGeoError, Hypergeometric};
pub use self::inverse_gaussian::{Error as InverseGaussianError, InverseGaussian};
pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal};
pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal, TruncatedNormal};
pub use self::normal_inverse_gaussian::{
Error as NormalInverseGaussianError, NormalInverseGaussian,
};
Expand Down
192 changes: 191 additions & 1 deletion rand_distr/src/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
use crate::utils::ziggurat;
use crate::{ziggurat_tables, Distribution, Open01};
use core::fmt;
use num_traits::Float;
use num_traits::{cast, Float, FloatConst};
use rand::distr::uniform::Uniform;
use rand::Rng;

use spec_math::cephes64::{ndtr, ndtri};

/// The standard Normal distribution `N(0, 1)`.
///
/// This is equivalent to `Normal::new(0.0, 1.0)`, but faster.
Expand Down Expand Up @@ -160,13 +163,16 @@ pub enum Error {
MeanTooSmall,
/// The standard deviation or other dispersion parameter is not finite.
BadVariance,
/// The left bound is greater than or equal to the right bound.
InvalidInterval,
}

impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::MeanTooSmall => "mean < 0 or NaN in log-normal distribution",
Error::BadVariance => "variation parameter is non-finite in (log)normal distribution",
Error::InvalidInterval => "the left bound must be less than right bound",
})
}
}
Expand Down Expand Up @@ -363,6 +369,144 @@ where
}
}

/// The [Truncated normal distribution](https://en.wikipedia.org/wiki/Truncated_normal_distribution)
///
/// # Example
///
/// ```
/// use rand_distr::{TruncatedNormal, Distribution};
///
/// let truncnorm = TruncatedNormal::new(0., 1., -1.0, 2.0).unwrap();
/// let val = truncnorm.sample(&mut rand::thread_rng());
/// println!("{}", val);
/// ```
///
/// # Notes
///
/// This implementation is ported from [`scipy.stats.truncnorm`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.truncnorm.html),
/// which is based on [Cephes Mathematical Library](https://www.netlib.org/cephes/).
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TruncatedNormal<F>
where
F: Float + FloatConst,
{
mean: F,
std_dev: F,
a: F,
b: F,
uniform: Uniform<f64>,
}

impl<F> TruncatedNormal<F>
where
F: Float + FloatConst,
{
/// Construct, from mean and standard deviation
///
/// Parameters:
///
/// - mean (`μ`, unrestricted)
/// - standard deviation (`σ`, must be finite)
/// - left bound (`a`)
/// - right bound (`b`)
#[inline]
pub fn new(mean: F, std_dev: F, a: F, b: F) -> Result<TruncatedNormal<F>, Error> {
if !std_dev.is_finite() {
return Err(Error::BadVariance);
}
if a >= b {
return Err(Error::InvalidInterval);
}
Ok(TruncatedNormal {
mean,
std_dev,
a,
b,
uniform: Uniform::new(0., 1.).unwrap(),
})
}

/// Returns the mean (`μ`) of the distribution.
pub fn mean(&self) -> F {
self.mean
}

/// Returns the standard deviation (`σ`) of the distribution.
pub fn std_dev(&self) -> F {
self.std_dev
}

/// Cumulative Distribution Function
pub fn cdf(&self, x: F) -> F {
cast(ndtr(cast(x).unwrap())).unwrap()
}

/// Inverse Cumulative Distribution Function
pub fn icdf(&self, x: F) -> F {
cast(ndtri(cast(x).unwrap())).unwrap()
}

/// Percent Point Function
/// based on `scipy.stats.truncnorm` with modifications
pub fn ppf(&self, q: F) -> F {
// logsumexp trick for log(p + q) with only log(p) and log(q)
let log_sum_exp = |log_p: F, log_q: F| -> F {
let max = log_p.max(log_q);
((log_p - max).exp() + (log_q - max).exp()).ln() + max
};

// Log diff for log(p - q) and insuring that the difference is not negative
let log_diff_exp = |log_p: F, log_q: F| -> F {
let max = log_p.max(log_q);
((log_p - max).exp() - (log_q - max).exp()).abs().ln() + max
};

// Log of Gaussian probability mass within an interval
let log_gauss_mass = |a: F, b: F| -> F {
if b <= F::zero() {
log_diff_exp(self.cdf(b).ln(), self.cdf(a).ln())
} else if a > F::zero() {
// Calculations in right tail are inaccurate, so we'll exploit the
// symmetry and work only in the left tail
log_diff_exp(self.cdf(-b).ln(), self.cdf(-a).ln())
} else {
// Catastrophic cancellation occurs as exp(log_mass) approaches 1.
// Correct for this with an alternative formulation.
// We're not concerned with underflow here: if only one term
// underflows, it was insignificant; if both terms underflow,
// the result can't accurately be represented in logspace anyway
// because log1p(x) ~ x for small x.
(-self.cdf(a) - self.cdf(-b)).ln_1p()
}
};

if self.a < F::zero() {
let log_phi_x = log_sum_exp(
self.cdf(self.a).ln(),
q.ln() + log_gauss_mass(self.a, self.b),
);
self.icdf(log_phi_x.exp())
} else {
let log_phi_x = log_sum_exp(
self.cdf(-self.b).ln(),
(-q).ln_1p() + log_gauss_mass(self.a, self.b),
);
-self.icdf(log_phi_x.exp())
}
}
}

impl<F> Distribution<F> for TruncatedNormal<F>
where
F: Float + FloatConst,
{
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
self.mean + self.std_dev * self.ppf(cast(self.uniform.sample(rng)).unwrap())
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -429,4 +573,50 @@ mod tests {
fn log_normal_distributions_can_be_compared() {
assert_eq!(LogNormal::new(1.0, 2.0), LogNormal::new(1.0, 2.0));
}

#[test]
fn test_truncated_normal_pos() {
let truncnorm = TruncatedNormal::new(0., 1., 1., 2.).unwrap();
let mut rng = crate::test::rng(212);
let mut integral = 0.;
for _ in 0..1000 {
integral += truncnorm.sample(&mut rng);
}
// According to the result from:
// https://www.wolframalpha.com/input?i=integral+e%5E%28-%28x%5E2%29%2F2%29%2F%28sqrt%282+%CF%80%29%29+from+1+to+2
//
// integral e^(-(x^2)/2)/(sqrt(2 π)) from 1 to 2 ≈ 0.135905
//
// The error of the integral result by 1000 samples from TruncatedNormal is below 3%
assert_almost_eq!(integral, 1359.05, 1359.05 * 0.03);
}

#[test]
fn test_truncated_normal_neg() {
let truncnorm = TruncatedNormal::new(0., 1., -2., -1.).unwrap();
let mut rng = crate::test::rng(212);
let mut integral = 0.;
for _ in 0..1000 {
integral += truncnorm.sample(&mut rng);
}
// Mirror case of the `test_truncated_normal_pos`
assert_almost_eq!(integral, -1359.05, 1359.05 * 0.03);
}

#[test]
fn test_truncated_normal_across() {
let truncnorm = TruncatedNormal::new(0., 1., -1., 1.).unwrap();
let mut rng = crate::test::rng(212);
let mut integral = 0.;
for _ in 0..1000 {
integral += truncnorm.sample(&mut rng);
}
// Symmetry case, the sum of result is almost equal to zero.
assert_almost_eq!(integral, 0., 1359.05 * 0.03);
}

#[test]
fn test_truncated_normal_invalid_bounds() {
assert!(TruncatedNormal::new(0., 1., 2., 1.).is_err());
}
}