Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expose partial matches #534

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions .github/workflows/temp-build-and-push.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: Branch - Build and push docker image

on:
push:

concurrency:
group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}'
cancel-in-progress: true

env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}

jobs:
docker:
runs-on:
labels: ubuntu-22.04-64core
permissions:
packages: write
contents: read
attestations: write
id-token: write
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and Push
uses: docker/build-push-action@v6
with:
context: .
push: true
tags: |
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ github.sha }}
platforms: linux/amd64
cache-from: type=gha
cache-to: type=gha,mode=max
2 changes: 1 addition & 1 deletion deploy/stage/common-values-iris-mpc.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
image: "ghcr.io/worldcoin/iris-mpc:v0.8.31"
image: "ghcr.io/worldcoin/iris-mpc:ee89ea0b0f1358d69ffbbf730b5ae427f7361153"

environment: stage
replicaCount: 1
Expand Down
3 changes: 3 additions & 0 deletions iris-mpc-common/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ pub struct Config {
#[serde(default)]
pub fake_db_size: usize,

#[serde(default)]
pub return_partial_results: bool,

#[serde(default)]
pub disable_persistence: bool,
}
Expand Down
16 changes: 11 additions & 5 deletions iris-mpc-common/src/helpers/smpc_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,13 @@ impl UniquenessRequest {

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UniquenessResult {
pub node_id: usize,
pub serial_id: Option<u32>,
pub is_match: bool,
pub signup_id: String,
pub matched_serial_ids: Option<Vec<u32>>,
pub node_id: usize,
pub serial_id: Option<u32>,
pub is_match: bool,
pub signup_id: String,
pub matched_serial_ids: Option<Vec<u32>>,
pub matched_serial_ids_left: Option<Vec<u32>>,
pub matched_serial_ids_right: Option<Vec<u32>>,
}

impl UniquenessResult {
Expand All @@ -316,13 +318,17 @@ impl UniquenessResult {
is_match: bool,
signup_id: String,
matched_serial_ids: Option<Vec<u32>>,
matched_serial_ids_left: Option<Vec<u32>>,
matched_serial_ids_right: Option<Vec<u32>>,
) -> Self {
Self {
node_id,
serial_id,
is_match,
signup_id,
matched_serial_ids,
matched_serial_ids_left,
matched_serial_ids_right,
}
}
}
Expand Down
46 changes: 39 additions & 7 deletions iris-mpc-gpu/src/dot/distance_comparator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,26 @@ pub struct DistanceComparator {
pub final_results_init_host: Vec<u32>,
pub match_counters: Vec<CudaSlice<u32>>,
pub all_matches: Vec<CudaSlice<u32>>,
pub match_counters_left: Vec<CudaSlice<u32>>,
pub match_counters_right: Vec<CudaSlice<u32>>,
pub partial_results_left: Vec<CudaSlice<u32>>,
pub partial_results_right: Vec<CudaSlice<u32>>,
}

impl DistanceComparator {
pub fn init(query_length: usize, device_manager: Arc<DeviceManager>) -> Self {
let ptx = compile_ptx(PTX_SRC).unwrap();
let mut open_kernels = Vec::new();
let mut open_kernels: Vec<CudaFunction> = Vec::new();
let mut merge_db_kernels = Vec::new();
let mut merge_batch_kernels = Vec::new();
let mut opened_results = vec![];
let mut final_results = vec![];
let mut match_counters: Vec<CudaSlice<u32>> = vec![];
let mut all_matches: Vec<CudaSlice<u32>> = vec![];
let mut match_counters = vec![];
let mut match_counters_left = vec![];
let mut match_counters_right = vec![];
let mut all_matches = vec![];
let mut partial_results_left = vec![];
let mut partial_results_right = vec![];

let devices_count = device_manager.device_count();

Expand All @@ -63,11 +71,23 @@ impl DistanceComparator {
opened_results.push(device.htod_copy(results_init_host.clone()).unwrap());
final_results.push(device.htod_copy(final_results_init_host.clone()).unwrap());
match_counters.push(device.alloc_zeros(query_length / ROTATIONS).unwrap());
match_counters_left.push(device.alloc_zeros(query_length / ROTATIONS).unwrap());
match_counters_right.push(device.alloc_zeros(query_length / ROTATIONS).unwrap());
all_matches.push(
device
.alloc_zeros(ALL_MATCHES_LEN * query_length / ROTATIONS)
.unwrap(),
);
partial_results_left.push(
device
.alloc_zeros(ALL_MATCHES_LEN * query_length / ROTATIONS)
.unwrap(),
);
partial_results_right.push(
device
.alloc_zeros(ALL_MATCHES_LEN * query_length / ROTATIONS)
.unwrap(),
);

open_kernels.push(open_results_function);
merge_db_kernels.push(merge_db_results_function);
Expand All @@ -85,7 +105,11 @@ impl DistanceComparator {
results_init_host,
final_results_init_host,
match_counters,
match_counters_left,
match_counters_right,
all_matches,
partial_results_left,
partial_results_right,
}
}

Expand Down Expand Up @@ -213,6 +237,10 @@ impl DistanceComparator {
num_elements as u64,
&self.match_counters[i],
&self.all_matches[i],
&self.match_counters_left[i],
&self.match_counters_right[i],
&self.partial_results_left[i],
&self.partial_results_right[i],
),
)
.unwrap();
Expand All @@ -233,26 +261,30 @@ impl DistanceComparator {
results
}

pub fn fetch_match_counters(&self) -> Vec<Vec<u32>> {
pub fn fetch_match_counters(&self, counters: &[CudaSlice<u32>]) -> Vec<Vec<u32>> {
let mut results = vec![];
for i in 0..self.device_manager.device_count() {
results.push(
self.device_manager
.device(i)
.dtoh_sync_copy(&self.match_counters[i])
.dtoh_sync_copy(&counters[i])
.unwrap(),
);
}
results
}

pub fn fetch_all_match_ids(&self, match_counters: Vec<Vec<u32>>) -> Vec<Vec<u32>> {
pub fn fetch_all_match_ids(
&self,
match_counters: Vec<Vec<u32>>,
matches: &[CudaSlice<u32>],
) -> Vec<Vec<u32>> {
let mut results = vec![];
for i in 0..self.device_manager.device_count() {
results.push(
self.device_manager
.device(i)
.dtoh_sync_copy(&self.all_matches[i])
.dtoh_sync_copy(&matches[i])
.unwrap(),
);
}
Expand Down
18 changes: 16 additions & 2 deletions iris-mpc-gpu/src/dot/kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ extern "C" __global__ void openResults(unsigned long long *result1, unsigned lon
}
}

extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft, unsigned long long *matchResultsRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *matchCounter, unsigned int *allMatches)
extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft, unsigned long long *matchResultsRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *matchCounter, unsigned int *allMatches, unsigned int *matchCounterLeft, unsigned int *matchCounterRight, unsigned int *partialResultsLeft, unsigned int *partialResultsRight)
{
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numElements)
Expand All @@ -67,6 +67,20 @@ extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft,
if (queryIdx >= queryLength || dbIdx >= dbLength)
continue;

// Check for partial results (only used for debugging)
if (matchLeft)
{
unsigned int queryMatchCounter = atomicAdd(&matchCounterLeft[queryIdx], 1);
if (queryMatchCounter < MAX_MATCHES_LEN)
partialResultsLeft[MAX_MATCHES_LEN * queryIdx + queryMatchCounter] = dbIdx;
}
if (matchRight)
{
unsigned int queryMatchCounter = atomicAdd(&matchCounterRight[queryIdx], 1);
if (queryMatchCounter < MAX_MATCHES_LEN)
partialResultsRight[MAX_MATCHES_LEN * queryIdx + queryMatchCounter] = dbIdx;
}

// Current *AND* policy: only match, if both eyes match
if (matchLeft && matchRight)
{
Expand All @@ -79,7 +93,7 @@ extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft,
}
}

extern "C" __global__ void mergeBatchResults(unsigned long long *matchResultsSelfLeft, unsigned long long *matchResultsSelfRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *__matchCounter, unsigned int *__allMatches)
extern "C" __global__ void mergeBatchResults(unsigned long long *matchResultsSelfLeft, unsigned long long *matchResultsSelfRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *__matchCounter, unsigned int *__allMatches, unsigned int *__matchCounterLeft, unsigned int *__matchCounterRight, unsigned int *__partialResultsLeft, unsigned int *__partialResultsRight)
{
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numElements)
Expand Down
Loading
Loading