Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added clippy to CI #74

Merged
merged 6 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/build-and-push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ jobs:
with:
context: .
push: true
tags:
- ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ github.event.release.tag_name }}
- ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ github.sha }}
tags: |
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ github.event.release.tag_name }},
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ github.sha }}
platforms: linux/amd64
cache-from: type=gha
cache-to: type=gha,mode=max
19 changes: 19 additions & 0 deletions .github/workflows/lint-clippy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Lint Clippy

on:
push:

jobs:
lint-clippy:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Rust nightly
run: rustup toolchain install nightly
- name: Set Rust nightly as default
run: rustup default nightly
- name: Install Rust clippy for checking clippy errors
run: rustup component add clippy
- name: Run Rust Clippy
run: cargo clippy --all-targets --all-features -- -D warnings --no-deps
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name: Lint
name: Lint Rustfmt

on:
push:

jobs:
# TODO: change this to lint-rustfmt once we've updated required jobs
lint:
runs-on: ubuntu-latest
steps:
Expand All @@ -16,4 +17,4 @@ jobs:
- name: Install Rustfmt for formatting
run: rustup component add rustfmt
- name: Run Rustfmt
run: cargo fmt -- --check
run: cargo fmt -- --check
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ url = "2"
uuid = { version = "1.8.0", features = ["v4"] }
tracing = "0.1.40"
dotenvy = "0.15.7"
static_assertions = "1.1.0"

