From e7be83027c50d0e72189cad68b629830b8999aee Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Wed, 5 Mar 2025 16:18:30 +0100 Subject: [PATCH] fix: Set proto-to-rel type handling --- .../main/java/io/substrait/relation/Set.java | 49 ++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/io/substrait/relation/Set.java b/core/src/main/java/io/substrait/relation/Set.java index 4efef01e2..ea4448222 100644 --- a/core/src/main/java/io/substrait/relation/Set.java +++ b/core/src/main/java/io/substrait/relation/Set.java @@ -2,6 +2,10 @@ import io.substrait.proto.SetRel; import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; import org.immutables.value.Value; @Value.Immutable @@ -42,7 +46,50 @@ public static Set.SetOp fromProto(SetRel.SetOp proto) { @Override protected Type.Struct deriveRecordType() { - return getInputs().get(0).getRecordType(); + // The different inputs may have schemas that differ in nullability, but not in type. + // In that case we should return a schema that is nullable where any of the inputs is nullable. + Stream inputRecordTypes = getInputs().stream().map(Rel::getRecordType); + return inputRecordTypes + .reduce( + (a, b) -> { + if (a.equals(b)) { + return a; // Short-circuit trivial case + } + + if (a.fields().size() != b.fields().size()) { + throw new IllegalArgumentException( + "Set's input records have different number of fields"); + } + + List fields = new ArrayList<>(); + for (int i = 0; i < a.fields().size(); i++) { + fields.add(coalesceNullabilityOrThrow(a.fields().get(i), b.fields().get(i), i)); + } + + return Type.Struct.builder() + .fields(fields) + .nullable(a.nullable() || b.nullable()) + .build(); + }) + .orElseThrow(() -> new IllegalStateException("A Set relation needs at least one input")); + } + + private Type coalesceNullabilityOrThrow(Type a, Type b, int idx) { + if (a.equals(b)) { + return a; + } else if (a.getClass() == b.getClass() && a.nullable() != b.nullable()) { + // Class is equal but nullability is not (i.e. at least one is nullable), this we + // can fix by making the result nullable. + Type nullableFieldA = TypeCreator.asNullable(a); + // There may be differences also in the nullability or content of the inner fields. + // In those cases we throw. TODO: should the coalescing be recursive? + if (nullableFieldA.equals(TypeCreator.asNullable(b))) { + return nullableFieldA; + } + } + throw new IllegalArgumentException( + String.format( + "Set's input records have different types for field %s: %s vs %s", idx, a, b)); } @Override