Skip to content

Commit

Permalink
Allow sniffer to mutate ongoing messages
Browse files Browse the repository at this point in the history
..Add new `InterceptMessage` property to allow the sniffer to mutate a
message before it sent to downstream/upstream.
  • Loading branch information
jbesraa committed Nov 18, 2024
1 parent b900d0a commit f52863f
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 10 deletions.
9 changes: 7 additions & 2 deletions roles/tests-integration/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use key_utils::{Secp256k1PublicKey, Secp256k1SecretKey};
use once_cell::sync::Lazy;
use pool_sv2::PoolSv2;
use sniffer::Sniffer;
pub use sniffer::{InterceptMessage, MessageDirection};
use std::{
collections::HashSet,
convert::TryFrom,
Expand Down Expand Up @@ -193,8 +194,12 @@ pub fn get_available_address() -> SocketAddr {
SocketAddr::from(([127, 0, 0, 1], port))
}

pub async fn start_sniffer(listening_address: SocketAddr, upstream: SocketAddr) -> Sniffer {
let sniffer = Sniffer::new(listening_address, upstream).await;
pub async fn start_sniffer(
listening_address: SocketAddr,
upstream: SocketAddr,
intercept_messages: Option<Vec<InterceptMessage>>,
) -> Sniffer {
let sniffer = Sniffer::new(listening_address, upstream, intercept_messages).await;
let sniffer_clone = sniffer.clone();
tokio::spawn(async move {
sniffer_clone.start().await;
Expand Down
111 changes: 104 additions & 7 deletions roles/tests-integration/tests/common/sniffer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use async_channel::{Receiver, Sender};
use codec_sv2::{
framing_sv2::framing::Frame, HandshakeRole, Initiator, Responder, StandardEitherFrame,
framing_sv2::framing::Frame, HandshakeRole, Initiator, Responder, StandardEitherFrame, Sv2Frame,
};
use key_utils::{Secp256k1PublicKey, Secp256k1SecretKey};
use network_helpers_sv2::noise_connection_tokio::Connection;
Expand All @@ -13,15 +13,16 @@ use roles_logic_sv2::{
IdentifyTransactionsSuccess, ProvideMissingTransactions,
ProvideMissingTransactionsSuccess, SubmitSolution,
},
TemplateDistribution,
TemplateDistribution::CoinbaseOutputDataSize,
PoolMessages,
TemplateDistribution::{self, CoinbaseOutputDataSize},
},
utils::Mutex,
};
use std::{collections::VecDeque, convert::TryInto, net::SocketAddr, sync::Arc};
use tokio::{
net::{TcpListener, TcpStream},
select,
time::sleep,
};
type MessageFrame = StandardEitherFrame<AnyMessage<'static>>;
type MsgType = u8;
Expand All @@ -30,6 +31,7 @@ type MsgType = u8;
enum SnifferError {
DownstreamClosed,
UpstreamClosed,
MessageInterrupted,
}

/// Allows to intercept messages sent between two roles.
Expand All @@ -50,17 +52,56 @@ pub struct Sniffer {
upstream_address: SocketAddr,
messages_from_downstream: MessagesAggregator,
messages_from_upstream: MessagesAggregator,
intercept_messages: Vec<InterceptMessage>,
}

#[derive(Debug, Clone)]
pub struct InterceptMessage {
direction: MessageDirection,
expected_message_type: MsgType,
response_message: PoolMessages<'static>,
response_message_type: MsgType,
break_on: bool,
}

impl InterceptMessage {
pub fn new(
direction: MessageDirection,
expected_message_type: MsgType,
response_message: PoolMessages<'static>,
response_message_type: MsgType,
break_on: bool,
) -> Self {
Self {
direction,
expected_message_type,
response_message,
response_message_type,
break_on,
}
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MessageDirection {
ToDownstream,
ToUpstream,
}

impl Sniffer {
/// Creates a new sniffer that listens on the given listening address and connects to the given
/// upstream address.
pub async fn new(listening_address: SocketAddr, upstream_address: SocketAddr) -> Self {
pub async fn new(
listening_address: SocketAddr,
upstream_address: SocketAddr,
intercept_messages: Option<Vec<InterceptMessage>>,
) -> Self {
Self {
listening_address,
upstream_address,
messages_from_downstream: MessagesAggregator::new(),
messages_from_upstream: MessagesAggregator::new(),
intercept_messages: intercept_messages.unwrap_or_default(),
}
}

Expand All @@ -82,10 +123,13 @@ impl Sniffer {
.expect("Failed to create upstream");
let downstream_messages = self.messages_from_downstream.clone();
let upstream_messages = self.messages_from_upstream.clone();
let intercept_messages = self.intercept_messages.clone();
let _ = select! {
r = Self::recv_from_down_send_to_up(downstream_receiver, upstream_sender, downstream_messages) => r,
r = Self::recv_from_up_send_to_down(upstream_receiver, downstream_sender, upstream_messages) => r,
r = Self::recv_from_down_send_to_up(downstream_receiver, upstream_sender, downstream_messages, intercept_messages.clone()) => r,
r = Self::recv_from_up_send_to_down(upstream_receiver, downstream_sender, upstream_messages, intercept_messages) => r,
};
// wait a bit so we dont drop the sniffer before the test has finished
sleep(std::time::Duration::from_secs(1)).await;
}

/// Returns the oldest message sent by downstream.
Expand Down Expand Up @@ -160,9 +204,36 @@ impl Sniffer {
recv: Receiver<MessageFrame>,
send: Sender<MessageFrame>,
downstream_messages: MessagesAggregator,
intercept_messages: Vec<InterceptMessage>,
) -> Result<(), SnifferError> {
while let Ok(mut frame) = recv.recv().await {
let (msg_type, msg) = Self::message_from_frame(&mut frame);
for interrupt_message in intercept_messages.iter() {
if interrupt_message.direction == MessageDirection::ToUpstream
&& interrupt_message.expected_message_type == msg_type
{
let extension_type = 0;
let channel_msg = false;
let frame = StandardEitherFrame::<AnyMessage<'_>>::Sv2(
Sv2Frame::from_message(
interrupt_message.response_message.clone(),
interrupt_message.response_message_type,
extension_type,
channel_msg,
)
.expect("Failed to create the frame"),
);
downstream_messages
.add_message(msg_type, interrupt_message.response_message.clone());
let _ = send.send(frame).await;
if interrupt_message.break_on {
return Err(SnifferError::MessageInterrupted);
} else {
continue;
}
}
}

downstream_messages.add_message(msg_type, msg);
if send.send(frame).await.is_err() {
return Err(SnifferError::UpstreamClosed);
Expand All @@ -175,13 +246,39 @@ impl Sniffer {
recv: Receiver<MessageFrame>,
send: Sender<MessageFrame>,
upstream_messages: MessagesAggregator,
intercept_messages: Vec<InterceptMessage>,
) -> Result<(), SnifferError> {
while let Ok(mut frame) = recv.recv().await {
let (msg_type, msg) = Self::message_from_frame(&mut frame);
upstream_messages.add_message(msg_type, msg);
for interrupt_message in intercept_messages.iter() {
if interrupt_message.direction == MessageDirection::ToDownstream
&& interrupt_message.expected_message_type == msg_type
{
let extension_type = 0;
let channel_msg = false;
let frame = StandardEitherFrame::<AnyMessage<'_>>::Sv2(
Sv2Frame::from_message(
interrupt_message.response_message.clone(),
interrupt_message.response_message_type,
extension_type,
channel_msg,
)
.expect("Failed to create the frame"),
);
upstream_messages
.add_message(msg_type, interrupt_message.response_message.clone());
let _ = send.send(frame).await;
if interrupt_message.break_on {
return Err(SnifferError::MessageInterrupted);
} else {
continue;
}
}
}
if send.send(frame).await.is_err() {
return Err(SnifferError::DownstreamClosed);
};
upstream_messages.add_message(msg_type, msg);
}
Err(SnifferError::UpstreamClosed)
}
Expand Down
2 changes: 1 addition & 1 deletion roles/tests-integration/tests/pool_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ async fn success_pool_template_provider_connection() {
let tp_addr = common::get_available_address();
let pool_addr = common::get_available_address();
let _tp = common::start_template_provider(tp_addr.port()).await;
let sniffer = common::start_sniffer(sniffer_addr, tp_addr).await;
let sniffer = common::start_sniffer(sniffer_addr, tp_addr, None).await;
let _ = common::start_pool(Some(pool_addr), Some(sniffer_addr)).await;
// here we assert that the downstream(pool in this case) have sent `SetupConnection` message
// with the correct parameters, protocol, flags, min_version and max_version. Note that the
Expand Down

0 comments on commit f52863f

Please sign in to comment.