Skip to content

Commit

Permalink
Trap weighted index overflow (#1353)
Browse files Browse the repository at this point in the history
* WeightedIndex: add test overflow (expected to panic)

* WeightedIndex::new: trap overflow in release builds only

* Introduce trait Weight

* Update regarding nightly SIMD changes
  • Loading branch information
dhardy authored Dec 30, 2023
1 parent 3c2e82f commit e9a27a8
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/distributions/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::distributions::utils::{IntAsSIMD, FloatAsSIMD, FloatSIMDUtils};
use crate::distributions::{Distribution, Standard};
use crate::Rng;
use core::mem;
#[cfg(feature = "simd_support")] use core::simd::*;
#[cfg(feature = "simd_support")] use core::simd::prelude::*;

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};
Expand Down
2 changes: 1 addition & 1 deletion src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ pub use self::slice::Slice;
#[doc(inline)]
pub use self::uniform::Uniform;
#[cfg(feature = "alloc")]
pub use self::weighted_index::{WeightedError, WeightedIndex};
pub use self::weighted_index::{Weight, WeightedError, WeightedIndex};

#[allow(unused)]
use crate::Rng;
Expand Down
6 changes: 4 additions & 2 deletions src/distributions/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ use crate::Rng;
use serde::{Serialize, Deserialize};
use core::mem::{self, MaybeUninit};
#[cfg(feature = "simd_support")]
use core::simd::*;
use core::simd::prelude::*;
#[cfg(feature = "simd_support")]
use core::simd::{LaneCount, MaskElement, SupportedLaneCount};


