Skip to content

Commit

Permalink
Merge pull request #592 from vanna-ai/bigquery-vector
Browse files Browse the repository at this point in the history
Add BigQuery as metadata and vector storage
  • Loading branch information
zainhoda authored Aug 9, 2024
2 parents 175c98a + af1fabc commit 75776e3
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "vanna"
version = "0.6.5"
version = "0.6.6"
authors = [
{ name="Zain Hoda", email="[email protected]" },
]
Expand Down
15 changes: 5 additions & 10 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:
pass

@abstractmethod
def remove_training_data(id: str, **kwargs) -> bool:
def remove_training_data(self, id: str, **kwargs) -> bool:
"""
Example:
```python
Expand Down Expand Up @@ -1276,15 +1276,10 @@ def connect_to_bigquery(

def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]:
if conn:
try:
job = conn.query(sql)
df = job.result().to_dataframe()
return df
except GoogleAPIError as error:
errors = []
for error in error.errors:
errors.append(error["message"])
raise errors
job = conn.query(sql)
df = job.result().to_dataframe()
return df

return None

self.dialect = "BigQuery SQL"
Expand Down
4 changes: 2 additions & 2 deletions src/vanna/flask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from flask import Flask, Response, jsonify, request, send_from_directory
from flask_sock import Sock

from ..base import VannaBase
from .assets import css_content, html_content, js_content
from .auth import AuthInterface, NoAuth
from ..base import VannaBase


class Cache(ABC):
Expand Down Expand Up @@ -1211,7 +1211,7 @@ def __init__(
self.config["ask_results_correct"] = ask_results_correct
self.config["followup_questions"] = followup_questions
self.config["summarization"] = summarization
self.config["function_generation"] = function_generation
self.config["function_generation"] = function_generation and hasattr(vn, "get_function")

self.index_html_path = index_html_path
self.assets_folder = assets_folder
Expand Down
3 changes: 2 additions & 1 deletion src/vanna/google/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .gemini_chat import GoogleGeminiChat
from .bigquery_vector import BigQuery_VectorStore
from .gemini_chat import GoogleGeminiChat
230 changes: 230 additions & 0 deletions src/vanna/google/bigquery_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
import datetime
import os
import uuid
from typing import List, Optional

import pandas as pd
from google.cloud import bigquery

from ..base import VannaBase


class BigQuery_VectorStore(VannaBase):
def __init__(self, config: dict, **kwargs):
self.config = config

self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))

if "api_key" in config or os.getenv("GOOGLE_API_KEY"):
"""
If Google api_key is provided through config
or set as an environment variable, assign it.
"""
print("Configuring genai")
import google.generativeai as genai

genai.configure(api_key=config["api_key"])

self.genai = genai
else:
# Authenticate using VertexAI
from vertexai.language_models import (
TextEmbeddingInput,
TextEmbeddingModel,
)

if self.config.get("project_id"):
self.project_id = self.config.get("project_id")
else:
self.project_id = os.getenv("GOOGLE_CLOUD_PROJECT")

if self.project_id is None:
raise ValueError("Project ID is not set")

self.conn = bigquery.Client(project=self.project_id)

dataset_name = self.config.get('bigquery_dataset_name', 'vanna_managed')
self.dataset_id = f"{self.project_id}.{dataset_name}"
dataset = bigquery.Dataset(self.dataset_id)

try:
self.conn.get_dataset(self.dataset_id) # Make an API request.
print(f"Dataset {self.dataset_id} already exists")
except Exception:
# Dataset does not exist, create it
dataset.location = "US"
self.conn.create_dataset(dataset, timeout=30) # Make an API request.
print(f"Created dataset {self.dataset_id}")

# Create a table called training_data in the dataset that contains the columns:
# id, training_data_type, question, content, embedding, created_at

self.table_id = f"{self.dataset_id}.training_data"
schema = [
bigquery.SchemaField("id", "STRING", mode="REQUIRED"),
bigquery.SchemaField("training_data_type", "STRING", mode="REQUIRED"),
bigquery.SchemaField("question", "STRING", mode="REQUIRED"),
bigquery.SchemaField("content", "STRING", mode="REQUIRED"),
bigquery.SchemaField("embedding", "FLOAT64", mode="REPEATED"),
bigquery.SchemaField("created_at", "TIMESTAMP", mode="REQUIRED"),
]

table = bigquery.Table(self.table_id, schema=schema)

try:
self.conn.get_table(self.table_id) # Make an API request.
print(f"Table {self.table_id} already exists")
except Exception:
# Table does not exist, create it
self.conn.create_table(table, timeout=30) # Make an API request.
print(f"Created table {self.table_id}")

# Create VECTOR INDEX IF NOT EXISTS
# TODO: This requires 5000 rows before it can be created
# vector_index_query = f"""
# CREATE VECTOR INDEX IF NOT EXISTS my_index
# ON `{self.table_id}`(embedding)
# OPTIONS(
# distance_type='COSINE',
# index_type='IVF',
# ivf_options='{{"num_lists": 1000}}'
# )
# """

