Skip to content

Commit

Permalink
fixed benches; tried to make set_word_pos more coherent by changing t…
Browse files Browse the repository at this point in the history
…he WordPosInput struct and adjusting the documentation; added checks in test_set_and_get_equivalence
  • Loading branch information
nstilt1 committed Jan 18, 2024
1 parent 0295f7c commit a8cd5c3
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 27 deletions.
2 changes: 1 addition & 1 deletion benches/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion::Criterion;

#[cfg(any(target_arch = "x86_64", target_arch = "x86", all(target_arch = "aarch64", target_os = "linux")))]
pub type Benchmarker = Criterion<criterion_cycles_per_byte::CyclesPerByte>;
Expand Down
77 changes: 51 additions & 26 deletions chacha20/src/rng.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,26 @@ impl Debug for Seed {
/// A wrapper for set_word_pos() input that can be assembled from:
/// * `u64`
/// * `[u8; 5]`
pub struct WordPosInput([u8; 5]);
pub struct WordPosInput {
block_pos: u32,
index: usize
}

impl From<[u8; 5]> for WordPosInput {
fn from(value: [u8; 5]) -> Self {
Self(value)
Self {
block_pos: u32::from_le_bytes(value[0..4].try_into().unwrap()),
index: (value[4] & 0b1111) as usize,
}
}
}

impl From<u64> for WordPosInput {
fn from(value: u64) -> Self {
let mut result = [0u8; 5];
let block_pos = (value >> 4).to_le_bytes();
let index_byte = value.to_le_bytes()[0];
result[0..4].copy_from_slice(&block_pos[0..4]);
result[4] = index_byte;
Self(result)
Self {
block_pos: u32::from_le_bytes((value >> 4).to_le_bytes()[0..4].try_into().unwrap()),
index: (value.to_le_bytes()[0] & 0b1111) as usize
}
}
}

Expand Down Expand Up @@ -123,10 +127,12 @@ impl From<[u8; 12]> for StreamId {

impl From<u128> for StreamId {
fn from(value: u128) -> Self {
let mut lower_12_bytes: [u8; 12] = [0u8; 12];
let bytes = value.to_le_bytes();
lower_12_bytes.copy_from_slice(&bytes[0..12]);
lower_12_bytes.into()
let mut result = Self([0u32; 3]);
for (n, chunk) in result.0.iter_mut().zip(bytes[0..12].chunks_exact(4)) {
*n = u32::from_le_bytes(chunk.try_into().unwrap());
}
result
}
}

Expand Down Expand Up @@ -426,20 +432,23 @@ macro_rules! impl_chacha_rng {
/// * u64
/// * [u8; 5]
///
/// As with `get_word_pos`, we use a 36-bit number. Since the generator
/// simply cycles at the end of its period (256 GiB), we ignore the upper 28
/// bits of a `u64`. When given a `[u8; 5]`, we ignore the first 4 bits of the
/// last byte.
/// As with `get_word_pos`, we use a 36-bit number. When given a `u64`, we use
/// the least significant 4 bits as the RNG's index, and the 32 bits before it
/// as the block position.
///
/// When given a `[u8; 5]`, the word_pos is set similarly, but it is more
/// arbitrary.
#[inline]
pub fn set_word_pos<W: Into<WordPosInput>>(&mut self, word_offset: W) {
let word_offset: WordPosInput = word_offset.into();
self.core.state[12] = (u32::from_le_bytes(word_offset.0[0..4].try_into().unwrap()));
let word_pos: WordPosInput = word_offset.into();
self.core.state[12] = word_pos.block_pos;
// generate will increase block_pos by 4
self.generate_and_set((word_offset.0[4] & 0x0F) as usize);
self.generate_and_set(word_pos.index);
}

/// Set the stream number. The lower 96 bits are used and the rest are
/// discarded. This method takes either:
/// * [u32; 3]
/// * [u8; 12]
/// * u128
///
Expand Down Expand Up @@ -645,18 +654,34 @@ pub(crate) mod tests {
#[test]
fn test_set_and_get_equivalence() {
let seed = [44u8; 32];
let mut rng = ChaCha20Rng::from_seed(seed.into());
let stream = 1337 as u128;
rng.set_stream(stream);
let word_pos = 35534 as u64;
rng.set_word_pos(word_pos);
let mut rng = ChaCha20Rng::from_seed(seed);

// test set_stream with [u32; 3]
rng.set_stream([313453u32, 0u32, 0u32]);
assert_eq!(rng.get_stream(), 313453);

// test set_stream with [u8; 12]
rng.set_stream([89, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
assert_eq!(rng.get_stream(), 89);

assert_eq!(rng.get_seed(), seed);
assert_eq!(rng.get_stream(), stream);
assert_eq!(rng.get_word_pos(), word_pos);
// test set_stream with u128
rng.set_stream(11111111);
assert_eq!(rng.get_stream(), 11111111);

// test set_block_pos with u32
rng.core.set_block_pos(58392);
assert_eq!(rng.core.get_block_pos(), 58392);

// test set_block_pos with [u8; 4]
rng.core.set_block_pos([77, 0, 0, 0]);
assert_eq!(rng.core.get_block_pos(), 77);

// test set_word_pos with u64
rng.set_word_pos(8888);
assert_eq!(rng.get_word_pos(), 8888);

// test set_word_pos with [u8; 5]
rng.set_word_pos([55, 0, 0, 0, 0])
}

#[cfg(feature = "serde1")]
Expand Down

0 comments on commit a8cd5c3

Please sign in to comment.