From 0d52d4d65bdf5094e3c82ed3e1ff3948c16ab53f Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 2 Nov 2022 13:20:02 -0400 Subject: [PATCH] fix: cast floor and ceil to integer types to workaround acero Acero's return type for floor and ceil is float64, so we cast to ensure that the return type of the data matches the ibis expression type. --- ibis_substrait/compiler/translate.py | 42 ++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/ibis_substrait/compiler/translate.py b/ibis_substrait/compiler/translate.py index 8eb8745a..8ab04015 100644 --- a/ibis_substrait/compiler/translate.py +++ b/ibis_substrait/compiler/translate.py @@ -1287,7 +1287,7 @@ def _exists_subquery( predicates = [pred.op().to_expr() for pred in op.predicates] tuples = stalg.Rel( filter=stalg.FilterRel( - input=translate(op.foreign_table, compiler), + input=translate(op.foreign_table, compiler=compiler), condition=translate( functools.reduce(operator.and_, predicates), compiler, @@ -1318,10 +1318,10 @@ def _not_exists_subquery( predicates = [pred.op().to_expr() for pred in op.predicates] tuples = stalg.Rel( filter=stalg.FilterRel( - input=translate(op.foreign_table, compiler), + input=translate(op.foreign_table, compiler=compiler), condition=translate( functools.reduce(operator.and_, predicates), - compiler, + compiler=compiler, **kwargs, ), ) @@ -1345,3 +1345,39 @@ def _not_exists_subquery( ], ) ) + + +@translate.register(ops.Floor) +@translate.register(ops.Ceil) +def _floor_ceil_cast( + op: ops.Floor, + expr: ir.Column | None = None, + *, + compiler: SubstraitCompiler | None = None, + **kwargs: Any, +) -> stalg.Expression: + if compiler is None: + raise ValueError + output_type = translate(op.output_dtype) + input = stalg.Expression( + scalar_function=stalg.Expression.ScalarFunction( + function_reference=compiler.function_id( + expr=expr if expr is not None else op.to_expr() + ), + output_type=output_type, + arguments=[ + stalg.FunctionArgument( + value=translate(arg, compiler=compiler, **kwargs) + ) + for arg in op.args + if isinstance(arg, (ir.Expr, ops.Value)) + ], + ) + ) + return stalg.Expression( + cast=stalg.Expression.Cast( + type=output_type, + input=input, + failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_THROW_EXCEPTION, + ) + )