Skip to content

Commit

Permalink
Add action construct and BlockFromMessage action
Browse files Browse the repository at this point in the history
This commit introduces an action construct that encapsulates all sniffer actions.
Additionally, it adds the `BlockFromMessage` action, allowing message streams
to be blocked after a specific message is encountered.

something
  • Loading branch information
Shourya742 committed Mar 2, 2025
1 parent 6dde6b9 commit 5453875
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 51 deletions.
4 changes: 2 additions & 2 deletions roles/tests-integration/lib/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ pub async fn start_sniffer(
identifier: String,
upstream: SocketAddr,
check_on_drop: bool,
intercept_message: Option<Vec<sniffer::InterceptMessage>>,
action: Option<Action>,
) -> (Sniffer, SocketAddr) {
let listening_address = get_available_address();
let sniffer = Sniffer::new(
identifier,
listening_address,
upstream,
check_on_drop,
intercept_message,
action,
)
.await;
let sniffer_clone = sniffer.clone();
Expand Down
188 changes: 141 additions & 47 deletions roles/tests-integration/lib/sniffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,42 @@ pub struct Sniffer {
messages_from_downstream: MessagesAggregator,
messages_from_upstream: MessagesAggregator,
check_on_drop: bool,
intercept_messages: Vec<InterceptMessage>,
action: Option<Action>,
}

/// Represents an action that [`Sniffer`] can take on intercepted messages.
#[derive(Debug, Clone)]
pub enum Action {
/// Blocks the message stream after encountering a specific message.
BlockFromMessage(BlockFromMessage),
/// Intercepts and modifies a message before forwarding it.
InterceptMessage(Box<InterceptMessage>),
}

/// Defines an action that blocks the message stream after detecting a specific message.
#[derive(Debug, Clone)]
pub struct BlockFromMessage {
direction: MessageDirection,
expected_message_type: MsgType,
}

impl BlockFromMessage {
/// Creates a new [`BlockFromMessage`] action.
///
/// - `direction`: The direction of the message stream to block.
/// - `expected_message_type`: The type of message after which the stream should be blocked.
pub fn new(direction: MessageDirection, expected_message_type: MsgType) -> Self {
BlockFromMessage {
direction,
expected_message_type,
}
}
}

impl From<BlockFromMessage> for Action {
fn from(value: BlockFromMessage) -> Self {
Action::BlockFromMessage(value)
}
}

/// Allows [`Sniffer`] to replace some intercepted message before forwarding it.
Expand Down Expand Up @@ -89,6 +124,12 @@ impl InterceptMessage {
}
}

