Skip to content

Commit

Permalink
rename the claude to anthropic
Browse files Browse the repository at this point in the history
  • Loading branch information
kun321 committed Mar 11, 2024
1 parent 077ce4c commit a2575f7
Showing 1 changed file with 58 additions and 58 deletions.
116 changes: 58 additions & 58 deletions tests/test_vanna.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,65 +17,65 @@
MY_VANNA_MODEL = 'chinook'
ANTHROPIC_Model = 'claude-3-sonnet-20240229'
MY_VANNA_API_KEY = os.environ['VANNA_API_KEY']
# OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
# MISTRAL_API_KEY = os.environ['MISTRAL_API_KEY']
OPENAI_API_KEY = os.environ['OPENAI_API_KEY']
MISTRAL_API_KEY = os.environ['MISTRAL_API_KEY']
ANTHROPIC_API_KEY = os.environ['ANTHROPIC_API_KEY']
#
# class VannaOpenAI(VannaDB_VectorStore, OpenAI_Chat):
# def __init__(self, config=None):
# VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config)
# OpenAI_Chat.__init__(self, config=config)
#
# vn_openai = VannaOpenAI(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'})
# vn_openai.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
#
# def test_vn_openai():
# sql = vn_openai.generate_sql("What are the top 4 customers by sales?")
# df = vn_openai.run_sql(sql)
# assert len(df) == 4
#
# class VannaMistral(VannaDB_VectorStore, Mistral):
# def __init__(self, config=None):
# VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config)
# Mistral.__init__(self, config={'api_key': MISTRAL_API_KEY, 'model': 'mistral-tiny'})
#
# vn_mistral = VannaMistral()
# vn_mistral.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
#
# def test_vn_mistral():
# sql = vn_mistral.generate_sql("What are the top 5 customers by sales?")
# df = vn_mistral.run_sql(sql)
# assert len(df) == 5
#
# vn_default = VannaDefault(model=MY_VANNA_MODEL, api_key=MY_VANNA_API_KEY)
# vn_default.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
#
# def test_vn_default():
# sql = vn_default.generate_sql("What are the top 6 customers by sales?")
# df = vn_default.run_sql(sql)
# assert len(df) == 6
#
# from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
# from vanna.openai.openai_chat import OpenAI_Chat
#
#
# class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
# def __init__(self, config=None):
# ChromaDB_VectorStore.__init__(self, config=config)
# OpenAI_Chat.__init__(self, config=config)
#
# vn_chroma = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'})
# vn_chroma.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
#
# def test_vn_chroma():
# df_ddl = vn_chroma.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
#
# for ddl in df_ddl['sql'].to_list():
# vn_chroma.train(ddl=ddl)
#
# sql = vn_chroma.generate_sql("What are the top 7 customers by sales?")
# df = vn_chroma.run_sql(sql)
# assert len(df) == 7

class VannaOpenAI(VannaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config)
OpenAI_Chat.__init__(self, config=config)

vn_openai = VannaOpenAI(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'})
vn_openai.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')

def test_vn_openai():
sql = vn_openai.generate_sql("What are the top 4 customers by sales?")
df = vn_openai.run_sql(sql)
assert len(df) == 4

class VannaMistral(VannaDB_VectorStore, Mistral):
def __init__(self, config=None):
VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config)
Mistral.__init__(self, config={'api_key': MISTRAL_API_KEY, 'model': 'mistral-tiny'})

vn_mistral = VannaMistral()
vn_mistral.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')

def test_vn_mistral():
sql = vn_mistral.generate_sql("What are the top 5 customers by sales?")
df = vn_mistral.run_sql(sql)
assert len(df) == 5

vn_default = VannaDefault(model=MY_VANNA_MODEL, api_key=MY_VANNA_API_KEY)
vn_default.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')

def test_vn_default():
sql = vn_default.generate_sql("What are the top 6 customers by sales?")
df = vn_default.run_sql(sql)
assert len(df) == 6

from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
from vanna.openai.openai_chat import OpenAI_Chat


class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)

vn_chroma = MyVanna(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'})
vn_chroma.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')

def test_vn_chroma():
df_ddl = vn_chroma.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")

for ddl in df_ddl['sql'].to_list():
vn_chroma.train(ddl=ddl)

sql = vn_chroma.generate_sql("What are the top 7 customers by sales?")
df = vn_chroma.run_sql(sql)
assert len(df) == 7


class VannaClaude(VannaDB_VectorStore, Anthropic_Chat):
Expand Down

0 comments on commit a2575f7

Please sign in to comment.