diff --git a/examples/log.rs b/examples/log.rs index 09f9ab0..7928b32 100644 --- a/examples/log.rs +++ b/examples/log.rs @@ -39,8 +39,8 @@ async fn main() { shutdown_signal: shutdown_signal(), request_handler, response_handler, - incoming_message_handler: |msg| msg, - outgoing_message_handler: |msg| msg, + incoming_message_handler: |msg| Some(msg), + outgoing_message_handler: |msg| Some(msg), upstream_proxy: None, ca, }; diff --git a/examples/noop.rs b/examples/noop.rs index 3566fdb..920c384 100644 --- a/examples/noop.rs +++ b/examples/noop.rs @@ -14,8 +14,8 @@ async fn main() { let request_handler = |req| RequestOrResponse::Request(req); let response_handler = |res| res; - let incoming_message_handler = |msg| msg; - let outgoing_message_handler = |msg| msg; + let incoming_message_handler = |msg| Some(msg); + let outgoing_message_handler = |msg| Some(msg); let mut private_key_bytes: &[u8] = include_bytes!("ca/hudsucker.key"); let mut ca_cert_bytes: &[u8] = include_bytes!("ca/hudsucker.pem"); diff --git a/examples/upstream_proxy.rs b/examples/upstream_proxy.rs index 934c7ad..c78a4ad 100644 --- a/examples/upstream_proxy.rs +++ b/examples/upstream_proxy.rs @@ -44,8 +44,8 @@ async fn main() { shutdown_signal: shutdown_signal(), request_handler, response_handler, - incoming_message_handler: |msg| msg, - outgoing_message_handler: |msg| msg, + incoming_message_handler: |msg| Some(msg), + outgoing_message_handler: |msg| Some(msg), upstream_proxy: None, ca: ca.clone(), }; @@ -54,8 +54,8 @@ async fn main() { listen_addr: SocketAddr::from(([127, 0, 0, 1], 3000)), request_handler: |req| RequestOrResponse::Request(req), response_handler: |res| res, - incoming_message_handler: |msg| msg, - outgoing_message_handler: |msg| msg, + incoming_message_handler: |msg| Some(msg), + outgoing_message_handler: |msg| Some(msg), shutdown_signal: shutdown_signal(), upstream_proxy: Some(UpstreamProxy::new( Intercept::All, diff --git a/src/lib.rs b/src/lib.rs index 7b09f0c..ab56ee0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -72,9 +72,16 @@ impl ResponseHandler for T where /// Handler for websocket messages. /// -/// The handler will be called for each websocket message. It can return a modified message. -pub trait MessageHandler: FnMut(Message) -> Message + Send + Sync + Clone + 'static {} -impl MessageHandler for T where T: FnMut(Message) -> Message + Send + Sync + Clone + 'static {} +/// The handler will be called for each websocket message. It can return an optional modified +/// message. If None is returned the message will not be forwarded. +pub trait MessageHandler: + FnMut(Message) -> Option + Send + Sync + Clone + 'static +{ +} +impl MessageHandler for T where + T: FnMut(Message) -> Option + Send + Sync + Clone + 'static +{ +} /// Configuration for the proxy server. /// diff --git a/src/proxy.rs b/src/proxy.rs index c1471ce..4f5d410 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -161,7 +161,11 @@ where while let Some(message) = server_stream.next().await { match message { Ok(message) => { - let message = incoming_message_handler(message); + let message = match incoming_message_handler(message) { + Some(message) => message, + None => continue, + }; + match client_sink.send(message).await { Err(tungstenite::Error::ConnectionClosed) => (), Err(e) => error!("websocket send error: {}", e), @@ -177,7 +181,11 @@ where while let Some(message) = client_stream.next().await { match message { Ok(message) => { - let message = outgoing_message_handler(message); + let message = match outgoing_message_handler(message) { + Some(message) => message, + None => continue, + }; + match server_sink.send(message).await { Err(tungstenite::Error::ConnectionClosed) => (), Err(e) => error!("websocket send error: {}", e),