Skip to content

Commit

Permalink
fix: properly support Options for Scalar/Aggregate/Window functions
Browse files Browse the repository at this point in the history
https://substrait.io/expressions/scalar_functions/#options
Options were supported in the Java classes but in most cases dropped when converting into protobuf

This
- fixes the protobuf conversion,
- adds a builder for FunctionOption
- adds a ExpressionCreator functions with options
- does some refactoring
  • Loading branch information
Blizzara committed Jul 2, 2024
1 parent 3e553ee commit 984583b
Show file tree
Hide file tree
Showing 13 changed files with 171 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import io.substrait.extension.SimpleExtension;
import io.substrait.type.Type;
import java.util.List;
import java.util.Map;
import org.immutables.value.Value;

@Value.Immutable
Expand All @@ -12,7 +11,7 @@ public abstract class AggregateFunctionInvocation {

public abstract List<FunctionArg> arguments();

public abstract Map<String, FunctionOption> options();
public abstract List<FunctionOption> options();

public abstract Expression.AggregationPhase aggregationPhase();

Expand Down
4 changes: 2 additions & 2 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ abstract static class ScalarFunctionInvocation implements Expression {

public abstract List<FunctionArg> arguments();

public abstract Map<String, FunctionOption> options();
public abstract List<FunctionOption> options();

public abstract Type outputType();

Expand All @@ -620,7 +620,7 @@ abstract class WindowFunctionInvocation implements Expression {

public abstract List<FunctionArg> arguments();

public abstract Map<String, FunctionOption> options();
public abstract List<FunctionOption> options();

public abstract AggregationPhase aggregationPhase();

Expand Down
113 changes: 88 additions & 25 deletions core/src/main/java/io/substrait/expression/ExpressionCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
Expand Down Expand Up @@ -284,20 +285,25 @@ public static Expression.ScalarFunctionInvocation scalarFunction(
SimpleExtension.ScalarFunctionVariant declaration,
Type outputType,
FunctionArg... arguments) {
return Expression.ScalarFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.addArguments(arguments)
.build();
return scalarFunction(declaration, outputType, Arrays.asList(), Arrays.asList(arguments));
}

public static Expression.ScalarFunctionInvocation scalarFunction(
SimpleExtension.ScalarFunctionVariant declaration,
Type outputType,
Iterable<? extends FunctionArg> arguments) {
return scalarFunction(declaration, outputType, Arrays.asList(), arguments);
}

public static Expression.ScalarFunctionInvocation scalarFunction(
SimpleExtension.ScalarFunctionVariant declaration,
Type outputType,
List<? extends FunctionOption> options,
Iterable<? extends FunctionArg> arguments) {
return Expression.ScalarFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.options(options)
.addAllArguments(arguments)
.build();
}
Expand All @@ -309,12 +315,25 @@ public static AggregateFunctionInvocation aggregateFunction(
List<Expression.SortField> sort,
Expression.AggregationInvocation invocation,
Iterable<? extends FunctionArg> arguments) {
return aggregateFunction(
declaration, outputType, phase, sort, invocation, Arrays.asList(), arguments);
}

public static AggregateFunctionInvocation aggregateFunction(
SimpleExtension.AggregateFunctionVariant declaration,
Type outputType,
Expression.AggregationPhase phase,
List<Expression.SortField> sort,
Expression.AggregationInvocation invocation,
List<? extends FunctionOption> options,
Iterable<? extends FunctionArg> arguments) {
return AggregateFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.aggregationPhase(phase)
.sort(sort)
.invocation(invocation)
.addAllOptions(options)
.addAllArguments(arguments)
.build();
}
Expand All @@ -326,14 +345,8 @@ public static AggregateFunctionInvocation aggregateFunction(
List<Expression.SortField> sort,
Expression.AggregationInvocation invocation,
FunctionArg... arguments) {
return AggregateFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.aggregationPhase(phase)
.sort(sort)
.invocation(invocation)
.addArguments(arguments)
.build();
return aggregateFunction(
declaration, outputType, phase, sort, invocation, Arrays.asList(arguments));
}

public static Expression.WindowFunctionInvocation windowFunction(
Expand All @@ -347,6 +360,32 @@ public static Expression.WindowFunctionInvocation windowFunction(
WindowBound lowerBound,
WindowBound upperBound,
Iterable<? extends FunctionArg> arguments) {
return windowFunction(
declaration,
outputType,
phase,
sort,
invocation,
partitionBy,
boundsType,
lowerBound,
upperBound,
Arrays.asList(),
arguments);
}

public static Expression.WindowFunctionInvocation windowFunction(
SimpleExtension.WindowFunctionVariant declaration,
Type outputType,
Expression.AggregationPhase phase,
List<Expression.SortField> sort,
Expression.AggregationInvocation invocation,
List<Expression> partitionBy,
Expression.WindowBoundsType boundsType,
WindowBound lowerBound,
WindowBound upperBound,
List<? extends FunctionOption> options,
Iterable<? extends FunctionArg> arguments) {
return Expression.WindowFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
Expand All @@ -357,6 +396,7 @@ public static Expression.WindowFunctionInvocation windowFunction(
.lowerBound(lowerBound)
.upperBound(upperBound)
.invocation(invocation)
.addAllOptions(options)
.addAllArguments(arguments)
.build();
}
Expand All @@ -370,6 +410,28 @@ public static ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFun
WindowBound lowerBound,
WindowBound upperBound,
Iterable<? extends FunctionArg> arguments) {
return windowRelFunction(
declaration,
outputType,
phase,
invocation,
boundsType,
lowerBound,
upperBound,
Arrays.asList(),
arguments);
}

public static ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFunction(
SimpleExtension.WindowFunctionVariant declaration,
Type outputType,
Expression.AggregationPhase phase,
Expression.AggregationInvocation invocation,
Expression.WindowBoundsType boundsType,
WindowBound lowerBound,
WindowBound upperBound,
List<? extends FunctionOption> options,
Iterable<? extends FunctionArg> arguments) {
return ConsistentPartitionWindow.WindowRelFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
Expand All @@ -379,6 +441,7 @@ public static ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFun
.upperBound(upperBound)
.invocation(invocation)
.addAllArguments(arguments)
.addAllOptions(options)
.build();
}

Expand All @@ -393,18 +456,18 @@ public static Expression.WindowFunctionInvocation windowFunction(
WindowBound lowerBound,
WindowBound upperBound,
FunctionArg... arguments) {
return Expression.WindowFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.aggregationPhase(phase)
.sort(sort)
.invocation(invocation)
.partitionBy(partitionBy)
.boundsType(boundsType)
.lowerBound(lowerBound)
.upperBound(upperBound)
.addArguments(arguments)
.build();
return windowFunction(
declaration,
outputType,
phase,
sort,
invocation,
partitionBy,
boundsType,
lowerBound,
upperBound,
Arrays.asList(),
Arrays.asList(arguments));
}

public static Expression cast(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@ public abstract class FunctionOption {
public abstract String getName();

public abstract List<String> values();

public static ImmutableFunctionOption.Builder builder() {
return ImmutableFunctionOption.builder();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.substrait.extension.SimpleExtension;
import io.substrait.proto.Expression;
import io.substrait.proto.FunctionArgument;
import io.substrait.proto.FunctionOption;
import io.substrait.proto.Rel;
import io.substrait.proto.SortField;
import io.substrait.proto.Type;
Expand Down Expand Up @@ -314,10 +315,21 @@ public Expression visit(io.substrait.expression.Expression.ScalarFunctionInvocat
.addAllArguments(
expr.arguments().stream()
.map(a -> a.accept(expr.declaration(), 0, argVisitor))
.collect(java.util.stream.Collectors.toList()))
.addAllOptions(
expr.options().stream()
.map(ExpressionProtoConverter::from)
.collect(java.util.stream.Collectors.toList())))
.build();
}

public static FunctionOption from(io.substrait.expression.FunctionOption option) {
return FunctionOption.newBuilder()
.setName(option.getName())
.addAllPreference(option.values())
.build();
}

@Override
public Expression visit(io.substrait.expression.Expression.Cast expr) {
return Expression.newBuilder()
Expand Down Expand Up @@ -495,7 +507,11 @@ public Expression visit(io.substrait.expression.Expression.WindowFunctionInvocat
.addAllPartitions(partitionExprs)
.setBoundsType(expr.boundsType().toProto())
.setLowerBound(lowerBound)
.setUpperBound(upperBound))
.setUpperBound(upperBound)
.addAllOptions(
expr.options().stream()
.map(ExpressionProtoConverter::from)
.collect(java.util.stream.Collectors.toList())))
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import io.substrait.expression.FunctionArg;
import io.substrait.expression.FunctionOption;
import io.substrait.expression.ImmutableExpression;
import io.substrait.expression.ImmutableFunctionOption;
import io.substrait.expression.WindowBound;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.SimpleExtension;
Expand Down Expand Up @@ -117,10 +116,15 @@ public Expression from(io.substrait.proto.Expression expr) {
IntStream.range(0, scalarFunction.getArgumentsCount())
.mapToObj(i -> pF.convert(declaration, i, scalarFunction.getArguments(i)))
.collect(java.util.stream.Collectors.toList());
var options =
scalarFunction.getOptionsList().stream()
.map(ProtoExpressionConverter::fromFunctionOption)
.collect(Collectors.toList());
yield ImmutableExpression.ScalarFunctionInvocation.builder()
.addAllArguments(args)
.declaration(declaration)
.outputType(protoTypeConverter.from(scalarFunction.getOutputType()))
.options(options)
.build();
}
case WINDOW_FUNCTION -> fromWindowFunction(expr.getWindowFunction());
Expand Down Expand Up @@ -241,8 +245,8 @@ public Expression.WindowFunctionInvocation fromWindowFunction(
.collect(Collectors.toList());
var options =
windowFunction.getOptionsList().stream()
.map(this::fromFunctionOption)
.collect(Collectors.toMap(FunctionOption::getName, Function.identity()));
.map(ProtoExpressionConverter::fromFunctionOption)
.collect(Collectors.toList());

WindowBound lowerBound = toWindowBound(windowFunction.getLowerBound());
WindowBound upperBound = toWindowBound(windowFunction.getUpperBound());
Expand Down Expand Up @@ -276,8 +280,8 @@ public ConsistentPartitionWindow.WindowRelFunctionInvocation fromWindowRelFuncti
windowRelFunction::getArguments);
var options =
windowRelFunction.getOptionsList().stream()
.map(this::fromFunctionOption)
.collect(Collectors.toMap(FunctionOption::getName, Function.identity()));
.map(ProtoExpressionConverter::fromFunctionOption)
.collect(Collectors.toList());

WindowBound lowerBound = toWindowBound(windowRelFunction.getLowerBound());
WindowBound upperBound = toWindowBound(windowRelFunction.getUpperBound());
Expand Down Expand Up @@ -393,10 +397,7 @@ public Expression.SortField fromSortField(SortField s) {
.build();
}

public FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) {
return ImmutableFunctionOption.builder()
.name(o.getName())
.addAllValues(o.getPreferenceList())
.build();
public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) {
return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,12 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp
break;
case MEASURE:
io.substrait.relation.Aggregate.Measure measure =
new ProtoAggregateFunctionConverter(
functionLookup, extensionCollection, protoExpressionConverter)
.from(expressionReference.getMeasure());
io.substrait.relation.Aggregate.Measure.builder()
.function(
new ProtoAggregateFunctionConverter(
functionLookup, extensionCollection, protoExpressionConverter)
.from(expressionReference.getMeasure()))
.build();
ImmutableAggregateFunctionReference buildMeasure =
ImmutableAggregateFunctionReference.builder()
.measure(measure)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import org.immutables.value.Value;

Expand Down Expand Up @@ -49,7 +48,7 @@ public abstract static class WindowRelFunctionInvocation {

public abstract List<FunctionArg> arguments();

public abstract Map<String, FunctionOption> options();
public abstract List<FunctionOption> options();

public abstract Type outputType();

Expand Down
Loading

0 comments on commit 984583b

Please sign in to comment.