diff --git a/Cargo.toml b/Cargo.toml index 869ad2e..3baf1c0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "deno_unsync" -version = "0.3.7" +version = "0.3.10" edition = "2021" authors = ["the Deno authors"] license = "MIT" @@ -9,6 +9,7 @@ description = "A collection of adapters to make working with Tokio single-thread readme = "README.md" [dependencies] +parking_lot = "0.12.3" tokio = { version = "1", features = ["rt"] } [dev-dependencies] diff --git a/src/flag.rs b/src/flag.rs index 293eac1..76e0e74 100644 --- a/src/flag.rs +++ b/src/flag.rs @@ -1,6 +1,8 @@ // Copyright 2018-2024 the Deno authors. MIT license. use std::cell::Cell; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; /// A flag with interior mutability that can be raised or lowered. /// Useful for indicating if an event has occurred. @@ -29,6 +31,32 @@ impl Flag { } } +/// Simplifies the use of an atomic boolean as a flag. +#[derive(Debug, Default)] +pub struct AtomicFlag(AtomicBool); + +impl AtomicFlag { + /// Creates a new flag that's raised. + pub fn raised() -> AtomicFlag { + Self(AtomicBool::new(true)) + } + + /// Raises the flag returning if the raise was successful. + pub fn raise(&self) -> bool { + !self.0.swap(true, Ordering::SeqCst) + } + + /// Lowers the flag returning if the lower was successful. + pub fn lower(&self) -> bool { + self.0.swap(false, Ordering::SeqCst) + } + + /// Gets if the flag is raised. + pub fn is_raised(&self) -> bool { + self.0.load(Ordering::SeqCst) + } +} + #[cfg(test)] mod test { use super::*; @@ -46,4 +74,21 @@ mod test { assert!(!flag.lower()); assert!(!flag.is_raised()); } + + #[test] + fn atomic_flag_raises_lowers() { + let flag = AtomicFlag::default(); + assert!(!flag.is_raised()); // false by default + assert!(flag.raise()); + assert!(flag.is_raised()); + assert!(!flag.raise()); + assert!(flag.is_raised()); + assert!(flag.lower()); + assert!(flag.raise()); + assert!(flag.lower()); + assert!(!flag.lower()); + let flag = AtomicFlag::raised(); + assert!(flag.is_raised()); + assert!(flag.lower()); + } } diff --git a/src/future.rs b/src/future.rs index 55488a7..d264508 100644 --- a/src/future.rs +++ b/src/future.rs @@ -1,15 +1,16 @@ // Copyright 2018-2024 the Deno authors. MIT license. +use parking_lot::Mutex; use std::cell::RefCell; use std::future::Future; use std::pin::Pin; use std::rc::Rc; +use std::sync::Arc; use std::task::Context; -use std::task::RawWaker; -use std::task::RawWakerVTable; +use std::task::Wake; use std::task::Waker; -use crate::Flag; +use crate::AtomicFlag; impl LocalFutureExt for T where T: Future {} @@ -57,7 +58,7 @@ where struct SharedLocalInner { data: RefCell>, - child_waker_state: Rc, + child_waker_state: Arc, } impl std::fmt::Debug for SharedLocalInner @@ -103,8 +104,8 @@ where data: RefCell::new(SharedLocalData { future_or_output: FutureOrOutput::Future(future), }), - child_waker_state: Rc::new(ChildWakerState { - can_poll: Flag::raised(), + child_waker_state: Arc::new(ChildWakerState { + can_poll: AtomicFlag::raised(), wakers: Default::default(), }), })) @@ -128,8 +129,7 @@ where FutureOrOutput::Future(fut) => { self.0.child_waker_state.wakers.push(cx.waker().clone()); if self.0.child_waker_state.can_poll.lower() { - let child_waker = - create_child_waker(self.0.child_waker_state.clone()); + let child_waker = Waker::from(self.0.child_waker_state.clone()); let mut child_cx = Context::from_waker(&child_waker); let fut = unsafe { Pin::new_unchecked(fut) }; match fut.poll(&mut child_cx) { @@ -154,81 +154,49 @@ where } #[derive(Debug, Default)] -struct WakerStore(RefCell>); +struct WakerStore(Mutex>); impl WakerStore { pub fn take_all(&self) -> Vec { - let mut wakers = self.0.borrow_mut(); + let mut wakers = self.0.lock(); std::mem::take(&mut *wakers) } pub fn clone_all(&self) -> Vec { - self.0.borrow().clone() + self.0.lock().clone() } pub fn push(&self, waker: Waker) { - self.0.borrow_mut().push(waker); + self.0.lock().push(waker); } } #[derive(Debug)] struct ChildWakerState { - can_poll: Flag, + can_poll: AtomicFlag, wakers: WakerStore, } -fn create_child_waker(state: Rc) -> Waker { - let raw_waker = RawWaker::new( - Rc::into_raw(state) as *const (), - &RawWakerVTable::new( - clone_waker, - wake_waker, - wake_by_ref_waker, - drop_waker, - ), - ); - unsafe { Waker::from_raw(raw_waker) } -} - -unsafe fn clone_waker(data: *const ()) -> RawWaker { - Rc::increment_strong_count(data as *const ChildWakerState); - RawWaker::new( - data, - &RawWakerVTable::new( - clone_waker, - wake_waker, - wake_by_ref_waker, - drop_waker, - ), - ) -} +impl Wake for ChildWakerState { + fn wake(self: Arc) { + self.can_poll.raise(); + let wakers = self.wakers.take_all(); -unsafe fn wake_waker(data: *const ()) { - let state = Rc::from_raw(data as *const ChildWakerState); - state.can_poll.raise(); - let wakers = state.wakers.take_all(); - drop(state); - - for waker in wakers { - waker.wake(); + for waker in wakers { + waker.wake(); + } } -} -unsafe fn wake_by_ref_waker(data: *const ()) { - let state = Rc::from_raw(data as *const ChildWakerState); - state.can_poll.raise(); - let wakers = state.wakers.clone_all(); - let _ = Rc::into_raw(state); // keep it alive + fn wake_by_ref(self: &Arc) { + self.can_poll.raise(); + let wakers = self.wakers.clone_all(); - for waker in wakers { - waker.wake_by_ref(); + for waker in wakers { + waker.wake_by_ref(); + } } } -unsafe fn drop_waker(data: *const ()) { - Rc::decrement_strong_count(data as *const ChildWakerState); -} - #[cfg(test)] mod test { use std::sync::Arc; diff --git a/src/lib.rs b/src/lib.rs index 0297f28..e1a6ef3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ mod task; mod task_queue; mod waker; +pub use flag::AtomicFlag; pub use flag::Flag; pub use joinset::JoinSet; pub use split::split_io;