From 267ddd8b3f67ae9062f1b58022ee6ccbc4d90452 Mon Sep 17 00:00:00 2001 From: Guillermo Oyarzun Date: Mon, 29 Jul 2024 17:26:42 +0200 Subject: [PATCH] feat(gpu): implement ilog2, trailing and leading zeros and ones on GPU --- .../tfhe-cuda-backend/cuda/include/integer.h | 35 +- .../cuda/src/integer/integer.cu | 52 + .../cuda/src/integer/integer.cuh | 114 +- .../cuda/src/integer/multiplication.cu | 30 +- .../cuda/src/integer/multiplication.cuh | 15 +- .../cuda/src/integer/scalar_mul.cuh | 7 +- backends/tfhe-cuda-backend/src/cuda_bind.rs | 55 +- tfhe/src/integer/gpu/ciphertext/info.rs | 4 + tfhe/src/integer/gpu/mod.rs | 125 +- tfhe/src/integer/gpu/server_key/radix/add.rs | 65 +- .../gpu/server_key/radix/comparison.rs | 2 +- .../src/integer/gpu/server_key/radix/ilog2.rs | 1070 +++++++++++++++++ tfhe/src/integer/gpu/server_key/radix/mod.rs | 43 +- .../gpu/server_key/radix/scalar_comparison.rs | 2 +- .../gpu/server_key/radix/tests_signed/mod.rs | 55 + .../radix/tests_signed/test_ilog2.rs | 64 + .../server_key/radix/tests_unsigned/mod.rs | 33 + .../radix/tests_unsigned/test_ilog2.rs | 64 + .../server_key/radix_parallel/ilog2.rs | 875 +------------- .../integer/server_key/radix_parallel/mod.rs | 2 +- .../radix_parallel/tests_cases_unsigned.rs | 6 - .../radix_parallel/tests_signed/mod.rs | 99 +- .../radix_parallel/tests_signed/test_ilog2.rs | 504 ++++++++ .../radix_parallel/tests_unsigned/mod.rs | 65 +- .../tests_unsigned/test_ilog2.rs | 488 ++++++++ 25 files changed, 2763 insertions(+), 1111 deletions(-) create mode 100644 tfhe/src/integer/gpu/server_key/radix/ilog2.rs create mode 100644 tfhe/src/integer/gpu/server_key/radix/tests_signed/test_ilog2.rs create mode 100644 tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_ilog2.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/tests_signed/test_ilog2.rs create mode 100644 tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_ilog2.rs diff --git a/backends/tfhe-cuda-backend/cuda/include/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer.h index c1874aed98..871829abeb 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer.h @@ -283,7 +283,7 @@ void cleanup_cuda_propagate_single_carry(void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, int8_t **mem_ptr_void); -void scratch_cuda_integer_radix_sum_ciphertexts_vec_kb_64( +void scratch_cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, uint32_t pbs_level, @@ -292,15 +292,14 @@ void scratch_cuda_integer_radix_sum_ciphertexts_vec_kb_64( uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory); -void cuda_integer_radix_sum_ciphertexts_vec_kb_64( +void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, void *radix_lwe_out, void *radix_lwe_vec, uint32_t num_radix_in_vec, int8_t *mem_ptr, void **bsks, void **ksks, uint32_t num_blocks_in_radix); -void cleanup_cuda_integer_radix_sum_ciphertexts_vec(void **streams, - uint32_t *gpu_indexes, - uint32_t gpu_count, - int8_t **mem_ptr_void); +void cleanup_cuda_integer_radix_partial_sum_ciphertexts_vec( + void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, + int8_t **mem_ptr_void); void scratch_cuda_integer_radix_overflowing_sub_kb_64( void **stream, uint32_t *gpu_indexes, uint32_t gpu_count, int8_t **mem_ptr, @@ -375,6 +374,30 @@ void cleanup_signed_overflowing_add_or_sub(void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, int8_t **mem_ptr_void); + +void scratch_cuda_integer_compute_prefix_sum_hillis_steele_64( + void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, int8_t **mem_ptr, + void *input_lut, uint32_t lwe_dimension, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t num_radix_blocks, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, bool allocate_gpu_memory); + +void cuda_integer_compute_prefix_sum_hillis_steele_64( + void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, + void *output_radix_lwe, void *input_radix_lwe, int8_t *mem_ptr, void **ksks, + void **bsks, uint32_t num_blocks, uint32_t shift); + +void cleanup_cuda_integer_compute_prefix_sum_hillis_steele_64( + void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, + int8_t **mem_ptr_void); + +void cuda_integer_reverse_blocks_64_inplace(void **streams, + uint32_t *gpu_indexes, + uint32_t gpu_count, void *lwe_array, + uint32_t num_blocks, + uint32_t lwe_size); + } // extern C template diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cu b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cu index 9152b0fd39..e5b4b2e742 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cu @@ -173,3 +173,55 @@ void cleanup_cuda_apply_bivariate_lut_kb_64(void **streams, int_radix_lut *mem_ptr = (int_radix_lut *)(*mem_ptr_void); mem_ptr->release((cudaStream_t *)(streams), gpu_indexes, gpu_count); } + +void scratch_cuda_integer_compute_prefix_sum_hillis_steele_64( + void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, int8_t **mem_ptr, + void *input_lut, uint32_t lwe_dimension, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t num_radix_blocks, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, bool allocate_gpu_memory) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + glwe_dimension * polynomial_size, lwe_dimension, + ks_level, ks_base_log, pbs_level, pbs_base_log, + grouping_factor, message_modulus, carry_modulus); + + scratch_cuda_apply_bivariate_lut_kb( + (cudaStream_t *)(streams), gpu_indexes, gpu_count, + (int_radix_lut **)mem_ptr, static_cast(input_lut), + num_radix_blocks, params, allocate_gpu_memory); +} + +void cuda_integer_compute_prefix_sum_hillis_steele_64( + void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, + void *output_radix_lwe, void *input_radix_lwe, int8_t *mem_ptr, void **ksks, + void **bsks, uint32_t num_blocks, uint32_t shift) { + + int_radix_params params = ((int_radix_lut *)mem_ptr)->params; + + host_compute_prefix_sum_hillis_steele( + (cudaStream_t *)(streams), gpu_indexes, gpu_count, + static_cast(output_radix_lwe), + static_cast(input_radix_lwe), params, + (int_radix_lut *)mem_ptr, bsks, (uint64_t **)(ksks), + num_blocks); +} + +void cleanup_cuda_integer_compute_prefix_sum_hillis_steele_64( + void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, + int8_t **mem_ptr_void) { + int_radix_lut *mem_ptr = (int_radix_lut *)(*mem_ptr_void); + mem_ptr->release((cudaStream_t *)(streams), gpu_indexes, gpu_count); +} + +void cuda_integer_reverse_blocks_64_inplace(void **streams, + uint32_t *gpu_indexes, + uint32_t gpu_count, void *lwe_array, + uint32_t num_blocks, + uint32_t lwe_size) { + + host_radix_blocks_reverse_inplace( + (cudaStream_t *)(streams), gpu_indexes, + static_cast(lwe_array), num_blocks, lwe_size); +} diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh index 00a633fbbf..e935d8768d 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh @@ -99,6 +99,35 @@ host_radix_blocks_rotate_left(cudaStream_t *streams, uint32_t *gpu_indexes, dst, src, value, blocks_count, lwe_size); } +// reverse the blocks in a list +// each cuda block swaps a couple of blocks +template +__global__ void radix_blocks_reverse_lwe_inplace(Torus *src, + uint32_t blocks_count, + uint32_t lwe_size) { + + size_t idx = blockIdx.x; + size_t rev_idx = blocks_count - 1 - idx; + + for (int j = threadIdx.x; j < lwe_size; j += blockDim.x) { + Torus back_element = src[rev_idx * lwe_size + j]; + Torus front_element = src[idx * lwe_size + j]; + src[idx * lwe_size + j] = back_element; + src[rev_idx * lwe_size + j] = front_element; + } +} + +template +__host__ void +host_radix_blocks_reverse_inplace(cudaStream_t *streams, uint32_t *gpu_indexes, + Torus *src, uint32_t blocks_count, + uint32_t lwe_size) { + cudaSetDevice(gpu_indexes[0]); + int num_blocks = blocks_count / 2, num_threads = 1024; + radix_blocks_reverse_lwe_inplace<<>>( + src, blocks_count, lwe_size); +} + // polynomial_size threads template __global__ void @@ -501,30 +530,17 @@ void scratch_cuda_propagate_single_carry_kb_inplace( } template -void host_propagate_single_carry(cudaStream_t *streams, uint32_t *gpu_indexes, - uint32_t gpu_count, Torus *lwe_array, - Torus *carry_out, Torus *input_carries, - int_sc_prop_memory *mem, void **bsks, - Torus **ksks, uint32_t num_blocks) { - auto params = mem->params; +void host_compute_prefix_sum_hillis_steele( + cudaStream_t *streams, uint32_t *gpu_indexes, uint32_t gpu_count, + Torus *step_output, Torus *generates_or_propagates, int_radix_params params, + int_radix_lut *luts, void **bsks, Torus **ksks, + uint32_t num_blocks) { + auto glwe_dimension = params.glwe_dimension; auto polynomial_size = params.polynomial_size; auto big_lwe_size = glwe_dimension * polynomial_size + 1; auto big_lwe_size_bytes = big_lwe_size * sizeof(Torus); - auto generates_or_propagates = mem->generates_or_propagates; - auto step_output = mem->step_output; - - auto luts_array = mem->luts_array; - auto luts_carry_propagation_sum = mem->luts_carry_propagation_sum; - auto message_acc = mem->message_acc; - - integer_radix_apply_univariate_lookup_table_kb( - streams, gpu_indexes, gpu_count, generates_or_propagates, lwe_array, bsks, - ksks, num_blocks, luts_array); - - // compute prefix sum with hillis&steele - int num_steps = ceil(log2((double)num_blocks)); int space = 1; cuda_memcpy_async_gpu_to_gpu(step_output, generates_or_propagates, @@ -541,15 +557,42 @@ void host_propagate_single_carry(cudaStream_t *streams, uint32_t *gpu_indexes, integer_radix_apply_bivariate_lookup_table_kb( streams, gpu_indexes, gpu_count, cur_blocks, cur_blocks, prev_blocks, - bsks, ksks, cur_total_blocks, luts_carry_propagation_sum, - luts_carry_propagation_sum->params.message_modulus); + bsks, ksks, cur_total_blocks, luts, luts->params.message_modulus); - cuda_synchronize_stream(streams[0], gpu_indexes[0]); cuda_memcpy_async_gpu_to_gpu( &generates_or_propagates[space * big_lwe_size], cur_blocks, big_lwe_size_bytes * cur_total_blocks, streams[0], gpu_indexes[0]); space *= 2; } +} + +template +void host_propagate_single_carry(cudaStream_t *streams, uint32_t *gpu_indexes, + uint32_t gpu_count, Torus *lwe_array, + Torus *carry_out, Torus *input_carries, + int_sc_prop_memory *mem, void **bsks, + Torus **ksks, uint32_t num_blocks) { + auto params = mem->params; + auto glwe_dimension = params.glwe_dimension; + auto polynomial_size = params.polynomial_size; + auto big_lwe_size = glwe_dimension * polynomial_size + 1; + auto big_lwe_size_bytes = big_lwe_size * sizeof(Torus); + + auto generates_or_propagates = mem->generates_or_propagates; + auto step_output = mem->step_output; + + auto luts_array = mem->luts_array; + auto luts_carry_propagation_sum = mem->luts_carry_propagation_sum; + auto message_acc = mem->message_acc; + + integer_radix_apply_univariate_lookup_table_kb( + streams, gpu_indexes, gpu_count, generates_or_propagates, lwe_array, bsks, + ksks, num_blocks, luts_array); + + // compute prefix sum with hillis&steele + host_compute_prefix_sum_hillis_steele( + streams, gpu_indexes, gpu_count, step_output, generates_or_propagates, + params, luts_carry_propagation_sum, bsks, ksks, num_blocks); host_radix_blocks_rotate_right(streams, gpu_indexes, gpu_count, step_output, generates_or_propagates, 1, num_blocks, @@ -613,30 +656,9 @@ void host_propagate_single_sub_borrow(cudaStream_t *streams, ksks, num_blocks, luts_array); // compute prefix sum with hillis&steele - int num_steps = ceil(log2((double)num_blocks)); - int space = 1; - cuda_memcpy_async_gpu_to_gpu(step_output, generates_or_propagates, - big_lwe_size_bytes * num_blocks, streams[0], - gpu_indexes[0]); - - for (int step = 0; step < num_steps; step++) { - if (space > num_blocks - 1) - PANIC("Cuda error: step output is going out of bounds in Hillis Steele " - "propagation") - auto cur_blocks = &step_output[space * big_lwe_size]; - auto prev_blocks = generates_or_propagates; - int cur_total_blocks = num_blocks - space; - - integer_radix_apply_bivariate_lookup_table_kb( - streams, gpu_indexes, gpu_count, cur_blocks, cur_blocks, prev_blocks, - bsks, ksks, cur_total_blocks, luts_carry_propagation_sum, - luts_carry_propagation_sum->params.message_modulus); - - cuda_memcpy_async_gpu_to_gpu( - &generates_or_propagates[space * big_lwe_size], cur_blocks, - big_lwe_size_bytes * cur_total_blocks, streams[0], gpu_indexes[0]); - space *= 2; - } + host_compute_prefix_sum_hillis_steele( + streams, gpu_indexes, gpu_count, step_output, generates_or_propagates, + params, luts_carry_propagation_sum, bsks, ksks, num_blocks); cuda_memcpy_async_gpu_to_gpu( overflowed, &generates_or_propagates[big_lwe_size * (num_blocks - 1)], diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cu b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cu index 49bb1e4dca..40c59a5724 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cu @@ -202,7 +202,7 @@ void cleanup_cuda_integer_mult(void **streams, uint32_t *gpu_indexes, mem_ptr->release((cudaStream_t *)(streams), gpu_indexes, gpu_count); } -void scratch_cuda_integer_radix_sum_ciphertexts_vec_kb_64( +void scratch_cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, uint32_t pbs_level, @@ -215,13 +215,13 @@ void scratch_cuda_integer_radix_sum_ciphertexts_vec_kb_64( glwe_dimension * polynomial_size, lwe_dimension, ks_level, ks_base_log, pbs_level, pbs_base_log, grouping_factor, message_modulus, carry_modulus); - scratch_cuda_integer_sum_ciphertexts_vec_kb( + scratch_cuda_integer_partial_sum_ciphertexts_vec_kb( (cudaStream_t *)(streams), gpu_indexes, gpu_count, (int_sum_ciphertexts_vec_memory **)mem_ptr, num_blocks_in_radix, max_num_radix_in_vec, params, allocate_gpu_memory); } -void cuda_integer_radix_sum_ciphertexts_vec_kb_64( +void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, void *radix_lwe_out, void *radix_lwe_vec, uint32_t num_radix_in_vec, int8_t *mem_ptr, void **bsks, void **ksks, uint32_t num_blocks_in_radix) { @@ -237,42 +237,47 @@ void cuda_integer_radix_sum_ciphertexts_vec_kb_64( switch (mem->params.polynomial_size) { case 512: - host_integer_sum_ciphertexts_vec_kb>( + host_integer_partial_sum_ciphertexts_vec_kb>( (cudaStream_t *)(streams), gpu_indexes, gpu_count, static_cast(radix_lwe_out), static_cast(radix_lwe_vec), terms_degree, bsks, (uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec); break; case 1024: - host_integer_sum_ciphertexts_vec_kb>( + host_integer_partial_sum_ciphertexts_vec_kb>( (cudaStream_t *)(streams), gpu_indexes, gpu_count, static_cast(radix_lwe_out), static_cast(radix_lwe_vec), terms_degree, bsks, (uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec); break; case 2048: - host_integer_sum_ciphertexts_vec_kb>( + host_integer_partial_sum_ciphertexts_vec_kb>( (cudaStream_t *)(streams), gpu_indexes, gpu_count, static_cast(radix_lwe_out), static_cast(radix_lwe_vec), terms_degree, bsks, (uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec); break; case 4096: - host_integer_sum_ciphertexts_vec_kb>( + host_integer_partial_sum_ciphertexts_vec_kb>( (cudaStream_t *)(streams), gpu_indexes, gpu_count, static_cast(radix_lwe_out), static_cast(radix_lwe_vec), terms_degree, bsks, (uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec); break; case 8192: - host_integer_sum_ciphertexts_vec_kb>( + host_integer_partial_sum_ciphertexts_vec_kb>( (cudaStream_t *)(streams), gpu_indexes, gpu_count, static_cast(radix_lwe_out), static_cast(radix_lwe_vec), terms_degree, bsks, (uint64_t **)(ksks), mem, num_blocks_in_radix, num_radix_in_vec); break; case 16384: - host_integer_sum_ciphertexts_vec_kb>( + host_integer_partial_sum_ciphertexts_vec_kb>( (cudaStream_t *)(streams), gpu_indexes, gpu_count, static_cast(radix_lwe_out), static_cast(radix_lwe_vec), terms_degree, bsks, @@ -286,10 +291,9 @@ void cuda_integer_radix_sum_ciphertexts_vec_kb_64( free(terms_degree); } -void cleanup_cuda_integer_radix_sum_ciphertexts_vec(void **streams, - uint32_t *gpu_indexes, - uint32_t gpu_count, - int8_t **mem_ptr_void) { +void cleanup_cuda_integer_radix_partial_sum_ciphertexts_vec( + void **streams, uint32_t *gpu_indexes, uint32_t gpu_count, + int8_t **mem_ptr_void) { int_sum_ciphertexts_vec_memory *mem_ptr = (int_sum_ciphertexts_vec_memory *)(*mem_ptr_void); diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh index 99dbf438be..ba831d1f9f 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh @@ -169,7 +169,7 @@ __global__ void fill_radix_from_lsb_msb(Torus *result_blocks, Torus *lsb_blocks, } } template -__host__ void scratch_cuda_integer_sum_ciphertexts_vec_kb( +__host__ void scratch_cuda_integer_partial_sum_ciphertexts_vec_kb( cudaStream_t *streams, uint32_t *gpu_indexes, uint32_t gpu_count, int_sum_ciphertexts_vec_memory **mem_ptr, uint32_t num_blocks_in_radix, uint32_t max_num_radix_in_vec, @@ -181,7 +181,7 @@ __host__ void scratch_cuda_integer_sum_ciphertexts_vec_kb( } template -__host__ void host_integer_sum_ciphertexts_vec_kb( +__host__ void host_integer_partial_sum_ciphertexts_vec_kb( cudaStream_t *streams, uint32_t *gpu_indexes, uint32_t gpu_count, Torus *radix_lwe_out, Torus *terms, int *terms_degree, void **bsks, uint64_t **ksks, int_sum_ciphertexts_vec_memory *mem_ptr, @@ -425,10 +425,6 @@ __host__ void host_integer_sum_ciphertexts_vec_kb( host_addition(streams[0], gpu_indexes[0], radix_lwe_out, old_blocks, &old_blocks[num_blocks * big_lwe_size], big_lwe_dimension, num_blocks); - - host_propagate_single_carry(streams, gpu_indexes, gpu_count, - radix_lwe_out, nullptr, nullptr, - mem_ptr->scp_mem, bsks, ksks, num_blocks); } template @@ -539,10 +535,15 @@ __host__ void host_integer_mult_radix_kb( terms_degree_msb[i] = (b_id > r_id) ? message_modulus - 2 : 0; } - host_integer_sum_ciphertexts_vec_kb( + host_integer_partial_sum_ciphertexts_vec_kb( streams, gpu_indexes, gpu_count, radix_lwe_out, vector_result_sb, terms_degree, bsks, ksks, mem_ptr->sum_ciphertexts_mem, num_blocks, 2 * num_blocks, mem_ptr->luts_array); + + auto scp_mem_ptr = mem_ptr->sum_ciphertexts_mem->scp_mem; + host_propagate_single_carry(streams, gpu_indexes, gpu_count, + radix_lwe_out, nullptr, nullptr, + scp_mem_ptr, bsks, ksks, num_blocks); } template diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh index 817a7bc0d1..8347945532 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh @@ -105,10 +105,15 @@ __host__ void host_integer_scalar_mul_radix( for (int i = 0; i < j * num_radix_blocks; i++) { terms_degree[i] = message_modulus - 1; } - host_integer_sum_ciphertexts_vec_kb( + host_integer_partial_sum_ciphertexts_vec_kb( streams, gpu_indexes, gpu_count, lwe_array, all_shifted_buffer, terms_degree, bsks, ksks, mem->sum_ciphertexts_vec_mem, num_radix_blocks, j); + + auto scp_mem_ptr = mem->sum_ciphertexts_vec_mem->scp_mem; + host_propagate_single_carry(streams, gpu_indexes, gpu_count, lwe_array, + nullptr, nullptr, scp_mem_ptr, bsks, ksks, + num_radix_blocks); } } diff --git a/backends/tfhe-cuda-backend/src/cuda_bind.rs b/backends/tfhe-cuda-backend/src/cuda_bind.rs index 70f9b2a607..11f7c30632 100644 --- a/backends/tfhe-cuda-backend/src/cuda_bind.rs +++ b/backends/tfhe-cuda-backend/src/cuda_bind.rs @@ -1043,7 +1043,7 @@ extern "C" { mem_ptr: *mut *mut i8, ); - pub fn scratch_cuda_integer_radix_sum_ciphertexts_vec_kb_64( + pub fn scratch_cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( streams: *const *mut c_void, gpu_indexes: *const u32, gpu_count: u32, @@ -1064,7 +1064,7 @@ extern "C" { allocate_gpu_memory: bool, ); - pub fn cuda_integer_radix_sum_ciphertexts_vec_kb_64( + pub fn cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( streams: *const *mut c_void, gpu_indexes: *const u32, gpu_count: u32, @@ -1077,7 +1077,7 @@ extern "C" { num_blocks_in_radix: u32, ); - pub fn cleanup_cuda_integer_radix_sum_ciphertexts_vec( + pub fn cleanup_cuda_integer_radix_partial_sum_ciphertexts_vec( streams: *const *mut c_void, gpu_indexes: *const u32, gpu_count: u32, @@ -1210,4 +1210,53 @@ extern "C" { gpu_count: u32, mem_ptr: *mut *mut i8, ); + pub fn scratch_cuda_integer_compute_prefix_sum_hillis_steele_64( + streams: *const *mut c_void, + gpu_indexes: *const u32, + gpu_count: u32, + mem_ptr: *mut *mut i8, + input_lut: *const c_void, + lwe_dimension: u32, + glwe_dimension: u32, + polynomial_size: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + num_blocks: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: u32, + allocate_gpu_memory: bool, + ); + + pub fn cuda_integer_compute_prefix_sum_hillis_steele_64( + streams: *const *mut c_void, + gpu_indexes: *const u32, + gpu_count: u32, + output_radix_lwe: *mut c_void, + input_radix_lwe: *const c_void, + mem_ptr: *mut i8, + ksks: *const *mut c_void, + bsks: *const *mut c_void, + num_blocks: u32, + shift: u32, + ); + + pub fn cleanup_cuda_integer_compute_prefix_sum_hillis_steele_64( + streams: *const *mut c_void, + gpu_indexes: *const u32, + gpu_count: u32, + mem_ptr: *mut *mut i8, + ); + + pub fn cuda_integer_reverse_blocks_64_inplace( + streams: *const *mut c_void, + gpu_indexes: *const u32, + gpu_count: u32, + output_radix_lwe: *mut c_void, + num_blocks: u32, + lwe_size: u32, + ); } // extern "C" diff --git a/tfhe/src/integer/gpu/ciphertext/info.rs b/tfhe/src/integer/gpu/ciphertext/info.rs index e9c2431a9f..a2970b40fc 100644 --- a/tfhe/src/integer/gpu/ciphertext/info.rs +++ b/tfhe/src/integer/gpu/ciphertext/info.rs @@ -477,6 +477,8 @@ impl CudaRadixCiphertextInfo { &self, num_blocks: usize, ) -> Self { + assert!(num_blocks > 0); + let mut new_block_info = Self { blocks: Vec::with_capacity(self.blocks.len() + num_blocks), }; @@ -506,6 +508,8 @@ impl CudaRadixCiphertextInfo { } pub(crate) fn after_trim_radix_blocks_msb(&self, num_blocks: usize) -> Self { + assert!(num_blocks > 0); + let mut new_block_info = Self { blocks: Vec::with_capacity(self.blocks.len().saturating_sub(num_blocks)), }; diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index bb2151caa2..e4545da735 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -1901,7 +1901,7 @@ pub unsafe fn unchecked_scalar_rotate_right_integer_radix_kb_assign_async< /// /// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization /// is required -pub unsafe fn unchecked_sum_ciphertexts_integer_radix_kb_assign_async< +pub unsafe fn unchecked_partial_sum_ciphertexts_integer_radix_kb_assign_async< T: UnsignedInteger, B: Numeric, >( @@ -1945,7 +1945,7 @@ pub unsafe fn unchecked_sum_ciphertexts_integer_radix_kb_assign_async< "GPU error: all data should reside on the same GPU." ); let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - scratch_cuda_integer_radix_sum_ciphertexts_vec_kb_64( + scratch_cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( streams.ptr.as_ptr(), streams.gpu_indexes.as_ptr(), streams.len() as u32, @@ -1965,7 +1965,7 @@ pub unsafe fn unchecked_sum_ciphertexts_integer_radix_kb_assign_async< pbs_type as u32, true, ); - cuda_integer_radix_sum_ciphertexts_vec_kb_64( + cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64( streams.ptr.as_ptr(), streams.gpu_indexes.as_ptr(), streams.len() as u32, @@ -1977,7 +1977,7 @@ pub unsafe fn unchecked_sum_ciphertexts_integer_radix_kb_assign_async< keyswitch_key.ptr.as_ptr(), num_blocks, ); - cleanup_cuda_integer_radix_sum_ciphertexts_vec( + cleanup_cuda_integer_radix_partial_sum_ciphertexts_vec( streams.ptr.as_ptr(), streams.gpu_indexes.as_ptr(), streams.len() as u32, @@ -2410,3 +2410,120 @@ pub unsafe fn unchecked_signed_overflowing_add_or_sub_radix_kb_assign_async< std::ptr::addr_of_mut!(mem_ptr), ); } + +#[allow(clippy::too_many_arguments)] +/// # Safety +/// +/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization +/// is required +pub unsafe fn compute_prefix_sum_hillis_steele_async( + streams: &CudaStreams, + radix_lwe_output: &mut CudaSliceMut, + radix_lwe_input: &CudaSlice, + input_lut: &[T], + bootstrapping_key: &CudaVec, + keyswitch_key: &CudaVec, + lwe_dimension: LweDimension, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + ks_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + pbs_base_log: DecompositionBaseLog, + num_blocks: u32, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + pbs_type: PBSType, + grouping_factor: LweBskGroupingFactor, + shift: u32, +) { + assert_eq!( + streams.gpu_indexes[0], + radix_lwe_input.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + assert_eq!( + streams.gpu_indexes[0], + radix_lwe_output.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + assert_eq!( + streams.gpu_indexes[0], + bootstrapping_key.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + assert_eq!( + streams.gpu_indexes[0], + keyswitch_key.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + let mut mem_ptr: *mut i8 = std::ptr::null_mut(); + scratch_cuda_integer_compute_prefix_sum_hillis_steele_64( + streams.ptr.as_ptr(), + streams.gpu_indexes.as_ptr(), + streams.len() as u32, + std::ptr::addr_of_mut!(mem_ptr), + input_lut.as_ptr().cast(), + lwe_dimension.0 as u32, + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + ); + + cuda_integer_compute_prefix_sum_hillis_steele_64( + streams.ptr.as_ptr(), + streams.gpu_indexes.as_ptr(), + streams.len() as u32, + radix_lwe_output.as_mut_c_ptr(0), + radix_lwe_input.as_c_ptr(0), + mem_ptr, + keyswitch_key.ptr.as_ptr(), + bootstrapping_key.ptr.as_ptr(), + num_blocks, + shift, + ); + + cleanup_cuda_integer_compute_prefix_sum_hillis_steele_64( + streams.ptr.as_ptr(), + streams.gpu_indexes.as_ptr(), + streams.len() as u32, + std::ptr::addr_of_mut!(mem_ptr), + ); +} + +#[allow(clippy::too_many_arguments)] +/// # Safety +/// +/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization +/// is required +pub unsafe fn reverse_blocks_inplace_async( + streams: &CudaStreams, + radix_lwe_output: &mut CudaSliceMut, + num_blocks: u32, + lwe_size: u32, +) { + assert_eq!( + streams.gpu_indexes[0], + radix_lwe_output.gpu_index(0), + "GPU error: all data should reside on the same GPU." + ); + if num_blocks > 1 { + cuda_integer_reverse_blocks_64_inplace( + streams.ptr.as_ptr(), + streams.gpu_indexes.as_ptr(), + streams.len() as u32, + radix_lwe_output.as_mut_c_ptr(0), + num_blocks, + lwe_size, + ); + } +} diff --git a/tfhe/src/integer/gpu/server_key/radix/add.rs b/tfhe/src/integer/gpu/server_key/radix/add.rs index b96585852e..06e6e832b2 100644 --- a/tfhe/src/integer/gpu/server_key/radix/add.rs +++ b/tfhe/src/integer/gpu/server_key/radix/add.rs @@ -8,8 +8,8 @@ use crate::integer::gpu::ciphertext::{ use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey}; use crate::integer::gpu::{ unchecked_add_integer_radix_assign_async, - unchecked_signed_overflowing_add_or_sub_radix_kb_assign_async, - unchecked_sum_ciphertexts_integer_radix_kb_assign_async, PBSType, + unchecked_partial_sum_ciphertexts_integer_radix_kb_assign_async, + unchecked_signed_overflowing_add_or_sub_radix_kb_assign_async, PBSType, }; use crate::shortint::ciphertext::NoiseLevel; @@ -231,10 +231,10 @@ impl CudaServerKey { /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until stream is synchronised - pub unsafe fn unchecked_sum_ciphertexts_assign_async( + pub unsafe fn unchecked_partial_sum_ciphertexts_assign_async( &self, - result: &mut CudaUnsignedRadixCiphertext, - ciphertexts: &[CudaUnsignedRadixCiphertext], + result: &mut T, + ciphertexts: &[T], streams: &CudaStreams, ) { if ciphertexts.is_empty() { @@ -275,7 +275,7 @@ impl CudaServerKey { match &self.bootstrapping_key { CudaBootstrappingKey::Classic(d_bsk) => { - unchecked_sum_ciphertexts_integer_radix_kb_assign_async( + unchecked_partial_sum_ciphertexts_integer_radix_kb_assign_async( streams, &mut result.as_mut().d_blocks.0.d_vec, &mut terms.0.d_vec, @@ -299,7 +299,7 @@ impl CudaServerKey { ); } CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { - unchecked_sum_ciphertexts_integer_radix_kb_assign_async( + unchecked_partial_sum_ciphertexts_integer_radix_kb_assign_async( streams, &mut result.as_mut().d_blocks.0.d_vec, &mut terms.0.d_vec, @@ -323,13 +323,46 @@ impl CudaServerKey { ); } } + self.propagate_single_carry_assign_async(result, streams); } - pub fn unchecked_sum_ciphertexts( + pub fn unchecked_sum_ciphertexts( &self, - ciphertexts: &[CudaUnsignedRadixCiphertext], + ciphertexts: &[T], streams: &CudaStreams, - ) -> Option { + ) -> T { + let mut result = unsafe { + self.unchecked_partial_sum_ciphertexts_async(ciphertexts, streams) + .unwrap() + }; + + unsafe { + self.propagate_single_carry_assign_async(&mut result, streams); + } + streams.synchronize(); + assert!(result.block_carries_are_empty()); + result + } + + pub fn unchecked_partial_sum_ciphertexts( + &self, + ciphertexts: &[T], + streams: &CudaStreams, + ) -> Option { + let result = unsafe { self.unchecked_partial_sum_ciphertexts_async(ciphertexts, streams) }; + streams.synchronize(); + result + } + + /// # Safety + /// + /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must + /// not be dropped until stream is synchronised + pub unsafe fn unchecked_partial_sum_ciphertexts_async( + &self, + ciphertexts: &[T], + streams: &CudaStreams, + ) -> Option { if ciphertexts.is_empty() { return None; } @@ -340,16 +373,16 @@ impl CudaServerKey { return Some(result); } - unsafe { self.unchecked_sum_ciphertexts_assign_async(&mut result, ciphertexts, streams) }; - streams.synchronize(); + self.unchecked_partial_sum_ciphertexts_assign_async(&mut result, ciphertexts, streams); + Some(result) } - pub fn sum_ciphertexts( + pub fn sum_ciphertexts( &self, - mut ciphertexts: Vec, + mut ciphertexts: Vec, streams: &CudaStreams, - ) -> Option { + ) -> Option { if ciphertexts.is_empty() { return None; } @@ -363,7 +396,7 @@ impl CudaServerKey { }); } - self.unchecked_sum_ciphertexts(&ciphertexts, streams) + Some(self.unchecked_sum_ciphertexts(&ciphertexts, streams)) } /// ```rust diff --git a/tfhe/src/integer/gpu/server_key/radix/comparison.rs b/tfhe/src/integer/gpu/server_key/radix/comparison.rs index 50e7ee9875..c1443193fa 100644 --- a/tfhe/src/integer/gpu/server_key/radix/comparison.rs +++ b/tfhe/src/integer/gpu/server_key/radix/comparison.rs @@ -41,7 +41,7 @@ impl CudaServerKey { stream, ); let mut block_info = ct_left.as_ref().info.blocks[0]; - block_info.degree = Degree::new(0); + block_info.degree = Degree::new(1); let ct_info = vec![block_info]; let ct_info = CudaRadixCiphertextInfo { blocks: ct_info }; diff --git a/tfhe/src/integer/gpu/server_key/radix/ilog2.rs b/tfhe/src/integer/gpu/server_key/radix/ilog2.rs new file mode 100644 index 0000000000..90fef9118d --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/ilog2.rs @@ -0,0 +1,1070 @@ +use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList; +use crate::core_crypto::gpu::vec::CudaVec; +use crate::core_crypto::gpu::CudaStreams; +use crate::core_crypto::prelude::{LweBskGroupingFactor, LweCiphertextCount}; +use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; +use crate::integer::gpu::ciphertext::{ + CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext, +}; +use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey}; +use crate::integer::gpu::{ + apply_univariate_lut_kb_async, compute_prefix_sum_hillis_steele_async, + reverse_blocks_inplace_async, PBSType, +}; +use crate::integer::server_key::radix_parallel::ilog2::{BitValue, Direction}; + +impl CudaServerKey { + /// This function takes a ciphertext in radix representation + /// and returns a vec of blocks, where each blocks holds the number of leading_zeros/ones + /// + /// This contains the logic of making a block have 0 leading_ones/zeros if its preceding + /// block was not full of ones/zeros + /// # Safety + /// + /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must + /// not be dropped until stream is synchronised + pub(crate) unsafe fn prepare_count_of_consecutive_bits_async( + &self, + ct: &T, + direction: Direction, + bit_value: BitValue, + streams: &CudaStreams, + ) -> CudaLweCiphertextList { + assert!( + self.carry_modulus.0 >= self.message_modulus.0, + "A carry modulus as least as big as the message modulus is required" + ); + + let num_ct_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0; + + let lwe_size = ct.as_ref().d_blocks.0.lwe_dimension.to_lwe_size().0; + + // Allocate the necessary amount of memory + let mut output_radix = + CudaVec::new(num_ct_blocks * lwe_size, streams, streams.gpu_indexes[0]); + + let lut = match direction { + Direction::Trailing => self.generate_lookup_table(|x| { + let x = x % self.message_modulus.0 as u64; + + let mut count = 0; + for i in 0..self.message_modulus.0.ilog2() { + if (x >> i) & 1 == bit_value.opposite() as u64 { + break; + } + count += 1; + } + count + }), + Direction::Leading => self.generate_lookup_table(|x| { + let x = x % self.message_modulus.0 as u64; + + let mut count = 0; + for i in (0..self.message_modulus.0.ilog2()).rev() { + if (x >> i) & 1 == bit_value.opposite() as u64 { + break; + } + count += 1; + } + count + }), + }; + + output_radix.copy_from_gpu_async( + &ct.as_ref().d_blocks.0.d_vec, + streams, + streams.gpu_indexes[0], + ); + let mut output_slice = output_radix + .as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0]) + .unwrap(); + + let input_slice = ct + .as_ref() + .d_blocks + .0 + .d_vec + .as_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0]) + .unwrap(); + + // Assign to each block its number of leading/trailing zeros/ones + // in the message space + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + apply_univariate_lut_kb_async( + streams, + &mut output_slice, + &input_slice, + lut.acc.as_ref(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + num_ct_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::Classical, + LweBskGroupingFactor(0), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + apply_univariate_lut_kb_async( + streams, + &mut output_slice, + &input_slice, + lut.acc.as_ref(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + num_ct_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + ); + } + } + + if direction == Direction::Leading { + // Our blocks are from lsb to msb + // `leading` means starting from the msb, so we reverse block + // for the cum sum process done later + reverse_blocks_inplace_async( + streams, + &mut output_slice, + num_ct_blocks as u32, + lwe_size as u32, + ); + } + + // Use hillis-steele cumulative-sum algorithm + // Here, each block either keeps his value (the number of leading zeros) + // or becomes 0 if the preceding block + // had a bit set to one in it (leading_zeros != num bits in message) + let num_bits_in_message = self.message_modulus.0.ilog2() as u64; + let sum_lut = self.generate_lookup_table_bivariate( + |block_num_bit_count, more_significant_block_bit_count| { + if more_significant_block_bit_count == num_bits_in_message { + block_num_bit_count + } else { + 0 + } + }, + ); + + let mut cts = CudaLweCiphertextList::new( + ct.as_ref().d_blocks.lwe_dimension(), + LweCiphertextCount(num_ct_blocks * ct.as_ref().d_blocks.lwe_ciphertext_count().0), + ct.as_ref().d_blocks.ciphertext_modulus(), + streams, + ); + + let input_radix_slice = output_radix + .as_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0]) + .unwrap(); + + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + compute_prefix_sum_hillis_steele_async( + streams, + &mut cts + .0 + .d_vec + .as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0]) + .unwrap(), + &input_radix_slice, + sum_lut.acc.acc.as_ref(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + num_ct_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::Classical, + LweBskGroupingFactor(0), + 0u32, + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + compute_prefix_sum_hillis_steele_async( + streams, + &mut cts + .0 + .d_vec + .as_mut_slice(0..lwe_size * num_ct_blocks, streams.gpu_indexes[0]) + .unwrap(), + &input_radix_slice, + sum_lut.acc.acc.as_ref(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + num_ct_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + 0u32, + ); + } + } + cts + } + + /// Counts how many consecutive bits there are + /// + /// The returned Ciphertexts has a variable size + /// i.e. It contains just the minimum number of block + /// needed to represent the maximum possible number of bits. + /// # Safety + /// + /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must + /// not be dropped until stream is synchronised + pub(crate) unsafe fn count_consecutive_bits_async( + &self, + ct: &T, + direction: Direction, + bit_value: BitValue, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext { + if ct.as_ref().d_blocks.0.d_vec.is_empty() { + return self.create_trivial_zero_radix(0, streams); + } + + let num_bits_in_message = self.message_modulus.0.ilog2(); + let original_num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0; + + let num_bits_in_ciphertext = num_bits_in_message + .checked_mul(original_num_blocks as u32) + .expect("Number of bits encrypted exceeds u32::MAX"); + + let mut leading_count_per_blocks = self.prepare_count_of_consecutive_bits_async( + &ct.duplicate(streams), + direction, + bit_value, + streams, + ); + + // `num_bits_in_ciphertext` is the max value we want to represent + // its ilog2 + 1 gives use how many bits we need to be able to represent it. + let counter_num_blocks = + (num_bits_in_ciphertext.ilog2() + 1).div_ceil(self.message_modulus.0.ilog2()) as usize; + + let lwe_dimension = ct.as_ref().d_blocks.lwe_dimension(); + + let lwe_size = lwe_dimension.to_lwe_size().0; + let mut cts = Vec::::with_capacity( + ct.as_ref().d_blocks.lwe_ciphertext_count().0, + ); + for i in 0..ct.as_ref().d_blocks.lwe_ciphertext_count().0 { + let mut new_item: CudaUnsignedRadixCiphertext = + self.create_trivial_zero_radix(counter_num_blocks, streams); + let mut dest_slice = new_item + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(0..lwe_size, streams.gpu_indexes[0]) + .unwrap(); + + let src_slice = leading_count_per_blocks + .0 + .d_vec + .as_mut_slice((i * lwe_size)..((i + 1) * lwe_size), streams.gpu_indexes[0]) + .unwrap(); + dest_slice.copy_from_gpu_async(&src_slice, streams, 0); + cts.push(new_item); + } + + self.unchecked_sum_ciphertexts(&cts, streams) + } + + //============================================================================================== + // Unchecked + //============================================================================================== + + /// See [Self::trailing_zeros] + /// + /// Expects ct to have clean carries + pub fn unchecked_trailing_zeros( + &self, + ct: &T, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext + where + T: CudaIntegerRadixCiphertext, + { + let res = unsafe { + self.count_consecutive_bits_async(ct, Direction::Trailing, BitValue::Zero, streams) + }; + streams.synchronize(); + res + } + + /// See [Self::trailing_ones] + /// + /// Expects ct to have clean carries + pub fn unchecked_trailing_ones( + &self, + ct: &T, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext + where + T: CudaIntegerRadixCiphertext, + { + let res = unsafe { + self.count_consecutive_bits_async(ct, Direction::Trailing, BitValue::One, streams) + }; + streams.synchronize(); + res + } + + /// See [Self::leading_zeros] + /// + /// Expects ct to have clean carries + pub fn unchecked_leading_zeros( + &self, + ct: &T, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext + where + T: CudaIntegerRadixCiphertext, + { + let res = unsafe { + self.count_consecutive_bits_async(ct, Direction::Leading, BitValue::Zero, streams) + }; + streams.synchronize(); + res + } + + /// See [Self::leading_ones] + /// + /// Expects ct to have clean carries + pub fn unchecked_leading_ones( + &self, + ct: &T, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext + where + T: CudaIntegerRadixCiphertext, + { + let res = unsafe { + self.count_consecutive_bits_async(ct, Direction::Leading, BitValue::One, streams) + }; + streams.synchronize(); + res + } + + /// Returns the base 2 logarithm of the number, rounded down. + /// + /// See [Self::ilog2] for an example + /// + /// Expects ct to have clean carries + pub fn unchecked_ilog2(&self, ct: &T, streams: &CudaStreams) -> CudaUnsignedRadixCiphertext + where + T: CudaIntegerRadixCiphertext, + { + let res = unsafe { self.unchecked_ilog2_async(ct, streams) }; + streams.synchronize(); + res + } + + /// # Safety + /// + /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must + /// not be dropped until stream is synchronised + pub unsafe fn unchecked_ilog2_async( + &self, + ct: &T, + streams: &CudaStreams, + ) -> CudaUnsignedRadixCiphertext + where + T: CudaIntegerRadixCiphertext, + { + if ct.as_ref().d_blocks.0.d_vec.is_empty() { + return self + .create_trivial_zero_radix(ct.as_ref().d_blocks.lwe_ciphertext_count().0, streams); + } + + let num_bits_in_message = self.message_modulus.0.ilog2(); + let original_num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0; + + let num_bits_in_ciphertext = num_bits_in_message + .checked_mul(original_num_blocks as u32) + .expect("Number of bits encrypted exceeds u32::MAX"); + + // `num_bits_in_ciphertext-1` is the max value we want to represent + // its ilog2 + 1 gives use how many bits we need to be able to represent it. + // We add `1` to this number as we are going to use signed numbers later + // + // The ilog2 of a number that is on n bits, is in range 1..=n-1 + let counter_num_blocks = ((num_bits_in_ciphertext - 1).ilog2() + 1 + 1) + .div_ceil(self.message_modulus.0.ilog2()) as usize; + + // 11111000 + // x.ilog2() = (x.num_bit() - 1) - x.leading_zeros() + // - (x.num_bit() - 1) is trivially known + // - we can get leading zeros via a sum + // + // However, the sum include a full propagation, thus the subtraction + // will add another full propagation which is costly. + // + // However, we can do better: + // let N = (x.num_bit() - 1) + // let L0 = x.leading_zeros() + // ``` + // x.ilog2() = N - L0 + // x.ilog2() = -(-(N - L0)) + // x.ilog2() = -(-N + L0) + // ``` + // Since N is a clear number, getting -N is free, + // meaning -N + L0 where L0 is actually `sum(L0[b0], .., L0[num_blocks-1])` + // can be done with `sum(-N, L0[b0], .., L0[num_blocks-1]), by switching to signed + // numbers. + // + // Also, to do -(-N + L0) aka -sum(-N, L0[b0], .., L0[num_blocks-1]) + // we can make the sum not return a fully propagated result, + // and extract message/carry blocks while negating them at the same time + // using the fact that in twos complement -X = bitnot(X) + 1 + // so given a non propagated `C`, we can compute the fully propagated `PC` + // PC = bitnot(message(C)) + bitnot(blockshift(carry(C), 1)) + 2 + + let mut leading_zeros_per_blocks = self.prepare_count_of_consecutive_bits_async( + &ct.duplicate(streams), + Direction::Leading, + BitValue::Zero, + streams, + ); + let lwe_dimension = ct.as_ref().d_blocks.lwe_dimension(); + + let lwe_size = lwe_dimension.to_lwe_size().0; + let capacity = (leading_zeros_per_blocks.0.d_vec.len() / lwe_size) + 1; + let mut cts = Vec::::with_capacity(capacity); + + for i in 0..(capacity - 1) { + let mut new_item: CudaSignedRadixCiphertext = + self.create_trivial_zero_radix(counter_num_blocks, streams); + + let mut dest_slice = new_item + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(0..lwe_size, streams.gpu_indexes[0]) + .unwrap(); + + let src_slice = leading_zeros_per_blocks + .0 + .d_vec + .as_mut_slice((i * lwe_size)..((i + 1) * lwe_size), streams.gpu_indexes[0]) + .unwrap(); + dest_slice.copy_from_gpu_async(&src_slice, streams, 0); + cts.push(new_item); + } + + let new_trivial: CudaSignedRadixCiphertext = self.create_trivial_radix( + -(num_bits_in_ciphertext as i32 - 1i32), + counter_num_blocks, + streams, + ); + + cts.push(new_trivial); + + let mut result = self + .unchecked_partial_sum_ciphertexts(&cts, streams) + .expect("internal error, empty ciphertext count"); + + // This is the part where we extract message and carry blocks + // while inverting their bits + let lut_a = self.generate_lookup_table(|x| { + // extract message + let x = x % self.message_modulus.0 as u64; + // bitnot the message + (!x) % self.message_modulus.0 as u64 + }); + + let mut message_blocks = CudaLweCiphertextList::new( + lwe_dimension, + LweCiphertextCount(counter_num_blocks), + ct.as_ref().d_blocks.ciphertext_modulus(), + streams, + ); + let mut message_blocks_slice = message_blocks + .0 + .d_vec + .as_mut_slice(0..lwe_size * counter_num_blocks, streams.gpu_indexes[0]) + .unwrap(); + let result_slice = result + .as_mut() + .d_blocks + .0 + .d_vec + .as_slice(0..lwe_size * counter_num_blocks, streams.gpu_indexes[0]) + .unwrap(); + + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + apply_univariate_lut_kb_async( + streams, + &mut message_blocks_slice, + &result_slice, + lut_a.acc.as_ref(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + counter_num_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::Classical, + LweBskGroupingFactor(0), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + apply_univariate_lut_kb_async( + streams, + &mut message_blocks_slice, + &result_slice, + lut_a.acc.as_ref(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + counter_num_blocks as u32, + self.message_modulus, + self.carry_modulus, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + ); + } + } + + let lut_b = self.generate_lookup_table(|x| { + // extract carry + let x = x / self.message_modulus.0 as u64; + // bitnot the carry + (!x) % self.message_modulus.0 as u64 + }); + + let mut carry_blocks = CudaLweCiphertextList::new( + lwe_dimension, + LweCiphertextCount( + counter_num_blocks, //* ct.as_ref().d_blocks.lwe_ciphertext_count().0, + ), + ct.as_ref().d_blocks.ciphertext_modulus(), + streams, + ); + + let mut trivial_last_block: CudaSignedRadixCiphertext = + self.create_trivial_radix((self.message_modulus.0 - 1) as u64, 1, streams); + let trivial_last_block_slice = trivial_last_block + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(0..lwe_size, streams.gpu_indexes[0]) + .unwrap(); + + let mut carry_blocks_last = carry_blocks + .0 + .d_vec + .as_mut_slice( + lwe_size * (counter_num_blocks - 1)..lwe_size * counter_num_blocks, + streams.gpu_indexes[0], + ) + .unwrap(); + + carry_blocks_last.copy_from_gpu_async(&trivial_last_block_slice, streams, 0u32); + + let mut carry_blocks_slice = carry_blocks + .0 + .d_vec + .as_mut_slice(0..lwe_size * counter_num_blocks, streams.gpu_indexes[0]) + .unwrap(); + unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + apply_univariate_lut_kb_async( + streams, + &mut carry_blocks_slice, + &result_slice, + lut_b.acc.as_ref(), + &d_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + counter_num_blocks as u32 - 1, + self.message_modulus, + self.carry_modulus, + PBSType::Classical, + LweBskGroupingFactor(0), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + apply_univariate_lut_kb_async( + streams, + &mut carry_blocks_slice, + &result_slice, + lut_b.acc.as_ref(), + &d_multibit_bsk.d_vec, + &self.key_switching_key.d_vec, + self.key_switching_key + .output_key_lwe_size() + .to_lwe_dimension(), + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + self.key_switching_key.decomposition_level_count(), + self.key_switching_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + counter_num_blocks as u32 - 1, + self.message_modulus, + self.carry_modulus, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + ); + } + } + } + + let mut ciphertexts = Vec::::with_capacity(3); + + let mut new_item: CudaSignedRadixCiphertext = + self.create_trivial_zero_radix(counter_num_blocks, streams); + let mut dest_slice = new_item + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(0..counter_num_blocks * lwe_size, streams.gpu_indexes[0]) + .unwrap(); + + let src_slice = message_blocks + .0 + .d_vec + .as_mut_slice(0..(counter_num_blocks * lwe_size), streams.gpu_indexes[0]) + .unwrap(); + + dest_slice.copy_from_gpu_async(&src_slice, streams, 0); + + ciphertexts.push(new_item); + + let mut new_item: CudaSignedRadixCiphertext = + self.create_trivial_zero_radix(counter_num_blocks, streams); + let mut dest_slice = new_item + .as_mut() + .d_blocks + .0 + .d_vec + .as_mut_slice(0..counter_num_blocks * lwe_size, streams.gpu_indexes[0]) + .unwrap(); + + let src_slice = carry_blocks + .0 + .d_vec + .as_mut_slice(0..(counter_num_blocks * lwe_size), streams.gpu_indexes[0]) + .unwrap(); + + dest_slice.copy_from_gpu_async(&src_slice, streams, 0); + + ciphertexts.push(new_item); + + let trivial_ct: CudaSignedRadixCiphertext = + self.create_trivial_radix(2u32, counter_num_blocks, streams); + ciphertexts.push(trivial_ct); + + let result = self.sum_ciphertexts(ciphertexts, streams).unwrap(); + + self.cast_to_unsigned(result, counter_num_blocks, streams) + } + + /// Returns the number of trailing zeros in the binary representation of `ct` + /// + /// The returned Ciphertexts has a variable size + /// i.e. It contains just the minimum number of block + /// needed to represent the maximum possible number of bits. + /// + /// This is a default function, it will internally clone the ciphertext if it has + /// non propagated carries, and it will output a ciphertext without any carries. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::gpu::CudaStreams; + /// use tfhe::integer::gpu::ciphertext::CudaSignedRadixCiphertext; + /// use tfhe::integer::gpu::gen_keys_gpu; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let number_of_blocks = 4; + /// + /// let gpu_index = 0; + /// let mut stream = CudaStreams::new_single_gpu(gpu_index); + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, &mut stream); + /// + /// let msg = -4i8; + /// + /// // Encrypt two messages + /// let ctxt = cks.encrypt_signed_radix(msg, number_of_blocks); + /// + /// let mut d_ctxt = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ctxt, &mut stream); + /// + /// // Compute homomorphically trailing zeros + /// let mut d_ct_res = sks.trailing_zeros(&d_ctxt, &stream); + /// + /// // Decrypt + /// let ct_res = d_ct_res.to_radix_ciphertext(&mut stream); + /// let res: u32 = cks.decrypt_radix(&ct_res); + /// assert_eq!(res, msg.trailing_zeros()); + /// ``` + pub fn trailing_zeros(&self, ct: &T, streams: &CudaStreams) -> CudaUnsignedRadixCiphertext + where + T: CudaIntegerRadixCiphertext, + { + let mut tmp; + let ct = if ct.block_carries_are_empty() { + ct + } else { + tmp = ct.duplicate(streams); + unsafe { + self.full_propagate_assign_async(&mut tmp, streams); + } + &tmp + }; + self.unchecked_trailing_zeros(ct, streams) + } + + /// Returns the number of trailing ones in the binary representation of `ct` + /// + /// The returned Ciphertexts has a variable size + /// i.e. It contains just the minimum number of block + /// needed to represent the maximum possible number of bits. + /// + /// This is a default function, it will internally clone the ciphertext if it has + /// non propagated carries, and it will output a ciphertext without any carries. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::gpu::CudaStreams; + /// use tfhe::integer::gpu::ciphertext::CudaSignedRadixCiphertext; + /// use tfhe::integer::gpu::gen_keys_gpu; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let number_of_blocks = 4; + /// + /// let gpu_index = 0; + /// let mut stream = CudaStreams::new_single_gpu(gpu_index); + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, &mut stream); + /// + /// let msg = -4i8; + /// + /// // Encrypt two messages + /// let ctxt = cks.encrypt_signed_radix(msg, number_of_blocks); + /// + /// let mut d_ctxt = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ctxt, &mut stream); + /// + /// // Compute homomorphically trailing ones + /// let mut d_ct_res = sks.trailing_ones(&d_ctxt, &stream); + /// + /// // Decrypt + /// let ct_res = d_ct_res.to_radix_ciphertext(&mut stream); + /// let res: u32 = cks.decrypt_radix(&ct_res); + /// assert_eq!(res, msg.trailing_ones()); + /// ``` + pub fn trailing_ones(&self, ct: &T, streams: &CudaStreams) -> CudaUnsignedRadixCiphertext + where + T: CudaIntegerRadixCiphertext, + { + let mut tmp; + let ct = if ct.block_carries_are_empty() { + ct + } else { + tmp = ct.duplicate(streams); + unsafe { + self.full_propagate_assign_async(&mut tmp, streams); + } + &tmp + }; + self.unchecked_trailing_ones(ct, streams) + } + + /// Returns the number of leading zeros in the binary representation of `ct` + /// + /// The returned Ciphertexts has a variable size + /// i.e. It contains just the minimum number of block + /// needed to represent the maximum possible number of bits. + /// + /// This is a default function, it will internally clone the ciphertext if it has + /// non propagated carries, and it will output a ciphertext without any carries. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::gpu::CudaStreams; + /// use tfhe::integer::gpu::ciphertext::CudaSignedRadixCiphertext; + /// use tfhe::integer::gpu::gen_keys_gpu; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let number_of_blocks = 4; + /// + /// let gpu_index = 0; + /// let mut stream = CudaStreams::new_single_gpu(gpu_index); + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, &mut stream); + /// + /// let msg = -4i8; + /// + /// // Encrypt two messages + /// let ctxt = cks.encrypt_signed_radix(msg, number_of_blocks); + /// + /// let mut d_ctxt = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ctxt, &mut stream); + /// + /// // Compute homomorphically leading zeros + /// let mut d_ct_res = sks.leading_zeros(&d_ctxt, &stream); + /// + /// // Decrypt + /// let ct_res = d_ct_res.to_radix_ciphertext(&mut stream); + /// let res: u32 = cks.decrypt_radix(&ct_res); + /// assert_eq!(res, msg.leading_zeros()); + /// ``` + pub fn leading_zeros(&self, ct: &T, streams: &CudaStreams) -> CudaUnsignedRadixCiphertext + where + T: CudaIntegerRadixCiphertext, + { + let mut tmp; + let ct = if ct.block_carries_are_empty() { + ct + } else { + tmp = ct.duplicate(streams); + unsafe { + self.full_propagate_assign_async(&mut tmp, streams); + } + &tmp + }; + self.unchecked_leading_zeros(ct, streams) + } + + /// Returns the number of leading ones in the binary representation of `ct` + /// + /// The returned Ciphertexts has a variable size + /// i.e. It contains just the minimum number of block + /// needed to represent the maximum possible number of bits. + /// + /// This is a default function, it will internally clone the ciphertext if it has + /// non propagated carries, and it will output a ciphertext without any carries. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::gpu::CudaStreams; + /// use tfhe::integer::gpu::ciphertext::CudaSignedRadixCiphertext; + /// use tfhe::integer::gpu::gen_keys_gpu; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let number_of_blocks = 4; + /// + /// let gpu_index = 0; + /// let mut stream = CudaStreams::new_single_gpu(gpu_index); + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, &mut stream); + /// + /// let msg = -4i8; + /// + /// // Encrypt two messages + /// let ctxt = cks.encrypt_signed_radix(msg, number_of_blocks); + /// + /// let mut d_ctxt = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ctxt, &mut stream); + /// + /// // Compute homomorphically leading ones + /// let mut d_ct_res = sks.leading_ones(&d_ctxt, &stream); + /// + /// // Decrypt + /// let ct_res = d_ct_res.to_radix_ciphertext(&mut stream); + /// let res: u32 = cks.decrypt_radix(&ct_res); + /// assert_eq!(res, msg.leading_ones()); + /// ``` + pub fn leading_ones(&self, ct: &T, streams: &CudaStreams) -> CudaUnsignedRadixCiphertext + where + T: CudaIntegerRadixCiphertext, + { + let mut tmp; + let ct = if ct.block_carries_are_empty() { + ct + } else { + tmp = ct.duplicate(streams); + unsafe { + self.full_propagate_assign_async(&mut tmp, streams); + } + &tmp + }; + self.unchecked_leading_ones(ct, streams) + } + + /// Returns the base 2 logarithm of the number, rounded down. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::gpu::CudaStreams; + /// use tfhe::integer::gpu::ciphertext::CudaSignedRadixCiphertext; + /// use tfhe::integer::gpu::gen_keys_gpu; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let number_of_blocks = 4; + /// + /// let gpu_index = 0; + /// let mut stream = CudaStreams::new_single_gpu(gpu_index); + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, &mut stream); + /// + /// let msg = 5i8; + /// + /// // Encrypt two messages + /// let ctxt = cks.encrypt_signed_radix(msg, number_of_blocks); + /// + /// let mut d_ctxt = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ctxt, &mut stream); + /// + /// // Compute homomorphically a log2 + /// let mut d_ct_res = sks.ilog2(&d_ctxt, &stream); + /// + /// // Decrypt + /// let ct_res = d_ct_res.to_radix_ciphertext(&mut stream); + /// let res: u32 = cks.decrypt_radix(&ct_res); + /// assert_eq!(res, msg.ilog2()); + /// ``` + pub fn ilog2(&self, ct: &T, streams: &CudaStreams) -> CudaUnsignedRadixCiphertext + where + T: CudaIntegerRadixCiphertext, + { + let mut tmp; + let ct = if ct.block_carries_are_empty() { + ct + } else { + tmp = ct.duplicate(streams); + unsafe { + self.full_propagate_assign_async(&mut tmp, streams); + } + &tmp + }; + + self.unchecked_ilog2(ct, streams) + } + + /// Returns the base 2 logarithm of the number, rounded down. + /// + /// Also returns a BooleanBlock, encrypting true (1) if the result is + /// valid (input is > 0), otherwise 0. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::gpu::CudaStreams; + /// use tfhe::integer::gpu::ciphertext::CudaSignedRadixCiphertext; + /// use tfhe::integer::gpu::gen_keys_gpu; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let number_of_blocks = 4; + /// + /// let gpu_index = 0; + /// let mut stream = CudaStreams::new_single_gpu(gpu_index); + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, &mut stream); + /// + /// let msg = 5i8; + /// + /// // Encrypt two messages + /// let ctxt = cks.encrypt_signed_radix(msg, number_of_blocks); + /// + /// let mut d_ctxt = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ctxt, &mut stream); + /// // Compute homomorphically a log2 and a check if input is valid + /// let (mut d_ct_res, mut d_is_oks) = sks.checked_ilog2(&d_ctxt, &stream); + /// + /// // Decrypt + /// let ct_res = d_ct_res.to_radix_ciphertext(&mut stream); + /// let res: u32 = cks.decrypt_radix(&ct_res); + /// assert_eq!(res, msg.ilog2()); + /// let is_oks = d_is_oks.to_boolean_block(&mut stream); + /// let is_ok = cks.decrypt_bool(&is_oks); + /// assert!(is_ok); + + pub fn checked_ilog2( + &self, + ct: &T, + streams: &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) + where + T: CudaIntegerRadixCiphertext, + { + let mut tmp; + let ct = if ct.block_carries_are_empty() { + ct + } else { + tmp = ct.duplicate(streams); + unsafe { + self.full_propagate_assign_async(&mut tmp, streams); + } + &tmp + }; + + (self.ilog2(ct, streams), self.scalar_gt(ct, 0, streams)) + } +} diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index 6a64ef19e6..a4469977b0 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -19,7 +19,7 @@ use crate::integer::gpu::{ }; use crate::shortint::ciphertext::{Degree, NoiseLevel}; use crate::shortint::engine::fill_accumulator; -use crate::shortint::server_key::LookupTableOwned; +use crate::shortint::server_key::{BivariateLookupTableOwned, LookupTableOwned}; use crate::shortint::PBSOrder; mod add; @@ -27,6 +27,7 @@ mod bitwise_op; mod cmux; mod comparison; mod div_mod; +mod ilog2; mod mul; mod neg; mod rotate; @@ -40,6 +41,7 @@ mod scalar_shift; mod scalar_sub; mod shift; mod sub; + #[cfg(test)] mod tests_signed; #[cfg(test)] @@ -404,6 +406,9 @@ impl CudaServerKey { num_blocks: usize, streams: &CudaStreams, ) -> T { + if num_blocks == 0 { + return ct.duplicate(streams); + } let new_num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0 + num_blocks; let ciphertext_modulus = ct.as_ref().d_blocks.ciphertext_modulus(); let lwe_size = ct.as_ref().d_blocks.lwe_dimension().to_lwe_size(); @@ -475,6 +480,9 @@ impl CudaServerKey { num_blocks: usize, streams: &CudaStreams, ) -> T { + if num_blocks == 0 { + return ct.duplicate(streams); + } let new_num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0 + num_blocks; let ciphertext_modulus = ct.as_ref().d_blocks.ciphertext_modulus(); let lwe_size = ct.as_ref().d_blocks.lwe_dimension().to_lwe_size(); @@ -540,6 +548,9 @@ impl CudaServerKey { num_blocks: usize, streams: &CudaStreams, ) -> T { + if num_blocks == 0 { + return ct.duplicate(streams); + } let new_num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0 - num_blocks; let ciphertext_modulus = ct.as_ref().d_blocks.ciphertext_modulus(); let lwe_size = ct.as_ref().d_blocks.lwe_dimension().to_lwe_size(); @@ -607,6 +618,9 @@ impl CudaServerKey { num_blocks: usize, streams: &CudaStreams, ) -> T { + if num_blocks == 0 { + return ct.duplicate(streams); + } let new_num_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0 - num_blocks; let ciphertext_modulus = ct.as_ref().d_blocks.ciphertext_modulus(); let lwe_size = ct.as_ref().d_blocks.lwe_dimension().to_lwe_size(); @@ -661,6 +675,30 @@ impl CudaServerKey { } } + /// Generates a bivariate accumulator + pub(crate) fn generate_lookup_table_bivariate(&self, f: F) -> BivariateLookupTableOwned + where + F: Fn(u64, u64) -> u64, + { + // Depending on the factor used, rhs and / or lhs may have carries + // (degree >= message_modulus) which is why we need to apply the message_modulus + // to clear them + let message_modulus = self.message_modulus.0 as u64; + let factor_u64 = message_modulus; + let wrapped_f = |input: u64| -> u64 { + let lhs = (input / factor_u64) % message_modulus; + let rhs = (input % factor_u64) % message_modulus; + + f(lhs, rhs) + }; + let accumulator = self.generate_lookup_table(wrapped_f); + + BivariateLookupTableOwned { + acc: accumulator, + ct_right_modulus: self.message_modulus, + } + } + /// # Safety /// /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must @@ -671,6 +709,9 @@ impl CudaServerKey { num_blocks: usize, streams: &CudaStreams, ) -> T { + if num_blocks == 0 { + return ct.duplicate(streams); + } let message_modulus = self.message_modulus.0 as u64; let num_bits_in_block = message_modulus.ilog2(); let padding_block_creator_lut = self.generate_lookup_table(|x| { diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs index 345143d05a..ef96b41914 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_comparison.rs @@ -164,7 +164,7 @@ impl CudaServerKey { streams, ); let mut block_info = ct.as_ref().info.blocks[0]; - block_info.degree = Degree::new(0); + block_info.degree = Degree::new(1); let ct_info = vec![block_info]; let ct_info = CudaRadixCiphertextInfo { blocks: ct_info }; diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs index bffdec8a51..b60305d610 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs @@ -2,6 +2,7 @@ pub(crate) mod test_add; pub(crate) mod test_bitwise_op; pub(crate) mod test_cmux; pub(crate) mod test_comparison; +pub(crate) mod test_ilog2; pub(crate) mod test_mul; pub(crate) mod test_neg; pub(crate) mod test_rotate; @@ -50,6 +51,60 @@ where gpu_result.to_signed_radix_ciphertext(&context.streams) } } +//For ilog2 +impl<'a, F> FunctionExecutor<&'a SignedRadixCiphertext, RadixCiphertext> for GpuFunctionExecutor +where + F: Fn(&CudaServerKey, &CudaSignedRadixCiphertext, &CudaStreams) -> CudaUnsignedRadixCiphertext, +{ + fn setup(&mut self, cks: &RadixClientKey, sks: Arc) { + self.setup_from_keys(cks, &sks); + } + + fn execute(&mut self, input: &'a SignedRadixCiphertext) -> RadixCiphertext { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let d_ctxt = + CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input, &context.streams); + + let gpu_result = (self.func)(&context.sks, &d_ctxt, &context.streams); + + gpu_result.to_radix_ciphertext(&context.streams) + } +} + +impl<'a, F> FunctionExecutor<&'a SignedRadixCiphertext, (RadixCiphertext, BooleanBlock)> + for GpuFunctionExecutor +where + F: Fn( + &CudaServerKey, + &CudaSignedRadixCiphertext, + &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock), +{ + fn setup(&mut self, cks: &RadixClientKey, sks: Arc) { + self.setup_from_keys(cks, &sks); + } + + fn execute(&mut self, input: &'a SignedRadixCiphertext) -> (RadixCiphertext, BooleanBlock) { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let d_ctxt = + CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input, &context.streams); + + let gpu_result = (self.func)(&context.sks, &d_ctxt, &context.streams); + + ( + gpu_result.0.to_radix_ciphertext(&context.streams), + gpu_result.1.to_boolean_block(&context.streams), + ) + } +} /// For default/unchecked binary functions impl<'a, F> diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_ilog2.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_ilog2.rs new file mode 100644 index 0000000000..7f61c0850c --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_ilog2.rs @@ -0,0 +1,64 @@ +use crate::integer::gpu::server_key::radix::tests_unsigned::{ + create_gpu_parametrized_test, GpuFunctionExecutor, +}; +use crate::integer::gpu::CudaServerKey; +use crate::integer::server_key::radix_parallel::tests_signed::test_ilog2::{ + default_checked_ilog2_test, default_ilog2_test, default_leading_ones_test, + default_leading_zeros_test, default_trailing_ones_test, default_trailing_zeros_test, +}; +use crate::shortint::parameters::*; + +create_gpu_parametrized_test!(integer_signed_default_trailing_zeros); +create_gpu_parametrized_test!(integer_signed_default_trailing_ones); +create_gpu_parametrized_test!(integer_signed_default_leading_zeros); +create_gpu_parametrized_test!(integer_signed_default_leading_ones); +create_gpu_parametrized_test!(integer_signed_default_ilog2); +create_gpu_parametrized_test!(integer_signed_default_checked_ilog2); + +fn integer_signed_default_trailing_zeros

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::trailing_zeros); + default_trailing_zeros_test(param, executor); +} + +fn integer_signed_default_trailing_ones

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::trailing_ones); + default_trailing_ones_test(param, executor); +} + +fn integer_signed_default_leading_zeros

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::leading_zeros); + default_leading_zeros_test(param, executor); +} + +fn integer_signed_default_leading_ones

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::leading_ones); + default_leading_ones_test(param, executor); +} + +fn integer_signed_default_ilog2

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::ilog2); + default_ilog2_test(param, executor); +} + +fn integer_signed_default_checked_ilog2

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::checked_ilog2); + default_checked_ilog2_test(param, executor); +} diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs index 95d27ec390..9db1dfad72 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs @@ -3,6 +3,7 @@ pub(crate) mod test_bitwise_op; pub(crate) mod test_cmux; pub(crate) mod test_comparison; pub(crate) mod test_div_mod; +pub(crate) mod test_ilog2; pub(crate) mod test_mul; pub(crate) mod test_neg; pub(crate) mod test_rotate; @@ -351,6 +352,38 @@ where } } +/// For ilog operation +impl<'a, F> FunctionExecutor<&'a RadixCiphertext, (RadixCiphertext, BooleanBlock)> + for GpuFunctionExecutor +where + F: Fn( + &CudaServerKey, + &CudaUnsignedRadixCiphertext, + &CudaStreams, + ) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock), +{ + fn setup(&mut self, cks: &RadixClientKey, sks: Arc) { + self.setup_from_keys(cks, &sks); + } + + fn execute(&mut self, input: &'a RadixCiphertext) -> (RadixCiphertext, BooleanBlock) { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let d_ctxt_1: CudaUnsignedRadixCiphertext = + CudaUnsignedRadixCiphertext::from_radix_ciphertext(input, &context.streams); + + let d_res = (self.func)(&context.sks, &d_ctxt_1, &context.streams); + + ( + d_res.0.to_radix_ciphertext(&context.streams), + d_res.1.to_boolean_block(&context.streams), + ) + } +} + impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), (RadixCiphertext, RadixCiphertext)> for GpuFunctionExecutor diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_ilog2.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_ilog2.rs new file mode 100644 index 0000000000..6b3c83abe8 --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/test_ilog2.rs @@ -0,0 +1,64 @@ +use crate::integer::gpu::server_key::radix::tests_unsigned::{ + create_gpu_parametrized_test, GpuFunctionExecutor, +}; +use crate::integer::gpu::CudaServerKey; +use crate::integer::server_key::radix_parallel::tests_unsigned::test_ilog2::{ + default_checked_ilog2_test, default_ilog2_test, default_leading_ones_test, + default_leading_zeros_test, default_trailing_ones_test, default_trailing_zeros_test, +}; +use crate::shortint::parameters::*; + +create_gpu_parametrized_test!(integer_default_trailing_zeros); +create_gpu_parametrized_test!(integer_default_trailing_ones); +create_gpu_parametrized_test!(integer_default_leading_zeros); +create_gpu_parametrized_test!(integer_default_leading_ones); +create_gpu_parametrized_test!(integer_default_ilog2); +create_gpu_parametrized_test!(integer_default_checked_ilog2); + +fn integer_default_trailing_zeros

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::trailing_zeros); + default_trailing_zeros_test(param, executor); +} + +fn integer_default_trailing_ones

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::trailing_ones); + default_trailing_ones_test(param, executor); +} + +fn integer_default_leading_zeros

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::leading_zeros); + default_leading_zeros_test(param, executor); +} + +fn integer_default_leading_ones

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::leading_ones); + default_leading_ones_test(param, executor); +} + +fn integer_default_ilog2

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::ilog2); + default_ilog2_test(param, executor); +} + +fn integer_default_checked_ilog2

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::checked_ilog2); + default_checked_ilog2_test(param, executor); +} diff --git a/tfhe/src/integer/server_key/radix_parallel/ilog2.rs b/tfhe/src/integer/server_key/radix_parallel/ilog2.rs index 52cacbefc7..5f851c5495 100644 --- a/tfhe/src/integer/server_key/radix_parallel/ilog2.rs +++ b/tfhe/src/integer/server_key/radix_parallel/ilog2.rs @@ -10,13 +10,13 @@ use rayon::prelude::*; /// Used to improved readability over using a `bool`. #[derive(Copy, Clone, Eq, PartialEq)] #[repr(u64)] -enum BitValue { +pub(crate) enum BitValue { Zero = 0, One = 1, } impl BitValue { - fn opposite(self) -> Self { + pub(crate) fn opposite(self) -> Self { match self { Self::One => Self::Zero, Self::Zero => Self::One, @@ -26,7 +26,7 @@ impl BitValue { /// Direction to count consecutive bits #[derive(Copy, Clone, Eq, PartialEq)] -enum Direction { +pub(crate) enum Direction { /// Count starting from the LSB Trailing, /// Count starting from MSB @@ -693,872 +693,3 @@ impl ServerKey { ) } } - -#[cfg(test)] -pub(crate) mod tests_unsigned { - use super::*; - use crate::integer::keycache::KEY_CACHE; - use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{ - FunctionExecutor, NB_CTXT, - }; - use crate::integer::server_key::radix_parallel::tests_unsigned::{ - nb_tests_smaller_for_params, random_non_zero_value, - }; - use crate::integer::{IntegerKeyKind, RadixClientKey}; - use crate::shortint::PBSParameters; - use rand::Rng; - use std::sync::Arc; - - fn default_test_count_consecutive_bits( - direction: Direction, - bit_value: BitValue, - param: P, - mut executor: T, - ) where - P: Into, - T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, - { - let param = param.into(); - let nb_tests_smaller = nb_tests_smaller_for_params(param); - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - - executor.setup(&cks, sks.clone()); - - let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); - - let compute_expected_clear = |x: u64| match (direction, bit_value) { - (Direction::Trailing, BitValue::Zero) => { - if x == 0 { - num_bits - } else { - x.trailing_zeros() - } - } - (Direction::Trailing, BitValue::One) => x.trailing_ones(), - (Direction::Leading, BitValue::Zero) => { - if x == 0 { - num_bits - } else { - (x << (u64::BITS - num_bits)).leading_zeros() - } - } - (Direction::Leading, BitValue::One) => (x << (u64::BITS - num_bits)).leading_ones(), - }; - - let method_name = match (direction, bit_value) { - (Direction::Trailing, BitValue::Zero) => "trailing_zeros", - (Direction::Trailing, BitValue::One) => "trailing_ones", - (Direction::Leading, BitValue::Zero) => "leading_zeros", - (Direction::Leading, BitValue::One) => "leading_ones", - }; - - let input_values = [0u64, modulus - 1] - .into_iter() - .chain((0..nb_tests_smaller).map(|_| rng.gen::() % modulus)) - .collect::>(); - - for clear in input_values { - let ctxt = cks.encrypt(clear); - - let ct_res = executor.execute(&ctxt); - let tmp = executor.execute(&ctxt); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - let expected_result = compute_expected_clear(clear); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for {method_name}, for {clear}.{method_name}() \ - expected {expected_result}, got {decrypted_result}" - ); - - for _ in 0..nb_tests_smaller { - // Add non-zero scalar to have non-clean ciphertexts - let clear_2 = random_non_zero_value(&mut rng, modulus); - - let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); - - let clear = clear.wrapping_add(clear_2) % modulus; - - let d0: u64 = cks.decrypt(&ctxt); - assert_eq!(d0, clear, "Failed sanity decryption check"); - - let ct_res = executor.execute(&ctxt); - assert!(ct_res.block_carries_are_empty()); - - let expected_result = compute_expected_clear(clear); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for {method_name}, for {clear}.{method_name}() \ - expected {expected_result}, got {decrypted_result}" - ); - } - } - - let input_values = [0u64, modulus - 1] - .into_iter() - .chain((0..nb_tests_smaller).map(|_| rng.gen::() % modulus)); - - for clear in input_values { - let ctxt = sks.create_trivial_radix(clear, NB_CTXT); - - let ct_res = executor.execute(&ctxt); - let tmp = executor.execute(&ctxt); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - let expected_result = compute_expected_clear(clear); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for {method_name}, for {clear}.{method_name}() \ - expected {expected_result}, got {decrypted_result}" - ); - } - } - - pub(crate) fn default_trailing_zeros_test(param: P, executor: T) - where - P: Into, - T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, - { - default_test_count_consecutive_bits(Direction::Trailing, BitValue::Zero, param, executor); - } - - pub(crate) fn default_trailing_ones_test(param: P, executor: T) - where - P: Into, - T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, - { - default_test_count_consecutive_bits(Direction::Trailing, BitValue::One, param, executor); - } - - pub(crate) fn default_leading_zeros_test(param: P, executor: T) - where - P: Into, - T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, - { - default_test_count_consecutive_bits(Direction::Leading, BitValue::Zero, param, executor); - } - - pub(crate) fn default_leading_ones_test(param: P, executor: T) - where - P: Into, - T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, - { - default_test_count_consecutive_bits(Direction::Leading, BitValue::One, param, executor); - } - - pub(crate) fn default_ilog2_test(param: P, mut executor: T) - where - P: Into, - T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, - { - let param = param.into(); - let nb_tests_smaller = nb_tests_smaller_for_params(param); - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - - executor.setup(&cks, sks.clone()); - - let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); - - // Test with invalid input - { - let ctxt = cks.encrypt(0u64); - - let ct_res = executor.execute(&ctxt); - assert!(ct_res.block_carries_are_empty()); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - let counter_num_blocks = ((num_bits - 1).ilog2() + 1 + 1) - .div_ceil(cks.parameters().message_modulus().0.ilog2()) - as usize; - let expected_result = (1u32 - << (counter_num_blocks as u32 * cks.parameters().message_modulus().0.ilog2())) - - 1; - assert_eq!( - decrypted_result, expected_result, - "Invalid result for ilog2 for 0.ilog2() \ - expected {expected_result}, got {decrypted_result}" - ); - } - - let input_values = (0..num_bits) - .map(|i| 1 << i) - .chain( - (0..nb_tests_smaller.saturating_sub(num_bits as usize)) - .map(|_| rng.gen_range(1..modulus)), - ) - .collect::>(); - - for clear in input_values { - let ctxt = cks.encrypt(clear); - - let ct_res = executor.execute(&ctxt); - let tmp = executor.execute(&ctxt); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - let expected_result = clear.ilog2(); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for ilog2 for {clear}.ilog2() \ - expected {expected_result}, got {decrypted_result}" - ); - - for _ in 0..nb_tests_smaller { - // Add non-zero scalar to have non-clean ciphertexts - // But here, we have to make sure clear is still > 0 - // as we are only testing valid ilog2 inputs - let (clear, clear_2) = loop { - let clear_2 = random_non_zero_value(&mut rng, modulus); - let clear = clear_2.wrapping_add(clear) % modulus; - if clear != 0 { - break (clear, clear_2); - } - }; - - let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); - - let d0: u64 = cks.decrypt(&ctxt); - assert_eq!(d0, clear, "Failed sanity decryption check"); - - let ct_res = executor.execute(&ctxt); - assert!(ct_res.block_carries_are_empty()); - - let expected_result = clear.ilog2(); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for ilog2, for {clear}.ilog2() \ - expected {expected_result}, got {decrypted_result}" - ); - } - } - - let input_values = (0..num_bits) - .map(|i| 1 << i) - .chain( - (0..nb_tests_smaller.saturating_sub(num_bits as usize)) - .map(|_| rng.gen_range(1..modulus)), - ) - .collect::>(); - - for clear in input_values { - let ctxt = sks.create_trivial_radix(clear, NB_CTXT); - - let ct_res = executor.execute(&ctxt); - let tmp = executor.execute(&ctxt); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - let expected_result = clear.ilog2(); - - assert_eq!( - decrypted_result, expected_result, - "Invalid result for ilog2, for {clear}.ilog2() \ - expected {expected_result}, got {decrypted_result}" - ); - } - } - - pub(crate) fn default_checked_ilog2_test(param: P, mut executor: T) - where - P: Into, - T: for<'a> FunctionExecutor<&'a RadixCiphertext, (RadixCiphertext, BooleanBlock)>, - { - let param = param.into(); - let nb_tests_smaller = nb_tests_smaller_for_params(param); - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - let sks = Arc::new(sks); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - - executor.setup(&cks, sks.clone()); - - let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); - - // Test with invalid input - { - let ctxt = cks.encrypt(0u64); - - let (ct_res, is_ok) = executor.execute(&ctxt); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(is_ok.as_ref().degree.get(), 1); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - let counter_num_blocks = ((num_bits - 1).ilog2() + 1 + 1) - .div_ceil(cks.parameters().message_modulus().0.ilog2()) - as usize; - let expected_result = (1u32 - << (counter_num_blocks as u32 * cks.parameters().message_modulus().0.ilog2())) - - 1; - assert_eq!( - decrypted_result, expected_result, - "Invalid result for ilog2 for 0.ilog2() \ - expected {expected_result}, got {decrypted_result}" - ); - let is_ok = cks.decrypt_bool(&is_ok); - assert!(!is_ok); - } - - let input_values = (0..num_bits) - .map(|i| 1 << i) - .chain( - (0..nb_tests_smaller.saturating_sub(num_bits as usize)) - .map(|_| rng.gen_range(1..modulus)), - ) - .collect::>(); - - for clear in input_values { - let ctxt = cks.encrypt(clear); - - let (ct_res, is_ok) = executor.execute(&ctxt); - let (tmp, tmp_is_ok) = executor.execute(&ctxt); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - assert_eq!(is_ok, tmp_is_ok); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - let expected_result = clear.ilog2(); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for ilog2 for {clear}.ilog2() \ - expected {expected_result}, got {decrypted_result}" - ); - let is_ok = cks.decrypt_bool(&is_ok); - assert!(is_ok); - - for _ in 0..nb_tests_smaller { - // Add non-zero scalar to have non-clean ciphertexts - // But here, we have to make sure clear is still > 0 - // as we are only testing valid ilog2 inputs - let (clear, clear_2) = loop { - let clear_2 = random_non_zero_value(&mut rng, modulus); - let clear = clear_2.wrapping_add(clear) % modulus; - if clear != 0 { - break (clear, clear_2); - } - }; - - let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); - - let d0: u64 = cks.decrypt(&ctxt); - assert_eq!(d0, clear, "Failed sanity decryption check"); - - let (ct_res, is_ok) = executor.execute(&ctxt); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(is_ok.as_ref().degree.get(), 1); - - let expected_result = clear.ilog2(); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for ilog2, for {clear}.ilog2() \ - expected {expected_result}, got {decrypted_result}" - ); - let is_ok = cks.decrypt_bool(&is_ok); - assert!(is_ok); - } - } - - let input_values = (0..num_bits) - .map(|i| 1 << i) - .chain( - (0..nb_tests_smaller.saturating_sub(num_bits as usize)) - .map(|_| rng.gen_range(1..modulus)), - ) - .collect::>(); - - for clear in input_values { - let ctxt = sks.create_trivial_radix(clear, NB_CTXT); - - let (ct_res, is_ok) = executor.execute(&ctxt); - assert!(ct_res.block_carries_are_empty()); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - let expected_result = clear.ilog2(); - - assert_eq!( - decrypted_result, expected_result, - "Invalid result for ilog2, for {clear}.ilog2() \ - expected {expected_result}, got {decrypted_result}" - ); - let is_ok = cks.decrypt_bool(&is_ok); - assert!(is_ok); - } - } -} - -#[cfg(test)] -pub(crate) mod tests_signed { - use super::*; - use crate::integer::keycache::KEY_CACHE; - use crate::integer::server_key::radix_parallel::tests_signed::{ - random_non_zero_value, signed_add_under_modulus, - }; - use crate::integer::server_key::radix_parallel::tests_unsigned::{ - nb_tests_smaller_for_params, NB_CTXT, - }; - use crate::integer::{IntegerKeyKind, RadixClientKey}; - use crate::shortint::PBSParameters; - use rand::Rng; - - fn default_test_count_consecutive_bits( - direction: Direction, - bit_value: BitValue, - param: P, - sks_method: F, - ) where - P: Into, - F: for<'a> Fn(&'a ServerKey, &'a SignedRadixCiphertext) -> RadixCiphertext, - { - let param = param.into(); - let nb_tests_smaller = nb_tests_smaller_for_params(param); - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); - - let compute_expected_clear = |x: i64| match (direction, bit_value) { - (Direction::Trailing, BitValue::Zero) => { - if x == 0 { - num_bits - } else { - x.trailing_zeros() - } - } - (Direction::Trailing, BitValue::One) => x.trailing_ones().min(num_bits), - (Direction::Leading, BitValue::Zero) => { - if x == 0 { - num_bits - } else { - (x << (u64::BITS - num_bits)).leading_zeros() - } - } - (Direction::Leading, BitValue::One) => (x << (u64::BITS - num_bits)).leading_ones(), - }; - - let method_name = match (direction, bit_value) { - (Direction::Trailing, BitValue::Zero) => "trailing_zeros", - (Direction::Trailing, BitValue::One) => "trailing_ones", - (Direction::Leading, BitValue::Zero) => "leading_zeros", - (Direction::Leading, BitValue::One) => "leading_ones", - }; - - let input_values = [-modulus, 0i64, modulus - 1] - .into_iter() - .chain((0..nb_tests_smaller).map(|_| rng.gen::() % modulus)) - .collect::>(); - - for clear in input_values { - let ctxt = cks.encrypt_signed(clear); - - let ct_res = sks_method(&sks, &ctxt); - let tmp = sks_method(&sks, &ctxt); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - let expected_result = compute_expected_clear(clear); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for {method_name}, for {clear}.{method_name}() \ - expected {expected_result}, got {decrypted_result}" - ); - - for _ in 0..nb_tests_smaller { - // Add non-zero scalar to have non-clean ciphertexts - let clear_2 = random_non_zero_value(&mut rng, modulus); - - let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); - - let clear = signed_add_under_modulus(clear, clear_2, modulus); - - let d0: i64 = cks.decrypt_signed(&ctxt); - assert_eq!(d0, clear, "Failed sanity decryption check"); - - let ct_res = sks_method(&sks, &ctxt); - assert!(ct_res.block_carries_are_empty()); - - let expected_result = compute_expected_clear(clear); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for {method_name}, for {clear}.{method_name}() \ - expected {expected_result}, got {decrypted_result}" - ); - } - } - - let input_values = [-modulus, 0i64, modulus - 1] - .into_iter() - .chain((0..nb_tests_smaller).map(|_| rng.gen::() % modulus)); - - for clear in input_values { - let ctxt = sks.create_trivial_radix(clear, NB_CTXT); - - let ct_res = sks_method(&sks, &ctxt); - assert!(ct_res.block_carries_are_empty()); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - let expected_result = compute_expected_clear(clear); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for {method_name}, for {clear}.{method_name}() \ - expected {expected_result}, got {decrypted_result}" - ); - } - } - - pub(crate) fn default_trailing_zeros_test

