Skip to content

Commit

Permalink
remove medoid related data and methods to index class
Browse files Browse the repository at this point in the history
  • Loading branch information
NeelamMahapatro committed Jan 24, 2024
1 parent 0b9118f commit 7462bd0
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 110 deletions.
7 changes: 0 additions & 7 deletions include/abstract_filter_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ template <typename label_type> class AbstractFilterStore
// returns internal mapping for given raw_label
DISKANN_DLLEXPORT virtual label_type get_numeric_label(const std::string &raw_label) = 0;

DISKANN_DLLEXPORT virtual void update_medoid_by_label(const label_type &label, const uint32_t new_medoid) = 0;
DISKANN_DLLEXPORT virtual const uint32_t &get_medoid_by_label(const label_type &label) = 0;
DISKANN_DLLEXPORT virtual const std::unordered_map<label_type, uint32_t> &get_labels_to_medoids() = 0;
DISKANN_DLLEXPORT virtual bool label_has_medoid(const label_type &label) = 0;

// TODO: in future we may accept a set or vector of universal labels
// DISKANN_DLLEXPORT virtual void set_universal_label(label_type universal_label) = 0;
DISKANN_DLLEXPORT virtual void set_universal_labels(const std::string &universal_labels) = 0;
Expand All @@ -52,14 +47,12 @@ template <typename label_type> class AbstractFilterStore
// For dynamic filtered build, we compact the data and hence location_to_labels, we need the compacted version of
// raw labels to compute GT correctly.
DISKANN_DLLEXPORT virtual void save_raw_labels(const std::string &save_path, const size_t total_points) = 0;
DISKANN_DLLEXPORT virtual void save_medoids(const std::string &save_path) = 0;
DISKANN_DLLEXPORT virtual void save_label_map(const std::string &save_path) = 0;
DISKANN_DLLEXPORT virtual void save_universal_label(const std::string &save_path) = 0;

protected:
// This is for internal use and only loads already parsed file
DISKANN_DLLEXPORT virtual size_t load_labels(const std::string &labels_file) = 0;
DISKANN_DLLEXPORT virtual size_t load_medoids(const std::string &labels_to_medoid_file) = 0;
DISKANN_DLLEXPORT virtual void load_label_map(const std::string &labels_map_file) = 0;
DISKANN_DLLEXPORT virtual void load_universal_labels(const std::string &universal_labels_file) = 0;

Expand Down
14 changes: 0 additions & 14 deletions include/in_mem_filter_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@ template <typename label_type> class InMemFilterStore : public AbstractFilterSto
// returns internal mapping for given raw_label
label_type get_numeric_label(const std::string &raw_label) override;

// Mode medoids related function to index class
void update_medoid_by_label(const label_type &label, const uint32_t new_medoid) override;
const uint32_t &get_medoid_by_label(const label_type &label) override;
const std::unordered_map<label_type, uint32_t> &get_labels_to_medoids() override;
bool label_has_medoid(const label_type &label) override;

// takes raw universal labels and map them internally.
void set_universal_labels(const std::string &raw_universal_labels) override;
std::pair<bool, label_type> get_universal_label() override;
Expand All @@ -46,7 +40,6 @@ template <typename label_type> class InMemFilterStore : public AbstractFilterSto
// For dynamic filtered build, we compact the data and hence location_to_labels, we need the compacted version of
// raw labels to compute GT correctly.
void save_raw_labels(const std::string &save_path, const size_t total_points) override;
void save_medoids(const std::string &save_path) override;
void save_label_map(const std::string &save_path) override;
void save_universal_label(const std::string &save_path) override;

Expand All @@ -58,7 +51,6 @@ template <typename label_type> class InMemFilterStore : public AbstractFilterSto
protected:
// This is for internal use and only loads already parsed file, used by index in during load().
size_t load_labels(const std::string &labels_file) override;
size_t load_medoids(const std::string &labels_to_medoid_file) override;
void load_label_map(const std::string &labels_map_file) override;
void load_universal_labels(const std::string &universal_labels_file) override;

Expand All @@ -68,12 +60,6 @@ template <typename label_type> class InMemFilterStore : public AbstractFilterSto
tsl::robin_set<label_type> _labels;
std::unordered_map<std::string, label_type> _label_map;

// medoids

// move medoids to Index class since its property of index
std::unordered_map<label_type, uint32_t> _label_to_medoid_id;
std::unordered_map<uint32_t, uint32_t> _medoid_counts; // medoids only happen for filtered index

// universal label
bool _has_universal_label = false;
label_type _universal_label;
Expand Down
5 changes: 5 additions & 0 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
// Calculate best medoids for filter data
void calculate_best_medoids(const size_t num_points_to_load, const uint32_t num_candidates);

