Skip to content

Commit

Permalink
refactor(core): simplify closest_representable and pbs_modulus_switch
Browse files Browse the repository at this point in the history
- both code were selecting the bit below the last representable bit,
extracted it and then added it to the bit above, the same effect can be
achieved by adding a 1 at the bit below the last representable bit
- update closest_representable to use an approach more like
pbs_modulus_switch yielding assembly with 42% less instructions (12 -> 7)
  • Loading branch information
IceTDrinker committed Sep 20, 2023
1 parent 9297a88 commit 0f88726
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
27 changes: 15 additions & 12 deletions tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus;
use crate::core_crypto::commons::math::decomposition::{
SignedDecompositionIter, SignedDecompositionIterNonNative,
};
use crate::core_crypto::commons::numeric::{Numeric, UnsignedInteger};
use crate::core_crypto::commons::numeric::UnsignedInteger;
use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount};
use crate::core_crypto::prelude::misc::divide_round_to_u128_custom_mod;
use std::marker::PhantomData;
Expand Down Expand Up @@ -104,17 +104,20 @@ where

// We compute the number of least significant bits which can not be represented by the
// decomposition
let non_rep_bit_count: usize = <Scalar as Numeric>::BITS - self.level_count * self.base_log;
// We generate a mask which captures the non representable bits
let non_rep_mask = Scalar::ONE << (non_rep_bit_count - 1);
// We retrieve the non representable bits
let non_rep_bits = input & non_rep_mask;
// We extract the msb of the non representable bits to perform the rounding
let non_rep_msb = non_rep_bits >> (non_rep_bit_count - 1);
// We remove the non-representable bits and perform the rounding
let res = input >> non_rep_bit_count;
let res = res + non_rep_msb;
res << non_rep_bit_count
// Example with level_count = 3, base_log = 4 and BITS == 64 -> 52
let non_rep_bit_count: usize = Scalar::BITS - self.level_count * self.base_log;
let shift = non_rep_bit_count - 1;
// Move the representable bits + 1 to the LSB, with our example :
// |-----| 64 - (64 - 12 - 1) == 13 bits
// 0....0XX...XX
let mut res = input >> shift;
// Add one to do the rounding by adding the half interval
res += Scalar::ONE;
// Discard the LSB which was the one deciding in which direction we round
// -2 == 111...1110, i.e. all bits are 1 except the LSB which is 0 allowing to zero it
res &= Scalar::TWO.wrapping_neg();
// Shift back to the right position
res << shift
}

/// Generate an iterator over the terms of the decomposition of the input.
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/core_crypto/fft_impl/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub fn pbs_modulus_switch<Scalar: UnsignedTorus + CastInto<usize>>(
// Start doing the right shift
output >>= Scalar::BITS - poly_size.log2().0 - 2 + lut_count_log.0;
// Do the rounding
output += output & Scalar::ONE;
output += Scalar::ONE;
// Finish the right shift
output >>= 1;
// Apply the lsb padding
Expand Down

0 comments on commit 0f88726

Please sign in to comment.