This is an automated email from the ASF dual-hosted git repository.

lincoln pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new bceca90aa7c [FLINK-36773][table] Introduce new Group Aggregate 
Operator with Async State API
bceca90aa7c is described below

commit bceca90aa7c492a15e5a0e3021586b1891ffec5d
Author: Xuyang <xyzhong...@163.com>
AuthorDate: Fri Jan 10 09:21:19 2025 +0800

    [FLINK-36773][table] Introduce new Group Aggregate Operator with Async 
State API
    
    This closes #25680
---
 ...syncKeyedOneInputStreamOperatorTestHarness.java |   6 +-
 .../exec/stream/StreamExecGroupAggregate.java      |  13 ++
 .../table/planner/plan/utils/AggregateUtil.scala   |  17 +++
 .../harness/GroupAggregateHarnessTest.scala        |  91 +++++++++---
 .../planner/runtime/harness/HarnessTestBase.scala  |  24 +++-
 .../runtime/stream/sql/AggregateITCase.scala       |  44 +++++-
 .../operators/aggregate/GroupAggFunction.java      | 155 +++------------------
 .../operators/aggregate/GroupAggFunctionBase.java  |  96 +++++++++++++
 .../async/AsyncStateGroupAggFunction.java          | 110 +++++++++++++++
 .../GroupAggHelper.java}                           | 125 +++++------------
 10 files changed, 422 insertions(+), 259 deletions(-)

diff --git 
a/flink-runtime/src/test/java/org/apache/flink/streaming/util/asyncprocessing/AsyncKeyedOneInputStreamOperatorTestHarness.java
 
b/flink-runtime/src/test/java/org/apache/flink/streaming/util/asyncprocessing/AsyncKeyedOneInputStreamOperatorTestHarness.java
index e2101d88044..d1cee26347b 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/streaming/util/asyncprocessing/AsyncKeyedOneInputStreamOperatorTestHarness.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/streaming/util/asyncprocessing/AsyncKeyedOneInputStreamOperatorTestHarness.java
@@ -35,7 +35,7 @@ import 
org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
 import org.apache.flink.streaming.runtime.streamrecord.RecordAttributes;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus;
-import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
+import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
 import org.apache.flink.util.function.RunnableWithException;
 import org.apache.flink.util.function.ThrowingConsumer;
 
@@ -56,7 +56,7 @@ import static org.assertj.core.api.Assertions.fail;
  * async processing, please use methods of test harness instead of operator.
  */
 public class AsyncKeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
