From 151d01354ff3def0c0bc3d3156d0c21d1f02177b Mon Sep 17 00:00:00 2001 From: Andrew Coleman Date: Fri, 14 Feb 2025 15:39:16 +0000 Subject: [PATCH] feat(spark): support ExistenceJoin internal join type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For certain filter expressions that embed subqueries, the Spark optimiser replaces these with a Join relation of type ‘ExistenceJoin’. This internal join type does not map directly to any standard SQL join type, or Substrait join type. To address this, it needs to be converted to a substrate ‘InPredicate’ within a filter condition. Signed-off-by: Andrew Coleman --- .../substrait/debug/ExpressionToString.scala | 18 ++- .../substrait/debug/RelToVerboseString.scala | 20 ++++ .../io/substrait/spark/ToSubstraitType.scala | 7 ++ .../spark/expression/ToSparkExpression.scala | 14 ++- .../expression/ToSubstraitExpression.scala | 7 +- .../spark/logical/ToLogicalPlan.scala | 2 +- .../spark/logical/ToSubstraitRel.scala | 107 ++++++++++++++---- .../spark/SubstraitPlanTestBase.scala | 37 +++--- .../scala/io/substrait/spark/TPCDSPlan.scala | 3 +- .../scala/io/substrait/spark/TPCHPlan.scala | 35 +++++- 10 files changed, 204 insertions(+), 46 deletions(-) diff --git a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala index 29c951363..2d4d3f833 100644 --- a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala +++ b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala @@ -21,7 +21,7 @@ import io.substrait.spark.DefaultExpressionVisitor import org.apache.spark.sql.catalyst.util.DateTimeUtils import io.substrait.expression.{Expression, FieldReference} -import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, StrLiteral} +import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, I64Literal, StrLiteral} import io.substrait.function.ToTypeString import io.substrait.util.DecimalUtil @@ -43,6 +43,10 @@ class ExpressionToString extends DefaultExpressionVisitor[String] { expr.value().toString } + override def visit(expr: I64Literal): String = { + expr.value().toString + } + override def visit(expr: DateLiteral): String = { DateTimeUtils.toJavaDate(expr.value()).toString } @@ -76,4 +80,16 @@ class ExpressionToString extends DefaultExpressionVisitor[String] { override def visit(expr: Expression.EmptyMapLiteral): String = { expr.toString } + + override def visit(expr: Expression.Cast): String = { + expr.getType.toString + } + + override def visit(expr: Expression.InPredicate): String = { + expr.toString + } + + override def visit(expr: Expression.ScalarSubquery): String = { + expr.toString + } } diff --git a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala index 79b34462f..efaa3f8bf 100644 --- a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala +++ b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala @@ -165,6 +165,26 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } + override def visit(set: Set): String = { + withBuilder(set, 8)( + builder => { + builder + .append("operation=") + .append(set.getSetOp) + }) + } + + override def visit(cross: Cross): String = { + withBuilder(cross, 10)( + builder => { + builder + .append("left=") + .append(cross.getLeft) + .append("right=") + .append(cross.getRight) + }) + } + override def visit(localFiles: LocalFiles): String = { withBuilder(localFiles, 10)( builder => { diff --git a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala index 951900cd8..dd7ecaebb 100644 --- a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala +++ b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala @@ -68,6 +68,13 @@ private class ToSparkType override def visit(expr: Type.IntervalYear): DataType = YearMonthIntervalType.DEFAULT + override def visit(expr: Type.Struct): DataType = { + StructType( + expr.fields.asScala.zipWithIndex + .map { case (t, i) => StructField(s"col${i + 1}", t.accept(this), t.nullable()) } + ) + } + override def visit(expr: Type.ListType): DataType = ArrayType(expr.elementType().accept(this), containsNull = expr.elementType().nullable()) diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala index a91b8e02c..30d6b8f05 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -19,9 +19,11 @@ package io.substrait.spark.expression import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType} import io.substrait.spark.logical.ToLogicalPlan -import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, CaseWhen, Cast, EqualTo, Expression, In, InSubquery, ListQuery, Literal, MakeDecimal, NamedExpression, ScalarSubquery} +import org.apache.spark.sql.catalyst.plans.ExistenceJoin +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint, LogicalPlan} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DateType, Decimal} +import org.apache.spark.sql.types.{BooleanType, DateType, Decimal} import org.apache.spark.substrait.SparkTypeUtil import org.apache.spark.unsafe.types.UTF8String @@ -204,6 +206,14 @@ class ToSparkExpression( In(value, list) } + override def visit(expr: SExpression.InPredicate): Expression = { + val needles = expr.needles().asScala.map(e => e.accept(this)) + val haystack = expr.haystack().accept(toLogicalPlan.get) + new InSubquery(needles, ListQuery(haystack, childOutputs = haystack.output)) { + override def nullable: Boolean = expr.getType.nullable() + } + } + override def visit(expr: SExpression.ScalarFunctionInvocation): Expression = { val eArgs = expr.arguments().asScala val args = eArgs.zipWithIndex.map { diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala index e43965363..611c0bcd5 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala @@ -65,6 +65,8 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] { protected def translateSubQuery(expr: PlanExpression[_]): Option[SExpression] = default(expr) + protected def translateInSubquery(expr: InSubquery): Option[SExpression] = default(expr) + protected def translateAttribute(a: AttributeReference): Option[SExpression] = { val bindReference = BindReferences.bindReference[Expression](a, currentOutput, allowFailures = false) @@ -141,10 +143,6 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] { case SubstraitLiteral(substraitLiteral) => Some(substraitLiteral) case a: AttributeReference if currentOutput.nonEmpty => translateAttribute(a) case a: Alias => translateUp(a.child) - case p - if p.getClass.getCanonicalName.equals( // removed in spark-3.3 - "org.apache.spark.sql.catalyst.expressions.PromotePrecision") => - translateUp(p.children.head) case CaseWhen(branches, elseValue) => translateCaseWhen(branches, elseValue) case In(value, list) => translateIn(value, list) case InSet(value, set) => translateIn(value, set.toSeq.map(v => Literal(v))) @@ -153,6 +151,7 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] { .seqToOption(children.map(translateUp)) .flatMap(toScalarFunction.convert(scalar, _)) case p: PlanExpression[_] => translateSubQuery(p) + case in: InSubquery => translateInSubquery(in) case other => default(other) } } diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 058fbe393..c944ecffd 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -218,7 +218,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] override def visit(fetch: relation.Fetch): LogicalPlan = { val child = fetch.getInput.accept(this) - val limit = fetch.getCount.getAsLong.intValue() + val limit = fetch.getCount.orElse(-1).intValue() // -1 means unassigned here val offset = fetch.getOffset.intValue() val toLiteral = (i: Int) => Literal(i, IntegerType) if (limit >= 0) { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index f00ebba53..d1e1889e3 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -43,6 +43,7 @@ import io.substrait.relation.RelProtoConverter import io.substrait.relation.Set.SetOp import io.substrait.relation.files.{FileFormat, ImmutableFileOrFiles} import io.substrait.relation.files.FileOrFiles.PathType +import io.substrait.utils.Util import java.util.{Collections, Optional} @@ -247,8 +248,9 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { } override def visitFilter(p: Filter): relation.Rel = { + val input = visit(p.child) val condition = toExpression(p.child.output)(p.condition) - relation.Filter.builder().condition(condition).input(visit(p.child)).build() + relation.Filter.builder().condition(condition).input(input).build() } private def toSubstraitJoin(joinType: JoinType): relation.Join.JoinType = joinType match { @@ -262,33 +264,73 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { } override def visitJoin(p: Join): relation.Rel = { - val left = visit(p.left) - val right = visit(p.right) - val condition = p.condition.map(toExpression(p.left.output ++ p.right.output)).getOrElse(TRUE) - val joinType = toSubstraitJoin(p.joinType) - - if (joinType == relation.Join.JoinType.INNER && TRUE == condition) { - relation.Cross.builder - .left(left) - .right(right) - .build - } else { - relation.Join.builder - .condition(condition) - .joinType(joinType) - .left(left) - .right(right) - .build + p match { + case Join(left, right, ExistenceJoin(exists), where, _) => + convertExistenceJoin(left, right, where, exists) + case _ => + val left = visit(p.left) + val right = visit(p.right) + val condition = + p.condition.map(toExpression(p.left.output ++ p.right.output)).getOrElse(TRUE) + val joinType = toSubstraitJoin(p.joinType) + + if (joinType == relation.Join.JoinType.INNER && TRUE == condition) { + relation.Cross.builder + .left(left) + .right(right) + .build + } else { + relation.Join.builder + .condition(condition) + .joinType(joinType) + .left(left) + .right(right) + .build + } + } + } + + private def convertExistenceJoin( + left: LogicalPlan, + right: LogicalPlan, + condition: Option[Expression], + exists: Attribute): relation.Rel = { + // An ExistenceJoin is an internal spark join type that is injected by the catalyst + // optimiser. It doesn't directly map to any SQL join type, but can be modelled using + // a Substrait `InPredicate` within a Filter condition. + + // The 'exists' attribute in the parent filter condition will be associated with this. + + // extract the needle expressions from the join condition + def findNeedles(expr: Expression): Iterable[Expression] = expr match { + case And(lhs, rhs) => findNeedles(lhs) ++ findNeedles(rhs) + case EqualTo(lhs, _) => Seq(lhs) + case _ => + throw new UnsupportedOperationException( + s"Unable to convert the ExistenceJoin condition: $expr") } + val needles = condition.toIterable + .flatMap(w => findNeedles(w)) + .map(toExpression(left.output)) + .asJava + + val haystack = visit(right) + + val inPredicate = + SExpression.InPredicate.builder().needles(needles).haystack(haystack).build() + // put this in a map exists->inPredicate for later lookup + toSubstraitExp.existenceJoins.put(exists.exprId.id, inPredicate) + // return the left child, which will become the child of the enclosing filter + visit(left) } override def visitProject(p: Project): relation.Rel = { val expressions = p.projectList.map(toExpression(p.child.output)).toList - + val child = visit(p.child) relation.Project.builder - .remap(relation.Rel.Remap.offset(p.child.output.size, expressions.size)) + .remap(relation.Rel.Remap.offset(child.getRecordType.fields.size, expressions.size)) .expressions(expressions.asJava) - .input(visit(p.child)) + .input(child) .build() } @@ -492,6 +534,8 @@ private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel) override protected val toScalarFunction: ToScalarFunction = ToScalarFunction(SparkExtension.SparkScalarFunctions) + val existenceJoins = scala.collection.mutable.Map[Long, SExpression.InPredicate]() + override protected def translateSubQuery(expr: PlanExpression[_]): Option[SExpression] = { expr match { case s: ScalarSubquery if s.outerAttrs.isEmpty && s.joinCond.isEmpty => @@ -504,4 +548,25 @@ private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel) case other => default(other) } } + + override protected def translateInSubquery(expr: InSubquery): Option[SExpression] = { + Util + .seqToOption(expr.values.map(translateUp).toList) + .flatMap( + values => + Some( + SExpression.InPredicate + .builder() + .needles(values.asJava) + .haystack(toSubstraitRel.visit(expr.query.plan)) + .build() + )) + } + + override protected def translateAttribute(a: AttributeReference): Option[SExpression] = { + existenceJoins.get(a.exprId.id) match { + case Some(exists) => Some(exists) + case None => super.translateAttribute(a) + } + } } diff --git a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala index cbd7a151c..beb57d1b9 100644 --- a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala +++ b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala @@ -86,19 +86,30 @@ trait SubstraitPlanTestBase { self: SharedSparkSession => def assertSqlSubstraitRelRoundTrip(query: String): LogicalPlan = { // TODO need a more robust way of testing this than round-tripping. - val logicalPlan = plan(query) - val pojoRel = new ToSubstraitRel().visit(logicalPlan) - val converter = new ToLogicalPlan(spark = spark); - val logicalPlan2 = pojoRel.accept(converter); - require(logicalPlan2.resolved); - val pojoRel2 = new ToSubstraitRel().visit(logicalPlan2) - - val extensionCollector = new ExtensionCollector; - val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel) - new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto) - - pojoRel2.shouldEqualPlainly(pojoRel) - logicalPlan2 + val sparkPlan = plan(query) + + // convert spark logical plan to substrait + val substraitPlan = new ToSubstraitRel().visit(sparkPlan) + + // Serialize to protobuf byte array + val extensionCollector = new ExtensionCollector + val bytes = new RelProtoConverter(extensionCollector).toProto(substraitPlan).toByteArray + + // Read it back + val protoPlan = io.substrait.proto.Rel.parseFrom(bytes) + val substraitPlan2 = + new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(protoPlan) + + // convert substrait back to spark plan + val sparkPlan2 = substraitPlan2.accept(new ToLogicalPlan(spark)) + require(sparkPlan2.resolved) + + // and back to substrait again + val substraitPlan3 = new ToSubstraitRel().visit(sparkPlan2) + + // compare with original substrait plan to ensure it round-tripped (via proto bytes) correctly + substraitPlan3.shouldEqualPlainly(substraitPlan) + sparkPlan2 } def plan(sql: String): LogicalPlan = { diff --git a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala index 048fec0ed..f0de6c122 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala @@ -35,8 +35,7 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { val failingSQL: Set[String] = Set( "q2", // because round() isn't defined in substrait to work with Decimal. https://github.com/substrait-io/substrait/pull/713 "q9", // requires implementation of named_struct() - "q10", "q35", "q45", // Unsupported join type ExistenceJoin (this is an internal spark type) - "q51", "q84", // TBD + "q35", "q51", "q84", // These fail when comparing the round-tripped query plans, but are actually equivalent (due to aliases being ignored by substrait) "q72" //requires implementation of date_add() ) // spotless:on diff --git a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala index c96cd5531..038f1ef3e 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala @@ -16,8 +16,6 @@ */ package io.substrait.spark -import io.substrait.spark.logical.{ToLogicalPlan, ToSubstraitRel} - import org.apache.spark.sql.TPCHBase class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { @@ -178,6 +176,39 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { "((l_orderkey, L_COMMITDATE), (l_orderkey, L_COMMITDATE, l_linestatus), l_shipdate, ())") } + test("in_subquery") { + assertSqlSubstraitRelRoundTrip( + "select p_retailprice from part where p_comment in (select o_comment from orders)") + } + + test("exists") { + assertSqlSubstraitRelRoundTrip( + "select p_retailprice from part where exists (select o_comment from orders)") + } + + test("ExistenceJoin") { + assertSqlSubstraitRelRoundTrip( + "select p_retailprice from part where p_partkey > 0 or p_comment in (select o_comment from orders)") + + assertSqlSubstraitRelRoundTrip("select p_retailprice from part where p_partkey > 0 or " + + "(p_comment in (select o_comment from orders) or p_comment in (select r_comment from region))") + + assertSqlSubstraitRelRoundTrip( + "select p_retailprice, sum(p_retailprice) from part " + + "where p_partkey > 0 or p_comment in (select o_comment from orders) " + + "group by p_retailprice") + + assertSqlSubstraitRelRoundTrip( + "select p_retailprice from part where p_partkey > 0 or " + + "(p_comment, p_retailprice) in (select o_comment, o_totalprice from orders)" + ) + + assertSqlSubstraitRelRoundTrip( + "select p_retailprice from part where p_partkey > 0 or " + + "(upper(p_comment), p_retailprice * 1.2) in (select upper(o_comment), o_totalprice from orders)" + ) + } + test("tpch_q1_variant") { // difference from tpch_q1 : 1) remove order by clause; 2) remove interval date literal assertSqlSubstraitRelRoundTrip(