Skip to content

Commit

Permalink
fix: cast floor and ceil to integer types to workaround acero
Browse files Browse the repository at this point in the history
Acero's return type for floor and ceil is float64, so we cast
to ensure that the return type of the data matches the ibis expression
type.
  • Loading branch information
gforsyth committed Mar 10, 2023
1 parent 6b2138b commit 0d52d4d
Showing 1 changed file with 39 additions and 3 deletions.
42 changes: 39 additions & 3 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ def _exists_subquery(
predicates = [pred.op().to_expr() for pred in op.predicates]
tuples = stalg.Rel(
filter=stalg.FilterRel(
input=translate(op.foreign_table, compiler),
input=translate(op.foreign_table, compiler=compiler),
condition=translate(
functools.reduce(operator.and_, predicates),
compiler,
Expand Down Expand Up @@ -1318,10 +1318,10 @@ def _not_exists_subquery(
predicates = [pred.op().to_expr() for pred in op.predicates]
tuples = stalg.Rel(
filter=stalg.FilterRel(
input=translate(op.foreign_table, compiler),
input=translate(op.foreign_table, compiler=compiler),
condition=translate(
functools.reduce(operator.and_, predicates),
compiler,
compiler=compiler,
**kwargs,
),
)
Expand All @@ -1345,3 +1345,39 @@ def _not_exists_subquery(
],
)
)


@translate.register(ops.Floor)
@translate.register(ops.Ceil)
def _floor_ceil_cast(
op: ops.Floor,
expr: ir.Column | None = None,
*,
compiler: SubstraitCompiler | None = None,
**kwargs: Any,
) -> stalg.Expression:
if compiler is None:
raise ValueError
output_type = translate(op.output_dtype)
input = stalg.Expression(
scalar_function=stalg.Expression.ScalarFunction(
function_reference=compiler.function_id(
expr=expr if expr is not None else op.to_expr()
),
output_type=output_type,
arguments=[
stalg.FunctionArgument(
value=translate(arg, compiler=compiler, **kwargs)
)
for arg in op.args
if isinstance(arg, (ir.Expr, ops.Value))
],
)
)
return stalg.Expression(
cast=stalg.Expression.Cast(
type=output_type,
input=input,
failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_THROW_EXCEPTION,
)
)

0 comments on commit 0d52d4d

Please sign in to comment.