Skip to content

Commit

Permalink
Merge pull request #63 from vanna-ai/arslan/model-name-formatting
Browse files Browse the repository at this point in the history
Add rules for model name
  • Loading branch information
arslanhashmi authored Aug 2, 2023
2 parents feb8443 + a9bafd5 commit ea36329
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 47 deletions.
8 changes: 4 additions & 4 deletions src/vanna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
from typing import List, Union, Callable, Tuple
from .exceptions import ImproperlyConfigured, DependencyError, ConnectionError, OTPCodeError, SQLRemoveError, \
ValidationError, APIError
from .utils import validate_config_path
from .utils import validate_config_path, sanitize_model_name
import warnings
import traceback
import os
Expand Down Expand Up @@ -280,7 +280,7 @@ def create_model(model: str, db_type: str) -> bool:
global __org
if __org is None:
__org = 'demo-tpc-h'

model = sanitize_model_name(model)
params = [NewOrganization(org_name=model, db_type=db_type)]

d = __rpc_call(method="create_org", params=params)
Expand Down Expand Up @@ -407,7 +407,7 @@ def set_model(model: str):
model = env_model
else:
raise ValidationError("Please replace 'my-model' with the name of your model")

dataset = sanitize_model_name(model)
_set_org(org=model)


Expand Down Expand Up @@ -916,7 +916,7 @@ def remove_sql(question: str) -> bool:
if 'result' not in d:
raise Exception(f"Error removing SQL")
return False

status = Status(**d['result'])

if not status.success:
Expand Down
34 changes: 32 additions & 2 deletions src/vanna/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re

from .exceptions import ImproperlyConfigured
from .exceptions import ImproperlyConfigured, ValidationError


def validate_config_path(path):
Expand All @@ -17,4 +18,33 @@ def validate_config_path(path):
if not os.access(path, os.R_OK):
raise ImproperlyConfigured(
f'Cannot read the config file. Please grant read privileges: {path}'
)
)


def sanitize_model_name(model_name):
try:
model_name = model_name.lower()

# Replace spaces with a hyphen
model_name = model_name.replace(" ", "-")

if '-' in model_name:

# remove double hyphones
model_name = re.sub(r"-+", "-", model_name)
if '_' in model_name:
# If name contains both underscores and hyphen replace all underscores with hyphens
model_name = re.sub(r'_', '-', model_name)

# Remove special characters only allow underscore
model_name = re.sub(r"[^a-zA-Z0-9-_]", "", model_name)

# Remove hyphen or underscore if any at the last or first
if model_name[-1] in ("-", "_"):
model_name = model_name[:-1]
if model_name[0] in ("-", "_"):
model_name = model_name[1:]

return model_name
except Exception as e:
raise ValidationError(e)
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
"question":"Which 10 domains received the highest amount of traffic on Black Friday in 2021 vs 2020",
"answer":"SELECT domain,\n sum(case when date = '2021-11-26' then total_visits\n else 0 end) as visits_2021,\n sum(case when date = '2020-11-27' then total_visits\n else 0 end) as visits_2020\nFROM s__p_500_by_domain_and_aggregated_by_tickers_sample.datafeeds.sp_500\nWHERE date in ('2021-11-26', '2020-11-27')\nGROUP BY domain\nORDER BY (visits_2021 - visits_2020) desc limit 10"
}
]
]
1 change: 1 addition & 0 deletions tests/fixtures/sql/testSqlCreate.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE employees (id INT, name VARCHAR(255), salary INT);
1 change: 1 addition & 0 deletions tests/fixtures/sql/testSqlSelect.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SELECT * FROM students WHERE name = 'Jane Doe';
1 change: 0 additions & 1 deletion tests/test_files/sql/testSqlCreate.sql

This file was deleted.

1 change: 0 additions & 1 deletion tests/test_files/sql/testSqlSelect.sql

This file was deleted.

153 changes: 115 additions & 38 deletions tests/test_vanna.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,20 @@ def test_create_user1(monkeypatch):

assert models == ['demo-tpc-h']

def test_create_model():
rv = vn.create_model(model='test_org', db_type='Snowflake')

@pytest.mark.parametrize("model_name", ["Test @Org_"])
def test_create_model(model_name):
rv = vn.create_model(model=model_name, db_type='Snowflake')
assert rv == True

models = vn.get_models()
assert 'test-org' in models


def test_is_user1_in_model():
rv = vn.get_models()
assert rv == ['demo-tpc-h', 'test_org']
assert rv == ['demo-tpc-h', 'test-org']


def test_is_user2_in_model(monkeypatch):
switch_to_user('user2', monkeypatch)
Expand All @@ -56,23 +63,23 @@ def test_switch_back_to_user1(monkeypatch):
switch_to_user('user1', monkeypatch)

models = vn.get_models()
assert models == ['demo-tpc-h', 'test_org']
assert models == ['demo-tpc-h', 'test-org']

def test_set_model_my_model():
with pytest.raises(ValidationError):
vn.set_model('my-model')

def test_set_model():
vn.set_model('test_org')
assert vn.__org == 'test_org' # type: ignore
vn.set_model('test-org')
assert vn.__org == 'test-org' # type: ignore

