Skip to content

Commit

Permalink
Bug fix: do the right thing when a thread-local is shared with clones…
Browse files Browse the repository at this point in the history
… on the same thread
  • Loading branch information
Pr0methean committed Jan 20, 2024
1 parent 2e399d3 commit 7f77e1d
Showing 1 changed file with 25 additions and 26 deletions.
51 changes: 25 additions & 26 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
use aligned::{Aligned, A64};
use bytemuck::{cast_slice_mut, Pod, Zeroable};
use core::fmt::Debug;
use core::mem::{MaybeUninit, size_of};
use crossbeam_channel::{bounded, never, Receiver, Sender, TryRecvError};
use core::mem::{MaybeUninit, replace, size_of};
use crossbeam_channel::{bounded, Receiver, Sender, TryRecvError};
use log::{error, info};
use rand::rngs::adapter::ReseedingRng;
use rand::rngs::OsRng;
Expand Down Expand Up @@ -69,6 +69,19 @@ unsafe impl<const N: usize, T: Zeroable> Zeroable for DefaultableAlignedArray<N,

unsafe impl<const N: usize, T: Pod> Pod for DefaultableAlignedArray<N, T> {}

struct RecyclableVec<T, U: From<T>> {
contents: Vec<T>,
recycler: Sender<U>
}

impl <T, U: From<T>> Drop for RecyclableVec<T, U> {
fn drop(&mut self) {
let contents = replace(&mut self.contents, vec![]);
let _ = contents.into_iter().map_while(
|seed| self.recycler.try_send(seed.into()).ok()).last();
}
}

/// An RNG that reads from a shared buffer, to which only one thread per buffer will read from a seed source. It will
/// share the buffer with all of its clones. Once this and all clones have been dropped, the source-reading thread will
/// detect this using a [std::sync::Weak] reference and terminate. Since this RNG is used to implement [BlockRngCore]
Expand All @@ -86,24 +99,7 @@ pub struct SharedBufferRng<const WORDS_PER_SEED: usize, const SEEDS_CAPACITY: us
receiver: Receiver<DefaultableAlignedArray<WORDS_PER_SEED, u64>>,
// Used to determine whether to implement CryptoRng
source: SourceType,
thread_local_buffer: Arc<ThreadLocal<Vec<[u64; WORDS_PER_SEED]>>>
}

impl <const WORDS_PER_SEED: usize, const SEEDS_CAPACITY: usize, SourceType: Rng + Clone> Drop
for SharedBufferRng<WORDS_PER_SEED, SEEDS_CAPACITY, SourceType> {
fn drop(&mut self) {
// Drop own reference to the receiver, so the channel will close if no other refs exist
self.receiver = never();

// Recycle thread-local buffer into shared buffer, but only until it's full
match self.thread_local_buffer.remove() {
None => {}
Some(buffer) => {
let _ = buffer.into_iter().map_while(
|seed| self.sender.try_send(seed.into()).ok()).last();
}
}
}
thread_local_buffer: Arc<ThreadLocal<RecyclableVec<[u64; WORDS_PER_SEED], DefaultableAlignedArray<WORDS_PER_SEED, u64>>>>
}

impl<const WORDS_PER_SEED: usize, const SEEDS_CAPACITY: usize, SourceType: Rng + Clone>
Expand Down Expand Up @@ -218,18 +214,18 @@ impl<const WORDS_PER_SEED: usize, const SEEDS_CAPACITY: usize, SourceType: Rng +
self.thread_local_buffer.entry(|entry| match entry {
Entry::Occupied(local_buffer) => {
let local_buffer = local_buffer.into_mut();
if !local_buffer.is_empty() {
*(results.as_mut()) = local_buffer.pop().unwrap();
if !local_buffer.contents.is_empty() {
*(results.as_mut()) = local_buffer.contents.pop().unwrap();
} else {
match self.receiver.try_recv() {
Ok(seed) => *results = seed,
Err(TryRecvError::Empty) => {
unsafe {
self.source.clone().fill_bytes(
MaybeUninit::slice_assume_init_mut(MaybeUninit::slice_as_bytes_mut(local_buffer.spare_capacity_mut())));
local_buffer.set_len(SEEDS_CAPACITY);
MaybeUninit::slice_assume_init_mut(MaybeUninit::slice_as_bytes_mut(local_buffer.contents.spare_capacity_mut())));
local_buffer.contents.set_len(SEEDS_CAPACITY);
}
*(results.as_mut()) = local_buffer.pop().unwrap();
*(results.as_mut()) = local_buffer.contents.pop().unwrap();
},
Err(TryRecvError::Disconnected) => panic!("SharedBufferRng already closed"),
}
Expand All @@ -246,7 +242,10 @@ impl<const WORDS_PER_SEED: usize, const SEEDS_CAPACITY: usize, SourceType: Rng +
local_buffer.set_len(SEEDS_CAPACITY);
}
*(results.as_mut()) = local_buffer.pop().unwrap();
vacancy.insert(local_buffer);
vacancy.insert(RecyclableVec {
contents: local_buffer,
recycler: self.sender.clone()
});
},
Err(TryRecvError::Disconnected) => panic!("SharedBufferRng already closed"),
}
Expand Down

0 comments on commit 7f77e1d

Please sign in to comment.