Skip to content

Commit

Permalink
Add edges for those unreachable nodes to improve graph connectivity (#…
Browse files Browse the repository at this point in the history
…340)

Signed-off-by: chasingegg <[email protected]>
  • Loading branch information
chasingegg authored Jan 16, 2024
1 parent 474eaf9 commit 5230b6a
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 24 deletions.
76 changes: 52 additions & 24 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class HnswIndexNode : public IndexNode {
return Status::empty_index;
}

knowhere::TimeRecorder build_time("Building HNSW cost");
knowhere::TimeRecorder build_time("Building HNSW cost", 2);
auto rows = dataset.GetRows();
if (rows <= 0) {
LOG_KNOWHERE_ERROR_ << "Can not add empty data to HNSW index.";
Expand Down Expand Up @@ -109,32 +109,60 @@ class HnswIndexNode : public IndexNode {
std::mt19937 urng(rng());
std::shuffle(shuffle_batch_ids.begin(), shuffle_batch_ids.end(), urng);
}
index_->addPoint(tensor, 0);

futures.reserve(batch_size);
for (int64_t round_id = 0; round_id < round_num; round_id++) {
int64_t start_id = (shuffle_build ? shuffle_batch_ids[round_id] : round_id) * batch_size;
int64_t end_id =
std::min(rows - 1, ((shuffle_build ? shuffle_batch_ids[round_id] : round_id) + 1) * batch_size);
for (int64_t i = start_id; i < end_id; ++i) {
futures.emplace_back(build_pool->push([&, idx = i + 1]() {
index_->addPoint(((const char*)tensor + index_->data_size_ * idx), idx);
uint64_t added = counter.fetch_add(1);
if (added % one_tenth_row == 0) {
LOG_KNOWHERE_INFO_ << "HNSW build progress: " << (added / one_tenth_row) << "0%";
}
}));
try {
index_->addPoint(tensor, 0);

futures.reserve(batch_size);
for (int64_t round_id = 0; round_id < round_num; round_id++) {
int64_t start_id = (shuffle_build ? shuffle_batch_ids[round_id] : round_id) * batch_size;
int64_t end_id =
std::min(rows - 1, ((shuffle_build ? shuffle_batch_ids[round_id] : round_id) + 1) * batch_size);
for (int64_t i = start_id; i < end_id; ++i) {
futures.emplace_back(build_pool->push([&, idx = i + 1]() {
index_->addPoint(((const char*)tensor + index_->data_size_ * idx), idx);
uint64_t added = counter.fetch_add(1);
if (added % one_tenth_row == 0) {
LOG_KNOWHERE_INFO_ << "HNSW build progress: " << (added / one_tenth_row) << "0%";
}
}));
}
for (auto& future : futures) {
future.wait();
}
// check for exceptions
for (auto& future : futures) {
future.result().value();
}
futures.clear();
}
for (auto& future : futures) {
future.wait();

build_time.RecordSection("graph build");
std::vector<unsigned> unreached = index_->findUnreachableVectors();
int unreached_num = unreached.size();
LOG_KNOWHERE_INFO_ << "there are " << unreached_num << " points can not be reached";
if (unreached_num > 0) {
futures.reserve(unreached_num);
for (int i = 0; i < unreached_num; ++i) {
futures.emplace_back(
build_pool->push([&, idx = i]() { index_->repairGraphConnectivity(unreached[idx]); }));
}
for (auto& future : futures) {
future.wait();
}
// check for exceptions
for (auto& future : futures) {
future.result().value();
}
}
futures.clear();
build_time.RecordSection("graph repair");
LOG_KNOWHERE_INFO_ << "HNSW built with #points num:" << index_->max_elements_ << " #M:" << index_->M_
<< " #max level:" << index_->maxlevel_
<< " #ef_construction:" << index_->ef_construction_
<< " #dim:" << *(size_t*)(index_->space_->get_dist_func_param());
} catch (std::exception& e) {
LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what();
return Status::hnsw_inner_error;
}

build_time.RecordSection("");
LOG_KNOWHERE_INFO_ << "HNSW built with #points num:" << index_->max_elements_ << " #M:" << index_->M_
<< " #max level:" << index_->maxlevel_ << " #ef_construction:" << index_->ef_construction_
<< " #dim:" << *(size_t*)(index_->space_->get_dist_func_param());
return Status::success;
}

Expand Down
126 changes: 126 additions & 0 deletions thirdparty/hnswlib/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,132 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
return getNeighboursWithinRadius(retset, query_data, radius, bitset);
}

// get those unreachable vectors at the base layer after index building
// only be called after index building
std::vector<tableint>
findUnreachableVectors() {
tableint currObj = enterpoint_node_;
std::vector<tableint> start_points;
start_points.push_back(currObj);
std::vector<bool> visited;
std::vector<tableint> unreached;
for (int level = maxlevel_; level >= 0; level--) {
visited = std::vector<bool>(cur_element_count, false);
std::vector<tableint> touched;
for (auto start_point : start_points) {
if (visited[start_point])
continue;
std::queue<tableint> q;
q.push(start_point);
visited[start_point] = true;
if (level > 0)
touched.push_back(start_point);
while (!q.empty()) {
tableint j = q.front();
q.pop();
unsigned int* data;
data = (unsigned int*)get_linklist_at_level(j, level);
size_t size = getListCount((linklistsizeint*)data);
tableint* datal = (tableint*)(data + 1);
for (size_t k = 0; k < size; k++) {
tableint cand = datal[k];
if (!visited[cand]) {
visited[cand] = true;
q.push(cand);
if (level > 0)
touched.push_back(cand);
}
}
}
}
start_points = touched;

for (tableint i = 0; i < cur_element_count; ++i) {
if (element_levels_[i] >= level) {
if (!visited[i]) {
if (level > 0) { // for upper level, directly add edges since nodes num is usually small and fast to search its neighbors
repairGraphConnectivity(i, level);
} else { // for base level, collect the unreachable nodes and repair them concurrently
unreached.push_back(i);
}
}
}
}
}
return unreached;
}

// add some edges for those unreachable vectors to improve graph connectivity
// only call this method after index building
void
repairGraphConnectivity(tableint cur_c, int level = 0) {
size_t m_max = level ? maxM_ : maxM0_;
tableint currObj = enterpoint_node_;

dist_t curdist = calcDistance(cur_c, currObj);

for (int level_above = maxlevel_; level_above > level; level_above--) {
bool changed = true;
while (changed) {
changed = false;
unsigned int* data;
// do not a lock here, since upper layer will not be modified
data = (unsigned int*)get_linklist(currObj, level_above);
int size = getListCount(data);
tableint* datal = (tableint*)(data + 1);
#if defined(USE_PREFETCH)
for (int i = 0; i < size; ++i) {
_mm_prefetch(getDataByInternalId(datal[i]), _MM_HINT_T0);
}
#endif
for (int i = 0; i < size; i++) {
tableint cand = datal[i];
if (cand < 0 || cand > max_elements_)
throw std::runtime_error("cand error");
dist_t d = calcDistance(cur_c, cand);

if (d < curdist) {
curdist = d;
currObj = cand;
changed = true;
}
}
}
}
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates = searchBaseLayer(
currObj, cur_c, level);

// get sorted id
std::vector<tableint> top_candidate_ids(candidates.size());
for (int i = static_cast<int>(candidates.size() - 1); i >= 0; i--) {
top_candidate_ids[i] = candidates.top().second;
candidates.pop();
}
int add_count = 0;
for (auto cand_id : top_candidate_ids) {
// skip same element
if (cand_id == cur_c) {
continue;
}

// try to connect candidate to the element
// add an edge if there is space
std::unique_lock <std::mutex> lock(link_list_locks_[cand_id]);
linklistsizeint *ll_cand = get_linklist_at_level(cand_id, level);
size_t size = getListCount(ll_cand);
tableint *data_cand = (tableint *) (ll_cand + 1);
if (size < m_max) {
data_cand[size] = cur_c;
setListCount(ll_cand, size + 1);
add_count++;
}
// do not add too much? If we already have m_max nodes connecting to the element
if (add_count >= m_max) {
break;
}
}
}

void
checkIntegrity() {
int connections_checked = 0;
Expand Down

0 comments on commit 5230b6a

Please sign in to comment.