Skip to content

Commit

Permalink
mt for mmap tree
Browse files Browse the repository at this point in the history
  • Loading branch information
ewoolsey committed Mar 13, 2024
1 parent bf7b2c2 commit b78b0bf
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 39 deletions.
60 changes: 22 additions & 38 deletions src/lazy_merkle_tree.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::merkle_tree::{Branch, Hasher, Proof};
use crate::{
merkle_tree::{Branch, Hasher, Proof},
util::as_bytes,
};
use std::{
fs::OpenOptions,
io::Write,
Expand Down Expand Up @@ -686,7 +689,7 @@ impl<H: Hasher> Clone for DenseTree<H> {
}

impl<H: Hasher> DenseTree<H> {
fn new_with_values(values: &[H::Hash], empty_value: &H::Hash, depth: usize) -> Self {
fn vec_from_values(values: &[H::Hash], empty_value: &H::Hash, depth: usize) -> Vec<H::Hash> {
let leaf_count = 1 << depth;
let storage_size = 1 << (depth + 1);
let mut storage = Vec::with_capacity(storage_size);
Expand Down Expand Up @@ -714,6 +717,12 @@ impl<H: Hasher> DenseTree<H> {
});
}

storage
}

fn new_with_values(values: &[H::Hash], empty_value: &H::Hash, depth: usize) -> Self {
let storage = Self::vec_from_values(values, empty_value, depth);

Self {
depth,
root_index: 1,
Expand Down Expand Up @@ -898,7 +907,7 @@ impl<H: Hasher> DenseMMapTree<H> {
/// - returns Err if mmap creation fails
fn new_with_values(
values: &[H::Hash],
empty_leaf: &H::Hash,
empty_value: &H::Hash,
depth: usize,
mmap_file_path: &str,
) -> Result<Self, DenseMMapError> {
Expand All @@ -907,19 +916,9 @@ impl<H: Hasher> DenseMMapTree<H> {
Err(_e) => return Err(DenseMMapError::FailedToCreatePathBuf),
};

let leaf_count = 1 << depth;
let first_leaf_index = 1 << depth;
let storage_size = 1 << (depth + 1);

assert!(values.len() <= leaf_count);
let storage = DenseTree::<H>::vec_from_values(values, empty_value, depth);

let mut mmap = MmapMutWrapper::new_with_initial_values(path_buf, empty_leaf, storage_size)?;
mmap[first_leaf_index..(first_leaf_index + values.len())].clone_from_slice(values);
for i in (1..first_leaf_index).rev() {
let left = &mmap[2 * i];
let right = &mmap[2 * i + 1];
mmap[i] = H::hash_node(left, right);
}
let mmap = MmapMutWrapper::new_from_storage(path_buf, &storage)?;

Ok(Self {
depth,
Expand Down Expand Up @@ -1130,29 +1129,14 @@ impl<H: Hasher> MmapMutWrapper<H> {
/// - file size cannot be set
/// - file is too large, possible truncation can occur
/// - cannot build memory map
pub fn new_with_initial_values(
pub fn new_from_storage(
file_path: PathBuf,
initial_value: &H::Hash,
storage_size: usize,
storage: &[H::Hash],
) -> Result<Self, DenseMMapError> {
let size_of_val = std::mem::size_of_val(initial_value);
let initial_vals: Vec<H::Hash> = vec![initial_value.clone(); storage_size];

// cast Hash pointer to u8 pointer
let ptr = initial_vals.as_ptr().cast::<u8>();

let size_of_buffer: usize = storage_size * size_of_val;

let buf: &[u8] = unsafe {
// moving pointer by u8 for storage_size * size of hash would get us the full
// buffer
std::slice::from_raw_parts(ptr, size_of_buffer)
};

// assure that buffer is correct length
assert_eq!(buf.len(), size_of_buffer);

let file_size: u64 = storage_size as u64 * size_of_val as u64;
// Safety: potential uninitialized padding from `H::Hash` is safe to use if
// we're casting back to the same type.
let buf = unsafe { as_bytes(storage) };
let buf_len = buf.len();

let mut file = match OpenOptions::new()
.read(true)
Expand All @@ -1165,13 +1149,13 @@ impl<H: Hasher> MmapMutWrapper<H> {
Err(_e) => return Err(DenseMMapError::FileCreationFailed),
};

file.set_len(file_size).expect("cannot set file size");
file.set_len(buf_len as u64).expect("cannot set file size");
if file.write_all(buf).is_err() {
return Err(DenseMMapError::FileCannotWriteBytes);
}

let mmap = unsafe {
MmapOptions::new(usize::try_from(file_size).expect("file size truncated"))
MmapOptions::new(usize::try_from(buf_len as u64).expect("file size truncated"))
.expect("cannot create memory map")
.with_file(file, 0)
.map_mut()
Expand Down
2 changes: 1 addition & 1 deletion src/merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::{
/// Hash types, values and algorithms for a Merkle tree
pub trait Hasher {
/// Type of the leaf and node hashes
type Hash: Clone + Eq + Serialize + Debug + Send + Sync;
type Hash: Clone + Eq + Serialize + Debug + Send + Sync + Sized;

/// Compute the hash of an intermediate node
fn hash_node(left: &Self::Hash, right: &Self::Hash) -> Self::Hash;
Expand Down
32 changes: 32 additions & 0 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@ pub(crate) fn keccak256(bytes: &[u8]) -> [u8; 32] {
output
}

/// function to convert a reference to any type to a slice of bytes
/// # Safety
/// This function is unsafe because it transmutes a reference to a slice of
/// bytes. If `T` is not densly packed, then the padding bytes will be
/// unitialized.
pub(crate) unsafe fn as_bytes<T: ?Sized>(p: &T) -> &[u8] {
::core::slice::from_raw_parts(
(p as *const T).cast::<u8>(),
::core::mem::size_of_val::<T>(p),
)
}

pub(crate) fn bytes_to_hex<const N: usize, const M: usize>(bytes: &[u8; N]) -> [u8; M] {
// TODO: Replace `M` with a const expression once it's stable.
debug_assert_eq!(M, 2 * N + 2);
Expand Down Expand Up @@ -113,6 +125,26 @@ pub(crate) fn deserialize_bytes<'de, const N: usize, D: Deserializer<'de>>(
mod test {
use super::*;

#[test]
fn test_as_bytes() {
// test byte array
let bytes = [0, 1, 2, 3u8];
let converted: [u8; 4] = unsafe { as_bytes(&bytes).try_into().unwrap() };
assert_eq!(bytes, converted);

// test u64 array
let array = [1u64, 1];
let converted: [u8; 16] = unsafe { as_bytes(&array).try_into().unwrap() };
let expected = [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0];
assert_eq!(converted, expected);

// test u64
let value = 1u64;
let converted: [u8; 8] = unsafe { as_bytes(&value).try_into().unwrap() };
let expected = [1, 0, 0, 0, 0, 0, 0, 0];
assert_eq!(converted, expected);
}

#[test]
fn test_serialize_bytes_hex() {
let bytes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
Expand Down

0 comments on commit b78b0bf

Please sign in to comment.