Skip to content

Commit

Permalink
feat: multiproof proof verification
Browse files Browse the repository at this point in the history
  • Loading branch information
cchudant committed Sep 28, 2024
1 parent f79dc23 commit 95da8db
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 105 deletions.
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ where
Database(DatabaseError),
/// Error when decoding a node
NodeDecodeError(parity_scale_codec::Error),
/// Error when creating a storage proof.
CreateProof(String),
}

impl<DatabaseError: DBError> core::convert::From<DatabaseError>
Expand Down Expand Up @@ -52,6 +54,7 @@ where
BonsaiStorageError::Merge(e) => write!(f, "Merge error: {}", e),
BonsaiStorageError::Database(e) => write!(f, "Database error: {}", e),
BonsaiStorageError::NodeDecodeError(e) => write!(f, "Node decode error: {}", e),
BonsaiStorageError::CreateProof(e) => write!(f, "Proof creation error: {}", e),
}
}
}
8 changes: 7 additions & 1 deletion src/trie/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ impl<H: StarkHash> NodeVisitor<H> for NoopVisitor<H> {
pub struct MerkleTreeIterator<'a, H: StarkHash, DB: BonsaiDatabase, ID: Id> {
pub(crate) tree: &'a mut MerkleTree<H>,
pub(crate) db: &'a KeyValueDB<DB, ID>,
/// Current iteration path.
pub(crate) current_path: Path,
/// The loaded nodes in the current path with their corresponding heights. Height is at the base of the node, meaning
/// the first node here will always have height 0.
pub(crate) current_nodes_heights: Vec<(NodeKey, usize)>,
/// Current leaf hash. Note that partial traversal (traversal that stops midway through the tree) will
/// also update this field if an exact match for the key is found, even though we may not have reached a leaf.
pub(crate) leaf_hash: Option<Felt>,
}

Expand All @@ -43,7 +48,8 @@ impl<'a, H: StarkHash, DB: BonsaiDatabase, ID: Id> fmt::Debug
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MerkleTreeIterator")
.field("cur_path", &self.current_path)
.field("cur_path_nodes_heights", &self.current_nodes_heights)
.field("current_nodes_heights", &self.current_nodes_heights)
.field("leaf_hash", &self.leaf_hash)
.finish()
}
}
Expand Down
244 changes: 140 additions & 104 deletions src/trie/proof.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use core::marker::PhantomData;

use super::{path::Path, tree::MerkleTree};
use super::{
merkle_node::{hash_binary_node, hash_edge_node, Direction},
path::Path,
tree::MerkleTree,
};
use crate::{
id::Id,
key_value_db::KeyValueDB,
Expand All @@ -9,17 +13,32 @@ use crate::{
merkle_node::{Node, NodeHandle},
tree::NodeKey,
},
BitSlice, BonsaiDatabase, BonsaiStorageError, HashMap,
BitSlice, BitVec, BonsaiDatabase, BonsaiStorageError, HashMap, HashSet,
};
use bitvec::view::BitView;
use hashbrown::hash_set;
use starknet_types_core::{felt::Felt, hash::StarkHash};

#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum Membership {
Member,
NonMember,
}

impl From<Membership> for bool {
fn from(value: Membership) -> Self {
value == Membership::Member
}
}

impl From<bool> for Membership {
fn from(value: bool) -> Self {
match value {
true => Self::Member,
false => Self::NonMember,
}
}
}

