Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cache column sample values for future use #4
Browse files Browse the repository at this point in the history
pramitchoudhary committed Jul 28, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 192da49 commit 2823f01
Showing 4 changed files with 81 additions and 19 deletions.
2 changes: 2 additions & 0 deletions sidekick/configs/data_template.py
Original file line number Diff line number Diff line change
@@ -11,3 +11,5 @@
"Column Type": "",
"Sample Values": []
}

data_samples_template = "Column {column_name} contains values similar to {comma_separated_sample_values}."
15 changes: 15 additions & 0 deletions sidekick/db_config.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
import sqlalchemy
from pandasql import sqldf
from psycopg2.extras import Json
from sidekick.configs.data_template import data_samples_template
from sidekick.logger import logger
from sqlalchemy import create_engine
from sqlalchemy_utils import database_exists
@@ -73,6 +74,7 @@ def _extract_schema_info(self, schema_info_path=None):
with open(table_info_file, "w") as outfile:
schema_info_path = json.load(outfile)["schema_info_path"]
res = []
sample_values = []
try:
if Path(schema_info_path).exists():
with open(schema_info_path, "r") as in_file:
@@ -82,8 +84,21 @@ def _extract_schema_info(self, schema_info_path=None):
if "Column Name" in data and "Column Type" in data:
col_name = data["Column Name"]
col_type = data["Column Type"]
# if column has sample values, save in cache for future use.
if "Sample Values" in data:
_sample_values = data["Sample Values"]
_ds = data_samples_template.format(
column_name=col_name, comma_separated_sample_values=",".join(_sample_values)
)
sample_values.append(_ds)
_new_samples = f"{col_name} {col_type}"
res.append(_new_samples)
if len(sample_values):
# cache it for future use
with open(
f"{self.base_path}/var/lib/tmp/data/{self._table_name}_column_values.json", "w"
) as outfile:
json.dump(sample_values, outfile, indent=2, sort_keys=False)
except ValueError as ve:
logger.error(f"Error in reading table context file: {ve}")
pass
27 changes: 19 additions & 8 deletions sidekick/query.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
LLMPredictor, ServiceContext, SQLDatabase)
from llama_index.indices.struct_store import SQLContextContainerBuilder
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, QUERY_PROMPT,
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, QUERY_PROMPT, NSQL_QUERY_PROMPT,
TASK_PROMPT)
from sidekick.logger import logger
from sidekick.utils import filter_samples, read_sample_pairs, remove_duplicates
@@ -50,6 +50,16 @@ def __init__(
self.openai_key = openai_key
self.content_queries = None

def load_table_info(self):
# Read table_info.jsonl
table_info_file = f"{self.path}/var/lib/tmp/data/table_context.json"
def setup(self):

# Load the table information
self.load_table_info()



def build_index(self, persist: bool = True):
# Below re-assignment of the OPENAI API key is weird but without that, it throws an error.
os.environ["OPENAI_API_KEY"] = self.openai_key
@@ -195,7 +205,7 @@ def generate_tasks(self, table_names: list, input_question: str):
except Exception as se:
raise se

def generate_sql(self, table_name: list, input_question: str, _dialect: str = "postgres", model_name: str = "nsql"):
def generate_sql(self, table_name: list, input_question: str, _dialect: str = "SQLite", model_name: str = "nsql"):
context_file = f"{self.path}/var/lib/tmp/data/context.json"
additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {}
context_queries = self.content_queries
@@ -251,8 +261,9 @@ def generate_sql(self, table_name: list, input_question: str, _dialect: str = "p
tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B")
model = AutoModelForCausalLM.from_pretrained("NumbersStation/nsql-6B")

data_samples = context_queries

# TODO add additional context.
# Note: Context seems to have effect only when sample QnA pairs are not present.
# Tried different prompt settings but was not able to get the desired results consistently.
_context = {
"if patterns like 'current time' or 'now' occurs in question": "always use NOW() - INTERVAL",
"if patterns like 'total number', or 'List' occurs in question": "always use DISTINCT",
@@ -270,16 +281,16 @@ def generate_sql(self, table_name: list, input_question: str, _dialect: str = "p
contextual_context.append(f"{_item}: {_val}")

print("Filtering Question/Query pairs")
_samples = filter_samples(input_question, probable_qs=sample_pairs,
model_path=local_model_path, threshold=0.90)
_samples = filter_samples(input_question, probable_qs=context_queries,
model_path='', threshold=0.90)

# If QnA pairs > 5, we keep only 5 of them for focused context
if len(_samples) > 5:
_samples = _samples[0:5][::-1]
qna_samples = '\n'.join(_samples)

contextual_context_val = ', '.join(contextual_context)

column_names = [str(_c) for _c in self.sql_database.get_column_names(table_name[0])]
if len(_samples) > 2:
# Check for the columns in the QnA samples provided, if exists keep them
context_columns = [_c for _c in column_names if _c.lower() in qna_samples.lower()]
@@ -289,7 +300,7 @@ def generate_sql(self, table_name: list, input_question: str, _dialect: str = "p
relevant_columns = context_columns if len(context_columns) > 0 else column_names
_data_info = ', '.join(relevant_columns)

query = prompt_template.format(table_name=_table_name, data_info=_data_info, data_info_detailed=data_samples,
query = NSQL_QUERY_PROMPT.format(table_name=table_name, data_info=_data_info, data_info_detailed=data_samples,
sample_queries=qna_samples, context=contextual_context_val,
question_txt=input_question)

56 changes: 45 additions & 11 deletions sidekick/utils.py
Original file line number Diff line number Diff line change
@@ -4,10 +4,12 @@
from pathlib import Path
from typing import Optional

import torch
import numpy as np
import pandas as pd
from pandasql import sqldf
from sentence_transformers import SentenceTransformer
from InstructorEmbedding import INSTRUCTOR
from sidekick.logger import logger
from sklearn.metrics.pairwise import cosine_similarity

@@ -37,6 +39,38 @@ def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, devic
return all_res


def generate_text_embeddings(model_path: str, x, batch_size: int = 32, device: Optional[str] = 'cpu'):
# Reference:
# 1. https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models
# 2. Evaluation result: https://www.sbert.net/_static/html/models_en_sentence_embeddings.html
# 3. Model Card: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
# 4. Reference: https://huggingface.co/spaces/mteb/leaderboard
# Maps sentence & paragraphs to a 384 dimensional dense vector space.
model_name_path = f"{model_path}/text_embedding/instructor-large"
current_torch_home = os.environ.get("TORCH_HOME", "")
if Path(model_name_path).is_dir():
is_empty = not any(Path(model_name_path).iterdir())
if is_empty:
# Download n cache at the specified location
os.environ["TORCH_HOME"] = model_path
model_name_path = "hkunlp/instructor-large"
sentence_model = INSTRUCTOR(model_name_path, device=device)
if device != 'cuda':
# Issue https://github.com/pytorch/pytorch/issues/69364
# # In the initial experimentation, quantized model is generates slightly better results
_model = torch.quantization.quantize_dynamic(
sentence_model, {torch.nn.Linear}, dtype=torch.qint8)
else:
_model = sentence_model
_sentences = [['Represent the Financial question for retrieving duplicate examples: ', _item] for _item in x]

res = _model.encode(_sentences)
del sentence_model
del _model
os.environ["TORCH_HOME"] = current_torch_home
return res


def filter_samples(input_q: str, probable_qs: list, model_path: str, threshold: float = 0.45):
# Only consider the questions, note: this might change in future.
_inq = ("# query: " + input_q).strip().lower()
@@ -102,21 +136,21 @@ def read_sample_pairs(input_path: str, model_name: str = "nsql"):
df = df.reset_index(drop=True)

# NSQL format
if model_name != 'nsql':
if model_name != "nsql":
# Open AI format
# Convert frame to below format
# [
# "# query": ""
# "# answer": ""
# ]
# Convert frame to below format
# [
# "# query": ""
# "# answer": ""
# ]
res = df.apply(lambda row: f"# query: {row['query']}\n# answer: {row['answer']}", axis=1).to_list()
else:
# Convert frame to below format
# [
# "Question": <question_text>
# "Answer":
# <response_text>
# ]
# [
# "Question": <question_text>
# "Answer":
# <response_text>
# ]
res = df.apply(lambda row: f"Question: {row['query']}\nAnswer:\n{row['answer']}", axis=1).to_list()
return res

0 comments on commit 2823f01

Please sign in to comment.