Skip to content
This repository has been archived by the owner on Dec 26, 2024. It is now read-only.

Commit

Permalink
fix(network): wake the poll when a new event is pushed (#1703)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShahakShama authored Feb 13, 2024
1 parent 2ca225a commit 69fd2b2
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 14 deletions.
41 changes: 28 additions & 13 deletions crates/papyrus_network/src/streamed_data/behaviour.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::collections::{HashMap, HashSet, VecDeque};
use std::io;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::task::{Context, Poll, Waker};
use std::time::Duration;

use defaultmap::DefaultHashMap;
Expand Down Expand Up @@ -121,6 +121,7 @@ pub struct Behaviour<Query: QueryBound, Data: DataBound> {
next_outbound_session_id: OutboundSessionId,
next_inbound_session_id: Arc<AtomicUsize>,
dropped_sessions: HashSet<SessionId>,
wakers_waiting_for_event: Vec<Waker>,
}

impl<Query: QueryBound, Data: DataBound> Behaviour<Query, Data> {
Expand All @@ -134,6 +135,7 @@ impl<Query: QueryBound, Data: DataBound> Behaviour<Query, Data> {
next_outbound_session_id: Default::default(),
next_inbound_session_id: Arc::new(Default::default()),
dropped_sessions: Default::default(),
wakers_waiting_for_event: Default::default(),
}
}

Expand All @@ -153,7 +155,7 @@ impl<Query: QueryBound, Data: DataBound> Behaviour<Query, Data> {
self.session_id_to_peer_id_and_connection_id
.insert(outbound_session_id.into(), (peer_id, connection_id));

self.pending_events.push_back(ToSwarm::NotifyHandler {
self.add_event_to_queue(ToSwarm::NotifyHandler {
peer_id,
handler: NotifyHandler::One(connection_id),
event: RequestFromBehaviourEvent::CreateOutboundSession { query, outbound_session_id },
Expand All @@ -170,7 +172,7 @@ impl<Query: QueryBound, Data: DataBound> Behaviour<Query, Data> {
) -> Result<(), SessionIdNotFoundError> {
let (peer_id, connection_id) =
self.get_peer_id_and_connection_id_from_session_id(inbound_session_id.into())?;
self.pending_events.push_back(ToSwarm::NotifyHandler {
self.add_event_to_queue(ToSwarm::NotifyHandler {
peer_id,
handler: NotifyHandler::One(connection_id),
event: RequestFromBehaviourEvent::SendData { data, inbound_session_id },
Expand All @@ -186,7 +188,7 @@ impl<Query: QueryBound, Data: DataBound> Behaviour<Query, Data> {
) -> Result<(), SessionIdNotFoundError> {
let (peer_id, connection_id) =
self.get_peer_id_and_connection_id_from_session_id(inbound_session_id.into())?;
self.pending_events.push_back(ToSwarm::NotifyHandler {
self.add_event_to_queue(ToSwarm::NotifyHandler {
peer_id,
handler: NotifyHandler::One(connection_id),
event: RequestFromBehaviourEvent::CloseInboundSession { inbound_session_id },
Expand All @@ -200,7 +202,7 @@ impl<Query: QueryBound, Data: DataBound> Behaviour<Query, Data> {
let (peer_id, connection_id) =
self.get_peer_id_and_connection_id_from_session_id(session_id)?;
if self.dropped_sessions.insert(session_id) {
self.pending_events.push_back(ToSwarm::NotifyHandler {
self.add_event_to_queue(ToSwarm::NotifyHandler {
peer_id,
handler: NotifyHandler::One(connection_id),
event: RequestFromBehaviourEvent::DropSession { session_id },
Expand All @@ -218,6 +220,16 @@ impl<Query: QueryBound, Data: DataBound> Behaviour<Query, Data> {
.copied()
.ok_or(SessionIdNotFoundError)
}

fn add_event_to_queue(
&mut self,
event: ToSwarm<Event<Query, Data>, RequestFromBehaviourEvent<Query, Data>>,
) {
self.pending_events.push_back(event);
for waker in self.wakers_waiting_for_event.drain(..) {
waker.wake();
}
}
}

impl<Query: QueryBound, Data: DataBound> NetworkBehaviour for Behaviour<Query, Data> {
Expand Down Expand Up @@ -254,21 +266,23 @@ impl<Query: QueryBound, Data: DataBound> NetworkBehaviour for Behaviour<Query, D
self.connection_ids_map.get_mut(peer_id).insert(connection_id);
}
FromSwarm::ConnectionClosed(ConnectionClosed { peer_id, connection_id, .. }) => {
let mut session_ids = Vec::new();
self.session_id_to_peer_id_and_connection_id.retain(
|session_id, (session_peer_id, session_connection_id)| {
if peer_id == *session_peer_id && connection_id == *session_connection_id {
self.pending_events.push_back(ToSwarm::GenerateEvent(
Event::SessionFailed {
session_id: *session_id,
error: SessionError::ConnectionClosed,
},
));
session_ids.push(session_id.clone());
false
} else {
true
}
},
);
for session_id in session_ids {
self.add_event_to_queue(ToSwarm::GenerateEvent(Event::SessionFailed {
session_id,
error: SessionError::ConnectionClosed,
}));
}
}
_ => {}
}
Expand Down Expand Up @@ -304,7 +318,7 @@ impl<Query: QueryBound, Data: DataBound> NetworkBehaviour for Behaviour<Query, D
}
}
if !is_event_muted {
self.pending_events.push_back(ToSwarm::GenerateEvent(converted_event));
self.add_event_to_queue(ToSwarm::GenerateEvent(converted_event));
}
}
RequestToBehaviourEvent::NotifySessionDropped { session_id } => {
Expand All @@ -315,12 +329,13 @@ impl<Query: QueryBound, Data: DataBound> NetworkBehaviour for Behaviour<Query, D

fn poll(
&mut self,
_cx: &mut Context<'_>,
cx: &mut Context<'_>,
) -> Poll<ToSwarm<Self::ToSwarm, <Self::ConnectionHandler as ConnectionHandler>::FromBehaviour>>
{
if let Some(event) = self.pending_events.pop_front() {
return Poll::Ready(event);
}
self.wakers_waiting_for_event.push(cx.waker().clone());
Poll::Pending
}
}
18 changes: 17 additions & 1 deletion crates/papyrus_network/src/streamed_data/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ impl<Query: QueryBound, Data: DataBound> Handler<Query, Data> {
) -> bool {
match inbound_session.poll_unpin(cx) {
Poll::Ready(Err(io_error)) => {
// No need to wake those waiting for pending events because this function is called
// inside `poll`.
pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour(
RequestToBehaviourEvent::GenerateEvent(GenericEvent::SessionFailed {
session_id: inbound_session_id.into(),
Expand All @@ -129,6 +131,8 @@ impl<Query: QueryBound, Data: DataBound> Handler<Query, Data> {
true
}
Poll::Ready(Ok(())) => {
// No need to wake those waiting for pending events because this function is called
// inside `poll`.
pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour(
RequestToBehaviourEvent::GenerateEvent(
GenericEvent::SessionFinishedSuccessfully {
Expand Down Expand Up @@ -228,7 +232,8 @@ impl<Query: QueryBound, Data: DataBound> ConnectionHandler for Handler<Query, Da
}
});

// Handling pending_events at the end of the function to avoid starvation.
// Handling pending_events at the end of the function to avoid starvation and to make sure
// we don't return Pending if the code above created an event.
if let Some(event) = self.pending_events.pop_front() {
return Poll::Ready(event);
}
Expand All @@ -240,6 +245,9 @@ impl<Query: QueryBound, Data: DataBound> ConnectionHandler for Handler<Query, Da
RequestFromBehaviourEvent::CreateOutboundSession { query, outbound_session_id } => {
// TODO(shahak) Consider extracting to a utility function to prevent forgetfulness
// of the timeout.

// No need to wake because the swarm guarantees that `poll` will be called after
// on_behaviour_event. See https://github.com/libp2p/rust-libp2p/issues/5147
self.pending_events.push_back(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(
OutboundProtocol {
Expand Down Expand Up @@ -283,6 +291,8 @@ impl<Query: QueryBound, Data: DataBound> ConnectionHandler for Handler<Query, Da
if remove_result.is_none() {
self.dropped_outbound_sessions_non_negotiated.insert(outbound_session_id);
}
// No need to wake because the swarm guarantees that `poll` will be called after
// on_behaviour_event. See https://github.com/libp2p/rust-libp2p/issues/5147
self.pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour(
RequestToBehaviourEvent::NotifySessionDropped {
session_id: outbound_session_id.into(),
Expand All @@ -293,6 +303,8 @@ impl<Query: QueryBound, Data: DataBound> ConnectionHandler for Handler<Query, Da
session_id: SessionId::InboundSessionId(inbound_session_id),
} => {
self.id_to_inbound_session.remove(&inbound_session_id);
// No need to wake because the swarm guarantees that `poll` will be called after
// on_behaviour_event. See https://github.com/libp2p/rust-libp2p/issues/5147
self.pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour(
RequestToBehaviourEvent::NotifySessionDropped {
session_id: inbound_session_id.into(),
Expand Down Expand Up @@ -344,6 +356,8 @@ impl<Query: QueryBound, Data: DataBound> ConnectionHandler for Handler<Query, Da
protocol: (query, write_stream),
info: inbound_session_id,
}) => {
// No need to wake because the swarm guarantees that `poll` will be called after
// on_connection_event. See https://github.com/libp2p/rust-libp2p/issues/5147
self.pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour(
RequestToBehaviourEvent::GenerateEvent(GenericEvent::NewInboundSession {
query,
Expand Down Expand Up @@ -372,6 +386,8 @@ impl<Query: QueryBound, Data: DataBound> ConnectionHandler for Handler<Query, Da
}
StreamUpgradeError::Io(error) => SessionError::IOError(error),
};
// No need to wake because the swarm guarantees that `poll` will be called after
// on_connection_event. See https://github.com/libp2p/rust-libp2p/issues/5147
self.pending_events.push_back(ConnectionHandlerEvent::NotifyBehaviour(
RequestToBehaviourEvent::GenerateEvent(GenericEvent::SessionFailed {
session_id: outbound_session_id.into(),
Expand Down

0 comments on commit 69fd2b2

Please sign in to comment.