#[derive(Debug, Clone, PartialEq)]
pub enum ProofNode {
Binary { left: Felt, right: Felt },
Expand All @@ -29,27 +48,103 @@ pub enum ProofNode {
impl ProofNode {
pub fn hash<H: StarkHash>(&self) -> Felt {
match self {
ProofNode::Binary { left, right } => H::hash(left, right),
ProofNode::Edge { child, path } => {
let mut bytes = [0u8; 32];
bytes.view_bits_mut()[256 - path.0.len()..].copy_from_bitslice(&path.0);
// SAFETY: path len is <= 251
let path_hash = Felt::from_bytes_be(&bytes);

let length = Felt::from(path.0.len() as u8);
H::hash(child, &path_hash) + length
}
ProofNode::Binary { left, right } => hash_binary_node::<H>(*left, *right),
ProofNode::Edge { child, path } => hash_edge_node::<H>(path, *child),
}
}
}

#[derive(Debug, Clone)]
pub struct MultiProof(pub HashMap<Felt, ProofNode>);
impl MultiProof {
/// If the proof proves more than just the provided `key_values`, this function will not fail.
/// Not the most optimized way of doing it, but we don't actually need to verify proofs in madara.
/// As such, it has not been properly proptested.
pub fn verify_proof<'a, 'b: 'a, H: StarkHash>(
&'b self,
root: Felt,
key_values: impl IntoIterator<Item = (impl AsRef<BitSlice>, Felt)> + 'a,
) -> impl Iterator<Item = Membership> + 'a {
let mut checked_cache: HashSet<Felt> = Default::default();
key_values.into_iter().map(move |(k, v)| {
let k = k.as_ref();

// todo: find a way to disable this check in non-251bit key tests.
// if key.len() != 251 {
// return Err(BonsaiStorageError::CreateProof(format!("Key {key:b} is not the correct length.")));
// }

// Go down the tree, starting from the root.
let mut current_path = BitVec::with_capacity(251);
let mut current_felt = root;

loop {
log::trace!("Start verify loop: {current_path:b} => {current_felt:#x}");
if current_path.len() == k.len() {
// End of traversal, check if value is correct
log::trace!("End of traversal");
break (v == current_felt).into();
}
if current_path.len() > k.len() {
// We overshot.
log::trace!("Overshot");
break Membership::NonMember;
}
let Some(node) = self.0.get(&current_felt) else {
// Missing node.
log::trace!("Missing");
break Membership::NonMember;
};

// Check hash and save to verification cache.
if let hash_set::Entry::Vacant(entry) = checked_cache.entry(v) {
let computed_hash = node.hash::<H>();
if computed_hash != current_felt {
// Hash mismatch.
log::trace!("Hash mismatch: {computed_hash:#x} {current_felt:#x}");
break Membership::NonMember;
}
entry.insert();
}

match node {
ProofNode::Binary { left, right } => {
// PANIC: We checked above that current_path.len() < k.len().
let direction = Direction::from(k[current_path.len()]);
log::trace!("Binary {direction:?}");
current_path.push(direction.into());
current_felt = match direction {
Direction::Left => *left,
Direction::Right => *right,
}
}
ProofNode::Edge { child, path } => {
log::trace!("Edge");
if k.get(current_path.len()..(current_path.len() + path.len()))
!= Some(&path.0)
{
log::trace!("Wrong edge: {path:?}");
// Wrong edge path.
break Membership::NonMember;
}
current_path.extend_from_bitslice(&path.0);
current_felt = *child;
}
}
}
})
}
}

impl<H: StarkHash + Send + Sync> MerkleTree<H> {
/// This function is designed to be very efficient if the `keys` are sorted - this allows for
/// the minimal amount of backtracking when switching from one key to the next.
pub fn get_multi_proof<DB: BonsaiDatabase, ID: Id>(
&mut self,
db: &KeyValueDB<DB, ID>,
keys: impl IntoIterator<Item = impl AsRef<BitSlice>>,
) -> Result<HashMap<Felt, ProofNode>, BonsaiStorageError<DB::DatabaseError>> {
struct ProofVisitor<H>(HashMap<Felt, ProofNode>, PhantomData<H>);
) -> Result<MultiProof, BonsaiStorageError<DB::DatabaseError>> {
struct ProofVisitor<H>(MultiProof, PhantomData<H>);
impl<H: StarkHash + Send + Sync> NodeVisitor<H> for ProofVisitor<H> {
fn visit_node<DB: BonsaiDatabase>(
&mut self,
Expand All @@ -74,100 +169,28 @@ impl<H: StarkHash + Send + Sync> MerkleTree<H> {
}
};
let hash = tree.get_or_compute_node_hash::<DB>(NodeHandle::InMemory(node_id))?;
self.0.insert(hash, proof_node);
self.0 .0.insert(hash, proof_node);
Ok(())
}
}
let mut visitor = ProofVisitor::<H>(Default::default(), PhantomData);
let mut visitor = ProofVisitor::<H>(MultiProof(Default::default()), PhantomData);

let mut iter = self.iter(db);
for key in keys {
iter.traverse_to(&mut visitor, key.as_ref())?;
let key = key.as_ref();
// todo: find a way to disable this check in non-251bit key tests.
// if key.len() != 251 {
// return Err(BonsaiStorageError::CreateProof(format!("Key {key:b} is not the correct length.")));
// }
iter.traverse_to(&mut visitor, key)?;
// We should have found a leaf here.
iter.leaf_hash.ok_or_else(|| {
BonsaiStorageError::CreateProof(format!("Key {key:b} is not in the trie."))
})?;
}

Ok(visitor.0)
}

