diff --git a/Cargo.toml b/Cargo.toml index 60fd044..eae351f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,8 @@ license = "MIT OR Apache-2.0" repository = "https://github.com/plabayo/tokio-graceful" [dependencies] +pin-project-lite = "0.2.13" +slab = "0.4.9" tokio = { version = "1", features = ["rt", "signal", "sync", "macros", "time"] } tracing = "0.1" diff --git a/src/guard.rs b/src/guard.rs index 72ec84e..a118c00 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -3,31 +3,25 @@ use std::{ sync::{atomic::AtomicUsize, Arc}, }; -use tokio::{sync::Notify, task::JoinHandle}; +use tokio::task::JoinHandle; + +use crate::trigger::{Receiver, Sender}; #[derive(Debug)] pub struct ShutdownGuard(WeakShutdownGuard); #[derive(Debug, Clone)] pub struct WeakShutdownGuard { - pub(crate) notify_signal: Arc, - pub(crate) notify_zero: Arc, + pub(crate) trigger_rx: Receiver, + pub(crate) zero_tx: Sender, pub(crate) ref_count: Arc, } impl ShutdownGuard { - pub fn new( - notify_signal: Arc, - notify_zero: Arc, - ref_count: Arc, - ) -> Self { + pub fn new(trigger_rx: Receiver, zero_tx: Sender, ref_count: Arc) -> Self { let value = ref_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); tracing::trace!("new shutdown guard: ref_count+1: {}", value + 1); - Self(WeakShutdownGuard::new( - notify_signal, - notify_zero, - ref_count, - )) + Self(WeakShutdownGuard::new(trigger_rx, zero_tx, ref_count)) } #[inline] @@ -113,27 +107,23 @@ impl Drop for ShutdownGuard { .fetch_sub(1, std::sync::atomic::Ordering::SeqCst); tracing::trace!("drop shutdown guard: ref_count-1: {}", cnt - 1); if cnt == 1 { - self.0.notify_zero.notify_one(); + self.0.zero_tx.trigger(); } } } impl WeakShutdownGuard { - pub fn new( - notify_signal: Arc, - notify_zero: Arc, - ref_count: Arc, - ) -> Self { + pub fn new(trigger_rx: Receiver, zero_tx: Sender, ref_count: Arc) -> Self { Self { - notify_signal, - notify_zero, + trigger_rx, + zero_tx, ref_count, } } #[inline] pub async fn cancelled(&self) { - self.notify_signal.notified().await; + self.trigger_rx.clone().await; } #[inline] diff --git a/src/lib.rs b/src/lib.rs index 272f8de..1406e62 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,3 +5,5 @@ pub use guard::{ShutdownGuard, WeakShutdownGuard}; mod shutdown; pub use shutdown::Shutdown; + +pub(crate) mod trigger; diff --git a/src/shutdown.rs b/src/shutdown.rs index 6ba0973..ecb058a 100644 --- a/src/shutdown.rs +++ b/src/shutdown.rs @@ -1,30 +1,28 @@ use std::{future::Future, sync::Arc, time}; -use tokio::sync::Notify; - -use crate::{ShutdownGuard, WeakShutdownGuard}; +use crate::{ + trigger::{trigger, Receiver}, + ShutdownGuard, WeakShutdownGuard, +}; pub struct Shutdown { guard: ShutdownGuard, - notify_zero: Arc, + zero_rx: Receiver, } impl Shutdown { pub fn new(signal: impl Future + Send + 'static) -> Self { - let notify_signal = Arc::new(Notify::new()); - let notify_zero = Arc::new(Notify::new()); - let guard = ShutdownGuard::new( - notify_signal.clone(), - notify_zero.clone(), - Arc::new(0usize.into()), - ); + let (signal_tx, signal_rx) = trigger(); + let (zero_tx, zero_rx) = trigger(); + + let guard = ShutdownGuard::new(signal_rx, zero_tx, Arc::new(0usize.into())); tokio::spawn(async move { signal.await; - notify_signal.notify_waiters(); + signal_tx.trigger(); }); - Self { guard, notify_zero } + Self { guard, zero_rx } } #[inline] @@ -57,11 +55,10 @@ impl Shutdown { } pub async fn shutdown(self) { - let zero_notified = self.notify_zero.notified(); tracing::trace!("::shutdown: waiting for signal to trigger (read: to be cancelled)"); self.guard.downgrade().cancelled().await; tracing::trace!("::shutdown: waiting for all guards to drop"); - zero_notified.await; + self.zero_rx.await; tracing::trace!("::shutdown: ready"); } @@ -69,7 +66,6 @@ impl Shutdown { self, limit: time::Duration, ) -> Result { - let zero_notified = self.notify_zero.notified(); tracing::trace!("::shutdown: waiting for signal to trigger (read: to be cancelled)"); self.guard.downgrade().cancelled().await; tracing::trace!( @@ -79,7 +75,7 @@ impl Shutdown { let start: time::Instant = time::Instant::now(); tokio::select! { _ = tokio::time::sleep(limit) => { Err(TimeoutError(limit)) } - _ = zero_notified => { Ok(start.elapsed()) } + _ = self.zero_rx => { Ok(start.elapsed()) } } } } diff --git a/src/trigger.rs b/src/trigger.rs new file mode 100644 index 0000000..b62998c --- /dev/null +++ b/src/trigger.rs @@ -0,0 +1,155 @@ +use std::{ + future::Future, + pin::Pin, + sync::{atomic::AtomicBool, Arc, Mutex}, + task::{Context, Poll, Waker}, +}; + +use pin_project_lite::pin_project; +use slab::Slab; + +type WakerList = Arc>>; +type TriggerState = Arc; + +#[derive(Debug, Clone)] +struct Subscriber { + wakers: WakerList, + state: TriggerState, +} + +#[derive(Debug)] +enum SubscriberState { + Waiting(usize), + Triggered, +} + +impl Subscriber { + pub fn state(&self, cx: &mut Context, key: Option) -> SubscriberState { + if self.state.load(std::sync::atomic::Ordering::SeqCst) { + return SubscriberState::Triggered; + } + + let mut wakers = self.wakers.lock().unwrap(); + if self.state.load(std::sync::atomic::Ordering::SeqCst) { + return SubscriberState::Triggered; + } + + let waker = cx.waker().clone(); + + SubscriberState::Waiting(if let Some(key) = key { + tracing::trace!("trigger::Subscriber: updating waker for key: {}", key); + *wakers.get_mut(key).unwrap() = waker; + key + } else { + let key = wakers.insert(waker); + tracing::trace!("trigger::Subscriber: insert waker for key: {}", key); + key + }) + } +} + +#[derive(Debug)] +enum ReceiverState { + Open { sub: Subscriber, key: Option }, + Closed, +} + +impl Clone for ReceiverState { + fn clone(&self) -> Self { + match self { + ReceiverState::Open { sub, .. } => ReceiverState::Open { + sub: sub.clone(), + key: None, + }, + ReceiverState::Closed => ReceiverState::Closed, + } + } +} + +impl Drop for ReceiverState { + fn drop(&mut self) { + if let ReceiverState::Open { sub, key } = self { + if let Some(key) = key.take() { + let mut wakers = sub.wakers.lock().unwrap(); + tracing::trace!( + "trigger::ReceiverState::Drop: remove waker for key: {}", + key + ); + wakers.remove(key); + } + } + } +} + +pin_project! { + #[derive(Debug, Clone)] + pub struct Receiver { + state: ReceiverState, + } +} + +impl Receiver { + fn new(wakers: WakerList, state: TriggerState) -> Self { + Self { + state: ReceiverState::Open { + sub: Subscriber { wakers, state }, + key: None, + }, + } + } +} + +impl Future for Receiver { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + let this = self.project(); + match this.state { + ReceiverState::Open { sub, key } => { + let state = sub.state(cx, *key); + match state { + SubscriberState::Waiting(new_key) => { + *key = Some(new_key); + std::task::Poll::Pending + } + SubscriberState::Triggered => { + *this.state = ReceiverState::Closed; + std::task::Poll::Ready(()) + } + } + } + ReceiverState::Closed => std::task::Poll::Ready(()), + } + } +} + +#[derive(Debug, Clone)] +pub struct Sender { + wakers: WakerList, + state: TriggerState, +} + +impl Sender { + fn new(wakers: WakerList, state: TriggerState) -> Self { + Self { wakers, state } + } + + pub fn trigger(&self) { + let wakers = self.wakers.lock().unwrap(); + self.state.store(true, std::sync::atomic::Ordering::SeqCst); + for (key, waker) in wakers.iter() { + tracing::trace!("trigger::Sender: wake up waker with key: {}", key); + waker.wake_by_ref(); + } + } +} + +pub fn trigger() -> (Sender, Receiver) { + let wakers = Arc::new(Mutex::new(Slab::new())); + let state = Arc::new(AtomicBool::new(false)); + + let sender = Sender::new(wakers.clone(), state.clone()); + let receiver = Receiver::new(wakers, state); + + (sender, receiver) +}