From 55482237672b096849b81fd0ca6ee6747860392e Mon Sep 17 00:00:00 2001 From: Alexandru Vasile <60601340+lexnv@users.noreply.github.com> Date: Thu, 16 May 2024 14:50:20 +0300 Subject: [PATCH] kad: Refactor `GetRecord` query and add tests (#97) This PR refactors the `GetRecord` query to make it more robust. - Avoid panics on unwraps and unimplemented logic - Separate config immutable variables from main query logic - Simplifies logic of `next_action` method - Private methods for internal logic Builds on top of: https://github.com/paritytech/litep2p/pull/96 cc @paritytech/networking --------- Signed-off-by: Alexandru Vasile --- src/protocol/libp2p/kademlia/mod.rs | 3 +- .../libp2p/kademlia/query/get_record.rs | 432 ++++++++++++++---- src/protocol/libp2p/kademlia/query/mod.rs | 29 +- 3 files changed, 367 insertions(+), 97 deletions(-) diff --git a/src/protocol/libp2p/kademlia/mod.rs b/src/protocol/libp2p/kademlia/mod.rs index 78f5c6a6..9a885ead 100644 --- a/src/protocol/libp2p/kademlia/mod.rs +++ b/src/protocol/libp2p/kademlia/mod.rs @@ -47,13 +47,12 @@ use tokio::sync::mpsc::{Receiver, Sender}; use std::collections::{hash_map::Entry, HashMap}; +pub use self::handle::RecordsType; pub use config::{Config, ConfigBuilder}; pub use handle::{KademliaEvent, KademliaHandle, Quorum, RoutingTableUpdateMode}; pub use query::QueryId; pub use record::{Key as RecordKey, PeerRecord, Record}; -pub use self::handle::RecordsType; - /// Logging target for the file. const LOG_TARGET: &str = "litep2p::ipfs::kademlia"; diff --git a/src/protocol/libp2p/kademlia/query/get_record.rs b/src/protocol/libp2p/kademlia/query/get_record.rs index 4d766af7..fb5004ca 100644 --- a/src/protocol/libp2p/kademlia/query/get_record.rs +++ b/src/protocol/libp2p/kademlia/query/get_record.rs @@ -18,7 +18,7 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -#![allow(unused)] +use bytes::Bytes; use crate::{ protocol::libp2p::kademlia::{ @@ -36,22 +36,57 @@ use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; /// Logging target for the file. const LOG_TARGET: &str = "litep2p::ipfs::kademlia::query::get_record"; +/// The configuration needed to instantiate a new [`GetRecordContext`]. #[derive(Debug)] -pub struct GetRecordContext { +pub struct GetRecordConfig { /// Local peer ID. - local_peer_id: PeerId, + pub local_peer_id: PeerId, - /// How many records have been successfully found. - pub record_count: usize, + /// How many records we already know about (ie extracted from storage). + /// + /// This can either be 0 or 1 when the record is extracted local storage. + pub known_records: usize, /// Quorum for the query. pub quorum: Quorum, + /// Replication factor. + pub replication_factor: usize, + + /// Parallelism factor. + pub parallelism_factor: usize, + /// Query ID. pub query: QueryId, /// Target key. pub target: Key, +} + +impl GetRecordConfig { + /// Checks if the found number of records meets the specified quorum. + /// + /// Used to determine if the query found enough records to stop. + fn sufficient_records(&self, records: usize) -> bool { + // The total number of known records is the sum of the records we knew about before starting + // the query and the records we found along the way. + let total_known = self.known_records + records; + + match self.quorum { + Quorum::All => total_known >= self.replication_factor, + Quorum::One => total_known >= 1, + Quorum::N(needed_responses) => total_known >= needed_responses.get(), + } + } +} + +#[derive(Debug)] +pub struct GetRecordContext { + /// Query immutable config. + pub config: GetRecordConfig, + + /// Cached Kadmelia message to send. + kad_message: Bytes, /// Peers from whom the `QueryEngine` is waiting to hear a response. pub pending: HashMap, @@ -67,42 +102,25 @@ pub struct GetRecordContext { /// Found records. pub found_records: Vec, - - /// Replication factor. - pub replication_factor: usize, - - /// Parallelism factor. - pub parallelism_factor: usize, } impl GetRecordContext { /// Create new [`GetRecordContext`]. - pub fn new( - local_peer_id: PeerId, - query: QueryId, - target: Key, - in_peers: VecDeque, - replication_factor: usize, - parallelism_factor: usize, - quorum: Quorum, - record_count: usize, - ) -> Self { + pub fn new(config: GetRecordConfig, in_peers: VecDeque) -> Self { let mut candidates = BTreeMap::new(); for candidate in &in_peers { - let distance = target.distance(&candidate.key); + let distance = config.target.distance(&candidate.key); candidates.insert(distance, candidate.clone()); } + let kad_message = KademliaMessage::get_record(config.target.clone().into_preimage()); + Self { - query, - target, - quorum, + config, + kad_message, + candidates, - record_count, - local_peer_id, - replication_factor, - parallelism_factor, pending: HashMap::new(), queried: HashSet::new(), found_records: Vec::new(), @@ -110,7 +128,7 @@ impl GetRecordContext { } /// Get the found records. - pub fn found_records(mut self) -> Vec { + pub fn found_records(self) -> Vec { self.found_records } @@ -145,21 +163,32 @@ impl GetRecordContext { } } - // add the queried peer to `queried` and all new peers which haven't been + // Add the queried peer to `queried` and all new peers which haven't been // queried to `candidates` self.queried.insert(peer.peer); - for candidate in peers { - if !self.queried.contains(&candidate.peer) - && !self.pending.contains_key(&candidate.peer) - { - if self.local_peer_id == candidate.peer { - continue; - } + let to_query_candidate = peers.into_iter().filter_map(|peer| { + // Peer already produced a response. + if self.queried.contains(&peer.peer) { + return None; + } - let distance = self.target.distance(&candidate.key); - self.candidates.insert(distance, candidate); + // Peer was queried, awaiting response. + if self.pending.contains_key(&peer.peer) { + return None; + } + + // Local node. + if self.config.local_peer_id == peer.peer { + return None; } + + Some(peer) + }); + + for candidate in to_query_candidate { + let distance = self.config.target.distance(&candidate.key); + self.candidates.insert(distance, candidate); } } @@ -167,74 +196,313 @@ impl GetRecordContext { // TODO: remove this and store the next action to `PeerAction` pub fn next_peer_action(&mut self, peer: &PeerId) -> Option { self.pending.contains_key(peer).then_some(QueryAction::SendMessage { - query: self.query, + query: self.config.query, peer: *peer, - message: KademliaMessage::get_record(self.target.clone().into_preimage()), + message: self.kad_message.clone(), }) } /// Schedule next peer for outbound `GET_VALUE` query. - pub fn schedule_next_peer(&mut self) -> QueryAction { - tracing::trace!(target: LOG_TARGET, query = ?self.query, "get next peer"); + fn schedule_next_peer(&mut self) -> Option { + tracing::trace!(target: LOG_TARGET, query = ?self.config.query, "get next peer"); + + let (_, candidate) = self.candidates.pop_first()?; - let (_, candidate) = self.candidates.pop_first().expect("entry to exist"); let peer = candidate.peer; tracing::trace!(target: LOG_TARGET, ?peer, "current candidate"); self.pending.insert(candidate.peer, candidate); - QueryAction::SendMessage { - query: self.query, + Some(QueryAction::SendMessage { + query: self.config.query, peer, - message: KademliaMessage::get_record(self.target.clone().into_preimage()), - } + message: self.kad_message.clone(), + }) + } + + /// Check if the query cannot make any progress. + /// + /// Returns true when there are no pending responses and no candidates to query. + fn is_done(&self) -> bool { + self.pending.is_empty() && self.candidates.is_empty() } /// Get next action for a `GET_VALUE` query. pub fn next_action(&mut self) -> Option { - // if there are no more peers to query, check if the query succeeded or failed - // the status is determined by whether a record was found - if self.pending.is_empty() && self.candidates.is_empty() { - match self.record_count + self.found_records.len() { - 0 => return Some(QueryAction::QueryFailed { query: self.query }), - _ => return Some(QueryAction::QuerySucceeded { query: self.query }), - } + // These are the records we knew about before starting the query and + // the records we found along the way. + let known_records = self.config.known_records + self.found_records.len(); + + // If we cannot make progress, return the final result. + // A query failed when we are not able to identify one single record. + if self.is_done() { + return if known_records == 0 { + Some(QueryAction::QueryFailed { + query: self.config.query, + }) + } else { + Some(QueryAction::QuerySucceeded { + query: self.config.query, + }) + }; } - // check if enough records have been found - let continue_search = match self.quorum { - Quorum::All => (self.record_count + self.found_records.len() < self.replication_factor), - Quorum::One => (self.record_count + self.found_records.len() < 1), - Quorum::N(num_responses) => - (self.record_count + self.found_records.len() < num_responses.into()), + // Check if enough records have been found + let sufficient_records = self.config.sufficient_records(self.found_records.len()); + if sufficient_records { + return Some(QueryAction::QuerySucceeded { + query: self.config.query, + }); + } + + // At this point, we either have pending responses or candidates to query; and we need more + // records. Ensure we do not exceed the parallelism factor. + if self.pending.len() == self.config.parallelism_factor { + return None; + } + + self.schedule_next_peer() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol::libp2p::kademlia::types::ConnectionType; + + fn default_config() -> GetRecordConfig { + GetRecordConfig { + local_peer_id: PeerId::random(), + quorum: Quorum::All, + known_records: 0, + replication_factor: 20, + parallelism_factor: 10, + query: QueryId(0), + target: Key::new(vec![1, 2, 3].into()), + } + } + + fn peer_to_kad(peer: PeerId) -> KademliaPeer { + KademliaPeer { + peer, + key: Key::from(peer), + addresses: vec![], + connection: ConnectionType::Connected, + } + } + + #[test] + fn config_check() { + // Quorum::All with no known records. + let config = GetRecordConfig { + quorum: Quorum::All, + known_records: 0, + replication_factor: 20, + ..default_config() + }; + assert!(config.sufficient_records(20)); + assert!(!config.sufficient_records(19)); + + // Quorum::All with 1 known records. + let config = GetRecordConfig { + quorum: Quorum::All, + known_records: 1, + replication_factor: 20, + ..default_config() + }; + assert!(config.sufficient_records(19)); + assert!(!config.sufficient_records(18)); + + // Quorum::One with no known records. + let config = GetRecordConfig { + quorum: Quorum::One, + known_records: 0, + ..default_config() + }; + assert!(config.sufficient_records(1)); + assert!(!config.sufficient_records(0)); + + // Quorum::One with known records. + let config = GetRecordConfig { + quorum: Quorum::One, + known_records: 1, + ..default_config() + }; + assert!(config.sufficient_records(1)); + assert!(config.sufficient_records(0)); + + // Quorum::N with no known records. + let config = GetRecordConfig { + quorum: Quorum::N(std::num::NonZeroUsize::new(10).expect("valid; qed")), + known_records: 0, + ..default_config() + }; + assert!(config.sufficient_records(10)); + assert!(!config.sufficient_records(9)); + + // Quorum::N with known records. + let config = GetRecordConfig { + quorum: Quorum::N(std::num::NonZeroUsize::new(10).expect("valid; qed")), + known_records: 1, + ..default_config() + }; + assert!(config.sufficient_records(9)); + assert!(!config.sufficient_records(8)); + } + + #[test] + fn completes_when_no_candidates() { + let config = default_config(); + let mut context = GetRecordContext::new(config, VecDeque::new()); + assert!(context.is_done()); + let event = context.next_action().unwrap(); + assert_eq!(event, QueryAction::QueryFailed { query: QueryId(0) }); + + let config = GetRecordConfig { + known_records: 1, + ..default_config() }; + let mut context = GetRecordContext::new(config, VecDeque::new()); + assert!(context.is_done()); + let event = context.next_action().unwrap(); + assert_eq!(event, QueryAction::QuerySucceeded { query: QueryId(0) }); + } + + #[test] + fn fulfill_parallelism() { + let config = GetRecordConfig { + parallelism_factor: 3, + ..default_config() + }; + + let in_peers_set: HashSet<_> = + [PeerId::random(), PeerId::random(), PeerId::random()].into_iter().collect(); + assert_eq!(in_peers_set.len(), 3); + + let in_peers = in_peers_set.iter().map(|peer| peer_to_kad(*peer)).collect(); + let mut context = GetRecordContext::new(config, in_peers); + + for num in 0..3 { + let event = context.next_action().unwrap(); + match event { + QueryAction::SendMessage { query, peer, .. } => { + assert_eq!(query, QueryId(0)); + // Added as pending. + assert_eq!(context.pending.len(), num + 1); + assert!(context.pending.contains_key(&peer)); - // if enough replicas for the record have been received (defined by the quorum size), - /// mark the query as succeeded - if !continue_search { - return Some(QueryAction::QuerySucceeded { query: self.query }); + // Check the peer is the one provided. + assert!(in_peers_set.contains(&peer)); + } + _ => panic!("Unexpected event"), + } } - // if the search must continue, try to schedule next outbound message if possible - if !self.pending.is_empty() || !self.candidates.is_empty() { - if self.pending.len() == self.parallelism_factor || self.candidates.is_empty() { - return None; + // Fulfilled parallelism. + assert!(context.next_action().is_none()); + } + + #[test] + fn completes_when_responses() { + let key = vec![1, 2, 3]; + let config = GetRecordConfig { + parallelism_factor: 3, + replication_factor: 3, + ..default_config() + }; + + let peer_a = PeerId::random(); + let peer_b = PeerId::random(); + let peer_c = PeerId::random(); + + let in_peers_set: HashSet<_> = [peer_a, peer_b, peer_c].into_iter().collect(); + assert_eq!(in_peers_set.len(), 3); + + let in_peers = [peer_a, peer_b, peer_c].iter().map(|peer| peer_to_kad(*peer)).collect(); + let mut context = GetRecordContext::new(config, in_peers); + + // Schedule peer queries. + for num in 0..3 { + let event = context.next_action().unwrap(); + match event { + QueryAction::SendMessage { query, peer, .. } => { + assert_eq!(query, QueryId(0)); + // Added as pending. + assert_eq!(context.pending.len(), num + 1); + assert!(context.pending.contains_key(&peer)); + + // Check the peer is the one provided. + assert!(in_peers_set.contains(&peer)); + } + _ => panic!("Unexpected event"), } + } - return Some(self.schedule_next_peer()); + // Checks a failed query that was not initiated. + let peer_d = PeerId::random(); + context.register_response_failure(peer_d); + assert_eq!(context.pending.len(), 3); + assert!(context.queried.is_empty()); + + // Provide responses back. + let record = Record::new(key.clone(), vec![1, 2, 3]); + context.register_response(peer_a, Some(record), vec![]); + assert_eq!(context.pending.len(), 2); + assert_eq!(context.queried.len(), 1); + assert_eq!(context.found_records.len(), 1); + + // Provide different response from peer b with peer d as candidate. + let record = Record::new(key.clone(), vec![4, 5, 6]); + context.register_response(peer_b, Some(record), vec![peer_to_kad(peer_d.clone())]); + assert_eq!(context.pending.len(), 1); + assert_eq!(context.queried.len(), 2); + assert_eq!(context.found_records.len(), 2); + assert_eq!(context.candidates.len(), 1); + + // Peer C fails. + context.register_response_failure(peer_c); + assert!(context.pending.is_empty()); + assert_eq!(context.queried.len(), 3); + assert_eq!(context.found_records.len(), 2); + + // Drain the last candidate. + let event = context.next_action().unwrap(); + match event { + QueryAction::SendMessage { query, peer, .. } => { + assert_eq!(query, QueryId(0)); + // Added as pending. + assert_eq!(context.pending.len(), 1); + assert_eq!(peer, peer_d); + } + _ => panic!("Unexpected event"), } - // TODO: probably not correct - tracing::warn!( - target: LOG_TARGET, - num_pending = ?self.pending.len(), - num_candidates = ?self.candidates.len(), - num_records = ?(self.record_count + self.found_records.len()), - quorum = ?self.quorum, - ?continue_search, - "unreachable condition for `GET_VALUE` search" + // Peer D responds. + let record = Record::new(key.clone(), vec![4, 5, 6]); + context.register_response(peer_d, Some(record), vec![]); + + // Produces the result. + let event = context.next_action().unwrap(); + assert_eq!(event, QueryAction::QuerySucceeded { query: QueryId(0) }); + + // Check results. + let found_records = context.found_records(); + assert_eq!( + found_records, + vec![ + PeerRecord { + peer: peer_a, + record: Record::new(key.clone(), vec![1, 2, 3]), + }, + PeerRecord { + peer: peer_b, + record: Record::new(key.clone(), vec![4, 5, 6]), + }, + PeerRecord { + peer: peer_d, + record: Record::new(key.clone(), vec![4, 5, 6]), + }, + ] ); - - unreachable!(); } } diff --git a/src/protocol/libp2p/kademlia/query/mod.rs b/src/protocol/libp2p/kademlia/query/mod.rs index 6177163d..f0287325 100644 --- a/src/protocol/libp2p/kademlia/query/mod.rs +++ b/src/protocol/libp2p/kademlia/query/mod.rs @@ -21,7 +21,10 @@ use crate::{ protocol::libp2p::kademlia::{ message::KademliaMessage, - query::{find_node::FindNodeContext, get_record::GetRecordContext}, + query::{ + find_node::FindNodeContext, + get_record::{GetRecordConfig, GetRecordContext}, + }, record::{Key as RecordKey, Record}, types::{KademliaPeer, Key}, PeerRecord, Quorum, @@ -83,7 +86,7 @@ enum QueryType { } /// Query action. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] pub enum QueryAction { /// Send message to peer. SendMessage { @@ -283,20 +286,20 @@ impl QueryEngine { ); let target = Key::new(target); + let config = GetRecordConfig { + local_peer_id: self.local_peer_id, + known_records: count, + quorum, + replication_factor: self.replication_factor, + parallelism_factor: self.parallelism_factor, + query: query_id, + target, + }; self.queries.insert( query_id, QueryType::GetRecord { - context: GetRecordContext::new( - self.local_peer_id, - query_id, - target, - candidates, - self.replication_factor, - self.parallelism_factor, - quorum, - count, - ), + context: GetRecordContext::new(config, candidates), }, ); @@ -395,7 +398,7 @@ impl QueryEngine { peers: context.peers_to_report, }, QueryType::GetRecord { context } => QueryAction::GetRecordQueryDone { - query_id: context.query, + query_id: context.config.query, records: context.found_records(), }, }