Skip to content

Commit

Permalink
Remove <key><value> tags before persisting #1
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed May 31, 2023
1 parent 4742c5c commit 569c8a9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
11 changes: 11 additions & 0 deletions sidekick/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ def save_context(self, info: str, extract_context: bool = True) -> Dict:

chat_history = {}
if query.strip() and "SELECT".lower() in response.lower():
# Remove <key> and <value> tags from the query/response before persisting
query = (
query.lower().replace("<key>", "").replace("</key>", "").replace("<value>", "").replace("</value>", "")
)
response = (
response.lower()
.replace("<key>", "")
.replace("</key>", "")
.replace("<value>", "")
.replace("</value>", "")
)
chat_history = {"Query": query, "Answer": response, "entity": extracted_entity}
self.track_history.append(chat_history)
else:
Expand Down
6 changes: 4 additions & 2 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from configs.prompt_template import DEBUGGING_PROMPT, QUERY_PROMPT, TASK_PROMPT
from examples.sample_data import sample_values, samples_queries
from langchain import OpenAI
from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
LLMPredictor, ServiceContext, SQLDatabase)
from llama_index.indices.struct_store import SQLContextContainerBuilder
from loguru import logger
from sqlalchemy import create_engine
Expand Down Expand Up @@ -70,9 +71,9 @@ def _query_tasks(self, question_str, data_info, sample_queries, table_name: list
_context=str(additional_context).lower(),
_question_str=question_str,
)

# Role and content
query_txt = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
logger.debug(f"Query Text:\n {query_txt}")
# TODO ADD local model
completion = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0301",
Expand Down Expand Up @@ -157,6 +158,7 @@ def generate_sql(self, table_name: str, input_question: str, _dialect: str = "po
)

logger.debug(f"Query Prompt:\n{query_str}")

table_context_dict = {str(table_name[0]).lower(): str(additional_context).lower()}
self.context_builder = SQLContextContainerBuilder(self.sql_database, context_dict=table_context_dict)

Expand Down

0 comments on commit 569c8a9

Please sign in to comment.