Skip to content

Commit

Permalink
feat: add ExpandRel support to core and spark
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Coleman <[email protected]>
  • Loading branch information
andrew-coleman committed Sep 26, 2024
1 parent 79f3779 commit 35fde68
Show file tree
Hide file tree
Showing 19 changed files with 359 additions and 35 deletions.
18 changes: 18 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.substrait.plan.Plan;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
import io.substrait.relation.Expand;
import io.substrait.relation.Fetch;
import io.substrait.relation.Filter;
import io.substrait.relation.Join;
Expand Down Expand Up @@ -313,6 +314,23 @@ private Project project(
return Project.builder().input(input).expressions(expressions).remap(remap).build();
}

public Expand expand(Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn, Rel input) {
return expand(fieldsFn, Optional.empty(), input);
}

public Expand expand(
Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn, Rel.Remap remap, Rel input) {
return expand(fieldsFn, Optional.of(remap), input);
}

private Expand expand(
Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn,
Optional<Rel.Remap> remap,
Rel input) {
var fields = fieldsFn.apply(input);
return Expand.builder().input(input).fields(fields).remap(remap).build();
}

public Set set(Set.SetOp op, Rel... inputs) {
return set(op, Optional.empty(), inputs);
}
Expand Down
23 changes: 23 additions & 0 deletions core/src/main/java/io/substrait/hint/Hint.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.substrait.hint;

import io.substrait.proto.RelCommon;
import java.util.List;
import java.util.Optional;
import org.immutables.value.Value;

