Skip to content

Commit

Permalink
[FLINK-35976][table-planner] Fix column name conflicts in StreamPhysi…
Browse files Browse the repository at this point in the history
…calOverAggregate
  • Loading branch information
lincoln-lil committed Aug 5, 2024
1 parent f4aee2a commit a9cddd1
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@ import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalOverAggrega
import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalOverAggregate, BatchPhysicalOverAggregateBase, BatchPhysicalPythonOverAggregate}
import org.apache.flink.table.planner.plan.utils.{AggregateUtil, OverAggregateUtil, SortUtil}
import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate
import org.apache.flink.table.planner.typeutils.RowTypeUtils
import org.apache.flink.table.planner.utils.ShortcutUtils

import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelOptRuleCall, RelOptUtil}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.plan.RelOptRule._
import org.apache.calcite.rel._
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.{AggregateCall, Window}
import org.apache.calcite.rel.core.Window
import org.apache.calcite.rel.core.Window.Group
import org.apache.calcite.rex.{RexInputRef, RexNode, RexShuttle}
import org.apache.calcite.sql.SqlAggFunction
Expand Down Expand Up @@ -107,7 +106,7 @@ class BatchPhysicalOverAggregateRule
(group, aggCallToAggFunction)
}

val outputRowType = inferOutputRowType(
val outputRowType = OverAggregateUtil.inferOutputRowType(
logicWindow.getCluster,
inputRowType,
groupToAggCallToAggFunction.flatMap(_._2).map(_._1))
Expand Down Expand Up @@ -198,22 +197,6 @@ class BatchPhysicalOverAggregateRule
isSatisfied
}

private def inferOutputRowType(
cluster: RelOptCluster,
inputType: RelDataType,
aggCalls: Seq[AggregateCall]): RelDataType = {

val inputNameList = inputType.getFieldNames
val inputTypeList = inputType.getFieldList.asScala.map(field => field.getType)

// we should avoid duplicated names with input column names
val aggNames = RowTypeUtils.getUniqueName(aggCalls.map(_.getName), inputNameList)
val aggTypes = aggCalls.map(_.getType)

val typeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
typeFactory.createStructType(inputTypeList ++ aggTypes, inputNameList ++ aggNames)
}

private def adjustGroup(
groupBuffer: ArrayBuffer[Window.Group],
groupIdx: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalOverAggregate
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalOverAggregate
import org.apache.flink.table.planner.plan.utils.OverAggregateUtil
import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate

import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
Expand Down Expand Up @@ -66,13 +67,20 @@ class StreamPhysicalOverAggregateRule(config: Config) extends ConverterRule(conf
.replace(FlinkConventions.STREAM_PHYSICAL)
.replace(requiredDistribution)
val providedTraitSet = rel.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL)
val newInput = RelOptRule.convert(logicWindow.getInput, requiredTraitSet)
val input = logicWindow.getInput
val newInput = RelOptRule.convert(input, requiredTraitSet)

val outputRowType = OverAggregateUtil.inferOutputRowType(
logicWindow.getCluster,
input.getRowType,
// only supports one group now
logicWindow.groups.get(0).getAggregateCalls(logicWindow).asScala)

new StreamPhysicalOverAggregate(
rel.getCluster,
providedTraitSet,
newInput,
rel.getRowType,
outputRowType,
logicWindow)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@ package org.apache.flink.table.planner.plan.utils

import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.JArrayList
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.nodes.exec.spec.{OverSpec, PartitionSpec}
import org.apache.flink.table.planner.plan.nodes.exec.spec.OverSpec.GroupSpec
import org.apache.flink.table.planner.typeutils.RowTypeUtils

import org.apache.calcite.plan.RelOptCluster
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.{RelCollation, RelCollations, RelFieldCollation}
import org.apache.calcite.rel.RelFieldCollation.{Direction, NullDirection}
import org.apache.calcite.rel.core.Window
import org.apache.calcite.rel.core.{AggregateCall, Window}
import org.apache.calcite.rex.{RexInputRef, RexLiteral, RexWindowBound}
import org.apache.calcite.sql.`type`.SqlTypeName

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

object OverAggregateUtil {
Expand Down Expand Up @@ -219,4 +224,20 @@ object OverAggregateUtil {
}
}
}

