Skip to content

Commit

Permalink
Implement csv support (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Apr 12, 2024
1 parent c899e0a commit 8a17b04
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 102 deletions.
41 changes: 24 additions & 17 deletions queries/dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,59 +24,66 @@


def read_ds(path: Path) -> DataFrame:
if settings.run.file_type != "parquet":
msg = f"unsupported file type: {settings.run.file_type!r}"
raise ValueError(msg)

if settings.run.include_io:
return dd.read_parquet(path, dtype_backend="pyarrow") # type: ignore[attr-defined,no-any-return]

# TODO: Load into memory before returning the Dask DataFrame.
# Code below is tripped up by date types
# df = pd.read_parquet(path, dtype_backend="pyarrow")
# return dd.from_pandas(df, npartitions=os.cpu_count())
msg = "cannot run Dask starting from an in-memory representation"
raise RuntimeError(msg)
if not settings.run.include_io:
msg = "cannot run Dask starting from an in-memory representation"
raise RuntimeError(msg)

path_str = f"{path}.{settings.run.file_type}"
if settings.run.file_type == "parquet":
return dd.read_parquet(path_str, dtype_backend="pyarrow") # type: ignore[attr-defined,no-any-return]
elif settings.run.file_type == "csv":
df = dd.read_csv(path_str, dtype_backend="pyarrow") # type: ignore[attr-defined]
for c in df.columns:
if c.endswith("date"):
df[c] = df[c].astype("date32[day][pyarrow]")
return df # type: ignore[no-any-return]
else:
msg = f"unsupported file type: {settings.run.file_type!r}"
raise ValueError(msg)


@on_second_call
def get_line_item_ds() -> DataFrame:
return read_ds(settings.dataset_base_dir / "lineitem.parquet")
return read_ds(settings.dataset_base_dir / "lineitem")


@on_second_call
def get_orders_ds() -> DataFrame:
return read_ds(settings.dataset_base_dir / "orders.parquet")
return read_ds(settings.dataset_base_dir / "orders")


@on_second_call
def get_customer_ds() -> DataFrame:
return read_ds(settings.dataset_base_dir / "customer.parquet")
return read_ds(settings.dataset_base_dir / "customer")


@on_second_call
def get_region_ds() -> DataFrame:
return read_ds(settings.dataset_base_dir / "region.parquet")
return read_ds(settings.dataset_base_dir / "region")


@on_second_call
def get_nation_ds() -> DataFrame:
return read_ds(settings.dataset_base_dir / "nation.parquet")
return read_ds(settings.dataset_base_dir / "nation")


@on_second_call
def get_supplier_ds() -> DataFrame:
return read_ds(settings.dataset_base_dir / "supplier.parquet")
return read_ds(settings.dataset_base_dir / "supplier")


@on_second_call
def get_part_ds() -> DataFrame:
return read_ds(settings.dataset_base_dir / "part.parquet")
return read_ds(settings.dataset_base_dir / "part")


@on_second_call
def get_part_supp_ds() -> DataFrame:
return read_ds(settings.dataset_base_dir / "partsupp.parquet")
return read_ds(settings.dataset_base_dir / "partsupp")


def run_query(query_number: int, query: Callable[..., Any]) -> None:
Expand Down
10 changes: 10 additions & 0 deletions queries/duckdb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ def _scan_ds(path: Path) -> str:
f"create temp table if not exists {name} as select * from read_parquet('{path_str}');"
)
return name
if settings.run.file_type == "csv":
if settings.run.include_io:
duckdb.read_csv(path_str)
return f"'{path_str}'"
else:
name = path_str.replace("/", "_").replace(".", "_").replace("-", "_")
duckdb.sql(
f"create temp table if not exists {name} as select * from read_csv('{path_str}');"
)
return name
elif settings.run.file_type == "feather":
msg = "duckdb does not support feather for now"
raise ValueError(msg)
Expand Down
6 changes: 6 additions & 0 deletions queries/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def _read_ds(path: Path) -> pd.DataFrame:
path_str = f"{path}.{settings.run.file_type}"
if settings.run.file_type == "parquet":
return pd.read_parquet(path_str, dtype_backend="pyarrow")
elif settings.run.file_type == "csv":
df = pd.read_csv(path_str, dtype_backend="pyarrow")
for c in df.columns:
if c.endswith("date"):
df[c] = df[c].astype("date32[day][pyarrow]") # type: ignore[call-overload]
return df
elif settings.run.file_type == "feather":
return pd.read_feather(path_str, dtype_backend="pyarrow")
else:
Expand Down
7 changes: 6 additions & 1 deletion queries/polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,21 @@

def _scan_ds(path: Path) -> pl.LazyFrame:
path_str = f"{path}.{settings.run.file_type}"

if settings.run.file_type == "parquet":
scan = pl.scan_parquet(path_str)
elif settings.run.file_type == "feather":
scan = pl.scan_ipc(path_str)
elif settings.run.file_type == "csv":
scan = pl.scan_csv(path_str, try_parse_dates=True)
else:
msg = f"unsupported file type: {settings.run.file_type!r}"
raise ValueError(msg)

if settings.run.include_io:
return scan
return scan.collect().rechunk().lazy()
else:
return scan.collect().rechunk().lazy()


def get_line_item_ds() -> pl.LazyFrame:
Expand Down
40 changes: 2 additions & 38 deletions queries/pyspark/executor.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,4 @@
from linetimer import CodeTimer

