Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make waker in Shared Send and Sync #16

Merged
merged 5 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "deno_unsync"
version = "0.3.7"
version = "0.3.10"
edition = "2021"
authors = ["the Deno authors"]
license = "MIT"
Expand All @@ -9,6 +9,7 @@ description = "A collection of adapters to make working with Tokio single-thread
readme = "README.md"

[dependencies]
parking_lot = "0.12.3"
tokio = { version = "1", features = ["rt"] }

[dev-dependencies]
Expand Down
45 changes: 45 additions & 0 deletions src/flag.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright 2018-2024 the Deno authors. MIT license.

use std::cell::Cell;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;

/// A flag with interior mutability that can be raised or lowered.
/// Useful for indicating if an event has occurred.
Expand Down Expand Up @@ -29,6 +31,32 @@ impl Flag {
}
}

/// Simplifies the use of an atomic boolean as a flag.
#[derive(Debug, Default)]
pub struct AtomicFlag(AtomicBool);

impl AtomicFlag {
/// Creates a new flag that's raised.
pub fn raised() -> AtomicFlag {
Self(AtomicBool::new(true))
}

/// Raises the flag returning if the raise was successful.
pub fn raise(&self) -> bool {
!self.0.swap(true, Ordering::SeqCst)
}

/// Lowers the flag returning if the lower was successful.
pub fn lower(&self) -> bool {
self.0.swap(false, Ordering::SeqCst)
}

/// Gets if the flag is raised.
pub fn is_raised(&self) -> bool {
self.0.load(Ordering::SeqCst)
}
}

#[cfg(test)]
mod test {
use super::*;
Expand All @@ -46,4 +74,21 @@ mod test {
assert!(!flag.lower());
assert!(!flag.is_raised());
}

#[test]
fn atomic_flag_raises_lowers() {
let flag = AtomicFlag::default();
assert!(!flag.is_raised()); // false by default
assert!(flag.raise());
assert!(flag.is_raised());
assert!(!flag.raise());
assert!(flag.is_raised());
assert!(flag.lower());
assert!(flag.raise());
assert!(flag.lower());
assert!(!flag.lower());
let flag = AtomicFlag::raised();
assert!(flag.is_raised());
assert!(flag.lower());
}
}
84 changes: 26 additions & 58 deletions src/future.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
// Copyright 2018-2024 the Deno authors. MIT license.

use parking_lot::Mutex;
use std::cell::RefCell;
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use std::sync::Arc;
use std::task::Context;
use std::task::RawWaker;
use std::task::RawWakerVTable;
use std::task::Wake;
use std::task::Waker;

use crate::Flag;
use crate::AtomicFlag;

impl<T: ?Sized> LocalFutureExt for T where T: Future {}

Expand Down Expand Up @@ -57,7 +58,7 @@ where

struct SharedLocalInner<TFuture: Future> {
data: RefCell<SharedLocalData<TFuture>>,
child_waker_state: Rc<ChildWakerState>,
child_waker_state: Arc<ChildWakerState>,
}

impl<TFuture: Future> std::fmt::Debug for SharedLocalInner<TFuture>
Expand Down Expand Up @@ -103,8 +104,8 @@ where
data: RefCell::new(SharedLocalData {
future_or_output: FutureOrOutput::Future(future),
}),
child_waker_state: Rc::new(ChildWakerState {
can_poll: Flag::raised(),
child_waker_state: Arc::new(ChildWakerState {
can_poll: AtomicFlag::raised(),
wakers: Default::default(),
}),
}))
Expand All @@ -128,8 +129,7 @@ where
FutureOrOutput::Future(fut) => {
self.0.child_waker_state.wakers.push(cx.waker().clone());
if self.0.child_waker_state.can_poll.lower() {
let child_waker =
create_child_waker(self.0.child_waker_state.clone());
let child_waker = Waker::from(self.0.child_waker_state.clone());
let mut child_cx = Context::from_waker(&child_waker);
let fut = unsafe { Pin::new_unchecked(fut) };
match fut.poll(&mut child_cx) {
Expand All @@ -154,81 +154,49 @@ where
}

#[derive(Debug, Default)]
struct WakerStore(RefCell<Vec<Waker>>);
struct WakerStore(Mutex<Vec<Waker>>);

impl WakerStore {
pub fn take_all(&self) -> Vec<Waker> {
let mut wakers = self.0.borrow_mut();
let mut wakers = self.0.lock();
std::mem::take(&mut *wakers)
}

pub fn clone_all(&self) -> Vec<Waker> {
self.0.borrow().clone()
self.0.lock().clone()
}

pub fn push(&self, waker: Waker) {
self.0.borrow_mut().push(waker);
self.0.lock().push(waker);
}
}

#[derive(Debug)]
struct ChildWakerState {
can_poll: Flag,
can_poll: AtomicFlag,
wakers: WakerStore,
}

fn create_child_waker(state: Rc<ChildWakerState>) -> Waker {
let raw_waker = RawWaker::new(
Rc::into_raw(state) as *const (),
&RawWakerVTable::new(
clone_waker,
wake_waker,
wake_by_ref_waker,
drop_waker,
),
);
unsafe { Waker::from_raw(raw_waker) }
}

unsafe fn clone_waker(data: *const ()) -> RawWaker {
Rc::increment_strong_count(data as *const ChildWakerState);
RawWaker::new(
data,
&RawWakerVTable::new(
clone_waker,
wake_waker,
wake_by_ref_waker,
drop_waker,
),
)
}
impl Wake for ChildWakerState {
fn wake(self: Arc<Self>) {
self.can_poll.raise();
let wakers = self.wakers.take_all();

unsafe fn wake_waker(data: *const ()) {
let state = Rc::from_raw(data as *const ChildWakerState);
state.can_poll.raise();
let wakers = state.wakers.take_all();
drop(state);

for waker in wakers {
waker.wake();
for waker in wakers {
waker.wake();
}
}
}

unsafe fn wake_by_ref_waker(data: *const ()) {
let state = Rc::from_raw(data as *const ChildWakerState);
state.can_poll.raise();
let wakers = state.wakers.clone_all();
let _ = Rc::into_raw(state); // keep it alive
fn wake_by_ref(self: &Arc<Self>) {
self.can_poll.raise();
let wakers = self.wakers.clone_all();

for waker in wakers {
waker.wake_by_ref();
for waker in wakers {
waker.wake_by_ref();
}
}
}

unsafe fn drop_waker(data: *const ()) {
Rc::decrement_strong_count(data as *const ChildWakerState);
}

#[cfg(test)]
mod test {
use std::sync::Arc;
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod task;
mod task_queue;
mod waker;

pub use flag::AtomicFlag;
pub use flag::Flag;
pub use joinset::JoinSet;
pub use split::split_io;
Expand Down