//load medoids
size_t load_medoids(const std::string &labels_to_medoid_file);
//save medoids
void save_medoids(const std::string &save_path);

// The query to use is placed in scratch->aligned_query
std::pair<uint32_t, uint32_t> iterate_to_fixed_point(InMemQueryScratch<T> *scratch, const uint32_t Lindex,
const std::vector<uint32_t> &init_ids, bool use_filter,
Expand Down
75 changes: 0 additions & 75 deletions src/in_mem_filter_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,29 +57,6 @@ template <typename label_type> void InMemFilterStore<label_type>::add_to_label_s
_labels.insert(label);
}

template <typename label_type>
void InMemFilterStore<label_type>::update_medoid_by_label(const label_type &label, const uint32_t new_medoid)
{
_label_to_medoid_id[label] = new_medoid;
}

template <typename label_type>
const uint32_t &InMemFilterStore<label_type>::get_medoid_by_label(const label_type &label)
{
return _label_to_medoid_id[label];
}

template <typename label_type>
const std::unordered_map<label_type, uint32_t> &InMemFilterStore<label_type>::get_labels_to_medoids()
{
return _label_to_medoid_id;
}

template <typename label_type> bool InMemFilterStore<label_type>::label_has_medoid(const label_type &label)
{
return _label_to_medoid_id.find(label) != _label_to_medoid_id.end();
}

template <typename label_type>
void InMemFilterStore<label_type>::add_label_to_location(const location_t point_id, const label_type label)
{
Expand Down Expand Up @@ -137,41 +114,6 @@ template <typename label_type> size_t InMemFilterStore<label_type>::load_labels(
return parse_label_file(labels_file);
}

template <typename label_type>
size_t InMemFilterStore<label_type>::load_medoids(const std::string &labels_to_medoid_file)
{
if (file_exists(labels_to_medoid_file))
{
std::ifstream medoid_stream(labels_to_medoid_file);
std::string line, token;
uint32_t line_cnt = 0;

_label_to_medoid_id.clear();
while (std::getline(medoid_stream, line))
{
std::istringstream iss(line);
uint32_t cnt = 0;
uint32_t medoid = 0;
label_type label;
while (std::getline(iss, token, ','))
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
label_type token_as_num = (label_type)std::stoul(token);
if (cnt == 0)
label = token_as_num;
else
medoid = token_as_num;
cnt++;
}
_label_to_medoid_id[label] = medoid;
line_cnt++;
}
return (size_t)line_cnt;
}
throw ANNException("ERROR: can not load medoids file does not exist", -1);
}

template <typename label_type> void InMemFilterStore<label_type>::load_label_map(const std::string &labels_map_file)
{
if (file_exists(labels_map_file))
Expand Down Expand Up @@ -290,23 +232,6 @@ template <typename label_type> void InMemFilterStore<label_type>::save_universal
}
}

template <typename label_type> void InMemFilterStore<label_type>::save_medoids(const std::string &save_path)
{
if (_label_to_medoid_id.size() > 0)
{
std::ofstream medoid_writer(save_path);
if (medoid_writer.fail())
{
throw diskann::ANNException(std::string("Failed to open medoid file ") + save_path, -1);
}
for (auto iter : _label_to_medoid_id)
{
medoid_writer << iter.first << ", " << iter.second << std::endl;
}
medoid_writer.close();
}
}

