diff --git a/consensus/src/dag/dag_driver.rs b/consensus/src/dag/dag_driver.rs index 682602be07f5c..01b33dedb5a17 100644 --- a/consensus/src/dag/dag_driver.rs +++ b/consensus/src/dag/dag_driver.rs @@ -37,9 +37,9 @@ use aptos_validator_transaction_pool as vtxn_pool; use async_trait::async_trait; use futures::{ executor::block_on, - future::{AbortHandle, Abortable}, - FutureExt, + future::{join, AbortHandle, Abortable}, }; +use futures_channel::oneshot; use std::{collections::HashSet, sync::Arc, time::Duration}; use tokio_retry::strategy::ExponentialBackoff; @@ -284,8 +284,9 @@ impl DagDriver { let rb = self.reliable_broadcast.clone(); let rb2 = self.reliable_broadcast.clone(); let (abort_handle, abort_registration) = AbortHandle::new_pair(); + let (tx, rx) = oneshot::channel(); let signature_builder = - SignatureBuilder::new(node.metadata().clone(), self.epoch_state.clone()); + SignatureBuilder::new(node.metadata().clone(), self.epoch_state.clone(), tx); let cert_ack_set = CertificateAckState::new(self.epoch_state.verifier.len()); let latest_ledger_info = self.ledger_info_provider.clone(); @@ -298,7 +299,12 @@ impl DagDriver { defer!( observe_round(timestamp, RoundStage::NodeBroadcasted); ); rb.broadcast(node, signature_builder).await }; - let core_task = node_broadcast.then(move |certificate| { + let certified_broadcast = async move { + let Ok(certificate) = rx.await else { + error!("channel closed before receiving ceritifcate"); + return; + }; + debug!( LogSchema::new(LogEvent::BroadcastCertifiedNode), id = node_clone.id() @@ -311,8 +317,9 @@ impl DagDriver { certified_node, latest_ledger_info.get_latest_ledger_info(), ); - rb2.broadcast(certified_node_msg, cert_ack_set) - }); + rb2.broadcast(certified_node_msg, cert_ack_set).await + }; + let core_task = join(node_broadcast, certified_broadcast); let author = self.author; let task = async move { debug!("{} Start reliable broadcast for round {}", author, round); diff --git a/consensus/src/dag/dag_store.rs b/consensus/src/dag/dag_store.rs index 7c9bf95d713d7..4023830412b92 100644 --- a/consensus/src/dag/dag_store.rs +++ b/consensus/src/dag/dag_store.rs @@ -1,7 +1,10 @@ // Copyright © Aptos Foundation // SPDX-License-Identifier: Apache-2.0 -use super::types::{DagSnapshotBitmask, NodeMetadata}; +use super::{ + types::{DagSnapshotBitmask, NodeMetadata}, + Node, +}; use crate::{ dag::{ storage::DAGStorage, @@ -23,19 +26,23 @@ use std::{ #[derive(Clone)] pub enum NodeStatus { - Unordered(Arc), + Unordered { + node: Arc, + aggregated_weak_voting_power: u128, + aggregated_strong_voting_power: u128, + }, Ordered(Arc), } impl NodeStatus { pub fn as_node(&self) -> &Arc { match self { - NodeStatus::Unordered(node) | NodeStatus::Ordered(node) => node, + NodeStatus::Unordered { node, .. } | NodeStatus::Ordered(node) => node, } } pub fn mark_as_ordered(&mut self) { - assert!(matches!(self, NodeStatus::Unordered(_))); + assert!(matches!(self, NodeStatus::Unordered { .. })); *self = NodeStatus::Ordered(self.as_node().clone()); } } @@ -107,7 +114,12 @@ impl InMemDag { .get_node_ref_mut(node.round(), node.author()) .expect("must be present"); ensure!(round_ref.is_none(), "race during insertion"); - *round_ref = Some(NodeStatus::Unordered(node.clone())); + *round_ref = Some(NodeStatus::Unordered { + node: node.clone(), + aggregated_weak_voting_power: 0, + aggregated_strong_voting_power: 0, + }); + self.update_votes(&node, true); Ok(()) } @@ -149,6 +161,39 @@ impl InMemDag { Ok(()) } + pub fn update_votes(&mut self, node: &Node, update_link_power: bool) { + if node.round() <= self.lowest_round() { + return; + } + + let voting_power = self + .epoch_state + .verifier + .get_voting_power(node.author()) + .expect("must exist"); + + for parent in node.parents_metadata() { + let node_status = self + .get_node_ref_mut(parent.round(), parent.author()) + .expect("must exist"); + match node_status { + Some(NodeStatus::Unordered { + aggregated_weak_voting_power, + aggregated_strong_voting_power, + .. + }) => { + if update_link_power { + *aggregated_strong_voting_power += voting_power as u128; + } else { + *aggregated_weak_voting_power += voting_power as u128; + } + }, + Some(NodeStatus::Ordered(_)) => {}, + None => unreachable!("parents must exist before voting for a node"), + } + } + } + pub fn exists(&self, metadata: &NodeMetadata) -> bool { self.get_node_ref_by_metadata(metadata).is_some() } @@ -211,24 +256,29 @@ impl InMemDag { .map(|node_status| node_status.as_node()) } - // TODO: I think we can cache votes in the NodeStatus::Unordered pub fn check_votes_for_node( &self, metadata: &NodeMetadata, validator_verifier: &ValidatorVerifier, ) -> bool { - self.get_round_iter(metadata.round() + 1) - .map(|next_round_iter| { - let votes = next_round_iter - .filter(|node_status| { - node_status - .as_node() - .parents() - .iter() - .any(|cert| cert.metadata() == metadata) - }) - .map(|node_status| node_status.as_node().author()); - validator_verifier.check_voting_power(votes, false).is_ok() + self.get_node_ref_by_metadata(metadata) + .map(|node_status| match node_status { + NodeStatus::Unordered { + aggregated_weak_voting_power, + aggregated_strong_voting_power, + .. + } => { + validator_verifier + .check_aggregated_voting_power(*aggregated_weak_voting_power, true) + .is_ok() + || validator_verifier + .check_aggregated_voting_power(*aggregated_strong_voting_power, false) + .is_ok() + }, + NodeStatus::Ordered(_) => { + error!("checking voting power for Ordered node"); + true + }, }) .unwrap_or(false) } @@ -260,7 +310,7 @@ impl InMemDag { .flat_map(|(_, round_ref)| round_ref.iter_mut()) .flatten() .filter(move |node_status| { - matches!(node_status, NodeStatus::Unordered(_)) + matches!(node_status, NodeStatus::Unordered { .. }) && reachable_filter(node_status.as_node()) }) } diff --git a/consensus/src/dag/order_rule.rs b/consensus/src/dag/order_rule.rs index 839e53f2a3d83..3f3a7aea00f66 100644 --- a/consensus/src/dag/order_rule.rs +++ b/consensus/src/dag/order_rule.rs @@ -136,7 +136,7 @@ impl OrderRule { .reachable( Some(current_anchor.metadata().clone()).iter(), Some(*self.lowest_unordered_anchor_round.read()), - |node_status| matches!(node_status, NodeStatus::Unordered(_)), + |node_status| matches!(node_status, NodeStatus::Unordered { .. }), ) // skip the current anchor itself .skip(1) diff --git a/consensus/src/dag/rb_handler.rs b/consensus/src/dag/rb_handler.rs index 2d3f307030d24..a484b75d6c0b9 100644 --- a/consensus/src/dag/rb_handler.rs +++ b/consensus/src/dag/rb_handler.rs @@ -104,6 +104,13 @@ impl NodeBroadcastHandler { } fn validate(&self, node: Node) -> anyhow::Result { + ensure!( + node.epoch() == self.epoch_state.epoch, + "different epoch {}, current {}", + node.epoch(), + self.epoch_state.epoch + ); + let num_vtxns = node.validator_txns().len() as u64; ensure!(num_vtxns <= self.vtxn_config.per_block_limit_txn_count()); for vtxn in node.validator_txns() { @@ -239,6 +246,8 @@ impl RpcHandler for NodeBroadcastHandler { .expect("must exist") .insert(*node.author(), vote.clone()); + self.dag.write().update_votes(&node, false); + debug!(LogSchema::new(LogEvent::Vote) .remote_peer(*node.author()) .round(node.round())); diff --git a/consensus/src/dag/tests/helpers.rs b/consensus/src/dag/tests/helpers.rs index d401968ff2664..7b651ac2d2bb0 100644 --- a/consensus/src/dag/tests/helpers.rs +++ b/consensus/src/dag/tests/helpers.rs @@ -40,7 +40,7 @@ pub(crate) fn new_node( parents: Vec, ) -> Node { Node::new( - 0, + 1, round, author, timestamp, diff --git a/consensus/src/dag/tests/rb_handler_tests.rs b/consensus/src/dag/tests/rb_handler_tests.rs index 11938cd7fd19d..9713f7e31f073 100644 --- a/consensus/src/dag/tests/rb_handler_tests.rs +++ b/consensus/src/dag/tests/rb_handler_tests.rs @@ -223,7 +223,7 @@ async fn test_node_broadcast_receiver_storage() { let sig = rb_receiver.process(node).await.expect("must succeed"); assert_ok_eq!(storage.get_votes(), vec![( - NodeId::new(0, 1, signers[0].author()), + NodeId::new(1, 1, signers[0].author()), sig )],); diff --git a/consensus/src/dag/types.rs b/consensus/src/dag/types.rs index a25248aecff41..4f677bf193961 100644 --- a/consensus/src/dag/types.rs +++ b/consensus/src/dag/types.rs @@ -30,12 +30,13 @@ use aptos_types::{ validator_txn::ValidatorTransaction, validator_verifier::ValidatorVerifier, }; +use futures_channel::oneshot; use serde::{Deserialize, Serialize}; use std::{ cmp::min, collections::HashSet, fmt::{Display, Formatter}, - ops::Deref, + ops::{Deref, DerefMut}, sync::Arc, }; @@ -343,7 +344,7 @@ impl Node { } } -#[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Hash, Clone)] +#[derive(Serialize, Deserialize, PartialEq, Debug, Eq, Hash, Clone, PartialOrd, Ord)] pub struct NodeId { epoch: u64, round: Round, @@ -534,47 +535,65 @@ impl TryFrom for Vote { pub struct SignatureBuilder { metadata: NodeMetadata, - partial_signatures: Mutex, + inner: Mutex<(PartialSignatures, Option>)>, epoch_state: Arc, } impl SignatureBuilder { - pub fn new(metadata: NodeMetadata, epoch_state: Arc) -> Arc { + pub fn new( + metadata: NodeMetadata, + epoch_state: Arc, + tx: oneshot::Sender, + ) -> Arc { Arc::new(Self { metadata, - partial_signatures: Mutex::new(PartialSignatures::empty()), + inner: Mutex::new((PartialSignatures::empty(), Some(tx))), epoch_state, }) } } impl BroadcastStatus for Arc { - type Aggregated = NodeCertificate; + type Aggregated = (); type Message = Node; type Response = Vote; + /// Processes the [Vote]s received for a given [Node]. Once a supermajority voting power + /// is reached, this method sends [NodeCertificate] into a channel. It will only return + /// successfully when [Vote]s are received from all the peers. fn add(&self, peer: Author, ack: Self::Response) -> anyhow::Result> { ensure!(self.metadata == ack.metadata, "Digest mismatch"); ack.verify(peer, &self.epoch_state.verifier)?; debug!(LogSchema::new(LogEvent::ReceiveVote) .remote_peer(peer) .round(self.metadata.round())); - let mut signatures_lock = self.partial_signatures.lock(); - signatures_lock.add_signature(peer, ack.signature); - Ok(self - .epoch_state - .verifier - .check_voting_power(signatures_lock.signatures().keys(), true) - .ok() - .map(|_| { - let aggregated_signature = self - .epoch_state - .verifier - .aggregate_signatures(&signatures_lock) - .expect("Signature aggregation should succeed"); - observe_node(self.metadata.timestamp(), NodeStage::CertAggregated); - NodeCertificate::new(self.metadata.clone(), aggregated_signature) - })) + let mut guard = self.inner.lock(); + let (partial_signatures, tx) = guard.deref_mut(); + partial_signatures.add_signature(peer, ack.signature); + + if tx.is_some() + && self + .epoch_state + .verifier + .check_voting_power(partial_signatures.signatures().keys(), true) + .is_ok() + { + let aggregated_signature = self + .epoch_state + .verifier + .aggregate_signatures(partial_signatures) + .expect("Signature aggregation should succeed"); + observe_node(self.metadata.timestamp(), NodeStage::CertAggregated); + let certificate = NodeCertificate::new(self.metadata.clone(), aggregated_signature); + + _ = tx.take().expect("must exist").send(certificate); + } + + if partial_signatures.signatures().len() == self.epoch_state.verifier.len() { + Ok(Some(())) + } else { + Ok(None) + } } } diff --git a/types/src/validator_verifier.rs b/types/src/validator_verifier.rs index 5e48ea057a4a7..f035e3bb0bf02 100644 --- a/types/src/validator_verifier.rs +++ b/types/src/validator_verifier.rs @@ -381,7 +381,14 @@ impl ValidatorVerifier { check_super_majority: bool, ) -> std::result::Result { let aggregated_voting_power = self.sum_voting_power(authors)?; + self.check_aggregated_voting_power(aggregated_voting_power, check_super_majority) + } + pub fn check_aggregated_voting_power( + &self, + aggregated_voting_power: u128, + check_super_majority: bool, + ) -> std::result::Result { let target = if check_super_majority { self.quorum_voting_power } else {