diff --git a/CHANGELOG.md b/CHANGELOG.md index bd01d0740..93d524be1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Add functionality to iteratively insert vectors into a faiss index to improve the memory footprint during indexing. [#1840](https://github.com/opensearch-project/k-NN/pull/1840) ### Bug Fixes * Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874) +* Fixed and abstracted functionality for allocating index memory [#1933](https://github.com/opensearch-project/k-NN/pull/1933) ### Infrastructure ### Documentation ### Maintenance diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index f147a6e7e..c57309cfc 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -65,6 +65,7 @@ class IndexService { virtual void writeIndex(std::string indexPath, jlong idMapAddress); virtual ~IndexService() = default; protected: + virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors); std::unique_ptr faissMethods; }; @@ -120,10 +121,12 @@ class BinaryIndexService : public IndexService { */ virtual void writeIndex(std::string indexPath, jlong idMapAddress) override; virtual ~BinaryIndexService() = default; +protected: + virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) override; }; } } -#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H +#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H \ No newline at end of file diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index cfb30cdb0..69866da76 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -57,6 +57,21 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, IndexService::IndexService(std::unique_ptr faissMethods) : faissMethods(std::move(faissMethods)) {} +void IndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) { + if(auto * indexHNSWSQ = dynamic_cast(index)) { + if(auto * indexScalarQuantizer = dynamic_cast(indexHNSWSQ->storage)) { + indexScalarQuantizer->codes.reserve(indexScalarQuantizer->code_size * numVectors); + } + return; + } + if(auto * indexHNSW = dynamic_cast(index)) { + if(auto * indexFlat = dynamic_cast(indexHNSW->storage)) { + indexFlat->codes.reserve(indexFlat->code_size * numVectors); + } + return; + } +} + jlong IndexService::initIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, @@ -83,36 +98,9 @@ jlong IndexService::initIndex( throw std::runtime_error("Index is not trained"); } - // Add vectors std::unique_ptr idMap (faissMethods->indexIdMap(indexWriter.get())); - /* - * NOTE: The process of memory allocation is currently only implemented for HNSW. - * This technique of checking the types of the index and subindices should be generalized into - * another function. - */ - - // Check to see if the current index is HNSW - faiss::IndexHNSWFlat * hnsw = dynamic_cast(idMap->index); - if(hnsw != NULL) { - // Check to see if the HNSW storage is IndexFlat - faiss::IndexFlat * storage = dynamic_cast(hnsw->storage); - if(storage != NULL) { - // Allocate enough memory for all of the vectors we plan on inserting - // We do this to avoid unnecessary memory allocations during insert - storage->codes.reserve(dim * numVectors * 4); - } - } - faiss::IndexHNSWSQ * hnswSq = dynamic_cast(idMap->index); - if(hnswSq != NULL) { - // Check to see if the HNSW storage is IndexFlat - faiss::IndexFlat * storage = dynamic_cast(hnswSq->storage); - if(storage != NULL) { - // Allocate enough memory for all of the vectors we plan on inserting - // We do this to avoid unnecessary memory allocations during insert - storage->codes.reserve(dim * numVectors * 2); - } - } + allocIndex(dynamic_cast(idMap->index), dim, numVectors); indexWriter.release(); return reinterpret_cast(idMap.release()); } @@ -168,6 +156,14 @@ void IndexService::writeIndex( BinaryIndexService::BinaryIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {} +void BinaryIndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) { + if(auto * indexBinaryHNSW = dynamic_cast(index)) { + auto * indexBinaryFlat = dynamic_cast(indexBinaryHNSW->storage); + indexBinaryFlat->xb.reserve(dim * numVectors / 8); + return; + } +} + jlong BinaryIndexService::initIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, @@ -194,27 +190,9 @@ jlong BinaryIndexService::initIndex( throw std::runtime_error("Index is not trained"); } - // Add vectors std::unique_ptr idMap(faissMethods->indexBinaryIdMap(indexWriter.get())); - /* - * NOTE: The process of memory allocation is currently only implemented for HNSW. - * This technique of checking the types of the index and subindices should be generalized into - * another function. - */ - - // Check to see if the current index is BinaryHNSW - faiss::IndexBinaryHNSW * hnsw = dynamic_cast(idMap->index); - - if(hnsw != NULL) { - // Check to see if the HNSW storage is IndexBinaryFlat - faiss::IndexBinaryFlat * storage = dynamic_cast(hnsw->storage); - if(storage != NULL) { - // Allocate enough memory for all of the vectors we plan on inserting - // We do this to avoid unnecessary memory allocations during insert - storage->xb.reserve(dim / 8 * numVectors); - } - } + allocIndex(dynamic_cast(idMap->index), dim, numVectors); indexWriter.release(); return reinterpret_cast(idMap.release()); } @@ -271,4 +249,4 @@ void BinaryIndexService::writeIndex( } } // namespace faiss_wrapper -} // namesapce knn_jni +} // namesapce knn_jni \ No newline at end of file