This is an automated email from the ASF dual-hosted git repository.

godfrey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 1eaf54b73de34fe86057675be090032746cd6049
Author: godfreyhe <[email protected]>
AuthorDate: Wed Dec 23 16:58:38 2020 +0800

    [FLINK-20737][table-planner-blink] Introduce 
StreamPhysicalIncrementalGroupAggregate, and make 
StreamExecIncrementalGroupAggregate only extended from ExecNode
    
    This closes #14478
---
 .../StreamExecIncrementalGroupAggregate.java       | 188 +++++++++++++++++++
 .../plan/metadata/FlinkRelMdColumnInterval.scala   |   8 +-
 .../metadata/FlinkRelMdModifiedMonotonicity.scala  |   2 +-
 .../StreamExecIncrementalGroupAggregate.scala      | 205 ---------------------
 .../StreamPhysicalGlobalGroupAggregate.scala       |   4 +-
 .../StreamPhysicalIncrementalGroupAggregate.scala  | 132 +++++++++++++
 .../physical/stream/IncrementalAggregateRule.scala |  73 +++-----
 .../table/planner/plan/utils/AggregateUtil.scala   |  59 +++++-
 8 files changed, 407 insertions(+), 264 deletions(-)

diff --git 
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecIncrementalGroupAggregate.java
 
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecIncrementalGroupAggregate.java
new file mode 100644
index 0000000..cab48e4
--- /dev/null
+++ 
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecIncrementalGroupAggregate.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.nodes.exec.stream;
+
+import org.apache.flink.api.dag.Transformation;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.transformations.OneInputTransformation;
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
+import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator;
+import org.apache.flink.table.planner.delegation.PlannerBase;
+import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
+import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
+import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
+import org.apache.flink.table.planner.plan.utils.AggregateInfoList;
+import org.apache.flink.table.planner.plan.utils.AggregateUtil;
+import org.apache.flink.table.planner.plan.utils.KeySelectorUtil;
+import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
+import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction;
+import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
+import 
org.apache.flink.table.runtime.operators.aggregate.MiniBatchIncrementalGroupAggFunction;
+import org.apache.flink.table.runtime.operators.bundle.KeyedMapBundleOperator;
+import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.RowType;
+
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.tools.RelBuilder;
+
+import java.util.Arrays;
+import java.util.Collections;
+
+/** Stream {@link ExecNode} for unbounded incremental group aggregate. */
+public class StreamExecIncrementalGroupAggregate extends ExecNodeBase<RowData>
+        implements StreamExecNode<RowData> {
+
+    /** The partial agg's grouping. */
+    private final int[] partialAggGrouping;
+    /** The final agg's grouping. */
+    private final int[] finalAggGrouping;
+    /** The partial agg's original agg calls. */
+    private final AggregateCall[] partialOriginalAggCalls;
+    /** Each element indicates whether the corresponding agg call needs 
`retract` method. */
+    private final boolean[] partialAggCallNeedRetractions;
+    /** The input row type of this node's partial local agg. */
+    private final RowType partialLocalAggInputRowType;
+    /** Whether this node consumes retraction messages. */
+    private final boolean partialAggNeedRetraction;
+
+    public StreamExecIncrementalGroupAggregate(
+            int[] partialAggGrouping,
+            int[] finalAggGrouping,
+            AggregateCall[] partialOriginalAggCalls,
+            boolean[] partialAggCallNeedRetractions,
+            RowType partialLocalAggInputRowType,
+            boolean partialAggNeedRetraction,
+            ExecEdge inputEdge,
+            RowType outputType,
+            String description) {
+        super(Collections.singletonList(inputEdge), outputType, description);
+        this.partialAggGrouping = partialAggGrouping;
+        this.finalAggGrouping = finalAggGrouping;
+        this.partialOriginalAggCalls = partialOriginalAggCalls;
+        this.partialAggCallNeedRetractions = partialAggCallNeedRetractions;
+        this.partialLocalAggInputRowType = partialLocalAggInputRowType;
+        this.partialAggNeedRetraction = partialAggNeedRetraction;
+    }
+
+    @SuppressWarnings("unchecked")
+    @Override
+    protected Transformation<RowData> translateToPlanInternal(PlannerBase 
planner) {
+        final TableConfig config = planner.getTableConfig();
+        final ExecNode<RowData> inputNode = (ExecNode<RowData>) 
getInputNodes().get(0);
+        final Transformation<RowData> inputTransform = 
inputNode.translateToPlan(planner);
+
+        final AggregateInfoList partialLocalAggInfoList =
+                AggregateUtil.createPartialAggInfoList(
+                        partialLocalAggInputRowType,
+                        
JavaScalaConversionUtil.toScala(Arrays.asList(partialOriginalAggCalls)),
+                        partialAggCallNeedRetractions,
+                        partialAggNeedRetraction,
+                        false);
+
+        final GeneratedAggsHandleFunction partialAggsHandler =
+                generateAggsHandler(
+                        "PartialGroupAggsHandler",
+                        partialLocalAggInfoList,
+                        partialAggGrouping.length,
+                        partialLocalAggInfoList.getAccTypes(),
+                        config,
+                        planner.getRelBuilder(),
+                        // the partial aggregate accumulators will be 
buffered, so need copy
+                        true);
+
+        final AggregateInfoList incrementalAggInfo =
+                AggregateUtil.createIncrementalAggInfoList(
+                        partialLocalAggInputRowType,
+                        
JavaScalaConversionUtil.toScala(Arrays.asList(partialOriginalAggCalls)),
+                        partialAggCallNeedRetractions,
+                        partialAggNeedRetraction);
+
+        final GeneratedAggsHandleFunction finalAggsHandler =
+                generateAggsHandler(
+                        "FinalGroupAggsHandler",
+                        incrementalAggInfo,
+                        0,
+                        partialLocalAggInfoList.getAccTypes(),
+                        config,
+                        planner.getRelBuilder(),
+                        // the final aggregate accumulators is not buffered
+                        false);
+
+        final RowDataKeySelector partialKeySelector =
+                KeySelectorUtil.getRowDataSelector(
+                        partialAggGrouping, 
InternalTypeInfo.of(inputNode.getOutputType()));
+        final RowDataKeySelector finalKeySelector =
+                KeySelectorUtil.getRowDataSelector(
+                        finalAggGrouping, 
partialKeySelector.getProducedType());
+
+        final MiniBatchIncrementalGroupAggFunction aggFunction =
+                new MiniBatchIncrementalGroupAggFunction(
+                        partialAggsHandler,
+                        finalAggsHandler,
+                        finalKeySelector,
+                        config.getIdleStateRetention().toMillis());
+
+        final OneInputStreamOperator<RowData, RowData> operator =
+                new KeyedMapBundleOperator<>(
+                        aggFunction, 
AggregateUtil.createMiniBatchTrigger(config));
+
+        // partitioned aggregation
+        final OneInputTransformation<RowData, RowData> transform =
+                new OneInputTransformation<>(
+                        inputTransform,
+                        getDesc(),
+                        operator,
+                        InternalTypeInfo.of(getOutputType()),
+                        inputTransform.getParallelism());
+
+        if (inputsContainSingleton()) {
+            transform.setParallelism(1);
+            transform.setMaxParallelism(1);
+        }
+
+        // set KeyType and Selector for state
+        transform.setStateKeySelector(partialKeySelector);
+        transform.setStateKeyType(partialKeySelector.getProducedType());
+        return transform;
+    }
+
+    private GeneratedAggsHandleFunction generateAggsHandler(
+            String name,
+            AggregateInfoList aggInfoList,
+            int mergedAccOffset,
+            DataType[] mergedAccExternalTypes,
+            TableConfig config,
+            RelBuilder relBuilder,
+            boolean inputFieldCopy) {
+
+        AggsHandlerCodeGenerator generator =
+                new AggsHandlerCodeGenerator(
+                        new CodeGeneratorContext(config),
+                        relBuilder,
+                        
JavaScalaConversionUtil.toScala(partialLocalAggInputRowType.getChildren()),
+                        inputFieldCopy);
+
+        return generator
+                .needMerge(mergedAccOffset, true, mergedAccExternalTypes)
+                .generateAggsHandler(name, aggInfoList);
+    }
+}
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala
index 5153643..10ec872 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala
@@ -538,7 +538,7 @@ class FlinkRelMdColumnInterval private extends 
MetadataHandler[ColumnInterval] {
       case agg: StreamPhysicalGroupAggregate => agg.grouping
       case agg: StreamPhysicalLocalGroupAggregate => agg.grouping
       case agg: StreamPhysicalGlobalGroupAggregate => agg.grouping
-      case agg: StreamExecIncrementalGroupAggregate => agg.partialAggGrouping
+      case agg: StreamPhysicalIncrementalGroupAggregate => 
agg.partialAggGrouping
       case agg: StreamExecGroupWindowAggregate => agg.getGrouping
       case agg: BatchExecGroupAggregateBase => agg.getGrouping ++ 
agg.getAuxGrouping
       case agg: Aggregate => AggregateUtil.checkAndGetFullGroupSet(agg)
@@ -608,9 +608,9 @@ class FlinkRelMdColumnInterval private extends 
MetadataHandler[ColumnInterval] {
             }
           case agg: StreamPhysicalLocalGroupAggregate =>
             getAggCallFromLocalAgg(aggCallIndex, agg.aggCalls, 
agg.getInput.getRowType)
-          case agg: StreamExecIncrementalGroupAggregate
-            if agg.partialAggInfoList.getActualAggregateCalls.length > 
aggCallIndex =>
-            agg.partialAggInfoList.getActualAggregateCalls(aggCallIndex)
+          case agg: StreamPhysicalIncrementalGroupAggregate
+            if agg.partialAggCalls.length > aggCallIndex =>
+            agg.partialAggCalls(aggCallIndex)
           case agg: StreamExecGroupWindowAggregate if agg.aggCalls.length > 
aggCallIndex =>
             agg.aggCalls(aggCallIndex)
           case agg: BatchExecLocalHashAggregate =>
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
index 272b557..9b0932a 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala
@@ -303,7 +303,7 @@ class FlinkRelMdModifiedMonotonicity private extends 
MetadataHandler[ModifiedMon
   }
 
   def getRelModifiedMonotonicity(
-      rel: StreamExecIncrementalGroupAggregate,
+      rel: StreamPhysicalIncrementalGroupAggregate,
       mq: RelMetadataQuery): RelModifiedMonotonicity = {
     getRelModifiedMonotonicityOnAggregate(
       rel.getInput, mq, rel.finalAggCalls.toList, rel.finalAggGrouping)
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala
deleted file mode 100644
index bcb6a93..0000000
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecIncrementalGroupAggregate.scala
+++ /dev/null
@@ -1,205 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.planner.plan.nodes.physical.stream
-
-import org.apache.flink.api.dag.Transformation
-import org.apache.flink.streaming.api.transformations.OneInputTransformation
-import org.apache.flink.table.api.TableConfig
-import org.apache.flink.table.data.RowData
-import org.apache.flink.table.planner.calcite.FlinkTypeFactory
-import org.apache.flink.table.planner.codegen.CodeGeneratorContext
-import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator
-import org.apache.flink.table.planner.delegation.StreamPlanner
-import org.apache.flink.table.planner.plan.nodes.exec.LegacyStreamExecNode
-import org.apache.flink.table.planner.plan.utils.{KeySelectorUtil, _}
-import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction
-import 
org.apache.flink.table.runtime.operators.aggregate.MiniBatchIncrementalGroupAggFunction
-import org.apache.flink.table.runtime.operators.bundle.KeyedMapBundleOperator
-import org.apache.flink.table.runtime.typeutils.InternalTypeInfo
-import org.apache.flink.table.types.DataType
-
-import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
-import org.apache.calcite.rel.`type`.RelDataType
-import org.apache.calcite.rel.core.AggregateCall
-import org.apache.calcite.rel.{RelNode, RelWriter}
-import org.apache.calcite.tools.RelBuilder
-
-import java.util
-
-import scala.collection.JavaConversions._
-
-/**
-  * Stream physical RelNode for unbounded incremental group aggregate.
-  *
-  * <p>Considering the following sub-plan:
-  * {{{
-  *   StreamPhysicalGlobalGroupAggregate (final-global-aggregate)
-  *   +- StreamPhysicalExchange
-  *      +- StreamPhysicalLocalGroupAggregate (final-local-aggregate)
-  *         +- StreamPhysicalGlobalGroupAggregate (partial-global-aggregate)
-  *            +- StreamPhysicalExchange
-  *               +- StreamPhysicalLocalGroupAggregate 
(partial-local-aggregate)
-  * }}}
-  *
-  * partial-global-aggregate and final-local-aggregate can be combined as
-  * this node to share [[org.apache.flink.api.common.state.State]].
-  * now the sub-plan is
-  * {{{
-  *   StreamPhysicalGlobalGroupAggregate (final-global-aggregate)
-  *   +- StreamPhysicalExchange
-  *      +- StreamExecIncrementalGroupAggregate
-  *         +- StreamPhysicalExchange
-  *            +- StreamPhysicalLocalGroupAggregate (partial-local-aggregate)
-  * }}}
-  *
-  * @see [[StreamPhysicalGroupAggregateBase]] for more info.
-  */
-class StreamExecIncrementalGroupAggregate(
-    cluster: RelOptCluster,
-    traitSet: RelTraitSet,
-    inputRel: RelNode,
-    inputRowType: RelDataType,
-    outputRowType: RelDataType,
-    val partialAggInfoList: AggregateInfoList,
-    finalAggInfoList: AggregateInfoList,
-    val finalAggCalls: Seq[AggregateCall],
-    val finalAggGrouping: Array[Int],
-    val partialAggGrouping: Array[Int])
-  extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel)
-  with LegacyStreamExecNode[RowData] {
-
-  override def deriveRowType(): RelDataType = outputRowType
-
-  override def requireWatermark: Boolean = false
-
-  override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): 
RelNode = {
-    new StreamExecIncrementalGroupAggregate(
-      cluster,
-      traitSet,
-      inputs.get(0),
-      inputRowType,
-      outputRowType,
-      partialAggInfoList,
-      finalAggInfoList,
-      finalAggCalls,
-      finalAggGrouping,
-      partialAggGrouping)
-  }
-
-  override def explainTerms(pw: RelWriter): RelWriter = {
-    super.explainTerms(pw)
-      .item("partialAggGrouping",
-        RelExplainUtil.fieldToString(partialAggGrouping, inputRel.getRowType))
-      .item("finalAggGrouping",
-    RelExplainUtil.fieldToString(finalAggGrouping, inputRel.getRowType))
-      .item("select", RelExplainUtil.streamGroupAggregationToString(
-        inputRel.getRowType,
-        getRowType,
-        finalAggInfoList,
-        finalAggGrouping,
-        shuffleKey = Some(partialAggGrouping)))
-  }
-
-  //~ ExecNode methods 
-----------------------------------------------------------
-
-  override protected def translateToPlanInternal(
-      planner: StreamPlanner): Transformation[RowData] = {
-    val config = planner.getTableConfig
-    val inputTransformation = getInputNodes.get(0).translateToPlan(planner)
-      .asInstanceOf[Transformation[RowData]]
-
-    val inRowType = FlinkTypeFactory.toLogicalRowType(inputRel.getRowType)
-    val outRowType = FlinkTypeFactory.toLogicalRowType(outputRowType)
-
-    val partialAggsHandler = generateAggsHandler(
-      "PartialGroupAggsHandler",
-      partialAggInfoList,
-      mergedAccOffset = partialAggGrouping.length,
-      partialAggInfoList.getAccTypes,
-      config,
-      planner.getRelBuilder,
-      // the partial aggregate accumulators will be buffered, so need copy
-      inputFieldCopy = true)
-
-    val finalAggsHandler = generateAggsHandler(
-      "FinalGroupAggsHandler",
-      finalAggInfoList,
-      mergedAccOffset = 0,
-      partialAggInfoList.getAccTypes,
-      config,
-      planner.getRelBuilder,
-      // the final aggregate accumulators is not buffered
-      inputFieldCopy = false)
-
-    val partialKeySelector = KeySelectorUtil.getRowDataSelector(
-      partialAggGrouping,
-      InternalTypeInfo.of(inRowType))
-    val finalKeySelector = KeySelectorUtil.getRowDataSelector(
-      finalAggGrouping,
-      partialKeySelector.getProducedType)
-
-    val aggFunction = new MiniBatchIncrementalGroupAggFunction(
-      partialAggsHandler,
-      finalAggsHandler,
-      finalKeySelector,
-      config.getIdleStateRetention.toMillis)
-
-    val operator = new KeyedMapBundleOperator(
-      aggFunction,
-      AggregateUtil.createMiniBatchTrigger(config))
-
-    // partitioned aggregation
-    val ret = new OneInputTransformation(
-      inputTransformation,
-      getRelDetailedDescription,
-      operator,
-      InternalTypeInfo.of(outRowType),
-      inputTransformation.getParallelism)
-
-    if (inputsContainSingleton()) {
-      ret.setParallelism(1)
-      ret.setMaxParallelism(1)
-    }
-
-    // set KeyType and Selector for state
-    ret.setStateKeySelector(partialKeySelector)
-    ret.setStateKeyType(partialKeySelector.getProducedType)
-    ret
-  }
-
-  def generateAggsHandler(
-    name: String,
-    aggInfoList: AggregateInfoList,
-    mergedAccOffset: Int,
-    mergedAccExternalTypes: Array[DataType],
-    config: TableConfig,
-    relBuilder: RelBuilder,
-    inputFieldCopy: Boolean): GeneratedAggsHandleFunction = {
-
-    val generator = new AggsHandlerCodeGenerator(
-      CodeGeneratorContext(config),
-      relBuilder,
-      FlinkTypeFactory.toLogicalRowType(inputRowType).getChildren,
-      inputFieldCopy)
-
-    generator
-      .needAccumulate()
-      .needMerge(mergedAccOffset, mergedAccOnHeap = true, 
mergedAccExternalTypes)
-      .generateAggsHandler(name, aggInfoList)
-  }
-}
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalGroupAggregate.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalGroupAggregate.scala
index b4e0e2e..7cec0fc 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalGroupAggregate.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalGroupAggregate.scala
@@ -40,9 +40,9 @@ class StreamPhysicalGlobalGroupAggregate(
     outputRowType: RelDataType,
     val grouping: Array[Int],
     val aggCalls: Seq[AggregateCall],
-    aggCallNeedRetractions: Array[Boolean],
+    val aggCallNeedRetractions: Array[Boolean],
     val localAggInputRowType: RelDataType,
-    needRetraction: Boolean,
+    val needRetraction: Boolean,
     val partialFinalType: PartialFinalType)
   extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel) {
 
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalIncrementalGroupAggregate.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalIncrementalGroupAggregate.scala
new file mode 100644
index 0000000..cb203b7
--- /dev/null
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalIncrementalGroupAggregate.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.planner.plan.nodes.physical.stream
+
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory
+import 
org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecIncrementalGroupAggregate
+import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode}
+import org.apache.flink.table.planner.plan.utils._
+
+import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
+import org.apache.calcite.rel.`type`.RelDataType
+import org.apache.calcite.rel.core.AggregateCall
+import org.apache.calcite.rel.{RelNode, RelWriter}
+
+import java.util
+
+/**
+ * Stream physical RelNode for unbounded incremental group aggregate.
+ *
+ * <p>Considering the following sub-plan:
+ * {{{
+ *   StreamPhysicalGlobalGroupAggregate (final-global-aggregate)
+ *   +- StreamPhysicalExchange
+ *      +- StreamPhysicalLocalGroupAggregate (final-local-aggregate)
+ *         +- StreamPhysicalGlobalGroupAggregate (partial-global-aggregate)
+ *            +- StreamPhysicalExchange
+ *               +- StreamPhysicalLocalGroupAggregate (partial-local-aggregate)
+ * }}}
+ *
+ * partial-global-aggregate and final-local-aggregate can be combined as
+ * this node to share [[org.apache.flink.api.common.state.State]].
+ * now the sub-plan is
+ * {{{
+ *   StreamPhysicalGlobalGroupAggregate (final-global-aggregate)
+ *   +- StreamPhysicalExchange
+ *      +- StreamPhysicalIncrementalGroupAggregate
+ *         +- StreamPhysicalExchange
+ *            +- StreamPhysicalLocalGroupAggregate (partial-local-aggregate)
+ * }}}
+ *
+ * @see [[StreamPhysicalGroupAggregateBase]] for more info.
+ */
+class StreamPhysicalIncrementalGroupAggregate(
+    cluster: RelOptCluster,
+    traitSet: RelTraitSet,
+    inputRel: RelNode,
+    val partialAggGrouping: Array[Int],
+    val partialAggCalls: Array[AggregateCall],
+    val finalAggGrouping: Array[Int],
+    val finalAggCalls: Array[AggregateCall],
+    partialOriginalAggCalls: Array[AggregateCall],
+    partialAggCallNeedRetractions: Array[Boolean],
+    partialAggNeedRetraction: Boolean,
+    partialLocalAggInputRowType: RelDataType,
+    partialGlobalAggRowType: RelDataType)
+  extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel) {
+
+  private lazy val incrementalAggInfo = 
AggregateUtil.createIncrementalAggInfoList(
+    FlinkTypeFactory.toLogicalRowType(partialLocalAggInputRowType),
+    partialOriginalAggCalls,
+    partialAggCallNeedRetractions,
+    partialAggNeedRetraction)
+
+  override def deriveRowType(): RelDataType = {
+    AggregateUtil.inferLocalAggRowType(
+      incrementalAggInfo,
+      partialGlobalAggRowType,
+      finalAggGrouping,
+      getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory])
+  }
+
+  override def requireWatermark: Boolean = false
+
+  override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): 
RelNode = {
+    new StreamPhysicalIncrementalGroupAggregate(
+      cluster,
+      traitSet,
+      inputs.get(0),
+      partialAggGrouping,
+      partialAggCalls,
+      finalAggGrouping,
+      finalAggCalls,
+      partialOriginalAggCalls,
+      partialAggCallNeedRetractions,
+      partialAggNeedRetraction,
+      partialLocalAggInputRowType,
+      partialGlobalAggRowType)
+  }
+
+  override def explainTerms(pw: RelWriter): RelWriter = {
+    super.explainTerms(pw)
+      .item("partialAggGrouping",
+        RelExplainUtil.fieldToString(partialAggGrouping, inputRel.getRowType))
+      .item("finalAggGrouping",
+        RelExplainUtil.fieldToString(finalAggGrouping, inputRel.getRowType))
+      .item("select", RelExplainUtil.streamGroupAggregationToString(
+        inputRel.getRowType,
+        getRowType,
+        incrementalAggInfo,
+        finalAggGrouping,
+        shuffleKey = Some(partialAggGrouping)))
+  }
+
+  override def translateToExecNode(): ExecNode[_] = {
+    new StreamExecIncrementalGroupAggregate(
+      partialAggGrouping,
+      finalAggGrouping,
+      partialOriginalAggCalls,
+      partialAggCallNeedRetractions,
+      FlinkTypeFactory.toLogicalRowType(partialLocalAggInputRowType),
+      partialAggNeedRetraction,
+      ExecEdge.DEFAULT,
+      FlinkTypeFactory.toLogicalRowType(getRowType),
+      getRelDetailedDescription
+    )
+  }
+}
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala
index ee8f074..235c441 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/stream/IncrementalAggregateRule.scala
@@ -22,8 +22,8 @@ import org.apache.flink.configuration.ConfigOption
 import org.apache.flink.configuration.ConfigOptions.key
 import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkTypeFactory}
 import org.apache.flink.table.planner.plan.PartialFinalType
