Skip to content

Commit

Permalink
Generate seed on calling thread if channel is empty
Browse files Browse the repository at this point in the history
  • Loading branch information
Pr0methean committed Jan 19, 2024
1 parent d1063af commit 5703a05
Showing 1 changed file with 25 additions and 38 deletions.
63 changes: 25 additions & 38 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
use aligned::{Aligned, A64};
use bytemuck::{cast_slice_mut, Pod, Zeroable};
use core::fmt::Debug;
use core::marker::PhantomData;
use core::mem::size_of;
use crossbeam_channel::{bounded, Receiver};
use crossbeam_channel::{bounded, Receiver, TryRecvError};
use log::info;
use rand::rngs::adapter::ReseedingRng;
use rand::rngs::OsRng;
Expand Down Expand Up @@ -71,27 +70,14 @@ unsafe impl<const N: usize, T: Pod> Pod for DefaultableAlignedArray<N, T> {}
/// * [SEEDS_CAPACITY] is the maximum number of `[u64; [WORDS_PER_SEED]]` instances to keep in memory for future use.
/// * [SourceType] is the type of the seed source; currently it's only used to ensure the [SharedBufferRng] implements
/// [CryptoRng] if and only if the seed source does so.
#[derive(Debug)]
pub struct SharedBufferRng<const WORDS_PER_SEED: usize, const SEEDS_CAPACITY: usize, SourceType> {
#[derive(Clone, Debug)]
pub struct SharedBufferRng<const WORDS_PER_SEED: usize, const SEEDS_CAPACITY: usize, SourceType: Rng + Clone> {
receiver: Receiver<DefaultableAlignedArray<WORDS_PER_SEED, u64>>,
// Used to determine whether to implement CryptoRng
_source: PhantomData<SourceType>,
source: SourceType,
}

// Can't derive Clone because that would only work for SourceType: Clone but we don't actually clone the source
impl<const WORDS_PER_SEED: usize, const SEEDS_CAPACITY: usize, SourceType> Clone
for SharedBufferRng<WORDS_PER_SEED, SEEDS_CAPACITY, SourceType>
{
/// Returns a new SharedBufferRng view on the same buffer.
fn clone(&self) -> Self {
SharedBufferRng {
receiver: self.receiver.clone(),
_source: self._source,
}
}
}

impl<const WORDS_PER_SEED: usize, const SEEDS_CAPACITY: usize, SourceType>
impl<const WORDS_PER_SEED: usize, const SEEDS_CAPACITY: usize, SourceType: Rng + Clone>
SharedBufferRng<WORDS_PER_SEED, SEEDS_CAPACITY, SourceType>
{
pub fn new_seeder(&self) -> BlockRng64<Self> {
Expand Down Expand Up @@ -146,7 +132,7 @@ pub fn default_rng() -> ReseedingRngStd {
impl<
const WORDS_PER_SEED: usize,
const SEEDS_CAPACITY: usize,
SourceType: Rng + Send + Debug + 'static,
SourceType: Rng + Send + Clone + Debug + 'static,
> SharedBufferRng<WORDS_PER_SEED, SEEDS_CAPACITY, SourceType>
where [(); WORDS_PER_SEED * size_of::<u64>()]:, [(); WORDS_PER_SEED * SEEDS_CAPACITY * size_of::<u64>()]:
{
Expand All @@ -155,6 +141,7 @@ impl<
pub fn new(mut source: SourceType) -> Self {
let (sender, receiver) = bounded(SEEDS_CAPACITY);
info!("Creating a SharedBufferRngInner for {:?}", source);
let source_copy = source.clone();
Builder::new().name(format!("Load seed from {:?} into shared buffer", source)).spawn(move || {
let mut aligned_seed: DefaultableAlignedArray<WORDS_PER_SEED, u64> = DefaultableAlignedArray::default();
if SEEDS_CAPACITY > 1 {
Expand All @@ -171,39 +158,39 @@ impl<
}
}
} else {
source.fill_bytes(cast_slice_mut(aligned_seed.as_mut()));
let result = sender.send(aligned_seed);
if !result.is_ok() {
info!("Detected (with seed already fetched) that a seed channel is no longer open for receiving");
return;
loop {
source.fill_bytes(cast_slice_mut(aligned_seed.as_mut()));
let result = sender.send(aligned_seed);
if !result.is_ok() {
info!("Detected (with seed already fetched) that a seed channel is no longer open for receiving");
return;
}
}
}
}).unwrap();
SharedBufferRng {
receiver: receiver.into(),
_source: PhantomData::default(),
source: source_copy,
}
}
}

impl<const WORDS_PER_SEED: usize, const SEEDS_CAPACITY: usize, SourceType> BlockRngCore
impl<const WORDS_PER_SEED: usize, const SEEDS_CAPACITY: usize, SourceType: Rng + Clone> BlockRngCore
for SharedBufferRng<WORDS_PER_SEED, SEEDS_CAPACITY, SourceType>
{
type Item = u64;
type Results = DefaultableAlignedArray<WORDS_PER_SEED, u64>;

fn generate(&mut self, results: &mut Self::Results) {
match self.receiver.recv() {
Ok(seed) => {
*results = seed;
return;
}
Err(e) => panic!("Error from recv(): {}", e),
match self.receiver.try_recv() {
Ok(seed) => *results = seed,
Err(TryRecvError::Empty) => self.source.clone().fill_bytes(cast_slice_mut(results.as_mut())),
Err(TryRecvError::Disconnected) => panic!("SharedBufferRng already closed"),
}
}
}

impl<const WORDS_PER_SEED: usize, const SEEDS_CAPACITY: usize, T: CryptoRng> CryptoRng
impl<const WORDS_PER_SEED: usize, const SEEDS_CAPACITY: usize, T: Rng + CryptoRng + Clone> CryptoRng
for SharedBufferRng<WORDS_PER_SEED, SEEDS_CAPACITY, T>
{
}
Expand All @@ -215,14 +202,14 @@ mod tests {
use rand_core::block::{BlockRng64, BlockRngCore};
use rand_core::Error;
use scc::Bag;
use std::sync::OnceLock;
use std::sync::{Arc, OnceLock};
use std::thread::spawn;

const U8_VALUES: usize = u8::MAX as usize + 1;

#[derive(Debug)]
#[derive(Clone, Debug)]
struct ByteValuesInOrderRng {
words_written: AtomicUsize,
words_written: Arc<AtomicUsize>,
}

impl BlockRngCore for ByteValuesInOrderRng {
Expand Down Expand Up @@ -261,7 +248,7 @@ mod tests {
const ITERS_PER_THREAD: usize = 1;
let seeder: SharedBufferRng<8, 4, _> =
SharedBufferRng::new(BlockRng64::new(ByteValuesInOrderRng {
words_written: AtomicUsize::new(0),
words_written: AtomicUsize::new(0).into(),
}));
let ths: Vec<_> = (0..THREADS)
.map(|_| {
Expand Down

0 comments on commit 5703a05

Please sign in to comment.