diff --git a/roles/tests-integration/lib/mod.rs b/roles/tests-integration/lib/mod.rs index aaa489e50..5865635e0 100644 --- a/roles/tests-integration/lib/mod.rs +++ b/roles/tests-integration/lib/mod.rs @@ -30,7 +30,7 @@ pub async fn start_sniffer( identifier: String, upstream: SocketAddr, check_on_drop: bool, - intercept_message: Option>, + action: Option, ) -> (Sniffer, SocketAddr) { let listening_address = get_available_address(); let sniffer = Sniffer::new( @@ -38,7 +38,7 @@ pub async fn start_sniffer( listening_address, upstream, check_on_drop, - intercept_message, + action, ) .await; let sniffer_clone = sniffer.clone(); diff --git a/roles/tests-integration/lib/sniffer.rs b/roles/tests-integration/lib/sniffer.rs index 6460a121f..4eb5117f4 100644 --- a/roles/tests-integration/lib/sniffer.rs +++ b/roles/tests-integration/lib/sniffer.rs @@ -59,7 +59,42 @@ pub struct Sniffer { messages_from_downstream: MessagesAggregator, messages_from_upstream: MessagesAggregator, check_on_drop: bool, - intercept_messages: Vec, + action: Option, +} + +/// Represents an action that [`Sniffer`] can take on intercepted messages. +#[derive(Debug, Clone)] +pub enum Action { + /// Blocks the message stream after encountering a specific message. + BlockFromMessage(BlockFromMessage), + /// Intercepts and modifies a message before forwarding it. + InterceptMessage(Box), +} + +/// Defines an action that blocks the message stream after detecting a specific message. +#[derive(Debug, Clone)] +pub struct BlockFromMessage { + direction: MessageDirection, + expected_message_type: MsgType, +} + +impl BlockFromMessage { + /// Creates a new [`BlockFromMessage`] action. + /// + /// - `direction`: The direction of the message stream to block. + /// - `expected_message_type`: The type of message after which the stream should be blocked. + pub fn new(direction: MessageDirection, expected_message_type: MsgType) -> Self { + BlockFromMessage { + direction, + expected_message_type, + } + } +} + +impl From for Action { + fn from(value: BlockFromMessage) -> Self { + Action::BlockFromMessage(value) + } } /// Allows [`Sniffer`] to replace some intercepted message before forwarding it. @@ -89,6 +124,12 @@ impl InterceptMessage { } } +impl From for Action { + fn from(value: InterceptMessage) -> Self { + Action::InterceptMessage(Box::new(value)) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum MessageDirection { ToDownstream, @@ -103,7 +144,7 @@ impl Sniffer { listening_address: SocketAddr, upstream_address: SocketAddr, check_on_drop: bool, - intercept_messages: Option>, + action: Option, ) -> Self { Self { identifier, @@ -112,7 +153,7 @@ impl Sniffer { messages_from_downstream: MessagesAggregator::new(), messages_from_upstream: MessagesAggregator::new(), check_on_drop, - intercept_messages: intercept_messages.unwrap_or_default(), + action, } } @@ -134,10 +175,10 @@ impl Sniffer { .expect("Failed to create upstream"); let downstream_messages = self.messages_from_downstream.clone(); let upstream_messages = self.messages_from_upstream.clone(); - let intercept_messages = self.intercept_messages.clone(); + let action = self.action.clone(); let _ = select! { - r = Self::recv_from_down_send_to_up(downstream_receiver, upstream_sender, downstream_messages, intercept_messages.clone()) => r, - r = Self::recv_from_up_send_to_down(upstream_receiver, downstream_sender, upstream_messages, intercept_messages) => r, + r = Self::recv_from_down_send_to_up(downstream_receiver, upstream_sender, downstream_messages, action.clone()) => r, + r = Self::recv_from_up_send_to_down(upstream_receiver, downstream_sender, upstream_messages, action) => r, }; // wait a bit so we dont drop the sniffer before the test has finished sleep(std::time::Duration::from_secs(1)).await; @@ -215,30 +256,56 @@ impl Sniffer { recv: Receiver, send: Sender, downstream_messages: MessagesAggregator, - intercept_messages: Vec, + action: Option, ) -> Result<(), SnifferError> { + let mut blocked = false; while let Ok(mut frame) = recv.recv().await { + if blocked { + continue; + } let (msg_type, msg) = Self::message_from_frame(&mut frame); - let intercept_message = intercept_messages.iter().find(|im| { - im.direction == MessageDirection::ToUpstream && im.expected_message_type == msg_type + let action = action.as_ref().and_then(|action| match action { + Action::BlockFromMessage(bm) + if bm.direction == MessageDirection::ToUpstream + && bm.expected_message_type == msg_type => + { + Some(action) + } + + Action::InterceptMessage(im) + if im.direction == MessageDirection::ToUpstream + && im.expected_message_type == msg_type => + { + Some(action) + } + + _ => None, }); - if let Some(intercept_message) = intercept_message { - let intercept_frame = StandardEitherFrame::>::Sv2( - Sv2Frame::from_message( - intercept_message.replacement_message.clone(), - intercept_message.replacement_message.message_type(), - 0, - false, - ) - .expect("Failed to create the frame"), - ); - downstream_messages.add_message( - intercept_message.replacement_message.message_type(), - intercept_message.replacement_message.clone(), - ); - send.send(intercept_frame) - .await - .map_err(|_| SnifferError::UpstreamClosed)?; + if let Some(ref action) = action { + match action { + Action::BlockFromMessage(_) => { + blocked = true; + continue; + } + Action::InterceptMessage(intercept_message) => { + let intercept_frame = StandardEitherFrame::>::Sv2( + Sv2Frame::from_message( + intercept_message.replacement_message.clone(), + intercept_message.replacement_message.message_type(), + 0, + false, + ) + .expect("Failed to create the frame"), + ); + downstream_messages.add_message( + intercept_message.replacement_message.message_type(), + intercept_message.replacement_message.clone(), + ); + send.send(intercept_frame) + .await + .map_err(|_| SnifferError::UpstreamClosed)?; + } + } } else { downstream_messages.add_message(msg_type, msg); send.send(frame) @@ -253,31 +320,58 @@ impl Sniffer { recv: Receiver, send: Sender, upstream_messages: MessagesAggregator, - intercept_messages: Vec, + action: Option, ) -> Result<(), SnifferError> { + let mut blocked = false; while let Ok(mut frame) = recv.recv().await { + if blocked { + continue; + } let (msg_type, msg) = Self::message_from_frame(&mut frame); - let intercept_message = intercept_messages.iter().find(|im| { - im.direction == MessageDirection::ToDownstream - && im.expected_message_type == msg_type + + let action = action.as_ref().and_then(|action| match action { + Action::BlockFromMessage(bm) + if bm.direction == MessageDirection::ToDownstream + && bm.expected_message_type == msg_type => + { + Some(action) + } + + Action::InterceptMessage(im) + if im.direction == MessageDirection::ToDownstream + && im.expected_message_type == msg_type => + { + Some(action) + } + + _ => None, }); - if let Some(intercept_message) = intercept_message { - let intercept_frame = StandardEitherFrame::>::Sv2( - Sv2Frame::from_message( - intercept_message.replacement_message.clone(), - intercept_message.replacement_message.message_type(), - 0, - false, - ) - .expect("Failed to create the frame"), - ); - upstream_messages.add_message( - intercept_message.replacement_message.message_type(), - intercept_message.replacement_message.clone(), - ); - send.send(intercept_frame) - .await - .map_err(|_| SnifferError::DownstreamClosed)?; + + if let Some(ref action) = action { + match action { + Action::BlockFromMessage(_) => { + blocked = true; + continue; + } + Action::InterceptMessage(intercept_message) => { + let intercept_frame = StandardEitherFrame::>::Sv2( + Sv2Frame::from_message( + intercept_message.replacement_message.clone(), + intercept_message.replacement_message.message_type(), + 0, + false, + ) + .expect("Failed to create the frame"), + ); + upstream_messages.add_message( + intercept_message.replacement_message.message_type(), + intercept_message.replacement_message.clone(), + ); + send.send(intercept_frame) + .await + .map_err(|_| SnifferError::DownstreamClosed)?; + } + } } else { upstream_messages.add_message(msg_type, msg); send.send(frame) diff --git a/roles/tests-integration/tests/sniffer_integration.rs b/roles/tests-integration/tests/sniffer_integration.rs index e81f771d9..1e578dd5f 100644 --- a/roles/tests-integration/tests/sniffer_integration.rs +++ b/roles/tests-integration/tests/sniffer_integration.rs @@ -43,7 +43,7 @@ async fn test_sniffer_intercept_to_downstream() { // this sniffer will replace SetupConnectionSuccess with SetupConnectionError let (_sniffer_a, sniffer_a_addr) = - start_sniffer("A".to_string(), tp_addr, false, Some(vec![intercept])).await; + start_sniffer("A".to_string(), tp_addr, false, Some(intercept.into())).await; // this sniffer will assert SetupConnectionSuccess was correctly replaced with // SetupConnectionError @@ -83,7 +83,7 @@ async fn test_sniffer_intercept_to_upstream() { ); let (sniffer_a, sniffer_a_addr) = - start_sniffer("A".to_string(), tp_addr, false, Some(vec![intercept])).await; + start_sniffer("A".to_string(), tp_addr, false, Some(intercept.into())).await; let (_sniffer_b, sniffer_b_addr) = start_sniffer("B".to_string(), sniffer_a_addr, false, None).await; @@ -142,3 +142,44 @@ async fn test_sniffer_wait_for_message_type_with_remove() { false ); } + +/// Verifies that [`Sniffer`] can intercept and block a message stream. +/// +/// This test sets up a chain where a message from the Template Provider (TP) +/// passes through two sniffers (`sniffer_a` and `sniffer_b`) before reaching the Pool. +/// +/// - `sniffer_a` is configured to block `SetupConnectionSuccess` messages directed downstream. +/// - `sniffer_b` should receive no messages after initial setup, ensuring the block works. +/// +/// **Flow:** +/// `TP -> sniffer_a -> sniffer_b -> Pool` +#[tokio::test] +async fn test_sniffer_blocks_message() { + start_tracing(); + let (_tp, tp_addr) = start_template_provider(None); + + // Define an action to block SetupConnectionSuccess messages going downstream. + let block_from_message = BlockFromMessage::new( + MessageDirection::ToDownstream, + MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS, + ); + + // `sniffer_a` intercepts and blocks `SetupConnectionSuccess` messages. + let (_sniffer_a, sniffer_a_addr) = start_sniffer( + "A".to_string(), + tp_addr, + false, + Some(block_from_message.into()), + ) + .await; + + // `sniffer_b` is placed downstream of `sniffer_a` and should receive nothing. + let (sniffer_b, sniffer_b_addr) = + start_sniffer("B".to_string(), sniffer_a_addr, false, None).await; + + // Start the Pool, connected to `sniffer_b`. + let _ = start_pool(Some(sniffer_b_addr)).await; + + // Assert that `sniffer_b` does not receive any messages, confirming `sniffer_a`'s block works. + assert!(sniffer_b.next_message_from_upstream().is_none()); +}