From 433de9709c548b795fabc80132397c59897b7385 Mon Sep 17 00:00:00 2001 From: Gao Date: Sat, 20 Jan 2024 04:42:47 +0800 Subject: [PATCH] Ensure topk results for IVF_FLAT_CC (#353) Signed-off-by: chasingegg --- include/knowhere/comp/index_param.h | 1 + src/index/ivf/ivf.cc | 15 ++++++++--- src/index/ivf/ivf_config.h | 5 ++++ tests/ut/test_ivfflat_cc.cc | 1 + tests/ut/test_search.cc | 26 +++++++++++++++++++ thirdparty/faiss/faiss/IndexIVF.cpp | 16 ++++++++---- thirdparty/faiss/faiss/IndexIVF.h | 4 ++- thirdparty/faiss/faiss/IndexIVFFlat.cpp | 12 ++++++--- thirdparty/faiss/faiss/IndexIVFPQ.cpp | 3 ++- .../faiss/faiss/IndexIVFSpectralHash.cpp | 5 ++-- .../faiss/faiss/IndexScalarQuantizer.cpp | 3 ++- .../faiss/faiss/impl/ScalarQuantizerScanner.h | 6 +++-- thirdparty/faiss/tests/test_lowlevel_ivf.cpp | 14 ++++++---- 13 files changed, 87 insertions(+), 24 deletions(-) diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 046ce195c..1b922fba8 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -85,6 +85,7 @@ constexpr const char* M = "m"; // PQ param for IVFPQ constexpr const char* SSIZE = "ssize"; constexpr const char* REORDER_K = "reorder_k"; constexpr const char* WITH_RAW_DATA = "with_raw_data"; +constexpr const char* ENSURE_TOPK_FULL = "ensure_topk_full"; // RAFT Params constexpr const char* REFINE_RATIO = "refine_ratio"; constexpr const char* CACHE_DATASET_ON_DEVICE = "cache_dataset_on_device"; diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index c01550800..603f49a1b 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -541,7 +541,7 @@ IvfIndexNode::Search(const DataSet& dataset, const Config& distances[i + offset] = static_cast(i_distances[i + offset]); } } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto cur_query = (const float*)data + index * dim; if (is_cosine) { copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); @@ -549,9 +549,18 @@ IvfIndexNode::Search(const DataSet& dataset, const Config& } faiss::IVFSearchParameters ivf_search_params; - ivf_search_params.nprobe = nprobe; - ivf_search_params.max_codes = 0; + ivf_search_params.sel = id_selector; + ivf_search_params.ensure_topk_full = ivf_cfg.ensure_topk_full.value(); + if (ivf_search_params.ensure_topk_full) { + ivf_search_params.nprobe = index_->nlist; + // use max_codes to early termination + ivf_search_params.max_codes = + (nprobe * 1.0 / index_->nlist) * (index_->ntotal - bitset.count()); + } else { + ivf_search_params.nprobe = nprobe; + ivf_search_params.max_codes = 0; + } index_->search(1, cur_query, k, distances + offset, ids + offset, &ivf_search_params); } else if constexpr (std::is_same::value) { diff --git a/src/index/ivf/ivf_config.h b/src/index/ivf/ivf_config.h index b0592f0fe..9de9c6231 100644 --- a/src/index/ivf/ivf_config.h +++ b/src/index/ivf/ivf_config.h @@ -21,6 +21,7 @@ class IvfConfig : public BaseConfig { CFG_INT nlist; CFG_INT nprobe; CFG_BOOL use_elkan; + CFG_BOOL ensure_topk_full; KNOHWERE_DECLARE_CONFIG(IvfConfig) { KNOWHERE_CONFIG_DECLARE_FIELD(nlist) .set_default(128) @@ -36,6 +37,10 @@ class IvfConfig : public BaseConfig { .set_default(true) .description("whether to use elkan algorithm") .for_train(); + KNOWHERE_CONFIG_DECLARE_FIELD(ensure_topk_full) + .set_default(true) + .description("whether to make sure topk results full") + .for_search(); } }; diff --git a/tests/ut/test_ivfflat_cc.cc b/tests/ut/test_ivfflat_cc.cc index 8e06841c4..d04346d0e 100644 --- a/tests/ut/test_ivfflat_cc.cc +++ b/tests/ut/test_ivfflat_cc.cc @@ -48,6 +48,7 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") { knowhere::Json json = base_gen(); json[knowhere::indexparam::NLIST] = 128; json[knowhere::indexparam::NPROBE] = 16; + json[knowhere::indexparam::ENSURE_TOPK_FULL] = false; return json; }; diff --git a/tests/ut/test_search.cc b/tests/ut/test_search.cc index 45cfd5f4f..56e8f51f7 100644 --- a/tests/ut/test_search.cc +++ b/tests/ut/test_search.cc @@ -212,6 +212,32 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { REQUIRE(recall > kBruteForceRecallThreshold); } + SECTION("Test Search with IVFFLATCC ensure topk full") { + using std::make_tuple; + auto ivfflatcc_gen_ = [base_gen, nb]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::NLIST] = 16; + json[knowhere::indexparam::NPROBE] = 1; + json[knowhere::indexparam::SSIZE] = 48; + json[knowhere::meta::TOPK] = nb; + return json; + }; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen_), + })); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + REQUIRE(idx.Type() == name); + REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success); + + auto results = idx.Search(*query_ds, json, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, nullptr); + float recall = GetKNNRecall(*gt.value(), *results.value()); + REQUIRE(recall > kBruteForceRecallThreshold); + } + SECTION("Test Search with Bitset") { using std::make_tuple; auto [name, gen, threshold] = GENERATE_REF(table, float>({ diff --git a/thirdparty/faiss/faiss/IndexIVF.cpp b/thirdparty/faiss/faiss/IndexIVF.cpp index 284dd83d5..2edbaec49 100644 --- a/thirdparty/faiss/faiss/IndexIVF.cpp +++ b/thirdparty/faiss/faiss/IndexIVF.cpp @@ -416,6 +416,7 @@ void IndexIVF::search_preassigned( const idx_t unlimited_list_size = std::numeric_limits::max(); idx_t max_codes = params ? params->max_codes : this->max_codes; + bool ensure_topk_full = params ? params->ensure_topk_full : false; IDSelector* sel = params ? params->sel : nullptr; const IDSelectorRange* selr = dynamic_cast(sel); if (selr) { @@ -545,7 +546,7 @@ void IndexIVF::search_preassigned( return list_size; } else { - size_t scan_cnt = 0; + size_t scan_cnt = 0; // only record valid cnt size_t segment_num = invlists->get_segment_num(key); for (size_t segment_idx = 0; segment_idx < segment_num; segment_idx++) { @@ -570,8 +571,8 @@ void IndexIVF::search_preassigned( ids, simi, idxi, - k); - scan_cnt += segment_size; + k, + scan_cnt); } return scan_cnt; @@ -613,7 +614,9 @@ void IndexIVF::search_preassigned( simi, idxi, max_codes - nscan); - if (nscan >= max_codes) { + + // if ensure_topk_full enabled, also make sure nscan >= k, then stop search further + if (nscan >= max_codes && (!ensure_topk_full || nscan >= k)) { break; } } @@ -1306,13 +1309,15 @@ size_t InvertedListScanner::scan_codes( const idx_t* ids, float* simi, idx_t* idxi, - size_t k) const { + size_t k, + size_t& scan_cnt) const { size_t nup = 0; if (!keep_max) { for (size_t j = 0; j < list_size; j++) { // // todo aguzhva: use int64_t id instead of j ? if (!sel || sel->is_member(j)) { + scan_cnt++; float dis = distance_to_code(codes); if (code_norms) { dis /= code_norms[j]; @@ -1329,6 +1334,7 @@ size_t InvertedListScanner::scan_codes( for (size_t j = 0; j < list_size; j++) { // // todo aguzhva: use int64_t id instead of j ? if (!sel || sel->is_member(j)) { + scan_cnt++; float dis = distance_to_code(codes); if (code_norms) { dis /= code_norms[j]; diff --git a/thirdparty/faiss/faiss/IndexIVF.h b/thirdparty/faiss/faiss/IndexIVF.h index 9e4bf0f62..e2334e884 100644 --- a/thirdparty/faiss/faiss/IndexIVF.h +++ b/thirdparty/faiss/faiss/IndexIVF.h @@ -71,6 +71,7 @@ struct Level1Quantizer { struct SearchParametersIVF : SearchParameters { size_t nprobe = 1; ///< number of probes at query time size_t max_codes = 0; ///< max nb of codes to visit to do a query + bool ensure_topk_full = false; ///< indicate whether we make sure topk result is full SearchParameters* quantizer_params = nullptr; virtual ~SearchParametersIVF() {} @@ -493,7 +494,8 @@ struct InvertedListScanner { const idx_t* ids, float* distances, idx_t* labels, - size_t k) const; + size_t k, + size_t& scan_cnt) const; // same as scan_codes, using an iterator virtual size_t iterate_codes( diff --git a/thirdparty/faiss/faiss/IndexIVFFlat.cpp b/thirdparty/faiss/faiss/IndexIVFFlat.cpp index 875637185..f67b521d3 100644 --- a/thirdparty/faiss/faiss/IndexIVFFlat.cpp +++ b/thirdparty/faiss/faiss/IndexIVFFlat.cpp @@ -286,7 +286,8 @@ struct IVFFlatScanner : InvertedListScanner { const idx_t* ids, float* simi, idx_t* idxi, - size_t k) const override { + size_t k, + size_t& scan_cnt) const override { const float* list_vecs = (const float*)codes; size_t nup = 0; @@ -294,10 +295,11 @@ struct IVFFlatScanner : InvertedListScanner { auto filter = [&](const size_t j) { return (!use_sel || sel->is_member(ids[j])); }; - // the lambda that applies a filtered element. + // the lambda that applies a valid element. auto apply = [&](const float dis_in, const size_t j) { const float dis = (code_norms == nullptr) ? dis_in : (dis_in / code_norms[j]); + scan_cnt++; if (C::cmp(simi[0], dis)) { const int64_t id = store_pairs ? lo_build(list_no, j) : ids[j]; heap_replace_top(k, simi, idxi, dis, id); @@ -389,7 +391,8 @@ struct IVFFlatBitsetViewScanner : InvertedListScanner { const idx_t* __restrict ids, float* __restrict simi, idx_t* __restrict idxi, - size_t k) const override { + size_t k, + size_t& scan_cnt) const override { const float* list_vecs = (const float*)codes; size_t nup = 0; @@ -397,10 +400,11 @@ struct IVFFlatBitsetViewScanner : InvertedListScanner { auto filter = [&](const size_t j) { return (!use_sel || !bitset.test(ids[j])); }; - // the lambda that applies a filtered element. + // the lambda that applies a valid element. auto apply = [&](const float dis_in, const size_t j) { const float dis = (code_norms == nullptr) ? dis_in : (dis_in / code_norms[j]); + scan_cnt++; if (C::cmp(simi[0], dis)) { const int64_t id = store_pairs ? lo_build(list_no, j) : ids[j]; heap_replace_top(k, simi, idxi, dis, id); diff --git a/thirdparty/faiss/faiss/IndexIVFPQ.cpp b/thirdparty/faiss/faiss/IndexIVFPQ.cpp index d0020b6ec..8db22b7eb 100644 --- a/thirdparty/faiss/faiss/IndexIVFPQ.cpp +++ b/thirdparty/faiss/faiss/IndexIVFPQ.cpp @@ -1228,7 +1228,8 @@ struct IVFPQScanner : IVFPQScannerT, const idx_t* ids, float* heap_sim, idx_t* heap_ids, - size_t k) const override { + size_t k, + size_t& scan_cnt) const override { KnnSearchResults res = { /* key */ this->key, /* ids */ this->store_pairs ? nullptr : ids, diff --git a/thirdparty/faiss/faiss/IndexIVFSpectralHash.cpp b/thirdparty/faiss/faiss/IndexIVFSpectralHash.cpp index e5d9a8e82..9ec202d98 100644 --- a/thirdparty/faiss/faiss/IndexIVFSpectralHash.cpp +++ b/thirdparty/faiss/faiss/IndexIVFSpectralHash.cpp @@ -258,12 +258,13 @@ struct IVFScanner : InvertedListScanner { const idx_t* ids, float* simi, idx_t* idxi, - size_t k) const override { + size_t k, + size_t& scan_cnt) const override { size_t nup = 0; for (size_t j = 0; j < list_size; j++) { if (!sel || sel->is_member(ids[j])) { float dis = hc.compute(codes); - + scan_cnt++; if (dis < simi[0]) { int64_t id = store_pairs ? lo_build(list_no, j) : ids[j]; maxheap_replace_top(k, simi, idxi, dis, id); diff --git a/thirdparty/faiss/faiss/IndexScalarQuantizer.cpp b/thirdparty/faiss/faiss/IndexScalarQuantizer.cpp index e30158c07..4c96aaf11 100644 --- a/thirdparty/faiss/faiss/IndexScalarQuantizer.cpp +++ b/thirdparty/faiss/faiss/IndexScalarQuantizer.cpp @@ -77,7 +77,8 @@ void IndexScalarQuantizer::search( minheap_heapify(k, D, I); } scanner->set_query(x + i * d); - scanner->scan_codes(ntotal, codes.data(), nullptr, nullptr, D, I, k); + size_t scan_cnt = 0; + scanner->scan_codes(ntotal, codes.data(), nullptr, nullptr, D, I, k, scan_cnt); // re-order heap if (metric_type == METRIC_L2) { diff --git a/thirdparty/faiss/faiss/impl/ScalarQuantizerScanner.h b/thirdparty/faiss/faiss/impl/ScalarQuantizerScanner.h index fd6ad6005..2ac6cd77f 100644 --- a/thirdparty/faiss/faiss/impl/ScalarQuantizerScanner.h +++ b/thirdparty/faiss/faiss/impl/ScalarQuantizerScanner.h @@ -62,7 +62,8 @@ struct IVFSQScannerIP : InvertedListScanner { const idx_t* ids, float* simi, idx_t* idxi, - size_t k) const override { + size_t k, + size_t& scan_cnt) const override { size_t nup = 0; for (size_t j = 0; j < list_size; j++, codes += code_size) { @@ -215,7 +216,8 @@ struct IVFSQScannerL2 : InvertedListScanner { const idx_t* ids, float* simi, idx_t* idxi, - size_t k) const override { + size_t k, + size_t& scan_cnt) const override { size_t nup = 0; // // baseline diff --git a/thirdparty/faiss/tests/test_lowlevel_ivf.cpp b/thirdparty/faiss/tests/test_lowlevel_ivf.cpp index e28e2a946..3734fca27 100644 --- a/thirdparty/faiss/tests/test_lowlevel_ivf.cpp +++ b/thirdparty/faiss/tests/test_lowlevel_ivf.cpp @@ -176,14 +176,15 @@ void test_lowlevel_access(const char* index_key, MetricType metric) { // here we get the inverted lists from the InvertedLists // object but they could come from anywhere - + size_t scan_cnt = 0; scanner->scan_codes( il->list_size(list_no), InvertedLists::ScopedCodes(il, list_no).get(), InvertedLists::ScopedIds(il, list_no).get(), D.data(), I.data(), - k); + k, + scan_cnt); if (j == 0) { // all results so far come from list_no, so let's check if @@ -338,14 +339,15 @@ void test_lowlevel_access_binary(const char* index_key) { // here we get the inverted lists from the InvertedLists // object but they could come from anywhere - + size_t scan_cnt = 0; scanner->scan_codes( il->list_size(list_no), InvertedLists::ScopedCodes(il, list_no).get(), InvertedLists::ScopedIds(il, list_no).get(), D.data(), I.data(), - k); + k, + scan_cnt); if (j == 0) { // all results so far come from list_no, so let's check if @@ -500,13 +502,15 @@ void test_threaded_search(const char* index_key, MetricType metric) { continue; scanner->set_list(list_no, q_dis[i * nprobe + j]); + size_t scan_cnt = 0; scanner->scan_codes( il->list_size(list_no), InvertedLists::ScopedCodes(il, list_no).get(), InvertedLists::ScopedIds(il, list_no).get(), local_D, local_I, - k); + k, + scan_cnt); } };