Skip to content

Commit

Permalink
Merge pull request #3 from nstilt1/revert-2-counter_support
Browse files Browse the repository at this point in the history
Revert "Added 64-bit counter for legacy; tested on all backends"
  • Loading branch information
nstilt1 authored Dec 14, 2023
2 parents c887b10 + b3031ea commit c41cafa
Show file tree
Hide file tree
Showing 16 changed files with 195 additions and 392 deletions.
31 changes: 5 additions & 26 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 1 addition & 4 deletions benches/.cargo/config.example
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
# This config file will choose which Rust Flag to use for the benches.
# Rename this file to "config" and uncomment out a rustflag to use it.
# The config file will be excluded with `*/.cargo/config` in the .gitignore
#
# For SSE2, the C target flag isn't required for the code to work, but it can
# prevent the other instruction sets from being used instead

[build]
#rustflags = "--cfg chacha20_force_avx2 -C target-feature=+avx2"
#rustflags = "--cfg chacha20_force_sse2 -C target-feature=+sse2,-sse4.1,-sse4.2"
#rustflags = "--cfg chacha20_force_sse2" # untested
#rustflags = "--cfg chacha20_force_neon"
rustflags = "--cfg chacha20_force_soft"
4 changes: 2 additions & 2 deletions benches/src/chacha20.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use chacha20::{
};

const KB: usize = 1024;
#[cfg(any(target_arch = "x86_64", target_arch = "x86", all(target_arch = "aarch64", target_os = "Linux")))]
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
fn bench(c: &mut Criterion<CyclesPerByte>) {
let mut group = c.benchmark_group("stream-cipher");

Expand All @@ -31,7 +31,7 @@ fn bench(c: &mut Criterion<CyclesPerByte>) {
group.finish();
}

#[cfg(not(any(target_arch = "x86_64", target_arch = "x86", all(target_arch = "aarch64", target_os = "Linux"))))]
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
fn bench(c: &mut Criterion) {
let mut group = c.benchmark_group("stream-cipher");

Expand Down
1 change: 0 additions & 1 deletion chacha20/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ rand_chacha = "0.3.1"
serde_json = "1.0" # Only to test serde1
sha2 = "0.10" # For testing fill_bytes()
sha3 = "0.10" # Also for testing fill_bytes(), but it may be unnecessary
chacha_0_7 = { package = "chacha20", version = "0.7.0", features = ["legacy"] } # Testing 64-bit counter

[features]
default = ["cipher"]
Expand Down
81 changes: 16 additions & 65 deletions chacha20/src/backends/avx2.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use crate::{Rounds, variants::Variant, STATE_WORDS};
use crate::{Rounds, ChaChaCore, variants::Variant, STATE_WORDS};
use core::marker::PhantomData;
#[cfg(feature = "rand_core")]
use crate::ChaChaCore;

#[cfg(target_arch = "x86")]
use core::arch::x86::*;
Expand All @@ -22,21 +20,19 @@ const PAR_BLOCKS: usize = 4;
/// Number of `__m256i` to store parallel blocks.
const N: usize = PAR_BLOCKS / 2;

struct Backend<R: Rounds, V: Variant> {
struct Backend<R: Rounds> {
v: [__m256i; 3],
ctr: [__m256i; N],
_pd: PhantomData<R>,
variant: PhantomData<V>
}

#[inline]
#[cfg(feature = "cipher")]
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn inner<R, F, V>(state: &mut [u32; STATE_WORDS], f: F)
pub(crate) unsafe fn inner<R, F>(state: &mut [u32; STATE_WORDS], f: F)
where
R: Rounds,
F: StreamClosure<BlockSize = U64>,
V: Variant
{
let state_ptr = state.as_ptr() as *const __m128i;
let v = [
Expand All @@ -45,59 +41,41 @@ where
_mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(2))),
];
let mut c = _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(3)));
if V::IS_32_BIT_COUNTER {
c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 0));
}else{
c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 1, 0, 0));
}
c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 0));
let mut ctr = [c; N];
for i in 0..N {
ctr[i] = c;
if V::IS_32_BIT_COUNTER {
c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 2, 0, 0, 0, 2));
}else{
c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 2, 0, 2));
}
c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 2, 0, 0, 0, 2));
}
let mut backend = Backend::<R, V> {
let mut backend = Backend::<R> {
v,
ctr,
_pd: PhantomData,
variant: PhantomData
};

f.call(&mut backend);

// handle 32-bit counter
state[12] = _mm256_extract_epi32(backend.ctr[0], 0) as u32;
// handle 64-bit counter
if !V::IS_32_BIT_COUNTER {
state[13] = _mm256_extract_epi32(backend.ctr[0], 1) as u32;
}
}