[dev-dependencies]
criterion = "0.5"
Expand Down
4 changes: 2 additions & 2 deletions src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ async fn main() -> eyre::Result<()> {

if expected_result.is_none() {
// New insertion
assert_eq!(result.is_match, false);
assert!(!result.is_match);
let request = thread_requests
.lock()
.await
Expand All @@ -130,7 +130,7 @@ async fn main() -> eyre::Result<()> {
} else {
// Existing entry
println!("Expected: {:?} Got: {:?}", expected_result, result.db_index);
assert_eq!(result.is_match, true);
assert!(result.is_match);
assert_eq!(result.db_index, expected_result.unwrap());
}

Expand Down
74 changes: 34 additions & 40 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ use gpu_iris_mpc::{
use lazy_static::lazy_static;
use rand::{rngs::StdRng, SeedableRng};
use ring::hkdf::{Algorithm, Okm, Salt, HKDF_SHA256};
use static_assertions::const_assert;
use std::{
fs::metadata,
mem,
ops::IndexMut,
path::PathBuf,
sync::{atomic::AtomicUsize, Arc, Mutex},
time::{Duration, Instant},
Expand Down Expand Up @@ -259,13 +261,7 @@ async fn receive_batch(
}

fn prepare_query_shares(shares: Vec<GaloisRingIrisCodeShare>) -> Vec<Vec<u8>> {
preprocess_query(
&shares
.into_iter()
.map(|e| e.coefs)
.flatten()
.collect::<Vec<_>>(),
)
preprocess_query(&shares.into_iter().flat_map(|e| e.coefs).collect::<Vec<_>>())
}

#[allow(clippy::type_complexity)]
Expand Down Expand Up @@ -317,7 +313,7 @@ fn open(
}
cudarc::nccl::result::group_end().unwrap();

distance_comparator.open_results(&a, &b, &c, results_ptrs, db_sizes, &streams);
distance_comparator.open_results(&a, &b, &c, results_ptrs, db_sizes, streams);
}

fn get_merged_results(host_results: &[Vec<u32>], n_devices: usize) -> Vec<u32> {
Expand Down Expand Up @@ -383,7 +379,7 @@ fn device_ptrs_to_shares<T>(
let b = device_ptrs_to_slices(b, lens, devs);

a.into_iter()
.zip(b.into_iter())
.zip(b)
.map(|(a, b)| ChunkShare::new(a, b))
.collect::<Vec<_>>()
}
Expand Down Expand Up @@ -551,28 +547,26 @@ async fn main() -> eyre::Result<()> {
let codes_db = db
.db
.iter()
.map(|iris| {
.flat_map(|iris| {
GaloisRingIrisCodeShare::encode_iris_code(
&iris.code,
&iris.mask,
&mut StdRng::seed_from_u64(RNG_SEED),
)[party_id]
.coefs
})
.flatten()
.collect::<Vec<_>>();

let masks_db = db
.db
.iter()
.map(|iris| {
.flat_map(|iris| {
GaloisRingIrisCodeShare::encode_mask_code(
&iris.mask,
&mut StdRng::seed_from_u64(RNG_SEED),
)[party_id]
.coefs
})
.flatten()
.collect::<Vec<_>>();

write_mmap_file(&code_db_path, &codes_db)?;
Expand Down Expand Up @@ -705,7 +699,7 @@ async fn main() -> eyre::Result<()> {
let mut next_exchange_event = device_manager.create_events();
let mut timer_events = vec![];
let start_timer = device_manager.create_events();
let mut end_timer = device_manager.create_events();
let end_timer = device_manager.create_events();

let current_db_size: Vec<usize> =
vec![DB_SIZE / device_manager.device_count(); device_manager.device_count()];
Expand Down Expand Up @@ -993,10 +987,11 @@ async fn main() -> eyre::Result<()> {
// before running that code.
// - End events are re-used in each thread, but we only end one thread at a
// time.
assert!(MAX_BATCHES_BEFORE_REUSE > MAX_CONCURRENT_BATCHES);
const_assert!(MAX_BATCHES_BEFORE_REUSE > MAX_CONCURRENT_BATCHES);

// into_iter() makes the Rust compiler check that the streams are not re-used.
let mut thread_streams = request_streams
.into_iter()
.iter()
.map(|s| unsafe { s.stream.as_mut().unwrap() })
.collect::<Vec<_>>();
// The compiler can't tell that we wait for the previous batch before re-using
Expand All @@ -1016,9 +1011,9 @@ async fn main() -> eyre::Result<()> {
.map(Arc::clone)
.collect::<Vec<_>>();
let db_sizes_batch = query_db_size.clone();
let thread_request_results_batch = device_ptrs(&request_results_batch);
let thread_request_results = device_ptrs(&request_results);
let thread_request_final_results = device_ptrs(&request_final_results);
let thread_request_results_batch = device_ptrs(request_results_batch);
let thread_request_results = device_ptrs(request_results);
let thread_request_final_results = device_ptrs(request_final_results);

// Batch phase 1 results
let thread_code_results_batch = device_ptrs(&batch_codes_engine.results);
Expand Down Expand Up @@ -1065,10 +1060,8 @@ async fn main() -> eyre::Result<()> {
// batches),
// - CUevent: thread_current_stream_event, thread_end_timer,
// - Comm: phase2, phase2_batch.
if previous_thread_handle.is_some() {
runtime::Handle::current()
.block_on(previous_thread_handle.unwrap())
.unwrap();
if let Some(phandle) = previous_thread_handle {
runtime::Handle::current().block_on(phandle).unwrap();
}
let thread_devs = thread_device_manager.devices();
let mut thread_phase2_batch = thread_phase2_batch.lock().unwrap();
Expand Down Expand Up @@ -1103,30 +1096,30 @@ async fn main() -> eyre::Result<()> {
&thread_code_results_batch,
&thread_code_results_peer_batch,
&result_sizes_batch,
&thread_devs,
thread_devs,
);
let mut mask_dots_batch: Vec<ChunkShare<u16>> = device_ptrs_to_shares(
&thread_mask_results_batch,
&thread_mask_results_peer_batch,
&result_sizes_batch,
&thread_devs,
thread_devs,
);

let mut code_dots: Vec<ChunkShare<u16>> = device_ptrs_to_shares(
&thread_code_results,
&thread_code_results_peer,
&result_sizes,
&thread_devs,
thread_devs,
);
let mut mask_dots: Vec<ChunkShare<u16>> = device_ptrs_to_shares(
&thread_mask_results,
&thread_mask_results_peer,
&result_sizes,
&thread_devs,
thread_devs,
);

// TODO: use phase 1 streams here
let mut phase2_streams = thread_phase2
let phase2_streams = thread_phase2
.get_devices()
.iter()
.map(|d| d.fork_default_stream().unwrap())
Expand All @@ -1144,7 +1137,7 @@ async fn main() -> eyre::Result<()> {
let mut thread_request_results_slice_batch: Vec<CudaSlice<u32>> = device_ptrs_to_slices(
&thread_request_results_batch,
&vec![QUERIES; thread_devs.len()],
&thread_devs,
thread_devs,
);

// Iterate over a list of tracing payloads, and create logs with mappings to
Expand Down Expand Up @@ -1179,7 +1172,7 @@ async fn main() -> eyre::Result<()> {
let mut thread_request_results_slice: Vec<CudaSlice<u32>> = device_ptrs_to_slices(
&thread_request_results,
&vec![QUERIES; thread_devs.len()],
&thread_devs,
thread_devs,
);

let chunk_size = thread_phase2.chunk_size();
Expand Down Expand Up @@ -1316,13 +1309,14 @@ async fn main() -> eyre::Result<()> {
// CudaDevice, which makes sure they aren't dropped.
unsafe {
event::record(
*&mut thread_current_stream_event[i],
*&mut thread_streams[i],
*thread_current_stream_event.index_mut(i),
*thread_streams.index_mut(i),
)
.unwrap();

// DEBUG: emit event to measure time for e2e process
event::record(*&mut thread_end_timer[i], *&mut thread_streams[i]).unwrap();
event::record(*thread_end_timer.index_mut(i), *thread_streams.index_mut(i))
.unwrap();
}
}

Expand All @@ -1333,22 +1327,22 @@ async fn main() -> eyre::Result<()> {

// Reset the results buffers for reuse
reset_results(
&thread_device_manager.devices(),
thread_device_manager.devices(),
&thread_request_results,
&RESULTS_INIT_HOST,
&mut phase2_streams,
&phase2_streams,
);
reset_results(
&thread_device_manager.devices(),
thread_device_manager.devices(),
&thread_request_results_batch,
&RESULTS_INIT_HOST,
&mut phase2_streams,
&phase2_streams,
);
reset_results(
&thread_device_manager.devices(),
thread_device_manager.devices(),
&thread_request_final_results,
&FINAL_RESULTS_INIT_HOST,
&mut phase2_streams,
&phase2_streams,
);

// Make sure to not call `Drop` on those
Expand Down Expand Up @@ -1407,7 +1401,7 @@ async fn main() -> eyre::Result<()> {
for i in 0..device_manager.device_count() {
unsafe {
device_manager.device(i).bind_to_thread().unwrap();
let total_time = elapsed(start_timer[i], *&mut end_timer[i]).unwrap();
let total_time = elapsed(start_timer[i], end_timer[i]).unwrap();
println!("Total time: {:?}", total_time);
}
}
Expand Down
Loading