Skip to content

Commit

Permalink
chore(docs): simplify improved formula in dark market
Browse files Browse the repository at this point in the history
  • Loading branch information
mayeul-zama committed Sep 13, 2023
1 parent b752837 commit 622287c
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 123 deletions.
212 changes: 116 additions & 96 deletions tfhe/docs/application_tutorials/dark_market.md
Original file line number Diff line number Diff line change
Expand Up @@ -331,14 +331,16 @@ We will call the new list the "prefix sum" of the array.

The new version for the plain `fill_orders` is as follows:
```rust
let fill_orders = |orders: &mut [u64], prefix_sum: &[u64], total_orders: u64|{
fn fill_orders(total_orders: u16, orders: &mut [u16], prefix_sum_arr: &[u16]) {
orders.iter().for_each(|order : &mut u64| {
if (total_orders >= prefix_sum[i]) {
continue;
} else if total_orders >= prefix_sum.get(i-1).unwrap_or(0) {
*order = total_orders - prefix_sum.get(i-1).unwrap_or(0);
} else {
let diff = total_orders as i64 - *prefix_sum_arr.get(i - 1).unwrap_or(&0) as i64;
if (diff < 0) {
*order = 0;
} else if diff < order {
*order = diff as u16;
} else {
// *order = *order;
continue;
}
});
};
Expand All @@ -347,11 +349,13 @@ let fill_orders = |orders: &mut [u64], prefix_sum: &[u64], total_orders: u64|{
To write this new function we need transform the conditional code into a mathematical expression since FHE does not support conditional operations.
```rust

