Skip to content

Commit

Permalink
added support for lancedb
Browse files Browse the repository at this point in the history
Signed-off-by: akashmangoai <[email protected]>
  • Loading branch information
akashmangoai committed Sep 1, 2024
1 parent bae7ffe commit 61e4454
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ The **Vector Store** module helps find the K most similar requests from the inpu
- [x] Support [DocArray](https://github.com/docarray/docarray), DocArray is a library for representing, sending and storing multi-modal data, perfect for Machine Learning applications.
- [x] Support qdrant
- [x] Support weaviate
- [x] Support [LanceDB](https://github.com/lancedb/lancedb),Developer-friendly, serverless vector database for AI applications. Easily add long-term memory to your LLM apps!
- [ ] Support other vector databases.
- **Cache Manager**:
The **Cache Manager** is responsible for controlling the operation of both the **Cache Storage** and **Vector Store**.
Expand Down
1 change: 1 addition & 0 deletions examples/data_manager/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def run():
'docarray',
'redis',
'weaviate',
'lancedb',
]
for vector_store in vector_stores:
cache_base = CacheBase('sqlite')
Expand Down
85 changes: 85 additions & 0 deletions gptcache/manager/vector_data/lancedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import List, Optional

import numpy as np
import pyarrow as pa

import lancedb
from gptcache.manager.vector_data.base import VectorBase, VectorData
from gptcache.utils import import_lancedb, import_torch

import_torch()
import_lancedb()

class LanceDB(VectorBase):
"""Vector store: LanceDB
:param persist_directory: The directory to persist, defaults to '/tmp/lancedb'.
:type persist_directory: str
:param table_name: The name of the table in LanceDB, defaults to 'gptcache'.
:type table_name: str
:param top_k: The number of the vectors results to return, defaults to 1.
:type top_k: int
"""

def __init__(
self,
persist_directory: Optional[str] = "/tmp/lancedb",
table_name: str = "gptcache",
top_k: int = 1,
):
self._persist_directory = persist_directory
self._table_name = table_name
self._top_k = top_k

# Initialize LanceDB database
self._db = lancedb.connect(self._persist_directory)

# Define the schema if creating a new table
schema = pa.schema([
pa.field("id", pa.string()),
pa.field("vector", pa.list_(pa.float32(), list_size=10)) # Assuming dimension 10 for vectors
])

# Initialize or open table
if self._table_name not in self._db.table_names():
self._table = self._db.create_table(self._table_name, schema=schema)
else:
self._table = self._db.open_table(self._table_name)

def mul_add(self, datas: List[VectorData]):
"""Add multiple vectors to the LanceDB table"""
vectors, ids = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas)))
data = [{"id": id, "vector": vector} for id, vector in zip(ids, vectors)]
self._table.add(data)

def search(self, data: np.ndarray, top_k: int = -1):
"""Search for the most similar vectors in the LanceDB table"""
if len(self._table) == 0:
return []

if top_k == -1:
top_k = self._top_k

results = self._table.search(data.tolist()).limit(top_k).to_list()
return [(result["_distance"], int(result["id"])) for result in results]

def delete(self, ids: List[int]):
"""Delete vectors from the LanceDB table based on IDs"""
for id in ids:
self._table.delete(f"id = '{id}'")

def rebuild(self, ids: Optional[List[int]] = None):
"""Rebuild the index, if applicable"""
return True

def flush(self):
"""Flush changes to disk (if necessary)"""
pass

def close(self):
"""Close the connection to LanceDB"""
pass

def count(self):
"""Return the total number of vectors in the table"""
return len(self._table)
23 changes: 23 additions & 0 deletions gptcache/manager/vector_data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class VectorBase:
`Chromadb` (with `top_k`, `client_settings`, `persist_directory`, `collection_name` params),
`Hnswlib` (with `index_file_path`, `dimension`, `top_k`, `max_elements` params).
`pgvector` (with `url`, `collection_name`, `index_params`, `top_k`, `dimension` params).
`lancedb` (with `url`, `collection_name`, `index_params`, `top_k`,).
:param name: the name of the vectorbase, it is support 'milvus', 'faiss', 'chromadb', 'hnswlib' now.
:type name: str
Expand Down Expand Up @@ -89,6 +90,14 @@ class VectorBase:
:param persist_directory: the directory to persist, defaults to '.chromadb/' in the current directory.
:type persist_directory: str
:param client_settings: the setting for LanceDB.
:param persist_directory: The directory to persist, defaults to '/tmp/lancedb'.
:type persist_directory: str
:param table_name: The name of the table in LanceDB, defaults to 'gptcache'.
:type table_name: str
:param top_k: The number of the vectors results to return, defaults to 1.
:type top_k: int
:param index_path: the path to hnswlib index, defaults to 'hnswlib_index.bin'.
:type index_path: str
:param max_elements: max_elements of hnswlib, defaults 100000.
Expand Down Expand Up @@ -289,6 +298,20 @@ def get(name, **kwargs):
class_schema=class_schema,
top_k=top_k,
)

elif name == "lancedb":
from gptcache.manager.vector_data.lancedb import LanceDB

persist_directory = kwargs.get("persist_directory", None)
table_name = kwargs.get("table_name", COLLECTION_NAME)
top_k: int = kwargs.get("top_k", TOP_K)

vector_base = LanceDB(
persist_directory=persist_directory,
table_name=table_name,
top_k=top_k,
)

else:
raise NotFoundError("vector store", name)
return vector_base
3 changes: 3 additions & 0 deletions gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"import_redis",
"import_qdrant",
"import_weaviate",
"import_lancedb",
]

import importlib.util
Expand Down Expand Up @@ -147,6 +148,8 @@ def import_duckdb():
_check_library("duckdb", package="duckdb")
_check_library("duckdb-engine", package="duckdb-engine")

def import_lancedb():
_check_library("lancedb", package="lancedb")

def import_sql_client(db_name):
if db_name == "postgresql":
Expand Down
24 changes: 24 additions & 0 deletions tests/unit_tests/manager/test_lancedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import unittest
import numpy as np
from gptcache.manager import VectorBase
from gptcache.manager.vector_data.base import VectorData

class TestLanceDB(unittest.TestCase):
def test_normal(self):
# Initialize the LanceDB with a temporary directory and top_k set to 3
db = VectorBase("lancedb", persist_directory="/tmp/test_lancedb", top_k=3)

# Add 100 vectors to the LanceDB
db.mul_add([VectorData(id=i, data=np.random.sample(10)) for i in range(100)])

# Perform a search with a random query vector
search_res = db.search(np.random.sample(10))

# Check that the search returns 3 results
self.assertEqual(len(search_res), 3)

# Delete vectors with specific IDs
db.delete([1, 3, 5, 7])

# Check that the count of vectors in the table is now 96
self.assertEqual(db.count(), 96)

0 comments on commit 61e4454

Please sign in to comment.