Skip to content

Commit

Permalink
Extract entity from Qs n fix tracking #1
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed May 30, 2023
1 parent a4ca238 commit a064450
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 23 deletions.
49 changes: 32 additions & 17 deletions sidekick/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
from typing import Dict, List
import re
from pathlib import Path
from typing import Dict, List, Tuple


# Reference: https://python.langchain.com/en/latest/modules/memory/examples/custom_memory.html
Expand All @@ -10,15 +12,14 @@ def __init__(self, k, path: str = None):
self.track_entity: List = []
self.path = path

def extract_entity():
def extract_entity(self, question: str, answer: str) -> Tuple[List, List]:
# Currently, anything tagged between below tags are extracted respectively,
# 1. From Input text: <context_key> some key </context_key>
# 2. From Output text: <context_value> some key </context_value>
{
"CPU": "payload->'engineEvent'-> 'pausing' -> 'engine'->> 'cpu'",
"global usage DAI instance": "GREATEST((gpu * 4.0)) as ai_units",
"global usage": "GREATEST((GREATEST((ram_gi / 64.0), (cpu / 8.0)) - gpu), 0) + (gpu * 4.0) as ai_units",
}
# TODO Chat mode for auto extraction of entities
c_k = re.findall(r"<key>(.+?)</key>", question)
c_val = re.findall(r"<value>(.+?)</value>", answer)
return (c_k, c_val)

def save_context(self, info: str, extract_context: bool = True) -> Dict:
# Construct dictionary to record history
Expand All @@ -32,23 +33,37 @@ def save_context(self, info: str, extract_context: bool = True) -> Dict:
response = " ".join(info.partition(":")[2].split(split_token)[1].partition(":")[2].strip().split())
# TODO add additional guardrails to check if the response is a valid response.
# At-least syntactically correct SQL.
if query.strip() and "SELECT".lower() in response.lower():
chat_history = {"Query": query, "Answer": response}
self.track_history.append(chat_history)
else:
raise ValueError("Response not valid. Please try again.")

# Check if entity extraction is enabled
# Add logic for entity extraction
extracted_entity = None
if extract_context:
# Generic logic pending
extracted_entity = {"GPU": "payload->'engineEvent'-> 'pausing' -> 'engine'->> 'gpu'"}
_k, _v = self.extract_entity(query, response)
k_v = " ".join(_k)
c_v = ", ".join(_v)
extracted_entity = {k_v: c_v}
self.track_entity.append(extracted_entity)

chat_history = {}
if query.strip() and "SELECT".lower() in response.lower():
chat_history = {"Query": query, "Answer": response, "entity": extracted_entity}
self.track_history.append(chat_history)
else:
raise ValueError("Response not valid. Please try again.")
# persist the information for future use
res = {"history": self.track_history, "entity": self.track_entity}

with open(f"{self.path}/var/lib/tmp/data/history.jsonl", "a") as outfile:
for entry in self.track_history:
json.dump(entry, outfile)
# Persist added information locally
if chat_history:
with open(f"{self.path}/var/lib/tmp/data/history.jsonl", "a") as outfile:
json.dump(chat_history, outfile)
outfile.write("\n")
if extract_context:
# Update context.json file for tracking entities
content_file_path = f"{self.path}/var/lib/tmp/data/context.json"
if Path(content_file_path).exists():
context_dict = json.load(open(content_file_path, "r"))
context_dict.update(extracted_entity)
with open(content_file_path, "w") as outfile:
json.dump(context_dict, outfile, indent=4, sort_keys=False)
return res
4 changes: 2 additions & 2 deletions sidekick/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _add_context(entity_memory: EntityMemory):
# Check if user has entered any value
if res:
try:
entity_memory.save_context(res)
_ = entity_memory.save_context(res)
except ValueError as ve:
logger.info(f"Not a valid input. Try again")

Expand Down Expand Up @@ -116,8 +116,8 @@ def update_context():
context_dict = json.load(open(content_file_path, "r"))
context_dict["<new_context_key>"] = "<new_context_value"
context_str = json.dumps(context_dict, indent=4, sort_keys=True, default=str)

updated_context = click.edit(context_str)
logger.debug(f"Context:\n {updated_context}")
if updated_context:
context_dict = json.loads(updated_context)
if "<new_context_key>" in context_dict:
Expand Down
6 changes: 2 additions & 4 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from configs.prompt_template import DEBUGGING_PROMPT, QUERY_PROMPT, TASK_PROMPT
from examples.sample_data import sample_values, samples_queries
from langchain import OpenAI
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
LLMPredictor, ServiceContext, SQLDatabase)
from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase
from llama_index.indices.struct_store import SQLContextContainerBuilder
from loguru import logger
from sqlalchemy import create_engine
Expand Down Expand Up @@ -45,7 +44,7 @@ def update_context_queries(self):
new_context_queries = samples_queries
history_file = f"{self.path}/var/lib/tmp/data/history.jsonl"
if Path(history_file).exists():
with open(f"{self.path}/var/lib/tmp/data/history.jsonl", "r") as in_file:
with open(history_file, "r") as in_file:
for line in in_file:
# Format:
# """
Expand All @@ -56,7 +55,6 @@ def update_context_queries(self):
response = json.loads(line)["Answer"]
_new_samples = f"""# query: {query}\n# answer: {response}"""
new_context_queries.append(_new_samples)

return new_context_queries

def _query_tasks(self, question_str, data_info, sample_queries, table_name: list):
Expand Down

0 comments on commit a064450

Please sign in to comment.