Skip to content

Commit

Permalink
Merge pull request #287 from kun321/support-claude
Browse files Browse the repository at this point in the history
add support for anthropic claude
  • Loading branch information
zainhoda authored Mar 27, 2024
2 parents c8201f5 + a2575f7 commit 790bfec
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 8 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ mysql = ["PyMySQL"]
bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
duckdb = ["duckdb"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic"]
test = ["tox"]
chromadb = ["chromadb"]
openai = ["openai"]
mistralai = ["mistralai"]
anthropic = ["anthropic"]
gemini = ["google-generativeai"]
marqo = ["marqo"]
Empty file added src/vanna/anthropic/__init__.py
Empty file.
78 changes: 78 additions & 0 deletions src/vanna/anthropic/anthropic_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os

import anthropic

from ..base import VannaBase


class Anthropic_Chat(VannaBase):
def __init__(self, client=None, config=None):
VannaBase.__init__(self, config=config)

if client is not None:
self.client = client
return

if config is None and client is None:
self.client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
return

# default parameters - can be overrided using config
self.temperature = 0.7
self.max_tokens = 500

if "temperature" in config:
self.temperature = config["temperature"]

if "max_tokens" in config:
self.max_tokens = config["max_tokens"]

if "api_key" in config:
self.client = anthropic.Anthropic(api_key=config["api_key"])

def system_message(self, message: str) -> any:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}

def submit_prompt(self, prompt, **kwargs) -> str:
if prompt is None:
raise Exception("Prompt is None")

if len(prompt) == 0:
raise Exception("Prompt is empty")

# Count the number of tokens in the message log
# Use 4 as an approximation for the number of characters per token
num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4

if self.config is not None and "model" in self.config:
print(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
# claude required system message is a single filed
# https://docs.anthropic.com/claude/reference/messages_post
system_message = ''
no_system_prompt = []
for prompt_message in prompt:
role = prompt_message['role']
if role == 'system':
system_message = prompt_message['content']
else:
no_system_prompt.append({"role": role, "content": prompt_message['content']})

response = self.client.messages.create(
model=self.config["model"],
messages=no_system_prompt,
system=system_message,
max_tokens=self.max_tokens,
temperature=self.temperature,
)

return response.content[0].text
33 changes: 26 additions & 7 deletions tests/test_vanna.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from vanna.openai.openai_chat import OpenAI_Chat
from vanna.vannadb.vannadb_vector import VannaDB_VectorStore
import os

from vanna.anthropic.anthropic_chat import Anthropic_Chat
from vanna.mistral.mistral import Mistral
from vanna.openai.openai_chat import OpenAI_Chat
from vanna.remote import VannaDefault


import os
from vanna.vannadb.vannadb_vector import VannaDB_VectorStore

try:
print("Trying to load .env")
Expand All @@ -15,9 +15,11 @@
pass

MY_VANNA_MODEL = 'chinook'
ANTHROPIC_Model = 'claude-3-sonnet-20240229'
MY_VANNA_API_KEY = os.environ['VANNA_API_KEY']
OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
MISTRAL_API_KEY = os.environ['MISTRAL_API_KEY']
ANTHROPIC_API_KEY = os.environ['ANTHROPIC_API_KEY']

class VannaOpenAI(VannaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
Expand Down Expand Up @@ -53,8 +55,9 @@ def test_vn_default():
df = vn_default.run_sql(sql)
assert len(df) == 6

from vanna.openai.openai_chat import OpenAI_Chat
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
from vanna.openai.openai_chat import OpenAI_Chat


class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
Expand All @@ -72,4 +75,20 @@ def test_vn_chroma():

sql = vn_chroma.generate_sql("What are the top 7 customers by sales?")
df = vn_chroma.run_sql(sql)
assert len(df) == 7
assert len(df) == 7


class VannaClaude(VannaDB_VectorStore, Anthropic_Chat):
def __init__(self, config=None):
VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config)
Anthropic_Chat.__init__(self, config={'api_key': ANTHROPIC_API_KEY, 'model': ANTHROPIC_Model})


vn_claude = VannaClaude()
vn_claude.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')


def test_vn_claude():
sql = vn_claude.generate_sql("What are the top 5 customers by sales?")
df = vn_claude.run_sql(sql)
assert len(df) == 5

0 comments on commit 790bfec

Please sign in to comment.