(param: P) - where - P: Into, - { - default_test_count_consecutive_bits( - Direction::Trailing, - BitValue::Zero, - param, - ServerKey::trailing_zeros_parallelized, - ); - } - - pub(crate) fn default_trailing_ones_test

(param: P) - where - P: Into, - { - default_test_count_consecutive_bits( - Direction::Trailing, - BitValue::One, - param, - ServerKey::trailing_ones_parallelized, - ); - } - - pub(crate) fn default_leading_zeros_test

(param: P) - where - P: Into, - { - default_test_count_consecutive_bits( - Direction::Leading, - BitValue::Zero, - param, - ServerKey::leading_zeros_parallelized, - ); - } - - pub(crate) fn default_leading_ones_test

(param: P) - where - P: Into, - { - default_test_count_consecutive_bits( - Direction::Leading, - BitValue::One, - param, - ServerKey::leading_ones_parallelized, - ); - } - - pub(crate) fn default_ilog2_test

(param: P) - where - P: Into, - { - let param = param.into(); - let nb_tests_smaller = nb_tests_smaller_for_params(param); - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); - - // Test with invalid input - { - for clear in [0i64, rng.gen_range(-modulus..=-1i64)] { - let ctxt = cks.encrypt_signed(clear); - - let ct_res = sks.ilog2_parallelized(&ctxt); - assert!(ct_res.block_carries_are_empty()); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - let expected_result = if clear < 0 { - num_bits - 1 - } else { - let counter_num_blocks = ((num_bits - 1).ilog2() + 1 + 1) - .div_ceil(cks.parameters().message_modulus().0.ilog2()) - as usize; - (1u32 - << (counter_num_blocks as u32 - * cks.parameters().message_modulus().0.ilog2())) - - 1 - }; - assert_eq!( - decrypted_result, expected_result, - "Invalid result for ilog2 for {clear}.ilog2() \ - expected {expected_result}, got {decrypted_result}" - ); - } - } - - let input_values = (0..num_bits - 1) - .map(|i| 1 << i) - .chain( - (0..nb_tests_smaller.saturating_sub(num_bits as usize)) - .map(|_| rng.gen_range(1..modulus)), - ) - .collect::>(); - - for clear in input_values { - let ctxt = cks.encrypt_signed(clear); - - let ct_res = sks.ilog2_parallelized(&ctxt); - let tmp = sks.ilog2_parallelized(&ctxt); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - let expected_result = clear.ilog2(); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for ilog2 for {clear}.ilog2() \ - expected {expected_result}, got {decrypted_result}" - ); - - for _ in 0..nb_tests_smaller { - // Add non-zero scalar to have non-clean ciphertexts - // But here, we have to make sure clear is still > 0 - // as we are only testing valid ilog2 inputs - let (clear, clear_2) = loop { - let clear_2 = random_non_zero_value(&mut rng, modulus); - let clear = signed_add_under_modulus(clear, clear_2, modulus); - if clear > 0 { - break (clear, clear_2); - } - }; - - let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); - - let d0: i64 = cks.decrypt_signed(&ctxt); - assert_eq!(d0, clear, "Failed sanity decryption check"); - - let ct_res = sks.ilog2_parallelized(&ctxt); - assert!(ct_res.block_carries_are_empty()); - - let expected_result = clear.ilog2(); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for ilog2, for {clear}.ilog2() \ - expected {expected_result}, got {decrypted_result}" - ); - } - } - - let input_values = (0..num_bits - 1) - .map(|i| 1 << i) - .chain( - (0..nb_tests_smaller.saturating_sub(num_bits as usize)) - .map(|_| rng.gen_range(1..modulus)), - ) - .collect::>(); - - for clear in input_values { - let ctxt: SignedRadixCiphertext = sks.create_trivial_radix(clear, NB_CTXT); - - let ct_res = sks.ilog2_parallelized(&ctxt); - let tmp = sks.ilog2_parallelized(&ctxt); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - let expected_result = clear.ilog2(); - - assert_eq!( - decrypted_result, expected_result, - "Invalid result for ilog2, for {clear}.ilog2() \ - expected {expected_result}, got {decrypted_result}" - ); - } - } - - pub(crate) fn default_checked_ilog2_test

