Skip to content

Commit

Permalink
Merge pull request #259 from andreped/ollama-fix
Browse files Browse the repository at this point in the history
Moved Ollama implementation to ollama.py; removed redundant import in mistral.py; minor linting in both
  • Loading branch information
zainhoda authored Mar 20, 2024
2 parents 919320f + b82a960 commit c8201f5
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 86 deletions.
25 changes: 14 additions & 11 deletions src/vanna/mistral/mistral.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,36 @@
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

from ..base import VannaBase
import re


class Mistral(VannaBase):
def __init__(self, config=None):
if config is None:
raise ValueError("For Mistral, config must be provided with an api_key and model")
raise ValueError(
"For Mistral, config must be provided with an api_key and model"
)

if 'api_key' not in config:
if "api_key" not in config:
raise ValueError("config must contain a Mistral api_key")
if 'model' not in config:

if "model" not in config:
raise ValueError("config must contain a Mistral model")

api_key = config['api_key']
model = config['model']
api_key = config["api_key"]
model = config["model"]
self.client = MistralClient(api_key=api_key)
self.model = model

def system_message(self, message: str) -> any:
return ChatMessage(role="system", content=message)

def user_message(self, message: str) -> any:
return ChatMessage(role="user", content=message)

def assistant_message(self, message: str) -> any:
return ChatMessage(role="assistant", content=message)

def generate_sql(self, question: str, **kwargs) -> str:
# Use the super generate_sql
sql = super().generate_sql(question, **kwargs)
Expand All @@ -42,5 +45,5 @@ def submit_prompt(self, prompt, **kwargs) -> str:
model=self.model,
messages=prompt,
)

return chat_response.choices[0].message.content
75 changes: 0 additions & 75 deletions src/vanna/ollama/__init__.py
Original file line number Diff line number Diff line change
@@ -1,75 +0,0 @@
from ..base import VannaBase
import requests
import json
import re

class Ollama(VannaBase):
def __init__(self, config=None):
if config is None or 'ollama_host' not in config:
self.host = "http://localhost:11434"
else:
self.host = config['ollama_host']

if config is None or 'model' not in config:
raise ValueError("config must contain a Ollama model")
else:
self.model = config['model']

def system_message(self, message: str) -> any:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}

def extract_sql_query(self, text):
"""
Extracts the first SQL statement after the word 'select', ignoring case,
matches until the first semicolon, three backticks, or the end of the string,
and removes three backticks if they exist in the extracted string.
Args:
- text (str): The string to search within for an SQL statement.
Returns:
- str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
"""
# Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
pattern = re.compile(r'select.*?(?:;|```|$)', re.IGNORECASE | re.DOTALL)

match = pattern.search(text)
if match:
# Remove three backticks from the matched string if they exist
return match.group(0).replace('```', '')
else:
return text

def generate_sql(self, question: str, **kwargs) -> str:
# Use the super generate_sql
sql = super().generate_sql(question, **kwargs)

# Replace "\_" with "_"
sql = sql.replace("\\_", "_")

sql = sql.replace("\\", "")

return self.extract_sql_query(sql)

def submit_prompt(self, prompt, **kwargs) -> str:
url = f"{self.host}/api/chat"
data = {
"model": self.model,
"stream": False,
"messages": prompt,
}

response = requests.post(url, json=data)

response_dict = response.json()

self.log(response.text)

return response_dict['message']['content']

77 changes: 77 additions & 0 deletions src/vanna/ollama/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import json
import re

import requests

from ..base import VannaBase


class Ollama(VannaBase):
def __init__(self, config=None):
if config is None or "ollama_host" not in config:
self.host = "http://localhost:11434"
else:
self.host = config["ollama_host"]

if config is None or "model" not in config:
raise ValueError("config must contain a Ollama model")
else:
self.model = config["model"]

def system_message(self, message: str) -> any:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}

def extract_sql_query(self, text):
"""
Extracts the first SQL statement after the word 'select', ignoring case,
matches until the first semicolon, three backticks, or the end of the string,
and removes three backticks if they exist in the extracted string.
Args:
- text (str): The string to search within for an SQL statement.
Returns:
- str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
"""
# Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)

match = pattern.search(text)
if match:
# Remove three backticks from the matched string if they exist
return match.group(0).replace("```", "")
else:
return text

def generate_sql(self, question: str, **kwargs) -> str:
# Use the super generate_sql
sql = super().generate_sql(question, **kwargs)

# Replace "\_" with "_"
sql = sql.replace("\\_", "_")

sql = sql.replace("\\", "")

return self.extract_sql_query(sql)

def submit_prompt(self, prompt, **kwargs) -> str:
url = f"{self.host}/api/chat"
data = {
"model": self.model,
"stream": False,
"messages": prompt,
}

response = requests.post(url, json=data)

response_dict = response.json()

self.log(response.text)

return response_dict["message"]["content"]

0 comments on commit c8201f5

Please sign in to comment.