From 98fae339ab06df232721f86f3c31ad862dc46e47 Mon Sep 17 00:00:00 2001 From: Parth Patel <41171860+parthvnp@users.noreply.github.com> Date: Thu, 29 Jun 2023 03:08:21 -0400 Subject: [PATCH] Add support for Qdrant Vector Store (#453) * Add Qdrant vector store client * Add setup for Qdrant vector store * Add lazy import for Qdrant client library * Add import_qdrant to index file * Use models not types for building configs * Add tests for Qdrant vector store --- gptcache/manager/vector_data/manager.py | 40 +++++++++ gptcache/manager/vector_data/qdrant.py | 107 ++++++++++++++++++++++++ gptcache/utils/__init__.py | 7 +- tests/unit_tests/manager/test_qdrant.py | 33 ++++++++ 4 files changed, 186 insertions(+), 1 deletion(-) create mode 100644 gptcache/manager/vector_data/qdrant.py create mode 100644 tests/unit_tests/manager/test_qdrant.py diff --git a/gptcache/manager/vector_data/manager.py b/gptcache/manager/vector_data/manager.py index 6453c88e..86616b5b 100644 --- a/gptcache/manager/vector_data/manager.py +++ b/gptcache/manager/vector_data/manager.py @@ -19,6 +19,12 @@ PGVECTOR_URL = "postgresql://postgres:postgres@localhost:5432/postgres" PGVECTOR_INDEX_PARAMS = {"index_type": "L2", "params": {"lists": 100, "probes": 10}} +QDRANT_GRPC_PORT = 6334 +QDRANT_HTTP_PORT = 6333 +QDRANT_INDEX_PARAMS = {"ef_construct": 100, "m": 16} +QDRANT_DEFAULT_LOCATION = "./qdrant_data" +QDRANT_FLUSH_INTERVAL_SEC = 5 + COLLECTION_NAME = "gptcache" @@ -217,6 +223,40 @@ def get(name, **kwargs): collection_name=collection_name, top_k=top_k, ) + elif name == "qdrant": + from gptcache.manager.vector_data.qdrant import QdrantVectorStore + url = kwargs.get("url", None) + port = kwargs.get("port", QDRANT_HTTP_PORT) + grpc_port = kwargs.get("grpc_port", QDRANT_GRPC_PORT) + prefer_grpc = kwargs.get("prefer_grpc", False) + https = kwargs.get("https", False) + api_key = kwargs.get("api_key", None) + prefix = kwargs.get("prefix", None) + timeout = kwargs.get("timeout", None) + host = kwargs.get("host", None) + collection_name = kwargs.get("collection_name", COLLECTION_NAME) + location = kwargs.get("location", QDRANT_DEFAULT_LOCATION) + dimension = kwargs.get("dimension", DIMENSION) + top_k: int = kwargs.get("top_k", TOP_K) + flush_interval_sec = kwargs.get("flush_interval_sec", QDRANT_FLUSH_INTERVAL_SEC) + index_params = kwargs.get("index_params", QDRANT_INDEX_PARAMS) + vector_base = QdrantVectorStore( + url=url, + port=port, + grpc_port=grpc_port, + prefer_grpc=prefer_grpc, + https=https, + api_key=api_key, + prefix=prefix, + timeout=timeout, + host=host, + collection_name=collection_name, + location=location, + dimension=dimension, + top_k=top_k, + flush_interval_sec=flush_interval_sec, + index_params=index_params, + ) else: raise NotFoundError("vector store", name) return vector_base diff --git a/gptcache/manager/vector_data/qdrant.py b/gptcache/manager/vector_data/qdrant.py new file mode 100644 index 00000000..93bcf99e --- /dev/null +++ b/gptcache/manager/vector_data/qdrant.py @@ -0,0 +1,107 @@ +from typing import List, Optional +import numpy as np + +from gptcache.utils import import_qdrant +from gptcache.utils.log import gptcache_log +from gptcache.manager.vector_data.base import VectorBase, VectorData + +import_qdrant() + +from qdrant_client import QdrantClient # pylint: disable=C0413 +from qdrant_client.models import PointStruct, HnswConfigDiff, VectorParams, OptimizersConfigDiff, \ + Distance # pylint: disable=C0413 + + +class QdrantVectorStore(VectorBase): + + def __init__( + self, + url: Optional[str] = None, + port: Optional[int] = 6333, + grpc_port: int = 6334, + prefer_grpc: bool = False, + https: Optional[bool] = None, + api_key: Optional[str] = None, + prefix: Optional[str] = None, + timeout: Optional[float] = None, + host: Optional[str] = None, + collection_name: Optional[str] = "gptcache", + location: Optional[str] = "./qdrant", + dimension: int = 0, + top_k: int = 1, + flush_interval_sec: int = 5, + index_params: Optional[dict] = None, + ): + if dimension <= 0: + raise ValueError( + f"invalid `dim` param: {dimension} in the Qdrant vector store." + ) + self._client: QdrantClient + self._collection_name = collection_name + self._in_memory = location == ":memory:" + self.dimension = dimension + self.top_k = top_k + if self._in_memory or location is not None: + self._create_local(location) + else: + self._create_remote(url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https) + self._create_collection(collection_name, flush_interval_sec, index_params) + + def _create_local(self, location): + self._client = QdrantClient(location=location) + + def _create_remote(self, url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https): + self._client = QdrantClient( + url=url, + port=port, + api_key=api_key, + timeout=timeout, + host=host, + grpc_port=grpc_port, + prefer_grpc=prefer_grpc, + prefix=prefix, + https=https, + ) + + def _create_collection(self, collection_name: str, flush_interval_sec: int, index_params: Optional[dict] = None): + hnsw_config = HnswConfigDiff(**(index_params or {})) + vectors_config = VectorParams(size=self.dimension, distance=Distance.COSINE, + hnsw_config=hnsw_config) + optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000, + flush_interval_sec=flush_interval_sec) + # check if the collection exists + existing_collections = self._client.get_collections() + for existing_collection in existing_collections.collections: + if existing_collection.name == collection_name: + gptcache_log.warning("The %s collection already exists, and it will be used directly.", collection_name) + break + else: + self._client.create_collection(collection_name=collection_name, vectors_config=vectors_config, + optimizers_config=optimizers_config) + + def mul_add(self, datas: List[VectorData]): + points = [PointStruct(id=d.id, vector=d.data.reshape(-1).tolist()) for d in datas] + self._client.upsert(collection_name=self._collection_name, points=points, wait=False) + + def search(self, data: np.ndarray, top_k: int = -1): + if top_k == -1: + top_k = self.top_k + reshaped_data = data.reshape(-1).tolist() + search_result = self._client.search(collection_name=self._collection_name, query_vector=reshaped_data, + limit=top_k) + return list(map(lambda x: (x.score, x.id), search_result)) + + def delete(self, ids: List[str]): + self._client.delete(collection_name=self._collection_name, points_selector=ids) + + def rebuild(self, ids=None): # pylint: disable=unused-argument + optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000) + self._client.update_collection(collection_name=self._collection_name, optimizer_config=optimizers_config) + + def flush(self): + # no need to flush manually as qdrant flushes automatically based on the optimizers_config for remote Qdrant + pass + + + def close(self): + self.flush() diff --git a/gptcache/utils/__init__.py b/gptcache/utils/__init__.py index ae3d7b47..7a919721 100644 --- a/gptcache/utils/__init__.py +++ b/gptcache/utils/__init__.py @@ -38,7 +38,8 @@ "import_paddlenlp", "import_tiktoken", "import_fastapi", - "import_redis" + "import_redis", + "import_qdrant" ] import importlib.util @@ -65,6 +66,10 @@ def import_milvus_lite(): _check_library("milvus") +def import_qdrant(): + _check_library("qdrant_client") + + def import_sbert(): _check_library("sentence_transformers", package="sentence-transformers") diff --git a/tests/unit_tests/manager/test_qdrant.py b/tests/unit_tests/manager/test_qdrant.py new file mode 100644 index 00000000..2dafe06f --- /dev/null +++ b/tests/unit_tests/manager/test_qdrant.py @@ -0,0 +1,33 @@ +import os +import unittest + +import numpy as np + +from gptcache.manager.vector_data import VectorBase +from gptcache.manager.vector_data.base import VectorData + + +class TestQdrant(unittest.TestCase): + def test_normal(self): + size = 10 + dim = 2 + top_k = 10 + qdrant = VectorBase( + "qdrant", + top_k=top_k, + dimension=dim, + location=":memory:" + ) + data = np.random.randn(size, dim).astype(np.float32) + qdrant.mul_add([VectorData(id=i, data=v) for v, i in zip(data, range(size))]) + search_result = qdrant.search(data[0], top_k) + self.assertEqual(len(search_result), top_k) + qdrant.mul_add([VectorData(id=size, data=data[0])]) + ret = qdrant.search(data[0]) + self.assertIn(ret[0][1], [0, size]) + self.assertIn(ret[1][1], [0, size]) + qdrant.delete([0, 1, 2, 3, 4, 5, size]) + ret = qdrant.search(data[0]) + self.assertNotIn(ret[0][1], [0, size]) + qdrant.rebuild() + qdrant.close()