Skip to content

Commit

Permalink
support basic searcher (#351)
Browse files Browse the repository at this point in the history
Signed-off-by: zhongxiaoyao.zxy <[email protected]>
  • Loading branch information
ShawnShawnYou authored Jan 24, 2025
1 parent 06ff47c commit 6ccab68
Show file tree
Hide file tree
Showing 7 changed files with 456 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/algorithm/hnswlib/hnswalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1490,4 +1490,10 @@ HierarchicalNSW::searchRange(const void* query_data,
// std::cout << "hnswalg::result.size(): " << result.size() << std::endl;
return result;
}

template MaxHeap
HierarchicalNSW::searchBaseLayerST<false, false>(InnerIdType ep_id,
const void* data_point,
size_t ef,
vsag::BaseFilterFunctor* isIdAllowed) const;
} // namespace hnswlib
11 changes: 11 additions & 0 deletions src/algorithm/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,17 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
bool
isValidLabel(LabelType label) override;

size_t
getMaxDegree() {
return maxM0_;
};

linklistsizeint*
get_linklist0(InnerIdType internal_id) const {
// only for test now
return (linklistsizeint*)(data_level0_memory_->GetElementPtr(internal_id, offsetLevel0_));
}

inline LabelType
getExternalLabel(InnerIdType internal_id) const {
std::shared_lock lock(points_locks_[internal_id]);
Expand Down
11 changes: 11 additions & 0 deletions src/data_cell/flatten_datacell.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,18 @@ FlattenDataCell<QuantTmpl, IOTmpl>::query(float* result_dists,
const std::shared_ptr<Computer<QuantTmpl>>& computer,
const InnerIdType* idx,
InnerIdType id_count) {
for (uint32_t i = 0; i < this->prefetch_jump_code_size_ and i < id_count; i++) {
this->io_->Prefetch(static_cast<uint64_t>(idx[i]) * static_cast<uint64_t>(code_size_),
this->prefetch_cache_line_size_);
}

for (int64_t i = 0; i < id_count; ++i) {
if (i + this->prefetch_jump_code_size_ < id_count) {
this->io_->Prefetch(static_cast<uint64_t>(idx[i + this->prefetch_jump_code_size_]) *
static_cast<uint64_t>(code_size_),
this->prefetch_cache_line_size_);
}

bool release = false;
const auto* codes = this->GetCodesById(idx[i], release);
computer->ComputeDist(codes, result_dists + i);
Expand Down
2 changes: 2 additions & 0 deletions src/data_cell/flatten_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class FlattenInterface {
InnerIdType total_count_{0};
InnerIdType max_capacity_{1000000};
uint32_t code_size_{0};
uint32_t prefetch_jump_code_size_{1};
uint32_t prefetch_cache_line_size_{1};
};

} // namespace vsag
129 changes: 129 additions & 0 deletions src/impl/basic_searcher.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "basic_searcher.h"

namespace vsag {

BasicSearcher::BasicSearcher(const IndexCommonParam& common_param) {
this->allocator_ = common_param.allocator_.get();
}

uint32_t
BasicSearcher::visit(const GraphInterfacePtr& graph_data_cell,
const std::shared_ptr<VisitedList>& vl,
const std::pair<float, uint64_t>& current_node_pair,
Vector<InnerIdType>& to_be_visited_rid,
Vector<InnerIdType>& to_be_visited_id) const {
uint32_t count_no_visited = 0;
Vector<InnerIdType> neighbors(allocator_);

graph_data_cell->GetNeighbors(current_node_pair.second, neighbors);

for (uint32_t i = 0; i < prefetch_jump_visit_size_; i++) {
vl->Prefetch(neighbors[i]);
}

for (uint32_t i = 0; i < neighbors.size(); i++) {
if (i + prefetch_jump_visit_size_ < neighbors.size()) {
vl->Prefetch(neighbors[i + prefetch_jump_visit_size_]);
}
if (not vl->Get(neighbors[i])) {
to_be_visited_rid[count_no_visited] = i;
to_be_visited_id[count_no_visited] = neighbors[i];
count_no_visited++;
vl->Set(neighbors[i]);
}
}
return count_no_visited;
}

MaxHeap
BasicSearcher::Search(const GraphInterfacePtr& graph_data_cell,
const FlattenInterfacePtr& vector_data_cell,
const std::shared_ptr<VisitedList>& vl,
const float* query,
const InnerSearchParam& inner_search_param) const {
MaxHeap top_candidates(allocator_);
MaxHeap candidate_set(allocator_);

if (not graph_data_cell or not vector_data_cell) {
return top_candidates;
}

auto computer = vector_data_cell->FactoryComputer(query);

float lower_bound;
float dist;
uint64_t candidate_id;
uint32_t hops = 0;
uint32_t dist_cmp = 0;
uint32_t count_no_visited = 0;
Vector<InnerIdType> to_be_visited_rid(graph_data_cell->MaximumDegree(), allocator_);
Vector<InnerIdType> to_be_visited_id(graph_data_cell->MaximumDegree(), allocator_);
Vector<float> line_dists(graph_data_cell->MaximumDegree(), allocator_);

InnerIdType ep_id = inner_search_param.ep_;
vector_data_cell->Query(&dist, computer, &ep_id, 1);
top_candidates.emplace(dist, ep_id);
candidate_set.emplace(-dist, ep_id);
vl->Set(ep_id);

while (!candidate_set.empty()) {
hops++;
std::pair<float, uint64_t> current_node_pair = candidate_set.top();

if ((-current_node_pair.first) > lower_bound &&
(top_candidates.size() == inner_search_param.ef_)) {
break;
}
candidate_set.pop();
if (not candidate_set.empty()) {
graph_data_cell->Prefetch(candidate_set.top().second, 0);
}

count_no_visited =
visit(graph_data_cell, vl, current_node_pair, to_be_visited_rid, to_be_visited_id);

dist_cmp += count_no_visited;

vector_data_cell->Query(
line_dists.data(), computer, to_be_visited_id.data(), count_no_visited);

for (uint32_t i = 0; i < count_no_visited; i++) {
dist = line_dists[i];
candidate_id = to_be_visited_id[i];
if (top_candidates.size() < inner_search_param.ef_ || lower_bound > dist) {
candidate_set.emplace(-dist, candidate_id);

top_candidates.emplace(dist, candidate_id);

if (top_candidates.size() > inner_search_param.ef_)
top_candidates.pop();

if (!top_candidates.empty())
lower_bound = top_candidates.top().first;
}
}
}

while (top_candidates.size() > inner_search_param.topk_) {
top_candidates.pop();
}

return top_candidates;
}

} // namespace vsag
64 changes: 64 additions & 0 deletions src/impl/basic_searcher.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "../utils.h"
#include "algorithm/hnswlib/algorithm_interface.h"
#include "common.h"
#include "data_cell/flatten_interface.h"
#include "data_cell/graph_interface.h"
#include "index/index_common_param.h"
#include "utils/visited_list.h"

namespace vsag {

class InnerSearchParam {
public:
int topk_{0};
float radius_{0.0f};
InnerIdType ep_{0};
uint64_t ef_{10};
BaseFilterFunctor* is_id_allowed_{nullptr};
};

class BasicSearcher {
public:
BasicSearcher(const IndexCommonParam& common_param);

virtual MaxHeap
Search(const GraphInterfacePtr& graph_data_cell,
const FlattenInterfacePtr& vector_data_cell,
const std::shared_ptr<VisitedList>& vl,
const float* query,
const InnerSearchParam& inner_search_param) const;

private:
// rid means the neighbor's rank (e.g., the first neighbor's rid == 0)
// id means the neighbor's id (e.g., the first neighbor's id == 12345)
uint32_t
visit(const GraphInterfacePtr& graph_data_cell,
const std::shared_ptr<VisitedList>& vl,
const std::pair<float, uint64_t>& current_node_pair,
Vector<InnerIdType>& to_be_visited_rid,
Vector<InnerIdType>& to_be_visited_id) const;

private:
Allocator* allocator_{nullptr};

uint32_t prefetch_jump_visit_size_{1};
};

} // namespace vsag
Loading

0 comments on commit 6ccab68

Please sign in to comment.