From ca2fb3b46970f6d49212b8c272d6d86ccdb12b3b Mon Sep 17 00:00:00 2001 From: rakri Date: Wed, 23 Oct 2024 23:08:04 -0700 Subject: [PATCH] made some changes for num local start pts --- apps/search_memory_index.cpp | 9 ++++++++- include/index.h | 3 ++- src/index.cpp | 39 ++++++++++++++++++++---------------- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index aaee7731a..2f8d06cdd 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -447,7 +447,7 @@ int main(int argc, char **argv) { std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type, query_filters_file; - uint32_t num_threads, K, filter_penalty_threshold, bruteforce_threshold, clustering_threshold, L_for_print; + uint32_t num_threads, K, filter_penalty_threshold, bruteforce_threshold, clustering_threshold, L_for_print, num_local; std::vector Lvec; bool print_all_recalls, dynamic, tags, show_qps_per_thread, global_start; float fail_if_recall_below = 0.0f; @@ -499,6 +499,10 @@ int main(int argc, char **argv) optional_configs.add_options()("use_global_start", po::value(&global_start)->default_value(false), "Whether or not to use global start or predicate-aware starting point in graph search"); + optional_configs.add_options()("num_local_start", + po::value(&num_local)->default_value(0), + "How many local start points to use"); + optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), program_options_utils::LABEL_TYPE_DESCRIPTION); @@ -607,6 +611,9 @@ int main(int argc, char **argv) } use_global_start = global_start; + num_start_points = num_local; + + std::cout<<"Num local start points: " << num_start_points << std::endl; try { diff --git a/include/index.h b/include/index.h index 7ffa425bb..105153deb 100644 --- a/include/index.h +++ b/include/index.h @@ -49,6 +49,7 @@ inline int64_t curr_query = -1; inline uint32_t penalty_scale = 10; inline uint32_t num_sp = 2; inline bool use_global_start = false; +inline uint32_t num_start_points = 1; namespace diskann { @@ -277,7 +278,7 @@ template clas std::vector> sort_filter_counts(const std::vector &filter_label); - std::pair sample_intersection(roaring::Roaring &intersection_bitmap, + std::pair> sample_intersection(roaring::Roaring &intersection_bitmap, const std::vector &filter_label); std::unordered_map load_label_map(const std::string &map_file); diff --git a/src/index.cpp b/src/index.cpp index 378e38c14..a27f62bd5 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2541,7 +2541,7 @@ std::vector> Index::sort_filter_cou } template -std::pair Index::sample_intersection(roaring::Roaring &intersection_bitmap, +std::pair> Index::sample_intersection(roaring::Roaring &intersection_bitmap, const std::vector &filter_label) { intersection_bitmap = _labels_to_points_sample[filter_label[0]]; @@ -2551,12 +2551,16 @@ std::pair Index::sample_intersection(roarin } uint32_t val = std::numeric_limits::max(); auto x = intersection_bitmap.begin(); - if (x != intersection_bitmap.end()) + std::vector results; + results.reserve(num_start_points); + while (x != intersection_bitmap.end() && results.size() < num_start_points) { val = _sample_map[*x]; + results.emplace_back(val); + x++; } // std::cout< @@ -2666,17 +2670,18 @@ std::pair Index::search_with_filters(const case 2: num_graphs++; auto [inter_estim, cand] = sample_intersection(scratch->get_valid_bitmap(), filter_label); - if (!use_global_start) { - if (cand < std::numeric_limits::max()) + + if (cand.size() > 0) { - init_ids.emplace_back(cand); - } else { + init_ids.insert(init_ids.end(), cand.begin(), cand.end()); +// init_ids.emplace_back(cand); + } /*else { if (_label_to_start_id.find(filter_label[0]) != _label_to_start_id.end()) { init_ids.emplace_back(_label_to_start_id[filter_label[0]]); } - } - } else { + } */ + if (use_global_start) { init_ids.emplace_back(_start); } @@ -2685,7 +2690,7 @@ std::pair Index::search_with_filters(const { std::ofstream out("query_stats.txt", std::ios_base::app); out << "estimated intersection size is " << inter_estim << std::endl; - out << "setting up init ids with id " << cand << std::endl; + //out << "setting up init ids with id " << cand << std::endl; out.close(); } retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true); @@ -2737,17 +2742,17 @@ std::pair Index::search_with_filters(const /* if (_dynamic_index) */ /* tl.unlock(); */ - if (!use_global_start) { - if (cand < std::numeric_limits::max()) + if (cand.size() > 0) { - init_ids.emplace_back(cand); - } else { + init_ids.insert(init_ids.end(), cand.begin(), cand.end()); +// init_ids.emplace_back(cand); + } /*else { if (_label_to_start_id.find(filter_label[0]) != _label_to_start_id.end()) { init_ids.emplace_back(_label_to_start_id[filter_label[0]]); } - } - } else { + }*/ + if (use_global_start) { init_ids.emplace_back(_start); } @@ -2760,7 +2765,7 @@ std::pair Index::search_with_filters(const out << filt << "/" << _labels_to_points[filt].cardinality() << " "; out << std::endl; out << "estimated intersection size is " << estimated_match << std::endl; - out << "setting up init ids with id " << cand << std::endl; + //out << "setting up init ids with id " << cand << std::endl; out << std::endl; out.close(); }