Skip to content

Commit

Permalink
feat(integer): add unsigned_oveflowing_add
Browse files Browse the repository at this point in the history
  • Loading branch information
tmontaigu committed Nov 14, 2023
1 parent 1f825dd commit 72f1aee
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 16 deletions.
101 changes: 97 additions & 4 deletions tfhe/src/integer/server_key/radix_parallel/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,94 @@ impl ServerKey {
};

if self.is_eligible_for_parallel_single_carry_propagation(lhs) {
self.unchecked_add_assign_parallelized_low_latency(lhs, rhs);
let _carry = self.unchecked_add_assign_parallelized_low_latency(lhs, rhs);
} else {
self.unchecked_add_assign(lhs, rhs);
self.full_propagate_parallelized(lhs);
}
}
/// Computes the addition of two unsigned ciphertexts and returns the overflow flag
///
/// # Example
///
/// ```rust
/// use tfhe::integer::gen_keys_radix;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
///
/// // Generate the client key and the server key:
/// let num_blocks = 4;
/// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks);
///
/// let msg1 = u8::MAX;
/// let msg2 = 1;
///
/// let ct1 = cks.encrypt(msg1);
/// let ct2 = cks.encrypt(msg2);
///
/// let (ct_res, overflowed) = sks.unsigned_overflowing_add_parallelized(&ct1, &ct2);
///
/// // Decrypt:
/// let dec_result: u8 = cks.decrypt(&ct_res);
/// let dec_overflowed = cks.decrypt_one_block(&overflowed);
/// let (expected_result, expected_overflow) = msg1.overflowing_add(msg2);
/// assert_eq!(dec_result, expected_result);
/// assert_eq!(dec_overflowed, u64::from(expected_overflow));
/// ```
pub fn unsigned_overflowing_add_parallelized(
&self,
ct_left: &RadixCiphertext,
ct_right: &RadixCiphertext,
) -> (RadixCiphertext, Ciphertext) {
let mut ct_res = ct_left.clone();
let overflowed = self.unsigned_overflowing_add_assign_parallelized(&mut ct_res, ct_right);
(ct_res, overflowed)
}

pub fn unsigned_overflowing_add_assign_parallelized(
&self,
ct_left: &mut RadixCiphertext,
ct_right: &RadixCiphertext,
) -> Ciphertext {
let mut tmp_rhs: RadixCiphertext;
if ct_left.blocks.is_empty() || ct_right.blocks.is_empty() {
return self.key.create_trivial(0);
}

let (lhs, rhs) = match (
ct_left.block_carries_are_empty(),
ct_right.block_carries_are_empty(),
) {
(true, true) => (ct_left, ct_right),
(true, false) => {
tmp_rhs = ct_right.clone();
self.full_propagate_parallelized(&mut tmp_rhs);
(ct_left, &tmp_rhs)
}
(false, true) => {
self.full_propagate_parallelized(ct_left);
(ct_left, ct_right)
}
(false, false) => {
tmp_rhs = ct_right.clone();
rayon::join(
|| self.full_propagate_parallelized(ct_left),
|| self.full_propagate_parallelized(&mut tmp_rhs),
);
(ct_left, &tmp_rhs)
}
};

if self.is_eligible_for_parallel_single_carry_propagation(lhs) {
self.unchecked_add_assign_parallelized_low_latency(lhs, rhs)
} else {
self.unchecked_add_assign(lhs, rhs);
let len = lhs.blocks.len();
for i in 0..len - 1 {
let _ = self.propagate_parallelized(lhs, i);
}
self.propagate_parallelized(lhs, len - 1)
}
}

