Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML-DSA] AVX2 performance improvements in NTT #584

Merged
merged 15 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions libcrux-ml-dsa/src/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::{
constants::COEFFICIENTS_IN_RING_ELEMENT,
polynomial::{PolynomialRingElement, SIMD_UNITS_IN_RING_ELEMENT},
constants::COEFFICIENTS_IN_RING_ELEMENT, polynomial::PolynomialRingElement,
simd::traits::Operations,
};

Expand Down Expand Up @@ -72,7 +71,7 @@ pub(crate) fn decompose_vector<SIMDUnit: Operations, const DIMENSION: usize, con
let mut vector_high = [PolynomialRingElement::<SIMDUnit>::ZERO(); DIMENSION];

for i in 0..DIMENSION {
for j in 0..SIMD_UNITS_IN_RING_ELEMENT {
for j in 0..vector_low[0].simd_units.len() {
let (low, high) = SIMDUnit::decompose::<GAMMA2>(t[i].simd_units[j]);

vector_low[i].simd_units[j] = low;
Expand Down Expand Up @@ -118,7 +117,7 @@ pub(crate) fn use_hint<SIMDUnit: Operations, const DIMENSION: usize, const GAMMA
for i in 0..DIMENSION {
let hint_simd = PolynomialRingElement::<SIMDUnit>::from_i32_array(&hint[i]);

for j in 0..SIMD_UNITS_IN_RING_ELEMENT {
for j in 0..result[0].simd_units.len() {
result[i].simd_units[j] =
SIMDUnit::use_hint::<GAMMA2>(re_vector[i].simd_units[j], hint_simd.simd_units[j]);
}
Expand Down
93 changes: 4 additions & 89 deletions libcrux-ml-dsa/src/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,98 +34,13 @@ const ZETAS_TIMES_MONTGOMERY_R: [FieldElementTimesMontgomeryR; 256] = [
-1362209, 3937738, 1400424, -846154, 1976782,
];

#[inline(always)]
fn ntt_at_layer_0<SIMDUnit: Operations>(
zeta_i: &mut usize,
re: &mut PolynomialRingElement<SIMDUnit>,
) {
*zeta_i += 1;

for round in 0..re.simd_units.len() {
re.simd_units[round] = SIMDUnit::ntt_at_layer_0(
re.simd_units[round],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i + 1],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i + 2],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i + 3],
);

*zeta_i += 4;
}

*zeta_i -= 1;
}
#[inline(always)]
fn ntt_at_layer_1<SIMDUnit: Operations>(
zeta_i: &mut usize,
re: &mut PolynomialRingElement<SIMDUnit>,
) {
*zeta_i += 1;

for round in 0..re.simd_units.len() {
re.simd_units[round] = SIMDUnit::ntt_at_layer_1(
re.simd_units[round],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i + 1],
);

*zeta_i += 2;
}

*zeta_i -= 1;
}
#[inline(always)]
fn ntt_at_layer_2<SIMDUnit: Operations>(
zeta_i: &mut usize,
re: &mut PolynomialRingElement<SIMDUnit>,
) {
for round in 0..re.simd_units.len() {
*zeta_i += 1;
re.simd_units[round] =
SIMDUnit::ntt_at_layer_2(re.simd_units[round], ZETAS_TIMES_MONTGOMERY_R[*zeta_i]);
}
}
#[inline(always)]
fn ntt_at_layer_3_plus<SIMDUnit: Operations, const LAYER: usize>(
zeta_i: &mut usize,
re: &mut PolynomialRingElement<SIMDUnit>,
) {
let step = 1 << LAYER;

for round in 0..(128 >> LAYER) {
*zeta_i += 1;

let offset = (round * step * 2) / COEFFICIENTS_IN_SIMD_UNIT;
let step_by = step / COEFFICIENTS_IN_SIMD_UNIT;

for j in offset..offset + step_by {
let t = montgomery_multiply_by_fer::<SIMDUnit>(
re.simd_units[j + step_by],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i],
);

re.simd_units[j + step_by] = SIMDUnit::subtract(&re.simd_units[j], &t);
re.simd_units[j] = SIMDUnit::add(&re.simd_units[j], &t);
}
}
}

#[inline(always)]
pub(crate) fn ntt<SIMDUnit: Operations>(
mut re: PolynomialRingElement<SIMDUnit>,
re: PolynomialRingElement<SIMDUnit>,
) -> PolynomialRingElement<SIMDUnit> {
let mut zeta_i = 0;

ntt_at_layer_3_plus::<SIMDUnit, 7>(&mut zeta_i, &mut re);
ntt_at_layer_3_plus::<SIMDUnit, 6>(&mut zeta_i, &mut re);
ntt_at_layer_3_plus::<SIMDUnit, 5>(&mut zeta_i, &mut re);
ntt_at_layer_3_plus::<SIMDUnit, 4>(&mut zeta_i, &mut re);
ntt_at_layer_3_plus::<SIMDUnit, 3>(&mut zeta_i, &mut re);
ntt_at_layer_2::<SIMDUnit>(&mut zeta_i, &mut re);
ntt_at_layer_1::<SIMDUnit>(&mut zeta_i, &mut re);
ntt_at_layer_0::<SIMDUnit>(&mut zeta_i, &mut re);

re
PolynomialRingElement {
simd_units: SIMDUnit::ntt(re.simd_units),
}
}

#[inline(always)]
Expand Down
5 changes: 1 addition & 4 deletions libcrux-ml-dsa/src/polynomial.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use crate::simd::traits::{Operations, COEFFICIENTS_IN_SIMD_UNIT};

