From ea009b1524842af57430f9079bd58dcae29e9a30 Mon Sep 17 00:00:00 2001 From: Andrew Coleman Date: Fri, 14 Feb 2025 15:39:16 +0000 Subject: [PATCH] feat(spark): support ExistenceJoin internal join type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For certain filter expressions that embed subqueries, the Spark optimiser replaces these with a Join relation of type ‘ExistenceJoin’. This internal join type does not map directly to any standard SQL join type, or Substrait join type. To address this, it needs to be converted to a substrate ‘InPredicate’ within a filter condition. Signed-off-by: Andrew Coleman --- .../substrait/debug/ExpressionToString.scala | 18 +++- .../substrait/debug/RelToVerboseString.scala | 20 ++++ .../io/substrait/spark/ToSubstraitType.scala | 6 ++ .../spark/expression/ToSparkExpression.scala | 38 +++++++- .../spark/logical/ToLogicalPlan.scala | 31 ++++++- .../spark/logical/ToSubstraitRel.scala | 92 ++++++++++++++----- .../spark/SubstraitPlanTestBase.scala | 36 +++++--- .../scala/io/substrait/spark/TPCDSPlan.scala | 3 +- .../scala/io/substrait/spark/TPCHPlan.scala | 35 ++++++- 9 files changed, 234 insertions(+), 45 deletions(-) diff --git a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala index 29c951363..2d4d3f833 100644 --- a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala +++ b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala @@ -21,7 +21,7 @@ import io.substrait.spark.DefaultExpressionVisitor import org.apache.spark.sql.catalyst.util.DateTimeUtils import io.substrait.expression.{Expression, FieldReference} -import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, StrLiteral} +import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, I64Literal, StrLiteral} import io.substrait.function.ToTypeString import io.substrait.util.DecimalUtil @@ -43,6 +43,10 @@ class ExpressionToString extends DefaultExpressionVisitor[String] { expr.value().toString } + override def visit(expr: I64Literal): String = { + expr.value().toString + } + override def visit(expr: DateLiteral): String = { DateTimeUtils.toJavaDate(expr.value()).toString } @@ -76,4 +80,16 @@ class ExpressionToString extends DefaultExpressionVisitor[String] { override def visit(expr: Expression.EmptyMapLiteral): String = { expr.toString } + + override def visit(expr: Expression.Cast): String = { + expr.getType.toString + } + + override def visit(expr: Expression.InPredicate): String = { + expr.toString + } + + override def visit(expr: Expression.ScalarSubquery): String = { + expr.toString + } } diff --git a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala index 79b34462f..efaa3f8bf 100644 --- a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala +++ b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala @@ -165,6 +165,26 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } + override def visit(set: Set): String = { + withBuilder(set, 8)( + builder => { + builder + .append("operation=") + .append(set.getSetOp) + }) + } + + override def visit(cross: Cross): String = { + withBuilder(cross, 10)( + builder => { + builder + .append("left=") + .append(cross.getLeft) + .append("right=") + .append(cross.getRight) + }) + } + override def visit(localFiles: LocalFiles): String = { withBuilder(localFiles, 10)( builder => { diff --git a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala index 951900cd8..d16034852 100644 --- a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala +++ b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala @@ -68,6 +68,12 @@ private class ToSparkType override def visit(expr: Type.IntervalYear): DataType = YearMonthIntervalType.DEFAULT + override def visit(expr: Type.Struct): DataType = { + StructType( + expr.fields.asScala.map(f => StructField(f.toString, f.accept(this), f.nullable())) + ) + } + override def visit(expr: Type.ListType): DataType = ArrayType(expr.elementType().accept(this), containsNull = expr.elementType().nullable()) diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala index a91b8e02c..f2d598f64 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -19,9 +19,11 @@ package io.substrait.spark.expression import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType} import io.substrait.spark.logical.ToLogicalPlan -import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, CaseWhen, Cast, EqualTo, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery} +import org.apache.spark.sql.catalyst.plans.ExistenceJoin +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint, LogicalPlan} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DateType, Decimal} +import org.apache.spark.sql.types.{BooleanType, DateType, Decimal} import org.apache.spark.substrait.SparkTypeUtil import org.apache.spark.unsafe.types.UTF8String @@ -39,6 +41,12 @@ class ToSparkExpression( extends DefaultExpressionVisitor[Expression] with HasOutputStack[Seq[NamedExpression]] { + private val existenceJoins = scala.collection.mutable.Map[AttributeReference, LogicalPlan]() + + def lookupExists(attr: AttributeReference): Option[LogicalPlan] = { + existenceJoins.get(attr) + } + override def visit(expr: SExpression.BoolLiteral): Expression = { if (expr.value()) { Literal.TrueLiteral @@ -198,6 +206,25 @@ class ToSparkExpression( .getOrElse(visitFallback(expr)) } + private def existenceJoin(expr: SExpression.InPredicate): AttributeReference = { + val dataType = ToSubstraitType.convert(expr.getType) + require(dataType == BooleanType) + val needles = expr.needles().asScala.map(e => e.accept(this)) + val haystack = expr.haystack().accept(toLogicalPlan.get) + val exists = AttributeReference("exists", BooleanType, nullable = false)() + val joinType = ExistenceJoin(exists) + require( + needles.size == haystack.output.size, + "The number of needles must match the number of haystack outputs") + val condition = needles + .zip(haystack.output) + .map { case (needle, hay) => EqualTo(needle, hay) } + .reduce[Expression] { case (lhs, rhs) => And(lhs, rhs) } + val join = Join(null, haystack, joinType, Some(condition), JoinHint.NONE) + existenceJoins.put(exists, join) + exists + } + override def visit(expr: SExpression.SingleOrList): Expression = { val value = expr.condition().accept(this) val list = expr.options().asScala.map(e => e.accept(this)) @@ -208,7 +235,12 @@ class ToSparkExpression( val eArgs = expr.arguments().asScala val args = eArgs.zipWithIndex.map { case (arg, i) => - arg.accept(expr.declaration(), i, this) + arg match { + case ip: SExpression.InPredicate => + existenceJoin(ip) + case _ => + arg.accept(expr.declaration(), i, this) + } } expr.declaration.name match { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 058fbe393..659635667 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.QueryExecution @@ -46,6 +46,7 @@ import io.substrait.relation.files.FileFormat import org.apache.hadoop.fs.Path import scala.collection.JavaConverters.asScalaBufferConverter +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer /** @@ -218,7 +219,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] override def visit(fetch: relation.Fetch): LogicalPlan = { val child = fetch.getInput.accept(this) - val limit = fetch.getCount.getAsLong.intValue() + val limit = fetch.getCount.orElse(-1).intValue() // -1 means unassigned here val offset = fetch.getOffset.intValue() val toLiteral = (i: Int) => Literal(i, IntegerType) if (limit >= 0) { @@ -291,13 +292,37 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } override def visit(filter: relation.Filter): LogicalPlan = { - val child = filter.getInput.accept(this) + var child = filter.getInput.accept(this) withChild(child) { val condition = filter.getCondition.accept(expressionConverter) + val existenceJoins = mutable.ListBuffer[AttributeReference]() + findExistenceJoins(condition, existenceJoins) + existenceJoins.foreach { + exists => + expressionConverter.lookupExists(exists) match { + case Some(Join(_, right, joinType, condition, hint)) => + // the filter child now becomes the join's left child + // and the join becomes the filter's child. + child = Join(child, right, joinType, condition, hint) + case _ => // should not get here! + throw new UnsupportedOperationException( + s"Expression not currently supported: $condition") + } + } Filter(condition, child) } } + private def findExistenceJoins( + expression: Expression, + attributes: mutable.ListBuffer[AttributeReference]): Unit = { + expression.children.foreach { + case attr: AttributeReference if expressionConverter.lookupExists(attr).isDefined => + attributes.append(attr) + case child => findExistenceJoins(child, attributes) + } + } + override def visit(set: relation.Set): LogicalPlan = { val children = set.getInputs.asScala.map(_.accept(this)) withOutput(children.flatMap(_.output)) { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index f00ebba53..4d1d2516f 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -247,8 +247,9 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { } override def visitFilter(p: Filter): relation.Rel = { + val input = visit(p.child) val condition = toExpression(p.child.output)(p.condition) - relation.Filter.builder().condition(condition).input(visit(p.child)).build() + relation.Filter.builder().condition(condition).input(input).build() } private def toSubstraitJoin(joinType: JoinType): relation.Join.JoinType = joinType match { @@ -262,33 +263,73 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { } override def visitJoin(p: Join): relation.Rel = { - val left = visit(p.left) - val right = visit(p.right) - val condition = p.condition.map(toExpression(p.left.output ++ p.right.output)).getOrElse(TRUE) - val joinType = toSubstraitJoin(p.joinType) - - if (joinType == relation.Join.JoinType.INNER && TRUE == condition) { - relation.Cross.builder - .left(left) - .right(right) - .build - } else { - relation.Join.builder - .condition(condition) - .joinType(joinType) - .left(left) - .right(right) - .build + p match { + case Join(left, right, ExistenceJoin(exists), where, _) => + convertExistenceJoin(left, right, where, exists) + case _ => + val left = visit(p.left) + val right = visit(p.right) + val condition = + p.condition.map(toExpression(p.left.output ++ p.right.output)).getOrElse(TRUE) + val joinType = toSubstraitJoin(p.joinType) + + if (joinType == relation.Join.JoinType.INNER && TRUE == condition) { + relation.Cross.builder + .left(left) + .right(right) + .build + } else { + relation.Join.builder + .condition(condition) + .joinType(joinType) + .left(left) + .right(right) + .build + } } } + private def convertExistenceJoin( + left: LogicalPlan, + right: LogicalPlan, + condition: Option[Expression], + exists: Attribute): relation.Rel = { + // An ExistenceJoin is an internal spark join type that is injected by the catalyst + // optimiser. It doesn't directly map to any SQL join type, but can be modelled using + // a Substrait `InPredicate` within a Filter condition. + + // The 'exists' attribute in the parent filter condition will be associated with this. + + // extract the needle expressions from the join condition + def findNeedles(expr: Expression): Iterable[Expression] = expr match { + case And(lhs, rhs) => findNeedles(lhs) ++ findNeedles(rhs) + case EqualTo(lhs, _) => Seq(lhs) + case _ => + throw new UnsupportedOperationException( + s"Unable to convert the ExistenceJoin condition: $expr") + } + val needles = condition.toIterable + .flatMap(w => findNeedles(w)) + .map(toExpression(left.output)) + .asJava + + val haystack = visit(right) + + val inPredicate = + SExpression.InPredicate.builder().needles(needles).haystack(haystack).build() + // put this in a map exists->inPredicate for later lookup + toSubstraitExp.existenceJoins.put(exists.exprId.id, inPredicate) + // return the left child, which will become the child of the enclosing filter + visit(left) + } + override def visitProject(p: Project): relation.Rel = { val expressions = p.projectList.map(toExpression(p.child.output)).toList - + val child = visit(p.child) relation.Project.builder - .remap(relation.Rel.Remap.offset(p.child.output.size, expressions.size)) + .remap(relation.Rel.Remap.offset(child.getRecordType.fields.size, expressions.size)) .expressions(expressions.asJava) - .input(visit(p.child)) + .input(child) .build() } @@ -492,6 +533,8 @@ private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel) override protected val toScalarFunction: ToScalarFunction = ToScalarFunction(SparkExtension.SparkScalarFunctions) + val existenceJoins = scala.collection.mutable.Map[Long, SExpression.InPredicate]() + override protected def translateSubQuery(expr: PlanExpression[_]): Option[SExpression] = { expr match { case s: ScalarSubquery if s.outerAttrs.isEmpty && s.joinCond.isEmpty => @@ -504,4 +547,11 @@ private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel) case other => default(other) } } + + override protected def translateAttribute(a: AttributeReference): Option[SExpression] = { + existenceJoins.get(a.exprId.id) match { + case Some(exists) => Some(exists) + case None => super.translateAttribute(a) + } + } } diff --git a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala index cbd7a151c..d026a1902 100644 --- a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala +++ b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala @@ -86,19 +86,29 @@ trait SubstraitPlanTestBase { self: SharedSparkSession => def assertSqlSubstraitRelRoundTrip(query: String): LogicalPlan = { // TODO need a more robust way of testing this than round-tripping. - val logicalPlan = plan(query) - val pojoRel = new ToSubstraitRel().visit(logicalPlan) - val converter = new ToLogicalPlan(spark = spark); - val logicalPlan2 = pojoRel.accept(converter); - require(logicalPlan2.resolved); - val pojoRel2 = new ToSubstraitRel().visit(logicalPlan2) - - val extensionCollector = new ExtensionCollector; - val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel) - new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto) - - pojoRel2.shouldEqualPlainly(pojoRel) - logicalPlan2 + val sparkPlan = plan(query) + + // convert spark logical plan to substrait + val substraitPlan = new ToSubstraitRel().visit(sparkPlan) + + // Serialize to protobuf byte array + val extensionCollector = new ExtensionCollector + val bytes = new RelProtoConverter(extensionCollector).toProto(substraitPlan).toByteArray + + // Read it back + val protoPlan = io.substrait.proto.Rel.parseFrom(bytes) + val substraitPlan2 = + new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(protoPlan) + + // convert substrait back to spark plan + val sparkPlan2 = substraitPlan2.accept(new ToLogicalPlan(spark)) + + // and back to substrait again + val substraitPlan3 = new ToSubstraitRel().visit(sparkPlan2) + + // compare with original substrait plan to ensure it round-tripped (via proto bytes) correctly + substraitPlan3.shouldEqualPlainly(substraitPlan) + sparkPlan2 } def plan(sql: String): LogicalPlan = { diff --git a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala index 048fec0ed..f0de6c122 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala @@ -35,8 +35,7 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { val failingSQL: Set[String] = Set( "q2", // because round() isn't defined in substrait to work with Decimal. https://github.com/substrait-io/substrait/pull/713 "q9", // requires implementation of named_struct() - "q10", "q35", "q45", // Unsupported join type ExistenceJoin (this is an internal spark type) - "q51", "q84", // TBD + "q35", "q51", "q84", // These fail when comparing the round-tripped query plans, but are actually equivalent (due to aliases being ignored by substrait) "q72" //requires implementation of date_add() ) // spotless:on diff --git a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala index c96cd5531..038f1ef3e 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala @@ -16,8 +16,6 @@ */ package io.substrait.spark -import io.substrait.spark.logical.{ToLogicalPlan, ToSubstraitRel} - import org.apache.spark.sql.TPCHBase class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { @@ -178,6 +176,39 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { "((l_orderkey, L_COMMITDATE), (l_orderkey, L_COMMITDATE, l_linestatus), l_shipdate, ())") } + test("in_subquery") { + assertSqlSubstraitRelRoundTrip( + "select p_retailprice from part where p_comment in (select o_comment from orders)") + } + + test("exists") { + assertSqlSubstraitRelRoundTrip( + "select p_retailprice from part where exists (select o_comment from orders)") + } + + test("ExistenceJoin") { + assertSqlSubstraitRelRoundTrip( + "select p_retailprice from part where p_partkey > 0 or p_comment in (select o_comment from orders)") + + assertSqlSubstraitRelRoundTrip("select p_retailprice from part where p_partkey > 0 or " + + "(p_comment in (select o_comment from orders) or p_comment in (select r_comment from region))") + + assertSqlSubstraitRelRoundTrip( + "select p_retailprice, sum(p_retailprice) from part " + + "where p_partkey > 0 or p_comment in (select o_comment from orders) " + + "group by p_retailprice") + + assertSqlSubstraitRelRoundTrip( + "select p_retailprice from part where p_partkey > 0 or " + + "(p_comment, p_retailprice) in (select o_comment, o_totalprice from orders)" + ) + + assertSqlSubstraitRelRoundTrip( + "select p_retailprice from part where p_partkey > 0 or " + + "(upper(p_comment), p_retailprice * 1.2) in (select upper(o_comment), o_totalprice from orders)" + ) + } + test("tpch_q1_variant") { // difference from tpch_q1 : 1) remove order by clause; 2) remove interval date literal assertSqlSubstraitRelRoundTrip(