Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Qdrant Vector Store #453

Merged
merged 6 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions gptcache/manager/vector_data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
PGVECTOR_URL = "postgresql://postgres:postgres@localhost:5432/postgres"
PGVECTOR_INDEX_PARAMS = {"index_type": "L2", "params": {"lists": 100, "probes": 10}}

QDRANT_GRPC_PORT = 6334
QDRANT_HTTP_PORT = 6333
QDRANT_INDEX_PARAMS = {"ef_construct": 100, "m": 16}
QDRANT_DEFAULT_LOCATION = "./qdrant_data"
QDRANT_FLUSH_INTERVAL_SEC = 5

COLLECTION_NAME = "gptcache"


Expand Down Expand Up @@ -217,6 +223,40 @@ def get(name, **kwargs):
collection_name=collection_name,
top_k=top_k,
)
elif name == "qdrant":
from gptcache.manager.vector_data.qdrant import QdrantVectorStore
url = kwargs.get("url", None)
port = kwargs.get("port", QDRANT_HTTP_PORT)
grpc_port = kwargs.get("grpc_port", QDRANT_GRPC_PORT)
prefer_grpc = kwargs.get("prefer_grpc", False)
https = kwargs.get("https", False)
api_key = kwargs.get("api_key", None)
prefix = kwargs.get("prefix", None)
timeout = kwargs.get("timeout", None)
host = kwargs.get("host", None)
collection_name = kwargs.get("collection_name", COLLECTION_NAME)
location = kwargs.get("location", QDRANT_DEFAULT_LOCATION)
dimension = kwargs.get("dimension", DIMENSION)
top_k: int = kwargs.get("top_k", TOP_K)
flush_interval_sec = kwargs.get("flush_interval_sec", QDRANT_FLUSH_INTERVAL_SEC)
index_params = kwargs.get("index_params", QDRANT_INDEX_PARAMS)
vector_base = QdrantVectorStore(
url=url,
port=port,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
https=https,
api_key=api_key,
prefix=prefix,
timeout=timeout,
host=host,
collection_name=collection_name,
location=location,
dimension=dimension,
top_k=top_k,
flush_interval_sec=flush_interval_sec,
index_params=index_params,
)
else:
raise NotFoundError("vector store", name)
return vector_base
107 changes: 107 additions & 0 deletions gptcache/manager/vector_data/qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import List, Optional
import numpy as np

from gptcache.utils import import_qdrant
from gptcache.utils.log import gptcache_log
from gptcache.manager.vector_data.base import VectorBase, VectorData

import_qdrant()

from qdrant_client import QdrantClient # pylint: disable=C0413
from qdrant_client.models import PointStruct, HnswConfigDiff, VectorParams, OptimizersConfigDiff, \
Distance # pylint: disable=C0413


class QdrantVectorStore(VectorBase):

def __init__(
self,
url: Optional[str] = None,
port: Optional[int] = 6333,
grpc_port: int = 6334,
prefer_grpc: bool = False,
https: Optional[bool] = None,
api_key: Optional[str] = None,
prefix: Optional[str] = None,
timeout: Optional[float] = None,
host: Optional[str] = None,
collection_name: Optional[str] = "gptcache",
location: Optional[str] = "./qdrant",
dimension: int = 0,
top_k: int = 1,
flush_interval_sec: int = 5,
index_params: Optional[dict] = None,
):
if dimension <= 0:
raise ValueError(

Check warning on line 36 in gptcache/manager/vector_data/qdrant.py

View check run for this annotation

Codecov / codecov/patch

gptcache/manager/vector_data/qdrant.py#L36

Added line #L36 was not covered by tests
f"invalid `dim` param: {dimension} in the Qdrant vector store."
)
self._client: QdrantClient
self._collection_name = collection_name
self._in_memory = location == ":memory:"
self.dimension = dimension
self.top_k = top_k
if self._in_memory or location is not None:
self._create_local(location)
else:
self._create_remote(url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https)

Check warning on line 47 in gptcache/manager/vector_data/qdrant.py

View check run for this annotation

Codecov / codecov/patch

gptcache/manager/vector_data/qdrant.py#L47

