Skip to content

Commit

Permalink
feat(spark): support ExistenceJoin internal join type
Browse files Browse the repository at this point in the history
For certain filter expressions that embed subqueries, the Spark
optimiser replaces these with a Join relation of type ‘ExistenceJoin’.
This internal join type does not map directly to any standard
SQL join type, or Substrait join type.
To address this, it needs to be converted to a substrate
‘InPredicate’ within a filter condition.

Signed-off-by: Andrew Coleman <andrew_coleman@uk.ibm.com>
  • Loading branch information
andrew-coleman committed Feb 26, 2025
1 parent d8f22f7 commit ea009b1
Show file tree
Hide file tree
Showing 9 changed files with 234 additions and 45 deletions.
18 changes: 17 additions & 1 deletion spark/src/main/scala/io/substrait/debug/ExpressionToString.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import io.substrait.spark.DefaultExpressionVisitor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

import io.substrait.expression.{Expression, FieldReference}
import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, StrLiteral}
import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, I64Literal, StrLiteral}
import io.substrait.function.ToTypeString
import io.substrait.util.DecimalUtil

Expand All @@ -43,6 +43,10 @@ class ExpressionToString extends DefaultExpressionVisitor[String] {
expr.value().toString
}

override def visit(expr: I64Literal): String = {
expr.value().toString
}

override def visit(expr: DateLiteral): String = {
DateTimeUtils.toJavaDate(expr.value()).toString
}
Expand Down Expand Up @@ -76,4 +80,16 @@ class ExpressionToString extends DefaultExpressionVisitor[String] {
override def visit(expr: Expression.EmptyMapLiteral): String = {
expr.toString
}

override def visit(expr: Expression.Cast): String = {
expr.getType.toString
}

override def visit(expr: Expression.InPredicate): String = {
expr.toString
}

override def visit(expr: Expression.ScalarSubquery): String = {
expr.toString
}
}
20 changes: 20 additions & 0 deletions spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,26 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] {
})
}

override def visit(set: Set): String = {
withBuilder(set, 8)(
builder => {
builder
.append("operation=")
.append(set.getSetOp)
})
}

override def visit(cross: Cross): String = {
withBuilder(cross, 10)(
builder => {
builder
.append("left=")
.append(cross.getLeft)
.append("right=")
.append(cross.getRight)
})
}