-        extends OneInputStreamOperatorTestHarness<IN, OUT> {
+        extends KeyedOneInputStreamOperatorTestHarness<K, IN, OUT> {
 
     /** Empty if the {@link #operator} is not {@link 
MultipleInputStreamOperator}. */
     private final List<Input<IN>> inputs = new ArrayList<>();
@@ -113,7 +113,7 @@ public class AsyncKeyedOneInputStreamOperatorTestHarness<K, 
IN, OUT>
             int numSubtasks,
             int subtaskIndex)
             throws Exception {
-        super(operatorFactory, maxParallelism, numSubtasks, subtaskIndex);
+        super(operatorFactory, keySelector, keyType, maxParallelism, 
numSubtasks, subtaskIndex);
 
         ClosureCleaner.clean(keySelector, 
ExecutionConfig.ClosureCleanerLevel.RECURSIVE, false);
         config.setStatePartitioner(0, keySelector);
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupAggregate.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupAggregate.java
index cb5a25ed913..5d363198573 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupAggregate.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupAggregate.java
@@ -21,6 +21,7 @@ package org.apache.flink.table.planner.plan.nodes.exec.stream;
 import org.apache.flink.FlinkVersion;
 import org.apache.flink.api.dag.Transformation;
 import org.apache.flink.configuration.ReadableConfig;
+import 
org.apache.flink.runtime.asyncprocessing.operators.AsyncKeyedProcessOperator;
 import org.apache.flink.streaming.api.operators.KeyedProcessOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.transformations.OneInputTransformation;
@@ -47,6 +48,7 @@ import 
org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser;
 import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
 import org.apache.flink.table.runtime.operators.aggregate.GroupAggFunction;
 import 
org.apache.flink.table.runtime.operators.aggregate.MiniBatchGroupAggFunction;
+import 
org.apache.flink.table.runtime.operators.aggregate.async.AsyncStateGroupAggFunction;
 import org.apache.flink.table.runtime.operators.bundle.KeyedMapBundleOperator;
 import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter;
 import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
@@ -227,6 +229,7 @@ public class StreamExecGroupAggregate extends 
StreamExecAggregateBase {
                         .generateRecordEqualiser("GroupAggValueEqualiser");
         final int inputCountIndex = aggInfoList.getIndexOfCountStar();
         final boolean isMiniBatchEnabled = 
MinibatchUtil.isMiniBatchEnabled(config);
+        final boolean isAsyncStateEnabled = 
AggregateUtil.isAsyncStateEnabled(config, aggInfoList);
 
         final OneInputStreamOperator<RowData, RowData> operator;
         if (isMiniBatchEnabled) {
@@ -242,6 +245,16 @@ public class StreamExecGroupAggregate extends 
StreamExecAggregateBase {
             operator =
                     new KeyedMapBundleOperator<>(
                             aggFunction, 
MinibatchUtil.createMiniBatchTrigger(config));
+        } else if (isAsyncStateEnabled) {
+            AsyncStateGroupAggFunction aggFunction =
+                    new AsyncStateGroupAggFunction(
+                            aggsHandler,
+                            recordEqualiser,
+                            accTypes,
+                            inputCountIndex,
+                            generateUpdateBefore,
+                            stateRetentionTime);
+            operator = new AsyncKeyedProcessOperator<>(aggFunction);
         } else {
             GroupAggFunction aggFunction =
                     new GroupAggFunction(
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
index a6ffe0ec44b..968a62af425 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
@@ -17,7 +17,9 @@
  */
 package org.apache.flink.table.planner.plan.utils
 
+import org.apache.flink.configuration.ReadableConfig
 import org.apache.flink.table.api.TableException
+import org.apache.flink.table.api.config.ExecutionConfigOptions
 import org.apache.flink.table.expressions._
 import org.apache.flink.table.expressions.ExpressionUtils.extractValue
 import org.apache.flink.table.functions._
@@ -1175,4 +1177,19 @@ object AggregateUtil extends Enumeration {
           })
       .exists(_.getKind == FunctionKind.TABLE_AGGREGATE)
   }
+
+  def isAsyncStateEnabled(config: ReadableConfig, aggInfoList: 
AggregateInfoList): Boolean = {
+    // Currently, we do not support async state with agg functions that 
include DataView.
+    val containsDataViewInAggInfo =
+      aggInfoList.aggInfos.toStream.stream().anyMatch(agg => 
!agg.viewSpecs.isEmpty)
+
+    val containsDataViewInDistinctInfo =
+      aggInfoList.distinctInfos.toStream
+        .stream()
+        .anyMatch(distinct => distinct.dataViewSpec.isDefined)
+
+    config.get(ExecutionConfigOptions.TABLE_EXEC_ASYNC_STATE_ENABLED) &&
+    !containsDataViewInAggInfo &&
+    !containsDataViewInDistinctInfo
+  }
 }
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/GroupAggregateHarnessTest.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/GroupAggregateHarnessTest.scala
index 9e4f83c1e5b..932605af5f2 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/GroupAggregateHarnessTest.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/GroupAggregateHarnessTest.scala
@@ -22,7 +22,7 @@ import 
org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness,
 import org.apache.flink.table.api.{EnvironmentSettings, _}
 import org.apache.flink.table.api.bridge.scala._
 import 
org.apache.flink.table.api.bridge.scala.internal.StreamTableEnvironmentImpl
-import org.apache.flink.table.api.config.AggregatePhaseStrategy
+import org.apache.flink.table.api.config.{AggregatePhaseStrategy, 
ExecutionConfigOptions}
 import 
org.apache.flink.table.api.config.ExecutionConfigOptions.{TABLE_EXEC_MINIBATCH_ALLOW_LATENCY,
 TABLE_EXEC_MINIBATCH_ENABLED, TABLE_EXEC_MINIBATCH_SIZE}
 import 
org.apache.flink.table.api.config.OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY
 import org.apache.flink.table.data.RowData
@@ -38,6 +38,7 @@ import 
org.apache.flink.testutils.junit.extensions.parameterized.{ParameterizedT
 import org.apache.flink.types.Row
 import org.apache.flink.types.RowKind._
 
+import org.assertj.core.api.Assertions.assertThat
 import org.junit.jupiter.api.{BeforeEach, TestTemplate}
 import org.junit.jupiter.api.extension.ExtendWith
 
@@ -50,7 +51,10 @@ import scala.collection.JavaConversions._
 import scala.collection.mutable
 
 @ExtendWith(Array(classOf[ParameterizedTestExtension]))
-class GroupAggregateHarnessTest(mode: StateBackendMode, miniBatch: 
MiniBatchMode)
+class GroupAggregateHarnessTest(
+    mode: StateBackendMode,
+    miniBatch: MiniBatchMode,
+    enableAsyncState: Boolean)
   extends HarnessTestBase(mode) {
 
   @BeforeEach
@@ -71,6 +75,9 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, 
miniBatch: MiniBatchMode
       case MiniBatchOff =>
         
tableConfig.getConfiguration.removeConfig(TABLE_EXEC_MINIBATCH_ALLOW_LATENCY)
     }
+    tEnv.getConfig.set(
+      ExecutionConfigOptions.TABLE_EXEC_ASYNC_STATE_ENABLED,
+      Boolean.box(enableAsyncState))
   }
 
   @TestTemplate
@@ -112,18 +119,38 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, 
miniBatch: MiniBatchMode
 
     // retract after clean up
     testHarness.processElement(binaryRecord(UPDATE_BEFORE, "ccc", 3L: JLong))
-    // not output
+    // has no output for sync state because of ttl
+    // has output for async state op because it is not supported to set ttl yet
+    if (enableAsyncState) {
+      expectedOutput.add(binaryRecord(DELETE, "ccc", 3L: JLong))
+    }
 
     // accumulate
     testHarness.processElement(binaryRecord(INSERT, "aaa", 4L: JLong))
-    expectedOutput.add(binaryRecord(INSERT, "aaa", 4L: JLong))
+    if (enableAsyncState) {
+      expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 1L: JLong))
+      expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 5L: JLong))
+    } else {
+      expectedOutput.add(binaryRecord(INSERT, "aaa", 4L: JLong))
+    }
+
     testHarness.processElement(binaryRecord(INSERT, "bbb", 2L: JLong))
-    expectedOutput.add(binaryRecord(INSERT, "bbb", 2L: JLong))
+    if (enableAsyncState) {
+      expectedOutput.add(binaryRecord(UPDATE_BEFORE, "bbb", 1L: JLong))
+      expectedOutput.add(binaryRecord(UPDATE_AFTER, "bbb", 3L: JLong))
+    } else {
+      expectedOutput.add(binaryRecord(INSERT, "bbb", 2L: JLong))
+    }
 
     // retract
     testHarness.processElement(binaryRecord(INSERT, "aaa", 5L: JLong))
-    expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 4L: JLong))
-    expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 9L: JLong))
+    if (enableAsyncState) {
+      expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 5L: JLong))
+      expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 10L: JLong))
+    } else {
+      expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 4L: JLong))
+      expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 9L: JLong))
+    }
 
     // accumulate
     testHarness.processElement(binaryRecord(INSERT, "eee", 6L: JLong))
