Skip to content

Commit

Permalink
Bugfix #5 - string in where not working
Browse files Browse the repository at this point in the history
- String in WHERE statement not working
- Update tests
  • Loading branch information
gkaretka committed Dec 21, 2023
1 parent d499274 commit e84f6f7
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 50 deletions.
3 changes: 2 additions & 1 deletion src/duckberg/sqlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ def extract_tables(self, parsed_sql: sqlparse.sql.Statement) -> list[TableWithAl

def extract_where_conditions(self, where_statement: list[sqlparse.sql.Where]):
comparison = sqlparse.sql.TokenList(where_statement[1:])
return parser.parse(str(comparison))
where_condition = str(comparison).replace('"', "'") # revert from double to single
return parser.parse(where_condition)
67 changes: 36 additions & 31 deletions tests/sqlparser/test_selects.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,55 @@


@pytest.fixture
def get_parser():
def get_parser() -> DuckBergSQLParser:
return DuckBergSQLParser()


def test_basic_select_1(get_parser):
sql1 = """
SELECT * FROM this_is_awesome_table"""
sql1_parsed = get_parser.parse_first_query(sql=sql1)
res1 = get_parser.extract_tables(sql1_parsed)
assert len(res1) == 1
assert list(map(lambda x: str(x), res1)) == ["this_is_awesome_table (None)"]
sql = """SELECT * FROM this_is_awesome_table"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
assert len(res) == 1
assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (None)"]


def test_basic_select_2(get_parser):
sql2 = """
SELECT * FROM this_is_awesome_table, second_awesome_table"""
sql2_parsed = get_parser.parse_first_query(sql=sql2)
res2 = get_parser.extract_tables(sql2_parsed)
assert len(res2) == 2
assert list(map(lambda x: str(x), res2)) == ["this_is_awesome_table (None)", "second_awesome_table (None)"]
sql = """SELECT * FROM 'this_is_awesome_table'"""
sql_parsed = get_parser.parse_first_query(sql=sql)
print(str(sql_parsed.tokens))
res = get_parser.extract_tables(sql_parsed)
print(res)
assert len(res) == 1
assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (None)"]


def test_basic_select_3(get_parser):
sql3 = """
SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table))"""
sql3_parsed = get_parser.parse_first_query(sql=sql3)
res3 = get_parser.extract_tables(sql3_parsed)
assert len(res3) == 1
assert list(map(lambda x: str(x), res3)) == ["this_is_awesome_table (None)"]
sql = """SELECT * FROM this_is_awesome_table, second_awesome_table"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
assert len(res) == 2
assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (None)", "second_awesome_table (None)"]


def test_basic_select_4(get_parser):
sql4 = """
SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table), second_awesome_table)"""
sql4_parsed = get_parser.parse_first_query(sql=sql4)
res4 = get_parser.extract_tables(sql4_parsed)
assert len(res4) == 2
assert list(map(lambda x: str(x), res4)) == ["this_is_awesome_table (None)", "second_awesome_table (None)"]
sql = """SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table))"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
assert len(res) == 1
assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (None)"]


def test_basic_select_5(get_parser):
sql5 = """
SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table tiat, second_awesome_table))"""
sql5_parsed = get_parser.parse_first_query(sql=sql5)
res5 = get_parser.extract_tables(sql5_parsed)
assert len(res5) == 2
assert list(map(lambda x: str(x), res5)) == ["this_is_awesome_table (tiat)", "second_awesome_table (None)"]
sql = """SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table), second_awesome_table)"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
assert len(res) == 2
assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (None)", "second_awesome_table (None)"]


def test_basic_select_6(get_parser):
sql = """SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table tiat, second_awesome_table))"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
assert len(res) == 2
assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (tiat)", "second_awesome_table (None)"]
55 changes: 37 additions & 18 deletions tests/sqlparser/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,52 @@ def get_parser():


def test_select_where_1(get_parser):
sql1 = """
SELECT * FROM this_is_awesome_table WHERE a > 15"""
sql1_parsed = get_parser.parse_first_query(sql=sql1)
res1 = get_parser.extract_tables(sql1_parsed)
res1_where = str(res1[0].comparisons)
assert "GreaterThan(term=Reference(name='a'), literal=LongLiteral(15))" == res1_where
sql = """SELECT * FROM this_is_awesome_table WHERE a > 15"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
res_where = str(res[0].comparisons)
assert "GreaterThan(term=Reference(name='a'), literal=LongLiteral(15))" == res_where


def test_select_where_2(get_parser):
sql2 = """
SELECT * FROM this_is_awesome_table WHERE a > 15 AND a < 16"""
sql2_parsed = get_parser.parse_first_query(sql=sql2)
res2 = get_parser.extract_tables(sql2_parsed)
res2_where = str(res2[0].comparisons)
sql = """SELECT * FROM this_is_awesome_table WHERE a > 15 AND a < 16"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
res_where = str(res[0].comparisons)
assert (
"And(left=GreaterThan(term=Reference(name='a'), literal=LongLiteral(15)), right=LessThan(term=Reference(name='a'), literal=LongLiteral(16)))"
== res2_where
== res_where
)


def test_select_where_3(get_parser):
sql3 = """
SELECT * FROM this_is_awesome_table WHERE (a > 15 AND a < 16) OR c > 15"""
sql3_parsed = get_parser.parse_first_query(sql=sql3)
res3 = get_parser.extract_tables(sql3_parsed)
res3_where = str(res3[0].comparisons)
sql = """SELECT * FROM this_is_awesome_table WHERE (a > 15 AND a < 16) OR c > 15"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
res_where = str(res[0].comparisons)
assert (
"Or(left=And(left=GreaterThan(term=Reference(name='a'), literal=LongLiteral(15)), right=LessThan(term=Reference(name='a'), literal=LongLiteral(16))), right=GreaterThan(term=Reference(name='c'), literal=LongLiteral(15)))"
== res3_where
== res_where
)


def test_select_where_4(get_parser):
sql = """SELECT * FROM this_is_awesome_table WHERE (b = "test string" AND a < 16) OR c > 15"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
res_where = str(res[0].comparisons)
assert (
"Or(left=And(left=EqualTo(term=Reference(name='b'), literal=literal('test string')), right=LessThan(term=Reference(name='a'), literal=LongLiteral(16))), right=GreaterThan(term=Reference(name='c'), literal=LongLiteral(15)))"
== res_where
)


def test_select_where_4(get_parser):
sql = """SELECT * FROM this_is_awesome_table WHERE (b = "test string" AND column = '108e6307-f23a-4e10-9e38-1866d58b4355') OR c > 15"""
sql_parsed = get_parser.parse_first_query(sql=sql)
res = get_parser.extract_tables(sql_parsed)
res_where = str(res[0].comparisons)
assert (
"Or(left=And(left=EqualTo(term=Reference(name='b'), literal=literal('test string')), right=EqualTo(term=Reference(name='column'), literal=literal('108e6307-f23a-4e10-9e38-1866d58b4355'))), right=GreaterThan(term=Reference(name='c'), literal=LongLiteral(15)))"
== res_where
)

0 comments on commit e84f6f7

Please sign in to comment.