Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(core): remove a copy in the external product #600

Merged
merged 1 commit into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::super::math::decomposition::TensorSignedDecompositionLendingIter;
use super::super::math::fft::{FftView, FourierPolynomialList};
use super::super::math::polynomial::{FourierPolynomialMutView, FourierPolynomialView};
use super::super::math::polynomial::FourierPolynomialMutView;
use crate::core_crypto::commons::math::decomposition::{DecompositionLevel, SignedDecomposer};
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::parameters::{
Expand Down Expand Up @@ -588,10 +588,12 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
out.as_mut_polynomial_list().iter_mut(),
output_fft_buffer
.into_chunks(fourier_poly_size)
.map(|slice| FourierPolynomialView { data: slice }),
.map(|slice| FourierPolynomialMutView { data: slice }),
)
.for_each(|(out, fourier)| {
fft.add_backward_as_torus(out, fourier, substack0.rb_mut());
// The fourier buffer is not re-used afterwards so we can use the in-place version of
// the add_backward_as_torus function
fft.add_backward_in_place_as_torus(out, fourier, substack0.rb_mut());
});
}
}
Expand Down
31 changes: 31 additions & 0 deletions tfhe/src/core_crypto/fft_impl/fft64/math/fft/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,17 @@ impl<'a> FftView<'a> {
self.backward_with_conv(standard, fourier, convert_add_backward_torus, stack)
}

/// Variant of [`Self::add_backward_as_torus`] writing the output of the backward fourier
/// transform in the [`FourierPolynomialMutView`] in place.
pub fn add_backward_in_place_as_torus<Scalar: UnsignedTorus>(
self,
standard: PolynomialMutView<'_, Scalar>,
fourier: FourierPolynomialMutView<'_>,
stack: PodStack<'_>,
) {
self.backward_with_conv_in_place(standard, fourier, convert_add_backward_torus, stack)
}

fn forward_with_conv<
'out,
Scalar: UnsignedTorus,
Expand Down Expand Up @@ -524,6 +535,26 @@ impl<'a> FftView<'a> {
let (standard_re, standard_im) = standard.split_at_mut(n / 2);
conv_fn(standard_re, standard_im, &tmp, self.twisties);
}

