Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spark): support ExistenceJoin internal join type #333

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion spark/src/main/scala/io/substrait/debug/ExpressionToString.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import io.substrait.spark.DefaultExpressionVisitor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

import io.substrait.expression.{Expression, FieldReference}
import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, StrLiteral}
import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, I64Literal, StrLiteral}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR, but FWIW I'd vote for removing these /debug/ things (and have done so in our fork), it's a lot of boilerplate code to maintain for not that much value 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I have mixed feelings on this. I might remove this in a followup PR.

import io.substrait.function.ToTypeString
import io.substrait.util.DecimalUtil

Expand All @@ -43,6 +43,10 @@ class ExpressionToString extends DefaultExpressionVisitor[String] {
expr.value().toString
}

override def visit(expr: I64Literal): String = {
expr.value().toString
}

override def visit(expr: DateLiteral): String = {
DateTimeUtils.toJavaDate(expr.value()).toString
}
Expand Down Expand Up @@ -76,4 +80,16 @@ class ExpressionToString extends DefaultExpressionVisitor[String] {
override def visit(expr: Expression.EmptyMapLiteral): String = {
expr.toString
}

override def visit(expr: Expression.Cast): String = {
expr.getType.toString
}

override def visit(expr: Expression.InPredicate): String = {
expr.toString
}

override def visit(expr: Expression.ScalarSubquery): String = {
expr.toString
}
}
20 changes: 20 additions & 0 deletions spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,26 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] {
})
}

override def visit(set: Set): String = {
withBuilder(set, 8)(
builder => {
builder
.append("operation=")
.append(set.getSetOp)
})
}

override def visit(cross: Cross): String = {
withBuilder(cross, 10)(
builder => {
builder
.append("left=")
.append(cross.getLeft)
.append("right=")
.append(cross.getRight)
})
}

override def visit(localFiles: LocalFiles): String = {
withBuilder(localFiles, 10)(
builder => {
Expand Down
7 changes: 7 additions & 0 deletions spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ private class ToSparkType

override def visit(expr: Type.IntervalYear): DataType = YearMonthIntervalType.DEFAULT

override def visit(expr: Type.Struct): DataType = {
StructType(
expr.fields.asScala.zipWithIndex
.map { case (t, i) => StructField(s"col${i + 1}", t.accept(this), t.nullable()) }
)
}

override def visit(expr: Type.ListType): DataType =
ArrayType(expr.elementType().accept(this), containsNull = expr.elementType().nullable())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package io.substrait.spark.expression
import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, SparkExtension, ToSubstraitType}
import io.substrait.spark.logical.ToLogicalPlan

import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, CaseWhen, Cast, EqualTo, Expression, In, InSubquery, ListQuery, Literal, MakeDecimal, NamedExpression, ScalarSubquery}
import org.apache.spark.sql.catalyst.plans.ExistenceJoin
import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DateType, Decimal}
import org.apache.spark.sql.types.{BooleanType, DateType, Decimal}
import org.apache.spark.substrait.SparkTypeUtil
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -204,6 +206,14 @@ class ToSparkExpression(
In(value, list)
}

override def visit(expr: SExpression.InPredicate): Expression = {
val needles = expr.needles().asScala.map(e => e.accept(this))
val haystack = expr.haystack().accept(toLogicalPlan.get)
new InSubquery(needles, ListQuery(haystack, childOutputs = haystack.output)) {
override def nullable: Boolean = expr.getType.nullable()
}
}

