Skip to content

Commit

Permalink
Pending h2ogpt integration for vulnerability scanning & Q generation #4
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed Feb 6, 2024
1 parent 5a19a4b commit dafce22
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 43 deletions.
3 changes: 3 additions & 0 deletions sidekick/configs/env.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ H2OGPT_URL = 'http://38.128.233.247'
H2OGPT_API_TOKEN = ""
H2OGPTE_URL = ""
H2OGPTE_API_TOKEN = ""
H2OGPT_BASE_URL = "https://api.gpt.h2o.ai/v1"
H2OGPT_BASE_API_TOKEN = ""


RECOMMENDATION_MODEL = "h2oai/h2ogpt-4096-llama2-70b-chat"
VULNERABILITY_SCANNER = "h2oai/h2ogpt-4096-llama2-70b-chat" # other options openai models depending on availability (e.g. 'gpt-3.5-turbo')
Expand Down
38 changes: 22 additions & 16 deletions sidekick/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from colorama import Back as B
from colorama import Fore as F
from colorama import Style
from dotenv import load_dotenv
from pandasql import sqldf
from sidekick.db_config import DBConfig
from sidekick.logger import logger
Expand All @@ -34,23 +35,33 @@
model_name = env_settings["MODEL_INFO"]["MODEL_NAME"]
h2o_remote_url = env_settings["MODEL_INFO"]["H2OGPTE_URL"]
h2o_key = env_settings["MODEL_INFO"]["H2OGPTE_API_TOKEN"]
# h2ogpt base model urls
h2ogpt_base_model_url = env_settings["MODEL_INFO"]["H2OGPT_URL"]
h2ogpt_base_model_key = env_settings["MODEL_INFO"]["H2OGPT_API_TOKEN"]
# h2ogpt code-sql model urls
h2ogpt_sql_model_url = env_settings["MODEL_INFO"]["H2OGPT_URL"]
h2ogpt_sql_model_key = env_settings["MODEL_INFO"]["H2OGPT_API_TOKEN"]

h2ogpt_base_model_url = env_settings["MODEL_INFO"]["H2OGPT_BASE_URL"]
h2ogpt_base_model_key = env_settings["MODEL_INFO"]["H2OGPT_BASE_API_TOKEN"]

self_correction_model = env_settings["MODEL_INFO"]["SELF_CORRECTION_MODEL"]
recommendation_model = env_settings["MODEL_INFO"]['RECOMMENDATION_MODEL']

# Load .env file
load_dotenv()

os.environ["TOKENIZERS_PARALLELISM"] = "False"
# Env variables
if not os.getenv("H2OGPT_URL"):
os.environ["H2OGPT_URL"] = h2ogpt_base_model_url
os.environ["H2OGPT_URL"] = h2ogpt_sql_model_url
if not os.getenv("H2OGPT_API_TOKEN"):
os.environ["H2OGPT_API_TOKEN"] = h2ogpt_base_model_key
os.environ["H2OGPT_API_TOKEN"] = h2ogpt_sql_model_key
if not os.getenv("H2OGPTE_URL"):
os.environ["H2OGPTE_URL"] = h2o_remote_url
if not os.getenv("H2OGPTE_API_TOKEN"):
os.environ["H2OGPTE_API_TOKEN"] = h2o_key
if not os.getenv("H2OGPT_BASE_URL"):
os.environ["H2OGPT_BASE_URL"] = h2ogpt_base_model_url
if not os.getenv("H2OGPT_BASE_API_TOKEN"):
os.environ["H2OGPT_BASE_API_TOKEN"] = h2ogpt_base_model_url
if not os.getenv("SELF_CORRECTION_MODEL"):
os.environ["SELF_CORRECTION_MODEL"] = self_correction_model
if not os.getenv("RECOMMENDATION_MODEL"):
Expand Down Expand Up @@ -163,17 +174,12 @@ def recommend_suggestions(cache_path: str, table_name: str, n_qs: int=10):
r_url = _key = None
# First check for keys in env variables
logger.debug(f"Checking environment settings ...")
env_url = os.environ["H2OGPTE_URL"]
env_key = os.environ["H2OGPTE_API_TOKEN"]
if env_url and env_key:
r_url = env_url
_key = env_key
elif Path(f"{app_base_path}/sidekick/configs/env.toml").exists():
# Reload .env info
logger.debug(f"Checking configuration file ...")
env_settings = toml.load(f"{app_base_path}/sidekick/configs/env.toml")
r_url = env_settings["MODEL_INFO"]["H2OGPTE_URL"]
_key = env_settings["MODEL_INFO"]["H2OGPTE_API_TOKEN"]
r_url = os.getenv("H2OGPT_API_TOKEN", None)
_key = os.getenv("H2OGPTE_API_TOKEN", None)
if not r_url or not _key:
logger.info(f"H2OGPTE client is not configured, attempting to use OSS H2OGPT client")
r_url = os.getenv("H2OGPT_BASE_URL", None)
_key = os.getenv("H2OGPT_BASE_API_TOKEN", None)
else:
raise Exception("Model url or key is missing.")

