From a9a6ca5b3642a85ae3885cbae680be6fa6f7e923 Mon Sep 17 00:00:00 2001 From: Parth Patel Date: Wed, 28 Jun 2023 01:06:51 -0400 Subject: [PATCH 1/6] Add Qdrant vector store client --- gptcache/manager/vector_data/qdrant.py | 110 +++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 gptcache/manager/vector_data/qdrant.py diff --git a/gptcache/manager/vector_data/qdrant.py b/gptcache/manager/vector_data/qdrant.py new file mode 100644 index 00000000..69c079bc --- /dev/null +++ b/gptcache/manager/vector_data/qdrant.py @@ -0,0 +1,110 @@ +from typing import List, Optional +import numpy as np + +from gptcache.utils import import_qdrant +from gptcache.utils.log import gptcache_log +from gptcache.manager.vector_data.base import VectorBase, VectorData + +import_qdrant() + +from qdrant_client import QdrantClient # pylint: disable=C0413 +from qdrant_client.conversions import common_types as types # pylint: disable=C0413 +from qdrant_client.models import PointStruct # pylint: disable=C0413 + + +class QdrantVectorStore(VectorBase): + + def __init__( + self, + url: Optional[str] = None, + port: Optional[int] = 6333, + grpc_port: int = 6334, + prefer_grpc: bool = False, + https: Optional[bool] = None, + api_key: Optional[str] = None, + prefix: Optional[str] = None, + timeout: Optional[float] = None, + host: Optional[str] = None, + collection_name: Optional[str] = "gptcache", + location: Optional[str] = "./qdrant", + dimension: int = 0, + top_k: int = 1, + flush_interval_sec: int = 5, + index_params: Optional[dict] = None, + ): + if dimension <= 0: + raise ValueError( + f"invalid `dim` param: {dimension} in the Qdrant vector store." + ) + self._client: QdrantClient + self._collection_name = collection_name + self._in_memory = location == ":memory:" + self._closeable = self._in_memory or location is not None + self.dimension = dimension + self.top_k = top_k + if self._in_memory or location is not None: + self._create_local(location) + else: + self._create_remote(url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https) + self._create_collection(collection_name, flush_interval_sec, index_params) + + def _create_local(self, location): + self._client = QdrantClient(location=location) + + def _create_remote(self, url, port, api_key, timeout, host, grpc_port, prefer_grpc, prefix, https): + self._client = QdrantClient( + url=url, + port=port, + api_key=api_key, + timeout=timeout, + host=host, + grpc_port=grpc_port, + prefer_grpc=prefer_grpc, + prefix=prefix, + https=https, + ) + + def _create_collection(self, collection_name: str, flush_interval_sec: int, index_params: Optional[dict] = None): + hnsw_config = types.HnswConfigDiff(**(index_params or {})) + vectors_config = types.VectorParams(size=self.dimension, distance=types.Distance.COSINE, + hnsw_config=hnsw_config) + optimizers_config = types.OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000, + flush_interval_sec=flush_interval_sec) + # check if the collection exists + existing_collection = self._client.get_collection(collection_name=collection_name) + if existing_collection: + gptcache_log.warning("The %s collection already exists, and it will be used directly.", collection_name) + self.col = existing_collection + else: + self.col = self._client.create_collection(collection_name=collection_name, vectors_config=vectors_config, + optimizers_config=optimizers_config) + + def mul_add(self, datas: List[VectorData]): + data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas))) + np_data = np.array(data_array).astype("float32") + entities = [id_array, np_data] + points = [PointStruct(id=_id, vector=vector) for _id, vector in zip(*entities)] + self._client.upsert(collection_name=self._collection_name, points=points, wait=False) + + def search(self, data: np.ndarray, top_k: int = -1): + if top_k == -1: + top_k = self.top_k + reshaped_data = data.reshape(1, -1).tolist() + search_result = self._client.search(collection_name=self._collection_name, query_vector=reshaped_data, + limit=top_k) + return list(map(lambda x: (x.id, x.score), search_result)) + + def delete(self, ids: List[str]): + self._client.delete_vectors(collection_name=self._collection_name, vectors=ids) + + def rebuild(self, ids=None): # pylint: disable=unused-argument + optimizers_config = types.OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000) + self._client.update_collection(collection_name=self._collection_name, optimizer_config=optimizers_config) + + def flush(self): + # no need to flush manually as qdrant flushes automatically based on the optimizers_config for remote Qdrant + if self._closeable: + self._client._save() # pylint: disable=protected-access + + def close(self): + self.flush() From b3b06c3b0b104de761cff93739e006f2fffef746 Mon Sep 17 00:00:00 2001 From: Parth Patel Date: Wed, 28 Jun 2023 01:07:07 -0400 Subject: [PATCH 2/6] Add setup for Qdrant vector store --- gptcache/manager/vector_data/manager.py | 40 +++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/gptcache/manager/vector_data/manager.py b/gptcache/manager/vector_data/manager.py index 6453c88e..86616b5b 100644 --- a/gptcache/manager/vector_data/manager.py +++ b/gptcache/manager/vector_data/manager.py @@ -19,6 +19,12 @@ PGVECTOR_URL = "postgresql://postgres:postgres@localhost:5432/postgres" PGVECTOR_INDEX_PARAMS = {"index_type": "L2", "params": {"lists": 100, "probes": 10}} +QDRANT_GRPC_PORT = 6334 +QDRANT_HTTP_PORT = 6333 +QDRANT_INDEX_PARAMS = {"ef_construct": 100, "m": 16} +QDRANT_DEFAULT_LOCATION = "./qdrant_data" +QDRANT_FLUSH_INTERVAL_SEC = 5 + COLLECTION_NAME = "gptcache" @@ -217,6 +223,40 @@ def get(name, **kwargs): collection_name=collection_name, top_k=top_k, ) + elif name == "qdrant": + from gptcache.manager.vector_data.qdrant import QdrantVectorStore + url = kwargs.get("url", None) + port = kwargs.get("port", QDRANT_HTTP_PORT) + grpc_port = kwargs.get("grpc_port", QDRANT_GRPC_PORT) + prefer_grpc = kwargs.get("prefer_grpc", False) + https = kwargs.get("https", False) + api_key = kwargs.get("api_key", None) + prefix = kwargs.get("prefix", None) + timeout = kwargs.get("timeout", None) + host = kwargs.get("host", None) + collection_name = kwargs.get("collection_name", COLLECTION_NAME) + location = kwargs.get("location", QDRANT_DEFAULT_LOCATION) + dimension = kwargs.get("dimension", DIMENSION) + top_k: int = kwargs.get("top_k", TOP_K) + flush_interval_sec = kwargs.get("flush_interval_sec", QDRANT_FLUSH_INTERVAL_SEC) + index_params = kwargs.get("index_params", QDRANT_INDEX_PARAMS) + vector_base = QdrantVectorStore( + url=url, + port=port, + grpc_port=grpc_port, + prefer_grpc=prefer_grpc, + https=https, + api_key=api_key, + prefix=prefix, + timeout=timeout, + host=host, + collection_name=collection_name, + location=location, + dimension=dimension, + top_k=top_k, + flush_interval_sec=flush_interval_sec, + index_params=index_params, + ) else: raise NotFoundError("vector store", name) return vector_base From c49a6ff9e30a01a660b1c18bd93853e2ba3f8d84 Mon Sep 17 00:00:00 2001 From: Parth Patel Date: Wed, 28 Jun 2023 01:15:02 -0400 Subject: [PATCH 3/6] Add lazy import for Qdrant client library --- gptcache/utils/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gptcache/utils/__init__.py b/gptcache/utils/__init__.py index ae3d7b47..9b1f1102 100644 --- a/gptcache/utils/__init__.py +++ b/gptcache/utils/__init__.py @@ -65,6 +65,10 @@ def import_milvus_lite(): _check_library("milvus") +def import_qdrant(): + _check_library("qdrant_client") + + def import_sbert(): _check_library("sentence_transformers", package="sentence-transformers") From e27990c45d076d628741f17f8c952d4243371366 Mon Sep 17 00:00:00 2001 From: Parth Patel Date: Wed, 28 Jun 2023 01:51:13 -0400 Subject: [PATCH 4/6] Add import_qdrant to index file --- gptcache/utils/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gptcache/utils/__init__.py b/gptcache/utils/__init__.py index 9b1f1102..7a919721 100644 --- a/gptcache/utils/__init__.py +++ b/gptcache/utils/__init__.py @@ -38,7 +38,8 @@ "import_paddlenlp", "import_tiktoken", "import_fastapi", - "import_redis" + "import_redis", + "import_qdrant" ] import importlib.util From 9912e2466180222c0a8e3c6a451795ccd6d64b78 Mon Sep 17 00:00:00 2001 From: Parth Patel Date: Wed, 28 Jun 2023 01:51:36 -0400 Subject: [PATCH 5/6] Use models not types for building configs --- gptcache/manager/vector_data/qdrant.py | 45 ++++++++++++-------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/gptcache/manager/vector_data/qdrant.py b/gptcache/manager/vector_data/qdrant.py index 69c079bc..93bcf99e 100644 --- a/gptcache/manager/vector_data/qdrant.py +++ b/gptcache/manager/vector_data/qdrant.py @@ -8,8 +8,8 @@ import_qdrant() from qdrant_client import QdrantClient # pylint: disable=C0413 -from qdrant_client.conversions import common_types as types # pylint: disable=C0413 -from qdrant_client.models import PointStruct # pylint: disable=C0413 +from qdrant_client.models import PointStruct, HnswConfigDiff, VectorParams, OptimizersConfigDiff, \ + Distance # pylint: disable=C0413 class QdrantVectorStore(VectorBase): @@ -39,7 +39,6 @@ def __init__( self._client: QdrantClient self._collection_name = collection_name self._in_memory = location == ":memory:" - self._closeable = self._in_memory or location is not None self.dimension = dimension self.top_k = top_k if self._in_memory or location is not None: @@ -65,46 +64,44 @@ def _create_remote(self, url, port, api_key, timeout, host, grpc_port, prefer_gr ) def _create_collection(self, collection_name: str, flush_interval_sec: int, index_params: Optional[dict] = None): - hnsw_config = types.HnswConfigDiff(**(index_params or {})) - vectors_config = types.VectorParams(size=self.dimension, distance=types.Distance.COSINE, - hnsw_config=hnsw_config) - optimizers_config = types.OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000, - flush_interval_sec=flush_interval_sec) + hnsw_config = HnswConfigDiff(**(index_params or {})) + vectors_config = VectorParams(size=self.dimension, distance=Distance.COSINE, + hnsw_config=hnsw_config) + optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000, + flush_interval_sec=flush_interval_sec) # check if the collection exists - existing_collection = self._client.get_collection(collection_name=collection_name) - if existing_collection: - gptcache_log.warning("The %s collection already exists, and it will be used directly.", collection_name) - self.col = existing_collection + existing_collections = self._client.get_collections() + for existing_collection in existing_collections.collections: + if existing_collection.name == collection_name: + gptcache_log.warning("The %s collection already exists, and it will be used directly.", collection_name) + break else: - self.col = self._client.create_collection(collection_name=collection_name, vectors_config=vectors_config, - optimizers_config=optimizers_config) + self._client.create_collection(collection_name=collection_name, vectors_config=vectors_config, + optimizers_config=optimizers_config) def mul_add(self, datas: List[VectorData]): - data_array, id_array = map(list, zip(*((data.data, data.id) for data in datas))) - np_data = np.array(data_array).astype("float32") - entities = [id_array, np_data] - points = [PointStruct(id=_id, vector=vector) for _id, vector in zip(*entities)] + points = [PointStruct(id=d.id, vector=d.data.reshape(-1).tolist()) for d in datas] self._client.upsert(collection_name=self._collection_name, points=points, wait=False) def search(self, data: np.ndarray, top_k: int = -1): if top_k == -1: top_k = self.top_k - reshaped_data = data.reshape(1, -1).tolist() + reshaped_data = data.reshape(-1).tolist() search_result = self._client.search(collection_name=self._collection_name, query_vector=reshaped_data, limit=top_k) - return list(map(lambda x: (x.id, x.score), search_result)) + return list(map(lambda x: (x.score, x.id), search_result)) def delete(self, ids: List[str]): - self._client.delete_vectors(collection_name=self._collection_name, vectors=ids) + self._client.delete(collection_name=self._collection_name, points_selector=ids) def rebuild(self, ids=None): # pylint: disable=unused-argument - optimizers_config = types.OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000) + optimizers_config = OptimizersConfigDiff(deleted_threshold=0.2, vacuum_min_vector_number=1000) self._client.update_collection(collection_name=self._collection_name, optimizer_config=optimizers_config) def flush(self): # no need to flush manually as qdrant flushes automatically based on the optimizers_config for remote Qdrant - if self._closeable: - self._client._save() # pylint: disable=protected-access + pass + def close(self): self.flush() From 800d893ae701aa9a35b50d65bc1a1c1dd0a33123 Mon Sep 17 00:00:00 2001 From: Parth Patel Date: Wed, 28 Jun 2023 01:51:52 -0400 Subject: [PATCH 6/6] Add tests for Qdrant vector store --- tests/unit_tests/manager/test_qdrant.py | 33 +++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/unit_tests/manager/test_qdrant.py diff --git a/tests/unit_tests/manager/test_qdrant.py b/tests/unit_tests/manager/test_qdrant.py new file mode 100644 index 00000000..2dafe06f --- /dev/null +++ b/tests/unit_tests/manager/test_qdrant.py @@ -0,0 +1,33 @@ +import os +import unittest + +import numpy as np + +from gptcache.manager.vector_data import VectorBase +from gptcache.manager.vector_data.base import VectorData + + +class TestQdrant(unittest.TestCase): + def test_normal(self): + size = 10 + dim = 2 + top_k = 10 + qdrant = VectorBase( + "qdrant", + top_k=top_k, + dimension=dim, + location=":memory:" + ) + data = np.random.randn(size, dim).astype(np.float32) + qdrant.mul_add([VectorData(id=i, data=v) for v, i in zip(data, range(size))]) + search_result = qdrant.search(data[0], top_k) + self.assertEqual(len(search_result), top_k) + qdrant.mul_add([VectorData(id=size, data=data[0])]) + ret = qdrant.search(data[0]) + self.assertIn(ret[0][1], [0, size]) + self.assertIn(ret[1][1], [0, size]) + qdrant.delete([0, 1, 2, 3, 4, 5, size]) + ret = qdrant.search(data[0]) + self.assertNotIn(ret[0][1], [0, size]) + qdrant.rebuild() + qdrant.close()