Skip to content

Commit

Permalink
poll for new connections and new bidi requests concurrently
Browse files Browse the repository at this point in the history
  • Loading branch information
divagant-martian committed Feb 15, 2024
1 parent adf2ec9 commit b06e613
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 53 deletions.
249 changes: 202 additions & 47 deletions src/transport/quinn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,61 +301,84 @@ impl<In: RpcMessage, Out: RpcMessage> QuinnConnection<In, Out> {
name: String,
requests: flume::Receiver<oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>>,
) {
let reconnect = ReconnectHandler {
endpoint,
state: ConnectionState::NotConnected,
addr,
name,
};
futures::pin_mut!(reconnect);

let mut receiver = Receiver::new(&requests);

// a pending request to open a bi-directional stream that was received with a lost
// connection
let mut pending_request = None;
'outer: loop {
tracing::debug!("Connecting to {} as {}", addr, name);
let connecting = match endpoint.connect(addr, &name) {
Ok(connecting) => connecting,
Err(e) => {
tracing::warn!("error calling connect: {}", e);
// could not connect, if a request is pending, drop it.
pending_request = None;
// try again. Maybe delay?
continue;
}
};
let connection = match connecting.await {
Ok(connection) => connection,
Err(e) => {
tracing::warn!("error awaiting connect: {}", e);
// could not connect, if a request is pending, drop it.
pending_request = None;
// try again. Maybe delay?
continue;
let mut pending_request: Option<
oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>,
> = None;
let mut connection = None;

loop {
tokio::select! {
// wait for a new connection to be opened
conn_result = reconnect.as_mut() => {
match conn_result {
Ok(new_connection) => connection = Some(new_connection),
Err(e) => {
let connection_err = match e {
ReconnectErr::Connect(e) => {
// TODO(@divma): the type for now accepts only a
// ConnectionError, not a ConnectError. I'm mapping this now to
// some ConnectionError since before it was not even reported.
// Maybe adjust the type?
tracing::warn!("error calling connect: {}", e);
quinn::ConnectionError::Reset
},
ReconnectErr::Connection(e) => e,
};
if let Some(request) = pending_request.take() {
if request.send(Err(connection_err)).is_err() {
tracing::debug!("requester dropped");
}
}
},
}
}
};
loop {
// first handle the pending request, then check for new requests
let request = match pending_request.take() {
Some(request) => request,
None => {
tracing::debug!("Awaiting request for new bidi substream...");
match requests.recv_async().await {
Ok(request) => request,
Err(_) => {
tracing::debug!("client dropped");
// wait for a new request as long as there is no pending one
req = receiver.next(), if pending_request.is_none() => {
match req {
Some(request) => {
pending_request = Some(request)
},
None => {
tracing::debug!("client dropped");
if let Some(connection) = connection {
connection.close(0u32.into(), b"requester dropped");
break;
}
break;
}
}
};
tracing::debug!("Got request for new bidi substream");
match connection.open_bi().await {
Ok(pair) => {
tracing::debug!("Bidi substream opened");
if request.send(Ok(pair)).is_err() {
tracing::debug!("requester dropped");
}
}

if let Some(connection) = connection.as_mut() {
if let Some(request) = pending_request.take() {
match connection.open_bi().await {
Ok(pair) => {
tracing::debug!("Bidi substream opened");
if request.send(Ok(pair)).is_err() {
tracing::debug!("requester dropped");
}
}
Err(e) => {
tracing::warn!("error opening bidi substream: {}", e);
tracing::warn!("recreating connection");
// NOTE: the connection might be stale, so we recreate the
// connection and set the request as pending instead of
// sending the error as a response
reconnect.set_not_connected();
pending_request = Some(request);
}
}
Err(e) => {
tracing::warn!("error opening bidi substream: {}", e);
tracing::warn!("recreating connection");
pending_request = Some(request);
// try again. Maybe delay?
continue 'outer;
}
}
}
Expand Down Expand Up @@ -406,6 +429,138 @@ impl<In: RpcMessage, Out: RpcMessage> QuinnConnection<In, Out> {
}
}

struct ReconnectHandler {
endpoint: quinn::Endpoint,
state: ConnectionState,
addr: SocketAddr,
name: String,
}

impl ReconnectHandler {
pub fn set_not_connected(&mut self) {
self.state.set_not_connected()
}
}

enum ConnectionState {
/// There is no active connection. An attempt to connect will be made.
NotConnected,
/// Connecting to the remote.
Connecting(quinn::Connecting),
/// A connection is already established. In this state, no more connection attempts are made.
Connected,
/// Intermediate state while processing.
Poisoned,
}

impl ConnectionState {
pub fn poison(&mut self) -> ConnectionState {
std::mem::replace(self, ConnectionState::Poisoned)
}

pub fn set_not_connected(&mut self) {
*self = ConnectionState::NotConnected
}
}

enum ReconnectErr {
Connect(quinn::ConnectError),
Connection(quinn::ConnectionError),
}

impl Future for ReconnectHandler {
type Output = Result<quinn::Connection, ReconnectErr>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.state.poison() {
ConnectionState::NotConnected => match self.endpoint.connect(self.addr, &self.name) {
Ok(connecting) => {
self.state = ConnectionState::Connecting(connecting);
self.poll(cx)
}
Err(e) => {
self.state = ConnectionState::NotConnected;
Poll::Ready(Err(ReconnectErr::Connect(e)))
}
},
ConnectionState::Connecting(mut connecting) => match connecting.poll_unpin(cx) {
Poll::Ready(res) => match res {
Ok(connection) => {
self.state = ConnectionState::Connected;
Poll::Ready(Ok(connection))
}
Err(e) => {
self.state = ConnectionState::NotConnected;
Poll::Ready(Err(ReconnectErr::Connection(e)))
}
},
Poll::Pending => {
self.state = ConnectionState::Connecting(connecting);
Poll::Pending
}
},
ConnectionState::Connected => {
// waiting for a request to open a new connection, nothing to do
self.state = ConnectionState::Connected;
Poll::Pending
}
ConnectionState::Poisoned => unreachable!("poisoned connection state"),
}
}
}

/// Wrapper over [`flume::Receiver`] that can be used with [`tokio::select`].
///
/// NOTE: from https://github.com/zesterer/flume/issues/104:
/// > If RecvFut is dropped without being polled, the item is never received.
enum Receiver<'a, T>
where
Self: 'a,
{
PreReceive(&'a flume::Receiver<T>),
Receiving(&'a flume::Receiver<T>, flume::r#async::RecvFut<'a, T>),
Poisoned,
}

impl<'a, T> Receiver<'a, T> {
fn new(recv: &'a flume::Receiver<T>) -> Self {
Receiver::PreReceive(recv)
}

fn poison(&mut self) -> Self {
std::mem::replace(self, Self::Poisoned)
}
}

impl<'a, T> futures::stream::Stream for Receiver<'a, T> {
type Item = T;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.poison() {
Receiver::PreReceive(recv) => {
let fut = recv.recv_async();
*self = Receiver::Receiving(recv, fut);
self.poll_next(cx)
}
Receiver::Receiving(recv, mut fut) => match fut.poll_unpin(cx) {
Poll::Ready(Ok(t)) => {
*self = Receiver::PreReceive(recv);
Poll::Ready(Some(t))
}
Poll::Ready(Err(flume::RecvError::Disconnected)) => {
*self = Receiver::PreReceive(recv);
Poll::Ready(None)
}
Poll::Pending => {
*self = Receiver::PreReceive(recv);
Poll::Pending
}
},
Receiver::Poisoned => unreachable!("poisoned receiver state"),
}
}
}

impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for QuinnConnection<In, Out> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ClientChannel")
Expand Down
16 changes: 10 additions & 6 deletions tests/quinn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,19 +145,23 @@ async fn server_away_and_back() -> anyhow::Result<()> {

let server_addr: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 12347));
let (server_config, server_cert) = configure_server()?;
let server = Endpoint::server(server_config.clone(), server_addr)?;

// create the RPC client
let client = make_client_endpoint("0.0.0.0:0".parse()?, &[&server_cert])?;
let client_connection =
transport::quinn::QuinnConnection::new(client, server_addr, "localhost".into());
let client = RpcClient::new(client_connection);

// send a request. No server available so it should fail
let e = client.rpc(Sqr(4)).await.unwrap_err();
println!("{e}");

// create the RPC Server
let server = Endpoint::server(server_config.clone(), server_addr)?;
let connection = transport::quinn::QuinnServerEndpoint::new(server)?;
let server = RpcServer::<ComputeService, _>::new(connection);
let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 1));

// create the RPC client
let client_connection =
transport::quinn::QuinnConnection::new(client, server_addr, "localhost".into());
let client = RpcClient::new(client_connection);

// send the first request and wait for the response to ensure everything works as expected
let SqrResponse(response) = client.rpc(Sqr(4)).await.unwrap();
assert_eq!(response, 16);
Expand Down

0 comments on commit b06e613

Please sign in to comment.