Skip to content

Commit

Permalink
Add trace span in bruteforce search (#370)
Browse files Browse the repository at this point in the history
Signed-off-by: Yudong Cai <[email protected]>
  • Loading branch information
cydrain authored Feb 2, 2024
1 parent e9574eb commit 5699751
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 6 deletions.
91 changes: 88 additions & 3 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
#include "knowhere/sparse_utils.h"
#include "knowhere/utils.h"

#ifdef NOT_COMPILE_FOR_SWIG
#include "knowhere/tracer.h"
#endif

namespace knowhere {

/* knowhere wrapper API to call faiss brute force search for all metric types */
Expand All @@ -47,13 +51,25 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset

auto xq = query->GetTensor();
auto nq = query->GetRows();

BruteForceConfig cfg;
std::string msg;
auto status = Config::Load(cfg, config, knowhere::SEARCH, &msg);
if (status != Status::success) {
return expected<DataSetPtr>::Err(status, msg);
}

#ifdef NOT_COMPILE_FOR_SWIG
std::shared_ptr<tracer::trace::Span> span = nullptr;
if (cfg.trace_id.has_value()) {
auto ctx = tracer::TraceContext{(uint8_t*)cfg.trace_id.value().c_str(), (uint8_t*)cfg.span_id.value().c_str(),
(uint8_t)cfg.trace_flags.value()};
span = tracer::StartSpan("knowhere bf search", &ctx);
span->SetAttribute(meta::METRIC_TYPE, cfg.metric_type.value());
span->SetAttribute(meta::TOPK, cfg.k.value());
}
#endif

std::string metric_str = cfg.metric_type.value();
auto result = Str2FaissMetricType(metric_str);
if (result.error() != Status::success) {
Expand Down Expand Up @@ -133,7 +149,15 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
return GenResultDataSet(nq, cfg.k.value(), labels.release(), distances.release());
auto res = GenResultDataSet(nq, cfg.k.value(), labels.release(), distances.release());

#ifdef NOT_COMPILE_FOR_SWIG
if (cfg.trace_id.has_value()) {
span->End();
}
#endif

return res;
}

template <typename DataType>
Expand All @@ -156,6 +180,17 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
BruteForceConfig cfg;
RETURN_IF_ERROR(Config::Load(cfg, config, knowhere::SEARCH));

#ifdef NOT_COMPILE_FOR_SWIG
std::shared_ptr<tracer::trace::Span> span = nullptr;
if (cfg.trace_id.has_value()) {
auto ctx = tracer::TraceContext{(uint8_t*)cfg.trace_id.value().c_str(), (uint8_t*)cfg.span_id.value().c_str(),
(uint8_t)cfg.trace_flags.value()};
span = tracer::StartSpan("knowhere bf search with buf", &ctx);
span->SetAttribute(meta::METRIC_TYPE, cfg.metric_type.value());
span->SetAttribute(meta::TOPK, cfg.k.value());
}
#endif

std::string metric_str = cfg.metric_type.value();
auto result = Str2FaissMetricType(cfg.metric_type.value());
if (result.error() != Status::success) {
Expand Down Expand Up @@ -232,6 +267,13 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
}));
}
RETURN_IF_ERROR(WaitAllSuccess(futs));

#ifdef NOT_COMPILE_FOR_SWIG
if (cfg.trace_id.has_value()) {
span->End();
}
#endif

return Status::success;
}

Expand Down Expand Up @@ -261,6 +303,21 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
return expected<DataSetPtr>::Err(status, std::move(msg));
}

#ifdef NOT_COMPILE_FOR_SWIG
std::shared_ptr<tracer::trace::Span> span = nullptr;
if (cfg.trace_id.has_value()) {
auto ctx = tracer::TraceContext{(uint8_t*)cfg.trace_id.value().c_str(), (uint8_t*)cfg.span_id.value().c_str(),
(uint8_t)cfg.trace_flags.value()};
span = tracer::StartSpan("knowhere bf range search", &ctx);
span->SetAttribute(meta::METRIC_TYPE, cfg.metric_type.value());
span->SetAttribute(meta::METRIC_TYPE, cfg.metric_type.value());
span->SetAttribute(meta::RADIUS, cfg.radius.value());
if (cfg.range_filter.value() != defaultRangeFilter) {
span->SetAttribute(meta::RANGE_FILTER, cfg.range_filter.value());
}
}
#endif

std::string metric_str = cfg.metric_type.value();
auto result = Str2FaissMetricType(metric_str);
if (result.error() != Status::success) {
Expand Down Expand Up @@ -351,7 +408,15 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
float* distances = nullptr;
size_t* lims = nullptr;
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims);
return GenResultDataSet(nq, ids, distances, lims);
auto res = GenResultDataSet(nq, ids, distances, lims);

#ifdef NOT_COMPILE_FOR_SWIG
if (cfg.trace_id.has_value()) {
span->End();
}
#endif

return res;
}

Status
Expand Down Expand Up @@ -430,12 +495,32 @@ BruteForce::SearchSparse(const DataSetPtr base_dataset, const DataSetPtr query_d
if (status != Status::success) {
return expected<DataSetPtr>::Err(status, msg);
}

#ifdef NOT_COMPILE_FOR_SWIG
std::shared_ptr<tracer::trace::Span> span = nullptr;
if (cfg.trace_id.has_value()) {
auto ctx = tracer::TraceContext{(uint8_t*)cfg.trace_id.value().c_str(), (uint8_t*)cfg.span_id.value().c_str(),
(uint8_t)cfg.trace_flags.value()};
span = tracer::StartSpan("knowhere bf search with buf", &ctx);
span->SetAttribute(meta::METRIC_TYPE, cfg.metric_type.value());
span->SetAttribute(meta::TOPK, cfg.k.value());
}
#endif

int topk = cfg.k.value();
auto labels = std::make_unique<sparse::label_t[]>(nq * topk);
auto distances = std::make_unique<float[]>(nq * topk);

SearchSparseWithBuf(base_dataset, query_dataset, labels.get(), distances.get(), config, bitset);
return GenResultDataSet(nq, topk, labels.release(), distances.release());
auto res = GenResultDataSet(nq, topk, labels.release(), distances.release());

#ifdef NOT_COMPILE_FOR_SWIG
if (cfg.trace_id.has_value()) {
span->End();
}
#endif

return res;
}

} // namespace knowhere
Expand Down
8 changes: 5 additions & 3 deletions src/common/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,11 @@ AddEvent(const std::string& event_label) {

bool
isEmptyID(const uint8_t* id, int length) {
for (int i = 0; i < length; i++) {
if (id[i] != 0) {
return false;
if (id != nullptr) {
for (int i = 0; i < length; i++) {
if (id[i] != 0) {
return false;
}
}
}
return true;
Expand Down

0 comments on commit 5699751

Please sign in to comment.