Skip to content

Commit

Permalink
fix bugs and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
divagant-martian committed Feb 16, 2024
1 parent 66e4350 commit 5dbf20b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 54 deletions.
78 changes: 33 additions & 45 deletions src/transport/quinn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub struct QuinnServerEndpoint<In: RpcMessage, Out: RpcMessage> {
impl<In: RpcMessage, Out: RpcMessage> QuinnServerEndpoint<In, Out> {
/// handles RPC requests from a connection
///
/// to cleanly shutdown tee handler, drop the receiver side of the sender.
/// to cleanly shutdown the handler, drop the receiver side of the sender.
async fn connection_handler(connection: quinn::Connection, sender: flume::Sender<SocketInner>) {
loop {
tracing::debug!("Awaiting incoming bidi substream on existing connection...");
Expand Down Expand Up @@ -106,7 +106,6 @@ impl<In: RpcMessage, Out: RpcMessage> QuinnServerEndpoint<In, Out> {
tracing::debug!("Spawning connection handler...");
tokio::spawn(Self::connection_handler(conection, sender.clone()));
}
tracing::debug!("endpoint handler finished");
}

/// Create a new server channel, given a quinn endpoint.
Expand Down Expand Up @@ -320,12 +319,11 @@ impl<In: RpcMessage, Out: RpcMessage> QuinnConnection<In, Out> {
loop {
tokio::select! {
// wait for a new connection to be opened
conn_result = reconnect.as_mut() => {
conn_result = reconnect.as_mut(), if !reconnect.connected() => {
tracing::trace!("tick: connection result");
match conn_result {
Ok(new_connection) => {
connection = Some(new_connection);
tracing::debug!("got new connection");
},
Err(e) => {
let connection_err = match e {
Expand All @@ -350,12 +348,12 @@ impl<In: RpcMessage, Out: RpcMessage> QuinnConnection<In, Out> {
},
}
}

// wait for a new request as long as there is no pending one
req = receiver.next(), if pending_request.is_none() => {
tracing::trace!("tick: bidi request");
match req {
Some(request) => {
tracing::debug!("got new bidi request");
pending_request = Some(request)
},
None => {
Expand All @@ -369,11 +367,6 @@ impl<In: RpcMessage, Out: RpcMessage> QuinnConnection<In, Out> {
}
}

tracing::trace!(
"connection is some {}; request is some {}",
connection.is_some(),
pending_request.is_some()
);
if let Some(connection) = connection.as_mut() {
if let Some(request) = pending_request.take() {
match connection.open_bi().await {
Expand Down Expand Up @@ -453,6 +446,10 @@ impl ReconnectHandler {
pub fn set_not_connected(&mut self) {
self.state.set_not_connected()
}

pub fn connected(&self) -> bool {
matches!(self.state, ConnectionState::Connected(_))
}
}

enum ConnectionState {
Expand All @@ -461,7 +458,7 @@ enum ConnectionState {
/// Connecting to the remote.
Connecting(quinn::Connecting),
/// A connection is already established. In this state, no more connection attempts are made.
Connected,
Connected(quinn::Connection),
/// Intermediate state while processing.
Poisoned,
}
Expand All @@ -486,44 +483,35 @@ impl Future for ReconnectHandler {

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.state.poison() {
ConnectionState::NotConnected => {
tracing::debug!(addr = %self.addr, name = self.name, "calling connect");

match self.endpoint.connect(self.addr, &self.name) {
Ok(connecting) => {
self.state = ConnectionState::Connecting(connecting);
self.poll(cx)
ConnectionState::NotConnected => match self.endpoint.connect(self.addr, &self.name) {
Ok(connecting) => {
self.state = ConnectionState::Connecting(connecting);
self.poll(cx)
}
Err(e) => {
self.state = ConnectionState::NotConnected;
Poll::Ready(Err(ReconnectErr::Connect(e)))
}
},
ConnectionState::Connecting(mut connecting) => match connecting.poll_unpin(cx) {
Poll::Ready(res) => match res {
Ok(connection) => {
self.state = ConnectionState::Connected(connection.clone());
Poll::Ready(Ok(connection))
}
Err(e) => {
self.state = ConnectionState::NotConnected;
Poll::Ready(Err(ReconnectErr::Connect(e)))
}
}
}
ConnectionState::Connecting(mut connecting) => {
tracing::debug!(addr = %self.addr, name = self.name, "awaiting connect");

match connecting.poll_unpin(cx) {
Poll::Ready(res) => match res {
Ok(connection) => {
self.state = ConnectionState::Connected;
Poll::Ready(Ok(connection))
}
Err(e) => {
self.state = ConnectionState::NotConnected;
Poll::Ready(Err(ReconnectErr::Connection(e)))
}
},
Poll::Pending => {
self.state = ConnectionState::Connecting(connecting);
Poll::Pending
Poll::Ready(Err(ReconnectErr::Connection(e)))
}
},
Poll::Pending => {
self.state = ConnectionState::Connecting(connecting);
Poll::Pending
}
}
ConnectionState::Connected => {
// waiting for a request to open a new connection, nothing to do
self.state = ConnectionState::Connected;
Poll::Pending
},
ConnectionState::Connected(connection) => {
self.state = ConnectionState::Connected(connection.clone());
Poll::Ready(Ok(connection))
}
ConnectionState::Poisoned => unreachable!("poisoned connection state"),
}
Expand Down Expand Up @@ -573,7 +561,7 @@ impl<'a, T> futures::stream::Stream for Receiver<'a, T> {
Poll::Ready(None)
}
Poll::Pending => {
*self = Receiver::PreReceive(recv);
*self = Receiver::Receiving(recv, fut);
Poll::Pending
}
},
Expand Down
2 changes: 0 additions & 2 deletions tests/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,7 @@ impl ComputeService {
let service = ComputeService;
while received < count {
received += 1;
tracing::debug!("before accept");
let (req, chan) = s.accept().await?;
tracing::debug!("after accept");
let service = service.clone();
tokio::spawn(async move {
use ComputeRequest::*;
Expand Down
7 changes: 0 additions & 7 deletions tests/quinn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ async fn server_away_and_back() -> anyhow::Result<()> {

// send a request. No server available so it should fail
let e = client.rpc(Sqr(4)).await.unwrap_err();
tracing::info!(%e, "got expected request failure");

// create the RPC Server
let server = Endpoint::server(server_config.clone(), server_addr)?;
Expand All @@ -164,27 +163,21 @@ async fn server_away_and_back() -> anyhow::Result<()> {

// send the first request and wait for the response to ensure everything works as expected
let SqrResponse(response) = client.rpc(Sqr(4)).await.unwrap();
tracing::info!(%response, "got expected response");
assert_eq!(response, 16);

let server = server_handle.await.unwrap().unwrap();
drop(server);
// wait for drop to free the socket
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;

tracing::info!("SERVER DROPPED");

// make the server run again
let server = Endpoint::server(server_config, server_addr)?;
let connection = transport::quinn::QuinnServerEndpoint::new(server)?;
let server = RpcServer::<ComputeService, _>::new(connection);
let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 5));
tracing::info!("Server spawned");

// server is running, this should work
tracing::info!("sending Sqr(3)");
let SqrResponse(response) = client.rpc(Sqr(3)).await.unwrap();
tracing::info!(%response, "got expected response");
assert_eq!(response, 9);

server_handle.abort();
Expand Down

0 comments on commit 5dbf20b

Please sign in to comment.