diff --git a/src/main.rs b/src/main.rs index e9fe2f8..d11e675 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,7 @@ use snapchain::network::server::MyHubService; use snapchain::node::snapchain_node::SnapchainNode; use snapchain::proto::admin_service_server::AdminServiceServer; use snapchain::proto::hub_service_server::HubServiceServer; +use snapchain::proto::Block; use snapchain::storage::db::RocksDB; use snapchain::storage::store::engine::MempoolMessage; use snapchain::storage::store::BlockStore; @@ -96,6 +97,7 @@ async fn main() -> Result<(), Box> { let (system_tx, mut system_rx) = mpsc::channel::(100); let (mempool_tx, mut mempool_rx) = mpsc::channel::(100); + let (block_tx, mut block_rx) = mpsc::channel::(100); let gossip_result = SnapchainGossip::create( keypair.clone(), @@ -130,7 +132,7 @@ async fn main() -> Result<(), Box> { app_config.mempool.clone(), Some(app_config.rpc_address.clone()), gossip_tx.clone(), - None, + Some(block_tx.clone()), block_store.clone(), app_config.rocksdb_dir.clone(), statsd_client.clone(), @@ -339,6 +341,9 @@ async fn main() -> Result<(), Box> { } } } + Some(block) = block_rx.recv() => { + node.handle_block(block).await; + } } } } diff --git a/src/node/snapchain_node.rs b/src/node/snapchain_node.rs index 61aee05..6df5c85 100644 --- a/src/node/snapchain_node.rs +++ b/src/node/snapchain_node.rs @@ -21,7 +21,8 @@ use informalsystems_malachitebft_metrics::Metrics; use libp2p::identity::ed25519::Keypair; use ractor::ActorRef; use std::collections::{BTreeMap, HashMap}; -use tokio::sync::mpsc; +use std::sync::Arc; +use tokio::sync::{mpsc, RwLock}; use tracing::warn; const MAX_SHARDS: u32 = 64; @@ -31,6 +32,7 @@ pub struct SnapchainNode { pub shard_stores: HashMap, pub shard_senders: HashMap, pub address: Address, + shard_mempools: HashMap>>, } impl SnapchainNode { @@ -54,6 +56,7 @@ impl SnapchainNode { let mut shard_senders: HashMap = HashMap::new(); let mut shard_stores: HashMap = HashMap::new(); + let mut shard_mempools: HashMap>> = HashMap::new(); // Create the shard validators for shard_id in config.shard_ids { @@ -99,6 +102,7 @@ impl SnapchainNode { shard_senders.insert(shard_id, engine.get_senders()); shard_stores.insert(shard_id, engine.get_stores()); + shard_mempools.insert(shard_id, engine.mempool()); let shard_proposer = ShardProposer::new( validator_address.clone(), @@ -190,6 +194,7 @@ impl SnapchainNode { address: validator_address, shard_senders, shard_stores, + shard_mempools, } } @@ -224,4 +229,24 @@ impl SnapchainNode { warn!("No actor found for shard, could not forward message"); } } + + pub async fn handle_block(&self, block: Block) { + for chunk in block.shard_chunks { + let header = chunk.header.expect("Expects chunk to have a header"); + let height = header.height.expect("Expects header to have a height"); + let mempool = self + .shard_mempools + .get(&height.shard_index) + .expect("Expects mempool to exist for shard"); + let mut mempool_write = mempool.write().await; + for transaction in chunk.transactions { + for user_message in transaction.user_messages { + mempool_write.remove(user_message.hex_hash()); + } + for system_message in transaction.system_messages { + mempool_write.remove(system_message.hex_hash()); + } + } + } + } }