Skip to content

Commit

Permalink
add support for db_kwargs for adbc in pl.read_database
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasaarholt committed Jul 27, 2023
1 parent ef91c45 commit 857bbcd
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions py-polars/polars/io/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def read_database(
partition_num: int | None = None,
protocol: str | None = None,
engine: DbReadEngine = "connectorx",
adbc_db_kwargs: dict[str, str] | None = None,
) -> DataFrame:
"""
Read a SQL query into a DataFrame.
Expand Down Expand Up @@ -60,6 +61,10 @@ def read_database(
an up-to-date list of drivers please see the ADBC docs:
* https://arrow.apache.org/adbc/
adbc_db_kwargs
Any additional arguments to pass to adbc ``connect()``, see
https://arrow.apache.org/adbc/main/driver/snowflake.html#snowflake-driver
for available options.
Notes
-----
Expand Down Expand Up @@ -121,7 +126,7 @@ def read_database(
elif engine == "adbc":
if not isinstance(query, str):
raise ValueError("Only a single SQL query string is accepted for adbc.")
return _read_sql_adbc(query, connection_uri)
return _read_sql_adbc(query, connection_uri, db_kwargs=adbc_db_kwargs)
else:
raise ValueError(f"Engine {engine!r} not implemented; use connectorx or adbc.")

Expand Down Expand Up @@ -153,16 +158,20 @@ def _read_sql_connectorx(
return from_arrow(tbl) # type: ignore[return-value]


def _read_sql_adbc(query: str, connection_uri: str) -> DataFrame:
with _open_adbc_connection(connection_uri) as conn:
def _read_sql_adbc(
query: str, connection_uri: str, db_kwargs: dict[str, str] | None
) -> DataFrame:
with _open_adbc_connection(connection_uri, db_kwargs=db_kwargs) as conn:
cursor = conn.cursor()
cursor.execute(query)
tbl = cursor.fetch_arrow_table()
cursor.close()
return from_arrow(tbl) # type: ignore[return-value]


def _open_adbc_connection(connection_uri: str) -> Any:
def _open_adbc_connection(
connection_uri: str, db_kwargs: dict[str, str] | None = None
) -> Any:
driver_name = connection_uri.split(":", 1)[0].lower()

# map uri prefix to module when not 1:1
Expand All @@ -184,4 +193,4 @@ def _open_adbc_connection(connection_uri: str) -> Any:
if driver_name in ("sqlite", "snowflake"):
connection_uri = re.sub(f"^{driver_name}:/{{,3}}", "", connection_uri)

return adbc_driver.connect(connection_uri)
return adbc_driver.connect(connection_uri, db_kwargs=db_kwargs)

0 comments on commit 857bbcd

Please sign in to comment.