From bc86b59b09d9b81301a90eca535efa2cd7f6e242 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillaume=20Mass=C3=A9=20=28=E9=A9=AC=E8=B5=9B=E5=8D=AB=29?= Date: Thu, 13 Jul 2023 19:21:44 -0400 Subject: [PATCH] feat: add PR review for nested struct in RexExpressionConverter (rebase) --- .../substrait/expression/FieldReference.java | 2 +- .../proto/ProtoExpressionConverter.java | 3 +- .../isthmus/NestedStructQueryTest.java | 162 ++++++------------ 3 files changed, 52 insertions(+), 115 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/FieldReference.java b/core/src/main/java/io/substrait/expression/FieldReference.java index 90832dcc..c2a1856d 100644 --- a/core/src/main/java/io/substrait/expression/FieldReference.java +++ b/core/src/main/java/io/substrait/expression/FieldReference.java @@ -40,8 +40,8 @@ public FieldReference dereferenceStruct(int index) { private FieldReference dereference(Type newType, ReferenceSegment nextSegment) { return ImmutableFieldReference.builder() .type(newType) - .addAllSegments(segments()) .addSegments(nextSegment) + .addAllSegments(segments()) .inputExpression(inputExpression()) .build(); } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index dae1d758..411f69de 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -11,6 +11,7 @@ import io.substrait.type.Type; import io.substrait.type.proto.ProtoTypeConverter; import java.util.ArrayList; +import java.util.Collections; import java.util.Objects; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -67,7 +68,7 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc "Unhandled type: " + segment.getReferenceTypeCase()); }); } - + Collections.reverse(segments); var fieldReference = switch (reference.getRootTypeCase()) { case EXPRESSION -> FieldReference.ofExpression( diff --git a/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java b/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java index 18688d02..ad434dd4 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java @@ -3,7 +3,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import com.google.protobuf.TextFormat; -import io.substrait.dsl.SubstraitBuilder; import io.substrait.plan.ProtoPlanConverter; import io.substrait.proto.Expression; import io.substrait.proto.Plan; @@ -38,17 +37,6 @@ RelDataType struct2(String field1, RelDataType value1, String field2, RelDataTyp Arrays.asList(Pair.of(field1, value1), Pair.of(field2, value2))); } - RelDataType struct3( - String field1, - RelDataType value1, - String field2, - RelDataType value2, - String field3, - RelDataType value3) { - return factory.createStructType( - Arrays.asList(Pair.of(field1, value1), Pair.of(field2, value2), Pair.of(field3, value3))); - } - RelDataType i32() { return factory.createSqlType(SqlTypeName.INTEGER); } @@ -66,7 +54,8 @@ RelDataType map(RelDataType key, RelDataType value) { } } - private void test(Table table, String query, String expectedExpressionText) { + private void test(Table table, String query, String expectedExpressionText) + throws SqlParseException, IOException { final Schema schema = new AbstractSchema() { @Override @@ -75,27 +64,20 @@ protected Map getTableMap() { } }; - try { - final SqlToSubstrait sqlToSubstrait = new SqlToSubstrait(); - final SubstraitBuilder builder = new SubstraitBuilder(extensions); - Plan plan = sqlToSubstrait.execute(query, "nested", schema); - Expression obtainedExpression = - plan.getRelations(0).getRoot().getInput().getProject().getExpressions(0); - Expression expectedExpression = TextFormat.parse(expectedExpressionText, Expression.class); - assertEquals(expectedExpression, obtainedExpression); - - ProtoPlanConverter converter = new ProtoPlanConverter(); - io.substrait.plan.Plan plan2 = converter.from(plan); - assertPlanRoundrip(plan2); - } catch (IOException e) { - throw new RuntimeException(e); - } catch (SqlParseException e) { - throw new RuntimeException(e); - } + final SqlToSubstrait sqlToSubstrait = new SqlToSubstrait(); + Plan plan = sqlToSubstrait.execute(query, "nested", schema); + Expression obtainedExpression = + plan.getRelations(0).getRoot().getInput().getProject().getExpressions(0); + Expression expectedExpression = TextFormat.parse(expectedExpressionText, Expression.class); + assertEquals(expectedExpression, obtainedExpression); + + ProtoPlanConverter converter = new ProtoPlanConverter(); + io.substrait.plan.Plan plan2 = converter.from(plan); + assertPlanRoundrip(plan2); } @Test - public void testNested0() { + public void testNestedStruct() throws SqlParseException, IOException { final Table table = new AbstractTable() { @Override @@ -131,7 +113,7 @@ public RelDataType getRowType(RelDataTypeFactory factory) { } @Test - public void testNested1() { + public void testNestedStruct2() throws SqlParseException, IOException { final Table table = new AbstractTable() { @Override @@ -156,10 +138,10 @@ public RelDataType getRowType(RelDataTypeFactory factory) { selection { direct_reference { struct_field { - field: 0 # b + field: 1 # a child { struct_field { - field: 1 # a + field: 0 # b } } } @@ -172,13 +154,15 @@ public RelDataType getRowType(RelDataTypeFactory factory) { } @Test - public void testNested2() { + public void testNestedStruct3() throws SqlParseException, IOException { final Table table = new AbstractTable() { @Override public RelDataType getRowType(RelDataTypeFactory factory) { var helper = new TypeHelper(factory); - return helper.struct("a", helper.struct("b", helper.struct("c", helper.i32()))); + return helper.struct2( + "aa", helper.i32(), + "a", helper.struct("b", helper.struct("c", helper.i32()))); } }; @@ -195,13 +179,13 @@ public RelDataType getRowType(RelDataTypeFactory factory) { selection { direct_reference { struct_field { - field: 0 # c + field: 1 # a child { struct_field { field: 0 # b child: { struct_field { - field: 0 # a + field: 0 # c } } } @@ -216,7 +200,7 @@ public RelDataType getRowType(RelDataTypeFactory factory) { } @Test - public void testList() { + public void testNestedList() throws SqlParseException, IOException { final Table table = new AbstractTable() { @Override @@ -239,11 +223,11 @@ public RelDataType getRowType(RelDataTypeFactory factory) { """ selection { direct_reference { - list_element { - offset: 1 + struct_field { + field: 1 # a child { - struct_field { - field: 1 # a + list_element { + offset: 1 } } } @@ -256,7 +240,7 @@ public RelDataType getRowType(RelDataTypeFactory factory) { } @Test - public void testList2() { + public void testNestedList2() throws SqlParseException, IOException { final Table table = new AbstractTable() { @Override @@ -283,17 +267,17 @@ public RelDataType getRowType(RelDataTypeFactory factory) { """ selection { direct_reference { - list_element { - offset: 3 + struct_field { + field: 1 # a child { list_element { - offset: 2 + offset: 1 child { list_element { - offset: 1 - child: { - struct_field { - field: 1 # a + offset: 2 + child { + list_element { + offset: 3 } } } @@ -310,54 +294,7 @@ public RelDataType getRowType(RelDataTypeFactory factory) { } @Test - public void testMap() throws SqlParseException { - final Table table = - new AbstractTable() { - @Override - public RelDataType getRowType(RelDataTypeFactory factory) { - var helper = new TypeHelper(factory); - return helper.struct( - "a", helper.map(helper.string(), helper.struct("c", helper.i32()))); - } - }; - - String query = - """ - SELECT - "nested"."my_table"."a"['foo']."c" - FROM - "nested"."my_table"; - """; - - String expectedExpressionText = - """ - selection { - direct_reference { - struct_field { - field: 0 # a - child: { - map_key: { - map_key: { - string: 'foo' - } - child: { - struct_field: { - field: 0 # c - } - } - } - } - } - } - root_reference: {} - } - """; - - test(table, query, expectedExpressionText); - } - - @Test - public void testProtobufDoc() throws SqlParseException { + public void testProtobufDoc() throws SqlParseException, IOException { final Table table = new AbstractTable() { @@ -388,24 +325,24 @@ public RelDataType getRowType(RelDataTypeFactory factory) { selection { direct_reference { struct_field { - field: 0 # .x + field: 0 # .a child { - map_key { - map_key { - string: "my_map_key" # ['my_map_key'] - } + struct_field { + field: 0 # .b child { - struct_field { - field: 0 # .c + list_element { + offset: 2 child { - list_element { - offset: 2 # [2] + struct_field { + field: 0 # .c child { - struct_field { - field: 0 # .b + map_key { + map_key { + string: "my_map_key" # ['my_map_key'] + } child { struct_field { - field: 0 # .a + field: 0 # .x } } } @@ -418,8 +355,7 @@ public RelDataType getRowType(RelDataTypeFactory factory) { } } } - root_reference { - } + root_reference {} } """; test(table, query, expectedExpressionText);