Skip to content

Commit

Permalink
feature: Add support for nested struct in RexExpressionConverter
Browse files Browse the repository at this point in the history
also fix a bug in ExpressionProtoConverter where we don't keep any of the
segments of a nested structure
  • Loading branch information
MasseGuillaume committed Jul 12, 2023
1 parent fb1cc75 commit b4ea5cf
Show file tree
Hide file tree
Showing 3 changed files with 305 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ public Expression visit(io.substrait.expression.Expression.MultiOrList expr)

@Override
public Expression visit(FieldReference expr) {
Expression.ReferenceSegment top = null;

Expression.ReferenceSegment seg = null;
for (var segment : expr.segments()) {
Expression.ReferenceSegment.Builder protoSegment;
Expand All @@ -351,13 +351,11 @@ public Expression visit(FieldReference expr) {
throw new IllegalArgumentException("Unhandled type: " + segment);
}
var builtSegment = protoSegment.build();
if (top == null) {
top = builtSegment;
}
seg = builtSegment;
}

var out = Expression.FieldReference.newBuilder().setDirectReference(top);
var out = Expression.FieldReference.newBuilder().setDirectReference(seg);

if (expr.inputExpression().isPresent()) {
out.setExpression(from(expr.inputExpression().get()));
} else if (expr.outerReferenceStepsOut().isPresent()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
import java.util.stream.Collectors;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rex.*;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;

public class RexExpressionConverter implements RexVisitor<Expression> {

static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(RexExpressionConverter.class);

Expand Down Expand Up @@ -119,16 +121,25 @@ public Expression visitRangeRef(RexRangeRef rangeRef) {

@Override
public Expression visitFieldAccess(RexFieldAccess fieldAccess) {
if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {
int stepsOut = relVisitor.getFieldAccessDepth(fieldAccess);

return FieldReference.newRootStructOuterReference(
fieldAccess.getField().getIndex(),
typeConverter.toSubstrait(fieldAccess.getType()),
stepsOut);
SqlKind kind = fieldAccess.getReferenceExpr().getKind();
switch (kind) {
case CORREL_VARIABLE -> {
int stepsOut = relVisitor.getFieldAccessDepth(fieldAccess);

return FieldReference.newRootStructOuterReference(
fieldAccess.getField().getIndex(),
typeConverter.toSubstrait(fieldAccess.getType()),
stepsOut);
}
case ITEM, INPUT_REF, FIELD_ACCESS -> {
Expression expression = fieldAccess.getReferenceExpr().accept(this);
System.out.println(expression.getClass());
return FieldReference.newStructReference(
fieldAccess.getField().getIndex(),
expression);
}
default -> throw new UnsupportedOperationException(String.format("RexFieldAccess for SqlKind %s not supported", kind));
}
throw new UnsupportedOperationException(
"RexFieldAccess for other than RexCorrelVariable not supported");
}

@Override
Expand Down
282 changes: 282 additions & 0 deletions isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
package io.substrait.isthmus;

import io.substrait.dsl.SubstraitBuilder;
import io.substrait.plan.ProtoPlanConverter;
import io.substrait.proto.Plan;

import io.substrait.proto.Expression;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.schema.Schema;
import org.apache.calcite.schema.Table;
import org.apache.calcite.schema.impl.AbstractSchema;
import org.apache.calcite.schema.impl.AbstractTable;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Pair;

import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.function.Function;

public class NestedStructQueryTest extends PlanTestBase {
private class TypeHelper {
private final RelDataTypeFactory factory;

public TypeHelper(RelDataTypeFactory factory) {
this.factory = factory;
}

RelDataType struct(String field, RelDataType value) {
return factory.createStructType(Arrays.asList(Pair.of(field, value)));
}

RelDataType struct2(String field1, RelDataType value1, String field2, RelDataType value2) {
return factory.createStructType(
Arrays.asList(Pair.of(field1, value1), Pair.of(field2, value2)));
}

RelDataType struct3(
String field1,
RelDataType value1,
String field2,
RelDataType value2,
String field3,
RelDataType value3) {
return factory.createStructType(
Arrays.asList(Pair.of(field1, value1), Pair.of(field2, value2), Pair.of(field3, value3)));
}

RelDataType i32() {
return factory.createSqlType(SqlTypeName.INTEGER);
}

RelDataType string() {
return factory.createSqlType(SqlTypeName.VARCHAR);
}

RelDataType list(RelDataType elementType) {
return factory.createArrayType(elementType, -1);
}

RelDataType map(RelDataType key, RelDataType value) {
return factory.createMapType(key, value);
}
}

private void test(Table table, String query) {
final Schema schema =
new AbstractSchema() {
@Override
protected Map<String, Table> getTableMap() {
return Map.of("my_table", table);
}
};

try {
final SqlToSubstrait sqlToSubstrait = new SqlToSubstrait();
final SubstraitBuilder builder = new SubstraitBuilder(extensions);
Plan plan = sqlToSubstrait.execute(query, "nested", schema);
Expression expression =
plan.getRelations(0).getRoot().getInput().getProject().getExpressions(0);
System.out.println(expression.toString());
ProtoPlanConverter converter = new ProtoPlanConverter();
io.substrait.plan.Plan plan2 = converter.from(plan);
System.out.println(plan2);
assertPlanRoundrip(plan2);
} catch (IOException e) {
throw new RuntimeException(e);
} catch (SqlParseException e) {
throw new RuntimeException(e);
}
}

@Test
public void testNested1() {
final Table table =
new AbstractTable() {
@Override
public RelDataType getRowType(RelDataTypeFactory factory) {
var helper = new TypeHelper(factory);
return helper.struct2(
"x", helper.i32(),
"a", helper.struct("b", helper.i32()));
}
};

String query =
"""
SELECT
"nested"."my_table"."a"."b"
FROM
"nested"."my_table";
""";

test(table, query);
}

@Test
public void testNested2() {
final Table table =
new AbstractTable() {
@Override
public RelDataType getRowType(RelDataTypeFactory factory) {
var helper = new TypeHelper(factory);
return helper.struct("a", helper.struct("b", helper.struct("c", helper.i32())));
}
};

String query =
"""
SELECT
"nested"."my_table"."a"."b"."c"
FROM
"nested"."my_table";
""";

test(table, query);
}

@Test
public void testList() {
final Table table =
new AbstractTable() {
@Override
public RelDataType getRowType(RelDataTypeFactory factory) {
var helper = new TypeHelper(factory);

return helper.struct2("x", helper.i32(), "a", helper.list(helper.i32()));
}
};

String query =
"""
SELECT
"nested"."my_table"."a"[1]
FROM
"nested"."my_table";
""";

test(table, query);
}

@Test
public void testList2() {
final Table table =
new AbstractTable() {
@Override
public RelDataType getRowType(RelDataTypeFactory factory) {
var helper = new TypeHelper(factory);

return helper.struct2(
"x",
helper.i32(),
"a",
helper.list(helper.list(helper.list(helper.list(helper.i32())))));
}
};

String query =
"""
SELECT
"nested"."my_table"."a"[1][2][3]
FROM
"nested"."my_table";
""";

test(table, query);
}

@Test
public void testMap() throws SqlParseException {
final Table table =
new AbstractTable() {
@Override
public RelDataType getRowType(RelDataTypeFactory factory) {
var helper = new TypeHelper(factory);
return helper.struct(
"a", helper.map(helper.string(), helper.struct("c", helper.i32())));
}
};

String query =
"""
SELECT
"nested"."my_table"."a"['foo']."c"
FROM
"nested"."my_table";
""";

test(table, query);
}

@Test
public void testFull() throws SqlParseException {

final Table table =
new AbstractTable() {
@Override
public RelDataType getRowType(RelDataTypeFactory factory) {
var helper = new TypeHelper(factory);
return helper.struct3(
"aa", helper.i32(),
"a",
helper.struct3(
"b", helper.i32(),
"c", helper.list(helper.struct("d", helper.i32())),
"g", helper.i32()),
"h",
helper.map(
helper.string(),
helper.map(helper.string(), helper.struct("x", helper.i32()))));
}
};

// todo: a.b[2].c['my_map_key'].x

String query =
"""
SELECT
"nested"."my_table"."a"."b",
"nested"."my_table"."a"."c"[1]."d",
"nested"."my_table"."h"['key1']['key2']."x"
FROM
"nested"."my_table";
""";

test(table, query);
}

@Test
public void testProtobufDoc() throws SqlParseException {

final Table table =
new AbstractTable() {
@Override
public RelDataType getRowType(RelDataTypeFactory factory) {

var helper = new TypeHelper(factory);
return helper.struct(
"a",
helper.struct(
"b",
helper.list(
helper.struct(
"c", helper.map(helper.string(), helper.struct("x", helper.i32()))))));
}
};

String query =
"""
SELECT
"nested"."my_table".a.b[2].c['my_map_key'].x
FROM
"nested"."my_table";
""";

test(table, query);
}
}

0 comments on commit b4ea5cf

Please sign in to comment.