Skip to content

Commit

Permalink
Initial skeleton to support local LLM #4
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed Jul 26, 2023
1 parent 25a9570 commit 192da49
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 63 deletions.
4 changes: 2 additions & 2 deletions sidekick/configs/.env.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[OPENAI]
OPENAI_API_KEY = ""
MODEL_NAME = "gpt-3.5-turbo-0301" # Others: e.g. gpt-4, gpt-4-32k, text-davinci-003
MODEL_NAME = "nsql" # Others: e.g. gpt-4, gpt-4-32k, text-davinci-003

[LOCAL_DB_CONFIG]
HOST_NAME = "localhost"
Expand All @@ -13,4 +13,4 @@ PORT = "5432"
LOG-LEVEL = "INFO"

[DB-DIALECT]
DB_TYPE = "postgresql"
DB_TYPE = "SQLite" # postgresql
17 changes: 17 additions & 0 deletions sidekick/configs/prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,20 @@
Query:\n {qry_txt}
""",
}

NSQL_QUERY_PROMPT = """
For SQL TABLE '{table_name}' sample question/answer pairs,\n({sample_queries})
CREATE TABLE '{table_name}'({data_info}
)
Table '{table_name}' has sample values ({data_info_detailed})
-- Using valid {_dialect}, answer the following questions with the information for '{table_name}' provided above; for final SQL only use values from the question.
-- Using reference for TABLES '{table_name}' {context}; {question_txt}?
SELECT"""
12 changes: 0 additions & 12 deletions sidekick/db_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,6 @@ def create_table(self, schema_info_path=None, schema_info=None):
# If schema information is not provided, extract from the template.
self.schema_info = """,\n""".join(self._extract_schema_info(schema_info_path)).strip()
logger.debug(f"Schema info used for creating table:\n {self.schema_info}")
# self.schema_info = """
# id uuid PRIMARY KEY,
# ts TIMESTAMP WITH TIME ZONE NOT NULL,
# kind TEXT NOT NULL, -- or int?,
# user_id TEXT,
# user_name TEXT,
# resource_type TEXT NOT NULL, -- or int?,
# resource_id TEXT,
# stream TEXT NOT NULL,
# source TEXT NOT NULL,
# payload jsonb NOT NULL
# """
create_syntax = f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
{self.schema_info}
Expand Down
142 changes: 100 additions & 42 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
import openai
import sqlglot
import toml
import torch
from langchain import OpenAI
from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
LLMPredictor, ServiceContext, SQLDatabase)
from llama_index.indices.struct_store import SQLContextContainerBuilder
from sidekick.configs.prompt_template import DEBUGGING_PROMPT, QUERY_PROMPT, TASK_PROMPT
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, QUERY_PROMPT,
TASK_PROMPT)
from sidekick.logger import logger
from sidekick.utils import csv_parser, filter_samples, remove_duplicates
from sidekick.utils import filter_samples, read_sample_pairs, remove_duplicates
from sqlalchemy import create_engine
from transformers import AutoModelForCausalLM, AutoTokenizer


def _check_file_info(file_path: str):
Expand Down Expand Up @@ -63,7 +67,7 @@ def update_context_queries(self):
new_context_queries = []
if self.sample_queries_path is not None and Path(self.sample_queries_path).exists():
logger.info(f"Using samples from path {self.sample_queries_path}")
new_context_queries = csv_parser(self.sample_queries_path)
new_context_queries = read_sample_pairs(self.sample_queries_path, "gpt")
# cache the samples for future use
with open(f"{self.path}/var/lib/tmp/data/queries_cache.json", "w") as f:
json.dump(new_context_queries, f, indent=2)
Expand Down Expand Up @@ -191,51 +195,105 @@ def generate_tasks(self, table_names: list, input_question: str):
except Exception as se:
raise se

def generate_sql(
self, table_name: list, input_question: str, _dialect: str = "postgres", model_name: str = "gpt-3.5-turbo-0301"
):
_tasks = self.task_formatter(self._tasks)
def generate_sql(self, table_name: list, input_question: str, _dialect: str = "postgres", model_name: str = "nsql"):
context_file = f"{self.path}/var/lib/tmp/data/context.json"
additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {}

context_queries = self.content_queries
# TODO: The need to pass data info again could be eliminated if Task generation becomes more consistent and accurate.
query_str = QUERY_PROMPT.format(
_dialect=_dialect,
_data_info=self._data_info,
_question=input_question,
_table_name=table_name,
_sample_queries=context_queries,
_tasks=_tasks,
)

table_context_dict = {str(table_name[0]).lower(): str(additional_context).lower()}
self.context_builder = SQLContextContainerBuilder(self.sql_database, context_dict=table_context_dict)
if model_name != "nsql":
_tasks = self.task_formatter(self._tasks)

table_schema_index = self.build_index(persist=False)
self.context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True)
context_container = self.context_builder.build_context_container()
# TODO: The need to pass data info again could be eliminated if Task generation becomes more consistent and accurate.
query_str = QUERY_PROMPT.format(
_dialect=_dialect,
_data_info=self._data_info,
_question=input_question,
_table_name=table_name,
_sample_queries=context_queries,
_tasks=_tasks,
)

