From 3498d44f5c0b59992fc29ad9d5dfc8b30502fe7f Mon Sep 17 00:00:00 2001 From: Andrew Coleman Date: Mon, 30 Sep 2024 10:55:39 +0100 Subject: [PATCH] =?UTF-8?q?feat(spark):=20add=20support=20for=20=E2=80=98o?= =?UTF-8?q?ffset=E2=80=99=20clause?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add missing support for the ‘offset’ clause in the spark module. Signed-off-by: Andrew Coleman --- .../spark/logical/ToLogicalPlan.scala | 19 +++++++--- .../spark/logical/ToSubstraitRel.scala | 36 +++++++++++++------ .../scala/io/substrait/spark/TPCHPlan.scala | 8 ++++- 3 files changed, 46 insertions(+), 17 deletions(-) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 45b6c2205..547a4b501 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -160,11 +160,20 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } override def visit(fetch: relation.Fetch): LogicalPlan = { val child = fetch.getInput.accept(this) - val limit = Literal(fetch.getCount.getAsLong.intValue(), IntegerType) - fetch.getOffset match { - case 1L => GlobalLimit(limitExpr = limit, child = child) - case -1L => LocalLimit(limitExpr = limit, child = child) - case _ => visitFallback(fetch) + val limit = fetch.getCount.getAsLong.intValue() + val offset = fetch.getOffset.intValue() + val toLiteral = (i: Int) => Literal(i, IntegerType) + if (limit >= 0) { + val limitExpr = toLiteral(limit) + if (offset > 0) { + GlobalLimit(limitExpr, + Offset(toLiteral(offset), + LocalLimit(toLiteral(offset + limit), child))) + } else { + GlobalLimit(limitExpr, LocalLimit(limitExpr, child)) + } + } else { + Offset(toLiteral(offset), child) } } override def visit(sort: relation.Sort): LogicalPlan = { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 08a06c2e4..e66126bb4 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -170,23 +170,37 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { case other => throw new UnsupportedOperationException(s"Unknown type: $other") } - private def fetchBuilder(limit: Long, global: Boolean): relation.ImmutableFetch.Builder = { - val offset = if (global) 1L else -1L - relation.Fetch - .builder() - .count(limit) + private def fetch(child: LogicalPlan, offset: Long, limit: Long = -1): relation.Fetch = { + relation.Fetch.builder() + .input(visit(child)) .offset(offset) + .count(limit) + .build() } + override def visitGlobalLimit(p: GlobalLimit): relation.Rel = { - fetchBuilder(asLong(p.limitExpr), global = true) - .input(visit(p.child)) - .build() + p match { + case OffsetAndLimit((offset, limit, child)) => fetch(child, offset, limit) + case GlobalLimit(IntegerLiteral(globalLimit), LocalLimit(IntegerLiteral(localLimit), child)) + if globalLimit == localLimit => fetch(child, 0, localLimit) + case _ => + throw new UnsupportedOperationException(s"Unable to convert the limit expression: $p") + } } override def visitLocalLimit(p: LocalLimit): relation.Rel = { - fetchBuilder(asLong(p.limitExpr), global = false) - .input(visit(p.child)) - .build() + val localLimit = asLong(p.limitExpr) + p.child match { + case OffsetAndLimit((offset, limit, child)) if localLimit >= limit => + fetch(child, offset, limit) + case GlobalLimit(IntegerLiteral(globalLimit), child) if localLimit >= globalLimit => + fetch(child, 0, globalLimit) + case _ => fetch(p.child, 0, localLimit) + } + } + + override def visitOffset(p: Offset): relation.Rel = { + fetch(p.child, asLong(p.offsetExpr)) } override def visitFilter(p: Filter): relation.Rel = { diff --git a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala index 224ac2e8d..76c5b9a6a 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala @@ -73,10 +73,16 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { "order by l_shipdate asc, l_discount desc nulls last") } - ignore("simpleOffsetClause") { // TODO need to implement the 'offset' clause for this to pass + test("simpleOffsetClause") { assertSqlSubstraitRelRoundTrip( "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + "order by l_shipdate asc, l_discount desc limit 100 offset 1000") + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + + "order by l_shipdate asc, l_discount desc offset 1000") + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + + "order by l_shipdate asc, l_discount desc limit 100") } test("simpleTest") {