diff --git a/README.md b/README.md index 44efc96..bdbc7e0 100644 --- a/README.md +++ b/README.md @@ -82,9 +82,18 @@ db.list_partitions(table="nyc.taxis") ### Querying data to Pandas dataframe +In the latest new update we have added very crude and simple SQL parser that can extract necessary information from the SQL query without the need to specify `table` and `partition_filters`. This is the new and prefered way: + +```python +query = "SELECT * FROM nyc.taxis WHERE payment_type = 1 AND trip_distance > 40 ORDER BY tolls_amount DESC" +df = db.select(sql=query).read_pandas() +``` + +Old way of selecting data (will get deprecated in the future): + ```python query = "SELECT * FROM nyc.taxis WHERE trip_distance > 40 ORDER BY tolls_amount DESC" -df = db.select(table="nyc.taxis", partition_filter="payment_type = 1", sql=query).read_pandas() +df = db.select(sql=query, table="nyc.taxis", partition_filter="payment_type = 1").read_pandas() ``` ## Playground diff --git a/playground/jupyter/notebooks/001 - Duckberg simple query - REST Iceberg catalog.ipynb b/playground/jupyter/notebooks/001 - Duckberg simple query - REST Iceberg catalog.ipynb index 209492b..18322cf 100644 --- a/playground/jupyter/notebooks/001 - Duckberg simple query - REST Iceberg catalog.ipynb +++ b/playground/jupyter/notebooks/001 - Duckberg simple query - REST Iceberg catalog.ipynb @@ -167,6 +167,26 @@ "With usage of partition filter to read data just from files we need." ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "f267cd22", + "metadata": {}, + "outputs": [], + "source": [ + "query = \"SELECT * FROM 'nyc.taxis' WHERE payment_type = 1 AND trip_distance > 40 ORDER BY tolls_amount DESC\"\n", + "df = db.select(sql=query)\n", + "df.head(10)" + ] + }, + { + "cell_type": "markdown", + "id": "40019f6b", + "metadata": {}, + "source": [ + "or the old way" + ] + }, { "cell_type": "code", "execution_count": null, @@ -179,7 +199,7 @@ "outputs": [], "source": [ "query = \"SELECT * FROM 'nyc.taxis' WHERE trip_distance > 40 ORDER BY tolls_amount DESC\"\n", - "df = db.select(table=\"nyc.taxis\", partition_filter=\"payment_type = 1\", sql=query).read_pandas()\n", + "df = db.select(sql=query, table=\"nyc.taxis\", partition_filter=\"payment_type = 1\")\n", "df.head(10)" ] }, diff --git a/pyproject.toml b/pyproject.toml index 6034560..5abf88b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,8 @@ classifiers = [ dependencies = [ "pyiceberg[duckdb,glue,s3fs]==0.5.1", - "pandas==2.1.3" + "pandas==2.1.3", + "sqlparse==0.4.4" ] [[tool.hatch.envs.test.matrix]] diff --git a/src/duckberg/__init__.py b/src/duckberg/__init__.py index dc5f990..29873ec 100644 --- a/src/duckberg/__init__.py +++ b/src/duckberg/__init__.py @@ -1,11 +1,14 @@ """Module containing services needed for executing queries with Duckdb + Iceberg.""" from typing import Optional - -from duckdb import DuckDBPyConnection +import duckdb from pyarrow.lib import RecordBatchReader from pyiceberg.catalog import Catalog, load_catalog, load_rest -from pyiceberg.table import Table +from duckberg.exceptions import TableNotInCatalogException +from pyiceberg.expressions import AlwaysTrue + +from duckberg.sqlparser import DuckBergSQLParser +from duckberg.table import DuckBergTable, TableWithAlias BATCH_SIZE_ROWS = 1024 DEFAULT_MEM_LIMIT = "1GB" @@ -20,6 +23,7 @@ def __init__( self, catalog_name: str, catalog_config: dict[str, str], + duckdb_connection: duckdb.DuckDBPyConnection = None, db_thread_limit: Optional[int] = DEFAULT_DB_THREAD_LIMIT, db_mem_limit: Optional[str] = DEFAULT_MEM_LIMIT, batch_size_rows: Optional[int] = BATCH_SIZE_ROWS, @@ -27,9 +31,22 @@ def __init__( self.db_thread_limit = db_thread_limit self.db_mem_limit = db_mem_limit self.batch_size_rows = batch_size_rows + self.duckdb_connection = duckdb_connection + + if self.duckdb_connection == None: + self.duckdb_connection = duckdb.connect() + self.init_duckdb() + self.sql_parser = DuckBergSQLParser() + self.tables: dict[str, DuckBergTable] = {} + + self.init_duckdb() self.__get_tables(catalog_config, catalog_name) + def init_duckdb(self): + self.duckdb_connection.execute(f"SET memory_limit='{self.db_mem_limit}'") + self.duckdb_connection.execute(f"SET threads TO {self.db_thread_limit}") + def __get_tables(self, catalog_config, catalog_name): tables = {} catalog: Catalog = load_catalog(catalog_name, **catalog_config) @@ -37,7 +54,9 @@ def __get_tables(self, catalog_config, catalog_name): namespaces = rest.list_namespaces() for n in namespaces: tables_names = rest.list_tables(n) - self.tables: dict[str, Table] = {".".join(t): catalog.load_table(t) for t in tables_names} + self.tables: dict[str, DuckBergTable] = { + ".".join(t): DuckBergTable.from_pyiceberg_table(catalog.load_table(t)) for t in tables_names + } return tables def list_tables(self): @@ -46,21 +65,46 @@ def list_tables(self): def list_partitions(self, table: str): t = self.tables[table] - if t.spec().is_unpartitioned(): - return None + if t.partitions == None: + t.precomp_partitions() + + return t.partitions - partition_cols_ids = [p["source-id"] for p in t.spec().model_dump()["fields"]] - col_names = [c["name"] for c in t.schema().model_dump()["fields"] if c["id"] in partition_cols_ids] + def select(self, sql: str, table: str = None, partition_filter: str = None, sql_params: [str] = None) -> RecordBatchReader: + if table is not None and partition_filter is not None: + return self._select_old(sql, table, partition_filter, sql_params) - return col_names + parsed_sql = self.sql_parser.parse_first_query(sql) + extracted_tables = self.sql_parser.extract_tables(parsed_sql) - def select(self, table: str, sql: str, partition_filter: str, sql_params: [str] = None) -> RecordBatchReader: - db_conn: DuckDBPyConnection = self.tables[table].scan(row_filter=partition_filter).to_duckdb(table_name=table) + table: TableWithAlias + for table in extracted_tables: + table_name = table.table_name - db_conn.execute(f"SET memory_limit='{self.db_mem_limit}'") - db_conn.execute(f"SET threads TO {self.db_thread_limit}") + if table_name not in self.tables: + raise TableNotInCatalogException + + row_filter = AlwaysTrue() + if table.comparisons is not None: + row_filter = table.comparisons + + table_data_scan_as_arrow = self.tables[table_name].scan(row_filter=row_filter).to_arrow() + self.duckdb_connection.register(table_name, table_data_scan_as_arrow) + + if sql_params is None: + return self.duckdb_connection.execute(sql).fetch_record_batch(self.batch_size_rows).read_pandas() + else: + return ( + self.duckdb_connection.execute(sql, parameters=sql_params) + .fetch_record_batch(self.batch_size_rows) + .read_pandas() + ) + + def _select_old(self, sql: str, table: str, partition_filter: str, sql_params: [str] = None): + table_data_scan_as_arrow = self.tables[table].scan(row_filter=partition_filter).to_arrow() + self.duckdb_connection.register(table, table_data_scan_as_arrow) if sql_params is None: - return db_conn.execute(sql).fetch_record_batch(self.batch_size_rows).read_pandas() + return self.duckdb_connection.execute(sql).fetch_record_batch(self.batch_size_rows).read_pandas() else: - return db_conn.execute(sql, parameters=sql_params).fetch_record_batch(self.batch_size_rows).read_pandas() + return self.duckdb_connection.execute(sql, parameters=sql_params).fetch_record_batch(self.batch_size_rows).read_pandas() diff --git a/src/duckberg/exceptions.py b/src/duckberg/exceptions.py new file mode 100644 index 0000000..db8eeb2 --- /dev/null +++ b/src/duckberg/exceptions.py @@ -0,0 +1,4 @@ +class TableNotInCatalogException(Exception): + """ + The specified table is not registered in data catalog + """ diff --git a/src/duckberg/sqlparser.py b/src/duckberg/sqlparser.py new file mode 100644 index 0000000..82b7822 --- /dev/null +++ b/src/duckberg/sqlparser.py @@ -0,0 +1,49 @@ +import sqlparse + +from duckberg.table import TableWithAlias +from pyiceberg.expressions import * +from pyiceberg.expressions import parser + + +class DuckBergSQLParser: + def parse_first_query(self, sql: str) -> sqlparse.sql.Statement: + reformated_sql = sql.replace("'", '"') # replace all single quotes with double quotes + return sqlparse.parse(reformated_sql)[0] + + def unpack_identifiers(self, token: sqlparse.sql.IdentifierList) -> list[TableWithAlias]: + return list( + map( + lambda y: TableWithAlias.from_identifier(y), + filter(lambda x: type(x) is sqlparse.sql.Identifier, token.tokens), + ) + ) + + def extract_tables(self, parsed_sql: sqlparse.sql.Statement) -> list[TableWithAlias]: + tables = [] + get_next = 0 + c_table: list[TableWithAlias] = [] + c_table_wc = None + for token in parsed_sql.tokens: + if get_next == 1 and token.ttype is not sqlparse.tokens.Whitespace: + if type(token) is sqlparse.sql.Identifier: + c_table = [TableWithAlias.from_identifier(token)] + get_next += 1 + elif type(token) is sqlparse.sql.IdentifierList: + c_table = self.unpack_identifiers(token) + get_next += 1 + elif type(token) is sqlparse.sql.Parenthesis: + tables.extend(self.extract_tables(token)) + + if token.ttype is sqlparse.tokens.Keyword and str(token.value).upper() == "FROM": + get_next += 1 + + if type(token) is sqlparse.sql.Where: + c_table_wc = self.extract_where_conditions(token) + + mapped_c_table = list(map(lambda x: x.set_comparisons(c_table_wc), c_table)) + tables.extend(mapped_c_table) + return tables + + def extract_where_conditions(self, where_statement: list[sqlparse.sql.Where]): + comparison = sqlparse.sql.TokenList(where_statement[1:]) + return parser.parse(str(comparison)) diff --git a/src/duckberg/table.py b/src/duckberg/table.py new file mode 100644 index 0000000..13ecac1 --- /dev/null +++ b/src/duckberg/table.py @@ -0,0 +1,58 @@ +from pyiceberg.catalog import Catalog +from pyiceberg.io import FileIO +from pyiceberg.table import Table +from pyiceberg.table.metadata import TableMetadata +from pyiceberg.typedef import Identifier +from pyiceberg.expressions import BooleanExpression +import sqlparse + + +class DuckBergTable(Table): + """ + Class for storing precomputed data for faster processing of queries + """ + + def __init__( + self, identifier: Identifier, metadata: TableMetadata, metadata_location: str, io: FileIO, catalog: Catalog + ) -> None: + super().__init__(identifier, metadata, metadata_location, io, catalog) + self.partitions = None + + @classmethod + def from_pyiceberg_table(cls, table: Table): + return cls(table.identifier, table.metadata, table.metadata_location, table.io, table.catalog) + + def precomp_partitions(self): + if self.spec().is_unpartitioned(): + self.partitions = [] + + partition_cols_ids = [p["source-id"] for p in self.spec().model_dump()["fields"]] + self.partitions = [c["name"] for c in self.schema().model_dump()["fields"] if c["id"] in partition_cols_ids] + + def __repr__(self) -> str: + return self.table + + +class TableWithAlias: + """ + Dataclass contains table name with alias + """ + + def __init__(self, tname: str, talias: str) -> None: + self.table_name: str = tname + self.table_alias: str = talias + self.comparisons: BooleanExpression = None + + @classmethod + def from_identifier(cls, identf: sqlparse.sql.Identifier): + return cls(identf.get_real_name(), identf.get_alias()) + + def set_comparisons(self, comparisons: BooleanExpression): + self.comparisons = comparisons + return self + + def __str__(self) -> str: + return f"{self.table_name} ({self.table_alias})" + + def __repr__(self) -> str: + return self.__str__() diff --git a/tests/duckberg-sample.py b/tests/duckberg-sample.py index 0b64cca..1c0b012 100644 --- a/tests/duckberg-sample.py +++ b/tests/duckberg-sample.py @@ -19,14 +19,19 @@ tables = db.list_tables() -assert len(tables) == 1 - -partitions = db.list_partitions(table="nyc.taxis") - -assert len(tables) == 1 - -query: str = "SELECT * FROM 'nyc.taxis' WHERE trip_distance > 40 ORDER BY tolls_amount DESC" - -dd = db.select(table="nyc.taxis", partition_filter="payment_type = 1", sql=query) - -df = dd.read_pandas() +assert(len(tables) == 1) + +# New way of quering data without partition filter +query: str = "SELECT count(*) FROM (SELECT * FROM 'nyc.taxis' WHERE trip_distance > 40 ORDER BY tolls_amount DESC)" +df = db.select(sql=query) +assert(df['count_star()'][0] == 2614) + +# New way of quering data +query: str = "SELECT count(*) FROM (SELECT * FROM 'nyc.taxis' WHERE payment_type = 1 AND trip_distance > 40 ORDER BY tolls_amount DESC)" +df = db.select(sql=query) +assert(df['count_star()'][0] == 1673) + +# Old way of quering data +query: str = "SELECT count(*) FROM (SELECT * FROM 'nyc.taxis' WHERE payment_type = 1 AND trip_distance > 40 ORDER BY tolls_amount DESC)" +df = db.select(sql=query, table="nyc.taxis", partition_filter="payment_type = 1") +assert(df['count_star()'][0] == 1673) diff --git a/tests/sqlparser/basic_selects.py b/tests/sqlparser/basic_selects.py new file mode 100644 index 0000000..53efc1a --- /dev/null +++ b/tests/sqlparser/basic_selects.py @@ -0,0 +1,40 @@ +from duckberg.sqlparser import DuckBergSQLParser + + +parser = DuckBergSQLParser() + + +sql1 = """ +SELECT * FROM this_is_awesome_table""" +sql1_parsed = parser.parse_first_query(sql=sql1) +res1 = parser.extract_tables(sql1_parsed) +assert len(res1) == 1 +assert list(map(lambda x: str(x), res1)) == ["this_is_awesome_table (None)"] + +sql2 = """ +SELECT * FROM this_is_awesome_table, second_awesome_table""" +sql2_parsed = parser.parse_first_query(sql=sql2) +res2 = parser.extract_tables(sql2_parsed) +assert len(res2) == 2 +assert list(map(lambda x: str(x), res2)) == ["this_is_awesome_table (None)", "second_awesome_table (None)"] + +sql3 = """ +SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table))""" +sql3_parsed = parser.parse_first_query(sql=sql3) +res3 = parser.extract_tables(sql3_parsed) +assert len(res3) == 1 +assert list(map(lambda x: str(x), res3)) == ["this_is_awesome_table (None)"] + +sql4 = """ +SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table), second_awesome_table)""" +sql4_parsed = parser.parse_first_query(sql=sql4) +res4 = parser.extract_tables(sql4_parsed) +assert len(res4) == 2 +assert list(map(lambda x: str(x), res4)) == ["this_is_awesome_table (None)", "second_awesome_table (None)"] + +sql5 = """ +SELECT * FROM (SELECT * FROM (SELECT * FROM this_is_awesome_table tiat, second_awesome_table))""" +sql5_parsed = parser.parse_first_query(sql=sql5) +res5 = parser.extract_tables(sql5_parsed) +assert len(res5) == 2 +assert list(map(lambda x: str(x), res5)) == ["this_is_awesome_table (tiat)", "second_awesome_table (None)"] diff --git a/tests/sqlparser/where_selects.py b/tests/sqlparser/where_selects.py new file mode 100644 index 0000000..25d39c0 --- /dev/null +++ b/tests/sqlparser/where_selects.py @@ -0,0 +1,32 @@ +from duckberg.sqlparser import DuckBergSQLParser + + +parser = DuckBergSQLParser() + + +sql1 = """ +SELECT * FROM this_is_awesome_table WHERE a > 15""" +sql1_parsed = parser.parse_first_query(sql=sql1) +res1 = parser.extract_tables(sql1_parsed) +res1_where = str(res1[0].comparisons) +assert "GreaterThan(term=Reference(name='a'), literal=LongLiteral(15))" == res1_where + +sql2 = """ +SELECT * FROM this_is_awesome_table WHERE a > 15 AND a < 16""" +sql2_parsed = parser.parse_first_query(sql=sql2) +res2 = parser.extract_tables(sql2_parsed) +res2_where = str(res2[0].comparisons) +assert ( + "And(left=GreaterThan(term=Reference(name='a'), literal=LongLiteral(15)), right=LessThan(term=Reference(name='a'), literal=LongLiteral(16)))" + == res2_where +) + +sql3 = """ +SELECT * FROM this_is_awesome_table WHERE (a > 15 AND a < 16) OR c > 15""" +sql3_parsed = parser.parse_first_query(sql=sql3) +res3 = parser.extract_tables(sql3_parsed) +res3_where = str(res3[0].comparisons) +assert ( + "Or(left=And(left=GreaterThan(term=Reference(name='a'), literal=LongLiteral(15)), right=LessThan(term=Reference(name='a'), literal=LongLiteral(16))), right=GreaterThan(term=Reference(name='c'), literal=LongLiteral(15)))" + == res3_where +)