(param: P) - where - P: Into, - { - let param = param.into(); - let nb_tests_smaller = nb_tests_smaller_for_params(param); - let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - - let mut rng = rand::thread_rng(); - - // message_modulus^vec_length - let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; - - let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); - - // Test with invalid input - { - for clear in [0i64, rng.gen_range(-modulus..=-1i64)] { - let ctxt = cks.encrypt_signed(clear); - - let (ct_res, is_ok) = sks.checked_ilog2_parallelized(&ctxt); - assert!(ct_res.block_carries_are_empty()); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - let expected_result = if clear < 0 { - num_bits - 1 - } else { - let counter_num_blocks = ((num_bits - 1).ilog2() + 1 + 1) - .div_ceil(cks.parameters().message_modulus().0.ilog2()) - as usize; - (1u32 - << (counter_num_blocks as u32 - * cks.parameters().message_modulus().0.ilog2())) - - 1 - }; - assert_eq!( - decrypted_result, expected_result, - "Invalid result for ilog2 for {clear}.ilog2() \ - expected {expected_result}, got {decrypted_result}" - ); - let is_ok = cks.decrypt_bool(&is_ok); - assert!(!is_ok); - } - } - - let input_values = (0..num_bits - 1) - .map(|i| 1 << i) - .chain( - (0..nb_tests_smaller.saturating_sub(num_bits as usize)) - .map(|_| rng.gen_range(1..modulus)), - ) - .collect::>(); - - for clear in input_values { - let ctxt = cks.encrypt_signed(clear); - - let (ct_res, is_ok) = sks.checked_ilog2_parallelized(&ctxt); - let (tmp, tmp_is_ok) = sks.checked_ilog2_parallelized(&ctxt); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); - assert_eq!(is_ok, tmp_is_ok); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - let expected_result = clear.ilog2(); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for ilog2 for {clear}.ilog2() \ - expected {expected_result}, got {decrypted_result}" - ); - let is_ok = cks.decrypt_bool(&is_ok); - assert!(is_ok); - - for _ in 0..nb_tests_smaller { - // Add non-zero scalar to have non-clean ciphertexts - // But here, we have to make sure clear is still > 0 - // as we are only testing valid ilog2 inputs - let (clear, clear_2) = loop { - let clear_2 = random_non_zero_value(&mut rng, modulus); - let clear = signed_add_under_modulus(clear, clear_2, modulus); - if clear > 0 { - break (clear, clear_2); - } - }; - - let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); - - let d0: i64 = cks.decrypt_signed(&ctxt); - assert_eq!(d0, clear, "Failed sanity decryption check"); - - let (ct_res, is_ok) = sks.checked_ilog2_parallelized(&ctxt); - assert!(ct_res.block_carries_are_empty()); - assert_eq!(is_ok.as_ref().degree.get(), 1); - - let expected_result = clear.ilog2(); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - assert_eq!( - decrypted_result, expected_result, - "Invalid result for ilog2, for {clear}.ilog2() \ - expected {expected_result}, got {decrypted_result}" - ); - let is_ok = cks.decrypt_bool(&is_ok); - assert!(is_ok); - } - } - - let input_values = (0..num_bits - 1) - .map(|i| 1 << i) - .chain( - (0..nb_tests_smaller.saturating_sub(num_bits as usize)) - .map(|_| rng.gen_range(1..modulus)), - ) - .collect::>(); - - for clear in input_values { - let ctxt: SignedRadixCiphertext = sks.create_trivial_radix(clear, NB_CTXT); - - let (ct_res, is_ok) = sks.checked_ilog2_parallelized(&ctxt); - assert!(ct_res.block_carries_are_empty()); - - let decrypted_result: u32 = cks.decrypt(&ct_res); - let expected_result = clear.ilog2(); - - assert_eq!( - decrypted_result, expected_result, - "Invalid result for ilog2, for {clear}.ilog2() \ - expected {expected_result}, got {decrypted_result}" - ); - let is_ok = cks.decrypt_bool(&is_ok); - assert!(is_ok); - } - } -} diff --git a/tfhe/src/integer/server_key/radix_parallel/mod.rs b/tfhe/src/integer/server_key/radix_parallel/mod.rs index 621a21fbba..d6c4310e6b 100644 --- a/tfhe/src/integer/server_key/radix_parallel/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/mod.rs @@ -21,7 +21,7 @@ mod shift; pub(crate) mod sub; mod sum; -mod ilog2; +pub(crate) mod ilog2; mod reverse_bits; #[cfg(test)] pub(crate) mod tests_cases_unsigned; diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs index 9847cb44aa..18f113d3c0 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs @@ -1838,12 +1838,6 @@ where } } -// Re-exports to still have tests case accessible from the same location -pub(crate) use crate::integer::server_key::radix_parallel::ilog2::tests_unsigned::{ - default_checked_ilog2_test, default_ilog2_test, default_leading_ones_test, - default_leading_zeros_test, default_trailing_ones_test, default_trailing_zeros_test, -}; - //============================================================================= // Default Scalar Tests //============================================================================= diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs index 074c4971b6..fe17f8c606 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/mod.rs @@ -3,6 +3,7 @@ pub(crate) mod test_add; pub(crate) mod test_bitwise_op; pub(crate) mod test_cmux; pub(crate) mod test_comparison; +pub(crate) mod test_ilog2; pub(crate) mod test_mul; pub(crate) mod test_neg; pub(crate) mod test_rotate; @@ -24,7 +25,7 @@ use crate::integer::server_key::radix_parallel::tests_unsigned::{ }; use crate::integer::tests::create_parametrized_test; use crate::integer::{ - BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext, + BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey, SignedRadixCiphertext, }; #[cfg(tarpaulin)] use crate::shortint::parameters::coverage_parameters::*; @@ -66,6 +67,34 @@ where (self.func)(sks, input) } } +impl<'a, F> FunctionExecutor<&'a SignedRadixCiphertext, (RadixCiphertext, BooleanBlock)> + for CpuFunctionExecutor +where + F: Fn(&ServerKey, &SignedRadixCiphertext) -> (RadixCiphertext, BooleanBlock), +{ + fn setup(&mut self, _cks: &RadixClientKey, sks: Arc) { + self.sks = Some(sks); + } + + fn execute(&mut self, input: &'a SignedRadixCiphertext) -> (RadixCiphertext, BooleanBlock) { + let sks = self.sks.as_ref().expect("setup was not properly called"); + (self.func)(sks, input) + } +} + +impl<'a, F> FunctionExecutor<&'a SignedRadixCiphertext, RadixCiphertext> for CpuFunctionExecutor +where + F: Fn(&ServerKey, &SignedRadixCiphertext) -> RadixCiphertext, +{ + fn setup(&mut self, _cks: &RadixClientKey, sks: Arc) { + self.sks = Some(sks); + } + + fn execute(&mut self, input: &'a SignedRadixCiphertext) -> RadixCiphertext { + let sks = self.sks.as_ref().expect("setup was not properly called"); + (self.func)(sks, input) + } +} impl<'a, F> FunctionExecutor<&'a mut SignedRadixCiphertext, ()> for CpuFunctionExecutor where F: Fn(&ServerKey, &'a mut SignedRadixCiphertext), @@ -419,22 +448,6 @@ fn integer_signed_smart_absolute_value(param: impl Into) { create_parametrized_test!(integer_signed_default_absolute_value); -create_parametrized_test!(integer_signed_default_trailing_zeros); -create_parametrized_test!(integer_signed_default_trailing_ones); -create_parametrized_test!(integer_signed_default_leading_zeros); -create_parametrized_test!(integer_signed_default_leading_ones); -create_parametrized_test!(integer_signed_default_ilog2); -create_parametrized_test!(integer_signed_default_checked_ilog2 { - // uses comparison so 1_1 parameters are not supported - PARAM_MESSAGE_2_CARRY_2_KS_PBS, - PARAM_MESSAGE_3_CARRY_3_KS_PBS, - PARAM_MESSAGE_4_CARRY_4_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS -}); - fn integer_signed_default_absolute_value(param: impl Into) { let param = param.into(); let nb_tests_smaller = nb_tests_smaller_for_params(param); @@ -471,58 +484,6 @@ fn integer_signed_default_absolute_value(param: impl Into) { } } -fn integer_signed_default_trailing_zeros

(param: P) -where - P: Into, -{ - crate::integer::server_key::radix_parallel::ilog2::tests_signed::default_trailing_zeros_test( - param, - ); -} - -fn integer_signed_default_trailing_ones

(param: P) -where - P: Into, -{ - crate::integer::server_key::radix_parallel::ilog2::tests_signed::default_trailing_ones_test( - param, - ); -} - -fn integer_signed_default_leading_zeros

(param: P) -where - P: Into, -{ - crate::integer::server_key::radix_parallel::ilog2::tests_signed::default_leading_zeros_test( - param, - ); -} - -fn integer_signed_default_leading_ones

(param: P) -where - P: Into, -{ - crate::integer::server_key::radix_parallel::ilog2::tests_signed::default_leading_ones_test( - param, - ); -} - -fn integer_signed_default_ilog2

(param: P) -where - P: Into, -{ - crate::integer::server_key::radix_parallel::ilog2::tests_signed::default_ilog2_test(param); -} - -fn integer_signed_default_checked_ilog2

(param: P) -where - P: Into, -{ - crate::integer::server_key::radix_parallel::ilog2::tests_signed::default_checked_ilog2_test( - param, - ); -} - //================================================================================ // Unchecked Scalar Tests //================================================================================ diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_ilog2.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_ilog2.rs new file mode 100644 index 0000000000..f63ab41e12 --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_ilog2.rs @@ -0,0 +1,504 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::ilog2::{BitValue, Direction}; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_signed::{ + random_non_zero_value, signed_add_under_modulus, NB_CTXT, +}; +use crate::integer::server_key::radix_parallel::tests_unsigned::{ + nb_tests_smaller_for_params, CpuFunctionExecutor, +}; +use crate::integer::tests::create_parametrized_test; +use crate::integer::{ + BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey, SignedRadixCiphertext, +}; +#[cfg(tarpaulin)] +use crate::shortint::parameters::coverage_parameters::*; +use crate::shortint::parameters::*; +use crate::shortint::PBSParameters; +use rand::Rng; +use std::sync::Arc; + +create_parametrized_test!(integer_signed_default_trailing_zeros); +create_parametrized_test!(integer_signed_default_trailing_ones); +create_parametrized_test!(integer_signed_default_leading_zeros); +create_parametrized_test!(integer_signed_default_leading_ones); +create_parametrized_test!(integer_signed_default_ilog2); +create_parametrized_test!(integer_signed_default_checked_ilog2 { + // uses comparison so 1_1 parameters are not supported + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + PARAM_MESSAGE_3_CARRY_3_KS_PBS, + PARAM_MESSAGE_4_CARRY_4_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS +}); + +fn integer_signed_default_trailing_zeros

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::trailing_zeros_parallelized); + default_trailing_zeros_test(param, executor); +} + +fn integer_signed_default_trailing_ones

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::trailing_ones_parallelized); + default_trailing_ones_test(param, executor); +} + +fn integer_signed_default_leading_zeros

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::leading_zeros_parallelized); + default_leading_zeros_test(param, executor); +} + +fn integer_signed_default_leading_ones

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::leading_ones_parallelized); + default_leading_ones_test(param, executor); +} + +fn integer_signed_default_ilog2

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::ilog2_parallelized); + default_ilog2_test(param, executor); +} + +fn integer_signed_default_checked_ilog2

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::checked_ilog2_parallelized); + default_checked_ilog2_test(param, executor); +} + +pub(crate) fn signed_default_count_consecutive_bits_test( + direction: Direction, + bit_value: BitValue, + param: P, + mut executor: T, +) where + P: Into, + T: for<'a> FunctionExecutor<&'a SignedRadixCiphertext, RadixCiphertext>, +{ + let param = param.into(); + let nb_tests_smaller = nb_tests_smaller_for_params(param); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks.clone()); + + // message_modulus^vec_length + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); + + let compute_expected_clear = |x: i64| match (direction, bit_value) { + (Direction::Trailing, BitValue::Zero) => { + if x == 0 { + num_bits + } else { + x.trailing_zeros() + } + } + (Direction::Trailing, BitValue::One) => x.trailing_ones().min(num_bits), + (Direction::Leading, BitValue::Zero) => { + if x == 0 { + num_bits + } else { + (x << (u64::BITS - num_bits)).leading_zeros() + } + } + (Direction::Leading, BitValue::One) => (x << (u64::BITS - num_bits)).leading_ones(), + }; + + let method_name = match (direction, bit_value) { + (Direction::Trailing, BitValue::Zero) => "trailing_zeros", + (Direction::Trailing, BitValue::One) => "trailing_ones", + (Direction::Leading, BitValue::Zero) => "leading_zeros", + (Direction::Leading, BitValue::One) => "leading_ones", + }; + + let input_values = [-modulus, 0i64, modulus - 1] + .into_iter() + .chain((0..nb_tests_smaller).map(|_| rng.gen::() % modulus)) + .collect::>(); + + for clear in input_values { + let ctxt = cks.encrypt_signed(clear); + + let ct_res = executor.execute(&ctxt); + let tmp = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = compute_expected_clear(clear); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for {method_name}, for {clear}.{method_name}() \ + expected {expected_result}, got {decrypted_result}" + ); + + for _ in 0..nb_tests_smaller { + // Add non-zero scalar to have non-clean ciphertexts + let clear_2 = random_non_zero_value(&mut rng, modulus); + + let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); + + let clear = signed_add_under_modulus(clear, clear_2, modulus); + + let d0: i64 = cks.decrypt_signed(&ctxt); + assert_eq!(d0, clear, "Failed sanity decryption check"); + + let ct_res = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let expected_result = compute_expected_clear(clear); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for {method_name}, for {clear}.{method_name}() \ + expected {expected_result}, got {decrypted_result}" + ); + } + } + + let input_values = [-modulus, 0i64, modulus - 1] + .into_iter() + .chain((0..nb_tests_smaller).map(|_| rng.gen::() % modulus)); + + for clear in input_values { + let ctxt = sks.create_trivial_radix(clear, NB_CTXT); + + let ct_res = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = compute_expected_clear(clear); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for {method_name}, for {clear}.{method_name}() \ + expected {expected_result}, got {decrypted_result}" + ); + } +} + +pub(crate) fn default_trailing_zeros_test(param: P, executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a SignedRadixCiphertext, RadixCiphertext>, +{ + signed_default_count_consecutive_bits_test( + Direction::Trailing, + BitValue::Zero, + param, + executor, + ); +} + +pub(crate) fn default_trailing_ones_test(param: P, executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a SignedRadixCiphertext, RadixCiphertext>, +{ + signed_default_count_consecutive_bits_test(Direction::Trailing, BitValue::One, param, executor); +} + +pub(crate) fn default_leading_zeros_test(param: P, executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a SignedRadixCiphertext, RadixCiphertext>, +{ + signed_default_count_consecutive_bits_test(Direction::Leading, BitValue::Zero, param, executor); +} + +pub(crate) fn default_leading_ones_test(param: P, executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a SignedRadixCiphertext, RadixCiphertext>, +{ + signed_default_count_consecutive_bits_test(Direction::Leading, BitValue::One, param, executor); +} + +pub(crate) fn default_ilog2_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a SignedRadixCiphertext, RadixCiphertext>, +{ + let param = param.into(); + let nb_tests_smaller = nb_tests_smaller_for_params(param); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + executor.setup(&cks, sks.clone()); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); + + // Test with invalid input + { + for clear in [0i64, rng.gen_range(-modulus..=-1i64)] { + let ctxt = cks.encrypt_signed(clear); + + let ct_res = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = if clear < 0 { + num_bits - 1 + } else { + let counter_num_blocks = ((num_bits - 1).ilog2() + 1 + 1) + .div_ceil(cks.parameters().message_modulus().0.ilog2()) + as usize; + (1u32 << (counter_num_blocks as u32 * cks.parameters().message_modulus().0.ilog2())) + - 1 + }; + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2 for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + } + } + + let input_values = (0..num_bits - 1) + .map(|i| 1 << i) + .chain( + (0..nb_tests_smaller.saturating_sub(num_bits as usize)) + .map(|_| rng.gen_range(1..modulus)), + ) + .collect::>(); + + for clear in input_values { + let ctxt = cks.encrypt_signed(clear); + + let ct_res = executor.execute(&ctxt); + let tmp = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = clear.ilog2(); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2 for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + + for _ in 0..nb_tests_smaller { + // Add non-zero scalar to have non-clean ciphertexts + // But here, we have to make sure clear is still > 0 + // as we are only testing valid ilog2 inputs + let (clear, clear_2) = loop { + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear = signed_add_under_modulus(clear, clear_2, modulus); + if clear > 0 { + break (clear, clear_2); + } + }; + + let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); + + let d0: i64 = cks.decrypt_signed(&ctxt); + assert_eq!(d0, clear, "Failed sanity decryption check"); + + let ct_res = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let expected_result = clear.ilog2(); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2, for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + } + } + + let input_values = (0..num_bits - 1) + .map(|i| 1 << i) + .chain( + (0..nb_tests_smaller.saturating_sub(num_bits as usize)) + .map(|_| rng.gen_range(1..modulus)), + ) + .collect::>(); + + for clear in input_values { + let ctxt: SignedRadixCiphertext = sks.create_trivial_radix(clear, NB_CTXT); + + let ct_res = executor.execute(&ctxt); + let tmp = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = clear.ilog2(); + + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2, for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + } +} + +pub(crate) fn default_checked_ilog2_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a SignedRadixCiphertext, (RadixCiphertext, BooleanBlock)>, +{ + let param = param.into(); + let nb_tests_smaller = nb_tests_smaller_for_params(param); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + + let mut rng = rand::thread_rng(); + let sks = Arc::new(sks); + executor.setup(&cks, sks.clone()); + + // message_modulus^vec_length + let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64; + + let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); + + // Test with invalid input + { + for clear in [0i64, rng.gen_range(-modulus..=-1i64)] { + let ctxt = cks.encrypt_signed(clear); + + let (ct_res, is_ok) = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = if clear < 0 { + num_bits - 1 + } else { + let counter_num_blocks = ((num_bits - 1).ilog2() + 1 + 1) + .div_ceil(cks.parameters().message_modulus().0.ilog2()) + as usize; + (1u32 << (counter_num_blocks as u32 * cks.parameters().message_modulus().0.ilog2())) + - 1 + }; + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2 for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + let is_ok = cks.decrypt_bool(&is_ok); + assert!(!is_ok); + } + } + + let input_values = (0..num_bits - 1) + .map(|i| 1 << i) + .chain( + (0..nb_tests_smaller.saturating_sub(num_bits as usize)) + .map(|_| rng.gen_range(1..modulus)), + ) + .collect::>(); + + for clear in input_values { + let ctxt = cks.encrypt_signed(clear); + + let (ct_res, is_ok) = executor.execute(&ctxt); + let (tmp, tmp_is_ok) = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + assert_eq!(is_ok, tmp_is_ok); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = clear.ilog2(); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2 for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + let is_ok = cks.decrypt_bool(&is_ok); + assert!(is_ok); + + for _ in 0..nb_tests_smaller { + // Add non-zero scalar to have non-clean ciphertexts + // But here, we have to make sure clear is still > 0 + // as we are only testing valid ilog2 inputs + let (clear, clear_2) = loop { + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear = signed_add_under_modulus(clear, clear_2, modulus); + if clear > 0 { + break (clear, clear_2); + } + }; + + let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); + + let d0: i64 = cks.decrypt_signed(&ctxt); + assert_eq!(d0, clear, "Failed sanity decryption check"); + + let (ct_res, is_ok) = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(is_ok.as_ref().degree.get(), 1); + + let expected_result = clear.ilog2(); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2, for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + let is_ok = cks.decrypt_bool(&is_ok); + assert!(is_ok); + } + } + + let input_values = (0..num_bits - 1) + .map(|i| 1 << i) + .chain( + (0..nb_tests_smaller.saturating_sub(num_bits as usize)) + .map(|_| rng.gen_range(1..modulus)), + ) + .collect::>(); + + for clear in input_values { + let ctxt: SignedRadixCiphertext = sks.create_trivial_radix(clear, NB_CTXT); + + let (ct_res, is_ok) = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = clear.ilog2(); + + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2, for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + let is_ok = cks.decrypt_bool(&is_ok); + assert!(is_ok); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs index ed23a5580d..a338445daf 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs @@ -4,6 +4,7 @@ pub(crate) mod test_bitwise_op; pub(crate) mod test_cmux; pub(crate) mod test_comparison; pub(crate) mod test_div_mod; +pub(crate) mod test_ilog2; pub(crate) mod test_mul; pub(crate) mod test_neg; pub(crate) mod test_rotate; @@ -483,22 +484,6 @@ create_parametrized_test!( ); // left/right rotations create_parametrized_test!(integer_trim_radix_msb_blocks_handles_dirty_inputs); -create_parametrized_test!(integer_default_trailing_zeros); -create_parametrized_test!(integer_default_trailing_ones); -create_parametrized_test!(integer_default_leading_zeros); -create_parametrized_test!(integer_default_leading_ones); -create_parametrized_test!(integer_default_ilog2); -create_parametrized_test!(integer_default_checked_ilog2 { - // This uses comparisons, so require more than 1 bit - PARAM_MESSAGE_2_CARRY_2_KS_PBS, - PARAM_MESSAGE_3_CARRY_3_KS_PBS, - PARAM_MESSAGE_4_CARRY_4_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS, - PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS -}); - create_parametrized_test!( integer_full_propagate { coverage => { @@ -700,54 +685,6 @@ where // Default Tests //============================================================================= -fn integer_default_trailing_zeros

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::trailing_zeros_parallelized); - default_trailing_zeros_test(param, executor); -} - -fn integer_default_trailing_ones

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::trailing_ones_parallelized); - default_trailing_ones_test(param, executor); -} - -fn integer_default_leading_zeros

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::leading_zeros_parallelized); - default_leading_zeros_test(param, executor); -} - -fn integer_default_leading_ones

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::leading_ones_parallelized); - default_leading_ones_test(param, executor); -} - -fn integer_default_ilog2

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::ilog2_parallelized); - default_ilog2_test(param, executor); -} - -fn integer_default_checked_ilog2

