Skip to content

Commit

Permalink
implement custom trigger logic
Browse files Browse the repository at this point in the history
  • Loading branch information
glendc committed Sep 3, 2023
1 parent 961e4d7 commit 42fba35
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 39 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ license = "MIT OR Apache-2.0"
repository = "https://github.com/plabayo/tokio-graceful"

[dependencies]
pin-project-lite = "0.2.13"
slab = "0.4.9"
tokio = { version = "1", features = ["rt", "signal", "sync", "macros", "time"] }
tracing = "0.1"

Expand Down
34 changes: 12 additions & 22 deletions src/guard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,25 @@ use std::{
sync::{atomic::AtomicUsize, Arc},
};

use tokio::{sync::Notify, task::JoinHandle};
use tokio::task::JoinHandle;

use crate::trigger::{Receiver, Sender};

#[derive(Debug)]
pub struct ShutdownGuard(WeakShutdownGuard);

#[derive(Debug, Clone)]
pub struct WeakShutdownGuard {
pub(crate) notify_signal: Arc<Notify>,
pub(crate) notify_zero: Arc<Notify>,
pub(crate) trigger_rx: Receiver,
pub(crate) zero_tx: Sender,
pub(crate) ref_count: Arc<AtomicUsize>,
}

impl ShutdownGuard {
pub fn new(
notify_signal: Arc<Notify>,
notify_zero: Arc<Notify>,
ref_count: Arc<AtomicUsize>,
) -> Self {
pub fn new(trigger_rx: Receiver, zero_tx: Sender, ref_count: Arc<AtomicUsize>) -> Self {
let value = ref_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
tracing::trace!("new shutdown guard: ref_count+1: {}", value + 1);
Self(WeakShutdownGuard::new(
notify_signal,
notify_zero,
ref_count,
))
Self(WeakShutdownGuard::new(trigger_rx, zero_tx, ref_count))
}

#[inline]
Expand Down Expand Up @@ -113,27 +107,23 @@ impl Drop for ShutdownGuard {
.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
tracing::trace!("drop shutdown guard: ref_count-1: {}", cnt - 1);
if cnt == 1 {
self.0.notify_zero.notify_one();
self.0.zero_tx.trigger();
}
}
}

impl WeakShutdownGuard {
pub fn new(
notify_signal: Arc<Notify>,
notify_zero: Arc<Notify>,
ref_count: Arc<AtomicUsize>,
) -> Self {
pub fn new(trigger_rx: Receiver, zero_tx: Sender, ref_count: Arc<AtomicUsize>) -> Self {
Self {
notify_signal,
notify_zero,
trigger_rx,
zero_tx,
ref_count,
}
}

#[inline]
pub async fn cancelled(&self) {
self.notify_signal.notified().await;
self.trigger_rx.clone().await;
}

#[inline]
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ pub use guard::{ShutdownGuard, WeakShutdownGuard};

mod shutdown;
pub use shutdown::Shutdown;

pub(crate) mod trigger;
30 changes: 13 additions & 17 deletions src/shutdown.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,28 @@
use std::{future::Future, sync::Arc, time};

use tokio::sync::Notify;

use crate::{ShutdownGuard, WeakShutdownGuard};
use crate::{
trigger::{trigger, Receiver},
ShutdownGuard, WeakShutdownGuard,
};

pub struct Shutdown {
guard: ShutdownGuard,
notify_zero: Arc<Notify>,
zero_rx: Receiver,
}

impl Shutdown {
pub fn new(signal: impl Future<Output = ()> + Send + 'static) -> Self {
let notify_signal = Arc::new(Notify::new());
let notify_zero = Arc::new(Notify::new());
let guard = ShutdownGuard::new(
notify_signal.clone(),
notify_zero.clone(),
Arc::new(0usize.into()),
);
let (signal_tx, signal_rx) = trigger();
let (zero_tx, zero_rx) = trigger();

let guard = ShutdownGuard::new(signal_rx, zero_tx, Arc::new(0usize.into()));

tokio::spawn(async move {
signal.await;
notify_signal.notify_waiters();
signal_tx.trigger();
});

Self { guard, notify_zero }
Self { guard, zero_rx }
}

#[inline]
Expand Down Expand Up @@ -57,19 +55,17 @@ impl Shutdown {
}

pub async fn shutdown(self) {
let zero_notified = self.notify_zero.notified();
tracing::trace!("::shutdown: waiting for signal to trigger (read: to be cancelled)");
self.guard.downgrade().cancelled().await;
tracing::trace!("::shutdown: waiting for all guards to drop");
zero_notified.await;
self.zero_rx.await;
tracing::trace!("::shutdown: ready");
}

pub async fn shutdown_with_limit(
self,
limit: time::Duration,
) -> Result<time::Duration, TimeoutError> {
let zero_notified = self.notify_zero.notified();
tracing::trace!("::shutdown: waiting for signal to trigger (read: to be cancelled)");
self.guard.downgrade().cancelled().await;
tracing::trace!(
Expand All @@ -79,7 +75,7 @@ impl Shutdown {
let start: time::Instant = time::Instant::now();
tokio::select! {
_ = tokio::time::sleep(limit) => { Err(TimeoutError(limit)) }
_ = zero_notified => { Ok(start.elapsed()) }
_ = self.zero_rx => { Ok(start.elapsed()) }
}
}
}
Expand Down
155 changes: 155 additions & 0 deletions src/trigger.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
use std::{
future::Future,
pin::Pin,
sync::{atomic::AtomicBool, Arc, Mutex},
task::{Context, Poll, Waker},
};

use pin_project_lite::pin_project;
use slab::Slab;

type WakerList = Arc<Mutex<Slab<Waker>>>;
type TriggerState = Arc<AtomicBool>;

#[derive(Debug, Clone)]
struct Subscriber {
wakers: WakerList,
state: TriggerState,
}

#[derive(Debug)]
enum SubscriberState {
Waiting(usize),
Triggered,
}

impl Subscriber {
pub fn state(&self, cx: &mut Context, key: Option<usize>) -> SubscriberState {
if self.state.load(std::sync::atomic::Ordering::SeqCst) {
return SubscriberState::Triggered;
}

let mut wakers = self.wakers.lock().unwrap();
if self.state.load(std::sync::atomic::Ordering::SeqCst) {
return SubscriberState::Triggered;
}

let waker = cx.waker().clone();

SubscriberState::Waiting(if let Some(key) = key {
tracing::trace!("trigger::Subscriber: updating waker for key: {}", key);
*wakers.get_mut(key).unwrap() = waker;
key
} else {
let key = wakers.insert(waker);
tracing::trace!("trigger::Subscriber: insert waker for key: {}", key);
key
})
}
}

#[derive(Debug)]
enum ReceiverState {
Open { sub: Subscriber, key: Option<usize> },
Closed,
}

impl Clone for ReceiverState {
fn clone(&self) -> Self {
match self {
ReceiverState::Open { sub, .. } => ReceiverState::Open {
sub: sub.clone(),
key: None,
},
ReceiverState::Closed => ReceiverState::Closed,
}
}
}

impl Drop for ReceiverState {
fn drop(&mut self) {
if let ReceiverState::Open { sub, key } = self {
if let Some(key) = key.take() {
let mut wakers = sub.wakers.lock().unwrap();
tracing::trace!(
"trigger::ReceiverState::Drop: remove waker for key: {}",
key
);
wakers.remove(key);
}
}
}
}

pin_project! {
#[derive(Debug, Clone)]
pub struct Receiver {
state: ReceiverState,
}
}

impl Receiver {
fn new(wakers: WakerList, state: TriggerState) -> Self {
Self {
state: ReceiverState::Open {
sub: Subscriber { wakers, state },
key: None,
},
}
}
}

impl Future for Receiver {
type Output = ();

fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
match this.state {
ReceiverState::Open { sub, key } => {
let state = sub.state(cx, *key);
match state {
SubscriberState::Waiting(new_key) => {
*key = Some(new_key);
std::task::Poll::Pending
}
SubscriberState::Triggered => {
*this.state = ReceiverState::Closed;
std::task::Poll::Ready(())
}
}
}
ReceiverState::Closed => std::task::Poll::Ready(()),
}
}
}

#[derive(Debug, Clone)]
pub struct Sender {
wakers: WakerList,
state: TriggerState,
}

impl Sender {
fn new(wakers: WakerList, state: TriggerState) -> Self {
Self { wakers, state }
}

pub fn trigger(&self) {
let wakers = self.wakers.lock().unwrap();
self.state.store(true, std::sync::atomic::Ordering::SeqCst);
for (key, waker) in wakers.iter() {
tracing::trace!("trigger::Sender: wake up waker with key: {}", key);
waker.wake_by_ref();
}
}
}

pub fn trigger() -> (Sender, Receiver) {
let wakers = Arc::new(Mutex::new(Slab::new()));
let state = Arc::new(AtomicBool::new(false));

let sender = Sender::new(wakers.clone(), state.clone());
let receiver = Receiver::new(wakers, state);

(sender, receiver)
}

0 comments on commit 42fba35

Please sign in to comment.