From 886e143f080599f8b43602d48464737a07612503 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Fri, 23 Feb 2024 15:35:55 +0100 Subject: [PATCH 1/2] Moved Ollama class from __init__.py to ollama.py; linted code --- src/vanna/ollama/__init__.py | 75 ----------------------------------- src/vanna/ollama/ollama.py | 77 ++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 75 deletions(-) diff --git a/src/vanna/ollama/__init__.py b/src/vanna/ollama/__init__.py index d0aee460..e69de29b 100644 --- a/src/vanna/ollama/__init__.py +++ b/src/vanna/ollama/__init__.py @@ -1,75 +0,0 @@ -from ..base import VannaBase -import requests -import json -import re - -class Ollama(VannaBase): - def __init__(self, config=None): - if config is None or 'ollama_host' not in config: - self.host = "http://localhost:11434" - else: - self.host = config['ollama_host'] - - if config is None or 'model' not in config: - raise ValueError("config must contain a Ollama model") - else: - self.model = config['model'] - - def system_message(self, message: str) -> any: - return {"role": "system", "content": message} - - def user_message(self, message: str) -> any: - return {"role": "user", "content": message} - - def assistant_message(self, message: str) -> any: - return {"role": "assistant", "content": message} - - def extract_sql_query(self, text): - """ - Extracts the first SQL statement after the word 'select', ignoring case, - matches until the first semicolon, three backticks, or the end of the string, - and removes three backticks if they exist in the extracted string. - - Args: - - text (str): The string to search within for an SQL statement. - - Returns: - - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found. - """ - # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string - pattern = re.compile(r'select.*?(?:;|```|$)', re.IGNORECASE | re.DOTALL) - - match = pattern.search(text) - if match: - # Remove three backticks from the matched string if they exist - return match.group(0).replace('```', '') - else: - return text - - def generate_sql(self, question: str, **kwargs) -> str: - # Use the super generate_sql - sql = super().generate_sql(question, **kwargs) - - # Replace "\_" with "_" - sql = sql.replace("\\_", "_") - - sql = sql.replace("\\", "") - - return self.extract_sql_query(sql) - - def submit_prompt(self, prompt, **kwargs) -> str: - url = f"{self.host}/api/chat" - data = { - "model": self.model, - "stream": False, - "messages": prompt, - } - - response = requests.post(url, json=data) - - response_dict = response.json() - - self.log(response.text) - - return response_dict['message']['content'] - diff --git a/src/vanna/ollama/ollama.py b/src/vanna/ollama/ollama.py index e69de29b..8644fbdf 100644 --- a/src/vanna/ollama/ollama.py +++ b/src/vanna/ollama/ollama.py @@ -0,0 +1,77 @@ +import json +import re + +import requests + +from ..base import VannaBase + + +class Ollama(VannaBase): + def __init__(self, config=None): + if config is None or "ollama_host" not in config: + self.host = "http://localhost:11434" + else: + self.host = config["ollama_host"] + + if config is None or "model" not in config: + raise ValueError("config must contain a Ollama model") + else: + self.model = config["model"] + + def system_message(self, message: str) -> any: + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + return {"role": "assistant", "content": message} + + def extract_sql_query(self, text): + """ + Extracts the first SQL statement after the word 'select', ignoring case, + matches until the first semicolon, three backticks, or the end of the string, + and removes three backticks if they exist in the extracted string. + + Args: + - text (str): The string to search within for an SQL statement. + + Returns: + - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found. + """ + # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string + pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL) + + match = pattern.search(text) + if match: + # Remove three backticks from the matched string if they exist + return match.group(0).replace("```", "") + else: + return text + + def generate_sql(self, question: str, **kwargs) -> str: + # Use the super generate_sql + sql = super().generate_sql(question, **kwargs) + + # Replace "\_" with "_" + sql = sql.replace("\\_", "_") + + sql = sql.replace("\\", "") + + return self.extract_sql_query(sql) + + def submit_prompt(self, prompt, **kwargs) -> str: + url = f"{self.host}/api/chat" + data = { + "model": self.model, + "stream": False, + "messages": prompt, + } + + response = requests.post(url, json=data) + + response_dict = response.json() + + self.log(response.text) + + return response_dict["message"]["content"] From e6032f660d8d34e5f013d3e3849edf928c2d7efb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Pedersen?= Date: Fri, 23 Feb 2024 15:37:00 +0100 Subject: [PATCH 2/2] Removed redundant re import in mistral; minor linting using pre-commit hooks --- src/vanna/mistral/mistral.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/vanna/mistral/mistral.py b/src/vanna/mistral/mistral.py index 627c46e7..f5c89ac8 100644 --- a/src/vanna/mistral/mistral.py +++ b/src/vanna/mistral/mistral.py @@ -1,33 +1,36 @@ from mistralai.client import MistralClient from mistralai.models.chat_completion import ChatMessage + from ..base import VannaBase -import re + class Mistral(VannaBase): def __init__(self, config=None): if config is None: - raise ValueError("For Mistral, config must be provided with an api_key and model") + raise ValueError( + "For Mistral, config must be provided with an api_key and model" + ) - if 'api_key' not in config: + if "api_key" not in config: raise ValueError("config must contain a Mistral api_key") - - if 'model' not in config: + + if "model" not in config: raise ValueError("config must contain a Mistral model") - api_key = config['api_key'] - model = config['model'] + api_key = config["api_key"] + model = config["model"] self.client = MistralClient(api_key=api_key) self.model = model def system_message(self, message: str) -> any: return ChatMessage(role="system", content=message) - + def user_message(self, message: str) -> any: return ChatMessage(role="user", content=message) - + def assistant_message(self, message: str) -> any: return ChatMessage(role="assistant", content=message) - + def generate_sql(self, question: str, **kwargs) -> str: # Use the super generate_sql sql = super().generate_sql(question, **kwargs) @@ -42,5 +45,5 @@ def submit_prompt(self, prompt, **kwargs) -> str: model=self.model, messages=prompt, ) - + return chat_response.choices[0].message.content