/// Function that come from pathfinder_merkle_tree::merkle_tree::MerkleTree
/// Verifies that the key `key` with value `value` is indeed part of the MPT that has root
/// `root`, given `proofs`.
/// Supports proofs of non-membership as well as proof of membership: this function returns
/// an enum corresponding to the membership of `value`, or returns `None` in case of a hash mismatch.
/// The algorithm follows this logic:
/// 1. init expected_hash <- root hash
/// 2. loop over nodes: current <- nodes[i]
/// 1. verify the current node's hash matches expected_hash (if not then we have a bad proof)
/// 2. move towards the target - if current is:
/// 1. binary node then choose the child that moves towards the target, else if
/// 2. edge node then check the path against the target bits
/// 1. If it matches then proceed with the child, else
/// 2. if it does not match then we now have a proof that the target does not exist
/// 3. nibble off target bits according to which child you got in (2). If all bits are gone then you
/// have reached the target and the child hash is the value you wanted and the proof is complete.
/// 4. set expected_hash <- to the child hash
/// 3. check that the expected_hash is `value` (we should've reached the leaf)
pub fn verify_proof(
_root: Felt,
_key: &BitSlice,
_value: Felt,
_proofs: &[ProofNode],
) -> Option<Membership> {
todo!()
// Protect from ill-formed keys
// if key.len() > 251 {
// return None;
// }

// let mut expected_hash = root;
// let mut remaining_path: &BitSlice = key;

// for proof_node in proofs.iter() {
// // Hash mismatch? Return None.
// if proof_node.hash::<H>() != expected_hash {
// return None;
// }
// match proof_node {
// ProofNode::Binary { left, right } => {
// // Direction will always correspond to the 0th index
// // because we're removing bits on every iteration.
// let direction = Direction::from(remaining_path[0]);

// // Set the next hash to be the left or right hash,
// // depending on the direction
// expected_hash = match direction {
// Direction::Left => *left,
// Direction::Right => *right,
// };

// // Advance by a single bit
// remaining_path = &remaining_path[1..];
// }
// ProofNode::Edge { child, path } => {
// if path.0 != remaining_path[..path.0.len()] {
// // If paths don't match, we've found a proof of non membership because we:
// // 1. Correctly moved towards the target insofar as is possible, and
// // 2. hashing all the nodes along the path does result in the root hash, which means
// // 3. the target definitely does not exist in this tree
// return Some(Membership::NonMember);
// }

// // Set the next hash to the child's hash
// expected_hash = *child;

// // Advance by the whole edge path
// remaining_path = &remaining_path[path.0.len()..];
// }
// }
// }

// // At this point, we should reach `value` !
// if expected_hash == value {
// Some(Membership::Member)
// } else {
// // Hash mismatch. Return `None`.
// None
// }
}
}

#[cfg(test)]
Expand Down Expand Up @@ -218,13 +241,26 @@ mod tests {
.get_mut(&smallvec::smallvec![])
.unwrap();

let proof = tree.get_multi_proof(&bonsai_storage.tries.db, [
bits![u8, Msb0; 0,0,0,1,0,0,0,1],
bits![u8, Msb0; 0,1,0,0,0,0,0,0],
])
let proof = tree
.get_multi_proof(
&bonsai_storage.tries.db,
[
bits![u8, Msb0; 0,0,0,1,0,0,0,1],
bits![u8, Msb0; 0,1,0,0,0,0,0,0],
],
)
.unwrap();

log::trace!("proof: {proof:?}");
assert_eq!(
proof
.verify_proof::<Pedersen>(
tree.root_hash(&bonsai_storage.tries.db).unwrap(),
[(bits![u8, Msb0; 0,0,0,1,0,0,0,0], ONE)]
)
.all(|v| v.into()),
true
);
todo!()
}
}

0 comments on commit 95da8db

Please sign in to comment.