Skip to content

Commit

Permalink
refactor: use two stage accept (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
rklaehn authored Jun 26, 2024
2 parents 8144fde + b3c37ff commit c2520b8
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 22 deletions.
2 changes: 1 addition & 1 deletion examples/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async fn main() -> anyhow::Result<()> {
let server = RpcServer::new(server);
let handle = tokio::task::spawn(async move {
for _ in 0..1 {
let (req, chan) = server.accept().await?;
let (req, chan) = server.accept().await?.read_first().await?;
match req {
IoRequest::Write(req) => chan.rpc_map_err(req, fs, Fs::write).await,
}?
Expand Down
5 changes: 4 additions & 1 deletion examples/modularize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ async fn main() -> Result<()> {
async fn run_server<C: ServiceEndpoint<AppService>>(server_conn: C, handler: app::Handler) {
let server = RpcServer::new(server_conn);
loop {
match server.accept().await {
let Ok(accepting) = server.accept().await else {
continue;
};
match accepting.read_first().await {
Err(err) => warn!(?err, "server accept failed"),
Ok((req, chan)) => {
let handler = handler.clone();
Expand Down
2 changes: 1 addition & 1 deletion examples/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ async fn main() -> anyhow::Result<()> {
let s = server;
let store = Store;
loop {
let (req, chan) = s.accept().await?;
let (req, chan) = s.accept().await?.read_first().await?;
use StoreRequest::*;
let store = store.clone();
#[rustfmt::skip]
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
//! let handler = Handler;
//! loop {
//! // accept connections
//! let (msg, chan) = server.accept().await?;
//! let (msg, chan) = server.accept().await?.read_first().await?;
//! // dispatch the message to the appropriate handler
//! match msg {
//! PingRequest::Ping(ping) => chan.rpc(ping, handler, Handler::ping).await?,
Expand Down
34 changes: 24 additions & 10 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,14 @@ where
}
}

impl<S: Service, C: ServiceEndpoint<S>> RpcServer<S, C> {
/// Accepts a new channel from a client, and reads the first request.
/// The result of accepting a new connection.
pub struct Accepting<S: Service, C: ServiceEndpoint<S>> {
send: C::SendSink,
recv: C::RecvStream,
}

impl<S: Service, C: ServiceEndpoint<S>> Accepting<S, C> {
/// Read the first message from the client.
///
/// The return value is a tuple of `(request, channel)`. Here `request` is the
/// first request which is already read from the stream. The `channel` is a
Expand All @@ -127,13 +133,8 @@ impl<S: Service, C: ServiceEndpoint<S>> RpcServer<S, C> {
///
/// Often sink and stream will wrap an an underlying byte stream. In this case you can
/// call into_inner() on them to get it back to perform byte level reads and writes.
pub async fn accept(&self) -> result::Result<(S::Req, RpcChannel<S, C>), RpcServerError<C>> {
let (send, mut recv) = self
.source
.accept_bi()
.await
.map_err(RpcServerError::Accept)?;

pub async fn read_first(self) -> result::Result<(S::Req, RpcChannel<S, C>), RpcServerError<C>> {
let Accepting { send, mut recv } = self;
// get the first message from the client. This will tell us what it wants to do.
let request: S::Req = recv
.next()
Expand All @@ -144,6 +145,19 @@ impl<S: Service, C: ServiceEndpoint<S>> RpcServer<S, C> {
.map_err(RpcServerError::RecvError)?;
Ok((request, RpcChannel::new(send, recv)))
}
}

impl<S: Service, C: ServiceEndpoint<S>> RpcServer<S, C> {
/// Accepts a new channel from a client. The result is an [Accepting] object that
/// can be used to read the first request.
pub async fn accept(&self) -> result::Result<Accepting<S, C>, RpcServerError<C>> {
let (send, recv) = self
.source
.accept_bi()
.await
.map_err(RpcServerError::Accept)?;
Ok(Accepting { send, recv })
}

/// Get the underlying service endpoint
pub fn into_inner(self) -> C {
Expand Down Expand Up @@ -309,7 +323,7 @@ where
{
let server = RpcServer::<S, C>::new(conn);
loop {
let (req, chan) = server.accept().await?;
let (req, chan) = server.accept().await?.read_first().await?;
let target = target.clone();
handler(chan, req, target).await?;
}
Expand Down
2 changes: 1 addition & 1 deletion tests/flume.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async fn flume_channel_mapped_bench() -> anyhow::Result<()> {
tokio::task::spawn(async move {
let service = ComputeService;
loop {
let (req, chan) = server.accept().await?;
let (req, chan) = server.accept().await?.read_first().await?;
let service = service.clone();
tokio::spawn(async move {
let req: OuterRequest = req;
Expand Down
6 changes: 4 additions & 2 deletions tests/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,10 @@ async fn hyper_channel_errors() -> anyhow::Result<()> {
let (res_tx, res_rx) = flume::unbounded();
let handle = tokio::spawn(async move {
loop {
let x = server.accept().await;
let res = match x {
let Ok(x) = server.accept().await else {
continue;
};
let res = match x.read_first().await {
Ok((req, chan)) => match req {
TestRequest::BigRequest(req) => {
chan.rpc(req, TestService, TestService::big).await
Expand Down
6 changes: 3 additions & 3 deletions tests/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ impl ComputeService {
let s = server;
let service = ComputeService;
loop {
let (req, chan) = s.accept().await?;
let (req, chan) = s.accept().await?.read_first().await?;
let service = service.clone();
tokio::spawn(async move { Self::handle_rpc_request(service, req, chan).await });
}
Expand Down Expand Up @@ -206,7 +206,7 @@ impl ComputeService {
let service = ComputeService;
while received < count {
received += 1;
let (req, chan) = s.accept().await?;
let (req, chan) = s.accept().await?.read_first().await?;
let service = service.clone();
tokio::spawn(async move {
use ComputeRequest::*;
Expand Down Expand Up @@ -236,7 +236,7 @@ impl ComputeService {
let service = ComputeService;
let request_stream = stream! {
loop {
yield s2.accept().await;
yield s2.accept().await?.read_first().await;
}
};
let process_stream = request_stream.map(move |r| {
Expand Down
2 changes: 1 addition & 1 deletion tests/slow_math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl ComputeService {
let s = server;
let service = ComputeService;
loop {
let (req, chan) = s.accept().await?;
let (req, chan) = s.accept().await?.read_first().await?;
use ComputeRequest::*;
let service = service.clone();
#[rustfmt::skip]
Expand Down
2 changes: 1 addition & 1 deletion tests/try.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async fn try_server_streaming() -> anyhow::Result<()> {
let server = RpcServer::<TryService, _>::new(server);
let server_handle = tokio::task::spawn(async move {
loop {
let (req, chan) = server.accept().await?;
let (req, chan) = server.accept().await?.read_first().await?;
let handler = Handler;
match req {
TryRequest::StreamN(req) => {
Expand Down

0 comments on commit c2520b8

Please sign in to comment.