Skip to content

Commit

Permalink
Improve async channels error handling and replace unbounded channels …
Browse files Browse the repository at this point in the history
…with bounded channels

Remove all unbounded channels to prevent unbounded memory usage and
potential crashes.

Use `FuturesUnordered` for sending to multiple channels simultaneously.
This prevents the sending loop from blocking if one channel is blocked,
and helps handle errors properly.
  • Loading branch information
hozan23 committed Jun 27, 2024
1 parent 1a3ef2d commit b8b5f00
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 50 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ authors.workspace = true

[features]
default = ["smol"]
crypto = ["dep:ed25519-dalek"]
crypto = ["ed25519-dalek"]
tokio = ["dep:tokio"]
smol = ["dep:smol", "dep:async-process"]
smol = ["dep:smol", "async-process"]

[dependencies]
log = "0.4.21"
Expand All @@ -29,6 +29,9 @@ pin-project-lite = "0.2.14"
async-process = { version = "2.2.3", optional = true }
smol = { version = "2.0.0", optional = true }
tokio = { version = "1.38.0", features = ["full"], optional = true }
futures-util = { version = "0.3.5", features = [
"alloc",
], default-features = false }

# encode
bincode = "2.0.0-rc.3"
Expand Down
5 changes: 5 additions & 0 deletions core/src/async_runtime/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ impl Executor {
) -> Task<T> {
self.inner.spawn(future).into()
}

#[cfg(feature = "tokio")]
pub fn handle(&self) -> &tokio::runtime::Handle {
return self.inner.handle();
}
}

