Skip to content

Commit

Permalink
Add support for Qdrant Vector Store (#453)
Browse files Browse the repository at this point in the history
* Add Qdrant vector store client

* Add setup for Qdrant vector store

* Add lazy import for Qdrant client library

* Add import_qdrant to index file

* Use models not types for building configs

* Add tests for Qdrant vector store
  • Loading branch information
parthvnp authored Jun 29, 2023
1 parent 66a3a1b commit 98fae33
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 1 deletion.
40 changes: 40 additions & 0 deletions gptcache/manager/vector_data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
PGVECTOR_URL = "postgresql://postgres:postgres@localhost:5432/postgres"
PGVECTOR_INDEX_PARAMS = {"index_type": "L2", "params": {"lists": 100, "probes": 10}}

QDRANT_GRPC_PORT = 6334
QDRANT_HTTP_PORT = 6333
QDRANT_INDEX_PARAMS = {"ef_construct": 100, "m": 16}
QDRANT_DEFAULT_LOCATION = "./qdrant_data"
QDRANT_FLUSH_INTERVAL_SEC = 5

COLLECTION_NAME = "gptcache"


Expand Down Expand Up @@ -217,6 +223,40 @@ def get(name, **kwargs):
collection_name=collection_name,
top_k=top_k,
)
elif name == "qdrant":
from gptcache.manager.vector_data.qdrant import QdrantVectorStore
url = kwargs.get("url", None)
port = kwargs.get("port", QDRANT_HTTP_PORT)
grpc_port = kwargs.get("grpc_port", QDRANT_GRPC_PORT)
prefer_grpc = kwargs.get("prefer_grpc", False)
https = kwargs.get("https", False)
api_key = kwargs.get("api_key", None)
prefix = kwargs.get("prefix", None)
timeout = kwargs.get("timeout", None)
host = kwargs.get("host", None)
collection_name = kwargs.get("collection_name", COLLECTION_NAME)
location = kwargs.get("location", QDRANT_DEFAULT_LOCATION)
dimension = kwargs.get("dimension", DIMENSION)
top_k: int = kwargs.get("top_k", TOP_K)
flush_interval_sec = kwargs.get("flush_interval_sec", QDRANT_FLUSH_INTERVAL_SEC)
index_params = kwargs.get("index_params", QDRANT_INDEX_PARAMS)
vector_base = QdrantVectorStore(
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
collection_name=collection_name,
location=location,
dimension=dimension,
top_k=top_k,
flush_interval_sec=flush_interval_sec,
index_params=index_params,
)
else:
raise NotFoundError("vector store", name)
return vector_base
107 changes: 107 additions & 0 deletions gptcache/manager/vector_data/qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import List, Optional
import numpy as np

from gptcache.utils import import_qdrant
from gptcache.utils.log import gptcache_log
from gptcache.manager.vector_data.base import VectorBase, VectorData

import_qdrant()

from qdrant_client import QdrantClient # pylint: disable=C0413
from qdrant_client.models import PointStruct, HnswConfigDiff, VectorParams, OptimizersConfigDiff, \
Distance # pylint: disable=C0413


class QdrantVectorStore(VectorBase):

def __init__(
self,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
collection_name: Optional[str] = "gptcache",
location: Optional[str] = "./qdrant",
dimension: int = 0,
top_k: int = 1,
flush_interval_sec: int = 5,
index_params: Optional[dict] = None,
):
if dimension <= 0:
raise ValueError(
f"invalid `dim` param: {dimension} in the Qdrant vector store."
)
self._client: QdrantClient
self._collection_name = collection_name
self._in_memory = location == ":memory:"
self.dimension = dimension
self.top_k = top_k
if self._in_memory or location is not None:
self._create_local(location)
else:
self._create_remote(url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https)
self._create_collection(collection_name, flush_interval_sec, index_params)

def _create_local(self, location):
self._client = QdrantClient(location=location)

def _create_remote(self, url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https):
self._client = QdrantClient(
url=url,
port=port,
api_key=api_key,
timeout=timeout,
host=host,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
prefix=prefix,
https=https,
)

def _create_collection(self, collection_name: str, flush_interval_sec: int, index_params: Optional[dict] = None):
hnsw_config = HnswConfigDiff(**(index_params or {}))
vectors_config = VectorParams(size=self.dimension, distance=Distance.COSINE,
hnsw_config=hnsw_config)
optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000,
flush_interval_sec=flush_interval_sec)
# check if the collection exists
existing_collections = self._client.get_collections()
for existing_collection in existing_collections.collections:
if existing_collection.name == collection_name:
gptcache_log.warning("The %s collection already exists, and it will be used directly.", collection_name)
break
else:
self._client.create_collection(collection_name=collection_name, vectors_config=vectors_config,
optimizers_config=optimizers_config)

def mul_add(self, datas: List[VectorData]):
points = [PointStruct(id=d.id, vector=d.data.reshape(-1).tolist()) for d in datas]
self._client.upsert(collection_name=self._collection_name, points=points, wait=False)

def search(self, data: np.ndarray, top_k: int = -1):
if top_k == -1:
top_k = self.top_k
reshaped_data = data.reshape(-1).tolist()
search_result = self._client.search(collection_name=self._collection_name, query_vector=reshaped_data,
limit=top_k)
return list(map(lambda x: (x.score, x.id), search_result))

def delete(self, ids: List[str]):
self._client.delete(collection_name=self._collection_name, points_selector=ids)

def rebuild(self, ids=None): # pylint: disable=unused-argument
optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000)
self._client.update_collection(collection_name=self._collection_name, optimizer_config=optimizers_config)

def flush(self):
# no need to flush manually as qdrant flushes automatically based on the optimizers_config for remote Qdrant
pass


def close(self):
self.flush()
7 changes: 6 additions & 1 deletion gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
"import_paddlenlp",
"import_tiktoken",
"import_fastapi",
"import_redis"
"import_redis",
"import_qdrant"
]

import importlib.util
Expand All @@ -65,6 +66,10 @@ def import_milvus_lite():
_check_library("milvus")


def import_qdrant():
_check_library("qdrant_client")


def import_sbert():
_check_library("sentence_transformers", package="sentence-transformers")

Expand Down
33 changes: 33 additions & 0 deletions tests/unit_tests/manager/test_qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import unittest

import numpy as np

from gptcache.manager.vector_data import VectorBase
from gptcache.manager.vector_data.base import VectorData


class TestQdrant(unittest.TestCase):
def test_normal(self):
size = 10
dim = 2
top_k = 10
qdrant = VectorBase(
"qdrant",
top_k=top_k,
dimension=dim,
location=":memory:"
)
data = np.random.randn(size, dim).astype(np.float32)
qdrant.mul_add([VectorData(id=i, data=v) for v, i in zip(data, range(size))])
search_result = qdrant.search(data[0], top_k)
self.assertEqual(len(search_result), top_k)
qdrant.mul_add([VectorData(id=size, data=data[0])])
ret = qdrant.search(data[0])
self.assertIn(ret[0][1], [0, size])
self.assertIn(ret[1][1], [0, size])
qdrant.delete([0, 1, 2, 3, 4, 5, size])
ret = qdrant.search(data[0])
self.assertNotIn(ret[0][1], [0, size])
qdrant.rebuild()
qdrant.close()

0 comments on commit 98fae33

Please sign in to comment.