Skip to content

Commit

Permalink
chore(integer): brings the CPU and GPU comopression tests into line.
Browse files Browse the repository at this point in the history
- also implements Debug, Eq, PartialEq to CompressedCiphertextList
  • Loading branch information
pdroalves committed Oct 4, 2024
1 parent 256378f commit e2bb2e6
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ use crate::core_crypto::prelude::*;
/// );
/// }
/// ```
#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)]
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)]
#[versionize(CompressedModulusSwitchedGlweCiphertextVersions)]
pub struct CompressedModulusSwitchedGlweCiphertext<Scalar: UnsignedInteger> {
pub(crate) packed_integers: PackedIntegers<Scalar>,
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/core_crypto/entities/packed_integers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::conformance::ParameterSetConformant;
use crate::core_crypto::backward_compatibility::entities::packed_integers::PackedIntegersVersions;
use crate::core_crypto::prelude::*;

#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)]
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)]
#[versionize(PackedIntegersVersions)]
pub struct PackedIntegers<Scalar: UnsignedInteger> {
pub(crate) packed_coeffs: Vec<Scalar>,
Expand Down
173 changes: 172 additions & 1 deletion tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl CompressedCiphertextListBuilder {
}
}

#[derive(Clone, Serialize, Deserialize, Versionize)]
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Versionize)]
#[versionize(CompressedCiphertextListVersions)]
pub struct CompressedCiphertextList {
pub(crate) packed_list: ShortintCompressedCiphertextList,
Expand Down Expand Up @@ -156,6 +156,8 @@ mod tests {
use crate::integer::ClientKey;
use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use itertools::Itertools;
use rand::Rng;

#[test]
fn test_heterogeneous_ciphertext_compression_ci_run_filter() {
Expand Down Expand Up @@ -195,4 +197,173 @@ mod tests {

assert!(cks.decrypt_bool(&decompressed3));
}

const NB_TESTS: usize = 10;
const NB_OPERATOR_TESTS: usize = 1;
#[test]
fn test_ciphertext_compression() {
let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);

let private_compression_key =
cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);

let (compression_key, decompression_key) =
cks.new_compression_decompression_keys(&private_compression_key);

let mut rng = rand::thread_rng();

let message_modulus: u128 = cks.parameters().message_modulus().0 as u128;

let num_blocks = 32;
for _ in 0..NB_TESTS {
// Unsigned
let modulus = message_modulus.pow(num_blocks as u32);
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = 1 + (rng.gen::<u64>() % 6);
let messages = (0..nb_messages)
.map(|_| rng.gen::<u128>() % modulus)
.collect::<Vec<_>>();

let cts = messages
.iter()
.map(|message| cks.encrypt_radix(*message, num_blocks))
.collect_vec();

let mut builder = CompressedCiphertextListBuilder::new();

for ct in cts {
builder.push(ct);
}

let compressed = builder.build(&compression_key);

for (i, message) in messages.iter().enumerate() {
let decompressed = compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted: u128 = cks.decrypt_radix(&decompressed);
assert_eq!(decrypted, *message);
}
}

// Signed
let modulus = message_modulus.pow((num_blocks - 1) as u32) as i128;
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = 1 + (rng.gen::<u64>() % 6);
let messages = (0..nb_messages)
.map(|_| rng.gen::<i128>() % modulus)
.collect::<Vec<_>>();

let cts = messages
.iter()
.map(|message| cks.encrypt_signed_radix(*message, num_blocks))
.collect_vec();

let mut builder = CompressedCiphertextListBuilder::new();

for ct in cts {
builder.push(ct);
}

let compressed = builder.build(&compression_key);

for (i, message) in messages.iter().enumerate() {
let decompressed = compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted: i128 = cks.decrypt_signed_radix(&decompressed);
assert_eq!(decrypted, *message);
}
}

// Boolean
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = 1 + (rng.gen::<u64>() % 6);
let messages = (0..nb_messages)
.map(|_| rng.gen::<i64>() % 2 != 0)
.collect::<Vec<_>>();

let cts = messages
.iter()
.map(|message| cks.encrypt_bool(*message))
.collect_vec();

let mut builder = CompressedCiphertextListBuilder::new();

for ct in cts {
builder.push(ct);
}

let cuda_compressed = builder.build(&compression_key);

for (i, message) in messages.iter().enumerate() {
let decompressed = cuda_compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted = cks.decrypt_bool(&decompressed);
assert_eq!(decrypted, *message);
}
}

// Hybrid
enum MessageType {
Unsigned(u128),
Signed(i128),
Boolean(bool),
}
for _ in 0..NB_OPERATOR_TESTS {
let mut builder = CompressedCiphertextListBuilder::new();

let nb_messages = 1 + (rng.gen::<u64>() % 6);
let mut messages = vec![];
for _ in 0..nb_messages {
let case_selector = rng.gen_range(0..3);
match case_selector {
0 => {
// Unsigned
let modulus = message_modulus.pow(num_blocks as u32);
let message = rng.gen::<u128>() % modulus;
let ct = cks.encrypt_radix(message, num_blocks);
builder.push(ct);
messages.push(MessageType::Unsigned(message));
}
1 => {
// Signed
let modulus = message_modulus.pow((num_blocks - 1) as u32) as i128;
let message = rng.gen::<i128>() % modulus;
let ct = cks.encrypt_signed_radix(message, num_blocks);
builder.push(ct);
messages.push(MessageType::Signed(message));
}
_ => {
// Boolean
let message = rng.gen::<i64>() % 2 != 0;
let ct = cks.encrypt_bool(message);
builder.push(ct);
messages.push(MessageType::Boolean(message));
}
}
}

let compressed = builder.build(&compression_key);

for (i, val) in messages.iter().enumerate() {
match val {
MessageType::Unsigned(message) => {
let decompressed =
compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted: u128 = cks.decrypt_radix(&decompressed);
assert_eq!(decrypted, *message);
}
MessageType::Signed(message) => {
let decompressed =
compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted: i128 = cks.decrypt_signed_radix(&decompressed);
assert_eq!(decrypted, *message);
}
MessageType::Boolean(message) => {
let decompressed =
compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted = cks.decrypt_bool(&decompressed);
assert_eq!(decrypted, *message);
}
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ mod tests {
use rand::Rng;

const NB_TESTS: usize = 10;
const NB_OPERATOR_TESTS: usize = 10;
const NB_OPERATOR_TESTS: usize = 1;

#[test]
fn test_gpu_ciphertext_compression() {
Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/shortint/ciphertext/compressed_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::shortint::backward_compatibility::ciphertext::CompressedCiphertextLis
use crate::shortint::parameters::CompressedCiphertextConformanceParams;
use crate::shortint::{CarryModulus, MessageModulus};

#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)]
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)]
#[versionize(CompressedCiphertextListVersions)]
pub struct CompressedCiphertextList {
pub modulus_switched_glwe_ciphertext_list: Vec<CompressedModulusSwitchedGlweCiphertext<u64>>,
Expand Down

0 comments on commit e2bb2e6

Please sign in to comment.