Skip to content

Commit

Permalink
Merge pull request #3 from oramasearch/feat/adds-log
Browse files Browse the repository at this point in the history
feat: use log crate instead of println
  • Loading branch information
micheleriva authored Nov 26, 2024
2 parents 019c5a6 + 7a166d3 commit 426c498
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 29 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ rand = "0.9.0-alpha.2"
ndarray-rand = "0.15.0"
rand_distr = "0.4.3"
rayon = "1.10.0"
log = "0.4.22"

[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ fn main() -> Result<()> {
// Configure PQ parameters
let m = 8; // Number of subspaces (controls compression ratio)
let ks = 256; // Number of centroids per subspace (usually 256 for uint8)
let mut pq = PQ::try_new(m, ks, Some(true))?;
let mut pq = PQ::try_new(m, ks)?;

// Train the quantizer on the data
println!("Training PQ model...");
Expand Down
2 changes: 1 addition & 1 deletion src/bin/example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn main() -> Result<()> {
let ks = 256; // Number of clusters per subspace
let verbose = Some(true);

let mut pq = PQ::try_new(m, ks, verbose)?;
let mut pq = PQ::try_new(m, ks)?;

// Step 3: Train the PQ Model
let iterations = 20; // Number of iterations for k-means
Expand Down
2 changes: 1 addition & 1 deletion src/bin/readme_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn main() -> Result<()> {
// Configure PQ parameters
let m = 8; // Number of subspaces (controls compression ratio)
let ks = 256; // Number of centroids per subspace (usually 256 for uint8)
let mut pq = PQ::try_new(m, ks, Some(true))?;
let mut pq = PQ::try_new(m, ks)?;

// Train the quantizer on the data
println!("Training PQ model...");
Expand Down
49 changes: 23 additions & 26 deletions src/pq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use anyhow::Result;
use ndarray::parallel::prelude::*;
use ndarray::{s, Array2, Array3, Axis};
use rayon::prelude::*;
use log::{debug, error, info, trace, warn};

#[derive(Debug, Clone, Copy)]
pub enum CodeType {
Expand All @@ -14,15 +15,14 @@ pub enum CodeType {
pub struct PQ {
m: usize,
ks: u32,
verbose: bool,
code_dtype: CodeType,
codewords: Option<Array3<f32>>,
ds: Option<Vec<usize>>,
dim: Option<usize>,
}

impl PQ {
pub fn try_new(m: usize, ks: u32, verbose: Option<bool>) -> Result<Self> {
pub fn try_new(m: usize, ks: u32) -> Result<Self> {
if ks == 0 {
anyhow::bail!(
"cluster subspaces (ks) must be a u32 between 1 and 2**32 - 1. Got {}",
Expand All @@ -37,7 +37,6 @@ impl PQ {
Ok(Self {
m,
ks,
verbose: verbose.unwrap_or(false),
code_dtype: determine_code_type(ks),
codewords: None,
ds: None,
Expand Down Expand Up @@ -94,15 +93,13 @@ impl PQ {
let trained_codewords: Vec<(usize, Array2<f32>)> = (0..self.m)
.into_par_iter()
.map(|m| {
if self.verbose {
println!(
"# Training the subspace: {} / {}, {} -> {}",
m,
self.m,
self.ds.as_ref().unwrap()[m],
self.ds.as_ref().unwrap()[m + 1]
);
}
info!(
"Training the subspace: {} / {}, {} -> {}",
m,
self.m,
self.ds.as_ref().unwrap()[m],
self.ds.as_ref().unwrap()[m + 1]
);

let ds_ref = self.ds.as_ref().unwrap();

Expand Down Expand Up @@ -256,13 +253,13 @@ mod tests {
// Edge case: ks is zero or exceeds u32 limits.
#[test]
fn test_try_new_invalid_ks_zero() {
let pq = PQ::try_new(4, 0, None);
let pq = PQ::try_new(4, 0);
assert!(pq.is_err(), "Initialization should fail when ks is zero");
}

#[test]
fn test_try_new_invalid_ks_max() {
let pq = PQ::try_new(4, u32::MAX, None);
let pq = PQ::try_new(4, u32::MAX);
assert!(
pq.is_ok(),
"Initialization should succeed when ks is u32::MAX"
Expand All @@ -272,7 +269,7 @@ mod tests {
// Edge Case: m is zero.
#[test]
fn test_try_new_invalid_m_zero() {
let pq = PQ::try_new(0, 256, None);
let pq = PQ::try_new(0, 256);
assert!(
pq.is_err(),
"Initialization should fail when m is zero, but it succeeded"
Expand All @@ -282,7 +279,7 @@ mod tests {
// Edge Case: Number of training vectors is less than ks.
#[test]
fn test_fit_vectors_less_than_ks() {
let mut pq = PQ::try_new(4, 256, None).unwrap();
let mut pq = PQ::try_new(4, 256).unwrap();
let vecs = create_dummy_vectors(100, 128); // Less than ks
let result = pq.fit(&vecs, 10);
assert!(
Expand All @@ -294,7 +291,7 @@ mod tests {
// Edge Case: Vectors have zero dimensions or m exceeds vector dimensions.
#[test]
fn test_fit_zero_dimensions() {
let mut pq = PQ::try_new(4, 256, None).unwrap();
let mut pq = PQ::try_new(4, 256).unwrap();
let vecs = create_dummy_vectors(1000, 0); // Zero dimensions
let result = pq.fit(&vecs, 10);
assert!(
Expand All @@ -305,7 +302,7 @@ mod tests {

#[test]
fn test_fit_m_greater_than_dimensions() {
let mut pq = PQ::try_new(200, 256, None).unwrap();
let mut pq = PQ::try_new(200, 256).unwrap();
let vecs = create_dummy_vectors(1000, 128); // m > dimensions
let result = pq.fit(&vecs, 10);
assert!(
Expand All @@ -317,7 +314,7 @@ mod tests {
// Edge Case: Calling encode before fit.
#[test]
fn test_encode_without_fit() {
let pq = PQ::try_new(4, 256, None).unwrap();
let pq = PQ::try_new(4, 256).unwrap();
let vecs = create_dummy_vectors(1000, 128);
let result = pq.encode(&vecs);
assert!(
Expand All @@ -329,7 +326,7 @@ mod tests {
// Edge Case: Vectors have different dimensions than those used in fit.
#[test]
fn test_encode_mismatched_dimensions() {
let mut pq = PQ::try_new(4, 256, None).unwrap();
let mut pq = PQ::try_new(4, 256).unwrap();
let train_vecs = create_dummy_vectors(1000, 128);
pq.fit(&train_vecs, 10).unwrap();

Expand All @@ -344,7 +341,7 @@ mod tests {
// Edge Case: Codes have incorrect dimensions or contain invalid values.
#[test]
fn test_decode_invalid_code_m() {
let mut pq = PQ::try_new(4, 256, None).unwrap();
let mut pq = PQ::try_new(4, 256).unwrap();
let train_vecs = create_dummy_vectors(1000, 128);
pq.fit(&train_vecs, 10).unwrap();

Expand All @@ -358,7 +355,7 @@ mod tests {

#[test]
fn test_decode_code_value_exceeds_ks() {
let mut pq = PQ::try_new(4, 256, None).unwrap();
let mut pq = PQ::try_new(4, 256).unwrap();
let train_vecs = create_dummy_vectors(1000, 128);
pq.fit(&train_vecs, 10).unwrap();

Expand All @@ -374,7 +371,7 @@ mod tests {
// Edge Case: Ensuring compress works end-to-end.
#[test]
fn test_compress() {
let mut pq = PQ::try_new(4, 256, None).unwrap();
let mut pq = PQ::try_new(4, 256).unwrap();
let vecs = create_dummy_vectors(1000, 128);
pq.fit(&vecs, 10).unwrap();

Expand All @@ -389,7 +386,7 @@ mod tests {
// Edge Case: Ensuring code values fit within specified data types.
#[test]
fn test_encode_code_dtype_u8_overflow() {
let mut pq = PQ::try_new(4, 300, None).unwrap(); // ks exceeds u8::MAX
let mut pq = PQ::try_new(4, 300).unwrap(); // ks exceeds u8::MAX
pq.code_dtype = CodeType::U8;
let vecs = create_random_vectors(1000, 128);
pq.fit(&vecs, 10).unwrap();
Expand All @@ -403,7 +400,7 @@ mod tests {

#[test]
fn test_encode_code_dtype_u16_overflow() {
let mut pq = PQ::try_new(4, 70000, None).unwrap();
let mut pq = PQ::try_new(4, 70000).unwrap();
pq.code_dtype = CodeType::U16;
pq.codewords = Some(Array3::zeros((pq.m, pq.ks as usize, 128 / pq.m)));
pq.dim = Some(128);
Expand All @@ -418,7 +415,7 @@ mod tests {

#[test]
fn test_encode_code_dtype_u8_valid() {
let mut pq = PQ::try_new(4, 200, None).unwrap(); // ks within u8::MAX
let mut pq = PQ::try_new(4, 200).unwrap(); // ks within u8::MAX
pq.code_dtype = CodeType::U8;
let vecs = create_random_vectors(1000, 128);
pq.fit(&vecs, 10).unwrap();
Expand Down

0 comments on commit 426c498

Please sign in to comment.