impl From<InterceptMessage> for Action {
fn from(value: InterceptMessage) -> Self {
Action::InterceptMessage(Box::new(value))
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MessageDirection {
ToDownstream,
Expand All @@ -103,7 +144,7 @@ impl Sniffer {
listening_address: SocketAddr,
upstream_address: SocketAddr,
check_on_drop: bool,
intercept_messages: Option<Vec<InterceptMessage>>,
action: Option<Action>,
) -> Self {
Self {
identifier,
Expand All @@ -112,7 +153,7 @@ impl Sniffer {
messages_from_downstream: MessagesAggregator::new(),
messages_from_upstream: MessagesAggregator::new(),
check_on_drop,
intercept_messages: intercept_messages.unwrap_or_default(),
action,
}
}

Expand All @@ -134,10 +175,10 @@ impl Sniffer {
.expect("Failed to create upstream");
let downstream_messages = self.messages_from_downstream.clone();
let upstream_messages = self.messages_from_upstream.clone();
let intercept_messages = self.intercept_messages.clone();
let action = self.action.clone();
let _ = select! {
r = Self::recv_from_down_send_to_up(downstream_receiver, upstream_sender, downstream_messages, intercept_messages.clone()) => r,
r = Self::recv_from_up_send_to_down(upstream_receiver, downstream_sender, upstream_messages, intercept_messages) => r,
r = Self::recv_from_down_send_to_up(downstream_receiver, upstream_sender, downstream_messages, action.clone()) => r,
r = Self::recv_from_up_send_to_down(upstream_receiver, downstream_sender, upstream_messages, action) => r,
};
// wait a bit so we dont drop the sniffer before the test has finished
sleep(std::time::Duration::from_secs(1)).await;
Expand Down Expand Up @@ -215,30 +256,56 @@ impl Sniffer {
recv: Receiver<MessageFrame>,
send: Sender<MessageFrame>,
downstream_messages: MessagesAggregator,
intercept_messages: Vec<InterceptMessage>,
action: Option<Action>,
) -> Result<(), SnifferError> {
let mut blocked = false;
while let Ok(mut frame) = recv.recv().await {
if blocked {
continue;
}
let (msg_type, msg) = Self::message_from_frame(&mut frame);
let intercept_message = intercept_messages.iter().find(|im| {
im.direction == MessageDirection::ToUpstream && im.expected_message_type == msg_type
let action = action.as_ref().and_then(|action| match action {
Action::BlockFromMessage(bm)
if bm.direction == MessageDirection::ToUpstream
&& bm.expected_message_type == msg_type =>
{
Some(action)
}

Action::InterceptMessage(im)
if im.direction == MessageDirection::ToUpstream
&& im.expected_message_type == msg_type =>
{
Some(action)
}

_ => None,
});
if let Some(intercept_message) = intercept_message {
let intercept_frame = StandardEitherFrame::<AnyMessage<'_>>::Sv2(
Sv2Frame::from_message(
intercept_message.replacement_message.clone(),
intercept_message.replacement_message.message_type(),
0,
false,
)
.expect("Failed to create the frame"),
);
downstream_messages.add_message(
intercept_message.replacement_message.message_type(),
intercept_message.replacement_message.clone(),
);
send.send(intercept_frame)
.await
.map_err(|_| SnifferError::UpstreamClosed)?;
if let Some(ref action) = action {
match action {
Action::BlockFromMessage(_) => {
blocked = true;
continue;
}
Action::InterceptMessage(intercept_message) => {
let intercept_frame = StandardEitherFrame::<AnyMessage<'_>>::Sv2(
Sv2Frame::from_message(
intercept_message.replacement_message.clone(),
intercept_message.replacement_message.message_type(),
0,
false,
)
.expect("Failed to create the frame"),
);
downstream_messages.add_message(
intercept_message.replacement_message.message_type(),
intercept_message.replacement_message.clone(),
);
send.send(intercept_frame)
.await
.map_err(|_| SnifferError::UpstreamClosed)?;
}
}
} else {
downstream_messages.add_message(msg_type, msg);
send.send(frame)
Expand All @@ -253,31 +320,58 @@ impl Sniffer {
recv: Receiver<MessageFrame>,
send: Sender<MessageFrame>,
upstream_messages: MessagesAggregator,
intercept_messages: Vec<InterceptMessage>,
action: Option<Action>,
) -> Result<(), SnifferError> {
let mut blocked = false;
while let Ok(mut frame) = recv.recv().await {
if blocked {
continue;
}
let (msg_type, msg) = Self::message_from_frame(&mut frame);
let intercept_message = intercept_messages.iter().find(|im| {
im.direction == MessageDirection::ToDownstream
&& im.expected_message_type == msg_type

let action = action.as_ref().and_then(|action| match action {
Action::BlockFromMessage(bm)
if bm.direction == MessageDirection::ToDownstream
&& bm.expected_message_type == msg_type =>
{
Some(action)
}

Action::InterceptMessage(im)
if im.direction == MessageDirection::ToDownstream
&& im.expected_message_type == msg_type =>
{
Some(action)
}

_ => None,
});
if let Some(intercept_message) = intercept_message {
let intercept_frame = StandardEitherFrame::<AnyMessage<'_>>::Sv2(
Sv2Frame::from_message(
intercept_message.replacement_message.clone(),
intercept_message.replacement_message.message_type(),
0,
false,
)
.expect("Failed to create the frame"),
);
upstream_messages.add_message(
intercept_message.replacement_message.message_type(),
intercept_message.replacement_message.clone(),
);
send.send(intercept_frame)
.await
.map_err(|_| SnifferError::DownstreamClosed)?;

if let Some(ref action) = action {
match action {
Action::BlockFromMessage(_) => {
blocked = true;
continue;
}
Action::InterceptMessage(intercept_message) => {
let intercept_frame = StandardEitherFrame::<AnyMessage<'_>>::Sv2(
Sv2Frame::from_message(
intercept_message.replacement_message.clone(),
intercept_message.replacement_message.message_type(),
0,
false,
)
.expect("Failed to create the frame"),
);
upstream_messages.add_message(
intercept_message.replacement_message.message_type(),
intercept_message.replacement_message.clone(),
);
send.send(intercept_frame)
.await
.map_err(|_| SnifferError::DownstreamClosed)?;
}
}
} else {
upstream_messages.add_message(msg_type, msg);
send.send(frame)
Expand Down
45 changes: 43 additions & 2 deletions roles/tests-integration/tests/sniffer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async fn test_sniffer_intercept_to_downstream() {

// this sniffer will replace SetupConnectionSuccess with SetupConnectionError
let (_sniffer_a, sniffer_a_addr) =
start_sniffer("A".to_string(), tp_addr, false, Some(vec![intercept])).await;
start_sniffer("A".to_string(), tp_addr, false, Some(intercept.into())).await;

// this sniffer will assert SetupConnectionSuccess was correctly replaced with
// SetupConnectionError
Expand Down Expand Up @@ -83,7 +83,7 @@ async fn test_sniffer_intercept_to_upstream() {
);

let (sniffer_a, sniffer_a_addr) =
start_sniffer("A".to_string(), tp_addr, false, Some(vec![intercept])).await;
start_sniffer("A".to_string(), tp_addr, false, Some(intercept.into())).await;

let (_sniffer_b, sniffer_b_addr) =
start_sniffer("B".to_string(), sniffer_a_addr, false, None).await;
Expand Down Expand Up @@ -142,3 +142,44 @@ async fn test_sniffer_wait_for_message_type_with_remove() {
false
);
}

/// Verifies that [`Sniffer`] can intercept and block a message stream.
///
/// This test sets up a chain where a message from the Template Provider (TP)
/// passes through two sniffers (`sniffer_a` and `sniffer_b`) before reaching the Pool.
///
/// - `sniffer_a` is configured to block `SetupConnectionSuccess` messages directed downstream.
/// - `sniffer_b` should receive no messages after initial setup, ensuring the block works.
///
/// **Flow:**
/// `TP -> sniffer_a -> sniffer_b -> Pool`
#[tokio::test]
async fn test_sniffer_blocks_message() {
start_tracing();
let (_tp, tp_addr) = start_template_provider(None);

// Define an action to block SetupConnectionSuccess messages going downstream.
let block_from_message = BlockFromMessage::new(
MessageDirection::ToDownstream,
MESSAGE_TYPE_SETUP_CONNECTION_SUCCESS,
);

// `sniffer_a` intercepts and blocks `SetupConnectionSuccess` messages.
let (_sniffer_a, sniffer_a_addr) = start_sniffer(
"A".to_string(),
tp_addr,
false,
Some(block_from_message.into()),
)
.await;

// `sniffer_b` is placed downstream of `sniffer_a` and should receive nothing.
let (sniffer_b, sniffer_b_addr) =
start_sniffer("B".to_string(), sniffer_a_addr, false, None).await;

// Start the Pool, connected to `sniffer_b`.
let _ = start_pool(Some(sniffer_b_addr)).await;

// Assert that `sniffer_b` does not receive any messages, confirming `sniffer_a`'s block works.
assert!(sniffer_b.next_message_from_upstream().is_none());
}

0 comments on commit 5453875

Please sign in to comment.