override def visit(expr: SExpression.ScalarFunctionInvocation): Expression = {
val eArgs = expr.arguments().asScala
val args = eArgs.zipWithIndex.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] {

protected def translateSubQuery(expr: PlanExpression[_]): Option[SExpression] = default(expr)

protected def translateInSubquery(expr: InSubquery): Option[SExpression] = default(expr)

protected def translateAttribute(a: AttributeReference): Option[SExpression] = {
val bindReference =
BindReferences.bindReference[Expression](a, currentOutput, allowFailures = false)
Expand Down Expand Up @@ -141,10 +143,6 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] {
case SubstraitLiteral(substraitLiteral) => Some(substraitLiteral)
case a: AttributeReference if currentOutput.nonEmpty => translateAttribute(a)
case a: Alias => translateUp(a.child)
case p
if p.getClass.getCanonicalName.equals( // removed in spark-3.3
"org.apache.spark.sql.catalyst.expressions.PromotePrecision") =>
translateUp(p.children.head)
case CaseWhen(branches, elseValue) => translateCaseWhen(branches, elseValue)
case In(value, list) => translateIn(value, list)
case InSet(value, set) => translateIn(value, set.toSeq.map(v => Literal(v)))
Expand All @@ -153,6 +151,7 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] {
.seqToOption(children.map(translateUp))
.flatMap(toScalarFunction.convert(scalar, _))
case p: PlanExpression[_] => translateSubQuery(p)
case in: InSubquery => translateInSubquery(in)
case other => default(other)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]

override def visit(fetch: relation.Fetch): LogicalPlan = {
val child = fetch.getInput.accept(this)
val limit = fetch.getCount.getAsLong.intValue()
val limit = fetch.getCount.orElse(-1).intValue() // -1 means unassigned here
val offset = fetch.getOffset.intValue()
val toLiteral = (i: Int) => Literal(i, IntegerType)
if (limit >= 0) {
Expand Down
107 changes: 86 additions & 21 deletions spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import io.substrait.relation.RelProtoConverter
import io.substrait.relation.Set.SetOp
import io.substrait.relation.files.{FileFormat, ImmutableFileOrFiles}
import io.substrait.relation.files.FileOrFiles.PathType
import io.substrait.utils.Util

import java.util.{Collections, Optional}

Expand Down Expand Up @@ -247,8 +248,9 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
}

override def visitFilter(p: Filter): relation.Rel = {
val input = visit(p.child)
val condition = toExpression(p.child.output)(p.condition)
relation.Filter.builder().condition(condition).input(visit(p.child)).build()
relation.Filter.builder().condition(condition).input(input).build()
}

private def toSubstraitJoin(joinType: JoinType): relation.Join.JoinType = joinType match {
Expand All @@ -262,33 +264,73 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
}

override def visitJoin(p: Join): relation.Rel = {
val left = visit(p.left)
val right = visit(p.right)
val condition = p.condition.map(toExpression(p.left.output ++ p.right.output)).getOrElse(TRUE)
val joinType = toSubstraitJoin(p.joinType)

if (joinType == relation.Join.JoinType.INNER && TRUE == condition) {
relation.Cross.builder
.left(left)
.right(right)
.build
} else {
relation.Join.builder
.condition(condition)
.joinType(joinType)
.left(left)
.right(right)
.build
p match {
case Join(left, right, ExistenceJoin(exists), where, _) =>
convertExistenceJoin(left, right, where, exists)
case _ =>
val left = visit(p.left)
val right = visit(p.right)
val condition =
p.condition.map(toExpression(p.left.output ++ p.right.output)).getOrElse(TRUE)
val joinType = toSubstraitJoin(p.joinType)

if (joinType == relation.Join.JoinType.INNER && TRUE == condition) {
relation.Cross.builder
.left(left)
.right(right)
.build
} else {
relation.Join.builder
.condition(condition)
.joinType(joinType)
.left(left)
.right(right)
.build
}
}
}

private def convertExistenceJoin(
left: LogicalPlan,
right: LogicalPlan,
condition: Option[Expression],
exists: Attribute): relation.Rel = {
// An ExistenceJoin is an internal spark join type that is injected by the catalyst
// optimiser. It doesn't directly map to any SQL join type, but can be modelled using
// a Substrait `InPredicate` within a Filter condition.

// The 'exists' attribute in the parent filter condition will be associated with this.

// extract the needle expressions from the join condition
def findNeedles(expr: Expression): Iterable[Expression] = expr match {
case And(lhs, rhs) => findNeedles(lhs) ++ findNeedles(rhs)
case EqualTo(lhs, _) => Seq(lhs)
case _ =>
throw new UnsupportedOperationException(
s"Unable to convert the ExistenceJoin condition: $expr")
}
val needles = condition.toIterable
.flatMap(w => findNeedles(w))
.map(toExpression(left.output))
.asJava

val haystack = visit(right)

val inPredicate =
SExpression.InPredicate.builder().needles(needles).haystack(haystack).build()
// put this in a map exists->inPredicate for later lookup
toSubstraitExp.existenceJoins.put(exists.exprId.id, inPredicate)
// return the left child, which will become the child of the enclosing filter
visit(left)
}

override def visitProject(p: Project): relation.Rel = {
val expressions = p.projectList.map(toExpression(p.child.output)).toList

val child = visit(p.child)
relation.Project.builder
.remap(relation.Rel.Remap.offset(p.child.output.size, expressions.size))
.remap(relation.Rel.Remap.offset(child.getRecordType.fields.size, expressions.size))
.expressions(expressions.asJava)
.input(visit(p.child))
.input(child)
.build()
}

Expand Down Expand Up @@ -492,6 +534,8 @@ private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel)
override protected val toScalarFunction: ToScalarFunction =
ToScalarFunction(SparkExtension.SparkScalarFunctions)

val existenceJoins = scala.collection.mutable.Map[Long, SExpression.InPredicate]()

override protected def translateSubQuery(expr: PlanExpression[_]): Option[SExpression] = {
expr match {
case s: ScalarSubquery if s.outerAttrs.isEmpty && s.joinCond.isEmpty =>
Expand All @@ -504,4 +548,25 @@ private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel)
case other => default(other)
}
}

override protected def translateInSubquery(expr: InSubquery): Option[SExpression] = {
Util
.seqToOption(expr.values.map(translateUp).toList)
.flatMap(
values =>
Some(
SExpression.InPredicate
.builder()
.needles(values.asJava)
.haystack(toSubstraitRel.visit(expr.query.plan))
.build()
))
}

override protected def translateAttribute(a: AttributeReference): Option[SExpression] = {
existenceJoins.get(a.exprId.id) match {
case Some(exists) => Some(exists)
case None => super.translateAttribute(a)
}
}
}
37 changes: 24 additions & 13 deletions spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,30 @@ trait SubstraitPlanTestBase { self: SharedSparkSession =>

def assertSqlSubstraitRelRoundTrip(query: String): LogicalPlan = {
// TODO need a more robust way of testing this than round-tripping.
val logicalPlan = plan(query)
val pojoRel = new ToSubstraitRel().visit(logicalPlan)
val converter = new ToLogicalPlan(spark = spark);
val logicalPlan2 = pojoRel.accept(converter);
require(logicalPlan2.resolved);
val pojoRel2 = new ToSubstraitRel().visit(logicalPlan2)

val extensionCollector = new ExtensionCollector;
val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel)
new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto)