fn backward_with_conv_in_place<
Scalar: UnsignedTorus,
F: Fn(&mut [Scalar], &mut [Scalar], &[c64], TwistiesView<'_>),
>(
self,
mut standard: PolynomialMutView<'_, Scalar>,
fourier: FourierPolynomialMutView<'_>,
conv_fn: F,
stack: PodStack<'_>,
) {
let fourier = fourier.data;
let standard = standard.as_mut();
let n = standard.len();
debug_assert_eq!(n, 2 * fourier.len());
self.plan.inv(fourier, stack);

let (standard_re, standard_im) = standard.split_at_mut(n / 2);
conv_fn(standard_re, standard_im, fourier, self.twisties);
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
Expand Down
93 changes: 88 additions & 5 deletions tfhe/src/core_crypto/fft_impl/fft64/math/fft/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ fn abs_diff<Scalar: UnsignedTorus>(a: Scalar, b: Scalar) -> Scalar {

fn test_roundtrip<Scalar: UnsignedTorus>() {
let mut generator = new_random_generator();
for size_log in 2..=14 {
// SIMD versions need size >= 32 in case of AVX512
for size_log in 5..=14 {
let size = 1_usize << size_log;

let fft = Fft::new(PolynomialSize(size));
Expand All @@ -40,6 +41,7 @@ fn test_roundtrip<Scalar: UnsignedTorus>() {
);
let mut stack = PodStack::new(&mut mem);

// Simple roundtrip
fft.forward_as_torus(fourier.as_mut_view(), poly.as_view(), stack.rb_mut());
fft.backward_as_torus(roundtrip.as_mut_view(), fourier.as_view(), stack.rb_mut());

Expand All @@ -50,6 +52,38 @@ fn test_roundtrip<Scalar: UnsignedTorus>() {
assert!(abs_diff(*expected, *actual) < (Scalar::ONE << (64 - 50)));
}
}

// Simple add roundtrip
// Need to zero out the buffer to have a correct result as we will be adding the result
roundtrip.as_mut().fill(Scalar::ZERO);
fft.forward_as_torus(fourier.as_mut_view(), poly.as_view(), stack.rb_mut());
fft.add_backward_as_torus(roundtrip.as_mut_view(), fourier.as_view(), stack.rb_mut());

for (expected, actual) in izip!(poly.as_ref().iter(), roundtrip.as_ref().iter()) {
if Scalar::BITS == 32 {
assert!(abs_diff(*expected, *actual) == Scalar::ZERO);
} else {
assert!(abs_diff(*expected, *actual) < (Scalar::ONE << (64 - 50)));
}
}

// Forward, then add backward in place
// Need to zero out the buffer to have a correct result as we will be adding the result
roundtrip.as_mut().fill(Scalar::ZERO);
fft.forward_as_torus(fourier.as_mut_view(), poly.as_view(), stack.rb_mut());
fft.add_backward_in_place_as_torus(
roundtrip.as_mut_view(),
fourier.as_mut_view(),
stack.rb_mut(),
);

for (expected, actual) in izip!(poly.as_ref().iter(), roundtrip.as_ref().iter()) {
if Scalar::BITS == 32 {
assert!(abs_diff(*expected, *actual) == Scalar::ZERO);
} else {
assert!(abs_diff(*expected, *actual) < (Scalar::ONE << (64 - 50)));
}
}
}
}

Expand All @@ -74,6 +108,7 @@ fn test_product<Scalar: UnsignedTorus>() {
}

let mut generator = new_random_generator();
// SIMD versions need size >= 32 in case of AVX512
for size_log in 5..=14 {
for _ in 0..100 {
let size = 1_usize << size_log;
Expand Down Expand Up @@ -119,15 +154,63 @@ fn test_product<Scalar: UnsignedTorus>() {
*f0 *= *f1;
}

convolution_naive(
convolution_from_naive.as_mut(),
poly0.as_ref(),
poly1.as_ref(),
);

// Simple backward
fft.backward_as_torus(
convolution_from_fft.as_mut_view(),
fourier0.as_view(),
stack.rb_mut(),
);
convolution_naive(
convolution_from_naive.as_mut(),
poly0.as_ref(),
poly1.as_ref(),

for (expected, actual) in izip!(
convolution_from_naive.as_ref().iter(),
convolution_from_fft.as_ref().iter()
) {
let threshold =
Scalar::ONE << (Scalar::BITS.saturating_sub(52 - integer_magnitude - size_log));
let abs_diff = abs_diff(*expected, *actual);
assert!(
abs_diff <= threshold,
"abs_diff: {abs_diff}, threshold: {threshold}",
);
}

// Simple add backward
// Need to zero out the buffer to have a correct result as we will be adding the result
convolution_from_fft.as_mut().fill(Scalar::ZERO);
fft.add_backward_as_torus(
convolution_from_fft.as_mut_view(),
fourier0.as_view(),
stack.rb_mut(),
);

for (expected, actual) in izip!(
convolution_from_naive.as_ref().iter(),
convolution_from_fft.as_ref().iter()
) {
let threshold =
Scalar::ONE << (Scalar::BITS.saturating_sub(52 - integer_magnitude - size_log));
let abs_diff = abs_diff(*expected, *actual);
assert!(
abs_diff <= threshold,
"abs_diff: {abs_diff}, threshold: {threshold}",
);
}

// In place backward then add to output buffer
// Need to zero out the buffer to have a correct result as we will be adding the result
// Here fourier0 still contains the proper fourier transform, this call will overwrite
// it
convolution_from_fft.as_mut().fill(Scalar::ZERO);
fft.add_backward_in_place_as_torus(
convolution_from_fft.as_mut_view(),
fourier0.as_mut_view(),
stack.rb_mut(),
);

for (expected, actual) in izip!(
Expand Down
Loading