@Value.Immutable
public abstract class Hint {
public abstract Optional<String> getAlias();

public abstract List<String> getOutputNames();

public RelCommon.Hint toProto() {
var builder = RelCommon.Hint.newBuilder().addAllOutputNames(getOutputNames());
getAlias().ifPresent(builder::setAlias);
return builder.build();
}

public static ImmutableHint.Builder builder() {
return ImmutableHint.builder();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ public OUTPUT visit(Project project) throws EXCEPTION {
return visitFallback(project);
}

@Override
public OUTPUT visit(Expand expand) throws EXCEPTION {
return visitFallback(expand);
}

@Override
public OUTPUT visit(Sort sort) throws EXCEPTION {
return visitFallback(sort);
Expand Down
62 changes: 62 additions & 0 deletions core/src/main/java/io/substrait/relation/Expand.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package io.substrait.relation;

import io.substrait.expression.Expression;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;
import java.util.stream.Stream;
import org.immutables.value.Value;

@Value.Enclosing
@Value.Immutable
public abstract class Expand extends SingleInputRel {
static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(Expand.class);

public abstract List<ExpandField> getFields();

@Override
public Type.Struct deriveRecordType() {
Type.Struct initial = getInput().getRecordType();
return TypeCreator.of(initial.nullable())
.struct(Stream.concat(initial.fields().stream(), Stream.of(TypeCreator.REQUIRED.I64)));
}

@Override
public <O, E extends Exception> O accept(RelVisitor<O, E> visitor) throws E {
return visitor.visit(this);
}

public static ImmutableExpand.Builder builder() {
return ImmutableExpand.builder();
}

public interface ExpandField {
Type getType();
}

@Value.Immutable
public abstract static class ConsistentField implements ExpandField {
public abstract Expression getExpression();

public Type getType() {
return getExpression().getType();
}

public static ImmutableExpand.ConsistentField.Builder builder() {
return ImmutableExpand.ConsistentField.builder();
}
}

@Value.Immutable
public abstract static class SwitchingField implements ExpandField {
public abstract List<Expression> getDuplicates();

public Type getType() {
return getDuplicates().get(0).getType();
}

public static ImmutableExpand.SwitchingField.Builder builder() {
return ImmutableExpand.SwitchingField.builder();
}
}
}
52 changes: 51 additions & 1 deletion core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import io.substrait.extension.AdvancedExtension;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.SimpleExtension;
import io.substrait.hint.Hint;
import io.substrait.proto.AggregateRel;
import io.substrait.proto.ConsistentPartitionWindowRel;
import io.substrait.proto.CrossRel;
import io.substrait.proto.ExpandRel;
import io.substrait.proto.ExtensionLeafRel;
import io.substrait.proto.ExtensionMultiRel;
import io.substrait.proto.ExtensionSingleRel;
Expand Down Expand Up @@ -87,6 +89,9 @@ public Rel from(io.substrait.proto.Rel rel) {
case PROJECT -> {
return newProject(rel.getProject());
}
case EXPAND -> {
return newExpand(rel.getExpand());
}
case CROSS -> {
return newCross(rel.getCross());
}
Expand Down Expand Up @@ -155,7 +160,10 @@ protected Filter newFilter(FilterRel rel) {
}

protected NamedStruct newNamedStruct(ReadRel rel) {
var namedStruct = rel.getBaseSchema();
return newNamedStruct(rel.getBaseSchema());
}

protected NamedStruct newNamedStruct(io.substrait.proto.NamedStruct namedStruct) {
var struct = namedStruct.getStruct();
return ImmutableNamedStruct.builder()
.names(namedStruct.getNamesList())
Expand Down Expand Up @@ -389,6 +397,38 @@ protected Project newProject(ProjectRel rel) {
return builder.build();
}

protected Expand newExpand(ExpandRel rel) {
var input = from(rel.getInput());
var converter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this);
var builder =
Expand.builder()
.input(input)
.fields(
rel.getFieldsList().stream()
.map(
expandField ->
switch (expandField.getFieldTypeCase()) {
case CONSISTENT_FIELD -> Expand.ConsistentField.builder()
.expression(converter.from(expandField.getConsistentField()))
.build();
case SWITCHING_FIELD -> Expand.SwitchingField.builder()
.duplicates(
expandField.getSwitchingField().getDuplicatesList().stream()
.map(converter::from)
.collect(java.util.stream.Collectors.toList()))
.build();
case FIELDTYPE_NOT_SET -> throw new UnsupportedOperationException(
"Expand fields not set");
})
.collect(java.util.stream.Collectors.toList()));

builder
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()))
.hint(optionalHint(rel.getCommon()));
return builder.build();
}

protected Aggregate newAggregate(AggregateRel rel) {
var input = from(rel.getInput());
var protoExprConverter =
Expand Down Expand Up @@ -647,6 +687,16 @@ protected static Optional<Rel.Remap> optionalRelmap(io.substrait.proto.RelCommon
relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null);
}

protected static Optional<Hint> optionalHint(io.substrait.proto.RelCommon relCommon) {
if (!relCommon.hasHint()) return Optional.empty();
var hint = relCommon.getHint();
var builder = Hint.builder().addAllOutputNames(hint.getOutputNamesList());
if (!hint.getAlias().isEmpty()) {
builder.alias(hint.getAlias());
}
return Optional.of(builder.build());
}

protected Optional<AdvancedExtension> optionalAdvancedExtension(
io.substrait.proto.RelCommon relCommon) {
return Optional.ofNullable(
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/java/io/substrait/relation/Rel.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.relation;

import io.substrait.extension.AdvancedExtension;
import io.substrait.hint.Hint;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;
Expand All @@ -21,6 +22,8 @@ public interface Rel {

List<Rel> getInputs();

Optional<Hint> getHint();

@Value.Immutable
public abstract static class Remap {
public abstract List<Integer> indices();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ public Optional<Rel> visit(Project project) throws EXCEPTION {
.build());
}

@Override
public Optional<Rel> visit(Expand expand) throws EXCEPTION {
throw new UnsupportedOperationException();
}

@Override
public Optional<Rel> visit(Sort sort) throws EXCEPTION {
var input = sort.getInput().accept(this);
Expand Down
Loading

0 comments on commit 35fde68

Please sign in to comment.