Skip to content

Commit

Permalink
add zero copy support for faiss
Browse files Browse the repository at this point in the history
Signed-off-by: xianliang.li <[email protected]>
  • Loading branch information
foxspy committed Feb 6, 2025
1 parent 45d757c commit 91cfcde
Show file tree
Hide file tree
Showing 35 changed files with 398 additions and 151 deletions.
2 changes: 1 addition & 1 deletion include/knowhere/index/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class Index {
Serialize(BinarySet& binset) const;

Status
Deserialize(const BinarySet& binset, const Json& json = {});
Deserialize(BinarySet&& binset, const Json& json = {});

Status
DeserializeFromFile(const std::string& filename, const Json& json = {});
Expand Down
3 changes: 2 additions & 1 deletion include/knowhere/index/index_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ class IndexNode : public Object {
* its release.
*/
virtual Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> config) = 0;
Deserialize(BinarySet&& binset, std::shared_ptr<Config> config) = 0;

/**
* @brief Deserializes the index from a file.
Expand Down Expand Up @@ -437,6 +437,7 @@ class IndexNode : public Object {
}

protected:
BinarySet binarySet_;
Version version_;
};

Expand Down
4 changes: 2 additions & 2 deletions include/knowhere/index/index_node_data_mock_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class IndexNodeDataMockWrapper : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> cfg) override {
return index_node_->Deserialize(binset, std::move(cfg));
Deserialize(BinarySet&& binset, std::shared_ptr<Config> cfg) override {
return index_node_->Deserialize(std::move(binset), std::move(cfg));
}

Status
Expand Down
4 changes: 2 additions & 2 deletions include/knowhere/index/index_node_thread_pool_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class IndexNodeThreadPoolWrapper : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> cfg) override {
return index_node_->Deserialize(binset, std::move(cfg));
Deserialize(BinarySet&& binset, std::shared_ptr<Config> cfg) override {
return index_node_->Deserialize(std::move(binset), std::move(cfg));
}

Status
Expand Down
2 changes: 1 addition & 1 deletion python/knowhere/knowhere.i
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ class IndexWrap {
knowhere::Status
Deserialize(knowhere::BinarySetPtr binset, const std::string& json) {
GILReleaser rel;
return idx.value().Deserialize(*binset, knowhere::Json::parse(json));
return idx.value().Deserialize(std::move(*binset), knowhere::Json::parse(json));
}

knowhere::Status
Expand Down
4 changes: 2 additions & 2 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class DiskANNIndexNode : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> cfg) override;
Deserialize(BinarySet&& binset, std::shared_ptr<Config> cfg) override;

Status
DeserializeFromFile(const std::string& filename, std::shared_ptr<Config> config) override {
Expand Down Expand Up @@ -385,7 +385,7 @@ DiskANNIndexNode<DataType>::Build(const DataSetPtr dataset, std::shared_ptr<Conf

template <typename DataType>
Status
DiskANNIndexNode<DataType>::Deserialize(const BinarySet& binset, std::shared_ptr<Config> cfg) {
DiskANNIndexNode<DataType>::Deserialize(BinarySet&& binset, std::shared_ptr<Config> cfg) {
std::lock_guard<std::mutex> lock(preparation_lock_);
auto prep_conf = static_cast<const DiskANNConfig&>(*cfg);
if (!CheckMetric(prep_conf.metric_type.value())) {
Expand Down
13 changes: 8 additions & 5 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "faiss/IndexBinaryFlat.h"
#include "faiss/IndexFlat.h"
#include "faiss/impl/AuxIndexStructures.h"
#include "faiss/impl/zerocopy_io.h"
#include "faiss/index_io.h"
#include "index/flat/flat_config.h"
#include "io/memory_io.h"
Expand Down Expand Up @@ -313,23 +314,25 @@ class FlatIndexNode : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config>) override {
Deserialize(BinarySet&& binset, std::shared_ptr<Config>) override {
std::vector<std::string> names = {"IVF", // compatible with knowhere-1.x
"BinaryIVF", // compatible with knowhere-1.x
Type()};
auto binary = binset.GetByNames(names);
binarySet_ = std::move(binset);
auto binary = binarySet_.GetByNames(names);
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return Status::invalid_binary_set;
}

MemoryIOReader reader(binary->data.get(), binary->size);
int io_flags = faiss::IO_FLAG_ZERO_COPY;
faiss::ZeroCopyIOReader reader(binary->data.get(), binary->size);
if constexpr (std::is_same<IndexType, faiss::IndexFlat>::value) {
faiss::Index* index = faiss::read_index(&reader);
faiss::Index* index = faiss::read_index(&reader, io_flags);
index_.reset(static_cast<IndexType*>(index));
}
if constexpr (std::is_same<IndexType, faiss::IndexBinaryFlat>::value) {
faiss::IndexBinary* index = faiss::read_index_binary(&reader);
faiss::IndexBinary* index = faiss::read_index_binary(&reader, io_flags);
index_.reset(static_cast<IndexType*>(index));
}
return Status::success;
Expand Down
2 changes: 1 addition & 1 deletion src/index/gpu/flat_gpu/flat_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class GpuFlatIndexNode : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, const Config& config) override {
Deserialize(BinarySet&& binset, const Config& config) override {
auto binary = binset.GetByName(Type());
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
Expand Down
2 changes: 1 addition & 1 deletion src/index/gpu/ivf_gpu/ivf_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class GpuIvfIndexNode : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, const Config& config) override {
Deserialize(BinarySet&& binset, const Config& config) override {
auto binary = binset.GetByName(Type());
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "invalid binary set.";
Expand Down
2 changes: 1 addition & 1 deletion src/index/gpu_raft/gpu_raft.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ struct GpuRaftIndexNode : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config>) override {
Deserialize(BinarySet&& binset, std::shared_ptr<Config>) override {
auto result = Status::success;
std::stringbuf buf;
auto binary = binset.GetByName(this->Type());
Expand Down
4 changes: 2 additions & 2 deletions src/index/gpu_raft/gpu_raft_cagra.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class GpuRaftCagraHybridIndexNode : public GpuRaftCagraIndexNode<DataType> {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> cfg) override {
Deserialize(BinarySet&& binset, std::shared_ptr<Config> cfg) override {
if (binset.Contains(std::string(this->Type()) + "_cpu")) {
this->adapt_for_cpu = true;
auto binary = binset.GetByName(std::string(this->Type() + "_cpu"));
Expand All @@ -147,7 +147,7 @@ class GpuRaftCagraHybridIndexNode : public GpuRaftCagraIndexNode<DataType> {
return Status::success;
}

return GpuRaftCagraIndexNode<DataType>::Deserialize(binset, std::move(cfg));
return GpuRaftCagraIndexNode<DataType>::Deserialize(std::move(binset), std::move(cfg));
}

Status
Expand Down
16 changes: 9 additions & 7 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "faiss/IndexRefine.h"
#include "faiss/impl/ScalarQuantizer.h"
#include "faiss/impl/mapped_io.h"
#include "faiss/impl/zerocopy_io.h"
#include "faiss/index_io.h"
#include "index/hnsw/faiss_hnsw_config.h"
#include "index/hnsw/hnsw.h"
Expand All @@ -52,7 +53,6 @@
#include "knowhere/log.h"
#include "knowhere/range_util.h"
#include "knowhere/utils.h"

#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
#include "knowhere/prometheus_client.h"
#endif
Expand Down Expand Up @@ -201,14 +201,16 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> config) override {
auto binary = binset.GetByName(Type());
Deserialize(BinarySet&& binset, std::shared_ptr<Config> config) override {
binarySet_ = std::move(binset);
auto binary = binarySet_.GetByName(Type());
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return Status::invalid_binary_set;
}

MemoryIOReader reader(binary->data.get(), binary->size);
int io_flags = faiss::IO_FLAG_ZERO_COPY;
faiss::ZeroCopyIOReader reader(binary->data.get(), binary->size);
try {
// this is a hack for compatibility, faiss index has 4-byte header to indicate index category
// create a new one to distinguish MV faiss hnsw from faiss hnsw
Expand Down Expand Up @@ -1890,11 +1892,11 @@ class HNSWIndexNodeWithFallback : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> config) override {
Deserialize(BinarySet&& binset, std::shared_ptr<Config> config) override {
if (use_base_index) {
return base_index->Deserialize(binset, config);
return base_index->Deserialize(std::move(binset), config);
} else {
return fallback_search_index->Deserialize(binset, config);
return fallback_search_index->Deserialize(std::move(binset), config);
}
}

Expand Down
7 changes: 4 additions & 3 deletions src/index/hnsw/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,18 +485,19 @@ class HnswIndexNode : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config>) override {
Deserialize(BinarySet&& binset, std::shared_ptr<Config>) override {
if (index_) {
delete index_;
}
try {
auto binary = binset.GetByName(Type());
binarySet_ = std::move(binset);
auto binary = binarySet_.GetByName(Type());
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return Status::invalid_binary_set;
}

MemoryIOReader reader(binary->data.get(), binary->size);
ZeroCopyIOReader reader(binary->data.get(), binary->size);

hnswlib::SpaceInterface<DistType>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DataType, DistType, quant_type>(space);
Expand Down
6 changes: 3 additions & 3 deletions src/index/index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ Index<T>::Serialize(BinarySet& binset) const {

template <typename T>
inline Status
Index<T>::Deserialize(const BinarySet& binset, const Json& json) {
Index<T>::Deserialize(BinarySet&& binset, const Json& json) {
Json json_(json);
auto cfg = this->node->CreateConfig();
{
Expand All @@ -318,12 +318,12 @@ Index<T>::Deserialize(const BinarySet& binset, const Json& json) {

#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
TimeRecorder rc("Load index", 2);
res = this->node->Deserialize(binset, std::move(cfg));
res = this->node->Deserialize(std::move(binset), std::move(cfg));
auto time = rc.ElapseFromBegin("done");
time *= 0.001; // convert to ms
knowhere_load_latency.Observe(time);
#else
res = this->node->Deserialize(binset, std::move(cfg));
res = this->node->Deserialize(std::move(binset), std::move(cfg));
#endif
return res;
}
Expand Down
18 changes: 10 additions & 8 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "faiss/IndexIVFScalarQuantizerCC.h"
#include "faiss/IndexScaNN.h"
#include "faiss/IndexScalarQuantizer.h"
#include "faiss/impl/zerocopy_io.h"
#include "faiss/index_io.h"
#include "index/ivf/ivf_config.h"
#include "io/memory_io.h"
Expand Down Expand Up @@ -162,7 +163,7 @@ class IvfIndexNode : public IndexNode {
return this->SerializeImpl(binset, typename IndexDispatch<IndexType>::Tag{});
}
Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> cfg) override;
Deserialize(BinarySet&& binset, std::shared_ptr<Config> cfg) override;
Status
DeserializeFromFile(const std::string& filename, std::shared_ptr<Config> cfg) override;

Expand Down Expand Up @@ -1172,17 +1173,18 @@ IvfIndexNode<DataType, IndexType>::SerializeImpl(BinarySet& binset, IVFFlatTag)

template <typename DataType, typename IndexType>
Status
IvfIndexNode<DataType, IndexType>::Deserialize(const BinarySet& binset, std::shared_ptr<Config> cfg) {
IvfIndexNode<DataType, IndexType>::Deserialize(BinarySet&& binset, std::shared_ptr<Config> cfg) {
std::vector<std::string> names = {"IVF", // compatible with knowhere-1.x
"BinaryIVF", // compatible with knowhere-1.x
Type()};
auto binary = binset.GetByNames(names);
binarySet_ = std::move(binset);
auto binary = binarySet_.GetByNames(names);
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return Status::invalid_binary_set;
}

MemoryIOReader reader(binary->data.get(), binary->size);
int io_flags = faiss::IO_FLAG_ZERO_COPY;
faiss::ZeroCopyIOReader reader(binary->data.get(), binary->size);
try {
if constexpr (std::is_same<IndexType, faiss::IndexIVFFlat>::value) {
if (this->version_ <= Version::GetMinimalVersion()) {
Expand All @@ -1193,11 +1195,11 @@ IvfIndexNode<DataType, IndexType>::Deserialize(const BinarySet& binset, std::sha
reader.data_ = binary->data.get();
reader.total_ = binary->size;
}
index_.reset(static_cast<faiss::IndexIVFFlat*>(faiss::read_index(&reader)));
index_.reset(static_cast<faiss::IndexIVFFlat*>(faiss::read_index(&reader, io_flags)));
} else if constexpr (std::is_same<IndexType, faiss::IndexBinaryIVF>::value) {
index_.reset(static_cast<IndexType*>(faiss::read_index_binary(&reader)));
index_.reset(static_cast<IndexType*>(faiss::read_index_binary(&reader, io_flags)));
} else {
index_.reset(static_cast<IndexType*>(faiss::read_index(&reader)));
index_.reset(static_cast<IndexType*>(faiss::read_index(&reader, io_flags)));
}
if constexpr (!std::is_same_v<IndexType, faiss::IndexScaNN> &&
!std::is_same_v<IndexType, faiss::IndexIVFScalarQuantizerCC>) {
Expand Down
4 changes: 2 additions & 2 deletions src/index/sparse/sparse_index_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class SparseInvertedIndexNode : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> config) override {
Deserialize(BinarySet&& binset, std::shared_ptr<Config> config) override {
if (index_ != nullptr) {
LOG_KNOWHERE_WARNING_ << Type() << " has already been created, deleting old";
DeleteExistingIndex();
Expand Down Expand Up @@ -528,7 +528,7 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode<T, use_wand> {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> config) override {
Deserialize(BinarySet&& binset, std::shared_ptr<Config> config) override {
return Status::not_implemented;
}

Expand Down
Loading

0 comments on commit 91cfcde

Please sign in to comment.