Skip to content

Commit

Permalink
Merge pull request #362 from Anush008/main
Browse files Browse the repository at this point in the history
feat: Qdrant as a supported vector store
  • Loading branch information
zainhoda authored Apr 17, 2024
2 parents 4ddc61e + 4ec4420 commit 86b516d
Show file tree
Hide file tree
Showing 4 changed files with 359 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
duckdb = ["duckdb"]
google = ["google-generativeai", "google-cloud-aiplatform"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed"]
test = ["tox"]
chromadb = ["chromadb"]
openai = ["openai"]
Expand All @@ -41,3 +41,4 @@ anthropic = ["anthropic"]
gemini = ["google-generativeai"]
marqo = ["marqo"]
zhipuai = ["zhipuai"]
qdrant = ["qdrant-client"]
3 changes: 3 additions & 0 deletions src/vanna/qdrant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .qdrant import Qdrant_VectorStore

__all__ = ["Qdrant_VectorStore"]
324 changes: 324 additions & 0 deletions src/vanna/qdrant/qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
from functools import cached_property
from typing import List, Tuple

import pandas as pd
from qdrant_client import QdrantClient, grpc, models

from ..base import VannaBase
from ..utils import deterministic_uuid

DOCUMENTATION_COLLECTION_NAME = "documentation"
DDL_COLLECTION_NAME = "ddl"
SQL_COLLECTION_NAME = "sql"
SCROLL_SIZE = 1000

ID_SUFFIXES = {
DDL_COLLECTION_NAME: "ddl",
DOCUMENTATION_COLLECTION_NAME: "doc",
SQL_COLLECTION_NAME: "sql",
}


class Qdrant_VectorStore(VannaBase):
"""Vectorstore implementation using Qdrant - https://qdrant.tech/"""

def __init__(
self,
config={},
):
"""
Vectorstore implementation using Qdrant - https://qdrant.tech/
Args:
- config (dict, optional): Dictionary of `Qdrant_VectorStore config` options. Defaults to `{}`.
- client: A `qdrant_client.QdrantClient` instance. Overrides other config options.
- location: If `":memory:"` - use in-memory Qdrant instance. If `str` - use it as a `url` parameter.
- url: Either host or str of "Optional[scheme], host, Optional[port], Optional[prefix]". Eg. `"http://localhost:6333"`.
- prefer_grpc: If `true` - use gPRC interface whenever possible in custom methods.
- https: If `true` - use HTTPS(SSL) protocol. Default: `None`
- api_key: API key for authentication in Qdrant Cloud. Default: `None`
- timeout: Timeout for REST and gRPC API requests. Defaults to 5 seconds for REST and unlimited for gRPC.
- path: Persistence path for QdrantLocal. Default: `None`.
- prefix: Prefix to the REST URL paths. Example: `service/v1` will result in `http://localhost:6333/service/v1/{qdrant-endpoint}`.
- n_results: Number of results to return from similarity search. Defaults to 10.
- fastembed_model: [Model](https://qdrant.github.io/fastembed/examples/Supported_Models/#supported-text-embedding-models) to use for `fastembed.TextEmbedding`.
Defaults to `"BAAI/bge-small-en-v1.5"`.
- collection_params: Additional parameters to pass to `qdrant_client.QdrantClient#create_collection()` method.
- distance_metric: Distance metric to use when creating collections. Defaults to `qdrant_client.models.Distance.COSINE`.
Raises:
TypeError: If config["client"] is not a `qdrant_client.QdrantClient` instance
"""
VannaBase.__init__(self, config=config)
client = config.get("client")

if client is None:
self._client = QdrantClient(
location=config.get("location", None),
url=config.get("url", None),
prefer_grpc=config.get("prefer_grpc", False),
https=config.get("https", None),
api_key=config.get("api_key", None),
timeout=config.get("timeout", None),
path=config.get("path", None),
prefix=config.get("prefix", None),
)
elif not isinstance(client, QdrantClient):
raise TypeError(
f"Unsupported client of type {client.__class__} was set in config"
)

else:
self._client = client

self.n_results = config.get("n_results", 10)
self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
self.collection_params = config.get("collection_params", {})
self.distance_metric = config.get("distance_metric", models.Distance.COSINE)

self._setup_collections()

def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
question_answer = format("Question: {0}\n\nSQL: {1}", question, sql)
id = deterministic_uuid(question_answer)

self._client.upsert(
SQL_COLLECTION_NAME,
points=[
models.PointStruct(
id=id,
vector=self.generate_embedding(question_answer),
payload={
"question": question,
"sql": sql,
},
)
],
)

return self._format_point_id(id, SQL_COLLECTION_NAME)

def add_ddl(self, ddl: str, **kwargs) -> str:
id = deterministic_uuid(ddl)
self._client.upsert(
DDL_COLLECTION_NAME,
points=[
models.PointStruct(
id=id,
vector=self.generate_embedding(ddl),
payload={
"ddl": ddl,
},
)
],
)
return self._format_point_id(id, DDL_COLLECTION_NAME)

def add_documentation(self, documentation: str, **kwargs) -> str:
id = deterministic_uuid(documentation)

self._client.upsert(
DOCUMENTATION_COLLECTION_NAME,
points=[
models.PointStruct(
id=id,
vector=self.generate_embedding(documentation),
payload={
"documentation": documentation,
},
)
],
)

return self._format_point_id(id, DOCUMENTATION_COLLECTION_NAME)

def get_training_data(self, **kwargs) -> pd.DataFrame:
df = pd.DataFrame()

if sql_data := self._get_all_points(SQL_COLLECTION_NAME):
question_list = [data.payload["question"] for data in sql_data]
sql_list = [data.payload["sql"] for data in sql_data]
id_list = [
self._format_point_id(data.id, SQL_COLLECTION_NAME) for data in sql_data
]

df_sql = pd.DataFrame(
{
"id": id_list,
"question": question_list,
"content": sql_list,
}
)

df_sql["training_data_type"] = "sql"

df = pd.concat([df, df_sql])

if ddl_data := self._get_all_points(DDL_COLLECTION_NAME):
ddl_list = [data.payload["ddl"] for data in ddl_data]
id_list = [
self._format_point_id(data.id, DDL_COLLECTION_NAME) for data in sql_data
]

df_ddl = pd.DataFrame(
{
"id": id_list,
"question": [None for _ in ddl_list],
"content": ddl_list,
}
)

df_ddl["training_data_type"] = "ddl"

df = pd.concat([df, df_ddl])

doc_data = self.documentation_collection.get()

if doc_data := self._get_all_points(DOCUMENTATION_COLLECTION_NAME):
document_list = [data.payload["documentation"] for data in doc_data]
id_list = [
self._format_point_id(data.id, DOCUMENTATION_COLLECTION_NAME)
for data in doc_data
]

df_doc = pd.DataFrame(
{
"id": id_list,
"question": [None for _ in document_list],
"content": document_list,
}
)

df_doc["training_data_type"] = "documentation"

df = pd.concat([df, df_doc])

return df

def remove_training_data(self, id: str, **kwargs) -> bool:
try:
id, collection_name = self._parse_point_id(id)
self._client.delete(collection_name, points_selector=[id])
except ValueError:
return False

def remove_collection(self, collection_name: str) -> bool:
"""
This function can reset the collection to empty state.
Args:
collection_name (str): sql or ddl or documentation
Returns:
bool: True if collection is deleted, False otherwise
"""
if collection_name in ID_SUFFIXES.keys():
self._client.delete_collection(collection_name)
self._setup_collections()
return True
else:
return False

@cached_property
def embeddings_dimension(self):
return len(self.generate_embedding("ABCDEF"))

def get_similar_question_sql(self, question: str, **kwargs) -> list:
results = self._client.search(
SQL_COLLECTION_NAME,
query_vector=self.generate_embedding(question),
limit=self.n_results,
with_payload=True,
)

return [dict(result.payload) for result in results]

def get_related_ddl(self, question: str, **kwargs) -> list:
results = self._client.search(
DDL_COLLECTION_NAME,
query_vector=self.generate_embedding(question),
limit=self.n_results,
with_payload=True,
)

return [result.payload["ddl"] for result in results]

def get_related_documentation(self, question: str, **kwargs) -> list:
results = self._client.search(
DOCUMENTATION_COLLECTION_NAME,
query_vector=self.generate_embedding(question),
limit=self.n_results,
with_payload=True,
)

return [result.payload["documentation"] for result in results]

def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding_model = self._client._get_or_init_model(
model_name=self.fastembed_model
)
embedding = next(embedding_model.embed(data))

return embedding.tolist()

def _get_all_points(self, collection_name: str):
results: List[models.Record] = []
next_offset = None
stop_scrolling = False
while not stop_scrolling:
records, next_offset = self._client.scroll(
collection_name,
limit=SCROLL_SIZE,
offset=next_offset,
with_payload=True,
with_vectors=False,
)
stop_scrolling = next_offset is None or (
isinstance(next_offset, grpc.PointId)
and next_offset.num == 0
and next_offset.uuid == ""
)

results.extend(records)

return results

def _setup_collections(self):
if not self._client.collection_exists(SQL_COLLECTION_NAME):
self._client.create_collection(
collection_name=SQL_COLLECTION_NAME,
vectors_config=models.VectorParams(
size=self.embeddings_dimension,
distance=self.distance_metric,
),
**self.collection_params,
)

if not self._client.collection_exists(DDL_COLLECTION_NAME):
self._client.create_collection(
collection_name=DDL_COLLECTION_NAME,
vectors_config=models.VectorParams(
size=self.embeddings_dimension,
distance=self.distance_metric,
),
**self.collection_params,
)
if not self._client.collection_exists(DOCUMENTATION_COLLECTION_NAME):
self._client.create_collection(
collection_name=DOCUMENTATION_COLLECTION_NAME,
vectors_config=models.VectorParams(
size=self.embeddings_dimension,
distance=self.distance_metric,
),
**self.collection_params,
)

def _format_point_id(self, id: str, collection_name: str) -> str:
return "{0}-{1}".format(id, ID_SUFFIXES[collection_name])

def _parse_point_id(self, id: str) -> Tuple[str, str]:
id, suffix = id.rsplit("-", 1)
for collection_name, suffix in ID_SUFFIXES.items():
if type == suffix:
return id, collection_name
raise ValueError(f"Invalid id {id}")
30 changes: 30 additions & 0 deletions tests/test_vanna.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,31 @@ def test_vn_default():
df = vn_default.run_sql(sql)
assert len(df) == 6

from vanna.qdrant import Qdrant_VectorStore


class VannaQdrant(Qdrant_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
Qdrant_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)

from qdrant_client import QdrantClient

qdrant_memory_client = QdrantClient(":memory:")

vn_qdrant = VannaQdrant(config={'client': qdrant_memory_client, 'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'})
vn_qdrant.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')

def test_vn_qdrant():
df_ddl = vn_qdrant.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")

for ddl in df_ddl['sql'].to_list():
vn_qdrant.train(ddl=ddl)

sql = vn_qdrant.generate_sql("What are the top 7 customers by sales?")
df = vn_qdrant.run_sql(sql)
assert len(df) == 7

from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
from vanna.openai.openai_chat import OpenAI_Chat

Expand All @@ -72,6 +97,11 @@ def __init__(self, config=None):
vn_chroma.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')

def test_vn_chroma():
existing_training_data = vn_chroma.get_training_data()
if len(existing_training_data) > 0:
for _, training_data in existing_training_data.iterrows():
vn_chroma.remove_training_data(training_data['id'])

df_ddl = vn_chroma.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")

for ddl in df_ddl['sql'].to_list():
Expand Down

0 comments on commit 86b516d

Please sign in to comment.