# try:
# self.conn.query(vector_index_query).result() # Make an API request.
# print(f"Vector index on {self.table_id} created or already exists")
# except Exception as e:
# print(f"Failed to create vector index: {e}")

def store_training_data(self, training_data_type: str, question: str, content: str, embedding: List[float], **kwargs) -> str:
id = str(uuid.uuid4())
created_at = datetime.datetime.now()
self.conn.insert_rows_json(self.table_id, [{
"id": id,
"training_data_type": training_data_type,
"question": question,
"content": content,
"embedding": embedding,
"created_at": created_at.isoformat()
}])

return id

def fetch_similar_training_data(self, training_data_type: str, question: str, n_results, **kwargs) -> pd.DataFrame:
question_embedding = self.generate_question_embedding(question)

query = f"""
SELECT
base.id as id,
base.question as question,
base.training_data_type as training_data_type,
base.content as content,
distance
FROM
VECTOR_SEARCH(
TABLE `{self.table_id}`,
'embedding',
(SELECT * FROM UNNEST([STRUCT({question_embedding})])),
top_k => 5,
distance_type => 'COSINE',
options => '{{"use_brute_force":true}}'
)
WHERE
base.training_data_type = '{training_data_type}'
"""

results = self.conn.query(query).result().to_dataframe()
return results

def generate_question_embedding(self, data: str, **kwargs) -> List[float]:
result = self.genai.embed_content(
model="models/text-embedding-004",
content=data,
task_type="retrieval_query")

if 'embedding' in result:
return result['embedding']
else:
raise ValueError("No embeddings returned")

def generate_storage_embedding(self, data: str, **kwargs) -> List[float]:
result = self.genai.embed_content(
model="models/text-embedding-004",
content=data,
task_type="retrieval_document")

if 'embedding' in result:
return result['embedding']
else:
raise ValueError("No embeddings returned")

# task = "RETRIEVAL_DOCUMENT"
# inputs = [TextEmbeddingInput(data, task)]
# embeddings = self.vertex_embedding_model.get_embeddings(inputs)

# if len(embeddings) == 0:
# raise ValueError("No embeddings returned")

# return embeddings[0].values

return result

def generate_embedding(self, data: str, **kwargs) -> List[float]:
return self.generate_storage_embedding(data, **kwargs)

def get_similar_question_sql(self, question: str, **kwargs) -> list:
df = self.fetch_similar_training_data(training_data_type="sql", question=question, n_results=self.n_results_sql)

# Return a list of dictionaries with only question, sql fields. The content field needs to be renamed to sql
return df.rename(columns={"content": "sql"})[["question", "sql"]].to_dict(orient="records")

def get_related_ddl(self, question: str, **kwargs) -> list:
df = self.fetch_similar_training_data(training_data_type="ddl", question=question, n_results=self.n_results_ddl)

# Return a list of strings of the content
return df["content"].tolist()

def get_related_documentation(self, question: str, **kwargs) -> list:
df = self.fetch_similar_training_data(training_data_type="documentation", question=question, n_results=self.n_results_documentation)

# Return a list of strings of the content
return df["content"].tolist()

def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
doc = {
"question": question,
"sql": sql
}

embedding = self.generate_embedding(str(doc))

return self.store_training_data(training_data_type="sql", question=question, content=sql, embedding=embedding)

def add_ddl(self, ddl: str, **kwargs) -> str:
embedding = self.generate_embedding(ddl)

return self.store_training_data(training_data_type="ddl", question="", content=ddl, embedding=embedding)

def add_documentation(self, documentation: str, **kwargs) -> str:
embedding = self.generate_embedding(documentation)

return self.store_training_data(training_data_type="documentation", question="", content=documentation, embedding=embedding)

def get_training_data(self, **kwargs) -> pd.DataFrame:
query = f"SELECT id, training_data_type, question, content FROM `{self.table_id}`"

return self.conn.query(query).result().to_dataframe()

def remove_training_data(self, id: str, **kwargs) -> bool:
query = f"DELETE FROM `{self.table_id}` WHERE id = '{id}'"

try:
self.conn.query(query).result()
return True

except Exception as e:
print(f"Failed to remove training data: {e}")
return False

0 comments on commit 75776e3

Please sign in to comment.