override def visit(localFiles: LocalFiles): String = {
withBuilder(localFiles, 10)(
builder => {
Expand Down
6 changes: 6 additions & 0 deletions spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ private class ToSparkType

override def visit(expr: Type.IntervalYear): DataType = YearMonthIntervalType.DEFAULT

override def visit(expr: Type.Struct): DataType = {
StructType(
expr.fields.asScala.map(f => StructField(f.toString, f.accept(this), f.nullable()))
)
}

override def visit(expr: Type.ListType): DataType =
ArrayType(expr.elementType().accept(this), containsNull = expr.elementType().nullable())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package io.substrait.spark.expression
import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType}
import io.substrait.spark.logical.ToLogicalPlan

import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, CaseWhen, Cast, EqualTo, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.catalyst.plans.ExistenceJoin
import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DateType, Decimal}
import org.apache.spark.sql.types.{BooleanType, DateType, Decimal}
import org.apache.spark.substrait.SparkTypeUtil
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -39,6 +41,12 @@ class ToSparkExpression(
extends DefaultExpressionVisitor[Expression]
with HasOutputStack[Seq[NamedExpression]] {

private val existenceJoins = scala.collection.mutable.Map[AttributeReference, LogicalPlan]()

def lookupExists(attr: AttributeReference): Option[LogicalPlan] = {
existenceJoins.get(attr)
}

override def visit(expr: SExpression.BoolLiteral): Expression = {
if (expr.value()) {
Literal.TrueLiteral
Expand Down Expand Up @@ -198,6 +206,25 @@ class ToSparkExpression(
.getOrElse(visitFallback(expr))
}

private def existenceJoin(expr: SExpression.InPredicate): AttributeReference = {
val dataType = ToSubstraitType.convert(expr.getType)
require(dataType == BooleanType)
val needles = expr.needles().asScala.map(e => e.accept(this))
val haystack = expr.haystack().accept(toLogicalPlan.get)
val exists = AttributeReference("exists", BooleanType, nullable = false)()
val joinType = ExistenceJoin(exists)
require(
needles.size == haystack.output.size,
"The number of needles must match the number of haystack outputs")
val condition = needles
.zip(haystack.output)
.map { case (needle, hay) => EqualTo(needle, hay) }
.reduce[Expression] { case (lhs, rhs) => And(lhs, rhs) }
val join = Join(null, haystack, joinType, Some(condition), JoinHint.NONE)
existenceJoins.put(exists, join)
exists
}

override def visit(expr: SExpression.SingleOrList): Expression = {
val value = expr.condition().accept(this)
val list = expr.options().asScala.map(e => e.accept(this))
Expand All @@ -208,7 +235,12 @@ class ToSparkExpression(
val eArgs = expr.arguments().asScala
val args = eArgs.zipWithIndex.map {
case (arg, i) =>
arg.accept(expr.declaration(), i, this)
arg match {
case ip: SExpression.InPredicate =>
existenceJoin(ip)
case _ =>
arg.accept(expr.declaration(), i, this)
}
}

expr.declaration.name match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.execution.QueryExecution
Expand All @@ -46,6 +46,7 @@ import io.substrait.relation.files.FileFormat
import org.apache.hadoop.fs.Path

import scala.collection.JavaConverters.asScalaBufferConverter
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

/**
Expand Down Expand Up @@ -218,7 +219,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]

override def visit(fetch: relation.Fetch): LogicalPlan = {
val child = fetch.getInput.accept(this)
val limit = fetch.getCount.getAsLong.intValue()
val limit = fetch.getCount.orElse(-1).intValue() // -1 means unassigned here
val offset = fetch.getOffset.intValue()
val toLiteral = (i: Int) => Literal(i, IntegerType)
if (limit >= 0) {
Expand Down Expand Up @@ -291,13 +292,37 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
}

override def visit(filter: relation.Filter): LogicalPlan = {
val child = filter.getInput.accept(this)
var child = filter.getInput.accept(this)
withChild(child) {
val condition = filter.getCondition.accept(expressionConverter)
val existenceJoins = mutable.ListBuffer[AttributeReference]()
findExistenceJoins(condition, existenceJoins)
existenceJoins.foreach {
exists =>
expressionConverter.lookupExists(exists) match {
case Some(Join(_, right, joinType, condition, hint)) =>
// the filter child now becomes the join's left child
// and the join becomes the filter's child.
child = Join(child, right, joinType, condition, hint)
case _ => // should not get here!
throw new UnsupportedOperationException(
s"Expression not currently supported: $condition")
}
}
Filter(condition, child)
}
}

private def findExistenceJoins(
expression: Expression,
attributes: mutable.ListBuffer[AttributeReference]): Unit = {
expression.children.foreach {
case attr: AttributeReference if expressionConverter.lookupExists(attr).isDefined =>
attributes.append(attr)
case child => findExistenceJoins(child, attributes)
}
}

override def visit(set: relation.Set): LogicalPlan = {
val children = set.getInputs.asScala.map(_.accept(this))
withOutput(children.flatMap(_.output)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,9 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
}

override def visitFilter(p: Filter): relation.Rel = {
val input = visit(p.child)
val condition = toExpression(p.child.output)(p.condition)
relation.Filter.builder().condition(condition).input(visit(p.child)).build()
relation.Filter.builder().condition(condition).input(input).build()
}

private def toSubstraitJoin(joinType: JoinType): relation.Join.JoinType = joinType match {
Expand All @@ -262,33 +263,73 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
}

override def visitJoin(p: Join): relation.Rel = {
val left = visit(p.left)
val right = visit(p.right)
val condition = p.condition.map(toExpression(p.left.output ++ p.right.output)).getOrElse(TRUE)
val joinType = toSubstraitJoin(p.joinType)

if (joinType == relation.Join.JoinType.INNER && TRUE == condition) {
relation.Cross.builder
.left(left)
.right(right)
.build
} else {
relation.Join.builder
.condition(condition)
.joinType(joinType)
.left(left)
.right(right)
.build
p match {
case Join(left, right, ExistenceJoin(exists), where, _) =>
convertExistenceJoin(left, right, where, exists)
case _ =>
val left = visit(p.left)
val right = visit(p.right)
val condition =
p.condition.map(toExpression(p.left.output ++ p.right.output)).getOrElse(TRUE)
val joinType = toSubstraitJoin(p.joinType)

if (joinType == relation.Join.JoinType.INNER && TRUE == condition) {
relation.Cross.builder
.left(left)
.right(right)
.build
} else {
relation.Join.builder
.condition(condition)
.joinType(joinType)
.left(left)
.right(right)
.build
}
}
}

private def convertExistenceJoin(
left: LogicalPlan,
right: LogicalPlan,
condition: Option[Expression],
exists: Attribute): relation.Rel = {
// An ExistenceJoin is an internal spark join type that is injected by the catalyst
// optimiser. It doesn't directly map to any SQL join type, but can be modelled using
// a Substrait `InPredicate` within a Filter condition.

// The 'exists' attribute in the parent filter condition will be associated with this.

// extract the needle expressions from the join condition
def findNeedles(expr: Expression): Iterable[Expression] = expr match {
case And(lhs, rhs) => findNeedles(lhs) ++ findNeedles(rhs)
case EqualTo(lhs, _) => Seq(lhs)
case _ =>
throw new UnsupportedOperationException(
s"Unable to convert the ExistenceJoin condition: $expr")
}
val needles = condition.toIterable
.flatMap(w => findNeedles(w))
.map(toExpression(left.output))
.asJava

val haystack = visit(right)

val inPredicate =
SExpression.InPredicate.builder().needles(needles).haystack(haystack).build()
// put this in a map exists->inPredicate for later lookup
toSubstraitExp.existenceJoins.put(exists.exprId.id, inPredicate)
// return the left child, which will become the child of the enclosing filter
visit(left)
}

override def visitProject(p: Project): relation.Rel = {
val expressions = p.projectList.map(toExpression(p.child.output)).toList

val child = visit(p.child)
relation.Project.builder
.remap(relation.Rel.Remap.offset(p.child.output.size, expressions.size))
.remap(relation.Rel.Remap.offset(child.getRecordType.fields.size, expressions.size))
.expressions(expressions.asJava)
.input(visit(p.child))
.input(child)
.build()
}

Expand Down Expand Up @@ -492,6 +533,8 @@ private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel)
override protected val toScalarFunction: ToScalarFunction =
ToScalarFunction(SparkExtension.SparkScalarFunctions)

val existenceJoins = scala.collection.mutable.Map[Long, SExpression.InPredicate]()

override protected def translateSubQuery(expr: PlanExpression[_]): Option[SExpression] = {
expr match {
case s: ScalarSubquery if s.outerAttrs.isEmpty && s.joinCond.isEmpty =>
Expand All @@ -504,4 +547,11 @@ private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel)
case other => default(other)
}
}

override protected def translateAttribute(a: AttributeReference): Option[SExpression] = {
existenceJoins.get(a.exprId.id) match {
case Some(exists) => Some(exists)
case None => super.translateAttribute(a)
}
}
}
36 changes: 23 additions & 13 deletions spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,29 @@ trait SubstraitPlanTestBase { self: SharedSparkSession =>

def assertSqlSubstraitRelRoundTrip(query: String): LogicalPlan = {
// TODO need a more robust way of testing this than round-tripping.
val logicalPlan = plan(query)
val pojoRel = new ToSubstraitRel().visit(logicalPlan)
val converter = new ToLogicalPlan(spark = spark);
val logicalPlan2 = pojoRel.accept(converter);
require(logicalPlan2.resolved);
val pojoRel2 = new ToSubstraitRel().visit(logicalPlan2)

val extensionCollector = new ExtensionCollector;
val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel)
new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto)

pojoRel2.shouldEqualPlainly(pojoRel)
logicalPlan2
val sparkPlan = plan(query)

// convert spark logical plan to substrait
val substraitPlan = new ToSubstraitRel().visit(sparkPlan)

// Serialize to protobuf byte array
val extensionCollector = new ExtensionCollector
val bytes = new RelProtoConverter(extensionCollector).toProto(substraitPlan).toByteArray

// Read it back
val protoPlan = io.substrait.proto.Rel.parseFrom(bytes)
val substraitPlan2 =
new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(protoPlan)

// convert substrait back to spark plan
val sparkPlan2 = substraitPlan2.accept(new ToLogicalPlan(spark))

// and back to substrait again
val substraitPlan3 = new ToSubstraitRel().visit(sparkPlan2)

// compare with original substrait plan to ensure it round-tripped (via proto bytes) correctly
substraitPlan3.shouldEqualPlainly(substraitPlan)
sparkPlan2
}

def plan(sql: String): LogicalPlan = {
Expand Down
3 changes: 1 addition & 2 deletions spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase {
val failingSQL: Set[String] = Set(
"q2", // because round() isn't defined in substrait to work with Decimal. https://github.com/substrait-io/substrait/pull/713
"q9", // requires implementation of named_struct()
"q10", "q35", "q45", // Unsupported join type ExistenceJoin (this is an internal spark type)
"q51", "q84", // TBD
"q35", "q51", "q84", // These fail when comparing the round-tripped query plans, but are actually equivalent (due to aliases being ignored by substrait)
"q72" //requires implementation of date_add()
)
// spotless:on
Expand Down
Loading

0 comments on commit ea009b1

Please sign in to comment.