From d3ebc14607b5cf21d12601bb36eea30cec74996a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20R=C3=BCdiger?= Date: Thu, 25 Mar 2021 11:25:35 +0100 Subject: [PATCH] typed response payload --- .gitignore | 1 + src/formats.rs | 6 +-- src/lib.rs | 117 ++++++++++++++++++++++++++++--------------------- 3 files changed, 70 insertions(+), 54 deletions(-) diff --git a/.gitignore b/.gitignore index ea8c4bf..de358ff 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +.vscode/ diff --git a/src/formats.rs b/src/formats.rs index 5bbbf97..5de42ad 100644 --- a/src/formats.rs +++ b/src/formats.rs @@ -38,9 +38,9 @@ pub struct RequestBody<'a> { #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(tag = "type")] #[serde(rename_all = "camelCase")] -pub enum Outgoing { +pub enum Outgoing { #[serde(rename_all = "camelCase")] - Next { request_id: ReqId, payload: Value }, + Next { request_id: ReqId, payload: Resp }, #[serde(rename_all = "camelCase")] Complete { request_id: ReqId }, #[serde(rename_all = "camelCase")] @@ -64,7 +64,7 @@ pub enum ErrorKind { }, } -impl Outgoing { +impl Outgoing { #[cfg(test)] pub fn request_id(&self) -> ReqId { match self { diff --git a/src/lib.rs b/src/lib.rs index 56d80e3..05555ac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,7 +38,7 @@ const INTER_STREAM_FAIRNESS: u64 = 64; pub trait Service { type Req: DeserializeOwned; - type Resp: Serialize + 'static; + type Resp: Serialize + Send + 'static; type Error: Serialize + 'static; type Ctx: Clone; @@ -48,7 +48,7 @@ pub trait Service { req: Self::Req, ) -> BoxStream<'static, Result>; - fn boxed(self) -> BoxedService + fn boxed(self) -> BoxedService where Self: Send + Sized + Sync + 'static, { @@ -56,20 +56,20 @@ pub trait Service { } } -pub trait WebsocketService { +pub trait WebsocketService { fn serve_ws( &self, ctx: Ctx, raw_req: Value, service_id: &str, - ) -> BoxStream<'static, Result>; + ) -> BoxStream<'static, Result>; } -impl WebsocketService for S +impl WebsocketService for S where S: Service, Req: DeserializeOwned, - Resp: Serialize + 'static, + Resp: Serialize + Send + 'static, Ctx: Clone, { fn serve_ws( @@ -77,7 +77,7 @@ where ctx: Ctx, raw_req: Value, service_id: &str, - ) -> BoxStream<'static, Result> { + ) -> BoxStream<'static, Result> { trace!( "Serving raw request for service {}: {:?}", service_id, @@ -86,16 +86,11 @@ where match serde_json::from_value(raw_req) { Ok(req) => self .serve(ctx, req) - .map(|resp_result| { - resp_result - .map(|resp| { - serde_json::to_value(&resp) - .expect("Could not serialize service response") - }) - .map_err(|err| ErrorKind::ServiceError { - value: serde_json::to_value(&err) - .expect("Could not serialize service error response"), - }) + .map(|res| { + res.map_err(|err| ErrorKind::ServiceError { + value: serde_json::to_value(&err) + .expect("Could not serialize service error response"), + }) }) .boxed(), Err(cause) => { @@ -110,11 +105,11 @@ where } } -pub type BoxedService = Box + Send + Sync>; +pub type BoxedService = Box + Send + Sync>; -pub async fn serve( +pub async fn serve( ws: warp::ws::Ws, - services: Arc>>, + services: Arc>>, ctx: Ctx, ) -> Result { // Set the max frame size to 64 MB (defaults to 16 MB which we have hit at CTA) @@ -126,11 +121,10 @@ pub async fn serve( // on_upgrade does not take in errors any longer } -#[allow(clippy::cognitive_complexity)] -fn client_connected( +fn client_connected( ws: WebSocket, ctx: Ctx, - services: Arc>>, + services: Arc>>, ) -> impl Future> { let (ws_out, ws_in) = ws.split(); @@ -225,8 +219,6 @@ fn client_connected( }) } -// Wtf, clippy? -#[allow(clippy::cognitive_complexity)] fn cancel_response_stream(snd_cancel: oneshot::Sender<()>) { if snd_cancel.is_canceled() { trace!("Not trying to cancel response stream whose cancel rcv has already dropped") @@ -250,44 +242,41 @@ fn cancel_response_streams_close_channel( mux_in.close_channel(); } -fn serve_request_stream( - srv: &BoxedService, +fn serve_request_stream( + srv: &BoxedService, ctx: Ctx, service_id: &str, - req_id: ReqId, + request_id: ReqId, payload: Value, ) -> impl Stream> { let resp_stream = srv .serve_ws(ctx, payload, service_id) - .map(move |payload_result| match payload_result { + .map(move |res| match res { Ok(payload) => Outgoing::Next { - request_id: req_id, + request_id, payload, }, - Err(kind) => Outgoing::Error { - request_id: req_id, - kind, - }, + Err(kind) => Outgoing::Error { request_id, kind }, }); AssertUnwindSafe(resp_stream) .catch_unwind() - .map(move |msg_result| match msg_result { + .map(move |res| match res { Ok(msg) => msg, Err(_) => Outgoing::Error { - request_id: req_id, + request_id, kind: ErrorKind::InternalError, }, }) .chain(stream::once(future::ready(Outgoing::Complete { - request_id: req_id, + request_id, }))) .map(|env| Ok(Message::text(serde_json::to_string(&env).unwrap()))) } -fn serve_request( +fn serve_request( canceled: oneshot::Receiver<()>, - srv: &BoxedService, + srv: &BoxedService, ctx: Ctx, service_id: &str, req_id: ReqId, @@ -317,7 +306,7 @@ fn serve_error( error_kind: ErrorKind, output: impl Sink>, ) -> impl Future { - let msg = Outgoing::Error { + let msg: Outgoing<()> = Outgoing::Error { request_id: req_id, kind: error_kind, }; @@ -361,7 +350,7 @@ mod tests { bad_field: String, } - #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] + #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)] struct Response(u64); struct TestService(); @@ -403,12 +392,38 @@ mod tests { } } - fn test_client( + #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)] + struct Response2(u64); + + struct TestService2(); + + impl TestService2 { + fn new() -> TestService2 { + TestService2() + } + } + + impl Service for TestService2 { + type Req = String; + type Resp = String; + type Error = String; + type Ctx = (); + + fn serve( + &self, + _ctx: (), + req: Self::Req, + ) -> BoxStream<'static, Result> { + stream::once(future::ok(req.chars().rev().collect::())).boxed() + } + } + + fn test_client( addr: SocketAddr, endpoint: &str, id: u64, req: Req, - ) -> (Vec, Outgoing) { + ) -> (Vec, Outgoing) { let addr = format!("ws://{}/test_ws", addr); let client = ClientBuilder::new(&*addr) .expect("Could not setup client") @@ -430,14 +445,14 @@ mod tests { .send_message(&OwnedMessage::Text(req_env_json)) .expect("Could not send request"); - let mut completion: Option = None; + let mut completion: Option> = None; let msgs = receiver .incoming_messages() .filter_map(move |msg| { let msg_ok = msg.expect("Expected message but got websocket error"); if let OwnedMessage::Text(raw_resp) = msg_ok { - let resp_env: Outgoing = serde_json::from_str(&*raw_resp) + let resp_env: Outgoing = serde_json::from_str(&*raw_resp) .expect("Could not deserialize response envelope"); if resp_env.request_id().0 == id { Some(resp_env) @@ -452,16 +467,13 @@ mod tests { if let Outgoing::Next { .. } = env { true } else { - completion = Some(env.clone()); + completion = Some(env.to_owned()); false } }) .filter_map(|env| { if let Outgoing::Next { payload, .. } = env { - Some( - serde_json::from_value::(payload) - .expect("Could not deserialize response"), - ) + Some(payload) } else { None } @@ -471,7 +483,10 @@ mod tests { } async fn start_test_service() -> SocketAddr { - let services = Arc::new(maplit::btreemap! {"test" => TestService::new().boxed()}); + let services = Arc::new(maplit::btreemap! { + "test" => TestService::new().boxed(), + "test2" => TestService2::new().boxed(), + }); let ws = warp::path("test_ws") .and(warp::ws()) .and(warp::any().map(move || services.clone()))