Added line #L47 was not covered by tests
self._create_collection(collection_name, flush_interval_sec, index_params)

def _create_local(self, location):
self._client = QdrantClient(location=location)

def _create_remote(self, url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https):
self._client = QdrantClient(

Check warning on line 54 in gptcache/manager/vector_data/qdrant.py

View check run for this annotation

Codecov / codecov/patch

gptcache/manager/vector_data/qdrant.py#L54

Added line #L54 was not covered by tests
url=url,
port=port,
api_key=api_key,
timeout=timeout,
host=host,
grpc_port=grpc_port,
prefer_grpc=prefer_grpc,
prefix=prefix,
https=https,
)

def _create_collection(self, collection_name: str, flush_interval_sec: int, index_params: Optional[dict] = None):
hnsw_config = HnswConfigDiff(**(index_params or {}))
vectors_config = VectorParams(size=self.dimension, distance=Distance.COSINE,
hnsw_config=hnsw_config)
optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000,
flush_interval_sec=flush_interval_sec)
# check if the collection exists
existing_collections = self._client.get_collections()
for existing_collection in existing_collections.collections:
if existing_collection.name == collection_name:
gptcache_log.warning("The %s collection already exists, and it will be used directly.", collection_name)
break

Check warning on line 77 in gptcache/manager/vector_data/qdrant.py

View check run for this annotation

Codecov / codecov/patch

gptcache/manager/vector_data/qdrant.py#L75-L77

Added lines #L75 - L77 were not covered by tests
else:
self._client.create_collection(collection_name=collection_name, vectors_config=vectors_config,
optimizers_config=optimizers_config)

def mul_add(self, datas: List[VectorData]):
points = [PointStruct(id=d.id, vector=d.data.reshape(-1).tolist()) for d in datas]
self._client.upsert(collection_name=self._collection_name, points=points, wait=False)

def search(self, data: np.ndarray, top_k: int = -1):
if top_k == -1:
top_k = self.top_k
reshaped_data = data.reshape(-1).tolist()
search_result = self._client.search(collection_name=self._collection_name, query_vector=reshaped_data,
limit=top_k)
return list(map(lambda x: (x.score, x.id), search_result))

def delete(self, ids: List[str]):
self._client.delete(collection_name=self._collection_name, points_selector=ids)

def rebuild(self, ids=None): # pylint: disable=unused-argument
optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000)
self._client.update_collection(collection_name=self._collection_name, optimizer_config=optimizers_config)

def flush(self):
# no need to flush manually as qdrant flushes automatically based on the optimizers_config for remote Qdrant
pass


def close(self):
self.flush()
7 changes: 6 additions & 1 deletion gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
"import_paddlenlp",
"import_tiktoken",
"import_fastapi",
"import_redis"
"import_redis",
"import_qdrant"
]

import importlib.util
Expand All @@ -65,6 +66,10 @@ def import_milvus_lite():
_check_library("milvus")


def import_qdrant():
_check_library("qdrant_client")


def import_sbert():
_check_library("sentence_transformers", package="sentence-transformers")

Expand Down
33 changes: 33 additions & 0 deletions tests/unit_tests/manager/test_qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import unittest

import numpy as np

from gptcache.manager.vector_data import VectorBase
from gptcache.manager.vector_data.base import VectorData


class TestQdrant(unittest.TestCase):
def test_normal(self):
size = 10
dim = 2
top_k = 10
qdrant = VectorBase(
"qdrant",
top_k=top_k,
dimension=dim,
location=":memory:"
)
data = np.random.randn(size, dim).astype(np.float32)
qdrant.mul_add([VectorData(id=i, data=v) for v, i in zip(data, range(size))])
search_result = qdrant.search(data[0], top_k)
self.assertEqual(len(search_result), top_k)
qdrant.mul_add([VectorData(id=size, data=data[0])])
ret = qdrant.search(data[0])
self.assertIn(ret[0][1], [0, size])
self.assertIn(ret[1][1], [0, size])
qdrant.delete([0, 1, 2, 3, 4, 5, size])
ret = qdrant.search(data[0])
self.assertNotIn(ret[0][1], [0, size])
qdrant.rebuild()
qdrant.close()
Loading