From 37d644a30914e6aeac784395883f07384b756cd0 Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 4 Jul 2024 02:48:19 +0200 Subject: [PATCH] Use AtomicU64 instead of u64 for the position field in Receiver The `recv()`, `recv_direct()`, `recv_blocking()`, and `try_recv()` methods currently require `&mut self` to modify the value of the `position` field. By using AtomicU64 for the `position` field eliminates the need for mutability. Fixes issue #66 --- src/lib.rs | 94 +++++++++++++++++++++++++++------------------------ tests/test.rs | 20 +++++------ 2 files changed, 59 insertions(+), 55 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 67032a4..a3a2b8f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -113,7 +113,10 @@ use std::fmt; use std::future::Future; use std::marker::PhantomPinned; use std::pin::Pin; -use std::sync::{Arc, Mutex}; +use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, Mutex, +}; use std::task::{Context, Poll}; use event_listener::{Event, EventListener}; @@ -135,8 +138,8 @@ use pin_project_lite::pin_project; /// # futures_lite::future::block_on(async { /// use async_broadcast::{broadcast, TryRecvError, TrySendError}; /// -/// let (s, mut r1) = broadcast(1); -/// let mut r2 = r1.clone(); +/// let (s, r1) = broadcast(1); +/// let r2 = r1.clone(); /// /// assert_eq!(s.broadcast(10).await, Ok(None)); /// assert_eq!(s.try_broadcast(20), Err(TrySendError::Full(20))); @@ -169,7 +172,7 @@ pub fn broadcast(cap: usize) -> (Sender, Receiver) { }; let r = Receiver { inner, - pos: 0, + pos: AtomicU64::new(0), listener: None, }; @@ -203,21 +206,22 @@ impl Inner { /// Try receiving at the given position, returning either the element or a reference to it. /// /// Result is used here instead of Cow because we don't have a Clone bound on T. - fn try_recv_at(&mut self, pos: &mut u64) -> Result, TryRecvError> { - let i = match pos.checked_sub(self.head_pos) { + fn try_recv_at(&mut self, pos: &AtomicU64) -> Result, TryRecvError> { + let i = pos.load(Ordering::Acquire); + let i = match i.checked_sub(self.head_pos) { Some(i) => i .try_into() .expect("Head position more than usize::MAX behind a receiver"), None => { - let count = self.head_pos - *pos; - *pos = self.head_pos; + let count = self.head_pos - pos.load(Ordering::Relaxed); + pos.store(self.head_pos, Ordering::Release); return Err(TryRecvError::Overflowed(count)); } }; let last_waiter; if let Some((_elt, waiters)) = self.queue.get_mut(i) { - *pos += 1; + pos.fetch_add(1, Ordering::Release); *waiters -= 1; last_waiter = *waiters == 0; } else { @@ -331,7 +335,7 @@ impl Sender { /// ``` /// use async_broadcast::{broadcast, TrySendError, TryRecvError}; /// - /// let (mut s, mut r) = broadcast::(3); + /// let (mut s, r) = broadcast::(3); /// assert_eq!(s.capacity(), 3); /// s.try_broadcast(1).unwrap(); /// s.try_broadcast(2).unwrap(); @@ -378,7 +382,7 @@ impl Sender { /// ``` /// use async_broadcast::{broadcast, TrySendError, TryRecvError}; /// - /// let (mut s, mut r) = broadcast::(2); + /// let (mut s, r) = broadcast::(2); /// s.try_broadcast(1).unwrap(); /// s.try_broadcast(2).unwrap(); /// assert_eq!(s.try_broadcast(3), Err(TrySendError::Full(3))); @@ -423,7 +427,7 @@ impl Sender { /// # futures_lite::future::block_on(async { /// use async_broadcast::broadcast; /// - /// let (mut s, mut r) = broadcast::(2); + /// let (mut s, r) = broadcast::(2); /// s.broadcast(1).await.unwrap(); /// /// let _ = r.deactivate(); @@ -447,7 +451,7 @@ impl Sender { /// # futures_lite::future::block_on(async { /// use async_broadcast::{broadcast, RecvError}; /// - /// let (s, mut r) = broadcast(1); + /// let (s, r) = broadcast(1); /// s.broadcast(1).await.unwrap(); /// assert!(s.close()); /// @@ -611,11 +615,11 @@ impl Sender { /// # futures_lite::future::block_on(async { /// use async_broadcast::{broadcast, RecvError}; /// - /// let (s, mut r1) = broadcast(2); + /// let (s, r1) = broadcast(2); /// /// assert_eq!(s.broadcast(1).await, Ok(None)); /// - /// let mut r2 = s.new_receiver(); + /// let r2 = s.new_receiver(); /// /// assert_eq!(s.broadcast(2).await, Ok(None)); /// drop(s); @@ -633,7 +637,7 @@ impl Sender { inner.receiver_count += 1; Receiver { inner: self.inner.clone(), - pos: inner.head_pos + inner.queue.len() as u64, + pos: AtomicU64::new(inner.head_pos + inner.queue.len() as u64), listener: None, } } @@ -816,7 +820,7 @@ impl Clone for Sender { #[derive(Debug)] pub struct Receiver { inner: Arc>>, - pos: u64, + pos: AtomicU64, /// Listens for a send or close event to unblock this stream. listener: Option, @@ -964,7 +968,7 @@ impl Receiver { /// # futures_lite::future::block_on(async { /// use async_broadcast::{broadcast, RecvError}; /// - /// let (s, mut r) = broadcast(1); + /// let (s, r) = broadcast(1); /// s.broadcast(1).await.unwrap(); /// assert!(s.close()); /// @@ -1138,7 +1142,7 @@ impl Receiver { /// let inactive = r.deactivate(); /// assert_eq!(s.try_broadcast(10), Err(TrySendError::Inactive(10))); /// - /// let mut r = inactive.activate(); + /// let r = inactive.activate(); /// assert_eq!(s.broadcast(10).await, Ok(None)); /// assert_eq!(r.recv().await, Ok(10)); /// # }); @@ -1175,8 +1179,8 @@ impl Receiver { /// # futures_lite::future::block_on(async { /// use async_broadcast::{broadcast, RecvError}; /// - /// let (s, mut r1) = broadcast(1); - /// let mut r2 = r1.clone(); + /// let (s, r1) = broadcast(1); + /// let r2 = r1.clone(); /// /// assert_eq!(s.broadcast(1).await, Ok(None)); /// drop(s); @@ -1187,7 +1191,7 @@ impl Receiver { /// assert_eq!(r2.recv().await, Err(RecvError::Closed)); /// # }); /// ``` - pub fn recv(&mut self) -> Pin>> { + pub fn recv(&self) -> Pin>> { Box::pin(self.recv_direct()) } @@ -1203,8 +1207,8 @@ impl Receiver { /// # futures_lite::future::block_on(async { /// use async_broadcast::{broadcast, RecvError}; /// - /// let (s, mut r1) = broadcast(1); - /// let mut r2 = r1.clone(); + /// let (s, r1) = broadcast(1); + /// let r2 = r1.clone(); /// /// assert_eq!(s.broadcast(1).await, Ok(None)); /// drop(s); @@ -1215,7 +1219,7 @@ impl Receiver { /// assert_eq!(r2.recv_direct().await, Err(RecvError::Closed)); /// # }); /// ``` - pub fn recv_direct(&mut self) -> Recv<'_, T> { + pub fn recv_direct(&self) -> Recv<'_, T> { Recv::_new(RecvInner { receiver: self, listener: None, @@ -1237,10 +1241,9 @@ impl Receiver { /// # futures_lite::future::block_on(async { /// use async_broadcast::{broadcast, TryRecvError}; /// - /// let (s, mut r1) = broadcast(1); - /// let mut r2 = r1.clone(); + /// let (s, r1) = broadcast(1); + /// let r2 = r1.clone(); /// assert_eq!(s.broadcast(1).await, Ok(None)); - /// /// assert_eq!(r1.try_recv(), Ok(1)); /// assert_eq!(r1.try_recv(), Err(TryRecvError::Empty)); /// assert_eq!(r2.try_recv(), Ok(1)); @@ -1251,11 +1254,11 @@ impl Receiver { /// assert_eq!(r2.try_recv(), Err(TryRecvError::Closed)); /// # }); /// ``` - pub fn try_recv(&mut self) -> Result { + pub fn try_recv(&self) -> Result { self.inner .lock() .unwrap() - .try_recv_at(&mut self.pos) + .try_recv_at(&self.pos) .map(|cow| cow.unwrap_or_else(T::clone)) } @@ -1284,7 +1287,7 @@ impl Receiver { /// ``` /// use async_broadcast::{broadcast, RecvError}; /// - /// let (s, mut r) = broadcast(1); + /// let (s, r) = broadcast(1); /// /// assert_eq!(s.broadcast_blocking(1), Ok(None)); /// drop(s); @@ -1293,7 +1296,7 @@ impl Receiver { /// assert_eq!(r.recv_blocking(), Err(RecvError::Closed)); /// ``` #[cfg(not(target_family = "wasm"))] - pub fn recv_blocking(&mut self) -> Result { + pub fn recv_blocking(&self) -> Result { self.recv_direct().wait() } @@ -1307,7 +1310,7 @@ impl Receiver { /// # futures_lite::future::block_on(async { /// use async_broadcast::{broadcast, RecvError}; /// - /// let (s1, mut r) = broadcast(2); + /// let (s1, r) = broadcast(2); /// /// assert_eq!(s1.broadcast(1).await, Ok(None)); /// @@ -1341,11 +1344,11 @@ impl Receiver { /// # futures_lite::future::block_on(async { /// use async_broadcast::{broadcast, RecvError}; /// - /// let (s, mut r1) = broadcast(2); + /// let (s, r1) = broadcast(2); /// /// assert_eq!(s.broadcast(1).await, Ok(None)); /// - /// let mut r2 = r1.new_receiver(); + /// let r2 = r1.new_receiver(); /// /// assert_eq!(s.broadcast(2).await, Ok(None)); /// drop(s); @@ -1363,7 +1366,7 @@ impl Receiver { inner.receiver_count += 1; Receiver { inner: self.inner.clone(), - pos: inner.head_pos + inner.queue.len() as u64, + pos: AtomicU64::new(inner.head_pos + inner.queue.len() as u64), listener: None, } } @@ -1458,7 +1461,7 @@ impl Drop for Receiver { // Remove ourself from each item's counter loop { - match inner.try_recv_at(&mut self.pos) { + match inner.try_recv_at(&self.pos) { Ok(_) => continue, Err(TryRecvError::Overflowed(_)) => continue, Err(TryRecvError::Closed) => break, @@ -1481,12 +1484,12 @@ impl Clone for Receiver { /// # futures_lite::future::block_on(async { /// use async_broadcast::{broadcast, RecvError}; /// - /// let (s, mut r1) = broadcast(1); + /// let (s, r1) = broadcast(1); /// /// assert_eq!(s.broadcast(1).await, Ok(None)); /// drop(s); /// - /// let mut r2 = r1.clone(); + /// let r2 = r1.clone(); /// /// assert_eq!(r1.recv().await, Ok(1)); /// assert_eq!(r1.recv().await, Err(RecvError::Closed)); @@ -1498,13 +1501,14 @@ impl Clone for Receiver { let mut inner = self.inner.lock().unwrap(); inner.receiver_count += 1; // increment the waiter count on all items not yet received by this object - let n = self.pos.saturating_sub(inner.head_pos) as usize; + let pos = self.pos.load(Ordering::Relaxed); + let n = pos.saturating_sub(inner.head_pos) as usize; for (_elt, waiters) in inner.queue.iter_mut().skip(n) { *waiters += 1; } Receiver { inner: self.inner.clone(), - pos: self.pos, + pos: AtomicU64::new(pos), listener: None, } } @@ -1798,7 +1802,7 @@ easy_wrapper! { pin_project! { #[derive(Debug)] struct RecvInner<'a, T> { - receiver: &'a mut Receiver, + receiver: &'a Receiver, listener: Option, // Keeping this type `!Unpin` enables future optimizations. @@ -1870,7 +1874,7 @@ impl InactiveReceiver { /// let inactive = r.deactivate(); /// assert_eq!(s.try_broadcast(10), Err(TrySendError::Inactive(10))); /// - /// let mut r = inactive.activate(); + /// let r = inactive.activate(); /// assert_eq!(s.try_broadcast(10), Ok(None)); /// assert_eq!(r.try_recv(), Ok(10)); /// ``` @@ -1889,7 +1893,7 @@ impl InactiveReceiver { /// let inactive = r.deactivate(); /// assert_eq!(s.try_broadcast(10), Err(TrySendError::Inactive(10))); /// - /// let mut r = inactive.activate_cloned(); + /// let r = inactive.activate_cloned(); /// assert_eq!(s.try_broadcast(10), Ok(None)); /// assert_eq!(r.try_recv(), Ok(10)); /// ``` @@ -1905,7 +1909,7 @@ impl InactiveReceiver { Receiver { inner: self.inner.clone(), - pos: inner.head_pos + inner.queue.len() as u64, + pos: AtomicU64::new(inner.head_pos + inner.queue.len() as u64), listener: None, } } diff --git a/tests/test.rs b/tests/test.rs index 7b9c13e..40e125d 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -12,14 +12,14 @@ fn ms(ms: u64) -> Duration { #[test] fn basic_sync() { - let (s, mut r1) = broadcast(10); - let mut r2 = r1.clone(); + let (s, r1) = broadcast(10); + let r2 = r1.clone(); s.try_broadcast(7).unwrap(); assert_eq!(r1.try_recv().unwrap(), 7); assert_eq!(r2.try_recv().unwrap(), 7); - let mut r3 = r1.clone(); + let r3 = r1.clone(); s.try_broadcast(8).unwrap(); assert_eq!(r1.try_recv().unwrap(), 8); assert_eq!(r2.try_recv().unwrap(), 8); @@ -48,7 +48,7 @@ fn basic_async() { #[cfg(not(target_family = "wasm"))] #[test] fn basic_blocking() { - let (s, mut r) = broadcast(1); + let (s, r) = broadcast(1); s.broadcast_blocking(7).unwrap(); assert_eq!(r.try_recv(), Ok(7)); @@ -64,9 +64,9 @@ fn basic_blocking() { #[test] fn parallel() { - let (s1, mut r1) = broadcast(2); + let (s1, r1) = broadcast(2); let s2 = s1.clone(); - let mut r2 = r1.clone(); + let r2 = r1.clone(); let (sender_sync_send, sender_sync_recv) = mpsc::channel(); let (receiver_sync_send, receiver_sync_recv) = mpsc::channel(); @@ -162,9 +162,9 @@ fn parallel_async() { #[test] fn channel_shrink() { let (s1, mut r1) = broadcast(4); - let mut r2 = r1.clone(); - let mut r3 = r1.clone(); - let mut r4 = r1.clone(); + let r2 = r1.clone(); + let r3 = r1.clone(); + let r4 = r1.clone(); s1.try_broadcast(1).unwrap(); s1.try_broadcast(2).unwrap(); @@ -287,7 +287,7 @@ fn open_channel() { sender_sync_send.send(()).unwrap(); receiver_sync_recv.recv().unwrap(); - let mut r = inactive.activate(); + let r = inactive.activate(); assert_eq!(r.recv().await.unwrap(), 9); assert_eq!(r.recv().await.unwrap(), 10); })