diff --git a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala index 29c951363..2d4d3f833 100644 --- a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala +++ b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala @@ -21,7 +21,7 @@ import io.substrait.spark.DefaultExpressionVisitor import org.apache.spark.sql.catalyst.util.DateTimeUtils import io.substrait.expression.{Expression, FieldReference} -import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, StrLiteral} +import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, I64Literal, StrLiteral} import io.substrait.function.ToTypeString import io.substrait.util.DecimalUtil @@ -43,6 +43,10 @@ class ExpressionToString extends DefaultExpressionVisitor[String] { expr.value().toString } + override def visit(expr: I64Literal): String = { + expr.value().toString + } + override def visit(expr: DateLiteral): String = { DateTimeUtils.toJavaDate(expr.value()).toString } @@ -76,4 +80,16 @@ class ExpressionToString extends DefaultExpressionVisitor[String] { override def visit(expr: Expression.EmptyMapLiteral): String = { expr.toString } + + override def visit(expr: Expression.Cast): String = { + expr.getType.toString + } + + override def visit(expr: Expression.InPredicate): String = { + expr.toString + } + + override def visit(expr: Expression.ScalarSubquery): String = { + expr.toString + } } diff --git a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala index 79b34462f..efaa3f8bf 100644 --- a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala +++ b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala @@ -165,6 +165,26 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } + override def visit(set: Set): String = { + withBuilder(set, 8)( + builder => { + builder + .append("operation=") + .append(set.getSetOp) + }) + } + + override def visit(cross: Cross): String = { + withBuilder(cross, 10)( + builder => { + builder + .append("left=") + .append(cross.getLeft) + .append("right=") + .append(cross.getRight) + }) + } + override def visit(localFiles: LocalFiles): String = { withBuilder(localFiles, 10)( builder => { diff --git a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala index 951900cd8..d16034852 100644 --- a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala +++ b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala @@ -68,6 +68,12 @@ private class ToSparkType override def visit(expr: Type.IntervalYear): DataType = YearMonthIntervalType.DEFAULT + override def visit(expr: Type.Struct): DataType = { + StructType( + expr.fields.asScala.map(f => StructField(f.toString, f.accept(this), f.nullable())) + ) + } + override def visit(expr: Type.ListType): DataType = ArrayType(expr.elementType().accept(this), containsNull = expr.elementType().nullable()) diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala index a91b8e02c..f2d598f64 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -19,9 +19,11 @@ package io.substrait.spark.expression import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType} import io.substrait.spark.logical.ToLogicalPlan -import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, CaseWhen, Cast, EqualTo, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery} +import org.apache.spark.sql.catalyst.plans.ExistenceJoin +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint, LogicalPlan} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DateType, Decimal} +import org.apache.spark.sql.types.{BooleanType, DateType, Decimal} import org.apache.spark.substrait.SparkTypeUtil import org.apache.spark.unsafe.types.UTF8String @@ -39,6 +41,12 @@ class ToSparkExpression( extends DefaultExpressionVisitor[Expression] with HasOutputStack[Seq[NamedExpression]] { + private val existenceJoins = scala.collection.mutable.Map[AttributeReference, LogicalPlan]() + + def lookupExists(attr: AttributeReference): Option[LogicalPlan] = { + existenceJoins.get(attr) + } + override def visit(expr: SExpression.BoolLiteral): Expression = { if (expr.value()) { Literal.TrueLiteral @@ -198,6 +206,25 @@ class ToSparkExpression( .getOrElse(visitFallback(expr)) } + private def existenceJoin(expr: SExpression.InPredicate): AttributeReference = { + val dataType = ToSubstraitType.convert(expr.getType) + require(dataType == BooleanType) + val needles = expr.needles().asScala.map(e => e.accept(this)) + val haystack = expr.haystack().accept(toLogicalPlan.get) + val exists = AttributeReference("exists", BooleanType, nullable = false)() + val joinType = ExistenceJoin(exists) + require( + needles.size == haystack.output.size, + "The number of needles must match the number of haystack outputs") + val condition = needles + .zip(haystack.output) + .map { case (needle, hay) => EqualTo(needle, hay) } + .reduce[Expression] { case (lhs, rhs) => And(lhs, rhs) } + val join = Join(null, haystack, joinType, Some(condition), JoinHint.NONE) + existenceJoins.put(exists, join) + exists + } + override def visit(expr: SExpression.SingleOrList): Expression = { val value = expr.condition().accept(this) val list = expr.options().asScala.map(e => e.accept(this)) @@ -208,7 +235,12 @@ class ToSparkExpression( val eArgs = expr.arguments().asScala val args = eArgs.zipWithIndex.map { case (arg, i) => - arg.accept(expr.declaration(), i, this) + arg match { + case ip: SExpression.InPredicate => + existenceJoin(ip) + case _ => + arg.accept(expr.declaration(), i, this) + } } expr.declaration.name match { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 058fbe393..659635667 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.QueryExecution @@ -46,6 +46,7 @@ import io.substrait.relation.files.FileFormat import org.apache.hadoop.fs.Path import scala.collection.JavaConverters.asScalaBufferConverter +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer /** @@ -218,7 +219,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] override def visit(fetch: relation.Fetch): LogicalPlan = { val child = fetch.getInput.accept(this) - val limit = fetch.getCount.getAsLong.intValue() + val limit = fetch.getCount.orElse(-1).intValue() // -1 means unassigned here val offset = fetch.getOffset.intValue() val toLiteral = (i: Int) => Literal(i, IntegerType) if (limit >= 0) { @@ -291,13 +292,37 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } override def visit(filter: relation.Filter): LogicalPlan = { - val child = filter.getInput.accept(this) + var child = filter.getInput.accept(this) withChild(child) { val condition = filter.getCondition.accept(expressionConverter) + val existenceJoins = mutable.ListBuffer[AttributeReference]() + findExistenceJoins(condition, existenceJoins) + existenceJoins.foreach { + exists => + expressionConverter.lookupExists(exists) match { + case Some(Join(_, right, joinType, condition, hint)) => + // the filter child now becomes the join's left child + // and the join becomes the filter's child. + child = Join(child, right, joinType, condition, hint) + case _ => // should not get here! + throw new UnsupportedOperationException( + s"Expression not currently supported: $condition") + } + } Filter(condition, child) } } + private def findExistenceJoins( + expression: Expression, + attributes: mutable.ListBuffer[AttributeReference]): Unit = { + expression.children.foreach { + case attr: AttributeReference if expressionConverter.lookupExists(attr).isDefined => + attributes.append(attr) + case child => findExistenceJoins(child, attributes) + } + } + override def visit(set: relation.Set): LogicalPlan = { val children = set.getInputs.asScala.map(_.accept(this)) withOutput(children.flatMap(_.output)) { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index f00ebba53..4d1d2516f 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -247,8 +247,9 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { } override def visitFilter(p: Filter): relation.Rel = { + val input = visit(p.child) val condition = toExpression(p.child.output)(p.condition) - relation.Filter.builder().condition(condition).input(visit(p.child)).build() + relation.Filter.builder().condition(condition).input(input).build() } private def toSubstraitJoin(joinType: JoinType): relation.Join.JoinType = joinType match { @@ -262,33 +263,73 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { } override def visitJoin(p: Join): relation.Rel = { - val left = visit(p.left) - val right = visit(p.right) - val condition = p.condition.map(toExpression(p.left.output ++ p.right.output)).getOrElse(TRUE) - val joinType = toSubstraitJoin(p.joinType) - - if (joinType == relation.Join.JoinType.INNER && TRUE == condition) { - relation.Cross.builder - .left(left) - .right(right) - .build - } else { - relation.Join.builder - .condition(condition) - .joinType(joinType) - .left(left) - .right(right) - .build + p match { + case Join(left, right, ExistenceJoin(exists), where, _) => + convertExistenceJoin(left, right, where, exists) + case _ => + val left = visit(p.left) + val right = visit(p.right) + val condition = + p.condition.map(toExpression(p.left.output ++ p.right.output)).getOrElse(TRUE) + val joinType = toSubstraitJoin(p.joinType) + + if (joinType == relation.Join.JoinType.INNER && TRUE == condition) { + relation.Cross.builder + .left(left) + .right(right) + .build + } else { + relation.Join.builder + .condition(condition) + .joinType(joinType) + .left(left) + .right(right) + .build + } } } + private def convertExistenceJoin( + left: LogicalPlan, + right: LogicalPlan, + condition: Option[Expression], + exists: Attribute): relation.Rel = { + // An ExistenceJoin is an internal spark join type that is injected by the catalyst + // optimiser. It doesn't directly map to any SQL join type, but can be modelled using + // a Substrait `InPredicate` within a Filter condition. + + // The 'exists' attribute in the parent filter condition will be associated with this. + + // extract the needle expressions from the join condition + def findNeedles(expr: Expression): Iterable[Expression] = expr match { + case And(lhs, rhs) => findNeedles(lhs) ++ findNeedles(rhs) + case EqualTo(lhs, _) => Seq(lhs) + case _ => + throw new UnsupportedOperationException( + s"Unable to convert the ExistenceJoin condition: $expr") + } + val needles = condition.toIterable + .flatMap(w => findNeedles(w)) + .map(toExpression(left.output)) + .asJava + + val haystack = visit(right) + + val inPredicate = + SExpression.InPredicate.builder().needles(needles).haystack(haystack).build() + // put this in a map exists->inPredicate for later lookup + toSubstraitExp.existenceJoins.put(exists.exprId.id, inPredicate) + // return the left child, which will become the child of the enclosing filter + visit(left) + } + override def visitProject(p: Project): relation.Rel = { val expressions = p.projectList.map(toExpression(p.child.output)).toList - + val child = visit(p.child) relation.Project.builder - .remap(relation.Rel.Remap.offset(p.child.output.size, expressions.size)) + .remap(relation.Rel.Remap.offset(child.getRecordType.fields.size, expressions.size)) .expressions(expressions.asJava) - .input(visit(p.child)) + .input(child) .build() } @@ -492,6 +533,8 @@ private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel) override protected val toScalarFunction: ToScalarFunction = ToScalarFunction(SparkExtension.SparkScalarFunctions) + val existenceJoins = scala.collection.mutable.Map[Long, SExpression.InPredicate]() + override protected def translateSubQuery(expr: PlanExpression[_]): Option[SExpression] = { expr match { case s: ScalarSubquery if s.outerAttrs.isEmpty && s.joinCond.isEmpty => @@ -504,4 +547,11 @@ private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel) case other => default(other) } } + + override protected def translateAttribute(a: AttributeReference): Option[SExpression] = { + existenceJoins.get(a.exprId.id) match { + case Some(exists) => Some(exists) + case None => super.translateAttribute(a) + } + } } diff --git a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala index cbd7a151c..d026a1902 100644 --- a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala +++ b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala @@ -86,19 +86,29 @@ trait SubstraitPlanTestBase { self: SharedSparkSession => def assertSqlSubstraitRelRoundTrip(query: String): LogicalPlan = { // TODO need a more robust way of testing this than round-tripping. - val logicalPlan = plan(query) - val pojoRel = new ToSubstraitRel().visit(logicalPlan) - val converter = new ToLogicalPlan(spark = spark); - val logicalPlan2 = pojoRel.accept(converter); - require(logicalPlan2.resolved); - val pojoRel2 = new ToSubstraitRel().visit(logicalPlan2) - - val extensionCollector = new ExtensionCollector; - val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel) - new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto) - - pojoRel2.shouldEqualPlainly(pojoRel) - logicalPlan2 + val sparkPlan = plan(query) + + // convert spark logical plan to substrait + val substraitPlan = new ToSubstraitRel().visit(sparkPlan) + + // Serialize to protobuf byte array + val extensionCollector = new ExtensionCollector + val bytes = new RelProtoConverter(extensionCollector).toProto(substraitPlan).toByteArray + + // Read it back + val protoPlan = io.substrait.proto.Rel.parseFrom(bytes) + val substraitPlan2 = + new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(protoPlan) + + // convert substrait back to spark plan + val sparkPlan2 = substraitPlan2.accept(new ToLogicalPlan(spark)) + + // and back to substrait again + val substraitPlan3 = new ToSubstraitRel().visit(sparkPlan2) + + // compare with original substrait plan to ensure it round-tripped (via proto bytes) correctly + substraitPlan3.shouldEqualPlainly(substraitPlan) + sparkPlan2 } def plan(sql: String): LogicalPlan = { diff --git a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala index 048fec0ed..f0de6c122 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala @@ -35,8 +35,7 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { val failingSQL: Set[String] = Set( "q2", // because round() isn't defined in substrait to work with Decimal. https://github.com/substrait-io/substrait/pull/713 "q9", // requires implementation of named_struct() - "q10", "q35", "q45", // Unsupported join type ExistenceJoin (this is an internal spark type) - "q51", "q84", // TBD + "q35", "q51", "q84", // These fail when comparing the round-tripped query plans, but are actually equivalent (due to aliases being ignored by substrait) "q72" //requires implementation of date_add() ) // spotless:on diff --git a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala index c96cd5531..038f1ef3e 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala @@ -16,8 +16,6 @@ */ package io.substrait.spark -import io.substrait.spark.logical.{ToLogicalPlan, ToSubstraitRel} - import org.apache.spark.sql.TPCHBase class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { @@ -178,6 +176,39 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { "((l_orderkey, L_COMMITDATE), (l_orderkey, L_COMMITDATE, l_linestatus), l_shipdate, ())") } + test("in_subquery") { + assertSqlSubstraitRelRoundTrip( + "select p_retailprice from part where p_comment in (select o_comment from orders)") + } + + test("exists") { + assertSqlSubstraitRelRoundTrip( + "select p_retailprice from part where exists (select o_comment from orders)") + } + + test("ExistenceJoin") { + assertSqlSubstraitRelRoundTrip( + "select p_retailprice from part where p_partkey > 0 or p_comment in (select o_comment from orders)") + + assertSqlSubstraitRelRoundTrip("select p_retailprice from part where p_partkey > 0 or " + + "(p_comment in (select o_comment from orders) or p_comment in (select r_comment from region))") + + assertSqlSubstraitRelRoundTrip( + "select p_retailprice, sum(p_retailprice) from part " + + "where p_partkey > 0 or p_comment in (select o_comment from orders) " + + "group by p_retailprice") + + assertSqlSubstraitRelRoundTrip( + "select p_retailprice from part where p_partkey > 0 or " + + "(p_comment, p_retailprice) in (select o_comment, o_totalprice from orders)" + ) + + assertSqlSubstraitRelRoundTrip( + "select p_retailprice from part where p_partkey > 0 or " + + "(upper(p_comment), p_retailprice * 1.2) in (select upper(o_comment), o_totalprice from orders)" + ) + } + test("tpch_q1_variant") { // difference from tpch_q1 : 1) remove order by clause; 2) remove interval date literal assertSqlSubstraitRelRoundTrip(