Skip to content

Commit

Permalink
first commit to generalize to AND of ORs
Browse files Browse the repository at this point in the history
  • Loading branch information
rakri committed Feb 28, 2025
1 parent ca2fb3b commit a2e6a0a
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 55 deletions.
8 changes: 7 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@
"cfenv": "cpp",
"typeindex": "cpp",
"typeinfo": "cpp",
"variant": "cpp"
"variant": "cpp",
"compare": "cpp",
"concepts": "cpp",
"future": "cpp",
"numbers": "cpp",
"semaphore": "cpp",
"stop_token": "cpp"
}
}
16 changes: 9 additions & 7 deletions apps/search_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads,
const uint32_t recall_at, const bool print_all_recalls, const std::vector<uint32_t> &Lvec,
const bool dynamic, const bool tags, const bool show_qps_per_thread,
const std::vector<std::vector<std::string>> &query_filters,
const std::vector<std::vector<std::vector<std::string>>> &query_filters,
const uint32_t filter_penalty_threshold, const uint32_t bruteforce_threshold,
const uint32_t clustering_threshold, uint32_t L_for_print, const float fail_if_recall_below,
uint32_t maxN = 10000000, float p1 = 0.1, float p2 = 0.1)
Expand Down Expand Up @@ -218,7 +218,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
old_g = num_graphs;
old_c = num_clusters;
method_used = 0;
std::vector<std::string> raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
std::vector<std::vector<std::string>> raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];

auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L,
query_result_ids[test_id].data() + i * recall_at,
Expand Down Expand Up @@ -248,7 +248,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
else
{
std::vector<std::string> raw_filter =
query_filters.size() == 1 ? query_filters[0] : query_filters[i];
query_filters.size() == 1 ? query_filters[0][0] : query_filters[i][0];

index->search_with_tags(query + i * query_aligned_dim, recall_at, L,
query_result_tags.data() + i * recall_at, nullptr, res, true,
Expand Down Expand Up @@ -598,16 +598,18 @@ int main(int argc, char **argv)
return -1;
}

std::vector<std::vector<std::string>> query_filters;
std::vector<std::vector<std::vector<std::string>>> query_filters;
if (filter_label != "")
{
std::vector<std::string> single_filter;
single_filter.push_back(filter_label);
std::vector<std::vector<std::string>> single_filter;
std::vector<std::string> tmp;
tmp.push_back(filter_label);
single_filter.push_back(tmp);
query_filters.push_back(single_filter);
}
else if (query_filters_file != "")
{
query_filters = read_file_to_vector_of_strings(query_filters_file);
query_filters = read_file_to_vector_of_vector_of_strings(query_filters_file);
}

use_global_start = global_start;
Expand Down
4 changes: 2 additions & 2 deletions apps/utils/compute_filtered_groundtruth.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,8 @@ inline void parse_query_label_file(const std::string &query_label_file,
std::istringstream inner_iss(token);
while (getline(inner_iss, token, '|'))
{
if (print_flag)
std::cout<<token<<" || ";
// if (print_flag)
// std::cout<<token<<" || ";
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
or_clause.push_back(token);
Expand Down
4 changes: 2 additions & 2 deletions include/abstract_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class AbstractIndex
const size_t K, const uint32_t L, IndexType *indices,
float *distances);
template <typename IndexType>
std::pair<uint32_t, uint32_t> search_with_filters(const DataType &query, const std::vector<std::string> &raw_label,
std::pair<uint32_t, uint32_t> search_with_filters(const DataType &query, const std::vector<std::vector<std::string>> &raw_label,
const size_t K, const uint32_t L, IndexType *indices,
float *distances);

Expand Down Expand Up @@ -115,7 +115,7 @@ class AbstractIndex
virtual std::pair<uint32_t, uint32_t> _search(const DataType &query, const size_t K, const uint32_t L,
std::any &indices, float *distances = nullptr) = 0;
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query,
const std::vector<std::string> &filter_label,
const std::vector<std::vector<std::string>> &filter_label,
const size_t K, const uint32_t L, std::any &indices,
float *distances) = 0;
virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) = 0;
Expand Down
4 changes: 2 additions & 2 deletions include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas

template <typename IndexType>
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_filters(const T *query,
const std::vector<LabelT> &filter_label,
const std::vector<std::vector<LabelT>> &filter_label,
const size_t K, const uint32_t L,
IndexType *indices, float *distances);

Expand Down Expand Up @@ -232,7 +232,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
virtual std::pair<uint32_t, uint32_t> _search(const DataType &query, const size_t K, const uint32_t L,
std::any &indices, float *distances = nullptr) override;
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query,
const std::vector<std::string> &filter_label_raw,
const std::vector<std::vector<std::string>> &filter_label_raw,
const size_t K, const uint32_t L, std::any &indices,
float *distances) override;