#[cfg(feature = "cipher")]
impl<R: Rounds, V: Variant> BlockSizeUser for Backend<R, V> {
impl<R: Rounds> BlockSizeUser for Backend<R> {
type BlockSize = U64;
}

#[cfg(feature = "cipher")]
impl<R: Rounds, V: Variant> ParBlocksSizeUser for Backend<R, V> {
impl<R: Rounds> ParBlocksSizeUser for Backend<R> {
type ParBlocksSize = U4;
}

#[cfg(feature = "cipher")]
impl<R: Rounds, V: Variant> StreamBackend for Backend<R, V> {
impl<R: Rounds> StreamBackend for Backend<R> {
#[inline(always)]
fn gen_ks_block(&mut self, block: &mut Block) {
unsafe {
let res = rounds::<R>(&self.v, &self.ctr);
for c in self.ctr.iter_mut() {
if V::IS_32_BIT_COUNTER {
*c = _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 1));
}else{
*c = _mm256_add_epi64(*c, _mm256_set_epi64x(0, 1, 0, 1));
}
*c = _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 1));
}

let res0: [__m128i; 8] = core::mem::transmute(res[0]);
Expand All @@ -116,11 +94,7 @@ impl<R: Rounds, V: Variant> StreamBackend for Backend<R, V> {

let pb = PAR_BLOCKS as i32;
for c in self.ctr.iter_mut() {
if V::IS_32_BIT_COUNTER {
*c = _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, pb, 0, 0, 0, pb));
}else{
*c = _mm256_add_epi64(*c, _mm256_set_epi64x(0, pb as i64, 0, pb as i64));
}
*c = _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, pb, 0, 0, 0, pb));
}

let mut block_ptr = blocks.as_mut_ptr() as *mut __m128i;
Expand Down Expand Up @@ -154,44 +128,25 @@ where
_mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(2))),
];
let mut c = _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(3)));

if V::IS_32_BIT_COUNTER {
// handle 32-bit counter
c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 0));
}else{
// handle 64-bit counter
c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 1, 0, 0));
}
c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 0));
let mut ctr = [c; N];
for i in 0..N {
ctr[i] = c;
if V::IS_32_BIT_COUNTER {
// handle 32-bit counter
c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 2, 0, 0, 0, 2));
}else{
// handle 64-bit counter
c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 2, 0, 2));
}
c = _mm256_add_epi32(c, _mm256_set_epi32(0, 0, 0, 2, 0, 0, 0, 2));
}
let mut backend = Backend::<R, V> {
let mut backend = Backend::<R> {
v,
ctr,
_pd: PhantomData,
variant: PhantomData
};

backend.rng_gen_par_ks_blocks(dest);

// handle 32-bit counter
core.state[12] = _mm256_extract_epi32(backend.ctr[0], 0) as u32;
// handle 64-bit counter
if !V::IS_32_BIT_COUNTER {
core.state[13] = _mm256_extract_epi32(backend.ctr[0], 1) as u32;
}
}

#[cfg(feature = "rand_core")]
impl<R: Rounds, V: Variant> Backend<R, V> {
impl<R: Rounds> Backend<R> {
#[inline(always)]
/// This is essentially the same as gen_par_ks_blocks except that it
/// takes a pointer.
Expand All @@ -201,11 +156,7 @@ impl<R: Rounds, V: Variant> Backend<R, V> {

let pb = PAR_BLOCKS as i32;
for c in self.ctr.iter_mut() {
if V::IS_32_BIT_COUNTER {
*c = _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, pb, 0, 0, 0, pb));
}else{
*c = _mm256_add_epi64(*c, _mm256_set_epi64x(0, pb as i64, 0, pb as i64));
}
*c = _mm256_add_epi32(*c, _mm256_set_epi32(0, 0, 0, pb, 0, 0, 0, pb));
}

let mut block_ptr = dest as *mut __m128i;
Expand Down
Loading

0 comments on commit c41cafa

Please sign in to comment.