Skip to content

Commit

Permalink
[feat] add shrink_to_fit() to Sender<T> and Receiver<T>
Browse files Browse the repository at this point in the history
  • Loading branch information
rakbladsvalsen committed May 3, 2024
1 parent fcf3849 commit 96e7399
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 80 deletions.
230 changes: 151 additions & 79 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
#![deny(missing_docs)]

#[cfg(feature = "select")]
pub mod select;
#[cfg(feature = "async")]
pub mod r#async;
#[cfg(feature = "select")]
pub mod select;

mod signal;

Expand All @@ -40,16 +40,19 @@ pub use select::Selector;

use std::{
collections::VecDeque,
sync::{Arc, atomic::{AtomicUsize, AtomicBool, Ordering}, Weak},
time::{Duration, Instant},
fmt,
marker::PhantomData,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, Weak,
},
thread,
fmt,
time::{Duration, Instant},
};

use crate::signal::{Signal, SyncSignal};
#[cfg(feature = "spin")]
use spin1::{Mutex as Spinlock, MutexGuard as SpinlockGuard};
use crate::signal::{Signal, SyncSignal};

/// An error that may be emitted when attempting to send a value into a channel on a sender when
/// all receivers are dropped.
Expand All @@ -58,7 +61,9 @@ pub struct SendError<T>(pub T);

impl<T> SendError<T> {
/// Consume the error, yielding the message that failed to send.
pub fn into_inner(self) -> T { self.0 }
pub fn into_inner(self) -> T {
self.0
}
}

impl<T> fmt::Debug for SendError<T> {
Expand Down Expand Up @@ -423,7 +428,6 @@ type ChanLock<T> = Spinlock<T>;
#[cfg(not(feature = "spin"))]
type ChanLock<T> = Mutex<T>;


type SignalVec<T> = VecDeque<Arc<Hook<T, dyn signal::Signal>>>;
struct Chan<T> {
sending: Option<(usize, SignalVec<T>)>,
Expand Down Expand Up @@ -476,6 +480,15 @@ impl<T> Shared<T> {
}
}

fn shrink_to_fit(&self) {
let mut lock = wait_lock(&self.chan);
lock.queue.shrink_to_fit();
}

fn queue_capacity(&self) -> usize {
wait_lock(&self.chan).queue.capacity()
}

fn send<S: Signal, R: From<Result<(), TrySendTimeoutError<T>>>>(
&self,
msg: T,
Expand Down Expand Up @@ -513,20 +526,26 @@ impl<T> Shared<T> {
drop(chan);
break;
}
},
}
Some((None, signal)) => {
drop(chan);
signal.fire();
break; // Was sync, so it has acquired the message
},
}
}
}

