diff --git a/src/mempool.rs b/src/mempool.rs index 05f96a50..58603439 100644 --- a/src/mempool.rs +++ b/src/mempool.rs @@ -34,7 +34,7 @@ use std::{ sync::Arc, }; use storage::{DataProposalVerdict, LaneEntry, Storage}; -use tokio::task::{futures, JoinHandle, JoinSet}; +use tokio::task::JoinSet; // Pick one of the two implementations // use storage_memory::LanesStorage; use storage_fjall::LanesStorage; @@ -113,7 +113,7 @@ pub struct MempoolStore { pub struct Mempool { bus: MempoolBusClient, file: Option, - blocker: JoinSet<()>, + blocker: JoinSet>, conf: SharedConf, crypto: SharedBlstCrypto, metrics: MempoolMetrics, @@ -237,7 +237,7 @@ impl Module for Mempool { bus, file: Some(ctx.common.config.data_directory.clone()), conf: ctx.common.config.clone(), - blocker: JoinSet::new() + blocker: JoinSet::new(), metrics, crypto: Arc::clone(&ctx.node.crypto), lanes: LanesStorage::new( @@ -265,6 +265,12 @@ impl Mempool { module_handle_messages! { on_bus self.bus, + on_shutdown { + // Waiting all proof txs being processed + let mut join_set: JoinSet> = JoinSet::new(); + std::mem::swap(&mut self.blocker, &mut join_set); + join_set.join_all().await; + }, listen> cmd => { let _ = self.handle_net_message(cmd) .log_error("Handling MempoolNetMessage in Mempool"); @@ -974,14 +980,14 @@ impl Mempool { let kc = self.known_contracts.clone(); let sender: &tokio::sync::broadcast::Sender = self.bus.get(); let sender = sender.clone(); - let t = tokio::task::spawn_blocking(move || { - let tx = - Self::process_proof_tx(kc, tx).log_error("Error processing proof tx")?; + self.blocker.spawn_blocking(move || { + let tx = Self::process_proof_tx(kc, tx) + .log_error("Processing proof tx in blocker")?; sender .send(InternalMempoolEvent::OnProcessedNewTx(tx)) - .log_warn("sending processed TX") + .log_warn("sending processed TX")?; + Ok(()) }); - while !t.is_finished() {} return Ok(()); } @@ -1242,6 +1248,7 @@ pub mod test { bus, file: None, conf: SharedConf::default(), + blocker: JoinSet::new(), crypto: Arc::new(crypto), metrics: MempoolMetrics::global("id".to_string()), lanes, diff --git a/src/utils/modules.rs b/src/utils/modules.rs index 1a512d81..b3d74b2f 100644 --- a/src/utils/modules.rs +++ b/src/utils/modules.rs @@ -133,6 +133,22 @@ pub mod signal { #[macro_export] macro_rules! module_handle_messages { + (on_bus $bus:expr, on_shutdown $on_shutdown:block, $($rest:tt)*) => { + { + let mut shutdown_receiver = unsafe { &mut *Pick::>::splitting_get_mut(&mut $bus) }; + let mut should_shutdown = false; + $crate::handle_messages! { + on_bus $bus, + $($rest)* + Ok(_) = $crate::utils::modules::signal::async_receive_shutdown::(&mut should_shutdown, &mut shutdown_receiver) => { + tracing::debug!("Break signal received for module {}", std::any::type_name::()); + $on_shutdown; + break; + } + } + should_shutdown + } + }; (on_bus $bus:expr, $($rest:tt)*) => { { let mut shutdown_receiver = unsafe { &mut *Pick::>::splitting_get_mut(&mut $bus) };