Skip to content

Commit

Permalink
fix: rpc client concurrently waits for requests and new connections (#62
Browse files Browse the repository at this point in the history
)
  • Loading branch information
divagant-martian authored Feb 27, 2024
1 parent 865622e commit 3323574
Show file tree
Hide file tree
Showing 3 changed files with 312 additions and 40 deletions.
261 changes: 226 additions & 35 deletions src/transport/quinn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,47 +301,103 @@ impl<In: RpcMessage, Out: RpcMessage> QuinnConnection<In, Out> {
name: String,
requests: flume::Receiver<oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>>,
) {
'outer: loop {
tracing::debug!("Connecting to {} as {}", addr, name);
let connecting = match endpoint.connect(addr, &name) {
Ok(connecting) => connecting,
Err(e) => {
tracing::warn!("error calling connect: {}", e);
// try again. Maybe delay?
continue;
let reconnect = ReconnectHandler {
endpoint,
state: ConnectionState::NotConnected,
addr,
name,
};
futures::pin_mut!(reconnect);

let mut receiver = Receiver::new(&requests);

let mut pending_request: Option<
oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>,
> = None;
let mut connection = None;

loop {
let mut conn_result = None;
let mut chann_result = None;
if !reconnect.connected() && pending_request.is_none() {
match futures::future::select(reconnect.as_mut(), receiver.next()).await {
futures::future::Either::Left((connection_result, _)) => {
conn_result = Some(connection_result)
}
futures::future::Either::Right((channel_result, _)) => {
chann_result = Some(channel_result);
}
}
};
let connection = match connecting.await {
Ok(connection) => connection,
Err(e) => {
tracing::warn!("error awaiting connect: {}", e);
// try again. Maybe delay?
continue;
} else if !reconnect.connected() {
// only need a new connection
conn_result = Some(reconnect.as_mut().await);
} else if pending_request.is_none() {
// there is a connection, just need a request
chann_result = Some(receiver.next().await);
}

if let Some(conn_result) = conn_result {
tracing::trace!("tick: connection result");
match conn_result {
Ok(new_connection) => {
connection = Some(new_connection);
}
Err(e) => {
let connection_err = match e {
ReconnectErr::Connect(e) => {
// TODO(@divma): the type for now accepts only a
// ConnectionError, not a ConnectError. I'm mapping this now to
// some ConnectionError since before it was not even reported.
// Maybe adjust the type?
tracing::warn!(%e, "error calling connect");
quinn::ConnectionError::Reset
}
ReconnectErr::Connection(e) => {
tracing::warn!(%e, "failed to connect");
e
}
};
if let Some(request) = pending_request.take() {
if request.send(Err(connection_err)).is_err() {
tracing::debug!("requester dropped");
}
}
}
}
};
loop {
tracing::debug!("Awaiting request for new bidi substream...");
let request = match requests.recv_async().await {
Ok(request) => request,
Err(_) => {
}

if let Some(req) = chann_result {
tracing::trace!("tick: bidi request");
match req {
Some(request) => pending_request = Some(request),
None => {
tracing::debug!("client dropped");
connection.close(0u32.into(), b"requester dropped");
if let Some(connection) = connection {
connection.close(0u32.into(), b"requester dropped");
}
break;
}
};
tracing::debug!("Got request for new bidi substream");
match connection.open_bi().await {
Ok(pair) => {
tracing::debug!("Bidi substream opened");
if request.send(Ok(pair)).is_err() {
tracing::debug!("requester dropped");
}
}

if let Some(connection) = connection.as_mut() {
if let Some(request) = pending_request.take() {
match connection.open_bi().await {
Ok(pair) => {
tracing::debug!("Bidi substream opened");
if request.send(Ok(pair)).is_err() {
tracing::debug!("requester dropped");
}
}
Err(e) => {
tracing::warn!("error opening bidi substream: {}", e);
tracing::warn!("recreating connection");
// NOTE: the connection might be stale, so we recreate the
// connection and set the request as pending instead of
// sending the error as a response
reconnect.set_not_connected();
pending_request = Some(request);
}
}
Err(e) => {
tracing::warn!("error opening bidi substream: {}", e);
tracing::warn!("recreating connection");
// try again. Maybe delay?
continue 'outer;
}
}
}
Expand Down Expand Up @@ -392,6 +448,141 @@ impl<In: RpcMessage, Out: RpcMessage> QuinnConnection<In, Out> {
}
}

struct ReconnectHandler {
endpoint: quinn::Endpoint,
state: ConnectionState,
addr: SocketAddr,
name: String,
}

impl ReconnectHandler {
pub fn set_not_connected(&mut self) {
self.state.set_not_connected()
}

pub fn connected(&self) -> bool {
matches!(self.state, ConnectionState::Connected(_))
}
}

enum ConnectionState {
/// There is no active connection. An attempt to connect will be made.
NotConnected,
/// Connecting to the remote.
Connecting(quinn::Connecting),
/// A connection is already established. In this state, no more connection attempts are made.
Connected(quinn::Connection),
/// Intermediate state while processing.
Poisoned,
}

impl ConnectionState {
pub fn poison(&mut self) -> ConnectionState {
std::mem::replace(self, ConnectionState::Poisoned)
}

pub fn set_not_connected(&mut self) {
*self = ConnectionState::NotConnected
}
}

enum ReconnectErr {
Connect(quinn::ConnectError),
Connection(quinn::ConnectionError),
}

impl Future for ReconnectHandler {
type Output = Result<quinn::Connection, ReconnectErr>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.state.poison() {
ConnectionState::NotConnected => match self.endpoint.connect(self.addr, &self.name) {
Ok(connecting) => {
self.state = ConnectionState::Connecting(connecting);
self.poll(cx)
}
Err(e) => {
self.state = ConnectionState::NotConnected;
Poll::Ready(Err(ReconnectErr::Connect(e)))
}
},
ConnectionState::Connecting(mut connecting) => match connecting.poll_unpin(cx) {
Poll::Ready(res) => match res {
Ok(connection) => {
self.state = ConnectionState::Connected(connection.clone());
Poll::Ready(Ok(connection))
}
Err(e) => {
self.state = ConnectionState::NotConnected;
Poll::Ready(Err(ReconnectErr::Connection(e)))
}
},
Poll::Pending => {
self.state = ConnectionState::Connecting(connecting);
Poll::Pending
}
},
ConnectionState::Connected(connection) => {
self.state = ConnectionState::Connected(connection.clone());
Poll::Ready(Ok(connection))
}
ConnectionState::Poisoned => unreachable!("poisoned connection state"),
}
}
}

/// Wrapper over [`flume::Receiver`] that can be used with [`tokio::select`].
///
/// NOTE: from https://github.com/zesterer/flume/issues/104:
/// > If RecvFut is dropped without being polled, the item is never received.
enum Receiver<'a, T>
where
Self: 'a,
{
PreReceive(&'a flume::Receiver<T>),
Receiving(&'a flume::Receiver<T>, flume::r#async::RecvFut<'a, T>),
Poisoned,
}

impl<'a, T> Receiver<'a, T> {
fn new(recv: &'a flume::Receiver<T>) -> Self {
Receiver::PreReceive(recv)
}

fn poison(&mut self) -> Self {
std::mem::replace(self, Self::Poisoned)
}
}

impl<'a, T> futures::stream::Stream for Receiver<'a, T> {
type Item = T;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.poison() {
Receiver::PreReceive(recv) => {
let fut = recv.recv_async();
*self = Receiver::Receiving(recv, fut);
self.poll_next(cx)
}
Receiver::Receiving(recv, mut fut) => match fut.poll_unpin(cx) {
Poll::Ready(Ok(t)) => {
*self = Receiver::PreReceive(recv);
Poll::Ready(Some(t))
}
Poll::Ready(Err(flume::RecvError::Disconnected)) => {
*self = Receiver::PreReceive(recv);
Poll::Ready(None)
}
Poll::Pending => {
*self = Receiver::Receiving(recv, fut);
Poll::Pending
}
},
Receiver::Poisoned => unreachable!("poisoned receiver state"),
}
}
}

impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for QuinnConnection<In, Out> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ClientChannel")
Expand Down
32 changes: 32 additions & 0 deletions tests/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,38 @@ impl ComputeService {
}
}

/// Runs the service until `count` requests have been received.
pub async fn server_bounded<C: ServiceEndpoint<ComputeService>>(
server: RpcServer<ComputeService, C>,
count: usize,
) -> result::Result<RpcServer<ComputeService, C>, RpcServerError<C>> {
tracing::info!(%count, "server running");
let s = server;
let mut received = 0;
let service = ComputeService;
while received < count {
received += 1;
let (req, chan) = s.accept().await?;
let service = service.clone();
tokio::spawn(async move {
use ComputeRequest::*;
tracing::info!(?req, "got request");
#[rustfmt::skip]
match req {
Sqr(msg) => chan.rpc(msg, service, ComputeService::sqr).await,
Sum(msg) => chan.client_streaming(msg, service, ComputeService::sum).await,
Fibonacci(msg) => chan.server_streaming(msg, service, ComputeService::fibonacci).await,
Multiply(msg) => chan.bidi_streaming(msg, service, ComputeService::multiply).await,
SumUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
MultiplyUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
}?;
Ok::<_, RpcServerError<C>>(())
});
}
tracing::info!(%count, "server finished");
Ok(s)
}

pub async fn server_par<C: ServiceEndpoint<ComputeService>>(
server: RpcServer<ComputeService, C>,
parallelism: usize,
Expand Down
Loading

0 comments on commit 3323574

Please sign in to comment.