From a064450451cd11c28afb925fd7786295b2881910 Mon Sep 17 00:00:00 2001 From: pramitchoudhary Date: Mon, 29 May 2023 19:53:26 -0700 Subject: [PATCH] Extract entity from Qs n fix tracking #1 --- sidekick/memory.py | 49 +++++++++++++++++++++++++++++--------------- sidekick/prompter.py | 4 ++-- sidekick/query.py | 6 ++---- 3 files changed, 36 insertions(+), 23 deletions(-) diff --git a/sidekick/memory.py b/sidekick/memory.py index 4d06372..e23b299 100644 --- a/sidekick/memory.py +++ b/sidekick/memory.py @@ -1,5 +1,7 @@ import json -from typing import Dict, List +import re +from pathlib import Path +from typing import Dict, List, Tuple # Reference: https://python.langchain.com/en/latest/modules/memory/examples/custom_memory.html @@ -10,15 +12,14 @@ def __init__(self, k, path: str = None): self.track_entity: List = [] self.path = path - def extract_entity(): + def extract_entity(self, question: str, answer: str) -> Tuple[List, List]: # Currently, anything tagged between below tags are extracted respectively, # 1. From Input text: some key # 2. From Output text: some key - { - "CPU": "payload->'engineEvent'-> 'pausing' -> 'engine'->> 'cpu'", - "global usage DAI instance": "GREATEST((gpu * 4.0)) as ai_units", - "global usage": "GREATEST((GREATEST((ram_gi / 64.0), (cpu / 8.0)) - gpu), 0) + (gpu * 4.0) as ai_units", - } + # TODO Chat mode for auto extraction of entities + c_k = re.findall(r"(.+?)", question) + c_val = re.findall(r"(.+?)", answer) + return (c_k, c_val) def save_context(self, info: str, extract_context: bool = True) -> Dict: # Construct dictionary to record history @@ -32,23 +33,37 @@ def save_context(self, info: str, extract_context: bool = True) -> Dict: response = " ".join(info.partition(":")[2].split(split_token)[1].partition(":")[2].strip().split()) # TODO add additional guardrails to check if the response is a valid response. # At-least syntactically correct SQL. - if query.strip() and "SELECT".lower() in response.lower(): - chat_history = {"Query": query, "Answer": response} - self.track_history.append(chat_history) - else: - raise ValueError("Response not valid. Please try again.") + # Check if entity extraction is enabled # Add logic for entity extraction + extracted_entity = None if extract_context: - # Generic logic pending - extracted_entity = {"GPU": "payload->'engineEvent'-> 'pausing' -> 'engine'->> 'gpu'"} + _k, _v = self.extract_entity(query, response) + k_v = " ".join(_k) + c_v = ", ".join(_v) + extracted_entity = {k_v: c_v} self.track_entity.append(extracted_entity) + chat_history = {} + if query.strip() and "SELECT".lower() in response.lower(): + chat_history = {"Query": query, "Answer": response, "entity": extracted_entity} + self.track_history.append(chat_history) + else: + raise ValueError("Response not valid. Please try again.") # persist the information for future use res = {"history": self.track_history, "entity": self.track_entity} - with open(f"{self.path}/var/lib/tmp/data/history.jsonl", "a") as outfile: - for entry in self.track_history: - json.dump(entry, outfile) + # Persist added information locally + if chat_history: + with open(f"{self.path}/var/lib/tmp/data/history.jsonl", "a") as outfile: + json.dump(chat_history, outfile) outfile.write("\n") + if extract_context: + # Update context.json file for tracking entities + content_file_path = f"{self.path}/var/lib/tmp/data/context.json" + if Path(content_file_path).exists(): + context_dict = json.load(open(content_file_path, "r")) + context_dict.update(extracted_entity) + with open(content_file_path, "w") as outfile: + json.dump(context_dict, outfile, indent=4, sort_keys=False) return res diff --git a/sidekick/prompter.py b/sidekick/prompter.py index 3654dba..0c235ce 100644 --- a/sidekick/prompter.py +++ b/sidekick/prompter.py @@ -88,7 +88,7 @@ def _add_context(entity_memory: EntityMemory): # Check if user has entered any value if res: try: - entity_memory.save_context(res) + _ = entity_memory.save_context(res) except ValueError as ve: logger.info(f"Not a valid input. Try again") @@ -116,8 +116,8 @@ def update_context(): context_dict = json.load(open(content_file_path, "r")) context_dict[""] = "" in context_dict: diff --git a/sidekick/query.py b/sidekick/query.py index 1b02f7f..fac93f5 100644 --- a/sidekick/query.py +++ b/sidekick/query.py @@ -9,8 +9,7 @@ from configs.prompt_template import DEBUGGING_PROMPT, QUERY_PROMPT, TASK_PROMPT from examples.sample_data import sample_values, samples_queries from langchain import OpenAI -from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex, - LLMPredictor, ServiceContext, SQLDatabase) +from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase from llama_index.indices.struct_store import SQLContextContainerBuilder from loguru import logger from sqlalchemy import create_engine @@ -45,7 +44,7 @@ def update_context_queries(self): new_context_queries = samples_queries history_file = f"{self.path}/var/lib/tmp/data/history.jsonl" if Path(history_file).exists(): - with open(f"{self.path}/var/lib/tmp/data/history.jsonl", "r") as in_file: + with open(history_file, "r") as in_file: for line in in_file: # Format: # """ @@ -56,7 +55,6 @@ def update_context_queries(self): response = json.loads(line)["Answer"] _new_samples = f"""# query: {query}\n# answer: {response}""" new_context_queries.append(_new_samples) - return new_context_queries def _query_tasks(self, question_str, data_info, sample_queries, table_name: list):