def test_add_user_to_model(monkeypatch):
rv = vn.add_user_to_model(model='test_org', email="[email protected]", is_admin=False)
rv = vn.add_user_to_model(model='test-org', email="[email protected]", is_admin=False)
assert rv == True

switch_to_user('user2', monkeypatch)
models = vn.get_models()
assert models == ['demo-tpc-h', 'test_org']
assert models == ['demo-tpc-h', 'test-org']

def test_update_model_visibility(monkeypatch):
rv = vn.update_model_visibility(public=True)
Expand All @@ -84,7 +91,7 @@ def test_update_model_visibility(monkeypatch):

switch_to_user('user3', monkeypatch)
models = vn.get_models()
assert models == ['demo-tpc-h', 'test_org']
assert models == ['demo-tpc-h', 'test-org']

switch_to_user('user1', monkeypatch)

Expand Down Expand Up @@ -188,7 +195,7 @@ def test_add_sql_pass_fail():

def test_add_documentation_pass(monkeypatch):
switch_to_user('user1', monkeypatch)
vn.set_model('test_org')
vn.set_model('test-org')
rv = vn.add_documentation(documentation="This is the documentation")
assert rv == True

Expand Down Expand Up @@ -216,10 +223,10 @@ def test_remove_training_data():
assert vn.get_training_data().shape[0] == num_training_data-1-index

def test_create_model_and_add_user():
created = vn.create_model('test_org2', 'Snowflake')
created = vn.create_model('test-org2', 'Snowflake')
assert created == True

added = vn.add_user_to_model(model='test_org2', email="[email protected]", is_admin=False)
added = vn.add_user_to_model(model='test-org2', email="[email protected]", is_admin=False)
assert added == True

def test_ask_no_output():
Expand All @@ -240,7 +247,7 @@ def test_generate_meta():
assert meta == 'AI Response'

def test_double_train():
vn.set_model('test_org')
vn.set_model('test-org')

training_data = vn.get_training_data()
assert training_data.shape == (0, 0)
Expand All @@ -256,33 +263,103 @@ def test_double_train():
training_data = vn.get_training_data()
assert training_data.shape == (1, 4)

@pytest.mark.parametrize("sql_file_path, json_file_path, should_work", [
('tests/test_files/sql/testSqlSelect.sql', 'tests/test_files/training/questions.json', True),
('tests/test_files/sql/testSqlCreate.sql', 'tests/test_files/training/questions.json', True),
('tests/test_files/sql/testSql.sql', 'tests/test_files/training/s.json', False),
])
def test_train(sql_file_path, json_file_path, should_work):
# if just question not sql
with pytest.raises(ValidationError):
vn.train(question="What's the data about student John Doe?")

# if just sql
assert vn.train(sql="SELECT * FROM students WHERE name = 'Jane Doe'") == True

# if just sql and documentation=True
assert vn.train(sql="SELECT * FROM students WHERE name = 'Jane Doe'", documentation=True) == True

# if just ddl statement
assert vn.train(ddl="This is the ddl") == True
@pytest.mark.parametrize("params", [
dict(
question=None,
sql="SELECT * FROM students WHERE name = 'Jane Doe'",
documentation=False,
ddl=None,
sql_file=None,
json_file=None,
),
dict(
question=None,
sql="SELECT * FROM students WHERE name = 'Jane Doe'",
documentation=True,
ddl=None,
sql_file=None,
json_file=None,
),
dict(
question=None,
sql=None,
documentation=False,
ddl="This is the ddl",
sql_file=None,
json_file=None,
),
dict(
question=None,
sql=None,
documentation=False,
ddl=None,
sql_file="tests/fixtures/sql/testSqlSelect.sql",
json_file=None,
),
dict(
question=None,
sql=None,
documentation=False,
ddl=None,
sql_file=None,
json_file="tests/fixtures/questions.json"
),
dict(
question=None,
sql=None,
documentation=False,
ddl=None,
sql_file="tests/fixtures/sql/testSqlCreate.sql",
json_file=None,
),
])
def test_train_success(monkeypatch, params):
vn.set_model('test-org')
assert vn.train(**params)


@pytest.mark.parametrize("params, expected_exc_class", [
(
dict(
question="What's the data about student John Doe?",
sql=None,
documentation=False,
ddl=None,
sql_file=None,
json_file=None,
),
ValidationError
),
(
dict(
question=None,
sql=None,
documentation=False,
ddl=None,
sql_file="wrong/path/or/file.sql",
json_file=None,
),
ImproperlyConfigured
),
(
dict(
question=None,
sql=None,
documentation=False,
ddl=None,
sql_file=None,
json_file="wrong/path/or/file.json",
),
ImproperlyConfigured
)
])
def test_train_validations(monkeypatch, params, expected_exc_class):
vn.set_model('test-org')

# if just sql_file
if should_work:
assert vn.train(sql_file=sql_file_path) == True
assert vn.train(json_file=json_file_path) == True
else:
with pytest.raises(ImproperlyConfigured):
vn.train(sql_file=sql_file_path)
vn.train(json_file=json_file_path)
with pytest.raises((ValidationError, ImproperlyConfigured)) as exc:
vn.train(**params)
assert isinstance(exc, expected_exc_class)


@pytest.mark.parametrize('model_name', [1234, ['test_org']])
Expand Down

0 comments on commit ea36329

Please sign in to comment.