diff --git a/Cargo.toml b/Cargo.toml index 4ab77d3..85e7393 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,8 @@ derive_more = { version = "0.99.17", default-features = false, features = [ ] } hashbrown = "0.14.3" log = "0.4.20" +smallvec = "1.11.2" + parity-scale-codec = { version = "3.0.0", default-features = false, features = [ "derive", ] } @@ -44,3 +46,4 @@ pathfinder-merkle-tree = { git = "https://github.com/massalabs/pathfinder.git", pathfinder-storage = { git = "https://github.com/massalabs/pathfinder.git", package = "pathfinder-storage", rev = "b7b6d76a76ab0e10f92e5f84ce099b5f727cb4db" } rand = "0.8.5" tempfile = "3.8.0" +rstest = "0.18.2" diff --git a/ensure_no_std/Cargo.lock b/ensure_no_std/Cargo.lock index 05af383..2861af3 100644 --- a/ensure_no_std/Cargo.lock +++ b/ensure_no_std/Cargo.lock @@ -63,6 +63,7 @@ dependencies = [ "log", "parity-scale-codec", "serde", + "smallvec", "starknet-types-core", ] @@ -379,6 +380,12 @@ dependencies = [ "keccak", ] +[[package]] +name = "smallvec" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" + [[package]] name = "spin" version = "0.5.2" diff --git a/src/bonsai_database.rs b/src/bonsai_database.rs index 0d49152..39b07ae 100644 --- a/src/bonsai_database.rs +++ b/src/bonsai_database.rs @@ -1,62 +1,62 @@ -use crate::{changes::ChangeKeyType, error::BonsaiStorageError, id::Id}; +use crate::id::Id; #[cfg(not(feature = "std"))] use alloc::vec::Vec; +#[cfg(feature = "std")] +use std::error::Error; +/// Key in the database of the different elements that can be stored in the database. #[derive(Debug, Hash, PartialEq, Eq)] -pub enum KeyType<'a> { +pub enum DatabaseKey<'a> { Trie(&'a [u8]), Flat(&'a [u8]), TrieLog(&'a [u8]), } -impl<'a> From<&'a ChangeKeyType> for KeyType<'a> { - fn from(change_key: &'a ChangeKeyType) -> Self { - match change_key { - ChangeKeyType::Trie(key) => KeyType::Trie(key.as_slice()), - ChangeKeyType::Flat(key) => KeyType::Flat(key.as_slice()), - } - } -} - -impl KeyType<'_> { +impl DatabaseKey<'_> { pub fn as_slice(&self) -> &[u8] { match self { - KeyType::Trie(slice) => slice, - KeyType::Flat(slice) => slice, - KeyType::TrieLog(slice) => slice, + DatabaseKey::Trie(slice) => slice, + DatabaseKey::Flat(slice) => slice, + DatabaseKey::TrieLog(slice) => slice, } } } +#[cfg(feature = "std")] +pub trait DBError: Error + Send + Sync {} + +#[cfg(not(feature = "std"))] +pub trait DBError: Send + Sync {} + /// Trait to be implemented on any type that can be used as a database. pub trait BonsaiDatabase { type Batch: Default; #[cfg(feature = "std")] - type DatabaseError: std::error::Error + Into; + type DatabaseError: Error + DBError; #[cfg(not(feature = "std"))] - type DatabaseError: Into; + type DatabaseError: DBError; /// Create a new empty batch of changes to be used in `insert`, `remove` and applied in database using `write_batch`. fn create_batch(&self) -> Self::Batch; /// Returns the value of the key if it exists - fn get(&self, key: &KeyType) -> Result>, Self::DatabaseError>; + fn get(&self, key: &DatabaseKey) -> Result>, Self::DatabaseError>; #[allow(clippy::type_complexity)] /// Returns all values with keys that start with the given prefix fn get_by_prefix( &self, - prefix: &KeyType, + prefix: &DatabaseKey, ) -> Result, Vec)>, Self::DatabaseError>; /// Returns true if the key exists - fn contains(&self, key: &KeyType) -> Result; + fn contains(&self, key: &DatabaseKey) -> Result; /// Insert a new key-value pair, returns the old value if it existed. /// If a batch is provided, the change will be written in the batch instead of the database. fn insert( &mut self, - key: &KeyType, + key: &DatabaseKey, value: &[u8], batch: Option<&mut Self::Batch>, ) -> Result>, Self::DatabaseError>; @@ -65,12 +65,12 @@ pub trait BonsaiDatabase { /// If a batch is provided, the change will be written in the batch instead of the database. fn remove( &mut self, - key: &KeyType, + key: &DatabaseKey, batch: Option<&mut Self::Batch>, ) -> Result>, Self::DatabaseError>; /// Remove all keys that start with the given prefix - fn remove_by_prefix(&mut self, prefix: &KeyType) -> Result<(), Self::DatabaseError>; + fn remove_by_prefix(&mut self, prefix: &DatabaseKey) -> Result<(), Self::DatabaseError>; /// Write batch of changes directly in the database fn write_batch(&mut self, batch: Self::Batch) -> Result<(), Self::DatabaseError>; @@ -81,11 +81,11 @@ pub trait BonsaiDatabase { } pub trait BonsaiPersistentDatabase { + type Transaction: BonsaiDatabase; #[cfg(feature = "std")] - type DatabaseError: std::error::Error + Into; + type DatabaseError: Error + DBError; #[cfg(not(feature = "std"))] - type DatabaseError: Into; - type Transaction: BonsaiDatabase; + type DatabaseError: DBError; /// Save a snapshot of the current database state /// This function returns a snapshot id that can be used to create a transaction fn snapshot(&mut self, id: ID); diff --git a/src/changes.rs b/src/changes.rs index 12fac91..3d9a19c 100644 --- a/src/changes.rs +++ b/src/changes.rs @@ -1,4 +1,4 @@ -use crate::id::Id; +use crate::{id::Id, trie::TrieKey}; use serde::{Deserialize, Serialize}; #[cfg(feature = "std")] use std::collections::{hash_map::Entry, HashMap, VecDeque}; @@ -14,45 +14,15 @@ pub struct Change { pub new_value: Option>, } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum ChangeKeyType { - Trie(Vec), - Flat(Vec), -} - -impl ChangeKeyType { - pub fn get_id(&self) -> u8 { - match self { - ChangeKeyType::Trie(_) => 0, - ChangeKeyType::Flat(_) => 1, - } - } - - pub fn as_slice(&self) -> &[u8] { - match self { - ChangeKeyType::Trie(key) => key.as_slice(), - ChangeKeyType::Flat(key) => key.as_slice(), - } - } - - pub fn from_id(id: u8, key: Vec) -> Self { - match id { - 0 => ChangeKeyType::Trie(key), - 1 => ChangeKeyType::Flat(key), - _ => panic!("Invalid id"), - } - } -} - #[derive(Debug, Default)] -pub struct ChangeBatch(pub(crate) HashMap); +pub struct ChangeBatch(pub(crate) HashMap); const KEY_SEPARATOR: u8 = 0x00; const NEW_VALUE: u8 = 0x00; const OLD_VALUE: u8 = 0x01; impl ChangeBatch { - pub fn insert_in_place(&mut self, key: ChangeKeyType, change: Change) { + pub fn insert_in_place(&mut self, key: TrieKey, change: Change) { match self.0.entry(key) { Entry::Occupied(mut entry) => { let e = entry.get_mut(); @@ -68,7 +38,7 @@ impl ChangeBatch { } pub fn serialize(&self, id: &ID) -> Vec<(Vec, &[u8])> { - let id = id.serialize(); + let id = id.to_bytes(); self.0 .iter() .flat_map(|(change_key, change)| { @@ -76,11 +46,16 @@ impl ChangeBatch { let mut changes = Vec::new(); if let Some(old_value) = &change.old_value { + if let Some(new_value) = &change.new_value { + if old_value == new_value { + return changes; + } + } let key = [ id.as_slice(), &[KEY_SEPARATOR], key_slice, - &[change_key.get_id()], + &[change_key.into()], &[OLD_VALUE], ] .concat(); @@ -92,7 +67,7 @@ impl ChangeBatch { id.as_slice(), &[KEY_SEPARATOR], key_slice, - &[change_key.get_id()], + &[change_key.into()], &[NEW_VALUE], ] .concat(); @@ -104,7 +79,7 @@ impl ChangeBatch { } pub fn deserialize(id: &ID, changes: Vec<(Vec, Vec)>) -> Self { - let id = id.serialize(); + let id = id.to_bytes(); let mut change_batch = ChangeBatch(HashMap::new()); let mut current_change = Change::default(); let mut last_key = None; @@ -116,7 +91,8 @@ impl ChangeBatch { let mut key = key.to_vec(); let change_type = key.pop().unwrap(); let key_type = key.pop().unwrap(); - let change_key = ChangeKeyType::from_id(key_type, key[id.len() + 1..].to_vec()); + let change_key = + TrieKey::from_variant_and_bytes(key_type, key[id.len() + 1..].to_vec()); if let Some(last_key) = last_key { if last_key != change_key { change_batch.insert_in_place(last_key, current_change); diff --git a/src/databases/hashmap_db.rs b/src/databases/hashmap_db.rs index 74a2d13..f11914e 100644 --- a/src/databases/hashmap_db.rs +++ b/src/databases/hashmap_db.rs @@ -1,11 +1,10 @@ use crate::{ - bonsai_database::BonsaiPersistentDatabase, error::BonsaiStorageError, id::Id, BonsaiDatabase, + bonsai_database::{BonsaiPersistentDatabase, DBError}, + id::Id, + BonsaiDatabase, }; #[cfg(not(feature = "std"))] -use alloc::{ - vec::Vec, - {collections::BTreeMap, string::ToString}, -}; +use alloc::{collections::BTreeMap, vec::Vec}; use core::{fmt, fmt::Display}; #[cfg(not(feature = "std"))] use hashbrown::HashMap; @@ -24,11 +23,7 @@ impl Display for HashMapDbError { } } -impl From for BonsaiStorageError { - fn from(err: HashMapDbError) -> Self { - Self::Database(err.to_string()) - } -} +impl DBError for HashMapDbError {} #[derive(Clone, Default)] pub struct HashMapDb { @@ -44,7 +39,7 @@ impl BonsaiDatabase for HashMapDb { fn remove_by_prefix( &mut self, - prefix: &crate::bonsai_database::KeyType, + prefix: &crate::bonsai_database::DatabaseKey, ) -> Result<(), Self::DatabaseError> { let mut keys_to_remove = Vec::new(); for key in self.db.keys() { @@ -60,14 +55,14 @@ impl BonsaiDatabase for HashMapDb { fn get( &self, - key: &crate::bonsai_database::KeyType, + key: &crate::bonsai_database::DatabaseKey, ) -> Result>, Self::DatabaseError> { Ok(self.db.get(key.as_slice()).cloned()) } fn get_by_prefix( &self, - prefix: &crate::bonsai_database::KeyType, + prefix: &crate::bonsai_database::DatabaseKey, ) -> Result, Vec)>, Self::DatabaseError> { let mut result = Vec::new(); for (key, value) in self.db.iter() { @@ -80,7 +75,7 @@ impl BonsaiDatabase for HashMapDb { fn insert( &mut self, - key: &crate::bonsai_database::KeyType, + key: &crate::bonsai_database::DatabaseKey, value: &[u8], _batch: Option<&mut Self::Batch>, ) -> Result>, Self::DatabaseError> { @@ -89,13 +84,16 @@ impl BonsaiDatabase for HashMapDb { fn remove( &mut self, - key: &crate::bonsai_database::KeyType, + key: &crate::bonsai_database::DatabaseKey, _batch: Option<&mut Self::Batch>, ) -> Result>, Self::DatabaseError> { Ok(self.db.remove(key.as_slice())) } - fn contains(&self, key: &crate::bonsai_database::KeyType) -> Result { + fn contains( + &self, + key: &crate::bonsai_database::DatabaseKey, + ) -> Result { Ok(self.db.contains_key(key.as_slice())) } diff --git a/src/databases/rocks_db.rs b/src/databases/rocks_db.rs index 6819586..dce95d3 100644 --- a/src/databases/rocks_db.rs +++ b/src/databases/rocks_db.rs @@ -12,9 +12,8 @@ use rocksdb::{ }; use crate::{ - bonsai_database::{BonsaiDatabase, BonsaiPersistentDatabase, KeyType}, + bonsai_database::{BonsaiDatabase, BonsaiPersistentDatabase, DBError, DatabaseKey}, id::Id, - BonsaiStorageError, }; use log::trace; @@ -89,12 +88,6 @@ pub enum RocksDBError { Custom(String), } -impl From for BonsaiStorageError { - fn from(err: RocksDBError) -> Self { - Self::Database(err.to_string()) - } -} - impl From for RocksDBError { fn from(err: Error) -> Self { Self::RocksDB(err) @@ -110,6 +103,8 @@ impl fmt::Display for RocksDBError { } } +impl DBError for RocksDBError {} + impl StdError for RocksDBError { fn cause(&self) -> Option<&dyn StdError> { match self { @@ -126,12 +121,12 @@ impl StdError for RocksDBError { } } -impl KeyType<'_> { +impl DatabaseKey<'_> { fn get_cf(&self) -> &'static str { match self { - KeyType::Trie(_) => TRIE_CF, - KeyType::Flat(_) => FLAT_CF, - KeyType::TrieLog(_) => TRIE_LOG_CF, + DatabaseKey::Trie(_) => TRIE_CF, + DatabaseKey::Flat(_) => FLAT_CF, + DatabaseKey::TrieLog(_) => TRIE_LOG_CF, } } } @@ -186,7 +181,7 @@ where fn insert( &mut self, - key: &KeyType, + key: &DatabaseKey, value: &[u8], batch: Option<&mut Self::Batch>, ) -> Result>, Self::DatabaseError> { @@ -201,7 +196,7 @@ where Ok(old_value) } - fn get(&self, key: &KeyType) -> Result>, Self::DatabaseError> { + fn get(&self, key: &DatabaseKey) -> Result>, Self::DatabaseError> { trace!("Getting from RocksDB: {:?}", key); let handle = self.db.cf_handle(key.get_cf()).expect(CF_ERROR); Ok(self.db.get_cf(&handle, key.as_slice())?) @@ -209,7 +204,7 @@ where fn get_by_prefix( &self, - prefix: &KeyType, + prefix: &DatabaseKey, ) -> Result, Vec)>, Self::DatabaseError> { trace!("Getting from RocksDB: {:?}", prefix); let handle = self.db.cf_handle(prefix.get_cf()).expect(CF_ERROR); @@ -232,7 +227,7 @@ where .collect()) } - fn contains(&self, key: &KeyType) -> Result { + fn contains(&self, key: &DatabaseKey) -> Result { trace!("Checking if RocksDB contains: {:?}", key); let handle = self.db.cf_handle(key.get_cf()).expect(CF_ERROR); Ok(self @@ -243,7 +238,7 @@ where fn remove( &mut self, - key: &KeyType, + key: &DatabaseKey, batch: Option<&mut Self::Batch>, ) -> Result>, Self::DatabaseError> { trace!("Removing from RocksDB: {:?}", key); @@ -257,7 +252,7 @@ where Ok(old_value) } - fn remove_by_prefix(&mut self, prefix: &KeyType) -> Result<(), Self::DatabaseError> { + fn remove_by_prefix(&mut self, prefix: &DatabaseKey) -> Result<(), Self::DatabaseError> { trace!("Getting from RocksDB: {:?}", prefix); let handle = self.db.cf_handle(prefix.get_cf()).expect(CF_ERROR); let iter = self.db.iterator_cf( @@ -328,7 +323,7 @@ impl<'db> BonsaiDatabase for RocksDBTransaction<'db> { fn insert( &mut self, - key: &KeyType, + key: &DatabaseKey, value: &[u8], batch: Option<&mut Self::Batch>, ) -> Result>, Self::DatabaseError> { @@ -345,7 +340,7 @@ impl<'db> BonsaiDatabase for RocksDBTransaction<'db> { Ok(old_value) } - fn get(&self, key: &KeyType) -> Result>, Self::DatabaseError> { + fn get(&self, key: &DatabaseKey) -> Result>, Self::DatabaseError> { trace!("Getting from RocksDB: {:?}", key); let handle = self.column_families.get(key.get_cf()).expect(CF_ERROR); Ok(self @@ -355,7 +350,7 @@ impl<'db> BonsaiDatabase for RocksDBTransaction<'db> { fn get_by_prefix( &self, - prefix: &KeyType, + prefix: &DatabaseKey, ) -> Result, Vec)>, Self::DatabaseError> { trace!("Getting from RocksDB: {:?}", prefix); let handle = self.column_families.get(prefix.get_cf()).expect(CF_ERROR); @@ -378,7 +373,7 @@ impl<'db> BonsaiDatabase for RocksDBTransaction<'db> { .collect()) } - fn contains(&self, key: &KeyType) -> Result { + fn contains(&self, key: &DatabaseKey) -> Result { trace!("Checking if RocksDB contains: {:?}", key); let handle = self.column_families.get(key.get_cf()).expect(CF_ERROR); Ok(self @@ -389,7 +384,7 @@ impl<'db> BonsaiDatabase for RocksDBTransaction<'db> { fn remove( &mut self, - key: &KeyType, + key: &DatabaseKey, batch: Option<&mut Self::Batch>, ) -> Result>, Self::DatabaseError> { trace!("Removing from RocksDB: {:?}", key); @@ -405,7 +400,7 @@ impl<'db> BonsaiDatabase for RocksDBTransaction<'db> { Ok(old_value) } - fn remove_by_prefix(&mut self, prefix: &KeyType) -> Result<(), Self::DatabaseError> { + fn remove_by_prefix(&mut self, prefix: &DatabaseKey) -> Result<(), Self::DatabaseError> { trace!("Getting from RocksDB: {:?}", prefix); let mut batch = self.create_batch(); { diff --git a/src/error.rs b/src/error.rs index e49ddd1..92e8ff4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,8 +1,16 @@ +#[cfg(feature = "std")] +use std::{error::Error, fmt::Display}; + +use crate::bonsai_database::DBError; + #[cfg(not(feature = "std"))] use alloc::string::String; /// All errors that can be returned by BonsaiStorage. #[derive(Debug)] -pub enum BonsaiStorageError { +pub enum BonsaiStorageError +where + DatabaseError: DBError, +{ /// Error from the underlying trie. Trie(String), /// Error when trying to go to a specific commit ID. @@ -12,5 +20,40 @@ pub enum BonsaiStorageError { /// Error when trying to merge a transactional state. Merge(String), /// Error from the underlying database. - Database(String), + Database(DatabaseError), + /// Error when decoding a node + NodeDecodeError(parity_scale_codec::Error), +} + +impl core::convert::From + for BonsaiStorageError +{ + fn from(value: DatabaseError) -> Self { + Self::Database(value) + } +} + +impl core::convert::From + for BonsaiStorageError +{ + fn from(value: parity_scale_codec::Error) -> Self { + Self::NodeDecodeError(value) + } +} + +#[cfg(feature = "std")] +impl Display for BonsaiStorageError +where + DatabaseError: Error + DBError, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BonsaiStorageError::Trie(e) => write!(f, "Trie error: {}", e), + BonsaiStorageError::GoTo(e) => write!(f, "GoTo error: {}", e), + BonsaiStorageError::Transaction(e) => write!(f, "Transaction error: {}", e), + BonsaiStorageError::Merge(e) => write!(f, "Merge error: {}", e), + BonsaiStorageError::Database(e) => write!(f, "Database error: {}", e), + BonsaiStorageError::NodeDecodeError(e) => write!(f, "Node decode error: {}", e), + } + } } diff --git a/src/id.rs b/src/id.rs index 0495067..a768b2c 100644 --- a/src/id.rs +++ b/src/id.rs @@ -4,7 +4,7 @@ use core::{fmt::Debug, hash}; /// Trait to be implemented on any type that can be used as an ID. pub trait Id: hash::Hash + PartialEq + Eq + PartialOrd + Ord + Debug + Copy { - fn serialize(&self) -> Vec; + fn to_bytes(&self) -> Vec; } /// A basic ID type that can be used for testing. @@ -12,7 +12,7 @@ pub trait Id: hash::Hash + PartialEq + Eq + PartialOrd + Ord + Debug + Copy { pub struct BasicId(u64); impl Id for BasicId { - fn serialize(&self) -> Vec { + fn to_bytes(&self) -> Vec { self.0.to_be_bytes().to_vec() } } @@ -36,7 +36,8 @@ impl BasicIdBuilder { /// Create a new ID (unique). pub fn new_id(&mut self) -> BasicId { - self.last_id += 1; - BasicId(self.last_id) + let id = BasicId(self.last_id); + self.last_id = self.last_id.checked_add(1).expect("Id overflow"); + id } } diff --git a/src/key_value_db.rs b/src/key_value_db.rs index aad5dc9..7e786f4 100644 --- a/src/key_value_db.rs +++ b/src/key_value_db.rs @@ -5,10 +5,10 @@ use log::trace; use std::collections::BTreeSet; use crate::{ - bonsai_database::{BonsaiDatabase, BonsaiPersistentDatabase, KeyType}, + bonsai_database::{BonsaiDatabase, BonsaiPersistentDatabase, DatabaseKey}, changes::{Change, ChangeBatch, ChangeStore}, id::Id, - trie::TrieKeyType, + trie::TrieKey, BonsaiStorageConfig, BonsaiStorageError, }; @@ -70,7 +70,6 @@ impl KeyValueDB where DB: BonsaiDatabase, ID: Id, - BonsaiStorageError: core::convert::From<::DatabaseError>, { pub(crate) fn new(underline_db: DB, config: KeyValueDBConfig, created_at: Option) -> Self { let mut changes_store = ChangeStore::new(); @@ -87,7 +86,7 @@ where } } - pub(crate) fn commit(&mut self, id: ID) -> Result<(), BonsaiStorageError> { + pub(crate) fn commit(&mut self, id: ID) -> Result<(), BonsaiStorageError> { if Some(&id) > self.changes_store.id_queue.back() { self.changes_store.id_queue.push_back(id); } else { @@ -102,15 +101,16 @@ where let current_changes = core::mem::take(&mut self.changes_store.current_changes); for (key, change) in current_changes.serialize(&id).iter() { self.db - .insert(&KeyType::TrieLog(key), change, Some(&mut batch))?; + .insert(&DatabaseKey::TrieLog(key), change, Some(&mut batch))?; } self.db.write_batch(batch)?; if let Some(max_saved_trie_logs) = self.config.max_saved_trie_logs { while self.changes_store.id_queue.len() > max_saved_trie_logs { // verified by previous conditional statement - let id = self.changes_store.id_queue.pop_front().unwrap().serialize(); - self.db.remove_by_prefix(&KeyType::TrieLog(&id))?; + let id = self.changes_store.id_queue.pop_front().unwrap(); + self.db + .remove_by_prefix(&DatabaseKey::TrieLog(&id.to_bytes()))?; } } Ok(()) @@ -124,26 +124,32 @@ where self.config.clone() } - pub(crate) fn get(&self, key: &TrieKeyType) -> Result>, BonsaiStorageError> { + pub(crate) fn get( + &self, + key: &TrieKey, + ) -> Result>, BonsaiStorageError> { trace!("Getting from KeyValueDB: {:?}", key); Ok(self.db.get(&key.into())?) } - pub(crate) fn contains(&self, key: &TrieKeyType) -> Result { + pub(crate) fn contains( + &self, + key: &TrieKey, + ) -> Result> { trace!("Contains from KeyValueDB: {:?}", key); Ok(self.db.contains(&key.into())?) } pub(crate) fn insert( &mut self, - key: &TrieKeyType, + key: &TrieKey, value: &[u8], batch: Option<&mut DB::Batch>, - ) -> Result<(), BonsaiStorageError> { + ) -> Result<(), BonsaiStorageError> { trace!("Inserting into KeyValueDB: {:?} {:?}", key, value); let old_value = self.db.insert(&key.into(), value, batch)?; self.changes_store.current_changes.insert_in_place( - key.into(), + key.clone(), Change { old_value, new_value: Some(value.to_vec()), @@ -154,13 +160,13 @@ where pub(crate) fn remove( &mut self, - key: &TrieKeyType, + key: &TrieKey, batch: Option<&mut DB::Batch>, - ) -> Result<(), BonsaiStorageError> { + ) -> Result<(), BonsaiStorageError> { trace!("Removing from KeyValueDB: {:?}", key); let old_value = self.db.remove(&key.into(), batch)?; self.changes_store.current_changes.insert_in_place( - key.into(), + key.clone(), Change { old_value, new_value: None, @@ -169,7 +175,10 @@ where Ok(()) } - pub(crate) fn write_batch(&mut self, batch: DB::Batch) -> Result<(), BonsaiStorageError> { + pub(crate) fn write_batch( + &mut self, + batch: DB::Batch, + ) -> Result<(), BonsaiStorageError> { trace!("Writing batch into KeyValueDB"); Ok(self.db.write_batch(batch)?) } @@ -196,10 +205,10 @@ where pub(crate) fn get_transaction( &self, id: ID, - ) -> Result, BonsaiStorageError> - where - BonsaiStorageError: core::convert::From<::DatabaseError>, - { + ) -> Result< + Option, + BonsaiStorageError<::DatabaseError>, + > { let Some(change_id) = self.snap_holder.range(..=id).last() else { return Ok(None); }; @@ -224,7 +233,7 @@ where let changes = ChangeBatch::deserialize( id, self.db - .get_by_prefix(&KeyType::TrieLog(id.serialize().as_ref())) + .get_by_prefix(&DatabaseKey::TrieLog(&id.to_bytes())) .map_err(|_| { BonsaiStorageError::Transaction(format!( "database is missing trie logs for {:?}", @@ -233,7 +242,7 @@ where })?, ); for (key, change) in changes.0 { - let key = KeyType::from(&key); + let key = DatabaseKey::from(&key); match (&change.old_value, &change.new_value) { (Some(_), Some(new_value)) => { txn.insert(&key, new_value, Some(&mut batch))?; @@ -255,11 +264,7 @@ where pub(crate) fn merge( &mut self, transaction: KeyValueDB, - ) -> Result<(), BonsaiStorageError> - where - BonsaiStorageError: - core::convert::From<>::DatabaseError>, - { + ) -> Result<(), BonsaiStorageError<>::DatabaseError>> { let Some(created_at) = transaction.created_at else { return Err(BonsaiStorageError::Merge( "Transaction has no created_at".to_string(), diff --git a/src/lib.rs b/src/lib.rs index 8335803..3a8b09c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -90,7 +90,7 @@ use crate::trie::merkle_tree::MerkleTree; #[cfg(not(feature = "std"))] use alloc::{format, vec::Vec}; use bitvec::{order::Msb0, slice::BitSlice}; -use bonsai_database::{BonsaiPersistentDatabase, KeyType}; +use bonsai_database::{BonsaiPersistentDatabase, DatabaseKey}; use changes::ChangeBatch; use key_value_db::KeyValueDB; use starknet_types_core::{ @@ -163,11 +163,13 @@ impl BonsaiStorage where DB: BonsaiDatabase, ChangeID: id::Id, - BonsaiStorageError: core::convert::From<::DatabaseError>, H: StarkHash, { /// Create a new bonsai storage instance - pub fn new(db: DB, config: BonsaiStorageConfig) -> Result { + pub fn new( + db: DB, + config: BonsaiStorageConfig, + ) -> Result> { let key_value_db = KeyValueDB::new(db, config.into(), None); Ok(Self { trie: MerkleTree::new(key_value_db)?, @@ -178,7 +180,7 @@ where db: DB, config: BonsaiStorageConfig, created_at: ChangeID, - ) -> Result { + ) -> Result> { let key_value_db = KeyValueDB::new(db, config.into(), Some(created_at)); Ok(Self { trie: MerkleTree::new(key_value_db)?, @@ -191,32 +193,44 @@ where &mut self, key: &BitSlice, value: &Felt, - ) -> Result<(), BonsaiStorageError> { + ) -> Result<(), BonsaiStorageError> { self.trie.set(key, *value)?; Ok(()) } /// Remove a key/value in the trie /// If the value doesn't exist it will do nothing - pub fn remove(&mut self, key: &BitSlice) -> Result<(), BonsaiStorageError> { + pub fn remove( + &mut self, + key: &BitSlice, + ) -> Result<(), BonsaiStorageError> { self.trie.set(key, Felt::ZERO)?; Ok(()) } /// Get a value in the trie. - pub fn get(&self, key: &BitSlice) -> Result, BonsaiStorageError> { + pub fn get( + &self, + key: &BitSlice, + ) -> Result, BonsaiStorageError> { self.trie.get(key) } /// Checks if the key exists in the trie. - pub fn contains(&self, key: &BitSlice) -> Result { + pub fn contains( + &self, + key: &BitSlice, + ) -> Result> { self.trie.contains(key) } /// Go to a specific commit ID. /// If insert/remove is called between the last `commit()` and a call to this function, /// the in-memory changes will be discarded. - pub fn revert_to(&mut self, requested_id: ChangeID) -> Result<(), BonsaiStorageError> { + pub fn revert_to( + &mut self, + requested_id: ChangeID, + ) -> Result<(), BonsaiStorageError> { let kv = self.trie.db_mut(); // Clear current changes @@ -246,7 +260,7 @@ where full.extend( ChangeBatch::deserialize( id, - kv.db.get_by_prefix(&KeyType::TrieLog(&id.serialize()))?, + kv.db.get_by_prefix(&DatabaseKey::TrieLog(&id.to_bytes()))?, ) .0, ); @@ -255,7 +269,7 @@ where // Revert changes let mut batch = kv.db.create_batch(); for (key, change) in full.iter().rev() { - let key = KeyType::from(key); + let key = DatabaseKey::from(key); match (&change.old_value, &change.new_value) { (Some(old_value), Some(_)) => { kv.db.insert(&key, old_value, Some(&mut batch))?; @@ -276,12 +290,13 @@ where kv.changes_store.id_queue.push_back(current); } for id in truncated.iter() { - kv.db.remove_by_prefix(&KeyType::TrieLog(&id.serialize()))?; + kv.db + .remove_by_prefix(&DatabaseKey::TrieLog(&id.to_bytes()))?; } // Write revert changes and trie logs truncation kv.db.write_batch(batch)?; - self.trie.reset_root_from_db()?; + self.trie.reset_to_last_commit()?; Ok(()) } @@ -291,13 +306,16 @@ where } /// Get trie root hash at the latest commit - pub fn root_hash(&self) -> Result { + pub fn root_hash(&self) -> Result> { Ok(self.trie.root_hash()) } /// This function must be used with transactional state only. /// Similar to `commit` but without optimizations. - pub fn transactional_commit(&mut self, id: ChangeID) -> Result<(), BonsaiStorageError> { + pub fn transactional_commit( + &mut self, + id: ChangeID, + ) -> Result<(), BonsaiStorageError> { self.trie.commit()?; self.trie.db_mut().commit(id)?; Ok(()) @@ -317,7 +335,7 @@ where pub fn get_proof( &self, key: &BitSlice, - ) -> Result, BonsaiStorageError> { + ) -> Result, BonsaiStorageError> { self.trie.get_proof(key) } @@ -336,17 +354,20 @@ impl BonsaiStorage where DB: BonsaiDatabase + BonsaiPersistentDatabase, ChangeID: id::Id, - BonsaiStorageError: core::convert::From<::DatabaseError>, H: StarkHash, { /// Update trie and database using all changes since the last commit. - pub fn commit(&mut self, id: ChangeID) -> Result<(), BonsaiStorageError> { + pub fn commit( + &mut self, + id: ChangeID, + ) -> Result<(), BonsaiStorageError<::DatabaseError>> { self.trie.commit()?; self.trie.db_mut().commit(id)?; self.trie.db_mut().create_snapshot(id); Ok(()) } + #[allow(clippy::type_complexity)] /// Get a transactional state of the trie at a specific commit ID. /// /// Transactional state allow you to fetch a point-in-time state of the trie. You can @@ -355,10 +376,10 @@ where &self, change_id: ChangeID, config: BonsaiStorageConfig, - ) -> Result>, BonsaiStorageError> - where - BonsaiStorageError: core::convert::From<::DatabaseError>, - { + ) -> Result< + Option>, + BonsaiStorageError<::DatabaseError>, + > { if let Some(transaction) = self.trie.db_ref().get_transaction(change_id)? { Ok(Some(BonsaiStorage::new_from_transactional_state( transaction, @@ -379,10 +400,7 @@ where pub fn merge( &mut self, transactional_bonsai_storage: BonsaiStorage, - ) -> Result<(), BonsaiStorageError> - where - BonsaiStorageError: - core::convert::From<>::DatabaseError>, + ) -> Result<(), BonsaiStorageError<>::DatabaseError>> { self.trie .db_mut() diff --git a/src/tests/madara_comparison.rs b/src/tests/madara_comparison.rs index 6a525f5..48b1a82 100644 --- a/src/tests/madara_comparison.rs +++ b/src/tests/madara_comparison.rs @@ -24,8 +24,7 @@ fn trie_height_251() { let mut id_builder = BasicIdBuilder::new(); let id = id_builder.new_id(); bonsai_storage.commit(id).unwrap(); - let root_hash = bonsai_storage.root_hash().unwrap(); - println!("root_hash: {:?}", root_hash); + bonsai_storage.root_hash().unwrap(); } // Test to add on Madara side to check with a tree of height 251 and see that we have same hash // #[test]// fn test_height_251() { diff --git a/src/tests/simple.rs b/src/tests/simple.rs index a67902f..9ae3eb6 100644 --- a/src/tests/simple.rs +++ b/src/tests/simple.rs @@ -35,19 +35,7 @@ fn basics() { ); let bitvec = BitVec::from_vec(pair3.0.clone()); bonsai_storage.insert(&bitvec, &pair3.1).unwrap(); - println!( - "get: {:?}", - bonsai_storage.get(&BitVec::from_vec(vec![1, 2, 1])) - ); bonsai_storage.commit(id_builder.new_id()).unwrap(); - println!( - "get: {:?}", - bonsai_storage.get(&BitVec::from_vec(vec![1, 2, 2])) - ); - println!( - "get: {:?}", - bonsai_storage.get(&BitVec::from_vec(vec![1, 2, 3])) - ); let bitvec = BitVec::from_vec(vec![1, 2, 1]); bonsai_storage.remove(&bitvec).unwrap(); assert_eq!( diff --git a/src/tests/trie_log.rs b/src/tests/trie_log.rs index 3145425..56c6ecc 100644 --- a/src/tests/trie_log.rs +++ b/src/tests/trie_log.rs @@ -184,9 +184,7 @@ fn remove_and_reinsert() { bonsai_storage.remove(&bitvec).unwrap(); bonsai_storage.commit(id1).unwrap(); let root_hash1 = bonsai_storage.root_hash().unwrap(); - let id2 = id_builder.new_id(); - println!("before second insert"); bonsai_storage.insert(&bitvec, &pair1.1).unwrap(); bonsai_storage.commit(id2).unwrap(); diff --git a/src/trie/merkle_node.rs b/src/trie/merkle_node.rs index 4905a94..fa02128 100644 --- a/src/trie/merkle_node.rs +++ b/src/trie/merkle_node.rs @@ -9,7 +9,7 @@ use bitvec::slice::BitSlice; use parity_scale_codec::{Decode, Encode}; use starknet_types_core::felt::Felt; -use super::merkle_tree::Path; +use super::path::Path; /// Id of a Node within the tree #[derive(Copy, Clone, Debug, PartialEq, Eq, Default, PartialOrd, Ord, Hash, Encode, Decode)] @@ -18,7 +18,7 @@ pub struct NodeId(pub u64); impl NodeId { /// Mutates the given NodeId to be the next one and returns it. pub fn next_id(&mut self) -> NodeId { - self.0 += 1; + self.0 = self.0.checked_add(1).expect("Node id overflow"); NodeId(self.0) } @@ -49,7 +49,8 @@ pub enum NodeHandle { /// Describes the [Node::Binary] variant. #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Encode, Decode)] pub struct BinaryNode { - /// The hash of this node. + /// The hash of this node. Is [None] if the node + /// has not yet been committed. pub hash: Option, /// The height of this node in the tree. pub height: u64, @@ -217,3 +218,151 @@ impl EdgeNode { &self.path.0[..common_length] } } + +#[test] +fn test_path_matches_basic() { + let path = Path( + BitSlice::::from_slice(&[0b10101010, 0b01010101, 0b10101010, 0b01010101]) + .to_bitvec(), + ); + let edge = EdgeNode { + hash: None, + height: 0, + path, + child: NodeHandle::Hash(Felt::ZERO), + }; + + let key = BitSlice::::from_slice(&[0b10101010, 0b01010101, 0b10101010, 0b01010101]); + assert!(edge.path_matches(key)); +} + +#[test] +fn test_path_matches_with_height() { + let path = Path( + BitSlice::::from_slice(&[0b10101010, 0b01010101, 0b10101010, 0b01010101]) + .to_bitvec(), + ); + let edge = EdgeNode { + hash: None, + height: 8, + path, + child: NodeHandle::Hash(Felt::ZERO), + }; + + let key = BitSlice::::from_slice(&[ + 0b10101010, 0b10101010, 0b01010101, 0b10101010, 0b01010101, + ]); + assert!(edge.path_matches(key)); +} + +#[test] +fn test_path_matches_only_part_with_height() { + let path = Path( + BitSlice::::from_slice(&[0b10101010, 0b01010101, 0b10101010, 0b01010101]) + .to_bitvec(), + ); + let edge = EdgeNode { + hash: None, + height: 8, + path, + child: NodeHandle::Hash(Felt::ZERO), + }; + + let key = BitSlice::::from_slice(&[ + 0b10101010, 0b10101010, 0b01010101, 0b10101010, 0b01010101, 0b10101010, + ]); + assert!(edge.path_matches(key)); +} + +#[test] +fn test_path_dont_match() { + let path = Path( + BitSlice::::from_slice(&[0b10111010, 0b01010101, 0b10101010, 0b01010101]) + .to_bitvec(), + ); + let edge = EdgeNode { + hash: None, + height: 0, + path, + child: NodeHandle::Hash(Felt::ZERO), + }; + + let key = BitSlice::::from_slice(&[ + 0b10101010, 0b01010101, 0b10101010, 0b01010101, 0b10101010, + ]); + assert!(!edge.path_matches(key)); +} + +#[test] +fn test_common_path_basic() { + let path = Path( + BitSlice::::from_slice(&[0b10101010, 0b01010101, 0b10101010, 0b01010101]) + .to_bitvec(), + ); + let edge = EdgeNode { + hash: None, + height: 0, + path: path.clone(), + child: NodeHandle::Hash(Felt::ZERO), + }; + + let key = BitSlice::::from_slice(&[0b10101010, 0b01010101, 0b10101010, 0b01010101]); + assert_eq!(edge.common_path(key), &path.0); +} + +#[test] +fn test_common_path_only_part() { + let path = Path( + BitSlice::::from_slice(&[0b10101010, 0b01010101, 0b10101010, 0b01010101]) + .to_bitvec(), + ); + let edge = EdgeNode { + hash: None, + height: 0, + path, + child: NodeHandle::Hash(Felt::ZERO), + }; + + let key = BitSlice::::from_slice(&[0b10101010, 0b01010101]); + assert_eq!( + edge.common_path(key), + BitSlice::::from_slice(&[0b10101010, 0b01010101]) + ); +} + +#[test] +fn test_common_path_part_with_height() { + let path = Path( + BitSlice::::from_slice(&[0b10101010, 0b01010101, 0b10101010, 0b01010101]) + .to_bitvec(), + ); + let edge = EdgeNode { + hash: None, + height: 8, + path, + child: NodeHandle::Hash(Felt::ZERO), + }; + + let key = BitSlice::::from_slice(&[0b01010101, 0b10101010]); + assert_eq!( + edge.common_path(key), + BitSlice::::from_slice(&[0b10101010]) + ); +} + +#[test] +fn test_no_common_path() { + let path = Path( + BitSlice::::from_slice(&[0b10101010, 0b01010101, 0b10101010, 0b01010101]) + .to_bitvec(), + ); + let edge = EdgeNode { + hash: None, + height: 0, + path, + child: NodeHandle::Hash(Felt::ZERO), + }; + + let key = BitSlice::::from_slice(&[0b01010101, 0b10101010]); + assert_eq!(edge.common_path(key), BitSlice::::empty()); +} diff --git a/src/trie/merkle_tree.rs b/src/trie/merkle_tree.rs index df2debe..6219cd7 100644 --- a/src/trie/merkle_tree.rs +++ b/src/trie/merkle_tree.rs @@ -10,7 +10,7 @@ use core::mem; use derive_more::Constructor; #[cfg(not(feature = "std"))] use hashbrown::HashMap; -use parity_scale_codec::{Decode, Encode, Error, Input, Output}; +use parity_scale_codec::{Decode, Encode}; use starknet_types_core::{felt::Felt, hash::StarkHash}; #[cfg(feature = "std")] use std::collections::HashMap; @@ -19,7 +19,8 @@ use crate::{error::BonsaiStorageError, id::Id, BonsaiDatabase, KeyValueDB}; use super::{ merkle_node::{BinaryNode, Direction, EdgeNode, Node, NodeHandle, NodeId}, - TrieKeyType, + path::Path, + TrieKey, }; #[cfg(test)] @@ -34,97 +35,11 @@ pub enum Membership { /// Wrapper type for a [HashMap] object. (It's not really a wrapper it's a /// copy of the type but we implement the necessary traits.) #[derive(Clone, Debug, PartialEq, Eq, Default, Constructor)] -pub struct NodesMapping(pub HashMap); - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Path(pub BitVec); - -impl Encode for Path { - fn encode_to(&self, dest: &mut T) { - // Inspired from scale_bits crate but don't use it to avoid copy and u32 length encoding - let iter = self.0.iter(); - let len = iter.len(); - // SAFETY: len is <= 251 - dest.push_byte(len as u8); - let mut next_store: u8 = 0; - let mut pos_in_next_store: u8 = 7; - for b in iter { - let bit = match *b { - true => 1, - false => 0, - }; - next_store |= bit << pos_in_next_store; - - if pos_in_next_store == 0 { - pos_in_next_store = 8; - dest.push_byte(next_store); - next_store = 0; - } - pos_in_next_store -= 1; - } - - if pos_in_next_store < 7 { - dest.push_byte(next_store); - } - } -} +pub struct NodesMapping(HashMap); -impl Decode for Path { - fn decode(input: &mut I) -> Result { - // Inspired from scale_bits crate but don't use it to avoid copy and u32 length encoding - // SAFETY: len is <= 251 - let len: u8 = input.read_byte()?; - let mut remaining_bits = len as usize; - let mut current_byte = None; - let mut bit = 7; - let mut bits = BitVec::::new(); - // No bits left to decode; we're done. - while remaining_bits != 0 { - // Get the next store entry to pull from: - let store = match current_byte { - Some(store) => store, - None => { - let store = match input.read_byte() { - Ok(s) => s, - Err(e) => return Err(e), - }; - current_byte = Some(store); - store - } - }; - - // Extract a bit: - let res = match (store >> bit) & 1 { - 0 => false, - 1 => true, - _ => unreachable!("Can only be 0 or 1 owing to &1"), - }; - bits.push(res); - - // Update records for next bit: - remaining_bits -= 1; - if bit == 0 { - current_byte = None; - bit = 8; - } - bit -= 1; - } - Ok(Self(bits)) - } -} - -#[test] -fn test_shared_path_encode_decode() { - let path = Path(BitVec::::from_slice(&[0b10101010, 0b10101010])); - let mut encoded = Vec::new(); - path.encode_to(&mut encoded); - - let decoded = Path::decode(&mut &encoded[..]).unwrap(); - assert_eq!(path, decoded); -} /// A node used in proof generated by the trie. /// -/// Each node hold only the minimum of data that need to be known for the proof: the child hashes (and path for edge node) +/// See pathfinders merkle-tree crate for more information. #[derive(Debug, Clone, PartialEq)] pub enum ProofNode { Binary { left: Felt, right: Felt }, @@ -155,17 +70,27 @@ impl ProofNode { /// /// For more information on how this functions internally, see [here](super::merkle_node). pub struct MerkleTree { + /// The handle to the current root node could be hash if no modifications has been done + /// since the last commit or in memory if there are some modifications. root_handle: NodeHandle, + /// The last known root hash. Updated only each commit. (possibly outdated between two commits) root_hash: Felt, + /// Temporary storage used to store the nodes that are modified during a commit. + /// This storage is used to avoid modifying the underlying database each time during a commit. storage_nodes: NodesMapping, + /// The underlying database used to store the nodes. db: KeyValueDB, + /// The id of the last node that has been added to the temporary storage. latest_node_id: NodeId, - death_row: Vec, + /// The list of nodes that should be removed from the underlying database during the next commit. + death_row: Vec, + /// The list of leaves that have been modified during the current commit. cache_leaf_modified: HashMap, InsertOrRemove>, + /// The hasher used to hash the nodes. _hasher: PhantomData, } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq)] enum InsertOrRemove { Insert(T), Remove, @@ -175,27 +100,21 @@ impl MerkleTree { /// Less visible initialization for `MerkleTree` as the main entry points should be /// [`MerkleTree::::load`] for persistent trees and [`MerkleTree::empty`] for /// transient ones. - pub fn new(mut db: KeyValueDB) -> Result - where - BonsaiStorageError: core::convert::From<::DatabaseError>, - { + pub fn new(mut db: KeyValueDB) -> Result> { let nodes_mapping: HashMap = HashMap::new(); - let root_node = db.get(&TrieKeyType::Trie(vec![]))?; + let root_node = db.get(&TrieKey::Trie(vec![]))?; let node = if let Some(root_node) = root_node { - Node::decode(&mut root_node.as_slice()).map_err(|err| { - BonsaiStorageError::Trie(format!("Couldn't decode root node: {}", err)) - })? + Node::decode(&mut root_node.as_slice())? } else { db.insert( - &TrieKeyType::Trie(vec![]), + &TrieKey::Trie(vec![]), &Node::Unresolved(Felt::ZERO).encode(), None, )?; Node::Unresolved(Felt::ZERO) }; - let root = node.hash().ok_or(BonsaiStorageError::Trie( - "Root doesn't exist in the storage".to_string(), - ))?; + // SAFETY: The root node has been created just above + let root = node.hash().unwrap(); Ok(Self { root_handle: NodeHandle::Hash(root), root_hash: root, @@ -212,12 +131,10 @@ impl MerkleTree { self.root_hash } - pub fn reset_root_from_db(&mut self) -> Result<(), BonsaiStorageError> - where - BonsaiStorageError: core::convert::From<::DatabaseError>, - { + /// Remove all the modifications that have been done since the last commit. + pub fn reset_to_last_commit(&mut self) -> Result<(), BonsaiStorageError> { let node = self - .get_tree_branch_in_db_from_path(&BitVec::::new())? + .get_trie_branch_in_db_from_path(&Path(BitVec::::new()))? .ok_or(BonsaiStorageError::Trie( "root node doesn't exist in the storage".to_string(), ))?; @@ -233,23 +150,20 @@ impl MerkleTree { } /// Persists all changes to storage and returns the new root hash. - pub fn commit(&mut self) -> Result - where - BonsaiStorageError: core::convert::From<::DatabaseError>, - { + pub fn commit(&mut self) -> Result> { let mut batch = self.db.create_batch(); for node_key in mem::take(&mut self.death_row) { self.db.remove(&node_key, Some(&mut batch))?; } - let root_hash = self.commit_subtree(self.root_handle, BitVec::new(), &mut batch)?; + let root_hash = self.commit_subtree(self.root_handle, Path(BitVec::new()), &mut batch)?; for (key, value) in mem::take(&mut self.cache_leaf_modified) { match value { InsertOrRemove::Insert(value) => { self.db - .insert(&TrieKeyType::Flat(key), &value.encode(), Some(&mut batch))?; + .insert(&TrieKey::Flat(key), &value.encode(), Some(&mut batch))?; } InsertOrRemove::Remove => { - self.db.remove(&TrieKeyType::Flat(key), Some(&mut batch))?; + self.db.remove(&TrieKey::Flat(key), Some(&mut batch))?; } } } @@ -274,12 +188,9 @@ impl MerkleTree { fn commit_subtree( &mut self, node_handle: NodeHandle, - path: BitVec, + path: Path, batch: &mut DB::Batch, - ) -> Result - where - BonsaiStorageError: core::convert::From<::DatabaseError>, - { + ) -> Result> { use Node::*; let node_id = match node_handle { NodeHandle::Hash(hash) => return Ok(hash), @@ -294,9 +205,9 @@ impl MerkleTree { "Couldn't fetch node in the temporary storage".to_string(), ))? { Unresolved(hash) => { - if path.is_empty() { + if path.0.is_empty() { self.db.insert( - &TrieKeyType::Trie(vec![]), + &TrieKey::Trie(vec![]), &Node::Unresolved(hash).encode(), Some(batch), )?; @@ -306,19 +217,17 @@ impl MerkleTree { } } Binary(mut binary) => { - let mut left_path = path.clone(); - left_path.push(false); + let left_path = path.new_with_direction(Direction::Left); let left_hash = self.commit_subtree(binary.left, left_path, batch)?; - let mut right_path = path.clone(); - right_path.push(true); + let right_path = path.new_with_direction(Direction::Right); let right_hash = self.commit_subtree(binary.right, right_path, batch)?; let hash = H::hash(&left_hash, &right_hash); binary.hash = Some(hash); binary.left = NodeHandle::Hash(left_hash); binary.right = NodeHandle::Hash(right_hash); - let key_bytes = [&[path.len() as u8], path.as_raw_slice()].concat(); + let key_bytes = [&[path.0.len() as u8], path.0.as_raw_slice()].concat(); self.db.insert( - &TrieKeyType::Trie(key_bytes), + &TrieKey::Trie(key_bytes), &Node::Binary(binary).encode(), Some(batch), )?; @@ -327,7 +236,7 @@ impl MerkleTree { Edge(mut edge) => { let mut child_path = path.clone(); - child_path.extend(&edge.path.0); + child_path.0.extend(&edge.path.0); let child_hash = self.commit_subtree(edge.child, child_path, batch)?; let mut bytes = [0u8; 32]; bytes.view_bits_mut::()[256 - edge.path.0.len()..] @@ -342,13 +251,13 @@ impl MerkleTree { let hash = H::hash(&child_hash, &felt_path) + length; edge.hash = Some(hash); edge.child = NodeHandle::Hash(child_hash); - let key_bytes = if path.is_empty() { + let key_bytes = if path.0.is_empty() { vec![] } else { - [&[path.len() as u8], path.as_raw_slice()].concat() + [&[path.0.len() as u8], path.0.as_raw_slice()].concat() }; self.db.insert( - &TrieKeyType::Trie(key_bytes), + &TrieKey::Trie(key_bytes), &Node::Edge(edge).encode(), Some(batch), )?; @@ -363,10 +272,11 @@ impl MerkleTree { /// /// * `key` - The key to set. /// * `value` - The value to set. - pub fn set(&mut self, key: &BitSlice, value: Felt) -> Result<(), BonsaiStorageError> - where - BonsaiStorageError: core::convert::From<::DatabaseError>, - { + pub fn set( + &mut self, + key: &BitSlice, + value: Felt, + ) -> Result<(), BonsaiStorageError> { if value == Felt::ZERO { return self.delete_leaf(key); } @@ -390,102 +300,94 @@ impl MerkleTree { use Node::*; match path.last() { Some(node_id) => { - let mut nodes_to_add = Vec::with_capacity(4); + let mut nodes_to_add = Vec::new(); self.storage_nodes.0.entry(*node_id).and_modify(|node| { - match node { - Edge(edge) => { - let common = edge.common_path(key); - // Height of the binary node - let branch_height = edge.height as usize + common.len(); - if branch_height == key.len() { - edge.child = NodeHandle::Hash(value); - // The leaf already exists, we simply change its value. - let key_bytes = - [&[key.len() as u8], key.to_bitvec().as_raw_slice()].concat(); - self.cache_leaf_modified - .insert(key_bytes, InsertOrRemove::Insert(value)); - return; - } - // Height of the binary node's children - let child_height = branch_height + 1; - - // Path from binary node to new leaf - let new_path = key[child_height..].to_bitvec(); - // Path from binary node to existing child - let old_path = edge.path.0[common.len() + 1..].to_bitvec(); - - // The new leaf branch of the binary node. - // (this may be edge -> leaf, or just leaf depending). - let key_bytes = - [&[key.len() as u8], key.to_bitvec().as_raw_slice()].concat(); + if let Edge(edge) = node { + let common = edge.common_path(key); + // Height of the binary node + let branch_height = edge.height as usize + common.len(); + if branch_height == key.len() { + edge.child = NodeHandle::Hash(value); + // The leaf already exists, we simply change its value. + let key_bytes = bitslice_to_bytes(key); self.cache_leaf_modified .insert(key_bytes, InsertOrRemove::Insert(value)); + return; + } + // Height of the binary node's children + let child_height = branch_height + 1; + + // Path from binary node to new leaf + let new_path = key[child_height..].to_bitvec(); + // Path from binary node to existing child + let old_path = edge.path.0[common.len() + 1..].to_bitvec(); + + // The new leaf branch of the binary node. + // (this may be edge -> leaf, or just leaf depending). + let key_bytes = bitslice_to_bytes(key); + self.cache_leaf_modified + .insert(key_bytes, InsertOrRemove::Insert(value)); + + let new = if new_path.is_empty() { + NodeHandle::Hash(value) + } else { + let new_edge = Node::Edge(EdgeNode { + hash: None, + height: child_height as u64, + path: Path(new_path), + child: NodeHandle::Hash(value), + }); + let edge_id = self.latest_node_id.next_id(); + nodes_to_add.push((edge_id, new_edge)); + NodeHandle::InMemory(edge_id) + }; - let new = if new_path.is_empty() { - NodeHandle::Hash(value) - } else { - let new_edge = Node::Edge(EdgeNode { - hash: None, - height: child_height as u64, - path: Path(new_path), - child: NodeHandle::Hash(value), - }); - let edge_id = self.latest_node_id.next_id(); - nodes_to_add.push((edge_id, new_edge)); - NodeHandle::InMemory(edge_id) - }; - - // The existing child branch of the binary node. - let old = if old_path.is_empty() { - edge.child - } else { - let old_edge = Node::Edge(EdgeNode { - hash: None, - height: child_height as u64, - path: Path(old_path), - child: edge.child, - }); - let edge_id = self.latest_node_id.next_id(); - nodes_to_add.push((edge_id, old_edge)); - NodeHandle::InMemory(edge_id) - }; - - let new_direction = Direction::from(key[branch_height]); - let (left, right) = match new_direction { - Direction::Left => (new, old), - Direction::Right => (old, new), - }; - - let branch = Node::Binary(BinaryNode { + // The existing child branch of the binary node. + let old = if old_path.is_empty() { + edge.child + } else { + let old_edge = Node::Edge(EdgeNode { hash: None, - height: branch_height as u64, - left, - right, + height: child_height as u64, + path: Path(old_path), + child: edge.child, }); + let edge_id = self.latest_node_id.next_id(); + nodes_to_add.push((edge_id, old_edge)); + NodeHandle::InMemory(edge_id) + }; - // We may require an edge leading to the binary node. - let new_node = if common.is_empty() { - branch - } else { - let branch_id = self.latest_node_id.next_id(); - nodes_to_add.push((branch_id, branch)); - - Node::Edge(EdgeNode { - hash: None, - height: edge.height, - path: Path(common.to_bitvec()), - child: NodeHandle::InMemory(branch_id), - }) - }; - let path = key[..edge.height as usize].to_bitvec(); - let key_bytes = - [&[path.len() as u8], path.into_vec().as_slice()].concat(); - self.death_row.push(TrieKeyType::Trie(key_bytes)); - *node = new_node; - } - Unresolved(_) | Binary(_) => { - unreachable!("The end of a traversion cannot be unresolved or binary") - } + let new_direction = Direction::from(key[branch_height]); + let (left, right) = match new_direction { + Direction::Left => (new, old), + Direction::Right => (old, new), + }; + + let branch = Node::Binary(BinaryNode { + hash: None, + height: branch_height as u64, + left, + right, + }); + + // We may require an edge leading to the binary node. + let new_node = if common.is_empty() { + branch + } else { + let branch_id = self.latest_node_id.next_id(); + nodes_to_add.push((branch_id, branch)); + + Node::Edge(EdgeNode { + hash: None, + height: edge.height, + path: Path(common.to_bitvec()), + child: NodeHandle::InMemory(branch_id), + }) + }; + let path = key[..edge.height as usize].to_bitvec(); + let key_bytes = [&[path.len() as u8], path.into_vec().as_slice()].concat(); + self.death_row.push(TrieKey::Trie(key_bytes)); + *node = new_node; }; }); for (id, node) in nodes_to_add { @@ -510,7 +412,7 @@ impl MerkleTree { self.root_handle = NodeHandle::InMemory(self.latest_node_id); - let key_bytes = [&[key.len() as u8], key.to_bitvec().as_raw_slice()].concat(); + let key_bytes = bitslice_to_bytes(key); self.cache_leaf_modified .insert(key_bytes, InsertOrRemove::Insert(value)); Ok(()) @@ -538,10 +440,10 @@ impl MerkleTree { /// # Arguments /// /// * `key` - The key to delete. - fn delete_leaf(&mut self, key: &BitSlice) -> Result<(), BonsaiStorageError> - where - BonsaiStorageError: core::convert::From<::DatabaseError>, - { + fn delete_leaf( + &mut self, + key: &BitSlice, + ) -> Result<(), BonsaiStorageError> { // Algorithm explanation: // // The leaf's parent node is either an edge, or a binary node. @@ -556,16 +458,13 @@ impl MerkleTree { // // Then we are done. - let key_bytes = [&[key.len() as u8], key.to_bitvec().as_raw_slice()].concat(); + let key_bytes = bitslice_to_bytes(key); self.cache_leaf_modified .insert(key_bytes.clone(), InsertOrRemove::Remove); - if !self.db.contains(&TrieKeyType::Flat(key_bytes))? { - return Ok(()); - } let path = self.preload_nodes(key)?; - let mut last_binary_path = key.to_bitvec(); + let mut last_binary_path = Path(key.to_bitvec()); // Go backwards until we hit a branch node. let mut node_iter = path.into_iter().rev().skip_while(|node| { @@ -576,18 +475,9 @@ impl MerkleTree { Node::Binary(_) => {} Node::Edge(edge) => { for _ in 0..edge.path.0.len() { - last_binary_path.pop(); - } - let key_bytes = [ - &[last_binary_path.len() as u8], - last_binary_path.as_raw_slice(), - ] - .concat(); - if last_binary_path.is_empty() { - self.death_row.push(TrieKeyType::Trie(vec![])); - } else { - self.death_row.push(TrieKeyType::Trie(key_bytes)); + last_binary_path.0.pop(); } + self.death_row.push((&last_binary_path).into()); } } !node.is_binary() @@ -609,7 +499,7 @@ impl MerkleTree { // Create an edge node to replace the old binary node // i.e. with the remaining child (note the direction invert), // and a path of just a single bit. - last_binary_path.push(direction.into()); + last_binary_path.0.push(direction.into()); let path = Path(once(bool::from(direction)).collect::>()); let mut edge = EdgeNode { hash: None, @@ -640,47 +530,45 @@ impl MerkleTree { }; // Check the parent of the new edge. If it is also an edge, then they must merge. - if let Some(node) = parent_branch_node { - let child = if let Node::Edge(edge) = - self.storage_nodes - .0 - .get(&node) - .ok_or(BonsaiStorageError::Trie( - "Node not found in memory".to_string(), - ))? { - let child_node = match edge.child { - NodeHandle::Hash(_) => return Ok(()), - NodeHandle::InMemory(child_id) => { - self.storage_nodes - .0 - .get(&child_id) - .ok_or(BonsaiStorageError::Trie( - "Node not found in memory".to_string(), - ))? + if let Some(node_id) = parent_branch_node { + let node = self + .storage_nodes + .0 + .get(&node_id) + .ok_or(BonsaiStorageError::Trie( + "Node not found in memory".to_string(), + ))?; + // If it's an edge node and the child is in memory and it's an edge too we + // return the child otherwise we leave + let child = + if let Node::Edge(edge) = node { + match edge.child { + NodeHandle::Hash(_) => return Ok(()), + NodeHandle::InMemory(child_id) => { + let child_node = self.storage_nodes.0.get(&child_id).ok_or( + BonsaiStorageError::Trie("Node not found in memory".to_string()), + )?; + if let Node::Edge(child_edge) = child_node { + child_edge.clone() + } else { + return Ok(()); + } + } } + } else { + return Ok(()); }; - match child_node { - Node::Edge(child_edge) => child_edge.clone(), - _ => { - return Ok(()); - } - } - } else { - return Ok(()); - }; + // Get a mutable reference to the parent node to merge them let edge = self .storage_nodes .0 - .get_mut(&node) + .get_mut(&node_id) .ok_or(BonsaiStorageError::Trie( "Node not found in memory".to_string(), ))?; - match edge { - Node::Edge(edge) => { - edge.path.0.extend_from_bitslice(&child.path.0); - edge.child = child.child; - } - _ => unreachable!(), + if let Node::Edge(edge) = edge { + edge.path.0.extend_from_bitslice(&child.path.0); + edge.child = child.child; } } Ok(()) @@ -695,34 +583,34 @@ impl MerkleTree { /// # Returns /// /// The value of the key. - pub fn get(&self, key: &BitSlice) -> Result, BonsaiStorageError> - where - BonsaiStorageError: core::convert::From<::DatabaseError>, - { - let key = &[&[key.len() as u8], key.to_bitvec().as_raw_slice()].concat(); - if let Some(value) = self.cache_leaf_modified.get(key) { + pub fn get( + &self, + key: &BitSlice, + ) -> Result, BonsaiStorageError> { + let key = bitslice_to_bytes(key); + if let Some(value) = self.cache_leaf_modified.get(&key) { match value { InsertOrRemove::Remove => return Ok(None), InsertOrRemove::Insert(value) => return Ok(Some(*value)), } } self.db - .get(&TrieKeyType::Flat(key.to_vec())) + .get(&TrieKey::Flat(key.to_vec())) .map(|r| r.map(|opt| Felt::decode(&mut opt.as_slice()).unwrap())) } - pub fn contains(&self, key: &BitSlice) -> Result - where - BonsaiStorageError: core::convert::From<::DatabaseError>, - { - let key = &[&[key.len() as u8], key.to_bitvec().as_raw_slice()].concat(); - if let Some(value) = self.cache_leaf_modified.get(key) { + pub fn contains( + &self, + key: &BitSlice, + ) -> Result> { + let key = bitslice_to_bytes(key); + if let Some(value) = self.cache_leaf_modified.get(&key) { match value { InsertOrRemove::Remove => return Ok(false), InsertOrRemove::Insert(_) => return Ok(true), } } - self.db.contains(&TrieKeyType::Flat(key.to_vec())) + self.db.contains(&TrieKey::Flat(key.to_vec())) } /// Returns the list of nodes along the path. @@ -743,15 +631,15 @@ impl MerkleTree { /// # Returns /// /// The merkle proof and all the child nodes hashes. - pub fn get_proof(&self, key: &BitSlice) -> Result, BonsaiStorageError> - where - BonsaiStorageError: core::convert::From<::DatabaseError>, - { + pub fn get_proof( + &self, + key: &BitSlice, + ) -> Result, BonsaiStorageError> { let mut nodes = Vec::with_capacity(251); let mut node = match self.root_handle { NodeHandle::Hash(_) => { let node = self - .get_tree_branch_in_db_from_path(&BitVec::::new())? + .get_trie_branch_in_db_from_path(&Path(BitVec::::new()))? .ok_or(BonsaiStorageError::Trie( "Couldn't fetch root node in db".to_string(), ))?; @@ -775,7 +663,7 @@ impl MerkleTree { let child_path = key[..edge.height as usize + edge.path.0.len()].to_bitvec(); let child_node = match edge.child { NodeHandle::Hash(hash) => { - let node = self.get_tree_branch_in_db_from_path(&child_path)?; + let node = self.get_trie_branch_in_db_from_path(&Path(child_path))?; if let Some(node) = node { node } else { @@ -819,7 +707,7 @@ impl MerkleTree { let next_path = key[..binary.height as usize + 1].to_bitvec(); let next_node = match next { NodeHandle::Hash(_) => self - .get_tree_branch_in_db_from_path(&next_path)? + .get_trie_branch_in_db_from_path(&Path(next_path))? .ok_or(BonsaiStorageError::Trie( "Couldn't fetch next node in db".to_string(), ))?, @@ -899,15 +787,15 @@ impl MerkleTree { /// # Returns /// /// The list of nodes along the path. - fn preload_nodes(&mut self, dst: &BitSlice) -> Result, BonsaiStorageError> - where - BonsaiStorageError: core::convert::From<::DatabaseError>, - { + fn preload_nodes( + &mut self, + dst: &BitSlice, + ) -> Result, BonsaiStorageError> { let mut nodes = Vec::with_capacity(251); let node_id = match self.root_handle { NodeHandle::Hash(_) => { let node = self - .get_tree_branch_in_db_from_path(&BitVec::::new())? + .get_trie_branch_in_db_from_path(&Path(BitVec::::new()))? .ok_or(BonsaiStorageError::Trie( "Couldn't fetch root node in db".to_string(), ))?; @@ -925,7 +813,7 @@ impl MerkleTree { root_id } }; - self.preload_nodes_subtree(dst, node_id, BitVec::::new(), &mut nodes)?; + self.preload_nodes_subtree(dst, node_id, Path(BitVec::::new()), &mut nodes)?; Ok(nodes) } @@ -933,12 +821,9 @@ impl MerkleTree { &mut self, dst: &BitSlice, root_id: NodeId, - mut path: BitVec, + mut path: Path, nodes: &mut Vec, - ) -> Result<(), BonsaiStorageError> - where - BonsaiStorageError: core::convert::From<::DatabaseError>, - { + ) -> Result<(), BonsaiStorageError> { let node = self .storage_nodes .0 @@ -948,14 +833,21 @@ impl MerkleTree { ))? .clone(); match node { + // We are in a case where the trie is empty and so there is nothing to preload. Node::Unresolved(_hash) => Ok(()), + // We are checking which side of the binary we should load in memory (if we don't have it already) + // We load this "child-side" node in the memory and refer his memory handle in the binary node. + // We also add the "child-side" node in the list that accumulate all the nodes we want to preload. + // We override the binary node in the memory with this new version that has the "child-side" memory handle + // instead of the hash. + // We call recursively the function with the "child-side" node. Node::Binary(mut binary_node) => { let next_direction = binary_node.direction(dst); - path.push(bool::from(next_direction)); + path.0.push(bool::from(next_direction)); let next = binary_node.get_child(next_direction); match next { NodeHandle::Hash(_) => { - let node = self.get_tree_branch_in_db_from_path(&path)?.ok_or( + let node = self.get_trie_branch_in_db_from_path(&path)?.ok_or( BonsaiStorageError::Trie("Couldn't fetch node in db".to_string()), )?; self.latest_node_id.next_id(); @@ -980,15 +872,19 @@ impl MerkleTree { } } } + // If the edge node match the path we want to preload then we load the child node in memory (if we don't have it already) + // and we override the edge node in the memory with this new version that has the child memory handle instead of the hash. + // We also add the child node in the list that accumulate all the nodes we want to preload. + // We call recursively the function with the child node. Node::Edge(mut edge_node) if edge_node.path_matches(dst) => { - path.extend_from_bitslice(&edge_node.path.0); - if path == dst { + path.0.extend_from_bitslice(&edge_node.path.0); + if path.0 == dst { return Ok(()); } let next = edge_node.child; match next { NodeHandle::Hash(_) => { - let node = self.get_tree_branch_in_db_from_path(&path)?; + let node = self.get_trie_branch_in_db_from_path(&path)?; if let Some(node) = node { self.latest_node_id.next_id(); self.storage_nodes.0.insert(self.latest_node_id, node); @@ -1006,24 +902,18 @@ impl MerkleTree { } } } + // We are in a case where the edge node doesn't match the path we want to preload so we return nothing. Node::Edge(_) => Ok(()), } } - fn get_tree_branch_in_db_from_path( + /// Get the node of the trie that corresponds to the path. + fn get_trie_branch_in_db_from_path( &self, - path: &BitVec, - ) -> Result, BonsaiStorageError> - where - BonsaiStorageError: core::convert::From<::DatabaseError>, - { - let key = if path.is_empty() { - vec![] - } else { - [&[path.len() as u8], path.as_raw_slice()].concat() - }; + path: &Path, + ) -> Result, BonsaiStorageError> { self.db - .get(&TrieKeyType::Trie(key))? + .get(&path.into())? .map(|node| { Node::decode(&mut node.as_slice()).map_err(|err| { BonsaiStorageError::Trie(format!("Couldn't decode node: {}", err)) @@ -1043,11 +933,10 @@ impl MerkleTree { /// # Arguments /// /// * `parent` - The parent node to merge the child with. - fn merge_edges(&self, parent: &mut EdgeNode) -> Result<(), BonsaiStorageError> - where - BonsaiStorageError: core::convert::From<::DatabaseError>, - { - //TODO: Add deletion of unused nodes + fn merge_edges( + &self, + parent: &mut EdgeNode, + ) -> Result<(), BonsaiStorageError> { let child_node = match parent.child { NodeHandle::Hash(_) => return Ok(()), NodeHandle::InMemory(child_id) => { @@ -1147,6 +1036,7 @@ impl MerkleTree { } #[cfg(test)] + #[allow(dead_code)] fn display(&self) { match self.root_handle { NodeHandle::Hash(hash) => { @@ -1160,6 +1050,7 @@ impl MerkleTree { } #[cfg(test)] + #[allow(dead_code)] fn print(&self, head: &NodeId) { use Node::*; @@ -1200,32 +1091,37 @@ impl MerkleTree { } } +fn bitslice_to_bytes(bitslice: &BitSlice) -> Vec { + [&[bitslice.len() as u8], bitslice.to_bitvec().as_raw_slice()].concat() +} + +#[cfg(test)] #[cfg(all(test, feature = "std"))] mod tests { - use crate::{ - databases::{create_rocks_db, RocksDB, RocksDBConfig}, - id::BasicId, - key_value_db::KeyValueDBConfig, - KeyValueDB, - }; - use bitvec::vec::BitVec; - use mp_felt::Felt252Wrapper; - use mp_hashers::pedersen::PedersenHasher; - use parity_scale_codec::{Decode, Encode}; - use rand::prelude::*; - use starknet_types_core::{felt::Felt, hash::Pedersen}; - - // convert a Madara felt to a standard Felt - fn felt_from_madara_felt(madara_felt: &Felt252Wrapper) -> Felt { - let encoded = madara_felt.encode(); - Felt::decode(&mut &encoded[..]).unwrap() - } + // use crate::{ + // databases::{create_rocks_db, RocksDB, RocksDBConfig}, + // id::BasicId, + // key_value_db::KeyValueDBConfig, + // KeyValueDB, + // }; + // use bitvec::vec::BitVec; + // use mp_felt::Felt252Wrapper; + // use mp_hashers::pedersen::PedersenHasher; + // use parity_scale_codec::{Decode, Encode}; + // use rand::prelude::*; + // use starknet_types_core::{felt::Felt, hash::Pedersen}; + + // // convert a Madara felt to a standard Felt + // fn felt_from_madara_felt(madara_felt: &Felt252Wrapper) -> Felt { + // let encoded = madara_felt.encode(); + // Felt::decode(&mut &encoded[..]).unwrap() + // } - // convert a standard Felt to a Madara felt - fn madara_felt_from_felt(felt: &Felt) -> Felt252Wrapper { - let encoded = felt.encode(); - Felt252Wrapper::decode(&mut &encoded[..]).unwrap() - } + // // convert a standard Felt to a Madara felt + // fn madara_felt_from_felt(felt: &Felt) -> Felt252Wrapper { + // let encoded = felt.encode(); + // Felt252Wrapper::decode(&mut &encoded[..]).unwrap() + // } // #[test] // fn one_commit_tree_compare() { diff --git a/src/trie/mod.rs b/src/trie/mod.rs index bb341ea..0d06bb8 100644 --- a/src/trie/mod.rs +++ b/src/trie/mod.rs @@ -1,5 +1,6 @@ mod merkle_node; pub mod merkle_tree; +mod path; mod trie_db; -pub use trie_db::TrieKeyType; +pub use trie_db::TrieKey; diff --git a/src/trie/path.rs b/src/trie/path.rs new file mode 100644 index 0000000..78906e4 --- /dev/null +++ b/src/trie/path.rs @@ -0,0 +1,144 @@ +use bitvec::{order::Msb0, vec::BitVec}; +use parity_scale_codec::{Decode, Encode, Error, Input, Output}; + +use super::{merkle_node::Direction, TrieKey}; + +#[cfg(not(feature = "std"))] +use alloc::vec::Vec; + +#[cfg(all(feature = "std", test))] +use rstest::rstest; + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Path(pub BitVec); + +impl Encode for Path { + fn encode_to(&self, dest: &mut T) { + // Copied from scale_bits crate (https://github.com/paritytech/scale-bits/blob/820a3e8e0c9db18ef6acfa2a9a19f738400b0637/src/scale/encode_iter.rs#L28) + // but don't use it directly to avoid copy and u32 length encoding + // How it works ? + // 1. We encode the number of bits in the bitvec as a u8 + // 2. We build elements of a size of u8 using bit shifting + // 3. A last element, not full, is created if there is a remainder of bits + let iter = self.0.iter(); + let len = iter.len(); + // SAFETY: len is <= 251 + dest.push_byte(len as u8); + let mut next_store: u8 = 0; + let mut pos_in_next_store: u8 = 7; + for b in iter { + let bit = match *b { + true => 1, + false => 0, + }; + next_store |= bit << pos_in_next_store; + + if pos_in_next_store == 0 { + pos_in_next_store = 8; + dest.push_byte(next_store); + next_store = 0; + } + pos_in_next_store -= 1; + } + + if pos_in_next_store < 7 { + dest.push_byte(next_store); + } + } + + fn size_hint(&self) -> usize { + // Inspired from scale_bits crate but don't use it to avoid copy and u32 length encoding + 1 + (self.0.len() + 7) / 8 + } +} + +impl Decode for Path { + fn decode(input: &mut I) -> Result { + // Inspired from scale_bits crate but don't use it to avoid copy and u32 length encoding + // SAFETY: len is <= 251 + let len: u8 = input.read_byte()?; + let mut remaining_bits = len as usize; + let mut current_byte = None; + let mut bit = 7; + let mut bits = BitVec::::new(); + // No bits left to decode; we're done. + while remaining_bits != 0 { + // Get the next store entry to pull from: + let store = match current_byte { + Some(store) => store, + None => { + let store = match input.read_byte() { + Ok(s) => s, + Err(e) => return Err(e), + }; + current_byte = Some(store); + store + } + }; + + // Extract a bit: + let res = match (store >> bit) & 1 { + 0 => false, + 1 => true, + _ => unreachable!("Can only be 0 or 1 owing to &1"), + }; + bits.push(res); + + // Update records for next bit: + remaining_bits -= 1; + if bit == 0 { + current_byte = None; + bit = 8; + } + bit -= 1; + } + Ok(Self(bits)) + } +} + +impl Path { + pub(crate) fn new_with_direction(&self, direction: Direction) -> Path { + let mut path = self.0.clone(); + path.push(direction.into()); + Path(path) + } +} + +impl From for TrieKey { + fn from(path: Path) -> Self { + let key = if path.0.is_empty() { + Vec::new() + } else { + [&[path.0.len() as u8], path.0.as_raw_slice()].concat() + }; + TrieKey::Trie(key) + } +} + +impl From<&Path> for TrieKey { + fn from(path: &Path) -> Self { + let key = if path.0.is_empty() { + Vec::new() + } else { + [&[path.0.len() as u8], path.0.as_raw_slice()].concat() + }; + TrieKey::Trie(key) + } +} + +#[cfg(all(feature = "std", test))] +#[rstest] +#[case(&[0b10101010, 0b10101010])] +#[case(&[])] +#[case(&[0b10101010])] +#[case(&[0b00000000])] +#[case(&[0b11111111])] +#[case(&[0b11111111, 0b00000000, 0b10101010, 0b10101010, 0b11111111, 0b00000000, 0b10101010, 0b10101010, 0b11111111, 0b00000000, 0b10101010, 0b10101010])] +fn test_shared_path_encode_decode(#[case] input: &[u8]) { + let path = Path(BitVec::::from_slice(input)); + let mut encoded = Vec::new(); + path.encode_to(&mut encoded); + + let decoded = Path::decode(&mut &encoded[..]).unwrap(); + assert_eq!(path, decoded); +} diff --git a/src/trie/trie_db.rs b/src/trie/trie_db.rs index 534a754..b7debf2 100644 --- a/src/trie/trie_db.rs +++ b/src/trie/trie_db.rs @@ -1,38 +1,61 @@ -use crate::{bonsai_database::KeyType, changes::ChangeKeyType}; +use crate::bonsai_database::DatabaseKey; + #[cfg(not(feature = "std"))] use alloc::vec::Vec; -#[derive(Debug, Hash, PartialEq, Eq)] -pub enum TrieKeyType { +/// Key in the database of the different elements that are used in the storage of the trie data. +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub enum TrieKey { Trie(Vec), Flat(Vec), } -impl TrieKeyType { - pub fn as_slice(&self) -> &[u8] { - match self { - TrieKeyType::Trie(slice) => slice, - TrieKeyType::Flat(slice) => slice, +enum TrieKeyType { + Trie = 0, + Flat = 1, +} + +impl From for u8 { + fn from(value: TrieKey) -> Self { + match value { + TrieKey::Trie(_) => TrieKeyType::Trie as u8, + TrieKey::Flat(_) => TrieKeyType::Flat as u8, } } } -impl<'a> From<&'a TrieKeyType> for KeyType<'a> { - fn from(key: &'a TrieKeyType) -> Self { - let key_slice = key.as_slice(); - match key { - TrieKeyType::Trie(_) => KeyType::Trie(key_slice), - TrieKeyType::Flat(_) => KeyType::Flat(key_slice), +impl From<&TrieKey> for u8 { + fn from(value: &TrieKey) -> Self { + match value { + TrieKey::Trie(_) => TrieKeyType::Trie as u8, + TrieKey::Flat(_) => TrieKeyType::Flat as u8, + } + } +} + +impl TrieKey { + pub fn from_variant_and_bytes(variant: u8, bytes: Vec) -> Self { + match variant { + x if x == TrieKeyType::Trie as u8 => TrieKey::Trie(bytes), + x if x == TrieKeyType::Flat as u8 => TrieKey::Flat(bytes), + _ => panic!("Invalid trie key type"), + } + } + + pub fn as_slice(&self) -> &[u8] { + match self { + TrieKey::Trie(slice) => slice, + TrieKey::Flat(slice) => slice, } } } -impl<'a> From<&'a TrieKeyType> for ChangeKeyType { - fn from(key: &'a TrieKeyType) -> Self { +impl<'a> From<&'a TrieKey> for DatabaseKey<'a> { + fn from(key: &'a TrieKey) -> Self { let key_slice = key.as_slice(); match key { - TrieKeyType::Trie(_) => ChangeKeyType::Trie(key_slice.to_vec()), - TrieKeyType::Flat(_) => ChangeKeyType::Flat(key_slice.to_vec()), + TrieKey::Trie(_) => DatabaseKey::Trie(key_slice), + TrieKey::Flat(_) => DatabaseKey::Flat(key_slice), } } }