Skip to content

Commit

Permalink
Add basic struct. and func. of SQL parsing
Browse files Browse the repository at this point in the history
- Update dependencies - add sqlparse
- Set new micro version - 0.0.3
- Add SQL parser
- Add table abstraction to hold necessary data
- Add exceptions - table not found in catalog
- Add some preliminary tests
  • Loading branch information
gkaretka authored and y0j0 committed Dec 20, 2023
1 parent f9c68b9 commit 6782236
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 29 deletions.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,18 @@ db.list_partitions(table="nyc.taxis")

### Querying data to Pandas dataframe

In the latest new update we have added very crude and simple SQL parser that can extract necessary information from the SQL query without the need to specify `table` and `partition_filters`. This is the new and prefered way:

```python
query = "SELECT * FROM nyc.taxis WHERE payment_type = 1 AND trip_distance > 40 ORDER BY tolls_amount DESC"
df = db.select(sql=query).read_pandas()
```

Old way of selecting data (will get deprecated in the future):

```python
query = "SELECT * FROM nyc.taxis WHERE trip_distance > 40 ORDER BY tolls_amount DESC"
df = db.select(table="nyc.taxis", partition_filter="payment_type = 1", sql=query).read_pandas()
df = db.select(sql=query, table="nyc.taxis", partition_filter="payment_type = 1").read_pandas()
```

