diff --git a/include/abstract_filter_store.h b/include/abstract_filter_store.h index 0e6316547..7d654f8e1 100644 --- a/include/abstract_filter_store.h +++ b/include/abstract_filter_store.h @@ -34,11 +34,6 @@ template class AbstractFilterStore // returns internal mapping for given raw_label DISKANN_DLLEXPORT virtual label_type get_numeric_label(const std::string &raw_label) = 0; - DISKANN_DLLEXPORT virtual void update_medoid_by_label(const label_type &label, const uint32_t new_medoid) = 0; - DISKANN_DLLEXPORT virtual const uint32_t &get_medoid_by_label(const label_type &label) = 0; - DISKANN_DLLEXPORT virtual const std::unordered_map &get_labels_to_medoids() = 0; - DISKANN_DLLEXPORT virtual bool label_has_medoid(const label_type &label) = 0; - // TODO: in future we may accept a set or vector of universal labels // DISKANN_DLLEXPORT virtual void set_universal_label(label_type universal_label) = 0; DISKANN_DLLEXPORT virtual void set_universal_labels(const std::string &universal_labels) = 0; @@ -52,14 +47,12 @@ template class AbstractFilterStore // For dynamic filtered build, we compact the data and hence location_to_labels, we need the compacted version of // raw labels to compute GT correctly. DISKANN_DLLEXPORT virtual void save_raw_labels(const std::string &save_path, const size_t total_points) = 0; - DISKANN_DLLEXPORT virtual void save_medoids(const std::string &save_path) = 0; DISKANN_DLLEXPORT virtual void save_label_map(const std::string &save_path) = 0; DISKANN_DLLEXPORT virtual void save_universal_label(const std::string &save_path) = 0; protected: // This is for internal use and only loads already parsed file DISKANN_DLLEXPORT virtual size_t load_labels(const std::string &labels_file) = 0; - DISKANN_DLLEXPORT virtual size_t load_medoids(const std::string &labels_to_medoid_file) = 0; DISKANN_DLLEXPORT virtual void load_label_map(const std::string &labels_map_file) = 0; DISKANN_DLLEXPORT virtual void load_universal_labels(const std::string &universal_labels_file) = 0; diff --git a/include/in_mem_filter_store.h b/include/in_mem_filter_store.h index 98c6a7408..55a8c5f10 100644 --- a/include/in_mem_filter_store.h +++ b/include/in_mem_filter_store.h @@ -29,12 +29,6 @@ template class InMemFilterStore : public AbstractFilterSto // returns internal mapping for given raw_label label_type get_numeric_label(const std::string &raw_label) override; - // Mode medoids related function to index class - void update_medoid_by_label(const label_type &label, const uint32_t new_medoid) override; - const uint32_t &get_medoid_by_label(const label_type &label) override; - const std::unordered_map &get_labels_to_medoids() override; - bool label_has_medoid(const label_type &label) override; - // takes raw universal labels and map them internally. void set_universal_labels(const std::string &raw_universal_labels) override; std::pair get_universal_label() override; @@ -46,7 +40,6 @@ template class InMemFilterStore : public AbstractFilterSto // For dynamic filtered build, we compact the data and hence location_to_labels, we need the compacted version of // raw labels to compute GT correctly. void save_raw_labels(const std::string &save_path, const size_t total_points) override; - void save_medoids(const std::string &save_path) override; void save_label_map(const std::string &save_path) override; void save_universal_label(const std::string &save_path) override; @@ -58,7 +51,6 @@ template class InMemFilterStore : public AbstractFilterSto protected: // This is for internal use and only loads already parsed file, used by index in during load(). size_t load_labels(const std::string &labels_file) override; - size_t load_medoids(const std::string &labels_to_medoid_file) override; void load_label_map(const std::string &labels_map_file) override; void load_universal_labels(const std::string &universal_labels_file) override; @@ -68,12 +60,6 @@ template class InMemFilterStore : public AbstractFilterSto tsl::robin_set _labels; std::unordered_map _label_map; - // medoids - - // move medoids to Index class since its property of index - std::unordered_map _label_to_medoid_id; - std::unordered_map _medoid_counts; // medoids only happen for filtered index - // universal label bool _has_universal_label = false; label_type _universal_label; diff --git a/include/index.h b/include/index.h index 021ff8f2f..077a29da2 100644 --- a/include/index.h +++ b/include/index.h @@ -255,6 +255,11 @@ template clas // Calculate best medoids for filter data void calculate_best_medoids(const size_t num_points_to_load, const uint32_t num_candidates); + //load medoids + size_t load_medoids(const std::string &labels_to_medoid_file); + //save medoids + void save_medoids(const std::string &save_path); + // The query to use is placed in scratch->aligned_query std::pair iterate_to_fixed_point(InMemQueryScratch *scratch, const uint32_t Lindex, const std::vector &init_ids, bool use_filter, diff --git a/src/in_mem_filter_store.cpp b/src/in_mem_filter_store.cpp index ee8f8c16a..a96cf4039 100644 --- a/src/in_mem_filter_store.cpp +++ b/src/in_mem_filter_store.cpp @@ -57,29 +57,6 @@ template void InMemFilterStore::add_to_label_s _labels.insert(label); } -template -void InMemFilterStore::update_medoid_by_label(const label_type &label, const uint32_t new_medoid) -{ - _label_to_medoid_id[label] = new_medoid; -} - -template -const uint32_t &InMemFilterStore::get_medoid_by_label(const label_type &label) -{ - return _label_to_medoid_id[label]; -} - -template -const std::unordered_map &InMemFilterStore::get_labels_to_medoids() -{ - return _label_to_medoid_id; -} - -template bool InMemFilterStore::label_has_medoid(const label_type &label) -{ - return _label_to_medoid_id.find(label) != _label_to_medoid_id.end(); -} - template void InMemFilterStore::add_label_to_location(const location_t point_id, const label_type label) { @@ -137,41 +114,6 @@ template size_t InMemFilterStore::load_labels( return parse_label_file(labels_file); } -template -size_t InMemFilterStore::load_medoids(const std::string &labels_to_medoid_file) -{ - if (file_exists(labels_to_medoid_file)) - { - std::ifstream medoid_stream(labels_to_medoid_file); - std::string line, token; - uint32_t line_cnt = 0; - - _label_to_medoid_id.clear(); - while (std::getline(medoid_stream, line)) - { - std::istringstream iss(line); - uint32_t cnt = 0; - uint32_t medoid = 0; - label_type label; - while (std::getline(iss, token, ',')) - { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - label_type token_as_num = (label_type)std::stoul(token); - if (cnt == 0) - label = token_as_num; - else - medoid = token_as_num; - cnt++; - } - _label_to_medoid_id[label] = medoid; - line_cnt++; - } - return (size_t)line_cnt; - } - throw ANNException("ERROR: can not load medoids file does not exist", -1); -} - template void InMemFilterStore::load_label_map(const std::string &labels_map_file) { if (file_exists(labels_map_file)) @@ -290,23 +232,6 @@ template void InMemFilterStore::save_universal } } -template void InMemFilterStore::save_medoids(const std::string &save_path) -{ - if (_label_to_medoid_id.size() > 0) - { - std::ofstream medoid_writer(save_path); - if (medoid_writer.fail()) - { - throw diskann::ANNException(std::string("Failed to open medoid file ") + save_path, -1); - } - for (auto iter : _label_to_medoid_id) - { - medoid_writer << iter.first << ", " << iter.second << std::endl; - } - medoid_writer.close(); - } -} - template void InMemFilterStore::save_label_map(const std::string &save_path) { if (_label_map.empty()) diff --git a/src/index.cpp b/src/index.cpp index f7b0244ee..478158183 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -290,7 +290,7 @@ void Index::save(const char *filename, bool compact_before_save { if (_filtered_index) { - _filter_store->save_medoids(std::string(filename) + "_labels_to_medoids.txt"); + save_medoids(std::string(filename) + "_labels_to_medoids.txt"); _filter_store->save_label_map(std::string(filename) + "_labels_map.txt"); _filter_store->save_universal_label(std::string(filename) + "_universal_label.txt"); _filter_store->save_labels(std::string(filename) + "_labels.txt", _nd + _num_frozen_pts); @@ -537,7 +537,7 @@ void Index::load(const char *filename, uint32_t num_threads, ui _filter_store->load_label_map(labels_map_file); label_num_pts = _filter_store->load_labels(labels_file); assert(label_num_pts == data_file_num_pts); - _filter_store->load_medoids(labels_to_medoids); + load_medoids(labels_to_medoids); _filter_store->load_universal_labels(std::string(filename) + "_universal_label.txt"); } @@ -722,6 +722,59 @@ void Index::calculate_best_medoids(const size_t num_points_to_l } } +template +size_t Index::load_medoids(const std::string &labels_to_medoid_file) +{ + if (file_exists(labels_to_medoid_file)) + { + std::ifstream medoid_stream(labels_to_medoid_file); + std::string line, token; + uint32_t line_cnt = 0; + + _label_to_medoid_id.clear(); + while (std::getline(medoid_stream, line)) + { + std::istringstream iss(line); + uint32_t cnt = 0; + uint32_t medoid = 0; + LabelT label; + while (std::getline(iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + LabelT token_as_num = (LabelT)std::stoul(token); + if (cnt == 0) + label = token_as_num; + else + medoid = token_as_num; + cnt++; + } + _label_to_medoid_id[label] = medoid; + line_cnt++; + } + return (size_t)line_cnt; + } + throw ANNException("ERROR: can not load medoids file does not exist", -1); +} + +template +void Index::save_medoids(const std::string &medoid_file_name) +{ + if (_label_to_medoid_id.size() > 0) + { + std::ofstream medoid_writer(medoid_file_name); + if (medoid_writer.fail()) + { + throw diskann::ANNException(std::string("Failed to open medoid file ") + medoid_file_name, -1); + } + for (auto iter : _label_to_medoid_id) + { + medoid_writer << iter.first << ", " << iter.second << std::endl; + } + medoid_writer.close(); + } +} + // Find common filter between a node's labels and a given set of labels, while // taking into account universal label template @@ -939,7 +992,7 @@ void Index::search_for_point_and_prune(int location, uint32_t L tl.lock(); std::vector filter_specific_start_nodes; for (auto &x : _filter_store->get_labels_by_location(location)) - filter_specific_start_nodes.emplace_back(_filter_store->get_medoid_by_label(x)); + filter_specific_start_nodes.emplace_back(_label_to_medoid_id[x]); if (_dynamic_index) tl.unlock(); @@ -1893,9 +1946,9 @@ std::pair Index::search_with_filters(const if (_dynamic_index) tl.lock(); - if (_filter_store->label_has_medoid(filter_label)) + if (_label_to_medoid_id.find(filter_label) != _label_to_medoid_id.end()) { - init_ids.emplace_back(_filter_store->get_medoid_by_label(filter_label)); + init_ids.emplace_back(_label_to_medoid_id[filter_label]); } else { @@ -2303,12 +2356,12 @@ template void Indexget_labels_to_medoids()) + for (auto &[label, medoid_id] : _label_to_medoid_id) { /* if (label == _universal_label) continue;*/ - _filter_store->update_medoid_by_label(label, (uint32_t)_nd + (medoid_id - (uint32_t)_max_points)); - //_label_to_start_id[label] = (uint32_t)_nd + (medoid_id - (uint32_t)_max_points); + uint32_t medoid = (uint32_t)_nd + (medoid_id - (uint32_t)_max_points); + _label_to_medoid_id[label] = medoid; } } } @@ -2608,12 +2661,9 @@ template void Indexget_labels_to_medoids()) + for (auto &[label, medoid_id] : _label_to_medoid_id) { - /*if (label == _universal_label) - continue;*/ - _filter_store->update_medoid_by_label(label, (uint32_t)_max_points + (medoid_id - (uint32_t)_nd)); - //_label_to_medoid_id[label] = (uint32_t)_max_points + (medoid_id - (uint32_t)_nd); + _label_to_medoid_id[label] = (uint32_t)_max_points + (medoid_id - (uint32_t)_nd); } } } @@ -2732,7 +2782,7 @@ int Index::insert_point(const T *point, const TagT tag, const s auto fz_location = (int)(_max_points) + _frozen_pts_used; // as first _fz_point _filter_store->add_to_label_set(label); - _filter_store->update_medoid_by_label(label, (uint32_t)fz_location); + _label_to_medoid_id[label] = (uint32_t)fz_location; std::vector fz_label = {label}; _filter_store->set_labels_to_location((location_t)fz_location, {label_str}); //_label_to_start_id[label] = (uint32_t)fz_location;