Skip to content

Commit

Permalink
Removed redundant re import in mistral; minor linting using pre-commi…
Browse files Browse the repository at this point in the history
…t hooks
  • Loading branch information
andreped committed Feb 23, 2024
1 parent 886e143 commit e6032f6
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions src/vanna/mistral/mistral.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,36 @@
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

from ..base import VannaBase
import re


class Mistral(VannaBase):
def __init__(self, config=None):
if config is None:
raise ValueError("For Mistral, config must be provided with an api_key and model")
raise ValueError(
"For Mistral, config must be provided with an api_key and model"
)

if 'api_key' not in config:
if "api_key" not in config:
raise ValueError("config must contain a Mistral api_key")
if 'model' not in config:

if "model" not in config:
raise ValueError("config must contain a Mistral model")

api_key = config['api_key']
model = config['model']
api_key = config["api_key"]
model = config["model"]
self.client = MistralClient(api_key=api_key)
self.model = model

def system_message(self, message: str) -> any:
return ChatMessage(role="system", content=message)

def user_message(self, message: str) -> any:
return ChatMessage(role="user", content=message)

def assistant_message(self, message: str) -> any:
return ChatMessage(role="assistant", content=message)

def generate_sql(self, question: str, **kwargs) -> str:
# Use the super generate_sql
sql = super().generate_sql(question, **kwargs)
Expand All @@ -42,5 +45,5 @@ def submit_prompt(self, prompt, **kwargs) -> str:
model=self.model,
messages=prompt,
)

return chat_response.choices[0].message.content

0 comments on commit e6032f6

Please sign in to comment.