# TODO: works for now, but need dynamic imports for this.
from queries.pyspark import ( # noqa: F401
q1,
q2,
q3,
q4,
q5,
q6,
q7,
q8,
q9,
q10,
q11,
q12,
q13,
q14,
q15,
q16,
q17,
q18,
q19,
q20,
q21,
q22,
)
from queries.common_utils import execute_all

if __name__ == "__main__":
num_queries = 22

with CodeTimer(name="Overall execution of ALL spark queries", unit="s"):
for query_number in range(1, num_queries + 1):
submodule = f"q{query_number}"
try:
eval(f"{submodule}.q()")
except Exception as exc:
print(
f"Exception occurred while executing PySpark query {query_number}:\n{exc}"
)
execute_all("pyspark")
79 changes: 35 additions & 44 deletions queries/pyspark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,11 @@

from pyspark.sql import SparkSession

from queries.common_utils import (
check_query_result_pd,
on_second_call,
run_query_generic,
)
from queries.common_utils import check_query_result_pd, run_query_generic
from settings import Settings

if TYPE_CHECKING:
from pathlib import Path

from pyspark.sql import DataFrame as SparkDF
from pyspark.sql import DataFrame

settings = Settings()

Expand All @@ -31,62 +25,59 @@ def get_or_create_spark() -> SparkSession:
return spark


def _read_parquet_ds(path: Path, table_name: str) -> SparkDF:
df = get_or_create_spark().read.parquet(str(path))
df.createOrReplaceTempView(table_name)
return df
def _read_ds(table_name: str) -> DataFrame:
# TODO: Persist data in memory before query
if not settings.run.include_io:
msg = "cannot run PySpark starting from an in-memory representation"
raise RuntimeError(msg)

path = settings.dataset_base_dir / f"{table_name}.{settings.run.file_type}"

@on_second_call
def get_line_item_ds() -> SparkDF:
return _read_parquet_ds(settings.dataset_base_dir / "lineitem.parquet", "lineitem")
if settings.run.file_type == "parquet":
df = get_or_create_spark().read.parquet(str(path))
elif settings.run.file_type == "csv":
df = get_or_create_spark().read.csv(str(path), header=True, inferSchema=True)
else:
msg = f"unsupported file type: {settings.run.file_type!r}"
raise ValueError(msg)

df.createOrReplaceTempView(table_name)
return df


@on_second_call
def get_orders_ds() -> SparkDF:
return _read_parquet_ds(settings.dataset_base_dir / "orders.parquet", "orders")
def get_line_item_ds() -> DataFrame:
return _read_ds("lineitem")


@on_second_call
def get_customer_ds() -> SparkDF:
return _read_parquet_ds(settings.dataset_base_dir / "customer.parquet", "customer")
def get_orders_ds() -> DataFrame:
return _read_ds("orders")


@on_second_call
def get_region_ds() -> SparkDF:
return _read_parquet_ds(settings.dataset_base_dir / "region.parquet", "region")
def get_customer_ds() -> DataFrame:
return _read_ds("customer")


@on_second_call
def get_nation_ds() -> SparkDF:
return _read_parquet_ds(settings.dataset_base_dir / "nation.parquet", "nation")
def get_region_ds() -> DataFrame:
return _read_ds("region")


@on_second_call
def get_supplier_ds() -> SparkDF:
return _read_parquet_ds(settings.dataset_base_dir / "supplier.parquet", "supplier")
def get_nation_ds() -> DataFrame:
return _read_ds("nation")


@on_second_call
def get_part_ds() -> SparkDF:
return _read_parquet_ds(settings.dataset_base_dir / "part.parquet", "part")
def get_supplier_ds() -> DataFrame:
return _read_ds("supplier")


@on_second_call
def get_part_supp_ds() -> SparkDF:
return _read_parquet_ds(settings.dataset_base_dir / "partsupp.parquet", "partsupp")
def get_part_ds() -> DataFrame:
return _read_ds("part")


def drop_temp_view() -> None:
spark = get_or_create_spark()
[
spark.catalog.dropTempView(t.name)
for t in spark.catalog.listTables()
if t.isTemporary
]
def get_part_supp_ds() -> DataFrame:
return _read_ds("partsupp")


def run_query(query_number: int, df: SparkDF) -> None:
def run_query(query_number: int, df: DataFrame) -> None:
query = df.toPandas
run_query_generic(
query, query_number, "pyspark", query_checker=check_query_result_pd
Expand Down
3 changes: 2 additions & 1 deletion scripts/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
lf = lf.select(columns)

lf.sink_parquet(settings.dataset_base_dir / f"{table_name}.parquet")
lf.sink_csv(settings.dataset_base_dir / f"{table_name}.csv")

# IPC currently not relevant
# lf.sink_ipc(base_path / f"{table_name}.ipc")
# lf.sink_ipc(base_path / f"{table_name}.feather")
5 changes: 4 additions & 1 deletion settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from pathlib import Path
from typing import Literal, TypeAlias

from pydantic import computed_field
from pydantic_settings import BaseSettings, SettingsConfigDict

FileType: TypeAlias = Literal["parquet", "feather", "csv"]


class Paths(BaseSettings):
answers: Path = Path("data/answers")
Expand All @@ -20,7 +23,7 @@ class Paths(BaseSettings):

class Run(BaseSettings):
include_io: bool = False
file_type: str = "parquet"
file_type: FileType = "parquet"

log_timings: bool = False
show_results: bool = False
Expand Down

0 comments on commit 8a17b04

Please sign in to comment.