Skip to content

Commit

Permalink
push test for trt bidi
Browse files Browse the repository at this point in the history
  • Loading branch information
rklaehn committed Dec 20, 2024
1 parent 11d843b commit dda76e4
Showing 1 changed file with 109 additions and 0 deletions.
109 changes: 109 additions & 0 deletions tests/try_bidi.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#![cfg(feature = "flume-transport")]
use derive_more::{From, TryInto};
use futures_lite::{Stream, StreamExt};
use quic_rpc::{
message::Msg,
pattern::try_bidi_streaming::{StreamCreated, TryBidiStreaming, TryBidiStreamingMsg},
server::RpcServerError,
transport::flume,
RpcClient, RpcServer, Service,
};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone)]
struct TryService;

impl Service for TryService {
type Req = TryRequest;
type Res = TryResponse;
}

#[derive(Debug, Serialize, Deserialize)]
pub struct StreamN {
n: u64,
}

impl Msg<TryService> for StreamN {
type Pattern = TryBidiStreaming;
}

impl TryBidiStreamingMsg<TryService> for StreamN {
type Item = u64;
type ItemError = String;
type CreateError = String;
type Update = u64;
}

/// request enum
#[derive(Debug, Serialize, Deserialize, From, TryInto)]
pub enum TryRequest {
StreamN(StreamN),
StreamNUpdate(u64),
}

#[derive(Debug, Serialize, Deserialize, From, TryInto, Clone)]
pub enum TryResponse {
StreamN(std::result::Result<u64, String>),
StreamNError(std::result::Result<StreamCreated, String>),
}

#[derive(Clone)]
struct Handler;

impl Handler {
async fn try_stream_n(
self,
req: StreamN,
updates: impl Stream<Item = u64>,
) -> std::result::Result<impl Stream<Item = std::result::Result<u64, String>>, String> {
if req.n % 2 != 0 {
return Err("odd n not allowed".to_string());
}
let stream = async_stream::stream! {
for i in 0..req.n {
if i > 5 {
yield Err("n too large".to_string());
return;
}
yield Ok(i);
}
};
Ok(stream)
}
}

#[tokio::test]
async fn try_bidi_streaming() -> anyhow::Result<()> {
tracing_subscriber::fmt::try_init().ok();
let (server, client) = flume::channel(1);

let server = RpcServer::<TryService, _>::new(server);
let server_handle = tokio::task::spawn(async move {
loop {
let (req, chan) = server.accept().await?.read_first().await?;
let handler = Handler;
match req {
TryRequest::StreamN(req) => {
chan.try_bidi_streaming(req, handler, Handler::try_stream_n)
.await?;
}
TryRequest::StreamNUpdate(_) => {
return Err(RpcServerError::UnexpectedUpdateMessage);
}
}
}
#[allow(unreachable_code)]
Ok(())
});
let client = RpcClient::<TryService, _>::new(client);
let (stream_n, update_sink) = client.try_bidi_streaming(StreamN { n: 10 }).await?;
let items: Vec<_> = stream_n.collect().await;
println!("{:?}", items);
drop(client);
// dropping the client will cause the server to terminate
match server_handle.await? {
Err(RpcServerError::Accept(_)) => {}
e => panic!("unexpected termination result {e:?}"),
}
Ok(())
}

0 comments on commit dda76e4

Please sign in to comment.