diff --git a/src/vanna/mistral/mistral.py b/src/vanna/mistral/mistral.py index 627c46e7..f5c89ac8 100644 --- a/src/vanna/mistral/mistral.py +++ b/src/vanna/mistral/mistral.py @@ -1,33 +1,36 @@ from mistralai.client import MistralClient from mistralai.models.chat_completion import ChatMessage + from ..base import VannaBase -import re + class Mistral(VannaBase): def __init__(self, config=None): if config is None: - raise ValueError("For Mistral, config must be provided with an api_key and model") + raise ValueError( + "For Mistral, config must be provided with an api_key and model" + ) - if 'api_key' not in config: + if "api_key" not in config: raise ValueError("config must contain a Mistral api_key") - - if 'model' not in config: + + if "model" not in config: raise ValueError("config must contain a Mistral model") - api_key = config['api_key'] - model = config['model'] + api_key = config["api_key"] + model = config["model"] self.client = MistralClient(api_key=api_key) self.model = model def system_message(self, message: str) -> any: return ChatMessage(role="system", content=message) - + def user_message(self, message: str) -> any: return ChatMessage(role="user", content=message) - + def assistant_message(self, message: str) -> any: return ChatMessage(role="assistant", content=message) - + def generate_sql(self, question: str, **kwargs) -> str: # Use the super generate_sql sql = super().generate_sql(question, **kwargs) @@ -42,5 +45,5 @@ def submit_prompt(self, prompt, **kwargs) -> str: model=self.model, messages=prompt, ) - + return chat_response.choices[0].message.content