Expand Down
48 changes: 35 additions & 13 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,30 @@ def self_correction(self, error_msg, input_query, remote_url, client_key):
_res = input_query
self_correction_model = os.getenv("SELF_CORRECTION_MODEL", "h2oai/h2ogpt-4096-llama2-70b-chat")
if "h2ogpt-" in self_correction_model:
from h2ogpte import H2OGPTE
client = H2OGPTE(address=remote_url, api_key=client_key)
text_completion = client.answer_question(
system_prompt=system_prompt,
text_context_list=[],
question=user_prompt,
llm=self_correction_model)
if remote_url and client_key:
try:
from h2ogpte import H2OGPTE
client = H2OGPTE(address=remote_url, api_key=client_key)
text_completion = client.answer_question(
system_prompt=system_prompt,
text_context_list=[],
question=user_prompt,
llm=self_correction_model)
except Exception as e:
logger.info(f"H2OGPTE client is not configured, reach out if API key is needed, {e}. Attempting to use H2OGPT client")
# Make attempt to use h2ogpt client with OSS access
_api_key = client_key if client_key else "***"
client_args = dict(base_url=remote_url, api_key=_api_key, timeout=20.0)
query_msg = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
h2ogpt_base_client = OpenAI(**client_args)
h2ogpt_base_client.with_options(max_retries=3).chat.completions.create(
model=self_correction_model,
messages=query_msg,
max_tokens=512,
temperature=0.5,
stop="```",
seed=42)
text_completion = completion.choices[0].message
_response = text_completion.content
elif 'gpt-3.5' in self_correction_model.lower() or 'gpt-4' in self_correction_model.lower():
# Check if the API key is set, else inform user
Expand All @@ -268,7 +285,7 @@ def self_correction(self, error_msg, input_query, remote_url, client_key):
)
_response = completion.choices[0].message.content
else:
raise ValueError(f"Invalid model name: {self_correction_model}")
raise ValueError(f"Invalid request for: {self_correction_model}")

