diff --git a/libvrs/src/client.rs b/libvrs/src/client.rs index fd13478..4fb3b57 100644 --- a/libvrs/src/client.rs +++ b/libvrs/src/client.rs @@ -11,18 +11,27 @@ use tracing::{debug, error}; pub struct Client { /// Sender half to send messages to shared task hdl_tx: mpsc::Sender, - /// Cancellation token to shutdown event loop and handle - cancel_token: CancellationToken, + /// Cancellation token to shutdown async task + cancel: CancellationToken, } /// Errors from interacting with [Client] #[derive(thiserror::Error, Debug)] pub enum Error { - #[error("Failed to on mpsc - {0}")] + #[error("Failed to send on mpsc - {0}")] MpscSendError(#[from] tokio::sync::mpsc::error::SendError), #[error("Failed to recv on oneshot - {0}")] OneShotRecvError(#[from] tokio::sync::oneshot::error::RecvError), + + #[error("{0}")] + IOError(#[from] std::io::Error), + + #[error("Internal inconsistency - {0}")] + InternalError(String), + + #[error("Connection disconnected")] + DisconnectedError, } /// Messages processed by event loop @@ -35,12 +44,6 @@ pub enum Event { }, /// Event when receiving response from remote RecvResponse(Response), - /// Event when reading on connection results in IO error - RecvError(std::io::Error), - /// Event when receiving request from remote - RecvRequest(Request), - /// Event when connection with runtime disconnects - DisconnectedFromRuntime, } /// The state of active [Client] @@ -52,112 +55,127 @@ struct State { inflight_reqs: HashMap>, /// Next request id to use next_req_id: u32, - /// Cancellation token used to shutdown event loop - cancel_token: CancellationToken, } impl Client { /// Create new client from connection transport between client and runtime pub fn new(conn: Connection) -> Self { let (hdl_tx, hdl_rx) = mpsc::channel(32); - let cancel_token = CancellationToken::new(); - let state = State::new(conn, cancel_token.clone()); - tokio::spawn(run(state, hdl_rx)); - Self { - hdl_tx, - cancel_token, - } + let cancel = CancellationToken::new(); + + let cancel_clone = cancel.clone(); + let state = State::new(conn); + tokio::spawn(async move { + tokio::select! { + res = run(state, hdl_rx) => { + if let Err(e) = res { + eprintln!("Client terminated with err - {e}"); + } + }, + _ = cancel_clone.cancelled() => { + debug!("terminating client..."); + } + } + }); + Self { hdl_tx, cancel } } /// Dispatch a request - pub async fn request(&self, contents: lyric::Form) -> Result { - debug!("request contents = {:?}", contents); + pub async fn request(&self, req: lyric::Form) -> Result { + debug!("request req = {}", req); let (resp_tx, resp_rx) = oneshot::channel(); - let ev = Event::SendRequest { - req: contents, - resp_tx, - }; - self.hdl_tx.send(ev).await?; + self.hdl_tx + .send(Event::SendRequest { req, resp_tx }) + .await?; Ok(resp_rx.await?) } + /// Detect if client has terminated + pub async fn closed(&self) { + self.hdl_tx.closed().await + } + /// Initiate Shutdown. The future completes when shutdown is complete pub async fn shutdown(&self) { debug!("shutdown - start"); - self.cancel_token.cancel(); - let _ = self.hdl_tx.closed().await; // wait until rx drop + self.cancel.cancel(); + let _ = self.hdl_tx.closed().await; // wait until rx drop in `run` debug!("shutdown - done"); } } -/// Start client side event loop -async fn run(mut state: State, mut hdl_rx: mpsc::Receiver) { +/// Run client task over command channel and connection +async fn run(mut state: State, mut hdl_rx: mpsc::Receiver) -> Result<(), Error> { loop { let ev = tokio::select! { Some(e) = hdl_rx.recv() => e, msg = state.conn.recv() => match msg { - Some(msg) => msg.map(Event::from).unwrap_or_else(Event::RecvError), - None => Event::DisconnectedFromRuntime, + Some(msg) => msg.map(Event::try_from)??, + None => return Err(Error::DisconnectedError), }, - _ = state.cancel_token.cancelled() => { - break; - } }; - state.handle_event(ev).await; + state.handle_event(ev).await?; } } impl State { - fn new(conn: Connection, cancel_token: CancellationToken) -> Self { + fn new(conn: Connection) -> Self { Self { conn, next_req_id: 0, inflight_reqs: HashMap::new(), - cancel_token, } } - async fn handle_event(&mut self, e: Event) { - debug!("received {:?}", e); - use Event::*; + async fn handle_event(&mut self, e: Event) -> Result<(), Error> { + debug!("handle_event e = {:?}", e); match e { - SendRequest { + Event::SendRequest { req: contents, resp_tx, - } => { - let req = Request { - id: self.next_req_id, - contents, - }; - self.next_req_id += 1; - self.inflight_reqs.insert(req.id, resp_tx); - let _ = self.conn.send(&Message::Request(req)).await; - } - RecvResponse(resp) => match self.inflight_reqs.remove(&resp.req_id) { - Some(tx) => { - let _ = tx.send(resp); - } - None => { - error!("Received unexpected response for request - {:?}", resp); - } - }, - RecvError(e) => { - error!("Encountered error - {}", e); - } - RecvRequest { .. } => panic!("Unimplemented - received request from runtime"), - DisconnectedFromRuntime => { - debug!("shutting down event loop..."); - self.cancel_token.cancel(); + } => self.handle_request(contents, resp_tx).await, + Event::RecvResponse(resp) => self.handle_recv_response(resp).await, + } + } + + /// Handle a send request event + async fn handle_request( + &mut self, + contents: lyric::Form, + resp_tx: oneshot::Sender, + ) -> Result<(), Error> { + let req = Request { + id: self.next_req_id, + contents, + }; + self.next_req_id += 1; + self.inflight_reqs.insert(req.id, resp_tx); + Ok(self.conn.send(&Message::Request(req)).await?) + } + + /// Handle a recv response event + async fn handle_recv_response(&mut self, resp: Response) -> Result<(), Error> { + match self.inflight_reqs.remove(&resp.req_id) { + Some(tx) => { + let _ = tx.send(resp); + Ok(()) } + None => Err(Error::InternalError(format!( + "Received unexpected response for request - {:?}", + resp + ))), } } } -impl From for Event { - fn from(msg: Message) -> Self { - match msg { - Message::Response(resp) => Self::RecvResponse(resp), - Message::Request(req) => Self::RecvRequest(req), +impl TryFrom for Event { + type Error = Error; + fn try_from(value: Message) -> Result { + match value { + Message::Response(resp) => Ok(Self::RecvResponse(resp)), + Message::Request(_) => Err(Error::InternalError( + "Client unexpectedly received Message::Request".to_string(), + )), } } } @@ -226,6 +244,10 @@ mod test { .await .expect("Request should be notified that remote connection was dropped before timeout"); - assert!(matches!(resp, Err(Error::OneShotRecvError(_)))); + assert_matches!( + resp, + Err(Error::OneShotRecvError(_)), + "Request should error when connection terminates" + ); } }