Skip to content

Commit

Permalink
refactor(core): remove a copy in the external product
Browse files Browse the repository at this point in the history
- add an fft backward primitive that can use the input fourier buffer as
output as well
- gains 0.6 ms on 2_2 m6i.metal
  • Loading branch information
IceTDrinker committed Oct 2, 2023
1 parent 8cc8dba commit 5edea3b
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 9 deletions.
8 changes: 5 additions & 3 deletions tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::super::math::decomposition::TensorSignedDecompositionLendingIter;
use super::super::math::fft::{FftView, FourierPolynomialList};
use super::super::math::polynomial::{FourierPolynomialMutView, FourierPolynomialView};
use super::super::math::polynomial::FourierPolynomialMutView;
use crate::core_crypto::commons::math::decomposition::{DecompositionLevel, SignedDecomposer};
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::parameters::{
Expand Down Expand Up @@ -588,10 +588,12 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
out.as_mut_polynomial_list().iter_mut(),
output_fft_buffer
.into_chunks(fourier_poly_size)
.map(|slice| FourierPolynomialView { data: slice }),
.map(|slice| FourierPolynomialMutView { data: slice }),
)
.for_each(|(out, fourier)| {
fft.add_backward_as_torus(out, fourier, substack0.rb_mut());
// The fourier buffer is not re-used afterwards so we can use the in-place version of
// the add_backward_as_torus function
fft.add_backward_in_place_as_torus(out, fourier, substack0.rb_mut());
});
}
}
Expand Down
31 changes: 31 additions & 0 deletions tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,17 @@ impl<'a> FftView<'a> {
self.backward_with_conv(standard, fourier, convert_add_backward_torus, stack)
}

/// Variant of [`Self::add_backward_as_torus`] writing the output of the backward fourier
/// transform in the [`FourierPolynomialMutView`] in place.
pub fn add_backward_in_place_as_torus<Scalar: UnsignedTorus>(
self,
standard: PolynomialMutView<'_, Scalar>,
fourier: FourierPolynomialMutView<'_>,
stack: PodStack<'_>,
) {
self.backward_with_conv_in_place(standard, fourier, convert_add_backward_torus, stack)
}

fn forward_with_conv<
'out,
Scalar: UnsignedTorus,
Expand Down Expand Up @@ -524,6 +535,26 @@ impl<'a> FftView<'a> {
let (standard_re, standard_im) = standard.split_at_mut(n / 2);
conv_fn(standard_re, standard_im, &tmp, self.twisties);
}