@@ -131,16 +158,32 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, 
miniBatch: MiniBatchMode
 
     // retract
     testHarness.processElement(binaryRecord(INSERT, "aaa", 7L: JLong))
-    expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 9L: JLong))
-    expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 16L: JLong))
+    if (enableAsyncState) {
+      expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 10L: JLong))
+      expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 17L: JLong))
+    } else {
+      expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 9L: JLong))
+      expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 16L: JLong))
+    }
+
     testHarness.processElement(binaryRecord(INSERT, "bbb", 3L: JLong))
-    expectedOutput.add(binaryRecord(UPDATE_BEFORE, "bbb", 2L: JLong))
-    expectedOutput.add(binaryRecord(UPDATE_AFTER, "bbb", 5L: JLong))
+    if (enableAsyncState) {
+      expectedOutput.add(binaryRecord(UPDATE_BEFORE, "bbb", 3L: JLong))
+      expectedOutput.add(binaryRecord(UPDATE_AFTER, "bbb", 6L: JLong))
+    } else {
+      expectedOutput.add(binaryRecord(UPDATE_BEFORE, "bbb", 2L: JLong))
+      expectedOutput.add(binaryRecord(UPDATE_AFTER, "bbb", 5L: JLong))
+    }
 
     // accumulate
     testHarness.processElement(binaryRecord(INSERT, "aaa", 0L: JLong))
-    expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 16L: JLong))
-    expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 16L: JLong))
+    if (enableAsyncState) {
+      expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 17L: JLong))
+      expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 17L: JLong))
+    } else {
+      expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 16L: JLong))
+      expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 16L: JLong))
+    }
 
     val result = testHarness.getOutput
     assertor.assertOutputEqualsSorted("result mismatch", expectedOutput, 
result)
@@ -328,6 +371,13 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, 
miniBatch: MiniBatchMode
     val t1 = tEnv.sqlQuery(sql)
     val testHarness = createHarnessTester(t1.toRetractStream[Row], 
"GroupAggregate")
     val outputTypes = Array(DataTypes.STRING().getLogicalType, 
DataTypes.BIGINT().getLogicalType)
+
+    if (enableAsyncState) {
+      assertThat(isAsyncStateOperator(testHarness)).isTrue
+    } else {
+      assertThat(isAsyncStateOperator(testHarness)).isFalse
+    }
+
     (testHarness, outputTypes)
   }
 
@@ -386,6 +436,9 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, 
miniBatch: MiniBatchMode
       DataTypes.BIGINT().getLogicalType
     )
 
+    // async state agg with data view is not supported yet
+    assertThat(isAsyncStateOperator(testHarness)).isFalse
+
     (testHarness, outputTypes)
   }
 
@@ -397,17 +450,19 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, 
miniBatch: MiniBatchMode
     // expect no exception happens
     testHarness.close()
   }
+
 }
 
 object GroupAggregateHarnessTest {
 
-  @Parameters(name = "StateBackend={0}, MiniBatch={1}")
+  @Parameters(name = "StateBackend={0}, MiniBatch={1},  EnableAsyncState={2}")
   def parameters(): JCollection[Array[java.lang.Object]] = {
     Seq[Array[AnyRef]](
-      Array(HEAP_BACKEND, MiniBatchOff),
-      Array(HEAP_BACKEND, MiniBatchOn),
-      Array(ROCKSDB_BACKEND, MiniBatchOff),
-      Array(ROCKSDB_BACKEND, MiniBatchOn)
+      Array(HEAP_BACKEND, MiniBatchOff, Boolean.box(false)),
+      Array(HEAP_BACKEND, MiniBatchOff, Boolean.box(true)),
+      Array(HEAP_BACKEND, MiniBatchOn, Boolean.box(false)),
+      Array(ROCKSDB_BACKEND, MiniBatchOff, Boolean.box(false)),
+      Array(ROCKSDB_BACKEND, MiniBatchOn, Boolean.box(false))
     )
   }
 }
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/HarnessTestBase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/HarnessTestBase.scala
index 345ce3f7ad4..16cac8ec45d 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/HarnessTestBase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/HarnessTestBase.scala
@@ -21,17 +21,18 @@ import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.dag.Transformation
 import org.apache.flink.api.java.functions.KeySelector
 import org.apache.flink.configuration.Configuration
-import org.apache.flink.runtime.state.CheckpointStorage
-import org.apache.flink.runtime.state.StateBackend
+import 
org.apache.flink.runtime.asyncprocessing.operators.AsyncKeyedProcessOperator
+import org.apache.flink.runtime.state.{CheckpointStorage, StateBackend}
 import org.apache.flink.runtime.state.hashmap.HashMapStateBackend
-import org.apache.flink.runtime.state.storage.FileSystemCheckpointStorage
-import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage
+import org.apache.flink.runtime.state.storage.{FileSystemCheckpointStorage, 
JobManagerCheckpointStorage}
 import org.apache.flink.state.rocksdb.EmbeddedRocksDBStateBackend
 import org.apache.flink.streaming.api.datastream.DataStream
-import org.apache.flink.streaming.api.operators.OneInputStreamOperator
+import org.apache.flink.streaming.api.operators.{OneInputStreamOperator, 
SimpleOperatorFactory}
 import org.apache.flink.streaming.api.transformations.{OneInputTransformation, 
PartitionTransformation}
 import org.apache.flink.streaming.api.watermark.Watermark
+import 
org.apache.flink.streaming.runtime.operators.asyncprocessing.AsyncStateProcessingOperator
 import 
org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, 
OneInputStreamOperatorTestHarness}
+import 
org.apache.flink.streaming.util.asyncprocessing.AsyncKeyedOneInputStreamOperatorTestHarness
 import org.apache.flink.table.data.RowData
 import org.apache.flink.table.planner.JLong
 import org.apache.flink.table.planner.runtime.utils.StreamingTestBase
