Skip to content

Commit

Permalink
feat(integer): improve scalar_mul
Browse files Browse the repository at this point in the history
This changes the algorithm for scalar_mul.
The new algorithm allows to remove a lot of work.

For small precisions (16, 32, 64) the gains are in range 5%-10%
for higher precisions the gains are 25%-50%.

This also changes the mul to use the functions that sums many
clean ciphertexts in parallel. For mul, there is only a 5%-10%
improvements for 128bits and 256bits mul.
  • Loading branch information
tmontaigu committed Sep 20, 2023
1 parent 53da809 commit 9225d65
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 235 deletions.
175 changes: 174 additions & 1 deletion tfhe/src/integer/server_key/radix_parallel/add.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::integer::ciphertext::IntegerRadixCiphertext;
use crate::integer::ServerKey;
use crate::integer::{RadixCiphertext, ServerKey};
use crate::shortint::Ciphertext;

use rayon::prelude::*;
Expand Down Expand Up @@ -740,4 +740,177 @@ impl ServerKey {

reduce_impl(self, ct_seq, op)
}

/// See [Self::unchecked_sum_ciphertexts_vec_parallelized] for constraints
pub fn unchecked_sum_ciphertexts_slice_parallelized(
&self,
ciphertexts: &[RadixCiphertext],
) -> Option<RadixCiphertext> {
self.unchecked_sum_ciphertexts_vec_parallelized(ciphertexts.to_vec())
}

/// Computes the sum of the ciphertexts in parallel.
///
/// - Returns None if ciphertexts is empty
///
/// - Expexts all ciphertexts to have empty carries
/// - Expects all ciphertexts to have the same size
pub fn unchecked_sum_ciphertexts_vec_parallelized<T>(
&self,
mut ciphertexts: Vec<T>,
) -> Option<T>
where
T: IntegerRadixCiphertext,
{
if ciphertexts.is_empty() {
return None;
}

if ciphertexts.len() == 1 {
return Some(ciphertexts[0].clone());
}

if ciphertexts.len() == 2 {
return Some(self.add_parallelized(&ciphertexts[0], &ciphertexts[1]));
}

let num_blocks = ciphertexts[0].blocks().len();
assert!(
ciphertexts[1..]
.iter()
.all(|ct| ct.blocks().len() == num_blocks),
"Not all ciphertexts have the same number of blocks"
);
assert!(
ciphertexts.iter().all(|ct| ct.block_carries_are_empty()),
"All ciphertexts must have empty carries"
);

let message_modulus = self.key.message_modulus.0;
let carry_modulus = self.key.carry_modulus.0;
let total_modulus = message_modulus * carry_modulus;
let message_max = message_modulus - 1;

let num_elements_to_fill_carry = (total_modulus - 1) / message_max;

let mut tmp_out = Vec::new();

while ciphertexts.len() > num_elements_to_fill_carry {
let mut chunks_iter = ciphertexts.par_chunks_exact_mut(num_elements_to_fill_carry);
let remainder_len = chunks_iter.remainder().len();

chunks_iter
.map(|chunk| {
let (s, rest) = chunk.split_first_mut().unwrap();
let mut first_block_where_addition_happenned = num_blocks - 1;
let mut last_block_where_addition_happenned = num_blocks - 1;
for a in rest.iter() {
let first_block_to_add = a
.blocks()
.iter()
.position(|block| block.degree.0 != 0)
.unwrap_or(num_blocks);
first_block_where_addition_happenned =
first_block_where_addition_happenned.min(first_block_to_add);
let last_block_to_add = a
.blocks()
.iter()
.rev()
.position(|block| block.degree.0 != 0)
.map(|pos| num_blocks - pos - 1)
.unwrap_or(num_blocks - 1);
last_block_where_addition_happenned =
last_block_where_addition_happenned.max(last_block_to_add);
for (ct_left_i, ct_right_i) in s.blocks_mut()
[first_block_to_add..last_block_to_add + 1]
.iter_mut()
.zip(a.blocks()[first_block_to_add..last_block_to_add + 1].iter())
{
self.key.unchecked_add_assign(ct_left_i, ct_right_i);
}
}

// last carry is not interesting
let mut carry_blocks = s.blocks()
[first_block_where_addition_happenned..last_block_where_addition_happenned]
.to_vec();

let message_blocks = s.blocks_mut();

rayon::join(
|| {
message_blocks[first_block_where_addition_happenned
..last_block_where_addition_happenned + 1]
.par_iter_mut()
.for_each(|block| {
self.key.message_extract_assign(block);
});
},
|| {
carry_blocks.par_iter_mut().for_each(|block| {
self.key.carry_extract_assign(block);
});
},
);

let mut carry_ct = RadixCiphertext::from(carry_blocks);
let num_blocks_to_add = s.blocks().len() - carry_ct.blocks.len();
self.extend_radix_with_trivial_zero_blocks_lsb_assign(&mut carry_ct, num_blocks_to_add);
let carry_ct = T::from(carry_ct.blocks);
(s.clone(), carry_ct)
})
.collect_into_vec(&mut tmp_out);

// tmp_out elements are tuple of 2 elements (message, carry)
let num_ct_created = tmp_out.len() * 2;
// Ciphertexts not treated in this iteration are at the end of ciphertexts vec.
// the rotation will make them 'wrap around' and be placed at range index
// (num_ct_created..remainder_len + num_ct_created)
// We will then fill the indices in range (0..num_ct_created)
ciphertexts.rotate_right(remainder_len + num_ct_created);

// Drain elements out of tmp_out to replace them
// at the beginning of the ciphertexts left to add
for (i, (m, c)) in tmp_out.drain(..).enumerate() {
ciphertexts[i * 2] = m;
ciphertexts[(i * 2) + 1] = c;
}
ciphertexts.truncate(num_ct_created + remainder_len);
}

// Now we will add the last chunk of terms
// just as was done above, however we do it
// we want to use an addition that leaves
// the resulting ciphertext with empty carries
let (result, rest) = ciphertexts.split_first_mut().unwrap();
for term in rest.iter() {
self.unchecked_add_assign(result, term);
}

let (message_blocks, carry_blocks) = rayon::join(
|| {
result
.blocks()
.par_iter()
.map(|block| self.key.message_extract(block))
.collect::<Vec<_>>()
},
|| {
let mut carry_blocks = Vec::with_capacity(num_blocks);
result.blocks()[..num_blocks - 1] // last carry is not interesting
.par_iter()
.map(|block| self.key.carry_extract(block))
.collect_into_vec(&mut carry_blocks);
carry_blocks.insert(0, self.key.create_trivial(0));
carry_blocks
},
);

let mut result = T::from(message_blocks);
let carry = T::from(carry_blocks);
self.add_assign_parallelized(&mut result, &carry);
assert!(result.block_carries_are_empty());

Some(result)
}
}
Loading

0 comments on commit 9225d65

Please sign in to comment.