From 14b7a1b1a8efa7b9706d6e1fa1f5a5930078b05f Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Thu, 20 Jun 2024 18:38:19 +0200 Subject: [PATCH] fix: make SubstraitRelVisitor to use initialSchema, as well as two tests --- .../io/substrait/type/proto/AggregateRoundtripTest.java | 9 +++++++-- .../io/substrait/type/proto/ExtensionRoundtripTest.java | 1 + .../java/io/substrait/isthmus/SubstraitRelVisitor.java | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java index 90c68e1d8..b710bdeb9 100644 --- a/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/AggregateRoundtripTest.java @@ -13,6 +13,7 @@ import io.substrait.relation.ProtoRelConverter; import io.substrait.relation.RelProtoConverter; import io.substrait.relation.VirtualTableScan; +import io.substrait.type.NamedStruct; import io.substrait.type.TypeCreator; import java.io.IOException; import java.math.BigDecimal; @@ -25,8 +26,12 @@ public class AggregateRoundtripTest extends TestBase { private void assertAggregateRoundtrip(Expression.AggregationInvocation invocation) { var expression = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); Expression.StructLiteral literal = - ImmutableExpression.StructLiteral.builder().from(expression).build(); - var input = VirtualTableScan.builder().addRows(literal).build(); + ImmutableExpression.StructLiteral.builder().addFields(expression).build(); + var input = + VirtualTableScan.builder() + .initialSchema(NamedStruct.of(Arrays.asList("decimal"), R.struct(R.decimal(10, 2)))) + .addRows(literal) + .build(); ExtensionCollector functionCollector = new ExtensionCollector(); var to = new RelProtoConverter(functionCollector); var extensions = defaultExtensionCollection; diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index 33081c0df..a99f0ede2 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -74,6 +74,7 @@ protected void verifyRoundTrip(Rel rel) { void virtualTable() { Rel rel = VirtualTableScan.builder() + .initialSchema(NamedStruct.of(Collections.emptyList(), R.struct())) .addRows(Expression.StructLiteral.builder().fields(Collections.emptyList()).build()) .commonExtension(commonExtension) .extension(relExtension) diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index 5a998aead..cd10a7ef3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -136,7 +136,7 @@ public Rel visit(org.apache.calcite.rel.core.Values values) { return ExpressionCreator.struct(false, fields); }) .collect(Collectors.toUnmodifiableList()); - return VirtualTableScan.builder().addAllDfsNames(type.names()).addAllRows(structs).build(); + return VirtualTableScan.builder().initialSchema(type).addAllRows(structs).build(); } @Override