pub(crate) const SIMD_UNITS_IN_RING_ELEMENT: usize =
crate::constants::COEFFICIENTS_IN_RING_ELEMENT / COEFFICIENTS_IN_SIMD_UNIT;
use crate::simd::traits::{Operations, COEFFICIENTS_IN_SIMD_UNIT, SIMD_UNITS_IN_RING_ELEMENT};

#[derive(Clone, Copy)]
pub(crate) struct PolynomialRingElement<SIMDUnit: Operations> {
Expand Down
14 changes: 5 additions & 9 deletions libcrux-ml-dsa/src/simd/avx2.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::simd::traits::Operations;
use crate::simd::traits::{Operations, SIMD_UNITS_IN_RING_ELEMENT};
use libcrux_intrinsics;

mod arithmetic;
Expand Down Expand Up @@ -118,14 +118,10 @@ impl Operations for AVX2SIMDUnit {
encoding::t1::deserialize(serialized).into()
}

fn ntt_at_layer_0(simd_unit: Self, zeta0: i32, zeta1: i32, zeta2: i32, zeta3: i32) -> Self {
ntt::ntt_at_layer_0(simd_unit.coefficients, zeta0, zeta1, zeta2, zeta3).into()
}
fn ntt_at_layer_1(simd_unit: Self, zeta0: i32, zeta1: i32) -> Self {
ntt::ntt_at_layer_1(simd_unit.coefficients, zeta0, zeta1).into()
}
fn ntt_at_layer_2(simd_unit: Self, zeta: i32) -> Self {
ntt::ntt_at_layer_2(simd_unit.coefficients, zeta).into()
fn ntt(simd_units: [Self; SIMD_UNITS_IN_RING_ELEMENT]) -> [Self; SIMD_UNITS_IN_RING_ELEMENT] {
let result = ntt::ntt(simd_units.map(|x| x.coefficients));

result.map(|x| x.into())
}

fn invert_ntt_at_layer_0(
Expand Down
43 changes: 30 additions & 13 deletions libcrux-ml-dsa/src/simd/avx2/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,29 @@ fn simd_multiply_i32_and_return_high(lhs: Vec256, rhs: Vec256) -> Vec256 {
}

#[inline(always)]
pub fn montgomery_multiply_by_constant(simd_unit: Vec256, constant: i32) -> Vec256 {
let constant = mm256_set1_epi32(constant);
pub fn montgomery_multiply_by_constant(lhs: Vec256, constant: i32) -> Vec256 {
let rhs = mm256_set1_epi32(constant);
let field_modulus = mm256_set1_epi32(FIELD_MODULUS);
let inverse_of_modulus_mod_montgomery_r =
mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32);

let product_low = mm256_mullo_epi32(simd_unit, constant);
let prod02 = mm256_mul_epi32(lhs, rhs);
let prod13 = mm256_mul_epi32(
mm256_shuffle_epi32::<0b11_11_01_01>(lhs),
mm256_shuffle_epi32::<0b11_11_01_01>(rhs),
);

let k = mm256_mullo_epi32(product_low, inverse_of_modulus_mod_montgomery_r);
let k02 = mm256_mul_epi32(prod02, inverse_of_modulus_mod_montgomery_r);
let k13 = mm256_mul_epi32(prod13, inverse_of_modulus_mod_montgomery_r);

let c = simd_multiply_i32_and_return_high(k, field_modulus);
let product_high = simd_multiply_i32_and_return_high(simd_unit, constant);
let c02 = mm256_mul_epi32(k02, field_modulus);
let c13 = mm256_mul_epi32(k13, field_modulus);

mm256_sub_epi32(product_high, c)
let res02 = mm256_sub_epi32(prod02, c02);
let res13 = mm256_sub_epi32(prod13, c13);
let res02_shifted = mm256_shuffle_epi32::<0b11_11_01_01>(res02);
let res = mm256_blend_epi32::<0b10101010>(res02_shifted, res13);
res
}

#[inline(always)]
Expand All @@ -60,14 +69,22 @@ pub fn montgomery_multiply(lhs: Vec256, rhs: Vec256) -> Vec256 {
let inverse_of_modulus_mod_montgomery_r =
mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32);

let product_low = mm256_mullo_epi32(lhs, rhs);

let k = mm256_mullo_epi32(product_low, inverse_of_modulus_mod_montgomery_r);
let prod02 = mm256_mul_epi32(lhs, rhs);
let prod13 = mm256_mul_epi32(
mm256_shuffle_epi32::<0b11_11_01_01>(lhs),
mm256_shuffle_epi32::<0b11_11_01_01>(rhs),
);
let k02 = mm256_mul_epi32(prod02, inverse_of_modulus_mod_montgomery_r);
let k13 = mm256_mul_epi32(prod13, inverse_of_modulus_mod_montgomery_r);

let c = simd_multiply_i32_and_return_high(k, field_modulus);
let product_high = simd_multiply_i32_and_return_high(lhs, rhs);
let c02 = mm256_mul_epi32(k02, field_modulus);
let c13 = mm256_mul_epi32(k13, field_modulus);

mm256_sub_epi32(product_high, c)
let res02 = mm256_sub_epi32(prod02, c02);
let res13 = mm256_sub_epi32(prod13, c13);
let res02_shifted = mm256_shuffle_epi32::<0b11_11_01_01>(res02);
let res = mm256_blend_epi32::<0b10101010>(res02_shifted, res13);
res
}

#[inline(always)]
Expand Down
Loading
Loading