This is an automated email from the ASF dual-hosted git repository. lincoln pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new bceca90aa7c [FLINK-36773][table] Introduce new Group Aggregate Operator with Async State API bceca90aa7c is described below commit bceca90aa7c492a15e5a0e3021586b1891ffec5d Author: Xuyang <xyzhong...@163.com> AuthorDate: Fri Jan 10 09:21:19 2025 +0800 [FLINK-36773][table] Introduce new Group Aggregate Operator with Async State API This closes #25680 --- ...syncKeyedOneInputStreamOperatorTestHarness.java | 6 +- .../exec/stream/StreamExecGroupAggregate.java | 13 ++ .../table/planner/plan/utils/AggregateUtil.scala | 17 +++ .../harness/GroupAggregateHarnessTest.scala | 91 +++++++++--- .../planner/runtime/harness/HarnessTestBase.scala | 24 +++- .../runtime/stream/sql/AggregateITCase.scala | 44 +++++- .../operators/aggregate/GroupAggFunction.java | 155 +++------------------ .../operators/aggregate/GroupAggFunctionBase.java | 96 +++++++++++++ .../async/AsyncStateGroupAggFunction.java | 110 +++++++++++++++ .../GroupAggHelper.java} | 125 +++++------------ 10 files changed, 422 insertions(+), 259 deletions(-) diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/util/asyncprocessing/AsyncKeyedOneInputStreamOperatorTestHarness.java b/flink-runtime/src/test/java/org/apache/flink/streaming/util/asyncprocessing/AsyncKeyedOneInputStreamOperatorTestHarness.java index e2101d88044..d1cee26347b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/util/asyncprocessing/AsyncKeyedOneInputStreamOperatorTestHarness.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/util/asyncprocessing/AsyncKeyedOneInputStreamOperatorTestHarness.java @@ -35,7 +35,7 @@ import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker; import org.apache.flink.streaming.runtime.streamrecord.RecordAttributes; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.watermarkstatus.WatermarkStatus; -import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; import org.apache.flink.util.function.RunnableWithException; import org.apache.flink.util.function.ThrowingConsumer; @@ -56,7 +56,7 @@ import static org.assertj.core.api.Assertions.fail; * async processing, please use methods of test harness instead of operator. */ public class AsyncKeyedOneInputStreamOperatorTestHarness<K, IN, OUT> - extends OneInputStreamOperatorTestHarness<IN, OUT> { + extends KeyedOneInputStreamOperatorTestHarness<K, IN, OUT> { /** Empty if the {@link #operator} is not {@link MultipleInputStreamOperator}. */ private final List<Input<IN>> inputs = new ArrayList<>(); @@ -113,7 +113,7 @@ public class AsyncKeyedOneInputStreamOperatorTestHarness<K, IN, OUT> int numSubtasks, int subtaskIndex) throws Exception { - super(operatorFactory, maxParallelism, numSubtasks, subtaskIndex); + super(operatorFactory, keySelector, keyType, maxParallelism, numSubtasks, subtaskIndex); ClosureCleaner.clean(keySelector, ExecutionConfig.ClosureCleanerLevel.RECURSIVE, false); config.setStatePartitioner(0, keySelector); diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupAggregate.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupAggregate.java index cb5a25ed913..5d363198573 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupAggregate.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGroupAggregate.java @@ -21,6 +21,7 @@ package org.apache.flink.table.planner.plan.nodes.exec.stream; import org.apache.flink.FlinkVersion; import org.apache.flink.api.dag.Transformation; import org.apache.flink.configuration.ReadableConfig; +import org.apache.flink.runtime.asyncprocessing.operators.AsyncKeyedProcessOperator; import org.apache.flink.streaming.api.operators.KeyedProcessOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.transformations.OneInputTransformation; @@ -47,6 +48,7 @@ import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser; import org.apache.flink.table.runtime.keyselector.RowDataKeySelector; import org.apache.flink.table.runtime.operators.aggregate.GroupAggFunction; import org.apache.flink.table.runtime.operators.aggregate.MiniBatchGroupAggFunction; +import org.apache.flink.table.runtime.operators.aggregate.async.AsyncStateGroupAggFunction; import org.apache.flink.table.runtime.operators.bundle.KeyedMapBundleOperator; import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter; import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; @@ -227,6 +229,7 @@ public class StreamExecGroupAggregate extends StreamExecAggregateBase { .generateRecordEqualiser("GroupAggValueEqualiser"); final int inputCountIndex = aggInfoList.getIndexOfCountStar(); final boolean isMiniBatchEnabled = MinibatchUtil.isMiniBatchEnabled(config); + final boolean isAsyncStateEnabled = AggregateUtil.isAsyncStateEnabled(config, aggInfoList); final OneInputStreamOperator<RowData, RowData> operator; if (isMiniBatchEnabled) { @@ -242,6 +245,16 @@ public class StreamExecGroupAggregate extends StreamExecAggregateBase { operator = new KeyedMapBundleOperator<>( aggFunction, MinibatchUtil.createMiniBatchTrigger(config)); + } else if (isAsyncStateEnabled) { + AsyncStateGroupAggFunction aggFunction = + new AsyncStateGroupAggFunction( + aggsHandler, + recordEqualiser, + accTypes, + inputCountIndex, + generateUpdateBefore, + stateRetentionTime); + operator = new AsyncKeyedProcessOperator<>(aggFunction); } else { GroupAggFunction aggFunction = new GroupAggFunction( diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala index a6ffe0ec44b..968a62af425 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala @@ -17,7 +17,9 @@ */ package org.apache.flink.table.planner.plan.utils +import org.apache.flink.configuration.ReadableConfig import org.apache.flink.table.api.TableException +import org.apache.flink.table.api.config.ExecutionConfigOptions import org.apache.flink.table.expressions._ import org.apache.flink.table.expressions.ExpressionUtils.extractValue import org.apache.flink.table.functions._ @@ -1175,4 +1177,19 @@ object AggregateUtil extends Enumeration { }) .exists(_.getKind == FunctionKind.TABLE_AGGREGATE) } + + def isAsyncStateEnabled(config: ReadableConfig, aggInfoList: AggregateInfoList): Boolean = { + // Currently, we do not support async state with agg functions that include DataView. + val containsDataViewInAggInfo = + aggInfoList.aggInfos.toStream.stream().anyMatch(agg => !agg.viewSpecs.isEmpty) + + val containsDataViewInDistinctInfo = + aggInfoList.distinctInfos.toStream + .stream() + .anyMatch(distinct => distinct.dataViewSpec.isDefined) + + config.get(ExecutionConfigOptions.TABLE_EXEC_ASYNC_STATE_ENABLED) && + !containsDataViewInAggInfo && + !containsDataViewInDistinctInfo + } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/GroupAggregateHarnessTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/GroupAggregateHarnessTest.scala index 9e4f83c1e5b..932605af5f2 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/GroupAggregateHarnessTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/GroupAggregateHarnessTest.scala @@ -22,7 +22,7 @@ import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, import org.apache.flink.table.api.{EnvironmentSettings, _} import org.apache.flink.table.api.bridge.scala._ import org.apache.flink.table.api.bridge.scala.internal.StreamTableEnvironmentImpl -import org.apache.flink.table.api.config.AggregatePhaseStrategy +import org.apache.flink.table.api.config.{AggregatePhaseStrategy, ExecutionConfigOptions} import org.apache.flink.table.api.config.ExecutionConfigOptions.{TABLE_EXEC_MINIBATCH_ALLOW_LATENCY, TABLE_EXEC_MINIBATCH_ENABLED, TABLE_EXEC_MINIBATCH_SIZE} import org.apache.flink.table.api.config.OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY import org.apache.flink.table.data.RowData @@ -38,6 +38,7 @@ import org.apache.flink.testutils.junit.extensions.parameterized.{ParameterizedT import org.apache.flink.types.Row import org.apache.flink.types.RowKind._ +import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.{BeforeEach, TestTemplate} import org.junit.jupiter.api.extension.ExtendWith @@ -50,7 +51,10 @@ import scala.collection.JavaConversions._ import scala.collection.mutable @ExtendWith(Array(classOf[ParameterizedTestExtension])) -class GroupAggregateHarnessTest(mode: StateBackendMode, miniBatch: MiniBatchMode) +class GroupAggregateHarnessTest( + mode: StateBackendMode, + miniBatch: MiniBatchMode, + enableAsyncState: Boolean) extends HarnessTestBase(mode) { @BeforeEach @@ -71,6 +75,9 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, miniBatch: MiniBatchMode case MiniBatchOff => tableConfig.getConfiguration.removeConfig(TABLE_EXEC_MINIBATCH_ALLOW_LATENCY) } + tEnv.getConfig.set( + ExecutionConfigOptions.TABLE_EXEC_ASYNC_STATE_ENABLED, + Boolean.box(enableAsyncState)) } @TestTemplate @@ -112,18 +119,38 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, miniBatch: MiniBatchMode // retract after clean up testHarness.processElement(binaryRecord(UPDATE_BEFORE, "ccc", 3L: JLong)) - // not output + // has no output for sync state because of ttl + // has output for async state op because it is not supported to set ttl yet + if (enableAsyncState) { + expectedOutput.add(binaryRecord(DELETE, "ccc", 3L: JLong)) + } // accumulate testHarness.processElement(binaryRecord(INSERT, "aaa", 4L: JLong)) - expectedOutput.add(binaryRecord(INSERT, "aaa", 4L: JLong)) + if (enableAsyncState) { + expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 1L: JLong)) + expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 5L: JLong)) + } else { + expectedOutput.add(binaryRecord(INSERT, "aaa", 4L: JLong)) + } + testHarness.processElement(binaryRecord(INSERT, "bbb", 2L: JLong)) - expectedOutput.add(binaryRecord(INSERT, "bbb", 2L: JLong)) + if (enableAsyncState) { + expectedOutput.add(binaryRecord(UPDATE_BEFORE, "bbb", 1L: JLong)) + expectedOutput.add(binaryRecord(UPDATE_AFTER, "bbb", 3L: JLong)) + } else { + expectedOutput.add(binaryRecord(INSERT, "bbb", 2L: JLong)) + } // retract testHarness.processElement(binaryRecord(INSERT, "aaa", 5L: JLong)) - expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 4L: JLong)) - expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 9L: JLong)) + if (enableAsyncState) { + expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 5L: JLong)) + expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 10L: JLong)) + } else { + expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 4L: JLong)) + expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 9L: JLong)) + } // accumulate testHarness.processElement(binaryRecord(INSERT, "eee", 6L: JLong)) @@ -131,16 +158,32 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, miniBatch: MiniBatchMode // retract testHarness.processElement(binaryRecord(INSERT, "aaa", 7L: JLong)) - expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 9L: JLong)) - expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 16L: JLong)) + if (enableAsyncState) { + expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 10L: JLong)) + expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 17L: JLong)) + } else { + expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 9L: JLong)) + expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 16L: JLong)) + } + testHarness.processElement(binaryRecord(INSERT, "bbb", 3L: JLong)) - expectedOutput.add(binaryRecord(UPDATE_BEFORE, "bbb", 2L: JLong)) - expectedOutput.add(binaryRecord(UPDATE_AFTER, "bbb", 5L: JLong)) + if (enableAsyncState) { + expectedOutput.add(binaryRecord(UPDATE_BEFORE, "bbb", 3L: JLong)) + expectedOutput.add(binaryRecord(UPDATE_AFTER, "bbb", 6L: JLong)) + } else { + expectedOutput.add(binaryRecord(UPDATE_BEFORE, "bbb", 2L: JLong)) + expectedOutput.add(binaryRecord(UPDATE_AFTER, "bbb", 5L: JLong)) + } // accumulate testHarness.processElement(binaryRecord(INSERT, "aaa", 0L: JLong)) - expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 16L: JLong)) - expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 16L: JLong)) + if (enableAsyncState) { + expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 17L: JLong)) + expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 17L: JLong)) + } else { + expectedOutput.add(binaryRecord(UPDATE_BEFORE, "aaa", 16L: JLong)) + expectedOutput.add(binaryRecord(UPDATE_AFTER, "aaa", 16L: JLong)) + } val result = testHarness.getOutput assertor.assertOutputEqualsSorted("result mismatch", expectedOutput, result) @@ -328,6 +371,13 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, miniBatch: MiniBatchMode val t1 = tEnv.sqlQuery(sql) val testHarness = createHarnessTester(t1.toRetractStream[Row], "GroupAggregate") val outputTypes = Array(DataTypes.STRING().getLogicalType, DataTypes.BIGINT().getLogicalType) + + if (enableAsyncState) { + assertThat(isAsyncStateOperator(testHarness)).isTrue + } else { + assertThat(isAsyncStateOperator(testHarness)).isFalse + } + (testHarness, outputTypes) } @@ -386,6 +436,9 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, miniBatch: MiniBatchMode DataTypes.BIGINT().getLogicalType ) + // async state agg with data view is not supported yet + assertThat(isAsyncStateOperator(testHarness)).isFalse + (testHarness, outputTypes) } @@ -397,17 +450,19 @@ class GroupAggregateHarnessTest(mode: StateBackendMode, miniBatch: MiniBatchMode // expect no exception happens testHarness.close() } + } object GroupAggregateHarnessTest { - @Parameters(name = "StateBackend={0}, MiniBatch={1}") + @Parameters(name = "StateBackend={0}, MiniBatch={1}, EnableAsyncState={2}") def parameters(): JCollection[Array[java.lang.Object]] = { Seq[Array[AnyRef]]( - Array(HEAP_BACKEND, MiniBatchOff), - Array(HEAP_BACKEND, MiniBatchOn), - Array(ROCKSDB_BACKEND, MiniBatchOff), - Array(ROCKSDB_BACKEND, MiniBatchOn) + Array(HEAP_BACKEND, MiniBatchOff, Boolean.box(false)), + Array(HEAP_BACKEND, MiniBatchOff, Boolean.box(true)), + Array(HEAP_BACKEND, MiniBatchOn, Boolean.box(false)), + Array(ROCKSDB_BACKEND, MiniBatchOff, Boolean.box(false)), + Array(ROCKSDB_BACKEND, MiniBatchOn, Boolean.box(false)) ) } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/HarnessTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/HarnessTestBase.scala index 345ce3f7ad4..16cac8ec45d 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/HarnessTestBase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/harness/HarnessTestBase.scala @@ -21,17 +21,18 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.dag.Transformation import org.apache.flink.api.java.functions.KeySelector import org.apache.flink.configuration.Configuration -import org.apache.flink.runtime.state.CheckpointStorage -import org.apache.flink.runtime.state.StateBackend +import org.apache.flink.runtime.asyncprocessing.operators.AsyncKeyedProcessOperator +import org.apache.flink.runtime.state.{CheckpointStorage, StateBackend} import org.apache.flink.runtime.state.hashmap.HashMapStateBackend -import org.apache.flink.runtime.state.storage.FileSystemCheckpointStorage -import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage +import org.apache.flink.runtime.state.storage.{FileSystemCheckpointStorage, JobManagerCheckpointStorage} import org.apache.flink.state.rocksdb.EmbeddedRocksDBStateBackend import org.apache.flink.streaming.api.datastream.DataStream -import org.apache.flink.streaming.api.operators.OneInputStreamOperator +import org.apache.flink.streaming.api.operators.{OneInputStreamOperator, SimpleOperatorFactory} import org.apache.flink.streaming.api.transformations.{OneInputTransformation, PartitionTransformation} import org.apache.flink.streaming.api.watermark.Watermark +import org.apache.flink.streaming.runtime.operators.asyncprocessing.AsyncStateProcessingOperator import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, OneInputStreamOperatorTestHarness} +import org.apache.flink.streaming.util.asyncprocessing.AsyncKeyedOneInputStreamOperatorTestHarness import org.apache.flink.table.data.RowData import org.apache.flink.table.planner.JLong import org.apache.flink.table.planner.runtime.utils.StreamingTestBase @@ -73,8 +74,11 @@ class HarnessTestBase(mode: StateBackendMode) extends StreamingTestBase { operator: OneInputStreamOperator[IN, OUT], keySelector: KeySelector[IN, KEY], keyType: TypeInformation[KEY]): KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT] = { - val harness = + val harness = if (operator.isInstanceOf[AsyncStateProcessingOperator]) { + AsyncKeyedOneInputStreamOperatorTestHarness.create(operator, keySelector, keyType) + } else { new KeyedOneInputStreamOperatorTestHarness[KEY, IN, OUT](operator, keySelector, keyType) + } harness.setStateBackend(getStateBackend) harness.setCheckpointStorage(getCheckpointStorage) harness @@ -127,6 +131,14 @@ class HarnessTestBase(mode: StateBackendMode) extends StreamingTestBase { def dropWatermarks(elements: Array[AnyRef]): util.Collection[AnyRef] = { elements.filter(e => !e.isInstanceOf[Watermark]).toList } + + protected def isAsyncStateOperator( + testHarness: KeyedOneInputStreamOperatorTestHarness[RowData, RowData, RowData]): Boolean = { + testHarness.getOperatorFactory + .asInstanceOf[SimpleOperatorFactory[_]] + .getOperator + .isInstanceOf[AsyncKeyedProcessOperator[_, _, _]] + } } object HarnessTestBase { diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala index 339f8a4aa15..db676e2866d 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala @@ -22,6 +22,7 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.table.api._ import org.apache.flink.table.api.bridge.scala._ +import org.apache.flink.table.api.config.ExecutionConfigOptions import org.apache.flink.table.api.internal.TableEnvironmentInternal import org.apache.flink.table.legacy.api.Types import org.apache.flink.table.planner.factories.TestValuesTableFactory @@ -30,32 +31,38 @@ import org.apache.flink.table.planner.plan.utils.JavaUserDefinedAggFunctions.{Us import org.apache.flink.table.planner.runtime.batch.sql.agg.{MyPojoAggFunction, VarArgsAggFunction} import org.apache.flink.table.planner.runtime.utils._ import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.OverloadedMaxFunction -import org.apache.flink.table.planner.runtime.utils.StreamingWithAggTestBase.AggMode -import org.apache.flink.table.planner.runtime.utils.StreamingWithMiniBatchTestBase.MiniBatchMode -import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode +import org.apache.flink.table.planner.runtime.utils.StreamingWithAggTestBase.{AggMode, LocalGlobalOff, LocalGlobalOn} +import org.apache.flink.table.planner.runtime.utils.StreamingWithMiniBatchTestBase.{MiniBatchMode, MiniBatchOff, MiniBatchOn} +import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.{HEAP_BACKEND, ROCKSDB_BACKEND, StateBackendMode} import org.apache.flink.table.planner.runtime.utils.TimeTestUtil.TimestampAndWatermarkWithOffset import org.apache.flink.table.planner.runtime.utils.UserDefinedFunctionTestUtils._ import org.apache.flink.table.planner.utils.DateTimeTestUtil.{localDate, localDateTime, localTime => mLocalTime} import org.apache.flink.table.runtime.functions.aggregate.{ListAggWithRetractAggFunction, ListAggWsWithRetractAggFunction} import org.apache.flink.table.runtime.typeutils.BigDecimalTypeInfo -import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension +import org.apache.flink.testutils.junit.extensions.parameterized.{ParameterizedTestExtension, Parameters} import org.apache.flink.types.Row import org.assertj.core.api.Assertions.assertThat import org.assertj.core.data.Percentage -import org.junit.jupiter.api.{Disabled, TestTemplate} +import org.junit.jupiter.api.{BeforeEach, Disabled, TestTemplate} import org.junit.jupiter.api.extension.ExtendWith import java.lang.{Integer => JInt, Long => JLong} import java.math.{BigDecimal => JBigDecimal} import java.time.Duration +import java.util import scala.collection.{mutable, Seq} +import scala.collection.JavaConversions._ import scala.math.BigDecimal.double2bigDecimal import scala.util.Random @ExtendWith(Array(classOf[ParameterizedTestExtension])) -class AggregateITCase(aggMode: AggMode, miniBatch: MiniBatchMode, backend: StateBackendMode) +class AggregateITCase( + aggMode: AggMode, + miniBatch: MiniBatchMode, + backend: StateBackendMode, + enableAsyncState: Boolean) extends StreamingWithAggTestBase(aggMode, miniBatch, backend) { val data = List( @@ -70,6 +77,15 @@ class AggregateITCase(aggMode: AggMode, miniBatch: MiniBatchMode, backend: State (20000L, 20, "Hello World") ) + @BeforeEach + override def before(): Unit = { + super.before() + + tEnv.getConfig.set( + ExecutionConfigOptions.TABLE_EXEC_ASYNC_STATE_ENABLED, + Boolean.box(enableAsyncState)) + } + @TestTemplate def testEmptyInputAggregation(): Unit = { val data = new mutable.MutableList[(Int, Int)] @@ -2077,3 +2093,19 @@ class AggregateITCase(aggMode: AggMode, miniBatch: MiniBatchMode, backend: State tEnv.dropTemporarySystemFunction("PERCENTILE") } } + +object AggregateITCase { + + @Parameters(name = "LocalGlobal={0}, {1}, StateBackend={2}, EnableAsyncState={3}") + def parameters(): util.Collection[Array[java.lang.Object]] = { + Seq[Array[AnyRef]]( + Array(LocalGlobalOff, MiniBatchOff, HEAP_BACKEND, Boolean.box(false)), + Array(LocalGlobalOff, MiniBatchOff, HEAP_BACKEND, Boolean.box(true)), + Array(LocalGlobalOff, MiniBatchOn, HEAP_BACKEND, Boolean.box(false)), + Array(LocalGlobalOn, MiniBatchOn, HEAP_BACKEND, Boolean.box(false)), + Array(LocalGlobalOff, MiniBatchOff, ROCKSDB_BACKEND, Boolean.box(false)), + Array(LocalGlobalOff, MiniBatchOn, ROCKSDB_BACKEND, Boolean.box(false)), + Array(LocalGlobalOn, MiniBatchOn, ROCKSDB_BACKEND, Boolean.box(false)) + ) + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java index b1c21c559b2..fc384f470ab 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java @@ -19,61 +19,26 @@ package org.apache.flink.table.runtime.operators.aggregate; import org.apache.flink.api.common.functions.OpenContext; -import org.apache.flink.api.common.state.StateTtlConfig; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.streaming.api.functions.KeyedProcessFunction; import org.apache.flink.table.data.RowData; -import org.apache.flink.table.data.utils.JoinedRowData; -import org.apache.flink.table.runtime.dataview.PerKeyStateDataViewStore; -import org.apache.flink.table.runtime.generated.AggsHandleFunction; import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction; import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser; -import org.apache.flink.table.runtime.generated.RecordEqualiser; +import org.apache.flink.table.runtime.operators.aggregate.utils.GroupAggHelper; import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; import org.apache.flink.table.types.logical.LogicalType; -import org.apache.flink.types.RowKind; import org.apache.flink.util.Collector; -import static org.apache.flink.table.data.util.RowDataUtil.isAccumulateMsg; -import static org.apache.flink.table.data.util.RowDataUtil.isRetractMsg; -import static org.apache.flink.table.runtime.util.StateConfigUtil.createTtlConfig; - /** Aggregate Function used for the groupby (without window) aggregate. */ -public class GroupAggFunction extends KeyedProcessFunction<RowData, RowData, RowData> { +public class GroupAggFunction extends GroupAggFunctionBase { private static final long serialVersionUID = -4767158666069797704L; - /** The code generated function used to handle aggregates. */ - private final GeneratedAggsHandleFunction genAggsHandler; - - /** The code generated equaliser used to equal RowData. */ - private final GeneratedRecordEqualiser genRecordEqualiser; - - /** The accumulator types. */ - private final LogicalType[] accTypes; - - /** Used to count the number of added and retracted input records. */ - private final RecordCounter recordCounter; - - /** Whether this operator will generate UPDATE_BEFORE messages. */ - private final boolean generateUpdateBefore; - - /** State idle retention time which unit is MILLISECONDS. */ - private final long stateRetentionTime; - - /** Reused output row. */ - private transient JoinedRowData resultRow = null; - - // function used to handle all aggregates - private transient AggsHandleFunction function = null; - - // function used to equal RowData - private transient RecordEqualiser equaliser = null; - // stores the accumulators private transient ValueState<RowData> accState = null; + private transient SyncStateGroupAggHelper aggHelper = null; + /** * Creates a {@link GroupAggFunction}. * @@ -93,23 +58,18 @@ public class GroupAggFunction extends KeyedProcessFunction<RowData, RowData, Row int indexOfCountStar, boolean generateUpdateBefore, long stateRetentionTime) { - this.genAggsHandler = genAggsHandler; - this.genRecordEqualiser = genRecordEqualiser; - this.accTypes = accTypes; - this.recordCounter = RecordCounter.of(indexOfCountStar); - this.generateUpdateBefore = generateUpdateBefore; - this.stateRetentionTime = stateRetentionTime; + super( + genAggsHandler, + genRecordEqualiser, + accTypes, + indexOfCountStar, + generateUpdateBefore, + stateRetentionTime); } @Override public void open(OpenContext openContext) throws Exception { super.open(openContext); - // instantiate function - StateTtlConfig ttlConfig = createTtlConfig(stateRetentionTime); - function = genAggsHandler.newInstance(getRuntimeContext().getUserCodeClassLoader()); - function.open(new PerKeyStateDataViewStore(getRuntimeContext(), ttlConfig)); - // instantiate equaliser - equaliser = genRecordEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader()); InternalTypeInfo<RowData> accTypeInfo = InternalTypeInfo.ofFields(accTypes); ValueStateDescriptor<RowData> accDesc = new ValueStateDescriptor<>("accState", accTypeInfo); @@ -118,100 +78,29 @@ public class GroupAggFunction extends KeyedProcessFunction<RowData, RowData, Row } accState = getRuntimeContext().getState(accDesc); - resultRow = new JoinedRowData(); + aggHelper = new SyncStateGroupAggHelper(); } @Override public void processElement(RowData input, Context ctx, Collector<RowData> out) throws Exception { RowData currentKey = ctx.getCurrentKey(); - boolean firstRow; - RowData accumulators = accState.value(); - if (null == accumulators) { - // Don't create a new accumulator for a retraction message. This - // might happen if the retraction message is the first message for the - // key or after a state clean up. - if (isRetractMsg(input)) { - return; - } - firstRow = true; - accumulators = function.createAccumulators(); - } else { - firstRow = false; - } - - // set accumulators to handler first - function.setAccumulators(accumulators); - // get previous aggregate result - RowData prevAggValue = function.getValue(); + aggHelper.processElement(input, currentKey, accState.value(), out); + } - // update aggregate result and set to the newRow - if (isAccumulateMsg(input)) { - // accumulate input - function.accumulate(input); - } else { - // retract input - function.retract(input); + private class SyncStateGroupAggHelper extends GroupAggHelper { + public SyncStateGroupAggHelper() { + super(recordCounter, generateUpdateBefore, ttlConfig, function, equaliser); } - // get current aggregate result - RowData newAggValue = function.getValue(); - - // get accumulator - accumulators = function.getAccumulators(); - if (!recordCounter.recordCountIsZero(accumulators)) { - // we aggregated at least one record for this key - - // update the state + @Override + protected void updateAccumulatorsState(RowData accumulators) throws Exception { accState.update(accumulators); - - // if this was not the first row and we have to emit retractions - if (!firstRow) { - if (stateRetentionTime <= 0 && equaliser.equals(prevAggValue, newAggValue)) { - // newRow is the same as before and state cleaning is not enabled. - // We do not emit retraction and acc message. - // If state cleaning is enabled, we have to emit messages to prevent too early - // state eviction of downstream operators. - return; - } else { - // retract previous result - if (generateUpdateBefore) { - // prepare UPDATE_BEFORE message for previous row - resultRow - .replace(currentKey, prevAggValue) - .setRowKind(RowKind.UPDATE_BEFORE); - out.collect(resultRow); - } - // prepare UPDATE_AFTER message for new row - resultRow.replace(currentKey, newAggValue).setRowKind(RowKind.UPDATE_AFTER); - } - } else { - // this is the first, output new result - // prepare INSERT message for new row - resultRow.replace(currentKey, newAggValue).setRowKind(RowKind.INSERT); - } - - out.collect(resultRow); - - } else { - // we retracted the last record for this key - // sent out a delete message - if (!firstRow) { - // prepare delete message for previous row - resultRow.replace(currentKey, prevAggValue).setRowKind(RowKind.DELETE); - out.collect(resultRow); - } - // and clear all state - accState.clear(); - // cleanup dataview under current key - function.cleanup(); } - } - @Override - public void close() throws Exception { - if (function != null) { - function.close(); + @Override + protected void clearAccumulatorsState() throws Exception { + accState.clear(); } } } diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunctionBase.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunctionBase.java new file mode 100644 index 00000000000..b6a65df3903 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunctionBase.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate; + +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.state.StateTtlConfig; +import org.apache.flink.streaming.api.functions.KeyedProcessFunction; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.dataview.PerKeyStateDataViewStore; +import org.apache.flink.table.runtime.generated.AggsHandleFunction; +import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction; +import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.generated.RecordEqualiser; +import org.apache.flink.table.runtime.operators.aggregate.async.AsyncStateGroupAggFunction; +import org.apache.flink.table.types.logical.LogicalType; + +import static org.apache.flink.table.runtime.util.StateConfigUtil.createTtlConfig; + +/** Base class for {@link GroupAggFunction} and {@link AsyncStateGroupAggFunction}. */ +public abstract class GroupAggFunctionBase extends KeyedProcessFunction<RowData, RowData, RowData> { + + /** The code generated function used to handle aggregates. */ + protected final GeneratedAggsHandleFunction genAggsHandler; + + /** The code generated equaliser used to equal RowData. */ + protected final GeneratedRecordEqualiser genRecordEqualiser; + + /** The accumulator types. */ + protected final LogicalType[] accTypes; + + /** Used to count the number of added and retracted input records. */ + protected final RecordCounter recordCounter; + + /** Whether this operator will generate UPDATE_BEFORE messages. */ + protected final boolean generateUpdateBefore; + + /** State idle retention config. */ + protected final StateTtlConfig ttlConfig; + + // function used to handle all aggregates + protected transient AggsHandleFunction function = null; + + // function used to equal RowData + protected transient RecordEqualiser equaliser = null; + + public GroupAggFunctionBase( + GeneratedAggsHandleFunction genAggsHandler, + GeneratedRecordEqualiser genRecordEqualiser, + LogicalType[] accTypes, + int indexOfCountStar, + boolean generateUpdateBefore, + long stateRetentionTime) { + this.genAggsHandler = genAggsHandler; + this.genRecordEqualiser = genRecordEqualiser; + this.accTypes = accTypes; + this.recordCounter = RecordCounter.of(indexOfCountStar); + this.generateUpdateBefore = generateUpdateBefore; + this.ttlConfig = createTtlConfig(stateRetentionTime); + } + + @Override + public void open(OpenContext openContext) throws Exception { + super.open(openContext); + + // instantiate function + function = genAggsHandler.newInstance(getRuntimeContext().getUserCodeClassLoader()); + function.open(new PerKeyStateDataViewStore(getRuntimeContext(), ttlConfig)); + // instantiate equaliser + equaliser = genRecordEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader()); + } + + @Override + public void close() throws Exception { + super.close(); + + if (function != null) { + function.close(); + } + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/AsyncStateGroupAggFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/AsyncStateGroupAggFunction.java new file mode 100644 index 00000000000..506c36cc899 --- /dev/null +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/async/AsyncStateGroupAggFunction.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.aggregate.async; + +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.state.v2.ValueState; +import org.apache.flink.runtime.state.v2.ValueStateDescriptor; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction; +import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser; +import org.apache.flink.table.runtime.operators.aggregate.GroupAggFunctionBase; +import org.apache.flink.table.runtime.operators.aggregate.utils.GroupAggHelper; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.util.Collector; + +/** Aggregate Function used for the groupby (without window) aggregate with async state api. */ +public class AsyncStateGroupAggFunction extends GroupAggFunctionBase { + + private static final long serialVersionUID = 1L; + + // stores the accumulators + private transient ValueState<RowData> accState = null; + + private transient AsyncStateGroupAggHelper aggHelper = null; + + /** + * Creates a {@link AsyncStateGroupAggFunction}. + * + * @param genAggsHandler The code generated function used to handle aggregates. + * @param genRecordEqualiser The code generated equaliser used to equal RowData. + * @param accTypes The accumulator types. + * @param indexOfCountStar The index of COUNT(*) in the aggregates. -1 when the input doesn't + * contain COUNT(*), i.e. doesn't contain retraction messages. We make sure there is a + * COUNT(*) if input stream contains retraction. + * @param generateUpdateBefore Whether this operator will generate UPDATE_BEFORE messages. + * @param stateRetentionTime state idle retention time which unit is MILLISECONDS. + */ + public AsyncStateGroupAggFunction( + GeneratedAggsHandleFunction genAggsHandler, + GeneratedRecordEqualiser genRecordEqualiser, + LogicalType[] accTypes, + int indexOfCountStar, + boolean generateUpdateBefore, + long stateRetentionTime) { + super( + genAggsHandler, + genRecordEqualiser, + accTypes, + indexOfCountStar, + generateUpdateBefore, + stateRetentionTime); + } + + @Override + public void open(OpenContext openContext) throws Exception { + super.open(openContext); + + InternalTypeInfo<RowData> accTypeInfo = InternalTypeInfo.ofFields(accTypes); + ValueStateDescriptor<RowData> accDesc = new ValueStateDescriptor<>("accState", accTypeInfo); + if (ttlConfig.isEnabled()) { + accDesc.enableTimeToLive(ttlConfig); + } + + accState = ((StreamingRuntimeContext) getRuntimeContext()).getValueState(accDesc); + aggHelper = new AsyncStateGroupAggHelper(); + } + + @Override + public void processElement(RowData input, Context ctx, Collector<RowData> out) + throws Exception { + RowData currentKey = ctx.getCurrentKey(); + accState.asyncValue() + .thenAccept(acc -> aggHelper.processElement(input, currentKey, acc, out)); + } + + private class AsyncStateGroupAggHelper extends GroupAggHelper { + + public AsyncStateGroupAggHelper() { + super(recordCounter, generateUpdateBefore, ttlConfig, function, equaliser); + } + + @Override + protected void updateAccumulatorsState(RowData accumulators) throws Exception { + accState.asyncUpdate(accumulators); + } + + @Override + protected void clearAccumulatorsState() throws Exception { + accState.asyncClear(); + } + } +} diff --git a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/utils/GroupAggHelper.java similarity index 52% copy from flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java copy to flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/utils/GroupAggHelper.java index b1c21c559b2..bc42304d080 100644 --- a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/GroupAggFunction.java +++ b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/aggregate/utils/GroupAggHelper.java @@ -16,42 +16,22 @@ * limitations under the License. */ -package org.apache.flink.table.runtime.operators.aggregate; +package org.apache.flink.table.runtime.operators.aggregate.utils; -import org.apache.flink.api.common.functions.OpenContext; import org.apache.flink.api.common.state.StateTtlConfig; -import org.apache.flink.api.common.state.ValueState; -import org.apache.flink.api.common.state.ValueStateDescriptor; -import org.apache.flink.streaming.api.functions.KeyedProcessFunction; import org.apache.flink.table.data.RowData; import org.apache.flink.table.data.utils.JoinedRowData; -import org.apache.flink.table.runtime.dataview.PerKeyStateDataViewStore; import org.apache.flink.table.runtime.generated.AggsHandleFunction; -import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction; -import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser; import org.apache.flink.table.runtime.generated.RecordEqualiser; -import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; -import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.runtime.operators.aggregate.RecordCounter; import org.apache.flink.types.RowKind; import org.apache.flink.util.Collector; import static org.apache.flink.table.data.util.RowDataUtil.isAccumulateMsg; import static org.apache.flink.table.data.util.RowDataUtil.isRetractMsg; -import static org.apache.flink.table.runtime.util.StateConfigUtil.createTtlConfig; -/** Aggregate Function used for the groupby (without window) aggregate. */ -public class GroupAggFunction extends KeyedProcessFunction<RowData, RowData, RowData> { - - private static final long serialVersionUID = -4767158666069797704L; - - /** The code generated function used to handle aggregates. */ - private final GeneratedAggsHandleFunction genAggsHandler; - - /** The code generated equaliser used to equal RowData. */ - private final GeneratedRecordEqualiser genRecordEqualiser; - - /** The accumulator types. */ - private final LogicalType[] accTypes; +/** A helper to do the logic of group agg. */ +public abstract class GroupAggHelper { /** Used to count the number of added and retracted input records. */ private final RecordCounter recordCounter; @@ -59,74 +39,36 @@ public class GroupAggFunction extends KeyedProcessFunction<RowData, RowData, Row /** Whether this operator will generate UPDATE_BEFORE messages. */ private final boolean generateUpdateBefore; - /** State idle retention time which unit is MILLISECONDS. */ - private final long stateRetentionTime; + /** State idle retention config. */ + private final StateTtlConfig ttlConfig; + + /** function used to handle all aggregates. */ + private final AggsHandleFunction function; + + /** function used to equal RowData. */ + private final RecordEqualiser equaliser; /** Reused output row. */ - private transient JoinedRowData resultRow = null; - - // function used to handle all aggregates - private transient AggsHandleFunction function = null; - - // function used to equal RowData - private transient RecordEqualiser equaliser = null; - - // stores the accumulators - private transient ValueState<RowData> accState = null; - - /** - * Creates a {@link GroupAggFunction}. - * - * @param genAggsHandler The code generated function used to handle aggregates. - * @param genRecordEqualiser The code generated equaliser used to equal RowData. - * @param accTypes The accumulator types. - * @param indexOfCountStar The index of COUNT(*) in the aggregates. -1 when the input doesn't - * contain COUNT(*), i.e. doesn't contain retraction messages. We make sure there is a - * COUNT(*) if input stream contains retraction. - * @param generateUpdateBefore Whether this operator will generate UPDATE_BEFORE messages. - * @param stateRetentionTime state idle retention time which unit is MILLISECONDS. - */ - public GroupAggFunction( - GeneratedAggsHandleFunction genAggsHandler, - GeneratedRecordEqualiser genRecordEqualiser, - LogicalType[] accTypes, - int indexOfCountStar, + private final JoinedRowData resultRow; + + public GroupAggHelper( + RecordCounter recordCounter, boolean generateUpdateBefore, - long stateRetentionTime) { - this.genAggsHandler = genAggsHandler; - this.genRecordEqualiser = genRecordEqualiser; - this.accTypes = accTypes; - this.recordCounter = RecordCounter.of(indexOfCountStar); + StateTtlConfig ttlConfig, + AggsHandleFunction function, + RecordEqualiser equaliser) { + this.recordCounter = recordCounter; this.generateUpdateBefore = generateUpdateBefore; - this.stateRetentionTime = stateRetentionTime; + this.ttlConfig = ttlConfig; + this.function = function; + this.equaliser = equaliser; + this.resultRow = new JoinedRowData(); } - @Override - public void open(OpenContext openContext) throws Exception { - super.open(openContext); - // instantiate function - StateTtlConfig ttlConfig = createTtlConfig(stateRetentionTime); - function = genAggsHandler.newInstance(getRuntimeContext().getUserCodeClassLoader()); - function.open(new PerKeyStateDataViewStore(getRuntimeContext(), ttlConfig)); - // instantiate equaliser - equaliser = genRecordEqualiser.newInstance(getRuntimeContext().getUserCodeClassLoader()); - - InternalTypeInfo<RowData> accTypeInfo = InternalTypeInfo.ofFields(accTypes); - ValueStateDescriptor<RowData> accDesc = new ValueStateDescriptor<>("accState", accTypeInfo); - if (ttlConfig.isEnabled()) { - accDesc.enableTimeToLive(ttlConfig); - } - accState = getRuntimeContext().getState(accDesc); - - resultRow = new JoinedRowData(); - } - - @Override - public void processElement(RowData input, Context ctx, Collector<RowData> out) + public void processElement( + RowData input, RowData currentKey, RowData accumulators, Collector<RowData> out) throws Exception { - RowData currentKey = ctx.getCurrentKey(); boolean firstRow; - RowData accumulators = accState.value(); if (null == accumulators) { // Don't create a new accumulator for a retraction message. This // might happen if the retraction message is the first message for the @@ -163,11 +105,11 @@ public class GroupAggFunction extends KeyedProcessFunction<RowData, RowData, Row // we aggregated at least one record for this key // update the state - accState.update(accumulators); + updateAccumulatorsState(accumulators); // if this was not the first row and we have to emit retractions if (!firstRow) { - if (stateRetentionTime <= 0 && equaliser.equals(prevAggValue, newAggValue)) { + if (!ttlConfig.isEnabled() && equaliser.equals(prevAggValue, newAggValue)) { // newRow is the same as before and state cleaning is not enabled. // We do not emit retraction and acc message. // If state cleaning is enabled, we have to emit messages to prevent too early @@ -202,16 +144,13 @@ public class GroupAggFunction extends KeyedProcessFunction<RowData, RowData, Row out.collect(resultRow); } // and clear all state - accState.clear(); + clearAccumulatorsState(); // cleanup dataview under current key function.cleanup(); } } - @Override - public void close() throws Exception { - if (function != null) { - function.close(); - } - } + protected abstract void updateAccumulatorsState(RowData accumulators) throws Exception; + + protected abstract void clearAccumulatorsState() throws Exception; }