_response = _response.split("```sql")
_idx = [_response.index(_r) for _r in _response if _r.lower().strip().startswith("select")]
Expand Down Expand Up @@ -775,19 +792,24 @@ def generate_sql(
# Validate the generate SQL for parsing errors, along with dialect specific validation
# Note: Doesn't do well with handling date-time conversions
# e.g.
# sqlite: SELECT DATETIME(MAX(timestamp), '-5 minute') FROM demo WHERE isin_id = 'VM88109EGG92'
# postgres: SELECT MAX(timestamp) - INTERVAL '5 minutes' FROM demo where isin_id='VM88109EGG92'
# sqlite: SELECT DATETIME(MAX(timestamp), '-5 minute') FROM demo WHERE isin_id = 'VM123'
# postgres: SELECT MAX(timestamp) - INTERVAL '5 minutes' FROM demo where isin_id='VM123'
# Reference ticket: https://github.com/tobymao/sqlglot/issues/2011
result = res
try:
result = sqlglot.transpile(res, identify=True, write=self.dialect)[0] if res else None
except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e:
_, ex_value, ex_traceback = sys.exc_info()
logger.info(f"Attempting to fix syntax error ...,\n {e}")
env_url = os.environ["H2OGPTE_URL"]
env_key = os.environ["H2OGPTE_API_TOKEN"]

h2o_client_url = os.getenv("H2OGPT_API_TOKEN", None)
h2o_client_key = os.getenv("H2OGPTE_API_TOKEN", None)
if not h2o_client_url or not h2o_client_key:
logger.info(f"H2OGPTE client is not configured, attempting to use OSS H2OGPT client")
h2o_client_url = os.getenv("H2OGPT_BASE_URL", None)
h2o_client_key = os.getenv("H2OGPT_BASE_API_TOKEN", None)
try:
result = self.self_correction(input_query=res, error_msg=str(ex_traceback), remote_url=env_url, client_key=env_key)
result = self.self_correction(input_query=res, error_msg=str(ex_traceback), remote_url=h2o_client_url, client_key=h2o_client_key)
except Exception as se:
# Another exception occurred, return the original SQL
logger.info(f"We did the best we could to fix syntactical error, there might be still be some issues:\n {se}")
Expand Down
60 changes: 46 additions & 14 deletions sidekick/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,13 @@ def check_vulnerability(input_query: str):
# Step 2 is optional, if remote url is provided, check for SQL injection patterns in the generated SQL code via LLM
# Currently, only support only for models as an endpoints
logger.debug(f"Requesting additional scan using configured models")
remote_url = os.environ["H2OGPTE_URL"]
api_key = os.environ["H2OGPTE_API_TOKEN"]

h2ogpt_client_url = h2ogpt_client_key = None
h2ogpte_client_url = os.getenv("H2OGPT_API_TOKEN", None)
h2ogpte_client_key = os.getenv("H2OGPTE_API_TOKEN", None)
if not h2ogpte_client_url or not h2ogpte_client_key:
logger.info(f"H2OGPTE client is not configured, attempting to use OSS H2OGPT client")
h2ogpt_client_url = os.getenv("H2OGPT_BASE_URL", None)
h2ogpt_client_key = os.getenv("H2OGPT_BASE_API_TOKEN", None)
_system_prompt = GUARDRAIL_PROMPT["system_prompt"].strip()
output_schema = """{
"type": "object",
Expand All @@ -567,15 +571,28 @@ def check_vulnerability(input_query: str):
temp_result = None
try:
llm_scanner = os.getenv("VULNERABILITY_SCANNER", "h2oai/h2ogpt-4096-llama2-70b-chat")
if "h2ogpt-" in llm_scanner:
if "h2ogpt-" in llm_scanner and h2ogpte_client_url !='' and h2ogpte_client_url and h2ogpte_client_key != '' and h2ogpte_client_key:
from h2ogpte import H2OGPTE
client = H2OGPTE(address=remote_url, api_key=api_key)
client = H2OGPTE(address=h2ogpte_client_url, api_key=h2ogpte_client_key)
text_completion = client.answer_question(
system_prompt=_system_prompt,
text_context_list=[],
question=_user_prompt,
llm=llm_scanner)
generated_res = text_completion.content.split("\n\n")[0]
elif h2ogpt_client_url:
_api_key = h2ogpt_client_key if h2ogpt_client_key else "EMPTY"
client_args = dict(base_url=h2ogpt_client_url, api_key=_api_key, timeout=20.0)
query_msg = [{"role": "system", "content": _system_prompt}, {"role": "user", "content": _user_prompt}]
h2ogpt_base_client = OpenAI(**client_args)
completion = h2ogpt_base_client.with_options(max_retries=3).chat.completions.create(
model=llm_scanner,
messages=query_msg,
max_tokens=512,
temperature=0.5,
stop="```",
seed=42)
generated_res = completion.choices[0].message.content.split("\n\n")[0]
elif 'gpt-3.5' in llm_scanner.lower() or 'gpt-4' in llm_scanner.lower():
# Check if the API key is set, else inform user
query_msg = [{"role": "system", "content": _system_prompt}, {"role": "user", "content": _user_prompt}]
Expand All @@ -590,7 +607,7 @@ def check_vulnerability(input_query: str):
)
generated_res = completion.choices[0].message.content
else:
raise ValueError(f"Invalid model name: {llm_scanner}")
raise ValueError(f"Invalid request for: {llm_scanner}")

_res = generated_res.strip()
temp_result = json.loads(_res) if _res else None
Expand All @@ -615,17 +632,32 @@ def generate_suggestions(remote_url, client_key:str, column_names: list, n_qs: i
results = "Currently not supported or remote API key is missing."
else:
column_info = ','.join(column_names)
input_prompt = RECOMMENDATION_PROMPT.format(data_schema=column_info, n_questions=n_qs
_system_prompt = f"Act as a data analyst, based on below data schema help answer the question"
_user_prompt = RECOMMENDATION_PROMPT.format(data_schema=column_info, n_questions=n_qs
)

recommender_model = os.getenv("RECOMMENDATION_MODEL", "h2oai/h2ogpt-4096-llama2-70b-chat")
client = H2OGPTE(address=remote_url, api_key=client_key)
text_completion = client.answer_question(
system_prompt=f"Act as a data analyst, based on below data schema help answer the question",
text_context_list=[],
question=input_prompt,
llm=recommender_model
)
try:
client = H2OGPTE(address=remote_url, api_key=client_key)
text_completion = client.answer_question(
system_prompt=_system_prompt,
text_context_list=[],
question=_user_prompt,
llm=recommender_model
)
except Exception as e:
logger.info(f"H2OGPTE client is not configured, reach out if API key is needed. {e}. Attempting to use H2OGPT client")
# Make attempt to use h2ogpt client with OSS access
client_args = dict(base_url=remote_url, api_key=client_key, timeout=20.0)
query_msg = [{"role": "system", "content": _system_prompt}, {"role": "user", "content": _user_prompt}]
h2ogpt_base_client = OpenAI(**client_args)
completion = h2ogpt_base_client.with_options(max_retries=3).chat.completions.create(
model=recommender_model,
messages=query_msg,
max_tokens=512,
temperature=0.5,
seed=42)
text_completion = completion.choices[0].message
_res = text_completion.content.split("\n")[2:]
results = "\n".join(_res)
return results

0 comments on commit dafce22

Please sign in to comment.