-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(spark): add Window support #307
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,9 +80,22 @@ class FunctionMappings { | |
s[HyperLogLogPlusPlus]("approx_count_distinct") | ||
) | ||
|
||
val WINDOW_SIGS: Seq[Sig] = Seq( | ||
s[RowNumber]("row_number"), | ||
s[Rank]("rank"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From looking at our internal fork's window impl, the rank functions in Spark have some child column but substrait doesn't define any. I wonder how it works here, ie how do get a match still? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
s[DenseRank]("dense_rank"), | ||
s[PercentRank]("percent_rank"), | ||
s[CumeDist]("cume_dist"), | ||
s[NTile]("ntile"), | ||
s[Lead]("lead"), | ||
s[Lag]("lag"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also have a note around Lead and Lag having an NullType null as a child, which I think Substrait doesn't support, did you run into anything like that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't seen this - do you have a test case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like you had, you just handle it here: https://github.com/substrait-io/substrait-java/pull/307/files#diff-78e1a9d4e9f4c7d8968b3ba83d3cc5222ade95d28e16b0328277c3f8c8a9d313R183 😄 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yes - not known for my memory! 😂 |
||
s[NthValue]("nth_value") | ||
) | ||
|
||
lazy val scalar_functions_map: Map[Class[_], Sig] = SCALAR_SIGS.map(s => (s.expClass, s)).toMap | ||
lazy val aggregate_functions_map: Map[Class[_], Sig] = | ||
AGGREGATE_SIGS.map(s => (s.expClass, s)).toMap | ||
lazy val window_functions_map: Map[Class[_], Sig] = WINDOW_SIGS.map(s => (s.expClass, s)).toMap | ||
} | ||
|
||
object FunctionMappings extends FunctionMappings |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package io.substrait.spark.expression | ||
|
||
import io.substrait.spark.expression.ToWindowFunction.fromSpark | ||
|
||
import org.apache.spark.sql.catalyst.expressions.{CurrentRow, Expression, FrameType, Literal, OffsetWindowFunction, RangeFrame, RowFrame, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, UnspecifiedFrame, WindowExpression, WindowFrame, WindowSpecDefinition} | ||
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression | ||
import org.apache.spark.sql.types.{IntegerType, LongType} | ||
|
||
import io.substrait.`type`.Type | ||
import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FunctionArg, WindowBound} | ||
import io.substrait.expression.Expression.WindowBoundsType | ||
import io.substrait.expression.WindowBound.{CURRENT_ROW, UNBOUNDED, WindowBoundVisitor} | ||
import io.substrait.extension.SimpleExtension | ||
import io.substrait.relation.ConsistentPartitionWindow.WindowRelFunctionInvocation | ||
|
||
import scala.collection.JavaConverters | ||
|
||
abstract class ToWindowFunction(functions: Seq[SimpleExtension.WindowFunctionVariant]) | ||
extends FunctionConverter[SimpleExtension.WindowFunctionVariant, WindowRelFunctionInvocation]( | ||
functions) { | ||
|
||
override def generateBinding( | ||
sparkExp: Expression, | ||
function: SimpleExtension.WindowFunctionVariant, | ||
arguments: Seq[FunctionArg], | ||
outputType: Type): WindowRelFunctionInvocation = { | ||
|
||
val (frameType, lower, upper) = sparkExp match { | ||
case WindowExpression( | ||
_, | ||
WindowSpecDefinition(_, _, SpecifiedWindowFrame(frameType, lower, upper))) => | ||
(fromSpark(frameType), fromSpark(lower), fromSpark(upper)) | ||
case WindowExpression(_, WindowSpecDefinition(_, orderSpec, UnspecifiedFrame)) => | ||
if (orderSpec.isEmpty) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just for posterity, this comes from Spark notes: https://github.com/apache/spark/blob/250f8affd04e4be14446dd02a1c52716e54a226d/sql/api/src/main/scala/org/apache/spark/sql/expressions/Window.scala#L36 |
||
(WindowBoundsType.ROWS, UNBOUNDED, UNBOUNDED) | ||
} else { | ||
(WindowBoundsType.RANGE, UNBOUNDED, CURRENT_ROW) | ||
} | ||
|
||
case _ => throw new UnsupportedOperationException(s"Unsupported window expression: $sparkExp") | ||
} | ||
|
||
ExpressionCreator.windowRelFunction( | ||
function, | ||
outputType, | ||
SExpression.AggregationPhase.INITIAL_TO_RESULT, // use defaults... | ||
SExpression.AggregationInvocation.ALL, // Spark doesn't define these | ||
frameType, | ||
lower, | ||
upper, | ||
JavaConverters.asJavaIterable(arguments) | ||
) | ||
} | ||
|
||
def convert( | ||
expression: WindowExpression, | ||
operands: Seq[SExpression]): Option[WindowRelFunctionInvocation] = { | ||
val cls = expression.windowFunction match { | ||
case agg: AggregateExpression => agg.aggregateFunction.getClass | ||
case other => other.getClass | ||
} | ||
|
||
Option(signatures.get(cls)) | ||
.flatMap(m => m.attemptMatch(expression, operands)) | ||
} | ||
|
||
def apply( | ||
expression: WindowExpression, | ||
operands: Seq[SExpression]): WindowRelFunctionInvocation = { | ||
convert(expression, operands).getOrElse(throw new UnsupportedOperationException( | ||
s"Unable to find binding for call ${expression.windowFunction} -- $operands -- $expression")) | ||
} | ||
} | ||
|
||
object ToWindowFunction { | ||
def fromSpark(frameType: FrameType): WindowBoundsType = frameType match { | ||
case RowFrame => WindowBoundsType.ROWS | ||
case RangeFrame => WindowBoundsType.RANGE | ||
case other => throw new UnsupportedOperationException(s"Unsupported bounds type: $other.") | ||
} | ||
|
||
def fromSpark(bound: Expression): WindowBound = bound match { | ||
case UnboundedPreceding => WindowBound.UNBOUNDED | ||
case UnboundedFollowing => WindowBound.UNBOUNDED | ||
case CurrentRow => WindowBound.CURRENT_ROW | ||
case e: Literal => | ||
e.dataType match { | ||
case IntegerType | LongType => | ||
val offset = e.eval().asInstanceOf[Int] | ||
if (offset < 0) WindowBound.Preceding.of(-offset) | ||
else if (offset == 0) WindowBound.CURRENT_ROW | ||
else WindowBound.Following.of(offset) | ||
} | ||
case _ => throw new UnsupportedOperationException(s"Unexpected bound: $bound") | ||
} | ||
|
||
def toSparkFrame( | ||
boundsType: WindowBoundsType, | ||
lowerBound: WindowBound, | ||
upperBound: WindowBound): WindowFrame = { | ||
val frameType = boundsType match { | ||
case WindowBoundsType.ROWS => RowFrame | ||
case WindowBoundsType.RANGE => RangeFrame | ||
case WindowBoundsType.UNSPECIFIED => return UnspecifiedFrame | ||
} | ||
SpecifiedWindowFrame( | ||
frameType, | ||
toSparkBound(lowerBound, isLower = true), | ||
toSparkBound(upperBound, isLower = false)) | ||
} | ||
|
||
private def toSparkBound(bound: WindowBound, isLower: Boolean): Expression = { | ||
bound.accept(new WindowBoundVisitor[Expression, Exception] { | ||
|
||
override def visit(preceding: WindowBound.Preceding): Expression = | ||
Literal(-preceding.offset().intValue()) | ||
|
||
override def visit(following: WindowBound.Following): Expression = | ||
Literal(following.offset().intValue()) | ||
|
||
override def visit(currentRow: WindowBound.CurrentRow): Expression = CurrentRow | ||
|
||
override def visit(unbounded: WindowBound.Unbounded): Expression = | ||
if (isLower) UnboundedPreceding else UnboundedFollowing | ||
}) | ||
} | ||
|
||
def apply(functions: Seq[SimpleExtension.WindowFunctionVariant]): ToWindowFunction = { | ||
new ToWindowFunction(functions) { | ||
override def getSigs: Seq[Sig] = | ||
FunctionMappings.WINDOW_SIGS ++ FunctionMappings.AGGREGATE_SIGS | ||
} | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW, in our internal fork I just deleted the whole RelToVerboseString thing. I don't see it bringing that much value over pretty-printing just the protobuf. And there's a fair amount of work to maintain this.