From a9698f75699dafb48b376d3e5f2a2bdad26b3d1f Mon Sep 17 00:00:00 2001 From: Keming Date: Tue, 3 Dec 2024 18:02:20 +0800 Subject: [PATCH] fix: divide 64 in packages (#4) Signed-off-by: Keming --- Cargo.lock | 8 ++++---- Cargo.toml | 2 +- Makefile | 7 ++++++- crates/disk/src/cache.rs | 2 +- crates/disk/src/disk.rs | 28 +++++++++++++++++----------- crates/service/src/main.rs | 23 ++++++++++------------- 6 files changed, 39 insertions(+), 31 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5e43519..1b21454 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -704,7 +704,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "cli" -version = "0.2.1" +version = "0.2.2" dependencies = [ "argh", "env_logger", @@ -849,7 +849,7 @@ dependencies = [ [[package]] name = "disk" -version = "0.2.1" +version = "0.2.2" dependencies = [ "anyhow", "aws-config", @@ -1881,7 +1881,7 @@ dependencies = [ [[package]] name = "rabitq" -version = "0.2.1" +version = "0.2.2" dependencies = [ "faer", "log", @@ -2225,7 +2225,7 @@ dependencies = [ [[package]] name = "service" -version = "0.2.1" +version = "0.2.2" dependencies = [ "argh", "axum", diff --git a/Cargo.toml b/Cargo.toml index 88956db..7982a1f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ members = ["crates/*"] [workspace.package] -version = "0.2.1" +version = "0.2.2" edition = "2021" description = "A Rust implementation of the RaBitQ vector search algorithm." license = "AGPL-3.0" diff --git a/Makefile b/Makefile index 1aeed6d..7636786 100644 --- a/Makefile +++ b/Makefile @@ -1,12 +1,17 @@ +packages := cli disk service + build: cargo b format: @cargo +nightly fmt + @$(foreach package, $(packages), cargo +nightly fmt --package $(package);) lint: - @cargo +nightly fmt -- --check + @cargo +nightly fmt --check + @$(foreach package, $(packages), cargo +nightly fmt --package $(package) --check;) @cargo clippy -- -D warnings + @$(foreach package, $(packages), cargo clippy --package $(package) -- -D warnings;) test: @cargo test diff --git a/crates/disk/src/cache.rs b/crates/disk/src/cache.rs index 3c8a635..ffee024 100644 --- a/crates/disk/src/cache.rs +++ b/crates/disk/src/cache.rs @@ -79,7 +79,7 @@ impl CachedVector { let s3_client = Arc::new(Client::new(&s3_config)); let num_per_block = BLOCK_BYTE_LIMIT / (4 * (dim + 1)); let total_num = num; - let total_block = (total_num + num_per_block - 1) / num_per_block; + let total_block = total_num.div_ceil(num_per_block); let sqlite_conn = Connection::open(Path::new(&local_path)).expect("failed to open sqlite"); sqlite_conn .execute( diff --git a/crates/disk/src/disk.rs b/crates/disk/src/disk.rs index 95b46b8..0efff2e 100644 --- a/crates/disk/src/disk.rs +++ b/crates/disk/src/disk.rs @@ -18,15 +18,15 @@ use crate::cache::CachedVector; /// Rank with cached raw vectors. #[derive(Debug)] -pub struct CacheReRanker { +pub struct CacheReRanker<'a> { threshold: f32, topk: usize, heap: BinaryHeap<(Ord32, AlwaysEqual)>, - query: Vec, + query: &'a [f32], } -impl CacheReRanker { - fn new(query: Vec, topk: usize) -> Self { +impl<'a> CacheReRanker<'a> { + fn new(query: &'a [f32], topk: usize) -> Self { Self { threshold: f32::MAX, query, @@ -45,7 +45,7 @@ impl CacheReRanker { for &(rough, u) in rough_distances.iter() { if rough < self.threshold { let accurate = cache - .get_query_vec_distance(&self.query, u) + .get_query_vec_distance(self.query, u) .await .expect("failed to get distance"); precise += 1; @@ -142,11 +142,16 @@ impl DiskRaBitQ { /// Query the topk nearest neighbors for the given query asynchronously. pub async fn query(&self, query: Vec, probe: usize, topk: usize) -> Vec<(f32, u32)> { - assert_eq!(self.dim as usize, query.len()); - let y_projected = project(&query, &self.orthogonal.as_ref()); + assert_eq!(self.dim as usize, query.len().div_ceil(64) * 64); + // padding + let mut query_vec = query.to_vec(); + if query.len() < self.dim as usize { + query_vec.extend_from_slice(&vec![0.0; self.dim as usize - query.len()]); + } + + let y_projected = project(&query_vec, &self.orthogonal.as_ref()); let k = self.centroids.shape().1; let mut lists = Vec::with_capacity(k); - let mut residual = vec![0f32; self.dim as usize]; for (i, centroid) in self.centroids.col_iter().enumerate() { let dist = l2_squared_distance( centroid @@ -161,10 +166,11 @@ impl DiskRaBitQ { lists.truncate(length); lists.sort_by(|a, b| a.0.total_cmp(&b.0)); - let mut re_ranker = CacheReRanker::new(query, topk); + let mut re_ranker = CacheReRanker::new(&query_vec, topk); + let mut residual = vec![0f32; self.dim as usize]; + let mut quantized = vec![0u8; (self.dim as usize).div_ceil(64) * 64]; let mut rough_distances = Vec::new(); - let mut quantized = vec![0u8; self.dim as usize]; - let mut binary_vec = vec![0u64; self.dim as usize * THETA_LOG_DIM as usize / 64]; + let mut binary_vec = vec![0u64; self.dim.div_ceil(64) as usize * THETA_LOG_DIM as usize]; for &(dist, i) in lists[..length].iter() { let (lower_bound, upper_bound) = min_max_residual(&mut residual, &y_projected.as_ref(), &self.centroids.col(i)); diff --git a/crates/service/src/main.rs b/crates/service/src/main.rs index 05865ca..d3be624 100644 --- a/crates/service/src/main.rs +++ b/crates/service/src/main.rs @@ -18,18 +18,15 @@ mod args; async fn shutdown_signal() { let mut interrupt = signal(SignalKind::interrupt()).unwrap(); let mut terminate = signal(SignalKind::terminate()).unwrap(); - loop { - tokio::select! { - _ = interrupt.recv() => { - info!("Received interrupt signal"); - break; - } - _ = terminate.recv() => { - info!("Received terminate signal"); - break; - } - }; - } + tokio::select! { + _ = interrupt.recv() => { + info!("Received interrupt signal"); + } + _ = terminate.recv() => { + info!("Received terminate signal"); + } + }; + info!("Shutting down"); } async fn health_check() -> impl IntoResponse { @@ -75,7 +72,7 @@ async fn main() { let config: args::Args = argh::from_env(); let model_path = Path::new(&config.dir); - download_meta_from_s3(&config.bucket, &config.key, &model_path) + download_meta_from_s3(&config.bucket, &config.key, model_path) .await .expect("failed to download meta"); let rabitq =