Skip to content

Commit

Permalink
Merge pull request #496 from zc277584121/milvus
Browse files Browse the repository at this point in the history
Add Milvus vectorstore support
  • Loading branch information
zainhoda authored Jun 18, 2024
2 parents e46b2f1 + d23bd5e commit 254ff6a
Show file tree
Hide file tree
Showing 5 changed files with 338 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
duckdb = ["duckdb"]
google = ["google-generativeai", "google-cloud-aiplatform"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]"]
test = ["tox"]
chromadb = ["chromadb"]
openai = ["openai"]
Expand All @@ -48,3 +48,4 @@ vllm = ["vllm"]
pinecone = ["pinecone-client", "fastembed"]
opensearch = ["opensearch-py", "opensearch-dsl"]
hf = ["transformers"]
milvus = ["pymilvus[model]"]
1 change: 1 addition & 0 deletions src/vanna/milvus/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .milvus_vector import Milvus_VectorStore
305 changes: 305 additions & 0 deletions src/vanna/milvus/milvus_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
import uuid
from typing import List

import pandas as pd
from pymilvus import DataType, MilvusClient, model

from ..base import VannaBase

# Setting the URI as a local file, e.g.`./milvus.db`,
# is the most convenient method, as it automatically utilizes Milvus Lite
# to store all data in this file.
#
# If you have large scale of data such as more than a million docs, we
# recommend setting up a more performant Milvus server on docker or kubernetes.
# When using this setup, please use the server URI,
# e.g.`http://localhost:19530`, as your URI.

DEFAULT_MILVUS_URI = "./milvus.db"
# DEFAULT_MILVUS_URI = "http://localhost:19530"

MAX_LIMIT_SIZE = 10_000


class Milvus_VectorStore(VannaBase):
"""
Vectorstore implementation using Milvus - https://milvus.io/docs/quickstart.md
Args:
- config (dict, optional): Dictionary of `Milvus_VectorStore config` options. Defaults to `None`.
- milvus_client: A `pymilvus.MilvusClient` instance.
- embedding_function:
A `milvus_model.base.BaseEmbeddingFunction` instance. Defaults to `DefaultEmbeddingFunction()`.
For more models, please refer to:
https://milvus.io/docs/embeddings.md
"""
def __init__(self, config=None):
VannaBase.__init__(self, config=config)

if "milvus_client" in config:
self.milvus_client = config["milvus_client"]
else:
self.milvus_client = MilvusClient(uri=DEFAULT_MILVUS_URI)

if "embedding_function" in config:
self.embedding_function = config.get("embedding_function")
else:
self.embedding_function = model.DefaultEmbeddingFunction()
self._embedding_dim = self.embedding_function.encode_documents(["foo"])[0].shape[0]
self._create_collections()
self.n_results = config.get("n_results", 10)

def _create_collections(self):
self._create_sql_collection("vannasql")
self._create_ddl_collection("vannaddl")
self._create_doc_collection("vannadoc")


def generate_embedding(self, data: str, **kwargs) -> List[float]:
return self.embedding_function.encode_documents(data).tolist()


def _create_sql_collection(self, name: str):
if not self.milvus_client.has_collection(collection_name=name):
vannasql_schema = MilvusClient.create_schema(
auto_id=False,
enable_dynamic_field=False,
)
vannasql_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
vannasql_schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=65535)
vannasql_schema.add_field(field_name="sql", datatype=DataType.VARCHAR, max_length=65535)
vannasql_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)

vannasql_index_params = self.milvus_client.prepare_index_params()
vannasql_index_params.add_index(
field_name="vector",
index_name="vector",
index_type="AUTOINDEX",
metric_type="L2",
)
self.milvus_client.create_collection(
collection_name=name,
schema=vannasql_schema,
index_params=vannasql_index_params,
consistency_level="Strong"
)

