Skip to content

Commit

Permalink
Merge pull request #8 from oramasearch/feat/implements-pq-residual
Browse files Browse the repository at this point in the history
tests: adds tests for pq_residual
  • Loading branch information
micheleriva authored Dec 5, 2024
2 parents 9aa9a08 + 56b44a3 commit ab0abc5
Showing 1 changed file with 166 additions and 0 deletions.
166 changes: 166 additions & 0 deletions src/pq_residual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,169 @@ impl PQResidual {
Ok(sum_residual)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::pq::{CodeType, PQ};
use ndarray::Array2;

fn create_dummy_pq(m: usize, ks: u32) -> PQ {
PQ::try_new(m, ks).unwrap()
}

#[test]
fn test_try_new_empty_pqs() {
let pqs: Vec<PQ> = vec![];
let result = PQResidual::try_new(pqs);
assert!(result.is_err(), "Should fail with empty pqs vector");
}

#[test]
fn test_try_new_single_pq() {
let pqs = vec![create_dummy_pq(4, 4)];
let result = PQResidual::try_new(pqs);
assert!(result.is_ok(), "Should succeed with a single PQ");
let residual = result.unwrap();
assert_eq!(residual.deep, 1);
assert_eq!(residual.m, 4);
assert_eq!(residual.code_dtype, CodeType::U8);
}

#[test]
fn test_try_new_multiple_pqs() {
let pqs = vec![create_dummy_pq(4, 4), create_dummy_pq(8, 4)];
let residual = PQResidual::try_new(pqs).expect("Should succeed with multiple PQs");
assert_eq!(residual.deep, 2);
assert_eq!(residual.m, 8);
}

#[test]
fn test_fit_with_small_data() {
let pqs = vec![create_dummy_pq(4, 4)];
let mut residual = PQResidual::try_new(pqs).unwrap();
let data = Array2::<f32>::from_shape_fn((10, 4), |(i, j)| (i * j) as f32);

let result = residual.fit(
&data,
5,
false,
&[],
&[],
false,
Some("testdata"),
None,
None,
);

assert!(
result.is_ok(),
"Fit should not fail on valid data with ks=4 and 10 samples"
);
}

#[test]
fn test_encode_decode_round_trip() {
let pqs = vec![create_dummy_pq(4, 4)];
let mut residual = PQResidual::try_new(pqs).unwrap();
let data = Array2::<f32>::from_shape_fn((20, 4), |(i, j)| (i + j) as f32);

residual.pqs[0]
.fit(&data, 5)
.expect("PQ fit should succeed with ks=4 and 20 samples");

let codes = residual
.encode(&data)
.expect("Encoding should succeed after fit");
let reconstructed = match codes {
Codes3D::U8(c) => residual.decode(&c.map(|&x| x as u32)).unwrap(),
Codes3D::U16(c) => residual.decode(&c.map(|&x| x as u32)).unwrap(),
Codes3D::U32(c) => residual.decode(&c).unwrap(),
};

assert_eq!(
reconstructed.dim(),
data.dim(),
"Decoded data shape mismatch"
);
}

#[test]
fn test_compress() {
let pqs = vec![create_dummy_pq(4, 4)];
let mut residual = PQResidual::try_new(pqs).unwrap();
let data = Array2::<f32>::from_shape_fn((20, 4), |(i, j)| (i * j) as f32);

residual.pqs[0]
.fit(&data, 5)
.expect("Fit should succeed with ks=4 and 20 samples");

let compressed = residual
.compress(&data)
.expect("Compress should succeed after fit");
assert_eq!(compressed.dim(), data.dim());
}

#[test]
fn test_decode_error_wrong_dimensions() {
let pqs = vec![create_dummy_pq(4, 4)];
let residual = PQResidual::try_new(pqs).unwrap();

let codes = Array3::<u32>::zeros((10, 2, residual.m));
let result = residual.decode(&codes);
assert!(result.is_err(), "Should fail if code depth doesn't match");
}

#[test]
fn test_error_before_fit() {
let pqs = vec![create_dummy_pq(4, 4)];
let residual = PQResidual::try_new(pqs).unwrap();
let data = Array2::<f32>::zeros((10, 4));

let encode_result = residual.encode(&data);
assert!(
encode_result.is_err(),
"Encode should fail if PQ not fitted"
);
}

#[test]
fn test_fit_with_zero_samples() {
let pqs = vec![create_dummy_pq(4, 4)];
let mut residual = PQResidual::try_new(pqs).unwrap();

let data = Array2::<f32>::zeros((0, 4));
let result = residual.fit(
&data,
5,
false,
&[],
&[],
false,
Some("zero_samples"),
None,
None,
);
assert!(result.is_err(), "Fit should fail with zero samples");
}

#[test]
fn test_fit_with_zero_dimensions() {
let pqs = vec![create_dummy_pq(4, 4)];
let mut residual = PQResidual::try_new(pqs).unwrap();

let data = Array2::<f32>::zeros((10, 0));
let result = residual.fit(
&data,
5,
false,
&[],
&[],
false,
Some("zero_dims"),
None,
None,
);
assert!(result.is_err(), "Fit should fail with zero dimensions");
}
}

0 comments on commit ab0abc5

Please sign in to comment.