diff --git a/src/distributions/float.rs b/src/distributions/float.rs index d4a1075726..37b71612a1 100644 --- a/src/distributions/float.rs +++ b/src/distributions/float.rs @@ -12,7 +12,7 @@ use crate::distributions::utils::{IntAsSIMD, FloatAsSIMD, FloatSIMDUtils}; use crate::distributions::{Distribution, Standard}; use crate::Rng; use core::mem; -#[cfg(feature = "simd_support")] use core::simd::*; +#[cfg(feature = "simd_support")] use core::simd::prelude::*; #[cfg(feature = "serde1")] use serde::{Serialize, Deserialize}; diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index a923f879d2..5adb82f811 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -126,7 +126,7 @@ pub use self::slice::Slice; #[doc(inline)] pub use self::uniform::Uniform; #[cfg(feature = "alloc")] -pub use self::weighted_index::{WeightedError, WeightedIndex}; +pub use self::weighted_index::{Weight, WeightedError, WeightedIndex}; #[allow(unused)] use crate::Rng; diff --git a/src/distributions/other.rs b/src/distributions/other.rs index 3ecf00492c..ebe3d57ed3 100644 --- a/src/distributions/other.rs +++ b/src/distributions/other.rs @@ -22,7 +22,9 @@ use crate::Rng; use serde::{Serialize, Deserialize}; use core::mem::{self, MaybeUninit}; #[cfg(feature = "simd_support")] -use core::simd::*; +use core::simd::prelude::*; +#[cfg(feature = "simd_support")] +use core::simd::{LaneCount, MaskElement, SupportedLaneCount}; // ----- Sampling distributions ----- @@ -163,7 +165,7 @@ impl Distribution for Standard { /// Since most bits are unused you could also generate only as many bits as you need, i.e.: /// ``` /// #![feature(portable_simd)] -/// use std::simd::*; +/// use std::simd::prelude::*; /// use rand::prelude::*; /// let mut rng = thread_rng(); /// diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index ace3a6b43e..879d8ea258 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -119,7 +119,8 @@ use crate::{Rng, RngCore}; #[allow(unused_imports)] // rustc doesn't detect that this is actually used use crate::distributions::utils::Float; -#[cfg(feature = "simd_support")] use core::simd::*; +#[cfg(feature = "simd_support")] use core::simd::prelude::*; +#[cfg(feature = "simd_support")] use core::simd::{LaneCount, SupportedLaneCount}; /// Error type returned from [`Uniform::new`] and `new_inclusive`. #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -1433,7 +1434,7 @@ mod tests { (-::core::$f_scalar::MAX * 0.2, ::core::$f_scalar::MAX * 0.7), ]; for &(low_scalar, high_scalar) in v.iter() { - for lane in 0..<$ty>::LANES { + for lane in 0..<$ty>::LEN { let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); let my_uniform = Uniform::new(low, high).unwrap(); @@ -1565,7 +1566,7 @@ mod tests { (::std::$f_scalar::NEG_INFINITY, ::std::$f_scalar::INFINITY), ]; for &(low_scalar, high_scalar) in v.iter() { - for lane in 0..<$ty>::LANES { + for lane in 0..<$ty>::LEN { let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); assert!(catch_unwind(|| range(low, high)).is_err()); diff --git a/src/distributions/utils.rs b/src/distributions/utils.rs index bddb0a4a59..f3b3089d7e 100644 --- a/src/distributions/utils.rs +++ b/src/distributions/utils.rs @@ -8,7 +8,8 @@ //! Math helper functions -#[cfg(feature = "simd_support")] use core::simd::*; +#[cfg(feature = "simd_support")] use core::simd::prelude::*; +#[cfg(feature = "simd_support")] use core::simd::{LaneCount, SimdElement, SupportedLaneCount}; pub(crate) trait WideningMultiply { @@ -245,7 +246,7 @@ pub(crate) trait Float: Sized { /// Implement functions on f32/f64 to give them APIs similar to SIMD types pub(crate) trait FloatAsSIMD: Sized { - const LANES: usize = 1; + const LEN: usize = 1; #[inline(always)] fn splat(scalar: Self) -> Self { scalar diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index 4c57edc5f6..de3628b5ea 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -99,12 +99,12 @@ impl WeightedIndex { where I: IntoIterator, I::Item: SampleBorrow, - X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, + X: Weight, { let mut iter = weights.into_iter(); let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); - let zero = ::default(); + let zero = X::ZERO; if !(total_weight >= zero) { return Err(WeightedError::InvalidWeight); } @@ -117,7 +117,10 @@ impl WeightedIndex { return Err(WeightedError::InvalidWeight); } weights.push(total_weight.clone()); - total_weight += w.borrow(); + + if let Err(()) = total_weight.checked_add_assign(w.borrow()) { + return Err(WeightedError::Overflow); + } } if total_weight == zero { @@ -236,6 +239,60 @@ where X: SampleUniform + PartialOrd } } +/// Bounds on a weight +/// +/// See usage in [`WeightedIndex`]. +pub trait Weight: Clone { + /// Representation of 0 + const ZERO: Self; + + /// Checked addition + /// + /// - `Result::Ok`: On success, `v` is added to `self` + /// - `Result::Err`: Returns an error when `Self` cannot represent the + /// result of `self + v` (i.e. overflow). The value of `self` should be + /// discarded. + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>; +} + +macro_rules! impl_weight_int { + ($t:ty) => { + impl Weight for $t { + const ZERO: Self = 0; + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { + match self.checked_add(*v) { + Some(sum) => { + *self = sum; + Ok(()) + } + None => Err(()), + } + } + } + }; + ($t:ty, $($tt:ty),*) => { + impl_weight_int!($t); + impl_weight_int!($($tt),*); + } +} +impl_weight_int!(i8, i16, i32, i64, i128, isize); +impl_weight_int!(u8, u16, u32, u64, u128, usize); + +macro_rules! impl_weight_float { + ($t:ty) => { + impl Weight for $t { + const ZERO: Self = 0.0; + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { + // Floats have an explicit representation for overflow + *self += *v; + Ok(()) + } + } + } +} +impl_weight_float!(f32); +impl_weight_float!(f64); + #[cfg(test)] mod test { use super::*; @@ -388,12 +445,11 @@ mod test { #[test] fn value_stability() { - fn test_samples( + fn test_samples( weights: I, buf: &mut [usize], expected: &[usize], ) where I: IntoIterator, I::Item: SampleBorrow, - X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, { assert_eq!(buf.len(), expected.len()); let distr = WeightedIndex::new(weights).unwrap(); @@ -420,6 +476,11 @@ mod test { fn weighted_index_distributions_can_be_compared() { assert_eq!(WeightedIndex::new(&[1, 2]), WeightedIndex::new(&[1, 2])); } + + #[test] + fn overflow() { + assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(WeightedError::Overflow)); + } } /// Error type returned from `WeightedIndex::new`. @@ -438,6 +499,9 @@ pub enum WeightedError { /// Too many weights are provided (length greater than `u32::MAX`) TooMany, + + /// The sum of weights overflows + Overflow, } #[cfg(feature = "std")] @@ -450,6 +514,7 @@ impl fmt::Display for WeightedError { WeightedError::InvalidWeight => "A weight is invalid in distribution", WeightedError::AllWeightsZero => "All weights are zero in distribution", WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution", + WeightedError::Overflow => "The sum of weights overflowed", }) } } diff --git a/src/seq/mod.rs b/src/seq/mod.rs index bbb46fc55f..9012b21b90 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -40,7 +40,7 @@ use alloc::vec::Vec; #[cfg(feature = "alloc")] use crate::distributions::uniform::{SampleBorrow, SampleUniform}; #[cfg(feature = "alloc")] -use crate::distributions::WeightedError; +use crate::distributions::{Weight, WeightedError}; use crate::Rng; use self::coin_flipper::CoinFlipper; @@ -170,11 +170,7 @@ pub trait SliceRandom { R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow, - X: SampleUniform - + for<'a> ::core::ops::AddAssign<&'a X> - + ::core::cmp::PartialOrd - + Clone - + Default; + X: SampleUniform + Weight + ::core::cmp::PartialOrd; /// Biased sampling for one element (mut) /// @@ -203,11 +199,7 @@ pub trait SliceRandom { R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow, - X: SampleUniform - + for<'a> ::core::ops::AddAssign<&'a X> - + ::core::cmp::PartialOrd - + Clone - + Default; + X: SampleUniform + Weight + ::core::cmp::PartialOrd; /// Biased sampling of `amount` distinct elements /// @@ -585,11 +577,7 @@ impl SliceRandom for [T] { R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow, - X: SampleUniform - + for<'a> ::core::ops::AddAssign<&'a X> - + ::core::cmp::PartialOrd - + Clone - + Default, + X: SampleUniform + Weight + ::core::cmp::PartialOrd, { use crate::distributions::{Distribution, WeightedIndex}; let distr = WeightedIndex::new(self.iter().map(weight))?; @@ -604,11 +592,7 @@ impl SliceRandom for [T] { R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow, - X: SampleUniform - + for<'a> ::core::ops::AddAssign<&'a X> - + ::core::cmp::PartialOrd - + Clone - + Default, + X: SampleUniform + Weight + ::core::cmp::PartialOrd, { use crate::distributions::{Distribution, WeightedIndex}; let distr = WeightedIndex::new(self.iter().map(weight))?;