fn backward_with_conv_in_place<
Scalar: UnsignedTorus,
F: Fn(&mut [Scalar], &mut [Scalar], &[c64], TwistiesView<'_>),
>(
self,
mut standard: PolynomialMutView<'_, Scalar>,
fourier: FourierPolynomialMutView<'_>,
conv_fn: F,
stack: PodStack<'_>,
) {
let fourier = fourier.data;
let standard = standard.as_mut();
let n = standard.len();
debug_assert_eq!(n, 2 * fourier.len());
self.plan.inv(fourier, stack);

let (standard_re, standard_im) = standard.split_at_mut(n / 2);
conv_fn(standard_re, standard_im, fourier, self.twisties);
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
Expand Down
95 changes: 89 additions & 6 deletions tfhe/src/core_crypto/fft_impl/fft64/math/fft/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ fn abs_diff<Scalar: UnsignedTorus>(a: Scalar, b: Scalar) -> Scalar {

fn test_roundtrip<Scalar: UnsignedTorus>() {
let mut generator = new_random_generator();
for size_log in 2..=14 {
// SIMD versions need size >= 8
for size_log in 3..=14 {
let size = 1_usize << size_log;

let fft = Fft::new(PolynomialSize(size));
Expand All @@ -40,6 +41,7 @@ fn test_roundtrip<Scalar: UnsignedTorus>() {
);
let mut stack = PodStack::new(&mut mem);

// Simple roundtrip
fft.forward_as_torus(fourier.as_mut_view(), poly.as_view(), stack.rb_mut());
fft.backward_as_torus(roundtrip.as_mut_view(), fourier.as_view(), stack.rb_mut());

Expand All @@ -50,6 +52,38 @@ fn test_roundtrip<Scalar: UnsignedTorus>() {
assert!(abs_diff(*expected, *actual) < (Scalar::ONE << (64 - 50)));
}
}

// Simple add roundtrip
// Need to zero out the buffer to have a correct result as we will be adding the result
roundtrip.as_mut().fill(Scalar::ZERO);
fft.forward_as_torus(fourier.as_mut_view(), poly.as_view(), stack.rb_mut());
fft.add_backward_as_torus(roundtrip.as_mut_view(), fourier.as_view(), stack.rb_mut());

for (expected, actual) in izip!(poly.as_ref().iter(), roundtrip.as_ref().iter()) {
if Scalar::BITS == 32 {
assert!(abs_diff(*expected, *actual) == Scalar::ZERO);
} else {
assert!(abs_diff(*expected, *actual) < (Scalar::ONE << (64 - 50)));
}
}

// Forward, then add backward in place
// Need to zero out the buffer to have a correct result as we will be adding the result
roundtrip.as_mut().fill(Scalar::ZERO);
fft.forward_as_torus(fourier.as_mut_view(), poly.as_view(), stack.rb_mut());
fft.add_backward_in_place_as_torus(
roundtrip.as_mut_view(),
fourier.as_mut_view(),
stack.rb_mut(),
);

for (expected, actual) in izip!(poly.as_ref().iter(), roundtrip.as_ref().iter()) {
if Scalar::BITS == 32 {
assert!(abs_diff(*expected, *actual) == Scalar::ZERO);
} else {
assert!(abs_diff(*expected, *actual) < (Scalar::ONE << (64 - 50)));
}
}
}
}

Expand All @@ -74,7 +108,8 @@ fn test_product<Scalar: UnsignedTorus>() {
}

let mut generator = new_random_generator();
for size_log in 5..=14 {
// SIMD version need a minimal size
for size_log in 3..=14 {
for _ in 0..100 {
let size = 1_usize << size_log;

Expand Down Expand Up @@ -119,15 +154,63 @@ fn test_product<Scalar: UnsignedTorus>() {
*f0 *= *f1;
}

convolution_naive(
convolution_from_naive.as_mut(),
poly0.as_ref(),
poly1.as_ref(),
);

// Simple backward
fft.backward_as_torus(
convolution_from_fft.as_mut_view(),
fourier0.as_view(),
stack.rb_mut(),
);
convolution_naive(
convolution_from_naive.as_mut(),
poly0.as_ref(),
poly1.as_ref(),

for (expected, actual) in izip!(
convolution_from_naive.as_ref().iter(),
convolution_from_fft.as_ref().iter()
) {
let threshold =
Scalar::ONE << (Scalar::BITS.saturating_sub(52 - integer_magnitude - size_log));
let abs_diff = abs_diff(*expected, *actual);
assert!(
abs_diff <= threshold,
"abs_diff: {abs_diff}, threshold: {threshold}",
);
}

// Simple add backward
// Need to zero out the buffer to have a correct result as we will be adding the result
convolution_from_fft.as_mut().fill(Scalar::ZERO);
fft.add_backward_as_torus(
convolution_from_fft.as_mut_view(),
fourier0.as_view(),
stack.rb_mut(),
);

for (expected, actual) in izip!(
convolution_from_naive.as_ref().iter(),
convolution_from_fft.as_ref().iter()
) {
let threshold =
Scalar::ONE << (Scalar::BITS.saturating_sub(52 - integer_magnitude - size_log));
let abs_diff = abs_diff(*expected, *actual);
assert!(
abs_diff <= threshold,
"abs_diff: {abs_diff}, threshold: {threshold}",
);
}

// In place backward then add to output buffer
// Need to zero out the buffer to have a correct result as we will be adding the result
// Here fourier0 still contains the proper fourier transform, this call will overwrite
// it
convolution_from_fft.as_mut().fill(Scalar::ZERO);
fft.add_backward_in_place_as_torus(
convolution_from_fft.as_mut_view(),
fourier0.as_mut_view(),
stack.rb_mut(),
);

for (expected, actual) in izip!(
Expand Down

0 comments on commit 5edea3b

Please sign in to comment.