Skip to content

Commit

Permalink
feat: remove hand-optimizations from queries
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 2, 2024
1 parent 700a440 commit 72044a6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 41 deletions.
53 changes: 18 additions & 35 deletions queries/pandas/q1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,51 +18,34 @@ def query() -> pd.DataFrame:
nonlocal lineitem
lineitem = lineitem()

lineitem_filtered = lineitem.loc[
:,
[
"l_quantity",
"l_extendedprice",
"l_discount",
"l_tax",
"l_returnflag",
"l_linestatus",
"l_shipdate",
"l_orderkey",
],
]
sel = lineitem_filtered.l_shipdate <= VAR1
lineitem_filtered = lineitem_filtered[sel]
lineitem_filtered["sum_qty"] = lineitem_filtered.l_quantity
lineitem_filtered["sum_base_price"] = lineitem_filtered.l_extendedprice
lineitem_filtered["avg_qty"] = lineitem_filtered.l_quantity
lineitem_filtered["avg_price"] = lineitem_filtered.l_extendedprice
lineitem_filtered["sum_disc_price"] = lineitem_filtered.l_extendedprice * (
sel = lineitem.l_shipdate <= VAR1
lineitem_filtered = lineitem[sel]

# This is lenient towards pandas as normally an optimizer should decide
# that this could be computed before the groupby aggregation.
# Other implementations don't enjoy this benefit.
lineitem_filtered["disc_price"] = lineitem_filtered.l_extendedprice * (
1 - lineitem_filtered.l_discount
)
lineitem_filtered["sum_charge"] = (
lineitem_filtered["charge"] = (
lineitem_filtered.l_extendedprice
* (1 - lineitem_filtered.l_discount)
* (1 + lineitem_filtered.l_tax)
)
lineitem_filtered["avg_disc"] = lineitem_filtered.l_discount
lineitem_filtered["count_order"] = lineitem_filtered.l_orderkey
gb = lineitem_filtered.groupby(["l_returnflag", "l_linestatus"])
gb = lineitem_filtered.groupby(["l_returnflag", "l_linestatus"], as_index=False)

total = gb.agg(
{
"sum_qty": "sum",
"sum_base_price": "sum",
"sum_disc_price": "sum",
"sum_charge": "sum",
"avg_qty": "mean",
"avg_price": "mean",
"avg_disc": "mean",
"count_order": "count",
}
sum_qty=pd.NamedAgg(column="l_quantity", aggfunc="sum"),
sum_base_price=pd.NamedAgg(column="l_extendedprice", aggfunc="sum"),
sum_disc_price=pd.NamedAgg(column="disc_price", aggfunc="sum"),
sum_charge=pd.NamedAgg(column="charge", aggfunc="sum"),
avg_qty=pd.NamedAgg(column="l_quantity", aggfunc="mean"),
avg_price=pd.NamedAgg(column="l_extendedprice", aggfunc="mean"),
avg_disc=pd.NamedAgg(column="l_discount", aggfunc="mean"),
count_order=pd.NamedAgg(column="l_orderkey", aggfunc="size"),
)

result_df = total.reset_index().sort_values(["l_returnflag", "l_linestatus"])
result_df = total.sort_values(["l_returnflag", "l_linestatus"])

return result_df # type: ignore[no-any-return]

Expand Down
13 changes: 8 additions & 5 deletions queries/pandas/q5.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,24 @@ def query() -> pd.DataFrame:
supplier_ds = supplier_ds()

rsel = region_ds.r_name == "ASIA"
osel = (orders_ds.o_orderdate >= date1) & (orders_ds.o_orderdate < date2)
forders = orders_ds[osel]
fregion = region_ds[rsel]
jn1 = fregion.merge(nation_ds, left_on="r_regionkey", right_on="n_regionkey")
jn1 = region_ds.merge(nation_ds, left_on="r_regionkey", right_on="n_regionkey")
jn2 = jn1.merge(customer_ds, left_on="n_nationkey", right_on="c_nationkey")
jn3 = jn2.merge(forders, left_on="c_custkey", right_on="o_custkey")
jn3 = jn2.merge(orders_ds, left_on="c_custkey", right_on="o_custkey")
jn4 = jn3.merge(line_item_ds, left_on="o_orderkey", right_on="l_orderkey")
jn5 = supplier_ds.merge(
jn4,
left_on=["s_suppkey", "s_nationkey"],
right_on=["l_suppkey", "n_nationkey"],
)
jn5["revenue"] = jn5.l_extendedprice * (1.0 - jn5.l_discount)
jn5 = jn5[
(jn5.o_orderdate >= date1)
& (jn5.o_orderdate < date2)
& (jn5.r_name == rsel)
]
gb = jn5.groupby("n_name", as_index=False)["revenue"].sum()
result_df = gb.sort_values("revenue", ascending=False)

return result_df # type: ignore[no-any-return]

utils.run_query(Q_NUM, query)
Expand Down
2 changes: 1 addition & 1 deletion queries/polars/q2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def q() -> None:
.filter(pl.col("p_size") == var_1)
.filter(pl.col("p_type").str.ends_with(var_2))
.filter(pl.col("r_name") == var_3)
).cache()
)

final_cols = [
"s_acctbal",
Expand Down

0 comments on commit 72044a6

Please sign in to comment.