Skip to content

Commit

Permalink
UniformUsize: inline sub-impls
Browse files Browse the repository at this point in the history
  • Loading branch information
dhardy committed Sep 6, 2024
1 parent d9debb7 commit 2144ae5
Showing 1 changed file with 70 additions and 29 deletions.
99 changes: 70 additions & 29 deletions src/distr/uniform_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,21 +395,19 @@ uniform_simd_int_impl! { (u8, i8), (u16, i16), (u32, i32), (u64, i64) }
/// this implementation will use 32-bit sampling when possible.
#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct UniformUsize(UniformUsizeImpl);
pub struct UniformUsize {
low: usize,
range: usize,
thresh: usize,
#[cfg(target_pointer_width = "64")]
mode64: bool,
}

#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
impl SampleUniform for usize {
type Sampler = UniformUsize;
}

#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum UniformUsizeImpl {
U32(UniformInt<u32>),
#[cfg(target_pointer_width = "64")]
U64(UniformInt<u64>),
}

#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
impl UniformSampler for UniformUsize {
type X = usize;
Expand All @@ -427,13 +425,7 @@ impl UniformSampler for UniformUsize {
return Err(Error::EmptyRange);
}

#[cfg(target_pointer_width = "64")]
if high > (u32::MAX as usize) {
return UniformInt::new(low as u64, high as u64)
.map(|ui| UniformUsize(UniformUsizeImpl::U64(ui)));
}

UniformInt::new(low as u32, high as u32).map(|ui| UniformUsize(UniformUsizeImpl::U32(ui)))
UniformSampler::new_inclusive(low, high - 1)
}

#[inline] // if the range is constant, this helps LLVM to do the
Expand All @@ -450,21 +442,72 @@ impl UniformSampler for UniformUsize {
}

#[cfg(target_pointer_width = "64")]
if high > (u32::MAX as usize) {
return UniformInt::new_inclusive(low as u64, high as u64)
.map(|ui| UniformUsize(UniformUsizeImpl::U64(ui)));
let mode64 = high > (u32::MAX as usize);
#[cfg(target_pointer_width = "32")]
let mode64 = false;

let (range, thresh);
if cfg!(target_pointer_width = "64") && !mode64 {
let range32 = (high as u32).wrapping_sub(low as u32).wrapping_add(1);
range = range32 as usize;
thresh = if range32 > 0 {
(range32.wrapping_neg() % range32) as usize
} else {
0
};
} else {
range = high.wrapping_sub(low).wrapping_add(1);
thresh = if range > 0 {
range.wrapping_neg() % range
} else {
0
};
}

UniformInt::new_inclusive(low as u32, high as u32)
.map(|ui| UniformUsize(UniformUsizeImpl::U32(ui)))
Ok(UniformUsize {
low,
range,
thresh,
#[cfg(target_pointer_width = "64")]
mode64,
})
}

#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
match self.0 {
UniformUsizeImpl::U32(uu) => uu.sample(rng) as usize,
#[cfg(target_pointer_width = "64")]
UniformUsizeImpl::U64(uu) => uu.sample(rng) as usize,
#[cfg(target_pointer_width = "32")]
let mode32 = true;
#[cfg(target_pointer_width = "64")]
let mode32 = !self.mode64;

if mode32 {
let range = self.range as u32;
if range == 0 {
return rng.random::<u32>() as usize;
}

let thresh = self.thresh as u32;
let hi = loop {
let (hi, lo) = rng.random::<u32>().wmul(range);
if lo >= thresh {
break hi;
}
};
self.low.wrapping_add(hi as usize)
} else {
let range = self.range as u64;
if range == 0 {
return rng.random::<u64>() as usize;
}

let thresh = self.thresh as u64;
let hi = loop {
let (hi, lo) = rng.random::<u64>().wmul(range);
if lo >= thresh {
break hi;
}
};
self.low.wrapping_add(hi as usize)
}
}

Expand All @@ -484,8 +527,7 @@ impl UniformSampler for UniformUsize {
return Err(Error::EmptyRange);
}

#[cfg(target_pointer_width = "64")]
if high > (u32::MAX as usize) {
if cfg!(target_pointer_width = "64") && high > (u32::MAX as usize) {
return UniformInt::<u64>::sample_single(low as u64, high as u64, rng)
.map(|x| x as usize);
}
Expand All @@ -509,8 +551,7 @@ impl UniformSampler for UniformUsize {
return Err(Error::EmptyRange);
}

#[cfg(target_pointer_width = "64")]
if high > (u32::MAX as usize) {
if cfg!(target_pointer_width = "64") && high > (u32::MAX as usize) {
return UniformInt::<u64>::sample_single_inclusive(low as u64, high as u64, rng)
.map(|x| x as usize);
}
Expand Down

0 comments on commit 2144ae5

Please sign in to comment.