Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pull messages out of mempool rather than pushing messages into engine #203

Merged
merged 7 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,14 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Use the new non-global metrics registry when we upgrade to newer version of malachite
let _ = Metrics::register(registry);

let (messages_request_tx, messages_request_rx) = mpsc::channel(100);
let node = SnapchainNode::create(
keypair.clone(),
app_config.consensus.clone(),
Some(app_config.rpc_address.clone()),
gossip_tx.clone(),
None,
messages_request_tx,
block_store.clone(),
app_config.rocksdb_dir.clone(),
statsd_client.clone(),
Expand All @@ -135,8 +137,8 @@ async fn main() -> Result<(), Box<dyn Error>> {
let (mempool_tx, mempool_rx) = mpsc::channel(app_config.mempool.queue_size as usize);
let mut mempool = Mempool::new(
mempool_rx,
messages_request_rx,
app_config.consensus.num_shards,
node.shard_senders.clone(),
node.shard_stores.clone(),
);
tokio::spawn(async move { mempool.run().await });
Expand Down
113 changes: 80 additions & 33 deletions src/mempool/mempool.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::collections::HashMap;
use std::collections::{BTreeMap, HashMap};

use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio::{
sync::{mpsc, oneshot},
time::Instant,
};

use crate::storage::{
store::{
engine::{MempoolMessage, Senders},
stores::Stores,
},
store::{engine::MempoolMessage, stores::Stores},
trie::merkle_trie::{self, TrieKey},
};

Expand All @@ -25,27 +25,34 @@ impl Default for Config {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct MempoolKey {
inserted_at: Instant,
}

pub struct Mempool {
shard_senders: HashMap<u32, Senders>,
shard_stores: HashMap<u32, Stores>,
message_router: Box<dyn MessageRouter>,
num_shards: u32,
mempool_rx: mpsc::Receiver<MempoolMessage>,
messages_request_rx: mpsc::Receiver<(u32, oneshot::Sender<Option<MempoolMessage>>)>,
aditiharini marked this conversation as resolved.
Show resolved Hide resolved
messages: HashMap<u32, BTreeMap<MempoolKey, MempoolMessage>>,
}

impl Mempool {
pub fn new(
mempool_rx: mpsc::Receiver<MempoolMessage>,
messages_request_rx: mpsc::Receiver<(u32, oneshot::Sender<Option<MempoolMessage>>)>,
num_shards: u32,
shard_senders: HashMap<u32, Senders>,
shard_stores: HashMap<u32, Stores>,
) -> Self {
Mempool {
shard_senders,
shard_stores,
num_shards,
mempool_rx,
message_router: Box::new(ShardRouter {}),
messages: HashMap::new(),
messages_request_rx,
}
}

Expand Down Expand Up @@ -75,32 +82,60 @@ impl Mempool {
}
}

fn is_message_already_merged(&mut self, message: &MempoolMessage) -> bool {
let fid = message.fid();
match message {
MempoolMessage::UserMessage(message) => {
self.message_exists_in_trie(fid, TrieKey::for_message(message))
async fn pull_message(&mut self, shard_id: u32, tx: oneshot::Sender<Option<MempoolMessage>>) {
let mut message = None;
loop {
let messages = self.messages.get_mut(&shard_id);
match messages {
None => break,
Some(messages) => {
match messages.pop_first() {
None => break,
Some((_, next_message)) => {
if self.message_is_valid(&next_message) {
message = Some(next_message);
break;
}
}
};
}
}
}

if let Err(_) = tx.send(message) {
error!("Unable to send message from mempool");
aditiharini marked this conversation as resolved.
Show resolved Hide resolved
}
}

fn get_trie_key(message: &MempoolMessage) -> Option<Vec<u8>> {
match message {
MempoolMessage::UserMessage(message) => return Some(TrieKey::for_message(message)),
MempoolMessage::ValidatorMessage(validator_message) => {
if let Some(onchain_event) = &validator_message.on_chain_event {
return self
.message_exists_in_trie(fid, TrieKey::for_onchain_event(&onchain_event));
return Some(TrieKey::for_onchain_event(&onchain_event));
}

if let Some(fname_transfer) = &validator_message.fname_transfer {
if let Some(proof) = &fname_transfer.proof {
let name = String::from_utf8(proof.name.clone()).unwrap();
return self.message_exists_in_trie(
fid,
TrieKey::for_fname(fname_transfer.id, &name),
);
return Some(TrieKey::for_fname(fname_transfer.id, &name));
}
}
false

return None;
}
}
}

fn is_message_already_merged(&mut self, message: &MempoolMessage) -> bool {
let fid = message.fid();
let trie_key = Self::get_trie_key(&message);
match trie_key {
Some(trie_key) => self.message_exists_in_trie(fid, trie_key),
None => false,
}
}

pub fn message_is_valid(&mut self, message: &MempoolMessage) -> bool {
if self.is_message_already_merged(message) {
return false;
Expand All @@ -110,21 +145,33 @@ impl Mempool {
}

pub async fn run(&mut self) {
while let Some(message) = self.mempool_rx.recv().await {
if self.message_is_valid(&message) {
let fid = message.fid();
let shard = self.message_router.route_message(fid, self.num_shards);
let senders = self.shard_senders.get(&shard);
match senders {
None => {
error!("Unable to find shard to send message to")
}
Some(senders) => {
if let Err(err) = senders.messages_tx.send(message).await {
error!("Unable to send message to engine: {}", err.to_string())
loop {
tokio::select! {
message = self.mempool_rx.recv() => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Selecting directly here is dangerous, the mempool rx channel and overload the messages_request_rx channel, and make it so the request is never fulfilled if we get sufficient message traffic to the mempool.

It's better to loop using a tick and ensure that the messages request channel is processed first and always prioritized.

See

let deadline = Instant::now() + timeout;
loop {
let timeout = time::sleep_until(deadline);
select! {
_ = poll_interval.tick() => {
for an example of how we're done this elsewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a change to specify biased in the select macro. This causes the macro to preference branches in order of specification. I think this is what we want here?

In this case, I don't really think we need a timeout because we will just poll indefinitely.

if let Some(message) = message {
// TODO(aditi): Maybe we don't need to run validations here?
if self.message_is_valid(&message) {
let fid = message.fid();
let shard_id = self.message_router.route_message(fid, self.num_shards);
// TODO(aditi): We need a size limit on the mempool and we need to figure out what to do if it's exceeded
match self.messages.get_mut(&shard_id) {
None => {
let mut messages = BTreeMap::new();
messages.insert(MempoolKey { inserted_at: Instant::now()}, message.clone());
self.messages.insert(shard_id, messages);
}
Some(messages) => {
messages.insert(MempoolKey { inserted_at: Instant::now()}, message.clone());
}
}
}
}
}
message_request = self.messages_request_rx.recv() => {
if let Some((shard_id, tx)) = message_request {
self.pull_message(shard_id, tx).await
}
}
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/mempool/mempool_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ mod tests {

fn setup() -> (ShardEngine, Mempool) {
let (_mempool_tx, mempool_rx) = mpsc::channel(100);
let (_mempool_tx, messages_request_rx) = mpsc::channel(100);
let (engine, _) = test_helper::new_engine();
let mut shard_senders = HashMap::new();
shard_senders.insert(1, engine.get_senders());
let mut shard_stores = HashMap::new();
shard_stores.insert(1, engine.get_stores());
let mempool = Mempool::new(mempool_rx, 1, shard_senders, shard_stores);
let mempool = Mempool::new(mempool_rx, messages_request_rx, 1, shard_stores);
(engine, mempool)
}

Expand Down
1 change: 1 addition & 0 deletions src/network/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ impl MyHubService {
stores.store_limits.clone(),
self.statsd_client.clone(),
100,
None,
);
let result = readonly_engine.simulate_message(&message);

Expand Down
10 changes: 6 additions & 4 deletions src/network/server_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,29 +134,31 @@ mod tests {
let (engine1, _) = test_helper::new_engine_with_options(test_helper::EngineOptions {
limits: Some(limits.clone()),
db_name: Some("db1.db".to_string()),
messages_request_tx: None,
});
let (engine2, _) = test_helper::new_engine_with_options(test_helper::EngineOptions {
limits: Some(limits.clone()),
db_name: Some("db2.db".to_string()),
messages_request_tx: None,
});
let db1 = engine1.db.clone();
let db2 = engine2.db.clone();

let (msgs_tx, _msgs_rx) = mpsc::channel(100);
let (_msgs_request_tx, msgs_request_rx) = mpsc::channel(100);

let shard1_stores = Stores::new(
db1,
merkle_trie::MerkleTrie::new(16).unwrap(),
limits.clone(),
);
let shard1_senders = Senders::new(msgs_tx.clone());
let shard1_senders = Senders::new();

let shard2_stores = Stores::new(
db2,
merkle_trie::MerkleTrie::new(16).unwrap(),
limits.clone(),
);
let shard2_senders = Senders::new(msgs_tx.clone());
let shard2_senders = Senders::new();
let stores = HashMap::from([(1, shard1_stores), (2, shard2_stores)]);
let senders = HashMap::from([(1, shard1_senders), (2, shard2_senders)]);
let num_shards = senders.len() as u32;
Expand All @@ -169,7 +171,7 @@ mod tests {
assert_eq!(message_router.route_message(SHARD2_FID, 2), 2);

let (mempool_tx, mempool_rx) = mpsc::channel(1000);
let mut mempool = Mempool::new(mempool_rx, num_shards, senders.clone(), stores.clone());
let mut mempool = Mempool::new(mempool_rx, msgs_request_rx, num_shards, stores.clone());
tokio::spawn(async move { mempool.run().await });

(
Expand Down
6 changes: 4 additions & 2 deletions src/node/snapchain_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::core::types::{
use crate::network::gossip::GossipEvent;
use crate::proto::{Block, ShardChunk};
use crate::storage::db::RocksDB;
use crate::storage::store::engine::{BlockEngine, Senders, ShardEngine};
use crate::storage::store::engine::{BlockEngine, MempoolMessage, Senders, ShardEngine};
use crate::storage::store::stores::StoreLimits;
use crate::storage::store::stores::Stores;
use crate::storage::store::BlockStore;
Expand All @@ -20,7 +20,7 @@ use informalsystems_malachitebft_metrics::Metrics;
use libp2p::identity::ed25519::Keypair;
use ractor::ActorRef;
use std::collections::{BTreeMap, HashMap};
use tokio::sync::mpsc;
use tokio::sync::{mpsc, oneshot};
use tracing::warn;

const MAX_SHARDS: u32 = 64;
Expand All @@ -39,6 +39,7 @@ impl SnapchainNode {
rpc_address: Option<String>,
gossip_tx: mpsc::Sender<GossipEvent<SnapchainValidatorContext>>,
block_tx: Option<mpsc::Sender<Block>>,
messages_request_tx: mpsc::Sender<(u32, oneshot::Sender<Option<MempoolMessage>>)>,
block_store: BlockStore,
rocksdb_dir: String,
statsd_client: StatsdClientWrapper,
Expand Down Expand Up @@ -91,6 +92,7 @@ impl SnapchainNode {
StoreLimits::default(),
statsd_client.clone(),
config.max_messages_per_block,
Some(messages_request_tx.clone()),
);

shard_senders.insert(shard_id, engine.get_senders());
Expand Down
19 changes: 17 additions & 2 deletions src/perf/engine_only_perftest.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use tokio::sync::mpsc;

use crate::mempool::mempool::Mempool;
use crate::proto::{Height, ShardChunk, ShardHeader};
use crate::storage::store::engine::{MempoolMessage, ShardStateChange};
use crate::storage::store::stores::StoreLimits;
use crate::storage::store::test_helper;
use crate::utils::cli::compose_message;
use std::collections::HashMap;
use std::error::Error;
use std::time::Duration;

Expand All @@ -28,16 +32,27 @@ fn state_change_to_shard_chunk(
}

pub async fn run() -> Result<(), Box<dyn Error>> {
let (mempool_tx, mempool_rx) = mpsc::channel(1000);
let (messages_request_tx, messages_request_rx) = mpsc::channel(100);

let (mut engine, _tmpdir) = test_helper::new_engine_with_options(test_helper::EngineOptions {
limits: Some(StoreLimits {
limits: test_helper::limits::unlimited(),
legacy_limits: test_helper::limits::unlimited(),
}),
db_name: None,
messages_request_tx: Some(messages_request_tx),
});

let mut shard_stores = HashMap::new();
shard_stores.insert(1, engine.get_stores());
let mut mempool = Mempool::new(mempool_rx, messages_request_rx, 1, shard_stores);

tokio::spawn(async move {
mempool.run().await;
});

let mut i = 0;
let messages_tx = engine.messages_tx();

let fid = test_helper::FID_FOR_TEST;

Expand All @@ -54,7 +69,7 @@ pub async fn run() -> Result<(), Box<dyn Error>> {
let text = format!("For benchmarking {}", i);
let msg = compose_message(fid, text.as_str(), None, None);

messages_tx
mempool_tx
.send(MempoolMessage::UserMessage(msg.clone()))
.await
.unwrap();
Expand Down
Loading