From 8bc5997279b49dcce8980b0f31a3daff4f8b9bc8 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Thu, 18 Jul 2024 15:10:29 -0700 Subject: [PATCH] feat: check that VirtualTableScan field names correspond to schema --- .../io/substrait/relation/VirtualTableScan.java | 13 ++++++++++++- .../io/substrait/relation/VirtualTableScanTest.java | 4 ++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/VirtualTableScan.java b/core/src/main/java/io/substrait/relation/VirtualTableScan.java index 6272e25b4..c35dab8cb 100644 --- a/core/src/main/java/io/substrait/relation/VirtualTableScan.java +++ b/core/src/main/java/io/substrait/relation/VirtualTableScan.java @@ -24,13 +24,16 @@ public abstract class VirtualTableScan extends AbstractReadRel { @Value.Check protected void check() { var names = getInitialSchema().names(); + + assert names.size() + == NamedFieldCountingTypeVisitor.countNames(this.getInitialSchema().struct()); var rows = getRows(); assert rows.size() > 0 && names.stream().noneMatch(s -> s == null) && rows.stream().noneMatch(r -> r == null) && rows.stream() - .allMatch(r -> r.getType().accept(new NamedFieldCountingTypeVisitor()) == names.size()); + .allMatch(r -> NamedFieldCountingTypeVisitor.countNames(r.getType()) == names.size()); } @Override @@ -44,6 +47,14 @@ public static ImmutableVirtualTableScan.Builder builder() { private static class NamedFieldCountingTypeVisitor implements TypeVisitor { + + private static final NamedFieldCountingTypeVisitor VISITOR = + new NamedFieldCountingTypeVisitor(); + + private static Integer countNames(Type type) { + return type.accept(VISITOR); + } + @Override public Integer visit(Type.Bool type) throws RuntimeException { return 0; diff --git a/core/src/test/java/io/substrait/relation/VirtualTableScanTest.java b/core/src/test/java/io/substrait/relation/VirtualTableScanTest.java index 32781f791..ca6fceaa7 100644 --- a/core/src/test/java/io/substrait/relation/VirtualTableScanTest.java +++ b/core/src/test/java/io/substrait/relation/VirtualTableScanTest.java @@ -39,8 +39,8 @@ void check() { R.struct( R.STRING, R.struct(R.STRING, R.STRING), - R.list(R.STRING), - R.map(R.STRING, R.STRING)))) + R.list(R.struct(R.STRING)), + R.map(R.struct(R.STRING), R.struct(R.STRING))))) .addRows( struct( false,