Skip to content

Commit

Permalink
fix(grouping): collect multiple grouping keys in single "groupings"
Browse files Browse the repository at this point in the history
Also updated our selection routine to properly handle the flattened
index count for projecting out of these groupbys.
  • Loading branch information
gforsyth committed Feb 6, 2024
1 parent d2c96a3 commit 30d35e3
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 88 deletions.
27 changes: 24 additions & 3 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,27 @@ def selection(
for relname, rel in rels:
if relname == "aggregate" and rel.measures:
mapping_counter = itertools.count(
len(rel.measures) + len(rel.groupings)
len(rel.measures)
# Individual groups can have multiple grouping expressions
# and we need to count all of these to properly index projections
# e.g. for the following query
# SELECT
# t0.b,
# t0.sum
# FROM (
# SELECT
# t1.a AS a,
# t1.b AS b,
# SUM(t1.c) AS sum
# FROM t AS t1
# GROUP BY
# t1.a,
# t1.b
# ) AS t0
#
# the two grouping keys (t1.a, t1.b) will be grouping
# expressions in the first (and only) group.
+ sum(len(group.grouping_expressions) for group in rel.groupings)
)
break
elif output_mapping := rel.common.emit.output_mapping:
Expand Down Expand Up @@ -1092,9 +1112,10 @@ def aggregation(
input=input,
groupings=[
stalg.AggregateRel.Grouping(
grouping_expressions=[translate(by, compiler=compiler, **kwargs)]
grouping_expressions=[
translate(by, compiler=compiler, **kwargs) for by in op.by
]
)
for by in op.by
],
measures=[
stalg.AggregateRel.Measure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,7 @@
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -567,11 +567,7 @@
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
Expand All @@ -581,11 +577,7 @@
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,9 @@
}
}
},
"groupings": [
{}
],
"measures": [
{
"measure": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1128,11 +1128,7 @@
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
Expand All @@ -1142,11 +1138,7 @@
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1037,11 +1037,7 @@
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1025,11 +1025,7 @@
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -651,11 +651,7 @@
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
Expand All @@ -665,11 +661,7 @@
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
Expand All @@ -679,11 +671,7 @@
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
Expand All @@ -693,11 +681,7 @@
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
Expand All @@ -707,11 +691,7 @@
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
Expand All @@ -721,11 +701,7 @@
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,23 +427,15 @@
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
"structField": {}
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
Expand All @@ -453,11 +445,7 @@
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
Expand All @@ -467,11 +455,7 @@
},
"rootReference": {}
}
}
]
},
{
"groupingExpressions": [
},
{
"selection": {
"directReference": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,9 @@
}
}
},
"groupings": [
{}
],
"measures": [
{
"measure": {
Expand Down
10 changes: 10 additions & 0 deletions ibis_substrait/tests/compiler/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,3 +525,13 @@ def test_aggregate_filter_select_output_mapping(compiler):
def test_filter_over_subquery(compiler):
t = ibis.table([("a", "int")], name="t").filter(_.a > _.a.mean())
translate(t, compiler=compiler)


def test_groupby_multiple_keys(compiler):
t = ibis.table(name="t", schema=(("a", "int"), ("b", "int")))
expr = t.group_by(["a", "b"]).agg()
plan = translate(expr, compiler=compiler)

# There should be one grouping with two separate expressions inside
assert len(plan.aggregate.groupings) == 1
assert len(plan.aggregate.groupings[0].grouping_expressions) == 2

0 comments on commit 30d35e3

Please sign in to comment.