Skip to content

Commit

Permalink
Merge pull request #384 from deployment-gap-model-education-fund/dev
Browse files Browse the repository at this point in the history
`dev` -> `main` for `v2024.11.21`
  • Loading branch information
bendnorman authored Nov 21, 2024
2 parents 336aa94 + 8b494a3 commit 79b9ea8
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 15 deletions.
5 changes: 3 additions & 2 deletions src/dbcp/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A Command line interface for the down ballot project."""

import argparse
import logging
import sys
Expand Down Expand Up @@ -66,9 +67,9 @@ def main():
SPATIAL_CACHE.clear()

if args.etl:
dbcp.etl.etl(args)
dbcp.etl.etl()
if args.data_mart:
dbcp.data_mart.create_data_marts(args)
dbcp.data_mart.create_data_marts()


if __name__ == "__main__":
Expand Down
15 changes: 10 additions & 5 deletions src/dbcp/data_mart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import pkgutil

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

import dbcp
from dbcp.constants import OUTPUT_DIR
Expand All @@ -15,7 +17,7 @@
logger = logging.getLogger(__name__)


def create_data_marts(args): # noqa: max-complexity=11
def create_data_marts(): # noqa: max-complexity=11
"""Collect and load all data mart tables to data warehouse."""
engine = dbcp.helpers.get_sql_engine()
data_marts = {}
Expand Down Expand Up @@ -64,8 +66,8 @@ def create_data_marts(args): # noqa: max-complexity=11
with engine.connect() as con:
for table in metadata.sorted_tables:
logger.info(f"Load {table.name} to postgres.")
df = enforce_dtypes(data_marts[table.name], table.name, "data_mart")
df = dbcp.helpers.trim_columns_length(df)
df = dbcp.helpers.trim_columns_length(data_marts[table.name])
df = enforce_dtypes(df, table.name, "data_mart")
df.to_sql(
name=table.name,
con=con,
Expand All @@ -74,7 +76,10 @@ def create_data_marts(args): # noqa: max-complexity=11
schema="data_mart",
method=psql_insert_copy,
)

df.to_parquet(parquet_dir / f"{table.name}.parquet", index=False)
schema = dbcp.helpers.get_pyarrow_schema_from_metadata(
table.name, "data_mart"
)
pa_table = pa.Table.from_pandas(df, schema=schema)
pq.write_table(pa_table, parquet_dir / f"{table.name}.parquet")

validate_data_mart(engine=engine)
11 changes: 9 additions & 2 deletions src/dbcp/etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import Callable, Dict

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import sqlalchemy as sa

import dbcp
Expand Down Expand Up @@ -225,12 +227,17 @@ def run_etl(funcs: dict[str, Callable], schema_name: str):
chunksize=1000,
method=psql_insert_copy,
)
df.to_parquet(parquet_dir / f"{table.name}.parquet", index=False)

schema = dbcp.helpers.get_pyarrow_schema_from_metadata(
table.name, schema_name
)
pa_table = pa.Table.from_pandas(df, schema=schema)
pq.write_table(pa_table, parquet_dir / f"{table.name}.parquet")

logger.info("Sucessfully finished ETL.")


def etl(args):
def etl():
"""Run dbc ETL."""
# Reduce size of caches if necessary
GEOCODER_CACHE.reduce_size()
Expand Down
44 changes: 38 additions & 6 deletions src/dbcp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import google.auth
import pandas as pd
import pandas_gbq
import pyarrow as pa
import sqlalchemy as sa
from google.cloud import bigquery
from tqdm import tqdm
Expand All @@ -35,6 +36,14 @@
"BOOLEAN": "boolean",
"DATETIME": "datetime64[ns]",
}
SA_TO_PA_TYPES = {
"VARCHAR": pa.string(),
"INTEGER": pa.int64(),
"BIGINT": pa.int64(),
"FLOAT": pa.float64(),
"BOOLEAN": pa.bool_(),
"DATETIME": pa.timestamp("ms"),
}
SA_TO_BQ_MODES = {True: "NULLABLE", False: "REQUIRED"}


Expand Down Expand Up @@ -81,6 +90,25 @@ def get_bq_schema_from_metadata(
return bq_schema


def get_pyarrow_schema_from_metadata(table_name: str, schema: str) -> pa.Schema:
"""
Create a PyArrow schema from SQL Alchemy metadata.
Args:
table_name: the name of the table.
schema: the name of the database schema.
Returns:
pyarrow_schema: a PyArrow schema description.
"""
table_name = f"{schema}.{table_name}"
metadata = get_schema_sql_alchemy_metadata(schema)
table_sa = metadata.tables[table_name]
pyarrow_schema = []
for column in table_sa.columns:
pyarrow_schema.append((column.name, SA_TO_PA_TYPES[str(column.type)]))
return pa.schema(pyarrow_schema)


def enforce_dtypes(df: pd.DataFrame, table_name: str, schema: str):
"""Apply dtypes to a dataframe using the sqlalchemy metadata."""
table_name = f"{schema}.{table_name}"
Expand All @@ -90,12 +118,16 @@ def enforce_dtypes(df: pd.DataFrame, table_name: str, schema: str):
except KeyError:
raise KeyError(f"{table_name} does not exist in metadata.")

dtypes = {
col.name: SA_TO_PD_TYPES[str(col.type)]
for col in table.columns
if col.name in df.columns
}
return df.astype(dtypes)
for col in table.columns:
# Add the column if it doesn't exist
if col.name not in df.columns:
df[col.name] = None
df[col.name] = df[col.name].astype(SA_TO_PD_TYPES[str(col.type)])

# convert datetime[ns] columns to milliseconds
for col in df.select_dtypes(include=["datetime64[ns]"]).columns:
df[col] = df[col].dt.floor("ms")
return df


def get_sql_engine() -> sa.engine.Engine:
Expand Down

0 comments on commit 79b9ea8

Please sign in to comment.