# Reference: https://github.com/jerryjliu/llama_index/issues/987
llm_predictor_gpt3 = LLMPredictor(llm=OpenAI(temperature=0.5, model_name=model_name))
service_context_gpt3 = ServiceContext.from_defaults(llm_predictor=llm_predictor_gpt3, chunk_size_limit=512)
table_context_dict = {str(table_name[0]).lower(): str(additional_context).lower()}
self.context_builder = SQLContextContainerBuilder(self.sql_database, context_dict=table_context_dict)

index = GPTSQLStructStoreIndex(
[], sql_database=self.sql_database, table_name=table_name, service_context=service_context_gpt3
)
res = self.generate_response(context_container, sql_index=index, input_prompt=query_str)
try:
# Check if `SQL` is formatted ---> ``` SQL_text ```
if "```" in str(res):
res = (
str(res).split("```", 1)[1].split(";", 1)[0].strip().replace("```", "").replace("sql\n", "").strip()
)
else:
res = str(res).split("Explanation:", 1)[0].strip()
sqlglot.transpile(res)
except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e:
logger.info("We did the best we could, there might be still be some error:\n")
logger.info(f"Realized query so far:\n {res}")
table_schema_index = self.build_index(persist=False)
self.context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True)
context_container = self.context_builder.build_context_container()

# Reference: https://github.com/jerryjliu/llama_index/issues/987
llm_predictor_gpt3 = LLMPredictor(llm=OpenAI(temperature=0.5, model_name=model_name))
service_context_gpt3 = ServiceContext.from_defaults(llm_predictor=llm_predictor_gpt3, chunk_size_limit=512)

index = GPTSQLStructStoreIndex(
[], sql_database=self.sql_database, table_name=table_name, service_context=service_context_gpt3
)
res = self.generate_response(context_container, sql_index=index, input_prompt=query_str)
try:
# Check if `SQL` is formatted ---> ``` SQL_text ```
if "```" in str(res):
res = (
str(res)
.split("```", 1)[1]
.split(";", 1)[0]
.strip()
.replace("```", "")
.replace("sql\n", "")
.strip()
)
else:
res = str(res).split("Explanation:", 1)[0].strip()
sqlglot.transpile(res)
except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e:
logger.info("We did the best we could, there might be still be some error:\n")
logger.info(f"Realized query so far:\n {res}")
else:
# Load h2oGPT.NSQL model
tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B")
model = AutoModelForCausalLM.from_pretrained("NumbersStation/nsql-6B")

data_samples = context_queries

_context = {
"if patterns like 'current time' or 'now' occurs in question": "always use NOW() - INTERVAL",
"if patterns like 'total number', or 'List' occurs in question": "always use DISTINCT",
}

filtered_context = filter_samples(input_question, probable_qs=list(_context.keys()),
model_path='', threshold=0.845)

print(f"Filter Context: {filtered_context}")

contextual_context = []
for _item in filtered_context:
_val = _context.get(_item, None)
if _val:
contextual_context.append(f"{_item}: {_val}")

print("Filtering Question/Query pairs")
_samples = filter_samples(input_question, probable_qs=sample_pairs,
model_path=local_model_path, threshold=0.90)

# If QnA pairs > 5, we keep only 5 of them for focused context
if len(_samples) > 5:
_samples = _samples[0:5][::-1]
qna_samples = '\n'.join(_samples)

contextual_context_val = ', '.join(contextual_context)

if len(_samples) > 2:
# Check for the columns in the QnA samples provided, if exists keep them
context_columns = [_c for _c in column_names if _c.lower() in qna_samples.lower()]
if len(context_columns) > 0:
contextual_data_samples = [_d for _cc in context_columns for _d in data_samples_list if _cc.lower() in _d.lower()]
data_samples = contextual_data_samples
relevant_columns = context_columns if len(context_columns) > 0 else column_names
_data_info = ', '.join(relevant_columns)

query = prompt_template.format(table_name=_table_name, data_info=_data_info, data_info_detailed=data_samples,
sample_queries=qna_samples, context=contextual_context_val,
question_txt=input_question)

input_ids = tokenizer(query, return_tensors="pt").input_ids
return res

def task_formatter(self, input_task: str):
Expand Down
25 changes: 18 additions & 7 deletions sidekick/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,29 @@ def setup_dir(base_path: str):
p.mkdir(parents=True, exist_ok=True)


def csv_parser(input_path: str):
def read_sample_pairs(input_path: str, model_name: str = "nsql"):
df = pd.read_csv(input_path)
df = df.dropna()
df = df.drop_duplicates()
df = df.reset_index(drop=True)

# Convert frame to below format
# [
# "# query": ""
# "# answer": ""
# ]
res = df.apply(lambda row: f"# query: {row['query']}\n# answer: {row['answer']}", axis=1).to_list()
# NSQL format
if model_name != 'nsql':
# Open AI format
# Convert frame to below format
# [
# "# query": ""
# "# answer": ""
# ]
res = df.apply(lambda row: f"# query: {row['query']}\n# answer: {row['answer']}", axis=1).to_list()
else:
# Convert frame to below format
# [
# "Question": <question_text>
# "Answer":
# <response_text>
# ]
res = df.apply(lambda row: f"Question: {row['query']}\nAnswer:\n{row['answer']}", axis=1).to_list()
return res


Expand Down

0 comments on commit 192da49

Please sign in to comment.