let fill_orders = |orders: &mut [u64], prefix_sum: &[u64], total_orders: u64| {
orders.iter().for_each(|order| : &mut){
*order = *order + ((total_orders >= prefix_sum - std::cmp::min(total_orders, prefix_sum.get(i - 1).unwrap_or(&0).clone()) - *order);
fn fill_orders(total_orders: u16, orders: &mut [u16], prefix_sum_arr: &[u16]) {
for (i, order) in orders.iter_mut().enumerate() {
*order = (total_orders as i64 - *prefix_sum_arr.get(i - 1).unwrap_or(&0) as i64)
.max(0)
.min(*order as i64) as u16;
}
};
}
```

New `fill_order` function requires a prefix sum array. We are going to calculate this prefix sum array in parallel
Expand All @@ -364,108 +368,124 @@ So we modify how the algorithm is implemented, but we don't change the algorithm

Here is the modified version of the algorithm in TFHE-rs:
```rust
fn volume_match_fhe_modified(
fn compute_prefix_sum(server_key: &ServerKey, arr: &[RadixCiphertext]) -> Vec<RadixCiphertext> {
if arr.is_empty() {
return arr.to_vec();
}
let mut prefix_sum: Vec<RadixCiphertext> = (0..arr.len().next_power_of_two())
.into_par_iter()
.map(|i| {
if i < arr.len() {
arr[i].clone()
} else {
server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS)
}
})
.collect();
for d in 0..prefix_sum.len().ilog2() {
prefix_sum
.par_chunks_exact_mut(2_usize.pow(d + 1))
.for_each(move |chunk| {
let length = chunk.len();
let mut left = chunk.get((length - 1) / 2).unwrap().clone();
server_key.smart_add_assign_parallelized(chunk.last_mut().unwrap(), &mut left)
});
}
let last = prefix_sum.last().unwrap().clone();
*prefix_sum.last_mut().unwrap() = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
for d in (0..prefix_sum.len().ilog2()).rev() {
prefix_sum
.par_chunks_exact_mut(2_usize.pow(d + 1))
.for_each(move |chunk| {
let length = chunk.len();
let temp = chunk.last().unwrap().clone();
let mut mid = chunk.get((length - 1) / 2).unwrap().clone();
server_key.smart_add_assign_parallelized(chunk.last_mut().unwrap(), &mut mid);
chunk[(length - 1) / 2] = temp;
});
}
prefix_sum.push(last);
prefix_sum[1..=arr.len()].to_vec()
}

fn fill_orders(
server_key: &ServerKey,
total_orders: &RadixCiphertext,
orders: &mut [RadixCiphertext],
prefix_sum_arr: &[RadixCiphertext],
) {
orders
.into_par_iter()
.enumerate()
.for_each(move |(i, order)| {
// (total_orders - prefix_sum).max(0)
let mut diff = if let Some(prefix_sum) = prefix_sum_arr.get(i - 1) {
// total_orders - prefix_sum
let mut diff = server_key
.smart_sub_parallelized(&mut total_orders.clone(), &mut prefix_sum.clone());

// total_orders > prefix_sum

let mut cond = server_key
.smart_gt_parallelized(&mut total_orders.clone(), &mut prefix_sum.clone());

// (total_orders - prefix_sum) * (total_orders > prefix_sum)
// = (total_orders - prefix_sum).max(0)
server_key.smart_mul_parallelized(&mut cond, &mut diff)
} else {
total_orders.clone()
};

// (total_orders - prefix_sum.get(i - 1)).max(0).min(*order);
*order = server_key.smart_min_parallelized(&mut diff, order);
});
}

/// FHE implementation of the volume matching algorithm.
///
/// In this function, the implemented algorithm is modified to utilize more concurrency.
///
/// Matches the given encrypted [sell_orders] with encrypted [buy_orders] using the given
/// [server_key]. The amount of the orders that are successfully filled is written over the original
/// order count.
pub fn volume_match(
sell_orders: &mut [RadixCiphertext],
buy_orders: &mut [RadixCiphertext],
server_key: &ServerKey,
) {
let compute_prefix_sum = |arr: &[RadixCiphertext]| {
if arr.is_empty() {
return arr.to_vec();
}
let mut prefix_sum: Vec<RadixCiphertext> = (0..arr.len().next_power_of_two())
.into_par_iter()
.map(|i| {
if i < arr.len() {
arr[i].clone()
} else {
server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS)
}
})
.collect();
// Up sweep
for d in 0..(prefix_sum.len().ilog2() as u32) {
prefix_sum
.par_chunks_exact_mut(2_usize.pow(d + 1))
.for_each(move |chunk| {
let length = chunk.len();
let mut left = chunk.get((length - 1) / 2).unwrap().clone();
server_key.smart_add_assign_parallelized(chunk.last_mut().unwrap(), &mut left)
});
}
// Down sweep
let last = prefix_sum.last().unwrap().clone();
*prefix_sum.last_mut().unwrap() = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
for d in (0..(prefix_sum.len().ilog2() as u32)).rev() {
prefix_sum
.par_chunks_exact_mut(2_usize.pow(d + 1))
.for_each(move |chunk| {
let length = chunk.len();
let t = chunk.last().unwrap().clone();
let mut left = chunk.get((length - 1) / 2).unwrap().clone();
server_key.smart_add_assign_parallelized(chunk.last_mut().unwrap(), &mut left);
chunk[(length - 1) / 2] = t;
});
}
prefix_sum.push(last);
prefix_sum[1..=arr.len()].to_vec()
};

println!("Creating prefix sum arrays...");
let time = Instant::now();
let (prefix_sum_sell_orders, prefix_sum_buy_orders) = rayon::join(
|| compute_prefix_sum(sell_orders),
|| compute_prefix_sum(buy_orders),
|| compute_prefix_sum(server_key, sell_orders),
|| compute_prefix_sum(server_key, buy_orders),
);
println!("Created prefix sum arrays in {:?}", time.elapsed());

let fill_orders = |total_orders: &RadixCiphertext,
orders: &mut [RadixCiphertext],
prefix_sum_arr: &[RadixCiphertext]| {
orders
.into_par_iter()
.enumerate()
.for_each(move |(i, order)| {
server_key.smart_add_assign_parallelized(
order,
&mut server_key.smart_mul_parallelized(
&mut server_key
.smart_ge_parallelized(&mut order.clone(), &mut total_orders.clone()),
&mut server_key.smart_sub_parallelized(
&mut server_key.smart_sub_parallelized(
&mut total_orders.clone(),
&mut server_key.smart_min_parallelized(
&mut total_orders.clone(),
&mut prefix_sum_arr
.get(i - 1)
.unwrap_or(
&server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS),
)
.clone(),
),
),
&mut order.clone(),
),
),
);
});
};
let zero = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);

let total_buy_orders = &mut prefix_sum_buy_orders
.last()
.unwrap_or(&server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS))
.clone();
let total_buy_orders = prefix_sum_buy_orders.last().unwrap_or(&zero);

let total_sell_orders = &mut prefix_sum_sell_orders
.last()
.unwrap_or(&server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS))
.clone();
let total_sell_orders = prefix_sum_sell_orders.last().unwrap_or(&zero);

