Skip to content

Commit

Permalink
expose partial matches
Browse files Browse the repository at this point in the history
  • Loading branch information
philsippl authored and eaypek-tfh committed Oct 25, 2024
1 parent d117a4d commit 5873808
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 32 deletions.
3 changes: 3 additions & 0 deletions iris-mpc-common/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ pub struct Config {
#[serde(default)]
pub fake_db_size: usize,

#[serde(default)]
pub return_partial_results: bool,

#[serde(default)]
pub disable_persistence: bool,
}
Expand Down
16 changes: 11 additions & 5 deletions iris-mpc-common/src/helpers/smpc_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,13 @@ impl UniquenessRequest {

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UniquenessResult {
pub node_id: usize,
pub serial_id: Option<u32>,
pub is_match: bool,
pub signup_id: String,
pub matched_serial_ids: Option<Vec<u32>>,
pub node_id: usize,
pub serial_id: Option<u32>,
pub is_match: bool,
pub signup_id: String,
pub matched_serial_ids: Option<Vec<u32>>,
pub matched_serial_ids_left: Option<Vec<u32>>,
pub matched_serial_ids_right: Option<Vec<u32>>,
}

impl UniquenessResult {
Expand All @@ -316,13 +318,17 @@ impl UniquenessResult {
is_match: bool,
signup_id: String,
matched_serial_ids: Option<Vec<u32>>,
matched_serial_ids_left: Option<Vec<u32>>,
matched_serial_ids_right: Option<Vec<u32>>,
) -> Self {
Self {
node_id,
serial_id,
is_match,
signup_id,
matched_serial_ids,
matched_serial_ids_left,
matched_serial_ids_right,
}
}
}
Expand Down
46 changes: 39 additions & 7 deletions iris-mpc-gpu/src/dot/distance_comparator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,26 @@ pub struct DistanceComparator {
pub final_results_init_host: Vec<u32>,
pub match_counters: Vec<CudaSlice<u32>>,
pub all_matches: Vec<CudaSlice<u32>>,
pub match_counters_left: Vec<CudaSlice<u32>>,
pub match_counters_right: Vec<CudaSlice<u32>>,
pub partial_results_left: Vec<CudaSlice<u32>>,
pub partial_results_right: Vec<CudaSlice<u32>>,
}

impl DistanceComparator {
pub fn init(query_length: usize, device_manager: Arc<DeviceManager>) -> Self {
let ptx = compile_ptx(PTX_SRC).unwrap();
let mut open_kernels = Vec::new();
let mut open_kernels: Vec<CudaFunction> = Vec::new();
let mut merge_db_kernels = Vec::new();
let mut merge_batch_kernels = Vec::new();
let mut opened_results = vec![];
let mut final_results = vec![];
let mut match_counters: Vec<CudaSlice<u32>> = vec![];
let mut all_matches: Vec<CudaSlice<u32>> = vec![];
let mut match_counters = vec![];
let mut match_counters_left = vec![];
let mut match_counters_right = vec![];
let mut all_matches = vec![];
let mut partial_results_left = vec![];
let mut partial_results_right = vec![];

let devices_count = device_manager.device_count();

Expand All @@ -63,11 +71,23 @@ impl DistanceComparator {
opened_results.push(device.htod_copy(results_init_host.clone()).unwrap());
final_results.push(device.htod_copy(final_results_init_host.clone()).unwrap());
match_counters.push(device.alloc_zeros(query_length / ROTATIONS).unwrap());
match_counters_left.push(device.alloc_zeros(query_length / ROTATIONS).unwrap());
match_counters_right.push(device.alloc_zeros(query_length / ROTATIONS).unwrap());
all_matches.push(
device
.alloc_zeros(ALL_MATCHES_LEN * query_length / ROTATIONS)
.unwrap(),
);
partial_results_left.push(
device
.alloc_zeros(ALL_MATCHES_LEN * query_length / ROTATIONS)
.unwrap(),
);
partial_results_right.push(
device
.alloc_zeros(ALL_MATCHES_LEN * query_length / ROTATIONS)
.unwrap(),
);

open_kernels.push(open_results_function);
merge_db_kernels.push(merge_db_results_function);
Expand All @@ -85,7 +105,11 @@ impl DistanceComparator {
results_init_host,
final_results_init_host,
match_counters,
match_counters_left,
match_counters_right,
all_matches,
partial_results_left,
partial_results_right,
}
}

Expand Down Expand Up @@ -213,6 +237,10 @@ impl DistanceComparator {
num_elements as u64,
&self.match_counters[i],
&self.all_matches[i],
&self.match_counters_left[i],
&self.match_counters_right[i],
&self.partial_results_left[i],
&self.partial_results_right[i],
),
)
.unwrap();
Expand All @@ -233,26 +261,30 @@ impl DistanceComparator {
results
}

pub fn fetch_match_counters(&self) -> Vec<Vec<u32>> {
pub fn fetch_match_counters(&self, counters: &[CudaSlice<u32>]) -> Vec<Vec<u32>> {
let mut results = vec![];
for i in 0..self.device_manager.device_count() {
results.push(
self.device_manager
.device(i)
.dtoh_sync_copy(&self.match_counters[i])
.dtoh_sync_copy(&counters[i])
.unwrap(),
);
}
results
}

pub fn fetch_all_match_ids(&self, match_counters: Vec<Vec<u32>>) -> Vec<Vec<u32>> {
pub fn fetch_all_match_ids(
&self,
match_counters: Vec<Vec<u32>>,
matches: &[CudaSlice<u32>],
) -> Vec<Vec<u32>> {
let mut results = vec![];
for i in 0..self.device_manager.device_count() {
results.push(
self.device_manager
.device(i)
.dtoh_sync_copy(&self.all_matches[i])
.dtoh_sync_copy(&matches[i])
.unwrap(),
);
}
Expand Down
18 changes: 16 additions & 2 deletions iris-mpc-gpu/src/dot/kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ extern "C" __global__ void openResults(unsigned long long *result1, unsigned lon
}
}

extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft, unsigned long long *matchResultsRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *matchCounter, unsigned int *allMatches)
extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft, unsigned long long *matchResultsRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *matchCounter, unsigned int *allMatches, unsigned int *matchCounterLeft, unsigned int *matchCounterRight, unsigned int *partialResultsLeft, unsigned int *partialResultsRight)
{
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numElements)
Expand All @@ -67,6 +67,20 @@ extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft,
if (queryIdx >= queryLength || dbIdx >= dbLength)
continue;

// Check for partial results (only used for debugging)
if (matchLeft)
{
unsigned int queryMatchCounter = atomicAdd(&matchCounterLeft[queryIdx], 1);
if (queryMatchCounter < MAX_MATCHES_LEN)
partialResultsLeft[MAX_MATCHES_LEN * queryIdx + queryMatchCounter] = dbIdx;
}
if (matchRight)
{
unsigned int queryMatchCounter = atomicAdd(&matchCounterRight[queryIdx], 1);
if (queryMatchCounter < MAX_MATCHES_LEN)
partialResultsRight[MAX_MATCHES_LEN * queryIdx + queryMatchCounter] = dbIdx;
}

// Current *AND* policy: only match, if both eyes match
if (matchLeft && matchRight)
{
Expand All @@ -79,7 +93,7 @@ extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft,
}
}

extern "C" __global__ void mergeBatchResults(unsigned long long *matchResultsSelfLeft, unsigned long long *matchResultsSelfRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *__matchCounter, unsigned int *__allMatches)
extern "C" __global__ void mergeBatchResults(unsigned long long *matchResultsSelfLeft, unsigned long long *matchResultsSelfRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *__matchCounter, unsigned int *__allMatches, unsigned int *__matchCounterLeft, unsigned int *__matchCounterRight, unsigned int *__partialResultsLeft, unsigned int *__partialResultsRight)
{
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numElements)
Expand Down
56 changes: 52 additions & 4 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ pub struct ServerActor {
query_db_size: Vec<usize>,
max_batch_size: usize,
max_db_size: usize,
return_partial_results: bool,
disable_persistence: bool,
}

Expand All @@ -112,6 +113,7 @@ impl ServerActor {
job_queue_size: usize,
max_db_size: usize,
max_batch_size: usize,
return_partial_results: bool,
disable_persistence: bool,
) -> eyre::Result<(Self, ServerActorHandle)> {
let device_manager = Arc::new(DeviceManager::init());
Expand All @@ -122,6 +124,7 @@ impl ServerActor {
job_queue_size,
max_db_size,
max_batch_size,
return_partial_results,
disable_persistence,
)
}
Expand All @@ -133,6 +136,7 @@ impl ServerActor {
job_queue_size: usize,
max_db_size: usize,
max_batch_size: usize,
return_partial_results: bool,
disable_persistence: bool,
) -> eyre::Result<(Self, ServerActorHandle)> {
let ids = device_manager.get_ids_from_magic(0);
Expand All @@ -145,6 +149,7 @@ impl ServerActor {
job_queue_size,
max_db_size,
max_batch_size,
return_partial_results,
disable_persistence,
)
}
Expand All @@ -158,6 +163,7 @@ impl ServerActor {
job_queue_size: usize,
max_db_size: usize,
max_batch_size: usize,
return_partial_results: bool,
disable_persistence: bool,
) -> eyre::Result<(Self, ServerActorHandle)> {
let (tx, rx) = mpsc::channel(job_queue_size);
Expand All @@ -169,6 +175,7 @@ impl ServerActor {
rx,
max_db_size,
max_batch_size,
return_partial_results,
disable_persistence,
)?;
Ok((actor, ServerActorHandle { job_queue: tx }))
Expand All @@ -183,6 +190,7 @@ impl ServerActor {
job_queue: mpsc::Receiver<ServerJob>,
max_db_size: usize,
max_batch_size: usize,
return_partial_results: bool,
disable_persistence: bool,
) -> eyre::Result<Self> {
assert!(max_batch_size != 0);
Expand Down Expand Up @@ -343,6 +351,7 @@ impl ServerActor {
batch_match_list_right,
max_batch_size,
max_db_size,
return_partial_results,
disable_persistence,
})
}
Expand Down Expand Up @@ -759,7 +768,7 @@ impl ServerActor {
// Fetch and truncate the match counters
let match_counters_devices = self
.distance_comparator
.fetch_match_counters()
.fetch_match_counters(&self.distance_comparator.match_counters)
.into_iter()
.map(|x| x[..batch_size].to_vec())
.collect::<Vec<_>>();
Expand All @@ -776,9 +785,10 @@ impl ServerActor {
});

// Transfer all match ids
let match_ids = self
.distance_comparator
.fetch_all_match_ids(match_counters_devices);
let match_ids = self.distance_comparator.fetch_all_match_ids(
match_counters_devices,
&self.distance_comparator.all_matches,
);

// Check if there are more matches than we fetch
// TODO: In the future we might want to dynamically allocate more memory here
Expand All @@ -793,6 +803,28 @@ impl ServerActor {
}
}