def _create_ddl_collection(self, name: str):
if not self.milvus_client.has_collection(collection_name=name):
vannaddl_schema = MilvusClient.create_schema(
auto_id=False,
enable_dynamic_field=False,
)
vannaddl_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
vannaddl_schema.add_field(field_name="ddl", datatype=DataType.VARCHAR, max_length=65535)
vannaddl_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)

vannaddl_index_params = self.milvus_client.prepare_index_params()
vannaddl_index_params.add_index(
field_name="vector",
index_name="vector",
index_type="AUTOINDEX",
metric_type="L2",
)
self.milvus_client.create_collection(
collection_name=name,
schema=vannaddl_schema,
index_params=vannaddl_index_params,
consistency_level="Strong"
)

def _create_doc_collection(self, name: str):
if not self.milvus_client.has_collection(collection_name=name):
vannadoc_schema = MilvusClient.create_schema(
auto_id=False,
enable_dynamic_field=False,
)
vannadoc_schema.add_field(field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True)
vannadoc_schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
vannadoc_schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self._embedding_dim)

vannadoc_index_params = self.milvus_client.prepare_index_params()
vannadoc_index_params.add_index(
field_name="vector",
index_name="vector",
index_type="AUTOINDEX",
metric_type="L2",
)
self.milvus_client.create_collection(
collection_name=name,
schema=vannadoc_schema,
index_params=vannadoc_index_params,
consistency_level="Strong"
)

def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
if len(question) == 0 or len(sql) == 0:
raise Exception("pair of question and sql can not be null")
_id = str(uuid.uuid4()) + "-sql"
embedding = self.embedding_function.encode_documents([question])[0]
self.milvus_client.insert(
collection_name="vannasql",
data={
"id": _id,
"text": question,
"sql": sql,
"vector": embedding
}
)
return _id

def add_ddl(self, ddl: str, **kwargs) -> str:
if len(ddl) == 0:
raise Exception("ddl can not be null")
_id = str(uuid.uuid4()) + "-ddl"
embedding = self.embedding_function.encode_documents([ddl])[0]
self.milvus_client.insert(
collection_name="vannaddl",
data={
"id": _id,
"ddl": ddl,
"vector": embedding
}
)
return _id

def add_documentation(self, documentation: str, **kwargs) -> str:
if len(documentation) == 0:
raise Exception("documentation can not be null")
_id = str(uuid.uuid4()) + "-doc"
embedding = self.embedding_function.encode_documents([documentation])[0]
self.milvus_client.insert(
collection_name="vannadoc",
data={
"id": _id,
"doc": documentation,
"vector": embedding
}
)
return _id

def get_training_data(self, **kwargs) -> pd.DataFrame:
sql_data = self.milvus_client.query(
collection_name="vannasql",
output_fields=["*"],
limit=MAX_LIMIT_SIZE,
)
df = pd.DataFrame()
df_sql = pd.DataFrame(
{
"id": [doc["id"] for doc in sql_data],
"question": [doc["text"] for doc in sql_data],
"content": [doc["sql"] for doc in sql_data],
}
)
df = pd.concat([df, df_sql])

ddl_data = self.milvus_client.query(
collection_name="vannaddl",
output_fields=["*"],
limit=MAX_LIMIT_SIZE,
)

df_ddl = pd.DataFrame(
{
"id": [doc["id"] for doc in ddl_data],
"question": [None for doc in ddl_data],
"content": [doc["ddl"] for doc in ddl_data],
}
)
df = pd.concat([df, df_ddl])

doc_data = self.milvus_client.query(
collection_name="vannadoc",
output_fields=["*"],
limit=MAX_LIMIT_SIZE,
)

df_doc = pd.DataFrame(
{
"id": [doc["id"] for doc in doc_data],
"question": [None for doc in doc_data],
"content": [doc["doc"] for doc in doc_data],
}
)
df = pd.concat([df, df_doc])
return df

