Skip to content

Commit

Permalink
chore(integer): make tests work with different ServerKey
Browse files Browse the repository at this point in the history
This is a first step, a second step would be
to plug the non parallel radix tests so that
they are testing the same things.
  • Loading branch information
tmontaigu committed Sep 26, 2023
1 parent 37be751 commit 3b40670
Show file tree
Hide file tree
Showing 7 changed files with 3,533 additions and 2,594 deletions.
6 changes: 6 additions & 0 deletions tfhe/src/integer/client_key/radix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,9 @@ impl From<(ClientKey, usize)> for RadixClientKey {
Self { key, num_blocks }
}
}

impl From<RadixClientKey> for ClientKey {
fn from(ck: RadixClientKey) -> Self {
ck.key
}
}
5 changes: 4 additions & 1 deletion tfhe/src/integer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ pub use bigint::i256::I256;
pub use bigint::i512::I512;
pub use bigint::u256::U256;
pub use bigint::u512::U512;
pub use ciphertext::{CrtCiphertext, IntegerCiphertext, RadixCiphertext, SignedRadixCiphertext};
pub use ciphertext::{
CrtCiphertext, IntegerCiphertext, IntegerRadixCiphertext, RadixCiphertext,
SignedRadixCiphertext,
};
pub use client_key::{ClientKey, CrtClientKey, RadixClientKey};
pub use public_key::{CompressedCompactPublicKey, CompressedPublicKey, PublicKey};
pub use server_key::{CheckError, CompressedServerKey, ServerKey};
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/integer/server_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub mod comparator;
mod crt;
mod crt_parallel;
mod radix;
mod radix_parallel;
pub(crate) mod radix_parallel;

use crate::integer::client_key::ClientKey;
use crate::shortint::server_key::MaxDegree;
Expand Down
152 changes: 29 additions & 123 deletions tfhe/src/integer/server_key/radix/tests.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use crate::integer::keycache::KEY_CACHE;
use crate::integer::U256;
use crate::integer::{ServerKey, U256};
use crate::shortint::parameters::*;
use crate::shortint::ClassicPBSParameters;
use rand::Rng;

use crate::integer::server_key::radix_parallel::tests_cases_unsigned::*;
use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor;

/// Number of loop iteration within randomized tests
const NB_TEST: usize = 30;

Expand Down Expand Up @@ -41,7 +44,7 @@ create_parametrized_test!(integer_blockshift_right);
create_parametrized_test!(integer_smart_scalar_mul);
create_parametrized_test!(integer_unchecked_scalar_left_shift);
create_parametrized_test!(integer_unchecked_scalar_right_shift);
create_parametrized_test!(integer_unchecked_negation);
create_parametrized_test!(integer_unchecked_neg);
create_parametrized_test!(integer_smart_neg);
create_parametrized_test!(integer_unchecked_sub);
create_parametrized_test!(integer_smart_sub);
Expand Down Expand Up @@ -736,36 +739,12 @@ fn integer_unchecked_scalar_right_shift(param: ClassicPBSParameters) {
}
}

fn integer_unchecked_negation(param: ClassicPBSParameters) {
let (cks, sks) = KEY_CACHE.get_from_params(param);

//RNG
let mut rng = rand::thread_rng();

// message_modulus^vec_length
let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64;

for _ in 0..NB_TEST {
// Define the cleartexts
let clear = rng.gen::<u64>() % modulus;

// println!("clear = {}", clear);

// Encrypt the integers
let ctxt = cks.encrypt_radix(clear, NB_CTXT);

// Negates the ctxt
let ct_tmp = sks.unchecked_neg(&ctxt);

// Decrypt the result
let dec: u64 = cks.decrypt_radix(&ct_tmp);

// Check the correctness
let clear_result = clear.wrapping_neg() % modulus;

//println!("clear = {}", clear);
assert_eq!(clear_result, dec);
}
fn integer_unchecked_neg<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_neg);
unchecked_neg_test(param, executor);
}

fn integer_smart_neg(param: ClassicPBSParameters) {
Expand Down Expand Up @@ -796,34 +775,12 @@ fn integer_smart_neg(param: ClassicPBSParameters) {
}
}

