diff --git a/src/client.rs b/src/client.rs index 1c412c7..513f0ec 100644 --- a/src/client.rs +++ b/src/client.rs @@ -28,8 +28,8 @@ pub type BoxStreamSync<'a, T> = Pin + Send + Sync + 'a> /// for the client DSL. `S` is the service type, `C` is the substream source. #[derive(Debug)] pub struct RpcClient { - source: C, - map: Arc>, + pub(crate) source: C, + pub(crate) map: Arc>, } impl Clone for RpcClient { @@ -405,7 +405,7 @@ impl error::Error for StreamingResponseItemError {} /// Wrap a stream with an additional item that is kept alive until the stream is dropped #[pin_project] -struct DeferDrop(#[pin] S, X); +pub(crate) struct DeferDrop(#[pin] pub S, pub X); impl Stream for DeferDrop { type Item = S::Item; diff --git a/src/lib.rs b/src/lib.rs index 8c6464d..bc92ea3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -104,6 +104,8 @@ pub use server::RpcServer; mod macros; mod map; +pub mod pattern; + /// Requirements for a RPC message /// /// Even when just using the mem transport, we require messages to be Serializable and Deserializable. diff --git a/src/pattern/mod.rs b/src/pattern/mod.rs new file mode 100644 index 0000000..5826152 --- /dev/null +++ b/src/pattern/mod.rs @@ -0,0 +1,2 @@ +//! +pub mod try_server_streaming; diff --git a/src/pattern/try_server_streaming.rs b/src/pattern/try_server_streaming.rs new file mode 100644 index 0000000..f2e9d52 --- /dev/null +++ b/src/pattern/try_server_streaming.rs @@ -0,0 +1,214 @@ +//! +use futures::{FutureExt, SinkExt, Stream, StreamExt, TryFutureExt}; +use serde::{Deserialize, Serialize}; + +use crate::{ + client::{BoxStreamSync, DeferDrop}, + message::{InteractionPattern, Msg}, + server::{race2, RpcChannel, RpcServerError}, + transport::ConnectionErrors, + RpcClient, Service, ServiceConnection, ServiceEndpoint, +}; + +use std::{ + error, + fmt::{self, Debug}, + result, + sync::Arc, +}; + +/// +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub struct StreamCreated; + +/// +#[derive(Debug, Clone, Copy)] +pub struct TryServerStreaming; + +impl InteractionPattern for TryServerStreaming {} + +/// Same as [ServerStreamingMsg], but with lazy stream creation and the error type explicitly defined. +pub trait TryServerStreamingMsg: Msg +where + std::result::Result: Into + TryFrom, + std::result::Result: Into + TryFrom, +{ + /// Error when creating the stream + type CreateError: Debug + Send + 'static; + + /// Error for stream items + type ItemError: Debug + Send + 'static; + + /// The type for the response + /// + /// For requests that can produce errors, this can be set to [Result](std::result::Result). + type Item: Send + 'static; +} + +/// Server error when accepting a server streaming request +/// +/// This combines network errors with application errors. Usually you don't +/// care about the exact nature of the error, but if you want to handle +/// application errors differently, you can match on this enum. +#[derive(Debug)] +pub enum Error { + /// Unable to open a substream at all + Open(C::OpenError), + /// Unable to send the request to the server + Send(C::SendError), + /// Error received when creating the stream + Recv(C::RecvError), + /// Connection was closed before receiving the first message + EarlyClose, + /// Unexpected response from the server + Downcast, + /// Application error + Application(E), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl error::Error for Error {} + +/// Client error when handling responses from a server streaming request +#[derive(Debug)] +pub enum ItemError { + /// Unable to receive the response from the server + Recv(S::RecvError), + /// Unexpected response from the server + Downcast, + /// Application error + Application(E), +} + +impl fmt::Display for ItemError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl error::Error for ItemError {} + +impl RpcChannel +where + S: Service, + C: ServiceEndpoint, + SInner: Service, +{ + /// handle the message M using the given function on the target object + /// + /// If you want to support concurrent requests, you need to spawn this on a tokio task yourself. + /// + /// Compared to [RpcChannel::server_streaming], with this method the stream creation is via + /// a function that returns a future that resolves to a stream. + pub async fn try_server_streaming( + self, + req: M, + target: T, + f: F, + ) -> result::Result<(), RpcServerError> + where + M: TryServerStreamingMsg, + std::result::Result: Into + TryFrom, + std::result::Result: + Into + TryFrom, + F: FnOnce(T, M) -> Fut + Send + 'static, + Fut: futures::Future> + Send + 'static, + Str: Stream> + Send + 'static, + T: Send + 'static, + { + let Self { + mut send, mut recv, .. + } = self; + // cancel if we get an update, no matter what it is + let cancel = recv + .next() + .map(|_| RpcServerError::UnexpectedUpdateMessage::); + // race the computation and the cancellation + race2(cancel.map(Err), async move { + // get the response + let responses = match f(target, req).await { + Ok(responses) => { + // turn into a S::Res so we can send it + let response = self.map.res_into_outer(Ok(StreamCreated).into()); + // send it and return the error if any + send.send(response) + .await + .map_err(RpcServerError::SendError)?; + responses + } + Err(cause) => { + // turn into a S::Res so we can send it + let response = self.map.res_into_outer(Err(cause).into()); + // send it and return the error if any + send.send(response) + .await + .map_err(RpcServerError::SendError)?; + return Ok(()); + } + }; + tokio::pin!(responses); + while let Some(response) = responses.next().await { + // turn into a S::Res so we can send it + let response = self.map.res_into_outer(response.into()); + // send it and return the error if any + send.send(response) + .await + .map_err(RpcServerError::SendError)?; + } + Ok(()) + }) + .await + } +} + +impl RpcClient +where + S: Service, + C: ServiceConnection, + SInner: Service, +{ + /// Bidi call to the server, request opens a stream, response is a stream + pub async fn try_server_streaming( + &self, + msg: M, + ) -> result::Result< + BoxStreamSync<'static, Result>>, + Error, + > + where + M: TryServerStreamingMsg, + Result: Into + TryFrom, + Result: Into + TryFrom, + { + let msg = self.map.req_into_outer(msg.into()); + let (mut send, mut recv) = self.source.open_bi().await.map_err(Error::Open)?; + send.send(msg).map_err(Error::Send).await?; + let map = Arc::clone(&self.map); + let Some(initial) = recv.next().await else { + return Err(Error::EarlyClose); + }; + let initial = initial.map_err(Error::Recv)?; // initial response + let initial = map + .res_try_into_inner(initial) + .map_err(|_| Error::Downcast)?; + let initial = >::try_from(initial) + .map_err(|_| Error::Downcast)?; + let _ = initial.map_err(Error::Application)?; + let recv = recv.map(move |x| { + let x = x.map_err(ItemError::Recv)?; + let x = map.res_try_into_inner(x).map_err(|_| ItemError::Downcast)?; + let x = >::try_from(x) + .map_err(|_| ItemError::Downcast)?; + let x = x.map_err(ItemError::Application)?; + Ok(x) + }); + // keep send alive so the request on the server side does not get cancelled + let recv = Box::pin(DeferDrop(recv, send)); + Ok(recv) + } +} diff --git a/src/server.rs b/src/server.rs index d5be6aa..0b71b4e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -68,7 +68,7 @@ pub struct RpcChannel, SInner: Service = S> { /// Stream to receive requests from the client. pub recv: C::RecvStream, /// Mapper to map between S and S2 - map: Arc>, + pub map: Arc>, } impl RpcChannel @@ -447,7 +447,7 @@ impl Future for UnwrapToPending { } } -async fn race2, B: Future>(f1: A, f2: B) -> T { +pub(crate) async fn race2, B: Future>(f1: A, f2: B) -> T { tokio::select! { x = f1 => x, x = f2 => x, diff --git a/tests/try.rs b/tests/try.rs new file mode 100644 index 0000000..be20774 --- /dev/null +++ b/tests/try.rs @@ -0,0 +1,103 @@ +#![cfg(feature = "flume-transport")] +use derive_more::{From, TryInto}; +use futures::{Stream, StreamExt}; +use quic_rpc::{ + message::Msg, + pattern::try_server_streaming::{StreamCreated, TryServerStreaming, TryServerStreamingMsg}, + server::RpcServerError, + transport::flume, + RpcClient, RpcServer, Service, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone)] +struct TryService; + +impl Service for TryService { + type Req = TryRequest; + type Res = TryResponse; +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct StreamN { + n: u64, +} + +impl Msg for StreamN { + type Pattern = TryServerStreaming; +} + +impl TryServerStreamingMsg for StreamN { + type Item = u64; + type ItemError = String; + type CreateError = String; +} + +/// request enum +#[derive(Debug, Serialize, Deserialize, From, TryInto)] +pub enum TryRequest { + StreamN(StreamN), +} + +#[derive(Debug, Serialize, Deserialize, From, TryInto, Clone)] +pub enum TryResponse { + StreamN(std::result::Result), + StreamNError(std::result::Result), +} + +#[derive(Clone)] +struct Handler; + +impl Handler { + async fn try_stream_n( + self, + req: StreamN, + ) -> std::result::Result>, String> { + if req.n % 2 != 0 { + return Err("odd n not allowed".to_string()); + } + let stream = async_stream::stream! { + for i in 0..req.n { + if i > 5 { + yield Err("n too large".to_string()); + return; + } + yield Ok(i); + } + }; + Ok(stream) + } +} + +#[tokio::test] +async fn try_server_streaming() -> anyhow::Result<()> { + tracing_subscriber::fmt::try_init().ok(); + let (server, client) = flume::connection::(1); + + let server = RpcServer::::new(server); + let server_handle = tokio::task::spawn(async move { + loop { + let (req, chan) = server.accept().await?; + let handler = Handler; + match req { + TryRequest::StreamN(req) => { + chan.try_server_streaming(req, handler, Handler::try_stream_n) + .await?; + } + } + } + #[allow(unreachable_code)] + Ok(()) + }); + let client = RpcClient::::new(client); + let stream_n = client.try_server_streaming(StreamN { n: 10 }).await?; + let items: Vec<_> = stream_n.collect().await; + println!("{:?}", items); + drop(client); + // dropping the client will cause the server to terminate + match server_handle.await? { + Err(RpcServerError::Accept(_)) => {} + e => panic!("unexpected termination result {e:?}"), + } + Ok(()) +}