Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trap weighted index overflow #1353

Merged
merged 4 commits into from
Dec 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/distributions/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::distributions::utils::{IntAsSIMD, FloatAsSIMD, FloatSIMDUtils};
use crate::distributions::{Distribution, Standard};
use crate::Rng;
use core::mem;
#[cfg(feature = "simd_support")] use core::simd::*;
#[cfg(feature = "simd_support")] use core::simd::prelude::*;

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};
Expand Down
2 changes: 1 addition & 1 deletion src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ pub use self::slice::Slice;
#[doc(inline)]
pub use self::uniform::Uniform;
#[cfg(feature = "alloc")]
pub use self::weighted_index::{WeightedError, WeightedIndex};
pub use self::weighted_index::{Weight, WeightedError, WeightedIndex};

#[allow(unused)]
use crate::Rng;
Expand Down
6 changes: 4 additions & 2 deletions src/distributions/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ use crate::Rng;
use serde::{Serialize, Deserialize};
use core::mem::{self, MaybeUninit};
#[cfg(feature = "simd_support")]
use core::simd::*;
use core::simd::prelude::*;
#[cfg(feature = "simd_support")]
use core::simd::{LaneCount, MaskElement, SupportedLaneCount};


// ----- Sampling distributions -----
Expand Down Expand Up @@ -163,7 +165,7 @@ impl Distribution<bool> for Standard {
/// Since most bits are unused you could also generate only as many bits as you need, i.e.:
/// ```
/// #![feature(portable_simd)]
/// use std::simd::*;
/// use std::simd::prelude::*;
/// use rand::prelude::*;
/// let mut rng = thread_rng();
///
Expand Down
7 changes: 4 additions & 3 deletions src/distributions/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ use crate::{Rng, RngCore};
#[allow(unused_imports)] // rustc doesn't detect that this is actually used
use crate::distributions::utils::Float;

#[cfg(feature = "simd_support")] use core::simd::*;
#[cfg(feature = "simd_support")] use core::simd::prelude::*;
#[cfg(feature = "simd_support")] use core::simd::{LaneCount, SupportedLaneCount};

/// Error type returned from [`Uniform::new`] and `new_inclusive`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -1433,7 +1434,7 @@ mod tests {
(-::core::$f_scalar::MAX * 0.2, ::core::$f_scalar::MAX * 0.7),
];
for &(low_scalar, high_scalar) in v.iter() {
for lane in 0..<$ty>::LANES {
for lane in 0..<$ty>::LEN {
let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar);
let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar);
let my_uniform = Uniform::new(low, high).unwrap();
Expand Down Expand Up @@ -1565,7 +1566,7 @@ mod tests {
(::std::$f_scalar::NEG_INFINITY, ::std::$f_scalar::INFINITY),
];
for &(low_scalar, high_scalar) in v.iter() {
for lane in 0..<$ty>::LANES {
for lane in 0..<$ty>::LEN {
let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar);
let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar);
assert!(catch_unwind(|| range(low, high)).is_err());
Expand Down
5 changes: 3 additions & 2 deletions src/distributions/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

//! Math helper functions

#[cfg(feature = "simd_support")] use core::simd::*;
#[cfg(feature = "simd_support")] use core::simd::prelude::*;
#[cfg(feature = "simd_support")] use core::simd::{LaneCount, SimdElement, SupportedLaneCount};


pub(crate) trait WideningMultiply<RHS = Self> {
Expand Down Expand Up @@ -245,7 +246,7 @@ pub(crate) trait Float: Sized {

/// Implement functions on f32/f64 to give them APIs similar to SIMD types
pub(crate) trait FloatAsSIMD: Sized {
const LANES: usize = 1;
const LEN: usize = 1;
#[inline(always)]
fn splat(scalar: Self) -> Self {
scalar
Expand Down
75 changes: 70 additions & 5 deletions src/distributions/weighted_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
where
I: IntoIterator,
I::Item: SampleBorrow<X>,
X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
X: Weight,
{
let mut iter = weights.into_iter();
let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();

let zero = <X as Default>::default();
let zero = X::ZERO;
if !(total_weight >= zero) {
return Err(WeightedError::InvalidWeight);
}
Expand All @@ -117,7 +117,10 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
return Err(WeightedError::InvalidWeight);
}
weights.push(total_weight.clone());
total_weight += w.borrow();

if let Err(()) = total_weight.checked_add_assign(w.borrow()) {
return Err(WeightedError::Overflow);
}
}

if total_weight == zero {
Expand Down Expand Up @@ -236,6 +239,60 @@ where X: SampleUniform + PartialOrd
}
}

/// Bounds on a weight
///
/// See usage in [`WeightedIndex`].
pub trait Weight: Clone {
/// Representation of 0
const ZERO: Self;

/// Checked addition
///
/// - `Result::Ok`: On success, `v` is added to `self`
/// - `Result::Err`: Returns an error when `Self` cannot represent the
/// result of `self + v` (i.e. overflow). The value of `self` should be
/// discarded.
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>;
Comment on lines +249 to +255
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing that concerns me slightly is that this is a fairly complex signature that, as you say, we can't easily change later.

Possibly the standard checked_add signature would be better, but this one is closer to what we actually need (considering we don't have a Copy bound).

}

macro_rules! impl_weight_int {
($t:ty) => {
impl Weight for $t {
const ZERO: Self = 0;
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
match self.checked_add(*v) {
Some(sum) => {
*self = sum;
Ok(())
}
None => Err(()),
}
}
}
};
($t:ty, $($tt:ty),*) => {
impl_weight_int!($t);
impl_weight_int!($($tt),*);
}
}
impl_weight_int!(i8, i16, i32, i64, i128, isize);
impl_weight_int!(u8, u16, u32, u64, u128, usize);

macro_rules! impl_weight_float {
($t:ty) => {
impl Weight for $t {
const ZERO: Self = 0.0;
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
// Floats have an explicit representation for overflow
*self += *v;
Ok(())
}
}
}
}
impl_weight_float!(f32);
impl_weight_float!(f64);

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -388,12 +445,11 @@ mod test {

#[test]
fn value_stability() {
fn test_samples<X: SampleUniform + PartialOrd, I>(
fn test_samples<X: Weight + SampleUniform + PartialOrd, I>(
weights: I, buf: &mut [usize], expected: &[usize],
) where
I: IntoIterator,
I::Item: SampleBorrow<X>,
X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
{
assert_eq!(buf.len(), expected.len());
let distr = WeightedIndex::new(weights).unwrap();
Expand All @@ -420,6 +476,11 @@ mod test {
fn weighted_index_distributions_can_be_compared() {
assert_eq!(WeightedIndex::new(&[1, 2]), WeightedIndex::new(&[1, 2]));
}

#[test]
fn overflow() {
assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(WeightedError::Overflow));
}
}

/// Error type returned from `WeightedIndex::new`.
Expand All @@ -438,6 +499,9 @@ pub enum WeightedError {

/// Too many weights are provided (length greater than `u32::MAX`)
TooMany,

/// The sum of weights overflows
Overflow,
}

#[cfg(feature = "std")]
Expand All @@ -450,6 +514,7 @@ impl fmt::Display for WeightedError {
WeightedError::InvalidWeight => "A weight is invalid in distribution",
WeightedError::AllWeightsZero => "All weights are zero in distribution",
WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution",
WeightedError::Overflow => "The sum of weights overflowed",
})
}
}
26 changes: 5 additions & 21 deletions src/seq/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use alloc::vec::Vec;
#[cfg(feature = "alloc")]
use crate::distributions::uniform::{SampleBorrow, SampleUniform};
#[cfg(feature = "alloc")]
use crate::distributions::WeightedError;
use crate::distributions::{Weight, WeightedError};
use crate::Rng;

use self::coin_flipper::CoinFlipper;
Expand Down Expand Up @@ -170,11 +170,7 @@ pub trait SliceRandom {
R: Rng + ?Sized,
F: Fn(&Self::Item) -> B,
B: SampleBorrow<X>,
X: SampleUniform
+ for<'a> ::core::ops::AddAssign<&'a X>
+ ::core::cmp::PartialOrd<X>
+ Clone
+ Default;
X: SampleUniform + Weight + ::core::cmp::PartialOrd<X>;

/// Biased sampling for one element (mut)
///
Expand Down Expand Up @@ -203,11 +199,7 @@ pub trait SliceRandom {
R: Rng + ?Sized,
F: Fn(&Self::Item) -> B,
B: SampleBorrow<X>,
X: SampleUniform
+ for<'a> ::core::ops::AddAssign<&'a X>
+ ::core::cmp::PartialOrd<X>
+ Clone
+ Default;
X: SampleUniform + Weight + ::core::cmp::PartialOrd<X>;

/// Biased sampling of `amount` distinct elements
///
Expand Down Expand Up @@ -585,11 +577,7 @@ impl<T> SliceRandom for [T] {
R: Rng + ?Sized,
F: Fn(&Self::Item) -> B,
B: SampleBorrow<X>,
X: SampleUniform
+ for<'a> ::core::ops::AddAssign<&'a X>
+ ::core::cmp::PartialOrd<X>
+ Clone
+ Default,
X: SampleUniform + Weight + ::core::cmp::PartialOrd<X>,
{
use crate::distributions::{Distribution, WeightedIndex};
let distr = WeightedIndex::new(self.iter().map(weight))?;
Expand All @@ -604,11 +592,7 @@ impl<T> SliceRandom for [T] {
R: Rng + ?Sized,
F: Fn(&Self::Item) -> B,
B: SampleBorrow<X>,
X: SampleUniform
+ for<'a> ::core::ops::AddAssign<&'a X>
+ ::core::cmp::PartialOrd<X>
+ Clone
+ Default,
X: SampleUniform + Weight + ::core::cmp::PartialOrd<X>,
{
use crate::distributions::{Distribution, WeightedIndex};
let distr = WeightedIndex::new(self.iter().map(weight))?;
Expand Down