From e84f6f77356cde37174949642fcb03a46a3af0ed Mon Sep 17 00:00:00 2001 From: Gregor Karetka Date: Thu, 21 Dec 2023 12:03:14 +0100 Subject: [PATCH] Bugfix #5 - string in where not working - String in WHERE statement not working - Update tests --- src/duckberg/sqlparser.py | 3 +- tests/sqlparser/test_selects.py | 67 ++++++++++++++++++--------------- tests/sqlparser/test_where.py | 55 ++++++++++++++++++--------- 3 files changed, 75 insertions(+), 50 deletions(-) diff --git a/src/duckberg/sqlparser.py b/src/duckberg/sqlparser.py index 02a70a6..8483952 100644 --- a/src/duckberg/sqlparser.py +++ b/src/duckberg/sqlparser.py @@ -46,4 +46,5 @@ def extract_tables(self, parsed_sql: sqlparse.sql.Statement) -> list[TableWithAl def extract_where_conditions(self, where_statement: list[sqlparse.sql.Where]): comparison = sqlparse.sql.TokenList(where_statement[1:]) - return parser.parse(str(comparison)) + where_condition = str(comparison).replace('"', "'") # revert from double to single + return parser.parse(where_condition) diff --git a/tests/sqlparser/test_selects.py b/tests/sqlparser/test_selects.py index 08cbd83..f954e74 100644 --- a/tests/sqlparser/test_selects.py +++ b/tests/sqlparser/test_selects.py @@ -4,50 +4,55 @@ @pytest.fixture -def get_parser(): +def get_parser() -> DuckBergSQLParser: return DuckBergSQLParser() def test_basic_select_1(get_parser): - sql1 = """ - SELECT * FROM this_is_awesome_table""" - sql1_parsed = get_parser.parse_first_query(sql=sql1) - res1 = get_parser.extract_tables(sql1_parsed) - assert len(res1) == 1 - assert list(map(lambda x: str(x), res1)) == ["this_is_awesome_table (None)"] + sql = """SELECT * FROM this_is_awesome_table""" + sql_parsed = get_parser.parse_first_query(sql=sql) + res = get_parser.extract_tables(sql_parsed) + assert len(res) == 1 + assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (None)"] def test_basic_select_2(get_parser): - sql2 = """ - SELECT * FROM this_is_awesome_table, second_awesome_table""" - sql2_parsed = get_parser.parse_first_query(sql=sql2) - res2 = get_parser.extract_tables(sql2_parsed) - assert len(res2) == 2 - assert list(map(lambda x: str(x), res2)) == ["this_is_awesome_table (None)", "second_awesome_table (None)"] + sql = """SELECT * FROM 'this_is_awesome_table'""" + sql_parsed = get_parser.parse_first_query(sql=sql) + print(str(sql_parsed.tokens)) + res = get_parser.extract_tables(sql_parsed) + print(res) + assert len(res) == 1 + assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (None)"] def test_basic_select_3(get_parser): - sql3 = """ - SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table))""" - sql3_parsed = get_parser.parse_first_query(sql=sql3) - res3 = get_parser.extract_tables(sql3_parsed) - assert len(res3) == 1 - assert list(map(lambda x: str(x), res3)) == ["this_is_awesome_table (None)"] + sql = """SELECT * FROM this_is_awesome_table, second_awesome_table""" + sql_parsed = get_parser.parse_first_query(sql=sql) + res = get_parser.extract_tables(sql_parsed) + assert len(res) == 2 + assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (None)", "second_awesome_table (None)"] def test_basic_select_4(get_parser): - sql4 = """ - SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table), second_awesome_table)""" - sql4_parsed = get_parser.parse_first_query(sql=sql4) - res4 = get_parser.extract_tables(sql4_parsed) - assert len(res4) == 2 - assert list(map(lambda x: str(x), res4)) == ["this_is_awesome_table (None)", "second_awesome_table (None)"] + sql = """SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table))""" + sql_parsed = get_parser.parse_first_query(sql=sql) + res = get_parser.extract_tables(sql_parsed) + assert len(res) == 1 + assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (None)"] def test_basic_select_5(get_parser): - sql5 = """ - SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table tiat, second_awesome_table))""" - sql5_parsed = get_parser.parse_first_query(sql=sql5) - res5 = get_parser.extract_tables(sql5_parsed) - assert len(res5) == 2 - assert list(map(lambda x: str(x), res5)) == ["this_is_awesome_table (tiat)", "second_awesome_table (None)"] + sql = """SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table), second_awesome_table)""" + sql_parsed = get_parser.parse_first_query(sql=sql) + res = get_parser.extract_tables(sql_parsed) + assert len(res) == 2 + assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (None)", "second_awesome_table (None)"] + + +def test_basic_select_6(get_parser): + sql = """SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table tiat, second_awesome_table))""" + sql_parsed = get_parser.parse_first_query(sql=sql) + res = get_parser.extract_tables(sql_parsed) + assert len(res) == 2 + assert list(map(lambda x: str(x), res)) == ["this_is_awesome_table (tiat)", "second_awesome_table (None)"] diff --git a/tests/sqlparser/test_where.py b/tests/sqlparser/test_where.py index 109d267..a88d4c0 100644 --- a/tests/sqlparser/test_where.py +++ b/tests/sqlparser/test_where.py @@ -9,33 +9,52 @@ def get_parser(): def test_select_where_1(get_parser): - sql1 = """ - SELECT * FROM this_is_awesome_table WHERE a > 15""" - sql1_parsed = get_parser.parse_first_query(sql=sql1) - res1 = get_parser.extract_tables(sql1_parsed) - res1_where = str(res1[0].comparisons) - assert "GreaterThan(term=Reference(name='a'), literal=LongLiteral(15))" == res1_where + sql = """SELECT * FROM this_is_awesome_table WHERE a > 15""" + sql_parsed = get_parser.parse_first_query(sql=sql) + res = get_parser.extract_tables(sql_parsed) + res_where = str(res[0].comparisons) + assert "GreaterThan(term=Reference(name='a'), literal=LongLiteral(15))" == res_where def test_select_where_2(get_parser): - sql2 = """ - SELECT * FROM this_is_awesome_table WHERE a > 15 AND a < 16""" - sql2_parsed = get_parser.parse_first_query(sql=sql2) - res2 = get_parser.extract_tables(sql2_parsed) - res2_where = str(res2[0].comparisons) + sql = """SELECT * FROM this_is_awesome_table WHERE a > 15 AND a < 16""" + sql_parsed = get_parser.parse_first_query(sql=sql) + res = get_parser.extract_tables(sql_parsed) + res_where = str(res[0].comparisons) assert ( "And(left=GreaterThan(term=Reference(name='a'), literal=LongLiteral(15)), right=LessThan(term=Reference(name='a'), literal=LongLiteral(16)))" - == res2_where + == res_where ) def test_select_where_3(get_parser): - sql3 = """ - SELECT * FROM this_is_awesome_table WHERE (a > 15 AND a < 16) OR c > 15""" - sql3_parsed = get_parser.parse_first_query(sql=sql3) - res3 = get_parser.extract_tables(sql3_parsed) - res3_where = str(res3[0].comparisons) + sql = """SELECT * FROM this_is_awesome_table WHERE (a > 15 AND a < 16) OR c > 15""" + sql_parsed = get_parser.parse_first_query(sql=sql) + res = get_parser.extract_tables(sql_parsed) + res_where = str(res[0].comparisons) assert ( "Or(left=And(left=GreaterThan(term=Reference(name='a'), literal=LongLiteral(15)), right=LessThan(term=Reference(name='a'), literal=LongLiteral(16))), right=GreaterThan(term=Reference(name='c'), literal=LongLiteral(15)))" - == res3_where + == res_where + ) + + +def test_select_where_4(get_parser): + sql = """SELECT * FROM this_is_awesome_table WHERE (b = "test string" AND a < 16) OR c > 15""" + sql_parsed = get_parser.parse_first_query(sql=sql) + res = get_parser.extract_tables(sql_parsed) + res_where = str(res[0].comparisons) + assert ( + "Or(left=And(left=EqualTo(term=Reference(name='b'), literal=literal('test string')), right=LessThan(term=Reference(name='a'), literal=LongLiteral(16))), right=GreaterThan(term=Reference(name='c'), literal=LongLiteral(15)))" + == res_where + ) + + +def test_select_where_4(get_parser): + sql = """SELECT * FROM this_is_awesome_table WHERE (b = "test string" AND column = '108e6307-f23a-4e10-9e38-1866d58b4355') OR c > 15""" + sql_parsed = get_parser.parse_first_query(sql=sql) + res = get_parser.extract_tables(sql_parsed) + res_where = str(res[0].comparisons) + assert ( + "Or(left=And(left=EqualTo(term=Reference(name='b'), literal=literal('test string')), right=EqualTo(term=Reference(name='column'), literal=literal('108e6307-f23a-4e10-9e38-1866d58b4355'))), right=GreaterThan(term=Reference(name='c'), literal=LongLiteral(15)))" + == res_where )