let (partial_match_ids_left, partial_match_ids_right) = if self.return_partial_results {
// Transfer the partial results to the host
let partial_match_counters_left = self
.distance_comparator
.fetch_match_counters(&self.distance_comparator.match_counters_left);
let partial_match_counters_right = self
.distance_comparator
.fetch_match_counters(&self.distance_comparator.match_counters_right);

let partial_results_left = self.distance_comparator.fetch_all_match_ids(
partial_match_counters_left,
&self.distance_comparator.partial_results_left,
);
let partial_results_right = self.distance_comparator.fetch_all_match_ids(
partial_match_counters_right,
&self.distance_comparator.partial_results_right,
);
(partial_results_left, partial_results_right)
} else {
(vec![], vec![])
};

// Write back to in-memory db
let previous_total_db_size = self.current_db_sizes.iter().sum::<usize>();
let n_insertions = insertion_list.iter().map(|x| x.len()).sum::<usize>();
Expand Down Expand Up @@ -854,6 +886,8 @@ impl ServerActor {
metadata: batch.metadata,
matches,
match_ids,
partial_match_ids_left,
partial_match_ids_right,
store_left: query_store_left,
store_right: query_store_right,
deleted_ids: batch.deletion_requests_indices,
Expand Down Expand Up @@ -881,6 +915,20 @@ impl ServerActor {
&self.streams[0],
);

reset_slice(
self.device_manager.devices(),
&self.distance_comparator.match_counters_left,
0,
&self.streams[0],
);

reset_slice(
self.device_manager.devices(),
&self.distance_comparator.match_counters_right,
0,
&self.streams[0],
);

// ---- END RESULT PROCESSING ----
log_timers(events);
let processed_mil_elements_per_second = (self.max_batch_size * previous_total_db_size)
Expand Down
18 changes: 10 additions & 8 deletions iris-mpc-gpu/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,16 @@ pub struct ServerJob {

#[derive(Debug, Clone)]
pub struct ServerJobResult {
pub merged_results: Vec<u32>,
pub request_ids: Vec<String>,
pub metadata: Vec<BatchMetadata>,
pub matches: Vec<bool>,
pub match_ids: Vec<Vec<u32>>,
pub store_left: BatchQueryEntries,
pub store_right: BatchQueryEntries,
pub deleted_ids: Vec<u32>,
pub merged_results: Vec<u32>,
pub request_ids: Vec<String>,
pub metadata: Vec<BatchMetadata>,
pub matches: Vec<bool>,
pub match_ids: Vec<Vec<u32>>,
pub partial_match_ids_left: Vec<Vec<u32>>,
pub partial_match_ids_right: Vec<Vec<u32>>,
pub store_left: BatchQueryEntries,
pub store_right: BatchQueryEntries,
pub deleted_ids: Vec<u32>,
}

enum Eye {
Expand Down
Loading

0 comments on commit 5873808

Please sign in to comment.