template <typename label_type> void InMemFilterStore<label_type>::save_label_map(const std::string &save_path)
{
if (_label_map.empty())
Expand Down
78 changes: 64 additions & 14 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ void Index<T, TagT, LabelT>::save(const char *filename, bool compact_before_save
{
if (_filtered_index)
{
_filter_store->save_medoids(std::string(filename) + "_labels_to_medoids.txt");
save_medoids(std::string(filename) + "_labels_to_medoids.txt");
_filter_store->save_label_map(std::string(filename) + "_labels_map.txt");
_filter_store->save_universal_label(std::string(filename) + "_universal_label.txt");
_filter_store->save_labels(std::string(filename) + "_labels.txt", _nd + _num_frozen_pts);
Expand Down Expand Up @@ -537,7 +537,7 @@ void Index<T, TagT, LabelT>::load(const char *filename, uint32_t num_threads, ui
_filter_store->load_label_map(labels_map_file);
label_num_pts = _filter_store->load_labels(labels_file);
assert(label_num_pts == data_file_num_pts);
_filter_store->load_medoids(labels_to_medoids);
load_medoids(labels_to_medoids);
_filter_store->load_universal_labels(std::string(filename) + "_universal_label.txt");
}

Expand Down Expand Up @@ -722,6 +722,59 @@ void Index<T, TagT, LabelT>::calculate_best_medoids(const size_t num_points_to_l
}
}

template <typename T, typename TagT, typename LabelT>
size_t Index<T, TagT, LabelT>::load_medoids(const std::string &labels_to_medoid_file)
{
if (file_exists(labels_to_medoid_file))
{
std::ifstream medoid_stream(labels_to_medoid_file);
std::string line, token;
uint32_t line_cnt = 0;

_label_to_medoid_id.clear();
while (std::getline(medoid_stream, line))
{
std::istringstream iss(line);
uint32_t cnt = 0;
uint32_t medoid = 0;
LabelT label;
while (std::getline(iss, token, ','))
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
LabelT token_as_num = (LabelT)std::stoul(token);
if (cnt == 0)
label = token_as_num;
else
medoid = token_as_num;
cnt++;
}
_label_to_medoid_id[label] = medoid;
line_cnt++;
}
return (size_t)line_cnt;
}
throw ANNException("ERROR: can not load medoids file does not exist", -1);
}

template <typename T, typename TagT, typename LabelT>
void Index<T, TagT, LabelT>::save_medoids(const std::string &medoid_file_name)
{
if (_label_to_medoid_id.size() > 0)
{
std::ofstream medoid_writer(medoid_file_name);
if (medoid_writer.fail())
{
throw diskann::ANNException(std::string("Failed to open medoid file ") + medoid_file_name, -1);
}
for (auto iter : _label_to_medoid_id)
{
medoid_writer << iter.first << ", " << iter.second << std::endl;
}
medoid_writer.close();
}
}

// Find common filter between a node's labels and a given set of labels, while
// taking into account universal label
template <typename T, typename TagT, typename LabelT>
Expand Down Expand Up @@ -939,7 +992,7 @@ void Index<T, TagT, LabelT>::search_for_point_and_prune(int location, uint32_t L
tl.lock();
std::vector<uint32_t> filter_specific_start_nodes;
for (auto &x : _filter_store->get_labels_by_location(location))
filter_specific_start_nodes.emplace_back(_filter_store->get_medoid_by_label(x));
filter_specific_start_nodes.emplace_back(_label_to_medoid_id[x]);

if (_dynamic_index)
tl.unlock();
Expand Down Expand Up @@ -1893,9 +1946,9 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
if (_dynamic_index)
tl.lock();

if (_filter_store->label_has_medoid(filter_label))
if (_label_to_medoid_id.find(filter_label) != _label_to_medoid_id.end())
{
init_ids.emplace_back(_filter_store->get_medoid_by_label(filter_label));
init_ids.emplace_back(_label_to_medoid_id[filter_label]);
}
else
{
Expand Down Expand Up @@ -2303,12 +2356,12 @@ template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT
if (_filtered_index && _dynamic_index)
{
// update medoid id's as frozen points are treated as medoid
for (auto &[label, medoid_id] : _filter_store->get_labels_to_medoids())
for (auto &[label, medoid_id] : _label_to_medoid_id)
{
/* if (label == _universal_label)
continue;*/
_filter_store->update_medoid_by_label(label, (uint32_t)_nd + (medoid_id - (uint32_t)_max_points));
//_label_to_start_id[label] = (uint32_t)_nd + (medoid_id - (uint32_t)_max_points);
uint32_t medoid = (uint32_t)_nd + (medoid_id - (uint32_t)_max_points);
_label_to_medoid_id[label] = medoid;
}
}
}
Expand Down Expand Up @@ -2608,12 +2661,9 @@ template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT
// update medoid id's as frozen points are treated as medoid
if (_filtered_index && _dynamic_index)
{
for (auto &[label, medoid_id] : _filter_store->get_labels_to_medoids())
for (auto &[label, medoid_id] : _label_to_medoid_id)
{
/*if (label == _universal_label)
continue;*/
_filter_store->update_medoid_by_label(label, (uint32_t)_max_points + (medoid_id - (uint32_t)_nd));
//_label_to_medoid_id[label] = (uint32_t)_max_points + (medoid_id - (uint32_t)_nd);
_label_to_medoid_id[label] = (uint32_t)_max_points + (medoid_id - (uint32_t)_nd);
}
}
}
Expand Down Expand Up @@ -2732,7 +2782,7 @@ int Index<T, TagT, LabelT>::insert_point(const T *point, const TagT tag, const s

auto fz_location = (int)(_max_points) + _frozen_pts_used; // as first _fz_point
_filter_store->add_to_label_set(label);
_filter_store->update_medoid_by_label(label, (uint32_t)fz_location);
_label_to_medoid_id[label] = (uint32_t)fz_location;
std::vector<LabelT> fz_label = {label};
_filter_store->set_labels_to_location((location_t)fz_location, {label_str});
//_label_to_start_id[label] = (uint32_t)fz_location;
Expand Down

0 comments on commit 7462bd0

Please sign in to comment.