-import 
org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecIncrementalGroupAggregate,
 StreamPhysicalExchange, StreamPhysicalGlobalGroupAggregate, 
StreamPhysicalLocalGroupAggregate}
-import org.apache.flink.table.planner.plan.utils.{AggregateInfoList, 
AggregateUtil, DistinctInfo}
+import 
org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamPhysicalExchange,
 StreamPhysicalGlobalGroupAggregate, StreamPhysicalIncrementalGroupAggregate, 
StreamPhysicalLocalGroupAggregate}
+import org.apache.flink.table.planner.plan.utils.AggregateUtil
 import org.apache.flink.util.Preconditions
 
 import org.apache.calcite.plan.RelOptRule.{any, operand}
@@ -37,7 +37,7 @@ import java.util.Collections
  * on final [[StreamPhysicalLocalGroupAggregate]] on partial 
[[StreamPhysicalGlobalGroupAggregate]],
  * and combines the final [[StreamPhysicalLocalGroupAggregate]] and
  * the partial [[StreamPhysicalGlobalGroupAggregate]] into a
- * [[StreamExecIncrementalGroupAggregate]].
+ * [[StreamPhysicalIncrementalGroupAggregate]].
  */
 class IncrementalAggregateRule
   extends RelOptRule(
@@ -69,55 +69,26 @@ class IncrementalAggregateRule
     val exchange: StreamPhysicalExchange = call.rel(1)
     val finalLocalAgg: StreamPhysicalLocalGroupAggregate = call.rel(2)
     val partialGlobalAgg: StreamPhysicalGlobalGroupAggregate = call.rel(3)
-    val aggInputRowType = partialGlobalAgg.localAggInputRowType
-
-    val partialLocalAggInfoList = partialGlobalAgg.localAggInfoList
-    val partialGlobalAggInfoList = partialGlobalAgg.globalAggInfoList
-    val finalGlobalAggInfoList = finalGlobalAgg.globalAggInfoList
-    val aggCalls = finalGlobalAggInfoList.getActualAggregateCalls
-
-    val typeFactory = 
finalGlobalAgg.getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
-
-    // pick distinct info from global which is on state, and modify excludeAcc 
parameter
-    val incrDistinctInfo = partialGlobalAggInfoList.distinctInfos.map { info =>
-      DistinctInfo(
-        info.argIndexes,
-        info.keyType,
-        info.accType,
-        // exclude distinct acc from the aggregate accumulator,
-        // because the output acc only need to contain the count
-        excludeAcc = true,
-        info.dataViewSpec,
-        info.consumeRetraction,
-        info.filterArgs,
-        info.aggIndexes
-      )
-    }
-
-    val incrAggInfoList = AggregateInfoList(
-      // pick local aggs info from local which is on heap
-      partialLocalAggInfoList.aggInfos,
-      partialGlobalAggInfoList.indexOfCountStar,
-      partialGlobalAggInfoList.countStarInserted,
-      incrDistinctInfo)
+    val partialLocalAggInputRowType = partialGlobalAgg.localAggInputRowType
 
-    val incrAggOutputRowType = AggregateUtil.inferLocalAggRowType(
-      incrAggInfoList,
-      partialGlobalAgg.getRowType,
-      finalGlobalAgg.grouping,
-      typeFactory)
+    val partialOriginalAggCalls = partialGlobalAgg.aggCalls.toArray
+    val partialRealAggCalls = 
partialGlobalAgg.localAggInfoList.getActualAggregateCalls
+    val finalRealAggCalls = 
finalGlobalAgg.globalAggInfoList.getActualAggregateCalls
 
-    val incrAgg = new StreamExecIncrementalGroupAggregate(
+    val incrAgg = new StreamPhysicalIncrementalGroupAggregate(
       partialGlobalAgg.getCluster,
       finalLocalAgg.getTraitSet, // extends final local agg traits (ACC trait)
       partialGlobalAgg.getInput,
-      aggInputRowType,
-      incrAggOutputRowType,
-      partialLocalAggInfoList,
-      incrAggInfoList,
-      aggCalls,
+      partialGlobalAgg.grouping,
+      partialRealAggCalls,
       finalLocalAgg.grouping,
-      partialGlobalAgg.grouping)
+      finalRealAggCalls,
+      partialOriginalAggCalls,
+      partialGlobalAgg.aggCallNeedRetractions,
+      partialGlobalAgg.needRetraction,
+      partialLocalAggInputRowType,
+      partialGlobalAgg.getRowType)
+    val incrAggOutputRowType = incrAgg.getRowType
 
     val newExchange = exchange.copy(exchange.getTraitSet, incrAgg, 
exchange.distribution)
 
@@ -135,9 +106,9 @@ class IncrementalAggregateRule
       val localAggInfoList = AggregateUtil.transformToStreamAggregateInfoList(
         // the final agg input is partial agg
         FlinkTypeFactory.toLogicalRowType(partialGlobalAgg.getRowType),
-        aggCalls,
+        finalRealAggCalls,
         // all the aggs do not need retraction
-        Array.fill(aggCalls.length)(false),
+        Array.fill(finalRealAggCalls.length)(false),
         // also do not need count*
         needInputCount = false,
         // the local agg is not works on state
@@ -148,7 +119,7 @@ class IncrementalAggregateRule
         localAggInfoList,
         incrAgg.getRowType,
         finalGlobalAgg.grouping,
-        typeFactory)
+        
finalGlobalAgg.getCluster.getTypeFactory.asInstanceOf[FlinkTypeFactory])
 
       Preconditions.checkState(RelOptUtil.areRowTypesEqual(
         incrAggOutputRowType,
@@ -161,9 +132,9 @@ class IncrementalAggregateRule
         newExchange,
         finalGlobalAgg.getRowType,
         finalGlobalAgg.grouping,
-        aggCalls,
+        finalRealAggCalls,
         // all the aggs do not need retraction
-        Array.fill(aggCalls.length)(false),
+        Array.fill(finalRealAggCalls.length)(false),
         finalGlobalAgg.localAggInputRowType,
         needRetraction = false,
         finalGlobalAgg.partialFinalType)
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
index ea1205b..43f3e3e 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
@@ -175,6 +175,62 @@ object AggregateUtil extends Enumeration {
     map
   }
 
+  def createPartialAggInfoList(
+      partialLocalAggInputRowType: RowType,
+      partialOriginalAggCalls: Seq[AggregateCall],
+      partialAggCallNeedRetractions: Array[Boolean],
+      partialAggNeedRetraction: Boolean,
+      isGlobal: Boolean): AggregateInfoList = {
+    transformToStreamAggregateInfoList(
+      partialLocalAggInputRowType,
+      partialOriginalAggCalls,
+      partialAggCallNeedRetractions,
+      partialAggNeedRetraction,
+      isStateBackendDataViews = isGlobal)
+  }
+
+  def createIncrementalAggInfoList(
+      partialLocalAggInputRowType: RowType,
+      partialOriginalAggCalls: Seq[AggregateCall],
+      partialAggCallNeedRetractions: Array[Boolean],
+      partialAggNeedRetraction: Boolean): AggregateInfoList = {
+    val partialLocalAggInfoList = createPartialAggInfoList(
+      partialLocalAggInputRowType,
+      partialOriginalAggCalls,
+      partialAggCallNeedRetractions,
+      partialAggNeedRetraction,
+      isGlobal = false)
+    val partialGlobalAggInfoList = createPartialAggInfoList(
+      partialLocalAggInputRowType,
+      partialOriginalAggCalls,
+      partialAggCallNeedRetractions,
+      partialAggNeedRetraction,
+      isGlobal = true)
+
+    // pick distinct info from global which is on state, and modify excludeAcc 
parameter
+    val incrementalDistinctInfos = partialGlobalAggInfoList.distinctInfos.map 
{ info =>
+      DistinctInfo(
+        info.argIndexes,
+        info.keyType,
+        info.accType,
+        // exclude distinct acc from the aggregate accumulator,
+        // because the output acc only need to contain the count
+        excludeAcc = true,
+        info.dataViewSpec,
+        info.consumeRetraction,
+        info.filterArgs,
+        info.aggIndexes
+      )
+    }
+
+    AggregateInfoList(
+      // pick local aggs info from local which is on heap
+      partialLocalAggInfoList.aggInfos,
+      partialGlobalAggInfoList.indexOfCountStar,
+      partialGlobalAggInfoList.countStarInserted,
+      incrementalDistinctInfos)
+  }
+
   def deriveAggregateInfoList(
       agg: StreamPhysicalRel,
       groupCount: Int,
@@ -532,6 +588,7 @@ object AggregateUtil extends Enumeration {
   /**
     * Inserts an COUNT(*) aggregate call if needed. The COUNT(*) aggregate 
call is used
     * to count the number of added and retracted input records.
+    *
     * @param needInputCount whether to insert an InputCount aggregate
     * @param aggregateCalls original aggregate calls
     * @return (indexOfCountStar, countStarInserted, newAggCalls)
@@ -885,7 +942,7 @@ object AggregateUtil extends Enumeration {
   def createMiniBatchTrigger(tableConfig: TableConfig): 
CountBundleTrigger[RowData] = {
     val size = tableConfig.getConfiguration.getLong(
       ExecutionConfigOptions.TABLE_EXEC_MINIBATCH_SIZE)
-    if (size <= 0 ) {
+    if (size <= 0) {
       throw new IllegalArgumentException(
         ExecutionConfigOptions.TABLE_EXEC_MINIBATCH_SIZE + " must be > 0.")
     }

Reply via email to