(param: P) -where - P: Into, -{ - let executor = CpuFunctionExecutor::new(&ServerKey::checked_ilog2_parallelized); - default_checked_ilog2_test(param, executor); -} - #[test] #[cfg(not(tarpaulin))] fn test_non_regression_clone_from() { diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_ilog2.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_ilog2.rs new file mode 100644 index 0000000000..9673707856 --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_ilog2.rs @@ -0,0 +1,488 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::ilog2::{BitValue, Direction}; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_unsigned::{ + nb_tests_smaller_for_params, random_non_zero_value, CpuFunctionExecutor, NB_CTXT, +}; +use crate::integer::tests::create_parametrized_test; +use crate::integer::{BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey}; +#[cfg(tarpaulin)] +use crate::shortint::parameters::coverage_parameters::*; +use crate::shortint::parameters::*; +use rand::Rng; +use std::sync::Arc; + +create_parametrized_test!(integer_default_trailing_zeros); +create_parametrized_test!(integer_default_trailing_ones); +create_parametrized_test!(integer_default_leading_zeros); +create_parametrized_test!(integer_default_leading_ones); +create_parametrized_test!(integer_default_ilog2); +create_parametrized_test!(integer_default_checked_ilog2 { + // This uses comparisons, so require more than 1 bit + PARAM_MESSAGE_2_CARRY_2_KS_PBS, + PARAM_MESSAGE_3_CARRY_3_KS_PBS, + PARAM_MESSAGE_4_CARRY_4_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS +}); + +fn integer_default_trailing_zeros

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::trailing_zeros_parallelized); + default_trailing_zeros_test(param, executor); +} + +fn integer_default_trailing_ones

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::trailing_ones_parallelized); + default_trailing_ones_test(param, executor); +} + +fn integer_default_leading_zeros

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::leading_zeros_parallelized); + default_leading_zeros_test(param, executor); +} + +fn integer_default_leading_ones

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::leading_ones_parallelized); + default_leading_ones_test(param, executor); +} + +fn integer_default_ilog2

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::ilog2_parallelized); + default_ilog2_test(param, executor); +} + +fn integer_default_checked_ilog2

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::checked_ilog2_parallelized); + default_checked_ilog2_test(param, executor); +} + +pub(crate) fn default_count_consecutive_bits_test( + direction: Direction, + bit_value: BitValue, + param: P, + mut executor: T, +) where + P: Into, + T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, +{ + let param = param.into(); + let nb_tests_smaller = nb_tests_smaller_for_params(param); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks.clone()); + + let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); + + let compute_expected_clear = |x: u64| match (direction, bit_value) { + (Direction::Trailing, BitValue::Zero) => { + if x == 0 { + num_bits + } else { + x.trailing_zeros() + } + } + (Direction::Trailing, BitValue::One) => x.trailing_ones(), + (Direction::Leading, BitValue::Zero) => { + if x == 0 { + num_bits + } else { + (x << (u64::BITS - num_bits)).leading_zeros() + } + } + (Direction::Leading, BitValue::One) => (x << (u64::BITS - num_bits)).leading_ones(), + }; + + let method_name = match (direction, bit_value) { + (Direction::Trailing, BitValue::Zero) => "trailing_zeros", + (Direction::Trailing, BitValue::One) => "trailing_ones", + (Direction::Leading, BitValue::Zero) => "leading_zeros", + (Direction::Leading, BitValue::One) => "leading_ones", + }; + + let input_values = [0u64, modulus - 1] + .into_iter() + .chain((0..nb_tests_smaller).map(|_| rng.gen::() % modulus)) + .collect::>(); + + for clear in input_values { + let ctxt = cks.encrypt(clear); + + let ct_res = executor.execute(&ctxt); + let tmp = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = compute_expected_clear(clear); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for {method_name}, for {clear}.{method_name}() \ + expected {expected_result}, got {decrypted_result}" + ); + + for _ in 0..nb_tests_smaller { + // Add non-zero scalar to have non-clean ciphertexts + let clear_2 = random_non_zero_value(&mut rng, modulus); + + let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); + + let clear = clear.wrapping_add(clear_2) % modulus; + + let d0: u64 = cks.decrypt(&ctxt); + assert_eq!(d0, clear, "Failed sanity decryption check"); + + let ct_res = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let expected_result = compute_expected_clear(clear); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for {method_name}, for {clear}.{method_name}() \ + expected {expected_result}, got {decrypted_result}" + ); + } + } + + let input_values = [0u64, modulus - 1] + .into_iter() + .chain((0..nb_tests_smaller).map(|_| rng.gen::() % modulus)); + + for clear in input_values { + let ctxt = sks.create_trivial_radix(clear, NB_CTXT); + + let ct_res = executor.execute(&ctxt); + let tmp = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = compute_expected_clear(clear); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for {method_name}, for {clear}.{method_name}() \ + expected {expected_result}, got {decrypted_result}" + ); + } +} + +pub(crate) fn default_trailing_zeros_test(param: P, executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, +{ + default_count_consecutive_bits_test(Direction::Trailing, BitValue::Zero, param, executor); +} + +pub(crate) fn default_trailing_ones_test(param: P, executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, +{ + default_count_consecutive_bits_test(Direction::Trailing, BitValue::One, param, executor); +} + +pub(crate) fn default_leading_zeros_test(param: P, executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, +{ + default_count_consecutive_bits_test(Direction::Leading, BitValue::Zero, param, executor); +} + +pub(crate) fn default_leading_ones_test(param: P, executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, +{ + default_count_consecutive_bits_test(Direction::Leading, BitValue::One, param, executor); +} + +pub(crate) fn default_ilog2_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, +{ + let param = param.into(); + let nb_tests_smaller = nb_tests_smaller_for_params(param); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks.clone()); + + let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); + + // Test with invalid input + { + let ctxt = cks.encrypt(0u64); + + let ct_res = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let counter_num_blocks = ((num_bits - 1).ilog2() + 1 + 1) + .div_ceil(cks.parameters().message_modulus().0.ilog2()) + as usize; + let expected_result = (1u32 + << (counter_num_blocks as u32 * cks.parameters().message_modulus().0.ilog2())) + - 1; + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2 for 0.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + } + + let input_values = (0..num_bits) + .map(|i| 1 << i) + .chain( + (0..nb_tests_smaller.saturating_sub(num_bits as usize)) + .map(|_| rng.gen_range(1..modulus)), + ) + .collect::>(); + + for clear in input_values { + let ctxt = cks.encrypt(clear); + + let ct_res = executor.execute(&ctxt); + let tmp = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = clear.ilog2(); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2 for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + + for _ in 0..nb_tests_smaller { + // Add non-zero scalar to have non-clean ciphertexts + // But here, we have to make sure clear is still > 0 + // as we are only testing valid ilog2 inputs + let (clear, clear_2) = loop { + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear = clear_2.wrapping_add(clear) % modulus; + if clear != 0 { + break (clear, clear_2); + } + }; + + let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); + + let d0: u64 = cks.decrypt(&ctxt); + assert_eq!(d0, clear, "Failed sanity decryption check"); + + let ct_res = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let expected_result = clear.ilog2(); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2, for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + } + } + + let input_values = (0..num_bits) + .map(|i| 1 << i) + .chain( + (0..nb_tests_smaller.saturating_sub(num_bits as usize)) + .map(|_| rng.gen_range(1..modulus)), + ) + .collect::>(); + + for clear in input_values { + let ctxt = sks.create_trivial_radix(clear, NB_CTXT); + + let ct_res = executor.execute(&ctxt); + let tmp = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = clear.ilog2(); + + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2, for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + } +} + +pub(crate) fn default_checked_ilog2_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<&'a RadixCiphertext, (RadixCiphertext, BooleanBlock)>, +{ + let param = param.into(); + let nb_tests_smaller = nb_tests_smaller_for_params(param); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks.clone()); + + let num_bits = NB_CTXT as u32 * cks.parameters().message_modulus().0.ilog2(); + + // Test with invalid input + { + let ctxt = cks.encrypt(0u64); + + let (ct_res, is_ok) = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(is_ok.as_ref().degree.get(), 1); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let counter_num_blocks = ((num_bits - 1).ilog2() + 1 + 1) + .div_ceil(cks.parameters().message_modulus().0.ilog2()) + as usize; + let expected_result = (1u32 + << (counter_num_blocks as u32 * cks.parameters().message_modulus().0.ilog2())) + - 1; + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2 for 0.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + let is_ok = cks.decrypt_bool(&is_ok); + assert!(!is_ok); + } + + let input_values = (0..num_bits) + .map(|i| 1 << i) + .chain( + (0..nb_tests_smaller.saturating_sub(num_bits as usize)) + .map(|_| rng.gen_range(1..modulus)), + ) + .collect::>(); + + for clear in input_values { + let ctxt = cks.encrypt(clear); + + let (ct_res, is_ok) = executor.execute(&ctxt); + let (tmp, tmp_is_ok) = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(ct_res, tmp); + assert_eq!(is_ok, tmp_is_ok); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = clear.ilog2(); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2 for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + let is_ok = cks.decrypt_bool(&is_ok); + assert!(is_ok); + + for _ in 0..nb_tests_smaller { + // Add non-zero scalar to have non-clean ciphertexts + // But here, we have to make sure clear is still > 0 + // as we are only testing valid ilog2 inputs + let (clear, clear_2) = loop { + let clear_2 = random_non_zero_value(&mut rng, modulus); + let clear = clear_2.wrapping_add(clear) % modulus; + if clear != 0 { + break (clear, clear_2); + } + }; + + let ctxt = sks.unchecked_scalar_add(&ctxt, clear_2); + + let d0: u64 = cks.decrypt(&ctxt); + assert_eq!(d0, clear, "Failed sanity decryption check"); + + let (ct_res, is_ok) = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + assert_eq!(is_ok.as_ref().degree.get(), 1); + + let expected_result = clear.ilog2(); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2, for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + let is_ok = cks.decrypt_bool(&is_ok); + assert!(is_ok); + } + } + + let input_values = (0..num_bits) + .map(|i| 1 << i) + .chain( + (0..nb_tests_smaller.saturating_sub(num_bits as usize)) + .map(|_| rng.gen_range(1..modulus)), + ) + .collect::>(); + + for clear in input_values { + let ctxt = sks.create_trivial_radix(clear, NB_CTXT); + + let (ct_res, is_ok) = executor.execute(&ctxt); + assert!(ct_res.block_carries_are_empty()); + + let decrypted_result: u32 = cks.decrypt(&ct_res); + let expected_result = clear.ilog2(); + + assert_eq!( + decrypted_result, expected_result, + "Invalid result for ilog2, for {clear}.ilog2() \ + expected {expected_result}, got {decrypted_result}" + ); + let is_ok = cks.decrypt_bool(&is_ok); + assert!(is_ok); + } +}