diff --git a/include/parameters.h b/include/parameters.h index edde5df9c..2c16bf9e1 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -139,11 +139,12 @@ enum State : uint8_t template class IndexSearchContext { public: - IndexSearchContext(uint32_t time_limit_in_microseconds = 0u, uint32_t io_limit = UINT32_MAX) - : _time_limit_in_microseconds(time_limit_in_microseconds), _io_limit(io_limit), _result_state(State::Unknown) + IndexSearchContext(uint32_t time_limit_in_microseconds = 0u, uint32_t io_limit = UINT32_MAX, bool allowLessThanKResults = false) + : _time_limit_in_microseconds(time_limit_in_microseconds), _io_limit(io_limit), _result_state(State::Unknown), _allowLessThankResults(allowLessThanKResults) { _use_filter = false; _label = (LabelT)0; + _total_result_returned = 0; } void SetLabel(LabelT label, bool use_filter) @@ -157,6 +158,16 @@ template class IndexSearchContext _result_state = state; } + void UpdateResultReturned(size_t result_returned) + { + _total_result_returned = result_returned; + } + + size_t GetResultCount() + { + return _total_result_returned; + } + State GetState() const { return _result_state; @@ -198,6 +209,11 @@ template class IndexSearchContext return _stats; } + bool GetAllowLessThanKResults() + { + return _allowLessThankResults; + } + private: uint32_t _time_limit_in_microseconds; uint32_t _io_limit; @@ -206,6 +222,8 @@ template class IndexSearchContext LabelT _label; Timer _timer; QueryStats _stats; + bool _allowLessThankResults; + size_t _total_result_returned; }; } // namespace diskann diff --git a/src/index.cpp b/src/index.cpp index 4157edcef..aac892573 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2251,14 +2251,20 @@ std::pair Index::search(const T *query, con break; } - if (pos < K) + if (pos < K && context.GetAllowLessThanKResults()) + { + context.SetState(State::Success); + context.UpdateResultReturned(pos); + } + else if(pos < K) { context.SetState(State::Failure); - diskann::cerr << "Found pos: " << pos << "fewer than K elements " << K << " for query" << std::endl; + context.UpdateResultReturned(pos); } else { context.SetState(State::Success); + context.UpdateResultReturned(K); } return retval; @@ -2387,14 +2393,20 @@ std::pair Index::search_with_filters(const if (pos == K) break; } - if (pos < K) + if (pos < K && context.GetAllowLessThanKResults()) + { + context.SetState(State::Success); + context.UpdateResultReturned(pos); + } + else if(pos < K) { context.SetState(State::Failure); - diskann::cerr << "Found fewer than K elements for query" << std::endl; + context.UpdateResultReturned(pos); } else { context.SetState(State::Success); + context.UpdateResultReturned(K); } return retval; diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 68c54ea65..1d9cce796 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -1687,6 +1687,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t } context.SetState(State::Success); + context.UpdateResultReturned(k_search); } // range search returns results of all neighbors within distance of range.