Skip to content

Commit

Permalink
fix(integer): do sum by safe chunk sizes
Browse files Browse the repository at this point in the history
Parameters are made with with assumptions on the number of leveled
add/sub/scalar_mul operations are made, so that the
noise level before doing a PBS has a correct level and everything is
safe, secure and correct.

So the lib implementation has to uphold these assumptions in order to
keep the error probability failure correct.

In the comparisons, at some point we had a vector of ciphertexts with a
degree == 1, so we greedily summed them (e.g with 2_2 params we summed
them by chunks of 15), while it is correct with regards to the carry and
message space it is however less correct with regards to the noise
level.

Noise wise, doing this huge sum is correct as long as the noise of each ciphertext
is independent from the others in the same chunk.

While it may generally be the case we are in, its not guaranteed, and
since we do not track that information we have to take the safer
approach of assuming the worst case: all noise are dependent.

So to fix the issue we compute the correct size of sum chunk by also
taking into account the max noise level.
  • Loading branch information
tmontaigu committed Sep 13, 2024
1 parent 0e64238 commit 72ad76b
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 70 deletions.
2 changes: 1 addition & 1 deletion backends/tfhe-cuda-backend/cuda/include/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -2095,7 +2095,7 @@ template <typename Torus> struct int_are_all_block_true_buffer {

if (allocate_gpu_memory) {
Torus total_modulus = params.message_modulus * params.carry_modulus;
uint32_t max_value = total_modulus - 1;
uint32_t max_value = (total_modulus - 1) / (params.message_modulus - 1);

int max_chunks = (num_radix_blocks + max_value - 1) / max_value;
tmp_block_accumulated = (Torus *)cuda_malloc_async(
Expand Down
4 changes: 2 additions & 2 deletions backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ __host__ void are_all_comparisons_block_true(
auto tmp_out = are_all_block_true_buffer->tmp_out;

uint32_t total_modulus = message_modulus * carry_modulus;
uint32_t max_value = total_modulus - 1;
uint32_t max_value = (total_modulus - 1) / (message_modulus - 1);

cuda_memcpy_async_gpu_to_gpu(tmp_out, lwe_array_in,
num_radix_blocks * (big_lwe_dimension + 1) *
Expand Down Expand Up @@ -173,7 +173,7 @@ __host__ void is_at_least_one_comparisons_block_true(
auto buffer = mem_ptr->eq_buffer->are_all_block_true_buffer;

uint32_t total_modulus = message_modulus * carry_modulus;
uint32_t max_value = total_modulus - 1;
uint32_t max_value = (total_modulus - 1) / (message_modulus - 1);

cuda_memcpy_async_gpu_to_gpu(mem_ptr->tmp_lwe_array_out, lwe_array_in,
num_radix_blocks * (big_lwe_dimension + 1) *
Expand Down
18 changes: 17 additions & 1 deletion tfhe/src/integer/server_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub(crate) mod radix;
pub(crate) mod radix_parallel;

use crate::integer::client_key::ClientKey;
use crate::shortint::ciphertext::MaxDegree;
use crate::shortint::ciphertext::{Degree, MaxDegree};
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;

Expand Down Expand Up @@ -227,6 +227,22 @@ impl ServerKey {

num_bits_to_represent_output_value.div_ceil(num_bits_in_message as usize)
}

/// Returns how many ciphertext can be summed at once
///
/// The number of ciphertext that can be added together depends on the degree
/// (in order not to go beyond the carry space and keep results correct) but also
/// on the noise level (in order to have the correct error probability and so correctness and
/// security)
///
/// - `degree` is expected degree of all elements to be summed
pub(crate) fn max_sum_size(&self, degree: Degree) -> usize {
let max_degree =
MaxDegree::from_msg_carry_modulus(self.message_modulus(), self.carry_modulus());
let max_sum_to_full_carry = max_degree.get() / degree.get();

max_sum_to_full_carry.min(self.key.max_noise_level.get())
}
}

impl AsRef<crate::shortint::ServerKey> for ServerKey {
Expand Down
25 changes: 10 additions & 15 deletions tfhe/src/integer/server_key/radix/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::ServerKey;
use crate::integer::ciphertext::boolean_value::BooleanBlock;
use crate::integer::ciphertext::IntegerRadixCiphertext;
use crate::integer::server_key::comparator::Comparator;
use crate::shortint::ciphertext::Degree;

impl ServerKey {
/// Compares for equality 2 ciphertexts
Expand Down Expand Up @@ -53,30 +54,27 @@ impl ServerKey {
.unchecked_apply_lookup_table_bivariate_assign(lhs_block, rhs_block, &lut);
});

let message_modulus = self.key.message_modulus.0;
let carry_modulus = self.key.carry_modulus.0;
let total_modulus = message_modulus * carry_modulus;
let max_value = total_modulus - 1;
let max_sum_size = self.max_sum_size(Degree::new(1));

let is_max_value = self
.key
.generate_lookup_table(|x| u64::from((x & max_value as u64) == max_value as u64));
.generate_lookup_table(|x| u64::from(x == max_sum_size as u64));

while block_comparisons.len() > 1 {
block_comparisons = block_comparisons
.chunks(max_value)
.chunks(max_sum_size)
.map(|blocks| {
let mut sum = blocks[0].clone();
for other_block in &blocks[1..] {
self.key.unchecked_add_assign(&mut sum, other_block);
}

if blocks.len() == max_value {
if blocks.len() == max_sum_size {
self.key.apply_lookup_table(&sum, &is_max_value)
} else {
let is_equal_to_num_blocks = self.key.generate_lookup_table(|x| {
u64::from((x & max_value as u64) == blocks.len() as u64)
});
let is_equal_to_num_blocks = self
.key
.generate_lookup_table(|x| u64::from(x == blocks.len() as u64));
self.key.apply_lookup_table(&sum, &is_equal_to_num_blocks)
}
})
Expand Down Expand Up @@ -112,15 +110,12 @@ impl ServerKey {
.unchecked_apply_lookup_table_bivariate_assign(lhs_block, rhs_block, &lut);
});

let message_modulus = self.key.message_modulus.0;
let carry_modulus = self.key.carry_modulus.0;
let total_modulus = message_modulus * carry_modulus;
let max_value = total_modulus - 1;
let max_sum_size = self.max_sum_size(Degree::new(1));
let is_non_zero = self.key.generate_lookup_table(|x| u64::from(x != 0));

while block_comparisons.len() > 1 {
block_comparisons = block_comparisons
.chunks(max_value)
.chunks(max_sum_size)
.map(|blocks| {
let mut sum = blocks[0].clone();
for other_block in &blocks[1..] {
Expand Down
34 changes: 4 additions & 30 deletions tfhe/src/integer/server_key/radix_parallel/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl ServerKey {
{
// Even though the corresponding function
// may already exist in self.key
// we generate our own lut to do less allocations
// we generate our own lut to do fewer allocations
// one for all the threads as opposed to one per thread
let lut = self
.key
Expand All @@ -76,7 +76,7 @@ impl ServerKey {
{
// Even though the corresponding function
// may already exist in self.key
// we generate our own lut to do less allocations
// we generate our own lut to do fewer allocations
// one for all the threads as opposed to one per thread
let lut = self
.key
Expand All @@ -90,34 +90,8 @@ impl ServerKey {
.unchecked_apply_lookup_table_bivariate_assign(lhs_block, rhs_block, &lut);
});

let message_modulus = self.key.message_modulus.0;
let carry_modulus = self.key.carry_modulus.0;
let total_modulus = message_modulus * carry_modulus;
let max_value = total_modulus - 1;

let mut block_comparisons_2 = Vec::with_capacity(block_comparisons.len() / 2);
let is_non_zero = self.key.generate_lookup_table(|x| u64::from(x != 0));

while block_comparisons.len() > 1 {
block_comparisons
.par_chunks(max_value)
.map(|blocks| {
let mut sum = blocks[0].clone();
for other_block in &blocks[1..] {
self.key.unchecked_add_assign(&mut sum, other_block);
}
self.key.apply_lookup_table(&sum, &is_non_zero)
})
.collect_into_vec(&mut block_comparisons_2);
std::mem::swap(&mut block_comparisons_2, &mut block_comparisons);
}

BooleanBlock::new_unchecked(
block_comparisons
.into_iter()
.next()
.unwrap_or_else(|| self.key.create_trivial(0)),
)
let result = self.is_at_least_one_comparisons_block_true(block_comparisons);
BooleanBlock::new_unchecked(result)
}

/// This implements all comparisons (<, <=, >, >=) for both signed and unsigned
Expand Down
36 changes: 15 additions & 21 deletions tfhe/src/integer/server_key/radix_parallel/scalar_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
use crate::integer::ciphertext::boolean_value::BooleanBlock;
use crate::integer::ciphertext::IntegerRadixCiphertext;
use crate::integer::server_key::comparator::{Comparator, ZeroComparisonType};
use crate::shortint::ciphertext::Degree;
use crate::shortint::server_key::LookupTableOwned;
use crate::shortint::Ciphertext;
use rayon::prelude::*;
Expand Down Expand Up @@ -160,27 +161,23 @@ impl ServerKey {
return self.key.create_trivial(1);
}

let message_modulus = self.key.message_modulus.0;
let carry_modulus = self.key.carry_modulus.0;
let total_modulus = message_modulus * carry_modulus;
let max_value = total_modulus - 1;

let max_sum_size = self.max_sum_size(Degree::new(1));
let is_max_value = self
.key
.generate_lookup_table(|x| u64::from(x == max_value as u64));
.generate_lookup_table(|x| u64::from(x == max_sum_size as u64));

while block_comparisons.len() > 1 {
// Since all blocks encrypt either 0 or 1, we can sum max_value of them
// as in the worst case we will be adding `max_value` ones
block_comparisons = block_comparisons
.par_chunks(max_value)
.par_chunks(max_sum_size)
.map(|blocks| {
let mut sum = blocks[0].clone();
for other_block in &blocks[1..] {
self.key.unchecked_add_assign(&mut sum, other_block);
}

if blocks.len() == max_value {
if blocks.len() == max_sum_size {
self.key.apply_lookup_table(&sum, &is_max_value)
} else {
let is_equal_to_num_blocks = self
Expand Down Expand Up @@ -213,25 +210,22 @@ impl ServerKey {
return self.key.create_trivial(1);
}

let message_modulus = self.key.message_modulus.0;
let carry_modulus = self.key.carry_modulus.0;
let total_modulus = message_modulus * carry_modulus;
let max_value = total_modulus - 1;

let is_not_zero = self.key.generate_lookup_table(|x| u64::from(x != 0));
let mut block_comparisons_2 = Vec::with_capacity(block_comparisons.len() / 2);
let max_sum_size = self.max_sum_size(Degree::new(1));

while block_comparisons.len() > 1 {
block_comparisons = block_comparisons
.par_chunks(max_value)
block_comparisons
.par_chunks(max_sum_size)
.map(|blocks| {
let mut sum = blocks[0].clone();
for other_block in &blocks[1..] {
self.key.unchecked_add_assign(&mut sum, other_block);
}

self.key.apply_lookup_table(&sum, &is_not_zero)
})
.collect::<Vec<_>>();
.collect_into_vec(&mut block_comparisons_2);
std::mem::swap(&mut block_comparisons_2, &mut block_comparisons);
}

block_comparisons
Expand Down Expand Up @@ -423,10 +417,10 @@ impl ServerKey {
let message_modulus = self.key.message_modulus.0;
let carry_modulus = self.key.carry_modulus.0;
let total_modulus = message_modulus * carry_modulus;
let max_value = total_modulus - 1;
let max_sum_size = self.max_sum_size(Degree::new(1));

assert!(carry_modulus >= message_modulus);
u8::try_from(max_value).unwrap();
u8::try_from(max_sum_size).unwrap();

let num_blocks = lhs.blocks().len();
let num_blocks_halved = (num_blocks / 2) + (num_blocks % 2);
Expand Down Expand Up @@ -516,10 +510,10 @@ impl ServerKey {
let message_modulus = self.key.message_modulus.0;
let carry_modulus = self.key.carry_modulus.0;
let total_modulus = message_modulus * carry_modulus;
let max_value = total_modulus - 1;
let max_sum_size = self.max_sum_size(Degree::new(1));

assert!(carry_modulus >= message_modulus);
u8::try_from(max_value).unwrap();
u8::try_from(max_sum_size).unwrap();

let num_blocks = lhs.blocks().len();
let num_blocks_halved = (num_blocks / 2) + (num_blocks % 2);
Expand Down

0 comments on commit 72ad76b

Please sign in to comment.