pojoRel2.shouldEqualPlainly(pojoRel)
logicalPlan2
val sparkPlan = plan(query)

// convert spark logical plan to substrait
val substraitPlan = new ToSubstraitRel().visit(sparkPlan)

// Serialize to protobuf byte array
val extensionCollector = new ExtensionCollector
val bytes = new RelProtoConverter(extensionCollector).toProto(substraitPlan).toByteArray

// Read it back
val protoPlan = io.substrait.proto.Rel.parseFrom(bytes)
val substraitPlan2 =
new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(protoPlan)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a substraitPlan2.shouldEqualPlainly(substraitPlan) here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could do, although it's checking the overall roundtrip at the end

// convert substrait back to spark plan
val sparkPlan2 = substraitPlan2.accept(new ToLogicalPlan(spark))
require(sparkPlan2.resolved)

// and back to substrait again
val substraitPlan3 = new ToSubstraitRel().visit(sparkPlan2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this additional conversion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really adding an extra conversion, I'm just moving the conversion to/from protobuf into the critical path. A the moment it just invokes the protobuf conversion but doesn't check it did the right thing.
Although this is not really core to this PR, so I'm happy to remove this if you prefer :)


// compare with original substrait plan to ensure it round-tripped (via proto bytes) correctly
substraitPlan3.shouldEqualPlainly(substraitPlan)
sparkPlan2
}

def plan(sql: String): LogicalPlan = {
Expand Down
3 changes: 1 addition & 2 deletions spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase {
val failingSQL: Set[String] = Set(
"q2", // because round() isn't defined in substrait to work with Decimal. https://github.com/substrait-io/substrait/pull/713
"q9", // requires implementation of named_struct()
"q10", "q35", "q45", // Unsupported join type ExistenceJoin (this is an internal spark type)
"q51", "q84", // TBD
"q35", "q51", "q84", // These fail when comparing the round-tripped query plans, but are actually equivalent (due to aliases being ignored by substrait)
"q72" //requires implementation of date_add()
)
// spotless:on
Expand Down
35 changes: 33 additions & 2 deletions spark/src/test/scala/io/substrait/spark/TPCHPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
*/
package io.substrait.spark

import io.substrait.spark.logical.{ToLogicalPlan, ToSubstraitRel}

import org.apache.spark.sql.TPCHBase

class TPCHPlan extends TPCHBase with SubstraitPlanTestBase {
Expand Down Expand Up @@ -178,6 +176,39 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase {
"((l_orderkey, L_COMMITDATE), (l_orderkey, L_COMMITDATE, l_linestatus), l_shipdate, ())")
}

test("in_subquery") {
assertSqlSubstraitRelRoundTrip(
"select p_retailprice from part where p_comment in (select o_comment from orders)")
}

test("exists") {
assertSqlSubstraitRelRoundTrip(
"select p_retailprice from part where exists (select o_comment from orders)")
}

test("ExistenceJoin") {
assertSqlSubstraitRelRoundTrip(
"select p_retailprice from part where p_partkey > 0 or p_comment in (select o_comment from orders)")

assertSqlSubstraitRelRoundTrip("select p_retailprice from part where p_partkey > 0 or " +
"(p_comment in (select o_comment from orders) or p_comment in (select r_comment from region))")

assertSqlSubstraitRelRoundTrip(
"select p_retailprice, sum(p_retailprice) from part " +
"where p_partkey > 0 or p_comment in (select o_comment from orders) " +
"group by p_retailprice")

assertSqlSubstraitRelRoundTrip(
"select p_retailprice from part where p_partkey > 0 or " +
"(p_comment, p_retailprice) in (select o_comment, o_totalprice from orders)"
)

assertSqlSubstraitRelRoundTrip(
"select p_retailprice from part where p_partkey > 0 or " +
"(upper(p_comment), p_retailprice * 1.2) in (select upper(o_comment), o_totalprice from orders)"
)
}

test("tpch_q1_variant") {
// difference from tpch_q1 : 1) remove order by clause; 2) remove interval date literal
assertSqlSubstraitRelRoundTrip(
Expand Down
Loading