Skip to content

Commit

Permalink
feat: add PR review for nested struct in RexExpressionConverter (rebase)
Browse files Browse the repository at this point in the history
  • Loading branch information
MasseGuillaume committed Jul 14, 2023
1 parent 525f517 commit bc86b59
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ public FieldReference dereferenceStruct(int index) {
private FieldReference dereference(Type newType, ReferenceSegment nextSegment) {
return ImmutableFieldReference.builder()
.type(newType)
.addAllSegments(segments())
.addSegments(nextSegment)
.addAllSegments(segments())
.inputExpression(inputExpression())
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import io.substrait.type.Type;
import io.substrait.type.proto.ProtoTypeConverter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -67,7 +68,7 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc
"Unhandled type: " + segment.getReferenceTypeCase());
});
}

Collections.reverse(segments);
var fieldReference =
switch (reference.getRootTypeCase()) {
case EXPRESSION -> FieldReference.ofExpression(
Expand Down
162 changes: 49 additions & 113 deletions isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import static org.junit.jupiter.api.Assertions.assertEquals;

import com.google.protobuf.TextFormat;
import io.substrait.dsl.SubstraitBuilder;
import io.substrait.plan.ProtoPlanConverter;
import io.substrait.proto.Expression;
import io.substrait.proto.Plan;
Expand Down Expand Up @@ -38,17 +37,6 @@ RelDataType struct2(String field1, RelDataType value1, String field2, RelDataTyp
Arrays.asList(Pair.of(field1, value1), Pair.of(field2, value2)));
}

RelDataType struct3(
String field1,
RelDataType value1,
String field2,
RelDataType value2,
String field3,
RelDataType value3) {
return factory.createStructType(
Arrays.asList(Pair.of(field1, value1), Pair.of(field2, value2), Pair.of(field3, value3)));
}

RelDataType i32() {
return factory.createSqlType(SqlTypeName.INTEGER);
}
Expand All @@ -66,7 +54,8 @@ RelDataType map(RelDataType key, RelDataType value) {
}
}

