Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: rpc client concurrently waits for requests and new connections #62

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 226 additions & 35 deletions src/transport/quinn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,47 +301,103 @@ impl<In: RpcMessage, Out: RpcMessage> QuinnConnection<In, Out> {
name: String,
requests: flume::Receiver<oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>>,
) {
'outer: loop {
tracing::debug!("Connecting to {} as {}", addr, name);
let connecting = match endpoint.connect(addr, &name) {
Ok(connecting) => connecting,
Err(e) => {
tracing::warn!("error calling connect: {}", e);
// try again. Maybe delay?
continue;
let reconnect = ReconnectHandler {
endpoint,
state: ConnectionState::NotConnected,
addr,
name,
};
futures::pin_mut!(reconnect);

let mut receiver = Receiver::new(&requests);

let mut pending_request: Option<
oneshot::Sender<Result<SocketInner, quinn::ConnectionError>>,
> = None;
let mut connection = None;

loop {
let mut conn_result = None;
let mut chann_result = None;
if !reconnect.connected() && pending_request.is_none() {
match futures::future::select(reconnect.as_mut(), receiver.next()).await {
futures::future::Either::Left((connection_result, _)) => {
conn_result = Some(connection_result)
}
futures::future::Either::Right((channel_result, _)) => {
chann_result = Some(channel_result);
}
}
};
let connection = match connecting.await {
Ok(connection) => connection,
Err(e) => {
tracing::warn!("error awaiting connect: {}", e);
// try again. Maybe delay?
continue;
} else if !reconnect.connected() {
// only need a new connection
conn_result = Some(reconnect.as_mut().await);
} else if pending_request.is_none() {
// there is a connection, just need a request
chann_result = Some(receiver.next().await);
}

if let Some(conn_result) = conn_result {
tracing::trace!("tick: connection result");
match conn_result {
Ok(new_connection) => {
connection = Some(new_connection);
}
Err(e) => {
let connection_err = match e {
ReconnectErr::Connect(e) => {
// TODO(@divma): the type for now accepts only a
// ConnectionError, not a ConnectError. I'm mapping this now to
// some ConnectionError since before it was not even reported.
// Maybe adjust the type?
tracing::warn!(%e, "error calling connect");
quinn::ConnectionError::Reset
}
ReconnectErr::Connection(e) => {
tracing::warn!(%e, "failed to connect");
e
}
};
if let Some(request) = pending_request.take() {
if request.send(Err(connection_err)).is_err() {
tracing::debug!("requester dropped");
}
}
}
}
};
loop {
tracing::debug!("Awaiting request for new bidi substream...");
let request = match requests.recv_async().await {
Ok(request) => request,
Err(_) => {
}

if let Some(req) = chann_result {
tracing::trace!("tick: bidi request");
match req {
Some(request) => pending_request = Some(request),
None => {
tracing::debug!("client dropped");
connection.close(0u32.into(), b"requester dropped");
if let Some(connection) = connection {
connection.close(0u32.into(), b"requester dropped");
}
break;
}
};
tracing::debug!("Got request for new bidi substream");
match connection.open_bi().await {
Ok(pair) => {
tracing::debug!("Bidi substream opened");
if request.send(Ok(pair)).is_err() {
tracing::debug!("requester dropped");
}
}

if let Some(connection) = connection.as_mut() {
if let Some(request) = pending_request.take() {
match connection.open_bi().await {
Ok(pair) => {
tracing::debug!("Bidi substream opened");
if request.send(Ok(pair)).is_err() {
tracing::debug!("requester dropped");
}
}
Err(e) => {
tracing::warn!("error opening bidi substream: {}", e);
tracing::warn!("recreating connection");
// NOTE: the connection might be stale, so we recreate the
// connection and set the request as pending instead of
// sending the error as a response
reconnect.set_not_connected();
pending_request = Some(request);
}
}
Err(e) => {
tracing::warn!("error opening bidi substream: {}", e);
tracing::warn!("recreating connection");
// try again. Maybe delay?
continue 'outer;
}
}
}
Expand Down Expand Up @@ -392,6 +448,141 @@ impl<In: RpcMessage, Out: RpcMessage> QuinnConnection<In, Out> {
}
}

struct ReconnectHandler {
endpoint: quinn::Endpoint,
state: ConnectionState,
addr: SocketAddr,
name: String,
}

impl ReconnectHandler {
pub fn set_not_connected(&mut self) {
self.state.set_not_connected()
}

pub fn connected(&self) -> bool {
matches!(self.state, ConnectionState::Connected(_))
}
}

enum ConnectionState {
/// There is no active connection. An attempt to connect will be made.
NotConnected,
/// Connecting to the remote.
Connecting(quinn::Connecting),
/// A connection is already established. In this state, no more connection attempts are made.
Connected(quinn::Connection),
/// Intermediate state while processing.
Poisoned,
}

impl ConnectionState {
pub fn poison(&mut self) -> ConnectionState {
std::mem::replace(self, ConnectionState::Poisoned)
}

pub fn set_not_connected(&mut self) {
*self = ConnectionState::NotConnected
}
}

enum ReconnectErr {
Connect(quinn::ConnectError),
Connection(quinn::ConnectionError),
}

impl Future for ReconnectHandler {
type Output = Result<quinn::Connection, ReconnectErr>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.state.poison() {
ConnectionState::NotConnected => match self.endpoint.connect(self.addr, &self.name) {
Ok(connecting) => {
self.state = ConnectionState::Connecting(connecting);
self.poll(cx)
}
Err(e) => {
self.state = ConnectionState::NotConnected;
Poll::Ready(Err(ReconnectErr::Connect(e)))
}
},
ConnectionState::Connecting(mut connecting) => match connecting.poll_unpin(cx) {
Poll::Ready(res) => match res {
Ok(connection) => {
self.state = ConnectionState::Connected(connection.clone());
Poll::Ready(Ok(connection))
}
Err(e) => {
self.state = ConnectionState::NotConnected;
Poll::Ready(Err(ReconnectErr::Connection(e)))
}
},
Poll::Pending => {
self.state = ConnectionState::Connecting(connecting);
Poll::Pending
}
},
ConnectionState::Connected(connection) => {
self.state = ConnectionState::Connected(connection.clone());
Poll::Ready(Ok(connection))
}
ConnectionState::Poisoned => unreachable!("poisoned connection state"),
}
}
}

/// Wrapper over [`flume::Receiver`] that can be used with [`tokio::select`].
///
/// NOTE: from https://github.com/zesterer/flume/issues/104:
/// > If RecvFut is dropped without being polled, the item is never received.
enum Receiver<'a, T>
where
Self: 'a,
{
PreReceive(&'a flume::Receiver<T>),
Receiving(&'a flume::Receiver<T>, flume::r#async::RecvFut<'a, T>),
Poisoned,
}

impl<'a, T> Receiver<'a, T> {
fn new(recv: &'a flume::Receiver<T>) -> Self {
Receiver::PreReceive(recv)
}

fn poison(&mut self) -> Self {
std::mem::replace(self, Self::Poisoned)
}
}

impl<'a, T> futures::stream::Stream for Receiver<'a, T> {
type Item = T;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.poison() {
Receiver::PreReceive(recv) => {
let fut = recv.recv_async();
*self = Receiver::Receiving(recv, fut);
self.poll_next(cx)
}
Receiver::Receiving(recv, mut fut) => match fut.poll_unpin(cx) {
Poll::Ready(Ok(t)) => {
*self = Receiver::PreReceive(recv);
Poll::Ready(Some(t))
}
Poll::Ready(Err(flume::RecvError::Disconnected)) => {
*self = Receiver::PreReceive(recv);
Poll::Ready(None)
}
Poll::Pending => {
*self = Receiver::Receiving(recv, fut);
Poll::Pending
}
},
Receiver::Poisoned => unreachable!("poisoned receiver state"),
}
}
}

impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for QuinnConnection<In, Out> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ClientChannel")
Expand Down
32 changes: 32 additions & 0 deletions tests/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,38 @@ impl ComputeService {
}
}

/// Runs the service until `count` requests have been received.
pub async fn server_bounded<C: ServiceEndpoint<ComputeService>>(
server: RpcServer<ComputeService, C>,
count: usize,
) -> result::Result<RpcServer<ComputeService, C>, RpcServerError<C>> {
tracing::info!(%count, "server running");
let s = server;
let mut received = 0;
let service = ComputeService;
while received < count {
received += 1;
let (req, chan) = s.accept().await?;
let service = service.clone();
tokio::spawn(async move {
use ComputeRequest::*;
tracing::info!(?req, "got request");
#[rustfmt::skip]
match req {
Sqr(msg) => chan.rpc(msg, service, ComputeService::sqr).await,
Sum(msg) => chan.client_streaming(msg, service, ComputeService::sum).await,
Fibonacci(msg) => chan.server_streaming(msg, service, ComputeService::fibonacci).await,
Multiply(msg) => chan.bidi_streaming(msg, service, ComputeService::multiply).await,
SumUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
MultiplyUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
}?;
Ok::<_, RpcServerError<C>>(())
});
}
tracing::info!(%count, "server finished");
Ok(s)
}

pub async fn server_par<C: ServiceEndpoint<ComputeService>>(
server: RpcServer<ComputeService, C>,
parallelism: usize,
Expand Down
Loading
Loading