@@ -73,8 +74,11 @@ class HarnessTestBase(mode: StateBackendMode) extends 
StreamingTestBase {
       operator: OneInputStreamOperator[IN, OUT],
       keySelector: KeySelector[IN, KEY],
       keyType: TypeInformation[KEY]): 
KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT] = {
-    val harness =
+    val harness = if (operator.isInstanceOf[AsyncStateProcessingOperator]) {
+      AsyncKeyedOneInputStreamOperatorTestHarness.create(operator, 
keySelector, keyType)
+    } else {
       new KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT](operator, 
keySelector, keyType)
+    }
     harness.setStateBackend(getStateBackend)
     harness.setCheckpointStorage(getCheckpointStorage)
     harness
@@ -127,6 +131,14 @@ class HarnessTestBase(mode: StateBackendMode) extends 
StreamingTestBase {
   def dropWatermarks(elements: Array[AnyRef]): util.Collection[AnyRef] = {
     elements.filter(e => !e.isInstanceOf[Watermark]).toList
   }
+
+  protected def isAsyncStateOperator(
+      testHarness: KeyedOneInputStreamOperatorTestHarness[RowData, RowData, 
RowData]): Boolean = {
+    testHarness.getOperatorFactory
+      .asInstanceOf[SimpleOperatorFactory[_]]
+      .getOperator
+      .isInstanceOf[AsyncKeyedProcessOperator[_, _, _]]
+  }
 }
 
 object HarnessTestBase {
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
index 339f8a4aa15..db676e2866d 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
@@ -22,6 +22,7 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo
 import org.apache.flink.streaming.api.datastream.DataStream
 import org.apache.flink.table.api._
 import org.apache.flink.table.api.bridge.scala._
+import org.apache.flink.table.api.config.ExecutionConfigOptions
 import org.apache.flink.table.api.internal.TableEnvironmentInternal
 import org.apache.flink.table.legacy.api.Types
 import org.apache.flink.table.planner.factories.TestValuesTableFactory
@@ -30,32 +31,38 @@ import 
org.apache.flink.table.planner.plan.utils.JavaUserDefinedAggFunctions.{Us
 import 
org.apache.flink.table.planner.runtime.batch.sql.agg.{MyPojoAggFunction, 
VarArgsAggFunction}
 import org.apache.flink.table.planner.runtime.utils._
 import 
org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.OverloadedMaxFunction
-import 
org.apache.flink.table.planner.runtime.utils.StreamingWithAggTestBase.AggMode
-import 
org.apache.flink.table.planner.runtime.utils.StreamingWithMiniBatchTestBase.MiniBatchMode
-import 
org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode
+import 
org.apache.flink.table.planner.runtime.utils.StreamingWithAggTestBase.{AggMode, 
LocalGlobalOff, LocalGlobalOn}
+import 
org.apache.flink.table.planner.runtime.utils.StreamingWithMiniBatchTestBase.{MiniBatchMode,
 MiniBatchOff, MiniBatchOn}
+import 
org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.{HEAP_BACKEND,
 ROCKSDB_BACKEND, StateBackendMode}
 import 
org.apache.flink.table.planner.runtime.utils.TimeTestUtil.TimestampAndWatermarkWithOffset
 import 
org.apache.flink.table.planner.runtime.utils.UserDefinedFunctionTestUtils._
 import org.apache.flink.table.planner.utils.DateTimeTestUtil.{localDate, 
localDateTime, localTime => mLocalTime}
 import 
org.apache.flink.table.runtime.functions.aggregate.{ListAggWithRetractAggFunction,
 ListAggWsWithRetractAggFunction}
 import org.apache.flink.table.runtime.typeutils.BigDecimalTypeInfo
-import 
org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension
+import 
org.apache.flink.testutils.junit.extensions.parameterized.{ParameterizedTestExtension,
 Parameters}
 import org.apache.flink.types.Row
 
 import org.assertj.core.api.Assertions.assertThat
 import org.assertj.core.data.Percentage
-import org.junit.jupiter.api.{Disabled, TestTemplate}
+import org.junit.jupiter.api.{BeforeEach, Disabled, TestTemplate}
 import org.junit.jupiter.api.extension.ExtendWith
 
 import java.lang.{Integer => JInt, Long => JLong}
 import java.math.{BigDecimal => JBigDecimal}
 import java.time.Duration
+import java.util
 
 import scala.collection.{mutable, Seq}
+import scala.collection.JavaConversions._
 import scala.math.BigDecimal.double2bigDecimal
 import scala.util.Random
 
 @ExtendWith(Array(classOf[ParameterizedTestExtension]))
-class AggregateITCase(aggMode: AggMode, miniBatch: MiniBatchMode, backend: 
StateBackendMode)
+class AggregateITCase(
+    aggMode: AggMode,
+    miniBatch: MiniBatchMode,
+    backend: StateBackendMode,
+    enableAsyncState: Boolean)
   extends StreamingWithAggTestBase(aggMode, miniBatch, backend) {
 
   val data = List(
@@ -70,6 +77,15 @@ class AggregateITCase(aggMode: AggMode, miniBatch: 
MiniBatchMode, backend: State
     (20000L, 20, "Hello World")
   )
 
+  @BeforeEach
+  override def before(): Unit = {
+    super.before()
+
+    tEnv.getConfig.set(
+      ExecutionConfigOptions.TABLE_EXEC_ASYNC_STATE_ENABLED,
+      Boolean.box(enableAsyncState))
+  }
+
   @TestTemplate
   def testEmptyInputAggregation(): Unit = {
     val data = new mutable.MutableList[(Int, Int)]
@@ -2077,3 +2093,19 @@ class AggregateITCase(aggMode: AggMode, miniBatch: 
MiniBatchMode, backend: State
     tEnv.dropTemporarySystemFunction("PERCENTILE")
   }
 }
+
+object AggregateITCase {
+
+  @Parameters(name = "LocalGlobal={0}, {1}, StateBackend={2}, 
EnableAsyncState={3}")
+  def parameters(): util.Collection[Array[java.lang.Object]] = {
+    Seq[Array[AnyRef]](
+      Array(LocalGlobalOff, MiniBatchOff, HEAP_BACKEND, Boolean.box(false)),
+      Array(LocalGlobalOff, MiniBatchOff, HEAP_BACKEND, Boolean.box(true)),
+      Array(LocalGlobalOff, MiniBatchOn, HEAP_BACKEND, Boolean.box(false)),
+      Array(LocalGlobalOn, MiniBatchOn, HEAP_BACKEND, Boolean.box(false)),
+      Array(LocalGlobalOff, MiniBatchOff, ROCKSDB_BACKEND, Boolean.box(false)),
+      Array(LocalGlobalOff, MiniBatchOn, ROCKSDB_BACKEND, Boolean.box(false)),
+      Array(LocalGlobalOn, MiniBatchOn, ROCKSDB_BACKEND, Boolean.box(false))
+    )
+  }
+}
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java
index b1c21c559b2..fc384f470ab 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java
@@ -19,61 +19,26 @@
 package org.apache.flink.table.runtime.operators.aggregate;
 
 import org.apache.flink.api.common.functions.OpenContext;
-import org.apache.flink.api.common.state.StateTtlConfig;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
-import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
 import org.apache.flink.table.data.RowData;
-import org.apache.flink.table.data.utils.JoinedRowData;
-import org.apache.flink.table.runtime.dataview.PerKeyStateDataViewStore;
-import org.apache.flink.table.runtime.generated.AggsHandleFunction;
 import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction;
 import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser;
-import org.apache.flink.table.runtime.generated.RecordEqualiser;
+import org.apache.flink.table.runtime.operators.aggregate.utils.GroupAggHelper;
 import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
 import org.apache.flink.table.types.logical.LogicalType;
-import org.apache.flink.types.RowKind;
 import org.apache.flink.util.Collector;
 
-import static org.apache.flink.table.data.util.RowDataUtil.isAccumulateMsg;
-import static org.apache.flink.table.data.util.RowDataUtil.isRetractMsg;
-import static 
org.apache.flink.table.runtime.util.StateConfigUtil.createTtlConfig;
-
 /** Aggregate Function used for the groupby (without window) aggregate. */
-public class GroupAggFunction extends KeyedProcessFunction<RowData, RowData, 
RowData> {
+public class GroupAggFunction extends GroupAggFunctionBase {
 
     private static final long serialVersionUID = -4767158666069797704L;
 
-    /** The code generated function used to handle aggregates. */
-    private final GeneratedAggsHandleFunction genAggsHandler;
-
-    /** The code generated equaliser used to equal RowData. */
-    private final GeneratedRecordEqualiser genRecordEqualiser;
-
-    /** The accumulator types. */
-    private final LogicalType[] accTypes;
-
-    /** Used to count the number of added and retracted input records. */
-    private final RecordCounter recordCounter;
-
-    /** Whether this operator will generate UPDATE_BEFORE messages. */
-    private final boolean generateUpdateBefore;
-
-    /** State idle retention time which unit is MILLISECONDS. */
-    private final long stateRetentionTime;
-
-    /** Reused output row. */
-    private transient JoinedRowData resultRow = null;
-
-    // function used to handle all aggregates
-    private transient AggsHandleFunction function = null;
-
-    // function used to equal RowData
-    private transient RecordEqualiser equaliser = null;
-
     // stores the accumulators
     private transient ValueState<RowData> accState = null;
 
+    private transient SyncStateGroupAggHelper aggHelper = null;
+
     /**
      * Creates a {@link GroupAggFunction}.
      *
@@ -93,23 +58,18 @@ public class GroupAggFunction extends 
KeyedProcessFunction<RowData, RowData, Row
             int indexOfCountStar,
             boolean generateUpdateBefore,
             long stateRetentionTime) {
-        this.genAggsHandler = genAggsHandler;
-        this.genRecordEqualiser = genRecordEqualiser;
-        this.accTypes = accTypes;
-        this.recordCounter = RecordCounter.of(indexOfCountStar);
-        this.generateUpdateBefore = generateUpdateBefore;
-        this.stateRetentionTime = stateRetentionTime;
+        super(
+                genAggsHandler,
+                genRecordEqualiser,
+                accTypes,
+                indexOfCountStar,
+                generateUpdateBefore,
+                stateRetentionTime);
     }
 
     @Override
     public void open(OpenContext openContext) throws Exception {
         super.open(openContext);
-        // instantiate function
-        StateTtlConfig ttlConfig = createTtlConfig(stateRetentionTime);
-        function = 
genAggsHandler.newInstance(getRuntimeContext().getUserCodeClassLoader());
-        function.open(new PerKeyStateDataViewStore(getRuntimeContext(), 
ttlConfig));
-        // instantiate equaliser
-        equaliser = 
genRecordEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader());
 
         InternalTypeInfo<RowData> accTypeInfo = 
InternalTypeInfo.ofFields(accTypes);
         ValueStateDescriptor<RowData> accDesc = new 
ValueStateDescriptor<>("accState", accTypeInfo);
@@ -118,100 +78,29 @@ public class GroupAggFunction extends 
KeyedProcessFunction<RowData, RowData, Row
         }
         accState = getRuntimeContext().getState(accDesc);
 
-        resultRow = new JoinedRowData();
+        aggHelper = new SyncStateGroupAggHelper();
     }
 
     @Override
     public void processElement(RowData input, Context ctx, Collector<RowData> 
out)
             throws Exception {
         RowData currentKey = ctx.getCurrentKey();
-        boolean firstRow;
-        RowData accumulators = accState.value();
-        if (null == accumulators) {
-            // Don't create a new accumulator for a retraction message. This
-            // might happen if the retraction message is the first message for 
the
-            // key or after a state clean up.
-            if (isRetractMsg(input)) {
-                return;
-            }
-            firstRow = true;
-            accumulators = function.createAccumulators();
-        } else {
-            firstRow = false;
-        }
-
-        // set accumulators to handler first
-        function.setAccumulators(accumulators);
-        // get previous aggregate result
-        RowData prevAggValue = function.getValue();
+        aggHelper.processElement(input, currentKey, accState.value(), out);
+    }
 
-        // update aggregate result and set to the newRow
-        if (isAccumulateMsg(input)) {
-            // accumulate input
-            function.accumulate(input);
-        } else {
-            // retract input
-            function.retract(input);
+    private class SyncStateGroupAggHelper extends GroupAggHelper {
+        public SyncStateGroupAggHelper() {
+            super(recordCounter, generateUpdateBefore, ttlConfig, function, 
equaliser);
         }
-        // get current aggregate result
-        RowData newAggValue = function.getValue();
-
-        // get accumulator
-        accumulators = function.getAccumulators();
 
-        if (!recordCounter.recordCountIsZero(accumulators)) {
-            // we aggregated at least one record for this key
-
-            // update the state
+        @Override
+        protected void updateAccumulatorsState(RowData accumulators) throws 
Exception {
             accState.update(accumulators);
-
-            // if this was not the first row and we have to emit retractions
-            if (!firstRow) {
-                if (stateRetentionTime <= 0 && equaliser.equals(prevAggValue, 
newAggValue)) {
-                    // newRow is the same as before and state cleaning is not 
enabled.
-                    // We do not emit retraction and acc message.
-                    // If state cleaning is enabled, we have to emit messages 
to prevent too early
-                    // state eviction of downstream operators.
-                    return;
-                } else {
-                    // retract previous result
-                    if (generateUpdateBefore) {
-                        // prepare UPDATE_BEFORE message for previous row
-                        resultRow
-                                .replace(currentKey, prevAggValue)
-                                .setRowKind(RowKind.UPDATE_BEFORE);
-                        out.collect(resultRow);
-                    }
-                    // prepare UPDATE_AFTER message for new row
-                    resultRow.replace(currentKey, 
newAggValue).setRowKind(RowKind.UPDATE_AFTER);
-                }
-            } else {
-                // this is the first, output new result
-                // prepare INSERT message for new row
-                resultRow.replace(currentKey, 
newAggValue).setRowKind(RowKind.INSERT);
-            }
-
-            out.collect(resultRow);
-
-        } else {
-            // we retracted the last record for this key
-            // sent out a delete message
-            if (!firstRow) {
-                // prepare delete message for previous row
-                resultRow.replace(currentKey, 
prevAggValue).setRowKind(RowKind.DELETE);
-                out.collect(resultRow);
-            }
-            // and clear all state
-            accState.clear();
-            // cleanup dataview under current key
-            function.cleanup();
         }
-    }
 
-    @Override
-    public void close() throws Exception {
-        if (function != null) {
-            function.close();
+        @Override
+        protected void clearAccumulatorsState() throws Exception {
+            accState.clear();
         }
     }
 }
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunctionBase.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunctionBase.java
new file mode 100644
index 00000000000..b6a65df3903
--- /dev/null
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunctionBase.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.aggregate;
+
+import org.apache.flink.api.common.functions.OpenContext;
+import org.apache.flink.api.common.state.StateTtlConfig;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.runtime.dataview.PerKeyStateDataViewStore;
+import org.apache.flink.table.runtime.generated.AggsHandleFunction;
+import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction;
+import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser;
+import org.apache.flink.table.runtime.generated.RecordEqualiser;
+import 
org.apache.flink.table.runtime.operators.aggregate.async.AsyncStateGroupAggFunction;
+import org.apache.flink.table.types.logical.LogicalType;
+
+import static 
org.apache.flink.table.runtime.util.StateConfigUtil.createTtlConfig;
+
+/** Base class for {@link GroupAggFunction} and {@link 
AsyncStateGroupAggFunction}. */
+public abstract class GroupAggFunctionBase extends 
KeyedProcessFunction<RowData, RowData, RowData> {
+
+    /** The code generated function used to handle aggregates. */
+    protected final GeneratedAggsHandleFunction genAggsHandler;
+
+    /** The code generated equaliser used to equal RowData. */
+    protected final GeneratedRecordEqualiser genRecordEqualiser;
+
+    /** The accumulator types. */
+    protected final LogicalType[] accTypes;
+
+    /** Used to count the number of added and retracted input records. */
+    protected final RecordCounter recordCounter;
+
+    /** Whether this operator will generate UPDATE_BEFORE messages. */
+    protected final boolean generateUpdateBefore;
+
+    /** State idle retention config. */
+    protected final StateTtlConfig ttlConfig;
+
+    // function used to handle all aggregates
+    protected transient AggsHandleFunction function = null;
+
+    // function used to equal RowData
+    protected transient RecordEqualiser equaliser = null;
+
+    public GroupAggFunctionBase(
+            GeneratedAggsHandleFunction genAggsHandler,
+            GeneratedRecordEqualiser genRecordEqualiser,
+            LogicalType[] accTypes,
+            int indexOfCountStar,
+            boolean generateUpdateBefore,
+            long stateRetentionTime) {
+        this.genAggsHandler = genAggsHandler;
+        this.genRecordEqualiser = genRecordEqualiser;
+        this.accTypes = accTypes;
+        this.recordCounter = RecordCounter.of(indexOfCountStar);
+        this.generateUpdateBefore = generateUpdateBefore;
+        this.ttlConfig = createTtlConfig(stateRetentionTime);
+    }
+
+    @Override
+    public void open(OpenContext openContext) throws Exception {
+        super.open(openContext);
+
+        // instantiate function
+        function = 
genAggsHandler.newInstance(getRuntimeContext().getUserCodeClassLoader());
+        function.open(new PerKeyStateDataViewStore(getRuntimeContext(), 
ttlConfig));
+        // instantiate equaliser
+        equaliser = 
genRecordEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader());
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+
+        if (function != null) {
+            function.close();
+        }
+    }
+}
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/AsyncStateGroupAggFunction.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/AsyncStateGroupAggFunction.java
new file mode 100644
index 00000000000..506c36cc899
--- /dev/null
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/AsyncStateGroupAggFunction.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.aggregate.async;
+
+import org.apache.flink.api.common.functions.OpenContext;
+import org.apache.flink.api.common.state.v2.ValueState;
+import org.apache.flink.runtime.state.v2.ValueStateDescriptor;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction;
+import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser;
+import org.apache.flink.table.runtime.operators.aggregate.GroupAggFunctionBase;
+import org.apache.flink.table.runtime.operators.aggregate.utils.GroupAggHelper;
+import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.util.Collector;
+
+/** Aggregate Function used for the groupby (without window) aggregate with 
async state api. */
+public class AsyncStateGroupAggFunction extends GroupAggFunctionBase {
+
+    private static final long serialVersionUID = 1L;
+
+    // stores the accumulators
+    private transient ValueState<RowData> accState = null;
+
+    private transient AsyncStateGroupAggHelper aggHelper = null;
+
+    /**
+     * Creates a {@link AsyncStateGroupAggFunction}.
+     *
+     * @param genAggsHandler The code generated function used to handle 
aggregates.
+     * @param genRecordEqualiser The code generated equaliser used to equal 
RowData.
+     * @param accTypes The accumulator types.
+     * @param indexOfCountStar The index of COUNT(*) in the aggregates. -1 
when the input doesn't
+     *     contain COUNT(*), i.e. doesn't contain retraction messages. We make 
sure there is a
+     *     COUNT(*) if input stream contains retraction.
+     * @param generateUpdateBefore Whether this operator will generate 
UPDATE_BEFORE messages.
+     * @param stateRetentionTime state idle retention time which unit is 
MILLISECONDS.
+     */
+    public AsyncStateGroupAggFunction(
+            GeneratedAggsHandleFunction genAggsHandler,
+            GeneratedRecordEqualiser genRecordEqualiser,
+            LogicalType[] accTypes,
+            int indexOfCountStar,
+            boolean generateUpdateBefore,
+            long stateRetentionTime) {
+        super(
+                genAggsHandler,
+                genRecordEqualiser,
+                accTypes,
+                indexOfCountStar,
+                generateUpdateBefore,
+                stateRetentionTime);
+    }
+
+    @Override
+    public void open(OpenContext openContext) throws Exception {
+        super.open(openContext);
+
+        InternalTypeInfo<RowData> accTypeInfo = 
InternalTypeInfo.ofFields(accTypes);
+        ValueStateDescriptor<RowData> accDesc = new 
ValueStateDescriptor<>("accState", accTypeInfo);
+        if (ttlConfig.isEnabled()) {
+            accDesc.enableTimeToLive(ttlConfig);
+        }
+
+        accState = ((StreamingRuntimeContext) 
getRuntimeContext()).getValueState(accDesc);
+        aggHelper = new AsyncStateGroupAggHelper();
+    }
+
+    @Override
+    public void processElement(RowData input, Context ctx, Collector<RowData> 
out)
+            throws Exception {
+        RowData currentKey = ctx.getCurrentKey();
+        accState.asyncValue()
+                .thenAccept(acc -> aggHelper.processElement(input, currentKey, 
acc, out));
+    }
+
+    private class AsyncStateGroupAggHelper extends GroupAggHelper {
+
+        public AsyncStateGroupAggHelper() {
+            super(recordCounter, generateUpdateBefore, ttlConfig, function, 
equaliser);
+        }
+
+        @Override
+        protected void updateAccumulatorsState(RowData accumulators) throws 
Exception {
+            accState.asyncUpdate(accumulators);
+        }
+
+        @Override
+        protected void clearAccumulatorsState() throws Exception {
+            accState.asyncClear();
+        }
+    }
+}
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/utils/GroupAggHelper.java
similarity index 52%
copy from 
flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java
copy to 
flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/utils/GroupAggHelper.java
index b1c21c559b2..bc42304d080 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/utils/GroupAggHelper.java
@@ -16,42 +16,22 @@
  * limitations under the License.
  */
 
-package org.apache.flink.table.runtime.operators.aggregate;
+package org.apache.flink.table.runtime.operators.aggregate.utils;
 
-import org.apache.flink.api.common.functions.OpenContext;
 import org.apache.flink.api.common.state.StateTtlConfig;
-import org.apache.flink.api.common.state.ValueState;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
-import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
 import org.apache.flink.table.data.RowData;
 import org.apache.flink.table.data.utils.JoinedRowData;
-import org.apache.flink.table.runtime.dataview.PerKeyStateDataViewStore;
 import org.apache.flink.table.runtime.generated.AggsHandleFunction;
-import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction;
-import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser;
 import org.apache.flink.table.runtime.generated.RecordEqualiser;
-import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
-import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.runtime.operators.aggregate.RecordCounter;
 import org.apache.flink.types.RowKind;
 import org.apache.flink.util.Collector;
 
 import static org.apache.flink.table.data.util.RowDataUtil.isAccumulateMsg;
 import static org.apache.flink.table.data.util.RowDataUtil.isRetractMsg;
-import static 
org.apache.flink.table.runtime.util.StateConfigUtil.createTtlConfig;
 
-/** Aggregate Function used for the groupby (without window) aggregate. */
-public class GroupAggFunction extends KeyedProcessFunction<RowData, RowData, 
RowData> {
-
-    private static final long serialVersionUID = -4767158666069797704L;
-
-    /** The code generated function used to handle aggregates. */
-    private final GeneratedAggsHandleFunction genAggsHandler;
-
-    /** The code generated equaliser used to equal RowData. */
-    private final GeneratedRecordEqualiser genRecordEqualiser;
-
-    /** The accumulator types. */
-    private final LogicalType[] accTypes;
+/** A helper to do the logic of group agg. */
+public abstract class GroupAggHelper {
 
     /** Used to count the number of added and retracted input records. */
     private final RecordCounter recordCounter;
@@ -59,74 +39,36 @@ public class GroupAggFunction extends 
KeyedProcessFunction<RowData, RowData, Row
     /** Whether this operator will generate UPDATE_BEFORE messages. */
     private final boolean generateUpdateBefore;
 
-    /** State idle retention time which unit is MILLISECONDS. */
-    private final long stateRetentionTime;
+    /** State idle retention config. */
+    private final StateTtlConfig ttlConfig;
+
+    /** function used to handle all aggregates. */
+    private final AggsHandleFunction function;
+
+    /** function used to equal RowData. */
+    private final RecordEqualiser equaliser;
 
     /** Reused output row. */
-    private transient JoinedRowData resultRow = null;
-
-    // function used to handle all aggregates
-    private transient AggsHandleFunction function = null;
-
-    // function used to equal RowData
-    private transient RecordEqualiser equaliser = null;
-
-    // stores the accumulators
-    private transient ValueState<RowData> accState = null;
-
-    /**
-     * Creates a {@link GroupAggFunction}.
-     *
-     * @param genAggsHandler The code generated function used to handle 
aggregates.
-     * @param genRecordEqualiser The code generated equaliser used to equal 
RowData.
-     * @param accTypes The accumulator types.
-     * @param indexOfCountStar The index of COUNT(*) in the aggregates. -1 
when the input doesn't
-     *     contain COUNT(*), i.e. doesn't contain retraction messages. We make 
sure there is a
-     *     COUNT(*) if input stream contains retraction.
-     * @param generateUpdateBefore Whether this operator will generate 
UPDATE_BEFORE messages.
-     * @param stateRetentionTime state idle retention time which unit is 
MILLISECONDS.
-     */
-    public GroupAggFunction(
-            GeneratedAggsHandleFunction genAggsHandler,
-            GeneratedRecordEqualiser genRecordEqualiser,
-            LogicalType[] accTypes,
-            int indexOfCountStar,
+    private final JoinedRowData resultRow;
+
+    public GroupAggHelper(
+            RecordCounter recordCounter,
             boolean generateUpdateBefore,
-            long stateRetentionTime) {
-        this.genAggsHandler = genAggsHandler;
-        this.genRecordEqualiser = genRecordEqualiser;
-        this.accTypes = accTypes;
-        this.recordCounter = RecordCounter.of(indexOfCountStar);
+            StateTtlConfig ttlConfig,
+            AggsHandleFunction function,
+            RecordEqualiser equaliser) {
+        this.recordCounter = recordCounter;
         this.generateUpdateBefore = generateUpdateBefore;
-        this.stateRetentionTime = stateRetentionTime;
+        this.ttlConfig = ttlConfig;
+        this.function = function;
+        this.equaliser = equaliser;
+        this.resultRow = new JoinedRowData();
     }
 
-    @Override
-    public void open(OpenContext openContext) throws Exception {
-        super.open(openContext);
-        // instantiate function
-        StateTtlConfig ttlConfig = createTtlConfig(stateRetentionTime);
-        function = 
genAggsHandler.newInstance(getRuntimeContext().getUserCodeClassLoader());
-        function.open(new PerKeyStateDataViewStore(getRuntimeContext(), 
ttlConfig));
-        // instantiate equaliser
-        equaliser = 
genRecordEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader());
-
-        InternalTypeInfo<RowData> accTypeInfo = 
InternalTypeInfo.ofFields(accTypes);
-        ValueStateDescriptor<RowData> accDesc = new 
ValueStateDescriptor<>("accState", accTypeInfo);
-        if (ttlConfig.isEnabled()) {
-            accDesc.enableTimeToLive(ttlConfig);
-        }
-        accState = getRuntimeContext().getState(accDesc);
-
-        resultRow = new JoinedRowData();
-    }
-
-    @Override
-    public void processElement(RowData input, Context ctx, Collector<RowData> 
out)
+    public void processElement(
+            RowData input, RowData currentKey, RowData accumulators, 
Collector<RowData> out)
             throws Exception {
-        RowData currentKey = ctx.getCurrentKey();
         boolean firstRow;
-        RowData accumulators = accState.value();
         if (null == accumulators) {
             // Don't create a new accumulator for a retraction message. This
             // might happen if the retraction message is the first message for 
the
@@ -163,11 +105,11 @@ public class GroupAggFunction extends 
KeyedProcessFunction<RowData, RowData, Row
             // we aggregated at least one record for this key
 
             // update the state
-            accState.update(accumulators);
+            updateAccumulatorsState(accumulators);
 
             // if this was not the first row and we have to emit retractions
             if (!firstRow) {
-                if (stateRetentionTime <= 0 && equaliser.equals(prevAggValue, 
newAggValue)) {
+                if (!ttlConfig.isEnabled() && equaliser.equals(prevAggValue, 
newAggValue)) {
                     // newRow is the same as before and state cleaning is not 
enabled.
                     // We do not emit retraction and acc message.
                     // If state cleaning is enabled, we have to emit messages 
to prevent too early
@@ -202,16 +144,13 @@ public class GroupAggFunction extends 
KeyedProcessFunction<RowData, RowData, Row
                 out.collect(resultRow);
             }
             // and clear all state
-            accState.clear();
+            clearAccumulatorsState();
             // cleanup dataview under current key
             function.cleanup();
         }
     }
 
-    @Override
-    public void close() throws Exception {
-        if (function != null) {
-            function.close();
-        }
-    }
+    protected abstract void updateAccumulatorsState(RowData accumulators) 
throws Exception;
+
+    protected abstract void clearAccumulatorsState() throws Exception;
 }

Reply via email to