fn integer_unchecked_sub(param: ClassicPBSParameters) {
let (cks, sks) = KEY_CACHE.get_from_params(param);

// RNG
let mut rng = rand::thread_rng();

// message_modulus^vec_length
let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64;

for _ in 0..NB_TEST {
// Define the cleartexts
let clear1 = rng.gen::<u64>() % modulus;
let clear2 = rng.gen::<u64>() % modulus;

// Encrypt the integers
let ctxt_1 = cks.encrypt_radix(clear1, NB_CTXT);
let ctxt_2 = cks.encrypt_radix(clear2, NB_CTXT);

// Add the ciphertext 1 and 2
let ct_tmp = sks.unchecked_sub(&ctxt_1, &ctxt_2);

// Decrypt the result
let dec: u64 = cks.decrypt_radix(&ct_tmp);

// Check the correctness
let clear_result = (clear1 - clear2) % modulus;
assert_eq!(clear_result, dec);
}
fn integer_unchecked_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_sub);
unchecked_sub_test(param, executor);
}

fn integer_smart_sub(param: ClassicPBSParameters) {
Expand Down Expand Up @@ -996,32 +953,12 @@ fn integer_smart_mul(param: ClassicPBSParameters) {
}
}

fn integer_unchecked_scalar_add(param: ClassicPBSParameters) {
let (cks, sks) = KEY_CACHE.get_from_params(param);

//RNG
let mut rng = rand::thread_rng();

// message_modulus^vec_length
let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64;

for _ in 0..NB_TEST {
let clear_0 = rng.gen::<u64>() % modulus;

let clear_1 = rng.gen::<u64>() % modulus;

// encryption of an integer
let ctxt_0 = cks.encrypt_radix(clear_0, NB_CTXT);

// add the two ciphertexts
let ct_res = sks.unchecked_scalar_add(&ctxt_0, clear_1);

// decryption of ct_res
let dec_res: u64 = cks.decrypt_radix(&ct_res);

// assert
assert_eq!((clear_0 + clear_1) % modulus, dec_res);
}
fn integer_unchecked_scalar_add<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_add);
unchecked_scalar_add_test(param, executor);
}

fn integer_smart_scalar_add(param: ClassicPBSParameters) {
Expand Down Expand Up @@ -1065,43 +1002,12 @@ fn integer_smart_scalar_add(param: ClassicPBSParameters) {
}
}

fn integer_unchecked_scalar_sub(param: ClassicPBSParameters) {
let (cks, sks) = KEY_CACHE.get_from_params(param);

//RNG
let mut rng = rand::thread_rng();

// message_modulus^vec_length
let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64;

// To force having one case where we subtract zero
{
let clear_0 = rng.gen::<u64>() % modulus;

let ctxt_0 = cks.encrypt_radix(clear_0, NB_CTXT);
let ct_res = sks.unchecked_scalar_sub(&ctxt_0, 0u64);

let dec_res: u64 = cks.decrypt_radix(&ct_res);

assert_eq!(clear_0, dec_res);
}

for _ in 0..NB_TEST {
let clear_0 = rng.gen::<u64>() % modulus;

let clear_1 = rng.gen::<u64>() % modulus;
// encryption of an integer
let ctxt_0 = cks.encrypt_radix(clear_0, NB_CTXT);

// add the two ciphertexts
let ct_res = sks.unchecked_scalar_sub(&ctxt_0, clear_1);

// decryption of ct_res
let dec_res: u64 = cks.decrypt_radix(&ct_res);

// assert
assert_eq!((clear_0.wrapping_sub(clear_1)) % modulus, dec_res);
}
fn integer_unchecked_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_sub);
unchecked_scalar_sub_test(param, executor);
}

fn integer_smart_scalar_sub(param: ClassicPBSParameters) {
Expand Down
4 changes: 3 additions & 1 deletion tfhe/src/integer/server_key/radix_parallel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ mod scalar_sub;
mod shift;
mod sub;

#[cfg(test)]
pub(crate) mod tests_cases_unsigned;
#[cfg(test)]
mod tests_signed;
#[cfg(test)]
mod tests_unsigned;
pub(crate) mod tests_unsigned;

use crate::integer::ciphertext::IntegerRadixCiphertext;

Expand Down
Loading

0 comments on commit 3b40670

Please sign in to comment.