From f25c7366d78ff5741eaf8a142d2dfd38976a4f08 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Sun, 27 Oct 2024 17:15:47 +0100 Subject: [PATCH 01/10] fix: converting proto to pojo should take into account join type for column matching --- core/src/main/java/io/substrait/relation/Join.java | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/java/io/substrait/relation/Join.java b/core/src/main/java/io/substrait/relation/Join.java index c3f9387b3..403663471 100644 --- a/core/src/main/java/io/substrait/relation/Join.java +++ b/core/src/main/java/io/substrait/relation/Join.java @@ -59,6 +59,7 @@ protected Type.Struct deriveRecordType() { switch (getJoinType()) { case LEFT, OUTER -> getRight().getRecordType().fields().stream() .map(TypeCreator::asNullable); + case SEMI, ANTI -> Stream.of(); // these are left joins which ignore right side columns default -> getRight().getRecordType().fields().stream(); }; return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); From 16a7694e0e2df9636489bd498bdd371a0a8cb1cd Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Sun, 27 Oct 2024 17:16:39 +0100 Subject: [PATCH 02/10] fix: support treestring for VirtualTableScan --- .../io/substrait/debug/RelToVerboseString.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala index 79b34462f..0ef4a5310 100644 --- a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala +++ b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala @@ -100,6 +100,7 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { builder.append("commonExtension=").append(commonExtension) }) } + override def visit(namedScan: NamedScan): String = { withBuilder(namedScan, 10)( builder => { @@ -115,6 +116,21 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } + override def visit(virtualTableScan: VirtualTableScan): String = { + withBuilder(virtualTableScan, 10)( + builder => { + fillReadRel(virtualTableScan, builder) + builder.append(", ") + builder.append("rows=").append(virtualTableScan.getRows) + + virtualTableScan.getExtension.ifPresent( + extension => { + builder.append(", ") + builder.append("extension=").append(extension) + }) + }) + } + override def visit(emptyScan: EmptyScan): String = { withBuilder(emptyScan, 10)( builder => { From 44c83dd4a0e1b95acdf1b03238b0c7b2f5148da3 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Sun, 27 Oct 2024 17:16:57 +0100 Subject: [PATCH 03/10] fix: correctly set nullability for aggregate references --- .../main/scala/io/substrait/spark/logical/ToSubstraitRel.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index f00ebba53..d43c30cb5 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -134,7 +134,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal) val aggOutputMap = aggregates.zipWithIndex.map { case (e, i) => - AttributeReference(s"agg_func_$i", e.dataType)() -> e + AttributeReference(s"agg_func_$i", e.dataType, nullable = e.nullable)() -> e } val aggOutput = aggOutputMap.map(_._1) From 7fe9564d648407847a7f0d75e9a1a2abdb1bd7d3 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Sun, 27 Oct 2024 21:21:08 -0400 Subject: [PATCH 04/10] fix: correctly set nullability for aggregate grouping exprs --- .../main/scala/io/substrait/spark/logical/ToSubstraitRel.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index d43c30cb5..8313d9a94 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -148,7 +148,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { } val groupOutputMap = actualGroupExprs.zipWithIndex.map { case (e, i) => - AttributeReference(s"group_col_$i", e.dataType)() -> e + AttributeReference(s"group_col_$i", e.dataType, nullable = e.nullable)() -> e } val groupOutput = groupOutputMap.map(_._1) From 88d95d129bd952322b62ce8fa0c8b20ca5474595 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Sun, 27 Oct 2024 21:23:51 -0400 Subject: [PATCH 05/10] fix: correctly set type for scalar subquery when converting proto to pojo --- .../expression/proto/ProtoExpressionConverter.java | 11 ++++++++++- .../substrait/type/proto/ExtensionRoundtripTest.java | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index d2b95d74f..8857614f8 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -15,6 +15,7 @@ import io.substrait.relation.ConsistentPartitionWindow; import io.substrait.relation.ProtoRelConverter; import io.substrait.type.Type; +import io.substrait.type.TypeVisitor; import io.substrait.type.proto.ProtoTypeConverter; import java.util.ArrayList; import java.util.Collections; @@ -196,7 +197,15 @@ public Expression from(io.substrait.proto.Expression expr) { var rel = protoRelConverter.from(expr.getSubquery().getScalar().getInput()); yield ImmutableExpression.ScalarSubquery.builder() .input(rel) - .type(rel.getRecordType()) + .type(rel.getRecordType().accept(new TypeVisitor.TypeThrowsVisitor("Expected struct field") { + @Override + public Type visit(Type.Struct type) throws RuntimeException { + if (type.fields().size() != 1) { + throw new UnsupportedOperationException("Scalar subquery must have exactly one field"); + } + return type.fields().get(0); + } + })) .build(); } case IN_PREDICATE -> { diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index eb2c05b29..7b2237661 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -312,7 +312,7 @@ void scalarSubquery() { Stream.of( Expression.ScalarSubquery.builder() .input(relWithEnhancement) - .type(TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.I64)) + .type(TypeCreator.REQUIRED.I64) .build()) .collect(Collectors.toList()), commonTable); From 8c9c67744b2875f3fba3fe100ff85be9d5e42362 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Thu, 21 Nov 2024 18:27:02 +0100 Subject: [PATCH 06/10] fix: spotless --- .../proto/ProtoExpressionConverter.java | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 8857614f8..3120c8941 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -197,15 +197,20 @@ public Expression from(io.substrait.proto.Expression expr) { var rel = protoRelConverter.from(expr.getSubquery().getScalar().getInput()); yield ImmutableExpression.ScalarSubquery.builder() .input(rel) - .type(rel.getRecordType().accept(new TypeVisitor.TypeThrowsVisitor("Expected struct field") { - @Override - public Type visit(Type.Struct type) throws RuntimeException { - if (type.fields().size() != 1) { - throw new UnsupportedOperationException("Scalar subquery must have exactly one field"); - } - return type.fields().get(0); - } - })) + .type( + rel.getRecordType() + .accept( + new TypeVisitor.TypeThrowsVisitor( + "Expected struct field") { + @Override + public Type visit(Type.Struct type) throws RuntimeException { + if (type.fields().size() != 1) { + throw new UnsupportedOperationException( + "Scalar subquery must have exactly one field"); + } + return type.fields().get(0); + } + })) .build(); } case IN_PREDICATE -> { From 00eb7507683f45a4880396de8223ae188e509a3e Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Thu, 21 Nov 2024 18:48:42 +0100 Subject: [PATCH 07/10] fix: add assert to check the pojo-proto-roundtrip --- .../test/scala/io/substrait/spark/SubstraitPlanTestBase.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala index cbd7a151c..dea8f3a53 100644 --- a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala +++ b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala @@ -95,7 +95,8 @@ trait SubstraitPlanTestBase { self: SharedSparkSession => val extensionCollector = new ExtensionCollector; val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel) - new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto) + val pojoFromProto = new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto) + assertResult(pojoRel)(pojoFromProto) pojoRel2.shouldEqualPlainly(pojoRel) logicalPlan2 From f6494e2486825ab08d66ea143876626a3de05ed3 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Thu, 21 Nov 2024 18:49:46 +0100 Subject: [PATCH 08/10] fix: handle fetch's count in a way that matches roundtrip --- .../scala/io/substrait/spark/logical/ToLogicalPlan.scala | 2 +- .../io/substrait/spark/logical/ToSubstraitRel.scala | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 058fbe393..606c85c51 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -218,7 +218,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] override def visit(fetch: relation.Fetch): LogicalPlan = { val child = fetch.getInput.accept(this) - val limit = fetch.getCount.getAsLong.intValue() + val limit = fetch.getCount.orElse(-1).intValue() val offset = fetch.getOffset.intValue() val toLiteral = (i: Int) => Literal(i, IntegerType) if (limit >= 0) { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 8313d9a94..f09571f27 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -212,12 +212,15 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { } private def fetch(child: LogicalPlan, offset: Long, limit: Long = -1): relation.Fetch = { - relation.Fetch + val builder = relation.Fetch .builder() .input(visit(child)) .offset(offset) - .count(limit) - .build() + if (limit != -1) { + builder.count(limit) + } + + builder.build() } override def visitGlobalLimit(p: GlobalLimit): relation.Rel = { From fb361979be185cb41b68a974af7d0e8adbb9f624 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Thu, 21 Nov 2024 18:50:07 +0100 Subject: [PATCH 09/10] fix: spotless --- .../main/scala/io/substrait/spark/logical/ToSubstraitRel.scala | 2 +- .../test/scala/io/substrait/spark/SubstraitPlanTestBase.scala | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index f09571f27..0b4315a63 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -220,7 +220,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { builder.count(limit) } - builder.build() + builder.build() } override def visitGlobalLimit(p: GlobalLimit): relation.Rel = { diff --git a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala index dea8f3a53..0571ee544 100644 --- a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala +++ b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala @@ -95,7 +95,8 @@ trait SubstraitPlanTestBase { self: SharedSparkSession => val extensionCollector = new ExtensionCollector; val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel) - val pojoFromProto = new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto) + val pojoFromProto = + new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto) assertResult(pojoRel)(pojoFromProto) pojoRel2.shouldEqualPlainly(pojoRel) From f3fee70ab54e9041820a8165992cf825e4f3140b Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Wed, 5 Mar 2025 16:18:30 +0100 Subject: [PATCH 10/10] fix: set proto-to-rel type nullability handling --- .../main/java/io/substrait/relation/Set.java | 59 ++++++++++++++++++- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/Set.java b/core/src/main/java/io/substrait/relation/Set.java index 4efef01e2..d542848ff 100644 --- a/core/src/main/java/io/substrait/relation/Set.java +++ b/core/src/main/java/io/substrait/relation/Set.java @@ -2,11 +2,15 @@ import io.substrait.proto.SetRel; import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; import org.immutables.value.Value; @Value.Immutable public abstract class Set extends AbstractRel implements HasExtension { - public abstract Set.SetOp getSetOp(); + public abstract SetOp getSetOp(); public static enum SetOp { UNKNOWN(SetRel.SetOp.SET_OP_UNSPECIFIED), @@ -29,7 +33,7 @@ public SetRel.SetOp toProto() { return proto; } - public static Set.SetOp fromProto(SetRel.SetOp proto) { + public static SetOp fromProto(SetRel.SetOp proto) { for (var v : values()) { if (v.proto == proto) { return v; @@ -42,7 +46,56 @@ public static Set.SetOp fromProto(SetRel.SetOp proto) { @Override protected Type.Struct deriveRecordType() { - return getInputs().get(0).getRecordType(); + // We intentionally don't validate that the types match to maintain backwards + // compatibility. We should, but then we'll need to handle things like VARCHAR + // vs FIXEDCHAR (comes up in Isthmus tests). We also don't recurse into nullability + // of the inner fields, in case the type itself is a struct or list or map. + + List inputRecordTypes = + getInputs().stream().map(Rel::getRecordType).collect(Collectors.toList()); + if (inputRecordTypes.isEmpty()) { + throw new IllegalArgumentException("Set operation must have at least one input"); + } + Type.Struct first = inputRecordTypes.get(0); + List rest = inputRecordTypes.subList(1, inputRecordTypes.size()); + + int numFields = first.fields().size(); + if (rest.stream().anyMatch(t -> t.fields().size() != numFields)) { + throw new IllegalArgumentException("Set's input records have different number of fields"); + } + + // As defined in https://substrait.io/relations/logical_relations/#set-operation-types + return switch (getSetOp()) { + case UNKNOWN -> first; // alternative would be to throw an exception + case MINUS_PRIMARY, MINUS_PRIMARY_ALL, MINUS_MULTISET -> first; + case INTERSECTION_PRIMARY -> coalesceNullability(first, rest, true); + case INTERSECTION_MULTISET, + INTERSECTION_MULTISET_ALL, + UNION_DISTINCT, + UNION_ALL -> coalesceNullability(first, rest, false); + }; + } + + private Type.Struct coalesceNullability( + Type.Struct first, List rest, boolean prioritizeFirst) { + List fields = new ArrayList<>(); + for (int i = 0; i < first.fields().size(); i++) { + Type typeA = first.fields().get(i); + int finalI = i; + boolean anyOtherIsNullable = rest.stream().anyMatch(t -> t.fields().get(finalI).nullable()); + if (prioritizeFirst && !anyOtherIsNullable) { + // For INTERSECTION_PRIMARY: if no other field is nullable, type shouldn't be nullable + fields.add(TypeCreator.asNotNullable(typeA)); + } else if (!prioritizeFirst && anyOtherIsNullable) { + // For other INTERSECTIONs and UNIONs: if any other field is nullable, type should be + // nullable + fields.add(TypeCreator.asNullable(typeA)); + } else { + // Can keep nullability as-is + fields.add(typeA); + } + } + return Type.Struct.builder().fields(fields).nullable(first.nullable()).build(); } @Override