diff --git a/tfhe/src/core_crypto/entities/compressed_modulus_switched_glwe_ciphertext.rs b/tfhe/src/core_crypto/entities/compressed_modulus_switched_glwe_ciphertext.rs index 963af0804..db2e7269c 100644 --- a/tfhe/src/core_crypto/entities/compressed_modulus_switched_glwe_ciphertext.rs +++ b/tfhe/src/core_crypto/entities/compressed_modulus_switched_glwe_ciphertext.rs @@ -77,7 +77,7 @@ use crate::core_crypto::prelude::*; /// ); /// } /// ``` -#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)] +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)] #[versionize(CompressedModulusSwitchedGlweCiphertextVersions)] pub struct CompressedModulusSwitchedGlweCiphertext { pub(crate) packed_integers: PackedIntegers, diff --git a/tfhe/src/core_crypto/entities/packed_integers.rs b/tfhe/src/core_crypto/entities/packed_integers.rs index 9f5e3ba2f..0df76d79f 100644 --- a/tfhe/src/core_crypto/entities/packed_integers.rs +++ b/tfhe/src/core_crypto/entities/packed_integers.rs @@ -4,7 +4,7 @@ use crate::conformance::ParameterSetConformant; use crate::core_crypto::backward_compatibility::entities::packed_integers::PackedIntegersVersions; use crate::core_crypto::prelude::*; -#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)] +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)] #[versionize(PackedIntegersVersions)] pub struct PackedIntegers { pub(crate) packed_coeffs: Vec, diff --git a/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs index 68811ef63..f83c74a73 100644 --- a/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs @@ -95,7 +95,7 @@ impl CompressedCiphertextListBuilder { } } -#[derive(Clone, Serialize, Deserialize, Versionize)] +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Versionize)] #[versionize(CompressedCiphertextListVersions)] pub struct CompressedCiphertextList { pub(crate) packed_list: ShortintCompressedCiphertextList, @@ -156,6 +156,8 @@ mod tests { use crate::integer::ClientKey; use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; + use itertools::Itertools; + use rand::Rng; #[test] fn test_heterogeneous_ciphertext_compression_ci_run_filter() { @@ -195,4 +197,173 @@ mod tests { assert!(cks.decrypt_bool(&decompressed3)); } + + const NB_TESTS: usize = 10; + const NB_OPERATOR_TESTS: usize = 1; + #[test] + fn test_ciphertext_compression() { + let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); + + let private_compression_key = + cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); + + let (compression_key, decompression_key) = + cks.new_compression_decompression_keys(&private_compression_key); + + let mut rng = rand::thread_rng(); + + let message_modulus: u128 = cks.parameters().message_modulus().0 as u128; + + let num_blocks = 32; + for _ in 0..NB_TESTS { + // Unsigned + let modulus = message_modulus.pow(num_blocks as u32); + for _ in 0..NB_OPERATOR_TESTS { + let nb_messages = 1 + (rng.gen::() % 6); + let messages = (0..nb_messages) + .map(|_| rng.gen::() % modulus) + .collect::>(); + + let cts = messages + .iter() + .map(|message| cks.encrypt_radix(*message, num_blocks)) + .collect_vec(); + + let mut builder = CompressedCiphertextListBuilder::new(); + + for ct in cts { + builder.push(ct); + } + + let compressed = builder.build(&compression_key); + + for (i, message) in messages.iter().enumerate() { + let decompressed = compressed.get(i, &decompression_key).unwrap().unwrap(); + let decrypted: u128 = cks.decrypt_radix(&decompressed); + assert_eq!(decrypted, *message); + } + } + + // Signed + let modulus = message_modulus.pow((num_blocks - 1) as u32) as i128; + for _ in 0..NB_OPERATOR_TESTS { + let nb_messages = 1 + (rng.gen::() % 6); + let messages = (0..nb_messages) + .map(|_| rng.gen::() % modulus) + .collect::>(); + + let cts = messages + .iter() + .map(|message| cks.encrypt_signed_radix(*message, num_blocks)) + .collect_vec(); + + let mut builder = CompressedCiphertextListBuilder::new(); + + for ct in cts { + builder.push(ct); + } + + let compressed = builder.build(&compression_key); + + for (i, message) in messages.iter().enumerate() { + let decompressed = compressed.get(i, &decompression_key).unwrap().unwrap(); + let decrypted: i128 = cks.decrypt_signed_radix(&decompressed); + assert_eq!(decrypted, *message); + } + } + + // Boolean + for _ in 0..NB_OPERATOR_TESTS { + let nb_messages = 1 + (rng.gen::() % 6); + let messages = (0..nb_messages) + .map(|_| rng.gen::() % 2 != 0) + .collect::>(); + + let cts = messages + .iter() + .map(|message| cks.encrypt_bool(*message)) + .collect_vec(); + + let mut builder = CompressedCiphertextListBuilder::new(); + + for ct in cts { + builder.push(ct); + } + + let cuda_compressed = builder.build(&compression_key); + + for (i, message) in messages.iter().enumerate() { + let decompressed = cuda_compressed.get(i, &decompression_key).unwrap().unwrap(); + let decrypted = cks.decrypt_bool(&decompressed); + assert_eq!(decrypted, *message); + } + } + + // Hybrid + enum MessageType { + Unsigned(u128), + Signed(i128), + Boolean(bool), + } + for _ in 0..NB_OPERATOR_TESTS { + let mut builder = CompressedCiphertextListBuilder::new(); + + let nb_messages = 1 + (rng.gen::() % 6); + let mut messages = vec![]; + for _ in 0..nb_messages { + let case_selector = rng.gen_range(0..3); + match case_selector { + 0 => { + // Unsigned + let modulus = message_modulus.pow(num_blocks as u32); + let message = rng.gen::() % modulus; + let ct = cks.encrypt_radix(message, num_blocks); + builder.push(ct); + messages.push(MessageType::Unsigned(message)); + } + 1 => { + // Signed + let modulus = message_modulus.pow((num_blocks - 1) as u32) as i128; + let message = rng.gen::() % modulus; + let ct = cks.encrypt_signed_radix(message, num_blocks); + builder.push(ct); + messages.push(MessageType::Signed(message)); + } + _ => { + // Boolean + let message = rng.gen::() % 2 != 0; + let ct = cks.encrypt_bool(message); + builder.push(ct); + messages.push(MessageType::Boolean(message)); + } + } + } + + let compressed = builder.build(&compression_key); + + for (i, val) in messages.iter().enumerate() { + match val { + MessageType::Unsigned(message) => { + let decompressed = + compressed.get(i, &decompression_key).unwrap().unwrap(); + let decrypted: u128 = cks.decrypt_radix(&decompressed); + assert_eq!(decrypted, *message); + } + MessageType::Signed(message) => { + let decompressed = + compressed.get(i, &decompression_key).unwrap().unwrap(); + let decrypted: i128 = cks.decrypt_signed_radix(&decompressed); + assert_eq!(decrypted, *message); + } + MessageType::Boolean(message) => { + let decompressed = + compressed.get(i, &decompression_key).unwrap().unwrap(); + let decrypted = cks.decrypt_bool(&decompressed); + assert_eq!(decrypted, *message); + } + } + } + } + } + } } diff --git a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs index eb289218f..c4e92b1fc 100644 --- a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs @@ -519,7 +519,7 @@ mod tests { use rand::Rng; const NB_TESTS: usize = 10; - const NB_OPERATOR_TESTS: usize = 10; + const NB_OPERATOR_TESTS: usize = 1; #[test] fn test_gpu_ciphertext_compression() { diff --git a/tfhe/src/shortint/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/shortint/ciphertext/compressed_ciphertext_list.rs index 304df183e..7d4359335 100644 --- a/tfhe/src/shortint/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/shortint/ciphertext/compressed_ciphertext_list.rs @@ -7,7 +7,7 @@ use crate::shortint::backward_compatibility::ciphertext::CompressedCiphertextLis use crate::shortint::parameters::CompressedCiphertextConformanceParams; use crate::shortint::{CarryModulus, MessageModulus}; -#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)] +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)] #[versionize(CompressedCiphertextListVersions)] pub struct CompressedCiphertextList { pub modulus_switched_glwe_ciphertext_list: Vec>,