From 61e44544151a5b498ec82979a1796670cf68dda3 Mon Sep 17 00:00:00 2001 From: akashmangoai Date: Sun, 1 Sep 2024 21:59:15 +0530 Subject: [PATCH] added support for lancedb Signed-off-by: akashmangoai --- README.md | 1 + examples/data_manager/vector_store.py | 1 + gptcache/manager/vector_data/lancedb.py | 85 ++++++++++++++++++++++++ gptcache/manager/vector_data/manager.py | 23 +++++++ gptcache/utils/__init__.py | 3 + tests/unit_tests/manager/test_lancedb.py | 24 +++++++ 6 files changed, 137 insertions(+) create mode 100644 gptcache/manager/vector_data/lancedb.py create mode 100644 tests/unit_tests/manager/test_lancedb.py diff --git a/README.md b/README.md index c5f6f955..c75e1b06 100644 --- a/README.md +++ b/README.md @@ -360,6 +360,7 @@ The **Vector Store** module helps find the K most similar requests from the inpu - [x] Support [DocArray](https://github.com/docarray/docarray), DocArray is a library for representing, sending and storing multi-modal data, perfect for Machine Learning applications. - [x] Support qdrant - [x] Support weaviate + - [x] Support [LanceDB](https://github.com/lancedb/lancedb),Developer-friendly, serverless vector database for AI applications. Easily add long-term memory to your LLM apps! - [ ] Support other vector databases. - **Cache Manager**: The **Cache Manager** is responsible for controlling the operation of both the **Cache Storage** and **Vector Store**. diff --git a/examples/data_manager/vector_store.py b/examples/data_manager/vector_store.py index 4d804d38..194010ed 100644 --- a/examples/data_manager/vector_store.py +++ b/examples/data_manager/vector_store.py @@ -20,6 +20,7 @@ def run(): 'docarray', 'redis', 'weaviate', + 'lancedb', ] for vector_store in vector_stores: cache_base = CacheBase('sqlite') diff --git a/gptcache/manager/vector_data/lancedb.py b/gptcache/manager/vector_data/lancedb.py new file mode 100644 index 00000000..3bdbad6f --- /dev/null +++ b/gptcache/manager/vector_data/lancedb.py @@ -0,0 +1,85 @@ +from typing import List, Optional + +import numpy as np +import pyarrow as pa + +import lancedb +from gptcache.manager.vector_data.base import VectorBase, VectorData +from gptcache.utils import import_lancedb, import_torch + +import_torch() +import_lancedb() + +class LanceDB(VectorBase): + """Vector store: LanceDB + + :param persist_directory: The directory to persist, defaults to '/tmp/lancedb'. + :type persist_directory: str + :param table_name: The name of the table in LanceDB, defaults to 'gptcache'. + :type table_name: str + :param top_k: The number of the vectors results to return, defaults to 1. + :type top_k: int + """ + + def __init__( + self, + persist_directory: Optional[str] = "/tmp/lancedb", + table_name: str = "gptcache", + top_k: int = 1, + ): + self._persist_directory = persist_directory + self._table_name = table_name + self._top_k = top_k + + # Initialize LanceDB database + self._db = lancedb.connect(self._persist_directory) + + # Define the schema if creating a new table + schema = pa.schema([ + pa.field("id", pa.string()), + pa.field("vector", pa.list_(pa.float32(), list_size=10)) # Assuming dimension 10 for vectors + ]) + + # Initialize or open table + if self._table_name not in self._db.table_names(): + self._table = self._db.create_table(self._table_name, schema=schema) + else: + self._table = self._db.open_table(self._table_name) + + def mul_add(self, datas: List[VectorData]): + """Add multiple vectors to the LanceDB table""" + vectors, ids = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas))) + data = [{"id": id, "vector": vector} for id, vector in zip(ids, vectors)] + self._table.add(data) + + def search(self, data: np.ndarray, top_k: int = -1): + """Search for the most similar vectors in the LanceDB table""" + if len(self._table) == 0: + return [] + + if top_k == -1: + top_k = self._top_k + + results = self._table.search(data.tolist()).limit(top_k).to_list() + return [(result["_distance"], int(result["id"])) for result in results] + + def delete(self, ids: List[int]): + """Delete vectors from the LanceDB table based on IDs""" + for id in ids: + self._table.delete(f"id = '{id}'") + + def rebuild(self, ids: Optional[List[int]] = None): + """Rebuild the index, if applicable""" + return True + + def flush(self): + """Flush changes to disk (if necessary)""" + pass + + def close(self): + """Close the connection to LanceDB""" + pass + + def count(self): + """Return the total number of vectors in the table""" + return len(self._table) diff --git a/gptcache/manager/vector_data/manager.py b/gptcache/manager/vector_data/manager.py index 815fb934..2314654d 100644 --- a/gptcache/manager/vector_data/manager.py +++ b/gptcache/manager/vector_data/manager.py @@ -42,6 +42,7 @@ class VectorBase: `Chromadb` (with `top_k`, `client_settings`, `persist_directory`, `collection_name` params), `Hnswlib` (with `index_file_path`, `dimension`, `top_k`, `max_elements` params). `pgvector` (with `url`, `collection_name`, `index_params`, `top_k`, `dimension` params). + `lancedb` (with `url`, `collection_name`, `index_params`, `top_k`,). :param name: the name of the vectorbase, it is support 'milvus', 'faiss', 'chromadb', 'hnswlib' now. :type name: str @@ -89,6 +90,14 @@ class VectorBase: :param persist_directory: the directory to persist, defaults to '.chromadb/' in the current directory. :type persist_directory: str + :param client_settings: the setting for LanceDB. + :param persist_directory: The directory to persist, defaults to '/tmp/lancedb'. + :type persist_directory: str + :param table_name: The name of the table in LanceDB, defaults to 'gptcache'. + :type table_name: str + :param top_k: The number of the vectors results to return, defaults to 1. + :type top_k: int + :param index_path: the path to hnswlib index, defaults to 'hnswlib_index.bin'. :type index_path: str :param max_elements: max_elements of hnswlib, defaults 100000. @@ -289,6 +298,20 @@ def get(name, **kwargs): class_schema=class_schema, top_k=top_k, ) + + elif name == "lancedb": + from gptcache.manager.vector_data.lancedb import LanceDB + + persist_directory = kwargs.get("persist_directory", None) + table_name = kwargs.get("table_name", COLLECTION_NAME) + top_k: int = kwargs.get("top_k", TOP_K) + + vector_base = LanceDB( + persist_directory=persist_directory, + table_name=table_name, + top_k=top_k, + ) + else: raise NotFoundError("vector store", name) return vector_base diff --git a/gptcache/utils/__init__.py b/gptcache/utils/__init__.py index 093fd354..53251aa7 100644 --- a/gptcache/utils/__init__.py +++ b/gptcache/utils/__init__.py @@ -42,6 +42,7 @@ "import_redis", "import_qdrant", "import_weaviate", + "import_lancedb", ] import importlib.util @@ -147,6 +148,8 @@ def import_duckdb(): _check_library("duckdb", package="duckdb") _check_library("duckdb-engine", package="duckdb-engine") +def import_lancedb(): + _check_library("lancedb", package="lancedb") def import_sql_client(db_name): if db_name == "postgresql": diff --git a/tests/unit_tests/manager/test_lancedb.py b/tests/unit_tests/manager/test_lancedb.py new file mode 100644 index 00000000..f7d98600 --- /dev/null +++ b/tests/unit_tests/manager/test_lancedb.py @@ -0,0 +1,24 @@ +import unittest +import numpy as np +from gptcache.manager import VectorBase +from gptcache.manager.vector_data.base import VectorData + +class TestLanceDB(unittest.TestCase): + def test_normal(self): + # Initialize the LanceDB with a temporary directory and top_k set to 3 + db = VectorBase("lancedb", persist_directory="/tmp/test_lancedb", top_k=3) + + # Add 100 vectors to the LanceDB + db.mul_add([VectorData(id=i, data=np.random.sample(10)) for i in range(100)]) + + # Perform a search with a random query vector + search_res = db.search(np.random.sample(10)) + + # Check that the search returns 3 results + self.assertEqual(len(search_res), 3) + + # Delete vectors with specific IDs + db.delete([1, 3, 5, 7]) + + # Check that the count of vectors in the table is now 96 + self.assertEqual(db.count(), 96)