Skip to content

Commit

Permalink
feat: parallelize operations
Browse files Browse the repository at this point in the history
  • Loading branch information
micheleriva committed Nov 26, 2024
1 parent 9295764 commit cbd4284
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 89 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ categories = ["algorithms"]

[dependencies]
anyhow = "1.0.93"
ndarray = "0.16.1"
ndarray = { version = "0.16.1", features = ["rayon"] }
ndarray-stats = "0.6.0"
rand = "0.9.0-alpha.2"
ndarray-rand = "0.15.0"
rand_distr = "0.4.3"
rayon = "1.10.0"

[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
Expand All @@ -31,4 +32,4 @@ path = "src/bin/example.rs"

[[bin]]
name = "readme_example"
path = "src/bin/readme_example.rs"
path = "src/bin/readme_example.rs"
123 changes: 68 additions & 55 deletions src/pq.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::utils::{determine_code_type, euclidean_distance, kmeans2};
use anyhow::Result;
use ndarray::{s, Array2, Array3};
use ndarray::parallel::prelude::*;
use ndarray::{s, Array2, Array3, Axis};
use rayon::prelude::*;

#[derive(Debug, Clone, Copy)]
pub enum CodeType {
Expand Down Expand Up @@ -89,25 +91,32 @@ impl PQ {
let max_width = dims_width.iter().max().unwrap();
let mut codewords = Array3::<f32>::zeros((self.m, self.ks as usize, *max_width));

for m in 0..self.m {
if self.verbose {
println!(
"# Training the subspace: {} / {}, {} -> {}",
m,
self.m,
self.ds.as_ref().unwrap()[m],
self.ds.as_ref().unwrap()[m + 1]
);
}
let trained_codewords: Vec<(usize, Array2<f32>)> = (0..self.m)
.into_par_iter()
.map(|m| {
if self.verbose {
println!(
"# Training the subspace: {} / {}, {} -> {}",
m,
self.m,
self.ds.as_ref().unwrap()[m],
self.ds.as_ref().unwrap()[m + 1]
);
}

let vecs_sub = vecs.slice(s![
..,
self.ds.as_ref().unwrap()[m]..self.ds.as_ref().unwrap()[m + 1]
]);
let ds_ref = self.ds.as_ref().unwrap();

let (centroids, _) = kmeans2(&vecs_sub.to_owned(), self.ks, iterations, "points")?;
let vecs_sub = vecs.slice(s![.., ds_ref[m]..ds_ref[m + 1]]);

let subspace_width = self.ds.as_ref().unwrap()[m + 1] - self.ds.as_ref().unwrap()[m];
let (centroids, _) = kmeans2(&vecs_sub.to_owned(), self.ks, iterations, "points")?;

Ok((m, centroids))
})
.collect::<Result<Vec<(usize, Array2<f32>)>>>()?;

for (m, centroids) in trained_codewords {
let ds_ref = self.ds.as_ref().unwrap();
let subspace_width = ds_ref[m + 1] - ds_ref[m];

codewords
.slice_mut(s![m, .., ..subspace_width])
Expand Down Expand Up @@ -141,41 +150,43 @@ impl PQ {
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Model not trained. Call fit() first"))?;

for m in 0..self.m {
let vecs_sub = vecs.slice(s![.., ds[m]..ds[m + 1]]);
let subspace_width = ds[m + 1] - ds[m];
let codewords_sub = codewords.slice(s![m, .., ..subspace_width]);

for (i, vec) in vecs_sub.rows().into_iter().enumerate() {
let mut min_dist = f32::INFINITY;
let mut min_idx = 0;

for (j, codeword) in codewords_sub.rows().into_iter().enumerate() {
let dist = euclidean_distance(&vec, &codeword);
if dist < min_dist {
min_dist = dist;
min_idx = j;
codes
.outer_iter_mut()
.into_par_iter()
.zip(vecs.outer_iter())
.for_each(|(mut code_row, vec)| {
for m in 0..self.m {
let subspace = vec.slice(s![ds[m]..ds[m + 1]]);
let subspace_width = ds[m + 1] - ds[m];
let codewords_sub = codewords.slice(s![m, .., ..subspace_width]);

let mut min_dist = f32::INFINITY;
let mut min_idx = 0;

for (j, codeword) in codewords_sub.axis_iter(Axis(0)).enumerate() {
let dist = euclidean_distance(&subspace, &codeword);
if dist < min_dist {
min_dist = dist;
min_idx = j;
}
}
}

codes[[i, m]] = min_idx as u32;
}
}
code_row[m] = min_idx as u32;
}
});

let codes = match self.code_dtype {
match self.code_dtype {
CodeType::U8 => {
if codes.iter().any(|&x| x > u8::MAX as u32) {
anyhow::bail!("Encoded values exceed U8 range");
}
codes
}
CodeType::U16 => {
if codes.iter().any(|&x| x > u16::MAX as u32) {
anyhow::bail!("Encoded values exceed U16 range");
}
codes
}
CodeType::U32 => codes,
CodeType::U32 => {}
};

Ok(codes)
Expand All @@ -202,23 +213,26 @@ impl PQ {

let mut vecs = Array2::<f32>::zeros((n_vectors, dim));

for m in 0..self.m {
let subspace_width = ds[m + 1] - ds[m];
vecs.outer_iter_mut()
.into_par_iter()
.zip(codes.outer_iter())
.try_for_each(|(mut vec_row, code_row)| -> Result<(), anyhow::Error> {
for m in 0..self.m {
let code_idx = code_row[m] as usize;
if code_idx >= self.ks as usize {
return Err(anyhow::anyhow!(
"Code value {} exceeds number of clusters {}",
code_idx,
self.ks
));
}
let subspace_width = ds[m + 1] - ds[m];
let codeword = codewords.slice(s![m, code_idx, ..subspace_width]);

for (i, code) in codes.column(m).iter().enumerate() {
let code_idx = *code as usize;
if code_idx >= self.ks as usize {
anyhow::bail!(
"Code value {} exceeds number of clusters {}",
code_idx,
self.ks
);
vec_row.slice_mut(s![ds[m]..ds[m + 1]]).assign(&codeword);
}

vecs.slice_mut(s![i, ds[m]..ds[m + 1]])
.assign(&codewords.slice(s![m, code_idx, ..subspace_width]));
}
}
Ok(())
})?;

Ok(vecs)
}
Expand All @@ -233,7 +247,6 @@ impl PQ {
mod tests {
use super::*;
use crate::utils::create_random_vectors;
use anyhow::Result;
use ndarray::Array2;

fn create_dummy_vectors(num_vectors: usize, dimension: usize) -> Array2<f32> {
Expand Down
74 changes: 42 additions & 32 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::pq::CodeType;
use anyhow::Result;
use ndarray::{s, Array1, Array2, ArrayView1, Axis};
use ndarray::parallel::prelude::*;
use ndarray::{Array1, Array2, ArrayView1, Axis};
use ndarray_stats::QuantileExt;
use rand::distr::{Distribution, Uniform};
use rand::seq::SliceRandom;
Expand Down Expand Up @@ -48,7 +49,7 @@ pub fn kmeans2(
_ => anyhow::bail!("Unsupported initialization method"),
};

let mut labels = Array1::zeros(n_samples);
let mut labels = Array1::<usize>::zeros(n_samples);
let mut old_centroids;
let mut has_converged = false;

Expand All @@ -59,39 +60,48 @@ pub fn kmeans2(

old_centroids = centroids.clone();

for (i, sample) in data.rows().into_iter().enumerate() {
let mut min_dist = f32::INFINITY;
let mut min_label = 0;

for (j, centroid) in centroids.rows().into_iter().enumerate() {
let dist = euclidean_distance(&sample, &centroid);
if dist < min_dist {
min_dist = dist;
min_label = j;
let labels_vec: Vec<usize> = data
.outer_iter()
.into_par_iter()
.map(|sample| {
let mut min_dist = f32::INFINITY;
let mut min_label = 0;

for (j, centroid) in centroids.outer_iter().enumerate() {
let dist = euclidean_distance(&sample, &centroid);
if dist < min_dist {
min_dist = dist;
min_label = j;
}
}
}
labels[i] = min_label;
}
min_label
})
.collect();

let mut new_centroids = Array2::zeros((k, n_features));
let mut counts = vec![0usize; k];

for (i, sample) in data.rows().into_iter().enumerate() {
let label = labels[i];
new_centroids.row_mut(label).add_assign(&sample);
counts[label] += 1;
}
labels = Array1::from(labels_vec);

for (i, count) in counts.iter().enumerate() {
if *count > 0 {
new_centroids.row_mut(i).mapv_inplace(|x| x / *count as f32);
} else {
let random_idx = rand::thread_rng().gen_range(0..n_samples);
new_centroids.row_mut(i).assign(&data.row(random_idx));
}
}

centroids = new_centroids;
let mut counts = vec![0usize; k];
let mut sums = vec![Array1::<f32>::zeros(n_features); k];

data.outer_iter()
.zip(labels.iter())
.for_each(|(sample, &label)| {
counts[label] += 1;
sums[label].add_assign(&sample);
});

centroids
.outer_iter_mut()
.into_par_iter()
.enumerate()
.for_each(|(i, mut centroid_row)| {
if counts[i] > 0 {
centroid_row.assign(&(sums[i].clone() / counts[i] as f32));
} else {
let random_idx = rand::thread_rng().gen_range(0..n_samples);
centroid_row.assign(&data.row(random_idx));
}
});

has_converged = check_convergence(&centroids, &old_centroids);
}
Expand Down

0 comments on commit cbd4284

Please sign in to comment.