static GLOBAL_EXECUTOR: OnceCell<Executor> = OnceCell::new();
Expand Down
1 change: 1 addition & 0 deletions core/src/async_util/task_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ impl TaskGroup {
self.stop_signal.broadcast().await;

loop {
// XXX BE CAREFUL HERE, it hold synchronous mutex across .await point.
let task = self.tasks.lock().pop();
if let Some(t) = task {
t.cancel().await
Expand Down
74 changes: 50 additions & 24 deletions core/src/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@ use std::{

use async_channel::{Receiver, Sender};
use chrono::{DateTime, Utc};
use futures_util::stream::{FuturesUnordered, StreamExt};
use log::{debug, error};

use crate::{async_runtime::lock::Mutex, util::random_16, Result};
use crate::{async_runtime::lock::Mutex, util::random_32, Result};

const CHANNEL_BUFFER_SIZE: usize = 1000;

pub type ArcEventSys<T> = Arc<EventSys<T>>;
pub type WeakEventSys<T> = Weak<EventSys<T>>;
pub type EventListenerID = u16;
pub type EventListenerID = u32;

type Listeners<T> = HashMap<T, HashMap<String, HashMap<EventListenerID, Sender<Event>>>>;

/// EventSys supports event emission to registered listeners based on topics.
/// EventSys emits events to registered listeners based on topics.
/// # Example
///
/// ```
Expand Down Expand Up @@ -74,22 +76,41 @@ type Listeners<T> = HashMap<T, HashMap<String, HashMap<EventListenerID, Sender<E
///
pub struct EventSys<T> {
listeners: Mutex<Listeners<T>>,
listener_buffer_size: usize,
}

impl<T> EventSys<T>
where
T: std::hash::Hash + Eq + std::fmt::Debug + Clone,
{
/// Creates a new `EventSys`
/// Creates a new [`EventSys`]
pub fn new() -> ArcEventSys<T> {
Arc::new(Self {
listeners: Mutex::new(HashMap::new()),
listener_buffer_size: CHANNEL_BUFFER_SIZE,
})
}

/// Creates a new [`EventSys`] with the provided buffer size for the
/// [`EventListener`] channel.
///
/// This is important to control the memory used by the listener channel.
/// If the consumer for the event listener can't keep up with the new events
/// coming, then the channel buffer will fill with new events, and if the
/// buffer is full, the emit function will block until the listener
/// starts to consume the buffered events.
///
/// If `size` is zero, this function will panic.
pub fn with_buffer_size(size: usize) -> ArcEventSys<T> {
Arc::new(Self {
listeners: Mutex::new(HashMap::new()),
listener_buffer_size: size,
})
}

/// Emits an event to the listeners.
///
/// The event must implement the `EventValueTopic` trait to indicate the
/// The event must implement the [`EventValueTopic`] trait to indicate the
/// topic of the event. Otherwise, you can use `emit_by_topic()`.
pub async fn emit<E: EventValueTopic<Topic = T> + Clone>(&self, value: &E) {
let topic = E::topic();
Expand All @@ -115,22 +136,26 @@ where
let event_id = E::id().to_string();

if !event_ids.contains_key(&event_id) {
debug!(
"Failed to emit an event to a non-existent event id: {:?}",
event_id
);
debug!("Failed to emit an event: unknown event id {:?}", event_id);
return;
}

let mut failed_listeners = vec![];
let mut results = FuturesUnordered::new();

let listeners = event_ids.get_mut(&event_id).unwrap();
for (listener_id, listener) in listeners.iter() {
if let Err(err) = listener.send(event.clone()).await {
let result = async { (*listener_id, listener.send(event.clone()).await) };
results.push(result);
}

let mut failed_listeners = vec![];
while let Some((id, fut_err)) = results.next().await {
if let Err(err) = fut_err {
debug!("Failed to emit event for topic {:?}: {}", topic, err);
failed_listeners.push(*listener_id);
failed_listeners.push(id);
}
}
drop(results);

for listener_id in failed_listeners.iter() {
listeners.remove(listener_id);
Expand All @@ -142,7 +167,7 @@ where
self: &Arc<Self>,
topic: &T,
) -> EventListener<T, E> {
let chan = async_channel::unbounded();
let chan = async_channel::bounded(self.listener_buffer_size);

let topics = &mut self.listeners.lock().await;

Expand All @@ -159,9 +184,10 @@ where

let listeners = event_ids.get_mut(&event_id).unwrap();

let mut listener_id = random_16();
let mut listener_id = random_32();
// Generate a new one if listener_id already exists
while listeners.contains_key(&listener_id) {
listener_id = random_16();
listener_id = random_32();
}

let listener =
Expand Down Expand Up @@ -197,7 +223,7 @@ where
pub struct EventListener<T, E> {
id: EventListenerID,
recv_chan: Receiver<Event>,
event_sys: WeakEventSys<T>,
event_sys: Weak<EventSys<T>>,
event_id: String,
topic: T,
phantom: PhantomData<E>,
Expand All @@ -208,10 +234,10 @@ where
T: std::hash::Hash + Eq + Clone + std::fmt::Debug,
E: EventValueAny + Clone + EventValue,
{
/// Create a new event listener.
/// Creates a new [`EventListener`].
fn new(
id: EventListenerID,
event_sys: WeakEventSys<T>,
event_sys: Weak<EventSys<T>>,
recv_chan: Receiver<Event>,
event_id: &str,
topic: &T,
Expand All @@ -226,12 +252,12 @@ where
}
}

/// Receive the next event.
/// Receives the next event.
pub async fn recv(&self) -> Result<E> {
match self.recv_chan.recv().await {
Ok(event) => match ((*event.value).value_as_any()).downcast_ref::<E>() {
Some(v) => Ok(v.clone()),
None => unreachable!("Error when attempting to downcast the event value."),
None => unreachable!("Failed to downcast the event value."),
},
Err(err) => {
error!("Failed to receive new event: {err}");
Expand All @@ -241,20 +267,20 @@ where
}
}

/// Cancels the listener and removes it from the `EventSys`.
/// Cancels the event listener and removes it from the [`EventSys`].
pub async fn cancel(&self) {
if let Some(es) = self.event_sys.upgrade() {
es.remove(&self.topic, &self.event_id, &self.id).await;
}
}

/// Returns the topic for this event listener.
pub async fn topic(&self) -> &T {
pub fn topic(&self) -> &T {
&self.topic
}

/// Returns the event id for this event listener.
pub async fn event_id(&self) -> &String {
pub fn event_id(&self) -> &String {
&self.event_id
}
}
Expand Down
56 changes: 43 additions & 13 deletions core/src/pubsub.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::{collections::HashMap, sync::Arc};

use futures_util::stream::{FuturesUnordered, StreamExt};
use log::error;

use crate::{async_runtime::lock::Mutex, util::random_16, Result};
use crate::{async_runtime::lock::Mutex, util::random_32, Result};

const CHANNEL_BUFFER_SIZE: usize = 1000;

pub type ArcPublisher<T> = Arc<Publisher<T>>;
pub type SubscriptionID = u16;
pub type SubscriptionID = u32;

/// A simple publish-subscribe system.
// # Example
Expand All @@ -28,27 +31,46 @@ pub type SubscriptionID = u16;
/// ```
pub struct Publisher<T> {
subs: Mutex<HashMap<SubscriptionID, async_channel::Sender<T>>>,
subscription_buffer_size: usize,
}

impl<T: Clone> Publisher<T> {
/// Creates a new Publisher
/// Creates a new [`Publisher`]
pub fn new() -> ArcPublisher<T> {
Arc::new(Self {
subs: Mutex::new(HashMap::new()),
subscription_buffer_size: CHANNEL_BUFFER_SIZE,
})
}

/// Creates a new [`Publisher`] with the provided buffer size for the
/// [`Subscription`] channel.
///
/// This is important to control the memory used by the [`Subscription`] channel.
/// If the subscriber can't keep up with the new messages coming, then the
/// channel buffer will fill with new messages, and if the buffer is full,
/// the emit function will block until the subscriber starts to process
/// the buffered messages.
///
/// If `size` is zero, this function will panic.
pub fn with_buffer_size(size: usize) -> ArcPublisher<T> {
Arc::new(Self {
subs: Mutex::new(HashMap::new()),
subscription_buffer_size: size,
})
}

/// Subscribe and return a Subscription
/// Subscribes and return a [`Subscription`]
pub async fn subscribe(self: &Arc<Self>) -> Subscription<T> {
let mut subs = self.subs.lock().await;

let chan = async_channel::unbounded();
let chan = async_channel::bounded(self.subscription_buffer_size);

let mut sub_id = random_16();
let mut sub_id = random_32();

// While the SubscriptionID already exists, generate a new one
// Generate a new one if sub_id already exists
while subs.contains_key(&sub_id) {
sub_id = random_16();
sub_id = random_32();
}

let sub = Subscription::new(sub_id, self.clone(), chan.1);
Expand All @@ -57,22 +79,30 @@ impl<T: Clone> Publisher<T> {
sub
}

/// Unsubscribe from the Publisher
/// Unsubscribes from the publisher
pub async fn unsubscribe(self: &Arc<Self>, id: &SubscriptionID) {
self.subs.lock().await.remove(id);
}

/// Notify all subscribers
/// Notifies all subscribers
pub async fn notify(self: &Arc<Self>, value: &T) {
let mut subs = self.subs.lock().await;

let mut results = FuturesUnordered::new();
let mut closed_subs = vec![];

for (sub_id, sub) in subs.iter() {
if let Err(err) = sub.send(value.clone()).await {
error!("failed to notify {}: {}", sub_id, err);
closed_subs.push(*sub_id);
let result = async { (*sub_id, sub.send(value.clone()).await) };
results.push(result);
}

while let Some((id, fut_err)) = results.next().await {
if let Err(err) = fut_err {
error!("failed to notify {}: {}", id, err);
closed_subs.push(id);
}
}
drop(results);

for sub_id in closed_subs.iter() {
subs.remove(sub_id);
Expand Down
12 changes: 7 additions & 5 deletions jsonrpc/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ pub const FAILED_TO_PARSE_ERROR_MSG: &str = "Failed to parse";
pub const METHOD_NOT_FOUND_ERROR_MSG: &str = "Method not found";
pub const UNSUPPORTED_JSONRPC_VERSION: &str = "Unsupported jsonrpc version";

const CHANNEL_SUBSCRIPTION_BUFFER_SIZE: usize = 100;

struct NewRequest {
srvc_name: String,
method_name: String,
Expand Down Expand Up @@ -108,7 +110,7 @@ impl Server {

let conn = Arc::new(conn);

let (ch_tx, ch_rx) = async_channel::unbounded();
let (ch_tx, ch_rx) = async_channel::bounded(CHANNEL_SUBSCRIPTION_BUFFER_SIZE);
// Create a new connection channel for managing subscriptions
let channel = Channel::new(ch_tx);

Expand All @@ -120,13 +122,13 @@ impl Server {
if let TaskResult::Completed(Err(err)) = result {
debug!("Notification loop stopped: {err}");
}
// Close the connection subscription channel
// Close the connection channel
chan.close();
};

let conn_cloned = conn.clone();
let queue_cloned = queue.clone();
// Start listening for responses in the queue or new notifications
// Start listening for new responses in the queue or new notifications
self.task_group.spawn(
async move {
loop {
Expand Down Expand Up @@ -163,12 +165,12 @@ impl Server {
} else {
warn!("Connection {} dropped", endpoint);
}
// Close the subscription channel when the connection dropped
// Close the connection channel when the connection dropped
chan.close();
};

let selfc = self.clone();
// Spawn a new task and wait for requests.
// Spawn a new task and wait for new requests.
self.task_group.spawn(
async move {
loop {
Expand Down
Loading

0 comments on commit b8b5f00

Please sign in to comment.