Skip to content

Commit

Permalink
feat(spark): support ExistenceJoin internal join type
Browse files Browse the repository at this point in the history
For certain filter expressions that embed subqueries, the Spark
optimiser replaces these with a Join relation of type ‘ExistenceJoin’.
This internal join type does not map directly to any standard
SQL join type, or Substrait join type.
To address this, it needs to be converted to a substrate
‘InPredicate’ within a filter condition.

Signed-off-by: Andrew Coleman <andrew_coleman@uk.ibm.com>
  • Loading branch information
andrew-coleman committed Mar 6, 2025
1 parent d8f22f7 commit 151d013
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 46 deletions.
18 changes: 17 additions & 1 deletion spark/src/main/scala/io/substrait/debug/ExpressionToString.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import io.substrait.spark.DefaultExpressionVisitor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

import io.substrait.expression.{Expression, FieldReference}
import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, StrLiteral}
import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, I64Literal, StrLiteral}
import io.substrait.function.ToTypeString
import io.substrait.util.DecimalUtil

Expand All @@ -43,6 +43,10 @@ class ExpressionToString extends DefaultExpressionVisitor[String] {
expr.value().toString
}

override def visit(expr: I64Literal): String = {
expr.value().toString
}

override def visit(expr: DateLiteral): String = {
DateTimeUtils.toJavaDate(expr.value()).toString
}
Expand Down Expand Up @@ -76,4 +80,16 @@ class ExpressionToString extends DefaultExpressionVisitor[String] {
override def visit(expr: Expression.EmptyMapLiteral): String = {
expr.toString
}

override def visit(expr: Expression.Cast): String = {
expr.getType.toString
}

override def visit(expr: Expression.InPredicate): String = {
expr.toString
}

override def visit(expr: Expression.ScalarSubquery): String = {
expr.toString
}
}
20 changes: 20 additions & 0 deletions spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,26 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] {
})
}

override def visit(set: Set): String = {
withBuilder(set, 8)(
builder => {
builder
.append("operation=")
.append(set.getSetOp)
})
}

override def visit(cross: Cross): String = {
withBuilder(cross, 10)(
builder => {
builder
.append("left=")
.append(cross.getLeft)
.append("right=")
.append(cross.getRight)
})
}

