Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to DB native export format #357

Merged
merged 7 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 21 additions & 37 deletions cumulus_library/actions/exporter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pathlib

import pyarrow
from pyarrow import csv, parquet
import pandas
from rich import console
from rich.progress import track

from cumulus_library import base_utils, study_manifest
Expand All @@ -24,32 +24,23 @@ def reset_counts_exports(
file.unlink()


def _write_chunk(writer, chunk, arrow_schema):
writer.write(
pyarrow.Table.from_pandas(
chunk.sort_values(by=list(chunk.columns), ascending=False, na_position="first"),
preserve_index=False,
schema=arrow_schema,
)
)


def export_study(
config: base_utils.StudyConfig,
manifest: study_manifest.StudyManifest,
*,
data_path: pathlib.Path,
archive: bool,
chunksize: int = 1000000,
) -> list:
):
"""Exports csvs/parquet extracts of tables listed in export_list
:param config: a StudyConfig object
:param manifest: a StudyManifest object
:keyword data_path: the path to the place on disk to save data
:keyword archive: If true, get all study data and zip with timestamp
:keyword chunksize: number of rows to export in a single transaction
:returns: a list of queries, (only for unit tests)
"""

skipped_tables = []
reset_counts_exports(manifest)
if manifest.get_dedicated_schema():
prefix = f"{manifest.get_dedicated_schema()}."
Expand All @@ -64,34 +55,27 @@ def export_study(
table_list.append(study_manifest.ManifestExport(name=row[0], export_type="archive"))
else:
table_list = manifest.get_export_table_list()
queries = []
path = pathlib.Path(f"{data_path}/{manifest.get_study_prefix()}/")
path.mkdir(parents=True, exist_ok=True)
for table in track(
table_list,
description=f"Exporting {manifest.get_study_prefix()} data...",
):
query = f"SELECT * FROM {table.name}" # noqa: S608
query = base_utils.update_query_if_schema_specified(query, manifest)
dataframe_chunks, db_schema = config.db.execute_as_pandas(query, chunksize=chunksize)
path.mkdir(parents=True, exist_ok=True)
arrow_schema = pyarrow.schema(config.db.col_pyarrow_types_from_sql(db_schema))
with parquet.ParquetWriter(
f"{path}/{table.name}.{table.export_type}.parquet", arrow_schema
) as p_writer:
with csv.CSVWriter(
f"{path}/{table.name}.{table.export_type}.csv",
arrow_schema,
write_options=csv.WriteOptions(
# Note that this quoting style is not exactly csv.QUOTE_MINIMAL
# https://github.com/apache/arrow/issues/42032
quoting_style="needed"
),
) as c_writer:
for chunk in dataframe_chunks:
_write_chunk(p_writer, chunk, arrow_schema) # pragma: no cover
_write_chunk(c_writer, chunk, arrow_schema) # pragma: no cover
queries.append(query)
table.name = base_utils.update_query_if_schema_specified(table.name, manifest)
parquet_path = config.db.export_table_as_parquet(table.name, table.export_type, path)
if parquet_path:
df = pandas.read_parquet(parquet_path)
df.to_csv(
parquet_path.with_suffix(".csv"),
index=False,
)
else:
skipped_tables.append(table.name)

if len(skipped_tables) > 0:
c = console.Console()
c.print("The following tables were empty and were not exported:")
for table in skipped_tables:
c.print(table)
if archive:
base_utils.zip_dir(path, data_path, manifest.get_study_prefix())
return queries
69 changes: 41 additions & 28 deletions cumulus_library/databases/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import os
import pathlib

import awswrangler
import boto3
import botocore
import numpy
import pandas
import pyarrow
import pyathena
from pyathena.common import BaseCursor as AthenaCursor
from pyathena.pandas.cursor import PandasCursor as AthenaPandasCursor
Expand Down Expand Up @@ -96,33 +96,6 @@ def col_parquet_types_from_pandas(self, field_types: list) -> list:
)
return output