// ----- Sampling distributions -----
Expand Down Expand Up @@ -163,7 +165,7 @@ impl Distribution<bool> for Standard {
/// Since most bits are unused you could also generate only as many bits as you need, i.e.:
/// ```
/// #![feature(portable_simd)]
/// use std::simd::*;
/// use std::simd::prelude::*;
/// use rand::prelude::*;
/// let mut rng = thread_rng();
///
Expand Down
7 changes: 4 additions & 3 deletions src/distributions/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ use crate::{Rng, RngCore};
#[allow(unused_imports)] // rustc doesn't detect that this is actually used
use crate::distributions::utils::Float;

#[cfg(feature = "simd_support")] use core::simd::*;
#[cfg(feature = "simd_support")] use core::simd::prelude::*;
#[cfg(feature = "simd_support")] use core::simd::{LaneCount, SupportedLaneCount};

/// Error type returned from [`Uniform::new`] and `new_inclusive`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -1433,7 +1434,7 @@ mod tests {
(-::core::$f_scalar::MAX * 0.2, ::core::$f_scalar::MAX * 0.7),
];
for &(low_scalar, high_scalar) in v.iter() {
for lane in 0..<$ty>::LANES {
for lane in 0..<$ty>::LEN {
let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar);
let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar);
let my_uniform = Uniform::new(low, high).unwrap();
Expand Down Expand Up @@ -1565,7 +1566,7 @@ mod tests {
(::std::$f_scalar::NEG_INFINITY, ::std::$f_scalar::INFINITY),
];
for &(low_scalar, high_scalar) in v.iter() {
for lane in 0..<$ty>::LANES {
for lane in 0..<$ty>::LEN {
let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar);
let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar);
assert!(catch_unwind(|| range(low, high)).is_err());
Expand Down
5 changes: 3 additions & 2 deletions src/distributions/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

//! Math helper functions

#[cfg(feature = "simd_support")] use core::simd::*;
#[cfg(feature = "simd_support")] use core::simd::prelude::*;
#[cfg(feature = "simd_support")] use core::simd::{LaneCount, SimdElement, SupportedLaneCount};


pub(crate) trait WideningMultiply<RHS = Self> {
Expand Down Expand Up @@ -245,7 +246,7 @@ pub(crate) trait Float: Sized {

/// Implement functions on f32/f64 to give them APIs similar to SIMD types
pub(crate) trait FloatAsSIMD: Sized {
const LANES: usize = 1;
const LEN: usize = 1;
#[inline(always)]
fn splat(scalar: Self) -> Self {
scalar
Expand Down
75 changes: 70 additions & 5 deletions src/distributions/weighted_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
where
I: IntoIterator,
I::Item: SampleBorrow<X>,
X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
X: Weight,
{
let mut iter = weights.into_iter();
let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();

let zero = <X as Default>::default();
let zero = X::ZERO;
if !(total_weight >= zero) {
return Err(WeightedError::InvalidWeight);
}
Expand All @@ -117,7 +117,10 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
return Err(WeightedError::InvalidWeight);
}
weights.push(total_weight.clone());
total_weight += w.borrow();

if let Err(()) = total_weight.checked_add_assign(w.borrow()) {
return Err(WeightedError::Overflow);
}
}

if total_weight == zero {
Expand Down Expand Up @@ -236,6 +239,60 @@ where X: SampleUniform + PartialOrd
}
}

/// Bounds on a weight
///
/// See usage in [`WeightedIndex`].
pub trait Weight: Clone {
/// Representation of 0
const ZERO: Self;

/// Checked addition
///
/// - `Result::Ok`: On success, `v` is added to `self`
/// - `Result::Err`: Returns an error when `Self` cannot represent the
/// result of `self + v` (i.e. overflow). The value of `self` should be
/// discarded.
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>;
}

macro_rules! impl_weight_int {
($t:ty) => {
impl Weight for $t {
const ZERO: Self = 0;
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
match self.checked_add(*v) {
Some(sum) => {
*self = sum;
Ok(())
}
None => Err(()),
}
}
}
};
($t:ty, $($tt:ty),*) => {
impl_weight_int!($t);
impl_weight_int!($($tt),*);
}
}
impl_weight_int!(i8, i16, i32, i64, i128, isize);
impl_weight_int!(u8, u16, u32, u64, u128, usize);

macro_rules! impl_weight_float {
($t:ty) => {
impl Weight for $t {
const ZERO: Self = 0.0;
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
// Floats have an explicit representation for overflow
*self += *v;
Ok(())
}
}
}
}
impl_weight_float!(f32);
impl_weight_float!(f64);

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -388,12 +445,11 @@ mod test {

#[test]
fn value_stability() {
fn test_samples<X: SampleUniform + PartialOrd, I>(
fn test_samples<X: Weight + SampleUniform + PartialOrd, I>(
weights: I, buf: &mut [usize], expected: &[usize],
) where
I: IntoIterator,
I::Item: SampleBorrow<X>,
X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
{
assert_eq!(buf.len(), expected.len());
let distr = WeightedIndex::new(weights).unwrap();
Expand All @@ -420,6 +476,11 @@ mod test {
fn weighted_index_distributions_can_be_compared() {
assert_eq!(WeightedIndex::new(&[1, 2]), WeightedIndex::new(&[1, 2]));
}

#[test]
fn overflow() {
assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(WeightedError::Overflow));
}
}

/// Error type returned from `WeightedIndex::new`.
Expand All @@ -438,6 +499,9 @@ pub enum WeightedError {

/// Too many weights are provided (length greater than `u32::MAX`)
TooMany,

/// The sum of weights overflows
Overflow,
}

#[cfg(feature = "std")]
Expand All @@ -450,6 +514,7 @@ impl fmt::Display for WeightedError {
WeightedError::InvalidWeight => "A weight is invalid in distribution",
WeightedError::AllWeightsZero => "All weights are zero in distribution",
WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution",
WeightedError::Overflow => "The sum of weights overflowed",
})
}
}
26 changes: 5 additions & 21 deletions src/seq/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use alloc::vec::Vec;
#[cfg(feature = "alloc")]
use crate::distributions::uniform::{SampleBorrow, SampleUniform};
#[cfg(feature = "alloc")]
use crate::distributions::WeightedError;
use crate::distributions::{Weight, WeightedError};
use crate::Rng;

use self::coin_flipper::CoinFlipper;
Expand Down Expand Up @@ -170,11 +170,7 @@ pub trait SliceRandom {
R: Rng + ?Sized,
F: Fn(&Self::Item) -> B,
B: SampleBorrow<X>,
X: SampleUniform
+ for<'a> ::core::ops::AddAssign<&'a X>
+ ::core::cmp::PartialOrd<X>
+ Clone
+ Default;
X: SampleUniform + Weight + ::core::cmp::PartialOrd<X>;

/// Biased sampling for one element (mut)
///
Expand Down Expand Up @@ -203,11 +199,7 @@ pub trait SliceRandom {
R: Rng + ?Sized,
F: Fn(&Self::Item) -> B,
B: SampleBorrow<X>,
X: SampleUniform
+ for<'a> ::core::ops::AddAssign<&'a X>
+ ::core::cmp::PartialOrd<X>
+ Clone
+ Default;
X: SampleUniform + Weight + ::core::cmp::PartialOrd<X>;

/// Biased sampling of `amount` distinct elements
///
Expand Down Expand Up @@ -585,11 +577,7 @@ impl<T> SliceRandom for [T] {
R: Rng + ?Sized,
F: Fn(&Self::Item) -> B,
B: SampleBorrow<X>,
X: SampleUniform
+ for<'a> ::core::ops::AddAssign<&'a X>
+ ::core::cmp::PartialOrd<X>
+ Clone
+ Default,
X: SampleUniform + Weight + ::core::cmp::PartialOrd<X>,
{
use crate::distributions::{Distribution, WeightedIndex};
let distr = WeightedIndex::new(self.iter().map(weight))?;
Expand All @@ -604,11 +592,7 @@ impl<T> SliceRandom for [T] {
R: Rng + ?Sized,
F: Fn(&Self::Item) -> B,
B: SampleBorrow<X>,
X: SampleUniform
+ for<'a> ::core::ops::AddAssign<&'a X>
+ ::core::cmp::PartialOrd<X>
+ Clone
+ Default,
X: SampleUniform + Weight + ::core::cmp::PartialOrd<X>,
{
use crate::distributions::{Distribution, WeightedIndex};
let distr = WeightedIndex::new(self.iter().map(weight))?;
Expand Down

0 comments on commit e9a27a8

Please sign in to comment.