diff --git a/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs b/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs index 3862aa7463..0d765305ef 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs @@ -2,7 +2,7 @@ use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus; use crate::core_crypto::commons::math::decomposition::{ SignedDecompositionIter, SignedDecompositionIterNonNative, }; -use crate::core_crypto::commons::numeric::{Numeric, UnsignedInteger}; +use crate::core_crypto::commons::numeric::UnsignedInteger; use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount}; use crate::core_crypto::prelude::misc::divide_round_to_u128_custom_mod; use std::marker::PhantomData; @@ -104,17 +104,20 @@ where // We compute the number of least significant bits which can not be represented by the // decomposition - let non_rep_bit_count: usize = ::BITS - self.level_count * self.base_log; - // We generate a mask which captures the non representable bits - let non_rep_mask = Scalar::ONE << (non_rep_bit_count - 1); - // We retrieve the non representable bits - let non_rep_bits = input & non_rep_mask; - // We extract the msb of the non representable bits to perform the rounding - let non_rep_msb = non_rep_bits >> (non_rep_bit_count - 1); - // We remove the non-representable bits and perform the rounding - let res = input >> non_rep_bit_count; - let res = res + non_rep_msb; - res << non_rep_bit_count + // Example with level_count = 3, base_log = 4 and BITS == 64 -> 52 + let non_rep_bit_count: usize = Scalar::BITS - self.level_count * self.base_log; + let shift = non_rep_bit_count - 1; + // Move the representable bits + 1 to the LSB, with our example : + // |-----| 64 - (64 - 12 - 1) == 13 bits + // 0....0XX...XX + let mut res = input >> shift; + // Add one to do the rounding by adding the half interval + res += Scalar::ONE; + // Discard the LSB which was the one deciding in which direction we round + // -2 == 111...1110, i.e. all bits are 1 except the LSB which is 0 allowing to zero it + res &= Scalar::TWO.wrapping_neg(); + // Shift back to the right position + res << shift } /// Generate an iterator over the terms of the decomposition of the input. diff --git a/tfhe/src/core_crypto/fft_impl/common.rs b/tfhe/src/core_crypto/fft_impl/common.rs index b1fee4b343..2297f9cc38 100644 --- a/tfhe/src/core_crypto/fft_impl/common.rs +++ b/tfhe/src/core_crypto/fft_impl/common.rs @@ -25,7 +25,7 @@ pub fn pbs_modulus_switch>( // Start doing the right shift output >>= Scalar::BITS - poly_size.log2().0 - 2 + lut_count_log.0; // Do the rounding - output += output & Scalar::ONE; + output += Scalar::ONE; // Finish the right shift output >>= 1; // Apply the lsb padding