Skip to content

Commit

Permalink
refactor: Simplify client.rs API and internals
Browse files Browse the repository at this point in the history
  • Loading branch information
leoshimo committed Dec 7, 2023
1 parent 728a718 commit 69807d1
Showing 1 changed file with 93 additions and 71 deletions.
164 changes: 93 additions & 71 deletions libvrs/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,27 @@ use tracing::{debug, error};
pub struct Client {
/// Sender half to send messages to shared task
hdl_tx: mpsc::Sender<Event>,
/// Cancellation token to shutdown event loop and handle
cancel_token: CancellationToken,
/// Cancellation token to shutdown async task
cancel: CancellationToken,
}

/// Errors from interacting with [Client]
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Failed to on mpsc - {0}")]
#[error("Failed to send on mpsc - {0}")]
MpscSendError(#[from] tokio::sync::mpsc::error::SendError<Event>),

#[error("Failed to recv on oneshot - {0}")]
OneShotRecvError(#[from] tokio::sync::oneshot::error::RecvError),

#[error("{0}")]
IOError(#[from] std::io::Error),

#[error("Internal inconsistency - {0}")]
InternalError(String),

#[error("Connection disconnected")]
DisconnectedError,
}

/// Messages processed by event loop
Expand All @@ -35,12 +44,6 @@ pub enum Event {
},
/// Event when receiving response from remote
RecvResponse(Response),
/// Event when reading on connection results in IO error
RecvError(std::io::Error),
/// Event when receiving request from remote
RecvRequest(Request),
/// Event when connection with runtime disconnects
DisconnectedFromRuntime,
}

/// The state of active [Client]
Expand All @@ -52,112 +55,127 @@ struct State {
inflight_reqs: HashMap<u32, oneshot::Sender<Response>>,
/// Next request id to use
next_req_id: u32,
/// Cancellation token used to shutdown event loop
cancel_token: CancellationToken,
}

impl Client {
/// Create new client from connection transport between client and runtime
pub fn new(conn: Connection) -> Self {
let (hdl_tx, hdl_rx) = mpsc::channel(32);
let cancel_token = CancellationToken::new();
let state = State::new(conn, cancel_token.clone());
tokio::spawn(run(state, hdl_rx));
Self {
hdl_tx,
cancel_token,
}
let cancel = CancellationToken::new();

let cancel_clone = cancel.clone();
let state = State::new(conn);
tokio::spawn(async move {
tokio::select! {
res = run(state, hdl_rx) => {
if let Err(e) = res {
eprintln!("Client terminated with err - {e}");
}
},
_ = cancel_clone.cancelled() => {
debug!("terminating client...");
}
}
});
Self { hdl_tx, cancel }
}

/// Dispatch a request
pub async fn request(&self, contents: lyric::Form) -> Result<Response, Error> {
debug!("request contents = {:?}", contents);
pub async fn request(&self, req: lyric::Form) -> Result<Response, Error> {
debug!("request req = {}", req);
let (resp_tx, resp_rx) = oneshot::channel();
let ev = Event::SendRequest {
req: contents,
resp_tx,
};
self.hdl_tx.send(ev).await?;
self.hdl_tx
.send(Event::SendRequest { req, resp_tx })
.await?;
Ok(resp_rx.await?)
}

/// Detect if client has terminated
pub async fn closed(&self) {
self.hdl_tx.closed().await
}

/// Initiate Shutdown. The future completes when shutdown is complete
pub async fn shutdown(&self) {
debug!("shutdown - start");
self.cancel_token.cancel();
let _ = self.hdl_tx.closed().await; // wait until rx drop
self.cancel.cancel();
let _ = self.hdl_tx.closed().await; // wait until rx drop in `run`
debug!("shutdown - done");
}
}

/// Start client side event loop
async fn run(mut state: State, mut hdl_rx: mpsc::Receiver<Event>) {
/// Run client task over command channel and connection
async fn run(mut state: State, mut hdl_rx: mpsc::Receiver<Event>) -> Result<(), Error> {
loop {
let ev = tokio::select! {
Some(e) = hdl_rx.recv() => e,
msg = state.conn.recv() => match msg {
Some(msg) => msg.map(Event::from).unwrap_or_else(Event::RecvError),
None => Event::DisconnectedFromRuntime,
Some(msg) => msg.map(Event::try_from)??,
None => return Err(Error::DisconnectedError),
},
_ = state.cancel_token.cancelled() => {
break;
}
};
state.handle_event(ev).await;
state.handle_event(ev).await?;
}
}

impl State {
fn new(conn: Connection, cancel_token: CancellationToken) -> Self {
fn new(conn: Connection) -> Self {
Self {
conn,
next_req_id: 0,
inflight_reqs: HashMap::new(),
cancel_token,
}
}

async fn handle_event(&mut self, e: Event) {
debug!("received {:?}", e);
use Event::*;
async fn handle_event(&mut self, e: Event) -> Result<(), Error> {
debug!("handle_event e = {:?}", e);
match e {
SendRequest {
Event::SendRequest {
req: contents,
resp_tx,
} => {
let req = Request {
id: self.next_req_id,
contents,
};
self.next_req_id += 1;
self.inflight_reqs.insert(req.id, resp_tx);
let _ = self.conn.send(&Message::Request(req)).await;
}
RecvResponse(resp) => match self.inflight_reqs.remove(&resp.req_id) {
Some(tx) => {
let _ = tx.send(resp);
}
None => {
error!("Received unexpected response for request - {:?}", resp);
}
},
RecvError(e) => {
error!("Encountered error - {}", e);
}
RecvRequest { .. } => panic!("Unimplemented - received request from runtime"),
DisconnectedFromRuntime => {
debug!("shutting down event loop...");
self.cancel_token.cancel();
} => self.handle_request(contents, resp_tx).await,
Event::RecvResponse(resp) => self.handle_recv_response(resp).await,
}
}

/// Handle a send request event
async fn handle_request(
&mut self,
contents: lyric::Form,
resp_tx: oneshot::Sender<Response>,
) -> Result<(), Error> {
let req = Request {
id: self.next_req_id,
contents,
};
self.next_req_id += 1;
self.inflight_reqs.insert(req.id, resp_tx);
Ok(self.conn.send(&Message::Request(req)).await?)
}

/// Handle a recv response event
async fn handle_recv_response(&mut self, resp: Response) -> Result<(), Error> {
match self.inflight_reqs.remove(&resp.req_id) {
Some(tx) => {
let _ = tx.send(resp);
Ok(())
}
None => Err(Error::InternalError(format!(
"Received unexpected response for request - {:?}",
resp
))),
}
}
}

impl From<Message> for Event {
fn from(msg: Message) -> Self {
match msg {
Message::Response(resp) => Self::RecvResponse(resp),
Message::Request(req) => Self::RecvRequest(req),
impl TryFrom<Message> for Event {
type Error = Error;
fn try_from(value: Message) -> Result<Self, Self::Error> {
match value {
Message::Response(resp) => Ok(Self::RecvResponse(resp)),
Message::Request(_) => Err(Error::InternalError(
"Client unexpectedly received Message::Request".to_string(),
)),
}
}
}
Expand Down Expand Up @@ -226,6 +244,10 @@ mod test {
.await
.expect("Request should be notified that remote connection was dropped before timeout");

assert!(matches!(resp, Err(Error::OneShotRecvError(_))));
assert_matches!(
resp,
Err(Error::OneShotRecvError(_)),
"Request should error when connection terminates"
);
}
}

0 comments on commit 69807d1

Please sign in to comment.