diff --git a/chacha20/src/rng.rs b/chacha20/src/rng.rs index 6e1afc17..eda339f4 100644 --- a/chacha20/src/rng.rs +++ b/chacha20/src/rng.rs @@ -160,7 +160,7 @@ impl From<[u8; 5]> for WordPosInput { impl From for WordPosInput { fn from(value: u64) -> Self { - let shifted = (value >> 4).to_le_bytes(); + let shifted = (value >> 6).to_le_bytes(); let original = value.to_le_bytes(); let mut result = [0u8; 5]; result[4] = original[0]; @@ -466,25 +466,23 @@ macro_rules! impl_chacha_rng { /// Get the offset from the start of the stream, in 32-bit words. /// - /// Since the generated blocks are 16 words (24) long and the + /// Since the generated blocks are 64 words (26) long and the /// counter is 32-bits, the offset is a 36-bit number. Sub-word offsets are /// not supported, hence the result can simply be multiplied by 4 to get a /// byte-offset. #[inline] pub fn get_word_pos(&self) -> u64 { - let buf_start_block = { - let buf_end_block = self.core.block.get_block_pos(); - u32::wrapping_sub(buf_end_block, BUF_BLOCKS.into()) - }; - let (buf_offset_blocks, block_offset_words) = { - let buf_offset_words = self.index as u32; - let blocks_part = buf_offset_words / u32::from(BLOCK_WORDS); - let words_part = buf_offset_words % u32::from(BLOCK_WORDS); - (blocks_part, words_part) - }; - let pos_block = u32::wrapping_add(buf_start_block, buf_offset_blocks); - let pos_block_words = u64::from(pos_block) * u64::from(BLOCK_WORDS); - pos_block_words + u64::from(block_offset_words) + // block_pos is a multiple of 4, and offset by 4; therefore, it already has the + // last 2 bits set to 0, allowing us to shift it left 4 and add the index + let mut result = u64::from( + self.core + .block + .get_block_pos() + .wrapping_sub(BUF_BLOCKS.into()), + ) << 4; + result += self.index as u64; + // eliminate the 36th bit + result & 0xfffffffff } /// Set the offset from the start of the stream, in 32-bit words. This method @@ -502,10 +500,14 @@ macro_rules! impl_chacha_rng { #[inline] pub fn set_word_pos>(&mut self, word_offset: W) { let word_offset: WordPosInput = word_offset.into(); - self.core - .block - .set_block_pos(u32::from_le_bytes(word_offset.0[0..4].try_into().unwrap())); - self.generate_and_set((word_offset.0[4] & 0x0F) as usize); + // when not using `set_word_pos`, the block_pos is always a multiple of 4. + // This change follows those conventions, as well as maintaining the 6-bit + // index + self.core.block.set_block_pos( + u32::from_le_bytes(word_offset.0[0..4].try_into().unwrap()) << 2, + ); + // generate will increase block_pos by 4 + self.generate_and_set((word_offset.0[4] & 0x3F) as usize); } /// Set the stream number. The lower 96 bits are used and the rest are @@ -529,8 +531,7 @@ macro_rules! impl_chacha_rng { *n = u32::from_le_bytes(chunk.try_into().unwrap()); } if self.index != 64 { - let wp = self.get_word_pos(); - self.set_word_pos(wp); + self.generate_and_set(self.index); } } @@ -683,21 +684,6 @@ mod tests { 26, 27, 28, 29, 30, 31, 32, ]; - // this test will not pass without the user passing a mutable input because the value - // is copied into the method - // #[test] - // #[cfg(feature = "zeroize")] - // fn test_zeroize_inputs_external() { - // let initial_seed = KEY.clone(); - // let ptr = initial_seed.as_ptr(); - // { - // let mut rng = ChaChaRng::from_seed(initial_seed.into()); - // rng.fill_bytes(&mut [0u8; 32]); - // } - // let memory_inspection = unsafe { core::slice::from_raw_parts(ptr, 32) }; - // assert_ne!(&KEY, memory_inspection); - // } - #[test] #[cfg(feature = "zeroize")] fn test_zeroize_inputs_internal() { @@ -727,11 +713,92 @@ mod tests { ); } + fn expend_u32(rng: &mut ChaCha20Rng, amount: usize) { + for _i in 0..amount { + rng.next_u32(); + } + } + #[test] - /// there was a little error with the usize::from_le_bytes() + /// an additional test for the new set_word_pos() + /// twas mainly for debugging it fn test_set_word_pos() { - let mut rng = ChaCha20Rng::from_entropy(); - rng.set_word_pos(3533); + let mut rng = ChaCha20Rng::from_seed([0u8; 32]); + let mut cloned = rng.clone(); + let mut u32_array = [0u32; 128]; + for val in u32_array.iter_mut() { + *val = cloned.next_u32(); + } + assert_ne!([0u32; 128], u32_array); + + // testing how block_pos normally works + assert_eq!(rng.core.block.get_block_pos(), 0); + expend_u32(&mut rng, 1); + assert_eq!(rng.core.block.get_block_pos(), 4); + expend_u32(&mut rng, 17); + assert_eq!(rng.core.block.get_block_pos(), 4); + expend_u32(&mut rng, 25); + assert_eq!(rng.core.block.get_block_pos(), 4); + + // advance to the next increase of block_pos + rng = ChaChaRng::from_seed([0u8; 32]); + expend_u32(&mut rng, 68); + assert_eq!(rng.core.block.get_block_pos(), 8); + + // advance to the next increase of block_pos + expend_u32(&mut rng, 18); + assert_eq!(rng.core.block.get_block_pos(), 8); + expend_u32(&mut rng, 18); + assert_eq!(rng.core.block.get_block_pos(), 8); + expend_u32(&mut rng, 34); + assert_eq!(rng.core.block.get_block_pos(), 12); + + rng = ChaCha20Rng::from_seed([0u8; 32]); + expend_u32(&mut rng, 513); + assert_eq!(rng.core.block.get_block_pos(), 36); + + // testing word_pos and output generation + rng.set_word_pos(0); + assert_eq!(rng.next_u32(), u32_array[0]); + + rng.set_word_pos(63); + assert_eq!(rng.get_word_pos(), 63); + assert_eq!(rng.index, 63); + + assert_eq!(rng.next_u32(), u32_array[63]); + assert_eq!(rng.index, 64); + assert_eq!(rng.core.block.get_block_pos(), 4); + assert_eq!(rng.get_word_pos(), 64); + + assert_eq!(rng.next_u32(), u32_array[64]); + assert_eq!(rng.index, 1); + assert_eq!(rng.get_word_pos(), 65); + + assert_eq!(rng.next_u32(), u32_array[65]); + assert_eq!(rng.index, 2); + assert_eq!(rng.get_word_pos(), 66); + + let test_word_pos = 1234567; + rng.set_word_pos(test_word_pos); + assert_eq!( + rng.core.block.get_block_pos(), + (test_word_pos as f32 / 64.0f32).ceil() as u32 * 4 + ); + assert_eq!(rng.get_word_pos(), test_word_pos); + + let max_word_pos: u64 = (2 as u64).pow(36) - 1; + rng.set_word_pos(max_word_pos); + assert_eq!(rng.get_word_pos(), max_word_pos); + rng.next_u32(); + rng.next_u32(); + assert_eq!(rng.get_word_pos(), 1); + + // final round for this test + for _i in 0..1024 { + let word_pos = rng.next_u64() & ((1 << 36 as u64) - 1); + rng.set_word_pos(word_pos); + assert_eq!(word_pos, rng.get_word_pos()); + } } #[test] fn test_wrapping_add() { @@ -1039,7 +1106,7 @@ mod tests { use super::{BLOCK_WORDS, BUF_BLOCKS}; let mut rng = ChaChaRng::from_seed(Default::default()); // refilling the buffer in set_word_pos will wrap the block counter to 0 - let last_block = (1 << 36) - u64::from(BUF_BLOCKS * BLOCK_WORDS); + let last_block = (2 as u64).pow(36) - u64::from(BUF_BLOCKS * BLOCK_WORDS); rng.set_word_pos(last_block); assert_eq!(rng.get_word_pos(), last_block); } @@ -1057,6 +1124,8 @@ mod tests { #[test] fn test_chacha_word_pos_zero() { let mut rng = ChaChaRng::from_seed(Default::default()); + assert_eq!(rng.core.block.get_block_pos(), 0); + assert_eq!(rng.index, 64); assert_eq!(rng.get_word_pos(), 0); rng.set_word_pos(0); assert_eq!(rng.get_word_pos(), 0);