Skip to content

Commit

Permalink
Merge pull request #41 from ava57r/box-index
Browse files Browse the repository at this point in the history
Add Box impl
  • Loading branch information
Enet4 authored Jul 22, 2021
2 parents 5592fb9 + eb05127 commit 4aba00d
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/index/flat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,25 @@ mod tests {
assert_eq!(xb.len(), 8 * 5);
}

#[test]
fn flat_index_boxed() {
let mut index = FlatIndexImpl::new_l2(8).unwrap();
assert_eq!(index.is_trained(), true); // Flat index does not need training
let some_data = &[
7.5_f32, -7.5, 7.5, -7.5, 7.5, 7.5, 7.5, 7.5, -1., 1., 1., 1., 1., 1., 1., -1., 0., 0.,
0., 1., 1., 0., 0., -1., 100., 100., 100., 100., -100., 100., 100., 100., 120., 100.,
100., 105., -100., 100., 100., 105.,
];
index.add(some_data).unwrap();
assert_eq!(index.ntotal(), 5);

let boxed = Box::new(index);
assert_eq!(boxed.is_trained(), true);
assert_eq!(boxed.ntotal(), 5);
let xb = boxed.xb();
assert_eq!(xb.len(), 8 * 5);
}

#[test]
fn index_verbose() {
let mut index = FlatIndexImpl::new_l2(D).unwrap();
Expand Down
91 changes: 91 additions & 0 deletions src/index/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,79 @@ pub trait Index {
fn set_verbose(&mut self, value: bool);
}

impl<I> Index for Box<I>
where
I: Index,
{
fn is_trained(&self) -> bool {
(**self).is_trained()
}

fn ntotal(&self) -> u64 {
(**self).ntotal()
}

fn d(&self) -> u32 {
(**self).d()
}

fn metric_type(&self) -> MetricType {
(**self).metric_type()
}

fn add(&mut self, x: &[f32]) -> Result<()> {
(**self).add(x)
}

fn add_with_ids(&mut self, x: &[f32], xids: &[Idx]) -> Result<()> {
(**self).add_with_ids(x, xids)
}

fn train(&mut self, x: &[f32]) -> Result<()> {
(**self).train(x)
}

fn assign(&mut self, q: &[f32], k: usize) -> Result<AssignSearchResult> {
(**self).assign(q, k)
}

fn search(&mut self, q: &[f32], k: usize) -> Result<SearchResult> {
(**self).search(q, k)
}

fn range_search(&mut self, q: &[f32], radius: f32) -> Result<RangeSearchResult> {
(**self).range_search(q, radius)
}

fn reset(&mut self) -> Result<()> {
(**self).reset()
}

fn remove_ids(&mut self, sel: &IdSelector) -> Result<usize> {
(**self).remove_ids(sel)
}

fn verbose(&self) -> bool {
(**self).verbose()
}

fn set_verbose(&mut self, value: bool) {
(**self).set_verbose(value)
}
}

/// Sub-trait for native implementations of a Faiss index.
pub trait NativeIndex: Index {
/// Retrieve a pointer to the native index object.
fn inner_ptr(&self) -> *mut FaissIndex;
}

impl<NI: NativeIndex> NativeIndex for Box<NI> {
fn inner_ptr(&self) -> *mut FaissIndex {
(**self).inner_ptr()
}
}

/// Trait for a Faiss index that can be safely searched over multiple threads.
/// Operations which do not modify the index are given a method taking an
/// immutable reference. This is not the default for every index type because
Expand All @@ -212,9 +279,25 @@ pub trait ConcurrentIndex: Index {
fn range_search(&self, q: &[f32], radius: f32) -> Result<RangeSearchResult>;
}

impl<CI: ConcurrentIndex> ConcurrentIndex for Box<CI> {
fn assign(&self, q: &[f32], k: usize) -> Result<AssignSearchResult> {
(**self).assign(q, k)
}

fn search(&self, q: &[f32], k: usize) -> Result<SearchResult> {
(**self).search(q, k)
}

fn range_search(&self, q: &[f32], radius: f32) -> Result<RangeSearchResult> {
(**self).range_search(q, radius)
}
}

/// Trait for Faiss index types known to be running on the CPU.
pub trait CpuIndex: Index {}

impl<CI: CpuIndex> CpuIndex for Box<CI> {}

/// Trait for Faiss index types which can be built from a pointer
/// to a native implementation.
pub trait FromInnerPtr: NativeIndex {
Expand Down Expand Up @@ -427,6 +510,14 @@ mod tests {
assert_eq!(index.ntotal(), 0);
}

#[test]
fn index_factory_flat_boxed() {
let index = index_factory(64, "Flat", MetricType::L2).unwrap();
let boxed = Box::new(index);
assert_eq!(boxed.is_trained(), true); // Flat index does not need training
assert_eq!(boxed.ntotal(), 0);
}

#[test]
fn index_factory_ivf_flat() {
let index = index_factory(64, "IVF8,Flat", MetricType::L2).unwrap();
Expand Down

0 comments on commit 4aba00d

Please sign in to comment.