Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spark): add Window support #307

Merged
merged 1 commit into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] {
})
}

override def visit(window: ConsistentPartitionWindow): String = {
withBuilder(window, 10)(
builder => {
builder
.append("functions=")
.append(window.getWindowFunctions)
.append("partitions=")
.append(window.getPartitionExpressions)
.append("sorts=")
.append(window.getSorts)
})
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, in our internal fork I just deleted the whole RelToVerboseString thing. I don't see it bringing that much value over pretty-printing just the protobuf. And there's a fair amount of work to maintain this.


override def visit(localFiles: LocalFiles): String = {
withBuilder(localFiles, 10)(
builder => {
Expand Down
6 changes: 5 additions & 1 deletion spark/src/main/scala/io/substrait/spark/SparkExtension.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package io.substrait.spark

import io.substrait.spark.expression.ToAggregateFunction
import io.substrait.spark.expression.{ToAggregateFunction, ToWindowFunction}

import io.substrait.extension.SimpleExtension

Expand All @@ -43,4 +43,8 @@ object SparkExtension {

val toAggregateFunction: ToAggregateFunction = ToAggregateFunction(
JavaConverters.asScalaBuffer(EXTENSION_COLLECTION.aggregateFunctions()))

val toWindowFunction: ToWindowFunction = ToWindowFunction(
JavaConverters.asScalaBuffer(EXTENSION_COLLECTION.windowFunctions())
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import io.substrait.spark.ToSubstraitType
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.types.DataType

Expand Down Expand Up @@ -238,7 +238,6 @@ class FunctionFinder[F <: SimpleExtension.Function, T](
val parent: FunctionConverter[F, T]) {

def attemptMatch(expression: Expression, operands: Seq[SExpression]): Option[T] = {

val opTypes = operands.map(_.getType)
val outputType = ToSubstraitType.apply(expression.dataType, expression.nullable)
val opTypesStr = opTypes.map(t => t.accept(ToTypeString.INSTANCE))
Expand All @@ -250,17 +249,23 @@ class FunctionFinder[F <: SimpleExtension.Function, T](
.map(name + ":" + _)
.find(k => directMap.contains(k))

if (directMatchKey.isDefined) {
if (operands.isEmpty) {
val variant = directMap(name + ":")
variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType)
Option(parent.generateBinding(expression, variant, operands, outputType))
} else if (directMatchKey.isDefined) {
val variant = directMap(directMatchKey.get)
variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType)
val funcArgs: Seq[FunctionArg] = operands
Option(parent.generateBinding(expression, variant, funcArgs, outputType))
} else if (singularInputType.isDefined) {
val types = expression match {
case agg: AggregateExpression => agg.aggregateFunction.children.map(_.dataType)
case other => other.children.map(_.dataType)
val children = expression match {
case agg: AggregateExpression => agg.aggregateFunction.children
case win: WindowExpression => win.windowFunction.children
case other => other.children
}
val nullable = expression.children.exists(e => e.nullable)
val types = children.map(_.dataType)
val nullable = children.exists(e => e.nullable)
FunctionFinder
.leastRestrictive(types)
.flatMap(
Expand Down Expand Up @@ -298,6 +303,4 @@ class FunctionFinder[F <: SimpleExtension.Function, T](
}
})
}

def allowedArgCount(count: Int): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,22 @@ class FunctionMappings {
s[HyperLogLogPlusPlus]("approx_count_distinct")
)

val WINDOW_SIGS: Seq[Sig] = Seq(
s[RowNumber]("row_number"),
s[Rank]("rank"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From looking at our internal fork's window impl, the rank functions in Spark have some child column but substrait doesn't define any. I wonder how it works here, ie how do get a match still?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s[DenseRank]("dense_rank"),
s[PercentRank]("percent_rank"),
s[CumeDist]("cume_dist"),
s[NTile]("ntile"),
s[Lead]("lead"),
s[Lag]("lag"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also have a note around Lead and Lag having an NullType null as a child, which I think Substrait doesn't support, did you run into anything like that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen this - do you have a test case?

Copy link
Contributor

@Blizzara Blizzara Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes - not known for my memory! 😂

s[NthValue]("nth_value")
)

lazy val scalar_functions_map: Map[Class[_], Sig] = SCALAR_SIGS.map(s => (s.expClass, s)).toMap
lazy val aggregate_functions_map: Map[Class[_], Sig] =
AGGREGATE_SIGS.map(s => (s.expClass, s)).toMap
lazy val window_functions_map: Map[Class[_], Sig] = WINDOW_SIGS.map(s => (s.expClass, s)).toMap
}

object FunctionMappings extends FunctionMappings
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ abstract class ToAggregateFunction(functions: Seq[SimpleExtension.AggregateFunct
expression: AggregateExpression,
operands: Seq[SExpression]): Option[AggregateFunctionInvocation] = {
Option(signatures.get(expression.aggregateFunction.getClass))
.filter(m => m.allowedArgCount(2))
.flatMap(m => m.attemptMatch(expression, operands))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ abstract class ToScalarFunction(functions: Seq[SimpleExtension.ScalarFunctionVar

def convert(expression: Expression, operands: Seq[SExpression]): Option[SExpression] = {
Option(signatures.get(expression.getClass))
.filter(m => m.allowedArgCount(2))
.flatMap(m => m.attemptMatch(expression, operands))
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.substrait.spark.expression

import io.substrait.spark.expression.ToWindowFunction.fromSpark

import org.apache.spark.sql.catalyst.expressions.{CurrentRow, Expression, FrameType, Literal, OffsetWindowFunction, RangeFrame, RowFrame, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, UnspecifiedFrame, WindowExpression, WindowFrame, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.types.{IntegerType, LongType}

import io.substrait.`type`.Type
import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FunctionArg, WindowBound}
import io.substrait.expression.Expression.WindowBoundsType
import io.substrait.expression.WindowBound.{CURRENT_ROW, UNBOUNDED, WindowBoundVisitor}
import io.substrait.extension.SimpleExtension
import io.substrait.relation.ConsistentPartitionWindow.WindowRelFunctionInvocation

import scala.collection.JavaConverters

abstract class ToWindowFunction(functions: Seq[SimpleExtension.WindowFunctionVariant])
extends FunctionConverter[SimpleExtension.WindowFunctionVariant, WindowRelFunctionInvocation](
functions) {

override def generateBinding(
sparkExp: Expression,
function: SimpleExtension.WindowFunctionVariant,
arguments: Seq[FunctionArg],
outputType: Type): WindowRelFunctionInvocation = {

val (frameType, lower, upper) = sparkExp match {
case WindowExpression(
_,
WindowSpecDefinition(_, _, SpecifiedWindowFrame(frameType, lower, upper))) =>
(fromSpark(frameType), fromSpark(lower), fromSpark(upper))
case WindowExpression(_, WindowSpecDefinition(_, orderSpec, UnspecifiedFrame)) =>
if (orderSpec.isEmpty) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(WindowBoundsType.ROWS, UNBOUNDED, UNBOUNDED)
} else {
(WindowBoundsType.RANGE, UNBOUNDED, CURRENT_ROW)
}

case _ => throw new UnsupportedOperationException(s"Unsupported window expression: $sparkExp")
}

ExpressionCreator.windowRelFunction(
function,
outputType,
SExpression.AggregationPhase.INITIAL_TO_RESULT, // use defaults...
SExpression.AggregationInvocation.ALL, // Spark doesn't define these
frameType,
lower,
upper,
JavaConverters.asJavaIterable(arguments)
)
}

def convert(
expression: WindowExpression,
operands: Seq[SExpression]): Option[WindowRelFunctionInvocation] = {
val cls = expression.windowFunction match {
case agg: AggregateExpression => agg.aggregateFunction.getClass
case other => other.getClass
}

Option(signatures.get(cls))
.flatMap(m => m.attemptMatch(expression, operands))
}

def apply(
expression: WindowExpression,
operands: Seq[SExpression]): WindowRelFunctionInvocation = {
convert(expression, operands).getOrElse(throw new UnsupportedOperationException(
s"Unable to find binding for call ${expression.windowFunction} -- $operands -- $expression"))
}
}

object ToWindowFunction {
def fromSpark(frameType: FrameType): WindowBoundsType = frameType match {
case RowFrame => WindowBoundsType.ROWS
case RangeFrame => WindowBoundsType.RANGE
case other => throw new UnsupportedOperationException(s"Unsupported bounds type: $other.")
}

def fromSpark(bound: Expression): WindowBound = bound match {
case UnboundedPreceding => WindowBound.UNBOUNDED
case UnboundedFollowing => WindowBound.UNBOUNDED
case CurrentRow => WindowBound.CURRENT_ROW
case e: Literal =>
e.dataType match {
case IntegerType | LongType =>
val offset = e.eval().asInstanceOf[Int]
if (offset < 0) WindowBound.Preceding.of(-offset)
else if (offset == 0) WindowBound.CURRENT_ROW
else WindowBound.Following.of(offset)
}
case _ => throw new UnsupportedOperationException(s"Unexpected bound: $bound")
}

def toSparkFrame(
boundsType: WindowBoundsType,
lowerBound: WindowBound,
upperBound: WindowBound): WindowFrame = {
val frameType = boundsType match {
case WindowBoundsType.ROWS => RowFrame
case WindowBoundsType.RANGE => RangeFrame
case WindowBoundsType.UNSPECIFIED => return UnspecifiedFrame
}
SpecifiedWindowFrame(
frameType,
toSparkBound(lowerBound, isLower = true),
toSparkBound(upperBound, isLower = false))
}

private def toSparkBound(bound: WindowBound, isLower: Boolean): Expression = {
bound.accept(new WindowBoundVisitor[Expression, Exception] {

override def visit(preceding: WindowBound.Preceding): Expression =
Literal(-preceding.offset().intValue())

override def visit(following: WindowBound.Following): Expression =
Literal(following.offset().intValue())

override def visit(currentRow: WindowBound.CurrentRow): Expression = CurrentRow

override def visit(unbounded: WindowBound.Unbounded): Expression =
if (isLower) UnboundedPreceding else UnboundedFollowing
})
}

def apply(functions: Seq[SimpleExtension.WindowFunctionVariant]): ToWindowFunction = {
new ToWindowFunction(functions) {
override def getSigs: Seq[Sig] =
FunctionMappings.WINDOW_SIGS ++ FunctionMappings.AGGREGATE_SIGS
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,56 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
}
}

override def visit(window: relation.ConsistentPartitionWindow): LogicalPlan = {
val child = window.getInput.accept(this)
withChild(child) {
val partitions = window.getPartitionExpressions.asScala
.map(expr => expr.accept(expressionConverter))
val sortOrders = window.getSorts.asScala.map(toSortOrder)
val windowExpressions = window.getWindowFunctions.asScala
.map(
func => {
val arguments = func.arguments().asScala.zipWithIndex.map {
case (arg, i) =>
arg.accept(func.declaration(), i, expressionConverter)
}
val windowFunction = SparkExtension.toWindowFunction
.getSparkExpressionFromSubstraitFunc(func.declaration.key, func.outputType)
.map(sig => sig.makeCall(arguments))
.map {
case win: WindowFunction => win
case agg: AggregateFunction =>
AggregateExpression(
agg,
ToAggregateFunction.toSpark(func.aggregationPhase()),
ToAggregateFunction.toSpark(func.invocation()),
None)
}
.getOrElse({
val msg = String.format(
"Unable to convert Window function %s(%s).",
func.declaration.name,
func.arguments.asScala
.map {
case ea: exp.EnumArg => ea.value.toString
case e: SExpression => e.getType.accept(new StringTypeVisitor)
case t: Type => t.accept(new StringTypeVisitor)
case a => throw new IllegalStateException("Unexpected value: " + a)
}
.mkString(", ")
)
throw new IllegalArgumentException(msg)
})
val frame =
ToWindowFunction.toSparkFrame(func.boundsType(), func.lowerBound(), func.upperBound())
val spec = WindowSpecDefinition(partitions, sortOrders, frame)
WindowExpression(windowFunction, spec)
})
.map(toNamedExpression)
Window(windowExpressions, partitions, sortOrders, child)
}
}

override def visit(join: relation.Join): LogicalPlan = {
val left = join.getLeft.accept(this)
val right = join.getRight.accept(this)
Expand Down Expand Up @@ -162,6 +212,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
}
SortOrder(expression, direction, nullOrdering, Seq.empty)
}

override def visit(fetch: relation.Fetch): LogicalPlan = {
val child = fetch.getInput.accept(this)
val limit = fetch.getCount.getAsLong.intValue()
Expand All @@ -180,6 +231,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
Offset(toLiteral(offset), child)
}
}

override def visit(sort: relation.Sort): LogicalPlan = {
val child = sort.getInput.accept(this)
withChild(child) {
Expand Down
Loading
Loading