Ok(()).into()
} else if chan.sending.as_ref().map(|(cap, _)| chan.queue.len() < *cap).unwrap_or(true) {
} else if chan
.sending
.as_ref()
.map(|(cap, _)| chan.queue.len() < *cap)
.unwrap_or(true)
{
chan.queue.push_back(msg);
Ok(()).into()
} else if should_block { // Only bounded from here on
} else if should_block {
// Only bounded from here on
let hook = make_signal(msg);
chan.sending.as_mut().unwrap().1.push_back(hook.clone());
drop(chan);
Expand All @@ -550,29 +569,37 @@ impl<T> Shared<T> {
// make_signal
|msg| Hook::slot(Some(msg), SyncSignal::default()),
// do_block
|hook| if let Some(deadline) = block.unwrap() {
hook.wait_deadline_send(&self.disconnected, deadline)
.or_else(|timed_out| {
if timed_out { // Remove our signal
let hook: Arc<Hook<T, dyn signal::Signal>> = hook.clone();
wait_lock(&self.chan).sending
.as_mut()
.unwrap().1
.retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
}
hook.try_take().map(|msg| if self.is_disconnected() {
Err(TrySendTimeoutError::Disconnected(msg))
} else {
Err(TrySendTimeoutError::Timeout(msg))
|hook| {
if let Some(deadline) = block.unwrap() {
hook.wait_deadline_send(&self.disconnected, deadline)
.or_else(|timed_out| {
if timed_out {
// Remove our signal
let hook: Arc<Hook<T, dyn signal::Signal>> = hook.clone();
wait_lock(&self.chan)
.sending
.as_mut()
.unwrap()
.1
.retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
}
hook.try_take()
.map(|msg| {
if self.is_disconnected() {
Err(TrySendTimeoutError::Disconnected(msg))
} else {
Err(TrySendTimeoutError::Timeout(msg))
}
})
.unwrap_or(Ok(()))
})
.unwrap_or(Ok(()))
})
} else {
hook.wait_send(&self.disconnected);
} else {
hook.wait_send(&self.disconnected);

match hook.try_take() {
Some(msg) => Err(TrySendTimeoutError::Disconnected(msg)),
None => Ok(()),
match hook.try_take() {
Some(msg) => Err(TrySendTimeoutError::Disconnected(msg)),
None => Ok(()),
}
}
},
)
Expand Down Expand Up @@ -612,32 +639,36 @@ impl<T> Shared<T> {
// make_signal
|| Hook::slot(None, SyncSignal::default()),
// do_block
|hook| if let Some(deadline) = block.unwrap() {
hook.wait_deadline_recv(&self.disconnected, deadline)
.or_else(|timed_out| {
if timed_out { // Remove our signal
let hook: Arc<Hook<T, dyn Signal>> = hook.clone();
wait_lock(&self.chan).waiting
.retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
}
match hook.try_take() {
Some(msg) => Ok(msg),
None => {
let disconnected = self.is_disconnected(); // Check disconnect *before* msg
if let Some(msg) = wait_lock(&self.chan).queue.pop_front() {
Ok(msg)
} else if disconnected {
Err(TryRecvTimeoutError::Disconnected)
} else {
Err(TryRecvTimeoutError::Timeout)
|hook| {
if let Some(deadline) = block.unwrap() {
hook.wait_deadline_recv(&self.disconnected, deadline)
.or_else(|timed_out| {
if timed_out {
// Remove our signal
let hook: Arc<Hook<T, dyn Signal>> = hook.clone();
wait_lock(&self.chan)
.waiting
.retain(|s| s.signal().as_ptr() != hook.signal().as_ptr());
}
match hook.try_take() {
Some(msg) => Ok(msg),
None => {
let disconnected = self.is_disconnected(); // Check disconnect *before* msg
if let Some(msg) = wait_lock(&self.chan).queue.pop_front() {
Ok(msg)
} else if disconnected {
Err(TryRecvTimeoutError::Disconnected)
} else {
Err(TryRecvTimeoutError::Timeout)
}
}
},
}
})
} else {
hook.wait_recv(&self.disconnected)
.or_else(|| wait_lock(&self.chan).queue.pop_front())
.ok_or(TryRecvTimeoutError::Disconnected)
}
})
} else {
hook.wait_recv(&self.disconnected)
.or_else(|| wait_lock(&self.chan).queue.pop_front())
.ok_or(TryRecvTimeoutError::Disconnected)
}
},
)
}
Expand Down Expand Up @@ -668,7 +699,9 @@ impl<T> Shared<T> {
}

fn is_full(&self) -> bool {
self.capacity().map(|cap| cap == self.len()).unwrap_or(false)
self.capacity()
.map(|cap| cap == self.len())
.unwrap_or(false)
}

fn len(&self) -> usize {
Expand Down Expand Up @@ -712,22 +745,26 @@ impl<T> Sender<T> {
/// or all receivers have been dropped. If the channel is unbounded, this method will not
/// block.
pub fn send(&self, msg: T) -> Result<(), SendError<T>> {
self.shared.send_sync(msg, Some(None)).map_err(|err| match err {
TrySendTimeoutError::Disconnected(msg) => SendError(msg),
_ => unreachable!(),
})
self.shared
.send_sync(msg, Some(None))
.map_err(|err| match err {
TrySendTimeoutError::Disconnected(msg) => SendError(msg),
_ => unreachable!(),
})
}

/// Send a value into the channel, returning an error if all receivers have been dropped
/// or the deadline has passed. If the channel is bounded and is full, this method will
/// block until space is available, the deadline is reached, or all receivers have been
/// dropped.
pub fn send_deadline(&self, msg: T, deadline: Instant) -> Result<(), SendTimeoutError<T>> {
self.shared.send_sync(msg, Some(Some(deadline))).map_err(|err| match err {
TrySendTimeoutError::Disconnected(msg) => SendTimeoutError::Disconnected(msg),
TrySendTimeoutError::Timeout(msg) => SendTimeoutError::Timeout(msg),
_ => unreachable!(),
})
self.shared
.send_sync(msg, Some(Some(deadline)))
.map_err(|err| match err {
TrySendTimeoutError::Disconnected(msg) => SendTimeoutError::Disconnected(msg),
TrySendTimeoutError::Timeout(msg) => SendTimeoutError::Timeout(msg),
_ => unreachable!(),
})
}

/// Send a value into the channel, returning an error if all receivers have been dropped
Expand Down Expand Up @@ -770,6 +807,16 @@ impl<T> Sender<T> {
self.shared.sender_count()
}

/// Discards excess capacity in the internal queue.
pub fn shrink_to_fit(&self) {
self.shared.shrink_to_fit();
}

/// Returns the number of elements the internal queue can hold without reallocating.
pub fn queue_capacity(&self) -> usize {
self.shared.queue_capacity()
}

/// Get the number of receivers that currently exist.
///
/// Note that this method makes no guarantees that a subsequent send will succeed; it's
Expand Down Expand Up @@ -800,7 +847,9 @@ impl<T> Clone for Sender<T> {
/// contents will only be cleaned up when all senders and the receiver have been dropped.
fn clone(&self) -> Self {
self.shared.sender_count.fetch_add(1, Ordering::Relaxed);
Self { shared: self.shared.clone() }
Self {
shared: self.shared.clone(),
}
}
}

Expand Down Expand Up @@ -861,7 +910,9 @@ impl<T> WeakSender<T> {
impl<T> Clone for WeakSender<T> {
/// Clones this [`WeakSender`].
fn clone(&self) -> Self {
Self { shared: self.shared.clone() }
Self {
shared: self.shared.clone(),
}
}
}

Expand Down Expand Up @@ -897,11 +948,13 @@ impl<T> Receiver<T> {
/// Wait for an incoming value from the channel associated with this receiver, returning an
/// error if all senders have been dropped or the deadline has passed.
pub fn recv_deadline(&self, deadline: Instant) -> Result<T, RecvTimeoutError> {
self.shared.recv_sync(Some(Some(deadline))).map_err(|err| match err {
TryRecvTimeoutError::Disconnected => RecvTimeoutError::Disconnected,
TryRecvTimeoutError::Timeout => RecvTimeoutError::Timeout,
_ => unreachable!(),
})
self.shared
.recv_sync(Some(Some(deadline)))
.map_err(|err| match err {
TryRecvTimeoutError::Disconnected => RecvTimeoutError::Disconnected,
TryRecvTimeoutError::Timeout => RecvTimeoutError::Timeout,
_ => unreachable!(),
})
}

/// Wait for an incoming value from the channel associated with this receiver, returning an
Expand All @@ -910,6 +963,16 @@ impl<T> Receiver<T> {
self.recv_deadline(Instant::now().checked_add(dur).unwrap())
}

/// Discard excess capacity in the internal queue.
pub fn shrink_to_fit(&self) {
self.shared.shrink_to_fit();
}

/// Returns the number of elements the internal queue can hold without reallocating.
pub fn queue_capacity(&self) -> usize {
self.shared.queue_capacity()
}

/// Create a blocking iterator over the values received on the channel that finishes iteration
/// when all senders have been dropped.
///
Expand All @@ -932,7 +995,10 @@ impl<T> Receiver<T> {
chan.pull_pending(false);
let queue = std::mem::take(&mut chan.queue);

Drain { queue, _phantom: PhantomData }
Drain {
queue,
_phantom: PhantomData,
}
}

/// Returns true if all senders for this channel have been dropped.
Expand Down Expand Up @@ -988,7 +1054,9 @@ impl<T> Clone for Receiver<T> {
/// implementing work stealing for concurrent programs.
fn clone(&self) -> Self {
self.shared.receiver_count.fetch_add(1, Ordering::Relaxed);
Self { shared: self.shared.clone() }
Self {
shared: self.shared.clone(),
}
}
}

Expand Down Expand Up @@ -1108,7 +1176,9 @@ impl<T> Iterator for IntoIter<T> {
pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(Shared::new(None));
(
Sender { shared: shared.clone() },
Sender {
shared: shared.clone(),
},
Receiver { shared },
)
}
Expand Down Expand Up @@ -1143,7 +1213,9 @@ pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) {
pub fn bounded<T>(cap: usize) -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(Shared::new(Some(cap)));
(
Sender { shared: shared.clone() },
Sender {
shared: shared.clone(),
},
Receiver { shared },
)
}
Loading

0 comments on commit 96e7399

Please sign in to comment.