Skip to content

Commit

Permalink
Add TryServerStreaming interaction pattern
Browse files Browse the repository at this point in the history
also reorganize this into a separate pattern dir.
  • Loading branch information
rklaehn committed Mar 26, 2024
1 parent aa598a4 commit 4a04a51
Show file tree
Hide file tree
Showing 6 changed files with 326 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ pub type BoxStreamSync<'a, T> = Pin<Box<dyn Stream<Item = T> + Send + Sync + 'a>
/// for the client DSL. `S` is the service type, `C` is the substream source.
#[derive(Debug)]
pub struct RpcClient<S, C, SInner = S> {
source: C,
map: Arc<dyn MapService<S, SInner>>,
pub(crate) source: C,
pub(crate) map: Arc<dyn MapService<S, SInner>>,
}

impl<S, C: Clone, SInner> Clone for RpcClient<S, C, SInner> {
Expand Down Expand Up @@ -405,7 +405,7 @@ impl<S: ConnectionErrors> error::Error for StreamingResponseItemError<S> {}

/// Wrap a stream with an additional item that is kept alive until the stream is dropped
#[pin_project]
struct DeferDrop<S: Stream, X>(#[pin] S, X);
pub(crate) struct DeferDrop<S: Stream, X>(#[pin] pub S, pub X);

impl<S: Stream, X> Stream for DeferDrop<S, X> {
type Item = S::Item;
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ pub use server::RpcServer;
mod macros;
mod map;

pub mod pattern;

/// Requirements for a RPC message
///
/// Even when just using the mem transport, we require messages to be Serializable and Deserializable.
Expand Down
2 changes: 2 additions & 0 deletions src/pattern/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
//!
pub mod try_server_streaming;
214 changes: 214 additions & 0 deletions src/pattern/try_server_streaming.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
//!
use futures::{FutureExt, SinkExt, Stream, StreamExt, TryFutureExt};
use serde::{Deserialize, Serialize};

use crate::{
client::{BoxStreamSync, DeferDrop},
message::{InteractionPattern, Msg},
server::{race2, RpcChannel, RpcServerError},
transport::ConnectionErrors,
RpcClient, Service, ServiceConnection, ServiceEndpoint,
};

use std::{
error,
fmt::{self, Debug},
result,
sync::Arc,
};

///
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct StreamCreated;

///
#[derive(Debug, Clone, Copy)]
pub struct TryServerStreaming;

impl InteractionPattern for TryServerStreaming {}

/// Same as [ServerStreamingMsg], but with lazy stream creation and the error type explicitly defined.
pub trait TryServerStreamingMsg<S: Service>: Msg<S, Pattern = TryServerStreaming>
where
std::result::Result<Self::Item, Self::ItemError>: Into<S::Res> + TryFrom<S::Res>,
std::result::Result<StreamCreated, Self::CreateError>: Into<S::Res> + TryFrom<S::Res>,
{
/// Error when creating the stream
type CreateError: Debug + Send + 'static;

/// Error for stream items
type ItemError: Debug + Send + 'static;

/// The type for the response
///
/// For requests that can produce errors, this can be set to [Result<T, E>](std::result::Result).
type Item: Send + 'static;
}

/// Server error when accepting a server streaming request
///
/// This combines network errors with application errors. Usually you don't
/// care about the exact nature of the error, but if you want to handle
/// application errors differently, you can match on this enum.
#[derive(Debug)]
pub enum Error<C: ConnectionErrors, E: Debug> {
/// Unable to open a substream at all
Open(C::OpenError),
/// Unable to send the request to the server
Send(C::SendError),
/// Error received when creating the stream
Recv(C::RecvError),
/// Connection was closed before receiving the first message
EarlyClose,
/// Unexpected response from the server
Downcast,
/// Application error
Application(E),
}

impl<S: ConnectionErrors, E: Debug> fmt::Display for Error<S, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}

impl<S: ConnectionErrors, E: Debug> error::Error for Error<S, E> {}

/// Client error when handling responses from a server streaming request
#[derive(Debug)]
pub enum ItemError<S: ConnectionErrors, E: Debug> {
/// Unable to receive the response from the server
Recv(S::RecvError),
/// Unexpected response from the server
Downcast,
/// Application error
Application(E),
}

impl<S: ConnectionErrors, E: Debug> fmt::Display for ItemError<S, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}

impl<S: ConnectionErrors, E: Debug> error::Error for ItemError<S, E> {}

impl<S, C, SInner> RpcChannel<S, C, SInner>
where
S: Service,
C: ServiceEndpoint<S>,
SInner: Service,
{
/// handle the message M using the given function on the target object
///
/// If you want to support concurrent requests, you need to spawn this on a tokio task yourself.
///
/// Compared to [RpcChannel::server_streaming], with this method the stream creation is via
/// a function that returns a future that resolves to a stream.
pub async fn try_server_streaming<M, F, Fut, Str, T>(
self,
req: M,
target: T,
f: F,
) -> result::Result<(), RpcServerError<C>>
where
M: TryServerStreamingMsg<SInner>,
std::result::Result<M::Item, M::ItemError>: Into<SInner::Res> + TryFrom<SInner::Res>,
std::result::Result<StreamCreated, M::CreateError>:
Into<SInner::Res> + TryFrom<SInner::Res>,
F: FnOnce(T, M) -> Fut + Send + 'static,
Fut: futures::Future<Output = std::result::Result<Str, M::CreateError>> + Send + 'static,
Str: Stream<Item = std::result::Result<M::Item, M::ItemError>> + Send + 'static,
T: Send + 'static,
{
let Self {
mut send, mut recv, ..
} = self;
// cancel if we get an update, no matter what it is
let cancel = recv
.next()
.map(|_| RpcServerError::UnexpectedUpdateMessage::<C>);
// race the computation and the cancellation
race2(cancel.map(Err), async move {
// get the response
let responses = match f(target, req).await {
Ok(responses) => {
// turn into a S::Res so we can send it
let response = self.map.res_into_outer(Ok(StreamCreated).into());
// send it and return the error if any
send.send(response)
.await
.map_err(RpcServerError::SendError)?;
responses
}
Err(cause) => {
// turn into a S::Res so we can send it
let response = self.map.res_into_outer(Err(cause).into());
// send it and return the error if any
send.send(response)
.await
.map_err(RpcServerError::SendError)?;
return Ok(());
}
};
tokio::pin!(responses);
while let Some(response) = responses.next().await {
// turn into a S::Res so we can send it
let response = self.map.res_into_outer(response.into());
// send it and return the error if any
send.send(response)
.await
.map_err(RpcServerError::SendError)?;
}
Ok(())
})
.await
}
}

impl<S, C, SInner> RpcClient<S, C, SInner>
where
S: Service,
C: ServiceConnection<S>,
SInner: Service,
{
/// Bidi call to the server, request opens a stream, response is a stream
pub async fn try_server_streaming<M>(
&self,
msg: M,
) -> result::Result<
BoxStreamSync<'static, Result<M::Item, ItemError<C, M::ItemError>>>,
Error<C, M::CreateError>,
>
where
M: TryServerStreamingMsg<SInner>,
Result<M::Item, M::ItemError>: Into<SInner::Res> + TryFrom<SInner::Res>,
Result<StreamCreated, M::CreateError>: Into<SInner::Res> + TryFrom<SInner::Res>,
{
let msg = self.map.req_into_outer(msg.into());
let (mut send, mut recv) = self.source.open_bi().await.map_err(Error::Open)?;
send.send(msg).map_err(Error::Send).await?;
let map = Arc::clone(&self.map);
let Some(initial) = recv.next().await else {
return Err(Error::EarlyClose);
};
let initial = initial.map_err(Error::Recv)?; // initial response
let initial = map
.res_try_into_inner(initial)
.map_err(|_| Error::Downcast)?;
let initial = <std::result::Result<StreamCreated, M::CreateError>>::try_from(initial)
.map_err(|_| Error::Downcast)?;
let _ = initial.map_err(Error::Application)?;
let recv = recv.map(move |x| {
let x = x.map_err(ItemError::Recv)?;
let x = map.res_try_into_inner(x).map_err(|_| ItemError::Downcast)?;
let x = <std::result::Result<M::Item, M::ItemError>>::try_from(x)
.map_err(|_| ItemError::Downcast)?;
let x = x.map_err(ItemError::Application)?;
Ok(x)
});
// keep send alive so the request on the server side does not get cancelled
let recv = Box::pin(DeferDrop(recv, send));
Ok(recv)
}
}
4 changes: 2 additions & 2 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub struct RpcChannel<S: Service, C: ServiceEndpoint<S>, SInner: Service = S> {
/// Stream to receive requests from the client.
pub recv: C::RecvStream,
/// Mapper to map between S and S2
map: Arc<dyn MapService<S, SInner>>,
pub map: Arc<dyn MapService<S, SInner>>,
}

impl<S, C> RpcChannel<S, C, S>
Expand Down Expand Up @@ -447,7 +447,7 @@ impl<T> Future for UnwrapToPending<T> {
}
}

async fn race2<T, A: Future<Output = T>, B: Future<Output = T>>(f1: A, f2: B) -> T {
pub(crate) async fn race2<T, A: Future<Output = T>, B: Future<Output = T>>(f1: A, f2: B) -> T {
tokio::select! {
x = f1 => x,
x = f2 => x,
Expand Down
103 changes: 103 additions & 0 deletions tests/try.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#![cfg(feature = "flume-transport")]
use derive_more::{From, TryInto};
use futures::{Stream, StreamExt};
use quic_rpc::{
message::Msg,
pattern::try_server_streaming::{StreamCreated, TryServerStreaming, TryServerStreamingMsg},
server::RpcServerError,
transport::flume,
RpcClient, RpcServer, Service,
};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone)]
struct TryService;

impl Service for TryService {
type Req = TryRequest;
type Res = TryResponse;
}

#[derive(Debug, Serialize, Deserialize)]
pub struct StreamN {
n: u64,
}

impl Msg<TryService> for StreamN {
type Pattern = TryServerStreaming;
}

impl TryServerStreamingMsg<TryService> for StreamN {
type Item = u64;
type ItemError = String;
type CreateError = String;
}

/// request enum
#[derive(Debug, Serialize, Deserialize, From, TryInto)]
pub enum TryRequest {
StreamN(StreamN),
}

#[derive(Debug, Serialize, Deserialize, From, TryInto, Clone)]
pub enum TryResponse {
StreamN(std::result::Result<u64, String>),
StreamNError(std::result::Result<StreamCreated, String>),
}

#[derive(Clone)]
struct Handler;

impl Handler {
async fn try_stream_n(
self,
req: StreamN,
) -> std::result::Result<impl Stream<Item = std::result::Result<u64, String>>, String> {
if req.n % 2 != 0 {
return Err("odd n not allowed".to_string());
}
let stream = async_stream::stream! {
for i in 0..req.n {
if i > 5 {
yield Err("n too large".to_string());
return;
}
yield Ok(i);
}
};
Ok(stream)
}
}

#[tokio::test]
async fn try_server_streaming() -> anyhow::Result<()> {
tracing_subscriber::fmt::try_init().ok();
let (server, client) = flume::connection::<TryRequest, TryResponse>(1);

let server = RpcServer::<TryService, _>::new(server);
let server_handle = tokio::task::spawn(async move {
loop {
let (req, chan) = server.accept().await?;
let handler = Handler;
match req {
TryRequest::StreamN(req) => {
chan.try_server_streaming(req, handler, Handler::try_stream_n)
.await?;
}
}
}
#[allow(unreachable_code)]
Ok(())
});
let client = RpcClient::<TryService, _>::new(client);
let stream_n = client.try_server_streaming(StreamN { n: 10 }).await?;
let items: Vec<_> = stream_n.collect().await;
println!("{:?}", items);
drop(client);
// dropping the client will cause the server to terminate
match server_handle.await? {
Err(RpcServerError::Accept(_)) => {}
e => panic!("unexpected termination result {e:?}"),
}
Ok(())
}

0 comments on commit 4a04a51

Please sign in to comment.