override def visit(localFiles: LocalFiles): String = {
withBuilder(localFiles, 10)(
builder => {
Expand Down
7 changes: 7 additions & 0 deletions spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ private class ToSparkType

override def visit(expr: Type.IntervalYear): DataType = YearMonthIntervalType.DEFAULT

override def visit(expr: Type.Struct): DataType = {
StructType(
expr.fields.asScala.zipWithIndex
.map { case (t, i) => StructField(s"col${i + 1}", t.accept(this), t.nullable()) }
)
}

override def visit(expr: Type.ListType): DataType =
ArrayType(expr.elementType().accept(this), containsNull = expr.elementType().nullable())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package io.substrait.spark.expression
import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType}
import io.substrait.spark.logical.ToLogicalPlan

import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, CaseWhen, Cast, EqualTo, Expression, In, InSubquery, ListQuery, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.catalyst.plans.ExistenceJoin
import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DateType, Decimal}
import org.apache.spark.sql.types.{BooleanType, DateType, Decimal}
import org.apache.spark.substrait.SparkTypeUtil
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -204,6 +206,14 @@ class ToSparkExpression(
In(value, list)
}

override def visit(expr: SExpression.InPredicate): Expression = {
val needles = expr.needles().asScala.map(e => e.accept(this))
val haystack = expr.haystack().accept(toLogicalPlan.get)
new InSubquery(needles, ListQuery(haystack, childOutputs = haystack.output)) {
override def nullable: Boolean = expr.getType.nullable()
}
}

override def visit(expr: SExpression.ScalarFunctionInvocation): Expression = {
val eArgs = expr.arguments().asScala
val args = eArgs.zipWithIndex.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] {

protected def translateSubQuery(expr: PlanExpression[_]): Option[SExpression] = default(expr)

protected def translateInSubquery(expr: InSubquery): Option[SExpression] = default(expr)

protected def translateAttribute(a: AttributeReference): Option[SExpression] = {
val bindReference =
BindReferences.bindReference[Expression](a, currentOutput, allowFailures = false)
Expand Down Expand Up @@ -141,10 +143,6 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] {
case SubstraitLiteral(substraitLiteral) => Some(substraitLiteral)
case a: AttributeReference if currentOutput.nonEmpty => translateAttribute(a)
case a: Alias => translateUp(a.child)
case p
if p.getClass.getCanonicalName.equals( // removed in spark-3.3
"org.apache.spark.sql.catalyst.expressions.PromotePrecision") =>
translateUp(p.children.head)
case CaseWhen(branches, elseValue) => translateCaseWhen(branches, elseValue)
case In(value, list) => translateIn(value, list)
case InSet(value, set) => translateIn(value, set.toSeq.map(v => Literal(v)))
Expand All @@ -153,6 +151,7 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] {
.seqToOption(children.map(translateUp))
.flatMap(toScalarFunction.convert(scalar, _))
case p: PlanExpression[_] => translateSubQuery(p)
case in: InSubquery => translateInSubquery(in)
case other => default(other)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]

override def visit(fetch: relation.Fetch): LogicalPlan = {
val child = fetch.getInput.accept(this)
val limit = fetch.getCount.getAsLong.intValue()
val limit = fetch.getCount.orElse(-1).intValue() // -1 means unassigned here
val offset = fetch.getOffset.intValue()
val toLiteral = (i: Int) => Literal(i, IntegerType)
if (limit >= 0) {
Expand Down
107 changes: 86 additions & 21 deletions spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import io.substrait.relation.RelProtoConverter
import io.substrait.relation.Set.SetOp
import io.substrait.relation.files.{FileFormat, ImmutableFileOrFiles}
import io.substrait.relation.files.FileOrFiles.PathType
import io.substrait.utils.Util

import java.util.{Collections, Optional}

Expand Down Expand Up @@ -247,8 +248,9 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
}

override def visitFilter(p: Filter): relation.Rel = {
val input = visit(p.child)
val condition = toExpression(p.child.output)(p.condition)
relation.Filter.builder().condition(condition).input(visit(p.child)).build()
relation.Filter.builder().condition(condition).input(input).build()
}

private def toSubstraitJoin(joinType: JoinType): relation.Join.JoinType = joinType match {
Expand All @@ -262,33 +264,73 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
}

override def visitJoin(p: Join): relation.Rel = {
val left = visit(p.left)
val right = visit(p.right)
val condition = p.condition.map(toExpression(p.left.output ++ p.right.output)).getOrElse(TRUE)
val joinType = toSubstraitJoin(p.joinType)

if (joinType == relation.Join.JoinType.INNER && TRUE == condition) {
relation.Cross.builder
.left(left)
.right(right)
.build
} else {
relation.Join.builder
.condition(condition)
.joinType(joinType)
.left(left)
.right(right)
.build
p match {
case Join(left, right, ExistenceJoin(exists), where, _) =>
convertExistenceJoin(left, right, where, exists)
case _ =>
val left = visit(p.left)
val right = visit(p.right)
val condition =
p.condition.map(toExpression(p.left.output ++ p.right.output)).getOrElse(TRUE)
val joinType = toSubstraitJoin(p.joinType)

if (joinType == relation.Join.JoinType.INNER && TRUE == condition) {
relation.Cross.builder
.left(left)
.right(right)
.build
} else {
relation.Join.builder
.condition(condition)
.joinType(joinType)
.left(left)
.right(right)
.build
}
}
}

private def convertExistenceJoin(
left: LogicalPlan,
right: LogicalPlan,
condition: Option[Expression],
exists: Attribute): relation.Rel = {
// An ExistenceJoin is an internal spark join type that is injected by the catalyst
// optimiser. It doesn't directly map to any SQL join type, but can be modelled using
// a Substrait `InPredicate` within a Filter condition.

// The 'exists' attribute in the parent filter condition will be associated with this.

// extract the needle expressions from the join condition
def findNeedles(expr: Expression): Iterable[Expression] = expr match {
case And(lhs, rhs) => findNeedles(lhs) ++ findNeedles(rhs)
case EqualTo(lhs, _) => Seq(lhs)
case _ =>
throw new UnsupportedOperationException(
s"Unable to convert the ExistenceJoin condition: $expr")
}
val needles = condition.toIterable
.flatMap(w => findNeedles(w))
.map(toExpression(left.output))
.asJava

val haystack = visit(right)

val inPredicate =
SExpression.InPredicate.builder().needles(needles).haystack(haystack).build()
// put this in a map exists->inPredicate for later lookup
toSubstraitExp.existenceJoins.put(exists.exprId.id, inPredicate)
// return the left child, which will become the child of the enclosing filter
visit(left)
}

override def visitProject(p: Project): relation.Rel = {
val expressions = p.projectList.map(toExpression(p.child.output)).toList

val child = visit(p.child)
relation.Project.builder
.remap(relation.Rel.Remap.offset(p.child.output.size, expressions.size))
.remap(relation.Rel.Remap.offset(child.getRecordType.fields.size, expressions.size))
.expressions(expressions.asJava)
.input(visit(p.child))
.input(child)
.build()
}

Expand Down Expand Up @@ -492,6 +534,8 @@ private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel)
override protected val toScalarFunction: ToScalarFunction =
ToScalarFunction(SparkExtension.SparkScalarFunctions)

val existenceJoins = scala.collection.mutable.Map[Long, SExpression.InPredicate]()

override protected def translateSubQuery(expr: PlanExpression[_]): Option[SExpression] = {
expr match {
case s: ScalarSubquery if s.outerAttrs.isEmpty && s.joinCond.isEmpty =>
Expand All @@ -504,4 +548,25 @@ private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel)
case other => default(other)
}
}

override protected def translateInSubquery(expr: InSubquery): Option[SExpression] = {
Util
.seqToOption(expr.values.map(translateUp).toList)
.flatMap(
values =>
Some(
SExpression.InPredicate
.builder()
.needles(values.asJava)
.haystack(toSubstraitRel.visit(expr.query.plan))
.build()
))
}

override protected def translateAttribute(a: AttributeReference): Option[SExpression] = {
existenceJoins.get(a.exprId.id) match {
case Some(exists) => Some(exists)
case None => super.translateAttribute(a)
}
}
}
37 changes: 24 additions & 13 deletions spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,30 @@ trait SubstraitPlanTestBase { self: SharedSparkSession =>

def assertSqlSubstraitRelRoundTrip(query: String): LogicalPlan = {
// TODO need a more robust way of testing this than round-tripping.
val logicalPlan = plan(query)
val pojoRel = new ToSubstraitRel().visit(logicalPlan)
val converter = new ToLogicalPlan(spark = spark);
val logicalPlan2 = pojoRel.accept(converter);
require(logicalPlan2.resolved);
val pojoRel2 = new ToSubstraitRel().visit(logicalPlan2)

val extensionCollector = new ExtensionCollector;
val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel)
new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto)

pojoRel2.shouldEqualPlainly(pojoRel)
logicalPlan2
val sparkPlan = plan(query)

// convert spark logical plan to substrait
val substraitPlan = new ToSubstraitRel().visit(sparkPlan)

// Serialize to protobuf byte array
val extensionCollector = new ExtensionCollector
val bytes = new RelProtoConverter(extensionCollector).toProto(substraitPlan).toByteArray

// Read it back
val protoPlan = io.substrait.proto.Rel.parseFrom(bytes)
val substraitPlan2 =
new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(protoPlan)

// convert substrait back to spark plan
val sparkPlan2 = substraitPlan2.accept(new ToLogicalPlan(spark))
require(sparkPlan2.resolved)

// and back to substrait again
val substraitPlan3 = new ToSubstraitRel().visit(sparkPlan2)

// compare with original substrait plan to ensure it round-tripped (via proto bytes) correctly
substraitPlan3.shouldEqualPlainly(substraitPlan)
sparkPlan2
}