println!("Matching orders...");
let time = Instant::now();
rayon::join(
|| fill_orders(total_sell_orders, buy_orders, &prefix_sum_buy_orders),
|| fill_orders(total_buy_orders, sell_orders, &prefix_sum_sell_orders),
|| {
fill_orders(
server_key,
total_sell_orders,
buy_orders,
&prefix_sum_buy_orders,
)
},
|| {
fill_orders(
server_key,
total_buy_orders,
sell_orders,
&prefix_sum_sell_orders,
)
},
);
println!("Matched orders in {:?}", time.elapsed());
}
Expand Down
42 changes: 20 additions & 22 deletions tfhe/examples/dark_market/improved_parallel_fhe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,28 +57,26 @@ fn fill_orders(
.into_par_iter()
.enumerate()
.for_each(move |(i, order)| {
server_key.smart_add_assign_parallelized(
order,
&mut server_key.smart_mul_parallelized(
&mut server_key
.smart_ge_parallelized(&mut order.clone(), &mut total_orders.clone()),
&mut server_key.smart_sub_parallelized(
&mut server_key.smart_sub_parallelized(
&mut total_orders.clone(),
&mut server_key.smart_min_parallelized(
&mut total_orders.clone(),
&mut prefix_sum_arr
.get(i - 1)
.unwrap_or(
&server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS),
)
.clone(),
),
),
&mut order.clone(),
),
),
);
// (total_orders - prefix_sum).max(0)
let mut diff = if let Some(prefix_sum) = prefix_sum_arr.get(i - 1) {
// total_orders - prefix_sum
let mut diff = server_key
.smart_sub_parallelized(&mut total_orders.clone(), &mut prefix_sum.clone());

// total_orders > prefix_sum

let mut cond = server_key
.smart_gt_parallelized(&mut total_orders.clone(), &mut prefix_sum.clone());

// (total_orders - prefix_sum) * (total_orders > prefix_sum)
// = (total_orders - prefix_sum).max(0)
server_key.smart_mul_parallelized(&mut cond, &mut diff)
} else {
total_orders.clone()
};

// (total_orders - prefix_sum.get(i - 1)).max(0).min(*order);
*order = server_key.smart_min_parallelized(&mut diff, order);
});
}

Expand Down
30 changes: 30 additions & 0 deletions tfhe/examples/dark_market/improved_plain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
fn compute_prefix_sum(arr: &[u16]) -> Vec<u16> {
let mut sum = 0;
arr.iter()
.map(|a| {
sum += a;
sum
})
.collect()
}

fn fill_orders(total_orders: u16, orders: &mut [u16], prefix_sum_arr: &[u16]) {
for (i, order) in orders.iter_mut().enumerate() {
*order = (total_orders as i64 - *prefix_sum_arr.get(i - 1).unwrap_or(&0) as i64)
.max(0)
.min(*order as i64) as u16;
}
}

pub fn volume_match(sell_orders: &mut [u16], buy_orders: &mut [u16]) {
let prefix_sum_sell_orders = compute_prefix_sum(sell_orders);

let prefix_sum_buy_orders = compute_prefix_sum(buy_orders);

let total_buy_orders = *prefix_sum_buy_orders.last().unwrap_or(&0);

let total_sell_orders = *prefix_sum_sell_orders.last().unwrap_or(&0);

fill_orders(total_sell_orders, buy_orders, &prefix_sum_buy_orders);
fill_orders(total_buy_orders, sell_orders, &prefix_sum_sell_orders);
}
14 changes: 10 additions & 4 deletions tfhe/examples/dark_market/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;

mod fhe;
mod improved_parallel_fhe;
mod improved_plain;
mod parallel_fhe;
mod plain;

Expand Down Expand Up @@ -62,9 +63,11 @@ fn run_test_cases(tester: impl Fn(&[u16], &[u16], &[u16], &[u16])) {
}
}

/// Runs the test cases for the fhe implementation of the volume matching algorithm.
///
/// [parallelized] indicates whether the fhe implementation should be run in parallel.
fn test_volume_match_plain(function: fn(&mut [u16], &mut [u16])) {
println!("Running test cases for the plain implementation");
run_test_cases(|a, b, c, d| plain::tester(a, b, c, d, function));
}

fn test_volume_match_fhe(
fhe_function: fn(&mut [RadixCiphertext], &mut [RadixCiphertext], &ServerKey),
) {
Expand All @@ -81,9 +84,12 @@ fn main() {
for argument in std::env::args() {
if argument == "plain" {
println!("Running plain version");
run_test_cases(plain::tester);
test_volume_match_plain(plain::volume_match);
println!();
}
if argument == "plain-improved" {
println!("Running plain improved version");
test_volume_match_plain(improved_plain::volume_match);
println!();
}
if argument == "fhe" {
Expand Down
3 changes: 2 additions & 1 deletion tfhe/examples/dark_market/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ pub fn tester(
input_buy_orders: &[u16],
expected_filled_sells: &[u16],
expected_filled_buys: &[u16],
function: fn(&mut [u16], &mut [u16]),
) {
let mut sell_orders = input_sell_orders.to_vec();
let mut buy_orders = input_buy_orders.to_vec();

println!("Running plain implementation...");
let time = Instant::now();
volume_match(&mut sell_orders, &mut buy_orders);
function(&mut sell_orders, &mut buy_orders);
println!("Ran plain implementation in {:?}", time.elapsed());

assert_eq!(sell_orders, expected_filled_sells);
Expand Down

0 comments on commit 622287c

Please sign in to comment.