Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: properly support Options for Scalar/Aggregate/Window functions #278

Merged
merged 3 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import io.substrait.extension.SimpleExtension;
import io.substrait.type.Type;
import java.util.List;
import java.util.Map;
import org.immutables.value.Value;

@Value.Immutable
Expand All @@ -12,7 +11,7 @@ public abstract class AggregateFunctionInvocation {

public abstract List<FunctionArg> arguments();

public abstract Map<String, FunctionOption> options();
Copy link
Contributor Author

@Blizzara Blizzara Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seemed weird to me, since the name is already kept inside the FunctionOption, so this just duplicates the source of truth. That said I don't mind reverting this if it's preferable to keep the map.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagine this is/was intended to allow for easy lookup of options given a name, so that consumers could look up the relevant options for a given function. However given that:

  // Name of the option to set. If the consumer does not recognize the
  // option, it must reject the plan. The name is matched case-insensitively
  // with option names defined for the function.

https://github.com/substrait-io/substrait/blob/7dbbf0468083d932a61b9c720700bd6083558fa9/proto/substrait/algebra.proto#L741-L744

It probably is safer to have this be a list and force consumers to process all of them.

public abstract List<FunctionOption> options();

public abstract Expression.AggregationPhase aggregationPhase();

Expand Down
4 changes: 2 additions & 2 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ abstract static class ScalarFunctionInvocation implements Expression {

public abstract List<FunctionArg> arguments();

public abstract Map<String, FunctionOption> options();
public abstract List<FunctionOption> options();

public abstract Type outputType();

Expand All @@ -620,7 +620,7 @@ abstract class WindowFunctionInvocation implements Expression {

public abstract List<FunctionArg> arguments();

public abstract Map<String, FunctionOption> options();
public abstract List<FunctionOption> options();

public abstract AggregationPhase aggregationPhase();

Expand Down
56 changes: 31 additions & 25 deletions core/src/main/java/io/substrait/expression/ExpressionCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
Expand Down Expand Up @@ -284,13 +285,13 @@ public static Expression.ScalarFunctionInvocation scalarFunction(
SimpleExtension.ScalarFunctionVariant declaration,
Type outputType,
FunctionArg... arguments) {
return Expression.ScalarFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.addArguments(arguments)
.build();
return scalarFunction(declaration, outputType, Arrays.asList(arguments));
}

/**
* Use {@link Expression.ScalarFunctionInvocation#builder()} directly to specify other parameters,
* e.g. options
*/
public static Expression.ScalarFunctionInvocation scalarFunction(
SimpleExtension.ScalarFunctionVariant declaration,
Type outputType,
Expand All @@ -302,6 +303,10 @@ public static Expression.ScalarFunctionInvocation scalarFunction(
.build();
}

/**
* Use {@link AggregateFunctionInvocation#builder()} directly to specify other parameters, e.g.
* options
*/
public static AggregateFunctionInvocation aggregateFunction(
SimpleExtension.AggregateFunctionVariant declaration,
Type outputType,
Expand All @@ -326,16 +331,14 @@ public static AggregateFunctionInvocation aggregateFunction(
List<Expression.SortField> sort,
Expression.AggregationInvocation invocation,
FunctionArg... arguments) {
return AggregateFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.aggregationPhase(phase)
.sort(sort)
.invocation(invocation)
.addArguments(arguments)
.build();
return aggregateFunction(
declaration, outputType, phase, sort, invocation, Arrays.asList(arguments));
}

/**
* Use {@link Expression.WindowFunctionInvocation#builder()} directly to specify other parameters,
* e.g. options
*/
public static Expression.WindowFunctionInvocation windowFunction(
SimpleExtension.WindowFunctionVariant declaration,
Type outputType,
Expand All @@ -361,6 +364,10 @@ public static Expression.WindowFunctionInvocation windowFunction(
.build();
}

/**
* Use {@link ConsistentPartitionWindow.WindowRelFunctionInvocation#builder()} directly to specify
* other parameters, e.g. options
*/
public static ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFunction(
SimpleExtension.WindowFunctionVariant declaration,
Type outputType,
Expand Down Expand Up @@ -393,18 +400,17 @@ public static Expression.WindowFunctionInvocation windowFunction(
WindowBound lowerBound,
WindowBound upperBound,
FunctionArg... arguments) {
return Expression.WindowFunctionInvocation.builder()
.declaration(declaration)
.outputType(outputType)
.aggregationPhase(phase)
.sort(sort)
.invocation(invocation)
.partitionBy(partitionBy)
.boundsType(boundsType)
.lowerBound(lowerBound)
.upperBound(upperBound)
.addArguments(arguments)
.build();
return windowFunction(
declaration,
outputType,
phase,
sort,
invocation,
partitionBy,
boundsType,
lowerBound,
upperBound,
Arrays.asList(arguments));
}

public static Expression cast(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@ public abstract class FunctionOption {
public abstract String getName();

public abstract List<String> values();

public static ImmutableFunctionOption.Builder builder() {
return ImmutableFunctionOption.builder();
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good addition ✨

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.substrait.extension.SimpleExtension;
import io.substrait.proto.Expression;
import io.substrait.proto.FunctionArgument;
import io.substrait.proto.FunctionOption;
import io.substrait.proto.Rel;
import io.substrait.proto.SortField;
import io.substrait.proto.Type;
Expand Down Expand Up @@ -314,10 +315,21 @@ public Expression visit(io.substrait.expression.Expression.ScalarFunctionInvocat
.addAllArguments(
expr.arguments().stream()
.map(a -> a.accept(expr.declaration(), 0, argVisitor))
.collect(java.util.stream.Collectors.toList()))
.addAllOptions(
expr.options().stream()
.map(ExpressionProtoConverter::from)
.collect(java.util.stream.Collectors.toList())))
.build();
}

public static FunctionOption from(io.substrait.expression.FunctionOption option) {
return FunctionOption.newBuilder()
.setName(option.getName())
.addAllPreference(option.values())
.build();
}

@Override
public Expression visit(io.substrait.expression.Expression.Cast expr) {
return Expression.newBuilder()
Expand Down Expand Up @@ -495,7 +507,11 @@ public Expression visit(io.substrait.expression.Expression.WindowFunctionInvocat
.addAllPartitions(partitionExprs)
.setBoundsType(expr.boundsType().toProto())
.setLowerBound(lowerBound)
.setUpperBound(upperBound))
.setUpperBound(upperBound)
.addAllOptions(
expr.options().stream()
.map(ExpressionProtoConverter::from)
.collect(java.util.stream.Collectors.toList())))
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import io.substrait.expression.FunctionArg;
import io.substrait.expression.FunctionOption;
import io.substrait.expression.ImmutableExpression;
import io.substrait.expression.ImmutableFunctionOption;
import io.substrait.expression.WindowBound;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.SimpleExtension;
Expand Down Expand Up @@ -117,10 +116,15 @@ public Expression from(io.substrait.proto.Expression expr) {
IntStream.range(0, scalarFunction.getArgumentsCount())
.mapToObj(i -> pF.convert(declaration, i, scalarFunction.getArguments(i)))
.collect(java.util.stream.Collectors.toList());
var options =
scalarFunction.getOptionsList().stream()
.map(ProtoExpressionConverter::fromFunctionOption)
.collect(Collectors.toList());
yield ImmutableExpression.ScalarFunctionInvocation.builder()
.addAllArguments(args)
.declaration(declaration)
.outputType(protoTypeConverter.from(scalarFunction.getOutputType()))
.options(options)
.build();
}
case WINDOW_FUNCTION -> fromWindowFunction(expr.getWindowFunction());
Expand Down Expand Up @@ -241,8 +245,8 @@ public Expression.WindowFunctionInvocation fromWindowFunction(
.collect(Collectors.toList());
var options =
windowFunction.getOptionsList().stream()
.map(this::fromFunctionOption)
.collect(Collectors.toMap(FunctionOption::getName, Function.identity()));
.map(ProtoExpressionConverter::fromFunctionOption)
.collect(Collectors.toList());

WindowBound lowerBound = toWindowBound(windowFunction.getLowerBound());
WindowBound upperBound = toWindowBound(windowFunction.getUpperBound());
Expand Down Expand Up @@ -276,8 +280,8 @@ public ConsistentPartitionWindow.WindowRelFunctionInvocation fromWindowRelFuncti
windowRelFunction::getArguments);
var options =
windowRelFunction.getOptionsList().stream()
.map(this::fromFunctionOption)
.collect(Collectors.toMap(FunctionOption::getName, Function.identity()));
.map(ProtoExpressionConverter::fromFunctionOption)
.collect(Collectors.toList());

WindowBound lowerBound = toWindowBound(windowRelFunction.getLowerBound());
WindowBound upperBound = toWindowBound(windowRelFunction.getUpperBound());
Expand Down Expand Up @@ -393,10 +397,7 @@ public Expression.SortField fromSortField(SortField s) {
.build();
}

public FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) {
return ImmutableFunctionOption.builder()
.name(o.getName())
.addAllValues(o.getPreferenceList())
.build();
public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is somewhat awkward to make this static so it can be re-used in the ProtoAggregateFunctionConverter, but I also can't think of a better place to put this.

return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,12 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp
break;
case MEASURE:
io.substrait.relation.Aggregate.Measure measure =
new ProtoAggregateFunctionConverter(
functionLookup, extensionCollection, protoExpressionConverter)
.from(expressionReference.getMeasure());
io.substrait.relation.Aggregate.Measure.builder()
vbarua marked this conversation as resolved.
Show resolved Hide resolved
.function(
new ProtoAggregateFunctionConverter(
functionLookup, extensionCollection, protoExpressionConverter)
.from(expressionReference.getMeasure()))
.build();
ImmutableAggregateFunctionReference buildMeasure =
ImmutableAggregateFunctionReference.builder()
.measure(measure)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import org.immutables.value.Value;

Expand Down Expand Up @@ -49,7 +48,7 @@ public abstract static class WindowRelFunctionInvocation {

public abstract List<FunctionArg> arguments();

public abstract Map<String, FunctionOption> options();
public abstract List<FunctionOption> options();

public abstract Type outputType();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.FunctionOption;
import io.substrait.expression.proto.ProtoExpressionConverter;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.SimpleExtension;
import io.substrait.type.proto.ProtoTypeConverter;
import java.io.IOException;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
Expand Down Expand Up @@ -37,7 +39,7 @@ public ProtoAggregateFunctionConverter(
this.protoExpressionConverter = protoExpressionConverter;
}

public io.substrait.relation.Aggregate.Measure from(
public io.substrait.expression.AggregateFunctionInvocation from(
io.substrait.proto.AggregateFunction measure) {
FunctionArg.ProtoFrom protoFrom =
new FunctionArg.ProtoFrom(protoExpressionConverter, protoTypeConverter);
Expand All @@ -47,15 +49,17 @@ public io.substrait.relation.Aggregate.Measure from(
IntStream.range(0, measure.getArgumentsCount())
.mapToObj(i -> protoFrom.convert(aggregateFunction, i, measure.getArguments(i)))
.collect(java.util.stream.Collectors.toList());
return Aggregate.Measure.builder()
.function(
AggregateFunctionInvocation.builder()
.arguments(functionArgs)
.declaration(aggregateFunction)
.outputType(protoTypeConverter.from(measure.getOutputType()))
.aggregationPhase(Expression.AggregationPhase.fromProto(measure.getPhase()))
.invocation(Expression.AggregationInvocation.fromProto(measure.getInvocation()))
.build())
List<FunctionOption> options =
measure.getOptionsList().stream()
.map(ProtoExpressionConverter::fromFunctionOption)
.collect(Collectors.toList());
return AggregateFunctionInvocation.builder()
.arguments(functionArgs)
.declaration(aggregateFunction)
.outputType(protoTypeConverter.from(measure.getOutputType()))
.aggregationPhase(Expression.AggregationPhase.fromProto(measure.getPhase()))
.invocation(Expression.AggregationInvocation.fromProto(measure.getInvocation()))
.options(options)
.build();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other advantage of this is that by forcing callers to wrap this in a Measure themselves, it also makes it easier for them to include the PreMeasure filter when it's available.

}
}
13 changes: 4 additions & 9 deletions core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package io.substrait.relation;

import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.ImmutableExpression;
Expand Down Expand Up @@ -392,6 +391,9 @@ private Aggregate newAggregate(AggregateRel rel) {
var input = from(rel.getInput());
var protoExprConverter =
new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this);
var protoAggrFuncConverter =
new ProtoAggregateFunctionConverter(lookup, extensions, protoExprConverter);

List<Aggregate.Grouping> groupings = new ArrayList<>(rel.getGroupingsCount());
for (var grouping : rel.getGroupingsList()) {
groupings.add(
Expand All @@ -413,14 +415,7 @@ private Aggregate newAggregate(AggregateRel rel) {
.collect(java.util.stream.Collectors.toList());
measures.add(
Aggregate.Measure.builder()
.function(
AggregateFunctionInvocation.builder()
.arguments(args)
.declaration(funcDecl)
.outputType(protoTypeConverter.from(func.getOutputType()))
.aggregationPhase(Expression.AggregationPhase.fromProto(func.getPhase()))
.invocation(Expression.AggregationInvocation.fromProto(func.getInvocation()))
.build())
.function(protoAggrFuncConverter.from(measure.getMeasure()))
.preMeasureFilter(
Optional.ofNullable(
measure.hasFilter() ? protoExprConverter.from(measure.getFilter()) : null))
Expand Down
Loading
Loading