diff --git a/src/transport/quinn.rs b/src/transport/quinn.rs index 78d91b5..101739b 100644 --- a/src/transport/quinn.rs +++ b/src/transport/quinn.rs @@ -301,47 +301,103 @@ impl QuinnConnection { name: String, requests: flume::Receiver>>, ) { - 'outer: loop { - tracing::debug!("Connecting to {} as {}", addr, name); - let connecting = match endpoint.connect(addr, &name) { - Ok(connecting) => connecting, - Err(e) => { - tracing::warn!("error calling connect: {}", e); - // try again. Maybe delay? - continue; + let reconnect = ReconnectHandler { + endpoint, + state: ConnectionState::NotConnected, + addr, + name, + }; + futures::pin_mut!(reconnect); + + let mut receiver = Receiver::new(&requests); + + let mut pending_request: Option< + oneshot::Sender>, + > = None; + let mut connection = None; + + loop { + let mut conn_result = None; + let mut chann_result = None; + if !reconnect.connected() && pending_request.is_none() { + match futures::future::select(reconnect.as_mut(), receiver.next()).await { + futures::future::Either::Left((connection_result, _)) => { + conn_result = Some(connection_result) + } + futures::future::Either::Right((channel_result, _)) => { + chann_result = Some(channel_result); + } } - }; - let connection = match connecting.await { - Ok(connection) => connection, - Err(e) => { - tracing::warn!("error awaiting connect: {}", e); - // try again. Maybe delay? - continue; + } else if !reconnect.connected() { + // only need a new connection + conn_result = Some(reconnect.as_mut().await); + } else if pending_request.is_none() { + // there is a connection, just need a request + chann_result = Some(receiver.next().await); + } + + if let Some(conn_result) = conn_result { + tracing::trace!("tick: connection result"); + match conn_result { + Ok(new_connection) => { + connection = Some(new_connection); + } + Err(e) => { + let connection_err = match e { + ReconnectErr::Connect(e) => { + // TODO(@divma): the type for now accepts only a + // ConnectionError, not a ConnectError. I'm mapping this now to + // some ConnectionError since before it was not even reported. + // Maybe adjust the type? + tracing::warn!(%e, "error calling connect"); + quinn::ConnectionError::Reset + } + ReconnectErr::Connection(e) => { + tracing::warn!(%e, "failed to connect"); + e + } + }; + if let Some(request) = pending_request.take() { + if request.send(Err(connection_err)).is_err() { + tracing::debug!("requester dropped"); + } + } + } } - }; - loop { - tracing::debug!("Awaiting request for new bidi substream..."); - let request = match requests.recv_async().await { - Ok(request) => request, - Err(_) => { + } + + if let Some(req) = chann_result { + tracing::trace!("tick: bidi request"); + match req { + Some(request) => pending_request = Some(request), + None => { tracing::debug!("client dropped"); - connection.close(0u32.into(), b"requester dropped"); + if let Some(connection) = connection { + connection.close(0u32.into(), b"requester dropped"); + } break; } - }; - tracing::debug!("Got request for new bidi substream"); - match connection.open_bi().await { - Ok(pair) => { - tracing::debug!("Bidi substream opened"); - if request.send(Ok(pair)).is_err() { - tracing::debug!("requester dropped"); + } + } + + if let Some(connection) = connection.as_mut() { + if let Some(request) = pending_request.take() { + match connection.open_bi().await { + Ok(pair) => { + tracing::debug!("Bidi substream opened"); + if request.send(Ok(pair)).is_err() { + tracing::debug!("requester dropped"); + } + } + Err(e) => { + tracing::warn!("error opening bidi substream: {}", e); + tracing::warn!("recreating connection"); + // NOTE: the connection might be stale, so we recreate the + // connection and set the request as pending instead of + // sending the error as a response + reconnect.set_not_connected(); + pending_request = Some(request); } - } - Err(e) => { - tracing::warn!("error opening bidi substream: {}", e); - tracing::warn!("recreating connection"); - // try again. Maybe delay? - continue 'outer; } } } @@ -392,6 +448,141 @@ impl QuinnConnection { } } +struct ReconnectHandler { + endpoint: quinn::Endpoint, + state: ConnectionState, + addr: SocketAddr, + name: String, +} + +impl ReconnectHandler { + pub fn set_not_connected(&mut self) { + self.state.set_not_connected() + } + + pub fn connected(&self) -> bool { + matches!(self.state, ConnectionState::Connected(_)) + } +} + +enum ConnectionState { + /// There is no active connection. An attempt to connect will be made. + NotConnected, + /// Connecting to the remote. + Connecting(quinn::Connecting), + /// A connection is already established. In this state, no more connection attempts are made. + Connected(quinn::Connection), + /// Intermediate state while processing. + Poisoned, +} + +impl ConnectionState { + pub fn poison(&mut self) -> ConnectionState { + std::mem::replace(self, ConnectionState::Poisoned) + } + + pub fn set_not_connected(&mut self) { + *self = ConnectionState::NotConnected + } +} + +enum ReconnectErr { + Connect(quinn::ConnectError), + Connection(quinn::ConnectionError), +} + +impl Future for ReconnectHandler { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.state.poison() { + ConnectionState::NotConnected => match self.endpoint.connect(self.addr, &self.name) { + Ok(connecting) => { + self.state = ConnectionState::Connecting(connecting); + self.poll(cx) + } + Err(e) => { + self.state = ConnectionState::NotConnected; + Poll::Ready(Err(ReconnectErr::Connect(e))) + } + }, + ConnectionState::Connecting(mut connecting) => match connecting.poll_unpin(cx) { + Poll::Ready(res) => match res { + Ok(connection) => { + self.state = ConnectionState::Connected(connection.clone()); + Poll::Ready(Ok(connection)) + } + Err(e) => { + self.state = ConnectionState::NotConnected; + Poll::Ready(Err(ReconnectErr::Connection(e))) + } + }, + Poll::Pending => { + self.state = ConnectionState::Connecting(connecting); + Poll::Pending + } + }, + ConnectionState::Connected(connection) => { + self.state = ConnectionState::Connected(connection.clone()); + Poll::Ready(Ok(connection)) + } + ConnectionState::Poisoned => unreachable!("poisoned connection state"), + } + } +} + +/// Wrapper over [`flume::Receiver`] that can be used with [`tokio::select`]. +/// +/// NOTE: from https://github.com/zesterer/flume/issues/104: +/// > If RecvFut is dropped without being polled, the item is never received. +enum Receiver<'a, T> +where + Self: 'a, +{ + PreReceive(&'a flume::Receiver), + Receiving(&'a flume::Receiver, flume::r#async::RecvFut<'a, T>), + Poisoned, +} + +impl<'a, T> Receiver<'a, T> { + fn new(recv: &'a flume::Receiver) -> Self { + Receiver::PreReceive(recv) + } + + fn poison(&mut self) -> Self { + std::mem::replace(self, Self::Poisoned) + } +} + +impl<'a, T> futures::stream::Stream for Receiver<'a, T> { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.poison() { + Receiver::PreReceive(recv) => { + let fut = recv.recv_async(); + *self = Receiver::Receiving(recv, fut); + self.poll_next(cx) + } + Receiver::Receiving(recv, mut fut) => match fut.poll_unpin(cx) { + Poll::Ready(Ok(t)) => { + *self = Receiver::PreReceive(recv); + Poll::Ready(Some(t)) + } + Poll::Ready(Err(flume::RecvError::Disconnected)) => { + *self = Receiver::PreReceive(recv); + Poll::Ready(None) + } + Poll::Pending => { + *self = Receiver::Receiving(recv, fut); + Poll::Pending + } + }, + Receiver::Poisoned => unreachable!("poisoned receiver state"), + } + } +} + impl fmt::Debug for QuinnConnection { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ClientChannel") diff --git a/tests/math.rs b/tests/math.rs index ab44b75..583b7ca 100644 --- a/tests/math.rs +++ b/tests/math.rs @@ -153,6 +153,38 @@ impl ComputeService { } } + /// Runs the service until `count` requests have been received. + pub async fn server_bounded>( + server: RpcServer, + count: usize, + ) -> result::Result, RpcServerError> { + tracing::info!(%count, "server running"); + let s = server; + let mut received = 0; + let service = ComputeService; + while received < count { + received += 1; + let (req, chan) = s.accept().await?; + let service = service.clone(); + tokio::spawn(async move { + use ComputeRequest::*; + tracing::info!(?req, "got request"); + #[rustfmt::skip] + match req { + Sqr(msg) => chan.rpc(msg, service, ComputeService::sqr).await, + Sum(msg) => chan.client_streaming(msg, service, ComputeService::sum).await, + Fibonacci(msg) => chan.server_streaming(msg, service, ComputeService::fibonacci).await, + Multiply(msg) => chan.bidi_streaming(msg, service, ComputeService::multiply).await, + SumUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?, + MultiplyUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?, + }?; + Ok::<_, RpcServerError>(()) + }); + } + tracing::info!(%count, "server finished"); + Ok(s) + } + pub async fn server_par>( server: RpcServer, parallelism: usize, diff --git a/tests/quinn.rs b/tests/quinn.rs index a5a0126..de5fc1f 100644 --- a/tests/quinn.rs +++ b/tests/quinn.rs @@ -4,7 +4,7 @@ use std::{ sync::Arc, }; -use quic_rpc::{RpcClient, RpcServer}; +use quic_rpc::{transport, RpcClient, RpcServer}; use quinn::{ClientConfig, Endpoint, ServerConfig}; use tokio::task::JoinHandle; @@ -91,7 +91,7 @@ pub fn make_endpoints(port: u16) -> anyhow::Result { fn run_server(server: quinn::Endpoint) -> JoinHandle> { tokio::task::spawn(async move { - let connection = quic_rpc::transport::quinn::QuinnServerEndpoint::new(server)?; + let connection = transport::quinn::QuinnServerEndpoint::new(server)?; let server = RpcServer::::new(connection); ComputeService::server(server).await?; anyhow::Ok(()) @@ -110,8 +110,7 @@ async fn quinn_channel_bench() -> anyhow::Result<()> { tracing::debug!("Starting server"); let server_handle = run_server(server); tracing::debug!("Starting client"); - let client = - quic_rpc::transport::quinn::QuinnConnection::new(client, server_addr, "localhost".into()); + let client = transport::quinn::QuinnConnection::new(client, server_addr, "localhost".into()); let client = RpcClient::::new(client); tracing::debug!("Starting benchmark"); bench(client, 50000).await?; @@ -129,8 +128,58 @@ async fn quinn_channel_smoke() -> anyhow::Result<()> { } = make_endpoints(12346)?; let server_handle = run_server(server); let client_connection = - quic_rpc::transport::quinn::QuinnConnection::new(client, server_addr, "localhost".into()); + transport::quinn::QuinnConnection::new(client, server_addr, "localhost".into()); smoke_test(client_connection).await?; server_handle.abort(); Ok(()) } + +/// Test that using the client after the server goes away and comes back behaves as if the server +/// had never gone away in the first place. +/// +/// This is a regression test. +#[tokio::test] +async fn server_away_and_back() -> anyhow::Result<()> { + tracing_subscriber::fmt::try_init().ok(); + tracing::info!("Creating endpoints"); + + let server_addr: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 12347)); + let (server_config, server_cert) = configure_server()?; + + // create the RPC client + let client = make_client_endpoint("0.0.0.0:0".parse()?, &[&server_cert])?; + let client_connection = + transport::quinn::QuinnConnection::new(client, server_addr, "localhost".into()); + let client = RpcClient::new(client_connection); + + // send a request. No server available so it should fail + client.rpc(Sqr(4)).await.unwrap_err(); + + // create the RPC Server + let server = Endpoint::server(server_config.clone(), server_addr)?; + let connection = transport::quinn::QuinnServerEndpoint::new(server)?; + let server = RpcServer::::new(connection); + let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 1)); + + // send the first request and wait for the response to ensure everything works as expected + let SqrResponse(response) = client.rpc(Sqr(4)).await.unwrap(); + assert_eq!(response, 16); + + let server = server_handle.await.unwrap().unwrap(); + drop(server); + // wait for drop to free the socket + tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; + + // make the server run again + let server = Endpoint::server(server_config, server_addr)?; + let connection = transport::quinn::QuinnServerEndpoint::new(server)?; + let server = RpcServer::::new(connection); + let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 5)); + + // server is running, this should work + let SqrResponse(response) = client.rpc(Sqr(3)).await.unwrap(); + assert_eq!(response, 9); + + server_handle.abort(); + Ok(()) +}