From 01a4214510beccea82e28102eefca391e6f2d5b9 Mon Sep 17 00:00:00 2001 From: Al Liu Date: Sun, 11 Feb 2024 02:32:10 +0800 Subject: [PATCH] polish `AsyncWaitGroup` --- Cargo.toml | 13 ++++-- README.md | 16 ++++++- src/lib.rs | 122 ++++++++++++++++++++++++++++++++--------------------- 3 files changed, 98 insertions(+), 53 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9dbd72a..60d952f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ homepage = "https://github.com/al8n/wg" repository = "https://github.com/al8n/wg.git" documentation = "https://docs.rs/wg/" readme = "README.md" -version = "0.5.0" +version = "0.6.0" license = "MIT OR Apache-2.0" keywords = ["waitgroup", "async", "sync", "notify", "wake"] categories = ["asynchronous", "concurrency", "data-structures"] @@ -18,14 +18,19 @@ full = ["triomphe", "parking_lot"] triomphe = ["dep:triomphe"] parking_lot = ["dep:parking_lot"] +future = ["event-listener", "event-listener-strategy", "pin-project-lite"] + [dependencies] -parking_lot = {version = "0.12", optional = true } +parking_lot = { version = "0.12", optional = true } triomphe = { version = "0.1", optional = true } +event-listener = { version = "5", optional = true } +event-listener-strategy = { version = "0.5", optional = true } +pin-project-lite = { version = "0.2", optional = true } [dev-dependencies] tokio = { version = "1", features = ["full"] } -async-std = { version = "1.12", features = ["attributes"] } +async-std = { version = "1.12", features = ["attributes"] } [package.metadata.docs.rs] all-features = true -rustdoc-args = ["--cfg", "docsrs"] \ No newline at end of file +rustdoc-args = ["--cfg", "docsrs"] diff --git a/README.md b/README.md index 575c359..1f10415 100644 --- a/README.md +++ b/README.md @@ -17,11 +17,25 @@ Golang like WaitGroup implementation for sync/async Rust. ## Installation + +By default, blocking version `WaitGroup` is enabled, if you want to use non-blocking `AsyncWaitGroup`, you need to +enbale `future` feature in your `Cargo.toml`. + +### Sync + ```toml [dependencies] -wg = "0.5" +wg = "0.6" ``` +### Async + +```toml +[dependencies] +wg = { version: "0.6", features = ["future"] } +``` + + ## Example ### Sync diff --git a/src/lib.rs b/src/lib.rs index c216c95..632ee06 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -371,21 +371,25 @@ impl WaitGroup { } } +#[cfg(feature = "future")] pub use r#async::*; +#[cfg(feature = "future")] mod r#async { use super::*; - use std::sync::atomic::{AtomicUsize, Ordering}; + use event_listener::{Event, EventListener}; + use event_listener_strategy::{easy_wrapper, EventListenerFuture, Strategy}; use std::{ - future::Future, pin::Pin, - task::{Context, Poll, Waker}, + sync::atomic::{AtomicUsize, Ordering}, + task::Poll, }; + #[derive(Debug)] struct AsyncInner { - waker: Mutex>, - count: AtomicUsize, + counter: AtomicUsize, + event: Event, } /// An AsyncWaitGroup waits for a collection of threads to finish. @@ -429,7 +433,7 @@ mod r#async { /// /// [`wait`]: struct.AsyncWaitGroup.html#method.wait /// [`add`]: struct.AsyncWaitGroup.html#method.add - #[cfg_attr(docsrs, doc(cfg(feature = "test")))] + #[cfg_attr(docsrs, doc(cfg(feature = "future")))] pub struct AsyncWaitGroup { inner: Arc, } @@ -438,8 +442,8 @@ mod r#async { fn default() -> Self { Self { inner: Arc::new(AsyncInner { - count: AtomicUsize::new(0), - waker: Mutex::new(None), + counter: AtomicUsize::new(0), + event: Event::new(), }), } } @@ -449,8 +453,8 @@ mod r#async { fn from(count: usize) -> Self { Self { inner: Arc::new(AsyncInner { - count: AtomicUsize::new(count), - waker: Mutex::new(None), + counter: AtomicUsize::new(count), + event: Event::new(), }), } } @@ -466,10 +470,8 @@ mod r#async { impl std::fmt::Debug for AsyncWaitGroup { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let count = self.inner.count.load(Ordering::Relaxed); - f.debug_struct("AsyncWaitGroup") - .field("count", &count) + .field("counter", &self.inner.counter) .finish() } } @@ -513,7 +515,7 @@ mod r#async { /// /// [`wait`]: struct.AsyncWaitGroup.html#method.wait pub fn add(&self, num: usize) -> Self { - self.inner.count.fetch_add(num, Ordering::SeqCst); + self.inner.counter.fetch_add(num, Ordering::AcqRel); Self { inner: self.inner.clone(), @@ -539,18 +541,14 @@ mod r#async { /// } /// ``` pub fn done(&self) { - let count = self.inner.count.fetch_sub(1, Ordering::Relaxed); - // We are the last worker - if count == 1 { - if let Some(waker) = self.inner.waker.lock_me().take() { - waker.wake(); - } + if self.inner.counter.fetch_sub(1, Ordering::SeqCst) == 1 { + self.inner.event.notify(usize::MAX); } } /// waitings return how many jobs are waiting. pub fn waitings(&self) -> usize { - self.inner.count.load(Ordering::Acquire) + self.inner.counter.load(Ordering::Acquire) } /// wait blocks until the [`AsyncWaitGroup`] counter is zero. @@ -575,8 +573,8 @@ mod r#async { /// wg.wait().await; /// } /// ``` - pub async fn wait(&self) { - WaitGroupFuture::new(&self.inner).await + pub fn wait(&self) -> WaitGroupFuture<'_> { + WaitGroupFuture::_new(WaitGroupFutureInner::new(&self.inner)) } /// Wait blocks until the [`AsyncWaitGroup`] counter is zero. This method is @@ -606,39 +604,67 @@ mod r#async { /// } /// ``` pub fn block_wait(&self) { - loop { - match self.inner.count.load(Ordering::Acquire) { - 0 => return, - _ => core::hint::spin_loop(), - } - } + WaitGroupFutureInner::new(&self.inner).wait(); } } - struct WaitGroupFuture<'a> { - inner: &'a Arc, + easy_wrapper! { + /// A future returned by [`AsyncWaitGroup::wait()`]. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + #[cfg_attr(docsrs, doc(cfg(feature = "future")))] + pub struct WaitGroupFuture<'a>(WaitGroupFutureInner<'a> => ()); + + #[cfg(all(feature = "std", not(target_family = "wasm")))] + pub(crate) wait(); } - impl<'a> WaitGroupFuture<'a> { + pin_project_lite::pin_project! { + /// A future that used to wait for the [`AsyncWaitGroup`] counter is zero. + #[must_use = "futures do nothing unless you `.await` or poll them"] + #[project(!Unpin)] + #[derive(Debug)] + struct WaitGroupFutureInner<'a> { + inner: &'a Arc, + listener: Option, + #[pin] + _pin: std::marker::PhantomPinned, + } + } + + impl<'a> WaitGroupFutureInner<'a> { fn new(inner: &'a Arc) -> Self { - Self { inner } + Self { + inner, + listener: None, + _pin: std::marker::PhantomPinned, + } } } - impl Future for WaitGroupFuture<'_> { + impl EventListenerFuture for WaitGroupFutureInner<'_> { type Output = (); - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.inner.count.load(Ordering::Relaxed) == 0 { - return Poll::Ready(()); - } - - let waker = cx.waker().clone(); - *self.inner.waker.lock_me() = Some(waker); + fn poll_with_strategy<'a, S: Strategy<'a>>( + self: Pin<&mut Self>, + strategy: &mut S, + context: &mut S::Context, + ) -> Poll { + let this = self.project(); + loop { + if this.inner.counter.load(Ordering::Acquire) == 0 { + return Poll::Ready(()); + } - match self.inner.count.load(Ordering::Relaxed) { - 0 => Poll::Ready(()), - _ => Poll::Pending, + if this.listener.is_some() { + // Poll using the given strategy + match S::poll(strategy, &mut *this.listener, context) { + Poll::Ready(_) => {} + Poll::Pending => return Poll::Pending, + } + } else { + *this.listener = Some(this.inner.event.listen()); + } } } } @@ -759,13 +785,13 @@ mod r#async { assert_eq!(wg.waitings(), 2); } - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_async_block_wait() { + #[test] + fn test_async_block_wait() { let wg = AsyncWaitGroup::new(); let t_wg = wg.add(1); - tokio::spawn(async move { + std::thread::spawn(move || { // do some time consuming task - t_wg.done() + t_wg.done(); }); // wait other thread completes