From 91cfcde5adc2a9a8b29b20f01f1ceb7eb294311b Mon Sep 17 00:00:00 2001 From: "xianliang.li" Date: Tue, 14 Jan 2025 18:28:23 +0800 Subject: [PATCH] add zero copy support for faiss Signed-off-by: xianliang.li --- include/knowhere/index/index.h | 2 +- include/knowhere/index/index_node.h | 3 +- .../index/index_node_data_mock_wrapper.h | 4 +- .../index/index_node_thread_pool_wrapper.h | 4 +- python/knowhere/knowhere.i | 2 +- src/index/diskann/diskann.cc | 4 +- src/index/flat/flat.cc | 13 +- src/index/gpu/flat_gpu/flat_gpu.cc | 2 +- src/index/gpu/ivf_gpu/ivf_gpu.cc | 2 +- src/index/gpu_raft/gpu_raft.h | 2 +- src/index/gpu_raft/gpu_raft_cagra.cc | 4 +- src/index/hnsw/faiss_hnsw.cc | 16 ++- src/index/hnsw/hnsw.h | 7 +- src/index/index.cc | 6 +- src/index/ivf/ivf.cc | 18 +-- src/index/sparse/sparse_index_node.cc | 4 +- src/io/memory_io.h | 69 +++++++++ tests/ut/test_diskann.cc | 8 +- tests/ut/test_faiss_hnsw.cc | 2 +- tests/ut/test_get_vector.cc | 4 +- tests/ut/test_index_node.cc | 8 +- tests/ut/test_iterator.cc | 2 +- tests/ut/test_search.cc | 8 +- tests/ut/test_sparse.cc | 4 +- thirdparty/faiss/faiss/IVFlib.cpp | 7 +- thirdparty/faiss/faiss/Index.h | 3 + thirdparty/faiss/faiss/impl/index_read.cpp | 131 +++++++++++++----- thirdparty/faiss/faiss/impl/io.h | 2 - .../faiss/faiss/impl/maybe_owned_vector.h | 83 ++++++----- thirdparty/faiss/faiss/impl/zerocopy_io.cpp | 61 ++++++++ thirdparty/faiss/faiss/impl/zerocopy_io.h | 23 +++ thirdparty/faiss/faiss/index_io.h | 2 + .../faiss/faiss/invlists/InvertedLists.cpp | 4 +- .../faiss/faiss/invlists/InvertedLists.h | 7 +- thirdparty/hnswlib/hnswlib/hnswalg.h | 28 ++-- 35 files changed, 398 insertions(+), 151 deletions(-) create mode 100644 thirdparty/faiss/faiss/impl/zerocopy_io.cpp create mode 100644 thirdparty/faiss/faiss/impl/zerocopy_io.h diff --git a/include/knowhere/index/index.h b/include/knowhere/index/index.h index 7bbb22ec4..01782d389 100644 --- a/include/knowhere/index/index.h +++ b/include/knowhere/index/index.h @@ -183,7 +183,7 @@ class Index { Serialize(BinarySet& binset) const; Status - Deserialize(const BinarySet& binset, const Json& json = {}); + Deserialize(BinarySet&& binset, const Json& json = {}); Status DeserializeFromFile(const std::string& filename, const Json& json = {}); diff --git a/include/knowhere/index/index_node.h b/include/knowhere/index/index_node.h index 829669970..b70f5385c 100644 --- a/include/knowhere/index/index_node.h +++ b/include/knowhere/index/index_node.h @@ -385,7 +385,7 @@ class IndexNode : public Object { * its release. */ virtual Status - Deserialize(const BinarySet& binset, std::shared_ptr config) = 0; + Deserialize(BinarySet&& binset, std::shared_ptr config) = 0; /** * @brief Deserializes the index from a file. @@ -437,6 +437,7 @@ class IndexNode : public Object { } protected: + BinarySet binarySet_; Version version_; }; diff --git a/include/knowhere/index/index_node_data_mock_wrapper.h b/include/knowhere/index/index_node_data_mock_wrapper.h index 3e19bde72..6283c8444 100644 --- a/include/knowhere/index/index_node_data_mock_wrapper.h +++ b/include/knowhere/index/index_node_data_mock_wrapper.h @@ -64,8 +64,8 @@ class IndexNodeDataMockWrapper : public IndexNode { } Status - Deserialize(const BinarySet& binset, std::shared_ptr cfg) override { - return index_node_->Deserialize(binset, std::move(cfg)); + Deserialize(BinarySet&& binset, std::shared_ptr cfg) override { + return index_node_->Deserialize(std::move(binset), std::move(cfg)); } Status diff --git a/include/knowhere/index/index_node_thread_pool_wrapper.h b/include/knowhere/index/index_node_thread_pool_wrapper.h index 3a1cb932f..91cab7925 100644 --- a/include/knowhere/index/index_node_thread_pool_wrapper.h +++ b/include/knowhere/index/index_node_thread_pool_wrapper.h @@ -60,8 +60,8 @@ class IndexNodeThreadPoolWrapper : public IndexNode { } Status - Deserialize(const BinarySet& binset, std::shared_ptr cfg) override { - return index_node_->Deserialize(binset, std::move(cfg)); + Deserialize(BinarySet&& binset, std::shared_ptr cfg) override { + return index_node_->Deserialize(std::move(binset), std::move(cfg)); } Status diff --git a/python/knowhere/knowhere.i b/python/knowhere/knowhere.i index 0fc2cb1fe..270cb7db3 100644 --- a/python/knowhere/knowhere.i +++ b/python/knowhere/knowhere.i @@ -253,7 +253,7 @@ class IndexWrap { knowhere::Status Deserialize(knowhere::BinarySetPtr binset, const std::string& json) { GILReleaser rel; - return idx.value().Deserialize(*binset, knowhere::Json::parse(json)); + return idx.value().Deserialize(std::move(*binset), knowhere::Json::parse(json)); } knowhere::Status diff --git a/src/index/diskann/diskann.cc b/src/index/diskann/diskann.cc index 2effb67b9..7385e3063 100644 --- a/src/index/diskann/diskann.cc +++ b/src/index/diskann/diskann.cc @@ -90,7 +90,7 @@ class DiskANNIndexNode : public IndexNode { } Status - Deserialize(const BinarySet& binset, std::shared_ptr cfg) override; + Deserialize(BinarySet&& binset, std::shared_ptr cfg) override; Status DeserializeFromFile(const std::string& filename, std::shared_ptr config) override { @@ -385,7 +385,7 @@ DiskANNIndexNode::Build(const DataSetPtr dataset, std::shared_ptr Status -DiskANNIndexNode::Deserialize(const BinarySet& binset, std::shared_ptr cfg) { +DiskANNIndexNode::Deserialize(BinarySet&& binset, std::shared_ptr cfg) { std::lock_guard lock(preparation_lock_); auto prep_conf = static_cast(*cfg); if (!CheckMetric(prep_conf.metric_type.value())) { diff --git a/src/index/flat/flat.cc b/src/index/flat/flat.cc index 50075efcf..0b62d0b52 100644 --- a/src/index/flat/flat.cc +++ b/src/index/flat/flat.cc @@ -13,6 +13,7 @@ #include "faiss/IndexBinaryFlat.h" #include "faiss/IndexFlat.h" #include "faiss/impl/AuxIndexStructures.h" +#include "faiss/impl/zerocopy_io.h" #include "faiss/index_io.h" #include "index/flat/flat_config.h" #include "io/memory_io.h" @@ -313,23 +314,25 @@ class FlatIndexNode : public IndexNode { } Status - Deserialize(const BinarySet& binset, std::shared_ptr) override { + Deserialize(BinarySet&& binset, std::shared_ptr) override { std::vector names = {"IVF", // compatible with knowhere-1.x "BinaryIVF", // compatible with knowhere-1.x Type()}; - auto binary = binset.GetByNames(names); + binarySet_ = std::move(binset); + auto binary = binarySet_.GetByNames(names); if (binary == nullptr) { LOG_KNOWHERE_ERROR_ << "Invalid binary set."; return Status::invalid_binary_set; } - MemoryIOReader reader(binary->data.get(), binary->size); + int io_flags = faiss::IO_FLAG_ZERO_COPY; + faiss::ZeroCopyIOReader reader(binary->data.get(), binary->size); if constexpr (std::is_same::value) { - faiss::Index* index = faiss::read_index(&reader); + faiss::Index* index = faiss::read_index(&reader, io_flags); index_.reset(static_cast(index)); } if constexpr (std::is_same::value) { - faiss::IndexBinary* index = faiss::read_index_binary(&reader); + faiss::IndexBinary* index = faiss::read_index_binary(&reader, io_flags); index_.reset(static_cast(index)); } return Status::success; diff --git a/src/index/gpu/flat_gpu/flat_gpu.cc b/src/index/gpu/flat_gpu/flat_gpu.cc index 547474fa1..88fc40ed2 100644 --- a/src/index/gpu/flat_gpu/flat_gpu.cc +++ b/src/index/gpu/flat_gpu/flat_gpu.cc @@ -131,7 +131,7 @@ class GpuFlatIndexNode : public IndexNode { } Status - Deserialize(const BinarySet& binset, const Config& config) override { + Deserialize(BinarySet&& binset, const Config& config) override { auto binary = binset.GetByName(Type()); if (binary == nullptr) { LOG_KNOWHERE_ERROR_ << "Invalid binary set."; diff --git a/src/index/gpu/ivf_gpu/ivf_gpu.cc b/src/index/gpu/ivf_gpu/ivf_gpu.cc index 8a062dbf8..0257c6dcf 100644 --- a/src/index/gpu/ivf_gpu/ivf_gpu.cc +++ b/src/index/gpu/ivf_gpu/ivf_gpu.cc @@ -202,7 +202,7 @@ class GpuIvfIndexNode : public IndexNode { } Status - Deserialize(const BinarySet& binset, const Config& config) override { + Deserialize(BinarySet&& binset, const Config& config) override { auto binary = binset.GetByName(Type()); if (binary == nullptr) { LOG_KNOWHERE_ERROR_ << "invalid binary set."; diff --git a/src/index/gpu_raft/gpu_raft.h b/src/index/gpu_raft/gpu_raft.h index 04fe6f596..7cb2c703f 100644 --- a/src/index/gpu_raft/gpu_raft.h +++ b/src/index/gpu_raft/gpu_raft.h @@ -180,7 +180,7 @@ struct GpuRaftIndexNode : public IndexNode { } Status - Deserialize(const BinarySet& binset, std::shared_ptr) override { + Deserialize(BinarySet&& binset, std::shared_ptr) override { auto result = Status::success; std::stringbuf buf; auto binary = binset.GetByName(this->Type()); diff --git a/src/index/gpu_raft/gpu_raft_cagra.cc b/src/index/gpu_raft/gpu_raft_cagra.cc index aa93c9c6f..0fed4ae2d 100644 --- a/src/index/gpu_raft/gpu_raft_cagra.cc +++ b/src/index/gpu_raft/gpu_raft_cagra.cc @@ -122,7 +122,7 @@ class GpuRaftCagraHybridIndexNode : public GpuRaftCagraIndexNode { } Status - Deserialize(const BinarySet& binset, std::shared_ptr cfg) override { + Deserialize(BinarySet&& binset, std::shared_ptr cfg) override { if (binset.Contains(std::string(this->Type()) + "_cpu")) { this->adapt_for_cpu = true; auto binary = binset.GetByName(std::string(this->Type() + "_cpu")); @@ -147,7 +147,7 @@ class GpuRaftCagraHybridIndexNode : public GpuRaftCagraIndexNode { return Status::success; } - return GpuRaftCagraIndexNode::Deserialize(binset, std::move(cfg)); + return GpuRaftCagraIndexNode::Deserialize(std::move(binset), std::move(cfg)); } Status diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 1c2406a3b..aff56f123 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -31,6 +31,7 @@ #include "faiss/IndexRefine.h" #include "faiss/impl/ScalarQuantizer.h" #include "faiss/impl/mapped_io.h" +#include "faiss/impl/zerocopy_io.h" #include "faiss/index_io.h" #include "index/hnsw/faiss_hnsw_config.h" #include "index/hnsw/hnsw.h" @@ -52,7 +53,6 @@ #include "knowhere/log.h" #include "knowhere/range_util.h" #include "knowhere/utils.h" - #if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) #include "knowhere/prometheus_client.h" #endif @@ -201,14 +201,16 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode { } Status - Deserialize(const BinarySet& binset, std::shared_ptr config) override { - auto binary = binset.GetByName(Type()); + Deserialize(BinarySet&& binset, std::shared_ptr config) override { + binarySet_ = std::move(binset); + auto binary = binarySet_.GetByName(Type()); if (binary == nullptr) { LOG_KNOWHERE_ERROR_ << "Invalid binary set."; return Status::invalid_binary_set; } - MemoryIOReader reader(binary->data.get(), binary->size); + int io_flags = faiss::IO_FLAG_ZERO_COPY; + faiss::ZeroCopyIOReader reader(binary->data.get(), binary->size); try { // this is a hack for compatibility, faiss index has 4-byte header to indicate index category // create a new one to distinguish MV faiss hnsw from faiss hnsw @@ -1890,11 +1892,11 @@ class HNSWIndexNodeWithFallback : public IndexNode { } Status - Deserialize(const BinarySet& binset, std::shared_ptr config) override { + Deserialize(BinarySet&& binset, std::shared_ptr config) override { if (use_base_index) { - return base_index->Deserialize(binset, config); + return base_index->Deserialize(std::move(binset), config); } else { - return fallback_search_index->Deserialize(binset, config); + return fallback_search_index->Deserialize(std::move(binset), config); } } diff --git a/src/index/hnsw/hnsw.h b/src/index/hnsw/hnsw.h index 63461ced3..6c45caffd 100644 --- a/src/index/hnsw/hnsw.h +++ b/src/index/hnsw/hnsw.h @@ -485,18 +485,19 @@ class HnswIndexNode : public IndexNode { } Status - Deserialize(const BinarySet& binset, std::shared_ptr) override { + Deserialize(BinarySet&& binset, std::shared_ptr) override { if (index_) { delete index_; } try { - auto binary = binset.GetByName(Type()); + binarySet_ = std::move(binset); + auto binary = binarySet_.GetByName(Type()); if (binary == nullptr) { LOG_KNOWHERE_ERROR_ << "Invalid binary set."; return Status::invalid_binary_set; } - MemoryIOReader reader(binary->data.get(), binary->size); + ZeroCopyIOReader reader(binary->data.get(), binary->size); hnswlib::SpaceInterface* space = nullptr; index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); diff --git a/src/index/index.cc b/src/index/index.cc index d6da77f63..11b4e3bb7 100644 --- a/src/index/index.cc +++ b/src/index/index.cc @@ -301,7 +301,7 @@ Index::Serialize(BinarySet& binset) const { template inline Status -Index::Deserialize(const BinarySet& binset, const Json& json) { +Index::Deserialize(BinarySet&& binset, const Json& json) { Json json_(json); auto cfg = this->node->CreateConfig(); { @@ -318,12 +318,12 @@ Index::Deserialize(const BinarySet& binset, const Json& json) { #if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) TimeRecorder rc("Load index", 2); - res = this->node->Deserialize(binset, std::move(cfg)); + res = this->node->Deserialize(std::move(binset), std::move(cfg)); auto time = rc.ElapseFromBegin("done"); time *= 0.001; // convert to ms knowhere_load_latency.Observe(time); #else - res = this->node->Deserialize(binset, std::move(cfg)); + res = this->node->Deserialize(std::move(binset), std::move(cfg)); #endif return res; } diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index 57307d8b1..bc45f4fa1 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -20,6 +20,7 @@ #include "faiss/IndexIVFScalarQuantizerCC.h" #include "faiss/IndexScaNN.h" #include "faiss/IndexScalarQuantizer.h" +#include "faiss/impl/zerocopy_io.h" #include "faiss/index_io.h" #include "index/ivf/ivf_config.h" #include "io/memory_io.h" @@ -162,7 +163,7 @@ class IvfIndexNode : public IndexNode { return this->SerializeImpl(binset, typename IndexDispatch::Tag{}); } Status - Deserialize(const BinarySet& binset, std::shared_ptr cfg) override; + Deserialize(BinarySet&& binset, std::shared_ptr cfg) override; Status DeserializeFromFile(const std::string& filename, std::shared_ptr cfg) override; @@ -1172,17 +1173,18 @@ IvfIndexNode::SerializeImpl(BinarySet& binset, IVFFlatTag) template Status -IvfIndexNode::Deserialize(const BinarySet& binset, std::shared_ptr cfg) { +IvfIndexNode::Deserialize(BinarySet&& binset, std::shared_ptr cfg) { std::vector names = {"IVF", // compatible with knowhere-1.x "BinaryIVF", // compatible with knowhere-1.x Type()}; - auto binary = binset.GetByNames(names); + binarySet_ = std::move(binset); + auto binary = binarySet_.GetByNames(names); if (binary == nullptr) { LOG_KNOWHERE_ERROR_ << "Invalid binary set."; return Status::invalid_binary_set; } - - MemoryIOReader reader(binary->data.get(), binary->size); + int io_flags = faiss::IO_FLAG_ZERO_COPY; + faiss::ZeroCopyIOReader reader(binary->data.get(), binary->size); try { if constexpr (std::is_same::value) { if (this->version_ <= Version::GetMinimalVersion()) { @@ -1193,11 +1195,11 @@ IvfIndexNode::Deserialize(const BinarySet& binset, std::sha reader.data_ = binary->data.get(); reader.total_ = binary->size; } - index_.reset(static_cast(faiss::read_index(&reader))); + index_.reset(static_cast(faiss::read_index(&reader, io_flags))); } else if constexpr (std::is_same::value) { - index_.reset(static_cast(faiss::read_index_binary(&reader))); + index_.reset(static_cast(faiss::read_index_binary(&reader, io_flags))); } else { - index_.reset(static_cast(faiss::read_index(&reader))); + index_.reset(static_cast(faiss::read_index(&reader, io_flags))); } if constexpr (!std::is_same_v && !std::is_same_v) { diff --git a/src/index/sparse/sparse_index_node.cc b/src/index/sparse/sparse_index_node.cc index a3f136751..5d6a5d733 100644 --- a/src/index/sparse/sparse_index_node.cc +++ b/src/index/sparse/sparse_index_node.cc @@ -255,7 +255,7 @@ class SparseInvertedIndexNode : public IndexNode { } Status - Deserialize(const BinarySet& binset, std::shared_ptr config) override { + Deserialize(BinarySet&& binset, std::shared_ptr config) override { if (index_ != nullptr) { LOG_KNOWHERE_WARNING_ << Type() << " has already been created, deleting old"; DeleteExistingIndex(); @@ -528,7 +528,7 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode { } Status - Deserialize(const BinarySet& binset, std::shared_ptr config) override { + Deserialize(BinarySet&& binset, std::shared_ptr config) override { return Status::not_implemented; } diff --git a/src/io/memory_io.h b/src/io/memory_io.h index 461c6aaba..a6386e3c2 100644 --- a/src/io/memory_io.h +++ b/src/io/memory_io.h @@ -13,6 +13,8 @@ #include +#include + namespace knowhere { #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ @@ -162,4 +164,71 @@ struct MemoryIOReader : public faiss::IOReader { } }; +struct ZeroCopyIOReader : public faiss::IOReader { + uint8_t* data_; + size_t rp_ = 0; + size_t total_ = 0; + + ZeroCopyIOReader(uint8_t* data, size_t size) : data_(data), rp_(0), total_(size) { + } + + ~ZeroCopyIOReader() = default; + + size_t + getDataView(void** ptr, size_t size, size_t nitems) { + if (size == 0) { + return nitems; + } + + size_t actual_size = size * nitems; + if (rp_ + size * nitems > total_) { + actual_size = total_ - rp_; + } + + size_t actual_nitems = (actual_size + size - 1) / size; + if (actual_nitems == 0) { + return 0; + } + + // get an address + *ptr = (void*)(reinterpret_cast(data_ + rp_)); + + // alter pos + rp_ += size * actual_nitems; + + return actual_nitems; + } + + size_t + operator()(void* ptr, size_t size, size_t nitems) override { + if (rp_ >= total_) { + return 0; + } + size_t nremain = (total_ - rp_) / size; + if (nremain < nitems) { + nitems = nremain; + } + memcpy(ptr, (data_ + rp_), size * nitems); + rp_ += size * nitems; + return nitems; + } + + template + size_t + read(T* ptr, size_t size, size_t nitems = 1) { + auto res = operator()((void*)ptr, size, nitems); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + for (size_t i = 0; i < nitems; ++i) { + *(ptr + i) = getSwappedBytes(*(ptr + i)); + } +#endif + return res; + } + + int + filedescriptor() override { + return -1; + } +}; + } // namespace knowhere diff --git a/tests/ut/test_diskann.cc b/tests/ut/test_diskann.cc index 666fda856..83cbcfe4b 100644 --- a/tests/ut/test_diskann.cc +++ b/tests/ut/test_diskann.cc @@ -167,7 +167,7 @@ TEST_CASE("Invalid diskann params test", "[diskann]") { knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack).value(); diskann.Build(ds_ptr, test_gen()); diskann.Serialize(binarySet); - diskann.Deserialize(binarySet, test_gen()); + diskann.Deserialize(std::move(binarySet), test_gen()); knowhere::Json test_json; auto query_ds = GenDataSet(kNumQueries, kDim, 42); @@ -304,7 +304,7 @@ base_search() { // knn search auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack).value(); - diskann.Deserialize(binset, deserialize_json); + diskann.Deserialize(std::move(binset), deserialize_json); REQUIRE(diskann.HasRawData(metric_str) == knowhere::IndexStaticFaced::HasRawData("DISKANN", version, json)); @@ -324,7 +324,7 @@ base_search() { } auto diskann_tmp = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack).value(); - diskann_tmp.Deserialize(binset, deserialize_json); + diskann_tmp.Deserialize(std::move(binset), deserialize_json); auto knn_search_json = knn_search_gen().dump(); knowhere::Json knn_json = knowhere::Json::parse(knn_search_json); auto res = diskann_tmp.Search(query_ds, knn_json, nullptr); @@ -430,7 +430,7 @@ TEST_CASE("Test DiskANN GetVectorByIds", "[diskann]") { auto index = knowhere::IndexFactory::Instance() .Create("DISKANN", version, diskann_index_pack) .value(); - auto ret = index.Deserialize(binset, deserialize_json); + auto ret = index.Deserialize(std::move(binset), deserialize_json); REQUIRE(ret == knowhere::Status::success); REQUIRE(diskann.HasRawData(knowhere::metric::L2) == diff --git a/tests/ut/test_faiss_hnsw.cc b/tests/ut/test_faiss_hnsw.cc index d207e4c25..3836a98d2 100644 --- a/tests/ut/test_faiss_hnsw.cc +++ b/tests/ut/test_faiss_hnsw.cc @@ -184,7 +184,7 @@ read_index(knowhere::Index& index, const std::string& filen binary_set.Append(name, data_ptr, data_size); } - index.Deserialize(binary_set, conf); + index.Deserialize(std::move(binary_set), conf); } template diff --git a/tests/ut/test_get_vector.cc b/tests/ut/test_get_vector.cc index 588ce2925..6046b1786 100644 --- a/tests/ut/test_get_vector.cc +++ b/tests/ut/test_get_vector.cc @@ -82,7 +82,7 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") { idx.Serialize(bs); auto idx_new = knowhere::IndexFactory::Instance().Create(name, version).value(); - idx_new.Deserialize(bs); + idx_new.Deserialize(std::move(bs)); REQUIRE(idx.HasRawData(metric_type) == knowhere::IndexStaticFaced::HasRawData(name, version, json)); @@ -209,7 +209,7 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") { idx.Serialize(bs); auto idx_new = knowhere::IndexFactory::Instance().Create(name, version).value(); - idx_new.Deserialize(bs); + idx_new.Deserialize(std::move(bs)); auto retrieve_task = [&]() { auto results = idx_new.GetVectorByIds(ids_ds); diff --git a/tests/ut/test_index_node.cc b/tests/ut/test_index_node.cc index 4084d338f..d7b880adb 100644 --- a/tests/ut/test_index_node.cc +++ b/tests/ut/test_index_node.cc @@ -90,7 +90,7 @@ class BaseFlatIndexNode : public IndexNode { } virtual Status - Deserialize(const BinarySet& binset, std::shared_ptr config) override { + Deserialize(BinarySet&& binset, std::shared_ptr config) override { LOG_KNOWHERE_INFO_ << "BaseFlatIndexNode::Deserialize()"; return Status::success; } @@ -152,7 +152,7 @@ TEST_CASE("Test index node") { REQUIRE(index.HasRawData(metric::L2) == true); REQUIRE(index.GetIndexMeta({}).error() == Status::not_implemented); REQUIRE(index.Serialize(binset) == Status::success); - REQUIRE(index.Deserialize(binset, {}) == Status::success); + REQUIRE(index.Deserialize(std::move(binset), {}) == Status::success); REQUIRE(index.DeserializeFromFile("", {}) == Status::success); REQUIRE(index.Dim() == 0); REQUIRE(index.Size() == 0); @@ -173,7 +173,7 @@ TEST_CASE("Test index node") { REQUIRE(index.HasRawData(metric::L2) == true); REQUIRE(index.GetIndexMeta({}).error() == Status::not_implemented); REQUIRE(index.Serialize(binset) == Status::success); - REQUIRE(index.Deserialize(binset, {}) == Status::success); + REQUIRE(index.Deserialize(std::move(binset), {}) == Status::success); REQUIRE(index.DeserializeFromFile("", {}) == Status::success); REQUIRE(index.Dim() == 0); REQUIRE(index.Size() == 0); @@ -195,7 +195,7 @@ TEST_CASE("Test index node") { REQUIRE(index.HasRawData(metric::L2) == true); REQUIRE(index.GetIndexMeta({}).error() == Status::not_implemented); REQUIRE(index.Serialize(binset) == Status::success); - REQUIRE(index.Deserialize(binset, {}) == Status::success); + REQUIRE(index.Deserialize(std::move(binset), {}) == Status::success); REQUIRE(index.DeserializeFromFile("", {}) == Status::success); REQUIRE(index.Dim() == 0); REQUIRE(index.Size() == 0); diff --git a/tests/ut/test_iterator.cc b/tests/ut/test_iterator.cc index 611996b1a..6e5d17353 100644 --- a/tests/ut/test_iterator.cc +++ b/tests/ut/test_iterator.cc @@ -237,7 +237,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { knowhere::BinarySet bs; REQUIRE(idx.Serialize(bs) == knowhere::Status::success); - REQUIRE(idx.Deserialize(bs) == knowhere::Status::success); + REQUIRE(idx.Deserialize(std::move(bs)) == knowhere::Status::success); auto its = idx.AnnIterator(query_ds, json, nullptr); REQUIRE(its.has_value()); diff --git a/tests/ut/test_search.cc b/tests/ut/test_search.cc index ddb938752..365db7a01 100644 --- a/tests/ut/test_search.cc +++ b/tests/ut/test_search.cc @@ -211,7 +211,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { json["enable_mmap"] = true; REQUIRE(idx.DeserializeFromFile(kMmapIndexPath, json) == knowhere::Status::success); } else { - REQUIRE(idx.Deserialize(bs, json) == knowhere::Status::success); + REQUIRE(idx.Deserialize(std::move(bs), json) == knowhere::Status::success); } // TODO: qianya (IVFSQ_CC deserialize casted from the IVFSQ directly, which will cause the hasRawData @@ -275,7 +275,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { knowhere::BinarySet bs; REQUIRE(idx.Serialize(bs) == knowhere::Status::success); - REQUIRE(idx.Deserialize(bs, json) == knowhere::Status::success); + REQUIRE(idx.Deserialize(std::move(bs), json) == knowhere::Status::success); auto results = idx.RangeSearch(query_ds, json, nullptr); REQUIRE(results.has_value()); @@ -490,7 +490,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { idx.Serialize(bs); auto idx_ = knowhere::IndexFactory::Instance().Create(name, version).value(); - idx_.Deserialize(bs); + idx_.Deserialize(std::move(bs)); auto results = idx_.Search(query_ds, json, nullptr); REQUIRE(results.has_value()); } @@ -641,7 +641,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { idx.Serialize(bs); auto idx_ = knowhere::IndexFactory::Instance().Create(name, version).value(); - idx_.Deserialize(bs); + idx_.Deserialize(std::move(bs)); auto results = idx_.Search(query_ds, json, nullptr); REQUIRE(results.has_value()); } diff --git a/tests/ut/test_sparse.cc b/tests/ut/test_sparse.cc index 879dfcb49..1390aed32 100644 --- a/tests/ut/test_sparse.cc +++ b/tests/ut/test_sparse.cc @@ -144,7 +144,7 @@ TEST_CASE("Test Mem Sparse Index With Float Vector", "[float metrics]") { WriteBinaryToFile(tmp_file, bs.GetByName(idx.Type())); REQUIRE(idx.DeserializeFromFile(tmp_file, json) == knowhere::Status::success); } else { - REQUIRE(idx.Deserialize(bs, json) == knowhere::Status::success); + REQUIRE(idx.Deserialize(std::move(bs), json) == knowhere::Status::success); } auto results = idx.Search(query_ds, json, nullptr); @@ -381,7 +381,7 @@ TEST_CASE("Test Mem Sparse Index Handle Empty Vector", "[float metrics]") { WriteBinaryToFile(tmp_file, bs.GetByName(idx.Type())); REQUIRE(idx.DeserializeFromFile(tmp_file, json) == knowhere::Status::success); } else { - REQUIRE(idx.Deserialize(bs, json) == knowhere::Status::success); + REQUIRE(idx.Deserialize(std::move(bs), json) == knowhere::Status::success); } const knowhere::Json conf = { diff --git a/thirdparty/faiss/faiss/IVFlib.cpp b/thirdparty/faiss/faiss/IVFlib.cpp index 23a16174d..720bca3dd 100644 --- a/thirdparty/faiss/faiss/IVFlib.cpp +++ b/thirdparty/faiss/faiss/IVFlib.cpp @@ -21,6 +21,7 @@ #include #include #include +#include namespace faiss { namespace ivflib { @@ -187,9 +188,9 @@ SlidingIndexWindow::SlidingIndexWindow(Index* index) : index(index) { template static void shift_and_add( - std::vector& dst, + MaybeOwnedVector& dst, size_t remove, - const std::vector& src) { + const MaybeOwnedVector& src) { if (remove > 0) memmove(dst.data(), dst.data() + remove, @@ -200,7 +201,7 @@ static void shift_and_add( } template -static void remove_from_begin(std::vector& v, size_t remove) { +static void remove_from_begin(MaybeOwnedVector& v, size_t remove) { if (remove > 0) v.erase(v.begin(), v.begin() + remove); } diff --git a/thirdparty/faiss/faiss/Index.h b/thirdparty/faiss/faiss/Index.h index 6fac99c3e..d02b0d370 100644 --- a/thirdparty/faiss/faiss/Index.h +++ b/thirdparty/faiss/faiss/Index.h @@ -15,6 +15,7 @@ #include #include #include +#include "faiss/impl/mapped_io.h" #define FAISS_VERSION_MAJOR 1 #define FAISS_VERSION_MINOR 8 @@ -81,6 +82,8 @@ struct Index { MetricType metric_type; float metric_arg; ///< argument of the metric type + std::shared_ptr mmap_owner; + explicit Index(idx_t d = 0, MetricType metric = METRIC_L2) : d(d), ntotal(0), diff --git a/thirdparty/faiss/faiss/impl/index_read.cpp b/thirdparty/faiss/faiss/impl/index_read.cpp index 7504114e8..c3780abae 100644 --- a/thirdparty/faiss/faiss/impl/index_read.cpp +++ b/thirdparty/faiss/faiss/impl/index_read.cpp @@ -48,8 +48,8 @@ #include #include #include -#include #include +#include #include #include @@ -60,8 +60,9 @@ #include // mmap facility -#include #include +#include +#include namespace faiss { @@ -88,7 +89,26 @@ bool read_is_mv(const char* fname) { } -template +template +void read_vector_with_size(VectorT& target, IOReader* f, size_t size) { + ZeroCopyIOReader* zr = dynamic_cast(f); + if (zr != nullptr) { + if constexpr (is_maybe_owned_vector_v) { + // create a view + char* address = nullptr; + size_t nread = zr->getDataView( + (void**)&address, + sizeof(typename VectorT::value_type), + size); + VectorT view = VectorT::createView(address, nread); + target = std::move(view); + + return; + } + } +} + +template void read_vector(VectorT& target, IOReader* f) { // is it a mmap-enabled reader? MappedFileIOReader* mf = dynamic_cast(f); @@ -114,10 +134,28 @@ void read_vector(VectorT& target, IOReader* f) { size, strerror(errno)); - VectorT mmapped = VectorT::from_mmapped( - address, nread, mf->mmap_owner - ); - target = std::move(mmapped); + VectorT mmapped_view = VectorT::createView(address, nread); + target = std::move(mmapped_view); + + return; + } + } + + ZeroCopyIOReader* zr = dynamic_cast(f); + if (zr != nullptr) { + if constexpr (is_maybe_owned_vector_v) { + // read the size first + size_t size = target.size(); + READANDCHECK(&size, 1); + + // create a view + char* address = nullptr; + size_t nread = zr->getDataView( + (void**)&address, + sizeof(typename VectorT::value_type), + size); + VectorT view = VectorT::createView(address, nread); + target = std::move(view); return; } @@ -142,30 +180,50 @@ void read_xb_vector(VectorT& target, IOReader* f) { // ok, mmap and check char* address = nullptr; - const size_t nread = mf->mmap((void**)&address, sizeof(typename VectorT::value_type), size); + const size_t nread = mf->mmap( + (void**)&address, + sizeof(typename VectorT::value_type), + size); FAISS_THROW_IF_NOT_FMT( - nread == (size), - "read error in %s: %zd != %zd (%s)", - f->name.c_str(), - nread, - size, - strerror(errno)); + nread == (size), + "read error in %s: %zd != %zd (%s)", + f->name.c_str(), + nread, + size, + strerror(errno)); - VectorT mmapped = VectorT::from_mmapped( - address, nread, mf->mmap_owner - ); - target = std::move(mmapped); + VectorT mmapped_view = VectorT::createView(address, nread); + target = std::move(mmapped_view); return; } } + ZeroCopyIOReader* zr = dynamic_cast(f); + if (zr != nullptr) { + if constexpr (std::is_same_v>) { + // read the size first + size_t size = target.size(); + READANDCHECK(&size, 1); + + size *= 4; + + char* address = nullptr; + size_t nread = zr->getDataView( + (void**)&address, + sizeof(typename VectorT::value_type), + size); + VectorT view = VectorT::createView(address, nread); + target = std::move(view); + return; + } + } + // the default case READXBVECTOR(target); } - /************************************************************* * Read **************************************************************/ @@ -373,19 +431,12 @@ InvertedLists* read_InvertedLists(IOReader* f, int io_flags) { std::vector sizes(ails->nlist); read_ArrayInvertedLists_sizes(f, sizes); for (size_t i = 0; i < ails->nlist; i++) { - ails->ids[i].resize(sizes[i]); - ails->codes[i].resize(sizes[i] * ails->code_size); - if (ails->with_norm) { - ails->code_norms[i].resize(sizes[i]); - } - } - for (size_t i = 0; i < ails->nlist; i++) { - size_t n = ails->ids[i].size(); + size_t n = sizes[i]; if (n > 0) { - READANDCHECK(ails->codes[i].data(), n * ails->code_size); - READANDCHECK(ails->ids[i].data(), n); + read_vector_with_size(ails->codes[i], f, n * ails->code_size); + read_vector_with_size(ails->ids[i], f, n); if (ails->with_norm) { - READANDCHECK(ails->code_norms[i].data(), n); + read_vector_with_size(ails->code_norms[i], f, n); } } } @@ -737,7 +788,11 @@ static ArrayInvertedLists* set_array_invlist( std::vector>& ids) { ArrayInvertedLists* ail = new ArrayInvertedLists(ivf->nlist, ivf->code_size); - std::swap(ail->ids, ids); + + ail->ids.resize(ids.size()); + for (size_t i = 0; i < ids.size(); i++) { + ail->ids[i] = MaybeOwnedVector(std::move(ids[i])); + } ivf->invlists = ail; ivf->own_invlists = true; return ail; @@ -1281,7 +1336,7 @@ Index* read_index(IOReader* f, int io_flags) { idxrf = new IndexRefineFlat(); *idxrf = *idxrf_old; delete idxrf_old; - } + } idxrf->own_fields = true; idxrf->own_refine_index = true; idx = idxrf; @@ -1452,7 +1507,11 @@ Index* read_index(FILE* f, int io_flags) { // enable mmap-supporting IOReader auto owner = std::make_shared(f); MappedFileIOReader reader(owner); - return read_index(&reader, io_flags); + Index* index = read_index(&reader, io_flags); + if (index != nullptr) { + index->mmap_owner = owner; + } + return index; } else { FileIOReader reader(f); return read_index(&reader, io_flags); @@ -1464,7 +1523,11 @@ Index* read_index(const char* fname, int io_flags) { // enable mmap-supporting IOReader auto owner = std::make_shared(fname); MappedFileIOReader reader(owner); - return read_index(&reader, io_flags); + Index* index = read_index(&reader, io_flags); + if (index != nullptr) { + index->mmap_owner = owner; + } + return index; } else { FileIOReader reader(fname); Index* idx = read_index(&reader, io_flags); diff --git a/thirdparty/faiss/faiss/impl/io.h b/thirdparty/faiss/faiss/impl/io.h index 59c2e3153..7f713d9f1 100644 --- a/thirdparty/faiss/faiss/impl/io.h +++ b/thirdparty/faiss/faiss/impl/io.h @@ -20,8 +20,6 @@ #include #include -#include - namespace faiss { struct IOReader { diff --git a/thirdparty/faiss/faiss/impl/maybe_owned_vector.h b/thirdparty/faiss/faiss/impl/maybe_owned_vector.h index 0437a6195..de0858d3f 100644 --- a/thirdparty/faiss/faiss/impl/maybe_owned_vector.h +++ b/thirdparty/faiss/faiss/impl/maybe_owned_vector.h @@ -14,11 +14,12 @@ struct MappingOwner { }; // a container that either works as std::vector that owns its own memory, -// or as a mapped pointer owned by someone third-party owner +// or as a view of a memory buffer, with a known size template struct MaybeOwnedVector { using value_type = T; using self_type = MaybeOwnedVector; + using vec_iterator = typename std::vector::const_iterator; bool is_owned = true; @@ -26,14 +27,13 @@ struct MaybeOwnedVector { std::vector owned_data; // these three are used if is_owned == false - T* mapped_data = nullptr; + T* view_data = nullptr; // the number of T elements - size_t mapped_size = 0; - std::shared_ptr mapping_owner; + size_t view_size = 0; - // points either to mapped_data, or to owned.data() + // points either to view_data, or to owned.data() T* c_ptr = nullptr; - // uses either mapped_size, or owned.size(); + // uses either view_size, or owned.size(); size_t c_size = 0; MaybeOwnedVector() = default; @@ -49,16 +49,15 @@ struct MaybeOwnedVector { is_owned = other.is_owned; owned_data = other.owned_data; - mapped_data = other.mapped_data; - mapped_size = other.mapped_size; - mapping_owner = other.mapping_owner; + view_data = other.view_data; + view_size = other.view_size; if (is_owned) { c_ptr = owned_data.data(); c_size = owned_data.size(); } else { - c_ptr = mapped_data; - c_size = mapped_size; + c_ptr = view_data; + c_size = view_size; } } @@ -66,16 +65,15 @@ struct MaybeOwnedVector { is_owned = other.is_owned; owned_data = std::move(other.owned_data); - mapped_data = other.mapped_data; - mapped_size = other.mapped_size; - mapping_owner = std::move(other.mapping_owner); + view_data = other.view_data; + view_size = other.view_size; if (is_owned) { c_ptr = owned_data.data(); c_size = owned_data.size(); } else { - c_ptr = mapped_data; - c_size = mapped_size; + c_ptr = view_data; + c_size = view_size; } } @@ -113,19 +111,16 @@ struct MaybeOwnedVector { c_size = owned_data.size(); } - static MaybeOwnedVector from_mmapped( + static MaybeOwnedVector createView( void* address, - const size_t n_mapped_elements, - const std::shared_ptr& owner - ) { + const size_t n_mapped_elements) { MaybeOwnedVector vec; vec.is_owned = false; - vec.mapped_data = reinterpret_cast(address); - vec.mapped_size = n_mapped_elements; - vec.mapping_owner = owner; + vec.view_data = reinterpret_cast(address); + vec.view_size = n_mapped_elements; - vec.c_ptr = vec.mapped_data; - vec.c_size = vec.mapped_size; + vec.c_ptr = vec.view_data; + vec.c_size = vec.view_size; return vec; } @@ -150,8 +145,26 @@ struct MaybeOwnedVector { return c_ptr[idx]; } + vec_iterator begin() const { + FAISS_ASSERT_MSG(is_owned, "This operation cannot be performed on a viewed vector"); + + return owned_data.begin(); + } + + vec_iterator end() const { + FAISS_ASSERT_MSG(is_owned, "This operation cannot be performed on a viewed vector"); + + return owned_data.end(); + } + + vec_iterator erase(vec_iterator begin, vec_iterator end) { + FAISS_ASSERT_MSG(is_owned, "This operation cannot be performed on a viewed vector"); + + return owned_data.erase(begin, end); + } + void clear() { - FAISS_ASSERT_MSG(is_owned, "This operation cannot be performed on a memory-mapped vector"); + FAISS_ASSERT_MSG(is_owned, "This operation cannot be performed on a viewed vector"); owned_data.clear(); c_ptr = owned_data.data(); @@ -159,7 +172,7 @@ struct MaybeOwnedVector { } void resize(const size_t new_size) { - FAISS_ASSERT_MSG(is_owned, "This operation cannot be performed on a memory-mapped vector"); + FAISS_ASSERT_MSG(is_owned, "This operation cannot be performed on a viewed vector"); owned_data.resize(new_size); c_ptr = owned_data.data(); @@ -167,7 +180,7 @@ struct MaybeOwnedVector { } void resize(const size_t new_size, const value_type v) { - FAISS_ASSERT_MSG(is_owned, "This operation cannot be performed on a memory-mapped vector"); + FAISS_ASSERT_MSG(is_owned, "This operation cannot be performed on a viewed vector"); owned_data.resize(new_size, v); c_ptr = owned_data.data(); @@ -177,12 +190,20 @@ struct MaybeOwnedVector { friend void swap(self_type& a, self_type& b) { std::swap(a.is_owned, b.is_owned); std::swap(a.owned_data, b.owned_data); - std::swap(a.mapped_data, b.mapped_data); - std::swap(a.mapped_size, b.mapped_size); - std::swap(a.mapping_owner, b.mapping_owner); + std::swap(a.view_data, b.view_data); + std::swap(a.view_size, b.view_size); std::swap(a.c_ptr, b.c_ptr); std::swap(a.c_size, b.c_size); } }; +template +struct is_maybe_owned_vector : std::false_type {}; + +template +struct is_maybe_owned_vector> : std::true_type {}; + +template +inline constexpr bool is_maybe_owned_vector_v = is_maybe_owned_vector::value; + } \ No newline at end of file diff --git a/thirdparty/faiss/faiss/impl/zerocopy_io.cpp b/thirdparty/faiss/faiss/impl/zerocopy_io.cpp new file mode 100644 index 000000000..1c0138f85 --- /dev/null +++ b/thirdparty/faiss/faiss/impl/zerocopy_io.cpp @@ -0,0 +1,61 @@ +#include +#include + +namespace faiss { + +ZeroCopyIOReader::ZeroCopyIOReader(uint8_t* data, size_t size) : data_(data), rp_(0), total_(size) { +} + +ZeroCopyIOReader::~ZeroCopyIOReader() { +} + +size_t +ZeroCopyIOReader::getDataView(void** ptr, size_t size, size_t nitems) { + if (size == 0) { + return nitems; + } + + size_t actual_size = size * nitems; + if (rp_ + size * nitems > total_) { + actual_size = total_ - rp_; + } + + size_t actual_nitems = (actual_size + size - 1) / size; + if (actual_nitems == 0) { + return 0; + } + + // get an address + *ptr = (void*)(reinterpret_cast(data_ + rp_)); + + // alter pos + rp_ += size * actual_nitems; + + return actual_nitems; +} + +void +ZeroCopyIOReader::reset() { + rp_ = 0; +} + +size_t +ZeroCopyIOReader::operator()(void* ptr, size_t size, size_t nitems) { + if (rp_ >= total_) { + return 0; + } + size_t nremain = (total_ - rp_) / size; + if (nremain < nitems) { + nitems = nremain; + } + memcpy(ptr, (data_ + rp_), size * nitems); + rp_ += size * nitems; + return nitems; +} + +int +ZeroCopyIOReader::filedescriptor() { + return -1; // Indicating no file descriptor available for memory buffer +} + +} // namespace faiss \ No newline at end of file diff --git a/thirdparty/faiss/faiss/impl/zerocopy_io.h b/thirdparty/faiss/faiss/impl/zerocopy_io.h new file mode 100644 index 000000000..d18a254b6 --- /dev/null +++ b/thirdparty/faiss/faiss/impl/zerocopy_io.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +namespace faiss { + +struct ZeroCopyIOReader : public faiss::IOReader { + uint8_t* data_; + size_t rp_ = 0; + size_t total_ = 0; + + ZeroCopyIOReader(uint8_t* data, size_t size); + ~ZeroCopyIOReader(); + + void reset(); + size_t getDataView(void** ptr, size_t size, size_t nitems); + size_t operator()(void* ptr, size_t size, size_t nitems) override; + + int filedescriptor() override; +}; + +} // namespace faiss \ No newline at end of file diff --git a/thirdparty/faiss/faiss/index_io.h b/thirdparty/faiss/faiss/index_io.h index 2f68a0bbb..06b3ee4d3 100644 --- a/thirdparty/faiss/faiss/index_io.h +++ b/thirdparty/faiss/faiss/index_io.h @@ -71,6 +71,8 @@ const int IO_FLAG_MMAP = IO_FLAG_SKIP_IVF_DATA | 0x646f0000; // after OnDiskInvertedLists get properly updated. const int IO_FLAG_MMAP_IFC = 1 << 9; +const int IO_FLAG_ZERO_COPY = 1 << 10; + Index* read_index(const char* fname, int io_flags = 0); Index* read_index(FILE* f, int io_flags = 0); Index* read_index(IOReader* reader, int io_flags = 0); diff --git a/thirdparty/faiss/faiss/invlists/InvertedLists.cpp b/thirdparty/faiss/faiss/invlists/InvertedLists.cpp index acf08c55b..a4574f3d3 100644 --- a/thirdparty/faiss/faiss/invlists/InvertedLists.cpp +++ b/thirdparty/faiss/faiss/invlists/InvertedLists.cpp @@ -488,8 +488,8 @@ InvertedLists* ArrayInvertedLists::to_readonly() { void ArrayInvertedLists::permute_invlists(const idx_t* map) { // todo aguzhva: permute norms as well? - std::vector> new_codes(nlist); - std::vector> new_ids(nlist); + std::vector> new_codes(nlist); + std::vector> new_ids(nlist); for (size_t i = 0; i < nlist; i++) { size_t o = map[i]; diff --git a/thirdparty/faiss/faiss/invlists/InvertedLists.h b/thirdparty/faiss/faiss/invlists/InvertedLists.h index bd4220017..5cb7d6257 100644 --- a/thirdparty/faiss/faiss/invlists/InvertedLists.h +++ b/thirdparty/faiss/faiss/invlists/InvertedLists.h @@ -23,6 +23,7 @@ #include #include +#include namespace faiss { @@ -338,11 +339,11 @@ struct InvertedLists { /// simple (default) implementation as an array of inverted lists struct ArrayInvertedLists : InvertedLists { - std::vector> codes; // binary codes, size nlist - std::vector> ids; ///< Inverted lists for indexes + std::vector> codes; // binary codes, size nlist + std::vector> ids; ///< Inverted lists for indexes bool with_norm = false; - std::vector> code_norms; // code norms + std::vector> code_norms; // code norms ArrayInvertedLists(size_t nlist, size_t code_size, bool with_norm = false); diff --git a/thirdparty/hnswlib/hnswlib/hnswalg.h b/thirdparty/hnswlib/hnswlib/hnswalg.h index f1b0aafa9..596187d1a 100644 --- a/thirdparty/hnswlib/hnswlib/hnswalg.h +++ b/thirdparty/hnswlib/hnswlib/hnswalg.h @@ -178,14 +178,16 @@ class HierarchicalNSW : public AlgorithmInterface { if (mmap_enabled_) { munmap(map_, map_size_); } else { - free(data_level0_memory_); - if (metric_type_ == Metric::COSINE) { + if (own_memory_data) { + free(data_level0_memory_); + } + if (metric_type_ == Metric::COSINE && own_memory_data) { free(data_norm_l2_); } } for (tableint i = 0; i < cur_element_count; i++) { - if (element_levels_[i] > 0) + if (element_levels_[i] > 0 && own_memory_data) free(linkLists_[i]); } free(linkLists_); @@ -226,6 +228,7 @@ class HierarchicalNSW : public AlgorithmInterface { size_t size_links_level0_; size_t offsetData_, offsetSQData_, offsetLevel0_; + bool own_memory_data = true; char* data_level0_memory_; float* data_norm_l2_; // vector's l2 norm char** linkLists_; @@ -937,7 +940,7 @@ class HierarchicalNSW : public AlgorithmInterface { } void - loadIndex(knowhere::MemoryIOReader& input, size_t max_elements_i = 0) { + loadIndex(knowhere::ZeroCopyIOReader& input, size_t max_elements_i = 0) { using knowhere::readBinaryPOD; // linxj: init with metrictype size_t dim; @@ -998,17 +1001,13 @@ class HierarchicalNSW : public AlgorithmInterface { readBinaryPOD(input, mult_); readBinaryPOD(input, ef_construction_); - data_level0_memory_ = (char*)malloc(max_elements * size_data_per_element_); // NOLINT - if (data_level0_memory_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); - input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + own_memory_data = false; + + input.getDataView(reinterpret_cast(&data_level0_memory_), size_data_per_element_, cur_element_count); // for COSINE, need load data_norm_l2_ if (metric_type_ == Metric::COSINE) { - data_norm_l2_ = (float*)malloc(max_elements * sizeof(float)); // NOLINT - if (data_norm_l2_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); - input.read(data_norm_l2_, cur_element_count * sizeof(float)); + input.getDataView(reinterpret_cast(&data_norm_l2_), sizeof(float), cur_element_count); } size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); @@ -1031,10 +1030,7 @@ class HierarchicalNSW : public AlgorithmInterface { linkLists_[i] = nullptr; } else { element_levels_[i] = linkListSize / size_links_per_element_; - linkLists_[i] = (char*)malloc(linkListSize); - if (linkLists_[i] == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); - input.read(linkLists_[i], linkListSize); + input.getDataView(reinterpret_cast(&linkLists_[i]), sizeof(char), linkListSize); } } }