Skip to content

Commit

Permalink
Add more comments and uts for #353 (#358)
Browse files Browse the repository at this point in the history
Signed-off-by: chasingegg <[email protected]>
  • Loading branch information
chasingegg authored Jan 22, 2024
1 parent b14a4d6 commit 170e88e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/index/ivf/ivf_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class IvfConfig : public BaseConfig {
CFG_INT nlist;
CFG_INT nprobe;
CFG_BOOL use_elkan;
CFG_BOOL ensure_topk_full;
CFG_BOOL ensure_topk_full; // only take affect on temp index(IVF_FLAT_CC) now
KNOHWERE_DECLARE_CONFIG(IvfConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(nlist)
.set_default(128)
Expand Down
36 changes: 31 additions & 5 deletions tests/ut/test_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,21 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") {
using std::make_tuple;
auto ivfflatcc_gen_ = [base_gen, nb]() {
knowhere::Json json = base_gen();
json[knowhere::indexparam::NLIST] = 16;
json[knowhere::indexparam::NLIST] = 32;
json[knowhere::indexparam::NPROBE] = 1;
json[knowhere::indexparam::SSIZE] = 48;
json[knowhere::meta::TOPK] = nb;
return json;
};
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen_),
}));
auto ivfflatcc_gen_no_ensure_topk_ = [ivfflatcc_gen_, nb]() {
knowhere::Json json = ivfflatcc_gen_();
json[knowhere::meta::TOPK] = nb / 2;
json[knowhere::indexparam::ENSURE_TOPK_FULL] = false;
return json;
};
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>(
{make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen_),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen_no_ensure_topk_)}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
Expand All @@ -235,7 +241,27 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") {
auto results = idx.Search(*query_ds, json, nullptr);
auto gt = knowhere::BruteForce::Search<knowhere::fp32>(train_ds, query_ds, json, nullptr);
float recall = GetKNNRecall(*gt.value(), *results.value());
REQUIRE(recall > kBruteForceRecallThreshold);
if (ivfflatcc_gen_().dump() == cfg_json) {
REQUIRE(recall > kBruteForceRecallThreshold);
} else {
REQUIRE(recall < kBruteForceRecallThreshold);
}

std::vector<std::function<std::vector<uint8_t>(size_t, size_t)>> gen_bitset_funcs = {
GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet};
const auto bitset_percentages = 0.5f;
for (const auto& gen_func : gen_bitset_funcs) {
auto bitset_data = gen_func(nb, bitset_percentages * nb);
knowhere::BitsetView bitset(bitset_data.data(), nb);
auto results = idx.Search(*query_ds, json, bitset);
auto gt = knowhere::BruteForce::Search<knowhere::fp32>(train_ds, query_ds, json, bitset);
float recall = GetKNNRecall(*gt.value(), *results.value());
if (ivfflatcc_gen_().dump() == cfg_json) {
REQUIRE(recall > kBruteForceRecallThreshold);
} else {
REQUIRE(recall < kBruteForceRecallThreshold);
}
}
}

SECTION("Test Search with Bitset") {
Expand Down
8 changes: 7 additions & 1 deletion thirdparty/faiss/faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,12 @@ struct Level1Quantizer {
struct SearchParametersIVF : SearchParameters {
size_t nprobe = 1; ///< number of probes at query time
size_t max_codes = 0; ///< max nb of codes to visit to do a query
bool ensure_topk_full = false; ///< indicate whether we make sure topk result is full
///< indicate whether we should early teriminate before topk results full when search reaches max_codes
///< to minimize code change, when users only use nprobe to search, this config does not take affect since we will first retrieve the nearest nprobe buckets
///< it is a bit heavy to further retrieve more buckets
///< therefore to make sure we get topk results, use nprobe=nlist and use max_codes to narrow down the search range
bool ensure_topk_full = false;

SearchParameters* quantizer_params = nullptr;

virtual ~SearchParametersIVF() {}
Expand Down Expand Up @@ -485,6 +490,7 @@ struct InvertedListScanner {
* @param distances heap distances (size k)
* @param labels heap labels (size k)
* @param k heap size
* @param scan_cnt valid number of codes be scanned
* @return number of heap updates performed
*/
virtual size_t scan_codes(
Expand Down

0 comments on commit 170e88e

Please sign in to comment.