This is an automated email from the ASF dual-hosted git repository.
lincoln pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 475f45ba0fb [FLINK-27519][table-planner] Fix column name conflicts in
StreamPhysicalOverAggregate
475f45ba0fb is described below
commit 475f45ba0fb78f81adaa627ed2e8fbdcd71b83f6
Author: lincoln lee <[email protected]>
AuthorDate: Tue Aug 6 19:30:20 2024 +0800
[FLINK-27519][table-planner] Fix column name conflicts in
StreamPhysicalOverAggregate
This closes #25152
---
.../batch/BatchPhysicalOverAggregateRule.scala | 23 ++------------
.../stream/StreamPhysicalOverAggregateRule.scala | 12 +++++--
.../planner/plan/utils/OverAggregateUtil.scala | 23 +++++++++++++-
.../plan/batch/sql/agg/OverAggregateTest.xml | 37 ++++++++++++++++++++++
.../plan/stream/sql/agg/OverAggregateTest.xml | 35 ++++++++++++++++++++
.../plan/batch/sql/agg/OverAggregateTest.scala | 28 ++++++++++++++++
.../plan/stream/sql/agg/OverAggregateTest.scala | 27 ++++++++++++++++
7 files changed, 162 insertions(+), 23 deletions(-)
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalOverAggregateRule.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalOverAggregateRule.scala
index 6d20ce229ec..a428d03f95f 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalOverAggregateRule.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalOverAggregateRule.scala
@@ -26,14 +26,13 @@ import
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalOverAggrega
import
org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalOverAggregate,
BatchPhysicalOverAggregateBase, BatchPhysicalPythonOverAggregate}
import org.apache.flink.table.planner.plan.utils.{AggregateUtil,
OverAggregateUtil, SortUtil}
import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate
-import org.apache.flink.table.planner.typeutils.RowTypeUtils
import org.apache.flink.table.planner.utils.ShortcutUtils
-import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelOptRuleCall,
RelOptUtil}
+import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.plan.RelOptRule._
import org.apache.calcite.rel._
import org.apache.calcite.rel.`type`.RelDataType
-import org.apache.calcite.rel.core.{AggregateCall, Window}
+import org.apache.calcite.rel.core.Window
import org.apache.calcite.rel.core.Window.Group
import org.apache.calcite.rex.{RexInputRef, RexNode, RexShuttle}
import org.apache.calcite.sql.SqlAggFunction
@@ -107,7 +106,7 @@ class BatchPhysicalOverAggregateRule
(group, aggCallToAggFunction)
}
- val outputRowType = inferOutputRowType(
+ val outputRowType = OverAggregateUtil.inferOutputRowType(
logicWindow.getCluster,
inputRowType,
groupToAggCallToAggFunction.flatMap(_._2).map(_._1))
@@ -198,22 +197,6 @@ class BatchPhysicalOverAggregateRule
isSatisfied
}
- private def inferOutputRowType(
- cluster: RelOptCluster,
- inputType: RelDataType,
- aggCalls: Seq[AggregateCall]): RelDataType = {
-
- val inputNameList = inputType.getFieldNames
- val inputTypeList = inputType.getFieldList.asScala.map(field =>
field.getType)
-
- // we should avoid duplicated names with input column names
- val aggNames = RowTypeUtils.getUniqueName(aggCalls.map(_.getName),
inputNameList)
- val aggTypes = aggCalls.map(_.getType)
-
- val typeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
- typeFactory.createStructType(inputTypeList ++ aggTypes, inputNameList ++
aggNames)
- }
-
private def adjustGroup(
groupBuffer: ArrayBuffer[Window.Group],
groupIdx: Int,
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalOverAggregateRule.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalOverAggregateRule.scala
index 60fdaceb93c..7004bcdf5e7 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalOverAggregateRule.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalOverAggregateRule.scala
@@ -22,6 +22,7 @@ import
org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution
import org.apache.flink.table.planner.plan.nodes.FlinkConventions
import
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalOverAggregate
import
org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalOverAggregate
+import org.apache.flink.table.planner.plan.utils.OverAggregateUtil
import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
@@ -66,13 +67,20 @@ class StreamPhysicalOverAggregateRule(config: Config)
extends ConverterRule(conf
.replace(FlinkConventions.STREAM_PHYSICAL)
.replace(requiredDistribution)
val providedTraitSet =
rel.getTraitSet.replace(FlinkConventions.STREAM_PHYSICAL)
- val newInput = RelOptRule.convert(logicWindow.getInput, requiredTraitSet)
+ val input = logicWindow.getInput
+ val newInput = RelOptRule.convert(input, requiredTraitSet)
+
+ val outputRowType = OverAggregateUtil.inferOutputRowType(
+ logicWindow.getCluster,
+ input.getRowType,
+ // only supports one group now
+ logicWindow.groups.get(0).getAggregateCalls(logicWindow).asScala)
new StreamPhysicalOverAggregate(
rel.getCluster,
providedTraitSet,
newInput,
- rel.getRowType,
+ outputRowType,
logicWindow)
}
}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/OverAggregateUtil.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/OverAggregateUtil.scala
index c68d6abe100..b054d2af886 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/OverAggregateUtil.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/OverAggregateUtil.scala
@@ -19,16 +19,21 @@ package org.apache.flink.table.planner.plan.utils
import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.JArrayList
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.nodes.exec.spec.{OverSpec,
PartitionSpec}
import org.apache.flink.table.planner.plan.nodes.exec.spec.OverSpec.GroupSpec
+import org.apache.flink.table.planner.typeutils.RowTypeUtils
+import org.apache.calcite.plan.RelOptCluster
+import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.{RelCollation, RelCollations, RelFieldCollation}
import org.apache.calcite.rel.RelFieldCollation.{Direction, NullDirection}
-import org.apache.calcite.rel.core.Window
+import org.apache.calcite.rel.core.{AggregateCall, Window}
import org.apache.calcite.rex.{RexInputRef, RexLiteral, RexWindowBound}
import org.apache.calcite.sql.`type`.SqlTypeName
import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
object OverAggregateUtil {
@@ -219,4 +224,20 @@ object OverAggregateUtil {
}
}
}
+
+ def inferOutputRowType(
+ cluster: RelOptCluster,
+ inputType: RelDataType,
+ aggCalls: Seq[AggregateCall]): RelDataType = {
+
+ val inputNameList = inputType.getFieldNames
+ val inputTypeList = inputType.getFieldList.asScala.map(_.getType)
+
+ // we should avoid duplicated names with input column names
+ val aggNames = RowTypeUtils.getUniqueName(aggCalls.map(_.getName),
inputNameList)
+ val aggTypes = aggCalls.map(_.getType)
+
+ val typeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
+ typeFactory.createStructType(inputTypeList ++ aggTypes, inputNameList ++
aggNames)
+ }
}
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
index 0ca5ec28442..ed6b45f01fc 100644
---
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.xml
@@ -735,6 +735,43 @@ Calc(select=[a, w0$o0 AS $1, w1$o0 AS $2])
+- Sort(orderBy=[b ASC, c ASC, a DESC])
+- Exchange(distribution=[hash[b]])
+- LegacyTableSourceScan(table=[[default_catalog,
default_database, MyTable, source: [TestTableSource(a, b, c)]]], fields=[a, b,
c])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testNestedOverAgg">
+ <Resource name="sql">
+ <![CDATA[
+SELECT *
+FROM (
+ SELECT
+ *, count(*) OVER (PARTITION BY a ORDER BY ts) AS c2
+ FROM (
+ SELECT
+ *, count(*) OVER (PARTITION BY a,b ORDER BY ts) AS c1
+ FROM src
+ )
+)
+]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[$3], c2=[$4])
++- LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[$3], c2=[COUNT() OVER
(PARTITION BY $0 ORDER BY $2 NULLS FIRST)])
+ +- LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[COUNT() OVER (PARTITION BY
$0, $1 ORDER BY $2 NULLS FIRST)])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src]])
+]]>
+ </Resource>
+ <Resource name="optimized exec plan">
+ <![CDATA[
+OverAggregate(partitionBy=[a], orderBy=[ts ASC], window#0=[COUNT(*) AS w0$o0_0
RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, ts, w0$o0,
w0$o0_0])
++- Exchange(distribution=[forward])
+ +- Sort(orderBy=[a ASC, ts ASC])
+ +- Exchange(distribution=[hash[a]])
+ +- OverAggregate(partitionBy=[a, b], orderBy=[ts ASC],
window#0=[COUNT(*) AS w0$o0 RANG BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW],
select=[a, b, ts, w0$o0])
+ +- Exchange(distribution=[forward])
+ +- Sort(orderBy=[a ASC, b ASC, ts ASC])
+ +- Exchange(distribution=[hash[a, b]])
+ +- TableSourceScan(table=[[default_catalog,
default_database, src]], fields=[a, b, ts])
]]>
</Resource>
</TestCase>
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.xml
index bab37227900..72e3bffb228 100644
---
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.xml
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.xml
@@ -16,6 +16,41 @@ See the License for the specific language governing
permissions and
limitations under the License.
-->
<Root>
+ <TestCase name="testNestedOverAgg">
+ <Resource name="sql">
+ <![CDATA[
+SELECT *
+FROM (
+ SELECT
+ *, count(*) OVER (PARTITION BY a ORDER BY ts) AS c2
+ FROM (
+ SELECT
+ *, count(*) OVER (PARTITION BY a,b ORDER BY ts) AS c1
+ FROM src
+ )
+)
+]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[$3], c2=[$4])
++- LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[$3], c2=[COUNT() OVER
(PARTITION BY $0 ORDER BY $2 NULLS FIRST)])
+ +- LogicalProject(a=[$0], b=[$1], ts=[$2], c1=[COUNT() OVER (PARTITION BY
$0, $1 ORDER BY $2 NULLS FIRST)])
+ +- LogicalWatermarkAssigner(rowtime=[ts], watermark=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src]])
+]]>
+ </Resource>
+ <Resource name="optimized exec plan">
+ <![CDATA[
+OverAggregate(partitionBy=[a], orderBy=[ts ASC], window=[ RANG BETWEEN
UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, ts, w0$o0, COUNT(*) AS
w0$o0_0])
++- Exchange(distribution=[hash[a]])
+ +- OverAggregate(partitionBy=[a, b], orderBy=[ts ASC], window=[ RANG
BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, b, ts, COUNT(*) AS
w0$o0])
+ +- Exchange(distribution=[hash[a, b]])
+ +- WatermarkAssigner(rowtime=[ts], watermark=[ts])
+ +- TableSourceScan(table=[[default_catalog, default_database,
src]], fields=[a, b, ts])
+]]>
+ </Resource>
+ </TestCase>
<TestCase name="testProctimeBoundedDistinctPartitionedRowOver">
<Resource name="sql">
<![CDATA[
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
index f71325beb57..fb95adbf319 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/OverAggregateTest.scala
@@ -367,4 +367,32 @@ class OverAggregateTest extends TableTestBase {
() =>
util.verifyExecPlan("SELECT overAgg(b, a) FROM T GROUP BY TUMBLE(ts,
INTERVAL '2' HOUR)"))
}
+
+ @Test
+ def testNestedOverAgg(): Unit = {
+ util.addTable(s"""
+ |CREATE TEMPORARY TABLE src (
+ | a STRING,
+ | b STRING,
+ | ts TIMESTAMP_LTZ(3),
+ | watermark FOR ts as ts
+ |) WITH (
+ | 'connector' = 'values'
+ | ,'bounded' = 'true'
+ |)
+ |""".stripMargin)
+
+ util.verifyExecPlan(s"""
+ |SELECT *
+ |FROM (
+ | SELECT
+ | *, count(*) OVER (PARTITION BY a ORDER BY ts)
AS c2
+ | FROM (
+ | SELECT
+ | *, count(*) OVER (PARTITION BY a,b ORDER BY
ts) AS c1
+ | FROM src
+ | )
+ |)
+ |""".stripMargin)
+ }
}
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.scala
index 65e6fb40eb9..e290bbce225 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.scala
@@ -441,4 +441,31 @@ class OverAggregateTest extends TableTestBase {
util.verifyExecPlan(sqlQuery)
}
+
+ @Test
+ def testNestedOverAgg(): Unit = {
+ util.addTable(s"""
+ |CREATE TEMPORARY TABLE src (
+ | a STRING,
+ | b STRING,
+ | ts TIMESTAMP_LTZ(3),
+ | watermark FOR ts as ts
+ |) WITH (
+ | 'connector' = 'values'
+ |)
+ |""".stripMargin)
+
+ util.verifyExecPlan(s"""
+ |SELECT *
+ |FROM (
+ | SELECT
+ | *, count(*) OVER (PARTITION BY a ORDER BY ts)
AS c2
+ | FROM (
+ | SELECT
+ | *, count(*) OVER (PARTITION BY a,b ORDER BY
ts) AS c1
+ | FROM src
+ | )
+ |)
+ |""".stripMargin)
+ }
}