-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: properly support Options for Scalar/Aggregate/Window functions #278
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,4 +9,8 @@ public abstract class FunctionOption { | |
public abstract String getName(); | ||
|
||
public abstract List<String> values(); | ||
|
||
public static ImmutableFunctionOption.Builder builder() { | ||
return ImmutableFunctionOption.builder(); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good addition ✨ |
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,6 @@ | |
import io.substrait.expression.FunctionArg; | ||
import io.substrait.expression.FunctionOption; | ||
import io.substrait.expression.ImmutableExpression; | ||
import io.substrait.expression.ImmutableFunctionOption; | ||
import io.substrait.expression.WindowBound; | ||
import io.substrait.extension.ExtensionLookup; | ||
import io.substrait.extension.SimpleExtension; | ||
|
@@ -117,10 +116,15 @@ public Expression from(io.substrait.proto.Expression expr) { | |
IntStream.range(0, scalarFunction.getArgumentsCount()) | ||
.mapToObj(i -> pF.convert(declaration, i, scalarFunction.getArguments(i))) | ||
.collect(java.util.stream.Collectors.toList()); | ||
var options = | ||
scalarFunction.getOptionsList().stream() | ||
.map(ProtoExpressionConverter::fromFunctionOption) | ||
.collect(Collectors.toList()); | ||
yield ImmutableExpression.ScalarFunctionInvocation.builder() | ||
.addAllArguments(args) | ||
.declaration(declaration) | ||
.outputType(protoTypeConverter.from(scalarFunction.getOutputType())) | ||
.options(options) | ||
.build(); | ||
} | ||
case WINDOW_FUNCTION -> fromWindowFunction(expr.getWindowFunction()); | ||
|
@@ -241,8 +245,8 @@ public Expression.WindowFunctionInvocation fromWindowFunction( | |
.collect(Collectors.toList()); | ||
var options = | ||
windowFunction.getOptionsList().stream() | ||
.map(this::fromFunctionOption) | ||
.collect(Collectors.toMap(FunctionOption::getName, Function.identity())); | ||
.map(ProtoExpressionConverter::fromFunctionOption) | ||
.collect(Collectors.toList()); | ||
|
||
WindowBound lowerBound = toWindowBound(windowFunction.getLowerBound()); | ||
WindowBound upperBound = toWindowBound(windowFunction.getUpperBound()); | ||
|
@@ -276,8 +280,8 @@ public ConsistentPartitionWindow.WindowRelFunctionInvocation fromWindowRelFuncti | |
windowRelFunction::getArguments); | ||
var options = | ||
windowRelFunction.getOptionsList().stream() | ||
.map(this::fromFunctionOption) | ||
.collect(Collectors.toMap(FunctionOption::getName, Function.identity())); | ||
.map(ProtoExpressionConverter::fromFunctionOption) | ||
.collect(Collectors.toList()); | ||
|
||
WindowBound lowerBound = toWindowBound(windowRelFunction.getLowerBound()); | ||
WindowBound upperBound = toWindowBound(windowRelFunction.getUpperBound()); | ||
|
@@ -393,10 +397,7 @@ public Expression.SortField fromSortField(SortField s) { | |
.build(); | ||
} | ||
|
||
public FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) { | ||
return ImmutableFunctionOption.builder() | ||
.name(o.getName()) | ||
.addAllValues(o.getPreferenceList()) | ||
.build(); | ||
public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is somewhat awkward to make this static so it can be re-used in the |
||
return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build(); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,12 +3,14 @@ | |
import io.substrait.expression.AggregateFunctionInvocation; | ||
import io.substrait.expression.Expression; | ||
import io.substrait.expression.FunctionArg; | ||
import io.substrait.expression.FunctionOption; | ||
import io.substrait.expression.proto.ProtoExpressionConverter; | ||
import io.substrait.extension.ExtensionLookup; | ||
import io.substrait.extension.SimpleExtension; | ||
import io.substrait.type.proto.ProtoTypeConverter; | ||
import java.io.IOException; | ||
import java.util.List; | ||
import java.util.stream.Collectors; | ||
import java.util.stream.IntStream; | ||
|
||
/** | ||
|
@@ -37,7 +39,7 @@ public ProtoAggregateFunctionConverter( | |
this.protoExpressionConverter = protoExpressionConverter; | ||
} | ||
|
||
public io.substrait.relation.Aggregate.Measure from( | ||
public io.substrait.expression.AggregateFunctionInvocation from( | ||
io.substrait.proto.AggregateFunction measure) { | ||
FunctionArg.ProtoFrom protoFrom = | ||
new FunctionArg.ProtoFrom(protoExpressionConverter, protoTypeConverter); | ||
|
@@ -47,15 +49,17 @@ public io.substrait.relation.Aggregate.Measure from( | |
IntStream.range(0, measure.getArgumentsCount()) | ||
.mapToObj(i -> protoFrom.convert(aggregateFunction, i, measure.getArguments(i))) | ||
.collect(java.util.stream.Collectors.toList()); | ||
return Aggregate.Measure.builder() | ||
.function( | ||
AggregateFunctionInvocation.builder() | ||
.arguments(functionArgs) | ||
.declaration(aggregateFunction) | ||
.outputType(protoTypeConverter.from(measure.getOutputType())) | ||
.aggregationPhase(Expression.AggregationPhase.fromProto(measure.getPhase())) | ||
.invocation(Expression.AggregationInvocation.fromProto(measure.getInvocation())) | ||
.build()) | ||
List<FunctionOption> options = | ||
measure.getOptionsList().stream() | ||
.map(ProtoExpressionConverter::fromFunctionOption) | ||
.collect(Collectors.toList()); | ||
return AggregateFunctionInvocation.builder() | ||
.arguments(functionArgs) | ||
.declaration(aggregateFunction) | ||
.outputType(protoTypeConverter.from(measure.getOutputType())) | ||
.aggregationPhase(Expression.AggregationPhase.fromProto(measure.getPhase())) | ||
.invocation(Expression.AggregationInvocation.fromProto(measure.getInvocation())) | ||
.options(options) | ||
.build(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The other advantage of this is that by forcing callers to wrap this in a Measure themselves, it also makes it easier for them to include the PreMeasure filter when it's available. |
||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seemed weird to me, since the name is already kept inside the FunctionOption, so this just duplicates the source of truth. That said I don't mind reverting this if it's preferable to keep the map.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I imagine this is/was intended to allow for easy lookup of options given a name, so that consumers could look up the relevant options for a given function. However given that:
https://github.com/substrait-io/substrait/blob/7dbbf0468083d932a61b9c720700bd6083558fa9/proto/substrait/algebra.proto#L741-L744
It probably is safer to have this be a list and force consumers to process all of them.