Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(derive): Span batch bitlist encoding #122

Merged
merged 1 commit into from
Apr 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 91 additions & 26 deletions crates/derive/src/types/batch/span_batch/bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::types::{SpanBatchError, MAX_SPAN_BATCH_SIZE};
use alloc::{vec, vec::Vec};
use alloy_rlp::Buf;
use anyhow::Result;
use core::cmp::Ordering;

/// Type for span batch bits.
#[derive(Debug, Default, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -47,12 +48,11 @@ impl SpanBatchBits {
b.advance(buffer_len);
v
};
let sb_bits = SpanBatchBits(bits.to_vec());
let sb_bits = SpanBatchBits(bits);

// TODO(clabby): Why doesn't this check work?
// if sb_bits.bit_len() > bit_length {
// return Err(SpanBatchError::BitfieldTooLong);
// }
if sb_bits.bit_len() > bit_length {
return Err(SpanBatchError::BitfieldTooLong);
}

Ok(sb_bits)
}
Expand All @@ -65,10 +65,9 @@ impl SpanBatchBits {
bit_length: usize,
bits: &SpanBatchBits,
) -> Result<(), SpanBatchError> {
// TODO(clabby): Why doesn't this check work?
// if bits.bit_len() > bit_length {
// return Err(SpanBatchError::BitfieldTooLong);
// }
if bits.bit_len() > bit_length {
return Err(SpanBatchError::BitfieldTooLong);
}

// Round up, ensure enough bytes when number of bits is not a multiple of 8.
// Alternative of (L+7)/8 is not overflow-safe.
Expand All @@ -90,12 +89,12 @@ impl SpanBatchBits {
// Check if the byte index is within the bounds of the bitlist
if byte_index < self.0.len() {
// Retrieve the specific byte that contains the bit we're interested in
let byte = self.0[byte_index];
let byte = self.0[self.0.len() - byte_index - 1];

// Shift the bits of the byte to the right, based on the bit index, and
// mask it with 1 to isolate the bit we're interested in.
// If the result is not zero, the bit is set to 1, otherwise it's 0.
Some(if byte & (1 << (8 - bit_index)) != 0 { 1 } else { 0 })
Some(if byte & (1 << bit_index) != 0 { 1 } else { 0 })
} else {
// Return None if the index is out of bounds
None
Expand All @@ -110,34 +109,58 @@ impl SpanBatchBits {
// Ensure the vector is large enough to contain the bit at 'index'.
// If not, resize the vector, filling with 0s.
if byte_index >= self.0.len() {
self.0.resize(byte_index + 1, 0);
Self::resize_from_right(&mut self.0, byte_index + 1);
}

// Retrieve the specific byte to modify
let byte = &mut self.0[byte_index];
let len = self.0.len();
let byte = &mut self.0[len - byte_index - 1];

if value {
// Set the bit to 1
*byte |= 1 << (8 - bit_index);
*byte |= 1 << bit_index;
} else {
// Set the bit to 0
*byte &= !(1 << (8 - bit_index));
*byte &= !(1 << bit_index);
}
}

/// Calculates the bit length of the [SpanBatchBits] bitfield.
pub fn bit_len(&self) -> usize {
if let Some((top_word, rest)) = self.0.split_last() {
// Calculate bit length. Rust's leading_zeros counts zeros from the MSB, so subtract
// from total bits.
let significant_bits = 8 - top_word.leading_zeros() as usize;

// Return total bits, taking into account the full words in `rest` and the significant
// bits in `top`.
rest.len() * 8 + significant_bits
} else {
// If the slice is empty, return 0.
0
// Iterate over the bytes from left to right to find the first non-zero byte
for (i, &byte) in self.0.iter().enumerate() {
if byte != 0 {
// Calculate the index of the most significant bit in the byte
let msb_index = 7 - byte.leading_zeros() as usize; // 0-based index

// Calculate the total bit length
let total_bit_length = msb_index + 1 + ((self.0.len() - i - 1) * 8);
return total_bit_length;
}
}

// If all bytes are zero, the bitlist is considered to have a length of 0
0
}

/// Resizes an array from the right. Useful for big-endian zero extension.
fn resize_from_right<T: Default + Clone>(vec: &mut Vec<T>, new_size: usize) {
let current_size = vec.len();
match new_size.cmp(&current_size) {
Ordering::Less => {
// Remove elements from the beginning.
let remove_count = current_size - new_size;
vec.drain(0..remove_count);
}
Ordering::Greater => {
// Calculate how many new elements to add.
let additional = new_size - current_size;
// Prepend new elements with default values.
let mut prepend_elements = vec![T::default(); additional];
prepend_elements.append(vec);
*vec = prepend_elements;
}
Ordering::Equal => { /* If new_size == current_size, do nothing. */ }
}
}
}
Expand All @@ -156,6 +179,48 @@ mod test {
SpanBatchBits::encode(&mut encoded, bits.0.len() * 8, &bits).unwrap();
assert_eq!(encoded, bits.0);
}

#[test]
fn test_span_bitlist_bitlen(index in 0usize..65536) {
let mut bits = SpanBatchBits::default();
bits.set_bit(index, true);
assert_eq!(bits.0.len(), (index / 8) + 1);
assert_eq!(bits.bit_len(), index + 1);
}

#[test]
fn test_span_bitlist_bitlen_shrink(first_index in 8usize..65536) {
let second_index = first_index.clamp(0, first_index - 8);
let mut bits = SpanBatchBits::default();

// Set and clear first index.
bits.set_bit(first_index, true);
assert_eq!(bits.0.len(), (first_index / 8) + 1);
assert_eq!(bits.bit_len(), first_index + 1);
bits.set_bit(first_index, false);
assert_eq!(bits.0.len(), (first_index / 8) + 1);
assert_eq!(bits.bit_len(), 0);

// Set second bit. Even though the array is larger, as it was originally allocated with more words,
// the bitlength should still be lowered as the higher-order words are 0'd out.
bits.set_bit(second_index, true);
assert_eq!(bits.0.len(), (first_index / 8) + 1);
assert_eq!(bits.bit_len(), second_index + 1);
}
}

#[test]
fn bitlist_big_endian_zero_extended() {
let mut bits = SpanBatchBits::default();

bits.set_bit(1, true);
bits.set_bit(6, true);
bits.set_bit(8, true);
bits.set_bit(15, true);
assert_eq!(bits.0[0], 0b1000_0001);
assert_eq!(bits.0[1], 0b0100_0010);
assert_eq!(bits.0.len(), 2);
assert_eq!(bits.bit_len(), 16);
}

#[test]
Expand Down
Loading