Expand Down
81 changes: 78 additions & 3 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1083,10 +1083,10 @@ template <typename T = float> inline void normalize(T *arr, const size_t dim)
}
}

inline std::vector<std::vector<std::string>> read_file_to_vector_of_strings(const std::string &filename,
inline std::vector<std::vector<std::vector<std::string>>> read_file_to_vector_of_vector_of_strings(const std::string &filename,
bool unique = false)
{
std::vector<std::vector<std::string>> query_filters;
std::vector<std::vector<std::vector<std::string>>> query_filters;
std::ifstream file(filename);
std::string line, token;

Expand All @@ -1097,20 +1097,95 @@ inline std::vector<std::vector<std::string>> read_file_to_vector_of_strings(cons

while (std::getline(file, line))
{

std::istringstream iss(line);
std::vector<std::vector<std::string>> lbls(0);

getline(iss, token, '\t');
std::istringstream new_iss(token);
while (getline(new_iss, token, '&'))
{
std::vector<std::string> or_clause(0);
std::istringstream inner_iss(token);
while (getline(inner_iss, token, '|'))
{
// if (print_flag)
// std::cout<<token<<" || ";
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
or_clause.push_back(token);
// labels.insert(token);
}
lbls.push_back(or_clause);
}
// std::sort(lbls.begin(), lbls.end());
query_filters.push_back(lbls);


/* std::istringstream iss(line);
std::vector<std::string> lbls(0);
while (getline(iss, token, '&'))
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
lbls.push_back(token);
}
query_filters.push_back(lbls);
query_filters.push_back(lbls);*/
}
std::cout << "Populated labels for " << query_filters.size() << " queries" << std::endl;
return query_filters;
}


inline std::vector<std::vector<std::string>> read_file_to_vector_of_strings(const std::string &filename,
bool unique = false)
{
std::vector<std::vector<std::string>> query_filters;
std::ifstream file(filename);
std::string line, token;

if (file.fail())
{
throw diskann::ANNException(std::string("Failed to open file ") + filename, -1);
}

while (std::getline(file, line))
{

std::istringstream iss(line);
std::vector<std::string> lbls(0);

getline(iss, token, '\t');
std::istringstream new_iss(token);
while (getline(new_iss, token, '&'))
{

// if (print_flag)
// std::cout<<token<<" || ";
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
// labels.insert(token);
lbls.push_back(token);
}
// std::sort(lbls.begin(), lbls.end());
query_filters.push_back(lbls);


/* std::istringstream iss(line);
std::vector<std::string> lbls(0);
while (getline(iss, token, '&'))
{
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
lbls.push_back(token);
}
query_filters.push_back(lbls);*/
}
std::cout << "Populated labels for " << query_filters.size() << " queries" << std::endl;
return query_filters;
}


inline void clean_up_artifacts(tsl::robin_set<std::string> paths_to_clean, tsl::robin_set<std::string> path_suffixes)
{
try
Expand Down
10 changes: 6 additions & 4 deletions src/abstract_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters(const DataType
auto any_indices = std::any(indices);
std::vector<std::string> tmp_lbls;
tmp_lbls.push_back(raw_label);
return _search_with_filters(query, tmp_lbls, K, L, any_indices, distances);
std::vector<std::vector<std::string>> tmp_lbls_vec;
tmp_lbls_vec.push_back(tmp_lbls);
return _search_with_filters(query, tmp_lbls_vec, K, L, any_indices, distances);
}

template <typename IndexType>
std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters(const DataType &query,
const std::vector<std::string> &raw_label,
const std::vector<std::vector<std::string>> &raw_label,
const size_t K, const uint32_t L, IndexType *indices,
float *distances)
{
Expand Down Expand Up @@ -175,11 +177,11 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_w
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices,
float *distances);
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters<uint32_t>(
const DataType &query, const std::vector<std::string> &raw_label, const size_t K, const uint32_t L,
const DataType &query, const std::vector<std::vector<std::string>> &raw_label, const size_t K, const uint32_t L,
uint32_t *indices, float *distances);

template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters<uint64_t>(
const DataType &query, const std::vector<std::string> &raw_label, const size_t K, const uint32_t L,
const DataType &query, const std::vector<std::vector<std::string>> &raw_label, const size_t K, const uint32_t L,
uint64_t *indices, float *distances);

template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags<float, int32_t>(
Expand Down
Loading

0 comments on commit a2e6a0a

Please sign in to comment.