diff --git a/crates/papyrus_network/src/db_executor/mod.rs b/crates/papyrus_network/src/db_executor/mod.rs index d3ea3d3bcf..fd216ca2e4 100644 --- a/crates/papyrus_network/src/db_executor/mod.rs +++ b/crates/papyrus_network/src/db_executor/mod.rs @@ -6,12 +6,14 @@ use futures::channel::mpsc::Sender; use futures::future::poll_fn; use futures::stream::FuturesUnordered; use futures::{Stream, StreamExt}; +#[cfg(test)] +use mockall::automock; use papyrus_storage::header::HeaderStorageReader; -use papyrus_storage::StorageReader; +use papyrus_storage::{db, StorageReader, StorageTxn}; use starknet_api::block::{BlockHeader, BlockNumber, BlockSignature}; use tokio::task::JoinHandle; -use crate::{BlockHashOrNumber, InternalQuery}; +use crate::{BlockHashOrNumber, DataType, InternalQuery}; #[cfg(test)] mod test; @@ -21,10 +23,14 @@ mod utils; #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Display)] pub struct QueryId(pub usize); -#[cfg_attr(test, derive(Debug, Clone, PartialEq, Eq))] +#[cfg_attr(test, derive(Debug, Clone, PartialEq, Eq, Default))] pub enum Data { // TODO(shahak): Consider uniting with SignedBlockHeader. - BlockHeaderAndSignature { header: BlockHeader, signatures: Vec }, + BlockHeaderAndSignature { + header: BlockHeader, + signatures: Vec, + }, + #[cfg_attr(test, default)] Fin, } @@ -86,7 +92,12 @@ impl DBExecutorError { /// The stream is never exhausted, and it is the responsibility of the user to poll it. pub trait DBExecutor: Stream> + Unpin { // TODO: add writer functionality - fn register_query(&mut self, query: InternalQuery, sender: Sender) -> QueryId; + fn register_query( + &mut self, + query: InternalQuery, + data_type: impl FetchBlockDataFromDb + Send + 'static, + sender: Sender, + ) -> QueryId; } // TODO: currently this executor returns only block headers and signatures. @@ -104,7 +115,12 @@ impl BlockHeaderDBExecutor { } impl DBExecutor for BlockHeaderDBExecutor { - fn register_query(&mut self, query: InternalQuery, mut sender: Sender) -> QueryId { + fn register_query( + &mut self, + query: InternalQuery, + data_type: impl FetchBlockDataFromDb + Send + 'static, + mut sender: Sender, + ) -> QueryId { let query_id = QueryId(self.next_query_id); self.next_query_id += 1; let storage_reader_clone = self.storage_reader.clone(); @@ -135,30 +151,12 @@ impl DBExecutor for BlockHeaderDBExecutor { block_counter, query_id, )?); - let header = txn - .get_block_header(block_number) - .map_err(|err| DBExecutorError::DBInternalError { - query_id, - storage_error: err, - })? - .ok_or(DBExecutorError::BlockNotFound { - block_hash_or_number: BlockHashOrNumber::Number(block_number), - query_id, - })?; - let signature = txn - .get_block_signature(block_number) - .map_err(|err| DBExecutorError::DBInternalError { - query_id, - storage_error: err, - })? - .ok_or(DBExecutorError::SignatureNotFound { block_number, query_id })?; + let data = data_type.fetch_block_data_from_db(block_number, query_id, &txn)?; // Using poll_fn because Sender::poll_ready is not a future match poll_fn(|cx| sender.poll_ready(cx)).await { Ok(()) => { - if let Err(e) = sender.start_send(Data::BlockHeaderAndSignature { - header, - signatures: vec![signature], - }) { + if let Err(e) = sender.start_send(data) { + // TODO: consider implement retry mechanism. return Err(DBExecutorError::SendError { query_id, send_error: e }); }; } @@ -201,3 +199,48 @@ pub(crate) fn poll_query_execution_set( Poll::Pending => Poll::Pending, } } + +#[cfg_attr(test, automock)] +// we need to tell clippy to ignore the "needless" lifetime warning because it's not true. +// we do need the lifetime for the automock, following clippy's suggestion will break the code. +#[allow(clippy::needless_lifetimes)] +pub trait FetchBlockDataFromDb { + fn fetch_block_data_from_db<'a>( + &self, + block_number: BlockNumber, + query_id: QueryId, + txn: &StorageTxn<'a, db::RO>, + ) -> Result; +} + +impl FetchBlockDataFromDb for DataType { + fn fetch_block_data_from_db( + &self, + block_number: BlockNumber, + query_id: QueryId, + txn: &StorageTxn<'_, db::RO>, + ) -> Result { + match self { + DataType::SignedBlockHeader => { + let header = txn + .get_block_header(block_number) + .map_err(|err| DBExecutorError::DBInternalError { + query_id, + storage_error: err, + })? + .ok_or(DBExecutorError::BlockNotFound { + block_hash_or_number: BlockHashOrNumber::Number(block_number), + query_id, + })?; + let signature = txn + .get_block_signature(block_number) + .map_err(|err| DBExecutorError::DBInternalError { + query_id, + storage_error: err, + })? + .ok_or(DBExecutorError::SignatureNotFound { block_number, query_id })?; + Ok(Data::BlockHeaderAndSignature { header, signatures: vec![signature] }) + } + } + } +} diff --git a/crates/papyrus_network/src/db_executor/test.rs b/crates/papyrus_network/src/db_executor/test.rs index b0edd27880..b6abb3ff33 100644 --- a/crates/papyrus_network/src/db_executor/test.rs +++ b/crates/papyrus_network/src/db_executor/test.rs @@ -10,8 +10,8 @@ use rand::random; use starknet_api::block::{BlockHash, BlockHeader, BlockNumber, BlockSignature}; use super::Data::BlockHeaderAndSignature; -use crate::db_executor::{DBExecutor, DBExecutorError}; -use crate::{BlockHashOrNumber, Direction, InternalQuery}; +use crate::db_executor::{DBExecutor, DBExecutorError, Data, MockFetchBlockDataFromDb}; +use crate::{BlockHashOrNumber, DataType, Direction, InternalQuery}; const BUFFER_SIZE: usize = 10; #[tokio::test] @@ -31,7 +31,7 @@ async fn header_db_executor_can_register_and_run_a_query() { limit: NUM_OF_BLOCKS, step: 1, }; - let query_id = db_executor.register_query(query, sender); + let query_id = db_executor.register_query(query, DataType::SignedBlockHeader, sender); // run the executor and collect query results. tokio::select! { @@ -75,7 +75,7 @@ async fn header_db_executor_start_block_given_by_hash() { limit: NUM_OF_BLOCKS, step: 1, }; - let query_id = db_executor.register_query(query, sender); + let query_id = db_executor.register_query(query, DataType::SignedBlockHeader, sender); // run the executor and collect query results. tokio::select! { @@ -109,7 +109,20 @@ async fn header_db_executor_query_of_missing_block() { limit: NUM_OF_BLOCKS, step: 1, }; - let _query_id = db_executor.register_query(query, sender); + let mut mock_data_type = MockFetchBlockDataFromDb::new(); + mock_data_type.expect_fetch_block_data_from_db().times((BLOCKS_DELTA + 1) as usize).returning( + |block_number, query_id, _| { + if block_number.0 == NUM_OF_BLOCKS { + Err(DBExecutorError::BlockNotFound { + block_hash_or_number: BlockHashOrNumber::Number(block_number), + query_id, + }) + } else { + Ok(Data::default()) + } + }, + ); + let _query_id = db_executor.register_query(query, mock_data_type, sender); tokio::select! { res = db_executor.next() => { @@ -148,7 +161,12 @@ async fn header_db_executor_can_receive_queries_after_stream_is_exhausted() { limit: NUM_OF_BLOCKS, step: 1, }; - let query_id = db_executor.register_query(query, sender); + let mut mock_data_type = MockFetchBlockDataFromDb::new(); + mock_data_type + .expect_fetch_block_data_from_db() + .times(NUM_OF_BLOCKS as usize) + .returning(|_, _, _| Ok(Data::default())); + let query_id = db_executor.register_query(query, mock_data_type, sender); // run the executor and collect query results. receiver.collect::>().await; @@ -183,7 +201,7 @@ async fn header_db_executor_drop_receiver_before_query_is_done() { drop(receiver); // register a query. - let _query_id = db_executor.register_query(query, sender); + let _query_id = db_executor.register_query(query, MockFetchBlockDataFromDb::new(), sender); // executor should return an error. let res = db_executor.next().await; diff --git a/crates/papyrus_network/src/network_manager/mod.rs b/crates/papyrus_network/src/network_manager/mod.rs index f17aaf1f90..1f4ae69bbd 100644 --- a/crates/papyrus_network/src/network_manager/mod.rs +++ b/crates/papyrus_network/src/network_manager/mod.rs @@ -22,7 +22,7 @@ use crate::db_executor::{self, BlockHeaderDBExecutor, DBExecutor, Data, QueryId} use crate::protobuf_messages::protobuf; use crate::streamed_bytes::behaviour::{Behaviour, SessionError}; use crate::streamed_bytes::{Config, GenericEvent, InboundSessionId}; -use crate::{NetworkConfig, PeerAddressConfig, Protocol, Query, ResponseReceivers}; +use crate::{DataType, NetworkConfig, PeerAddressConfig, Protocol, Query, ResponseReceivers}; type StreamCollection = SelectAll>; type SubscriberChannels = (Receiver, Router); @@ -170,7 +170,11 @@ impl GenericNetworkManager) -> QueryId { + fn register_query( + &mut self, + query: InternalQuery, + _data_type: impl FetchBlockDataFromDb + Send, + mut sender: Sender, + ) -> QueryId { let query_id = QueryId(self.next_query_id); self.next_query_id += 1; let headers = self.query_to_headers.get(&query).unwrap().clone();