def get_similar_question_sql(self, question: str, **kwargs) -> list:
search_params = {
"metric_type": "L2",
"params": {"nprobe": 128},
}
embeddings = self.embedding_function.encode_queries([question])
res = self.milvus_client.search(
collection_name="vannasql",
anns_field="vector",
data=embeddings,
limit=self.n_results,
output_fields=["text", "sql"],
search_params=search_params
)
res = res[0]

list_sql = []
for doc in res:
dict = {}
dict["question"] = doc["entity"]["text"]
dict["sql"] = doc["entity"]["sql"]
list_sql.append(dict)
return list_sql

def get_related_ddl(self, question: str, **kwargs) -> list:
search_params = {
"metric_type": "L2",
"params": {"nprobe": 128},
}
embeddings = self.embedding_function.encode_queries([question])
res = self.milvus_client.search(
collection_name="vannaddl",
anns_field="vector",
data=embeddings,
limit=self.n_results,
output_fields=["ddl"],
search_params=search_params
)
res = res[0]

list_ddl = []
for doc in res:
list_ddl.append(doc["entity"]["ddl"])
return list_ddl

def get_related_documentation(self, question: str, **kwargs) -> list:
search_params = {
"metric_type": "L2",
"params": {"nprobe": 128},
}
embeddings = self.embedding_function.encode_queries([question])
res = self.milvus_client.search(
collection_name="vannadoc",
anns_field="vector",
data=embeddings,
limit=self.n_results,
output_fields=["doc"],
search_params=search_params
)
res = res[0]

list_doc = []
for doc in res:
list_doc.append(doc["entity"]["doc"])
return list_doc

def remove_training_data(self, id: str, **kwargs) -> bool:
if id.endswith("-sql"):
self.milvus_client.delete(collection_name="vannasql", ids=[id])
return True
elif id.endswith("-ddl"):
self.milvus_client.delete(collection_name="vannaddl", ids=[id])
return True
elif id.endswith("-doc"):
self.milvus_client.delete(collection_name="vannadoc", ids=[id])
return True
else:
return False
2 changes: 2 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def test_regular_imports():
from vanna.hf.hf import Hf
from vanna.local import LocalContext_OpenAI
from vanna.marqo.marqo import Marqo_VectorStore
from vanna.milvus.milvus_vector import Milvus_VectorStore
from vanna.mistral.mistral import Mistral
from vanna.ollama.ollama import Ollama
from vanna.openai.openai_chat import OpenAI_Chat
Expand All @@ -24,6 +25,7 @@ def test_shortcut_imports():
from vanna.chromadb import ChromaDB_VectorStore
from vanna.hf import Hf
from vanna.marqo import Marqo_VectorStore
from vanna.milvus import Milvus_VectorStore
from vanna.mistral import Mistral
from vanna.ollama import Ollama
from vanna.openai import OpenAI_Chat, OpenAI_Embeddings
Expand Down
28 changes: 28 additions & 0 deletions tests/test_vanna.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,34 @@ def test_vn_chroma():
df = vn_chroma.run_sql(sql)
assert len(df) == 7


from vanna.milvus import Milvus_VectorStore


class VannaMilvus(Milvus_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
Milvus_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)

vn_milvus = VannaMilvus(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'})
vn_milvus.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')

def test_vn_milvus():
existing_training_data = vn_milvus.get_training_data()
if len(existing_training_data) > 0:
for _, training_data in existing_training_data.iterrows():
vn_milvus.remove_training_data(training_data['id'])

df_ddl = vn_milvus.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")

for ddl in df_ddl['sql'].to_list():
vn_milvus.train(ddl=ddl)

sql = vn_milvus.generate_sql("What are the top 7 customers by sales?")
df = vn_milvus.run_sql(sql)
assert len(df) == 7


class VannaNumResults(ChromaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
Expand Down

0 comments on commit 254ff6a

Please sign in to comment.