def plan(sql: String): LogicalPlan = {
Expand Down
3 changes: 1 addition & 2 deletions spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase {
val failingSQL: Set[String] = Set(
"q2", // because round() isn't defined in substrait to work with Decimal. https://github.com/substrait-io/substrait/pull/713
"q9", // requires implementation of named_struct()
"q10", "q35", "q45", // Unsupported join type ExistenceJoin (this is an internal spark type)
"q51", "q84", // TBD
"q35", "q51", "q84", // These fail when comparing the round-tripped query plans, but are actually equivalent (due to aliases being ignored by substrait)
"q72" //requires implementation of date_add()
)
// spotless:on
Expand Down
35 changes: 33 additions & 2 deletions spark/src/test/scala/io/substrait/spark/TPCHPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
*/
package io.substrait.spark

import io.substrait.spark.logical.{ToLogicalPlan, ToSubstraitRel}

import org.apache.spark.sql.TPCHBase

class TPCHPlan extends TPCHBase with SubstraitPlanTestBase {
Expand Down Expand Up @@ -178,6 +176,39 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase {
"((l_orderkey, L_COMMITDATE), (l_orderkey, L_COMMITDATE, l_linestatus), l_shipdate, ())")
}

test("in_subquery") {
assertSqlSubstraitRelRoundTrip(
"select p_retailprice from part where p_comment in (select o_comment from orders)")
}

test("exists") {
assertSqlSubstraitRelRoundTrip(
"select p_retailprice from part where exists (select o_comment from orders)")
}

test("ExistenceJoin") {
assertSqlSubstraitRelRoundTrip(
"select p_retailprice from part where p_partkey > 0 or p_comment in (select o_comment from orders)")

assertSqlSubstraitRelRoundTrip("select p_retailprice from part where p_partkey > 0 or " +
"(p_comment in (select o_comment from orders) or p_comment in (select r_comment from region))")

assertSqlSubstraitRelRoundTrip(
"select p_retailprice, sum(p_retailprice) from part " +
"where p_partkey > 0 or p_comment in (select o_comment from orders) " +
"group by p_retailprice")

assertSqlSubstraitRelRoundTrip(
"select p_retailprice from part where p_partkey > 0 or " +
"(p_comment, p_retailprice) in (select o_comment, o_totalprice from orders)"
)

assertSqlSubstraitRelRoundTrip(
"select p_retailprice from part where p_partkey > 0 or " +
"(upper(p_comment), p_retailprice * 1.2) in (select upper(o_comment), o_totalprice from orders)"
)
}

test("tpch_q1_variant") {
// difference from tpch_q1 : 1) remove order by clause; 2) remove interval date literal
assertSqlSubstraitRelRoundTrip(
Expand Down

0 comments on commit 151d013

Please sign in to comment.