Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
divagant-martian committed Feb 15, 2024
1 parent b06e613 commit 66e4350
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 28 deletions.
75 changes: 48 additions & 27 deletions src/transport/quinn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub struct QuinnServerEndpoint<In: RpcMessage, Out: RpcMessage> {
impl<In: RpcMessage, Out: RpcMessage> QuinnServerEndpoint<In, Out> {
/// handles RPC requests from a connection
///
/// to cleanly shutdown the handler, drop the receiver side of the sender.
/// to cleanly shutdown tee handler, drop the receiver side of the sender.
async fn connection_handler(connection: quinn::Connection, sender: flume::Sender<SocketInner>) {
loop {
tracing::debug!("Awaiting incoming bidi substream on existing connection...");
Expand Down Expand Up @@ -106,6 +106,7 @@ impl<In: RpcMessage, Out: RpcMessage> QuinnServerEndpoint<In, Out> {
tracing::debug!("Spawning connection handler...");
tokio::spawn(Self::connection_handler(conection, sender.clone()));
}
tracing::debug!("endpoint handler finished");
}

/// Create a new server channel, given a quinn endpoint.
Expand Down Expand Up @@ -311,8 +312,6 @@ impl<In: RpcMessage, Out: RpcMessage> QuinnConnection<In, Out> {

let mut receiver = Receiver::new(&requests);

// a pending request to open a bi-directional stream that was received with a lost
// connection
let mut pending_request: Option<
oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>,
> = None;
Expand All @@ -322,19 +321,26 @@ impl<In: RpcMessage, Out: RpcMessage> QuinnConnection<In, Out> {
tokio::select! {
// wait for a new connection to be opened
conn_result = reconnect.as_mut() => {
tracing::trace!("tick: connection result");
match conn_result {
Ok(new_connection) => connection = Some(new_connection),
Ok(new_connection) => {
connection = Some(new_connection);
tracing::debug!("got new connection");
},
Err(e) => {
let connection_err = match e {
ReconnectErr::Connect(e) => {
// TODO(@divma): the type for now accepts only a
// ConnectionError, not a ConnectError. I'm mapping this now to
// some ConnectionError since before it was not even reported.
// Maybe adjust the type?
tracing::warn!("error calling connect: {}", e);
tracing::warn!(%e, "error calling connect");
quinn::ConnectionError::Reset
},
ReconnectErr::Connection(e) => e,
ReconnectErr::Connection(e) => {
tracing::warn!(%e, "failed to connect");
e
},
};
if let Some(request) = pending_request.take() {
if request.send(Err(connection_err)).is_err() {
Expand All @@ -346,8 +352,10 @@ impl<In: RpcMessage, Out: RpcMessage> QuinnConnection<In, Out> {
}
// wait for a new request as long as there is no pending one
req = receiver.next(), if pending_request.is_none() => {
tracing::trace!("tick: bidi request");
match req {
Some(request) => {
tracing::debug!("got new bidi request");
pending_request = Some(request)
},
None => {
Expand All @@ -361,6 +369,11 @@ impl<In: RpcMessage, Out: RpcMessage> QuinnConnection<In, Out> {
}
}

tracing::trace!(
"connection is some {}; request is some {}",
connection.is_some(),
pending_request.is_some()
);
if let Some(connection) = connection.as_mut() {
if let Some(request) = pending_request.take() {
match connection.open_bi().await {
Expand Down Expand Up @@ -473,32 +486,40 @@ impl Future for ReconnectHandler {

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.state.poison() {
ConnectionState::NotConnected => match self.endpoint.connect(self.addr, &self.name) {
Ok(connecting) => {
self.state = ConnectionState::Connecting(connecting);
self.poll(cx)
}
Err(e) => {
self.state = ConnectionState::NotConnected;
Poll::Ready(Err(ReconnectErr::Connect(e)))
}
},
ConnectionState::Connecting(mut connecting) => match connecting.poll_unpin(cx) {
Poll::Ready(res) => match res {
Ok(connection) => {
self.state = ConnectionState::Connected;
Poll::Ready(Ok(connection))
ConnectionState::NotConnected => {
tracing::debug!(addr = %self.addr, name = self.name, "calling connect");

match self.endpoint.connect(self.addr, &self.name) {
Ok(connecting) => {
self.state = ConnectionState::Connecting(connecting);
self.poll(cx)
}
Err(e) => {
self.state = ConnectionState::NotConnected;
Poll::Ready(Err(ReconnectErr::Connection(e)))
Poll::Ready(Err(ReconnectErr::Connect(e)))
}
},
Poll::Pending => {
self.state = ConnectionState::Connecting(connecting);
Poll::Pending
}
},
}
ConnectionState::Connecting(mut connecting) => {
tracing::debug!(addr = %self.addr, name = self.name, "awaiting connect");

match connecting.poll_unpin(cx) {
Poll::Ready(res) => match res {
Ok(connection) => {
self.state = ConnectionState::Connected;
Poll::Ready(Ok(connection))
}
Err(e) => {
self.state = ConnectionState::NotConnected;
Poll::Ready(Err(ReconnectErr::Connection(e)))
}
},
Poll::Pending => {
self.state = ConnectionState::Connecting(connecting);
Poll::Pending
}
}
}
ConnectionState::Connected => {
// waiting for a request to open a new connection, nothing to do
self.state = ConnectionState::Connected;
Expand Down
2 changes: 2 additions & 0 deletions tests/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ impl ComputeService {
let service = ComputeService;
while received < count {
received += 1;
tracing::debug!("before accept");
let (req, chan) = s.accept().await?;
tracing::debug!("after accept");
let service = service.clone();
tokio::spawn(async move {
use ComputeRequest::*;
Expand Down
8 changes: 7 additions & 1 deletion tests/quinn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ async fn server_away_and_back() -> anyhow::Result<()> {

// send a request. No server available so it should fail
let e = client.rpc(Sqr(4)).await.unwrap_err();
println!("{e}");
tracing::info!(%e, "got expected request failure");

// create the RPC Server
let server = Endpoint::server(server_config.clone(), server_addr)?;
Expand All @@ -164,21 +164,27 @@ async fn server_away_and_back() -> anyhow::Result<()> {

// send the first request and wait for the response to ensure everything works as expected
let SqrResponse(response) = client.rpc(Sqr(4)).await.unwrap();
tracing::info!(%response, "got expected response");
assert_eq!(response, 16);

let server = server_handle.await.unwrap().unwrap();
drop(server);
// wait for drop to free the socket
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;

tracing::info!("SERVER DROPPED");

// make the server run again
let server = Endpoint::server(server_config, server_addr)?;
let connection = transport::quinn::QuinnServerEndpoint::new(server)?;
let server = RpcServer::<ComputeService, _>::new(connection);
let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 5));
tracing::info!("Server spawned");

// server is running, this should work
tracing::info!("sending Sqr(3)");
let SqrResponse(response) = client.rpc(Sqr(3)).await.unwrap();
tracing::info!(%response, "got expected response");
assert_eq!(response, 9);

server_handle.abort();
Expand Down

0 comments on commit 66e4350

Please sign in to comment.