Skip to content

Commit

Permalink
chore(all): add tuniform params for GPU, make TUNIFORM params default…
Browse files Browse the repository at this point in the history
… whenever we have them
  • Loading branch information
agnesLeroy committed Oct 2, 2024
1 parent cb9dac6 commit a595740
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 22 deletions.
30 changes: 15 additions & 15 deletions tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,19 @@ impl CudaCompressedCiphertextList {
/// use tfhe::integer::gpu::ciphertext::compressed_ciphertext_list::CudaCompressedCiphertextListBuilder;
/// use tfhe::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext};
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// use tfhe::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
/// use tfhe::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
///
/// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64);
/// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
///
/// let private_compression_key =
/// cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64);
/// cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
///
/// let streams = CudaStreams::new_multi_gpu();
///
/// let num_blocks = 32;
/// let (radix_cks, _) = gen_keys_radix_gpu(
/// PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
/// PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
/// num_blocks,
/// &streams,
/// );
Expand Down Expand Up @@ -268,19 +268,19 @@ impl CompressedCiphertextList {
/// use tfhe::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext};
/// use tfhe::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// use tfhe::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
/// use tfhe::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
///
/// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64);
/// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
///
/// let private_compression_key =
/// cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64);
/// cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
///
/// let streams = CudaStreams::new_multi_gpu();
///
/// let num_blocks = 32;
/// let (radix_cks, _) = gen_keys_radix_gpu(
/// PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
/// PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
/// num_blocks,
/// &streams,
/// );
Expand Down Expand Up @@ -514,25 +514,25 @@ mod tests {
use super::*;
use crate::integer::gpu::gen_keys_radix_gpu;
use crate::integer::ClientKey;
use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use rand::Rng;

const NB_TESTS: usize = 10;
const NB_OPERATOR_TESTS: usize = 10;

#[test]
fn test_gpu_ciphertext_compression() {
let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64);
let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);

let private_compression_key =
cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64);
cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);

let streams = CudaStreams::new_multi_gpu();

let num_blocks = 32;
let (radix_cks, _) = gen_keys_radix_gpu(
PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
num_blocks,
&streams,
);
Expand Down
12 changes: 6 additions & 6 deletions tfhe/src/shortint/parameters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ pub const PARAM_MESSAGE_2_CARRY_0_KS_PBS: ClassicPBSParameters =
pub const PARAM_MESSAGE_2_CARRY_1_KS_PBS: ClassicPBSParameters =
PARAM_MESSAGE_2_CARRY_1_KS_PBS_GAUSSIAN_2M64;
pub const PARAM_MESSAGE_2_CARRY_2_KS_PBS: ClassicPBSParameters =
PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
pub const PARAM_MESSAGE_2_CARRY_3_KS_PBS: ClassicPBSParameters =
PARAM_MESSAGE_2_CARRY_3_KS_PBS_GAUSSIAN_2M64;
pub const PARAM_MESSAGE_2_CARRY_4_KS_PBS: ClassicPBSParameters =
Expand Down Expand Up @@ -839,13 +839,13 @@ pub const PARAM_MESSAGE_7_CARRY_1_KS_PBS: ClassicPBSParameters =
pub const PARAM_MESSAGE_8_CARRY_0_KS_PBS: ClassicPBSParameters =
PARAM_MESSAGE_8_CARRY_0_KS_PBS_GAUSSIAN_2M64;
pub const PARAM_MESSAGE_1_CARRY_1_PBS_KS: ClassicPBSParameters =
PARAM_MESSAGE_1_CARRY_1_PBS_KS_GAUSSIAN_2M64;
PARAM_MESSAGE_1_CARRY_1_PBS_KS_TUNIFORM_2M64;
pub const PARAM_MESSAGE_2_CARRY_2_PBS_KS: ClassicPBSParameters =
PARAM_MESSAGE_2_CARRY_2_PBS_KS_GAUSSIAN_2M64;
PARAM_MESSAGE_2_CARRY_2_PBS_KS_TUNIFORM_2M64;
pub const PARAM_MESSAGE_3_CARRY_3_PBS_KS: ClassicPBSParameters =
PARAM_MESSAGE_3_CARRY_3_PBS_KS_GAUSSIAN_2M64;
PARAM_MESSAGE_3_CARRY_3_PBS_KS_TUNIFORM_2M64;
pub const PARAM_MESSAGE_4_CARRY_4_PBS_KS: ClassicPBSParameters =
PARAM_MESSAGE_4_CARRY_4_PBS_KS_GAUSSIAN_2M64;
PARAM_MESSAGE_4_CARRY_4_PBS_KS_TUNIFORM_2M64;

pub const PARAM_MESSAGE_1_CARRY_0: ClassicPBSParameters = PARAM_MESSAGE_1_CARRY_0_KS_PBS;
pub const PARAM_MESSAGE_1_CARRY_1: ClassicPBSParameters = PARAM_MESSAGE_1_CARRY_1_KS_PBS;
Expand Down Expand Up @@ -889,6 +889,6 @@ pub const PARAM_SMALL_MESSAGE_3_CARRY_3: ClassicPBSParameters = PARAM_MESSAGE_3_
pub const PARAM_SMALL_MESSAGE_4_CARRY_4: ClassicPBSParameters = PARAM_MESSAGE_4_CARRY_4_PBS_KS;

pub const COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS: CompressionParameters =
list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;

pub const COMP_PARAM_MESSAGE_2_CARRY_2: CompressionParameters = COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS;
2 changes: 1 addition & 1 deletion tfhe/src/shortint/parameters/multi_bit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,6 @@ pub const PARAM_GPU_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS: MultiBitPBSParam
pub const PARAM_GPU_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_3_KS_PBS: MultiBitPBSParameters =
PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64;
pub const PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS: MultiBitPBSParameters =
PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
pub const PARAM_GPU_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS: MultiBitPBSParameters =
PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64;
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,23 @@ pub const PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64:
grouping_factor: LweBskGroupingFactor(3),
deterministic_execution: false,
};
pub const PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64:
MultiBitPBSParameters = MultiBitPBSParameters {
lwe_dimension: LweDimension(882),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(2048),
lwe_noise_distribution: DynamicDistribution::new_t_uniform(46),
glwe_noise_distribution: DynamicDistribution::new_t_uniform(17),
pbs_base_log: DecompositionBaseLog(22),
pbs_level: DecompositionLevelCount(1),
ks_base_log: DecompositionBaseLog(3),
ks_level: DecompositionLevelCount(5),
message_modulus: MessageModulus(4),
carry_modulus: CarryModulus(4),
max_noise_level: MaxNoiseLevel::new(5),
log2_p_fail: -64.59,
ciphertext_modulus: CiphertextModulus::new_native(),
encryption_key_choice: EncryptionKeyChoice::Big,
grouping_factor: LweBskGroupingFactor(3),
deterministic_execution: true,
};

0 comments on commit a595740

Please sign in to comment.