diff --git a/src/transport/quinn.rs b/src/transport/quinn.rs index 82a356f..77ed943 100644 --- a/src/transport/quinn.rs +++ b/src/transport/quinn.rs @@ -301,61 +301,84 @@ impl QuinnConnection { name: String, requests: flume::Receiver>>, ) { + let reconnect = ReconnectHandler { + endpoint, + state: ConnectionState::NotConnected, + addr, + name, + }; + futures::pin_mut!(reconnect); + + let mut receiver = Receiver::new(&requests); + // a pending request to open a bi-directional stream that was received with a lost // connection - let mut pending_request = None; - 'outer: loop { - tracing::debug!("Connecting to {} as {}", addr, name); - let connecting = match endpoint.connect(addr, &name) { - Ok(connecting) => connecting, - Err(e) => { - tracing::warn!("error calling connect: {}", e); - // could not connect, if a request is pending, drop it. - pending_request = None; - // try again. Maybe delay? - continue; - } - }; - let connection = match connecting.await { - Ok(connection) => connection, - Err(e) => { - tracing::warn!("error awaiting connect: {}", e); - // could not connect, if a request is pending, drop it. - pending_request = None; - // try again. Maybe delay? - continue; + let mut pending_request: Option< + oneshot::Sender>, + > = None; + let mut connection = None; + + loop { + tokio::select! { + // wait for a new connection to be opened + conn_result = reconnect.as_mut() => { + match conn_result { + Ok(new_connection) => connection = Some(new_connection), + Err(e) => { + let connection_err = match e { + ReconnectErr::Connect(e) => { + // TODO(@divma): the type for now accepts only a + // ConnectionError, not a ConnectError. I'm mapping this now to + // some ConnectionError since before it was not even reported. + // Maybe adjust the type? + tracing::warn!("error calling connect: {}", e); + quinn::ConnectionError::Reset + }, + ReconnectErr::Connection(e) => e, + }; + if let Some(request) = pending_request.take() { + if request.send(Err(connection_err)).is_err() { + tracing::debug!("requester dropped"); + } + } + }, + } } - }; - loop { - // first handle the pending request, then check for new requests - let request = match pending_request.take() { - Some(request) => request, - None => { - tracing::debug!("Awaiting request for new bidi substream..."); - match requests.recv_async().await { - Ok(request) => request, - Err(_) => { - tracing::debug!("client dropped"); + // wait for a new request as long as there is no pending one + req = receiver.next(), if pending_request.is_none() => { + match req { + Some(request) => { + pending_request = Some(request) + }, + None => { + tracing::debug!("client dropped"); + if let Some(connection) = connection { connection.close(0u32.into(), b"requester dropped"); - break; } + break; } } - }; - tracing::debug!("Got request for new bidi substream"); - match connection.open_bi().await { - Ok(pair) => { - tracing::debug!("Bidi substream opened"); - if request.send(Ok(pair)).is_err() { - tracing::debug!("requester dropped"); + } + } + + if let Some(connection) = connection.as_mut() { + if let Some(request) = pending_request.take() { + match connection.open_bi().await { + Ok(pair) => { + tracing::debug!("Bidi substream opened"); + if request.send(Ok(pair)).is_err() { + tracing::debug!("requester dropped"); + } + } + Err(e) => { + tracing::warn!("error opening bidi substream: {}", e); + tracing::warn!("recreating connection"); + // NOTE: the connection might be stale, so we recreate the + // connection and set the request as pending instead of + // sending the error as a response + reconnect.set_not_connected(); + pending_request = Some(request); } - } - Err(e) => { - tracing::warn!("error opening bidi substream: {}", e); - tracing::warn!("recreating connection"); - pending_request = Some(request); - // try again. Maybe delay? - continue 'outer; } } } @@ -406,6 +429,138 @@ impl QuinnConnection { } } +struct ReconnectHandler { + endpoint: quinn::Endpoint, + state: ConnectionState, + addr: SocketAddr, + name: String, +} + +impl ReconnectHandler { + pub fn set_not_connected(&mut self) { + self.state.set_not_connected() + } +} + +enum ConnectionState { + /// There is no active connection. An attempt to connect will be made. + NotConnected, + /// Connecting to the remote. + Connecting(quinn::Connecting), + /// A connection is already established. In this state, no more connection attempts are made. + Connected, + /// Intermediate state while processing. + Poisoned, +} + +impl ConnectionState { + pub fn poison(&mut self) -> ConnectionState { + std::mem::replace(self, ConnectionState::Poisoned) + } + + pub fn set_not_connected(&mut self) { + *self = ConnectionState::NotConnected + } +} + +enum ReconnectErr { + Connect(quinn::ConnectError), + Connection(quinn::ConnectionError), +} + +impl Future for ReconnectHandler { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.state.poison() { + ConnectionState::NotConnected => match self.endpoint.connect(self.addr, &self.name) { + Ok(connecting) => { + self.state = ConnectionState::Connecting(connecting); + self.poll(cx) + } + Err(e) => { + self.state = ConnectionState::NotConnected; + Poll::Ready(Err(ReconnectErr::Connect(e))) + } + }, + ConnectionState::Connecting(mut connecting) => match connecting.poll_unpin(cx) { + Poll::Ready(res) => match res { + Ok(connection) => { + self.state = ConnectionState::Connected; + Poll::Ready(Ok(connection)) + } + Err(e) => { + self.state = ConnectionState::NotConnected; + Poll::Ready(Err(ReconnectErr::Connection(e))) + } + }, + Poll::Pending => { + self.state = ConnectionState::Connecting(connecting); + Poll::Pending + } + }, + ConnectionState::Connected => { + // waiting for a request to open a new connection, nothing to do + self.state = ConnectionState::Connected; + Poll::Pending + } + ConnectionState::Poisoned => unreachable!("poisoned connection state"), + } + } +} + +/// Wrapper over [`flume::Receiver`] that can be used with [`tokio::select`]. +/// +/// NOTE: from https://github.com/zesterer/flume/issues/104: +/// > If RecvFut is dropped without being polled, the item is never received. +enum Receiver<'a, T> +where + Self: 'a, +{ + PreReceive(&'a flume::Receiver), + Receiving(&'a flume::Receiver, flume::r#async::RecvFut<'a, T>), + Poisoned, +} + +impl<'a, T> Receiver<'a, T> { + fn new(recv: &'a flume::Receiver) -> Self { + Receiver::PreReceive(recv) + } + + fn poison(&mut self) -> Self { + std::mem::replace(self, Self::Poisoned) + } +} + +impl<'a, T> futures::stream::Stream for Receiver<'a, T> { + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.poison() { + Receiver::PreReceive(recv) => { + let fut = recv.recv_async(); + *self = Receiver::Receiving(recv, fut); + self.poll_next(cx) + } + Receiver::Receiving(recv, mut fut) => match fut.poll_unpin(cx) { + Poll::Ready(Ok(t)) => { + *self = Receiver::PreReceive(recv); + Poll::Ready(Some(t)) + } + Poll::Ready(Err(flume::RecvError::Disconnected)) => { + *self = Receiver::PreReceive(recv); + Poll::Ready(None) + } + Poll::Pending => { + *self = Receiver::PreReceive(recv); + Poll::Pending + } + }, + Receiver::Poisoned => unreachable!("poisoned receiver state"), + } + } +} + impl fmt::Debug for QuinnConnection { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ClientChannel") diff --git a/tests/quinn.rs b/tests/quinn.rs index 1ecf116..6a6068d 100644 --- a/tests/quinn.rs +++ b/tests/quinn.rs @@ -145,19 +145,23 @@ async fn server_away_and_back() -> anyhow::Result<()> { let server_addr: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 12347)); let (server_config, server_cert) = configure_server()?; - let server = Endpoint::server(server_config.clone(), server_addr)?; + + // create the RPC client let client = make_client_endpoint("0.0.0.0:0".parse()?, &[&server_cert])?; + let client_connection = + transport::quinn::QuinnConnection::new(client, server_addr, "localhost".into()); + let client = RpcClient::new(client_connection); + + // send a request. No server available so it should fail + let e = client.rpc(Sqr(4)).await.unwrap_err(); + println!("{e}"); // create the RPC Server + let server = Endpoint::server(server_config.clone(), server_addr)?; let connection = transport::quinn::QuinnServerEndpoint::new(server)?; let server = RpcServer::::new(connection); let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 1)); - // create the RPC client - let client_connection = - transport::quinn::QuinnConnection::new(client, server_addr, "localhost".into()); - let client = RpcClient::new(client_connection); - // send the first request and wait for the response to ensure everything works as expected let SqrResponse(response) = client.rpc(Sqr(4)).await.unwrap(); assert_eq!(response, 16);