-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #63 from vanna-ai/arslan/model-name-formatting
Add rules for model name
- Loading branch information
Showing
8 changed files
with
154 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
CREATE TABLE employees (id INT, name VARCHAR(255), salary INT); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
SELECT * FROM students WHERE name = 'Jane Doe'; |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,13 +37,20 @@ def test_create_user1(monkeypatch): | |
|
||
assert models == ['demo-tpc-h'] | ||
|
||
def test_create_model(): | ||
rv = vn.create_model(model='test_org', db_type='Snowflake') | ||
|
||
@pytest.mark.parametrize("model_name", ["Test @Org_"]) | ||
def test_create_model(model_name): | ||
rv = vn.create_model(model=model_name, db_type='Snowflake') | ||
assert rv == True | ||
|
||
models = vn.get_models() | ||
assert 'test-org' in models | ||
|
||
|
||
def test_is_user1_in_model(): | ||
rv = vn.get_models() | ||
assert rv == ['demo-tpc-h', 'test_org'] | ||
assert rv == ['demo-tpc-h', 'test-org'] | ||
|
||
|
||
def test_is_user2_in_model(monkeypatch): | ||
switch_to_user('user2', monkeypatch) | ||
|
@@ -56,23 +63,23 @@ def test_switch_back_to_user1(monkeypatch): | |
switch_to_user('user1', monkeypatch) | ||
|
||
models = vn.get_models() | ||
assert models == ['demo-tpc-h', 'test_org'] | ||
assert models == ['demo-tpc-h', 'test-org'] | ||
|
||
def test_set_model_my_model(): | ||
with pytest.raises(ValidationError): | ||
vn.set_model('my-model') | ||
|
||
def test_set_model(): | ||
vn.set_model('test_org') | ||
assert vn.__org == 'test_org' # type: ignore | ||
vn.set_model('test-org') | ||
assert vn.__org == 'test-org' # type: ignore | ||
|
||
def test_add_user_to_model(monkeypatch): | ||
rv = vn.add_user_to_model(model='test_org', email="[email protected]", is_admin=False) | ||
rv = vn.add_user_to_model(model='test-org', email="[email protected]", is_admin=False) | ||
assert rv == True | ||
|
||
switch_to_user('user2', monkeypatch) | ||
models = vn.get_models() | ||
assert models == ['demo-tpc-h', 'test_org'] | ||
assert models == ['demo-tpc-h', 'test-org'] | ||
|
||
def test_update_model_visibility(monkeypatch): | ||
rv = vn.update_model_visibility(public=True) | ||
|
@@ -84,7 +91,7 @@ def test_update_model_visibility(monkeypatch): | |
|
||
switch_to_user('user3', monkeypatch) | ||
models = vn.get_models() | ||
assert models == ['demo-tpc-h', 'test_org'] | ||
assert models == ['demo-tpc-h', 'test-org'] | ||
|
||
switch_to_user('user1', monkeypatch) | ||
|
||
|
@@ -188,7 +195,7 @@ def test_add_sql_pass_fail(): | |
|
||
def test_add_documentation_pass(monkeypatch): | ||
switch_to_user('user1', monkeypatch) | ||
vn.set_model('test_org') | ||
vn.set_model('test-org') | ||
rv = vn.add_documentation(documentation="This is the documentation") | ||
assert rv == True | ||
|
||
|
@@ -216,10 +223,10 @@ def test_remove_training_data(): | |
assert vn.get_training_data().shape[0] == num_training_data-1-index | ||
|
||
def test_create_model_and_add_user(): | ||
created = vn.create_model('test_org2', 'Snowflake') | ||
created = vn.create_model('test-org2', 'Snowflake') | ||
assert created == True | ||
|
||
added = vn.add_user_to_model(model='test_org2', email="[email protected]", is_admin=False) | ||
added = vn.add_user_to_model(model='test-org2', email="[email protected]", is_admin=False) | ||
assert added == True | ||
|
||
def test_ask_no_output(): | ||
|
@@ -240,7 +247,7 @@ def test_generate_meta(): | |
assert meta == 'AI Response' | ||
|
||
def test_double_train(): | ||
vn.set_model('test_org') | ||
vn.set_model('test-org') | ||
|
||
training_data = vn.get_training_data() | ||
assert training_data.shape == (0, 0) | ||
|
@@ -256,33 +263,103 @@ def test_double_train(): | |
training_data = vn.get_training_data() | ||
assert training_data.shape == (1, 4) | ||
|
||
@pytest.mark.parametrize("sql_file_path, json_file_path, should_work", [ | ||
('tests/test_files/sql/testSqlSelect.sql', 'tests/test_files/training/questions.json', True), | ||
('tests/test_files/sql/testSqlCreate.sql', 'tests/test_files/training/questions.json', True), | ||
('tests/test_files/sql/testSql.sql', 'tests/test_files/training/s.json', False), | ||
]) | ||
def test_train(sql_file_path, json_file_path, should_work): | ||
# if just question not sql | ||
with pytest.raises(ValidationError): | ||
vn.train(question="What's the data about student John Doe?") | ||
|
||
# if just sql | ||
assert vn.train(sql="SELECT * FROM students WHERE name = 'Jane Doe'") == True | ||
|
||
# if just sql and documentation=True | ||
assert vn.train(sql="SELECT * FROM students WHERE name = 'Jane Doe'", documentation=True) == True | ||
|
||
# if just ddl statement | ||
assert vn.train(ddl="This is the ddl") == True | ||
@pytest.mark.parametrize("params", [ | ||
dict( | ||
question=None, | ||
sql="SELECT * FROM students WHERE name = 'Jane Doe'", | ||
documentation=False, | ||
ddl=None, | ||
sql_file=None, | ||
json_file=None, | ||
), | ||
dict( | ||
question=None, | ||
sql="SELECT * FROM students WHERE name = 'Jane Doe'", | ||
documentation=True, | ||
ddl=None, | ||
sql_file=None, | ||
json_file=None, | ||
), | ||
dict( | ||
question=None, | ||
sql=None, | ||
documentation=False, | ||
ddl="This is the ddl", | ||
sql_file=None, | ||
json_file=None, | ||
), | ||
dict( | ||
question=None, | ||
sql=None, | ||
documentation=False, | ||
ddl=None, | ||
sql_file="tests/fixtures/sql/testSqlSelect.sql", | ||
json_file=None, | ||
), | ||
dict( | ||
question=None, | ||
sql=None, | ||
documentation=False, | ||
ddl=None, | ||
sql_file=None, | ||
json_file="tests/fixtures/questions.json" | ||
), | ||
dict( | ||
question=None, | ||
sql=None, | ||
documentation=False, | ||
ddl=None, | ||
sql_file="tests/fixtures/sql/testSqlCreate.sql", | ||
json_file=None, | ||
), | ||
]) | ||
def test_train_success(monkeypatch, params): | ||
vn.set_model('test-org') | ||
assert vn.train(**params) | ||
|
||
|
||
@pytest.mark.parametrize("params, expected_exc_class", [ | ||
( | ||
dict( | ||
question="What's the data about student John Doe?", | ||
sql=None, | ||
documentation=False, | ||
ddl=None, | ||
sql_file=None, | ||
json_file=None, | ||
), | ||
ValidationError | ||
), | ||
( | ||
dict( | ||
question=None, | ||
sql=None, | ||
documentation=False, | ||
ddl=None, | ||
sql_file="wrong/path/or/file.sql", | ||
json_file=None, | ||
), | ||
ImproperlyConfigured | ||
), | ||
( | ||
dict( | ||
question=None, | ||
sql=None, | ||
documentation=False, | ||
ddl=None, | ||
sql_file=None, | ||
json_file="wrong/path/or/file.json", | ||
), | ||
ImproperlyConfigured | ||
) | ||
]) | ||
def test_train_validations(monkeypatch, params, expected_exc_class): | ||
vn.set_model('test-org') | ||
|
||
# if just sql_file | ||
if should_work: | ||
assert vn.train(sql_file=sql_file_path) == True | ||
assert vn.train(json_file=json_file_path) == True | ||
else: | ||
with pytest.raises(ImproperlyConfigured): | ||
vn.train(sql_file=sql_file_path) | ||
vn.train(json_file=json_file_path) | ||
with pytest.raises((ValidationError, ImproperlyConfigured)) as exc: | ||
vn.train(**params) | ||
assert isinstance(exc, expected_exc_class) | ||
|
||
|
||
@pytest.mark.parametrize('model_name', [1234, ['test_org']]) | ||
|