From 03ac4c80b890a419649475968ba900b31c2bf9f3 Mon Sep 17 00:00:00 2001 From: Andrew Coleman Date: Wed, 2 Oct 2024 11:52:07 +0100 Subject: [PATCH] feat(spark): add Window support To support the OVER clause in SQL Signed-off-by: Andrew Coleman --- .../substrait/debug/RelToVerboseString.scala | 13 ++ .../io/substrait/spark/SparkExtension.scala | 6 +- .../spark/expression/FunctionConverter.scala | 19 ++- .../spark/expression/FunctionMappings.scala | 13 ++ .../spark/expression/ToWindowFunction.scala | 141 ++++++++++++++++++ .../spark/logical/ToLogicalPlan.scala | 55 +++++++ .../spark/logical/ToSubstraitRel.scala | 33 +++- .../main/scala/io/substrait/utils/Util.scala | 1 + .../scala/io/substrait/spark/TPCDSPlan.scala | 2 +- .../scala/io/substrait/spark/WindowPlan.scala | 67 +++++++++ 10 files changed, 340 insertions(+), 10 deletions(-) create mode 100644 spark/src/main/scala/io/substrait/spark/expression/ToWindowFunction.scala create mode 100644 spark/src/test/scala/io/substrait/spark/WindowPlan.scala diff --git a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala index 0ba749b9e..79b34462f 100644 --- a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala +++ b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala @@ -152,6 +152,19 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } + override def visit(window: ConsistentPartitionWindow): String = { + withBuilder(window, 10)( + builder => { + builder + .append("functions=") + .append(window.getWindowFunctions) + .append("partitions=") + .append(window.getPartitionExpressions) + .append("sorts=") + .append(window.getSorts) + }) + } + override def visit(localFiles: LocalFiles): String = { withBuilder(localFiles, 10)( builder => { diff --git a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala index d61a06d3e..53b5bfaaf 100644 --- a/spark/src/main/scala/io/substrait/spark/SparkExtension.scala +++ b/spark/src/main/scala/io/substrait/spark/SparkExtension.scala @@ -16,7 +16,7 @@ */ package io.substrait.spark -import io.substrait.spark.expression.ToAggregateFunction +import io.substrait.spark.expression.{ToAggregateFunction, ToWindowFunction} import io.substrait.extension.SimpleExtension @@ -43,4 +43,8 @@ object SparkExtension { val toAggregateFunction: ToAggregateFunction = ToAggregateFunction( JavaConverters.asScalaBuffer(EXTENSION_COLLECTION.aggregateFunctions())) + + val toWindowFunction: ToWindowFunction = ToWindowFunction( + JavaConverters.asScalaBuffer(EXTENSION_COLLECTION.windowFunctions()) + ) } diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala index 7b0c2b473..1dbbbcdc3 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionConverter.scala @@ -19,7 +19,7 @@ package io.substrait.spark.expression import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion} -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, WindowExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.types.DataType @@ -237,7 +237,6 @@ class FunctionFinder[F <: SimpleExtension.Function, T]( val parent: FunctionConverter[F, T]) { def attemptMatch(expression: Expression, operands: Seq[SExpression]): Option[T] = { - val opTypes = operands.map(_.getType) val outputType = ToSubstraitType.apply(expression.dataType, expression.nullable) val opTypesStr = opTypes.map(t => t.accept(ToTypeString.INSTANCE)) @@ -249,17 +248,23 @@ class FunctionFinder[F <: SimpleExtension.Function, T]( .map(name + ":" + _) .find(k => directMap.contains(k)) - if (directMatchKey.isDefined) { + if (operands.isEmpty) { + val variant = directMap(name + ":") + variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType) + Option(parent.generateBinding(expression, variant, operands, outputType)) + } else if (directMatchKey.isDefined) { val variant = directMap(directMatchKey.get) variant.validateOutputType(JavaConverters.bufferAsJavaList(operands.toBuffer), outputType) val funcArgs: Seq[FunctionArg] = operands Option(parent.generateBinding(expression, variant, funcArgs, outputType)) } else if (singularInputType.isDefined) { - val types = expression match { - case agg: AggregateExpression => agg.aggregateFunction.children.map(_.dataType) - case other => other.children.map(_.dataType) + val children = expression match { + case agg: AggregateExpression => agg.aggregateFunction.children + case win: WindowExpression => win.windowFunction.children + case other => other.children } - val nullable = expression.children.exists(e => e.nullable) + val types = children.map(_.dataType) + val nullable = children.exists(e => e.nullable) FunctionFinder .leastRestrictive(types) .flatMap( diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala index 27eebfc67..d6e36788e 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala @@ -72,9 +72,22 @@ class FunctionMappings { s[HyperLogLogPlusPlus]("approx_count_distinct") ) + val WINDOW_SIGS: Seq[Sig] = Seq( + s[RowNumber]("row_number"), + s[Rank]("rank"), + s[DenseRank]("dense_rank"), + s[PercentRank]("percent_rank"), + s[CumeDist]("cume_dist"), + s[NTile]("ntile"), + s[Lead]("lead"), + s[Lag]("lag"), + s[NthValue]("nth_value") + ) + lazy val scalar_functions_map: Map[Class[_], Sig] = SCALAR_SIGS.map(s => (s.expClass, s)).toMap lazy val aggregate_functions_map: Map[Class[_], Sig] = AGGREGATE_SIGS.map(s => (s.expClass, s)).toMap + lazy val window_functions_map: Map[Class[_], Sig] = WINDOW_SIGS.map(s => (s.expClass, s)).toMap } object FunctionMappings extends FunctionMappings diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToWindowFunction.scala b/spark/src/main/scala/io/substrait/spark/expression/ToWindowFunction.scala new file mode 100644 index 000000000..6b77e821a --- /dev/null +++ b/spark/src/main/scala/io/substrait/spark/expression/ToWindowFunction.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.substrait.spark.expression + +import io.substrait.spark.expression.ToWindowFunction.{fromSpark, fromSparkFollowing, fromSparkPreceding} + +import org.apache.spark.sql.catalyst.expressions.{CurrentRow, Expression, FrameType, Literal, OffsetWindowFunction, RangeFrame, RowFrame, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, UnspecifiedFrame, WindowExpression, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.types.{IntegerType, LongType} + +import io.substrait.`type`.Type +import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FunctionArg, WindowBound} +import io.substrait.expression.Expression.WindowBoundsType +import io.substrait.expression.WindowBound.{CURRENT_ROW, Following, Preceding, UNBOUNDED} +import io.substrait.extension.SimpleExtension +import io.substrait.relation.ConsistentPartitionWindow.WindowRelFunctionInvocation + +import scala.collection.JavaConverters + +abstract class ToWindowFunction(functions: Seq[SimpleExtension.WindowFunctionVariant]) + extends FunctionConverter[SimpleExtension.WindowFunctionVariant, WindowRelFunctionInvocation]( + functions) { + + override def generateBinding( + sparkExp: Expression, + function: SimpleExtension.WindowFunctionVariant, + arguments: Seq[FunctionArg], + outputType: Type): WindowRelFunctionInvocation = { + + val (frameType, lower, upper) = sparkExp match { + case WindowExpression(_: OffsetWindowFunction, _) => + (WindowBoundsType.ROWS, UNBOUNDED, CURRENT_ROW) + case WindowExpression( + _, + WindowSpecDefinition(_, _, SpecifiedWindowFrame(frameType, lower, upper))) => + (fromSpark(frameType), fromSparkPreceding(lower), fromSparkFollowing(upper)) + case WindowExpression(_, WindowSpecDefinition(_, _, UnspecifiedFrame)) => + (WindowBoundsType.ROWS, UNBOUNDED, CURRENT_ROW) + case _ => throw new UnsupportedOperationException(s"Unsupported window expression: $sparkExp") + } + + ExpressionCreator.windowRelFunction( + function, + outputType, + SExpression.AggregationPhase.INITIAL_TO_RESULT, // use defaults... + SExpression.AggregationInvocation.ALL, // Spark doesn't define these + frameType, + lower, + upper, + JavaConverters.asJavaIterable(arguments) + ) + } + + def convert( + expression: WindowExpression, + operands: Seq[SExpression]): Option[WindowRelFunctionInvocation] = { + val cls = expression.windowFunction match { + case agg: AggregateExpression => agg.aggregateFunction.getClass + case other => other.getClass + } + + Option(signatures.get(cls)) + .filter(m => m.allowedArgCount(2)) + .flatMap(m => m.attemptMatch(expression, operands)) + } + + def apply( + expression: WindowExpression, + operands: Seq[SExpression]): WindowRelFunctionInvocation = { + convert(expression, operands).getOrElse(throw new UnsupportedOperationException( + s"Unable to find binding for call ${expression.windowFunction} -- $operands -- $expression")) + } +} + +object ToWindowFunction { + def fromSpark(frameType: FrameType): WindowBoundsType = frameType match { + case RowFrame => WindowBoundsType.ROWS + case RangeFrame => WindowBoundsType.RANGE + case other => throw new UnsupportedOperationException(s"Unsupported bounds type: $other.") + } + + def toSpark(boundsType: WindowBoundsType): FrameType = boundsType match { + case WindowBoundsType.ROWS => RowFrame + case WindowBoundsType.RANGE => RangeFrame + case other => throw new UnsupportedOperationException(s"Unsupported bounds type: $other.") + } + + def fromSparkPreceding(bound: Expression): WindowBound = bound match { + case UnboundedPreceding => UNBOUNDED + case CurrentRow => CURRENT_ROW + case Literal(i: Int, IntegerType) => Preceding.of(i.toLong) + case Literal(l: Long, LongType) => Preceding.of(l) + case _ => + throw new UnsupportedOperationException(s"Unsupported bounds expression ${bound.getClass}") + } + + def fromSparkFollowing(bound: Expression): WindowBound = bound match { + case UnboundedFollowing => UNBOUNDED + case CurrentRow => CURRENT_ROW + case Literal(i: Int, IntegerType) => Following.of(i.toLong) + case Literal(l: Long, LongType) => Following.of(l) + case _ => + throw new UnsupportedOperationException(s"Unsupported bounds expression ${bound.getClass}") + } + + def toSparkPreceding(bound: WindowBound): Expression = bound match { + case UNBOUNDED => UnboundedPreceding + case CURRENT_ROW => CurrentRow + case p: Preceding => Literal(p.offset()) + case _ => throw new UnsupportedOperationException(s"Unsupported bounds expression $bound") + } + + def toSparkFollowing(bound: WindowBound): Expression = bound match { + case UNBOUNDED => UnboundedFollowing + case CURRENT_ROW => CurrentRow + case f: Following => Literal(f.offset()) + case _ => throw new UnsupportedOperationException(s"Unsupported bounds expression $bound") + } + + def apply(functions: Seq[SimpleExtension.WindowFunctionVariant]): ToWindowFunction = { + new ToWindowFunction(functions) { + override def getSigs: Seq[Sig] = + FunctionMappings.WINDOW_SIGS ++ FunctionMappings.AGGREGATE_SIGS + } + } + +} diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 3740babee..0cf5f46fe 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -116,6 +116,59 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } } + override def visit(window: relation.ConsistentPartitionWindow): LogicalPlan = { + val child = window.getInput.accept(this) + withChild(child) { + val partitions = window.getPartitionExpressions.asScala + .map(expr => expr.accept(expressionConverter)) + val sortOrders = window.getSorts.asScala.map(toSortOrder) + val windowExpressions = window.getWindowFunctions.asScala + .map( + func => { + val arguments = func.arguments().asScala.zipWithIndex.map { + case (arg, i) => + arg.accept(func.declaration(), i, expressionConverter) + } + val windowFunction = SparkExtension.toWindowFunction + .getSparkExpressionFromSubstraitFunc(func.declaration.key, func.outputType) + .map(sig => sig.makeCall(arguments)) + .map { + case win: WindowFunction => win + case agg: AggregateFunction => + AggregateExpression( + agg, + ToAggregateFunction.toSpark(func.aggregationPhase()), + ToAggregateFunction.toSpark(func.invocation()), + None) + } + .getOrElse({ + val msg = String.format( + "Unable to convert Window function %s(%s).", + func.declaration.name, + func.arguments.asScala + .map { + case ea: exp.EnumArg => ea.value.toString + case e: SExpression => e.getType.accept(new StringTypeVisitor) + case t: Type => t.accept(new StringTypeVisitor) + case a => throw new IllegalStateException("Unexpected value: " + a) + } + .mkString(", ") + ) + throw new IllegalArgumentException(msg) + }) + val frame = SpecifiedWindowFrame( + ToWindowFunction.toSpark(func.boundsType), + ToWindowFunction.toSparkPreceding(func.lowerBound()), + ToWindowFunction.toSparkFollowing(func.upperBound()) + ) + val spec = WindowSpecDefinition(partitions, sortOrders, frame) + WindowExpression(windowFunction, spec) + }) + .map(toNamedExpression) + Window(windowExpressions, partitions, sortOrders, child) + } + } + override def visit(join: relation.Join): LogicalPlan = { val left = join.getLeft.accept(this) val right = join.getRight.accept(this) @@ -159,6 +212,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } SortOrder(expression, direction, nullOrdering, Seq.empty) } + override def visit(fetch: relation.Fetch): LogicalPlan = { val child = fetch.getInput.accept(this) val limit = fetch.getCount.getAsLong.intValue() @@ -177,6 +231,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] Offset(toLiteral(offset), child) } } + override def visit(sort: relation.Sort): LogicalPlan = { val child = sort.getInput.accept(this) withChild(child) { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 46d00f8cd..216807718 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{NullType, StructType} import ToSubstraitType.toNamedStruct import io.substrait.{proto, relation} import io.substrait.debug.TreePrinter @@ -166,6 +166,37 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { .build() } + private def fromWindowCall( + expression: WindowExpression, + output: Seq[Attribute]): relation.ConsistentPartitionWindow.WindowRelFunctionInvocation = { + val children = expression.windowFunction match { + case agg: AggregateExpression => agg.aggregateFunction.children + case _: RankLike => Seq.empty + case other => other.children + } + val substraitExps = children.filter(_ != Literal(null, NullType)).map(toExpression(output)) + SparkExtension.toWindowFunction.apply(expression, substraitExps) + } + + override def visitWindow(window: Window): relation.Rel = { + val windowExpressions = window.windowExpressions.map { + case w: WindowExpression => fromWindowCall(w, window.child.output) + case a: Alias if a.child.isInstanceOf[WindowExpression] => + fromWindowCall(a.child.asInstanceOf[WindowExpression], window.child.output) + case other => + throw new UnsupportedOperationException(s"Unsupported window expression: $other") + }.asJava + + val partitionExpressions = window.partitionSpec.map(toExpression(window.child.output)).asJava + val sorts = window.orderSpec.map(toSortField(window.child.output)).asJava + relation.ConsistentPartitionWindow.builder + .input(visit(window.child)) + .addAllWindowFunctions(windowExpressions) + .addAllPartitionExpressions(partitionExpressions) + .addAllSorts(sorts) + .build() + } + private def asLong(e: Expression): Long = e match { case IntegerLiteral(limit) => limit case other => throw new UnsupportedOperationException(s"Unknown type: $other") diff --git a/spark/src/main/scala/io/substrait/utils/Util.scala b/spark/src/main/scala/io/substrait/utils/Util.scala index 165d59953..f7d373155 100644 --- a/spark/src/main/scala/io/substrait/utils/Util.scala +++ b/spark/src/main/scala/io/substrait/utils/Util.scala @@ -29,6 +29,7 @@ object Util { * Thomas Preissler */ def crossProduct[T](lists: Seq[Seq[T]]): Seq[Seq[T]] = { + if (lists.isEmpty) return lists /** list [a, b], element 1 => list + element => [a, b, 1] */ val appendElementToList: (Seq[T], T) => Seq[T] = diff --git a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala index 826d7200c..aa1cc8e18 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala @@ -56,7 +56,7 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { } } - ignore("window") { + test("window") { val qry = s"""(SELECT | item_sk, | rank() diff --git a/spark/src/test/scala/io/substrait/spark/WindowPlan.scala b/spark/src/test/scala/io/substrait/spark/WindowPlan.scala new file mode 100644 index 000000000..769b97717 --- /dev/null +++ b/spark/src/test/scala/io/substrait/spark/WindowPlan.scala @@ -0,0 +1,67 @@ +package io.substrait.spark + +import io.substrait.spark.logical.{ToLogicalPlan, ToSubstraitRel} + +import org.apache.spark.sql.TPCBase +import org.apache.spark.sql.catalyst.TableIdentifier + +/** + * These tests are based on the examples in the Spark documentation on Window functions. + * https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-window.html + */ +class WindowPlan extends TPCBase with SubstraitPlanTestBase { + override def beforeAll(): Unit = { + super.beforeAll() + sparkContext.setLogLevel("WARN") + } + + override protected def createTables(): Unit = { + spark.sql( + "CREATE TABLE employees (name STRING, dept STRING, salary INT, age INT) USING parquet;") + } + + override protected def dropTables(): Unit = { + spark.sessionState.catalog.dropTable(TableIdentifier("employees"), true, true) + } + + test("rank") { + val query = + """ + |SELECT name, dept, salary, RANK() OVER (PARTITION BY dept ORDER BY salary) AS rank FROM employees + | + |""".stripMargin + assertSqlSubstraitRelRoundTrip(query) + } + + test("cume_dist") { + val query = + """ + |SELECT name, dept, age, CUME_DIST() OVER (PARTITION BY dept ORDER BY age + | RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cume_dist FROM employees + | + |""".stripMargin + assertSqlSubstraitRelRoundTrip(query) + } + + test("min") { + val query = + """ + |SELECT name, dept, salary, MIN(salary) OVER (PARTITION BY dept ORDER BY salary) AS min + | FROM employees + | + |""".stripMargin + assertSqlSubstraitRelRoundTrip(query) + } + + test("lag/lead") { + val query = + """ + |SELECT name, salary, + | LAG(salary) OVER (PARTITION BY dept ORDER BY salary) AS lag, + | LEAD(salary, 1, 0) OVER (PARTITION BY dept ORDER BY salary) AS lead + | FROM employees; + | + |""".stripMargin + assertSqlSubstraitRelRoundTrip(query) + } +}