pub fn add_parallelized_work_efficient<T>(&self, ct_left: &T, ct_right: &T) -> T
where
Expand Down Expand Up @@ -309,6 +391,9 @@ impl ServerKey {
///
/// At most num_block - 1 threads are used
///
/// Returns the output carry that can be used to check for unsigned addition
/// overflow.
///
/// # Requirements
///
/// - The parameters have 4 bits in total
Expand All @@ -317,7 +402,11 @@ impl ServerKey {
/// # Output
///
/// - lhs will have its carries empty
pub(crate) fn unchecked_add_assign_parallelized_low_latency<T>(&self, lhs: &mut T, rhs: &T)
pub(crate) fn unchecked_add_assign_parallelized_low_latency<T>(
&self,
lhs: &mut T,
rhs: &T,
) -> Ciphertext
where
T: IntegerRadixCiphertext,
{
Expand All @@ -342,12 +431,15 @@ impl ServerKey {
/// - first unchecked_add
/// - at this point at most on bit of carry is taken
/// - use this function to propagate them in parallel
pub(crate) fn propagate_single_carry_parallelized_low_latency<T>(&self, ct: &mut T)
pub(crate) fn propagate_single_carry_parallelized_low_latency<T>(
&self,
ct: &mut T,
) -> Ciphertext
where
T: IntegerRadixCiphertext,
{
let generates_or_propagates = self.generate_init_carry_array(ct);
let (input_carries, _) =
let (input_carries, output_carry) =
self.compute_carry_propagation_parallelized_low_latency(generates_or_propagates);

ct.blocks_mut()
Expand All @@ -357,6 +449,7 @@ impl ServerKey {
self.key.unchecked_add_assign(block, input_carry);
self.key.message_extract_assign(block);
});
output_carry
}

/// Backbone algorithm of parallel carry (only one bit) propagation
Expand Down
12 changes: 9 additions & 3 deletions tfhe/src/integer/server_key/radix_parallel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ impl ServerKey {
/// let res: u64 = cks.decrypt_one_block(&ct_res.blocks()[1]);
/// assert_eq!(3, res);
/// ```
pub fn propagate_parallelized<T>(&self, ctxt: &mut T, index: usize)
pub fn propagate_parallelized<T>(
&self,
ctxt: &mut T,
index: usize,
) -> crate::shortint::Ciphertext
where
T: IntegerRadixCiphertext,
{
Expand All @@ -77,6 +81,8 @@ impl ServerKey {
self.key
.unchecked_add_assign(&mut ctxt.blocks_mut()[index + 1], &carry);
}

carry
}

pub fn partial_propagate_parallelized<T>(&self, ctxt: &mut T, start_index: usize)
Expand Down Expand Up @@ -107,11 +113,11 @@ impl ServerKey {

ctxt.blocks_mut()[start_index..].swap_with_slice(&mut message_blocks);
let carries = T::from_blocks(carry_blocks);
self.unchecked_add_assign_parallelized_low_latency(ctxt, &carries);
let _ = self.unchecked_add_assign_parallelized_low_latency(ctxt, &carries);
} else {
let len = ctxt.blocks().len();
for i in start_index..len {
self.propagate_parallelized(ctxt, i);
let _ = self.propagate_parallelized(ctxt, i);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/server_key/radix_parallel/neg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl ServerKey {

if self.is_eligible_for_parallel_single_carry_propagation(ct) {
let mut ct = self.unchecked_neg(ct);
self.propagate_single_carry_parallelized_low_latency(&mut ct);
let _carry = self.propagate_single_carry_parallelized_low_latency(&mut ct);
ct
} else {
let mut ct = self.unchecked_neg(ct);
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/server_key/radix_parallel/scalar_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ impl ServerKey {

if self.is_eligible_for_parallel_single_carry_propagation(ct) {
self.unchecked_scalar_add_assign(ct, scalar);
self.propagate_single_carry_parallelized_low_latency(ct);
let _carry = self.propagate_single_carry_parallelized_low_latency(ct);
} else {
self.unchecked_scalar_add_assign(ct, scalar);
self.full_propagate_parallelized(ct);
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl ServerKey {
self.unchecked_scalar_sub_assign(ct, scalar);

if self.is_eligible_for_parallel_single_carry_propagation(ct) {
self.propagate_single_carry_parallelized_low_latency(ct);
let _carry = self.propagate_single_carry_parallelized_low_latency(ct);
} else {
self.full_propagate_parallelized(ct);
}
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/server_key/radix_parallel/sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ impl ServerKey {

if self.is_eligible_for_parallel_single_carry_propagation(lhs) {
let neg = self.unchecked_neg(rhs);
self.unchecked_add_assign_parallelized_low_latency(lhs, &neg);
let _carry = self.unchecked_add_assign_parallelized_low_latency(lhs, &neg);
} else {
self.unchecked_sub_assign(lhs, rhs);
self.full_propagate_parallelized(lhs);
Expand Down
136 changes: 131 additions & 5 deletions tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,17 @@ fn rotate_right_helper(value: u64, n: u32, actual_bit_size: u32) -> u64 {
}

fn overflowing_sub_under_modulus(lhs: u64, rhs: u64, modulus: u64) -> (u64, bool) {
let result = lhs.wrapping_sub(rhs);
// Technically using a div is not the fastest way to check for overflow,
// but as we have to do the remainder regardless, that /% should be one instruction
let (q, r) = (result / modulus, result % modulus);
assert!(
!(modulus.is_power_of_two() && (modulus - 1).overflowing_mul(2).1),
"If modulus is not a power of two, then must not overflow u64"
);
let (result, overflowed) = lhs.overflowing_sub(rhs);
(result % modulus, overflowed)
}

(r, q != 0)
fn overflowing_add_under_modulus(lhs: u64, rhs: u64, modulus: u64) -> (u64, bool) {
let (result, overflowed) = lhs.overflowing_add(rhs);
(result % modulus, overflowed || result >= modulus)
}

/// This trait is to be implemented by a struct that is capable
Expand Down Expand Up @@ -1771,6 +1776,127 @@ where
}
}

pub(crate) fn default_overflowing_add_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
T: for<'a> FunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext),
(RadixCiphertext, Ciphertext),
>,
{
let (cks, mut sks) = KEY_CACHE.get_from_params(param);
let cks = RadixClientKey::from((cks, NB_CTXT));

sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);

let mut rng = rand::thread_rng();

// message_modulus^vec_length
let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64;

executor.setup(&cks, sks.clone());

for _ in 0..NB_TEST_SMALLER {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;

let ctxt_0 = cks.encrypt(clear_0);
let ctxt_1 = cks.encrypt(clear_1);

let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
assert!(result_overflowed.carry_is_empty());
assert_eq!(ct_res, tmp_ct, "Failed determinism check");
assert_eq!(tmp_o, result_overflowed, "Failed determinism check");

let (expected_result, expected_overflowed) =
overflowing_add_under_modulus(clear_0, clear_1, modulus);

let decrypted_result: u64 = cks.decrypt(&ct_res);
let decrypted_overflowed = cks.decrypt_one_block(&result_overflowed) == 1;
assert_eq!(
decrypted_result, expected_result,
"Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_add for ({clear_0} + {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);

for _ in 0..NB_TEST_SMALLER {
// Add non zero scalar to have non clean ciphertexts
let clear_2 = random_non_zero_value(&mut rng, modulus);
let clear_3 = random_non_zero_value(&mut rng, modulus);

let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2);
let ctxt_1 = sks.unchecked_scalar_add(&ctxt_1, clear_3);

let (clear_lhs, _) = overflowing_add_under_modulus(clear_0, clear_2, modulus);
let (clear_rhs, _) = overflowing_add_under_modulus(clear_1, clear_3, modulus);

let d0: u64 = cks.decrypt(&ctxt_0);
assert_eq!(d0, clear_lhs, "Failed sanity decryption check");
let d1: u64 = cks.decrypt(&ctxt_1);
assert_eq!(d1, clear_rhs, "Failed sanity decryption check");

let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
assert!(result_overflowed.carry_is_empty());

let (expected_result, expected_overflowed) =
overflowing_add_under_modulus(clear_lhs, clear_rhs, modulus);

let decrypted_result: u64 = cks.decrypt(&ct_res);
let decrypted_overflowed = cks.decrypt_one_block(&result_overflowed) == 1;
assert_eq!(
decrypted_result, expected_result,
"Invalid result for add, for ({clear_lhs} + {clear_rhs}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_add, for ({clear_lhs} + {clear_rhs}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
}
}

// Test with trivial inputs
for _ in 0..4 {
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;

let a: RadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
let b: RadixCiphertext = sks.create_trivial_radix(clear_1, NB_CTXT);

let (encrypted_result, encrypted_overflow) =
sks.unsigned_overflowing_add_parallelized(&a, &b);

let (expected_result, expected_overflowed) =
overflowing_add_under_modulus(clear_0, clear_1, modulus);

let decrypted_result: u64 = cks.decrypt(&encrypted_result);
let decrypted_overflowed = cks.decrypt_one_block(&encrypted_overflow) == 1;
assert_eq!(
decrypted_result, expected_result,
"Invalid result for add, for ({clear_0} + {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_add, for ({clear_0} + {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
}
}

pub(crate) fn default_overflowing_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
Expand Down
9 changes: 9 additions & 0 deletions tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ create_parametrized_test!(integer_smart_add);
create_parametrized_test!(integer_smart_add_sequence_multi_thread);
create_parametrized_test!(integer_smart_add_sequence_single_thread);
create_parametrized_test!(integer_default_add);
create_parametrized_test!(integer_default_overflowing_add);
create_parametrized_test!(integer_default_add_work_efficient {
// This algorithm requires 3 bits
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
Expand Down Expand Up @@ -717,6 +718,14 @@ where
default_add_test(param, executor);
}

fn integer_default_overflowing_add<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_add_parallelized);
default_overflowing_add_test(param, executor);
}

fn integer_default_sub<P>(param: P)
where
P: Into<PBSParameters>,
Expand Down

0 comments on commit 72f1aee

Please sign in to comment.