def col_pyarrow_types_from_sql(self, columns: list[tuple]) -> list:
output = []
for column in columns:
match column[1]:
case "varchar":
output.append((column[0], pyarrow.string()))
case "bigint":
output.append((column[0], pyarrow.int64()))
case "integer":
output.append((column[0], pyarrow.int64()))
case "double":
output.append((column[0], pyarrow.float64()))
# This is future proofing - we don't see this type currently.
case "decimal":
output.append( # pragma: no cover
(column[0], pyarrow.decimal128(column[4], column[5]))
)
case "boolean":
output.append((column[0], pyarrow.bool_()))
case "date":
output.append((column[0], pyarrow.date64()))
case "timestamp":
output.append((column[0], pyarrow.timestamp("s")))
case _:
raise errors.CumulusLibraryError(f"Unsupported SQL type '{column[1]}' found.")
return output

def upload_file(
self,
*,
Expand Down Expand Up @@ -168,6 +141,46 @@ def upload_file(
)
return f"s3://{bucket}/{s3_key}"

def export_table_as_parquet(
self, table_name: str, table_type: str, location: pathlib.Path, *args, **kwargs
) -> str | None:
s3_client = boto3.client("s3")
output_path = location / f"{table_name}.parquet"
workgroup = self.connection._client.get_work_group(WorkGroup=self.work_group)
wg_conf = workgroup["WorkGroup"]["Configuration"]["ResultConfiguration"]
s3_path = wg_conf["OutputLocation"]
bucket = "/".join(s3_path.split("/")[2:3])
output_path = location / f"{table_name}.{table_type}.parquet"
s3_path = f"s3://{bucket}/export/{table_name}.{table_type}.parquet"

# Cleanup location in case there was an error of some kind
res = s3_client.list_objects_v2(
Bucket=bucket, Prefix=f"export/{table_name}.{table_type}.parquet"
)
if "Contents" in res:
for file in res["Contents"]:
s3_client.delete_object(Bucket=bucket, Key=file["Key"])

self.connection.cursor().execute(f"""UNLOAD
(SELECT * FROM {table_name})
TO '{s3_path}'
WITH (format='PARQUET', compression='SNAPPY')
""") # noqa: S608
Comment on lines +166 to +170
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could? create a centralized jinja template for these DB specific queires. I don't know how much this export mechanism needs the injection protection, since it's not in queries that are being distributed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The table name comes from the manifest yeah? So it is user input, which does mean it's subject to chicanery, by the user or study author or malicious 3rd party app modifying files in $HOME... (btw: do we have any sanitizing of table names when reading the manifest?)

But I'm not overly stressed - I'll leave it up to you on risk assessment here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do have a regex validation for custom prefixes, but not for table names. if we did that, I'd be fine saying this is safe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but... on the other hand, a user can run a DROP TABLE query if they edit the manifest and it would be valid, so... maybe we're just not safe from a malicious user.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or a malicious script on the user's machine

# UNLOAD is not guaranteed to create a single file. AWS Wrangler's read_parquet
# allows us to ignore that wrinkle
try:
df = awswrangler.s3.read_parquet(s3_path)
except awswrangler.exceptions.NoFilesFound:
return None
df = df.sort_values(by=list(df.columns), ascending=False, na_position="first")
df.to_parquet(output_path)
res = s3_client.list_objects_v2(
Bucket=bucket, Prefix=f"export/{table_name}.{table_type}.parquet"
)
for file in res["Contents"]:
s3_client.delete_object(Bucket=bucket, Key=file["Key"])
return output_path

def create_schema(self, schema_name) -> None:
"""Creates a new schema object inside the database"""
glue_client = boto3.client("glue")
Expand Down
13 changes: 10 additions & 3 deletions cumulus_library/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,6 @@ def col_parquet_types_from_pandas(self, field_types: list) -> list:
# )
return []

def col_pyarrow_types_from_sql(self, columns: list[tuple]) -> list:
return columns # pragma: no cover

def upload_file(
self,
*,
Expand All @@ -208,6 +205,16 @@ def upload_file(
have an API for file upload (i.e. cloud databases)"""
return None

@abc.abstractmethod
def export_table_as_parquet(
self, table_name: str, table_type: str, location: pathlib.Path, *args, **kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: table_type seems like an odd addition here. In both implementations, it's just used to create the output file name, which felt like a bit of duplicated business logic.

What if you used a "input/output, did it write anything" pattern like (totally ignorable suggestion, just brainstorming):

def export_table_as_parquet(self, table_name: str, output_path: pathlib.Path) -> bool:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, you need the bare table name for the actual sql query, and then this concat with table_type to tell you if it's flat/cube/something else. I like that a :little: bit more than having a split for readability.

And I need the path downstream so that I have one place to handle the 'take the parquet and create a csv from it once it's downloaded' logic. So... I think I'm going to soft advocate for as is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, as discussed offline, this now takes a filename arg and returns a bool

) -> pathlib.Path | None:
"""Gets a parquet file from a specified table.

This is intended as a way to get the most database native parquet export possible,
so we don't have to infer schema information. Only do schema inferring if your
DB engine does not support parquet natively. If a table is empty, return None."""

@abc.abstractmethod
def create_schema(self, schema_name):
"""Creates a new schema object inside the catalog"""
Expand Down
49 changes: 23 additions & 26 deletions cumulus_library/databases/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@

import collections
import datetime
import pathlib
import re

import duckdb
import pandas
import pyarrow
import pyarrow.dataset

from cumulus_library import errors
from cumulus_library.databases import base


Expand Down Expand Up @@ -174,30 +173,6 @@ def execute_as_pandas(
return iter([result.df().convert_dtypes()]), result.description
return result.df().convert_dtypes(), result.description

def col_pyarrow_types_from_sql(self, columns: list[tuple]) -> list:
output = []
for column in columns:
match column[1]:
case "STRING":
output.append((column[0], pyarrow.string()))
case "INTEGER":
output.append((column[0], pyarrow.int64()))
case "NUMBER":
output.append((column[0], pyarrow.float64()))
case "DOUBLE":
output.append((column[0], pyarrow.float64()))
case "boolean" | "bool":
output.append((column[0], pyarrow.bool_()))
case "Date":
output.append((column[0], pyarrow.date64()))
case "TIMESTAMP" | "DATETIME":
output.append((column[0], pyarrow.timestamp("s")))
case _:
raise errors.CumulusLibraryError(
f"{column[0], column[1]} does not have a conversion type"
)
return output

def parser(self) -> base.DatabaseParser:
return DuckDbParser()

Expand All @@ -207,6 +182,28 @@ def operational_errors(self) -> tuple[type[Exception], ...]:
duckdb.BinderException,
)

def export_table_as_parquet(
self, table_name: str, table_type: str, location: pathlib.Path, *args, **kwargs
) -> str | None:
parquet_path = location / f"{table_name}.{table_type}.parquet"
parquet_path.parent.mkdir(exist_ok=True, parents=True)
table_size = self.connection.execute(f"SELECT count(*) FROM {table_name}").fetchone() # noqa: S608
if table_size[0] == 0:
return None
query = f"""COPY
(SELECT * FROM {table_name})
TO '{parquet_path}'
(FORMAT parquet)
""" # noqa: S608
self.connection.execute(query)

df = pandas.read_parquet(parquet_path)
df = df.sort_values(
by=list(df.columns), ascending=False, ignore_index=True, na_position="first"
)
df.to_parquet(parquet_path)
return parquet_path

def create_schema(self, schema_name):
"""Creates a new schema object inside the database"""
schemas = self.connection.sql(
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
name = "cumulus-library"
requires-python = ">= 3.11"
dependencies = [
"awswrangler >= 3.11, < 4",
"cumulus-fhir-support >= 1.3.1", # 1.3.1 fixes a "load all rows into memory" bug
"duckdb >= 1.1.3",
"Jinja2 > 3",
Expand Down
Loading
Loading