diff --git a/core/src/main/java/io/substrait/relation/AbstractReadRel.java b/core/src/main/java/io/substrait/relation/AbstractReadRel.java index b640916c..51df27fc 100644 --- a/core/src/main/java/io/substrait/relation/AbstractReadRel.java +++ b/core/src/main/java/io/substrait/relation/AbstractReadRel.java @@ -11,6 +11,8 @@ public abstract class AbstractReadRel extends ZeroInputRel implements HasExtensi public abstract Optional getFilter(); + public abstract Optional getBestEffortFilter(); + // TODO: // public abstract Optional diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 130e3aa4..7ec44b37 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -183,6 +183,13 @@ protected EmptyScan newEmptyScan(ReadRel rel) { var builder = EmptyScan.builder() .initialSchema(namedStruct) + .bestEffortFilter( + Optional.ofNullable( + rel.hasBestEffortFilter() + ? new ProtoExpressionConverter( + lookup, extensions, namedStruct.struct(), this) + .from(rel.getBestEffortFilter()) + : null)) .filter( Optional.ofNullable( rel.hasFilter() @@ -238,6 +245,13 @@ protected NamedScan newNamedScan(ReadRel rel) { NamedScan.builder() .initialSchema(namedStruct) .names(rel.getNamedTable().getNamesList()) + .bestEffortFilter( + Optional.ofNullable( + rel.hasBestEffortFilter() + ? new ProtoExpressionConverter( + lookup, extensions, namedStruct.struct(), this) + .from(rel.getBestEffortFilter()) + : null)) .filter( Optional.ofNullable( rel.hasFilter() @@ -279,6 +293,13 @@ protected LocalFiles newLocalFiles(ReadRel rel) { rel.getLocalFiles().getItemsList().stream() .map(this::newFileOrFiles) .collect(java.util.stream.Collectors.toList())) + .bestEffortFilter( + Optional.ofNullable( + rel.hasBestEffortFilter() + ? new ProtoExpressionConverter( + lookup, extensions, namedStruct.struct(), this) + .from(rel.getBestEffortFilter()) + : null)) .filter( Optional.ofNullable( rel.hasFilter() @@ -356,6 +377,9 @@ protected VirtualTableScan newVirtualTable(ReadRel rel) { var builder = VirtualTableScan.builder() + .bestEffortFilter( + Optional.ofNullable( + rel.hasBestEffortFilter() ? converter.from(rel.getBestEffortFilter()) : null)) .filter(Optional.ofNullable(rel.hasFilter() ? converter.from(rel.getFilter()) : null)) .initialSchema(NamedStruct.fromProto(rel.getBaseSchema(), protoTypeConverter)) .rows(structLiterals); diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index ce0aac7a..2b936c42 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -208,6 +208,7 @@ public Rel visit(NamedScan namedScan) throws RuntimeException { .setBaseSchema(namedScan.getInitialSchema().toProto(typeProtoConverter)); namedScan.getFilter().ifPresent(f -> builder.setFilter(toProto(f))); + namedScan.getBestEffortFilter().ifPresent(f -> builder.setBestEffortFilter(toProto(f))); namedScan.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); return Rel.newBuilder().setRead(builder).build(); @@ -227,6 +228,7 @@ public Rel visit(LocalFiles localFiles) throws RuntimeException { .build()) .setBaseSchema(localFiles.getInitialSchema().toProto(typeProtoConverter)); localFiles.getFilter().ifPresent(t -> builder.setFilter(toProto(t))); + localFiles.getBestEffortFilter().ifPresent(t -> builder.setBestEffortFilter(toProto(t))); localFiles.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); return Rel.newBuilder().setRead(builder.build()).build(); @@ -450,6 +452,7 @@ public Rel visit(VirtualTableScan virtualTableScan) throws RuntimeException { .setBaseSchema(virtualTableScan.getInitialSchema().toProto(typeProtoConverter)); virtualTableScan.getFilter().ifPresent(f -> builder.setFilter(toProto(f))); + virtualTableScan.getBestEffortFilter().ifPresent(f -> builder.setBestEffortFilter(toProto(f))); virtualTableScan.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); return Rel.newBuilder().setRead(builder).build(); diff --git a/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java index 028a07e0..c27201b5 100644 --- a/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java @@ -25,6 +25,8 @@ void namedScan() { namedScan = NamedScan.builder() .from(namedScan) + .bestEffortFilter( + b.equal(b.fieldReference(namedScan, 0), b.fieldReference(namedScan, 1))) .filter(b.equal(b.fieldReference(namedScan, 0), b.fieldReference(namedScan, 1))) .build(); @@ -55,6 +57,8 @@ void virtualTable() { virtTable = VirtualTableScan.builder() .from(virtTable) + .bestEffortFilter( + b.equal(b.fieldReference(virtTable, 0), b.fieldReference(virtTable, 1))) .filter(b.equal(b.fieldReference(virtTable, 0), b.fieldReference(virtTable, 1))) .build(); verifyRoundTrip(virtTable);