## Playground
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,26 @@
"With usage of partition filter to read data just from files we need."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f267cd22",
"metadata": {},
"outputs": [],
"source": [
"query = \"SELECT * FROM 'nyc.taxis' WHERE payment_type = 1 AND trip_distance > 40 ORDER BY tolls_amount DESC\"\n",
"df = db.select(sql=query)\n",
"df.head(10)"
]
},
{
"cell_type": "markdown",
"id": "40019f6b",
"metadata": {},
"source": [
"or the old way"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -179,7 +199,7 @@
"outputs": [],
"source": [
"query = \"SELECT * FROM 'nyc.taxis' WHERE trip_distance > 40 ORDER BY tolls_amount DESC\"\n",
"df = db.select(table=\"nyc.taxis\", partition_filter=\"payment_type = 1\", sql=query).read_pandas()\n",
"df = db.select(sql=query, table=\"nyc.taxis\", partition_filter=\"payment_type = 1\")\n",
"df.head(10)"
]
},
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ classifiers = [

dependencies = [
"pyiceberg[duckdb,glue,s3fs]==0.5.1",
"pandas==2.1.3"
"pandas==2.1.3",
"sqlparse==0.4.4"
]

[[tool.hatch.envs.test.matrix]]
Expand Down
74 changes: 59 additions & 15 deletions src/duckberg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Module containing services needed for executing queries with Duckdb + Iceberg."""

from typing import Optional

from duckdb import DuckDBPyConnection
import duckdb
from pyarrow.lib import RecordBatchReader
from pyiceberg.catalog import Catalog, load_catalog, load_rest
from pyiceberg.table import Table
from duckberg.exceptions import TableNotInCatalogException
from pyiceberg.expressions import AlwaysTrue

from duckberg.sqlparser import DuckBergSQLParser
from duckberg.table import DuckBergTable, TableWithAlias

BATCH_SIZE_ROWS = 1024
DEFAULT_MEM_LIMIT = "1GB"
Expand All @@ -20,24 +23,40 @@ def __init__(
self,
catalog_name: str,
catalog_config: dict[str, str],
duckdb_connection: duckdb.DuckDBPyConnection = None,
db_thread_limit: Optional[int] = DEFAULT_DB_THREAD_LIMIT,
db_mem_limit: Optional[str] = DEFAULT_MEM_LIMIT,
batch_size_rows: Optional[int] = BATCH_SIZE_ROWS,
):
self.db_thread_limit = db_thread_limit
self.db_mem_limit = db_mem_limit
self.batch_size_rows = batch_size_rows
self.duckdb_connection = duckdb_connection

if self.duckdb_connection == None:
self.duckdb_connection = duckdb.connect()
self.init_duckdb()

self.sql_parser = DuckBergSQLParser()
self.tables: dict[str, DuckBergTable] = {}

self.init_duckdb()
self.__get_tables(catalog_config, catalog_name)

def init_duckdb(self):
self.duckdb_connection.execute(f"SET memory_limit='{self.db_mem_limit}'")
self.duckdb_connection.execute(f"SET threads TO {self.db_thread_limit}")

def __get_tables(self, catalog_config, catalog_name):
tables = {}
catalog: Catalog = load_catalog(catalog_name, **catalog_config)
rest = load_rest(catalog_name, catalog_config)
namespaces = rest.list_namespaces()
for n in namespaces:
tables_names = rest.list_tables(n)
self.tables: dict[str, Table] = {".".join(t): catalog.load_table(t) for t in tables_names}
self.tables: dict[str, DuckBergTable] = {
".".join(t): DuckBergTable.from_pyiceberg_table(catalog.load_table(t)) for t in tables_names
}
return tables

def list_tables(self):
Expand All @@ -46,21 +65,46 @@ def list_tables(self):
def list_partitions(self, table: str):
t = self.tables[table]

if t.spec().is_unpartitioned():
return None
if t.partitions == None:
t.precomp_partitions()

return t.partitions

partition_cols_ids = [p["source-id"] for p in t.spec().model_dump()["fields"]]
col_names = [c["name"] for c in t.schema().model_dump()["fields"] if c["id"] in partition_cols_ids]
def select(self, sql: str, table: str = None, partition_filter: str = None, sql_params: [str] = None) -> RecordBatchReader:
if table is not None and partition_filter is not None:
return self._select_old(sql, table, partition_filter, sql_params)

return col_names
parsed_sql = self.sql_parser.parse_first_query(sql)
extracted_tables = self.sql_parser.extract_tables(parsed_sql)

def select(self, table: str, sql: str, partition_filter: str, sql_params: [str] = None) -> RecordBatchReader:
db_conn: DuckDBPyConnection = self.tables[table].scan(row_filter=partition_filter).to_duckdb(table_name=table)
table: TableWithAlias
for table in extracted_tables:
table_name = table.table_name

db_conn.execute(f"SET memory_limit='{self.db_mem_limit}'")
db_conn.execute(f"SET threads TO {self.db_thread_limit}")
if table_name not in self.tables:
raise TableNotInCatalogException

row_filter = AlwaysTrue()
if table.comparisons is not None:
row_filter = table.comparisons

table_data_scan_as_arrow = self.tables[table_name].scan(row_filter=row_filter).to_arrow()
self.duckdb_connection.register(table_name, table_data_scan_as_arrow)

if sql_params is None:
return self.duckdb_connection.execute(sql).fetch_record_batch(self.batch_size_rows).read_pandas()
else:
return (
self.duckdb_connection.execute(sql, parameters=sql_params)
.fetch_record_batch(self.batch_size_rows)
.read_pandas()
)

def _select_old(self, sql: str, table: str, partition_filter: str, sql_params: [str] = None):
table_data_scan_as_arrow = self.tables[table].scan(row_filter=partition_filter).to_arrow()
self.duckdb_connection.register(table, table_data_scan_as_arrow)

if sql_params is None:
return db_conn.execute(sql).fetch_record_batch(self.batch_size_rows).read_pandas()
return self.duckdb_connection.execute(sql).fetch_record_batch(self.batch_size_rows).read_pandas()
else:
return db_conn.execute(sql, parameters=sql_params).fetch_record_batch(self.batch_size_rows).read_pandas()
return self.duckdb_connection.execute(sql, parameters=sql_params).fetch_record_batch(self.batch_size_rows).read_pandas()
4 changes: 4 additions & 0 deletions src/duckberg/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class TableNotInCatalogException(Exception):
"""
The specified table is not registered in data catalog
"""
49 changes: 49 additions & 0 deletions src/duckberg/sqlparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import sqlparse

from duckberg.table import TableWithAlias
from pyiceberg.expressions import *
from pyiceberg.expressions import parser


class DuckBergSQLParser:
def parse_first_query(self, sql: str) -> sqlparse.sql.Statement:
reformated_sql = sql.replace("'", '"') # replace all single quotes with double quotes
return sqlparse.parse(reformated_sql)[0]

def unpack_identifiers(self, token: sqlparse.sql.IdentifierList) -> list[TableWithAlias]:
return list(
map(
lambda y: TableWithAlias.from_identifier(y),
filter(lambda x: type(x) is sqlparse.sql.Identifier, token.tokens),
)
)

def extract_tables(self, parsed_sql: sqlparse.sql.Statement) -> list[TableWithAlias]:
tables = []
get_next = 0
c_table: list[TableWithAlias] = []
c_table_wc = None
for token in parsed_sql.tokens:
if get_next == 1 and token.ttype is not sqlparse.tokens.Whitespace:
if type(token) is sqlparse.sql.Identifier:
c_table = [TableWithAlias.from_identifier(token)]
get_next += 1
elif type(token) is sqlparse.sql.IdentifierList:
c_table = self.unpack_identifiers(token)
get_next += 1
elif type(token) is sqlparse.sql.Parenthesis:
tables.extend(self.extract_tables(token))

if token.ttype is sqlparse.tokens.Keyword and str(token.value).upper() == "FROM":
get_next += 1

if type(token) is sqlparse.sql.Where:
c_table_wc = self.extract_where_conditions(token)

mapped_c_table = list(map(lambda x: x.set_comparisons(c_table_wc), c_table))
tables.extend(mapped_c_table)
return tables

def extract_where_conditions(self, where_statement: list[sqlparse.sql.Where]):
comparison = sqlparse.sql.TokenList(where_statement[1:])
return parser.parse(str(comparison))
58 changes: 58 additions & 0 deletions src/duckberg/table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from pyiceberg.catalog import Catalog
from pyiceberg.io import FileIO
from pyiceberg.table import Table
from pyiceberg.table.metadata import TableMetadata
from pyiceberg.typedef import Identifier
from pyiceberg.expressions import BooleanExpression
import sqlparse


class DuckBergTable(Table):
"""
Class for storing precomputed data for faster processing of queries
"""

def __init__(
self, identifier: Identifier, metadata: TableMetadata, metadata_location: str, io: FileIO, catalog: Catalog
) -> None:
super().__init__(identifier, metadata, metadata_location, io, catalog)
self.partitions = None

@classmethod
def from_pyiceberg_table(cls, table: Table):
return cls(table.identifier, table.metadata, table.metadata_location, table.io, table.catalog)

def precomp_partitions(self):
if self.spec().is_unpartitioned():
self.partitions = []

partition_cols_ids = [p["source-id"] for p in self.spec().model_dump()["fields"]]
self.partitions = [c["name"] for c in self.schema().model_dump()["fields"] if c["id"] in partition_cols_ids]

def __repr__(self) -> str:
return self.table


class TableWithAlias:
"""
Dataclass contains table name with alias
"""

def __init__(self, tname: str, talias: str) -> None:
self.table_name: str = tname
self.table_alias: str = talias
self.comparisons: BooleanExpression = None

@classmethod
def from_identifier(cls, identf: sqlparse.sql.Identifier):
return cls(identf.get_real_name(), identf.get_alias())

def set_comparisons(self, comparisons: BooleanExpression):
self.comparisons = comparisons
return self

def __str__(self) -> str:
return f"{self.table_name} ({self.table_alias})"

def __repr__(self) -> str:
return self.__str__()
27 changes: 16 additions & 11 deletions tests/duckberg-sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,19 @@

tables = db.list_tables()

assert len(tables) == 1

partitions = db.list_partitions(table="nyc.taxis")

assert len(tables) == 1

query: str = "SELECT * FROM 'nyc.taxis' WHERE trip_distance > 40 ORDER BY tolls_amount DESC"

dd = db.select(table="nyc.taxis", partition_filter="payment_type = 1", sql=query)

df = dd.read_pandas()
assert(len(tables) == 1)

# New way of quering data without partition filter
query: str = "SELECT count(*) FROM (SELECT * FROM 'nyc.taxis' WHERE trip_distance > 40 ORDER BY tolls_amount DESC)"
df = db.select(sql=query)
assert(df['count_star()'][0] == 2614)

# New way of quering data
query: str = "SELECT count(*) FROM (SELECT * FROM 'nyc.taxis' WHERE payment_type = 1 AND trip_distance > 40 ORDER BY tolls_amount DESC)"
df = db.select(sql=query)
assert(df['count_star()'][0] == 1673)

# Old way of quering data
query: str = "SELECT count(*) FROM (SELECT * FROM 'nyc.taxis' WHERE payment_type = 1 AND trip_distance > 40 ORDER BY tolls_amount DESC)"
df = db.select(sql=query, table="nyc.taxis", partition_filter="payment_type = 1")
assert(df['count_star()'][0] == 1673)
40 changes: 40 additions & 0 deletions tests/sqlparser/basic_selects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from duckberg.sqlparser import DuckBergSQLParser


parser = DuckBergSQLParser()


sql1 = """
SELECT * FROM this_is_awesome_table"""
sql1_parsed = parser.parse_first_query(sql=sql1)
res1 = parser.extract_tables(sql1_parsed)
assert len(res1) == 1
assert list(map(lambda x: str(x), res1)) == ["this_is_awesome_table (None)"]

sql2 = """
SELECT * FROM this_is_awesome_table, second_awesome_table"""
sql2_parsed = parser.parse_first_query(sql=sql2)
res2 = parser.extract_tables(sql2_parsed)
assert len(res2) == 2
assert list(map(lambda x: str(x), res2)) == ["this_is_awesome_table (None)", "second_awesome_table (None)"]

sql3 = """
SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table))"""
sql3_parsed = parser.parse_first_query(sql=sql3)
res3 = parser.extract_tables(sql3_parsed)
assert len(res3) == 1
assert list(map(lambda x: str(x), res3)) == ["this_is_awesome_table (None)"]

sql4 = """
SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table), second_awesome_table)"""
sql4_parsed = parser.parse_first_query(sql=sql4)
res4 = parser.extract_tables(sql4_parsed)
assert len(res4) == 2
assert list(map(lambda x: str(x), res4)) == ["this_is_awesome_table (None)", "second_awesome_table (None)"]

sql5 = """
SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table tiat, second_awesome_table))"""
sql5_parsed = parser.parse_first_query(sql=sql5)
res5 = parser.extract_tables(sql5_parsed)
assert len(res5) == 2
assert list(map(lambda x: str(x), res5)) == ["this_is_awesome_table (tiat)", "second_awesome_table (None)"]
Loading

0 comments on commit 6782236

Please sign in to comment.