private void test(Table table, String query, String expectedExpressionText) {
private void test(Table table, String query, String expectedExpressionText)
throws SqlParseException, IOException {
final Schema schema =
new AbstractSchema() {
@Override
Expand All @@ -75,27 +64,20 @@ protected Map<String, Table> getTableMap() {
}
};

try {
final SqlToSubstrait sqlToSubstrait = new SqlToSubstrait();
final SubstraitBuilder builder = new SubstraitBuilder(extensions);
Plan plan = sqlToSubstrait.execute(query, "nested", schema);
Expression obtainedExpression =
plan.getRelations(0).getRoot().getInput().getProject().getExpressions(0);
Expression expectedExpression = TextFormat.parse(expectedExpressionText, Expression.class);
assertEquals(expectedExpression, obtainedExpression);

ProtoPlanConverter converter = new ProtoPlanConverter();
io.substrait.plan.Plan plan2 = converter.from(plan);
assertPlanRoundrip(plan2);
} catch (IOException e) {
throw new RuntimeException(e);
} catch (SqlParseException e) {
throw new RuntimeException(e);
}
final SqlToSubstrait sqlToSubstrait = new SqlToSubstrait();
Plan plan = sqlToSubstrait.execute(query, "nested", schema);
Expression obtainedExpression =
plan.getRelations(0).getRoot().getInput().getProject().getExpressions(0);
Expression expectedExpression = TextFormat.parse(expectedExpressionText, Expression.class);
assertEquals(expectedExpression, obtainedExpression);

ProtoPlanConverter converter = new ProtoPlanConverter();
io.substrait.plan.Plan plan2 = converter.from(plan);
assertPlanRoundrip(plan2);
}

@Test
public void testNested0() {
public void testNestedStruct() throws SqlParseException, IOException {
final Table table =
new AbstractTable() {
@Override
Expand Down Expand Up @@ -131,7 +113,7 @@ public RelDataType getRowType(RelDataTypeFactory factory) {
}

@Test
public void testNested1() {
public void testNestedStruct2() throws SqlParseException, IOException {
final Table table =
new AbstractTable() {
@Override
Expand All @@ -156,10 +138,10 @@ public RelDataType getRowType(RelDataTypeFactory factory) {
selection {
direct_reference {
struct_field {
field: 0 # b
field: 1 # a
child {
struct_field {
field: 1 # a
field: 0 # b
}
}
}
Expand All @@ -172,13 +154,15 @@ public RelDataType getRowType(RelDataTypeFactory factory) {
}

@Test
public void testNested2() {
public void testNestedStruct3() throws SqlParseException, IOException {
final Table table =
new AbstractTable() {
@Override
public RelDataType getRowType(RelDataTypeFactory factory) {
var helper = new TypeHelper(factory);
return helper.struct("a", helper.struct("b", helper.struct("c", helper.i32())));
return helper.struct2(
"aa", helper.i32(),
"a", helper.struct("b", helper.struct("c", helper.i32())));
}
};

Expand All @@ -195,13 +179,13 @@ public RelDataType getRowType(RelDataTypeFactory factory) {
selection {
direct_reference {
struct_field {
field: 0 # c
field: 1 # a
child {
struct_field {
field: 0 # b
child: {
struct_field {
field: 0 # a
field: 0 # c
}
}
}
Expand All @@ -216,7 +200,7 @@ public RelDataType getRowType(RelDataTypeFactory factory) {
}

@Test
public void testList() {
public void testNestedList() throws SqlParseException, IOException {
final Table table =
new AbstractTable() {
@Override
Expand All @@ -239,11 +223,11 @@ public RelDataType getRowType(RelDataTypeFactory factory) {
"""
selection {
direct_reference {
list_element {
offset: 1
struct_field {
field: 1 # a
child {
struct_field {
field: 1 # a
list_element {
offset: 1
}
}
}
Expand All @@ -256,7 +240,7 @@ public RelDataType getRowType(RelDataTypeFactory factory) {
}

@Test
public void testList2() {
public void testNestedList2() throws SqlParseException, IOException {
final Table table =
new AbstractTable() {
@Override
Expand All @@ -283,17 +267,17 @@ public RelDataType getRowType(RelDataTypeFactory factory) {
"""
selection {
direct_reference {
list_element {
offset: 3
struct_field {
field: 1 # a
child {
list_element {
offset: 2
offset: 1
child {
list_element {
offset: 1
child: {
struct_field {
field: 1 # a
offset: 2
child {
list_element {
offset: 3
}
}
}
Expand All @@ -310,54 +294,7 @@ public RelDataType getRowType(RelDataTypeFactory factory) {
}

@Test
public void testMap() throws SqlParseException {
final Table table =
new AbstractTable() {
@Override
public RelDataType getRowType(RelDataTypeFactory factory) {
var helper = new TypeHelper(factory);
return helper.struct(
"a", helper.map(helper.string(), helper.struct("c", helper.i32())));
}
};

String query =
"""
SELECT
"nested"."my_table"."a"['foo']."c"
FROM
"nested"."my_table";
""";

String expectedExpressionText =
"""
selection {
direct_reference {
struct_field {
field: 0 # a
child: {
map_key: {
map_key: {
string: 'foo'
}
child: {
struct_field: {
field: 0 # c
}
}
}
}
}
}
root_reference: {}
}
""";

test(table, query, expectedExpressionText);
}

@Test
public void testProtobufDoc() throws SqlParseException {
public void testProtobufDoc() throws SqlParseException, IOException {

final Table table =
new AbstractTable() {
Expand Down Expand Up @@ -388,24 +325,24 @@ public RelDataType getRowType(RelDataTypeFactory factory) {
selection {
direct_reference {
struct_field {
field: 0 # .x
field: 0 # .a
child {
map_key {
map_key {
string: "my_map_key" # ['my_map_key']
}
struct_field {
field: 0 # .b
child {
struct_field {
field: 0 # .c
list_element {
offset: 2
child {
list_element {
offset: 2 # [2]
struct_field {
field: 0 # .c
child {
struct_field {
field: 0 # .b
map_key {
map_key {
string: "my_map_key" # ['my_map_key']
}
child {
struct_field {
field: 0 # .a
field: 0 # .x
}
}
}
Expand All @@ -418,8 +355,7 @@ public RelDataType getRowType(RelDataTypeFactory factory) {
}
}
}
root_reference {
}
root_reference {}
}
""";
test(table, query, expectedExpressionText);
Expand Down

0 comments on commit bc86b59

Please sign in to comment.