diff --git a/ibis_substrait/compiler/translate.py b/ibis_substrait/compiler/translate.py index ac620d54..ad4b6b0d 100644 --- a/ibis_substrait/compiler/translate.py +++ b/ibis_substrait/compiler/translate.py @@ -524,6 +524,9 @@ def value_op( ) +_bounds_enum = {"rows": 1, "range": 2} + + @translate.register(ops.WindowOp) # type: ignore def window_op( op: ops.WindowOp, # type: ignore @@ -537,6 +540,9 @@ def window_op( end = op.end func = op.func func_args = op.func.args + how = op.how + + bounds_type = _bounds_enum[how] lower_bound, upper_bound = _translate_window_bounds(start, end) @@ -558,6 +564,7 @@ def window_op( ], lower_bound=lower_bound, upper_bound=upper_bound, + bounds_type=bounds_type, ) ) diff --git a/ibis_substrait/tests/compiler/test_compiler.py b/ibis_substrait/tests/compiler/test_compiler.py index df44bc5e..0ea38243 100644 --- a/ibis_substrait/tests/compiler/test_compiler.py +++ b/ibis_substrait/tests/compiler/test_compiler.py @@ -584,3 +584,35 @@ def test_join_chain_indexing_in_group_by(compiler): .selection.direct_reference.struct_field.field == 7 ) + + +_window_hows = { + "unspecified": "BOUNDS_TYPE_UNSPECIFIED", + "range": "BOUNDS_TYPE_RANGE", + "rows": "BOUNDS_TYPE_ROWS", +} + + +@pytest.mark.parametrize( + "bounds", + [ + (-4, 2), + (1, 5), + (None, None), + (2, 4), + ], +) +@pytest.mark.parametrize("how", ["range", "rows"]) +def test_aggregation_window_how(t, compiler, bounds, how): + how_arg = {how: bounds} + expr = t.projection( + [t.full_name.length().mean().over(ibis.window(group_by="age", **how_arg))] + ) + result = translate(expr, compiler=compiler) + + bounds_type_int = result.project.expressions[0].window_function.bounds_type + + assert ( + stalg.Expression.WindowFunction.BoundsType.Name(bounds_type_int) + == _window_hows[how] + )