Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] feature : add zero copy support for vector index #1032

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/knowhere/index/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class Index {
Serialize(BinarySet& binset) const;

Status
Deserialize(const BinarySet& binset, const Json& json = {});
Deserialize(BinarySet&& binset, const Json& json = {});

Status
DeserializeFromFile(const std::string& filename, const Json& json = {});
Expand Down
3 changes: 2 additions & 1 deletion include/knowhere/index/index_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ class IndexNode : public Object {
* its release.
*/
virtual Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> config) = 0;
Deserialize(BinarySet&& binset, std::shared_ptr<Config> config) = 0;

/**
* @brief Deserializes the index from a file.
Expand Down Expand Up @@ -437,6 +437,7 @@ class IndexNode : public Object {
}

protected:
BinarySet binarySet_;
Version version_;
};

Expand Down
4 changes: 2 additions & 2 deletions include/knowhere/index/index_node_data_mock_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class IndexNodeDataMockWrapper : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> cfg) override {
return index_node_->Deserialize(binset, std::move(cfg));
Deserialize(BinarySet&& binset, std::shared_ptr<Config> cfg) override {
return index_node_->Deserialize(std::move(binset), std::move(cfg));
}

Status
Expand Down
4 changes: 2 additions & 2 deletions include/knowhere/index/index_node_thread_pool_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class IndexNodeThreadPoolWrapper : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> cfg) override {
return index_node_->Deserialize(binset, std::move(cfg));
Deserialize(BinarySet&& binset, std::shared_ptr<Config> cfg) override {
return index_node_->Deserialize(std::move(binset), std::move(cfg));
}

Status
Expand Down
2 changes: 1 addition & 1 deletion python/knowhere/knowhere.i
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ class IndexWrap {
knowhere::Status
Deserialize(knowhere::BinarySetPtr binset, const std::string& json) {
GILReleaser rel;
return idx.value().Deserialize(*binset, knowhere::Json::parse(json));
return idx.value().Deserialize(std::move(*binset), knowhere::Json::parse(json));
}

knowhere::Status
Expand Down
4 changes: 2 additions & 2 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class DiskANNIndexNode : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> cfg) override;
Deserialize(BinarySet&& binset, std::shared_ptr<Config> cfg) override;

Status
DeserializeFromFile(const std::string& filename, std::shared_ptr<Config> config) override {
Expand Down Expand Up @@ -385,7 +385,7 @@ DiskANNIndexNode<DataType>::Build(const DataSetPtr dataset, std::shared_ptr<Conf

template <typename DataType>
Status
DiskANNIndexNode<DataType>::Deserialize(const BinarySet& binset, std::shared_ptr<Config> cfg) {
DiskANNIndexNode<DataType>::Deserialize(BinarySet&& binset, std::shared_ptr<Config> cfg) {
std::lock_guard<std::mutex> lock(preparation_lock_);
auto prep_conf = static_cast<const DiskANNConfig&>(*cfg);
if (!CheckMetric(prep_conf.metric_type.value())) {
Expand Down
13 changes: 8 additions & 5 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "faiss/IndexBinaryFlat.h"
#include "faiss/IndexFlat.h"
#include "faiss/impl/AuxIndexStructures.h"
#include "faiss/impl/zerocopy_io.h"
#include "faiss/index_io.h"
#include "index/flat/flat_config.h"
#include "io/memory_io.h"
Expand Down Expand Up @@ -313,23 +314,25 @@ class FlatIndexNode : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config>) override {
Deserialize(BinarySet&& binset, std::shared_ptr<Config>) override {
std::vector<std::string> names = {"IVF", // compatible with knowhere-1.x
"BinaryIVF", // compatible with knowhere-1.x
Type()};
auto binary = binset.GetByNames(names);
binarySet_ = std::move(binset);
auto binary = binarySet_.GetByNames(names);
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return Status::invalid_binary_set;
}

MemoryIOReader reader(binary->data.get(), binary->size);
int io_flags = faiss::IO_FLAG_ZERO_COPY;
faiss::ZeroCopyIOReader reader(binary->data.get(), binary->size);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such an approach won't lead to a memory leak, correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The binary will be held by the index; so there will be no memory leak

if constexpr (std::is_same<IndexType, faiss::IndexFlat>::value) {
faiss::Index* index = faiss::read_index(&reader);
faiss::Index* index = faiss::read_index(&reader, io_flags);
index_.reset(static_cast<IndexType*>(index));
}
if constexpr (std::is_same<IndexType, faiss::IndexBinaryFlat>::value) {
faiss::IndexBinary* index = faiss::read_index_binary(&reader);
faiss::IndexBinary* index = faiss::read_index_binary(&reader, io_flags);
index_.reset(static_cast<IndexType*>(index));
}
return Status::success;
Expand Down
2 changes: 1 addition & 1 deletion src/index/gpu/flat_gpu/flat_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class GpuFlatIndexNode : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, const Config& config) override {
Deserialize(BinarySet&& binset, const Config& config) override {
auto binary = binset.GetByName(Type());
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
Expand Down
2 changes: 1 addition & 1 deletion src/index/gpu/ivf_gpu/ivf_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class GpuIvfIndexNode : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, const Config& config) override {
Deserialize(BinarySet&& binset, const Config& config) override {
auto binary = binset.GetByName(Type());
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "invalid binary set.";
Expand Down
2 changes: 1 addition & 1 deletion src/index/gpu_raft/gpu_raft.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ struct GpuRaftIndexNode : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config>) override {
Deserialize(BinarySet&& binset, std::shared_ptr<Config>) override {
auto result = Status::success;
std::stringbuf buf;
auto binary = binset.GetByName(this->Type());
Expand Down
4 changes: 2 additions & 2 deletions src/index/gpu_raft/gpu_raft_cagra.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class GpuRaftCagraHybridIndexNode : public GpuRaftCagraIndexNode<DataType> {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> cfg) override {
Deserialize(BinarySet&& binset, std::shared_ptr<Config> cfg) override {
if (binset.Contains(std::string(this->Type()) + "_cpu")) {
this->adapt_for_cpu = true;
auto binary = binset.GetByName(std::string(this->Type() + "_cpu"));
Expand All @@ -147,7 +147,7 @@ class GpuRaftCagraHybridIndexNode : public GpuRaftCagraIndexNode<DataType> {
return Status::success;
}

return GpuRaftCagraIndexNode<DataType>::Deserialize(binset, std::move(cfg));
return GpuRaftCagraIndexNode<DataType>::Deserialize(std::move(binset), std::move(cfg));
}

Status
Expand Down
16 changes: 9 additions & 7 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "faiss/IndexRefine.h"
#include "faiss/impl/ScalarQuantizer.h"
#include "faiss/impl/mapped_io.h"
#include "faiss/impl/zerocopy_io.h"
#include "faiss/index_io.h"
#include "index/hnsw/faiss_hnsw_config.h"
#include "index/hnsw/hnsw.h"
Expand All @@ -52,7 +53,6 @@
#include "knowhere/log.h"
#include "knowhere/range_util.h"
#include "knowhere/utils.h"

#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
#include "knowhere/prometheus_client.h"
#endif
Expand Down Expand Up @@ -201,14 +201,16 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> config) override {
auto binary = binset.GetByName(Type());
Deserialize(BinarySet&& binset, std::shared_ptr<Config> config) override {
binarySet_ = std::move(binset);
auto binary = binarySet_.GetByName(Type());
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return Status::invalid_binary_set;
}

MemoryIOReader reader(binary->data.get(), binary->size);
int io_flags = faiss::IO_FLAG_ZERO_COPY;
faiss::ZeroCopyIOReader reader(binary->data.get(), binary->size);
try {
// this is a hack for compatibility, faiss index has 4-byte header to indicate index category
// create a new one to distinguish MV faiss hnsw from faiss hnsw
Expand Down Expand Up @@ -1890,11 +1892,11 @@ class HNSWIndexNodeWithFallback : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> config) override {
Deserialize(BinarySet&& binset, std::shared_ptr<Config> config) override {
if (use_base_index) {
return base_index->Deserialize(binset, config);
return base_index->Deserialize(std::move(binset), config);
} else {
return fallback_search_index->Deserialize(binset, config);
return fallback_search_index->Deserialize(std::move(binset), config);
}
}

Expand Down
7 changes: 4 additions & 3 deletions src/index/hnsw/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,18 +485,19 @@ class HnswIndexNode : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config>) override {
Deserialize(BinarySet&& binset, std::shared_ptr<Config>) override {
if (index_) {
delete index_;
}
try {
auto binary = binset.GetByName(Type());
binarySet_ = std::move(binset);
auto binary = binarySet_.GetByName(Type());
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return Status::invalid_binary_set;
}

MemoryIOReader reader(binary->data.get(), binary->size);
ZeroCopyIOReader reader(binary->data.get(), binary->size);

hnswlib::SpaceInterface<DistType>* space = nullptr;
index_ = new (std::nothrow) hnswlib::HierarchicalNSW<DataType, DistType, quant_type>(space);
Expand Down
6 changes: 3 additions & 3 deletions src/index/index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ Index<T>::Serialize(BinarySet& binset) const {

template <typename T>
inline Status
Index<T>::Deserialize(const BinarySet& binset, const Json& json) {
Index<T>::Deserialize(BinarySet&& binset, const Json& json) {
Json json_(json);
auto cfg = this->node->CreateConfig();
{
Expand All @@ -318,12 +318,12 @@ Index<T>::Deserialize(const BinarySet& binset, const Json& json) {

#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
TimeRecorder rc("Load index", 2);
res = this->node->Deserialize(binset, std::move(cfg));
res = this->node->Deserialize(std::move(binset), std::move(cfg));
auto time = rc.ElapseFromBegin("done");
time *= 0.001; // convert to ms
knowhere_load_latency.Observe(time);
#else
res = this->node->Deserialize(binset, std::move(cfg));
res = this->node->Deserialize(std::move(binset), std::move(cfg));
#endif
return res;
}
Expand Down
18 changes: 10 additions & 8 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "faiss/IndexIVFScalarQuantizerCC.h"
#include "faiss/IndexScaNN.h"
#include "faiss/IndexScalarQuantizer.h"
#include "faiss/impl/zerocopy_io.h"
#include "faiss/index_io.h"
#include "index/ivf/ivf_config.h"
#include "io/memory_io.h"
Expand Down Expand Up @@ -162,7 +163,7 @@ class IvfIndexNode : public IndexNode {
return this->SerializeImpl(binset, typename IndexDispatch<IndexType>::Tag{});
}
Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> cfg) override;
Deserialize(BinarySet&& binset, std::shared_ptr<Config> cfg) override;
Status
DeserializeFromFile(const std::string& filename, std::shared_ptr<Config> cfg) override;

Expand Down Expand Up @@ -1172,17 +1173,18 @@ IvfIndexNode<DataType, IndexType>::SerializeImpl(BinarySet& binset, IVFFlatTag)

template <typename DataType, typename IndexType>
Status
IvfIndexNode<DataType, IndexType>::Deserialize(const BinarySet& binset, std::shared_ptr<Config> cfg) {
IvfIndexNode<DataType, IndexType>::Deserialize(BinarySet&& binset, std::shared_ptr<Config> cfg) {
std::vector<std::string> names = {"IVF", // compatible with knowhere-1.x
"BinaryIVF", // compatible with knowhere-1.x
Type()};
auto binary = binset.GetByNames(names);
binarySet_ = std::move(binset);
auto binary = binarySet_.GetByNames(names);
if (binary == nullptr) {
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return Status::invalid_binary_set;
}

MemoryIOReader reader(binary->data.get(), binary->size);
int io_flags = faiss::IO_FLAG_ZERO_COPY;
faiss::ZeroCopyIOReader reader(binary->data.get(), binary->size);
try {
if constexpr (std::is_same<IndexType, faiss::IndexIVFFlat>::value) {
if (this->version_ <= Version::GetMinimalVersion()) {
Expand All @@ -1193,11 +1195,11 @@ IvfIndexNode<DataType, IndexType>::Deserialize(const BinarySet& binset, std::sha
reader.data_ = binary->data.get();
reader.total_ = binary->size;
}
index_.reset(static_cast<faiss::IndexIVFFlat*>(faiss::read_index(&reader)));
index_.reset(static_cast<faiss::IndexIVFFlat*>(faiss::read_index(&reader, io_flags)));
} else if constexpr (std::is_same<IndexType, faiss::IndexBinaryIVF>::value) {
index_.reset(static_cast<IndexType*>(faiss::read_index_binary(&reader)));
index_.reset(static_cast<IndexType*>(faiss::read_index_binary(&reader, io_flags)));
} else {
index_.reset(static_cast<IndexType*>(faiss::read_index(&reader)));
index_.reset(static_cast<IndexType*>(faiss::read_index(&reader, io_flags)));
}
if constexpr (!std::is_same_v<IndexType, faiss::IndexScaNN> &&
!std::is_same_v<IndexType, faiss::IndexIVFScalarQuantizerCC>) {
Expand Down
4 changes: 2 additions & 2 deletions src/index/sparse/sparse_index_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class SparseInvertedIndexNode : public IndexNode {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> config) override {
Deserialize(BinarySet&& binset, std::shared_ptr<Config> config) override {
if (index_ != nullptr) {
LOG_KNOWHERE_WARNING_ << Type() << " has already been created, deleting old";
DeleteExistingIndex();
Expand Down Expand Up @@ -528,7 +528,7 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode<T, use_wand> {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> config) override {
Deserialize(BinarySet&& binset, std::shared_ptr<Config> config) override {
return Status::not_implemented;
}

Expand Down
Loading