def inferOutputRowType(
cluster: RelOptCluster,
inputType: RelDataType,
aggCalls: Seq[AggregateCall]): RelDataType = {

val inputNameList = inputType.getFieldNames
val inputTypeList = inputType.getFieldList.asScala.map(field => field.getType)

// we should avoid duplicated names with input column names
val aggNames = RowTypeUtils.getUniqueName(aggCalls.map(_.getName), inputNameList)
val aggTypes = aggCalls.map(_.getType)

val typeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
typeFactory.createStructType(inputTypeList ++ aggTypes, inputNameList ++ aggNames)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,43 @@ Calc(select=[a, w0$o0 AS $1, w1$o0 AS $2])
+- Sort(orderBy=[b ASC, c ASC, a DESC])
+- Exchange(distribution=[hash[b]])
+- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b, c])
]]>
</Resource>
</TestCase>
<TestCase name="testNestedOverAgg">
<Resource name="sql">
<![CDATA[
SELECT *
FROM (
SELECT
*, count(*) OVER (PARTITION BY a ORDER BY ts) AS c2
FROM (
SELECT
*, count(*) OVER (PARTITION BY a,b ORDER BY ts) AS c1
FROM src
)
)
]]>
</Resource>
<Resource name="ast">
<![CDATA[
LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[$3], c2=[$4])
+- LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[$3], c2=[COUNT() OVER (PARTITION BY $0 ORDER BY $2 NULLS FIRST)])
+- LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[COUNT() OVER (PARTITION BY $0, $1 ORDER BY $2 NULLS FIRST)])
+- LogicalTableScan(table=[[default_catalog, default_database, src]])
]]>
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
OverAggregate(partitionBy=[a], orderBy=[ts ASC], window#0=[COUNT(*) AS w0$o0_0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, ts, w0$o0, w0$o0_0])
+- Exchange(distribution=[forward])
+- Sort(orderBy=[a ASC, ts ASC])
+- Exchange(distribution=[hash[a]])
+- OverAggregate(partitionBy=[a, b], orderBy=[ts ASC], window#0=[COUNT(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, ts, w0$o0])
+- Exchange(distribution=[forward])
+- Sort(orderBy=[a ASC, b ASC, ts ASC])
+- Exchange(distribution=[hash[a, b]])
+- TableSourceScan(table=[[default_catalog, default_database, src]], fields=[a, b, ts])
]]>
</Resource>
</TestCase>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,41 @@ See the License for the specific language governing permissions and
limitations under the License.
-->
<Root>
<TestCase name="testNestedOverAgg">
<Resource name="sql">
<![CDATA[
SELECT *
FROM (
SELECT
*, count(*) OVER (PARTITION BY a ORDER BY ts) AS c2
FROM (
SELECT
*, count(*) OVER (PARTITION BY a,b ORDER BY ts) AS c1
FROM src
)
)
]]>
</Resource>
<Resource name="ast">
<![CDATA[
LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[$3], c2=[$4])
+- LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[$3], c2=[COUNT() OVER (PARTITION BY $0 ORDER BY $2 NULLS FIRST)])
+- LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[COUNT() OVER (PARTITION BY $0, $1 ORDER BY $2 NULLS FIRST)])
+- LogicalWatermarkAssigner(rowtime=[ts], watermark=[$2])
+- LogicalTableScan(table=[[default_catalog, default_database, src]])
]]>
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
OverAggregate(partitionBy=[a], orderBy=[ts ASC], window=[ RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, ts, w0$o0, COUNT(*) AS w0$o0_0])
+- Exchange(distribution=[hash[a]])
+- OverAggregate(partitionBy=[a, b], orderBy=[ts ASC], window=[ RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, ts, COUNT(*) AS w0$o0])
+- Exchange(distribution=[hash[a, b]])
+- WatermarkAssigner(rowtime=[ts], watermark=[ts])
+- TableSourceScan(table=[[default_catalog, default_database, src]], fields=[a, b, ts])
]]>
</Resource>
</TestCase>
<TestCase name="testProctimeBoundedDistinctPartitionedRowOver">
<Resource name="sql">
<![CDATA[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,4 +367,32 @@ class OverAggregateTest extends TableTestBase {
() =>
util.verifyExecPlan("SELECT overAgg(b, a) FROM T GROUP BY TUMBLE(ts, INTERVAL '2' HOUR)"))
}

@Test
def testNestedOverAgg(): Unit = {
util.addTable(s"""
|CREATE TEMPORARY TABLE src (
| a STRING,
| b STRING,
| ts TIMESTAMP_LTZ(3),
| watermark FOR ts as ts
|) WITH (
| 'connector' = 'values'
| ,'bounded' = 'true'
|)
|""".stripMargin)

util.verifyExecPlan(s"""
|SELECT *
|FROM (
| SELECT
| *, count(*) OVER (PARTITION BY a ORDER BY ts) AS c2
| FROM (
| SELECT
| *, count(*) OVER (PARTITION BY a,b ORDER BY ts) AS c1
| FROM src
| )
|)
|""".stripMargin)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -441,4 +441,31 @@ class OverAggregateTest extends TableTestBase {

util.verifyExecPlan(sqlQuery)
}

@Test
def testNestedOverAgg(): Unit = {
util.addTable(s"""
|CREATE TEMPORARY TABLE src (
| a STRING,
| b STRING,
| ts TIMESTAMP_LTZ(3),
| watermark FOR ts as ts
|) WITH (
| 'connector' = 'values'
|)
|""".stripMargin)

util.verifyExecPlan(s"""
|SELECT *
|FROM (
| SELECT
| *, count(*) OVER (PARTITION BY a ORDER BY ts) AS c2
| FROM (
| SELECT
| *, count(*) OVER (PARTITION BY a,b ORDER BY ts) AS c1
| FROM src
| )
|)
|""".stripMargin)
}
}

0 comments on commit a9cddd1

Please sign in to comment.