Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add proto roundtrips for Spark tests and fix issues it surfaces #315

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io.substrait.relation.ConsistentPartitionWindow;
import io.substrait.relation.ProtoRelConverter;
import io.substrait.type.Type;
import io.substrait.type.TypeVisitor;
import io.substrait.type.proto.ProtoTypeConverter;
import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -196,7 +197,20 @@ public Expression from(io.substrait.proto.Expression expr) {
var rel = protoRelConverter.from(expr.getSubquery().getScalar().getInput());
yield ImmutableExpression.ScalarSubquery.builder()
.input(rel)
.type(rel.getRecordType())
.type(
rel.getRecordType()
.accept(
new TypeVisitor.TypeThrowsVisitor<Type, RuntimeException>(
"Expected struct field") {
@Override
public Type visit(Type.Struct type) throws RuntimeException {
if (type.fields().size() != 1) {
throw new UnsupportedOperationException(
"Scalar subquery must have exactly one field");
}
return type.fields().get(0);
}
}))
.build();
}
case IN_PREDICATE -> {
Expand Down
1 change: 1 addition & 0 deletions core/src/main/java/io/substrait/relation/Join.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ protected Type.Struct deriveRecordType() {
switch (getJoinType()) {
case LEFT, OUTER -> getRight().getRecordType().fields().stream()
.map(TypeCreator::asNullable);
case SEMI, ANTI -> Stream.of(); // these are left joins which ignore right side columns
default -> getRight().getRecordType().fields().stream();
};
return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes));
Expand Down
37 changes: 36 additions & 1 deletion core/src/main/java/io/substrait/relation/Set.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import io.substrait.proto.SetRel;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;
import org.immutables.value.Value;

@Value.Immutable
Expand Down Expand Up @@ -42,7 +46,38 @@ public static Set.SetOp fromProto(SetRel.SetOp proto) {

@Override
protected Type.Struct deriveRecordType() {
return getInputs().get(0).getRecordType();
// The different inputs may have schemas that differ in nullability, but not in type.
// In that case we should return a schema that is nullable where any of the inputs is nullable.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the docs for this (https://substrait.io/relations/logical_relations/#set-operation-types), the output nullability depends on which set operation is being performed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I realized that as well but forgot to fix 😅 I'll try to tomorrow..

Stream<Type.Struct> inputRecordTypes = getInputs().stream().map(Rel::getRecordType);
return inputRecordTypes
.reduce(
(a, b) -> {
if (a.equals(b)) {
return a; // Short-circuit trivial case
}

if (a.fields().size() != b.fields().size()) {
throw new IllegalArgumentException(
"Set's input records have different number of fields");
}

List<Type> fields = new ArrayList<>();
for (int i = 0; i < a.fields().size(); i++) {
// We intentionally don't validate that the types match, to maintain backwards
// compatibility. We should, but then we'll need to handle things like VARCHAR
// vs FIXEDCHAR (comes up in Isthmus tests). We also don't recurse into nullability
// of the inner fields, in case the type itself is a struct or list or map.
Type typeA = a.fields().get(i);
Type typeB = b.fields().get(i);
fields.add(typeB.nullable() ? TypeCreator.asNullable(typeA) : typeA);
}

return Type.Struct.builder()
.fields(fields)
.nullable(a.nullable() || b.nullable())
.build();
})
.orElseThrow(() -> new IllegalStateException("A Set relation needs at least one input"));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ void scalarSubquery() {
Stream.of(
Expression.ScalarSubquery.builder()
.input(relWithEnhancement)
.type(TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.I64))
.type(TypeCreator.REQUIRED.I64)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not 100% sure about this, is it actually meant to return a struct type? given it's scalar that seems a bit weird

.build())
.collect(Collectors.toList()),
commonTable);
Expand Down
16 changes: 16 additions & 0 deletions spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] {
builder.append("commonExtension=").append(commonExtension)
})
}

override def visit(namedScan: NamedScan): String = {
withBuilder(namedScan, 10)(
builder => {
Expand All @@ -115,6 +116,21 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] {
})
}

override def visit(virtualTableScan: VirtualTableScan): String = {
withBuilder(virtualTableScan, 10)(
builder => {
fillReadRel(virtualTableScan, builder)
builder.append(", ")
builder.append("rows=").append(virtualTableScan.getRows)

virtualTableScan.getExtension.ifPresent(
extension => {
builder.append(", ")
builder.append("extension=").append(extension)
})
})
}

override def visit(emptyScan: EmptyScan): String = {
withBuilder(emptyScan, 10)(
builder => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]

override def visit(fetch: relation.Fetch): LogicalPlan = {
val child = fetch.getInput.accept(this)
val limit = fetch.getCount.getAsLong.intValue()
val limit = fetch.getCount.orElse(-1).intValue()
val offset = fetch.getOffset.intValue()
val toLiteral = (i: Int) => Literal(i, IntegerType)
if (limit >= 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal)
val aggOutputMap = aggregates.zipWithIndex.map {
case (e, i) =>
AttributeReference(s"agg_func_$i", e.dataType)() -> e
AttributeReference(s"agg_func_$i", e.dataType, nullable = e.nullable)() -> e
Copy link
Contributor Author

@Blizzara Blizzara Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these were causing wrong nullability for the type in the created pojos. I don't think that type field is used anywhere so it didn't cause harm, but still failed roundtrip tests as the type isn't written in proto and then it got correctly evaluated from other fields on read.

}
val aggOutput = aggOutputMap.map(_._1)

Expand All @@ -148,7 +148,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
}
val groupOutputMap = actualGroupExprs.zipWithIndex.map {
case (e, i) =>
AttributeReference(s"group_col_$i", e.dataType)() -> e
AttributeReference(s"group_col_$i", e.dataType, nullable = e.nullable)() -> e
}
val groupOutput = groupOutputMap.map(_._1)

Expand Down Expand Up @@ -212,12 +212,15 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
}

private def fetch(child: LogicalPlan, offset: Long, limit: Long = -1): relation.Fetch = {
relation.Fetch
val builder = relation.Fetch
.builder()
.input(visit(child))
.offset(offset)
.count(limit)
.build()
if (limit != -1) {
builder.count(limit)
}

builder.build()
}

override def visitGlobalLimit(p: GlobalLimit): relation.Rel = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ trait SubstraitPlanTestBase { self: SharedSparkSession =>

val extensionCollector = new ExtensionCollector;
val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel)
new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto)
val pojoFromProto =
new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto)
assertResult(pojoRel)(pojoFromProto)

pojoRel2.shouldEqualPlainly(pojoRel)
logicalPlan2
Expand Down
Loading