diff --git a/Cargo.lock b/Cargo.lock index 93ade05..04985c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -324,6 +324,7 @@ dependencies = [ "portable-atomic", "portable-atomic-util", "rawpointer", + "rayon", ] [[package]] @@ -682,6 +683,7 @@ dependencies = [ "ndarray-stats", "rand 0.9.0-alpha.2", "rand_distr", + "rayon", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 270e975..532ed04 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,11 +12,12 @@ categories = ["algorithms"] [dependencies] anyhow = "1.0.93" -ndarray = "0.16.1" +ndarray = { version = "0.16.1", features = ["rayon"] } ndarray-stats = "0.6.0" rand = "0.9.0-alpha.2" ndarray-rand = "0.15.0" rand_distr = "0.4.3" +rayon = "1.10.0" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } @@ -31,4 +32,4 @@ path = "src/bin/example.rs" [[bin]] name = "readme_example" -path = "src/bin/readme_example.rs" \ No newline at end of file +path = "src/bin/readme_example.rs" diff --git a/src/pq.rs b/src/pq.rs index 9322bd7..485b5ab 100644 --- a/src/pq.rs +++ b/src/pq.rs @@ -1,6 +1,8 @@ use crate::utils::{determine_code_type, euclidean_distance, kmeans2}; use anyhow::Result; -use ndarray::{s, Array2, Array3}; +use ndarray::parallel::prelude::*; +use ndarray::{s, Array2, Array3, Axis}; +use rayon::prelude::*; #[derive(Debug, Clone, Copy)] pub enum CodeType { @@ -89,25 +91,32 @@ impl PQ { let max_width = dims_width.iter().max().unwrap(); let mut codewords = Array3::::zeros((self.m, self.ks as usize, *max_width)); - for m in 0..self.m { - if self.verbose { - println!( - "# Training the subspace: {} / {}, {} -> {}", - m, - self.m, - self.ds.as_ref().unwrap()[m], - self.ds.as_ref().unwrap()[m + 1] - ); - } + let trained_codewords: Vec<(usize, Array2)> = (0..self.m) + .into_par_iter() + .map(|m| { + if self.verbose { + println!( + "# Training the subspace: {} / {}, {} -> {}", + m, + self.m, + self.ds.as_ref().unwrap()[m], + self.ds.as_ref().unwrap()[m + 1] + ); + } - let vecs_sub = vecs.slice(s![ - .., - self.ds.as_ref().unwrap()[m]..self.ds.as_ref().unwrap()[m + 1] - ]); + let ds_ref = self.ds.as_ref().unwrap(); - let (centroids, _) = kmeans2(&vecs_sub.to_owned(), self.ks, iterations, "points")?; + let vecs_sub = vecs.slice(s![.., ds_ref[m]..ds_ref[m + 1]]); - let subspace_width = self.ds.as_ref().unwrap()[m + 1] - self.ds.as_ref().unwrap()[m]; + let (centroids, _) = kmeans2(&vecs_sub.to_owned(), self.ks, iterations, "points")?; + + Ok((m, centroids)) + }) + .collect::)>>>()?; + + for (m, centroids) in trained_codewords { + let ds_ref = self.ds.as_ref().unwrap(); + let subspace_width = ds_ref[m + 1] - ds_ref[m]; codewords .slice_mut(s![m, .., ..subspace_width]) @@ -141,41 +150,43 @@ impl PQ { .as_ref() .ok_or_else(|| anyhow::anyhow!("Model not trained. Call fit() first"))?; - for m in 0..self.m { - let vecs_sub = vecs.slice(s![.., ds[m]..ds[m + 1]]); - let subspace_width = ds[m + 1] - ds[m]; - let codewords_sub = codewords.slice(s![m, .., ..subspace_width]); - - for (i, vec) in vecs_sub.rows().into_iter().enumerate() { - let mut min_dist = f32::INFINITY; - let mut min_idx = 0; - - for (j, codeword) in codewords_sub.rows().into_iter().enumerate() { - let dist = euclidean_distance(&vec, &codeword); - if dist < min_dist { - min_dist = dist; - min_idx = j; + codes + .outer_iter_mut() + .into_par_iter() + .zip(vecs.outer_iter()) + .for_each(|(mut code_row, vec)| { + for m in 0..self.m { + let subspace = vec.slice(s![ds[m]..ds[m + 1]]); + let subspace_width = ds[m + 1] - ds[m]; + let codewords_sub = codewords.slice(s![m, .., ..subspace_width]); + + let mut min_dist = f32::INFINITY; + let mut min_idx = 0; + + for (j, codeword) in codewords_sub.axis_iter(Axis(0)).enumerate() { + let dist = euclidean_distance(&subspace, &codeword); + if dist < min_dist { + min_dist = dist; + min_idx = j; + } } - } - codes[[i, m]] = min_idx as u32; - } - } + code_row[m] = min_idx as u32; + } + }); - let codes = match self.code_dtype { + match self.code_dtype { CodeType::U8 => { if codes.iter().any(|&x| x > u8::MAX as u32) { anyhow::bail!("Encoded values exceed U8 range"); } - codes } CodeType::U16 => { if codes.iter().any(|&x| x > u16::MAX as u32) { anyhow::bail!("Encoded values exceed U16 range"); } - codes } - CodeType::U32 => codes, + CodeType::U32 => {} }; Ok(codes) @@ -202,23 +213,26 @@ impl PQ { let mut vecs = Array2::::zeros((n_vectors, dim)); - for m in 0..self.m { - let subspace_width = ds[m + 1] - ds[m]; + vecs.outer_iter_mut() + .into_par_iter() + .zip(codes.outer_iter()) + .try_for_each(|(mut vec_row, code_row)| -> Result<(), anyhow::Error> { + for m in 0..self.m { + let code_idx = code_row[m] as usize; + if code_idx >= self.ks as usize { + return Err(anyhow::anyhow!( + "Code value {} exceeds number of clusters {}", + code_idx, + self.ks + )); + } + let subspace_width = ds[m + 1] - ds[m]; + let codeword = codewords.slice(s![m, code_idx, ..subspace_width]); - for (i, code) in codes.column(m).iter().enumerate() { - let code_idx = *code as usize; - if code_idx >= self.ks as usize { - anyhow::bail!( - "Code value {} exceeds number of clusters {}", - code_idx, - self.ks - ); + vec_row.slice_mut(s![ds[m]..ds[m + 1]]).assign(&codeword); } - - vecs.slice_mut(s![i, ds[m]..ds[m + 1]]) - .assign(&codewords.slice(s![m, code_idx, ..subspace_width])); - } - } + Ok(()) + })?; Ok(vecs) } @@ -233,7 +247,6 @@ impl PQ { mod tests { use super::*; use crate::utils::create_random_vectors; - use anyhow::Result; use ndarray::Array2; fn create_dummy_vectors(num_vectors: usize, dimension: usize) -> Array2 { diff --git a/src/utils.rs b/src/utils.rs index 9022af0..4d5c6e8 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,6 +1,7 @@ use crate::pq::CodeType; use anyhow::Result; -use ndarray::{s, Array1, Array2, ArrayView1, Axis}; +use ndarray::parallel::prelude::*; +use ndarray::{Array1, Array2, ArrayView1, Axis}; use ndarray_stats::QuantileExt; use rand::distr::{Distribution, Uniform}; use rand::seq::SliceRandom; @@ -48,7 +49,7 @@ pub fn kmeans2( _ => anyhow::bail!("Unsupported initialization method"), }; - let mut labels = Array1::zeros(n_samples); + let mut labels = Array1::::zeros(n_samples); let mut old_centroids; let mut has_converged = false; @@ -59,39 +60,48 @@ pub fn kmeans2( old_centroids = centroids.clone(); - for (i, sample) in data.rows().into_iter().enumerate() { - let mut min_dist = f32::INFINITY; - let mut min_label = 0; - - for (j, centroid) in centroids.rows().into_iter().enumerate() { - let dist = euclidean_distance(&sample, ¢roid); - if dist < min_dist { - min_dist = dist; - min_label = j; + let labels_vec: Vec = data + .outer_iter() + .into_par_iter() + .map(|sample| { + let mut min_dist = f32::INFINITY; + let mut min_label = 0; + + for (j, centroid) in centroids.outer_iter().enumerate() { + let dist = euclidean_distance(&sample, ¢roid); + if dist < min_dist { + min_dist = dist; + min_label = j; + } } - } - labels[i] = min_label; - } + min_label + }) + .collect(); - let mut new_centroids = Array2::zeros((k, n_features)); - let mut counts = vec![0usize; k]; - - for (i, sample) in data.rows().into_iter().enumerate() { - let label = labels[i]; - new_centroids.row_mut(label).add_assign(&sample); - counts[label] += 1; - } + labels = Array1::from(labels_vec); - for (i, count) in counts.iter().enumerate() { - if *count > 0 { - new_centroids.row_mut(i).mapv_inplace(|x| x / *count as f32); - } else { - let random_idx = rand::thread_rng().gen_range(0..n_samples); - new_centroids.row_mut(i).assign(&data.row(random_idx)); - } - } - - centroids = new_centroids; + let mut counts = vec![0usize; k]; + let mut sums = vec![Array1::::zeros(n_features); k]; + + data.outer_iter() + .zip(labels.iter()) + .for_each(|(sample, &label)| { + counts[label] += 1; + sums[label].add_assign(&sample); + }); + + centroids + .outer_iter_mut() + .into_par_iter() + .enumerate() + .for_each(|(i, mut centroid_row)| { + if counts[i] > 0 { + centroid_row.assign(&(sums[i].clone() / counts[i] as f32)); + } else { + let random_idx = rand::thread_rng().gen_range(0..n_samples); + centroid_row.assign(&data.row(random_idx)); + } + }); has_converged = check_convergence(¢roids, &old_centroids); }