Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(core): simplify closest_representable and pbs_modulus_switch #576

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 101 additions & 13 deletions tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ pub struct SignedDecomposer<Scalar>
where
Scalar: UnsignedInteger,
{
// non_rep_bit_count: usize,
rounding_half_interval: Scalar,
representable_bits_mask: Scalar,
pub(crate) base_log: usize,
pub(crate) level_count: usize,
integer_type: PhantomData<Scalar>,
Expand Down Expand Up @@ -44,7 +47,14 @@ where
Scalar::BITS > base_log.0 * level_count.0,
"Decomposed bits exceeds the size of the integer to be decomposed"
);

let rounding_half_interval = Scalar::ONE << (Scalar::BITS - base_log.0 * level_count.0 - 1);
let representable_bits_mask = !((rounding_half_interval << 1).wrapping_sub(Scalar::ONE));

SignedDecomposer {
// non_rep_bit_count: Scalar::BITS - base_log.0 * level_count.0,
rounding_half_interval,
representable_bits_mask,
base_log: base_log.0,
level_count: level_count.0,
integer_type: PhantomData,
Expand Down Expand Up @@ -102,19 +112,78 @@ where
// The closest number representable by the decomposition can be computed by performing
// the rounding at the appropriate bit.

// We compute the number of least significant bits which can not be represented by the
// decomposition
let non_rep_bit_count: usize = <Scalar as Numeric>::BITS - self.level_count * self.base_log;
// We generate a mask which captures the non representable bits
let non_rep_mask = Scalar::ONE << (non_rep_bit_count - 1);
// We retrieve the non representable bits
let non_rep_bits = input & non_rep_mask;
// We extract the msb of the non representable bits to perform the rounding
let non_rep_msb = non_rep_bits >> (non_rep_bit_count - 1);
// We remove the non-representable bits and perform the rounding
let res = input >> non_rep_bit_count;
let res = res + non_rep_msb;
res << non_rep_bit_count
// // We compute the number of least significant bits which can not be represented by the
// // decomposition
// let non_rep_bit_count: usize = <Scalar as Numeric>::BITS - self.level_count *
// self.base_log; // We generate a mask which captures the non representable bits
// let non_rep_mask = Scalar::ONE << (non_rep_bit_count - 1);
// // We retrieve the non representable bits
// let non_rep_bits = input & non_rep_mask;
// // We extract the msb of the non representable bits to perform the rounding
// let non_rep_msb = non_rep_bits >> (non_rep_bit_count - 1);
// // We remove the non-representable bits and perform the rounding
// let res = input >> non_rep_bit_count;
// let res = res + non_rep_msb;
// res << non_rep_bit_count

// // simplified
// // We compute the number of least significant bits which can not be represented by the
// // decomposition
// // Example with level_count = 3, base_log = 4 and BITS == 64
// let non_rep_bit_count: usize = <Scalar as Numeric>::BITS - self.level_count *
// self.base_log; // |-----| 64 - (64 - 12 - 1) == 13 bits
// // 0....0XX...XX
// let mut res = input >> (non_rep_bit_count - 1);
// // Add one to do the rounding by adding the half interval
// res += Scalar::ONE;
// // Discard the LSB which was the one deciding in which direction we round
// res >>= 1;
// // Shift back to the right position with the rounding
// res << non_rep_bit_count

// // cached
// // We compute the number of least significant bits which can not be represented by the
// // decomposition
// // Example with level_count = 3, base_log = 4 and BITS == 64
// // |-----| 64 - (64 - 12 - 1) == 13 bits
// // 0....0XX...XX
// let mut res = input >> (self.non_rep_bit_count - 1);
// // Add one to do the rounding by adding the half interval
// res += Scalar::ONE;
// // Discard the LSB which was the one deciding in which direction we round
// res >>= 1;
// // Shift back to the right position with the rounding
// res << self.non_rep_bit_count

// // alternative
// let representable_bits = self.base_log * self.level_count;
// // Half interval
// //
// // 0...0010...0
// // ^ ^
// // |____|__ 12 representable bits
// let rounding_half_interval = Scalar::ONE << (Scalar::BITS - representable_bits - 1);
// // Representable mask
// // 0...0100...0
// // ^ ^
// // |___|__ 11 bits
// //
// // minus 1
// //
// // 0...0011...1
// // ^ ^
// // |____|__ 12 bits
// //
// // Not
// //
// // 1...1100...0
// // ^ ^
// // |____|__ 12 representable bits
// let representable_bits_mask = !((rounding_half_interval << 1).wrapping_sub(Scalar::ONE));
// input.wrapping_add(rounding_half_interval) & representable_bits_mask

// alternative cached
input.wrapping_add(self.rounding_half_interval) & self.representable_bits_mask
}

/// Generate an iterator over the terms of the decomposition of the input.
Expand Down Expand Up @@ -390,3 +459,22 @@ where
)
}
}

#[test]
pub fn test_closest_rep() {
use rand::Rng;
let mut rng = rand::thread_rng();
let values: Vec<u64> = Vec::from_iter((0..1_000_000_000).map(|_| rng.gen()));
let mut rounded = vec![0u64; values.len()];

let decomp = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(3));

let start = std::time::Instant::now();
values
.into_iter()
.zip(rounded.iter_mut())
.for_each(|(input, output)| *output = decomp.closest_representable(input));
let elapsed = start.elapsed().as_secs_f64();

panic!("{elapsed} s");
}
2 changes: 1 addition & 1 deletion tfhe/src/core_crypto/fft_impl/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub fn pbs_modulus_switch<Scalar: UnsignedTorus + CastInto<usize>>(
// Start doing the right shift
output >>= Scalar::BITS - poly_size.log2().0 - 2 + lut_count_log.0;
// Do the rounding
output += output & Scalar::ONE;
output += Scalar::ONE;
// Finish the right shift
output >>= 1;
// Apply the lsb padding
Expand Down
Loading